Skip to content

Commit

Permalink
Configure the db fixture in test cases using the predefined DBConfig.
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou committed Nov 29, 2023
1 parent f966cc9 commit ee83876
Show file tree
Hide file tree
Showing 15 changed files with 123 additions and 66 deletions.
11 changes: 2 additions & 9 deletions superduperdb/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,9 @@ def build(uri, mapping):
conn = mongomock.MongoClient()
return mapping['mongodb'](conn, name)
else:
name = uri.split('//')[0]
conn = ibis.connect(uri)
db_name = uri.split('/')[-1]
if 'ibis' in mapping:
cls_ = mapping['ibis']
elif 'sqlalchemy' in mapping:
cls_ = mapping['sqlalchemy']
conn = conn.con
else:
raise ValueError('No ibis or sqlalchemy backend specified')
return cls_(conn, db_name)
return mapping['ibis'](conn, name)


def build_compute(cfg):
Expand Down
32 changes: 19 additions & 13 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from superduperdb.components.vector_index import VectorIndex
from superduperdb.ext.pillow.encoder import pil_image

from .db_config import DBConfig

_config._CONFIG_IMMUTABLE = False


Expand All @@ -40,10 +42,6 @@
GLOBAL_TEST_N_DATA_POINTS = 250
LOCAL_TEST_N_DATA_POINTS = 5

MONGOMOCK_URI = 'mongomock:///test_db'
SQLITE_URI = 'sqlite://:memory:'


_sleep = time.sleep

SCOPE = 'function'
Expand All @@ -54,6 +52,8 @@

SDDB_USE_MONGOMOCK = 'SDDB_USE_MONGOMOCK' in os.environ
SDDB_INSTRUMENT_TIME = 'SDDB_INSTRUMENT_TIME' in os.environ
RANDOM_SEED = 42
random.seed(RANDOM_SEED)


@pytest.fixture(autouse=SDDB_INSTRUMENT_TIME, scope=SCOPE)
Expand Down Expand Up @@ -264,9 +264,9 @@ def create_db(CFG, **kwargs):
return db

add_encoders(db)
n_data = kwargs.get('n_data', LOCAL_TEST_N_DATA_POINTS)

# prepare data
n_data = kwargs.get('n_data', LOCAL_TEST_N_DATA_POINTS)
is_mongodb_bachend = isinstance(db.databackend, MongoDataBackend)
if is_mongodb_bachend:
add_random_data_to_mongo_db(db, number_data_points=n_data)
Expand All @@ -287,18 +287,24 @@ def create_db(CFG, **kwargs):
@pytest.fixture
def db(request, monkeypatch) -> Iterator[Datalayer]:
# TODO: Use pre-defined config instead of dict here
db_type, setup_config = (
request.param if hasattr(request, 'param') else ("mongodb", None)
)
setup_config = setup_config or {}
if db_type == "mongodb":
monkeypatch.setattr(CFG, 'data_backend', MONGOMOCK_URI)
elif db_type == "sqldb":
monkeypatch.setattr(CFG, 'data_backend', SQLITE_URI)
param = request.param if hasattr(request, 'param') else DBConfig.mongodb
if isinstance(param, dict):
# e.g. @pytest.mark.parametrize("db", [DBConfig.sqldb], indirect=True)
setup_config = param.copy()
elif isinstance(param, tuple):
# e.g. @pytest.mark.parametrize(
# "db", [(DBConfig.sqldb, {'n_data': 10})], indirect=True)
setup_config = param[0].copy()
setup_config.update(param[1] or {})
else:
raise ValueError(f'Unsupported param type: {type(param)}')

monkeypatch.setattr(CFG, 'data_backend', setup_config['data_backend'])

db = create_db(CFG, **setup_config)
yield db

db_type = setup_config.get('db_type', 'mongodb')
if db_type == "mongodb":
db.drop(force=True)
elif db_type == "sqldb":
Expand Down
41 changes: 41 additions & 0 deletions test/db_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
MONGOMOCK_URI = 'mongomock:///test_db'
SQLITE_URI = 'sqlite://:memory:'
N_DATA_POINTS = 5


class DBConfig:
# Common configuration parameters
COMMON_CONFIG = {
'empty': False,
'add_encoders': True,
'add_data': True,
'add_models': True,
'add_vector_index': True,
'n_data': N_DATA_POINTS,
}

# Base configuration for MongoDB and SQL databases
_mongodb_base = {
'db_type': 'mongodb',
'data_backend': MONGOMOCK_URI,
**COMMON_CONFIG,
}
_sqldb_base = {'db_type': 'sqldb', 'data_backend': SQLITE_URI, **COMMON_CONFIG}

# Configurations for an empty database
mongodb_empty = {**_mongodb_base, 'empty': True}
sqldb_empty = {**_sqldb_base, 'empty': True}

# Full database configurations including encoder, data, model, and vector_index
mongodb = {**_mongodb_base}
sqldb = {**_sqldb_base}

# Configurations without vector_index
mongodb_no_vector_index = {**_mongodb_base, 'add_vector_index': False}
sqldb_no_vector_index = {**_sqldb_base, 'add_vector_index': False}

# Configurations with only encoder and data
mongodb_data = {**_mongodb_base, 'add_models': False, 'add_vector_index': False}
sqldb_data = {**_sqldb_base, 'add_models': False, 'add_vector_index': False}

# Additional frequently used presets can be added as needed...
7 changes: 4 additions & 3 deletions test/unittest/backends/base/test_query.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from test.db_config import DBConfig

import pytest

from superduperdb.backends.mongodb.query import Collection
from superduperdb.base.document import Document


@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True)
@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True)
def test_execute_insert_and_find(db):
collection = Collection('documents')
collection.insert_many([Document({'this': 'is a test'})]).execute(db)
r = collection.find_one().execute(db)
assert r['this'] == 'is a test'


@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True)
@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True)
def test_execute_complex_query(db):
collection = Collection('documents')
collection.insert_many(
Expand All @@ -25,7 +27,6 @@ def test_execute_complex_query(db):
assert sorted(cur_this) == sorted(expected)


@pytest.mark.parametrize("db", [('mongodb', None)], indirect=True)
def test_execute_like_queries(db):
collection = Collection('documents')
# get a data point for testing
Expand Down
3 changes: 2 additions & 1 deletion test/unittest/backends/mongodb/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
torch = None

import random
from test.db_config import DBConfig

from superduperdb.backends.mongodb.query import Collection
from superduperdb.base.document import Document
Expand Down Expand Up @@ -63,7 +64,7 @@ def test_replace(db):


@pytest.mark.skipif(not torch, reason='Torch not installed')
@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True)
@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True)
def test_insert_from_uris(db, image_url):
import PIL

Expand Down
Loading

0 comments on commit ee83876

Please sign in to comment.