Skip to content

Commit

Permalink
Standardize fixture creation using configuration settings, with suppo…
Browse files Browse the repository at this point in the history
…rt for MongoDB and SQL databases.
  • Loading branch information
jieguangzhou committed Nov 28, 2023
1 parent e8b91df commit fa9a02d
Show file tree
Hide file tree
Showing 17 changed files with 514 additions and 340 deletions.
20 changes: 1 addition & 19 deletions superduperdb/backends/sqlalchemy/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from contextlib import contextmanager

import click
from bson import ObjectId
from sqlalchemy import JSON, Boolean, Column, DateTime, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
Expand Down Expand Up @@ -339,9 +338,7 @@ def _update_object(
def create_job(self, info: t.Dict):
try:
with self.session_context() as session:
job = Job(**info)
convert_object_id_type(job)
session.add(job)
session.add(Job(**info))
except Exception as e:
raise exceptions.MetaDataStoreJobException(
'Error while creating job in metadata store'
Expand Down Expand Up @@ -457,18 +454,3 @@ def get_model_queries(self, model: str):
{'query_id': id, 'query': query, 'sql': query.repr_()}
)
return unpacked_queries


def convert_object_id_type(job: Job):
kwargs = job.kwargs
ids = kwargs.get('ids', [])
if not isinstance(ids, list):
return

new_ids = []
for id in ids:
if isinstance(id, ObjectId):
new_ids.append(str(id))
else:
new_ids.append(id)
kwargs['ids'] = new_ids
7 changes: 1 addition & 6 deletions superduperdb/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,19 +50,14 @@ def build(uri, mapping):
return mapping['mongodb'](conn, name)
else:
conn = ibis.connect(uri)
db_name = conn.current_database
db_name = uri.split('/')[-1]
if 'ibis' in mapping:
cls_ = mapping['ibis']
elif 'sqlalchemy' in mapping:
cls_ = mapping['sqlalchemy']
conn = conn.con
else:
raise ValueError('No ibis or sqlalchemy backend specified')

# if ':memory:' in uri:
# conn = uri
# db_name = None

return cls_(conn, db_name)


Expand Down
140 changes: 99 additions & 41 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

import superduperdb as s
from superduperdb import CFG, logging
from superduperdb.backends.ibis.field_types import dtype
from superduperdb.backends.ibis.query import Table
from superduperdb.backends.mongodb.data_backend import MongoDataBackend
from superduperdb.backends.mongodb.query import Collection

# ruff: noqa: E402
Expand All @@ -19,6 +22,7 @@
from superduperdb.base.document import Document
from superduperdb.components.dataset import Dataset
from superduperdb.components.listener import Listener
from superduperdb.components.schema import Schema
from superduperdb.components.vector_index import VectorIndex
from superduperdb.ext.pillow.encoder import pil_image

Expand All @@ -37,6 +41,8 @@
LOCAL_TEST_N_DATA_POINTS = 5

MONGOMOCK_URI = 'mongomock:///test_db'
SQLITE_URI = 'sqlite://:memory:'


_sleep = time.sleep

Expand Down Expand Up @@ -120,13 +126,54 @@ def valid_dataset():
return d


def add_random_data(
def add_random_data_to_sql_db(
db: Datalayer,
table_name: str = 'documents',
number_data_points: int = GLOBAL_TEST_N_DATA_POINTS,
):
float_tensor = db.encoders['torch.float32[32]']
data = []

schema = Schema(
identifier=table_name,
fields={
'id': dtype('str'),
'x': float_tensor,
'y': dtype('int32'),
'z': float_tensor,
},
)
t = Table(identifier=table_name, schema=schema)
db.add(t)

for i in range(number_data_points):
x = torch.randn(32)
y = int(random.random() > 0.5)
z = torch.randn(32)
data.append(
Document(
{
'id': str(i),
'x': x,
'y': y,
'z': z,
}
)
)
db.execute(
t.insert(data),
refresh=False,
)


def add_random_data_to_mongo_db(
db: Datalayer,
collection_name: str = 'documents',
number_data_points: int = GLOBAL_TEST_N_DATA_POINTS,
):
float_tensor = db.encoders['torch.float32[32]']
data = []

for i in range(number_data_points):
x = torch.randn(32)
y = int(random.random() > 0.5)
Expand All @@ -141,11 +188,10 @@ def add_random_data(
)
)

if data:
db.execute(
Collection(collection_name).insert_many(data),
refresh=False,
)
db.execute(
Collection(collection_name).insert_many(data),
refresh=False,
)


def add_encoders(db: Datalayer):
Expand Down Expand Up @@ -174,16 +220,25 @@ def add_vector_index(
db: Datalayer, collection_name='documents', identifier='test_vector_search'
):
# TODO: Support configurable key and model
is_mongodb_bachend = isinstance(db.databackend, MongoDataBackend)
if is_mongodb_bachend:
select_x = Collection(collection_name).find()
select_z = Collection(collection_name).find()
else:
table = db.load('table', collection_name).to_query()
select_x = table.select('id', 'x')
select_z = table.select('id', 'z')

db.add(
Listener(
select=Collection(collection_name).find(),
select=select_x,
key='x',
model='linear_a',
)
)
db.add(
Listener(
select=Collection(collection_name).find(),
select=select_z,
key='z',
model='linear_a',
)
Expand All @@ -202,49 +257,52 @@ def image_url():
return f'file://{path}'


def setup_db(db, **kwargs):
def create_db(CFG, **kwargs):
# TODO: support more parameters to control the setup
db = build_datalayer(CFG)
if kwargs.get('empty', False):
return db

add_encoders(db)
n_data = kwargs.get('n_data', GLOBAL_TEST_N_DATA_POINTS)
add_random_data(db, number_data_points=n_data)
n_data = kwargs.get('n_data', LOCAL_TEST_N_DATA_POINTS)

# prepare data
is_mongodb_bachend = isinstance(db.databackend, MongoDataBackend)
if is_mongodb_bachend:
add_random_data_to_mongo_db(db, number_data_points=n_data)
else:
add_random_data_to_sql_db(db, number_data_points=n_data)

# prepare models
if kwargs.get('add_models', True):
add_models(db)

# prepare vector index
if kwargs.get('add_vector_index', True):
add_vector_index(db)


@pytest.fixture(scope='session')
def db() -> Datalayer:
db = build_datalayer(CFG, data_backend=MONGOMOCK_URI)
setup_db(db)
return db


@pytest.fixture
def local_db(request) -> Datalayer:
db = build_datalayer(CFG, data_backend=MONGOMOCK_URI)
setup_config = getattr(request, 'param', {'n_data': LOCAL_TEST_N_DATA_POINTS})
setup_db(db, **setup_config)
return db

def db(request, monkeypatch) -> Iterator[Datalayer]:
# TODO: Use pre-defined config instead of dict here
db_type, setup_config = (
request.param if hasattr(request, 'param') else ("mongodb", None)
)
setup_config = setup_config or {}
if db_type == "mongodb":
monkeypatch.setattr(CFG, 'data_backend', MONGOMOCK_URI)
elif db_type == "sqldb":
monkeypatch.setattr(CFG, 'data_backend', SQLITE_URI)

@pytest.fixture
def local_empty_db(request, monkeypatch) -> Datalayer:
for key, value in DB_CONFIGS.items():
monkeypatch.setattr(CFG, key, value)
db = build_datalayer(CFG)
db = create_db(CFG, **setup_config)
yield db
db.drop(force=True)


DB_CONFIGS = {
'data_backend': MONGOMOCK_URI,
# 'metadata_store': "sqlite://:memory:",
# 'data_backend': "sqlite://:memory:",
# 'data_backend': "mongodb://testmongodbuser:testmongodbpassword@localhost:27018/test_db",
# 'metadata_store': "mysql://root:root123@localhost:3306/test_db",
'metadata_store': "sqlite://:memory:",
# 'data_backend': "mysql://root:root123@localhost:3306/test_db",
# 'metadata_store': "sqlite://mydb.sqlite",
'artifact_store': 'filesystem:///tmp/superduperdb_test',
}

if db_type == "mongodb":
db.drop(force=True)
elif db_type == "sqldb":
db.artifact_store.drop(force=True)
tables = db.databackend.conn.list_tables()
for table in tables:
db.databackend.conn.drop_table(table, force=True)
17 changes: 11 additions & 6 deletions test/unittest/backends/base/test_query.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,31 @@
import pytest

from superduperdb.backends.mongodb.query import Collection
from superduperdb.base.document import Document


def test_execute_insert_and_find(local_empty_db):
@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True)
def test_execute_insert_and_find(db):
collection = Collection('documents')
collection.insert_many([Document({'this': 'is a test'})]).execute(local_empty_db)
r = collection.find_one().execute(local_empty_db)
collection.insert_many([Document({'this': 'is a test'})]).execute(db)
r = collection.find_one().execute(db)
assert r['this'] == 'is a test'


def test_execute_complex_query(local_empty_db):
@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True)
def test_execute_complex_query(db):
collection = Collection('documents')
collection.insert_many(
[Document({'this': f'is a test {i}'}) for i in range(100)]
).execute(local_empty_db)
).execute(db)

cur = collection.find().limit(10).sort('this', -1).execute(local_empty_db)
cur = collection.find().limit(10).sort('this', -1).execute(db)
expected = [f'is a test {i}' for i in range(99, 89, -1)]
cur_this = [r['this'] for r in cur]
assert sorted(cur_this) == sorted(expected)


@pytest.mark.parametrize("db", [('mongodb', None)], indirect=True)
def test_execute_like_queries(db):
collection = Collection('documents')
# get a data point for testing
Expand Down
Loading

0 comments on commit fa9a02d

Please sign in to comment.