Skip to content

Commit

Permalink
Refactored ResultSet to lazy loading (#624)
Browse files Browse the repository at this point in the history
* lazy load added, tests fixed

* tests added

* lint

* lint

* integration tests fixed

* displaylimit added to integration tests

* code cleaned
  • Loading branch information
yafimvo authored Jun 22, 2023
1 parent e5c341e commit 4b938ed
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 46 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
* [Doc] Document --persist-replace in API section (#539)
* [Fix] Fixed CI issue by updating `invalid_connection_string_duckdb` in `test_magic.py` (#631)

* [Fix] Refactored `ResultSet` to lazy loading (#470)

## 0.7.9 (2023-06-19)

* [Feature] Modified `histogram` command to support data with NULL values ([#176](https://github.com/ploomber/jupysql/issues/176))
Expand Down
147 changes: 112 additions & 35 deletions src/sql/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,41 +113,14 @@ def __init__(self, sqlaproxy, config):
self.config = config
self.keys = {}
self._results = []
self.truncated = False
self.sqlaproxy = sqlaproxy

# https://peps.python.org/pep-0249/#description
is_dbapi_results = hasattr(sqlaproxy, "description")
self.is_dbapi_results = hasattr(sqlaproxy, "description")

self.pretty = None

if is_dbapi_results:
should_try_fetch_results = True
else:
should_try_fetch_results = sqlaproxy.returns_rows

if should_try_fetch_results:
# sql alchemy results
if not is_dbapi_results:
self.keys = sqlaproxy.keys()
elif isinstance(sqlaproxy.description, Iterable):
self.keys = [i[0] for i in sqlaproxy.description]
else:
self.keys = []

if len(self.keys) > 0:
if isinstance(config.autolimit, int) and config.autolimit > 0:
self._results = sqlaproxy.fetchmany(size=config.autolimit)
else:
self._results = sqlaproxy.fetchall()

self.field_names = unduplicate_field_names(self.keys)

_style = None

self.pretty = PrettyTable(self.field_names)

if isinstance(config.style, str):
_style = prettytable.__dict__[config.style.upper()]
self.pretty.set_style(_style)

def _repr_html_(self):
_cell_with_spaces_pattern = re.compile(r"(<td>)( {2,})")
if self.pretty:
Expand All @@ -156,17 +129,17 @@ def _repr_html_(self):
# to create clickable links
result = html.unescape(result)
result = _cell_with_spaces_pattern.sub(_nonbreaking_spaces, result)
if len(self) > self.pretty.row_count:
if self.truncated:
HTML = (
'%s\n<span style="font-style:italic;text-align:center;">'
"%d rows, truncated to displaylimit of %d</span>"
"Truncated to displaylimit of %d</span>"
"<br>"
'<span style="font-style:italic;text-align:center;">'
"If you want to see more, please visit "
'<a href="https://jupysql.ploomber.io/en/latest/api/configuration.html#displaylimit">displaylimit</a>' # noqa: E501
" configuration</span>"
)
result = HTML % (result, len(self), self.pretty.row_count)
result = HTML % (result, self.pretty.row_count)
return result
else:
return None
Expand All @@ -175,7 +148,9 @@ def __len__(self):
return len(self._results)

def __iter__(self):
for result in self._results:
results = self._fetch_query_results(fetch_all=True)

for result in results:
yield result

def __str__(self, *arg, **kwarg):
Expand Down Expand Up @@ -364,6 +339,105 @@ def csv(self, filename=None, **format_params):
else:
return outfile.getvalue()

def fetch_results(self, fetch_all=False):
"""
Returns a limited representation of the query results.
Parameters
----------
fetch_all : bool default False
Return all query rows
"""
is_dbapi_results = self.is_dbapi_results
sqlaproxy = self.sqlaproxy
config = self.config

if is_dbapi_results:
should_try_fetch_results = True
else:
should_try_fetch_results = sqlaproxy.returns_rows

if should_try_fetch_results:
# sql alchemy results
if not is_dbapi_results:
self.keys = sqlaproxy.keys()
elif isinstance(sqlaproxy.description, Iterable):
self.keys = [i[0] for i in sqlaproxy.description]
else:
self.keys = []

if len(self.keys) > 0:
self._results = self._fetch_query_results(fetch_all=fetch_all)

self.field_names = unduplicate_field_names(self.keys)

_style = None

self.pretty = PrettyTable(self.field_names)

if isinstance(config.style, str):
_style = prettytable.__dict__[config.style.upper()]
self.pretty.set_style(_style)

return self

def _fetch_query_results(self, fetch_all=False):
"""
Returns rows of a query result as a list of tuples.
Parameters
----------
fetch_all : bool default False
Return all query rows
"""
sqlaproxy = self.sqlaproxy
config = self.config
_should_try_lazy_fetch = hasattr(sqlaproxy, "_soft_closed")

_should_fetch_all = (
(config.displaylimit == 0 or not config.displaylimit)
or fetch_all
or not _should_try_lazy_fetch
)

is_autolimit = isinstance(config.autolimit, int) and config.autolimit > 0
is_connection_closed = (
sqlaproxy._soft_closed if _should_try_lazy_fetch else False
)

should_return_results = is_connection_closed or (
len(self._results) > 0 and is_autolimit
)

if should_return_results:
# this means we already loaded all
# the results to self._results or we use
# autolimit and shouldn't fetch more
results = self._results
else:
if is_autolimit:
results = sqlaproxy.fetchmany(size=config.autolimit)
else:
if _should_fetch_all:
all_results = sqlaproxy.fetchall()
results = self._results + all_results
self._results = results
else:
results = sqlaproxy.fetchmany(size=config.displaylimit)

if _should_try_lazy_fetch:
# Try to fetch an extra row to find out
# if there are more results to fetch
row = sqlaproxy.fetchone()
if row is not None:
results += [row]

# Check if we have more rows to show
if config.displaylimit > 0:
self.truncated = len(results) > config.displaylimit

return results


def display_affected_rowcount(rowcount):
if rowcount > 0:
Expand Down Expand Up @@ -557,6 +631,9 @@ def run(conn, sql, config):
return df
else:
resultset = ResultSet(result, config)

# lazy load
resultset.fetch_results()
return select_df_type(resultset, config)


Expand Down
2 changes: 2 additions & 0 deletions src/tests/integration/test_generic_db_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def test_create_table_with_indexed_df(
ip_with_dynamic_db = request.getfixturevalue(ip_with_dynamic_db)
# Clean up

ip_with_dynamic_db.run_cell("%config SqlMagic.displaylimit = 0")

ip_with_dynamic_db.run_cell(
f"%sql DROP TABLE {test_table_name_dict['new_table_from_df']}"
)
Expand Down
3 changes: 3 additions & 0 deletions src/tests/integration/test_mssql.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def test_cte(ip_with_MSSQL, test_table_name_dict):

def test_create_table_with_indexed_df(ip_with_MSSQL, test_table_name_dict):
# MSSQL gives error if DB doesn't exist

ip_with_MSSQL.run_cell("%config SqlMagic.displaylimit = 0")

try:
ip_with_MSSQL.run_cell(
f"%sql DROP TABLE {test_table_name_dict['new_table_from_df']}"
Expand Down
2 changes: 2 additions & 0 deletions src/tests/integration/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def test_query_count(ip_with_oracle, test_table_name_dict):

@pytest.mark.xfail(reason="Some issue with checking isidentifier part in persist")
def test_create_table_with_indexed_df(ip_with_oracle, test_table_name_dict):
ip_with_oracle.run_cell("%config SqlMagic.displaylimit = 0")

# Prepare DF
ip_with_oracle.run_cell(
f"""results = %sql SELECT * FROM {test_table_name_dict['taxi']} \
Expand Down
11 changes: 4 additions & 7 deletions src/tests/test_magic.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def test_displaylimit_default(ip):
ip.run_cell("%sql INSERT INTO number_table VALUES (4, 3)")

out = runsql(ip, "SELECT * FROM number_table;")
assert "truncated to displaylimit of 10" in out._repr_html_()
assert "Truncated to displaylimit of 10" in out._repr_html_()


def test_displaylimit(ip):
Expand All @@ -577,7 +577,7 @@ def test_displaylimit_enabled_truncated_length(ip, config_value, expected_length

ip.run_cell(f"%config SqlMagic.displaylimit = {config_value}")
out = runsql(ip, "SELECT * FROM number_table;")
assert f"truncated to displaylimit of {expected_length}" in out._repr_html_()
assert f"Truncated to displaylimit of {expected_length}" in out._repr_html_()


@pytest.mark.parametrize("config_value", [(None), (0)])
Expand All @@ -591,7 +591,7 @@ def test_displaylimit_enabled_no_limit(

ip.run_cell(f"%config SqlMagic.displaylimit = {config_value}")
out = runsql(ip, "SELECT * FROM number_table;")
assert "truncated to displaylimit of " not in out._repr_html_()
assert "Truncated to displaylimit of " not in out._repr_html_()


@pytest.mark.parametrize(
Expand Down Expand Up @@ -645,10 +645,7 @@ def test_displaylimit_with_conditional_clause(
out = runsql(ip, query_clause)

if expected_truncated_length:
assert (
f"{expected_truncated_length} rows, truncated to displaylimit of 10"
in out._repr_html_()
)
assert "Truncated to displaylimit of 10" in out._repr_html_()


def test_column_local_vars(ip):
Expand Down
Loading

0 comments on commit 4b938ed

Please sign in to comment.