Skip to content

Commit

Permalink
Provide STORE for {Class,Function}Def.name in ExpressionContextProvid…
Browse files Browse the repository at this point in the history
…er (#394)

* Add failing test cases

* mark *Def names as STORE

* Update libcst/metadata/expression_context_provider.py

Co-authored-by: Jimmy Lai <[email protected]>

* Fix lint

* Visit annotations and params

* Fix and extend tests

Co-authored-by: Jimmy Lai <[email protected]>
  • Loading branch information
cdonovick and jimmylai authored Sep 29, 2020
1 parent efe0fdb commit 34c1826
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 8 deletions.
40 changes: 36 additions & 4 deletions libcst/matchers/tests/test_findall.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,15 @@ def foo(bar: int) -> bool:
meta.ExpressionContextProvider, meta.ExpressionContext.STORE
),
)
self.assertNodeSequenceEqual(booleans, [cst.Name("a"), cst.Name("b")])
self.assertNodeSequenceEqual(
booleans,
[
cst.Name("a"),
cst.Name("b"),
cst.Name("foo"),
cst.Name("bar"),
],
)

# Test that we can provide an explicit resolver and tree
booleans = findall(
Expand All @@ -85,7 +93,15 @@ def foo(bar: int) -> bool:
),
metadata_resolver=wrapper,
)
self.assertNodeSequenceEqual(booleans, [cst.Name("a"), cst.Name("b")])
self.assertNodeSequenceEqual(
booleans,
[
cst.Name("a"),
cst.Name("b"),
cst.Name("foo"),
cst.Name("bar"),
],
)

# Test that failing to provide metadata leads to no match
booleans = findall(
Expand Down Expand Up @@ -127,7 +143,15 @@ def foo(bar: int) -> bool:
wrapper = meta.MetadataWrapper(module)
visitor = TestVisitor()
wrapper.visit(visitor)
self.assertNodeSequenceEqual(visitor.results, [cst.Name("a"), cst.Name("b")])
self.assertNodeSequenceEqual(
visitor.results,
[
cst.Name("a"),
cst.Name("b"),
cst.Name("foo"),
cst.Name("bar"),
],
)

def test_findall_with_transformers(self) -> None:
# Find all assignments in a tree
Expand Down Expand Up @@ -160,7 +184,15 @@ def foo(bar: int) -> bool:
wrapper = meta.MetadataWrapper(module)
visitor = TestTransformer()
wrapper.visit(visitor)
self.assertNodeSequenceEqual(visitor.results, [cst.Name("a"), cst.Name("b")])
self.assertNodeSequenceEqual(
visitor.results,
[
cst.Name("a"),
cst.Name("b"),
cst.Name("foo"),
cst.Name("bar"),
],
)


class MatchersExtractAllTest(UnitTest):
Expand Down
8 changes: 4 additions & 4 deletions libcst/matchers/tests/test_matchers_with_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def bar() -> int:
visitor = TestVisitor()
module.visit(visitor)

self.assertEqual(visitor.match_names, {"a", "b", "c"})
self.assertEqual(visitor.match_names, {"a", "b", "c", "foo", "bar"})

def test_matches_on_transformers(self) -> None:
# Set up a simple visitor that has a metadata dependency, try to use it in matchers.
Expand Down Expand Up @@ -533,7 +533,7 @@ def bar() -> int:
visitor = TestTransformer()
module.visit(visitor)

self.assertEqual(visitor.match_names, {"a", "b", "c"})
self.assertEqual(visitor.match_names, {"a", "b", "c", "foo", "bar"})

def test_matches_decorator_on_visitors(self) -> None:
# Set up a simple visitor that has a metadata dependency, try to use it in matchers.
Expand Down Expand Up @@ -573,7 +573,7 @@ def bar() -> int:
visitor = TestVisitor()
module.visit(visitor)

self.assertEqual(visitor.match_names, {"a", "b", "c"})
self.assertEqual(visitor.match_names, {"a", "b", "c", "foo", "bar"})

def test_matches_decorator_on_transformers(self) -> None:
# Set up a simple visitor that has a metadata dependency, try to use it in matchers.
Expand Down Expand Up @@ -613,4 +613,4 @@ def bar() -> int:
visitor = TestTransformer()
module.visit(visitor)

self.assertEqual(visitor.match_names, {"a", "b", "c"})
self.assertEqual(visitor.match_names, {"a", "b", "c", "foo", "bar"})
38 changes: 38 additions & 0 deletions libcst/metadata/expression_context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,44 @@ def visit_List(self, node: cst.List) -> Optional[bool]:
def visit_StarredElement(self, node: cst.StarredElement) -> Optional[bool]:
self.provider.set_metadata(node, self.context)

def visit_ClassDef(self, node: cst.ClassDef) -> Optional[bool]:
node.name.visit(
ExpressionContextVisitor(self.provider, ExpressionContext.STORE)
)
node.body.visit(self)
for base in node.bases:
base.visit(self)
for keyword in node.keywords:
keyword.visit(self)
for decorator in node.decorators:
decorator.visit(self)
return False

def visit_FunctionDef(self, node: cst.FunctionDef) -> Optional[bool]:
node.name.visit(
ExpressionContextVisitor(self.provider, ExpressionContext.STORE)
)
node.params.visit(self)
node.body.visit(self)
for decorator in node.decorators:
decorator.visit(self)
returns = node.returns
if returns:
returns.visit(self)
return False

def visit_Param(self, node: cst.Param) -> Optional[bool]:
node.name.visit(
ExpressionContextVisitor(self.provider, ExpressionContext.STORE)
)
annotation = node.annotation
if annotation:
annotation.visit(self)
default = node.default
if default:
default.visit(self)
return False


class ExpressionContextProvider(BatchableMetadataProvider[Optional[ExpressionContext]]):
"""
Expand Down
35 changes: 35 additions & 0 deletions libcst/metadata/tests/test_expression_context_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


from textwrap import dedent
from typing import Dict, Optional, cast

import libcst as cst
Expand Down Expand Up @@ -376,3 +377,37 @@ def test_for(self) -> None:
},
)
)

def test_class(self) -> None:
code = """
class Foo(Bar):
x = y
"""
wrapper = MetadataWrapper(parse_module(dedent(code)))
wrapper.visit(
DependentVisitor(
test=self,
name_to_context={
"Foo": ExpressionContext.STORE,
"Bar": ExpressionContext.LOAD,
"x": ExpressionContext.STORE,
"y": ExpressionContext.LOAD,
},
)
)

def test_function(self) -> None:
code = """def foo(x: int = y) -> None: pass"""
wrapper = MetadataWrapper(parse_module(code))
wrapper.visit(
DependentVisitor(
test=self,
name_to_context={
"foo": ExpressionContext.STORE,
"x": ExpressionContext.STORE,
"int": ExpressionContext.LOAD,
"y": ExpressionContext.LOAD,
"None": ExpressionContext.LOAD,
},
)
)

0 comments on commit 34c1826

Please sign in to comment.