@@ -421,23 +421,35 @@ def __repr__(self) -> str:
421421 def __eq__ (self , actual ) -> bool :
422422 """Return whether the given value is equal to the expected value
423423 within the pre-specified tolerance."""
424+
425+ def is_bool (val : Any ) -> bool :
426+ # Check if `val` is a native bool or numpy bool.
427+ if isinstance (val , bool ):
428+ return True
429+ try :
430+ import numpy as np
431+
432+ return isinstance (val , np .bool_ )
433+ except ImportError :
434+ return False
435+
424436 asarray = _as_numpy_array (actual )
425437 if asarray is not None :
426438 # Call ``__eq__()`` manually to prevent infinite-recursion with
427439 # numpy<1.13. See #3748.
428440 return all (self .__eq__ (a ) for a in asarray .flat )
429441
430- # Short-circuit exact equality, except for bool
431- if isinstance (self .expected , bool ) and not isinstance (actual , bool ):
442+ # Short-circuit exact equality, except for bool and np.bool_
443+ if is_bool (self .expected ) and not is_bool (actual ):
432444 return False
433445 elif actual == self .expected :
434446 return True
435447
436448 # If either type is non-numeric, fall back to strict equality.
437449 # NB: we need Complex, rather than just Number, to ensure that __abs__,
438450 # __sub__, and __float__ are defined. Also, consider bool to be
439- # nonnumeric , even though it has the required arithmetic.
440- if isinstance (self .expected , bool ) or not (
451+ # non-numeric , even though it has the required arithmetic.
452+ if is_bool (self .expected ) or not (
441453 isinstance (self .expected , (Complex , Decimal ))
442454 and isinstance (actual , (Complex , Decimal ))
443455 ):
0 commit comments