#!/usr/bin/env python3
#
# reformat_generics.py
"""
Formats generics (``List[...]``, ``Union[...]`` etc.) at the module and class level.
Example output, with a line length of ``100``:
.. code-block:: python
ParamsMappingValueType = Union[str, bytes, int, float, Iterable[Union[str, bytes, int, float]]]
Data = Union[None, str, bytes, MutableMapping[str, Any], Iterable[Tuple[str, Optional[str]]], IO]
ParamsType = Union[
Mapping[Union[str, bytes, int, float], ParamsMappingValueType],
Union[str, bytes],
Tuple[Union[str, bytes, int, float], ParamsMappingValueType],
None
]
"""
#
# Copyright © 2020-2021 Dominic Davis-Foster <dominic@davis-foster.co.uk>
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
# OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE
# OR OTHER DEALINGS IN THE SOFTWARE.
#
# stdlib
import ast
import sys
import typing
from collections.abc import Collection
from io import StringIO
from textwrap import indent as indent_string
# 3rd party
import astatine
import asttokens
from domdf_python_tools.stringlist import DelimitedList, StringList
from domdf_python_tools.words import TAB
__all__ = ("reformat_generics", "Generic", "List")
collection_types = {"Union", "List", "Tuple", "Set", "Dict", "Callable", "Optional", "Literal"}
subclasses = [Collection]
# pylint: disable=loop-global-usage
while subclasses:
subclass = subclasses.pop(0)
if subclass.__module__ == "collections.abc":
collection_types.add(subclass.__name__)
subclasses.extend(subclass.__subclasses__())
# pylint: enable=loop-global-usage
def get_slice_value(ast_slice): # py36: ast.Index # py39: ast.Tuple # noqa: MAN001,MAN002
"""
Return the value of the slice in a Python version-independent way.
:param ast_slice:
"""
if sys.version_info < (3, 9): # pragma: no cover (py39+)
return ast_slice.value
else:
return ast_slice # pragma: no cover (<py39)
[docs]class Generic:
"""
Represents a :class:`typing.Generic`, :py:obj:`typing.Union`, :py:obj:`typing.Callable` etc.
:param name: The name of the Generic
:param elements: The ``__class_getitem__`` elements of the Generic.
"""
def __init__(self, name: str, elements: typing.Sequence[typing.Union[str, "Generic", "List"]]):
self.name = str(name)
self.elements: DelimitedList[typing.Union[str, Generic, List]] = DelimitedList(elements)
def __repr__(self) -> str:
return f"{self.name}[{self.elements:, }]"
[docs]class List:
"""
Represents a list of elements, most often used within a :py:obj:`typing.Callable`.
:param elements:
"""
def __init__(self, elements: typing.Sequence[typing.Union[str, Generic, "List"]]):
self.elements = DelimitedList(elements)
def __repr__(self) -> str:
return f"[{self.elements:, }]"
class Visitor(ast.NodeVisitor):
in_class = False
def __init__(self):
self.unions: typing.List[typing.Tuple[ast.Subscript, Generic, bool]] = []
def visit_Subscript(self, node: ast.Subscript) -> None:
if isinstance(node.value, ast.Name) and node.value.id in collection_types:
union = Generic(node.value.id, UnionVisitor().visit(get_slice_value(node.slice)))
self.unions.append((node, union, self.in_class))
return
elif isinstance(node.value, ast.Attribute):
if isinstance(node.value.value, ast.Name) and node.value.value.id in {"typing", "typing_extensions"}:
if node.value.attr in collection_types:
union = Generic(
'.'.join([node.value.value.id, node.value.attr]),
UnionVisitor().visit(get_slice_value(node.slice)),
)
self.unions.append((node, union, self.in_class))
return
self.generic_visit(node)
def visit(self, node: ast.AST) -> typing.List[typing.Tuple[ast.Subscript, Generic, bool]]:
super().visit(node)
return self.unions
@staticmethod
def visit_FunctionDef(node: ast.FunctionDef) -> None:
return None
def visit_ClassDef(self, node: ast.ClassDef) -> None:
self.unions.extend(ClassVisitor().visit(node))
@staticmethod
def visit_AsyncFunctionDef(node: ast.AsyncFunctionDef) -> None:
return None
class ClassVisitor(Visitor):
in_class = True
def visit_ClassDef(self, node: ast.ClassDef) -> None:
self.generic_visit(node)
class UnionVisitor(ast.NodeVisitor):
def __init__(self):
super().__init__()
self.structure: typing.List[typing.Union[str, Generic, List]] = []
def generic_visit(self, node: ast.AST) -> None:
super().generic_visit(node)
def visit_Name(self, node: ast.Name) -> None:
self.structure.append(f"{node.id}")
def visit_Attribute(self, node: ast.Attribute) -> None:
parts: DelimitedList[str] = DelimitedList()
value: typing.Union[ast.Name, ast.expr] = node.value
NameNode, AttributeNode = ast.Name, ast.Attribute
while True:
if isinstance(value, NameNode):
parts.append(value.id)
break
elif isinstance(value, AttributeNode):
parts.append(value.value.id) # type: ignore[attr-defined]
value = value.attr # type: ignore[assignment]
elif isinstance(value, str):
parts.append(value)
break
else:
raise NotImplementedError(f"Unsupported value type {type(value)}")
self.structure.append(f"{parts:.}.{node.attr}")
def visit_Subscript(self, node: ast.Subscript) -> None:
union = Generic(
'.'.join(astatine.get_attribute_name(node.value)),
UnionVisitor().visit(get_slice_value(node.slice)),
)
self.structure.append(union)
def visit_List(self, node: ast.List) -> None:
elements = []
for child in node.elts:
elements.extend(UnionVisitor().visit(child))
self.structure.append(List(elements))
@staticmethod
def visit_Load(node: ast.Load) -> None:
return None
@staticmethod
def visit_FunctionDef(node: ast.FunctionDef) -> None:
return None
def visit_ClassDef(self, node: ast.ClassDef) -> None:
self.generic_visit(node)
@staticmethod
def visit_AsyncFunctionDef(node: ast.AsyncFunctionDef) -> None:
return None
if sys.version_info[:2] < (3, 8): # pragma: no cover (py38+)
def visit_NameConstant(self, node: ast.NameConstant) -> None:
self.structure.append(node.value)
def visit_Str(self, node: ast.Str) -> None:
self.structure.append(f'"{node.s}"')
def visit_Ellipsis(self, node: ast.Ellipsis) -> None:
self.structure.append("...")
else: # pragma: no cover (<py38)
def visit_Constant(self, node: ast.Constant) -> None:
if isinstance(node.value, str):
self.structure.append(f'"{node.value}"')
elif node.value is Ellipsis:
self.structure.append("...")
elif node.value is None or isinstance(node.value, bool):
self.structure.append(str(node.value))
else:
print(node, node.value)
self.generic_visit(node)
def visit(self, node: ast.AST) -> typing.List[typing.Union[str, Generic, List]]:
super().visit(node)
return self.structure