Skip to content

Commit

Permalink
fix assignment/access ordering in comprehensions (#423)
Browse files Browse the repository at this point in the history
  • Loading branch information
zsol authored Dec 1, 2020
1 parent 2485d5a commit 1326a0e
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
4 changes: 4 additions & 0 deletions libcst/metadata/scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
)


# Comprehensions are handled separately in _visit_comp_alike due to
# the complexity of the semantics
_ASSIGNMENT_LIKE_NODES = (
cst.AnnAssign,
cst.AsName,
Expand Down Expand Up @@ -976,6 +978,8 @@ def _visit_comp_alike(
self.provider.set_metadata(for_in, self.scope)
with self._new_scope(ComprehensionScope, node):
for_in.target.visit(self)
# Things from here on can refer to the target.
self.scope._assignment_count += 1
for condition in for_in.ifs:
condition.visit(self)
inner_for_in = for_in.inner_for_in
Expand Down
102 changes: 102 additions & 0 deletions libcst/metadata/tests/test_scope_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -1465,3 +1465,105 @@ def f(a):
b_global_refs = list(b_global_assignment.references)
self.assertEqual(len(b_global_refs), 1)
self.assertEqual(b_global_refs[0].node, second_print.args[0].value)

def test_ordering_comprehension(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
def f(a):
[a for a in [] for b in a]
[b for a in [] for b in a]
[a for a in [] for a in []]
a = 1
"""
)
f = cst.ensure_type(m.body[0], cst.FunctionDef)
a_param = f.params.params[0].name
a_param_assignment = list(scopes[a_param]["a"])[0]
a_param_refs = list(a_param_assignment.references)
self.assertEqual(a_param_refs, [])
first_comp = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.ListComp,
)
a_comp_assignment = list(scopes[first_comp.elt]["a"])[0]
self.assertEqual(len(a_comp_assignment.references), 2)
self.assertIn(
first_comp.elt, [ref.node for ref in a_comp_assignment.references]
)

second_comp = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[1], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.ListComp,
)
b_comp_assignment = list(scopes[second_comp.elt]["b"])[0]
self.assertEqual(len(b_comp_assignment.references), 1)
a_second_comp_assignment = list(scopes[second_comp.elt]["a"])[0]
self.assertEqual(len(a_second_comp_assignment.references), 1)

third_comp = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[2], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.ListComp,
)
a_third_comp_assignments = list(scopes[third_comp.elt]["a"])
self.assertEqual(len(a_third_comp_assignments), 2)
a_third_comp_access = list(scopes[third_comp.elt].accesses)[0]
self.assertEqual(a_third_comp_access.node, third_comp.elt)
# We record both assignments because it's impossible to know which one
# the access refers to without running the program
self.assertEqual(len(a_third_comp_access.referents), 2)
inner_for_in = third_comp.for_in.inner_for_in
self.assertIsNotNone(inner_for_in)
if inner_for_in:
self.assertIn(
inner_for_in.target,
{
ref.node
for ref in a_third_comp_access.referents
if isinstance(ref, Assignment)
},
)

a_global = (
cst.ensure_type(
cst.ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Assign
)
.targets[0]
.target
)
a_global_assignment = list(scopes[a_global]["a"])[0]
a_global_refs = list(a_global_assignment.references)
self.assertEqual(a_global_refs, [])

def test_ordering_comprehension_confusing(self) -> None:
m, scopes = get_scope_metadata_provider(
"""
def f(a):
[a for a in a]
a = 1
"""
)
f = cst.ensure_type(m.body[0], cst.FunctionDef)
a_param = f.params.params[0].name
a_param_assignment = list(scopes[a_param]["a"])[0]
a_param_refs = list(a_param_assignment.references)
self.assertEqual(len(a_param_refs), 1)
comp = cst.ensure_type(
cst.ensure_type(
cst.ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0],
cst.Expr,
).value,
cst.ListComp,
)
a_comp_assignment = list(scopes[comp.elt]["a"])[0]
self.assertEqual(list(a_param_refs)[0].node, comp.for_in.iter)
self.assertEqual(len(a_comp_assignment.references), 1)
self.assertEqual(list(a_comp_assignment.references)[0].node, comp.elt)

0 comments on commit 1326a0e

Please sign in to comment.