Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ Changelog

`CalVer, YY.month.patch <https://calver.org/>`_

Unreleased
==========
- Autofix for :ref:`ASYNC910 <async910>` / :ref:`ASYNC911 <async911>` no longer inserts checkpoints inside ``except`` clauses (which would trigger :ref:`ASYNC120 <async120>`); instead the checkpoint is added at the top of the function or of the enclosing loop. `(issue #403) <https://github.com/python-trio/flake8-async/issues/403>`_

25.7.1
======
- :ref:`ASYNC102 <async102>` no longer triggered for asyncio due to different cancellation semantics it uses.
Expand Down
149 changes: 138 additions & 11 deletions flake8_async/visitors/visitor91x.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,11 @@ class LoopState:
nodes_needing_checkpoints: list[cst.Return | cst.Yield | ArtificialStatement] = (
field(default_factory=list[cst.Return | cst.Yield | ArtificialStatement])
)
# If a missing checkpoint was detected for a return/yield inside an except
# clause, inserting the checkpoint there would trigger ASYNC120. Instead,
# mark the innermost loop so a checkpoint is inserted at the top of the
# loop body.
needs_checkpoint_at_loop_start: bool = False

def copy(self):
return LoopState(
Expand All @@ -182,6 +187,7 @@ def copy(self):
uncheckpointed_before_break=self.uncheckpointed_before_break.copy(),
artificial_errors=self.artificial_errors.copy(),
nodes_needing_checkpoints=self.nodes_needing_checkpoints.copy(),
needs_checkpoint_at_loop_start=self.needs_checkpoint_at_loop_start,
)


Expand Down Expand Up @@ -341,6 +347,12 @@ def __init__(
self.explicitly_imported_library = explicitly_imported
self.nodes_needing_checkpoint = nodes_needing_checkpoint
self.__library = library
# Depth of except handlers we're currently inside, and a flag set if
# we detected a node that would have been fixed inside an except
# clause — in that case the caller should insert a checkpoint at the
# top of the loop body instead.
self.except_depth = 0
self.needs_checkpoint_at_loop_start = False

@property
def library(self) -> tuple[str, ...]:
Expand All @@ -349,6 +361,22 @@ def library(self) -> tuple[str, ...]:
def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool:
return not self.noautofix

def visit_ExceptHandler(
self, node: cst.ExceptHandler | cst.ExceptStarHandler
) -> None:
self.except_depth += 1

def leave_ExceptHandler(
self,
original_node: cst.ExceptHandler | cst.ExceptStarHandler,
updated_node: cst.ExceptHandler | cst.ExceptStarHandler,
) -> Any:
self.except_depth -= 1
return updated_node

visit_ExceptStarHandler = visit_ExceptHandler
leave_ExceptStarHandler = leave_ExceptHandler

def leave_Yield(
self,
original_node: cst.Yield,
Expand All @@ -359,7 +387,15 @@ def leave_Yield(
if original_node in self.nodes_needing_checkpoint and self.should_autofix(
original_node
):
self.add_statement = checkpoint_statement(self.library[0])
if self.except_depth > 0:
# Inserting inside an except clause would trigger ASYNC120.
# Signal to the caller to insert at the top of the loop body.
# ensure_imported_library is called by the caller at the
# insertion site so we don't add an import for a checkpoint
# that ends up suppressed.
self.needs_checkpoint_at_loop_start = True
else:
self.add_statement = checkpoint_statement(self.library[0])
return updated_node

# returns handled same as yield, but ofc needs to ignore types
Expand Down Expand Up @@ -410,6 +446,15 @@ def __init__(self, *args: Any, **kwargs: Any):
self.try_state = TryState()
self.match_state = MatchState()

# Depth of except handlers we're currently inside. Used to avoid
# inserting checkpoints inside an except clause (which would trigger
# ASYNC120).
self.except_depth = 0
# Set when a missing checkpoint is detected inside an except clause
# for a return/yield that is not inside a loop. We then insert a
# single checkpoint at the top of the function body instead.
self.add_checkpoint_at_function_start = False

# ASYNC100
self.has_checkpoint_stack: list[ContextManager] = []
self.taskgroup_has_start_soon: dict[str, bool] = {}
Expand Down Expand Up @@ -510,6 +555,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
# node_dict is cleaned up and don't need to be saved
"taskgroup_has_start_soon",
"suppress_imported_as", # a copy is saved, but state is not reset
"except_depth",
"add_checkpoint_at_function_start",
copy=True,
)
self.uncheckpointed_statements = set()
Expand All @@ -518,6 +565,8 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
self.loop_state = LoopState()
# try_state is reset upon entering try
self.taskgroup_has_start_soon = {}
self.except_depth = 0
self.add_checkpoint_at_function_start = False

self.async_function = (
node.asynchronous is not None
Expand Down Expand Up @@ -563,6 +612,19 @@ def leave_FunctionDef(

self.ensure_imported_library()

if (
self.add_checkpoint_at_function_start
and self.new_body is not None
and isinstance(self.new_body, cst.IndentedBlock)
):
# insert checkpoint at the top of body (for missing checkpoints
# detected on return/yield inside an except clause)
new_body_block = list(self.new_body.body)
new_body_block.insert(0, self.checkpoint_statement())
self.new_body = self.new_body.with_changes(body=new_body_block)

self.ensure_imported_library()

if self.new_body is not None:
updated_node = updated_node.with_changes(body=self.new_body)
self.restore_state(original_node)
Expand Down Expand Up @@ -598,7 +660,15 @@ def check_function_exit(
if len(self.uncheckpointed_statements) == 1 and self.should_autofix(
original_node
):
self.loop_state.nodes_needing_checkpoints.append(original_node)
if self.except_depth > 0:
# Inserting a checkpoint inside the except clause would
# trigger ASYNC120. Instead mark the innermost loop so a
# checkpoint is inserted at the top of the loop body.
# ensure_imported_library is called at the actual
# insertion site in leave_While.
self.loop_state.needs_checkpoint_at_loop_start = True
else:
self.loop_state.nodes_needing_checkpoints.append(original_node)
return False

any_errors = False
Expand All @@ -618,7 +688,7 @@ def leave_Return(
if self.check_function_exit(original_node) and self.should_autofix(
original_node
):
self.add_statement = self.checkpoint_statement()
self._set_missing_checkpoint_fix()
# avoid duplicate error messages
# but don't see it as a cancel point for ASYNC100
self.checkpoint_schedule_point()
Expand All @@ -627,6 +697,49 @@ def leave_Return(
assert original_node.deep_equals(updated_node)
return original_node

def _set_missing_checkpoint_fix(self) -> None:
"""Record where to insert a fix for a detected missing checkpoint.

Normally we insert the checkpoint right before the offending
return/yield, but if the statement is inside an `except` clause that
would just trigger ASYNC120 (checkpoint inside except). Instead we
insert a checkpoint at a safe location: the top of the innermost
enclosing loop (so yields between iterations are also covered), or
the top of the function body otherwise.

If the uncheckpointed statements include a prior yield (so the error
would be "yield since prior yield"), neither alternative actually
fixes the path, so we fall back to the old behavior of inserting
before the statement. The resulting checkpoint may trigger ASYNC120
as a secondary diagnostic, but that is a pre-existing limitation of
the autofix in such contorted cases.
"""
if self.except_depth > 0 and self._can_redirect_except_fix():
if ARTIFICIAL_STATEMENT in self.uncheckpointed_statements:
# we're inside a loop
self.loop_state.needs_checkpoint_at_loop_start = True
else:
self.add_checkpoint_at_function_start = True
# ensure_imported_library is called at the actual insertion site
# (leave_FunctionDef / leave_While) so we don't add an import for
# a checkpoint that ends up suppressed by noqa.
else:
self.add_statement = self.checkpoint_statement()

def _can_redirect_except_fix(self) -> bool:
"""Check if the missing checkpoint can be fixed at a safe top location.

Redirecting to the top of the function or of the enclosing loop only
actually fixes the uncheckpointed path if every uncheckpointed
statement is "function definition" or the artificial loop-start
marker. A previous yield indicates a path that the redirected
checkpoint would not cover.
"""
return all(
isinstance(stmt, ArtificialStatement) or stmt.name == "function definition"
for stmt in self.uncheckpointed_statements
)

def error_91x(
self,
node: cst.Return | cst.FunctionDef | cst.Yield,
Expand Down Expand Up @@ -838,7 +951,7 @@ def leave_Yield(
if self.check_function_exit(original_node) and self.should_autofix(
original_node
):
self.add_statement = self.checkpoint_statement()
self._set_missing_checkpoint_fix()

# mark as requiring checkpoint after
pos = self.get_metadata(PositionProvider, original_node).start # type: ignore
Expand Down Expand Up @@ -883,6 +996,7 @@ def visit_ExceptHandler(self, node: cst.ExceptHandler | cst.ExceptStarHandler):
self.uncheckpointed_statements = (
self.try_state.body_uncheckpointed_statements.copy()
)
self.except_depth += 1

def leave_ExceptHandler(
self,
Expand All @@ -892,6 +1006,7 @@ def leave_ExceptHandler(
self.try_state.except_uncheckpointed_statements.update(
self.uncheckpointed_statements
)
self.except_depth -= 1
return updated_node

def visit_Try_orelse(self, node: cst.Try | cst.TryStar):
Expand Down Expand Up @@ -1103,6 +1218,10 @@ def leave_While_body(self, node: cst.For | cst.While):
# the potential checkpoints
if not any_error:
self.loop_state.nodes_needing_checkpoints = []
# Also clear the loop-top insertion flag; it may have been set by
# return/yield nodes inside an except clause whose errors were
# suppressed via noqa.
self.loop_state.needs_checkpoint_at_loop_start = False

if (
self.loop_state.infinite_loop
Expand Down Expand Up @@ -1181,16 +1300,12 @@ def leave_While(
self.restore_state(original_node)
return updated_node

insert_at_loop_start = self.loop_state.needs_checkpoint_at_loop_start

# ASYNC913, indefinite loop with no guaranteed checkpoint
if self.loop_state.nodes_needing_checkpoints == [ARTIFICIAL_STATEMENT]:
if self.should_autofix(original_node, code="ASYNC913"):
# insert checkpoint at start of body
new_body = list(updated_node.body.body)
new_body.insert(0, self.checkpoint_statement())
indentedblock = updated_node.body.with_changes(body=new_body)
updated_node = updated_node.with_changes(body=indentedblock)

self.ensure_imported_library()
insert_at_loop_start = True
elif self.loop_state.nodes_needing_checkpoints:
assert ARTIFICIAL_STATEMENT not in self.loop_state.nodes_needing_checkpoints
transformer = InsertCheckpointsInLoopBody(
Expand All @@ -1207,6 +1322,18 @@ def leave_While(
# include any necessary import added
self.add_import.update(transformer.add_import)

if transformer.needs_checkpoint_at_loop_start:
insert_at_loop_start = True

if insert_at_loop_start:
# insert checkpoint at start of body
new_body = list(updated_node.body.body)
new_body.insert(0, self.checkpoint_statement())
indentedblock = updated_node.body.with_changes(body=new_body)
updated_node = updated_node.with_changes(body=indentedblock)

self.ensure_imported_library()

self.restore_state(original_node)
return updated_node

Expand Down
4 changes: 2 additions & 2 deletions tests/autofix_files/async911.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,10 @@ async def foo_try_1(): # error: 0, "exit", Statement("function definition", lin

# no checkpoint after yield in ValueError
async def foo_try_2(): # error: 0, "exit", Statement("yield", lineno+5)
await trio.lowlevel.checkpoint()
try:
await foo()
except ValueError:
await trio.lowlevel.checkpoint()
# try might not have checkpointed
yield # error: 8, "yield", Statement("function definition", lineno-5)
except:
Expand All @@ -398,10 +398,10 @@ async def foo_try_3(): # error: 0, "exit", Statement("yield", lineno+6)


async def foo_try_4(): # safe
await trio.lowlevel.checkpoint()
try:
...
except:
await trio.lowlevel.checkpoint()
yield # error: 8, "yield", Statement("function definition", lineno-4)
finally:
await foo()
Expand Down
17 changes: 6 additions & 11 deletions tests/autofix_files/async911.py.diff
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@


async def foo_while_endless_3():
@@ x,9 x,11 @@
@@ x,13 x,16 @@
# try
async def foo_try_1(): # error: 0, "exit", Statement("function definition", lineno) # error: 0, "exit", Statement("yield", lineno+2)
try:
Expand All @@ -263,22 +263,20 @@


# no checkpoint after yield in ValueError
@@ x,12 x,14 @@
async def foo_try_2(): # error: 0, "exit", Statement("yield", lineno+5)
+ await trio.lowlevel.checkpoint()
try:
await foo()
except ValueError:
+ await trio.lowlevel.checkpoint()
# try might not have checkpointed
yield # error: 8, "yield", Statement("function definition", lineno-5)
except:
@@ x,6 x,7 @@
await foo()
else:
pass
+ await trio.lowlevel.checkpoint()


async def foo_try_3(): # error: 0, "exit", Statement("yield", lineno+6)
@@ x,13 x,16 @@
@@ x,10 x,13 @@
except:
await foo()
else:
Expand All @@ -288,13 +286,10 @@


async def foo_try_4(): # safe
+ await trio.lowlevel.checkpoint()
try:
...
except:
+ await trio.lowlevel.checkpoint()
yield # error: 8, "yield", Statement("function definition", lineno-4)
finally:
await foo()
@@ x,6 x,7 @@
try:
await foo()
Expand Down
Loading
Loading