Skip to content

Commit

Permalink
Walrus operator's left hand side now has STORE expression context (#433)
Browse files Browse the repository at this point in the history
  • Loading branch information
zsol authored Dec 16, 2020
1 parent 753a4f5 commit 1571cdd
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 0 deletions.
7 changes: 7 additions & 0 deletions libcst/metadata/expression_context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def visit_AugAssign(self, node: cst.AugAssign) -> bool:
node.value.visit(self)
return False

def visit_NamedExpr(self, node: cst.NamedExpr) -> bool:
node.target.visit(
ExpressionContextVisitor(self.provider, ExpressionContext.STORE)
)
node.value.visit(self)
return False

def visit_Name(self, node: cst.Name) -> bool:
self.provider.set_metadata(node, self.context)
return False
Expand Down
20 changes: 20 additions & 0 deletions libcst/metadata/tests/test_expression_context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,3 +411,23 @@ def test_function(self) -> None:
},
)
)

def test_walrus(self) -> None:
code = """
if x := y:
pass
"""
wrapper = MetadataWrapper(
parse_module(
dedent(code), config=cst.PartialParserConfig(python_version="3.8")
)
)
wrapper.visit(
DependentVisitor(
test=self,
name_to_context={
"x": ExpressionContext.STORE,
"y": ExpressionContext.LOAD,
},
)
)
22 changes: 22 additions & 0 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import sys
from textwrap import dedent
from typing import Mapping, Tuple, cast

Expand Down Expand Up @@ -1609,6 +1610,27 @@ def test_no_out_of_order_references_in_global_scope(self) -> None:
),
)

def test_walrus_accesses(self) -> None:
if sys.version_info < (3, 8):
self.skipTest("This python version doesn't support :=")
m, scopes = get_scope_metadata_provider(
"""
if x := y:
y = 1
x
"""
)
for scope in scopes.values():
for acc in scope.accesses:
self.assertEqual(
len(acc.referents),
1 if getattr(acc.node, "value") == "x" else 0,
msg=(
"Access for node has incorrect number of referents: "
+ f"{acc.node}"
),
)

def test_cast(self) -> None:
with self.assertRaises(cst.ParserSyntaxError):
m, scopes = get_scope_metadata_provider(
Expand Down

0 comments on commit 1571cdd

Please sign in to comment.