Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ASYNC912: timeout/cancelscope with only conditional checkpoints #242

Merged
merged 6 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Changelog
*[CalVer, YY.month.patch](https://calver.org/)*

## 24.5.1
- Add ASYNC912: no checkpoints in with statement are guaranteed to run.

## 24.4.1
- ASYNC91X fix internal error caused by multiple `try/except` incorrectly sharing state.
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Note: 22X, 23X and 24X has not had asyncio-specific suggestions written.
- **ASYNC910**: Exit or `return` from async function with no guaranteed checkpoint or exception since function definition. You might want to enable this on a codebase to make it easier to reason about checkpoints, and make the logic of ASYNC911 correct.
- **ASYNC911**: Exit, `yield` or `return` from async iterable with no guaranteed checkpoint since possible function entry (yield or function definition)
Checkpoints are `await`, `async for`, and `async with` (on one of enter/exit).
- **ASYNC912**: TODO: write

### Removed Warnings
- **TRIOxxx**: All error codes are now renamed ASYNCxxx
Expand Down
2 changes: 1 addition & 1 deletion flake8_async/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@


# CalVer: YY.month.patch, e.g. first release of July 2022 == "22.7.1"
__version__ = "24.4.1"
__version__ = "24.5.1"


# taken from https://github.com/Zac-HD/shed
Expand Down
1 change: 0 additions & 1 deletion flake8_async/visitors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from . import (
visitor2xx,
visitor91x,
visitor100,
visitor101,
visitor102,
visitor103_104,
Expand Down
12 changes: 10 additions & 2 deletions flake8_async/visitors/flake8asyncvisitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ def error(
), "No error code defined, but class has multiple codes"
error_code = next(iter(self.error_codes))
# don't emit an error if this code is disabled in a multi-code visitor
elif strip_error_subidentifier(error_code) not in self.options.enabled_codes:
elif (
(ec_no_sub := strip_error_subidentifier(error_code))
not in self.options.enabled_codes
and ec_no_sub not in self.options.autofix_codes
):
Comment on lines -101 to +105
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's not great that there's duplication of logic between the visitors (it confused me for a bit while developing). I originally expected all visitors to be rewritten to use libcst, but given that's not going to happen anytime soon (or at all), I should probably refactor these two and put common code in a parent class.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense to me 👍

return

self.__state.problems.append(
Expand Down Expand Up @@ -217,7 +221,11 @@ def error(
error_code = next(iter(self.error_codes))
# don't emit an error if this code is disabled in a multi-code visitor
# TODO: write test for only one of 910/911 enabled/autofixed
elif strip_error_subidentifier(error_code) not in self.options.enabled_codes:
elif (
(ec_no_sub := strip_error_subidentifier(error_code))
not in self.options.enabled_codes
and ec_no_sub not in self.options.autofix_codes
):
return False # pragma: no cover

if self.is_noqa(node, error_code):
Expand Down
4 changes: 4 additions & 0 deletions flake8_async/visitors/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def disabled_by_default(error_class: type[T_EITHER]) -> type[T_EITHER]:
return error_class


def disable_codes_by_default(*codes: str) -> None:
default_disabled_error_codes.extend(codes)


def utility_visitor(c: type[T]) -> type[T]:
assert not hasattr(c, "error_codes")
c.error_codes = {}
Expand Down
90 changes: 0 additions & 90 deletions flake8_async/visitors/visitor100.py

This file was deleted.

108 changes: 78 additions & 30 deletions flake8_async/visitors/visitor91x.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from ..base import Statement
from .flake8asyncvisitor import Flake8AsyncVisitor_cst
from .helpers import (
AttributeCall,
cancel_scope_names,
disabled_by_default,
disable_codes_by_default,
error_class_cst,
flatten_preserving_comments,
fnmatch_qualified_name_cst,
func_has_decorator,
iter_guaranteed_once_cst,
Expand All @@ -31,8 +33,11 @@
from collections.abc import Mapping, Sequence


class ArtificialStatement(Statement): ...


# Statement injected at the start of loops to track missed checkpoints.
ARTIFICIAL_STATEMENT = Statement("artificial", -1)
ARTIFICIAL_STATEMENT = ArtificialStatement("artificial", -1)


def func_empty_body(node: cst.FunctionDef) -> bool:
Expand Down Expand Up @@ -233,8 +238,10 @@ def leave_Yield(
leave_Return = leave_Yield # type: ignore


disable_codes_by_default("ASYNC910", "ASYNC911", "ASYNC912")


@error_class_cst
@disabled_by_default
class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors):
error_codes: Mapping[str, str] = {
"ASYNC910": (
Expand All @@ -249,6 +256,10 @@ class Visitor91X(Flake8AsyncVisitor_cst, CommonVisitors):
"CancelScope with no guaranteed checkpoint. This makes it potentially "
"impossible to cancel."
),
"ASYNC100": (
"{0}.{1} context contains no checkpoints, remove the context or add"
" `await {0}.lowlevel.checkpoint()`."
),
}

def __init__(self, *args: Any, **kwargs: Any):
Expand All @@ -262,15 +273,24 @@ def __init__(self, *args: Any, **kwargs: Any):
self.loop_state = LoopState()
self.try_state = TryState()

# ASYNC100
self.has_checkpoint_stack: list[bool] = []
self.node_dict: dict[cst.With, list[AttributeCall]] = {}

Comment on lines +282 to +285
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the primary ASYNC100 logic is simply copy-pasted from visitor100.py

def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool:
if code is None:
code = "ASYNC911" if self.has_yield else "ASYNC910"

return (
not self.noautofix
and super().should_autofix(
node, "ASYNC911" if self.has_yield else "ASYNC910"
)
and super().should_autofix(node, code)
and self.library != ("asyncio",)
)
Copy link
Member Author

@jakkdl jakkdl May 5, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was needed in a previous implementation, but because async100 does work on asyncio (now) and does not care about self.noautofix (which I still haven't figured out why it was introduced), it now doesn't use this method at all. But good to respect an explicitly defined code anyway, which it didn't before.


def checkpoint(self) -> None:
self.uncheckpointed_statements = set()
self.has_checkpoint_stack = [True] * len(self.has_checkpoint_stack)

def checkpoint_statement(self) -> cst.SimpleStatementLine:
return checkpoint_statement(self.library[0])

Expand All @@ -289,9 +309,11 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
"uncheckpointed_statements",
"loop_state",
"try_state",
"has_checkpoint_stack",
copy=True,
)
self.uncheckpointed_statements = set()
self.has_checkpoint_stack = []
self.has_yield = self.safe_decorator = False
self.loop_state = LoopState()

Expand Down Expand Up @@ -365,7 +387,7 @@ def check_function_exit(
any_errors = False
# raise the actual errors
for statement in self.uncheckpointed_statements:
if statement == ARTIFICIAL_STATEMENT:
if isinstance(statement, ArtificialStatement):
continue
any_errors |= self.error_91x(original_node, statement)

Expand All @@ -382,6 +404,7 @@ def leave_Return(
self.add_statement = self.checkpoint_statement()
# avoid duplicate error messages
self.uncheckpointed_statements = set()
# we don't treat it as a checkpoint for ASYNC100
jakkdl marked this conversation as resolved.
Show resolved Hide resolved

# return original node to avoid problems with identity equality
assert original_node.deep_equals(updated_node)
Expand All @@ -392,7 +415,7 @@ def error_91x(
node: cst.Return | cst.FunctionDef | cst.Yield,
statement: Statement,
) -> bool:
assert statement != ARTIFICIAL_STATEMENT
assert not isinstance(statement, ArtificialStatement)

if isinstance(node, cst.FunctionDef):
msg = "exit"
Expand All @@ -413,7 +436,7 @@ def leave_Await(
# so only set checkpoint after the await node

# all nodes are now checkpointed
self.uncheckpointed_statements = set()
self.checkpoint()
return updated_node

# raising exception means we don't need to checkpoint so we can treat it as one
Expand All @@ -425,27 +448,49 @@ def leave_Await(
# missing-checkpoint warning when there might in fact be one (i.e. a false alarm).
def visit_With_body(self, node: cst.With):
if getattr(node, "asynchronous", None):
self.uncheckpointed_statements = set()
if with_has_call(node, *cancel_scope_names) or with_has_call(
node, "timeout", "timeout_at", base="asyncio"
self.checkpoint()
if res := (
with_has_call(node, *cancel_scope_names)
or with_has_call(node, "timeout", "timeout_at", base="asyncio")
):
pos = self.get_metadata(PositionProvider, node).start # pyright: ignore
line: int = pos.line # pyright: ignore
column: int = pos.column # pyright: ignore
self.uncheckpointed_statements.add(Statement("with", line, column))
# self.uncheckpointed_statements.add(res[0])

def leave_With_body(self, node: cst.With):
pos = self.get_metadata(PositionProvider, node).start # pyright: ignore
line: int = pos.line # pyright: ignore
column: int = pos.column # pyright: ignore
s = Statement("with", line, column)
if s in self.uncheckpointed_statements:
self.error(node, error_code="ASYNC912")
self.uncheckpointed_statements.remove(s)

if getattr(node, "asynchronous", None):
self.uncheckpointed_statements = set()
self.uncheckpointed_statements.add(
ArtificialStatement("with", line, column)
)
self.node_dict[node] = res
self.has_checkpoint_stack.append(False)
else:
self.has_checkpoint_stack.append(True)

def leave_With(self, original_node: cst.With, updated_node: cst.With):
# ASYNC100
if not self.has_checkpoint_stack.pop():
autofix = len(updated_node.items) == 1
for res in self.node_dict[original_node]:
# bypass 910 & 911's should_autofix logic, which excludes asyncio
# (TODO: and uses self.noautofix ... which I don't remember what it's for)
autofix &= self.error(
res.node, res.base, res.function, error_code="ASYNC100"
) and super().should_autofix(res.node, code="ASYNC100")

if autofix:
return flatten_preserving_comments(updated_node)
# ASYNC912
else:
pos = self.get_metadata( # pyright: ignore
PositionProvider, original_node
).start # pyright: ignore
line: int = pos.line # pyright: ignore
column: int = pos.column # pyright: ignore
s = ArtificialStatement("with", line, column)
if s in self.uncheckpointed_statements:
self.error(original_node, error_code="ASYNC912")
self.uncheckpointed_statements.remove(s)
if getattr(original_node, "asynchronous", None):
self.checkpoint()
return updated_node

# error if no checkpoint since earlier yield or function entry
def leave_Yield(
Expand All @@ -455,6 +500,9 @@ def leave_Yield(
return updated_node
self.has_yield = True

# Treat as a checkpoint for ASYNC100
self.has_checkpoint_stack = [True] * len(self.has_checkpoint_stack)

if self.check_function_exit(original_node) and self.should_autofix(
original_node
):
Expand Down Expand Up @@ -629,7 +677,7 @@ def visit_While_body(self, node: cst.For | cst.While):
# appropriate errors if the loop doesn't checkpoint

if getattr(node, "asynchronous", None):
self.uncheckpointed_statements = set()
self.checkpoint()
else:
self.uncheckpointed_statements = {ARTIFICIAL_STATEMENT}

Expand Down Expand Up @@ -675,7 +723,7 @@ def leave_While_body(self, node: cst.For | cst.While):
# AsyncFor guarantees checkpoint on running out of iterable
# so reset checkpoint state at end of loop. (but not state at break)
if getattr(node, "asynchronous", None):
self.uncheckpointed_statements = set()
self.checkpoint()
else:
# enter orelse with worst case:
# loop body might execute fully before entering orelse
Expand All @@ -699,7 +747,7 @@ def leave_While_orelse(self, node: cst.For | cst.While):
# if this is an infinite loop, with no break in it, don't raise
# alarms about the state after it.
if self.loop_state.infinite_loop and not self.loop_state.has_break:
self.uncheckpointed_statements = set()
self.checkpoint()
else:
# We may exit from:
# orelse (covering: no body, body until continue, and all body)
Expand Down Expand Up @@ -804,7 +852,7 @@ def visit_CompFor(self, node: cst.CompFor):

# if async comprehension, checkpoint
if node.asynchronous:
self.uncheckpointed_statements = set()
self.checkpoint()
self.comp_unknown = False
return False

Expand Down
Loading