|
9 | 9 |
|
10 | 10 | from abc import ABC, abstractmethod |
11 | 11 | from dataclasses import dataclass, field |
12 | | -from typing import TYPE_CHECKING, Any |
| 12 | +from typing import TYPE_CHECKING, Any, cast |
13 | 13 |
|
14 | 14 | import libcst as cst |
15 | 15 | import libcst.matchers as m |
@@ -73,8 +73,8 @@ class LoopState: |
73 | 73 | uncheckpointed_before_break: set[Statement] = field(default_factory=set) |
74 | 74 |
|
75 | 75 | artificial_errors: set[cst.Return | cst.Yield] = field(default_factory=set) |
76 | | - nodes_needing_checkpoints: list[cst.Return | cst.Yield] = field( |
77 | | - default_factory=list |
| 76 | + nodes_needing_checkpoints: list[cst.Return | cst.Yield | ArtificialStatement] = ( |
| 77 | + field(default_factory=list) |
78 | 78 | ) |
79 | 79 |
|
80 | 80 | def copy(self): |
@@ -215,8 +215,10 @@ def __init__( |
215 | 215 | self, |
216 | 216 | nodes_needing_checkpoint: Sequence[cst.Yield | cst.Return], |
217 | 217 | library: tuple[str, ...], |
| 218 | + explicitly_imported: dict[str, bool], |
218 | 219 | ): |
219 | 220 | super().__init__() |
| 221 | + self.explicitly_imported_library = explicitly_imported |
220 | 222 | self.nodes_needing_checkpoint = nodes_needing_checkpoint |
221 | 223 | self.__library = library |
222 | 224 |
|
@@ -262,6 +264,7 @@ class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors): |
262 | 264 | "CancelScope with no guaranteed checkpoint. This makes it potentially " |
263 | 265 | "impossible to cancel." |
264 | 266 | ), |
| 267 | + "ASYNC913": ("Indefinite loop with no guaranteed checkpoint."), |
265 | 268 | "ASYNC100": ( |
266 | 269 | "{0}.{1} context contains no checkpoints, remove the context or add" |
267 | 270 | " `await {0}.lowlevel.checkpoint()`." |
@@ -382,6 +385,8 @@ def leave_FunctionDef( |
382 | 385 | indentedblock = updated_node.body.with_changes(body=new_body) |
383 | 386 | updated_node = updated_node.with_changes(body=indentedblock) |
384 | 387 |
|
| 388 | + self.ensure_imported_library() |
| 389 | + |
385 | 390 | self.restore_state(original_node) |
386 | 391 | return updated_node |
387 | 392 |
|
@@ -772,6 +777,18 @@ def leave_While_body(self, node: cst.For | cst.While): |
772 | 777 | if not any_error: |
773 | 778 | self.loop_state.nodes_needing_checkpoints = [] |
774 | 779 |
|
| 780 | + if ( |
| 781 | + self.loop_state.infinite_loop |
| 782 | + and not self.loop_state.has_break |
| 783 | + and ARTIFICIAL_STATEMENT in self.uncheckpointed_statements |
| 784 | + and self.error(node, error_code="ASYNC913") |
| 785 | + ): |
| 786 | + # We can override nodes_needing_checkpoints, as that's solely for checkpoints |
| 787 | + # that error because of the artificial statement injected at the start of |
| 788 | + # the loop. When inserting a checkpoint at the start of the loop, those |
| 789 | + # will be remedied |
| 790 | + self.loop_state.nodes_needing_checkpoints = [ARTIFICIAL_STATEMENT] |
| 791 | + |
775 | 792 | # replace artificial statements in else with prebody uncheckpointed statements |
776 | 793 | # non-artificial stmts before continue/break/at body end will already be in them |
777 | 794 | for stmts in ( |
@@ -832,15 +849,38 @@ def leave_While( |
832 | 849 | | cst.FlattenSentinel[cst.For | cst.While] |
833 | 850 | | cst.RemovalSentinel |
834 | 851 | ): |
835 | | - if self.loop_state.nodes_needing_checkpoints: |
| 852 | + # don't bother autofixing same-line loops |
| 853 | + if isinstance(updated_node.body, cst.SimpleStatementSuite): |
| 854 | + self.restore_state(original_node) |
| 855 | + return updated_node |
| 856 | + |
| 857 | + # ASYNC913, indefinite loop with no guaranteed checkpoint |
| 858 | + if self.loop_state.nodes_needing_checkpoints == [ARTIFICIAL_STATEMENT]: |
| 859 | + if self.should_autofix(original_node, code="ASYNC913"): |
| 860 | + # insert checkpoint at start of body |
| 861 | + new_body = list(updated_node.body.body) |
| 862 | + new_body.insert(0, self.checkpoint_statement()) |
| 863 | + indentedblock = updated_node.body.with_changes(body=new_body) |
| 864 | + updated_node = updated_node.with_changes(body=indentedblock) |
| 865 | + |
| 866 | + self.ensure_imported_library() |
| 867 | + elif self.loop_state.nodes_needing_checkpoints: |
| 868 | + assert ARTIFICIAL_STATEMENT not in self.loop_state.nodes_needing_checkpoints |
836 | 869 | transformer = InsertCheckpointsInLoopBody( |
837 | | - self.loop_state.nodes_needing_checkpoints, self.library |
| 870 | + cast( |
| 871 | + "list[cst.Yield | cst.Return]", |
| 872 | + self.loop_state.nodes_needing_checkpoints, |
| 873 | + ), |
| 874 | + self.library, |
| 875 | + self.explicitly_imported_library, |
838 | 876 | ) |
839 | 877 | # type of updated_node expanded to the return type |
840 | 878 | updated_node = updated_node.visit(transformer) # type: ignore |
841 | 879 |
|
| 880 | + # include any necessary import added |
| 881 | + self.add_import.update(transformer.add_import) |
| 882 | + |
842 | 883 | self.restore_state(original_node) |
843 | | - # https://github.com/afonasev/flake8-return/issues/133 |
844 | 884 | return updated_node |
845 | 885 |
|
846 | 886 | leave_For = leave_While |
|
0 commit comments