Skip to content
Open
37 changes: 35 additions & 2 deletions Lib/annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def evaluate(
"""
match format:
case Format.STRING:
return self.__forward_arg__
return self.__resolved_forward_str__
case Format.VALUE:
is_forwardref_format = False
case Format.FORWARDREF:
Expand Down Expand Up @@ -258,6 +258,26 @@ def __forward_arg__(self):
"Attempted to access '__forward_arg__' on an uninitialized ForwardRef"
)

@property
def __resolved_forward_str__(self):
# __forward_arg__ with any names from __extra_names__ replaced
# with the type_repr of the value they represent
resolved_str = self.__forward_arg__
names = self.__extra_names__

if names:
# identifiers can be replaced directly
if resolved_str.isidentifier():
if (name_obj := names.get(resolved_str), _sentinel) is not _sentinel:
resolved_str = type_repr(name_obj)
else:
visitor = _ExtraNameFixer(names)
ast_expr = ast.parse(resolved_str, mode="eval").body
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might need to cache this, it's probably pretty slow. What do you think?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is in fact, not very fast. Definitely worth caching if it's going to be accessed frequently.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> a_anno
ForwardRef('unknown | str | int | list[str] | tuple[int, ...]', is_class=True, owner=<class '__main__.Example'>)
>>> b_anno
ForwardRef('unknown', is_class=True, owner=<class '__main__.Example'>)

>>> a = timeit(lambda: a_anno.__resolved_forward_str__, number=10_000)
>>> b = timeit(lambda: b_anno.__resolved_forward_str__, number=10_000)
>>> a
0.18359258100008446
>>> b
0.0018204959997092374
>>> a / b
100.8475607907994

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added a cache and ended up shortening the name of the property and the cache. Adding the cache did mean adding an extra slot to both classes.

Not completely sold on the names I have if you have something better.

node = visitor.visit(ast_expr)
resolved_str = ast.unparse(node)

return resolved_str

@property
def __forward_code__(self):
if self.__code__ is not None:
Expand Down Expand Up @@ -321,7 +341,7 @@ def __repr__(self):
extra.append(", is_class=True")
if self.__owner__ is not None:
extra.append(f", owner={self.__owner__!r}")
return f"ForwardRef({self.__forward_arg__!r}{''.join(extra)})"
return f"ForwardRef({self.__resolved_forward_str__!r}{''.join(extra)})"


_Template = type(t"")
Expand Down Expand Up @@ -1163,3 +1183,16 @@ def _get_dunder_annotations(obj):
if not isinstance(ann, dict):
raise ValueError(f"{obj!r}.__annotations__ is neither a dict nor None")
return ann


class _ExtraNameFixer(ast.NodeTransformer):
"""Fixer for __extra_names__ items in ForwardRef __repr__ and string evaluation"""
def __init__(self, extra_names):
self.extra_names = extra_names

def visit_Name(self, node: ast.Name):
if (new_name := self.extra_names.get(node.id, _sentinel)) is not _sentinel:
new_node = ast.Name(id=type_repr(new_name))
ast.copy_location(node, new_node)
Comment thread
DavidCEllis marked this conversation as resolved.
Outdated
node = new_node
return node
19 changes: 19 additions & 0 deletions Lib/test/test_annotationlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -1961,6 +1961,11 @@ def test_forward_repr(self):
"typing.List[ForwardRef('int', owner='class')]",
)

def test_forward_repr_extra_names(self):
Comment thread
DavidCEllis marked this conversation as resolved.
fr = ForwardRef("__annotationlib_name_1__")
fr.__extra_names__ = {"__annotationlib_name_1__": list[str]}
self.assertEqual(repr(fr), "ForwardRef('list[str]')")

def test_forward_recursion_actually(self):
def namespace1():
a = ForwardRef("A")
Expand Down Expand Up @@ -2037,6 +2042,20 @@ def test_evaluate_string_format(self):
fr = ForwardRef("set[Any]")
self.assertEqual(fr.evaluate(format=Format.STRING), "set[Any]")

def test_evaluate_string_format_extra_names(self):
# Test that internal extra_names are replaced when evaluating as strings

# As identifier
fr = ForwardRef("__annotationlib_name_1__")
fr.__extra_names__ = {"__annotationlib_name_1__": str}
self.assertEqual(fr.evaluate(format=Format.STRING), "str")

# Via AST visitor
def f(a: ref | str): ...

fr = get_annotations(f, format=Format.FORWARDREF)['a']
self.assertEqual(fr.evaluate(format=Format.STRING), "ref | str")

def test_evaluate_forwardref_format(self):
fr = ForwardRef("undef")
evaluated = fr.evaluate(format=Format.FORWARDREF)
Expand Down
Loading