|
12 | 12 | from msgspec.json import Encoder, decode |
13 | 13 |
|
14 | 14 | from cattrs._compat import fields, get_origin, has, is_bare, is_mapping, is_sequence |
15 | | -from cattrs.dispatch import HookFactory, UnstructureHook |
| 15 | +from cattrs.dispatch import UnstructureHook |
16 | 16 | from cattrs.fns import identity |
17 | 17 |
|
18 | | -from ..converters import Converter |
| 18 | +from ..converters import BaseConverter, Converter |
| 19 | +from ..gen import make_hetero_tuple_unstructure_fn |
19 | 20 | from ..strategies import configure_union_passthrough |
| 21 | +from ..tuples import is_namedtuple |
20 | 22 | from . import wrap |
21 | 23 |
|
22 | 24 | T = TypeVar("T") |
@@ -85,86 +87,89 @@ def configure_passthroughs(converter: Converter) -> None: |
85 | 87 | A passthrough is when we let msgspec handle something automatically. |
86 | 88 | """ |
87 | 89 | converter.register_unstructure_hook(bytes, to_builtins) |
88 | | - converter.register_unstructure_hook_factory( |
89 | | - is_mapping, make_unstructure_mapping_factory(converter) |
90 | | - ) |
91 | | - converter.register_unstructure_hook_factory( |
92 | | - is_sequence, make_unstructure_seq_factory(converter) |
93 | | - ) |
94 | | - converter.register_unstructure_hook_factory( |
95 | | - has, make_attrs_unstruct_factory(converter) |
| 90 | + converter.register_unstructure_hook_factory(is_mapping)(mapping_unstructure_factory) |
| 91 | + converter.register_unstructure_hook_factory(is_sequence)(seq_unstructure_factory) |
| 92 | + converter.register_unstructure_hook_factory(has)(attrs_unstructure_factory) |
| 93 | + converter.register_unstructure_hook_factory(is_namedtuple)( |
| 94 | + namedtuple_unstructure_factory |
96 | 95 | ) |
97 | 96 |
|
98 | 97 |
|
99 | | -def make_unstructure_seq_factory(converter: Converter) -> HookFactory[UnstructureHook]: |
100 | | - def unstructure_seq_factory(type) -> UnstructureHook: |
101 | | - if is_bare(type): |
102 | | - type_arg = Any |
103 | | - handler = converter.get_unstructure_hook(type_arg, cache_result=False) |
104 | | - elif getattr(type, "__args__", None) not in (None, ()): |
105 | | - type_arg = type.__args__[0] |
106 | | - handler = converter.get_unstructure_hook(type_arg, cache_result=False) |
107 | | - else: |
108 | | - handler = None |
109 | | - |
110 | | - if handler in (identity, to_builtins): |
111 | | - return handler |
112 | | - return converter.gen_unstructure_iterable(type) |
113 | | - |
114 | | - return unstructure_seq_factory |
115 | | - |
116 | | - |
117 | | -def make_unstructure_mapping_factory( |
118 | | - converter: Converter, |
119 | | -) -> HookFactory[UnstructureHook]: |
120 | | - def unstructure_mapping_factory(type) -> UnstructureHook: |
121 | | - if is_bare(type): |
122 | | - key_arg = Any |
123 | | - val_arg = Any |
124 | | - key_handler = converter.get_unstructure_hook(key_arg, cache_result=False) |
125 | | - value_handler = converter.get_unstructure_hook(val_arg, cache_result=False) |
126 | | - elif (args := getattr(type, "__args__", None)) not in (None, ()): |
127 | | - if len(args) == 2: |
128 | | - key_arg, val_arg = args |
129 | | - else: |
130 | | - # Probably a Counter |
131 | | - key_arg, val_arg = args, Any |
132 | | - key_handler = converter.get_unstructure_hook(key_arg, cache_result=False) |
133 | | - value_handler = converter.get_unstructure_hook(val_arg, cache_result=False) |
| 98 | +def seq_unstructure_factory(type, converter: BaseConverter) -> UnstructureHook: |
| 99 | + if is_bare(type): |
| 100 | + type_arg = Any |
| 101 | + handler = converter.get_unstructure_hook(type_arg, cache_result=False) |
| 102 | + elif getattr(type, "__args__", None) not in (None, ()): |
| 103 | + type_arg = type.__args__[0] |
| 104 | + handler = converter.get_unstructure_hook(type_arg, cache_result=False) |
| 105 | + else: |
| 106 | + handler = None |
| 107 | + |
| 108 | + if handler in (identity, to_builtins): |
| 109 | + return handler |
| 110 | + return converter.gen_unstructure_iterable(type) |
| 111 | + |
| 112 | + |
| 113 | +def mapping_unstructure_factory(type, converter: BaseConverter) -> UnstructureHook: |
| 114 | + if is_bare(type): |
| 115 | + key_arg = Any |
| 116 | + val_arg = Any |
| 117 | + key_handler = converter.get_unstructure_hook(key_arg, cache_result=False) |
| 118 | + value_handler = converter.get_unstructure_hook(val_arg, cache_result=False) |
| 119 | + elif (args := getattr(type, "__args__", None)) not in (None, ()): |
| 120 | + if len(args) == 2: |
| 121 | + key_arg, val_arg = args |
134 | 122 | else: |
135 | | - key_handler = value_handler = None |
| 123 | + # Probably a Counter |
| 124 | + key_arg, val_arg = args, Any |
| 125 | + key_handler = converter.get_unstructure_hook(key_arg, cache_result=False) |
| 126 | + value_handler = converter.get_unstructure_hook(val_arg, cache_result=False) |
| 127 | + else: |
| 128 | + key_handler = value_handler = None |
| 129 | + |
| 130 | + if key_handler in (identity, to_builtins) and value_handler in ( |
| 131 | + identity, |
| 132 | + to_builtins, |
| 133 | + ): |
| 134 | + return to_builtins |
| 135 | + return converter.gen_unstructure_mapping(type) |
| 136 | + |
136 | 137 |
|
137 | | - if key_handler in (identity, to_builtins) and value_handler in ( |
138 | | - identity, |
139 | | - to_builtins, |
140 | | - ): |
141 | | - return to_builtins |
142 | | - return converter.gen_unstructure_mapping(type) |
| 138 | +def attrs_unstructure_factory(type: Any, converter: BaseConverter) -> UnstructureHook: |
| 139 | + """Choose whether to use msgspec handling or our own.""" |
| 140 | + origin = get_origin(type) |
| 141 | + attribs = fields(origin or type) |
| 142 | + if attrs_has(type) and any(isinstance(a.type, str) for a in attribs): |
| 143 | + resolve_types(type) |
| 144 | + attribs = fields(origin or type) |
143 | 145 |
|
144 | | - return unstructure_mapping_factory |
| 146 | + if any( |
| 147 | + attr.name.startswith("_") |
| 148 | + or ( |
| 149 | + converter.get_unstructure_hook(attr.type, cache_result=False) |
| 150 | + not in (identity, to_builtins) |
| 151 | + ) |
| 152 | + for attr in attribs |
| 153 | + ): |
| 154 | + return converter.gen_unstructure_attrs_fromdict(type) |
145 | 155 |
|
| 156 | + return to_builtins |
146 | 157 |
|
147 | | -def make_attrs_unstruct_factory(converter: Converter) -> HookFactory[UnstructureHook]: |
148 | | - """Short-circuit attrs and dataclass handling if it matches msgspec.""" |
149 | 158 |
|
150 | | - def attrs_factory(type: Any) -> UnstructureHook: |
151 | | - """Choose whether to use msgspec handling or our own.""" |
152 | | - origin = get_origin(type) |
153 | | - attribs = fields(origin or type) |
154 | | - if attrs_has(type) and any(isinstance(a.type, str) for a in attribs): |
155 | | - resolve_types(type) |
156 | | - attribs = fields(origin or type) |
157 | | - |
158 | | - if any( |
159 | | - attr.name.startswith("_") |
160 | | - or ( |
161 | | - converter.get_unstructure_hook(attr.type, cache_result=False) |
162 | | - not in (identity, to_builtins) |
163 | | - ) |
164 | | - for attr in attribs |
165 | | - ): |
166 | | - return converter.gen_unstructure_attrs_fromdict(type) |
| 159 | +def namedtuple_unstructure_factory( |
| 160 | + type: type[tuple], converter: BaseConverter |
| 161 | +) -> UnstructureHook: |
| 162 | + """A hook factory for unstructuring namedtuples, modified for msgspec.""" |
167 | 163 |
|
168 | | - return to_builtins |
| 164 | + if all( |
| 165 | + converter.get_unstructure_hook(t) in (identity, to_builtins) |
| 166 | + for t in type.__annotations__.values() |
| 167 | + ): |
| 168 | + return identity |
169 | 169 |
|
170 | | - return attrs_factory |
| 170 | + return make_hetero_tuple_unstructure_fn( |
| 171 | + type, |
| 172 | + converter, |
| 173 | + unstructure_to=tuple, |
| 174 | + type_args=tuple(type.__annotations__.values()), |
| 175 | + ) |
0 commit comments