From ee838766af76e160a11c222d82d4adf1d5346c94 Mon Sep 17 00:00:00 2001 From: JieguangZhou Date: Wed, 29 Nov 2023 11:30:45 +0800 Subject: [PATCH] Configure the db fixture in test cases using the predefined DBConfig. --- superduperdb/base/build.py | 11 +--- test/conftest.py | 32 ++++++----- test/db_config.py | 41 ++++++++++++++ test/unittest/backends/base/test_query.py | 7 ++- .../unittest/backends/mongodb/test_queries.py | 3 +- test/unittest/base/test_datalayer.py | 55 ++++++++++--------- test/unittest/base/test_documents.py | 4 +- test/unittest/component/test_model.py | 3 +- test/unittest/component/test_serialization.py | 4 +- test/unittest/ext/test_openai.py | 3 +- test/unittest/ext/test_sklearn.py | 3 +- test/unittest/ext/test_torch.py | 4 +- test/unittest/ext/test_transformers.py | 6 +- test/unittest/ext/test_vanilla.py | 10 ++-- test/unittest/misc/test_downloaders.py | 3 +- 15 files changed, 123 insertions(+), 66 deletions(-) create mode 100644 test/db_config.py diff --git a/superduperdb/base/build.py b/superduperdb/base/build.py index bf40b4aa96..7ac8f6e6d8 100644 --- a/superduperdb/base/build.py +++ b/superduperdb/base/build.py @@ -55,16 +55,9 @@ def build(uri, mapping): conn = mongomock.MongoClient() return mapping['mongodb'](conn, name) else: + name = uri.split('//')[0] conn = ibis.connect(uri) - db_name = uri.split('/')[-1] - if 'ibis' in mapping: - cls_ = mapping['ibis'] - elif 'sqlalchemy' in mapping: - cls_ = mapping['sqlalchemy'] - conn = conn.con - else: - raise ValueError('No ibis or sqlalchemy backend specified') - return cls_(conn, db_name) + return mapping['ibis'](conn, name) def build_compute(cfg): diff --git a/test/conftest.py b/test/conftest.py index 45ec089753..5ddb4b36b2 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -26,6 +26,8 @@ from superduperdb.components.vector_index import VectorIndex from superduperdb.ext.pillow.encoder import pil_image +from .db_config import DBConfig + _config._CONFIG_IMMUTABLE = False @@ -40,10 +42,6 @@ GLOBAL_TEST_N_DATA_POINTS = 250 LOCAL_TEST_N_DATA_POINTS = 5 -MONGOMOCK_URI = 'mongomock:///test_db' -SQLITE_URI = 'sqlite://:memory:' - - _sleep = time.sleep SCOPE = 'function' @@ -54,6 +52,8 @@ SDDB_USE_MONGOMOCK = 'SDDB_USE_MONGOMOCK' in os.environ SDDB_INSTRUMENT_TIME = 'SDDB_INSTRUMENT_TIME' in os.environ +RANDOM_SEED = 42 +random.seed(RANDOM_SEED) @pytest.fixture(autouse=SDDB_INSTRUMENT_TIME, scope=SCOPE) @@ -264,9 +264,9 @@ def create_db(CFG, **kwargs): return db add_encoders(db) - n_data = kwargs.get('n_data', LOCAL_TEST_N_DATA_POINTS) # prepare data + n_data = kwargs.get('n_data', LOCAL_TEST_N_DATA_POINTS) is_mongodb_bachend = isinstance(db.databackend, MongoDataBackend) if is_mongodb_bachend: add_random_data_to_mongo_db(db, number_data_points=n_data) @@ -287,18 +287,24 @@ def create_db(CFG, **kwargs): @pytest.fixture def db(request, monkeypatch) -> Iterator[Datalayer]: # TODO: Use pre-defined config instead of dict here - db_type, setup_config = ( - request.param if hasattr(request, 'param') else ("mongodb", None) - ) - setup_config = setup_config or {} - if db_type == "mongodb": - monkeypatch.setattr(CFG, 'data_backend', MONGOMOCK_URI) - elif db_type == "sqldb": - monkeypatch.setattr(CFG, 'data_backend', SQLITE_URI) + param = request.param if hasattr(request, 'param') else DBConfig.mongodb + if isinstance(param, dict): + # e.g. @pytest.mark.parametrize("db", [DBConfig.sqldb], indirect=True) + setup_config = param.copy() + elif isinstance(param, tuple): + # e.g. @pytest.mark.parametrize( + # "db", [(DBConfig.sqldb, {'n_data': 10})], indirect=True) + setup_config = param[0].copy() + setup_config.update(param[1] or {}) + else: + raise ValueError(f'Unsupported param type: {type(param)}') + + monkeypatch.setattr(CFG, 'data_backend', setup_config['data_backend']) db = create_db(CFG, **setup_config) yield db + db_type = setup_config.get('db_type', 'mongodb') if db_type == "mongodb": db.drop(force=True) elif db_type == "sqldb": diff --git a/test/db_config.py b/test/db_config.py new file mode 100644 index 0000000000..e896869857 --- /dev/null +++ b/test/db_config.py @@ -0,0 +1,41 @@ +MONGOMOCK_URI = 'mongomock:///test_db' +SQLITE_URI = 'sqlite://:memory:' +N_DATA_POINTS = 5 + + +class DBConfig: + # Common configuration parameters + COMMON_CONFIG = { + 'empty': False, + 'add_encoders': True, + 'add_data': True, + 'add_models': True, + 'add_vector_index': True, + 'n_data': N_DATA_POINTS, + } + + # Base configuration for MongoDB and SQL databases + _mongodb_base = { + 'db_type': 'mongodb', + 'data_backend': MONGOMOCK_URI, + **COMMON_CONFIG, + } + _sqldb_base = {'db_type': 'sqldb', 'data_backend': SQLITE_URI, **COMMON_CONFIG} + + # Configurations for an empty database + mongodb_empty = {**_mongodb_base, 'empty': True} + sqldb_empty = {**_sqldb_base, 'empty': True} + + # Full database configurations including encoder, data, model, and vector_index + mongodb = {**_mongodb_base} + sqldb = {**_sqldb_base} + + # Configurations without vector_index + mongodb_no_vector_index = {**_mongodb_base, 'add_vector_index': False} + sqldb_no_vector_index = {**_sqldb_base, 'add_vector_index': False} + + # Configurations with only encoder and data + mongodb_data = {**_mongodb_base, 'add_models': False, 'add_vector_index': False} + sqldb_data = {**_sqldb_base, 'add_models': False, 'add_vector_index': False} + + # Additional frequently used presets can be added as needed... diff --git a/test/unittest/backends/base/test_query.py b/test/unittest/backends/base/test_query.py index c6c0591999..cb37a0865a 100644 --- a/test/unittest/backends/base/test_query.py +++ b/test/unittest/backends/base/test_query.py @@ -1,10 +1,12 @@ +from test.db_config import DBConfig + import pytest from superduperdb.backends.mongodb.query import Collection from superduperdb.base.document import Document -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_execute_insert_and_find(db): collection = Collection('documents') collection.insert_many([Document({'this': 'is a test'})]).execute(db) @@ -12,7 +14,7 @@ def test_execute_insert_and_find(db): assert r['this'] == 'is a test' -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_execute_complex_query(db): collection = Collection('documents') collection.insert_many( @@ -25,7 +27,6 @@ def test_execute_complex_query(db): assert sorted(cur_this) == sorted(expected) -@pytest.mark.parametrize("db", [('mongodb', None)], indirect=True) def test_execute_like_queries(db): collection = Collection('documents') # get a data point for testing diff --git a/test/unittest/backends/mongodb/test_queries.py b/test/unittest/backends/mongodb/test_queries.py index 45c020c8d9..1161bf8c12 100644 --- a/test/unittest/backends/mongodb/test_queries.py +++ b/test/unittest/backends/mongodb/test_queries.py @@ -6,6 +6,7 @@ torch = None import random +from test.db_config import DBConfig from superduperdb.backends.mongodb.query import Collection from superduperdb.base.document import Document @@ -63,7 +64,7 @@ def test_replace(db): @pytest.mark.skipif(not torch, reason='Torch not installed') -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_insert_from_uris(db, image_url): import PIL diff --git a/test/unittest/base/test_datalayer.py b/test/unittest/base/test_datalayer.py index d56e73dc40..4ea597aaf3 100644 --- a/test/unittest/base/test_datalayer.py +++ b/test/unittest/base/test_datalayer.py @@ -12,6 +12,7 @@ import dataclasses as dc +from test.db_config import DBConfig from unittest.mock import MagicMock, patch from superduperdb.backends.ibis.field_types import dtype @@ -98,7 +99,7 @@ def add_fake_model(db: Datalayer): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_add_version(db): # Check the component functions are called @@ -139,7 +140,7 @@ def test_add_version(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_add_component_with_bad_artifact(db): artifact = Artifact({'data': lambda x: x}, serializer='pickle') @@ -149,7 +150,7 @@ def test_add_component_with_bad_artifact(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_add_artifact_auto_replace(db): # Check artifact is automatically replaced to metadata @@ -162,7 +163,7 @@ def test_add_artifact_auto_replace(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_add_child(db): child_component = TestComponent(identifier='child') @@ -191,7 +192,7 @@ def test_add_child(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_add(db): component = TestComponent(identifier='test') @@ -212,7 +213,7 @@ def test_add(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_remove_component_version(db): db.add( @@ -239,7 +240,7 @@ def test_remove_component_version(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_remove_component_with_parent(db): # Can not remove the child component if the parent exists @@ -258,7 +259,7 @@ def test_remove_component_with_parent(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_remove_component_with_clean_up(db): # Test clean up @@ -272,7 +273,7 @@ def test_remove_component_with_clean_up(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_remove_component_from_data_layer_dict(db): # Test component is deleted from datalayer @@ -284,7 +285,7 @@ def test_remove_component_from_data_layer_dict(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_remove_component_with_artifact(db): # Test artifact is deleted from artifact store @@ -304,7 +305,7 @@ def test_remove_component_with_artifact(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_remove_one_version(db): db.add( @@ -320,7 +321,7 @@ def test_remove_one_version(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_remove_multi_version(db): db.add( @@ -336,7 +337,7 @@ def test_remove_multi_version(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_remove_not_exist_component(db): with pytest.raises(exceptions.ComponentException) as e: @@ -347,7 +348,7 @@ def test_remove_not_exist_component(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_show(db): db.add( @@ -386,7 +387,7 @@ def test_show(db): @pytest.mark.skipif(not torch, reason='Torch not installed') @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_predict(db: Datalayer): models = [ @@ -424,7 +425,7 @@ def test_predict(db: Datalayer): @pytest.mark.skipif(not torch, reason='Torch not installed') @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_predict_context(db: Datalayer): db.add( @@ -449,7 +450,7 @@ def test_predict_context(db: Datalayer): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_get_context(db): from superduperdb.backends.base.query import Select @@ -476,7 +477,7 @@ def test_get_context(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_load(db): db.add( @@ -513,7 +514,7 @@ def test_load(db): assert 'e1' in db.encoders -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_insert_mongo_db(db): add_fake_model(db) inserted_ids, _ = db.insert( @@ -530,7 +531,7 @@ def test_insert_mongo_db(db): assert sorted(result) == ['0', '1', '2', '3', '4'] -@pytest.mark.parametrize("db", [('sqldb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.sqldb_empty], indirect=True) def test_insert_sql_db(db): add_fake_model(db) table = db.load('table', 'documents') @@ -545,7 +546,7 @@ def test_insert_sql_db(db): assert sorted(result) == ['0', '1', '2', '3', '4'] -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_update_db(db): # TODO: test update sql db after the update method is implemented add_fake_model(db) @@ -568,7 +569,7 @@ def test_update_db(db): @pytest.mark.parametrize( "db", [ - ('mongodb', {'n_data': 6}), + (DBConfig.mongodb_data, {'n_data': 6}), ], indirect=True, ) @@ -580,7 +581,7 @@ def test_delete(db): @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_replace(db): model = Model( @@ -609,7 +610,7 @@ def test_replace(db): @pytest.mark.skipif(not torch, reason='Torch not installed') @pytest.mark.parametrize( - "db", [('mongodb', {'empty': True}), ('sqldb', {'empty': True})], indirect=True + "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_compound_component(db): m = TorchModel( @@ -650,7 +651,7 @@ def test_compound_component(db): @pytest.mark.skipif(not torch, reason='Torch not installed') -@pytest.mark.parametrize("db", [('mongodb', None), ('sqldb', None)], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb, DBConfig.sqldb], indirect=True) def test_reload_dataset(db): from superduperdb.components.dataset import Dataset @@ -674,8 +675,8 @@ def test_reload_dataset(db): @pytest.mark.parametrize( "db", [ - ('mongodb', {'add_vector_index': False, 'n_data': 500}), - ('sqldb', {'add_vector_index': False, 'n_data': 500}), + (DBConfig.mongodb_no_vector_index, {'n_data': 500}), + (DBConfig.sqldb_no_vector_index, {'n_data': 500}), ], indirect=True, ) diff --git a/test/unittest/base/test_documents.py b/test/unittest/base/test_documents.py index f4a72be0cd..3ffc90ee13 100644 --- a/test/unittest/base/test_documents.py +++ b/test/unittest/base/test_documents.py @@ -7,6 +7,8 @@ except ImportError: torch = None +from test.db_config import DBConfig + from superduperdb.base.document import Document @@ -32,7 +34,7 @@ def test_document_outputs(document): @pytest.mark.skipif(not torch, reason='Torch not installed') -@pytest.mark.parametrize("db", [('mongodb', None), ('sqldb', None)], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb, DBConfig.sqldb], indirect=True) def test_only_uri(db): r = Document( Document.decode( diff --git a/test/unittest/component/test_model.py b/test/unittest/component/test_model.py index 75bd9b582f..a956234060 100644 --- a/test/unittest/component/test_model.py +++ b/test/unittest/component/test_model.py @@ -1,4 +1,5 @@ import random +from test.db_config import DBConfig from unittest.mock import MagicMock, patch import numpy as np @@ -204,7 +205,7 @@ def test_pm_create_predict_job(predict_mixin): @patch.object(Datalayer, 'add') -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_pm_predict_and_listen(mock_add, predict_mixin, db): X = 'x' select = MagicMock(CompoundSelect) diff --git a/test/unittest/component/test_serialization.py b/test/unittest/component/test_serialization.py index 63e4976e59..0676d47f79 100644 --- a/test/unittest/component/test_serialization.py +++ b/test/unittest/component/test_serialization.py @@ -6,6 +6,8 @@ from superduperdb.ext.torch.encoder import tensor except ImportError: torch = None +from test.db_config import DBConfig + from sklearn.svm import SVC from superduperdb.base.artifact import Artifact @@ -27,7 +29,7 @@ def test_model(): @pytest.mark.skipif(not torch, reason='Torch not installed') -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_sklearn(db): m = Estimator( identifier='test', diff --git a/test/unittest/ext/test_openai.py b/test/unittest/ext/test_openai.py index 0d5c6d9ff1..66c0b976d7 100644 --- a/test/unittest/ext/test_openai.py +++ b/test/unittest/ext/test_openai.py @@ -1,5 +1,6 @@ import json import os +from test.db_config import DBConfig import openai import pytest @@ -45,7 +46,7 @@ def open_ai_with_rhymes(db, monkeypatch): filter_headers=['authorization'], record_on_exception=False, ) -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_retrieve_with_similar_context(open_ai_with_rhymes): db = open_ai_with_rhymes m = OpenAIChatCompletion( diff --git a/test/unittest/ext/test_sklearn.py b/test/unittest/ext/test_sklearn.py index ccdd9921db..70e5c70123 100644 --- a/test/unittest/ext/test_sklearn.py +++ b/test/unittest/ext/test_sklearn.py @@ -1,4 +1,5 @@ import random +from test.db_config import DBConfig import numpy import pytest @@ -60,7 +61,7 @@ def test_fit_predict_classic(self, pipeline, X, y): output = pipeline.predict(X) assert len(output) == len(X) - @pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) + @pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_fit_db(self, pipeline, data_in_db): pipeline.fit( X='X', diff --git a/test/unittest/ext/test_torch.py b/test/unittest/ext/test_torch.py index f9e7f22bc3..d4e0b98668 100644 --- a/test/unittest/ext/test_torch.py +++ b/test/unittest/ext/test_torch.py @@ -7,6 +7,8 @@ except ImportError: torch = None +from test.db_config import DBConfig + from superduperdb.backends.mongodb.query import Collection from superduperdb.components.metric import Metric @@ -50,7 +52,7 @@ def acc(x, y): @pytest.mark.skipif(not torch, reason='Torch not installed') @pytest.mark.parametrize( 'db', - [('mongodb', {'add_vector_index': False, 'add_models': False, 'n_data': 500})], + [(DBConfig.mongodb_data, {'n_data': 500})], indirect=True, ) def test_fit(db, valid_dataset): diff --git a/test/unittest/ext/test_transformers.py b/test/unittest/ext/test_transformers.py index 9d49476889..b664ddcd7b 100644 --- a/test/unittest/ext/test_transformers.py +++ b/test/unittest/ext/test_transformers.py @@ -7,6 +7,8 @@ except ImportError: torch = None +from test.db_config import DBConfig + from superduperdb.backends.mongodb.query import Collection from superduperdb.base.document import Document as D from superduperdb.components.dataset import Dataset @@ -17,7 +19,7 @@ @pytest.fixture -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def transformers_model(db): from transformers import AutoModelForSequenceClassification, AutoTokenizer @@ -57,7 +59,7 @@ def td(): @pytest.mark.skipif(not torch, reason='Torch not installed') -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_tranformers_fit(transformers_model, db, td): repo_name = td training_args = TransformersTrainerConfiguration( diff --git a/test/unittest/ext/test_vanilla.py b/test/unittest/ext/test_vanilla.py index ee889a28e9..abb7b95259 100644 --- a/test/unittest/ext/test_vanilla.py +++ b/test/unittest/ext/test_vanilla.py @@ -1,3 +1,5 @@ +from test.db_config import DBConfig + import pytest from superduperdb.backends.mongodb.query import Collection @@ -27,7 +29,7 @@ def test_function_predict(): assert function.predict([1, 1]) == [1, 1] -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_function_predict_with_document_embedded(data_in_db): function = Model( object=lambda x: x, @@ -41,7 +43,7 @@ def test_function_predict_with_document_embedded(data_in_db): assert [o['_outputs']['X']['test']['0'] for o in out] == [1, 2, 3, 4, 5] -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_function_predict_without_document_embedded(data_in_db): function = Model(object=lambda x: x, identifier='test') function.predict( @@ -51,7 +53,7 @@ def test_function_predict_without_document_embedded(data_in_db): assert [o['_outputs']['X']['test']['0'] for o in out] == [1, 2, 3, 4, 5] -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_function_predict_with_flatten_outputs(data_in_db): function = Model( object=lambda x: [x, x, x] if x > 2 else [x, x], @@ -92,7 +94,7 @@ def test_function_predict_with_flatten_outputs(data_in_db): @pytest.mark.skip -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_function_predict_with_mix_flatten_outputs(data_in_db): function = Model( object=lambda x: x if x < 2 else [x, x, x], diff --git a/test/unittest/misc/test_downloaders.py b/test/unittest/misc/test_downloaders.py index e5ca916dcc..986b6bd87c 100644 --- a/test/unittest/misc/test_downloaders.py +++ b/test/unittest/misc/test_downloaders.py @@ -1,6 +1,7 @@ import os import tempfile import uuid +from test.db_config import DBConfig import pytest @@ -27,7 +28,7 @@ def patch_cfg_downloads(monkeypatch): yield -@pytest.mark.parametrize("db", [('mongodb', {'empty': True})], indirect=True) +@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_file_blobs(db, patch_cfg_downloads, image_url): to_insert = [ Document(