Source code for formate.dynamic_quotes

#!/usr/bin/env python3
#
#  dynamic_quotes.py
r"""
Applies "dynamic quotes" to Python source code.

The rules are:

* Use double quotes ``"`` where possible.
* Use single quotes ``'`` for empty strings and single characters (``a``, ``\n`` etc.).
* Leave the quotes unchanged for multiline strings, f strings and raw strings.
"""
#
#  Copyright © 2020-2021 Dominic Davis-Foster <dominic@davis-foster.co.uk>
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
#  EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
#  MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
#  IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
#  DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
#  OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
#  OR OTHER DEALINGS IN THE SOFTWARE.
#

# stdlib
import ast
import re
import sys
from typing import Mapping, Union

# 3rd party
from domdf_python_tools.utils import double_repr_string

# this package
from formate.utils import Rewriter

__all__ = ("dynamic_quotes", )

if sys.version_info >= (3, 12):
	StrOrConstant = ast.Constant
else:
	StrOrConstant = Union[ast.Str, ast.Constant]


class QuoteRewriter(Rewriter):

	if sys.version_info[:2] < (3, 8):  # pragma: no cover (py38+)

		def visit_Str(self, node: ast.Str) -> None:
			self.rewrite_quotes_for_node(node)
	else:  # pragma: no cover (<py38)

		def visit_Constant(self, node: ast.Constant) -> None:
			if isinstance(node.value, str):
				self.rewrite_quotes_for_node(node)
			else:
				self.generic_visit(node)

	def visit_definition(self, node: Union[ast.ClassDef, ast.FunctionDef, ast.AsyncFunctionDef]) -> None:
		"""
		Mark the docstring of the function or class to identify it later.

		:param node:
		"""

		if node.body and isinstance(node.body[0], ast.Expr):
			doc_node = node.body[0].value
			doc_node.is_docstring = True  # type: ignore[attr-defined]

		self.generic_visit(node)

	def visit_ClassDef(self, node: ast.ClassDef) -> None:
		self.visit_definition(node)

	def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
		self.visit_definition(node)

	def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> None:
		self.visit_definition(node)

	def rewrite_quotes_for_node(self, node: StrOrConstant) -> None:
		"""
		Mark the area for rewriting quotes in the given node.

		:param node:
		"""

		text_range = self.tokens.get_text_range(node)

		if text_range == (0, 0):
			return

		string = self.source[text_range[0]:text_range[1]]

		if getattr(node, "is_docstring", False):
			# TODO: format docstring with triple quotes and correct indentation
			return
		else:

			if sys.version_info >= (3, 12):  # pragma: no cover (<py312)
				value = node.value
			else:  # pragma: no cover (py312+)
				value = node.s

			if string in {'""', "''"}:
				self.record_replacement(text_range, "''")
			elif not re.match("^[\"']", string):
				return
			elif len(value) == 1:
				self.record_replacement(text_range, repr(value))
			elif '\n' in string:
				return
			elif '\n' in value or "\\n" in value:
				return
			else:
				self.record_replacement(
						text_range,
						double_repr_string(value).translate(_surrogate_translator),
						)


[docs]def dynamic_quotes(source: str) -> str: """ Reformats quotes in the given source, and returns the reformatted source. :param source: The source to reformat. :returns: The reformatted source. """ return QuoteRewriter(source).rewrite()
class _LazyTranslate(Mapping): """ Escapes surrogates in the range U+D800 to U+DFFF, so they are left unchanged in the source. """ def __iter__(self): # noqa: MAN002 raise NotImplementedError def __len__(self): # noqa: MAN002 raise NotImplementedError def __getitem__(self, item: int) -> str: if item in range(55296, 57343): return repr(chr(item)).strip("'") else: return chr(item) _surrogate_translator = _LazyTranslate()