From 1d2b97b515383f400a80e140e09a036d7eb9ca66 Mon Sep 17 00:00:00 2001 From: SangGyu An Date: Thu, 6 Jul 2023 09:40:32 -0700 Subject: [PATCH] snippets display improvement --- CHANGELOG.md | 1 + src/sql/cmd/snippets.py | 9 +++-- src/tests/test_magic_cmd.py | 67 +++++++++++++++++++++++++++++++++++-- 3 files changed, 72 insertions(+), 5 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a28ea3025..4773429f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ * [Feature] Support flexible spacing `myvar=<<` operator ([#525](https://github.com/ploomber/jupysql/issues/525)) * [Feature] Added a line under `ResultSet` to distinguish it from data frame and error message when invalid operations are performed (#468) * [Feature] Moved `%sqlrender` feature to `%sqlcmd snippets` (#647) +* [Feature] Added tables listing stored snippets when `%sqlcmd snippets` is called (#648) * [Doc] Modified integrations content to ensure they're all consistent (#523) * [Doc] Document --persist-replace in API section (#539) diff --git a/src/sql/cmd/snippets.py b/src/sql/cmd/snippets.py index 08f6a5720..fe6e19eb9 100644 --- a/src/sql/cmd/snippets.py +++ b/src/sql/cmd/snippets.py @@ -2,6 +2,7 @@ from sql.exceptions import UsageError from sql.cmd.cmd_utils import CmdParser from sql.store import store +from sql.display import Table, Message def _modify_display_msg(key, remaining_keys, dependent_keys=None): @@ -63,8 +64,8 @@ def snippets(others): help="Force delete all stored snippets", required=False, ) + all_snippets = util.get_all_keys() if len(others) == 1: - all_snippets = util.get_all_keys() if others[0] in all_snippets: return str(store[others[0]]) @@ -80,7 +81,11 @@ def snippets(others): args = parser.parse_args(others) SNIPPET_ARGS = [args.delete, args.delete_force, args.delete_force_all] if SNIPPET_ARGS.count(None) == len(SNIPPET_ARGS): - return ", ".join(util.get_all_keys()) + if len(all_snippets) == 0: + return Message("No snippets stored") + else: + return Table(["Stored snippets"], [[snippet] for snippet in all_snippets]) + if args.delete: deps = util.get_key_dependents(args.delete) if deps: diff --git a/src/tests/test_magic_cmd.py b/src/tests/test_magic_cmd.py index 1e5a1c7be..1a6adc0e1 100644 --- a/src/tests/test_magic_cmd.py +++ b/src/tests/test_magic_cmd.py @@ -8,6 +8,7 @@ from sql.connection import Connection from sql.store import store from sql.inspect import _is_numeric +from sql.display import Table, Message VALID_COMMANDS_MESSAGE = ( @@ -62,6 +63,14 @@ def ip_snippets(ip): yield ip +@pytest.fixture +def test_snippet_ip(ip): + for key in list(store): + del store[key] + ip.run_cell("%sql sqlite://") + yield ip + + @pytest.mark.parametrize( "cell, error_type, error_message", [ @@ -416,9 +425,61 @@ def test_test_error(ip, cell, error_type, error_message): assert str(out.error_in_exec) == error_message -def test_snippet(ip_snippets): - out = ip_snippets.run_cell("%sqlcmd snippets").result - assert "high_price, high_price_a, high_price_b" in out +@pytest.mark.parametrize( + "cmds, result", + [ + (["%sqlcmd snippets"], Message("No snippets stored")), + ( + [ + """%%sql --save test_snippet --no-execute +SELECT * FROM "test_store" WHERE price >= 1.50 +""", + "%sqlcmd snippets", + ], + Table( + ["Stored snippets"], + [["test_snippet"]], + ), + ), + ( + [ + """%%sql --save test_snippet --no-execute +SELECT * FROM "test_store" WHERE price >= 1.50 +""", + """%%sql --save test_snippet_a --no-execute +SELECT * FROM "test_snippet" WHERE symbol == 'a' +""", + "%sqlcmd snippets", + ], + Table( + ["Stored snippets"], + [["test_snippet"], ["test_snippet_a"]], + ), + ), + ( + [ + """%%sql --save test_snippet --no-execute +SELECT * FROM "test_store" WHERE price >= 1.50 +""", + """%%sql --save test_snippet_a --no-execute +SELECT * FROM "test_snippet" WHERE symbol == 'a' +""", + """%%sql --save test_snippet_b --no-execute +SELECT * FROM "test_snippet" WHERE symbol == 'b' +""", + "%sqlcmd snippets", + ], + Table( + ["Stored snippets"], + [["test_snippet"], ["test_snippet_a"], ["test_snippet_b"]], + ), + ), + ], +) +def test_snippet(test_snippet_ip, cmds, result): + out = [test_snippet_ip.run_cell(cmd) for cmd in cmds][-1].result + assert str(out) == str(result) + assert isinstance(out, type(result)) @pytest.mark.parametrize(