Skip to content
Open
43 changes: 41 additions & 2 deletions Lib/annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Format(enum.IntEnum):
"__cell__",
"__owner__",
"__stringifier_dict__",
"__resolved_str_cache__",
)


Expand Down Expand Up @@ -94,6 +95,7 @@ def __init__(
# value later.
self.__code__ = None
self.__ast_node__ = None
self.__resolved_str_cache__ = None

def __init_subclass__(cls, /, *args, **kwds):
raise TypeError("Cannot subclass ForwardRef")
Expand All @@ -113,7 +115,7 @@ def evaluate(
"""
match format:
case Format.STRING:
return self.__forward_arg__
return self.__resolved_str__
case Format.VALUE:
is_forwardref_format = False
case Format.FORWARDREF:
Expand Down Expand Up @@ -258,6 +260,29 @@ def __forward_arg__(self):
"Attempted to access '__forward_arg__' on an uninitialized ForwardRef"
)

@property
def __resolved_str__(self):
# __forward_arg__ with any names from __extra_names__ replaced
# with the type_repr of the value they represent
if self.__resolved_str_cache__ is None:
resolved_str = self.__forward_arg__
names = self.__extra_names__

if names:
# identifiers can be replaced directly
if resolved_str.isidentifier():
if (name_obj := names.get(resolved_str), _sentinel) is not _sentinel:
Comment thread
DavidCEllis marked this conversation as resolved.
Outdated
resolved_str = type_repr(name_obj)
else:
visitor = _ExtraNameFixer(names)
ast_expr = ast.parse(resolved_str, mode="eval").body
node = visitor.visit(ast_expr)
resolved_str = ast.unparse(node)

self.__resolved_str_cache__ = resolved_str

return self.__resolved_str_cache__

@property
def __forward_code__(self):
if self.__code__ is not None:
Expand Down Expand Up @@ -321,7 +346,7 @@ def __repr__(self):
extra.append(", is_class=True")
if self.__owner__ is not None:
extra.append(f", owner={self.__owner__!r}")
return f"ForwardRef({self.__forward_arg__!r}{''.join(extra)})"
return f"ForwardRef({self.__resolved_str__!r}{''.join(extra)})"


_Template = type(t"")
Expand Down Expand Up @@ -357,6 +382,7 @@ def __init__(
self.__cell__ = cell
self.__owner__ = owner
self.__stringifier_dict__ = stringifier_dict
self.__resolved_str_cache__ = None # Needed for ForwardRef

def __convert_to_ast(self, other):
if isinstance(other, _Stringifier):
Expand Down Expand Up @@ -1163,3 +1189,16 @@ def _get_dunder_annotations(obj):
if not isinstance(ann, dict):
raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
return ann


class _ExtraNameFixer(ast.NodeTransformer):
"""Fixer for __extra_names__ items in ForwardRef __repr__ and string evaluation"""
def __init__(self, extra_names):
self.extra_names = extra_names

def visit_Name(self, node: ast.Name):
if (new_name := self.extra_names.get(node.id, _sentinel)) is not _sentinel:
new_node = ast.Name(id=type_repr(new_name))
ast.copy_location(node, new_node)
Comment thread
DavidCEllis marked this conversation as resolved.
Outdated
node = new_node
return node
27 changes: 27 additions & 0 deletions Lib/test/test_annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,6 +1961,15 @@ def test_forward_repr(self):
"typing.List[ForwardRef('int', owner='class')]",
)

def test_forward_repr_extra_names(self):
Comment thread
DavidCEllis marked this conversation as resolved.
def f(a: undefined | str): ...

annos = get_annotations(f, format=Format.FORWARDREF)

self.assertRegex(
repr(annos['a']), r"ForwardRef\('undefined \| str'.*\)"
)

def test_forward_recursion_actually(self):
def namespace1():
a = ForwardRef("A")
Expand Down Expand Up @@ -2037,6 +2046,24 @@ def test_evaluate_string_format(self):
fr = ForwardRef("set[Any]")
self.assertEqual(fr.evaluate(format=Format.STRING), "set[Any]")

def test_evaluate_string_format_extra_names(self):
# Test that internal extra_names are replaced when evaluating as strings

# As identifier
fr = ForwardRef("__annotationlib_name_1__")
fr.__extra_names__ = {"__annotationlib_name_1__": str}
self.assertEqual(fr.evaluate(format=Format.STRING), "str")

# Via AST visitor
def f(a: ref | str): ...

fr = get_annotations(f, format=Format.FORWARDREF)['a']
# Test the cache is not populated before access
self.assertIsNone(fr.__resolved_str_cache__)

self.assertEqual(fr.evaluate(format=Format.STRING), "ref | str")
self.assertEqual(fr.__resolved_str_cache__, "ref | str")

def test_evaluate_forwardref_format(self):
fr = ForwardRef("undef")
evaluated = fr.evaluate(format=Format.FORWARDREF)
Expand Down
Loading