diff --git a/bdpy/dataform/kvs.py b/bdpy/dataform/kvs.py index 3387099..ea593f9 100644 --- a/bdpy/dataform/kvs.py +++ b/bdpy/dataform/kvs.py @@ -168,19 +168,20 @@ def delete(self, **kwargs) -> None: if key_group_id is None: return None - # Delete from key_value_store - sql = f"DELETE FROM key_value_store WHERE id = {key_group_id}" - cursor = self._conn.cursor() - cursor.execute(sql) + # Delete from key_group_members and key_value_store + sqls = [ + f""" + DELETE FROM key_group_members WHERE key_value_store_id = {key_group_id} + """, + f""" + DELETE FROM key_value_store WHERE id = {key_group_id} + """, + ] + self._conn.execute("BEGIN TRANSACTION;") + for sql in sqls: + self._conn.execute(sql) self._conn.commit() - cursor.close() - # Delete from key_group_members - sql = f"DELETE FROM key_group_members WHERE key_value_store_id = {key_group_id}" - cursor = self._conn.cursor() - cursor.execute(sql) - self._conn.commit() - cursor.close() return None def _get_key_group_id(self, **kwargs) -> Optional[int]: diff --git a/tests/dataform/test_kvs.py b/tests/dataform/test_kvs.py index f3970de..d33afbb 100644 --- a/tests/dataform/test_kvs.py +++ b/tests/dataform/test_kvs.py @@ -153,7 +153,7 @@ def test_set_get(self): kvs.set(np.array([]), layer="conv1", subject="sub04", roi="LOC", metric="accuracy") val = kvs.get(layer="conv1", subject="sub04", roi="LOC", metric="accuracy") np.testing.assert_array_equal(val, np.array([])) - + # Found (np.nan) kvs.set(np.array([np.nan]), layer="conv1", subject="sub04", roi="FFA", metric="accuracy") val = kvs.get(layer="conv1", subject="sub04", roi="FFA", metric="accuracy") @@ -175,7 +175,22 @@ def test_update(self): kvs.set(np.array([10, 20, 30, 40]), layer="conv1", subject="sub03", roi="PPA", metric="accuracy") val = kvs.get(layer="conv1", subject="sub03", roi="PPA", metric="accuracy") np.testing.assert_array_equal(val, np.array([10, 20, 30, 40])) - + + def test_delete(self): + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test_3304.db") + self._init_test_db(db_path) + + kvs = SQLite3KeyValueStore(db_path) + + kvs.set(np.array([ 1, 2, 3, 4]), layer="conv1", subject="sub03", roi="LOC", metric="accuracy") + kvs.set(np.array([ 5, 6, 7, 8]), layer="conv1", subject="sub03", roi="FFA", metric="accuracy") + kvs.set(np.array([np.nan]), layer="conv1", subject="sub03", roi="PPA", metric="accuracy") + + kvs.delete(layer="conv1", subject="sub03", roi="PPA", metric="accuracy") + np.testing.assert_(kvs.exists(layer="conv1", subject="sub03", roi="LOC", metric="accuracy")) + np.testing.assert_(~kvs.exists(layer="conv1", subject="sub03", roi="PPA", metric="accuracy"), + 'AssertionError: Failed to delete the record.') if __name__ == "__main__": unittest.main()