33
44from base64 import b64decode
55from datetime import date , datetime
6+ from enum import Enum
67from functools import partial
7- from typing import Any , Callable , TypeVar , Union
8+ from typing import Any , Callable , TypeVar , Union , get_type_hints
89
910from attrs import has as attrs_has
1011from attrs import resolve_types
1112from msgspec import Struct , convert , to_builtins
1213from msgspec .json import Encoder , decode
1314
14- from cattrs ._compat import fields , get_origin , has , is_bare , is_mapping , is_sequence
15+ from cattrs ._compat import (
16+ fields ,
17+ get_args ,
18+ get_origin ,
19+ has ,
20+ is_bare ,
21+ is_mapping ,
22+ is_sequence ,
23+ )
1524from cattrs .dispatch import UnstructureHook
1625from cattrs .fns import identity
1726
@@ -61,11 +70,13 @@ def configure_converter(converter: Converter) -> None:
6170
6271 * bytes are serialized as base64 strings, directly by msgspec
6372 * datetimes and dates are passed through to be serialized as RFC 3339 directly
73+ * enums are passed through to msgspec directly
6474 * union passthrough configured for str, bool, int, float and None
6575 """
6676 configure_passthroughs (converter )
6777
6878 converter .register_unstructure_hook (Struct , to_builtins )
79+ converter .register_unstructure_hook (Enum , to_builtins )
6980
7081 converter .register_structure_hook (Struct , convert )
7182 converter .register_structure_hook (bytes , lambda v , _ : b64decode (v ))
@@ -87,45 +98,45 @@ def configure_passthroughs(converter: Converter) -> None:
8798 A passthrough is when we let msgspec handle something automatically.
8899 """
89100 converter .register_unstructure_hook (bytes , to_builtins )
90- converter .register_unstructure_hook_factory (is_mapping )( mapping_unstructure_factory )
91- converter .register_unstructure_hook_factory (is_sequence )( seq_unstructure_factory )
92- converter .register_unstructure_hook_factory (has )( attrs_unstructure_factory )
93- converter .register_unstructure_hook_factory (is_namedtuple )(
94- namedtuple_unstructure_factory
101+ converter .register_unstructure_hook_factory (is_mapping , mapping_unstructure_factory )
102+ converter .register_unstructure_hook_factory (is_sequence , seq_unstructure_factory )
103+ converter .register_unstructure_hook_factory (has , attrs_unstructure_factory )
104+ converter .register_unstructure_hook_factory (
105+ is_namedtuple , namedtuple_unstructure_factory
95106 )
96107
97108
98109def seq_unstructure_factory (type , converter : BaseConverter ) -> UnstructureHook :
110+ """The msgspec unstructure hook factory for sequences."""
99111 if is_bare (type ):
100112 type_arg = Any
101113 handler = converter .get_unstructure_hook (type_arg , cache_result = False )
102- elif getattr (type , "__args__" , None ) not in (None , ()):
103- type_arg = type .__args__ [0 ]
104- handler = converter .get_unstructure_hook (type_arg , cache_result = False )
105114 else :
106- handler = None
115+ args = get_args (type )
116+ type_arg = args [0 ]
117+ handler = converter .get_unstructure_hook (type_arg , cache_result = False )
107118
108119 if handler in (identity , to_builtins ):
109120 return handler
110121 return converter .gen_unstructure_iterable (type )
111122
112123
113124def mapping_unstructure_factory (type , converter : BaseConverter ) -> UnstructureHook :
125+ """The msgspec unstructure hook factory for mappings."""
114126 if is_bare (type ):
115127 key_arg = Any
116128 val_arg = Any
117129 key_handler = converter .get_unstructure_hook (key_arg , cache_result = False )
118130 value_handler = converter .get_unstructure_hook (val_arg , cache_result = False )
119- elif (args := getattr (type , "__args__" , None )) not in (None , ()):
131+ else :
132+ args = get_args (type )
120133 if len (args ) == 2 :
121134 key_arg , val_arg = args
122135 else :
123136 # Probably a Counter
124137 key_arg , val_arg = args , Any
125138 key_handler = converter .get_unstructure_hook (key_arg , cache_result = False )
126139 value_handler = converter .get_unstructure_hook (val_arg , cache_result = False )
127- else :
128- key_handler = value_handler = None
129140
130141 if key_handler in (identity , to_builtins ) and value_handler in (
131142 identity ,
@@ -135,7 +146,7 @@ def mapping_unstructure_factory(type, converter: BaseConverter) -> UnstructureHo
135146 return converter .gen_unstructure_mapping (type )
136147
137148
138- def attrs_unstructure_factory (type : Any , converter : BaseConverter ) -> UnstructureHook :
149+ def attrs_unstructure_factory (type : Any , converter : Converter ) -> UnstructureHook :
139150 """Choose whether to use msgspec handling or our own."""
140151 origin = get_origin (type )
141152 attribs = fields (origin or type )
@@ -163,13 +174,13 @@ def namedtuple_unstructure_factory(
163174
164175 if all (
165176 converter .get_unstructure_hook (t ) in (identity , to_builtins )
166- for t in type . __annotations__ .values ()
177+ for t in get_type_hints ( type ) .values ()
167178 ):
168179 return identity
169180
170181 return make_hetero_tuple_unstructure_fn (
171182 type ,
172183 converter ,
173184 unstructure_to = tuple ,
174- type_args = tuple (type . __annotations__ .values ()),
185+ type_args = tuple (get_type_hints ( type ) .values ()),
175186 )
0 commit comments