Skip to content

Commit

Permalink
Improve _Predictor developer contract
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Feb 22, 2024
1 parent fb6a9f6 commit 1c509a6
Show file tree
Hide file tree
Showing 59 changed files with 1,643 additions and 2,539 deletions.
3 changes: 1 addition & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

**Before you create a Pull Request, remember to update the Changelog with your changes.**



## Changes Since Last Release

#### Changed defaults / behaviours
Expand All @@ -17,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
#### New Features & Functionality
- CI fails if CHANGELOG.md is not updated on PRs
- Update Menu structure and renamed use-cases
- Change and simplify the contract for writing new `_Predictor` descendants (`.predict_one`, `.predict`)

#### Bug Fixes
- LLM CI random errors
Expand Down
5 changes: 1 addition & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,8 @@ fix-and-test: ## Lint the code before testing
# Linter and code formatting
ruff check --fix $(DIRECTORIES)
# Linting
rm -rf .mypy_cache/
mypy superduperdb
# Unit testing
pytest $(PYTEST_ARGUMENTS)
# Check for missing docstrings
interrogate superduperdb



Expand Down
2 changes: 1 addition & 1 deletion examples/llm_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@
prompt = "### Human: Who are you? ### Assistant: "

# Automatically load lora model for prediction, default use the latest checkpoint
print(llm.predict(prompt, max_new_tokens=100, do_sample=True))
print(llm.predict_in_db(prompt, max_new_tokens=100, do_sample=True))
3 changes: 2 additions & 1 deletion superduperdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from .components.datatype import DataType, Encoder
from .components.listener import Listener
from .components.metric import Metric
from .components.model import Model
from .components.model import Model, ObjectModel
from .components.schema import Schema
from .components.vector_index import VectorIndex, vector

Expand All @@ -31,6 +31,7 @@
'DataType',
'Encoder',
'Document',
'ObjectModel',
'Model',
'Listener',
'VectorIndex',
Expand Down
4 changes: 2 additions & 2 deletions superduperdb/backends/base/data_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import typing as t
from abc import ABC, abstractmethod

from superduperdb.components.model import APIModel, Model
from superduperdb.components.model import APIModel, ObjectModel


class BaseDataBackend(ABC):
Expand Down Expand Up @@ -35,7 +35,7 @@ def build_artifact_store(self):
"""
pass

def create_model_table_or_collection(self, model: t.Union[Model, APIModel]):
def create_model_table_or_collection(self, model: t.Union[ObjectModel, APIModel]):
pass

@abstractmethod
Expand Down
4 changes: 2 additions & 2 deletions superduperdb/backends/ibis/data_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from superduperdb.backends.ibis.utils import get_output_table_name
from superduperdb.backends.local.artifacts import FileSystemArtifactStore
from superduperdb.backends.sqlalchemy.metadata import SQLAlchemyMetadata
from superduperdb.components.model import APIModel, Model
from superduperdb.components.model import APIModel, ObjectModel
from superduperdb.components.schema import Schema

BASE64_PREFIX = 'base64:'
Expand Down Expand Up @@ -50,7 +50,7 @@ def insert(self, table_name, raw_documents):
else:
self.conn.create_table(table_name, pandas.DataFrame(raw_documents))

def create_model_table_or_collection(self, model: t.Union[Model, APIModel]):
def create_model_table_or_collection(self, model: t.Union[ObjectModel, APIModel]):
msg = (
"Model must have an encoder to create with the"
f" {type(self).__name__} backend."
Expand Down
2 changes: 1 addition & 1 deletion superduperdb/backends/mongodb/artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _load_bytes(self, file_id: str):
cur = self.filesystem.find_one({'filename': file_id})
if cur is None:
raise FileNotFoundError(f'File not found in {file_id}')
return next(cur)
return cur.read()

def _save_bytes(self, serialized: bytes, file_id: str):
return self.filesystem.put(serialized, filename=file_id)
Expand Down
77 changes: 43 additions & 34 deletions superduperdb/backends/query_dataset.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import inspect
import random
import typing as t

from superduperdb.backends.base.query import Select
from superduperdb.misc.special_dicts import MongoStyleDict

if t.TYPE_CHECKING:
from superduperdb.components.model import Mapping


class ExpiryCache(list):
def __getitem__(self, index):
Expand Down Expand Up @@ -31,51 +35,48 @@ class QueryDataset:
def __init__(
self,
select: Select,
keys: t.Optional[t.List[str]] = None,
mapping: t.Optional['Mapping'] = None,
ids: t.Optional[t.List[str]] = None,
fold: t.Union[str, None] = 'train',
suppress: t.Sequence[str] = (),
transform: t.Optional[t.Callable] = None,
db=None,
ids: t.Optional[t.List[str]] = None,
in_memory: bool = True,
extract: t.Optional[str] = None,
**kwargs,
):
self._database = db
self.keys = keys
self._db = db

self.transform = transform if transform else lambda x: x
self.transform = transform
if fold is not None:
self.select = select.add_fold(fold)
else:
self.select = select

self.in_memory = in_memory
if self.in_memory:
if ids is None:
self._documents = list(self.database.execute(self.select))
self._documents = list(self.db.execute(self.select))
else:
self._documents = list(
self.database.execute(self.select.select_using_ids(ids))
self.db.execute(self.select.select_using_ids(ids))
)
else:
if ids is None:
self._ids = [
r[self.select.id_field]
for r in self.database.execute(self.select.select_ids)
for r in self.db.execute(self.select.select_ids)
]
else:
self._ids = ids
self.select_one = self.select.select_single_id
self.suppress = suppress
self.extract = extract

self.mapping = mapping

@property
def database(self):
if self._database is None:
def db(self):
if self._db is None:
from superduperdb.base.build import build_datalayer

self._database = build_datalayer()
return self._database
self._db = build_datalayer()
return self._db

def __len__(self):
if self.in_memory:
Expand All @@ -88,22 +89,25 @@ def __getitem__(self, item):
input = self._documents[item]
else:
input = self.select_one(
self._ids[item], self.database, encoders=self.database.datatypes
self._ids[item], self.db, encoders=self.db.datatypes
)
r = MongoStyleDict(input.unpack())
s = MongoStyleDict({})

if self.keys is not None:
for k in self.keys:
if k == '_base':
s[k] = r
else:
s[k] = r[k]
else:
s = r
out = self.transform(s)
if self.extract:
out = out[self.extract]
input = MongoStyleDict(input.unpack(db=self.db))
from superduperdb.components.model import Signature

out = input
if self.mapping is not None:
out = self.mapping(out)
if self.transform is not None and self.mapping is not None:
if self.mapping.signature == Signature.args_kwargs:
out = self.transform(*out[0], **out[1])
elif self.mapping.signature == Signature.args:
out = self.transform(*out)
elif self.mapping.signature == Signature.kwargs:
out = self.transform(**out)
elif self.mapping.signature == Signature.singleton:
out = self.transform(out)
elif self.transform is not None:
out = self.transform(out)
return out


Expand Down Expand Up @@ -204,7 +208,12 @@ def __getitem__(self, index):
return self.transform(s)


def query_dataset_factory(data_prefetch: bool = False, **kwargs):
if data_prefetch:
def query_dataset_factory(**kwargs):
if kwargs.get('data_prefetch', False):
return CachedQueryDataset(**kwargs)
kwargs = {
k: v
for k, v in kwargs.items()
if k in inspect.signature(QueryDataset.__init__).parameters
}
return QueryDataset(**kwargs)
29 changes: 17 additions & 12 deletions superduperdb/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from superduperdb.base.datalayer import Datalayer


def build_metadata(cfg, databackend: t.Optional['BaseDataBackend'] = None):
def _build_metadata(cfg, databackend: t.Optional['BaseDataBackend'] = None):
# Connect to metadata store.
# ------------------------------
# 1. try to connect to the metadata store specified in the configuration.
Expand All @@ -28,7 +28,9 @@ def build_metadata(cfg, databackend: t.Optional['BaseDataBackend'] = None):
if cfg.metadata_store is not None:
# try to connect to the metadata store specified in the configuration.
logging.info("Connecting to Metadata Client:", cfg.metadata_store)
return build(cfg.metadata_store, metadata_stores, type='metadata')
return _build_databackend_impl(
cfg.metadata_store, metadata_stores, type='metadata'
)
else:
try:
# try to connect to the data backend engine.
Expand All @@ -45,19 +47,21 @@ def build_metadata(cfg, databackend: t.Optional['BaseDataBackend'] = None):
try:
# try to connect to the data backend uri.
logging.info("Connecting to Metadata Client with URI: ", cfg.data_backend)
return build(cfg.data_backend, metadata_stores, type='metadata')
return _build_databackend_impl(
cfg.data_backend, metadata_stores, type='metadata'
)
except Exception as e:
# Exit quickly if a connection fails.
logging.error("Error initializing to Metadata Client:", str(e))
sys.exit(1)


def build_databackend(cfg, databackend=None):
def _build_databackend(cfg, databackend=None):
# Connect to data backend.
# ------------------------------
try:
if not databackend:
databackend = build(cfg.data_backend, data_backends)
databackend = _build_databackend_impl(cfg.data_backend, data_backends)
logging.info("Data Client is ready.", databackend.conn)
except Exception as e:
# Exit quickly if a connection fails.
Expand All @@ -66,7 +70,7 @@ def build_databackend(cfg, databackend=None):
return databackend


def build_artifact_store(
def _build_artifact_store(
artifact_store: t.Optional[str] = None,
databackend: t.Optional['BaseDataBackend'] = None,
):
Expand All @@ -90,7 +94,7 @@ def build_artifact_store(


# Helper function to build a data backend based on the URI.
def build(uri, mapping, type: str = 'data_backend'):
def _build_databackend_impl(uri, mapping, type: str = 'data_backend'):
logging.debug(f"Parsing data connection URI:{uri}")

if re.match('^mongodb:\/\/', uri) is not None:
Expand Down Expand Up @@ -140,7 +144,7 @@ def build(uri, mapping, type: str = 'data_backend'):
return mapping['sqlalchemy'](sql_conn, name)


def build_compute(compute):
def _build_compute(compute):
logging.info("Connecting to compute client:", compute)

if compute == 'local' or compute is None:
Expand Down Expand Up @@ -170,6 +174,7 @@ def build_datalayer(cfg=None, databackend=None, **kwargs) -> Datalayer:
:param cfg: Configuration to use. If None, use ``superduperdb.CFG``.
:param databackend: Databacked to use.
If None, use ``superduperdb.CFG.data_backend``.
:pararm kwargs: keyword arguments to be adopted by the `CFG`
"""

# Configuration
Expand All @@ -185,17 +190,17 @@ def build_datalayer(cfg=None, databackend=None, **kwargs) -> Datalayer:
cfg.force_set(k, v)

# Build databackend
databackend = build_databackend(cfg, databackend)
databackend = _build_databackend(cfg, databackend)

# Build metadata store
metadata = build_metadata(cfg, databackend)
metadata = _build_metadata(cfg, databackend)
assert metadata

# Build artifact store
artifact_store = build_artifact_store(cfg.artifact_store, databackend)
artifact_store = _build_artifact_store(cfg.artifact_store, databackend)

# Build compute
compute = build_compute(cfg.cluster.compute)
compute = _build_compute(cfg.cluster.compute)

# Build DataLayer
# ------------------------------
Expand Down
1 change: 0 additions & 1 deletion superduperdb/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,6 @@ class Config(BaseConfig):
"""

data_backend: str = 'mongodb://superduper:superduper@localhost:27017/test_db'

lance_home: str = os.path.join('.superduperdb', 'vector_indices')

artifact_store: t.Optional[str] = None
Expand Down
Loading

0 comments on commit 1c509a6

Please sign in to comment.