From 1326a0ee642db32c547540c971e1b6f79dcfbdc1 Mon Sep 17 00:00:00 2001 From: Zsolt Dollenstein Date: Tue, 1 Dec 2020 13:01:29 +0000 Subject: [PATCH] fix assignment/access ordering in comprehensions (#423) --- libcst/metadata/scope_provider.py | 4 + libcst/metadata/tests/test_scope_provider.py | 102 +++++++++++++++++++ 2 files changed, 106 insertions(+) diff --git a/libcst/metadata/scope_provider.py b/libcst/metadata/scope_provider.py index c54317963..00ae536b4 100644 --- a/libcst/metadata/scope_provider.py +++ b/libcst/metadata/scope_provider.py @@ -36,6 +36,8 @@ ) +# Comprehensions are handled separately in _visit_comp_alike due to +# the complexity of the semantics _ASSIGNMENT_LIKE_NODES = ( cst.AnnAssign, cst.AsName, @@ -976,6 +978,8 @@ def _visit_comp_alike( self.provider.set_metadata(for_in, self.scope) with self._new_scope(ComprehensionScope, node): for_in.target.visit(self) + # Things from here on can refer to the target. + self.scope._assignment_count += 1 for condition in for_in.ifs: condition.visit(self) inner_for_in = for_in.inner_for_in diff --git a/libcst/metadata/tests/test_scope_provider.py b/libcst/metadata/tests/test_scope_provider.py index 11351ff1f..0903ebf80 100644 --- a/libcst/metadata/tests/test_scope_provider.py +++ b/libcst/metadata/tests/test_scope_provider.py @@ -1465,3 +1465,105 @@ def f(a): b_global_refs = list(b_global_assignment.references) self.assertEqual(len(b_global_refs), 1) self.assertEqual(b_global_refs[0].node, second_print.args[0].value) + + def test_ordering_comprehension(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + def f(a): + [a for a in [] for b in a] + [b for a in [] for b in a] + [a for a in [] for a in []] + a = 1 + """ + ) + f = cst.ensure_type(m.body[0], cst.FunctionDef) + a_param = f.params.params[0].name + a_param_assignment = list(scopes[a_param]["a"])[0] + a_param_refs = list(a_param_assignment.references) + self.assertEqual(a_param_refs, []) + first_comp = cst.ensure_type( + cst.ensure_type( + cst.ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0], + cst.Expr, + ).value, + cst.ListComp, + ) + a_comp_assignment = list(scopes[first_comp.elt]["a"])[0] + self.assertEqual(len(a_comp_assignment.references), 2) + self.assertIn( + first_comp.elt, [ref.node for ref in a_comp_assignment.references] + ) + + second_comp = cst.ensure_type( + cst.ensure_type( + cst.ensure_type(f.body.body[1], cst.SimpleStatementLine).body[0], + cst.Expr, + ).value, + cst.ListComp, + ) + b_comp_assignment = list(scopes[second_comp.elt]["b"])[0] + self.assertEqual(len(b_comp_assignment.references), 1) + a_second_comp_assignment = list(scopes[second_comp.elt]["a"])[0] + self.assertEqual(len(a_second_comp_assignment.references), 1) + + third_comp = cst.ensure_type( + cst.ensure_type( + cst.ensure_type(f.body.body[2], cst.SimpleStatementLine).body[0], + cst.Expr, + ).value, + cst.ListComp, + ) + a_third_comp_assignments = list(scopes[third_comp.elt]["a"]) + self.assertEqual(len(a_third_comp_assignments), 2) + a_third_comp_access = list(scopes[third_comp.elt].accesses)[0] + self.assertEqual(a_third_comp_access.node, third_comp.elt) + # We record both assignments because it's impossible to know which one + # the access refers to without running the program + self.assertEqual(len(a_third_comp_access.referents), 2) + inner_for_in = third_comp.for_in.inner_for_in + self.assertIsNotNone(inner_for_in) + if inner_for_in: + self.assertIn( + inner_for_in.target, + { + ref.node + for ref in a_third_comp_access.referents + if isinstance(ref, Assignment) + }, + ) + + a_global = ( + cst.ensure_type( + cst.ensure_type(m.body[1], cst.SimpleStatementLine).body[0], cst.Assign + ) + .targets[0] + .target + ) + a_global_assignment = list(scopes[a_global]["a"])[0] + a_global_refs = list(a_global_assignment.references) + self.assertEqual(a_global_refs, []) + + def test_ordering_comprehension_confusing(self) -> None: + m, scopes = get_scope_metadata_provider( + """ + def f(a): + [a for a in a] + a = 1 + """ + ) + f = cst.ensure_type(m.body[0], cst.FunctionDef) + a_param = f.params.params[0].name + a_param_assignment = list(scopes[a_param]["a"])[0] + a_param_refs = list(a_param_assignment.references) + self.assertEqual(len(a_param_refs), 1) + comp = cst.ensure_type( + cst.ensure_type( + cst.ensure_type(f.body.body[0], cst.SimpleStatementLine).body[0], + cst.Expr, + ).value, + cst.ListComp, + ) + a_comp_assignment = list(scopes[comp.elt]["a"])[0] + self.assertEqual(list(a_param_refs)[0].node, comp.for_in.iter) + self.assertEqual(len(a_comp_assignment.references), 1) + self.assertEqual(list(a_comp_assignment.references)[0].node, comp.elt)