Skip to content

Commit 43f7d0f

Browse files
authored
Implement a disambiguator for discriminated unions (#392)
* Implement a disambiguator for discriminated unions Another way at "tagged unions" except backwards compatible! Also, a ton of services provide data in this fashion. Debatably cattrs already does something similar where you don't need to register a structure hook for a `Literal[1, 2, 3, 4]` or whatever. * PR feedback * Stop supporting Python 3.7 * Remove forgotten `is_py37` reference... * Fix ruff error + one isort error * Don't add `isort` to CI * Oops, don't fix ruff and then error out
1 parent c1e93e9 commit 43f7d0f

6 files changed

Lines changed: 191 additions & 23 deletions

File tree

HISTORY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
([#393](https://github.com/python-attrs/cattrs/issues/393))
4848
- Remove some unused lines in the unstructuring code.
4949
([#416](https://github.com/python-attrs/cattrs/pull/416))
50+
- Disambiguate a union of attrs classes where there's a `typing.Literal` tag of some sort.
51+
([#391](https://github.com/python-attrs/cattrs/pull/391))
5052

5153
## 23.1.2 (2023-06-02)
5254

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ clean-test: ## remove test and coverage artifacts
4747
rm -f .coverage
4848
rm -fr htmlcov/
4949

50-
lint: ## check style with flake8
50+
lint: ## check style with ruff and black
5151
pdm run ruff src/ tests
5252
pdm run black --check src tests docs/conf.py
5353

src/cattrs/converters.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
is_typeddict,
5656
is_union_type,
5757
)
58-
from .disambiguators import create_uniq_field_dis_func
58+
from .disambiguators import create_default_dis_func
5959
from .dispatch import MultiStrategyDispatch
6060
from .errors import (
6161
IterableValidationError,
@@ -729,13 +729,16 @@ def _get_dis_func(union: Any) -> Callable[[Any], Type]:
729729
e for e in union_types if e is not NoneType # type: ignore
730730
)
731731

732+
# TODO: technically both disambiguators could support TypedDicts and
733+
# dataclasses...
732734
if not all(has(get_origin(e) or e) for e in union_types):
733735
raise StructureHandlerNotFoundError(
734736
"Only unions of attrs classes supported "
735737
"currently. Register a loads hook manually.",
736738
type_=union,
737739
)
738-
return create_uniq_field_dis_func(*union_types)
740+
741+
return create_default_dis_func(*union_types)
739742

740743
def __deepcopy__(self, _) -> "BaseConverter":
741744
return self.copy()

src/cattrs/disambiguators.py

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,77 @@
11
"""Utilities for union (sum type) disambiguation."""
2-
from collections import OrderedDict
2+
from collections import OrderedDict, defaultdict
33
from functools import reduce
44
from operator import or_
5-
from typing import Any, Callable, Dict, Mapping, Optional, Type
5+
from typing import Any, Callable, Dict, Mapping, Optional, Type, Union
66

7-
from attr import NOTHING, fields
7+
from attr import NOTHING, fields, fields_dict
88

9-
from cattrs._compat import get_origin
9+
from cattrs._compat import get_args, get_origin, is_literal
1010

1111

12-
def create_uniq_field_dis_func(
12+
def create_default_dis_func(
1313
*classes: Type[Any],
1414
) -> Callable[[Mapping[Any, Any]], Optional[Type[Any]]]:
1515
"""Given attr classes, generate a disambiguation function.
1616
17-
The function is based on unique fields."""
17+
The function is based on unique fields or unique values."""
1818
if len(classes) < 2:
1919
raise ValueError("At least two classes required.")
20+
21+
# first, attempt for unique values
22+
23+
# requirements for a discriminator field:
24+
# (... TODO: a single fallback is OK)
25+
# - it must be *required*
26+
# - it must always be enumerated
27+
cls_candidates = [
28+
{
29+
at.name
30+
for at in fields(get_origin(cl) or cl)
31+
if at.default is NOTHING and is_literal(at.type)
32+
}
33+
for cl in classes
34+
]
35+
36+
discriminators = cls_candidates[0]
37+
for possible_discriminators in cls_candidates:
38+
discriminators &= possible_discriminators
39+
40+
best_result = None
41+
best_discriminator = None
42+
for discriminator in discriminators:
43+
mapping = defaultdict(list)
44+
45+
for cl in classes:
46+
for key in get_args(fields_dict(get_origin(cl) or cl)[discriminator].type):
47+
mapping[key].append(cl)
48+
49+
if best_result is None or max(len(v) for v in mapping.values()) <= max(
50+
len(v) for v in best_result.values()
51+
):
52+
best_result = mapping
53+
best_discriminator = discriminator
54+
55+
if (
56+
best_result
57+
and best_discriminator
58+
and max(len(v) for v in best_result.values()) != len(classes)
59+
):
60+
final_mapping = {
61+
k: v[0] if len(v) == 1 else Union[tuple(v)] for k, v in best_result.items()
62+
}
63+
64+
def dis_func(data: Mapping[Any, Any]) -> Optional[Type]:
65+
if not isinstance(data, Mapping):
66+
raise ValueError("Only input mappings are supported.")
67+
return final_mapping[data[best_discriminator]]
68+
69+
return dis_func
70+
71+
# next, attempt for unique keys
72+
73+
# NOTE: This could just as well work with just field availability and not
74+
# uniqueness, returning Unions ... it doesn't do that right now.
2075
cls_and_attrs = [
2176
(cl, {at.name for at in fields(get_origin(cl) or cl)}) for cl in classes
2277
]
@@ -57,3 +112,6 @@ def dis_func(data: Mapping[Any, Any]) -> Optional[Type]:
57112
return fallback
58113

59114
return dis_func
115+
116+
117+
create_uniq_field_dis_func = create_default_dis_func

tests/test_disambigutors.py

Lines changed: 93 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
"""Tests for auto-disambiguators."""
2-
from typing import Any
2+
from typing import Any, Literal, Union
33

44
import attr
55
import pytest
6-
from attr import NOTHING
6+
from attrs import NOTHING, define
77
from hypothesis import HealthCheck, assume, given, settings
88

9-
from cattrs.disambiguators import create_uniq_field_dis_func
9+
from cattrs.disambiguators import create_default_dis_func, create_uniq_field_dis_func
1010

1111
from .untyped import simple_classes
1212

@@ -22,6 +22,9 @@ class A:
2222
# Can't generate for only one class.
2323
create_uniq_field_dis_func(A)
2424

25+
with pytest.raises(ValueError):
26+
create_default_dis_func(A)
27+
2528
@attr.s
2629
class B:
2730
pass
@@ -30,6 +33,9 @@ class B:
3033
# No fields on either class.
3134
create_uniq_field_dis_func(A, B)
3235

36+
with pytest.raises(ValueError):
37+
create_default_dis_func(A, B)
38+
3339
@attr.s
3440
class C:
3541
a = attr.ib()
@@ -42,6 +48,10 @@ class D:
4248
# No unique fields on either class.
4349
create_uniq_field_dis_func(C, D)
4450

51+
with pytest.raises(ValueError):
52+
# No discriminator candidates
53+
create_default_dis_func(C, D)
54+
4555
@attr.s
4656
class E:
4757
pass
@@ -54,6 +64,18 @@ class F:
5464
# no usable non-default attributes
5565
create_uniq_field_dis_func(E, F)
5666

67+
@define()
68+
class G:
69+
x: Literal[1]
70+
71+
@define()
72+
class H:
73+
x: Literal[1]
74+
75+
with pytest.raises(ValueError):
76+
# The discriminator chosen does not actually help
77+
create_default_dis_func(C, D)
78+
5779

5880
@given(simple_classes(defaults=False))
5981
def test_fallback(cl_and_vals):
@@ -99,3 +121,71 @@ def test_disambiguation(cl_and_vals_a, cl_and_vals_b):
99121
fn = create_uniq_field_dis_func(cl_a, cl_b)
100122

101123
assert fn(attr.asdict(cl_a(*vals_a, **kwargs_a))) is cl_a
124+
125+
126+
# not too sure of properties of `create_default_dis_func`
127+
def test_disambiguate_from_discriminated_enum():
128+
# can it find any discriminator?
129+
@define()
130+
class A:
131+
a: Literal[0]
132+
133+
@define()
134+
class B:
135+
a: Literal[1]
136+
137+
fn = create_default_dis_func(A, B)
138+
assert fn({"a": 0}) is A
139+
assert fn({"a": 1}) is B
140+
141+
# can it find the better discriminator?
142+
@define()
143+
class C:
144+
a: Literal[0]
145+
b: Literal[1]
146+
147+
@define()
148+
class D:
149+
a: Literal[0]
150+
b: Literal[0]
151+
152+
fn = create_default_dis_func(C, D)
153+
assert fn({"a": 0, "b": 1}) is C
154+
assert fn({"a": 0, "b": 0}) is D
155+
156+
# can it handle multiple tiers of discriminators?
157+
# (example inspired by Discord's gateway's discriminated union)
158+
@define()
159+
class E:
160+
op: Literal[1]
161+
162+
@define()
163+
class F:
164+
op: Literal[0]
165+
t: Literal["MESSAGE_CREATE"]
166+
167+
@define()
168+
class G:
169+
op: Literal[0]
170+
t: Literal["MESSAGE_UPDATE"]
171+
172+
fn = create_default_dis_func(E, F, G)
173+
assert fn({"op": 1}) is E
174+
assert fn({"op": 0, "t": "MESSAGE_CREATE"}) is Union[F, G]
175+
176+
# can it handle multiple literals?
177+
@define()
178+
class H:
179+
a: Literal[1]
180+
181+
@define()
182+
class J:
183+
a: Literal[0, 1]
184+
185+
@define()
186+
class K:
187+
a: Literal[0]
188+
189+
fn = create_default_dis_func(H, J, K)
190+
assert fn({"a": 1}) is Union[H, J]
191+
assert fn({"a": 0}) is Union[J, K]

tests/test_structure_attrs.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Loading of attrs classes."""
22
from enum import Enum
33
from ipaddress import IPv4Address, IPv6Address, ip_address
4-
from typing import Union
4+
from typing import Literal, Union
55
from unittest.mock import Mock
66

77
import pytest
@@ -139,8 +139,6 @@ def dis(obj, _):
139139
@pytest.mark.parametrize("converter_cls", [BaseConverter, Converter])
140140
def test_structure_literal(converter_cls):
141141
"""Structuring a class with a literal field works."""
142-
from typing import Literal
143-
144142
converter = converter_cls()
145143

146144
@define
@@ -155,8 +153,6 @@ class ClassWithLiteral:
155153
@pytest.mark.parametrize("converter_cls", [BaseConverter, Converter])
156154
def test_structure_literal_enum(converter_cls):
157155
"""Structuring a class with a literal field works."""
158-
from typing import Literal
159-
160156
converter = converter_cls()
161157

162158
class Foo(Enum):
@@ -175,8 +171,6 @@ class ClassWithLiteral:
175171
@pytest.mark.parametrize("converter_cls", [BaseConverter, Converter])
176172
def test_structure_literal_multiple(converter_cls):
177173
"""Structuring a class with a literal field works."""
178-
from typing import Literal
179-
180174
converter = converter_cls()
181175

182176
class Foo(Enum):
@@ -210,8 +204,6 @@ class ClassWithLiteral:
210204
@pytest.mark.parametrize("converter_cls", [BaseConverter, Converter])
211205
def test_structure_literal_error(converter_cls):
212206
"""Structuring a class with a literal field can raise an error."""
213-
from typing import Literal
214-
215207
converter = converter_cls()
216208

217209
@define
@@ -225,8 +217,6 @@ class ClassWithLiteral:
225217
@pytest.mark.parametrize("converter_cls", [BaseConverter, Converter])
226218
def test_structure_literal_multiple_error(converter_cls):
227219
"""Structuring a class with a literal field can raise an error."""
228-
from typing import Literal
229-
230220
converter = converter_cls()
231221

232222
@define
@@ -302,3 +292,28 @@ def test_structure_prefers_attrib_converters(converter_type):
302292

303293
attrib_converter.assert_any_call(5)
304294
assert inst.z == "5"
295+
296+
297+
@pytest.mark.parametrize("converter_type", [BaseConverter, Converter])
298+
def test_structure_multitier_discriminator_union(converter_type):
299+
converter = converter_type()
300+
301+
@define()
302+
class E:
303+
op: Literal[1]
304+
305+
@define()
306+
class F:
307+
op: Literal[0]
308+
t: Literal["MESSAGE_CREATE"]
309+
310+
@define()
311+
class G:
312+
op: Literal[0]
313+
t: Literal["MESSAGE_UPDATE"]
314+
315+
inst = converter.structure({"op": 1}, Union[E, F, G])
316+
assert isinstance(inst, E)
317+
318+
inst = converter.structure({"op": 0, "t": "MESSAGE_CREATE"}, Union[E, F, G])
319+
assert isinstance(inst, F)

0 commit comments

Comments
 (0)