Skip to content

Commit 66fae25

Browse files
authored
Improve msgspec coverage (#501)
* Improve msgspec coverage * Test for enums and literals
1 parent 3c4572f commit 66fae25

2 files changed

Lines changed: 51 additions & 18 deletions

File tree

src/cattrs/preconf/msgspec.py

Lines changed: 28 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,24 @@
33

44
from base64 import b64decode
55
from datetime import date, datetime
6+
from enum import Enum
67
from functools import partial
7-
from typing import Any, Callable, TypeVar, Union
8+
from typing import Any, Callable, TypeVar, Union, get_type_hints
89

910
from attrs import has as attrs_has
1011
from attrs import resolve_types
1112
from msgspec import Struct, convert, to_builtins
1213
from msgspec.json import Encoder, decode
1314

14-
from cattrs._compat import fields, get_origin, has, is_bare, is_mapping, is_sequence
15+
from cattrs._compat import (
16+
fields,
17+
get_args,
18+
get_origin,
19+
has,
20+
is_bare,
21+
is_mapping,
22+
is_sequence,
23+
)
1524
from cattrs.dispatch import UnstructureHook
1625
from cattrs.fns import identity
1726

@@ -61,11 +70,13 @@ def configure_converter(converter: Converter) -> None:
6170
6271
* bytes are serialized as base64 strings, directly by msgspec
6372
* datetimes and dates are passed through to be serialized as RFC 3339 directly
73+
* enums are passed through to msgspec directly
6474
* union passthrough configured for str, bool, int, float and None
6575
"""
6676
configure_passthroughs(converter)
6777

6878
converter.register_unstructure_hook(Struct, to_builtins)
79+
converter.register_unstructure_hook(Enum, to_builtins)
6980

7081
converter.register_structure_hook(Struct, convert)
7182
converter.register_structure_hook(bytes, lambda v, _: b64decode(v))
@@ -87,45 +98,45 @@ def configure_passthroughs(converter: Converter) -> None:
8798
A passthrough is when we let msgspec handle something automatically.
8899
"""
89100
converter.register_unstructure_hook(bytes, to_builtins)
90-
converter.register_unstructure_hook_factory(is_mapping)(mapping_unstructure_factory)
91-
converter.register_unstructure_hook_factory(is_sequence)(seq_unstructure_factory)
92-
converter.register_unstructure_hook_factory(has)(attrs_unstructure_factory)
93-
converter.register_unstructure_hook_factory(is_namedtuple)(
94-
namedtuple_unstructure_factory
101+
converter.register_unstructure_hook_factory(is_mapping, mapping_unstructure_factory)
102+
converter.register_unstructure_hook_factory(is_sequence, seq_unstructure_factory)
103+
converter.register_unstructure_hook_factory(has, attrs_unstructure_factory)
104+
converter.register_unstructure_hook_factory(
105+
is_namedtuple, namedtuple_unstructure_factory
95106
)
96107

97108

98109
def seq_unstructure_factory(type, converter: BaseConverter) -> UnstructureHook:
110+
"""The msgspec unstructure hook factory for sequences."""
99111
if is_bare(type):
100112
type_arg = Any
101113
handler = converter.get_unstructure_hook(type_arg, cache_result=False)
102-
elif getattr(type, "__args__", None) not in (None, ()):
103-
type_arg = type.__args__[0]
104-
handler = converter.get_unstructure_hook(type_arg, cache_result=False)
105114
else:
106-
handler = None
115+
args = get_args(type)
116+
type_arg = args[0]
117+
handler = converter.get_unstructure_hook(type_arg, cache_result=False)
107118

108119
if handler in (identity, to_builtins):
109120
return handler
110121
return converter.gen_unstructure_iterable(type)
111122

112123

113124
def mapping_unstructure_factory(type, converter: BaseConverter) -> UnstructureHook:
125+
"""The msgspec unstructure hook factory for mappings."""
114126
if is_bare(type):
115127
key_arg = Any
116128
val_arg = Any
117129
key_handler = converter.get_unstructure_hook(key_arg, cache_result=False)
118130
value_handler = converter.get_unstructure_hook(val_arg, cache_result=False)
119-
elif (args := getattr(type, "__args__", None)) not in (None, ()):
131+
else:
132+
args = get_args(type)
120133
if len(args) == 2:
121134
key_arg, val_arg = args
122135
else:
123136
# Probably a Counter
124137
key_arg, val_arg = args, Any
125138
key_handler = converter.get_unstructure_hook(key_arg, cache_result=False)
126139
value_handler = converter.get_unstructure_hook(val_arg, cache_result=False)
127-
else:
128-
key_handler = value_handler = None
129140

130141
if key_handler in (identity, to_builtins) and value_handler in (
131142
identity,
@@ -135,7 +146,7 @@ def mapping_unstructure_factory(type, converter: BaseConverter) -> UnstructureHo
135146
return converter.gen_unstructure_mapping(type)
136147

137148

138-
def attrs_unstructure_factory(type: Any, converter: BaseConverter) -> UnstructureHook:
149+
def attrs_unstructure_factory(type: Any, converter: Converter) -> UnstructureHook:
139150
"""Choose whether to use msgspec handling or our own."""
140151
origin = get_origin(type)
141152
attribs = fields(origin or type)
@@ -163,13 +174,13 @@ def namedtuple_unstructure_factory(
163174

164175
if all(
165176
converter.get_unstructure_hook(t) in (identity, to_builtins)
166-
for t in type.__annotations__.values()
177+
for t in get_type_hints(type).values()
167178
):
168179
return identity
169180

170181
return make_hetero_tuple_unstructure_fn(
171182
type,
172183
converter,
173184
unstructure_to=tuple,
174-
type_args=tuple(type.__annotations__.values()),
185+
type_args=tuple(get_type_hints(type).values()),
175186
)

tests/preconf/test_msgspec_cpython.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
"""Tests for msgspec functionality."""
2+
from __future__ import annotations
3+
4+
from enum import Enum
25
from typing import (
36
Any,
47
Callable,
58
Dict,
69
List,
10+
Literal,
711
Mapping,
812
MutableMapping,
913
MutableSequence,
@@ -59,6 +63,10 @@ class NC(NamedTuple):
5963
a: C
6064

6165

66+
class E(Enum):
67+
TEST = "test"
68+
69+
6270
@fixture
6371
def converter() -> Conv:
6472
return make_converter()
@@ -75,10 +83,17 @@ def test_unstructure_passthrough(converter: Conv):
7583
assert converter.get_unstructure_hook(str) == identity
7684
assert is_passthrough(converter.get_unstructure_hook(bytes))
7785
assert converter.get_unstructure_hook(None) == identity
86+
assert is_passthrough(converter.get_unstructure_hook(Literal[1]))
87+
assert is_passthrough(converter.get_unstructure_hook(E))
7888

7989
# Any is special-cased, and we cannot know if it'll match
8090
# the msgspec behavior.
8191
assert not is_passthrough(converter.get_unstructure_hook(List))
92+
assert not is_passthrough(converter.get_unstructure_hook(Sequence))
93+
assert not is_passthrough(converter.get_unstructure_hook(MutableSequence))
94+
assert not is_passthrough(converter.get_unstructure_hook(List[Any]))
95+
assert not is_passthrough(converter.get_unstructure_hook(Sequence))
96+
assert not is_passthrough(converter.get_unstructure_hook(MutableSequence))
8297

8398
assert is_passthrough(converter.get_unstructure_hook(List[int]))
8499
assert is_passthrough(converter.get_unstructure_hook(Sequence[int]))
@@ -101,9 +116,13 @@ def test_unstructure_pt_mappings(converter: Conv):
101116
assert is_passthrough(converter.get_unstructure_hook(Dict[str, str]))
102117
assert is_passthrough(converter.get_unstructure_hook(Dict[int, int]))
103118

104-
assert is_passthrough(converter.get_unstructure_hook(Dict[int, A]))
119+
assert not is_passthrough(converter.get_unstructure_hook(Dict))
120+
assert not is_passthrough(converter.get_unstructure_hook(dict))
105121
assert not is_passthrough(converter.get_unstructure_hook(Dict[int, B]))
122+
assert not is_passthrough(converter.get_unstructure_hook(Mapping))
123+
assert not is_passthrough(converter.get_unstructure_hook(MutableMapping))
106124

125+
assert is_passthrough(converter.get_unstructure_hook(Dict[int, A]))
107126
assert is_passthrough(converter.get_unstructure_hook(Mapping[int, int]))
108127
assert is_passthrough(converter.get_unstructure_hook(MutableMapping[int, int]))
109128

@@ -113,6 +132,9 @@ def test_dump_hook(converter: Conv):
113132
assert converter.get_dumps_hook(A) == converter.encoder.encode
114133
assert converter.get_dumps_hook(Dict[str, str]) == converter.encoder.encode
115134

135+
# msgspec cannot handle these, so cattrs does.
136+
assert converter.get_dumps_hook(B) == converter.dumps
137+
116138

117139
def test_get_loads_hook(converter: Conv):
118140
"""`Converter.get_loads_hook` works."""

0 commit comments

Comments
 (0)