Skip to content

Commit 4f4a6e9

Browse files
authored
Tin/better union hooks (#499)
* Improve union structure hook handling * Improve typeddict coverage * Skip test on 3.9 and 3.10
1 parent 066ace9 commit 4f4a6e9

7 files changed

Lines changed: 55 additions & 46 deletions

File tree

src/cattrs/converters.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565
IterableValidationNote,
6666
StructureHandlerNotFoundError,
6767
)
68-
from .fns import identity, raise_error
68+
from .fns import Predicate, identity, raise_error
6969
from .gen import (
7070
AttributeOverride,
7171
DictStructureFn,
@@ -174,6 +174,7 @@ def __init__(
174174
self._prefer_attrib_converters = prefer_attrib_converters
175175

176176
self.detailed_validation = detailed_validation
177+
self._union_struct_registry: dict[Any, Callable[[Any, type[T]], T]] = {}
177178

178179
# Create a per-instance cache.
179180
if unstruct_strat is UnstructureStrategy.AS_DICT:
@@ -246,7 +247,8 @@ def __init__(
246247
(is_supported_union, self._gen_attrs_union_structure, True),
247248
(
248249
lambda t: is_union_type(t) and t in self._union_struct_registry,
249-
self._structure_union,
250+
self._union_struct_registry.__getitem__,
251+
True,
250252
),
251253
(is_optional, self._structure_optional),
252254
(has, self._structure_attrs),
@@ -266,9 +268,6 @@ def __init__(
266268

267269
self._dict_factory = dict_factory
268270

269-
# Unions are instances now, not classes. We use different registries.
270-
self._union_struct_registry: dict[Any, Callable[[Any, type[T]], T]] = {}
271-
272271
self._unstruct_copy_skip = self._unstructure_func.get_num_fns()
273272
self._struct_copy_skip = self._structure_func.get_num_fns()
274273

@@ -330,7 +329,7 @@ def register_unstructure_hook(
330329
return None
331330

332331
def register_unstructure_hook_func(
333-
self, check_func: Callable[[Any], bool], func: UnstructureHook
332+
self, check_func: Predicate, func: UnstructureHook
334333
) -> None:
335334
"""Register a class-to-primitive converter function for a class, using
336335
a function to check if it's a match.
@@ -339,25 +338,25 @@ def register_unstructure_hook_func(
339338

340339
@overload
341340
def register_unstructure_hook_factory(
342-
self, predicate: Callable[[Any], bool]
341+
self, predicate: Predicate
343342
) -> Callable[[UnstructureHookFactory], UnstructureHookFactory]:
344343
...
345344

346345
@overload
347346
def register_unstructure_hook_factory(
348-
self, predicate: Callable[[Any], bool]
347+
self, predicate: Predicate
349348
) -> Callable[[ExtendedUnstructureHookFactory], ExtendedUnstructureHookFactory]:
350349
...
351350

352351
@overload
353352
def register_unstructure_hook_factory(
354-
self, predicate: Callable[[Any], bool], factory: UnstructureHookFactory
353+
self, predicate: Predicate, factory: UnstructureHookFactory
355354
) -> UnstructureHookFactory:
356355
...
357356

358357
@overload
359358
def register_unstructure_hook_factory(
360-
self, predicate: Callable[[Any], bool], factory: ExtendedUnstructureHookFactory
359+
self, predicate: Predicate, factory: ExtendedUnstructureHookFactory
361360
) -> ExtendedUnstructureHookFactory:
362361
...
363362

@@ -473,7 +472,7 @@ def register_structure_hook(
473472
self._structure_func.register_cls_list([(cl, func)])
474473

475474
def register_structure_hook_func(
476-
self, check_func: Callable[[type[T]], bool], func: StructureHook
475+
self, check_func: Predicate, func: StructureHook
477476
) -> None:
478477
"""Register a class-to-primitive converter function for a class, using
479478
a function to check if it's a match.
@@ -482,25 +481,25 @@ def register_structure_hook_func(
482481

483482
@overload
484483
def register_structure_hook_factory(
485-
self, predicate: Callable[[Any, bool]]
484+
self, predicate: Predicate
486485
) -> Callable[[StructureHookFactory, StructureHookFactory]]:
487486
...
488487

489488
@overload
490489
def register_structure_hook_factory(
491-
self, predicate: Callable[[Any, bool]]
490+
self, predicate: Predicate
492491
) -> Callable[[ExtendedStructureHookFactory, ExtendedStructureHookFactory]]:
493492
...
494493

495494
@overload
496495
def register_structure_hook_factory(
497-
self, predicate: Callable[[Any], bool], factory: StructureHookFactory
496+
self, predicate: Predicate, factory: StructureHookFactory
498497
) -> StructureHookFactory:
499498
...
500499

501500
@overload
502501
def register_structure_hook_factory(
503-
self, predicate: Callable[[Any], bool], factory: ExtendedStructureHookFactory
502+
self, predicate: Predicate, factory: ExtendedStructureHookFactory
504503
) -> ExtendedStructureHookFactory:
505504
...
506505

@@ -903,11 +902,6 @@ def _structure_optional(self, obj, union):
903902
# We can't actually have a Union of a Union, so this is safe.
904903
return self._structure_func.dispatch(other)(obj, other)
905904

906-
def _structure_union(self, obj, union):
907-
"""Deal with structuring a union."""
908-
handler = self._union_struct_registry[union]
909-
return handler(obj, union)
910-
911905
def _structure_tuple(self, obj: Any, tup: type[T]) -> T:
912906
"""Deal with structuring into a tuple."""
913907
tup_params = None if tup in (Tuple, tuple) else tup.__args__

src/cattrs/dispatch.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,11 @@
66
from attrs import Factory, define
77

88
from ._compat import TypeAlias
9+
from .fns import Predicate
910

1011
if TYPE_CHECKING:
1112
from .converters import BaseConverter
1213

13-
T = TypeVar("T")
14-
1514
TargetType: TypeAlias = Any
1615
UnstructuredValue: TypeAlias = Any
1716
StructuredValue: TypeAlias = Any
@@ -46,12 +45,12 @@ class FunctionDispatch:
4645

4746
_converter: BaseConverter
4847
_handler_pairs: list[
49-
tuple[Callable[[Any], bool], Callable[[Any, Any], Any], bool, bool]
48+
tuple[Predicate, Callable[[Any, Any], Any], bool, bool]
5049
] = Factory(list)
5150

5251
def register(
5352
self,
54-
predicate: Callable[[Any], bool],
53+
predicate: Predicate,
5554
func: Callable[..., Any],
5655
is_generator=False,
5756
takes_converter=False,
@@ -148,13 +147,9 @@ def register_cls_list(self, cls_and_handler, direct: bool = False) -> None:
148147
def register_func_list(
149148
self,
150149
pred_and_handler: list[
151-
tuple[Callable[[Any], bool], Any]
152-
| tuple[Callable[[Any], bool], Any, bool]
153-
| tuple[
154-
Callable[[Any], bool],
155-
Callable[[Any, BaseConverter], Any],
156-
Literal["extended"],
157-
]
150+
tuple[Predicate, Any]
151+
| tuple[Predicate, Any, bool]
152+
| tuple[Predicate, Callable[[Any, BaseConverter], Any], Literal["extended"]]
158153
],
159154
):
160155
"""

src/cattrs/fns.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
"""Useful internal functions."""
2-
from typing import NoReturn, Type, TypeVar
2+
from typing import Any, Callable, NoReturn, Type, TypeVar
33

4+
from ._compat import TypeAlias
45
from .errors import StructureHandlerNotFoundError
56

67
T = TypeVar("T")
78

9+
Predicate: TypeAlias = Callable[[Any], bool]
10+
"""A predicate function determines if a type can be handled."""
11+
812

913
def identity(obj: T) -> T:
1014
"""The identity function."""

src/cattrs/gen/typeddicts.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -565,11 +565,9 @@ def _required_keys(cls: type) -> set[str]:
565565
# gathering required keys. *sigh*
566566
own_annotations = cls.__dict__.get("__annotations__", {})
567567
required_keys = set()
568-
for base in cls.__mro__[1:]:
569-
if base in (object, dict):
570-
# These have no required keys for sure.
571-
continue
572-
required_keys |= _required_keys(base)
568+
# On 3.8 - 3.10, typing.TypedDict doesn't put typeddict superclasses
569+
# in the MRO, therefore we cannot handle non-required keys properly
570+
# in some situations. Oh well.
573571
for key in getattr(cls, "__required_keys__", []):
574572
annotation_type = own_annotations[key]
575573
annotation_origin = get_origin(annotation_type)
@@ -597,13 +595,7 @@ def _required_keys(cls: type) -> set[str]:
597595

598596
own_annotations = cls.__dict__.get("__annotations__", {})
599597
required_keys = set()
600-
superclass_keys = set()
601-
for base in cls.__mro__[1:]:
602-
required_keys |= _required_keys(base)
603-
superclass_keys |= base.__dict__.get("__annotations__", {}).keys()
604598
for key in own_annotations:
605-
if key in superclass_keys:
606-
continue
607599
annotation_type = own_annotations[key]
608600

609601
if is_annotated(annotation_type):

tests/_compat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import sys
22

33
is_py38 = sys.version_info[:2] == (3, 8)
4+
is_py39 = sys.version_info[:2] == (3, 9)
45
is_py39_plus = sys.version_info >= (3, 9)
6+
is_py310 = sys.version_info[:2] == (3, 10)
57
is_py310_plus = sys.version_info >= (3, 10)
68
is_py311_plus = sys.version_info >= (3, 11)
79
is_py312_plus = sys.version_info >= (3, 12)

tests/test_typeddicts.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from hypothesis import assume, given
88
from hypothesis.strategies import booleans
99
from pytest import raises
10-
from typing_extensions import NotRequired
10+
from typing_extensions import NotRequired, Required
1111

1212
from cattrs import BaseConverter, Converter
1313
from cattrs._compat import ExtensionsTypedDict, get_notrequired_base, is_generic
@@ -24,7 +24,7 @@
2424
make_dict_unstructure_fn,
2525
)
2626

27-
from ._compat import is_py38, is_py311_plus
27+
from ._compat import is_py38, is_py39, is_py310, is_py311_plus
2828
from .typeddicts import (
2929
generic_typeddicts,
3030
simple_typeddicts,
@@ -263,6 +263,28 @@ def test_required(
263263
assert restructured == instance
264264

265265

266+
@pytest.mark.skipif(is_py39 or is_py310, reason="Sigh")
267+
def test_required_keys() -> None:
268+
"""We don't support the full gamut of functionality on 3.8.
269+
270+
When using `typing.TypedDict` we have only partial functionality;
271+
this test tests only a subset of this.
272+
"""
273+
c = mk_converter()
274+
275+
class Base(TypedDict, total=False):
276+
a: Required[datetime]
277+
278+
class Sub(Base):
279+
b: int
280+
281+
fn = make_dict_unstructure_fn(Sub, c)
282+
283+
with raises(KeyError):
284+
# This needs to raise since 'a' is missing, and it's Required.
285+
fn({"b": 1})
286+
287+
266288
@given(simple_typeddicts(min_attrs=1, total=True), booleans())
267289
def test_omit(cls_and_instance: Tuple[type, Dict], detailed_validation: bool) -> None:
268290
"""`override(omit=True)` works."""

tests/typeddicts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ def simple_typeddicts(
180180
note(
181181
"\n".join(
182182
[
183-
"class HypTypedDict(TypedDict):",
183+
f"class HypTypedDict(TypedDict{'' if total else ', total=False'}):",
184184
*[f" {n}: {a}" for n, a in attrs_dict.items()],
185185
]
186186
)

0 commit comments

Comments
 (0)