Skip to content

Commit

Permalink
Add UT for token expiry reconnection test
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Jun 18, 2024
1 parent 6dbf7ad commit 564aadc
Show file tree
Hide file tree
Showing 11 changed files with 75 additions and 22 deletions.
2 changes: 1 addition & 1 deletion superduperdb/backends/base/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def wrapper(*args, **kwargs):
return attr(*args, **kwargs)
except Exception as e:
error_message = str(e).lower()
if 'expire' in error_message or 'token' in error_message:
if 'expire' in error_message and 'token' in error_message:
logging.warn(
f"Authentication expiry detected: {e}. "
"Attempting to reconnect..."
Expand Down
8 changes: 7 additions & 1 deletion superduperdb/backends/base/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@ class MetaDataStore(ABC):
:param uri: URI to the databackend database.
:param flavour: Flavour of the databackend.
:param callback: Optional callback to create connection.
"""

def __init__(self, uri: str, flavour: t.Optional[str] = None):
def __init__(
self,
uri: t.Optional[str] = None,
flavour: t.Optional[str] = None,
callback: t.Optional[t.Callable] = None,
):
self.uri = uri
self.flavour = flavour

Expand Down
6 changes: 5 additions & 1 deletion superduperdb/backends/ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,11 @@ def build_artifact_store(self):

def build_metadata(self):
"""Build metadata for the database."""
return MetaDataStoreProxy(SQLAlchemyMetadata(self.uri))

def callback():
return self.conn.con, self.name

return MetaDataStoreProxy(SQLAlchemyMetadata(callback=callback))

def insert(self, table_name, raw_documents):
"""Insert data into the database.
Expand Down
4 changes: 1 addition & 3 deletions superduperdb/backends/ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pandas

from superduperdb import Document
from superduperdb import Document, logging
from superduperdb.backends.base.query import (
Query,
applies_to,
Expand Down Expand Up @@ -379,8 +379,6 @@ def drop_outputs(self, predict_id: str, embedded=False):
:param predict_ids: The ids of the predictions to select.
"""
if embedded:
from superduperdb import logging

logging.warn(
'Outputs cannot be emebedded in sql, dropping the entire output table.'
)
Expand Down
2 changes: 1 addition & 1 deletion superduperdb/backends/mongodb/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def db(self):

def build_metadata(self):
"""Build the metadata store for the data backend."""
return MetaDataStoreProxy(MongoMetaDataStore(self.uri))
return MetaDataStoreProxy(MongoMetaDataStore(callback=self.connection_callback))

def build_artifact_store(self):
"""Build the artifact store for the data backend."""
Expand Down
20 changes: 16 additions & 4 deletions superduperdb/backends/mongodb/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,26 @@ class MongoMetaDataStore(MetaDataStore):
:param conn: MongoDB client connection
:param name: Name of database to host filesystem
:param callback: Optional callback to create connection.
"""

def __init__(self, uri: str, flavour: t.Optional[str] = None):
def __init__(
self,
uri: t.Optional[str] = None,
flavour: t.Optional[str] = None,
callback: t.Optional[t.Callable] = None,
):
super().__init__(uri=uri, flavour=flavour)
from .data_backend import _connection_callback

self.conn, self.name = _connection_callback(uri, flavour)
self.connection_callback = lambda: _connection_callback(uri, flavour)
if callback:
self.connection_callback = callback
else:
assert uri
from .data_backend import _connection_callback

self.connection_callback = lambda: _connection_callback(uri, flavour)

self.conn, self.name = self.connection_callback()
self._setup()

def _setup(self):
Expand Down
19 changes: 15 additions & 4 deletions superduperdb/backends/sqlalchemy/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,25 @@ class SQLAlchemyMetadata(MetaDataStore):
:param conn: connection to the meta-data store
:param name: Name to identify DB using the connection
:param callback: Optional callback to create connection.
"""

def __init__(self, uri: str, flavour: t.Optional[str] = None):
def __init__(
self,
uri: t.Optional[str] = None,
flavour: t.Optional[str] = None,
callback: t.Optional[t.Callable] = None,
):
super().__init__(uri=uri, flavour=flavour)
assert isinstance(uri, str)

sql_conn = create_engine(uri)
name = uri.split('//')[0]
if callback:
self.connection_callback = callback
else:
assert isinstance(uri, str)
name = uri.split('//')[0]
self.connection_callback = lambda: (create_engine(uri), name)

sql_conn, name = self.connection_callback()

self.name = name
self.conn = sql_conn
Expand Down
2 changes: 1 addition & 1 deletion superduperdb/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def create(cls, uri, mapping: t.Dict):


class _DataBackendMatcher(_MetaDataMatcher):
patterns = {**_MetaDataMatcher.patterns, r'\.csv$': ('ibis', 'pandas')}
patterns = {**_MetaDataMatcher.patterns, r'.*\.csv$': ('ibis', 'pandas')}

@classmethod
def create(cls, uri, mapping: t.Dict):
Expand Down
4 changes: 3 additions & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,9 @@ def valid_dataset(db):
select = MongoQuery(table='documents').find({'_fold': 'valid'})
else:
table = db['documents']
select = table.select('id', 'x', 'y', 'z').filter(table._fold == 'valid')
select = table.select('id', '_fold', 'x', 'y', 'z').filter(
table._fold == 'valid'
)
d = Dataset(
identifier='my_valid',
select=select,
Expand Down
27 changes: 24 additions & 3 deletions test/unittest/base/test_datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import dataclasses as dc
from test.db_config import DBConfig
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from superduperdb.backends.ibis.field_types import dtype
from superduperdb.backends.mongodb.data_backend import MongoDataBackend
Expand Down Expand Up @@ -667,6 +667,7 @@ def test_reload_dataset(db):
"db",
[
(DBConfig.sqldb_no_vector_index, {'n_data': n_data_points}),
(DBConfig.mongodb_no_vector_index, {'n_data': n_data_points}),
],
indirect=True,
)
Expand All @@ -675,7 +676,9 @@ def test_dataset(db):
select = MongoQuery(table='documents').find({'_fold': 'valid'})
else:
table = db['documents']
select = table.select('id', 'x', 'y', 'z').filter(table._fold == 'valid')
select = table.select('id', '_fold', 'x', 'y', 'z').filter(
table._fold == 'valid'
)

d = Dataset(
identifier='test_dataset',
Expand All @@ -687,4 +690,22 @@ def test_dataset(db):
assert len(dataset.data) == len(list(db.execute(dataset.select)))


# TODO: add UT for task workflow
@pytest.mark.parametrize(
"db",
[
(DBConfig.mongodb_data, {'n_data': 6}),
],
indirect=True,
)
def test_retry_on_token_expiry(db):
# Mock the methods
db.cfg.auto_schema = True
db.databackend.reconnect = MagicMock()
db.databackend.auto_create_table_schema = MagicMock(
side_effect=[Exception("The connection token has been expired already"), None]
)

# Perform the insert operation
db.execute(MongoQuery(table='documents').insert([{'x': 1}]))

assert db.databackend.reconnect.call_count == 1
3 changes: 1 addition & 2 deletions test/unittest/ext/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,7 @@ def test_fit(db, valid_dataset, model):
if isinstance(db.databackend.type, MongoDataBackend):
select = MongoQuery(table='documents').find()
else:
select = db['documents'].select('id', 'x', 'y', 'z', '_fold')

select = db['documents'].select('id', '_fold', 'x', 'y', 'z')
trainer = TorchTrainer(
key=('x', 'y'),
select=select,
Expand Down

0 comments on commit 564aadc

Please sign in to comment.