diff --git a/libcst/metadata/expression_context_provider.py b/libcst/metadata/expression_context_provider.py index b06ba1132..d9dffa129 100644 --- a/libcst/metadata/expression_context_provider.py +++ b/libcst/metadata/expression_context_provider.py @@ -84,6 +84,13 @@ def visit_AugAssign(self, node: cst.AugAssign) -> bool: node.value.visit(self) return False + def visit_NamedExpr(self, node: cst.NamedExpr) -> bool: + node.target.visit( + ExpressionContextVisitor(self.provider, ExpressionContext.STORE) + ) + node.value.visit(self) + return False + def visit_Name(self, node: cst.Name) -> bool: self.provider.set_metadata(node, self.context) return False diff --git a/libcst/metadata/tests/test_expression_context_provider.py b/libcst/metadata/tests/test_expression_context_provider.py index 25cc1d0da..91008df69 100644 --- a/libcst/metadata/tests/test_expression_context_provider.py +++ b/libcst/metadata/tests/test_expression_context_provider.py @@ -411,3 +411,23 @@ def test_function(self) -> None: }, ) ) + + def test_walrus(self) -> None: + code = """ + if x := y: + pass + """ + wrapper = MetadataWrapper( + parse_module( + dedent(code), config=cst.PartialParserConfig(python_version="3.8") + ) + ) + wrapper.visit( + DependentVisitor( + test=self, + name_to_context={ + "x": ExpressionContext.STORE, + "y": ExpressionContext.LOAD, + }, + ) + ) diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index d1566aac9..27a8f495c 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. +import sys from textwrap import dedent from typing import Mapping, Tuple, cast @@ -1609,6 +1610,27 @@ def test_no_out_of_order_references_in_global_scope(self) -> None: ), ) + def test_walrus_accesses(self) -> None: + if sys.version_info < (3, 8): + self.skipTest("This python version doesn't support :=") + m, scopes = get_scope_metadata_provider( + """ + if x := y: + y = 1 + x + """ + ) + for scope in scopes.values(): + for acc in scope.accesses: + self.assertEqual( + len(acc.referents), + 1 if getattr(acc.node, "value") == "x" else 0, + msg=( + "Access for node has incorrect number of referents: " + + f"{acc.node}" + ), + ) + def test_cast(self) -> None: with self.assertRaises(cst.ParserSyntaxError): m, scopes = get_scope_metadata_provider(