Skip to content

Commit ad2a044

Browse files
authored
Tin/more-typeddict-coverage (#492)
* TypedDict coverage * Improve typeddicts coverage * Remove dead code * TypedDict fix * Remove dead code
1 parent 0ad5cae commit ad2a044

5 files changed

Lines changed: 91 additions & 48 deletions

File tree

src/cattrs/_compat.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -412,19 +412,22 @@ def is_counter(type):
412412
or getattr(type, "__origin__", None) is Counter
413413
)
414414

415-
def is_generic(obj) -> bool:
416-
"""Whether obj is a generic type."""
415+
def is_generic(type) -> bool:
416+
"""Whether `type` is a generic type."""
417417
# Inheriting from protocol will inject `Generic` into the MRO
418418
# without `__orig_bases__`.
419-
return isinstance(obj, (_GenericAlias, GenericAlias)) or (
420-
is_subclass(obj, Generic) and hasattr(obj, "__orig_bases__")
419+
return isinstance(type, (_GenericAlias, GenericAlias)) or (
420+
is_subclass(type, Generic) and hasattr(type, "__orig_bases__")
421421
)
422422

423423
def copy_with(type, args):
424424
"""Replace a generic type's arguments."""
425425
if is_annotated(type):
426426
# typing.Annotated requires a special case.
427427
return Annotated[args]
428+
if isinstance(args, tuple) and len(args) == 1:
429+
# Some annotations can't handle 1-tuples.
430+
args = args[0]
428431
return type.__origin__[args]
429432

430433
def get_full_type_hints(obj, globalns=None, localns=None):

src/cattrs/gen/_generics.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77

88
def generate_mapping(cl: type, old_mapping: dict[str, type] = {}) -> dict[str, type]:
9-
mapping = {}
9+
"""Generate a mapping of typevars to actual types for a generic class."""
10+
mapping = dict(old_mapping)
1011

1112
origin = get_origin(cl)
1213

@@ -25,8 +26,6 @@ def generate_mapping(cl: type, old_mapping: dict[str, type] = {}) -> dict[str, t
2526
continue
2627
mapping[p.__name__] = t
2728

28-
if not mapping:
29-
return dict(old_mapping)
3029
elif is_generic(cl):
3130
# Origin is None, so this may be a subclass of a generic class.
3231
orig_bases = cl.__orig_bases__

src/cattrs/gen/typeddicts.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_annots(cl) -> dict[str, Any]:
5353
if TYPE_CHECKING: # pragma: no cover
5454
from typing_extensions import Literal
5555

56-
from cattr.converters import BaseConverter
56+
from ..converters import BaseConverter
5757

5858
__all__ = ["make_dict_unstructure_fn", "make_dict_structure_fn"]
5959

@@ -209,7 +209,7 @@ def make_dict_unstructure_fn(
209209
# No default or no override.
210210
lines.append(f" res['{kn}'] = {invoke}")
211211
else:
212-
lines.append(f" if '{kn}' in instance: res['{kn}'] = {invoke}")
212+
lines.append(f" if '{attr_name}' in instance: res['{kn}'] = {invoke}")
213213

214214
internal_arg_line = ", ".join([f"{i}={i}" for i in internal_arg_parts])
215215
if internal_arg_line:
@@ -340,9 +340,7 @@ def make_dict_structure_fn(
340340
if nrb is not NOTHING:
341341
t = nrb
342342

343-
if isinstance(t, TypeVar):
344-
t = mapping.get(t.__name__, t)
345-
elif is_generic(t) and not is_bare(t) and not is_annotated(t):
343+
if is_generic(t) and not is_bare(t) and not is_annotated(t):
346344
t = deep_copy_with(t, mapping)
347345

348346
# For each attribute, we try resolving the type here and now.
@@ -554,8 +552,12 @@ def _required_keys(cls: type) -> set[str]:
554552
elif sys.version_info >= (3, 9):
555553
from typing_extensions import Annotated, NotRequired, Required, get_args
556554

555+
# Note that there is no `typing.Required` on 3.9 and 3.10, only in
556+
# `typing_extensions`. Therefore, `typing.TypedDict` will not honor this
557+
# annotation, only `typing_extensions.TypedDict`.
558+
557559
def _required_keys(cls: type) -> set[str]:
558-
"""Own own processor for required keys."""
560+
"""Our own processor for required keys."""
559561
if _is_extensions_typeddict(cls):
560562
return cls.__required_keys__
561563

@@ -564,6 +566,9 @@ def _required_keys(cls: type) -> set[str]:
564566
own_annotations = cls.__dict__.get("__annotations__", {})
565567
required_keys = set()
566568
for base in cls.__mro__[1:]:
569+
if base in (object, dict):
570+
# These have no required keys for sure.
571+
continue
567572
required_keys |= _required_keys(base)
568573
for key in getattr(cls, "__required_keys__", []):
569574
annotation_type = own_annotations[key]
@@ -574,9 +579,7 @@ def _required_keys(cls: type) -> set[str]:
574579
annotation_type = annotation_args[0]
575580
annotation_origin = get_origin(annotation_type)
576581

577-
if annotation_origin is Required:
578-
required_keys.add(key)
579-
elif annotation_origin is NotRequired:
582+
if annotation_origin is NotRequired:
580583
pass
581584
elif cls.__total__:
582585
required_keys.add(key)
@@ -588,7 +591,7 @@ def _required_keys(cls: type) -> set[str]:
588591
# On 3.8, typing.TypedDicts do not have __required_keys__.
589592

590593
def _required_keys(cls: type) -> set[str]:
591-
"""Own own processor for required keys."""
594+
"""Our own processor for required keys."""
592595
if _is_extensions_typeddict(cls):
593596
return cls.__required_keys__
594597

tests/test_typeddicts.py

Lines changed: 47 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
"""Tests for TypedDict un/structuring."""
22
from datetime import datetime, timezone
3-
from typing import Dict, Generic, Set, Tuple, TypedDict, TypeVar
3+
from typing import Dict, Generic, NewType, Set, Tuple, TypedDict, TypeVar
44

55
import pytest
6+
from attrs import NOTHING
67
from hypothesis import assume, given
78
from hypothesis.strategies import booleans
89
from pytest import raises
910
from typing_extensions import NotRequired
1011

1112
from cattrs import BaseConverter, Converter
12-
from cattrs._compat import ExtensionsTypedDict, is_generic
13+
from cattrs._compat import ExtensionsTypedDict, get_notrequired_base, is_generic
1314
from cattrs.errors import (
1415
ClassValidationError,
1516
ForbiddenExtraKeysError,
@@ -51,17 +52,27 @@ def get_annot(t) -> dict:
5152
args = t.__args__
5253
params = origin.__parameters__
5354
param_to_args = dict(zip(params, args))
54-
return {
55-
k: param_to_args[v] if v in param_to_args else v
56-
for k, v in origin_annotations.items()
57-
}
55+
res = {}
56+
for k, v in origin_annotations.items():
57+
if (nrb := get_notrequired_base(v)) is not NOTHING:
58+
res[k] = (
59+
NotRequired[param_to_args[nrb]] if nrb in param_to_args else v
60+
)
61+
else:
62+
res[k] = param_to_args[v] if v in param_to_args else v
63+
return res
5864

5965
# Origin is `None`, so this is a subclass for a generic typeddict.
6066
mapping = generate_mapping(t)
61-
return {
62-
k: mapping[v.__name__] if v.__name__ in mapping else v
63-
for k, v in get_annots(t).items()
64-
}
67+
res = {}
68+
for k, v in get_annots(t).items():
69+
if (nrb := get_notrequired_base(v)) is not NOTHING:
70+
res[k] = (
71+
NotRequired[mapping[nrb.__name__]] if nrb.__name__ in mapping else v
72+
)
73+
else:
74+
res[k] = mapping[v.__name__] if v.__name__ in mapping else v
75+
return res
6576
return get_annots(t)
6677

6778

@@ -196,6 +207,27 @@ class GenericTypedDict(TypedDict, Generic[T]):
196207
c.structure({"a": 1}, GenericTypedDict)
197208

198209

210+
@pytest.mark.skipif(not is_py311_plus, reason="3.11+ only")
211+
@given(detailed_validation=...)
212+
def test_deep_generics(detailed_validation: bool):
213+
c = mk_converter(detailed_validation=detailed_validation)
214+
215+
Int = NewType("Int", int)
216+
217+
c.register_unstructure_hook_func(lambda t: t is Int, lambda v: v - 1)
218+
219+
T = TypeVar("T")
220+
T1 = TypeVar("T1")
221+
222+
class GenericParent(TypedDict, Generic[T]):
223+
a: T
224+
225+
class GenericChild(GenericParent[Int], Generic[T1]):
226+
b: T1
227+
228+
assert c.unstructure({"b": 2, "a": 2}, GenericChild[Int]) == {"a": 1, "b": 1}
229+
230+
199231
@given(simple_typeddicts(total=True, not_required=True), booleans())
200232
def test_not_required(
201233
cls_and_instance: Tuple[type, Dict], detailed_validation: bool
@@ -273,36 +305,25 @@ def test_omit(cls_and_instance: Tuple[type, Dict], detailed_validation: bool) ->
273305
assert restructured == instance
274306

275307

276-
@given(simple_typeddicts(min_attrs=1, total=True), booleans())
308+
@given(simple_typeddicts(min_attrs=1, total=True, not_required=True), booleans())
277309
def test_rename(cls_and_instance: Tuple[type, Dict], detailed_validation: bool) -> None:
278310
"""`override(rename=...)` works."""
279311
c = mk_converter(detailed_validation=detailed_validation)
280312

281313
cls, instance = cls_and_instance
282314
key = next(iter(get_annot(cls)))
283315
c.register_unstructure_hook(
284-
cls,
285-
make_dict_unstructure_fn(
286-
cls,
287-
c,
288-
_cattrs_detailed_validation=detailed_validation,
289-
**{key: override(rename="renamed")},
290-
),
316+
cls, make_dict_unstructure_fn(cls, c, **{key: override(rename="renamed")})
291317
)
292318

293319
unstructured = c.unstructure(instance, unstructure_as=cls)
294320

295321
assert key not in unstructured
296-
assert "renamed" in unstructured
322+
if key in instance:
323+
assert "renamed" in unstructured
297324

298325
c.register_structure_hook(
299-
cls,
300-
make_dict_structure_fn(
301-
cls,
302-
c,
303-
_cattrs_detailed_validation=detailed_validation,
304-
**{key: override(rename="renamed")},
305-
),
326+
cls, make_dict_structure_fn(cls, c, **{key: override(rename="renamed")})
306327
)
307328
restructured = c.structure(unstructured, cls)
308329

tests/typeddicts.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Strategies for typed dicts."""
22
from datetime import datetime, timezone
33
from string import ascii_lowercase
4-
from typing import Any, Dict, Generic, List, Optional, Set, Tuple, TypeVar
4+
from typing import Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar
55

66
from attrs import NOTHING
7+
from hypothesis import note
78
from hypothesis.strategies import (
89
DrawFn,
910
SearchStrategy,
@@ -49,7 +50,7 @@ def gen_typeddict_attr_names():
4950
@composite
5051
def int_attributes(
5152
draw: DrawFn, total: bool = True, not_required: bool = False
52-
) -> Tuple[int, SearchStrategy, SearchStrategy]:
53+
) -> Tuple[Type[int], SearchStrategy, SearchStrategy]:
5354
if total:
5455
if not_required and draw(booleans()):
5556
return (NotRequired[int], integers() | just(NOTHING), text(ascii_lowercase))
@@ -176,6 +177,15 @@ def simple_typeddicts(
176177
else typeddict_cls
177178
)("HypTypedDict", attrs_dict, total=total)
178179

180+
note(
181+
"\n".join(
182+
[
183+
"class HypTypedDict(TypedDict):",
184+
*[f" {n}: {a}" for n, a in attrs_dict.items()],
185+
]
186+
)
187+
)
188+
179189
if draw(booleans()):
180190

181191
class InheritedTypedDict(cls):
@@ -240,9 +250,8 @@ def generic_typeddicts(
240250
generics.append(typevar)
241251
if total and draw(booleans()):
242252
# We might decide to make these NotRequired
243-
actual_types.append(NotRequired[attr_type])
244-
else:
245-
actual_types.append(attr_type)
253+
typevar = NotRequired[typevar]
254+
actual_types.append(attr_type)
246255
attrs_dict[attr_name] = typevar
247256

248257
cls = make_typeddict(
@@ -282,6 +291,14 @@ def make_typeddict(
282291
lines.append(f" {n}: _{trimmed}_type")
283292

284293
script = "\n".join(lines)
294+
295+
note_lines = script
296+
for n, t in globs.items():
297+
if n == "TypedDict":
298+
continue
299+
note_lines = note_lines.replace(n, repr(t))
300+
note(note_lines)
301+
285302
eval(compile(script, "name", "exec"), globs)
286303

287304
return globs[cls_name]

0 commit comments

Comments
 (0)