Skip to content

Commit 9a49703

Browse files
Zac-HDclaudepre-commit-ci[bot]
authored
Fix autofix checkpoint placement in except clauses (#442)
* Avoid placing autofix checkpoints inside except clauses The autofix for missing checkpoints (ASYNC910/ASYNC911) used to insert `await ...checkpoint()` right before the offending return/yield, which would be inside an except handler and trigger ASYNC120. Detect that case and redirect the insertion to the top of the innermost enclosing loop (if any) or to the top of the function body instead. Fall back to the previous location only for contorted cases where neither of those safer locations would actually break the uncheckpointed path. Fixes #403. * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent b4637bb commit 9a49703

File tree

7 files changed

+287
-25
lines changed

7 files changed

+287
-25
lines changed

docs/changelog.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ Changelog
44

55
`CalVer, YY.month.patch <https://calver.org/>`_
66

7+
Unreleased
8+
==========
9+
- Autofix for :ref:`ASYNC910 <async910>` / :ref:`ASYNC911 <async911>` no longer inserts checkpoints inside ``except`` clauses (which would trigger :ref:`ASYNC120 <async120>`); instead the checkpoint is added at the top of the function or of the enclosing loop. `(issue #403) <https://github.com/python-trio/flake8-async/issues/403>`_
10+
711
25.7.1
812
======
913
- :ref:`ASYNC102 <async102>` no longer triggered for asyncio due to different cancellation semantics it uses.

flake8_async/visitors/visitor91x.py

Lines changed: 138 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,11 @@ class LoopState:
172172
nodes_needing_checkpoints: list[cst.Return | cst.Yield | ArtificialStatement] = (
173173
field(default_factory=list[cst.Return | cst.Yield | ArtificialStatement])
174174
)
175+
# If a missing checkpoint was detected for a return/yield inside an except
176+
# clause, inserting the checkpoint there would trigger ASYNC120. Instead,
177+
# mark the innermost loop so a checkpoint is inserted at the top of the
178+
# loop body.
179+
needs_checkpoint_at_loop_start: bool = False
175180

176181
def copy(self):
177182
return LoopState(
@@ -182,6 +187,7 @@ def copy(self):
182187
uncheckpointed_before_break=self.uncheckpointed_before_break.copy(),
183188
artificial_errors=self.artificial_errors.copy(),
184189
nodes_needing_checkpoints=self.nodes_needing_checkpoints.copy(),
190+
needs_checkpoint_at_loop_start=self.needs_checkpoint_at_loop_start,
185191
)
186192

187193

@@ -341,6 +347,12 @@ def __init__(
341347
self.explicitly_imported_library = explicitly_imported
342348
self.nodes_needing_checkpoint = nodes_needing_checkpoint
343349
self.__library = library
350+
# Depth of except handlers we're currently inside, and a flag set if
351+
# we detected a node that would have been fixed inside an except
352+
# clause — in that case the caller should insert a checkpoint at the
353+
# top of the loop body instead.
354+
self.except_depth = 0
355+
self.needs_checkpoint_at_loop_start = False
344356

345357
@property
346358
def library(self) -> tuple[str, ...]:
@@ -349,6 +361,22 @@ def library(self) -> tuple[str, ...]:
349361
def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool:
350362
return not self.noautofix
351363

364+
def visit_ExceptHandler(
365+
self, node: cst.ExceptHandler | cst.ExceptStarHandler
366+
) -> None:
367+
self.except_depth += 1
368+
369+
def leave_ExceptHandler(
370+
self,
371+
original_node: cst.ExceptHandler | cst.ExceptStarHandler,
372+
updated_node: cst.ExceptHandler | cst.ExceptStarHandler,
373+
) -> Any:
374+
self.except_depth -= 1
375+
return updated_node
376+
377+
visit_ExceptStarHandler = visit_ExceptHandler
378+
leave_ExceptStarHandler = leave_ExceptHandler
379+
352380
def leave_Yield(
353381
self,
354382
original_node: cst.Yield,
@@ -359,7 +387,15 @@ def leave_Yield(
359387
if original_node in self.nodes_needing_checkpoint and self.should_autofix(
360388
original_node
361389
):
362-
self.add_statement = checkpoint_statement(self.library[0])
390+
if self.except_depth > 0:
391+
# Inserting inside an except clause would trigger ASYNC120.
392+
# Signal to the caller to insert at the top of the loop body.
393+
# ensure_imported_library is called by the caller at the
394+
# insertion site so we don't add an import for a checkpoint
395+
# that ends up suppressed.
396+
self.needs_checkpoint_at_loop_start = True
397+
else:
398+
self.add_statement = checkpoint_statement(self.library[0])
363399
return updated_node
364400

365401
# returns handled same as yield, but ofc needs to ignore types
@@ -410,6 +446,15 @@ def __init__(self, *args: Any, **kwargs: Any):
410446
self.try_state = TryState()
411447
self.match_state = MatchState()
412448

449+
# Depth of except handlers we're currently inside. Used to avoid
450+
# inserting checkpoints inside an except clause (which would trigger
451+
# ASYNC120).
452+
self.except_depth = 0
453+
# Set when a missing checkpoint is detected inside an except clause
454+
# for a return/yield that is not inside a loop. We then insert a
455+
# single checkpoint at the top of the function body instead.
456+
self.add_checkpoint_at_function_start = False
457+
413458
# ASYNC100
414459
self.has_checkpoint_stack: list[ContextManager] = []
415460
self.taskgroup_has_start_soon: dict[str, bool] = {}
@@ -510,6 +555,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
510555
# node_dict is cleaned up and don't need to be saved
511556
"taskgroup_has_start_soon",
512557
"suppress_imported_as", # a copy is saved, but state is not reset
558+
"except_depth",
559+
"add_checkpoint_at_function_start",
513560
copy=True,
514561
)
515562
self.uncheckpointed_statements = set()
@@ -518,6 +565,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
518565
self.loop_state = LoopState()
519566
# try_state is reset upon entering try
520567
self.taskgroup_has_start_soon = {}
568+
self.except_depth = 0
569+
self.add_checkpoint_at_function_start = False
521570

522571
self.async_function = (
523572
node.asynchronous is not None
@@ -563,6 +612,19 @@ def leave_FunctionDef(
563612

564613
self.ensure_imported_library()
565614

615+
if (
616+
self.add_checkpoint_at_function_start
617+
and self.new_body is not None
618+
and isinstance(self.new_body, cst.IndentedBlock)
619+
):
620+
# insert checkpoint at the top of body (for missing checkpoints
621+
# detected on return/yield inside an except clause)
622+
new_body_block = list(self.new_body.body)
623+
new_body_block.insert(0, self.checkpoint_statement())
624+
self.new_body = self.new_body.with_changes(body=new_body_block)
625+
626+
self.ensure_imported_library()
627+
566628
if self.new_body is not None:
567629
updated_node = updated_node.with_changes(body=self.new_body)
568630
self.restore_state(original_node)
@@ -598,7 +660,15 @@ def check_function_exit(
598660
if len(self.uncheckpointed_statements) == 1 and self.should_autofix(
599661
original_node
600662
):
601-
self.loop_state.nodes_needing_checkpoints.append(original_node)
663+
if self.except_depth > 0:
664+
# Inserting a checkpoint inside the except clause would
665+
# trigger ASYNC120. Instead mark the innermost loop so a
666+
# checkpoint is inserted at the top of the loop body.
667+
# ensure_imported_library is called at the actual
668+
# insertion site in leave_While.
669+
self.loop_state.needs_checkpoint_at_loop_start = True
670+
else:
671+
self.loop_state.nodes_needing_checkpoints.append(original_node)
602672
return False
603673

604674
any_errors = False
@@ -618,7 +688,7 @@ def leave_Return(
618688
if self.check_function_exit(original_node) and self.should_autofix(
619689
original_node
620690
):
621-
self.add_statement = self.checkpoint_statement()
691+
self._set_missing_checkpoint_fix()
622692
# avoid duplicate error messages
623693
# but don't see it as a cancel point for ASYNC100
624694
self.checkpoint_schedule_point()
@@ -627,6 +697,49 @@ def leave_Return(
627697
assert original_node.deep_equals(updated_node)
628698
return original_node
629699

700+
def _set_missing_checkpoint_fix(self) -> None:
701+
"""Record where to insert a fix for a detected missing checkpoint.
702+
703+
Normally we insert the checkpoint right before the offending
704+
return/yield, but if the statement is inside an `except` clause that
705+
would just trigger ASYNC120 (checkpoint inside except). Instead we
706+
insert a checkpoint at a safe location: the top of the innermost
707+
enclosing loop (so yields between iterations are also covered), or
708+
the top of the function body otherwise.
709+
710+
If the uncheckpointed statements include a prior yield (so the error
711+
would be "yield since prior yield"), neither alternative actually
712+
fixes the path, so we fall back to the old behavior of inserting
713+
before the statement. The resulting checkpoint may trigger ASYNC120
714+
as a secondary diagnostic, but that is a pre-existing limitation of
715+
the autofix in such contorted cases.
716+
"""
717+
if self.except_depth > 0 and self._can_redirect_except_fix():
718+
if ARTIFICIAL_STATEMENT in self.uncheckpointed_statements:
719+
# we're inside a loop
720+
self.loop_state.needs_checkpoint_at_loop_start = True
721+
else:
722+
self.add_checkpoint_at_function_start = True
723+
# ensure_imported_library is called at the actual insertion site
724+
# (leave_FunctionDef / leave_While) so we don't add an import for
725+
# a checkpoint that ends up suppressed by noqa.
726+
else:
727+
self.add_statement = self.checkpoint_statement()
728+
729+
def _can_redirect_except_fix(self) -> bool:
730+
"""Check if the missing checkpoint can be fixed at a safe top location.
731+
732+
Redirecting to the top of the function or of the enclosing loop only
733+
actually fixes the uncheckpointed path if every uncheckpointed
734+
statement is "function definition" or the artificial loop-start
735+
marker. A previous yield indicates a path that the redirected
736+
checkpoint would not cover.
737+
"""
738+
return all(
739+
isinstance(stmt, ArtificialStatement) or stmt.name == "function definition"
740+
for stmt in self.uncheckpointed_statements
741+
)
742+
630743
def error_91x(
631744
self,
632745
node: cst.Return | cst.FunctionDef | cst.Yield,
@@ -838,7 +951,7 @@ def leave_Yield(
838951
if self.check_function_exit(original_node) and self.should_autofix(
839952
original_node
840953
):
841-
self.add_statement = self.checkpoint_statement()
954+
self._set_missing_checkpoint_fix()
842955

843956
# mark as requiring checkpoint after
844957
pos = self.get_metadata(PositionProvider, original_node).start # type: ignore
@@ -883,6 +996,7 @@ def visit_ExceptHandler(self, node: cst.ExceptHandler | cst.ExceptStarHandler):
883996
self.uncheckpointed_statements = (
884997
self.try_state.body_uncheckpointed_statements.copy()
885998
)
999+
self.except_depth += 1
8861000

8871001
def leave_ExceptHandler(
8881002
self,
@@ -892,6 +1006,7 @@ def leave_ExceptHandler(
8921006
self.try_state.except_uncheckpointed_statements.update(
8931007
self.uncheckpointed_statements
8941008
)
1009+
self.except_depth -= 1
8951010
return updated_node
8961011

8971012
def visit_Try_orelse(self, node: cst.Try | cst.TryStar):
@@ -1103,6 +1218,10 @@ def leave_While_body(self, node: cst.For | cst.While):
11031218
# the potential checkpoints
11041219
if not any_error:
11051220
self.loop_state.nodes_needing_checkpoints = []
1221+
# Also clear the loop-top insertion flag; it may have been set by
1222+
# return/yield nodes inside an except clause whose errors were
1223+
# suppressed via noqa.
1224+
self.loop_state.needs_checkpoint_at_loop_start = False
11061225

11071226
if (
11081227
self.loop_state.infinite_loop
@@ -1181,16 +1300,12 @@ def leave_While(
11811300
self.restore_state(original_node)
11821301
return updated_node
11831302

1303+
insert_at_loop_start = self.loop_state.needs_checkpoint_at_loop_start
1304+
11841305
# ASYNC913, indefinite loop with no guaranteed checkpoint
11851306
if self.loop_state.nodes_needing_checkpoints == [ARTIFICIAL_STATEMENT]:
11861307
if self.should_autofix(original_node, code="ASYNC913"):
1187-
# insert checkpoint at start of body
1188-
new_body = list(updated_node.body.body)
1189-
new_body.insert(0, self.checkpoint_statement())
1190-
indentedblock = updated_node.body.with_changes(body=new_body)
1191-
updated_node = updated_node.with_changes(body=indentedblock)
1192-
1193-
self.ensure_imported_library()
1308+
insert_at_loop_start = True
11941309
elif self.loop_state.nodes_needing_checkpoints:
11951310
assert ARTIFICIAL_STATEMENT not in self.loop_state.nodes_needing_checkpoints
11961311
transformer = InsertCheckpointsInLoopBody(
@@ -1207,6 +1322,18 @@ def leave_While(
12071322
# include any necessary import added
12081323
self.add_import.update(transformer.add_import)
12091324

1325+
if transformer.needs_checkpoint_at_loop_start:
1326+
insert_at_loop_start = True
1327+
1328+
if insert_at_loop_start:
1329+
# insert checkpoint at start of body
1330+
new_body = list(updated_node.body.body)
1331+
new_body.insert(0, self.checkpoint_statement())
1332+
indentedblock = updated_node.body.with_changes(body=new_body)
1333+
updated_node = updated_node.with_changes(body=indentedblock)
1334+
1335+
self.ensure_imported_library()
1336+
12101337
self.restore_state(original_node)
12111338
return updated_node
12121339

tests/autofix_files/async911.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -373,10 +373,10 @@ async def foo_try_1(): # error: 0, "exit", Statement("function definition", lin
373373

374374
# no checkpoint after yield in ValueError
375375
async def foo_try_2(): # error: 0, "exit", Statement("yield", lineno+5)
376+
await trio.lowlevel.checkpoint()
376377
try:
377378
await foo()
378379
except ValueError:
379-
await trio.lowlevel.checkpoint()
380380
# try might not have checkpointed
381381
yield # error: 8, "yield", Statement("function definition", lineno-5)
382382
except:
@@ -398,10 +398,10 @@ async def foo_try_3(): # error: 0, "exit", Statement("yield", lineno+6)
398398

399399

400400
async def foo_try_4(): # safe
401+
await trio.lowlevel.checkpoint()
401402
try:
402403
...
403404
except:
404-
await trio.lowlevel.checkpoint()
405405
yield # error: 8, "yield", Statement("function definition", lineno-4)
406406
finally:
407407
await foo()

tests/autofix_files/async911.py.diff

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@
251251

252252

253253
async def foo_while_endless_3():
254-
@@ x,9 x,11 @@
254+
@@ x,13 x,16 @@
255255
# try
256256
async def foo_try_1(): # error: 0, "exit", Statement("function definition", lineno) # error: 0, "exit", Statement("yield", lineno+2)
257257
try:
@@ -263,22 +263,20 @@
263263

264264

265265
# no checkpoint after yield in ValueError
266-
@@ x,12 x,14 @@
266+
async def foo_try_2(): # error: 0, "exit", Statement("yield", lineno+5)
267+
+ await trio.lowlevel.checkpoint()
267268
try:
268269
await foo()
269270
except ValueError:
270-
+ await trio.lowlevel.checkpoint()
271-
# try might not have checkpointed
272-
yield # error: 8, "yield", Statement("function definition", lineno-5)
273-
except:
271+
@@ x,6 x,7 @@
274272
await foo()
275273
else:
276274
pass
277275
+ await trio.lowlevel.checkpoint()
278276

279277

280278
async def foo_try_3(): # error: 0, "exit", Statement("yield", lineno+6)
281-
@@ x,13 x,16 @@
279+
@@ x,10 x,13 @@
282280
except:
283281
await foo()
284282
else:
@@ -288,13 +286,10 @@
288286

289287

290288
async def foo_try_4(): # safe
289+
+ await trio.lowlevel.checkpoint()
291290
try:
292291
...
293292
except:
294-
+ await trio.lowlevel.checkpoint()
295-
yield # error: 8, "yield", Statement("function definition", lineno-4)
296-
finally:
297-
await foo()
298293
@@ x,6 x,7 @@
299294
try:
300295
await foo()

0 commit comments

Comments
 (0)