Skip to content

Commit 6e43e68

Browse files
authored
Merge pull request #1 from evanwilson2123/kim-14324-raisesgroup
Kim 14324 raisesgroup
2 parents 2a74cdf + 503c50b commit 6e43e68

File tree

3 files changed

+44
-6
lines changed

3 files changed

+44
-6
lines changed

changelog/14324.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fix ``pytest.RaisesGroup`` incorrectly calling the ``check`` callback with contained exceptions instead of only the exception group.

src/_pytest/raises.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,14 +1207,35 @@ def matches(
12071207
reason = (
12081208
cast(str, self._fail_reason) + f" on the {type(exception).__name__}"
12091209
)
1210+
1211+
suggest_subexception_check = False
12101212
if (
1211-
len(actual_exceptions) == len(self.expected_exceptions) == 1
1213+
self.check is not None
1214+
and len(actual_exceptions) == len(self.expected_exceptions) == 1
12121215
and isinstance(expected := self.expected_exceptions[0], type)
1213-
# we explicitly break typing here :)
1214-
and self._check_check(actual_exceptions[0]) # type: ignore[arg-type]
1216+
and isinstance(actual_exceptions[0], expected)
12151217
):
1218+
annotations = getattr(self.check, "__annotations__", {})
1219+
param_names = [name for name in annotations if name != "return"]
1220+
if param_names:
1221+
param_annotation = annotations[param_names[0]]
1222+
1223+
if isinstance(param_annotation, str):
1224+
suggest_subexception_check = (
1225+
"ExceptionGroup" not in param_annotation
1226+
and "BaseExceptionGroup" not in param_annotation
1227+
)
1228+
else:
1229+
origin = get_origin(param_annotation) or param_annotation
1230+
if isinstance(origin, type):
1231+
suggest_subexception_check = not issubclass(
1232+
origin, BaseExceptionGroup
1233+
)
1234+
1235+
if suggest_subexception_check:
12161236
self._fail_reason = reason + (
1217-
f", but did return True for the expected {self._repr_expected(expected)}."
1237+
f", but the single contained exception matches the expected "
1238+
f"{self._repr_expected(expected)}."
12181239
f" You might want RaisesGroup(RaisesExc({expected.__name__}, check=<...>))"
12191240
)
12201241
else:

testing/python/raises_group.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,9 +412,12 @@ def is_exc(e: ExceptionGroup[ValueError]) -> bool:
412412
return e is exc
413413

414414
is_exc_repr = repr_callable(is_exc)
415+
416+
# this should pass (same object)
415417
with RaisesGroup(ValueError, check=is_exc):
416418
raise exc
417419

420+
# this should fail WITHOUT suggestion
418421
with (
419422
fails_raises_group(
420423
f"check {is_exc_repr} did not return True on the ExceptionGroup"
@@ -426,16 +429,29 @@ def is_exc(e: ExceptionGroup[ValueError]) -> bool:
426429
def is_value_error(e: BaseException) -> bool:
427430
return isinstance(e, ValueError)
428431

429-
# helpful suggestion if the user thinks the check is for the sub-exception
432+
# this should fail WITH suggestion (because check looks like it's for inner exception)
430433
with (
431434
fails_raises_group(
432-
f"check {is_value_error} did not return True on the ExceptionGroup, but did return True for the expected ValueError. You might want RaisesGroup(RaisesExc(ValueError, check=<...>))"
435+
f"check {is_value_error} did not return True on the ExceptionGroup, but the single contained exception matches the expected ValueError. You might want RaisesGroup(RaisesExc(ValueError, check=<...>))"
433436
),
434437
RaisesGroup(ValueError, check=is_value_error),
435438
):
436439
raise ExceptionGroup("", (ValueError(),))
437440

438441

442+
def test_check_called_only_with_group() -> None:
443+
seen = []
444+
445+
def check(exc_group: ExceptionGroup[ValueError]) -> bool:
446+
seen.append(type(exc_group))
447+
return len(exc_group.exceptions) == 1
448+
449+
with RaisesGroup(ValueError, match="Main message", check=check):
450+
raise ExceptionGroup("Main message", [ValueError("foo")])
451+
452+
assert seen == [ExceptionGroup]
453+
454+
439455
def test_unwrapped_match_check() -> None:
440456
def my_check(e: object) -> bool: # pragma: no cover
441457
return True

0 commit comments

Comments
 (0)