|
6 | 6 | from typing import TYPE_CHECKING |
7 | 7 |
|
8 | 8 | from griffe._internal.exceptions import AliasResolutionError, CyclicAliasError |
| 9 | +from griffe._internal.expressions import Expr |
9 | 10 | from griffe._internal.logger import logger |
10 | 11 |
|
11 | 12 | if TYPE_CHECKING: |
| 13 | + from collections.abc import Sequence |
| 14 | + |
12 | 15 | from griffe._internal.models import Attribute, Class, Function, Module, Object, TypeAlias |
13 | 16 |
|
14 | 17 |
|
@@ -60,10 +63,43 @@ def _merge_stubs_overloads(obj: Module | Class, stubs: Module | Class) -> None: |
60 | 63 | for function_name, overloads in list(stubs.overloads.items()): |
61 | 64 | if overloads: |
62 | 65 | with suppress(KeyError): |
63 | | - obj.get_member(function_name).overloads = overloads |
| 66 | + _merge_overload_annotations(obj.get_member(function_name), overloads) |
64 | 67 | del stubs.overloads[function_name] |
65 | 68 |
|
66 | 69 |
|
| 70 | +def _merge_annotations(annotations: Sequence[Expr]) -> Expr | None: |
| 71 | + if len(annotations) == 1: |
| 72 | + return annotations[0] |
| 73 | + if annotations: |
| 74 | + return Expr._to_binop(annotations, op="|") |
| 75 | + return None |
| 76 | + |
| 77 | + |
| 78 | +def _merge_overload_annotations(function: Function, overloads: list[Function]) -> None: |
| 79 | + function.overloads = overloads |
| 80 | + for parameter in function.parameters: |
| 81 | + if parameter.annotation is None: |
| 82 | + seen = set() |
| 83 | + annotations = [] |
| 84 | + for overload in overloads: |
| 85 | + with suppress(KeyError): |
| 86 | + annotation = overload.parameters[parameter.name].annotation |
| 87 | + str_annotation = str(annotation) |
| 88 | + if isinstance(annotation, Expr) and str_annotation not in seen: |
| 89 | + annotations.append(annotation) |
| 90 | + seen.add(str_annotation) |
| 91 | + parameter.annotation = _merge_annotations(annotations) |
| 92 | + if function.returns is None: |
| 93 | + seen = set() |
| 94 | + return_annotations = [] |
| 95 | + for overload in overloads: |
| 96 | + str_annotation = str(overload.returns) |
| 97 | + if isinstance(overload.returns, Expr) and str_annotation not in seen: |
| 98 | + return_annotations.append(overload.returns) |
| 99 | + seen.add(str_annotation) |
| 100 | + function.returns = _merge_annotations(return_annotations) |
| 101 | + |
| 102 | + |
67 | 103 | def _merge_stubs_members(obj: Module | Class, stubs: Module | Class) -> None: |
68 | 104 | # Merge imports to later know if objects coming from the stubs were imported. |
69 | 105 | obj.imports.update(stubs.imports) |
|
0 commit comments