@@ -42,13 +42,17 @@ def __init__(self, node: ast.Call, funcname: str, _):
4242
4343 if self .funcname == "CancelScope" :
4444 self .has_timeout = False
45+ for kw in node .keywords :
46+ # note: sets to True even if timeout is explicitly set to inf
47+ if kw .arg == "deadline" :
48+ self .has_timeout = True
49+
50+ # trio 0.27 adds shield parameter to all scope helpers
51+ if self .funcname in cancel_scope_names :
4552 for kw in node .keywords :
4653 # Only accepts constant values
4754 if kw .arg == "shield" and isinstance (kw .value , ast .Constant ):
4855 self .shielded = kw .value .value
49- # sets to True even if timeout is explicitly set to inf
50- if kw .arg == "deadline" :
51- self .has_timeout = True
5256
5357 def __init__ (self , * args : Any , ** kwargs : Any ):
5458 super ().__init__ (* args , ** kwargs )
@@ -109,7 +113,12 @@ def visit_With(self, node: ast.With | ast.AsyncWith):
109113
110114 # Check for a `with trio.<scope_creator>`
111115 for item in node .items :
112- call = get_matching_call (item .context_expr , * cancel_scope_names )
116+ call = get_matching_call (
117+ item .context_expr ,
118+ "open_nursery" ,
119+ "create_task_group" ,
120+ * cancel_scope_names ,
121+ )
113122 if call is None :
114123 continue
115124
@@ -122,7 +131,18 @@ def visit_With(self, node: ast.With | ast.AsyncWith):
122131 break
123132
124133 def visit_AsyncWith (self , node : ast .AsyncWith ):
125- self .async_call_checker (node )
134+ # trio.open_nursery and anyio.create_task_group are not cancellation points
135+ # so only treat this as an async call if it contains a call that does not match.
136+ # asyncio.TaskGroup() appears to be a source of cancellation when exiting.
137+ for item in node .items :
138+ if not (
139+ get_matching_call (item .context_expr , "open_nursery" , base = "trio" )
140+ or get_matching_call (
141+ item .context_expr , "create_task_group" , base = "anyio"
142+ )
143+ ):
144+ self .async_call_checker (node )
145+ break
126146 self .visit_With (node )
127147
128148 def visit_Try (self , node : ast .Try ):
@@ -160,18 +180,31 @@ def visit_ExceptHandler(self, node: ast.ExceptHandler):
160180
161181 def visit_Assign (self , node : ast .Assign ):
162182 # checks for <scopename>.shield = [True/False]
183+ # and <scopename>.cancel_scope.shield
184+ # We don't care to differentiate between them depending on if the scope is
185+ # a nursery or not, so e.g. `cs.cancel_scope.shield`/`nursery.shield` will "work"
163186 if self ._trio_context_managers and len (node .targets ) == 1 :
164- last_scope = self ._trio_context_managers [- 1 ]
165187 target = node .targets [0 ]
166- if (
167- last_scope .variable_name is not None
168- and isinstance (target , ast .Attribute )
169- and isinstance (target .value , ast .Name )
170- and target .value .id == last_scope .variable_name
171- and target .attr == "shield"
172- and isinstance (node .value , ast .Constant )
173- ):
174- last_scope .shielded = node .value .value
188+ for scope in reversed (self ._trio_context_managers ):
189+ if (
190+ scope .variable_name is not None
191+ and isinstance (node .value , ast .Constant )
192+ and isinstance (target , ast .Attribute )
193+ and target .attr == "shield"
194+ and (
195+ (
196+ isinstance (target .value , ast .Name )
197+ and target .value .id == scope .variable_name
198+ )
199+ or (
200+ isinstance (target .value , ast .Attribute )
201+ and target .value .attr == "cancel_scope"
202+ and isinstance (target .value .value , ast .Name )
203+ and target .value .value .id == scope .variable_name
204+ )
205+ )
206+ ):
207+ scope .shielded = node .value .value
175208
176209 def visit_FunctionDef (
177210 self , node : ast .FunctionDef | ast .AsyncFunctionDef | ast .Lambda
0 commit comments