Skip to content

Commit

Permalink
Fix bug involving calling set on a template parameter within all bran…
Browse files Browse the repository at this point in the history
…ches of an if block (#1665)
  • Loading branch information
davidism authored Dec 21, 2024
2 parents fbc3a69 + 66587ce commit c8fdce1
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 10 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ Unreleased
filter. :issue:`1624`
- Using ``set`` for multiple assignment (``a, b = 1, 2``) does not fail when the
target is a namespace attribute. :issue:`1413`
- Using ``set`` in all branches of ``{% if %}{% elif %}{% else %}`` blocks
does not cause the variable to be considered initially undefined.
:issue:`1253`


Version 3.1.4
Expand Down
17 changes: 7 additions & 10 deletions src/jinja2/idtracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,23 +121,20 @@ def load(self, name: str) -> None:
self._define_ref(name, load=(VAR_LOAD_RESOLVE, name))

def branch_update(self, branch_symbols: t.Sequence["Symbols"]) -> None:
stores: t.Dict[str, int] = {}
stores: t.Set[str] = set()

for branch in branch_symbols:
for target in branch.stores:
if target in self.stores:
continue
stores[target] = stores.get(target, 0) + 1
stores.update(branch.stores)

stores.difference_update(self.stores)

for sym in branch_symbols:
self.refs.update(sym.refs)
self.loads.update(sym.loads)
self.stores.update(sym.stores)

for name, branch_count in stores.items():
if branch_count == len(branch_symbols):
continue

target = self.find_ref(name) # type: ignore
for name in stores:
target = self.find_ref(name)
assert target is not None, "should not happen"

if self.parent is not None:
Expand Down
10 changes: 10 additions & 0 deletions tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,16 @@ def is_foo(ctx, s):
assert tmpl.render() == "foo"


def test_load_parameter_when_set_in_all_if_branches(env):
tmpl = env.from_string(
"{% if True %}{{ a.b }}{% set a = 1 %}"
"{% elif False %}{% set a = 2 %}"
"{% else %}{% set a = 3 %}{% endif %}"
"{{ a }}"
)
assert tmpl.render(a={"b": 0}) == "01"


@pytest.mark.parametrize("unicode_char", ["\N{FORM FEED}", "\x85"])
def test_unicode_whitespace(env, unicode_char):
content = "Lorem ipsum\n" + unicode_char + "\nMore text"
Expand Down

0 comments on commit c8fdce1

Please sign in to comment.