Skip to content

Commit

Permalink
Fix CTE generation when the snippets have trailing semicolons (#544)
Browse files Browse the repository at this point in the history
  • Loading branch information
edublancas authored May 30, 2023
1 parent 461a5b5 commit 1b2713e
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# CHANGELOG

## 0.7.7dev
* [Fix] Fix CTE generation when the snippets have trailing semicolons

## 0.7.6 (2023-05-29)

Expand Down
14 changes: 11 additions & 3 deletions src/sql/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ def __str__(self) -> str:
` (backtick)
"""
with_clause_template = Template(
"""WITH{% for name in with_ %} {{name}} AS ({{saved[name]._query}})\
"""WITH{% for name in with_ %} {{name}} AS ({{rts(saved[name]._query)}})\
{{ "," if not loop.last }}{% endfor %}{{query}}"""
)
with_clause_template_backtick = Template(
"""WITH{% for name in with_ %} `{{name}}` AS ({{saved[name]._query}})\
"""WITH{% for name in with_ %} `{{name}}` AS ({{rts(saved[name]._query)}})\
{{ "," if not loop.last }}{% endfor %}{{query}}"""
)
is_use_backtick = sql.connection.Connection.current.is_use_backtick_template()
Expand All @@ -118,10 +118,18 @@ def __str__(self) -> str:
with_clause_template_backtick if is_use_backtick else with_clause_template
)
return template.render(
query=self._query, saved=self._store._data, with_=with_all
query=self._query,
saved=self._store._data,
with_=with_all,
rts=_remove_trailing_semicolon,
)


def _remove_trailing_semicolon(query):
query_ = query.rstrip()
return query_[:-1] if query_[-1] == ";" else query


def _get_dependencies(store, keys):
"""Get a list of all dependencies to reconstruct the CTEs in keys"""
# get the dependencies for each key
Expand Down
34 changes: 34 additions & 0 deletions src/tests/test_magic_cte.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
def test_trailing_semicolons_removed_from_cte(ip):
ip.run_cell(
"""%%sql --save positive_x
SELECT * FROM number_table WHERE x > 0;
"""
)

ip.run_cell(
"""%%sql --save positive_y
SELECT * FROM number_table WHERE y > 0;
"""
)

cell_execution = ip.run_cell(
"""%%sql --save final --with positive_x --with positive_y
SELECT * FROM positive_x
UNION
SELECT * FROM positive_y;
"""
)

cell_final_query = ip.run_cell(
"%sqlrender final --with positive_x --with positive_y"
)

assert cell_execution.success
assert cell_final_query.result == (
"WITH `positive_x` AS (\nSELECT * "
"FROM number_table WHERE x > 0), `positive_y` AS (\nSELECT * "
"FROM number_table WHERE y > 0)\nSELECT * FROM positive_x\n"
"UNION\nSELECT * FROM positive_y;"
)

0 comments on commit 1b2713e

Please sign in to comment.