Skip to content

Commit cd6f816

Browse files
committed
feat: Support merging overload annotations into implementation
Issue-442: #442
1 parent 59266a9 commit cd6f816

File tree

3 files changed

+75
-17
lines changed

3 files changed

+75
-17
lines changed

packages/griffelib/src/griffe/_internal/expressions.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,22 @@ def is_generator(self) -> bool:
251251
"""Whether this expression is a generator."""
252252
return isinstance(self, ExprSubscript) and self.canonical_name == "Generator"
253253

254+
@staticmethod
255+
def _to_binop(elements: Sequence[Expr], op: str) -> ExprBinOp:
256+
if len(elements) == 2: # noqa: PLR2004
257+
left, right = elements
258+
if isinstance(left, Expr):
259+
left = left.modernize()
260+
if isinstance(right, Expr):
261+
right = right.modernize()
262+
return ExprBinOp(left=left, operator=op, right=right)
263+
264+
left = ExprSubscript._to_binop(elements[:-1], op=op)
265+
right = elements[-1]
266+
if isinstance(right, Expr):
267+
right = right.modernize()
268+
return ExprBinOp(left=left, operator=op, right=right)
269+
254270

255271
@dataclass(eq=True, slots=True)
256272
class ExprAttribute(Expr):
@@ -888,22 +904,6 @@ def iterate(self, *, flat: bool = True) -> Iterator[str | Expr]:
888904
yield from _yield(self.slice, flat=flat, outer_precedence=_OperatorPrecedence.NONE)
889905
yield "]"
890906

891-
@staticmethod
892-
def _to_binop(elements: Sequence[Expr], op: str) -> ExprBinOp:
893-
if len(elements) == 2: # noqa: PLR2004
894-
left, right = elements
895-
if isinstance(left, Expr):
896-
left = left.modernize()
897-
if isinstance(right, Expr):
898-
right = right.modernize()
899-
return ExprBinOp(left=left, operator=op, right=right)
900-
901-
left = ExprSubscript._to_binop(elements[:-1], op=op)
902-
right = elements[-1]
903-
if isinstance(right, Expr):
904-
right = right.modernize()
905-
return ExprBinOp(left=left, operator=op, right=right)
906-
907907
def modernize(self) -> ExprBinOp | ExprSubscript:
908908
if self.canonical_path == "typing.Union":
909909
return self._to_binop(self.slice.elements, op="|") # type: ignore[union-attr]

packages/griffelib/src/griffe/_internal/merger.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
from typing import TYPE_CHECKING
77

88
from griffe._internal.exceptions import AliasResolutionError, CyclicAliasError
9+
from griffe._internal.expressions import Expr
910
from griffe._internal.logger import logger
1011

1112
if TYPE_CHECKING:
13+
from collections.abc import Sequence
14+
1215
from griffe._internal.models import Attribute, Class, Function, Module, Object, TypeAlias
1316

1417

@@ -60,10 +63,43 @@ def _merge_stubs_overloads(obj: Module | Class, stubs: Module | Class) -> None:
6063
for function_name, overloads in list(stubs.overloads.items()):
6164
if overloads:
6265
with suppress(KeyError):
63-
obj.get_member(function_name).overloads = overloads
66+
_merge_overload_annotations(obj.get_member(function_name), overloads)
6467
del stubs.overloads[function_name]
6568

6669

70+
def _merge_annotations(annotations: Sequence[Expr]) -> Expr | None:
71+
if len(annotations) == 1:
72+
return annotations[0]
73+
if annotations:
74+
return Expr._to_binop(annotations, op="|")
75+
return None
76+
77+
78+
def _merge_overload_annotations(function: Function, overloads: list[Function]) -> None:
79+
function.overloads = overloads
80+
for parameter in function.parameters:
81+
if parameter.annotation is None:
82+
seen = set()
83+
annotations = []
84+
for overload in overloads:
85+
with suppress(KeyError):
86+
annotation = overload.parameters[parameter.name].annotation
87+
str_annotation = str(annotation)
88+
if isinstance(annotation, Expr) and str_annotation not in seen:
89+
annotations.append(annotation)
90+
seen.add(str_annotation)
91+
parameter.annotation = _merge_annotations(annotations)
92+
if function.returns is None:
93+
seen = set()
94+
return_annotations = []
95+
for overload in overloads:
96+
str_annotation = str(overload.returns)
97+
if isinstance(overload.returns, Expr) and str_annotation not in seen:
98+
return_annotations.append(overload.returns)
99+
seen.add(str_annotation)
100+
function.returns = _merge_annotations(return_annotations)
101+
102+
67103
def _merge_stubs_members(obj: Module | Class, stubs: Module | Class) -> None:
68104
# Merge imports to later know if objects coming from the stubs were imported.
69105
obj.imports.update(stubs.imports)

tests/test_merger.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,25 @@ def test_merge_attribute_values() -> None:
7373
},
7474
) as pkg:
7575
assert str(pkg["__all__"].value) == "['hello']"
76+
77+
78+
def test_merge_overload_annotations() -> None:
79+
"""Assert that overload annotations are merged correctly."""
80+
with temporary_visited_package(
81+
"package",
82+
{
83+
"mod.py": "def func(x): ...",
84+
"mod.pyi": """
85+
from typing import overload
86+
87+
@overload
88+
def func(x: int) -> int: ...
89+
90+
@overload
91+
def func(x: float) -> float: ...
92+
""",
93+
},
94+
) as pkg:
95+
func = pkg["mod.func"]
96+
assert str(func.parameters["x"].annotation) == "int | float"
97+
assert str(func.returns) == "int | float"

0 commit comments

Comments
 (0)