Skip to content

Commit 9789a56

Browse files
authored
Disambiguate dataclasses too (#477)
1 parent 68081f4 commit 9789a56

5 files changed

Lines changed: 85 additions & 25 deletions

File tree

HISTORY.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@ Our backwards-compatibility policy can be found [here](https://github.com/python
1919
([#432](https://github.com/python-attrs/cattrs/issues/432) [#472](https://github.com/python-attrs/cattrs/pull/472))
2020
- The default union handler now properly takes renamed fields into account.
2121
([#472](https://github.com/python-attrs/cattrs/pull/472))
22+
- The default union handler now also handles dataclasses.
23+
([#](https://github.com/python-attrs/cattrs/pull/))
2224
- Add support for [PEP 695](https://peps.python.org/pep-0695/) type aliases.
2325
([#452](https://github.com/python-attrs/cattrs/pull/452))
2426
- The `include_subclasses` strategy now fetches the member hooks from the converter (making use of converter defaults) if overrides are not provided, instead of generating new hooks with no overrides.

docs/unions.md

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Handling Unions
22

3-
_cattrs_ is able to handle simple unions of _attrs_ classes [automatically](#default-union-strategy).
3+
_cattrs_ is able to handle simple unions of _attrs_ classes and dataclasses [automatically](#default-union-strategy).
44
More complex cases require converter customization (since there are many ways of handling unions).
55

6-
_cattrs_ also comes with a number of strategies to help handle unions:
6+
_cattrs_ also comes with a number of optional strategies to help handle unions:
77

88
- [tagged unions strategy](strategies.md#tagged-unions-strategy) mentioned below
99
- [union passthrough strategy](strategies.md#union-passthrough), which is preapplied to all the [preconfigured](preconf.md) converters
@@ -12,10 +12,10 @@ _cattrs_ also comes with a number of strategies to help handle unions:
1212

1313
For convenience, _cattrs_ includes a default union structuring strategy which is a little more opinionated.
1414

15-
Given a union of several _attrs_ classes, the default union strategy will attempt to handle it in several ways.
15+
Given a union of several _attrs_ classes and/or dataclasses, the default union strategy will attempt to handle it in several ways.
1616

1717
First, it will look for `Literal` fields.
18-
If all members of the union contain a literal field, _cattrs_ will generate a disambiguation function based on the field.
18+
If _all members_ of the union contain a literal field, _cattrs_ will generate a disambiguation function based on the field.
1919

2020
```python
2121
from typing import Literal
@@ -68,6 +68,10 @@ The field `field_with_default` will not be considered since it has a default val
6868
Literals can now be potentially used to disambiguate.
6969
```
7070

71+
```{versionchanged} 24.1.0
72+
Dataclasses are now supported in addition to _attrs_ classes.
73+
```
74+
7175
## Unstructuring Unions with Extra Metadata
7276

7377
```{note}

src/cattrs/_compat.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import deque
33
from collections.abc import MutableSet as AbcMutableSet
44
from collections.abc import Set as AbcSet
5-
from dataclasses import MISSING, is_dataclass
5+
from dataclasses import MISSING, Field, is_dataclass
66
from dataclasses import fields as dataclass_fields
77
from typing import AbstractSet as TypingAbstractSet
88
from typing import (
@@ -18,6 +18,7 @@
1818
Protocol,
1919
Tuple,
2020
Type,
21+
Union,
2122
get_args,
2223
get_origin,
2324
get_type_hints,
@@ -31,9 +32,11 @@
3132

3233
from attrs import NOTHING, Attribute, Factory, resolve_types
3334
from attrs import fields as attrs_fields
35+
from attrs import fields_dict as attrs_fields_dict
3436

3537
__all__ = [
3638
"adapted_fields",
39+
"fields_dict",
3740
"ExceptionGroup",
3841
"ExtensionsTypedDict",
3942
"get_type_alias_base",
@@ -119,6 +122,13 @@ def fields(type):
119122
raise Exception("Not an attrs or dataclass class.") from None
120123

121124

125+
def fields_dict(type) -> Dict[str, Union[Attribute, Field]]:
126+
"""Return the fields_dict for attrs and dataclasses."""
127+
if is_dataclass(type):
128+
return {f.name: f for f in dataclass_fields(type)}
129+
return attrs_fields_dict(type)
130+
131+
122132
def adapted_fields(cl) -> List[Attribute]:
123133
"""Return the attrs format of `fields()` for attrs and dataclasses."""
124134
if is_dataclass(cl):

src/cattrs/disambiguators.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,23 @@
22
from __future__ import annotations
33

44
from collections import defaultdict
5+
from dataclasses import MISSING
56
from functools import reduce
67
from operator import or_
78
from typing import TYPE_CHECKING, Any, Callable, Literal, Mapping, Union
89

9-
from attrs import NOTHING, Attribute, AttrsInstance, fields, fields_dict
10-
11-
from ._compat import NoneType, get_args, get_origin, has, is_literal, is_union_type
10+
from attrs import NOTHING, Attribute, AttrsInstance
11+
12+
from ._compat import (
13+
NoneType,
14+
adapted_fields,
15+
fields_dict,
16+
get_args,
17+
get_origin,
18+
has,
19+
is_literal,
20+
is_union_type,
21+
)
1222
from .gen import AttributeOverride
1323

1424
if TYPE_CHECKING:
@@ -31,13 +41,16 @@ def create_default_dis_func(
3141
overrides: dict[str, AttributeOverride]
3242
| Literal["from_converter"] = "from_converter",
3343
) -> Callable[[Mapping[Any, Any]], type[Any] | None]:
34-
"""Given attrs classes, generate a disambiguation function.
44+
"""Given attrs classes or dataclasses, generate a disambiguation function.
3545
3646
The function is based on unique fields without defaults or unique values.
3747
3848
:param use_literals: Whether to try using fields annotated as literals for
3949
disambiguation.
4050
:param overrides: Attribute overrides to apply.
51+
52+
.. versionchanged:: 24.1.0
53+
Dataclasses are now supported.
4154
"""
4255
if len(classes) < 2:
4356
raise ValueError("At least two classes required.")
@@ -55,7 +68,11 @@ def create_default_dis_func(
5568
# (... TODO: a single fallback is OK)
5669
# - it must always be enumerated
5770
cls_candidates = [
58-
{at.name for at in fields(get_origin(cl) or cl) if is_literal(at.type)}
71+
{
72+
at.name
73+
for at in adapted_fields(get_origin(cl) or cl)
74+
if is_literal(at.type)
75+
}
5976
for cl in classes
6077
]
6178

@@ -128,10 +145,10 @@ def dis_func(data: Mapping[Any, Any]) -> type | None:
128145
uniq = cl_reqs - other_reqs
129146

130147
# We want a unique attribute with no default.
131-
cl_fields = fields(get_origin(cl) or cl)
148+
cl_fields = fields_dict(get_origin(cl) or cl)
132149
for maybe_renamed_attr_name in uniq:
133150
orig_name = back_map[maybe_renamed_attr_name]
134-
if getattr(cl_fields, orig_name).default is NOTHING:
151+
if cl_fields[orig_name].default in (NOTHING, MISSING):
135152
break
136153
else:
137154
if fallback is None:
@@ -173,13 +190,13 @@ def _overriden_name(at: Attribute, override: AttributeOverride | None) -> str:
173190

174191

175192
def _usable_attribute_names(
176-
cl: type[AttrsInstance], overrides: dict[str, AttributeOverride]
193+
cl: type[Any], overrides: dict[str, AttributeOverride]
177194
) -> tuple[set[str], dict[str, str]]:
178195
"""Return renamed fields and a mapping to original field names."""
179196
res = set()
180197
mapping = {}
181198

182-
for at in fields(get_origin(cl) or cl):
199+
for at in adapted_fields(get_origin(cl) or cl):
183200
res.add(n := _overriden_name(at, overrides.get(at.name)))
184201
mapping[n] = at.name
185202

tests/test_disambiguators.py

Lines changed: 38 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Tests for auto-disambiguators."""
2+
from dataclasses import dataclass
23
from functools import partial
34
from typing import Literal, Union
45

@@ -7,11 +8,7 @@
78
from hypothesis import HealthCheck, assume, given, settings
89

910
from cattrs import Converter
10-
from cattrs.disambiguators import (
11-
create_default_dis_func,
12-
create_uniq_field_dis_func,
13-
is_supported_union,
14-
)
11+
from cattrs.disambiguators import create_default_dis_func, is_supported_union
1512
from cattrs.gen import make_dict_structure_fn, override
1613

1714
from .untyped import simple_classes
@@ -27,7 +24,7 @@ class A:
2724

2825
with pytest.raises(ValueError):
2926
# Can't generate for only one class.
30-
create_uniq_field_dis_func(c, A)
27+
create_default_dis_func(c, A)
3128

3229
with pytest.raises(ValueError):
3330
create_default_dis_func(c, A)
@@ -38,7 +35,7 @@ class B:
3835

3936
with pytest.raises(TypeError):
4037
# No fields on either class.
41-
create_uniq_field_dis_func(c, A, B)
38+
create_default_dis_func(c, A, B)
4239

4340
@define
4441
class C:
@@ -50,7 +47,7 @@ class D:
5047

5148
with pytest.raises(TypeError):
5249
# No unique fields on either class.
53-
create_uniq_field_dis_func(c, C, D)
50+
create_default_dis_func(c, C, D)
5451

5552
with pytest.raises(TypeError):
5653
# No discriminator candidates
@@ -66,7 +63,7 @@ class F:
6663

6764
with pytest.raises(TypeError):
6865
# no usable non-default attributes
69-
create_uniq_field_dis_func(c, E, F)
66+
create_default_dis_func(c, E, F)
7067

7168
@define
7269
class G:
@@ -93,7 +90,7 @@ def test_fallback(cl_and_vals):
9390
class A:
9491
pass
9592

96-
fn = create_uniq_field_dis_func(c, A, cl)
93+
fn = create_default_dis_func(c, A, cl)
9794

9895
assert fn({}) is A
9996
assert fn(asdict(cl(*vals, **kwargs))) is cl
@@ -124,7 +121,7 @@ def test_disambiguation(cl_and_vals_a, cl_and_vals_b):
124121
for attr_name in req_b - req_a:
125122
assume(getattr(fields(cl_b), attr_name).default is NOTHING)
126123

127-
fn = create_uniq_field_dis_func(c, cl_a, cl_b)
124+
fn = create_default_dis_func(c, cl_a, cl_b)
128125

129126
assert fn(asdict(cl_a(*vals_a, **kwargs_a))) is cl_a
130127

@@ -271,3 +268,33 @@ class B:
271268

272269
assert converter.structure({"a": 1}, Union[A, B]) == A(1)
273270
assert converter.structure({"b": 1}, Union[A, B]) == B(1)
271+
272+
273+
def test_dataclasses(converter):
274+
"""The default strategy works for dataclasses too."""
275+
276+
@define
277+
class A:
278+
a: int
279+
280+
@dataclass
281+
class B:
282+
b: int
283+
284+
assert converter.structure({"a": 1}, Union[A, B]) == A(1)
285+
assert converter.structure({"b": 1}, Union[A, B]) == B(1)
286+
287+
288+
def test_dataclasses_literals(converter):
289+
"""The default strategy works for dataclasses too."""
290+
291+
@define
292+
class A:
293+
a: Literal["a"] = "a"
294+
295+
@dataclass
296+
class B:
297+
b: Literal["b"]
298+
299+
assert converter.structure({"a": "a"}, Union[A, B]) == A()
300+
assert converter.structure({"b": "b"}, Union[A, B]) == B("b")

0 commit comments

Comments
 (0)