Skip to content

Commit 4fcd15b

Browse files
Tinchepre-commit-ci[bot]hynek
authored
First pass over generics (#1079)
* First pass over generics * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Reformat comment * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * More work on generics * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add test case * Tweak condition * Remove redundant code * Add test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Hynek Schlawack <hs@ox.cx>
1 parent 9cf2ed5 commit 4fcd15b

5 files changed

Lines changed: 110 additions & 3 deletions

File tree

src/attr/_compat.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import warnings
1010

1111
from collections.abc import Mapping, Sequence # noqa
12+
from typing import _GenericAlias
1213

1314

1415
PYPY = platform.python_implementation() == "PyPy"
@@ -174,3 +175,10 @@ def func():
174175
# don't have a direct reference to the thread-local in their globals dict.
175176
# If they have such a reference, it breaks cloudpickle.
176177
repr_context = threading.local()
178+
179+
180+
def get_generic_base(cl):
181+
"""If this is a generic class (A[str]), return the generic base for it."""
182+
if cl.__class__ is _GenericAlias:
183+
return cl.__origin__
184+
return None

src/attr/_funcs.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import copy
55

6+
from ._compat import get_generic_base
67
from ._make import NOTHING, _obj_setattr, fields
78
from .exceptions import AttrsAttributeNotFoundError
89

@@ -296,7 +297,19 @@ def has(cls):
296297
297298
:rtype: bool
298299
"""
299-
return getattr(cls, "__attrs_attrs__", None) is not None
300+
attrs = getattr(cls, "__attrs_attrs__", None)
301+
if attrs is not None:
302+
return True
303+
304+
# No attrs, maybe it's a specialized generic (A[str])?
305+
generic_base = get_generic_base(cls)
306+
if generic_base is not None:
307+
generic_attrs = getattr(generic_base, "__attrs_attrs__", None)
308+
if generic_attrs is not None:
309+
# Stick it on here for speed next time.
310+
cls.__attrs_attrs__ = generic_attrs
311+
return generic_attrs is not None
312+
return False
300313

301314

302315
def assoc(inst, **changes):

src/attr/_make.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
# We need to import _compat itself in addition to the _compat members to avoid
1313
# having the thread-local in the globals here.
1414
from . import _compat, _config, setters
15-
from ._compat import PY310, _AnnotationExtractor, set_closure_cell
15+
from ._compat import (
16+
PY310,
17+
_AnnotationExtractor,
18+
get_generic_base,
19+
set_closure_cell,
20+
)
1621
from .exceptions import (
1722
DefaultAlreadySetError,
1823
FrozenInstanceError,
@@ -1918,12 +1923,26 @@ def fields(cls):
19181923
19191924
.. versionchanged:: 16.2.0 Returned tuple allows accessing the fields
19201925
by name.
1926+
.. versionchanged:: 23.1.0 Add support for generic classes.
19211927
"""
1922-
if not isinstance(cls, type):
1928+
generic_base = get_generic_base(cls)
1929+
1930+
if generic_base is None and not isinstance(cls, type):
19231931
raise TypeError("Passed object must be a class.")
1932+
19241933
attrs = getattr(cls, "__attrs_attrs__", None)
1934+
19251935
if attrs is None:
1936+
if generic_base is not None:
1937+
attrs = getattr(generic_base, "__attrs_attrs__", None)
1938+
if attrs is not None:
1939+
# Even though this is global state, stick it on here to speed
1940+
# it up. We rely on `cls` being cached for this to be
1941+
# efficient.
1942+
cls.__attrs_attrs__ = attrs
1943+
return attrs
19261944
raise NotAnAttrsClassError(f"{cls!r} is not an attrs-decorated class.")
1945+
19271946
return attrs
19281947

19291948

tests/test_funcs.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77

88
from collections import OrderedDict
9+
from typing import Generic, TypeVar
910

1011
import pytest
1112

@@ -418,6 +419,37 @@ def test_negative(self):
418419
"""
419420
assert not has(object)
420421

422+
def test_generics(self):
423+
"""
424+
Works with generic classes.
425+
"""
426+
T = TypeVar("T")
427+
428+
@attr.define
429+
class A(Generic[T]):
430+
a: T
431+
432+
assert has(A)
433+
434+
assert has(A[str])
435+
# Verify twice, since there's caching going on.
436+
assert has(A[str])
437+
438+
def test_generics_negative(self):
439+
"""
440+
Returns `False` on non-decorated generic classes.
441+
"""
442+
T = TypeVar("T")
443+
444+
class A(Generic[T]):
445+
a: T
446+
447+
assert not has(A)
448+
449+
assert not has(A[str])
450+
# Verify twice, since there's caching going on.
451+
assert not has(A[str])
452+
421453

422454
class TestAssoc:
423455
"""

tests/test_make.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sys
1414

1515
from operator import attrgetter
16+
from typing import Generic, TypeVar
1617

1718
import pytest
1819

@@ -1114,6 +1115,22 @@ def test_handler_non_attrs_class(self):
11141115
f"{object!r} is not an attrs-decorated class."
11151116
) == e.value.args[0]
11161117

1118+
def test_handler_non_attrs_generic_class(self):
1119+
"""
1120+
Raises `ValueError` if passed a non-*attrs* generic class.
1121+
"""
1122+
T = TypeVar("T")
1123+
1124+
class B(Generic[T]):
1125+
pass
1126+
1127+
with pytest.raises(NotAnAttrsClassError) as e:
1128+
fields(B[str])
1129+
1130+
assert (
1131+
f"{B[str]!r} is not an attrs-decorated class."
1132+
) == e.value.args[0]
1133+
11171134
@given(simple_classes())
11181135
def test_fields(self, C):
11191136
"""
@@ -1129,6 +1146,24 @@ def test_fields_properties(self, C):
11291146
for attribute in fields(C):
11301147
assert getattr(fields(C), attribute.name) is attribute
11311148

1149+
def test_generics(self):
1150+
"""
1151+
Fields work with generic classes.
1152+
"""
1153+
T = TypeVar("T")
1154+
1155+
@attr.define
1156+
class A(Generic[T]):
1157+
a: T
1158+
1159+
assert len(fields(A)) == 1
1160+
assert fields(A).a.name == "a"
1161+
assert fields(A).a.default is attr.NOTHING
1162+
1163+
assert len(fields(A[str])) == 1
1164+
assert fields(A[str]).a.name == "a"
1165+
assert fields(A[str]).a.default is attr.NOTHING
1166+
11321167

11331168
class TestFieldsDict:
11341169
"""

0 commit comments

Comments
 (0)