Skip to content

Commit

Permalink
Fix the bug when using sqlite as metadata link
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou committed Nov 27, 2023
1 parent b327a52 commit e8b91df
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 61 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,4 @@ test/sleep.json
.vscode

/tmp
.superduperdb
21 changes: 16 additions & 5 deletions superduperdb/backends/mongodb/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,18 +223,29 @@ def delete_component_version(
if self._component_used(type_id, identifier, version=version):
raise Exception('Component version already in use in other components!')

self.parent_child_mappings.delete_many(
{'parent': Component.make_unique_id(type_id, identifier, version)}
)

return self.component_collection.delete_many(
delete_result = self.component_collection.delete_many(
{
'identifier': identifier,
'type_id': type_id,
'version': version,
}
)

parent_unique_id = Component.make_unique_id(type_id, identifier, version)

# TODO: Do we need to delete the child component here?
# We delete the child component in SQLAlchemyMetadata,
# but not in MongoMetaDataStore
children_unique_ids = [
r['child']
for r in self.parent_child_mappings.find({'parent': parent_unique_id})
]
for child_unique_id in children_unique_ids:
type_id, identifier, version = Component.parse_unique_id(child_unique_id)
self.delete_component_version(type_id, identifier, version)

return delete_result

def _get_component(
self,
type_id: str,
Expand Down
95 changes: 60 additions & 35 deletions superduperdb/backends/sqlalchemy/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import contextmanager

import click
from bson import ObjectId
from sqlalchemy import JSON, Boolean, Column, DateTime, Integer, String
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship, sessionmaker
Expand All @@ -11,12 +12,14 @@
from superduperdb.backends.base.metadata import MetaDataStore, NonExistentMetadataError
from superduperdb.base import exceptions
from superduperdb.base.serializable import Serializable
from superduperdb.components.component import Component as _Component
from superduperdb.misc.colors import Colors

if t.TYPE_CHECKING:
from superduperdb.backends.base.query import Select

Base = declarative_base()
DEFAULT_LENGTH = 255


class DictMixin:
Expand All @@ -27,63 +30,63 @@ def as_dict(self):
class QueryID(Base): # type: ignore[valid-type, misc]
__tablename__ = 'query_id_table'

query_id = Column(String, primary_key=True)
query_id = Column(String(DEFAULT_LENGTH), primary_key=True)
query = Column(JSON)
model = Column(String)
model = Column(String(DEFAULT_LENGTH))


class Job(Base, DictMixin): # type: ignore[valid-type, misc]
__tablename__ = 'job'

identifier = Column(String, primary_key=True)
component_identifier = Column(String)
type_id = Column(String)
identifier = Column(String(DEFAULT_LENGTH), primary_key=True)
component_identifier = Column(String(DEFAULT_LENGTH))
type_id = Column(String(DEFAULT_LENGTH))
info = Column(JSON)
time = Column(DateTime)
status = Column(String)
status = Column(String(DEFAULT_LENGTH))
args = Column(JSON)
kwargs = Column(JSON)
method_name = Column(String)
method_name = Column(String(DEFAULT_LENGTH))
stdout = Column(JSON)
stderr = Column(JSON)
cls = Column(String)
cls = Column(String(DEFAULT_LENGTH))


class ParentChildAssociation(Base): # type: ignore[valid-type, misc]
__tablename__ = 'parent_child_association'

parent_id = Column(String, primary_key=True)
child_id = Column(String, primary_key=True)
parent_id = Column(String(DEFAULT_LENGTH), primary_key=True)
child_id = Column(String(DEFAULT_LENGTH), primary_key=True)


class Component(Base, DictMixin): # type: ignore[valid-type, misc]
__tablename__ = 'component'

id = Column(String, primary_key=True)
identifier = Column(String)
id = Column(String(DEFAULT_LENGTH), primary_key=True)
identifier = Column(String(DEFAULT_LENGTH))
version = Column(Integer)
hidden = Column(Boolean)
type_id = Column(String)
cls = Column(String)
module = Column(String)
type_id = Column(String(DEFAULT_LENGTH))
cls = Column(String(DEFAULT_LENGTH))
module = Column(String(DEFAULT_LENGTH))
dict = Column(JSON)

# Define the parent-child relationship
parents = relationship(
children = relationship(
"Component",
secondary=ParentChildAssociation.__table__,
primaryjoin=id == ParentChildAssociation.parent_id,
secondaryjoin=id == ParentChildAssociation.child_id,
backref="children",
backref="parents",
cascade="all, delete",
)


class Meta(Base, DictMixin): # type: ignore[valid-type, misc]
__tablename__ = 'meta'

key = Column(String, primary_key=True)
value = Column(String)
key = Column(String(DEFAULT_LENGTH), primary_key=True)
value = Column(String(DEFAULT_LENGTH))


class SQLAlchemyMetadata(MetaDataStore):
Expand Down Expand Up @@ -145,16 +148,14 @@ def session_context(self):
def component_version_has_parents(
self, type_id: str, identifier: str, version: int
):
unique_id = _Component.make_unique_id(type_id, identifier, version)
with self.session_context() as session:
return (
session.query(Component)
session.query(ParentChildAssociation)
.filter(
Component.type_id == type_id,
Component.identifier == identifier,
Component.version == version,
ParentChildAssociation.child_id == unique_id,
)
.first()
.parent_id
is not None
)

Expand Down Expand Up @@ -226,24 +227,25 @@ def _get_component(
.first()
)

return res.as_dict()
return res.as_dict() if res else None

def get_component_version_parents(self, unique_id: str):
with self.session_context() as session:
components = (
session.query(Component)
assocations = (
session.query(ParentChildAssociation)
.filter(
Component.id == unique_id,
ParentChildAssociation.child_id == unique_id,
)
.all()
)
return sum([c.parents for c in components], [])
parents = [a.parent_id for a in assocations]
return parents

def get_latest_version(
self, type_id: str, identifier: str, allow_hidden: bool = False
):
with self.session_context() as session:
return (
component = (
session.query(Component)
.filter(
Component.type_id == type_id,
Expand All @@ -252,8 +254,12 @@ def get_latest_version(
)
.order_by(Component.version.desc())
.first()
.version
)
if component is None:
raise FileNotFoundError(
f'Can\'t find {type_id}: {identifier} in metadata'
)
return component.version

def hide_component_version(self, type_id: str, identifier: str, version: int):
with self.session_context() as session:
Expand All @@ -269,7 +275,7 @@ def _replace_object(self, info, identifier, type_id, version):
Component.type_id == type_id,
Component.identifier == identifier,
Component.version == version,
).update({'dict': info})
).update(info)

def replace_component(
self,
Expand All @@ -290,15 +296,17 @@ def replace_component(
def show_components(self, type_id: t.Optional[str] = None, **kwargs):
if type_id is not None:
with self.session_context() as session:
return [
identifiers = [
c.identifier
for c in session.query(Component)
.filter(Component.type_id == type_id)
.all()
]
else:
with self.session_context() as session:
return [c.identifier for c in session.query(Component).all()]
identifiers = [c.identifier for c in session.query(Component).all()]
identifiers = sorted(set(identifiers), key=lambda x: identifiers.index(x))
return identifiers

def show_component_versions(self, type_id: str, identifier: str):
with self.session_context() as session:
Expand Down Expand Up @@ -331,7 +339,9 @@ def _update_object(
def create_job(self, info: t.Dict):
try:
with self.session_context() as session:
session.add(Job(**info))
job = Job(**info)
convert_object_id_type(job)
session.add(job)
except Exception as e:
raise exceptions.MetaDataStoreJobException(
'Error while creating job in metadata store'
Expand Down Expand Up @@ -447,3 +457,18 @@ def get_model_queries(self, model: str):
{'query_id': id, 'query': query, 'sql': query.repr_()}
)
return unpacked_queries


def convert_object_id_type(job: Job):
kwargs = job.kwargs
ids = kwargs.get('ids', [])
if not isinstance(ids, list):
return

new_ids = []
for id in ids:
if isinstance(id, ObjectId):
new_ids.append(str(id))
else:
new_ids.append(id)
kwargs['ids'] = new_ids
16 changes: 14 additions & 2 deletions superduperdb/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,21 @@ def build(uri, mapping):
conn = mongomock.MongoClient()
return mapping['mongodb'](conn, name)
else:
name = uri.split('//')[0]
conn = ibis.connect(uri)
return mapping['ibis'](conn, name)
db_name = conn.current_database
if 'ibis' in mapping:
cls_ = mapping['ibis']
elif 'sqlalchemy' in mapping:
cls_ = mapping['sqlalchemy']
conn = conn.con
else:
raise ValueError('No ibis or sqlalchemy backend specified')

# if ':memory:' in uri:
# conn = uri
# db_name = None

return cls_(conn, db_name)


def build_datalayer(cfg=None, **kwargs) -> Datalayer:
Expand Down
6 changes: 3 additions & 3 deletions superduperdb/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,19 +1037,19 @@ def _download_content( # TODO: duplicated function
n_download_workers = self.metadata.get_metadata(
key='n_download_workers'
)
except TypeError:
except exceptions.MetadatastoreException:
n_download_workers = 0

if headers is None:
try:
headers = self.metadata.get_metadata(key='headers')
except TypeError:
except exceptions.MetadatastoreException:
headers = 0

if timeout is None:
try:
timeout = self.metadata.get_metadata(key='download_timeout')
except TypeError:
except exceptions.MetadatastoreException:
timeout = None

def download_update(key, id, bytes):
Expand Down
5 changes: 5 additions & 0 deletions superduperdb/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,8 @@ def schedule_jobs(
@classmethod
def make_unique_id(cls, type_id: str, identifier: str, version: int) -> str:
return f'{type_id}/{identifier}/{version}'

@classmethod
def parse_unique_id(cls, unique_id: str) -> t.Tuple[str, str, int]:
type_id, identifier, version = unique_id.split('/')
return type_id, identifier, int(version)
22 changes: 19 additions & 3 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,22 @@ def local_db(request) -> Datalayer:


@pytest.fixture
def local_empty_db(request) -> Datalayer:
db = build_datalayer(CFG, data_backend=MONGOMOCK_URI)
return db
def local_empty_db(request, monkeypatch) -> Datalayer:
for key, value in DB_CONFIGS.items():
monkeypatch.setattr(CFG, key, value)
db = build_datalayer(CFG)
yield db
db.drop(force=True)


DB_CONFIGS = {
'data_backend': MONGOMOCK_URI,
# 'metadata_store': "sqlite://:memory:",
# 'data_backend': "sqlite://:memory:",
# 'data_backend': "mongodb://testmongodbuser:testmongodbpassword@localhost:27018/test_db",
# 'metadata_store': "mysql://root:root123@localhost:3306/test_db",
'metadata_store': "sqlite://:memory:",
# 'data_backend': "mysql://root:root123@localhost:3306/test_db",
# 'metadata_store': "sqlite://mydb.sqlite",
'artifact_store': 'filesystem:///tmp/superduperdb_test',
}
Loading

0 comments on commit e8b91df

Please sign in to comment.