|
2 | 2 | from collections import OrderedDict, defaultdict |
3 | 3 | from functools import reduce |
4 | 4 | from operator import or_ |
5 | | -from typing import Any, Callable, Dict, Mapping, Optional, Type, Union |
| 5 | +from typing import Any, Callable, Dict, Mapping, Optional, Set, Type, Union |
6 | 6 |
|
7 | | -from attr import NOTHING, fields, fields_dict |
| 7 | +from attrs import NOTHING, fields, fields_dict |
8 | 8 |
|
9 | | -from cattrs._compat import get_args, get_origin, is_literal |
| 9 | +from ._compat import get_args, get_origin, has, is_literal, is_union_type |
| 10 | + |
| 11 | +__all__ = ("is_supported_union", "create_default_dis_func") |
| 12 | + |
| 13 | +NoneType = type(None) |
| 14 | + |
| 15 | + |
| 16 | +def is_supported_union(typ: Type) -> bool: |
| 17 | + """Whether the type is a union of attrs classes.""" |
| 18 | + return is_union_type(typ) and all( |
| 19 | + e is NoneType or has(get_origin(e) or e) for e in typ.__args__ |
| 20 | + ) |
10 | 21 |
|
11 | 22 |
|
12 | 23 | def create_default_dis_func( |
13 | | - *classes: Type[Any], |
| 24 | + *classes: Type[Any], use_literals: bool = True |
14 | 25 | ) -> Callable[[Mapping[Any, Any]], Optional[Type[Any]]]: |
15 | | - """Given attr classes, generate a disambiguation function. |
| 26 | + """Given attrs classes, generate a disambiguation function. |
| 27 | +
|
| 28 | + The function is based on unique fields or unique values. |
16 | 29 |
|
17 | | - The function is based on unique fields or unique values.""" |
| 30 | + :param use_literals: Whether to try using fields annotated as literals for |
| 31 | + disambiguation. |
| 32 | + """ |
18 | 33 | if len(classes) < 2: |
19 | 34 | raise ValueError("At least two classes required.") |
20 | 35 |
|
21 | 36 | # first, attempt for unique values |
| 37 | + if use_literals: |
| 38 | + # requirements for a discriminator field: |
| 39 | + # (... TODO: a single fallback is OK) |
| 40 | + # - it must always be enumerated |
| 41 | + cls_candidates = [ |
| 42 | + {at.name for at in fields(get_origin(cl) or cl) if is_literal(at.type)} |
| 43 | + for cl in classes |
| 44 | + ] |
| 45 | + |
| 46 | + # literal field names common to all members |
| 47 | + discriminators: Set[str] = cls_candidates[0] |
| 48 | + for possible_discriminators in cls_candidates: |
| 49 | + discriminators &= possible_discriminators |
| 50 | + |
| 51 | + best_result = None |
| 52 | + best_discriminator = None |
| 53 | + for discriminator in discriminators: |
| 54 | + # maps Literal values (strings, ints...) to classes |
| 55 | + mapping = defaultdict(list) |
| 56 | + |
| 57 | + for cl in classes: |
| 58 | + for key in get_args( |
| 59 | + fields_dict(get_origin(cl) or cl)[discriminator].type |
| 60 | + ): |
| 61 | + mapping[key].append(cl) |
| 62 | + |
| 63 | + if best_result is None or max(len(v) for v in mapping.values()) <= max( |
| 64 | + len(v) for v in best_result.values() |
| 65 | + ): |
| 66 | + best_result = mapping |
| 67 | + best_discriminator = discriminator |
| 68 | + |
| 69 | + if ( |
| 70 | + best_result |
| 71 | + and best_discriminator |
| 72 | + and max(len(v) for v in best_result.values()) != len(classes) |
| 73 | + ): |
| 74 | + final_mapping = { |
| 75 | + k: v[0] if len(v) == 1 else Union[tuple(v)] |
| 76 | + for k, v in best_result.items() |
| 77 | + } |
22 | 78 |
|
23 | | - # requirements for a discriminator field: |
24 | | - # (... TODO: a single fallback is OK) |
25 | | - # - it must be *required* |
26 | | - # - it must always be enumerated |
27 | | - cls_candidates = [ |
28 | | - { |
29 | | - at.name |
30 | | - for at in fields(get_origin(cl) or cl) |
31 | | - if at.default is NOTHING and is_literal(at.type) |
32 | | - } |
33 | | - for cl in classes |
34 | | - ] |
35 | | - |
36 | | - discriminators = cls_candidates[0] |
37 | | - for possible_discriminators in cls_candidates: |
38 | | - discriminators &= possible_discriminators |
39 | | - |
40 | | - best_result = None |
41 | | - best_discriminator = None |
42 | | - for discriminator in discriminators: |
43 | | - mapping = defaultdict(list) |
44 | | - |
45 | | - for cl in classes: |
46 | | - for key in get_args(fields_dict(get_origin(cl) or cl)[discriminator].type): |
47 | | - mapping[key].append(cl) |
| 79 | + def dis_func(data: Mapping[Any, Any]) -> Optional[Type]: |
| 80 | + if not isinstance(data, Mapping): |
| 81 | + raise ValueError("Only input mappings are supported.") |
| 82 | + return final_mapping[data[best_discriminator]] |
48 | 83 |
|
49 | | - if best_result is None or max(len(v) for v in mapping.values()) <= max( |
50 | | - len(v) for v in best_result.values() |
51 | | - ): |
52 | | - best_result = mapping |
53 | | - best_discriminator = discriminator |
54 | | - |
55 | | - if ( |
56 | | - best_result |
57 | | - and best_discriminator |
58 | | - and max(len(v) for v in best_result.values()) != len(classes) |
59 | | - ): |
60 | | - final_mapping = { |
61 | | - k: v[0] if len(v) == 1 else Union[tuple(v)] for k, v in best_result.items() |
62 | | - } |
63 | | - |
64 | | - def dis_func(data: Mapping[Any, Any]) -> Optional[Type]: |
65 | | - if not isinstance(data, Mapping): |
66 | | - raise ValueError("Only input mappings are supported.") |
67 | | - return final_mapping[data[best_discriminator]] |
68 | | - |
69 | | - return dis_func |
| 84 | + return dis_func |
70 | 85 |
|
71 | 86 | # next, attempt for unique keys |
72 | 87 |
|
|
0 commit comments