File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change 99import warnings
1010
1111from collections .abc import Mapping , Sequence # noqa
12+ from typing import _GenericAlias
1213
1314
1415PYPY = platform .python_implementation () == "PyPy"
@@ -174,3 +175,10 @@ def func():
174175# don't have a direct reference to the thread-local in their globals dict.
175176# If they have such a reference, it breaks cloudpickle.
176177repr_context = threading .local ()
178+
179+
180+ def get_generic_base (cl ):
181+ """If this is a generic class (A[str]), return the generic base for it."""
182+ if cl .__class__ is _GenericAlias :
183+ return cl .__origin__
184+ return None
Original file line number Diff line number Diff line change 33
44import copy
55
6+ from ._compat import get_generic_base
67from ._make import NOTHING , _obj_setattr , fields
78from .exceptions import AttrsAttributeNotFoundError
89
@@ -296,7 +297,19 @@ def has(cls):
296297
297298 :rtype: bool
298299 """
299- return getattr (cls , "__attrs_attrs__" , None ) is not None
300+ attrs = getattr (cls , "__attrs_attrs__" , None )
301+ if attrs is not None :
302+ return True
303+
304+ # No attrs, maybe it's a specialized generic (A[str])?
305+ generic_base = get_generic_base (cls )
306+ if generic_base is not None :
307+ generic_attrs = getattr (generic_base , "__attrs_attrs__" , None )
308+ if generic_attrs is not None :
309+ # Stick it on here for speed next time.
310+ cls .__attrs_attrs__ = generic_attrs
311+ return generic_attrs is not None
312+ return False
300313
301314
302315def assoc (inst , ** changes ):
Original file line number Diff line number Diff line change 1212# We need to import _compat itself in addition to the _compat members to avoid
1313# having the thread-local in the globals here.
1414from . import _compat , _config , setters
15- from ._compat import PY310 , _AnnotationExtractor , set_closure_cell
15+ from ._compat import (
16+ PY310 ,
17+ _AnnotationExtractor ,
18+ get_generic_base ,
19+ set_closure_cell ,
20+ )
1621from .exceptions import (
1722 DefaultAlreadySetError ,
1823 FrozenInstanceError ,
@@ -1918,12 +1923,26 @@ def fields(cls):
19181923
19191924 .. versionchanged:: 16.2.0 Returned tuple allows accessing the fields
19201925 by name.
1926+ .. versionchanged:: 23.1.0 Add support for generic classes.
19211927 """
1922- if not isinstance (cls , type ):
1928+ generic_base = get_generic_base (cls )
1929+
1930+ if generic_base is None and not isinstance (cls , type ):
19231931 raise TypeError ("Passed object must be a class." )
1932+
19241933 attrs = getattr (cls , "__attrs_attrs__" , None )
1934+
19251935 if attrs is None :
1936+ if generic_base is not None :
1937+ attrs = getattr (generic_base , "__attrs_attrs__" , None )
1938+ if attrs is not None :
1939+ # Even though this is global state, stick it on here to speed
1940+ # it up. We rely on `cls` being cached for this to be
1941+ # efficient.
1942+ cls .__attrs_attrs__ = attrs
1943+ return attrs
19261944 raise NotAnAttrsClassError (f"{ cls !r} is not an attrs-decorated class." )
1945+
19271946 return attrs
19281947
19291948
Original file line number Diff line number Diff line change 66
77
88from collections import OrderedDict
9+ from typing import Generic , TypeVar
910
1011import pytest
1112
@@ -418,6 +419,37 @@ def test_negative(self):
418419 """
419420 assert not has (object )
420421
422+ def test_generics (self ):
423+ """
424+ Works with generic classes.
425+ """
426+ T = TypeVar ("T" )
427+
428+ @attr .define
429+ class A (Generic [T ]):
430+ a : T
431+
432+ assert has (A )
433+
434+ assert has (A [str ])
435+ # Verify twice, since there's caching going on.
436+ assert has (A [str ])
437+
438+ def test_generics_negative (self ):
439+ """
440+ Returns `False` on non-decorated generic classes.
441+ """
442+ T = TypeVar ("T" )
443+
444+ class A (Generic [T ]):
445+ a : T
446+
447+ assert not has (A )
448+
449+ assert not has (A [str ])
450+ # Verify twice, since there's caching going on.
451+ assert not has (A [str ])
452+
421453
422454class TestAssoc :
423455 """
Original file line number Diff line number Diff line change 1313import sys
1414
1515from operator import attrgetter
16+ from typing import Generic , TypeVar
1617
1718import pytest
1819
@@ -1114,6 +1115,22 @@ def test_handler_non_attrs_class(self):
11141115 f"{ object !r} is not an attrs-decorated class."
11151116 ) == e .value .args [0 ]
11161117
1118+ def test_handler_non_attrs_generic_class (self ):
1119+ """
1120+ Raises `ValueError` if passed a non-*attrs* generic class.
1121+ """
1122+ T = TypeVar ("T" )
1123+
1124+ class B (Generic [T ]):
1125+ pass
1126+
1127+ with pytest .raises (NotAnAttrsClassError ) as e :
1128+ fields (B [str ])
1129+
1130+ assert (
1131+ f"{ B [str ]!r} is not an attrs-decorated class."
1132+ ) == e .value .args [0 ]
1133+
11171134 @given (simple_classes ())
11181135 def test_fields (self , C ):
11191136 """
@@ -1129,6 +1146,24 @@ def test_fields_properties(self, C):
11291146 for attribute in fields (C ):
11301147 assert getattr (fields (C ), attribute .name ) is attribute
11311148
1149+ def test_generics (self ):
1150+ """
1151+ Fields work with generic classes.
1152+ """
1153+ T = TypeVar ("T" )
1154+
1155+ @attr .define
1156+ class A (Generic [T ]):
1157+ a : T
1158+
1159+ assert len (fields (A )) == 1
1160+ assert fields (A ).a .name == "a"
1161+ assert fields (A ).a .default is attr .NOTHING
1162+
1163+ assert len (fields (A [str ])) == 1
1164+ assert fields (A [str ]).a .name == "a"
1165+ assert fields (A [str ]).a .default is attr .NOTHING
1166+
11321167
11331168class TestFieldsDict :
11341169 """
You can’t perform that action at this time.
0 commit comments