Skip to content

Commit

Permalink
Fix duckdb leading comments (#895)
Browse files Browse the repository at this point in the history
  • Loading branch information
marshallwhiteorg authored Oct 11, 2023
1 parent 973406b commit 3b0b1c0
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* [Fix] Comments added in SQL query to be stripped before saved as snippet (#886)
* [Fix] Fixed bug passing :NUMBER while string slicing in query (#901)
* [Fix] Disabled CTE generation when snippets are detected in a non-SELECT type query. (#651, #652)
* [Fix] Fix empty result in certain duckdb `SELECT` and `SUMMARIZE` queries with leading comments (#892)

## 0.10.2 (2023-09-22)

Expand Down
2 changes: 2 additions & 0 deletions src/sql/connection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
PLOOMBER_DOCS_LINK_STR,
default_alias_for_engine,
ResultSetCollection,
detect_duckdb_summarize_or_select,
)


Expand All @@ -17,4 +18,5 @@
"PLOOMBER_DOCS_LINK_STR",
"default_alias_for_engine",
"ResultSetCollection",
"detect_duckdb_summarize_or_select",
]
51 changes: 31 additions & 20 deletions src/sql/connection/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,34 +717,21 @@ def _connection_execute(self, query, parameters=None):
Parameters to use in the query (:variable format)
"""
parameters = parameters or {}

# we do not support multiple statements
if len(sqlparse.split(query)) > 1:
raise NotImplementedError("Only one statement is supported.")

words = query.split()

if words:
first_word_statement = words[0].lower()
else:
first_word_statement = ""

# NOTE: in duckdb db "from TABLE_NAME" is valid
# TODO: we can parse the query to ensure that it's a SELECT statement
# for example, it might start with WITH but the final statement might
# not be a SELECT
# `summarize` is added to support %sql SUMMARIZE table in duckdb
is_select = first_word_statement in {"select", "with", "from", "summarize"}

operation = partial(self._execute_with_parameters, query, parameters)
out = self._execute_with_error_handling(operation)

if self._requires_manual_commit:
# calling connection.commit() when using duckdb-engine will yield
# empty results if we commit after a SELECT statement
# see: https://github.com/Mause/duckdb_engine/issues/734
if is_select and self.dialect == "duckdb":
return out
# Calling connection.commit() when using duckdb-engine will yield
# empty results if we commit after a SELECT or SUMMARIZE statement,
# see: https://github.com/Mause/duckdb_engine/issues/734.
if self.dialect == "duckdb":
no_commit = detect_duckdb_summarize_or_select(query)
if no_commit:
return out

# in sqlalchemy 1.x, connection has no commit attribute
if IS_SQLALCHEMY_ONE:
Expand Down Expand Up @@ -1187,4 +1174,28 @@ def set_sqlalchemy_isolation_level(conn):
return False


def detect_duckdb_summarize_or_select(query):
"""
Checks if the SQL query is a DuckDB SELECT or SUMMARIZE statement.
Note:
Assumes there is only one SQL statement in the query.
"""
statements = sqlparse.parse(query)
if statements:
if len(statements) > 1:
raise NotImplementedError("Multiple statements are not supported")
stype = statements[0].get_type()
if stype == "SELECT":
return True
elif stype == "UNKNOWN":
# Further analysis is required
sql_stripped = sqlparse.format(query, strip_comments=True)
words = sql_stripped.split()
return len(words) > 0 and (
words[0].lower() == "from" or words[0].lower() == "summarize"
)
return False


atexit.register(ConnectionManager.close_all, verbose=True)
61 changes: 61 additions & 0 deletions src/tests/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
is_pep249_compliant,
default_alias_for_engine,
ResultSetCollection,
detect_duckdb_summarize_or_select,
)
from sql.warnings import JupySQLRollbackPerformed
from sql.connection import error_handling
Expand Down Expand Up @@ -1184,3 +1185,63 @@ def test_database_in_directory_that_doesnt_exist(tmp_empty, uri, expected):
SQLAlchemyConnection(engine=create_engine(uri))

assert expected in str(excinfo.value)


@pytest.mark.parametrize(
"query, expected_output",
[
("SELECT * FROM table", True),
("SUMMARIZE table", True),
("FROM table SELECT *", True),
("UPDATE table SET column=value", False),
("INSERT INTO table (column) VALUES (value)", False),
("INSERT INTO table SELECT * FROM table2", False),
(
"UPDATE table SET column=10 WHERE column IN (SELECT column FROM table2)",
False,
),
("WITH x AS (SELECT * FROM table) SELECT * FROM x", True),
("WITH x AS (SELECT * FROM table) INSERT INTO y SELECT * FROM x", False),
("", False),
("DELETE FROM table", False),
("WITH summarize AS (SELECT * FROM table) SELECT * FROM summarize", True),
(
"""
WITH summarize AS (SELECT * FROM table)
INSERT INTO y SELECT * FROM summarize
""",
False,
),
("UPDATE table SET column='SELECT'", False),
("CREATE TABLE SELECT (id INT)", False),
("CREATE TABLE x (SELECT VARCHAR(100))", False),
('INSTALL "x"', False),
("SELECT SUM(column) FILTER (WHERE column > 10) FROM table", True),
("SELECT column FROM (SELECT * FROM table WHERE column = 'SELECT') AS x", True),
# Invalid SQL returns false
("INSERT INTO table (column) VALUES ('SELECT')", False),
# Comments have no effect
("-- SELECT * FROM table", False),
("-- SELECT * FROM table\nSELECT * FROM table", True),
("-- SELECT * FROM table\nINSERT INTO table SELECT * FROM table2", False),
("-- FROM table SELECT *", False),
("-- FROM table SELECT *\n/**/FROM/**/ table SELECT */**/", True),
("-- FROM table SELECT *\nINSERT INTO table FROM table2 SELECT *", False),
(
"""
-- INSERT INTO table SELECT * FROM table2
SELECT /**/ * FROM tbl /**/
""",
True,
),
(
"""
-- INSERT INTO table SELECT * FROM table2
/**/SUMMARIZE/**/ /**//**/tbl/**/
""",
True,
),
],
)
def test_detect_duckdb_summarize_or_select(query, expected_output):
assert detect_duckdb_summarize_or_select(query) == expected_output
66 changes: 66 additions & 0 deletions src/tests/test_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1989,6 +1989,72 @@ def test_accessing_previously_nonexisting_file(ip_empty, tmp_empty, capsys):
assert expected in out


expected_summarize = {
"column_name": ("memid",),
"column_type": ("BIGINT",),
"min": ("1",),
"max": ("8",),
"approx_unique": ("5",),
"avg": ("3.8",),
"std": ("2.7748873851023217",),
"q25": ("2",),
"q50": ("3",),
"q75": ("6",),
"count": (5,),
"null_percentage": ("0.0%",),
}
expected_select = {"memid": (1, 2, 3, 5, 8)}


@pytest.mark.parametrize(
"cell, expected_output",
[
("%sql /* x */ SUMMARIZE df", expected_summarize),
("%sql /*x*//*x*/ SUMMARIZE /*x*/ df", expected_summarize),
(
"""%%sql
/*x*/
SUMMARIZE df
""",
expected_summarize,
),
(
"""%%sql
/*x*/
/*x*/
-- comment
SUMMARIZE df
/*x*/
""",
expected_summarize,
),
(
"""%%sql
/*x*/
SELECT * FROM df
""",
expected_select,
),
(
"""%%sql
/*x*/
FROM df SELECT *
""",
expected_select,
),
],
)
def test_comments_in_duckdb_select_summarize(ip_empty, cell, expected_output):
ip_empty.run_cell("%sql duckdb://")
df = pd.DataFrame( # noqa: F841
data=dict(
memid=[1, 2, 3, 5, 8],
),
)
out = ip_empty.run_cell(cell).result
assert out.dict() == expected_output


@pytest.mark.parametrize(
"sql_snippet, sql_query, expected_result, raises",
[
Expand Down

0 comments on commit 3b0b1c0

Please sign in to comment.