Skip to content

Commit 9c6ed47

Browse files
committed
Use the same annotation logic for VALUE, FORWARDREF and STRING
Also split out the single large test into multiple smaller tests
1 parent 864305d commit 9c6ed47

2 files changed

Lines changed: 71 additions & 20 deletions

File tree

Lib/dataclasses.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -537,36 +537,34 @@ def _make_annotate_function(cls, annotations):
537537
# annotations should be in FORWARDREF format at this stage
538538

539539
def __annotate__(format, /):
540+
Format = annotationlib.Format
540541
match format:
541-
case annotationlib.Format.VALUE | annotationlib.Format.FORWARDREF:
542-
return {
543-
k: v.evaluate(format=format)
544-
if isinstance(v, annotationlib.ForwardRef) else v
545-
for k, v in annotations.items()
546-
}
547-
548-
case annotationlib.Format.STRING:
542+
case Format.VALUE | Format.FORWARDREF | Format.STRING:
549543
cls_annotations = {}
550544
for base in reversed(cls.__mro__):
551545
cls_annotations.update(
552546
annotationlib.get_annotations(base, format=format)
553547
)
554548

555-
string_annos = {}
549+
new_annotations = {}
556550
for k, v in annotations.items():
557551
try:
558-
string_annos[k] = cls_annotations[k]
552+
new_annotations[k] = cls_annotations[k]
559553
except KeyError:
560554
# This should be the return value
561-
string_annos[k] = annotationlib.type_repr(v)
562-
return string_annos
555+
if format == Format.STRING:
556+
new_annotations[k] = annotationlib.type_repr(v)
557+
else:
558+
new_annotations[k] = v
559+
560+
return new_annotations
563561

564562
case _:
565563
raise NotImplementedError(format)
566564

567565
# This is a flag for _add_slots to know it needs to regenerate this method
568566
# In order to remove references to the original class when it is replaced
569-
__annotate__.__generated_by_dataclasses = True
567+
__annotate__._generated_by_dataclasses = True
570568

571569
return __annotate__
572570

@@ -1399,7 +1397,7 @@ def _add_slots(cls, is_frozen, weakref_slot, defined_fields):
13991397

14001398
# Fix references in generated __annotate__ methods
14011399
method = getattr(newcls, "__init__")
1402-
update_annotations = getattr(method.__annotate__, "__generated_by_dataclasses", False)
1400+
update_annotations = getattr(method.__annotate__, "_generated_by_dataclasses", False)
14031401

14041402
if update_annotations:
14051403
new_annotations = {}

Lib/test/test_dataclasses/__init__.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2470,10 +2470,13 @@ def __init__(self, a):
24702470

24712471
self.assertEqual(D(5).a, 10)
24722472

2473+
2474+
class TestInitAnnotate(unittest.TestCase):
2475+
# Tests for the generated __annotate__ function for __init__
2476+
# See: https://github.com/python/cpython/issues/137530
2477+
24732478
def test_annotate_function(self):
2474-
# Test that the __init__ function has correct annotate function
2475-
# See: https://github.com/python/cpython/issues/137530
2476-
# With no forward references
2479+
# No forward references
24772480
@dataclass
24782481
class A:
24792482
a: int
@@ -2486,6 +2489,9 @@ class A:
24862489
self.assertEqual(forwardref_annos, {'a': int, 'return': None})
24872490
self.assertEqual(string_annos, {'a': 'int', 'return': 'None'})
24882491

2492+
self.assertTrue(getattr(A.__init__.__annotate__, "_generated_by_dataclasses"))
2493+
2494+
def test_annotate_function_forwardref(self):
24892495
# With forward references
24902496
@dataclass
24912497
class B:
@@ -2512,26 +2518,53 @@ class B:
25122518
self.assertEqual(forwardref_annos, {'b': int, 'return': None})
25132519
self.assertEqual(string_annos, {'b': 'undefined', 'return': 'None'})
25142520

2515-
del undefined # Remove so we can use the name in later examples
2516-
2521+
def test_annotate_function_init_false(self):
25172522
# Check `init=False` attributes don't get into the annotations of the __init__ function
25182523
@dataclass
25192524
class C:
25202525
c: str = field(init=False)
25212526

25222527
self.assertEqual(annotationlib.get_annotations(C.__init__), {'return': None})
25232528

2524-
2529+
def test_annotate_function_contains_forwardref(self):
25252530
# Check string annotations on objects containing a ForwardRef
25262531
@dataclass
25272532
class D:
25282533
d: list[undefined]
25292534

2535+
with self.assertRaises(NameError):
2536+
annotationlib.get_annotations(D.__init__)
2537+
2538+
self.assertEqual(
2539+
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
2540+
{"d": list[support.EqualToForwardRef("undefined", is_class=True, owner=D)], "return": None}
2541+
)
2542+
2543+
self.assertEqual(
2544+
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
2545+
{"d": "list[undefined]", "return": "None"}
2546+
)
2547+
2548+
# Now test when it is defined
2549+
undefined = str
2550+
2551+
# VALUE should now resolve
2552+
self.assertEqual(
2553+
annotationlib.get_annotations(D.__init__),
2554+
{"d": list[str], "return": None}
2555+
)
2556+
2557+
self.assertEqual(
2558+
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.FORWARDREF),
2559+
{"d": list[str], "return": None}
2560+
)
2561+
25302562
self.assertEqual(
25312563
annotationlib.get_annotations(D.__init__, format=annotationlib.Format.STRING),
25322564
{"d": "list[undefined]", "return": "None"}
25332565
)
25342566

2567+
def test_annotate_function_not_replaced(self):
25352568
# Check that __annotate__ is not replaced on non-generated __init__ functions
25362569
@dataclass(slots=True)
25372570
class E:
@@ -2543,6 +2576,26 @@ def __init__(self, x: int) -> None:
25432576
annotationlib.get_annotations(E.__init__), {"x": int, "return": None}
25442577
)
25452578

2579+
self.assertFalse(hasattr(E.__init__.__annotate__, "_generated_by_dataclasses"))
2580+
2581+
def test_init_false_forwardref(self):
2582+
# Currently this raises a NameError even though the ForwardRef
2583+
# is not in the __init__ method
2584+
2585+
@dataclass
2586+
class F:
2587+
not_in_init: list[undefined] = field(init=False, default=None)
2588+
in_init: int
2589+
2590+
annos = annotationlib.get_annotations(F.__init__, format=annotationlib.Format.FORWARDREF)
2591+
self.assertEqual(
2592+
annos,
2593+
{"in_init": int, "return": None},
2594+
)
2595+
2596+
with self.assertRaises(NameError):
2597+
annos = annotationlib.get_annotations(F.__init__) # NameError on not_in_init
2598+
25462599

25472600
class TestRepr(unittest.TestCase):
25482601
def test_repr(self):

0 commit comments

Comments
 (0)