Skip to content

Commit

Permalink
Fix junk cdc testing errors
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Dec 1, 2023
1 parent 83e9360 commit b7de69f
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 76 deletions.
1 change: 1 addition & 0 deletions superduperdb/cdc/cdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def stop(self, name: str = ''):
for _, listener in self._CDC_LISTENERS.items():
listener.stop()
finally:
self._running = False
self._CDC_LISTENERS = {}
self.stop_handler()

Expand Down
117 changes: 41 additions & 76 deletions test/integration/test_cdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,7 @@


@pytest.fixture
def listener_and_collection_name(database_with_default_encoders_and_model):
db = database_with_default_encoders_and_model
collection_name = 'documents'
db.cdc._cdc_existing_collections = []
listener = db.cdc.listen(on=Collection(collection_name), timeout=LISTEN_TIMEOUT)
db.cdc.cdc_change_handler._QUEUE_BATCH_SIZE = 1

yield listener, collection_name, db

db.cdc.stop()


@pytest.fixture
def database_listener_with_lance_searcher(database_with_default_encoders_and_model):
def database_with_cdc(database_with_default_encoders_and_model):
db = database_with_default_encoders_and_model

from superduperdb import CFG
Expand All @@ -61,26 +48,18 @@ def database_listener_with_lance_searcher(database_with_default_encoders_and_mod
searcher_type='lance',
)

db.cdc._cdc_existing_collections = ['documents']
db.cdc.listen(on=Collection('documents'), timeout=LISTEN_TIMEOUT)
db.cdc._cdc_existing_collections = []
listener = db.cdc.listen(on=Collection('documents'), timeout=LISTEN_TIMEOUT)
db.cdc.cdc_change_handler._QUEUE_BATCH_SIZE = 1

yield db, 'documents'
yield listener, 'documents', db

CFG.force_set('vector_search', 'in_memory')
db.cdc.stop()
shutil.rmtree('.superduperdb/vector_indices')


@pytest.fixture
def listener_without_cdc_handler_and_collection_name(
database_with_default_encoders_and_model,
):
db = database_with_default_encoders_and_model
collection_name = 'documents'
db.cdc._cdc_existing_collections = []
listener = db.cdc.listen(on=Collection(collection_name), timeout=LISTEN_TIMEOUT)
yield listener, collection_name, db
db.cdc.stop()
try:
shutil.rmtree('.superduperdb/vector_indices')
except FileNotFoundError:
pass


def retry_state_check(state_check):
Expand All @@ -98,23 +77,23 @@ def retry_state_check(state_check):


@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_smoke(listener_without_cdc_handler_and_collection_name):
def test_smoke(database_with_cdc):
"""Health-check before we test stateful database changes"""
_, name, db = listener_without_cdc_handler_and_collection_name
_, name, db = database_with_cdc
assert isinstance(name, str)


@pytest.mark.parametrize('op_type', ['insert'])
@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_task_workflow(
listener_and_collection_name,
database_with_cdc,
fake_inserts,
fake_updates,
op_type,
):
"""Test that task graph executed on `insert`"""

_, name, db = listener_and_collection_name
_, name, db = database_with_cdc

with add_and_cleanup_listeners(db, name) as database_with_listeners:
# `refresh=False` to ensure `_outputs` not produced after `Insert` refresh.
Expand Down Expand Up @@ -153,10 +132,14 @@ def state_check_2():

@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_vector_database_sync_with_delete(
database_listener_with_lance_searcher,
database_with_cdc,
fake_inserts,
):
db, name = database_listener_with_lance_searcher
(
_,
name,
db,
) = database_with_cdc

inserted_ids, _ = db.execute(
Collection(name).insert_many(fake_inserts[:2]),
Expand All @@ -183,10 +166,10 @@ def state_check_2():

@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_vector_database_sync(
database_listener_with_lance_searcher,
database_with_cdc,
fake_inserts,
):
db, name = database_listener_with_lance_searcher
_, name, db = database_with_cdc
db.execute(
Collection(name).insert_many([fake_inserts[0]]),
refresh=False,
Expand All @@ -202,10 +185,10 @@ def state_check():

@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_single_insert(
listener_without_cdc_handler_and_collection_name,
database_with_cdc,
fake_inserts,
):
listener, name, db = listener_without_cdc_handler_and_collection_name
listener, name, db = database_with_cdc
db.execute(
Collection(name).insert_many([fake_inserts[0]]),
refresh=False,
Expand All @@ -219,10 +202,10 @@ def state_check():

@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_many_insert(
listener_without_cdc_handler_and_collection_name,
database_with_cdc,
fake_inserts,
):
listener, name, db = listener_without_cdc_handler_and_collection_name
listener, name, db = database_with_cdc
db.execute(
Collection(name).insert_many(fake_inserts),
refresh=False,
Expand All @@ -236,62 +219,44 @@ def state_check():

@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_delete_one(
listener_without_cdc_handler_and_collection_name,
database_with_cdc,
fake_inserts,
):
listener, name, db = listener_without_cdc_handler_and_collection_name
listener, name, db = database_with_cdc
db.cdc.stop()
inserted_ids, _ = db.execute(
Collection(name).insert_many(fake_inserts),
refresh=False,
)
listener = db.cdc.listen(on=Collection('documents'), timeout=LISTEN_TIMEOUT)

db.execute(Collection(name).delete_one({'_id': inserted_ids[0]}))
db.execute(Collection(name).delete_one({'_id': inserted_ids[0]}), refresh=False)

def state_check():
assert listener.info()["deletes"] == 1

retry_state_check(state_check)


@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_single_update(
listener_without_cdc_handler_and_collection_name,
fake_updates,
):
listener, name, db = listener_without_cdc_handler_and_collection_name
inserted_ids, _ = db.execute(
Collection(name).insert_many(fake_updates),
refresh=False,
)
encoder = db.encoders['torch.float32[32]']
db.execute(
Collection(name).update_many(
{"_id": inserted_ids[0]},
Document({'$set': {'x': encoder(torch.randn(32))}}),
)
)

def state_check():
assert listener.info()["updates"] == 1

retry_state_check(state_check)


@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_many_update(
listener_without_cdc_handler_and_collection_name,
database_with_cdc,
fake_updates,
):
listener, name, db = listener_without_cdc_handler_and_collection_name
listener, name, db = database_with_cdc
db.cdc.stop()
inserted_ids, _ = db.execute(
Collection(name).insert_many(fake_updates), refresh=False
)
encoder = db.encoders['torch.float32[32]']
listener = db.cdc.listen(on=Collection('documents'), timeout=LISTEN_TIMEOUT)

db.execute(
Collection(name).update_many(
{"_id": {"$in": inserted_ids[:5]}},
Document({'$set': {'x': encoder(torch.randn(32))}}),
)
),
refresh=False,
)

def state_check():
Expand All @@ -302,11 +267,11 @@ def state_check():

@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_insert_without_cdc_handler(
listener_without_cdc_handler_and_collection_name,
database_with_cdc,
fake_inserts,
):
"""Test that `insert` without CDC handler does not execute task graph"""
_, name, db = listener_without_cdc_handler_and_collection_name
_, name, db = database_with_cdc
inserted_ids, _ = db.execute(
Collection(name).insert_many(fake_inserts),
refresh=False,
Expand All @@ -316,9 +281,9 @@ def test_insert_without_cdc_handler(


@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_cdc_stop(listener_and_collection_name):
def test_cdc_stop(database_with_cdc):
"""Test that CDC listen service stopped properly"""
listener, _, _ = listener_and_collection_name
listener, _, _ = database_with_cdc
listener.stop()

def state_check():
Expand Down

0 comments on commit b7de69f

Please sign in to comment.