Skip to content

Commit 9dfc868

Browse files
committed
add match-case support to visitor91x
1 parent 3a0cbdb commit 9dfc868

File tree

5 files changed

+271
-1
lines changed

5 files changed

+271
-1
lines changed

flake8_async/visitors/visitor103_104.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def visit_If(self, node: ast.If):
202202
def visit_Match(self, node: ast.Match): # type: ignore[name-defined]
203203
if not self.unraised:
204204
return
205-
self.visit(node.subject) # this doesn't matter for 103/104, idr if it matters
206205
all_cases_raise = True
207206
has_fallback = False
208207
for case in node.cases:

flake8_async/visitors/visitor91x.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,22 @@ def copy(self):
198198
)
199199

200200

201+
@dataclass
202+
class MatchState:
203+
# TryState, LoopState, and MatchState all do fairly similar things. It would be nice
204+
# to harmonize them and share logic.
205+
base_uncheckpointed_statements: set[Statement] = field(default_factory=set)
206+
case_uncheckpointed_statements: set[Statement] = field(default_factory=set)
207+
has_fallback: bool = False
208+
209+
def copy(self):
210+
return MatchState(
211+
base_uncheckpointed_statements=self.base_uncheckpointed_statements.copy(),
212+
case_uncheckpointed_statements=self.case_uncheckpointed_statements.copy(),
213+
has_fallback=self.has_fallback,
214+
)
215+
216+
201217
def checkpoint_statement(library: str) -> cst.SimpleStatementLine:
202218
# logic before this should stop code from wanting to insert the non-existing
203219
# asyncio.lowlevel.checkpoint
@@ -373,6 +389,7 @@ def __init__(self, *args: Any, **kwargs: Any):
373389

374390
self.loop_state = LoopState()
375391
self.try_state = TryState()
392+
self.match_state = MatchState()
376393

377394
# ASYNC100
378395
self.has_checkpoint_stack: list[bool] = []
@@ -894,6 +911,53 @@ def visit_IfExp(self, node: cst.IfExp) -> bool:
894911
self.leave_If(node, node) # type: ignore
895912
return False # libcst shouldn't visit subnodes again
896913

914+
def leave_Match_subject(self, node: cst.Match) -> None:
915+
# We start the match logic after parsing the subject, instead of visit_Match,
916+
# since the subject is always executed and might checkpoint.
917+
if not self.async_function:
918+
return
919+
self.save_state(node, "match_state", copy=True)
920+
self.match_state = MatchState(self.uncheckpointed_statements.copy())
921+
922+
def visit_MatchCase(self, node: cst.MatchCase) -> None:
923+
# enter each case from the state after parsing the subject
924+
self.uncheckpointed_statements = self.match_state.base_uncheckpointed_statements
925+
926+
def leave_MatchCase_guard(self, node: cst.MatchCase) -> None:
927+
# `case _:` is no pattern and no guard, which means we know body is executed.
928+
# But we also know that `case _ if <guard>:` is guaranteed to execute the guard,
929+
# so for later logic we can treat them the same *if* there's no pattern and that
930+
# guard checkpoints.
931+
if (
932+
isinstance(node.pattern, cst.MatchAs)
933+
and node.pattern.pattern is None
934+
and (node.guard is None or not self.uncheckpointed_statements)
935+
):
936+
self.match_state.has_fallback = True
937+
938+
def leave_MatchCase(
939+
self, original_node: cst.MatchCase, updated_node: cst.MatchCase
940+
) -> cst.MatchCase:
941+
# collect the state at the end of each case
942+
self.match_state.case_uncheckpointed_statements.update(
943+
self.uncheckpointed_statements
944+
)
945+
return updated_node
946+
947+
def leave_Match(
948+
self, original_node: cst.Match, updated_node: cst.Match
949+
) -> cst.Match:
950+
# leave the Match with the worst-case of all branches
951+
self.uncheckpointed_statements = self.match_state.case_uncheckpointed_statements
952+
# if no fallback, also add the state at entering the match (after parsing subject)
953+
if not self.match_state.has_fallback:
954+
self.uncheckpointed_statements.update(
955+
self.match_state.base_uncheckpointed_statements
956+
)
957+
958+
self.restore_state(original_node)
959+
return updated_node
960+
897961
def visit_While(self, node: cst.While | cst.For):
898962
self.save_state(
899963
node,
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
# ARG --enable=ASYNC910,ASYNC911,ASYNC913
2+
# AUTOFIX
3+
# ASYNCIO_NO_AUTOFIX
4+
import trio
5+
6+
7+
async def foo(): ...
8+
9+
10+
async def match_subject() -> None:
11+
match await foo():
12+
case False:
13+
pass
14+
15+
16+
async def match_not_all_cases() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
17+
None
18+
):
19+
match foo():
20+
case 1:
21+
...
22+
case _:
23+
await foo()
24+
await trio.lowlevel.checkpoint()
25+
26+
27+
async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
28+
None
29+
):
30+
match foo():
31+
case 1:
32+
await foo()
33+
case 2:
34+
await foo()
35+
case _ if True:
36+
await foo()
37+
await trio.lowlevel.checkpoint()
38+
39+
40+
async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
41+
None
42+
):
43+
match foo():
44+
case 1:
45+
await foo()
46+
case 2:
47+
await foo()
48+
case _ if foo():
49+
await foo()
50+
await trio.lowlevel.checkpoint()
51+
52+
53+
async def match_all_cases() -> None:
54+
match foo():
55+
case 1:
56+
await foo()
57+
case 2:
58+
await foo()
59+
case _:
60+
await foo()
61+
62+
63+
async def match_fallback_await_in_guard() -> None:
64+
# The case guard is only executed if the pattern matches, so we can mostly treat
65+
# it as part of the body, except for a special case for fallback+checkpointing guard.
66+
match foo():
67+
case 1 if await foo():
68+
...
69+
case _ if await foo():
70+
...
71+
72+
73+
async def match_checkpoint_guard() -> None:
74+
# The above pattern is quite cursed, but this seems fairly reasonable to do.
75+
match foo():
76+
case 1 if await foo():
77+
...
78+
case _:
79+
await foo()
80+
81+
82+
async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
83+
None
84+
):
85+
match foo():
86+
case 1:
87+
...
88+
case _ if await foo():
89+
...
90+
await trio.lowlevel.checkpoint()
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
---
2+
+++
3+
@@ x,6 x,7 @@
4+
...
5+
case _:
6+
await foo()
7+
+ await trio.lowlevel.checkpoint()
8+
9+
10+
async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
11+
@@ x,6 x,7 @@
12+
await foo()
13+
case _ if True:
14+
await foo()
15+
+ await trio.lowlevel.checkpoint()
16+
17+
18+
async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
19+
@@ x,6 x,7 @@
20+
await foo()
21+
case _ if foo():
22+
await foo()
23+
+ await trio.lowlevel.checkpoint()
24+
25+
26+
async def match_all_cases() -> None:
27+
@@ x,3 x,4 @@
28+
...
29+
case _ if await foo():
30+
...
31+
+ await trio.lowlevel.checkpoint()

tests/eval_files/async91x_py310.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# ARG --enable=ASYNC910,ASYNC911,ASYNC913
2+
# AUTOFIX
3+
# ASYNCIO_NO_AUTOFIX
4+
import trio
5+
6+
7+
async def foo(): ...
8+
9+
10+
async def match_subject() -> None:
11+
match await foo():
12+
case False:
13+
pass
14+
15+
16+
async def match_not_all_cases() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
17+
None
18+
):
19+
match foo():
20+
case 1:
21+
...
22+
case _:
23+
await foo()
24+
25+
26+
async def match_no_fallback() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
27+
None
28+
):
29+
match foo():
30+
case 1:
31+
await foo()
32+
case 2:
33+
await foo()
34+
case _ if True:
35+
await foo()
36+
37+
38+
async def match_fallback_is_guarded() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
39+
None
40+
):
41+
match foo():
42+
case 1:
43+
await foo()
44+
case 2:
45+
await foo()
46+
case _ if foo():
47+
await foo()
48+
49+
50+
async def match_all_cases() -> None:
51+
match foo():
52+
case 1:
53+
await foo()
54+
case 2:
55+
await foo()
56+
case _:
57+
await foo()
58+
59+
60+
async def match_fallback_await_in_guard() -> None:
61+
# The case guard is only executed if the pattern matches, so we can mostly treat
62+
# it as part of the body, except for a special case for fallback+checkpointing guard.
63+
match foo():
64+
case 1 if await foo():
65+
...
66+
case _ if await foo():
67+
...
68+
69+
70+
async def match_checkpoint_guard() -> None:
71+
# The above pattern is quite cursed, but this seems fairly reasonable to do.
72+
match foo():
73+
case 1 if await foo():
74+
...
75+
case _:
76+
await foo()
77+
78+
79+
async def match_not_checkpoint_in_all_guards() -> ( # ASYNC910: 0, "exit", Statement("function definition", lineno)
80+
None
81+
):
82+
match foo():
83+
case 1:
84+
...
85+
case _ if await foo():
86+
...

0 commit comments

Comments
 (0)