Skip to content

Commit 5bed1be

Browse files
authored
refactor: Improve handling of typed dicts
Issue-284: #284 Issue-mkdocstrings-python-207: mkdocstrings/python#207 PR-414: #414
1 parent 708fd84 commit 5bed1be

2 files changed

Lines changed: 357 additions & 54 deletions

File tree

Lines changed: 181 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,188 @@
1+
# TODO: Support `extra_items=type`.
2+
# TODO: Support `closed=True/False`.
3+
14
from __future__ import annotations
25

3-
from typing import TYPE_CHECKING, Any
6+
import ast
7+
from itertools import chain
8+
from typing import TYPE_CHECKING, Any, TypedDict
49

5-
from griffe._internal.docstrings.models import DocstringParameter, DocstringSectionParameters
10+
from griffe._internal.docstrings.models import (
11+
DocstringParameter,
12+
DocstringSectionParameters,
13+
)
614
from griffe._internal.enumerations import DocstringSectionKind, ParameterKind
715
from griffe._internal.expressions import Expr, ExprSubscript
816
from griffe._internal.extensions.base import Extension
917
from griffe._internal.models import Class, Docstring, Function, Parameter, Parameters
1018

1119
if TYPE_CHECKING:
12-
from collections.abc import Iterable
20+
from collections.abc import Iterable, Iterator
21+
22+
23+
class _TypedDictAttr(TypedDict):
24+
name: str
25+
annotation: str | Expr | None
26+
docstring: Docstring | None
27+
28+
29+
def _unwrap_annotation(annotation: str | Expr | None, *, default_required: bool) -> tuple[str | Expr | None, bool]:
30+
required = default_required
31+
32+
# Annotations can be written ReadOnly[Required[T]] or Required[ReadOnly[T]],
33+
# so we unwrap a first time here and a second time at the end.
34+
if isinstance(annotation, ExprSubscript) and annotation.canonical_path in {
35+
"typing.ReadOnly",
36+
"typing_extensions.ReadOnly",
37+
}:
38+
annotation = annotation.slice
39+
40+
# Unwrap `Required` and `NotRequired`, set `required` accordingly.
41+
if isinstance(annotation, ExprSubscript):
42+
if annotation.canonical_path in {
43+
"typing.Required",
44+
"typing_extensions.Required",
45+
}:
46+
annotation = annotation.slice
47+
required = True
48+
elif annotation.canonical_path in {
49+
"typing.NotRequired",
50+
"typing_extensions.NotRequired",
51+
}:
52+
annotation = annotation.slice
53+
required = False
54+
55+
# Unwrap `ReadOnly` a second time here.
56+
if isinstance(annotation, ExprSubscript) and annotation.canonical_path in {
57+
"typing.ReadOnly",
58+
"typing_extensions.ReadOnly",
59+
}:
60+
annotation = annotation.slice
61+
62+
return annotation, required
63+
64+
65+
def _get_or_set_attrs(cls: Class) -> tuple[list[_TypedDictAttr], list[_TypedDictAttr]]:
66+
if (attrs := cls.extra.get("unpack_typeddict", {}).get("_attributes")) is not None:
67+
return attrs
68+
69+
# Inspect `total` keyword argument to determine default requiredness.
70+
default_required = True
71+
for arg, value in cls.keywords.items():
72+
if arg == "total":
73+
try:
74+
total = ast.literal_eval(str(value))
75+
except (ValueError, SyntaxError):
76+
break
77+
if total is True:
78+
default_required = True
79+
elif total is False:
80+
default_required = False
81+
break
82+
83+
# Extract attributes.
84+
required_attrs = []
85+
optional_attrs = []
86+
for attr in cls.attributes.values():
87+
annotation, required = _unwrap_annotation(attr.annotation, default_required=default_required)
88+
if required:
89+
required_attrs.append(
90+
_TypedDictAttr(
91+
name=attr.name,
92+
annotation=annotation,
93+
docstring=attr.docstring,
94+
),
95+
)
96+
else:
97+
optional_attrs.append(
98+
_TypedDictAttr(
99+
name=attr.name,
100+
annotation=annotation,
101+
docstring=attr.docstring,
102+
),
103+
)
13104

105+
cls.extra["unpack_typeddict"]["_attributes"] = (required_attrs, optional_attrs)
106+
return (required_attrs, optional_attrs)
14107

15-
def _update_docstring(func: Function, parameters: Iterable[Parameter], kwparam: Parameter | None = None) -> None:
108+
109+
def _update_docstring(
110+
func: Function,
111+
required: Iterable[_TypedDictAttr],
112+
optional: Iterable[_TypedDictAttr],
113+
kwparam: Parameter | None = None,
114+
) -> None:
16115
if not func.docstring:
17116
func.docstring = Docstring("", parent=func)
117+
118+
params_section = None
18119
sections = func.docstring.parsed
120+
121+
# Find existing "Parameters" section.
19122
section_gen = (section for section in sections if section.kind is DocstringSectionKind.parameters)
20-
if kwparam and (params_section := next(section_gen, None)):
21-
# Remove the `**kwargs` entry.
123+
params_section = next(section_gen, None)
124+
125+
# Pop original variadic keyword parameter from section.
126+
if kwparam and params_section is not None:
22127
param_gen = (i for i, arg in enumerate(params_section.value) if arg.name.lstrip("*") == kwparam.name)
23128
if (kwarg_pos := next(param_gen, None)) is not None:
24129
params_section.value.pop(kwarg_pos)
25-
else:
26-
# Create a parameters section if none exists.
27-
params_section = DocstringSectionParameters([])
28-
func.docstring.parsed.append(params_section)
29-
# Add entries for all parameters.
30-
for param in parameters:
31-
if param.name != "self":
130+
131+
# If we have required parameters, add them to the "Parameters" section.
132+
if required:
133+
# Create a "Parameters" section if none exists.
134+
if params_section is None:
135+
params_section = DocstringSectionParameters([])
136+
func.docstring.parsed.append(params_section)
137+
138+
# Add required parameters to the section.
139+
for attr in required:
32140
params_section.value.append(
33141
DocstringParameter(
34-
name=param.name,
35-
description=param.docstring.value if param.docstring else "",
36-
annotation=param.annotation,
37-
value=param.default,
142+
name=attr["name"],
143+
description=attr["docstring"].value if attr["docstring"] else "",
144+
annotation=attr["annotation"],
38145
),
39146
)
40147

148+
# If we have optional parameters, add them to the "Parameters" section too,
149+
# with a default value of `...`.
150+
if optional:
151+
# Create a "Parameters" section if none exists.
152+
if params_section is None:
153+
params_section = DocstringSectionParameters([])
154+
func.docstring.parsed.append(params_section)
41155

42-
def _params_from_attrs(attrs: Iterable[Any]) -> Parameters:
43-
return Parameters(
44-
Parameter(name="self", kind=ParameterKind.positional_or_keyword),
45-
*(
46-
Parameter(
47-
name=attr.name,
48-
annotation=attr.annotation,
49-
kind=ParameterKind.keyword_only,
50-
default=attr.value,
51-
docstring=attr.docstring,
156+
# Add optional parameters to the section.
157+
for attr in optional:
158+
params_section.value.append(
159+
DocstringParameter(
160+
name=attr["name"],
161+
description=attr["docstring"].value if attr["docstring"] else "",
162+
annotation=attr["annotation"],
163+
value="...",
164+
),
52165
)
53-
for attr in attrs
54-
),
55-
)
166+
167+
# TODO: Add `**kwargs` parameter if extra items are allowed.
168+
169+
170+
def _params_from_attrs(required: Iterable[_TypedDictAttr], optional: Iterable[_TypedDictAttr]) -> Iterator[Parameter]:
171+
for attr in required:
172+
yield Parameter(
173+
name=attr["name"],
174+
annotation=attr["annotation"],
175+
kind=ParameterKind.keyword_only,
176+
docstring=attr["docstring"],
177+
)
178+
for attr in optional:
179+
yield Parameter(
180+
name=attr["name"],
181+
annotation=attr["annotation"],
182+
kind=ParameterKind.keyword_only,
183+
default="...",
184+
docstring=attr["docstring"],
185+
)
56186

57187

58188
class UnpackTypedDictExtension(Extension):
@@ -67,19 +197,23 @@ def on_class(self, *, cls: Class, **kwargs: Any) -> None: # noqa: ARG002
67197
else:
68198
return
69199

70-
attributes = cls.attributes.values()
200+
required, optional = _get_or_set_attrs(cls)
71201

72202
if "__init__" not in cls.members:
73203
# Build the `__init__` method and add it to the class.
74-
parameters = _params_from_attrs(attributes)
204+
parameters = Parameters(
205+
Parameter(name="self", kind=ParameterKind.positional_or_keyword),
206+
*_params_from_attrs(required, optional),
207+
)
208+
# TODO: Add `**kwargs` parameter if extra items are allowed.
75209
init = Function(name="__init__", parameters=parameters, returns="None")
76210
cls.set_member("__init__", init)
77211
# Update the `__init__` docstring.
78-
_update_docstring(init, parameters)
212+
_update_docstring(init, required, optional)
79213

80214
# Remove attributes from the class, as they are now in the `__init__` method.
81-
for attr in attributes:
82-
cls.del_member(attr.name)
215+
for attr in chain(required, optional):
216+
cls.del_member(attr["name"])
83217

84218
def on_function(self, *, func: Function, **kwargs: Any) -> None: # noqa: ARG002
85219
"""Expand `**kwargs: Unpack[TypedDict]` in function signatures."""
@@ -102,26 +236,22 @@ def on_function(self, *, func: Function, **kwargs: Any) -> None: # noqa: ARG002
102236
else:
103237
return
104238

105-
if "__init__" in typed_dict.members:
106-
# The `__init__` was already generated: use its parameters.
107-
parameters = typed_dict["__init__"].parameters
108-
else:
109-
# Fallback to building parameters from attributes.
110-
parameters = _params_from_attrs(typed_dict.attributes.values())
239+
required, optional = _get_or_set_attrs(typed_dict)
111240

112241
# Update any parameter section in the docstring.
113242
# We do this before updating the signature so that
114243
# parsing the docstring doesn't emit warnings.
115-
_update_docstring(func, parameters, parameter)
244+
_update_docstring(func, required, optional, parameter)
116245

117246
# Update the function parameters.
118247
del func.parameters[parameter.name]
119-
for param in parameters:
120-
if param.name != "self":
121-
func.parameters[param.name] = Parameter(
122-
name=param.name,
123-
annotation=param.annotation,
124-
kind=ParameterKind.keyword_only,
125-
default=param.default,
126-
docstring=param.docstring,
127-
)
248+
for param in _params_from_attrs(required, optional):
249+
func.parameters[param.name] = Parameter(
250+
name=param.name,
251+
annotation=param.annotation,
252+
kind=ParameterKind.keyword_only,
253+
default=param.default,
254+
docstring=param.docstring,
255+
)
256+
257+
# TODO: Add `**kwargs` parameter if extra items are allowed.

0 commit comments

Comments
 (0)