diff --git a/CHANGELOG.md b/CHANGELOG.md index bd68642de..e0899067b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ * [Fix] Comments added in SQL query to be stripped before saved as snippet (#886) * [Fix] Fixed bug passing :NUMBER while string slicing in query (#901) * [Fix] Disabled CTE generation when snippets are detected in a non-SELECT type query. (#651, #652) +* [Fix] Fix empty result in certain duckdb `SELECT` and `SUMMARIZE` queries with leading comments (#892) ## 0.10.2 (2023-09-22) diff --git a/src/sql/connection/__init__.py b/src/sql/connection/__init__.py index 213ae0f15..7c48e624b 100644 --- a/src/sql/connection/__init__.py +++ b/src/sql/connection/__init__.py @@ -6,6 +6,7 @@ PLOOMBER_DOCS_LINK_STR, default_alias_for_engine, ResultSetCollection, + detect_duckdb_summarize_or_select, ) @@ -17,4 +18,5 @@ "PLOOMBER_DOCS_LINK_STR", "default_alias_for_engine", "ResultSetCollection", + "detect_duckdb_summarize_or_select", ] diff --git a/src/sql/connection/connection.py b/src/sql/connection/connection.py index 96bc4d463..1d98c4cde 100644 --- a/src/sql/connection/connection.py +++ b/src/sql/connection/connection.py @@ -717,34 +717,21 @@ def _connection_execute(self, query, parameters=None): Parameters to use in the query (:variable format) """ parameters = parameters or {} - # we do not support multiple statements if len(sqlparse.split(query)) > 1: raise NotImplementedError("Only one statement is supported.") - words = query.split() - - if words: - first_word_statement = words[0].lower() - else: - first_word_statement = "" - - # NOTE: in duckdb db "from TABLE_NAME" is valid - # TODO: we can parse the query to ensure that it's a SELECT statement - # for example, it might start with WITH but the final statement might - # not be a SELECT - # `summarize` is added to support %sql SUMMARIZE table in duckdb - is_select = first_word_statement in {"select", "with", "from", "summarize"} - operation = partial(self._execute_with_parameters, query, parameters) out = self._execute_with_error_handling(operation) if self._requires_manual_commit: - # calling connection.commit() when using duckdb-engine will yield - # empty results if we commit after a SELECT statement - # see: https://github.com/Mause/duckdb_engine/issues/734 - if is_select and self.dialect == "duckdb": - return out + # Calling connection.commit() when using duckdb-engine will yield + # empty results if we commit after a SELECT or SUMMARIZE statement, + # see: https://github.com/Mause/duckdb_engine/issues/734. + if self.dialect == "duckdb": + no_commit = detect_duckdb_summarize_or_select(query) + if no_commit: + return out # in sqlalchemy 1.x, connection has no commit attribute if IS_SQLALCHEMY_ONE: @@ -1187,4 +1174,28 @@ def set_sqlalchemy_isolation_level(conn): return False +def detect_duckdb_summarize_or_select(query): + """ + Checks if the SQL query is a DuckDB SELECT or SUMMARIZE statement. + + Note: + Assumes there is only one SQL statement in the query. + """ + statements = sqlparse.parse(query) + if statements: + if len(statements) > 1: + raise NotImplementedError("Multiple statements are not supported") + stype = statements[0].get_type() + if stype == "SELECT": + return True + elif stype == "UNKNOWN": + # Further analysis is required + sql_stripped = sqlparse.format(query, strip_comments=True) + words = sql_stripped.split() + return len(words) > 0 and ( + words[0].lower() == "from" or words[0].lower() == "summarize" + ) + return False + + atexit.register(ConnectionManager.close_all, verbose=True) diff --git a/src/tests/test_connection.py b/src/tests/test_connection.py index 6fc3d2ac9..1329eebca 100644 --- a/src/tests/test_connection.py +++ b/src/tests/test_connection.py @@ -22,6 +22,7 @@ is_pep249_compliant, default_alias_for_engine, ResultSetCollection, + detect_duckdb_summarize_or_select, ) from sql.warnings import JupySQLRollbackPerformed from sql.connection import error_handling @@ -1184,3 +1185,63 @@ def test_database_in_directory_that_doesnt_exist(tmp_empty, uri, expected): SQLAlchemyConnection(engine=create_engine(uri)) assert expected in str(excinfo.value) + + +@pytest.mark.parametrize( + "query, expected_output", + [ + ("SELECT * FROM table", True), + ("SUMMARIZE table", True), + ("FROM table SELECT *", True), + ("UPDATE table SET column=value", False), + ("INSERT INTO table (column) VALUES (value)", False), + ("INSERT INTO table SELECT * FROM table2", False), + ( + "UPDATE table SET column=10 WHERE column IN (SELECT column FROM table2)", + False, + ), + ("WITH x AS (SELECT * FROM table) SELECT * FROM x", True), + ("WITH x AS (SELECT * FROM table) INSERT INTO y SELECT * FROM x", False), + ("", False), + ("DELETE FROM table", False), + ("WITH summarize AS (SELECT * FROM table) SELECT * FROM summarize", True), + ( + """ + WITH summarize AS (SELECT * FROM table) + INSERT INTO y SELECT * FROM summarize + """, + False, + ), + ("UPDATE table SET column='SELECT'", False), + ("CREATE TABLE SELECT (id INT)", False), + ("CREATE TABLE x (SELECT VARCHAR(100))", False), + ('INSTALL "x"', False), + ("SELECT SUM(column) FILTER (WHERE column > 10) FROM table", True), + ("SELECT column FROM (SELECT * FROM table WHERE column = 'SELECT') AS x", True), + # Invalid SQL returns false + ("INSERT INTO table (column) VALUES ('SELECT')", False), + # Comments have no effect + ("-- SELECT * FROM table", False), + ("-- SELECT * FROM table\nSELECT * FROM table", True), + ("-- SELECT * FROM table\nINSERT INTO table SELECT * FROM table2", False), + ("-- FROM table SELECT *", False), + ("-- FROM table SELECT *\n/**/FROM/**/ table SELECT */**/", True), + ("-- FROM table SELECT *\nINSERT INTO table FROM table2 SELECT *", False), + ( + """ + -- INSERT INTO table SELECT * FROM table2 + SELECT /**/ * FROM tbl /**/ + """, + True, + ), + ( + """ + -- INSERT INTO table SELECT * FROM table2 + /**/SUMMARIZE/**/ /**//**/tbl/**/ + """, + True, + ), + ], +) +def test_detect_duckdb_summarize_or_select(query, expected_output): + assert detect_duckdb_summarize_or_select(query) == expected_output diff --git a/src/tests/test_magic.py b/src/tests/test_magic.py index 27902d829..1ae8ca3b7 100644 --- a/src/tests/test_magic.py +++ b/src/tests/test_magic.py @@ -1989,6 +1989,72 @@ def test_accessing_previously_nonexisting_file(ip_empty, tmp_empty, capsys): assert expected in out +expected_summarize = { + "column_name": ("memid",), + "column_type": ("BIGINT",), + "min": ("1",), + "max": ("8",), + "approx_unique": ("5",), + "avg": ("3.8",), + "std": ("2.7748873851023217",), + "q25": ("2",), + "q50": ("3",), + "q75": ("6",), + "count": (5,), + "null_percentage": ("0.0%",), +} +expected_select = {"memid": (1, 2, 3, 5, 8)} + + +@pytest.mark.parametrize( + "cell, expected_output", + [ + ("%sql /* x */ SUMMARIZE df", expected_summarize), + ("%sql /*x*//*x*/ SUMMARIZE /*x*/ df", expected_summarize), + ( + """%%sql + /*x*/ + SUMMARIZE df + """, + expected_summarize, + ), + ( + """%%sql + /*x*/ + /*x*/ + -- comment + SUMMARIZE df + /*x*/ + """, + expected_summarize, + ), + ( + """%%sql + /*x*/ + SELECT * FROM df + """, + expected_select, + ), + ( + """%%sql + /*x*/ + FROM df SELECT * + """, + expected_select, + ), + ], +) +def test_comments_in_duckdb_select_summarize(ip_empty, cell, expected_output): + ip_empty.run_cell("%sql duckdb://") + df = pd.DataFrame( # noqa: F841 + data=dict( + memid=[1, 2, 3, 5, 8], + ), + ) + out = ip_empty.run_cell(cell).result + assert out.dict() == expected_output + + @pytest.mark.parametrize( "sql_snippet, sql_query, expected_result, raises", [