From 1c509a6ae9fb302e0d95ca53bc95ec69efec4c93 Mon Sep 17 00:00:00 2001 From: Duncan Blythe Date: Tue, 20 Feb 2024 16:30:31 +0100 Subject: [PATCH] Improve _Predictor developer contract --- CHANGELOG.md | 3 +- Makefile | 5 +- examples/llm_finetune.py | 2 +- superduperdb/__init__.py | 3 +- superduperdb/backends/base/data_backend.py | 4 +- superduperdb/backends/ibis/data_backend.py | 4 +- superduperdb/backends/mongodb/artifacts.py | 2 +- superduperdb/backends/query_dataset.py | 77 +- superduperdb/base/build.py | 29 +- superduperdb/base/config.py | 1 - superduperdb/base/datalayer.py | 162 +-- superduperdb/base/document.py | 38 +- superduperdb/base/leaf.py | 6 + superduperdb/base/serializable.py | 5 + superduperdb/components/component.py | 19 +- superduperdb/components/datatype.py | 65 +- superduperdb/components/listener.py | 32 +- superduperdb/components/model.py | 1259 ++++++++--------- superduperdb/components/schema.py | 3 + superduperdb/components/vector_index.py | 22 +- superduperdb/ext/anthropic/model.py | 40 +- superduperdb/ext/cohere/model.py | 83 +- superduperdb/ext/jina/model.py | 36 +- superduperdb/ext/llamacpp/model.py | 19 +- superduperdb/ext/llm/base.py | 35 +- superduperdb/ext/llm/model.py | 33 +- superduperdb/ext/openai/model.py | 473 ++----- .../ext/sentence_transformers/model.py | 51 +- superduperdb/ext/sklearn/model.py | 118 +- superduperdb/ext/torch/model.py | 184 +-- superduperdb/ext/transformers/model.py | 78 +- superduperdb/misc/special_dicts.py | 2 + superduperdb/server/app.py | 2 +- superduperdb/vector_search/atlas.py | 4 +- test/conftest.py | 20 +- test/integration/conftest.py | 18 +- .../ext/anthropic/test_model_anthropic.py | 31 +- .../ext/cohere/test_model_cohere.py | 68 +- test/integration/ext/jina/test_model_jina.py | 35 +- .../ext/openai/test_model_openai.py | 269 +--- test/integration/test_atlas.py | 4 +- test/integration/test_cdc.py | 6 +- test/integration/test_end2end.py | 33 +- test/integration/test_ibis.py | 4 +- test/integration/test_ray.py | 8 +- test/unittest/backends/ibis/test_query.py | 6 +- test/unittest/backends/test_query_dataset.py | 18 +- test/unittest/base/test_datalayer.py | 103 +- test/unittest/base/test_serializable.py | 10 +- test/unittest/component/test_component.py | 54 +- test/unittest/component/test_listener.py | 4 +- test/unittest/component/test_model.py | 505 +++---- test/unittest/component/test_serialization.py | 4 +- test/unittest/ext/llm/utils.py | 9 +- test/unittest/ext/test_llama_cpp.py | 4 +- test/unittest/ext/test_torch.py | 36 +- test/unittest/ext/test_transformers.py | 3 +- test/unittest/ext/test_vanilla.py | 25 +- test/unittest/test_quality.py | 6 +- 59 files changed, 1643 insertions(+), 2539 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c93153c03e..ab01664d7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 **Before you create a Pull Request, remember to update the Changelog with your changes.** - - ## Changes Since Last Release #### Changed defaults / behaviours @@ -17,6 +15,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 #### New Features & Functionality - CI fails if CHANGELOG.md is not updated on PRs - Update Menu structure and renamed use-cases +- Change and simplify the contract for writing new `_Predictor` descendants (`.predict_one`, `.predict`) #### Bug Fixes - LLM CI random errors diff --git a/Makefile b/Makefile index c1e98c43ed..79b22324d8 100644 --- a/Makefile +++ b/Makefile @@ -105,11 +105,8 @@ fix-and-test: ## Lint the code before testing # Linter and code formatting ruff check --fix $(DIRECTORIES) # Linting + rm -rf .mypy_cache/ mypy superduperdb - # Unit testing - pytest $(PYTEST_ARGUMENTS) - # Check for missing docstrings - interrogate superduperdb diff --git a/examples/llm_finetune.py b/examples/llm_finetune.py index 87715615cd..1fcc51e0be 100644 --- a/examples/llm_finetune.py +++ b/examples/llm_finetune.py @@ -60,4 +60,4 @@ prompt = "### Human: Who are you? ### Assistant: " # Automatically load lora model for prediction, default use the latest checkpoint -print(llm.predict(prompt, max_new_tokens=100, do_sample=True)) +print(llm.predict_in_db(prompt, max_new_tokens=100, do_sample=True)) diff --git a/superduperdb/__init__.py b/superduperdb/__init__.py index 71873059fe..bc58720343 100644 --- a/superduperdb/__init__.py +++ b/superduperdb/__init__.py @@ -17,7 +17,7 @@ from .components.datatype import DataType, Encoder from .components.listener import Listener from .components.metric import Metric -from .components.model import Model +from .components.model import Model, ObjectModel from .components.schema import Schema from .components.vector_index import VectorIndex, vector @@ -31,6 +31,7 @@ 'DataType', 'Encoder', 'Document', + 'ObjectModel', 'Model', 'Listener', 'VectorIndex', diff --git a/superduperdb/backends/base/data_backend.py b/superduperdb/backends/base/data_backend.py index 71c02dbfa7..8267d9b49e 100644 --- a/superduperdb/backends/base/data_backend.py +++ b/superduperdb/backends/base/data_backend.py @@ -1,7 +1,7 @@ import typing as t from abc import ABC, abstractmethod -from superduperdb.components.model import APIModel, Model +from superduperdb.components.model import APIModel, ObjectModel class BaseDataBackend(ABC): @@ -35,7 +35,7 @@ def build_artifact_store(self): """ pass - def create_model_table_or_collection(self, model: t.Union[Model, APIModel]): + def create_model_table_or_collection(self, model: t.Union[ObjectModel, APIModel]): pass @abstractmethod diff --git a/superduperdb/backends/ibis/data_backend.py b/superduperdb/backends/ibis/data_backend.py index e70654dd14..814802103f 100644 --- a/superduperdb/backends/ibis/data_backend.py +++ b/superduperdb/backends/ibis/data_backend.py @@ -12,7 +12,7 @@ from superduperdb.backends.ibis.utils import get_output_table_name from superduperdb.backends.local.artifacts import FileSystemArtifactStore from superduperdb.backends.sqlalchemy.metadata import SQLAlchemyMetadata -from superduperdb.components.model import APIModel, Model +from superduperdb.components.model import APIModel, ObjectModel from superduperdb.components.schema import Schema BASE64_PREFIX = 'base64:' @@ -50,7 +50,7 @@ def insert(self, table_name, raw_documents): else: self.conn.create_table(table_name, pandas.DataFrame(raw_documents)) - def create_model_table_or_collection(self, model: t.Union[Model, APIModel]): + def create_model_table_or_collection(self, model: t.Union[ObjectModel, APIModel]): msg = ( "Model must have an encoder to create with the" f" {type(self).__name__} backend." diff --git a/superduperdb/backends/mongodb/artifacts.py b/superduperdb/backends/mongodb/artifacts.py index 37b66325d1..c2db132797 100644 --- a/superduperdb/backends/mongodb/artifacts.py +++ b/superduperdb/backends/mongodb/artifacts.py @@ -46,7 +46,7 @@ def _load_bytes(self, file_id: str): cur = self.filesystem.find_one({'filename': file_id}) if cur is None: raise FileNotFoundError(f'File not found in {file_id}') - return next(cur) + return cur.read() def _save_bytes(self, serialized: bytes, file_id: str): return self.filesystem.put(serialized, filename=file_id) diff --git a/superduperdb/backends/query_dataset.py b/superduperdb/backends/query_dataset.py index 4a17dc23dc..5b332909bf 100644 --- a/superduperdb/backends/query_dataset.py +++ b/superduperdb/backends/query_dataset.py @@ -1,9 +1,13 @@ +import inspect import random import typing as t from superduperdb.backends.base.query import Select from superduperdb.misc.special_dicts import MongoStyleDict +if t.TYPE_CHECKING: + from superduperdb.components.model import Mapping + class ExpiryCache(list): def __getitem__(self, index): @@ -31,51 +35,48 @@ class QueryDataset: def __init__( self, select: Select, - keys: t.Optional[t.List[str]] = None, + mapping: t.Optional['Mapping'] = None, + ids: t.Optional[t.List[str]] = None, fold: t.Union[str, None] = 'train', - suppress: t.Sequence[str] = (), transform: t.Optional[t.Callable] = None, db=None, - ids: t.Optional[t.List[str]] = None, in_memory: bool = True, - extract: t.Optional[str] = None, - **kwargs, ): - self._database = db - self.keys = keys + self._db = db - self.transform = transform if transform else lambda x: x + self.transform = transform if fold is not None: self.select = select.add_fold(fold) else: self.select = select + self.in_memory = in_memory if self.in_memory: if ids is None: - self._documents = list(self.database.execute(self.select)) + self._documents = list(self.db.execute(self.select)) else: self._documents = list( - self.database.execute(self.select.select_using_ids(ids)) + self.db.execute(self.select.select_using_ids(ids)) ) else: if ids is None: self._ids = [ r[self.select.id_field] - for r in self.database.execute(self.select.select_ids) + for r in self.db.execute(self.select.select_ids) ] else: self._ids = ids self.select_one = self.select.select_single_id - self.suppress = suppress - self.extract = extract + + self.mapping = mapping @property - def database(self): - if self._database is None: + def db(self): + if self._db is None: from superduperdb.base.build import build_datalayer - self._database = build_datalayer() - return self._database + self._db = build_datalayer() + return self._db def __len__(self): if self.in_memory: @@ -88,22 +89,25 @@ def __getitem__(self, item): input = self._documents[item] else: input = self.select_one( - self._ids[item], self.database, encoders=self.database.datatypes + self._ids[item], self.db, encoders=self.db.datatypes ) - r = MongoStyleDict(input.unpack()) - s = MongoStyleDict({}) - - if self.keys is not None: - for k in self.keys: - if k == '_base': - s[k] = r - else: - s[k] = r[k] - else: - s = r - out = self.transform(s) - if self.extract: - out = out[self.extract] + input = MongoStyleDict(input.unpack(db=self.db)) + from superduperdb.components.model import Signature + + out = input + if self.mapping is not None: + out = self.mapping(out) + if self.transform is not None and self.mapping is not None: + if self.mapping.signature == Signature.args_kwargs: + out = self.transform(*out[0], **out[1]) + elif self.mapping.signature == Signature.args: + out = self.transform(*out) + elif self.mapping.signature == Signature.kwargs: + out = self.transform(**out) + elif self.mapping.signature == Signature.singleton: + out = self.transform(out) + elif self.transform is not None: + out = self.transform(out) return out @@ -204,7 +208,12 @@ def __getitem__(self, index): return self.transform(s) -def query_dataset_factory(data_prefetch: bool = False, **kwargs): - if data_prefetch: +def query_dataset_factory(**kwargs): + if kwargs.get('data_prefetch', False): return CachedQueryDataset(**kwargs) + kwargs = { + k: v + for k, v in kwargs.items() + if k in inspect.signature(QueryDataset.__init__).parameters + } return QueryDataset(**kwargs) diff --git a/superduperdb/base/build.py b/superduperdb/base/build.py index e338121b06..f4c6a8b5b0 100644 --- a/superduperdb/base/build.py +++ b/superduperdb/base/build.py @@ -19,7 +19,7 @@ from superduperdb.base.datalayer import Datalayer -def build_metadata(cfg, databackend: t.Optional['BaseDataBackend'] = None): +def _build_metadata(cfg, databackend: t.Optional['BaseDataBackend'] = None): # Connect to metadata store. # ------------------------------ # 1. try to connect to the metadata store specified in the configuration. @@ -28,7 +28,9 @@ def build_metadata(cfg, databackend: t.Optional['BaseDataBackend'] = None): if cfg.metadata_store is not None: # try to connect to the metadata store specified in the configuration. logging.info("Connecting to Metadata Client:", cfg.metadata_store) - return build(cfg.metadata_store, metadata_stores, type='metadata') + return _build_databackend_impl( + cfg.metadata_store, metadata_stores, type='metadata' + ) else: try: # try to connect to the data backend engine. @@ -45,19 +47,21 @@ def build_metadata(cfg, databackend: t.Optional['BaseDataBackend'] = None): try: # try to connect to the data backend uri. logging.info("Connecting to Metadata Client with URI: ", cfg.data_backend) - return build(cfg.data_backend, metadata_stores, type='metadata') + return _build_databackend_impl( + cfg.data_backend, metadata_stores, type='metadata' + ) except Exception as e: # Exit quickly if a connection fails. logging.error("Error initializing to Metadata Client:", str(e)) sys.exit(1) -def build_databackend(cfg, databackend=None): +def _build_databackend(cfg, databackend=None): # Connect to data backend. # ------------------------------ try: if not databackend: - databackend = build(cfg.data_backend, data_backends) + databackend = _build_databackend_impl(cfg.data_backend, data_backends) logging.info("Data Client is ready.", databackend.conn) except Exception as e: # Exit quickly if a connection fails. @@ -66,7 +70,7 @@ def build_databackend(cfg, databackend=None): return databackend -def build_artifact_store( +def _build_artifact_store( artifact_store: t.Optional[str] = None, databackend: t.Optional['BaseDataBackend'] = None, ): @@ -90,7 +94,7 @@ def build_artifact_store( # Helper function to build a data backend based on the URI. -def build(uri, mapping, type: str = 'data_backend'): +def _build_databackend_impl(uri, mapping, type: str = 'data_backend'): logging.debug(f"Parsing data connection URI:{uri}") if re.match('^mongodb:\/\/', uri) is not None: @@ -140,7 +144,7 @@ def build(uri, mapping, type: str = 'data_backend'): return mapping['sqlalchemy'](sql_conn, name) -def build_compute(compute): +def _build_compute(compute): logging.info("Connecting to compute client:", compute) if compute == 'local' or compute is None: @@ -170,6 +174,7 @@ def build_datalayer(cfg=None, databackend=None, **kwargs) -> Datalayer: :param cfg: Configuration to use. If None, use ``superduperdb.CFG``. :param databackend: Databacked to use. If None, use ``superduperdb.CFG.data_backend``. + :pararm kwargs: keyword arguments to be adopted by the `CFG` """ # Configuration @@ -185,17 +190,17 @@ def build_datalayer(cfg=None, databackend=None, **kwargs) -> Datalayer: cfg.force_set(k, v) # Build databackend - databackend = build_databackend(cfg, databackend) + databackend = _build_databackend(cfg, databackend) # Build metadata store - metadata = build_metadata(cfg, databackend) + metadata = _build_metadata(cfg, databackend) assert metadata # Build artifact store - artifact_store = build_artifact_store(cfg.artifact_store, databackend) + artifact_store = _build_artifact_store(cfg.artifact_store, databackend) # Build compute - compute = build_compute(cfg.cluster.compute) + compute = _build_compute(cfg.cluster.compute) # Build DataLayer # ------------------------------ diff --git a/superduperdb/base/config.py b/superduperdb/base/config.py index 69b21a3057..4c876f9fe0 100644 --- a/superduperdb/base/config.py +++ b/superduperdb/base/config.py @@ -179,7 +179,6 @@ class Config(BaseConfig): """ data_backend: str = 'mongodb://superduper:superduper@localhost:27017/test_db' - lance_home: str = os.path.join('.superduperdb', 'vector_indices') artifact_store: t.Optional[str] = None diff --git a/superduperdb/base/datalayer.py b/superduperdb/base/datalayer.py index c0e16ec92c..1c524aa90d 100644 --- a/superduperdb/base/datalayer.py +++ b/superduperdb/base/datalayer.py @@ -26,7 +26,7 @@ from superduperdb.cdc.cdc import DatabaseChangeDataCapture from superduperdb.components.component import Component from superduperdb.components.datatype import DataType, Encodable, serializers -from superduperdb.components.model import Model +from superduperdb.components.model import ObjectModel from superduperdb.components.schema import Schema from superduperdb.jobs.job import ComponentJob, FunctionJob, Job from superduperdb.jobs.task_workflow import TaskWorkflow @@ -104,11 +104,11 @@ def rebuild(self, cfg=None): cfg = cfg or s.CFG - self.databackend = build.build_databackend(cfg) - self.compute = build.build_compute(cfg.cluster.compute) + self.databackend = build._build_databackend(cfg) + self.compute = build._build_compute(cfg.cluster.compute) - self.metadata = build.build_metadata(cfg, self.databackend) - self.artifact_store = build.build_artifact_store( + self.metadata = build._build_metadata(cfg, self.databackend) + self.artifact_store = build._build_artifact_store( cfg.artifact_store, self.databackend ) self.artifact_store.serializers = self.serializers @@ -250,7 +250,7 @@ def validate( # TODO: never called component = self.load(type_id, identifier) metric_list = [self.load('metric', m) for m in metrics] - assert isinstance(component, Model) + assert isinstance(component, ObjectModel) return component.validate( self, validation_set, @@ -296,7 +296,7 @@ def show( def _get_context( self, model, context_select: t.Optional[Select], context_key: t.Optional[str] ): - assert model.takes_context, 'model does not take context' + assert 'context' in model.inputs.params, 'model does not take context' assert context_select is not None sources = list(self.execute(context_select)) context = sources[:] @@ -309,93 +309,6 @@ def _get_context( ] return context, sources - async def apredict( - self, - model_name: str, - input: t.Union[Document, t.Any], - context_select: t.Optional[Select] = None, - context_key: str = '_base', - **kwargs, - ): - """ - Apply model to input using asyncio. - - :param model_name: model identifier - :param input: input to be passed to the model. - Must be possible to encode with registered datatypes - :param context_select: select query object to provide context - :param context_key: key to use to extract context from context_select - """ - model = self.models[model_name] - context = None - sources: t.List[Document] = [] - - if context_select is not None: - context, sources = self._get_context(model, context_select, context_key) - - out = await model.apredict( - input.unpack() if isinstance(input, Document) else input, - one=True, - context=context, - **kwargs, - ) - - if model.datatype is not None: - out = model.datatype(out) - - if context is not None: - return Document(out), sources - return Document(out), [] - - def predict( - self, - model_name: str, - input: t.Union[Document, t.Any], - context_select: t.Optional[t.Union[str, Select]] = None, - context_key: t.Optional[str] = None, - **kwargs, - ) -> t.Tuple[Document, t.List[Document]]: - """ - Apply model to input. - - :param model_name: model identifier - :param input: input to be passed to the model. - Must be possible to encode with registered datatypes - :param context_select: select query object to provide context - :param context_key: key to use to extract context from context_select - """ - model = self.models[model_name] - context = None - sources: t.List[Document] = [] - - if context_select is not None: - if isinstance(context_select, Select): - context, sources = self._get_context(model, context_select, context_key) - elif isinstance(context_select, str): - context = context_select - else: - raise TypeError("context_select should be either Select or str") - - out = model.predict( - input.unpack() if isinstance(input, Document) else input, - one=True, - context=[r.unpack() for r in context] if context else None, - **kwargs, - ) - - if isinstance(model.datatype, DataType): - out = model.datatype(out) - elif isinstance(model.output_schema, Schema): - # TODO make Schema callable - out = model.output_schema.encode(out) - - if not isinstance(out, dict): - out = {'_base': out} - - if context is not None: - return Document(out), sources - return Document(out), [] - def execute(self, query: ExecuteQuery, *args, **kwargs) -> ExecuteResult: """ Execute a query on the db. @@ -749,6 +662,7 @@ def _build_task_workflow( listener_selects = {} for identifier in listeners: + # TODO reload listener here (with lazy loading) info = self.metadata.get_component('listener', identifier) listener_query = Document.decode(info['dict']['select'], None) listener_select = serializable.Serializable.decode(listener_query) @@ -763,7 +677,7 @@ def _build_task_workflow( model, _, key = identifier.rpartition('/') G.add_node( - f'{model}.predict({key})', + f'{model}.predict_in_db({key})', job=ComponentJob( component_identifier=model, args=[key], @@ -772,7 +686,7 @@ def _build_task_workflow( 'select': listener_query.dict().encode(), **info['dict']['predict_kwargs'], }, - method_name='predict', + method_name='predict_in_db', type_id='model', ), ) @@ -789,14 +703,14 @@ def _build_task_workflow( model, _, key = identifier.rpartition('/') G.add_edge( f'{download_content.__name__}()', - f'{model}.predict({key})', + f'{model}.predict_in_db({key})', ) deps = self._get_dependencies_for_listener(identifier) for dep in deps: dep_model, _, dep_key = dep.rpartition('/') G.add_edge( - f'{dep_model}.predict({dep_key})', - f'{model}.predict({key})', + f'{dep_model}.predict_in_db({dep_key})', + f'{model}.predict_in_db({key})', ) if s.CFG.self_hosted_vector_search: @@ -835,7 +749,7 @@ def _build_task_workflow( model = vi.indexing_listener.model.identifier key = vi.indexing_listener.key G.add_edge( - f'{model}.predict({key})', f'{identifier}.{copy_vectors.__name__}' + f'{model}.predict_in_db({key})', f'{identifier}.{copy_vectors.__name__}' ) return G @@ -1116,6 +1030,54 @@ def select_nearest( logging.info(str(outs)) return vi.get_nearest(like, db=self, ids=ids, n=n, outputs=outs) + # TODO deprecate in favour of remote model calls + def predict( + self, + model_name: str, + input: t.Union[Document, t.Any], + context_select: t.Optional[t.Union[str, Select]] = None, + context_key: t.Optional[str] = None, + ) -> t.Tuple[Document, t.List[Document]]: + """ + Apply model to input. + + :param model_name: model identifier + :param input: input to be passed to the model. + Must be possible to encode with registered datatypes + :param context_select: select query object to provide context + :param context_key: key to use to extract context from context_select + """ + model = self.models[model_name] + context = None + sources: t.List[Document] = [] + + if context_select is not None: + assert 'context' in model.inputs.params + if isinstance(context_select, Select): + context, sources = self._get_context(model, context_select, context_key) + elif isinstance(context_select, str): + context = context_select + else: + raise TypeError("context_select should be either Select or str") + context = [r.unpack() for r in context] + else: + sources = [] + + if 'context' in model.inputs.params: + out = model.predict_one(input, context=context) + else: + out = model.predict_one(input) + + if isinstance(model.datatype, DataType): + out = model.datatype(out) + elif isinstance(model.output_schema, Schema): + # TODO make Schema callable + out = model.output_schema.encode(out) + + if not isinstance(out, dict): + out = {'_base': out} + return Document(out), sources + def close(self): """ Gracefully shutdown the Datalayer diff --git a/superduperdb/base/document.py b/superduperdb/base/document.py index c6b010fb05..8408f5b9a7 100644 --- a/superduperdb/base/document.py +++ b/superduperdb/base/document.py @@ -14,6 +14,7 @@ from superduperdb.misc.special_dicts import MongoStyleDict if t.TYPE_CHECKING: + from superduperdb.base.datalayer import Datalayer from superduperdb.components.schema import Schema @@ -87,12 +88,14 @@ def outputs(self, key: str, model: str, version: t.Optional[int] = None) -> t.An @staticmethod def decode( r: t.Dict, - db, + db: t.Optional['Datalayer'] = None, bytes_encoding: t.Optional[BytesEncoding] = None, reference: bool = False, ) -> t.Any: bytes_encoding = bytes_encoding or CFG.bytes_encoding - decoded = _decode(dict(r), db, bytes_encoding, reference=reference) + decoded = _decode( + dict(r), db=db, bytes_encoding=bytes_encoding, reference=reference + ) if isinstance(decoded, dict): return Document(decoded) return decoded @@ -100,13 +103,12 @@ def decode( def __repr__(self) -> str: return f'Document({repr(dict(self))})' - def unpack(self) -> t.Any: - """Returns the content, but with any encodables replacecs by their contents""" - if '_base' in self: - r = self['_base'] - else: - r = dict(self) - return _unpack(r) + def unpack(self, db=None) -> t.Any: + """Returns the content, but with any encodables replaced by their contents""" + out = _unpack(self, db=db) + if '_base' in out: + out = out['_base'] + return out def _find_leaves(r: t.Any, leaf_type: t.Optional[str] = None, pop: bool = False): @@ -147,16 +149,17 @@ def _decode( bytes_encoding = bytes_encoding or CFG.bytes_encoding if isinstance(r, dict) and '_content' in r: return _LEAF_TYPES[r['_content']['leaf_type']].decode( - r, db, reference=reference + r, db=db, reference=reference ) elif isinstance(r, list): return [ - _decode(x, db, bytes_encoding=bytes_encoding, reference=reference) + _decode(x, db=db, bytes_encoding=bytes_encoding, reference=reference) for x in r ] elif isinstance(r, dict): return { - k: _decode(v, db, bytes_encoding, reference=reference) for k, v in r.items() + k: _decode(v, db=db, bytes_encoding=bytes_encoding, reference=reference) + for k, v in r.items() } else: return r @@ -166,6 +169,8 @@ def _decode( class Reference(Serializable): identifier: str leaf_type: str + path: t.Optional[str] = None + db: t.Optional['Datalayer'] = None def _encode_with_references(r: t.Any, references: t.Dict): @@ -231,17 +236,18 @@ def _encode_with_schema( return r -def _unpack(item: t.Any) -> t.Any: +def _unpack(item: t.Any, db=None) -> t.Any: if isinstance(item, Encodable): + # TODO move logic into Encodable if item.reference: file_id = _construct_file_id_from_uri(item.uri) if item.datatype.directory: file_id = os.path.join(item.datatype.directory, file_id) return file_id - return item.x + return item.unpack(db=db) elif isinstance(item, dict): - return {k: _unpack(v) for k, v in item.items()} + return {k: _unpack(v, db=db) for k, v in item.items()} elif isinstance(item, list): - return [_unpack(x) for x in item] + return [_unpack(x, db=db) for x in item] else: return item diff --git a/superduperdb/base/leaf.py b/superduperdb/base/leaf.py index 19ae13c0ca..6c68b09108 100644 --- a/superduperdb/base/leaf.py +++ b/superduperdb/base/leaf.py @@ -9,6 +9,9 @@ class Leaf(ABC): def unique_id(self): pass + def unpack(self, db=None): + return self + @abstractmethod def encode( self, @@ -21,3 +24,6 @@ def encode( @abstractmethod def decode(cls, r, db): pass + + def init(self, db=None): + pass diff --git a/superduperdb/base/serializable.py b/superduperdb/base/serializable.py index fc24eb6ac8..3ba861b188 100644 --- a/superduperdb/base/serializable.py +++ b/superduperdb/base/serializable.py @@ -87,6 +87,11 @@ def variables(self) -> t.List['Variable']: return sorted(list(out.values()), key=lambda x: x.value) def set_variables(self, db, **kwargs) -> 'Serializable': + """ + Set free variables of self. + + :param db: + """ r = self.encode(leaf_types_to_keep=(Variable,)) r = _replace_variables(r, db, **kwargs) return self.decode(r) diff --git a/superduperdb/components/component.py b/superduperdb/components/component.py index 5af8ffffd2..0ab6598a97 100644 --- a/superduperdb/components/component.py +++ b/superduperdb/components/component.py @@ -34,7 +34,6 @@ class Component(Serializable, Leaf): leaf_type: t.ClassVar[str] = 'component' _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, 'DataType']]] = () set_post_init: t.ClassVar[t.Sequence] = ('version',) - identifier: str artifacts: dc.InitVar[t.Optional[t.Dict]] = None @@ -45,6 +44,23 @@ def __post_init__(self, artifacts): if not self.identifier: raise ValueError('identifier cannot be empty or None') + def init(self, db=None): + from superduperdb.base.document import Document + from superduperdb.components.datatype import Encodable + + for f in dc.fields(self): + item = getattr(self, f.name) + if isinstance(item, Component): + item.init(db=db) + if isinstance(item, dict): + setattr(self, f.name, Document(item).unpack(db=self.db or db)) + if isinstance(item, list): + unpacked = Document({'_base': item}).unpack(db=self.db or db) + setattr(self, f.name, unpacked) + if isinstance(item, Encodable): + item.init(db=db) + setattr(self, f.name, item.x) + @cached_property def artifact_schema(self): from superduperdb import Schema @@ -155,7 +171,6 @@ def schedule_jobs( self, db: Datalayer, dependencies: t.Sequence[Job] = (), - verbose: bool = False, ) -> t.Sequence[t.Any]: """Run the job for this listener diff --git a/superduperdb/components/datatype.py b/superduperdb/components/datatype.py index b00d1cb154..d237ab102f 100644 --- a/superduperdb/components/datatype.py +++ b/superduperdb/components/datatype.py @@ -143,11 +143,6 @@ def build_torch_state_serializer(module, info): ) -@dc.dataclass -class LazyLoader: - info: t.Dict - - @dc.dataclass class Encodable(Leaf): """ @@ -163,6 +158,7 @@ class Encodable(Leaf): datatype: DataType x: t.Optional[t.Any] = None uri: t.Optional[str] = None + file_id: t.Optional[str] = None @property def unique_id(self): @@ -176,11 +172,27 @@ def artifact(self): def reference(self): return self.datatype.reference + def init(self, db): + self.x = db.artifact_store.load_artifact(self.file_id) + + def unpack(self, db): + """ + Unpack the content of the `Encodable` + + :param db: `Datalayer` instance to assist with + """ + if self.x is None: + self.init() + return self.x + def encode( self, bytes_encoding: t.Optional[BytesEncoding] = None, leaf_types_to_keep: t.Sequence = (), ) -> t.Union[t.Optional[str], t.Dict[str, t.Any]]: + """ + :param bytes_encoding: + """ from superduperdb.backends.base.artifact import ArtifactSavingError def _encode(x): @@ -211,20 +223,37 @@ def _encode(x): } @classmethod - def decode(cls, r, db, reference: bool = False): - datatype = db.datatypes[r['_content']['datatype']] - # TODO tidy up this logic - if datatype.artifact and not datatype.reference and not reference: - object = db.artifact_store.load_artifact(r['_content']) - elif datatype.artifact and datatype.reference: - return Encodable(x=None, datatype=datatype, uri=r['_content']['uri']) - elif 'bytes' not in r['_content'] and reference: - assert ( - 'uri' in r['_content'] - ), 'If load by reference, need a valid URI for data, found "None"' - return Encodable(x=None, datatype=datatype, uri=r['_content']['uri']) - else: + def decode(cls, r, db=None, reference: bool = False): + # TODO tidy up this logic by creating different subclasses of datatype + # Idea + if 'bytes' in r['_content']: + if db is None: + try: + from superduperdb.components.datatype import serializers + + datatype = serializers[r['_content']['datatype']] + except KeyError: + raise Exception( + f'You specified a serializer which doesn\'t have a' + f' default value: {r["_content"]["datatype"]}' + ) + else: + datatype = db.datatypes[r['_content']['datatype']] + object = datatype.decoder(r['_content']['bytes'], info=datatype.info) + else: + datatype = db.datatypes[r['_content']['datatype']] + if datatype.artifact and not datatype.reference and not reference: + object = db.artifact_store.load_artifact(r['_content']) + elif datatype.artifact and datatype.reference: + return Encodable(x=None, datatype=datatype, uri=r['_content']['uri']) + elif 'bytes' not in r['_content'] and reference: + assert ( + 'uri' in r['_content'] + ), 'If load by reference, need a valid URI for data, found "None"' + return Encodable(x=None, datatype=datatype, uri=r['_content']['uri']) + else: + raise Exception('Incorrect argument combination.') return Encodable( x=object, datatype=datatype, diff --git a/superduperdb/components/listener.py b/superduperdb/components/listener.py index 8a73f0055a..2409771369 100644 --- a/superduperdb/components/listener.py +++ b/superduperdb/components/listener.py @@ -7,12 +7,13 @@ from superduperdb.backends.base.query import CompoundSelect from superduperdb.base.datalayer import Datalayer from superduperdb.base.document import _OUTPUTS_KEY +from superduperdb.components.model import Mapping from superduperdb.misc.annotations import public_api from superduperdb.misc.server import request_server from ..jobs.job import Job from .component import Component -from .model import Model +from .model import ModelInputType, _Predictor @public_api(stability='stable') @@ -32,37 +33,42 @@ class Listener(Component): __doc__ = __doc__.format(component_parameters=Component.__doc__) - key: t.Union[str, t.List, t.Dict] - model: t.Union[str, Model] + key: ModelInputType + model: _Predictor select: CompoundSelect - identifier: t.Optional[str] = None # type: ignore[assignment] active: bool = True predict_kwargs: t.Optional[t.Dict] = dc.field(default_factory=dict) + identifier: t.Optional[str] = None # type: ignore[assignment] type_id: t.ClassVar[str] = 'listener' def __post_init__(self, artifacts): - if self.identifier is None and self.model is not None: - if isinstance(self.model, str): - self.identifier = f'{self.model}/{self.id_key}' - else: - self.identifier = f'{self.model.identifier}/{self.id_key}' + identifier = f'{self.model.identifier}/{self.mapping.id_key}' + if self.identifier != identifier: + assert self.identifier is None, 'Don\'t set manually' + self.identifier = identifier super().__post_init__(artifacts) + @property + def mapping(self): + return Mapping(self.key, signature=self.model.signature) + @property def outputs(self): return ( - f'{_OUTPUTS_KEY}.{self.id_key}.{self.model.identifier}.{self.model.version}' + f'{_OUTPUTS_KEY}.{self.mapping.id_key}' + f'.{self.model.identifier}.{self.model.version}' ) @override def pre_create(self, db: Datalayer) -> None: if isinstance(self.model, str): - self.model = t.cast(Model, db.load('model', self.model)) + self.model = t.cast(_Predictor, db.load('model', self.model)) if self.select is not None and self.select.variables: self.select = t.cast(CompoundSelect, self.select.set_variables(db)) + @override def post_create(self, db: Datalayer) -> None: # Start cdc service if enabled if self.select is not None and self.active and not db.server_mode: @@ -123,7 +129,6 @@ def schedule_jobs( self, db: Datalayer, dependencies: t.Sequence[Job] = (), - verbose: bool = False, ) -> t.Sequence[t.Any]: """ Schedule jobs for the listener @@ -137,12 +142,11 @@ def schedule_jobs( assert not isinstance(self.model, str) out = [ - self.model.predict( + self.model.predict_in_db_job( X=self.key, db=db, select=self.select.copy(), dependencies=dependencies, - **(self.predict_kwargs or {}), ) ] return out diff --git a/superduperdb/components/model.py b/superduperdb/components/model.py index a8ccb03bb1..fb73686f74 100644 --- a/superduperdb/components/model.py +++ b/superduperdb/components/model.py @@ -8,22 +8,21 @@ from functools import wraps import tqdm -from overrides import override from sklearn.pipeline import Pipeline from superduperdb import logging from superduperdb.backends.base.metadata import NonExistentMetadataError -from superduperdb.backends.base.query import CompoundSelect, Select, TableOrCollection +from superduperdb.backends.base.query import CompoundSelect, Select from superduperdb.backends.ibis.field_types import FieldType from superduperdb.backends.ibis.query import IbisCompoundSelect, Table from superduperdb.backends.query_dataset import QueryDataset +from superduperdb.base.document import Document from superduperdb.base.serializable import Serializable from superduperdb.components.component import Component from superduperdb.components.datatype import DataType, dill_serializer from superduperdb.components.metric import Metric from superduperdb.components.schema import Schema from superduperdb.jobs.job import ComponentJob, Job -from superduperdb.misc import border_msg from superduperdb.misc.annotations import public_api from superduperdb.misc.special_dicts import MongoStyleDict @@ -31,30 +30,14 @@ from superduperdb.base.datalayer import Datalayer from superduperdb.components.dataset import Dataset -EncoderArg = t.Union[DataType, FieldType, str, None] -XType = t.Union[t.Any, t.List, t.Dict] - -class _to_call: - def __init__(self, callable, **kwargs): - self.callable = callable - self.kwargs = kwargs - - def __call__(self, X): - return self.callable(X, **self.kwargs) +EncoderArg = t.Union[DataType, FieldType, None] +ModelInputType = t.Union[str, t.List[str], t.Tuple[t.List[str], t.Dict[str, str]]] class Inputs: - def __init__(self, fn, predict_kwargs: t.Dict = {}): - sig = inspect.signature(fn) - sig_keys = list(sig.parameters.keys()) - params = [] - for k in sig_keys: - if k in predict_kwargs or (k == 'kwargs' and sig.parameters[k].kind == 4): - continue - params.append(k) - - self.params = {p: p for p in params} + def __init__(self, params): + self.params = params def __len__(self): return len(self.params) @@ -69,6 +52,18 @@ def get_kwargs(self, args): return kwargs +class CallableInputs(Inputs): + def __init__(self, fn, predict_kwargs: t.Dict = {}): + sig = inspect.signature(fn) + sig_keys = list(sig.parameters.keys()) + params = [] + for k in sig_keys: + if k in predict_kwargs or (k == 'kwargs' and sig.parameters[k].kind == 4): + continue + params.append(k) + self.params = params + + @dc.dataclass(kw_only=True) class _TrainingConfiguration(Component): """ @@ -90,361 +85,364 @@ def get(self, k, default=None): @dc.dataclass(kw_only=True) -class _Predictor: - # Mixin class for components which can predict. - """:param encoder: Encoder instance - :param output_schema: Output schema (mapping of encoders) - :param flatten: Flatten the model outputs - :param preprocess: Preprocess function - :param postprocess: Postprocess function - :param collate_fn: Collate function - :param batch_predict: Whether to batch predict - :param takes_context: Whether the model takes context into account - :param metrics: The metrics to evaluate on - :param model_update_kwargs: The kwargs to use for model update - :param validation_sets: The validation ``Dataset`` instances to use - :param predict_X: The key of the input data to use for .predict - :param predict_select: The select to use for .predict - :param predict_max_chunk_size: The max chunk size to use for .predict - :param predict_kwargs: The kwargs to use for .predict""" - - type_id: t.ClassVar[str] = 'model' +class _Fittable: + training_configuration: t.Union[str, _TrainingConfiguration, None] = None + train_X: t.Optional[ModelInputType] = None + train_y: t.Optional[str] = None + train_select: t.Optional[CompoundSelect] = None + metric_values: t.Dict = dc.field(default_factory=lambda: {}) - datatype: EncoderArg = None - output_schema: t.Optional[Schema] = None - flatten: bool = False - preprocess: t.Optional[t.Callable] = None - postprocess: t.Optional[t.Callable] = None - collate_fn: t.Optional[t.Callable] = None - batch_predict: bool = False - takes_context: bool = False - metrics: t.Sequence[t.Union[str, Metric, None]] = () - model_update_kwargs: t.Dict = dc.field(default_factory=dict) - validation_sets: t.Optional[t.Sequence[t.Union[str, Dataset]]] = None + def post_create(self, db: Datalayer) -> None: + if isinstance(self.training_configuration, str): + self.training_configuration = db.load( + 'training_configuration', self.training_configuration + ) # type: ignore[assignment] + # TODO is this necessary - should be handled by `db.add` automatically? - predict_X: t.Optional[str] = None - predict_select: t.Optional[CompoundSelect] = None - predict_max_chunk_size: t.Optional[int] = None - predict_kwargs: t.Optional[t.Dict] = None + def schedule_jobs(self, db, dependencies=()): + jobs = [] + if self.train_X is not None: + assert ( + isinstance(self.training_configuration, _TrainingConfiguration) + or self.training_configuration is None + ) + assert self.train_select is not None + jobs.append( + self.fit( + X=self.train_X, + y=self.train_y, + configuration=self.training_configuration, + select=self.train_select, + db=db, + dependencies=dependencies, + metrics=self.metrics, + validation_sets=self.validation_sets, + ) + ) + return jobs - @abstractmethod - def to_call(self, X, *args, **kwargs): - """ - The method to use to call prediction. Should be implemented - by the child class. - """ + def _validate( + self, + db: Datalayer, + validation_set: t.Union[Dataset, str], + metrics: t.Sequence[Metric], + ): + if isinstance(validation_set, str): + from superduperdb.components.dataset import Dataset - @property - def inputs(self): - kwargs = self.predict_kwargs if self.predict_kwargs else {} - return Inputs(self.preprocess or self.object, kwargs) + validation_set = t.cast(Dataset, db.load('dataset', validation_set)) - def setup_required_inputs(self, X): - if isinstance(X, (tuple, list)): - required_args = len(self.inputs) - assert len(X) == required_args - X = self.inputs.get_kwargs(X) + mdicts = [MongoStyleDict(r.unpack()) for r in validation_set.data] + assert self.train_X is not None + mapping = Mapping(self.train_X, self.signature) + dataset = list(map(mapping, mdicts)) + prediction = self.predict(dataset) + assert self.train_y is not None + target = [d[self.train_y] for d in mdicts] + assert isinstance(prediction, list) + assert isinstance(target, list) + results = {} - elif isinstance(X, dict): - required_args = len(self.inputs) - assert len(X) == required_args - else: - X = self.inputs.get_kwargs([X]) - return X + for m in metrics: + out = m(prediction, target) + results[f'{validation_set.identifier}/{m.identifier}'] = out + return results - def create_predict_job( + def create_fit_job( self, - X: XType, + X: t.Union[str, t.Sequence[str]], select: t.Optional[Select] = None, - ids: t.Optional[t.Sequence[str]] = None, - max_chunk_size: t.Optional[int] = None, + y: t.Optional[str] = None, **kwargs, ): return ComponentJob( component_identifier=self.identifier, - method_name='predict', + method_name='fit', type_id='model', args=[X], kwargs={ + 'y': y, 'select': select.dict().encode() if select else None, - 'ids': ids, - 'max_chunk_size': max_chunk_size, **kwargs, }, ) - async def _apredict_one(self, X: t.Any, **kwargs): - raise NotImplementedError + @abstractmethod + def _fit( + self, + X: t.Any, + y: t.Optional[t.Any] = None, + configuration: t.Optional[_TrainingConfiguration] = None, + data_prefetch: bool = False, + db: t.Optional[Datalayer] = None, + metrics: t.Optional[t.Sequence[Metric]] = None, + select: t.Optional[Select] = None, + validation_sets: t.Optional[t.Sequence[Dataset]] = None, + ): + pass - async def _apredict(self, X: t.Any, one: bool = False, **kwargs): - raise NotImplementedError + def fit( + self, + X: t.Any, + y: t.Optional[t.Any] = None, + configuration: t.Optional[_TrainingConfiguration] = None, + data_prefetch: bool = False, + db: t.Optional[Datalayer] = None, + dependencies: t.Sequence[Job] = (), + metrics: t.Optional[t.Sequence[Metric]] = None, + select: t.Optional[Select] = None, + validation_sets: t.Optional[t.Sequence[Dataset]] = None, + **kwargs, + ) -> t.Optional[Pipeline]: + """ + Fit the model on the given data. - def _predict_one(self, X: t.Any, **kwargs) -> int: - if self.preprocess: - X = self.setup_required_inputs(X) - X = self.preprocess(**X) - output = self.to_call(X, **kwargs) - if self.postprocess: - output = self.postprocess(output) - return output + :param X: The key of the input data to use for training + :param y: The key of the target data to use for training + :param configuration: The training configuration (optional) + :param data_prefetch: Whether to prefetch the data (optional) + :param db: The datalayer (optional) + :param dependencies: The dependencies (optional) + :param metrics: The metrics to evaluate on (optional) + :param select: The select to use for training (optional) + :param validation_sets: The validation ``Dataset`` instances to use (optional) + """ + if isinstance(select, dict): + # TODO replace with Document.decode(select) + select = Serializable.from_dict(select) - def _forward( - self, X: t.Sequence[int], num_workers: int = 0, **kwargs - ) -> t.Sequence[int]: - if self.batch_predict: - return self.to_call(X, **kwargs) + if validation_sets: + from superduperdb.components.dataset import Dataset - outputs = [] - if num_workers: - to_call = _to_call(self.to_call, **kwargs) - pool = multiprocessing.Pool(processes=num_workers) - for r in pool.map(to_call, X): - outputs.append(r) - pool.close() - pool.join() + validation_sets = list(validation_sets) + for i, vs in enumerate(validation_sets): + if isinstance(vs, Dataset): + assert db is not None + db.add(vs) + validation_sets[i] = vs + + self.training_configuration = configuration or self.training_configuration + + if db is not None: + db.add(self) + + if db is not None and db.compute.type == 'distributed': + return self.create_fit_job( + X, + select=select, + y=y, + **kwargs, + )(db=db, dependencies=dependencies) else: - for r in X: - outputs.append(self.to_call(r, **kwargs)) - return outputs + return self._fit( + X, + y=y, + configuration=configuration, + data_prefetch=data_prefetch, + db=db, + metrics=metrics, + select=select, + validation_sets=validation_sets, + **kwargs, + ) - def _predict(self, X: t.Any, one: bool = False, **predict_kwargs): - if one: - return self._predict_one(X) + def append_metrics(self, d: t.Dict[str, float]) -> None: + if self.metric_values is not None: + for k, v in d.items(): + self.metric_values.setdefault(k, []).append(v) - if self.preprocess: - preprocessed_X = [] - for r in X: - r = self.setup_required_inputs(r) - preprocessed_X.append(self.preprocess(**r)) - X = preprocessed_X - elif self.preprocess is not None: - raise ValueError('Bad preprocess') - if self.collate_fn: - X = self.collate_fn(X) - outputs = self._forward(X, **predict_kwargs) +@wraps(_TrainingConfiguration) +def TrainingConfiguration(identifier: str, **kwargs): + return _TrainingConfiguration(identifier=identifier, kwargs=kwargs) - if self.postprocess: - outputs = [self.postprocess(o) for o in outputs] - elif self.postprocess is not None: - raise ValueError('Bad postprocess') - return outputs +@dc.dataclass +class Signature: + singleton: t.ClassVar[str] = 'singleton' + args: t.ClassVar[str] = '*args' + kwargs: t.ClassVar[str] = '**kwargs' + args_kwargs: t.ClassVar[str] = '*args,**kwargs' - def validate_keys(self, X, key, one=False): - if isinstance(key, str): - if one: - assert isinstance(X, dict) - X = X[key] - else: - X = [r[key] for r in X] - elif isinstance(key, list): - X = ( - ((X[k] for k in key) if one else [(r[k] for k in key) for r in X]) - if key - else X, - ) - elif isinstance(key, dict): - X = ( - ( - (X[k] for k in key.values()) - if one - else [(r[k] for k in key.values()) for r in X] - ) - if key - else X, - ) +class Mapping: + def __init__(self, mapping: ModelInputType, signature: str): + self.mapping = self._map_args_kwargs(mapping) + self.signature = signature + + @property + def id_key(self): + out = [] + for arg in self.mapping[0]: + out.append(arg) + for k, v in self.mapping[1]: + if k.startswith('_outputs.'): + k = k.split('.')[1] + out.append(f'{k}={v}') + return ','.join(out) + + @staticmethod + def _map_args_kwargs(mapping): + if isinstance(mapping, str): + return ([mapping], {}) + elif isinstance(mapping, (list, tuple)) and isinstance(mapping[0], str): + return (mapping, {}) + elif isinstance(mapping, dict): + return ((), mapping) else: - if key is not None: - raise TypeError - return X + assert isinstance(mapping[0], (list, tuple)) + assert isinstance(mapping[1], dict) + return mapping - def predict( - self, - X: XType, - db: t.Optional[Datalayer] = None, - select: t.Optional[CompoundSelect] = None, - ids: t.Optional[t.List[str]] = None, - max_chunk_size: t.Optional[int] = None, - dependencies: t.Sequence[Job] = (), - listen: bool = False, - one: bool = False, - context: t.Optional[t.Dict] = None, - insert_to: t.Optional[t.Union[TableOrCollection, str]] = None, - key: t.Optional[t.Union[t.Dict, t.List, str]] = None, - in_memory: bool = True, - overwrite: bool = False, - **kwargs, - ) -> t.Any: - was_added = self.db is not None - db = self.db or db + def __call__(self, r): + """ + >>> r = {'a': 1, 'b': 2} + >>> self.mapping = [('a', 'b'), {}] + >>> _Predictor._data_from_input_type(docs) + ([1, 2], {}) + >>> self.mapping = [('a',), {'b': 'X'}] + >>> _Predictor._data_from_input_type(docs) + ([1], {'X': 2}) + """ + args = [] + kwargs = {} + for key in self.mapping[0]: + args.append(r[key]) + for k, v in self.mapping[1].items(): + kwargs[v] = r[k] + args = Document({'_base': args}).unpack() + kwargs = Document(kwargs).unpack() + + if self.signature == Signature.kwargs: + return kwargs + elif self.signature == Signature.args: + return (*args, *list(kwargs.values())) + elif self.signature == Signature.singleton: + if args: + assert not kwargs + assert len(args) == 1 + return args[0] + else: + assert kwargs + assert len(kwargs) == 1 + return next(kwargs.values()) + return args, kwargs - if one: - assert select is None, 'select must be None when ``one=True`` (direct call)' - if isinstance(select, dict): - select = Serializable.decode(select) +@dc.dataclass(kw_only=True) +class _Predictor(Component): + # Mixin class for components which can predict. + """:param datatype: DataType instance + :param output_schema: Output schema (mapping of encoders) + :param flatten: Flatten the model outputs + :param collate_fn: Collate function + :param model_update_kwargs: The kwargs to use for model update + :param metrics: The metrics to evaluate on + :param validation_sets: The validation ``Dataset`` instances to use + :param predict_kwargs: Additional arguments to use at prediction time + """ - if isinstance(select, Table): - select = select.to_query() + type_id: t.ClassVar[str] = 'model' + signature: t.ClassVar[str] = Signature.args_kwargs - if db is not None: - if isinstance(select, IbisCompoundSelect): - from superduperdb.backends.sqlalchemy.metadata import SQLAlchemyMetadata + datatype: EncoderArg = None + output_schema: t.Optional[Schema] = None + flatten: bool = False + model_update_kwargs: t.Dict = dc.field(default_factory=dict) + metrics: t.Sequence[Metric] = () + validation_sets: t.Optional[t.Sequence[Dataset]] = None + predict_kwargs: t.Dict = dc.field(default_factory=lambda: {}) - assert isinstance(db.metadata, SQLAlchemyMetadata) - try: - _ = db.metadata.get_query(str(hash(select))) - except NonExistentMetadataError: - logging.info(f'Query {select} not found in metadata, adding...') - db.metadata.add_query(select, self.identifier) - logging.info('Done') - - if not was_added: - logging.info(f'Adding model {self.identifier} to db') - assert isinstance(self, Component) - db.add(self) - - if listen: - assert db is not None - assert select is not None - return self._predict_and_listen( - X=X, - db=db, - select=select, - max_chunk_size=max_chunk_size, - **kwargs, - ) + def post_create(self, db): + output_component = db.databackend.create_model_table_or_collection(self) + if output_component is not None: + db.add(output_component) - # TODO: tidy up this logic - if select is not None and db is not None and db.compute.type == 'distributed': - return self.create_predict_job( - X, - select=select, - ids=ids, - max_chunk_size=max_chunk_size, - overwrite=overwrite, - **kwargs, - )(db=db, dependencies=dependencies) - else: - if select is not None and ids is None: - assert db is not None - return self._predict_with_select( - X=X, - select=select, - db=db, - in_memory=in_memory, - max_chunk_size=max_chunk_size, - overwrite=overwrite, - **kwargs, - ) - elif select is not None and ids is not None: - assert db is not None - return self._predict_with_select_and_ids( - X=X, - select=select, - ids=ids, - db=db, - max_chunk_size=max_chunk_size, - in_memory=in_memory, - **kwargs, - ) - else: - if self.takes_context: - kwargs['context'] = context + @property + def inputs(self) -> Inputs: + return Inputs(list(inspect.signature(self.predict_one).parameters.keys())) - X_predict = self.validate_keys(X, key, one=one) + @abstractmethod + def predict_one(self, *args, **kwargs) -> int: + """ + Execute a single prediction on a datapoint + given by positional and keyword arguments. - output = self._predict( - X_predict, - one=one, - **kwargs, - ) - if insert_to is not None: - msg = ( - '`self.db` has not been set; this is necessary if' - ' `insert_to` is not None; use `db.add(self)`' - ) - - from superduperdb.base.datalayer import Datalayer - - assert isinstance(db, Datalayer), msg - if isinstance(insert_to, str): - insert_to = db.load( - 'table', - insert_to, - ) # type: ignore[assignment] - if one: - output = [output] - - assert isinstance(insert_to, TableOrCollection) - if one: - X = [X] - - inserted_ids, _ = db.execute(insert_to.insert(X)) # type: ignore[arg-type] - inserted_ids = t.cast(t.List[t.Any], inserted_ids) - assert isinstance(key, str) - - insert_to.model_update( - db=db, - model=self.identifier, - outputs=output, - key=key, - version=self.version, - ids=inserted_ids, - flatten=self.flatten, - **self.model_update_kwargs, - ) - return output - - async def apredict( - self, - X: t.Any, - context: t.Optional[t.Dict] = None, - one: bool = False, - **kwargs, - ): - if self.takes_context: - kwargs['context'] = context - return await self._apredict(X, one=one, **kwargs) + :param args: arguments handled by model + :param kwargs: key-word arguments handled by model + """ + pass - def _predict_and_listen( - self, - X: t.Any, - select: CompoundSelect, - db: Datalayer, - in_memory: bool = True, - max_chunk_size: t.Optional[int] = None, - dependencies: t.Sequence[Job] = (), - **kwargs, - ): - from superduperdb.components.listener import Listener + @abstractmethod + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + """ + Execute a single prediction on a datapoint + given by positional and keyword arguments. - return db.add( - Listener( - key=X, - model=t.cast(Model, self), - select=select, - predict_kwargs={ - **kwargs, - 'in_memory': in_memory, - 'max_chunk_size': max_chunk_size, - }, - ), - dependencies=dependencies, - )[0] - - def _predict_with_select( + :param args: arguments handled by model + :param kwargs: key-word arguments handled by model + """ + pass + + def _prepare_select_for_predict(self, select, db): + if isinstance(select, dict): + select = Serializable.decode(select) + # TODO logic in the wrong place + if isinstance(select, Table): + select = select.to_query() + if isinstance(select, IbisCompoundSelect): + from superduperdb.backends.sqlalchemy.metadata import SQLAlchemyMetadata + + assert isinstance(db.metadata, SQLAlchemyMetadata) + try: + _ = db.metadata.get_query(str(hash(select))) + except NonExistentMetadataError: + logging.info(f'Query {select} not found in metadata, adding...') + db.metadata.add_query(select, self.identifier) + logging.info('Done') + return select + + def predict_in_db_job( self, - X: t.Any, - select: Select, + X: ModelInputType, db: Datalayer, + select: t.Optional[CompoundSelect], + ids: t.Optional[t.List[str]] = None, max_chunk_size: t.Optional[int] = None, + dependencies: t.Sequence[Job] = (), in_memory: bool = True, overwrite: bool = False, - **kwargs, ): + """ + Execute a single prediction on a datapoint + given by positional and keyword arguments as a job. + + :param X: combination of input keys to be mapped to the model + :param db: SuperDuperDB instance + :param select: CompoundSelect query + :param ids: Iterable of ids + :param max_chunk_size: Chunks of data + :param dependencies: List of dependencies (jobs) + :param in_memory: Load data into memory or not + :param overwrite: Overwrite all documents or only new documents + """ + job = ComponentJob( + component_identifier=self.identifier, + method_name='predict_in_db', + type_id='model', + args=[X], + kwargs={ + 'select': select.dict().encode() if select else None, + 'ids': ids, + 'max_chunk_size': max_chunk_size, + 'in_memory': in_memory, + 'overwrite': overwrite, + }, + ) + job(db, dependencies=dependencies) + return job + + def _get_ids_from_select(self, X, select, db, overwrite: bool = False): ids = [] if not overwrite: query = select.select_ids_of_missing_outputs( @@ -454,34 +452,128 @@ def _predict_with_select( ) else: query = select.select_ids - try: id_field = db.databackend.id_field except AttributeError: id_field = query.table_or_collection.primary_id - for r in tqdm.tqdm(db.execute(query)): ids.append(str(r[id_field])) + return ids + + def predict_in_db( + self, + X: ModelInputType, + db: Datalayer, + select: CompoundSelect, + ids: t.Optional[t.List[str]] = None, + max_chunk_size: t.Optional[int] = None, + in_memory: bool = True, + overwrite: bool = False, + ) -> t.Any: + """ + Execute a single prediction on a datapoint + given by positional and keyword arguments as a job. + + :param X: combination of input keys to be mapped to the model + :param db: SuperDuperDB instance + :param select: CompoundSelect query + :param ids: Iterable of ids + :param max_chunk_size: Chunks of data + :param dependencies: List of dependencies (jobs) + :param in_memory: Load data into memory or not + :param overwrite: Overwrite all documents or only new documents + """ + if isinstance(select, dict): + select = Serializable.decode(select) + if isinstance(select, Table): + select = select.to_query() + + self._prepare_select_for_predict(select, db) + if self.identifier not in db.show('model'): + logging.info(f'Adding model {self.identifier} to db') + assert isinstance(self, Component) + db.add(self) + assert isinstance( + self.version, int + ), 'Something has gone wrong setting `self.version`' + + if ids is None: + ids = self._get_ids_from_select( + X, select=select, db=db, overwrite=overwrite + ) return self._predict_with_select_and_ids( X=X, - db=db, - ids=ids, select=select, + ids=ids, + db=db, max_chunk_size=max_chunk_size, in_memory=in_memory, - **kwargs, ) + def _prepare_inputs_from_select( + self, + X: ModelInputType, + db: Datalayer, + select: CompoundSelect, + ids, + in_memory: bool = True, + ): + X_data: t.Any + mapping = Mapping(X, self.signature) + if in_memory: + if db is None: + raise ValueError('db cannot be None') + docs = list(db.execute(select.select_using_ids(ids))) + # TODO add signature to Mapping.__call__ + X_data = list(map(lambda x: mapping(x), docs)) + else: + # TODO above logic missing in case of not a string + # idea: add the concept of tuple and dictionary strings to `Document` + X_data = QueryDataset( + select=select, + ids=ids, + fold=None, + db=db, + in_memory=False, + mapping=mapping, + ) + if len(X_data) > len(ids): + raise Exception( + 'You\'ve specified more documents than unique ids;' + f' Is it possible that {select.table_or_collection.primary_id}' + f' isn\'t uniquely identifying?' + ) + return X_data, mapping + + @staticmethod + def handle_input_type(data, signature): + if signature == Signature.singleton: + return (data,), {} + elif signature == Signature.args: + return data, {} + elif signature == Signature.kwargs: + return (), data + elif signature == Signature.args_kwargs: + return data[0], data[1] + else: + raise ValueError( + f'Unexpected signature {data}: ' + f'Possible values {Signature.args_kwargs},' + f'{Signature.kwargs}, ' + f'{Signature.args}, ' + f'{Signature.singleton}.' + ) + raise Exception('Unexpected signature') + def _predict_with_select_and_ids( self, X: t.Any, db: Datalayer, - select: Select, + select: CompoundSelect, ids: t.List[str], in_memory: bool = True, max_chunk_size: t.Optional[int] = None, - **kwargs, ): if max_chunk_size is not None: it = 0 @@ -494,50 +586,34 @@ def _predict_with_select_and_ids( select=select, max_chunk_size=None, in_memory=in_memory, - **kwargs, ) it += 1 return - X_data: t.Any - if in_memory: - if db is None: - raise ValueError('db cannot be None') - docs = list(db.execute(select.select_using_ids(ids))) - if X == '_base': - X_data = [r.unpack() for r in docs] - elif isinstance(X, str): - X_data = [MongoStyleDict(r.unpack())[X] for r in docs] - else: - X_data = [] - for doc in docs: - doc = MongoStyleDict(doc.unpack()) - if isinstance(X, (tuple, list)): - X_data.append([doc[k] for k in X]) - elif isinstance(X, dict): - X_data.append([doc[k] for k in X.values()]) - else: - raise TypeError + dataset, mapping = self._prepare_inputs_from_select( + X=X, db=db, select=select, ids=ids, in_memory=in_memory + ) + outputs = self.predict(dataset) + outputs = self._encode_outputs(outputs) - else: - X_data = QueryDataset( - select=select, - ids=ids, - fold=None, - db=db, - in_memory=False, - keys=[X], - ) + logging.info(f'Adding {len(outputs)} model outputs to `db`') - if len(X_data) > len(ids): - raise Exception( - 'You\'ve specified more documents than unique ids;' - f' Is it possible that {select.table_or_collection.primary_id}' - f' isn\'t unique identifying?' - ) + assert isinstance( + self.version, int + ), 'Version has not been set, can\'t save outputs...' - outputs = self.predict(X=X_data, one=False, **kwargs) + select.model_update( + db=db, + model=self.identifier, + outputs=outputs, + key=mapping.id_key, + version=self.version, + ids=ids, + flatten=self.flatten, + **self.model_update_kwargs, + ) + def _encode_outputs(self, outputs): if isinstance(self.datatype, DataType): if self.flatten: outputs = [ @@ -546,160 +622,83 @@ def _predict_with_select_and_ids( else: outputs = [self.datatype(x).encode() for x in outputs] elif isinstance(self.output_schema, Schema): - encoded_ouputs = [] + encoded_outputs = [] for output in outputs: if isinstance(output, dict): - encoded_ouputs.append(self.output_schema(output)) + encoded_outputs.append(self.output_schema(output)) elif self.flatten: encoded_output = [self.output_schema(x) for x in output] - encoded_ouputs.append(encoded_output) - outputs = encoded_ouputs if encoded_ouputs else outputs + encoded_outputs.append(encoded_output) + outputs = encoded_outputs if encoded_outputs else outputs + return outputs - assert isinstance(self.version, int) - logging.info(f'Adding {len(outputs)} model outputs to `db`') - key = X - if isinstance(X, (tuple, list)): - key = ','.join(X) - elif isinstance(X, dict): - key = ','.join(list(X.values())) +@dc.dataclass(kw_only=True) +class _DeviceManaged: + preferred_devices: t.Sequence[str] = ('cuda', 'mps', 'cpu') + device: t.Optional[str] = None - select.model_update( - db=db, - model=self.identifier, - outputs=outputs, - key=key, - version=self.version, - ids=ids, - flatten=self.flatten, - **self.model_update_kwargs, - ) + def on_load(self, db: Datalayer) -> None: + if self.preferred_devices: + for i, device in enumerate(self.preferred_devices): + try: + self.to(device) + self.device = device + return + except Exception: + if i == len(self.preferred_devices) - 1: + raise + logging.info(f'Successfully mapped to {self.device}') + + @abstractmethod + def to(self, device): + pass + + +class Node: + def __init__(self, position): + self.position = position + + +@dc.dataclass +class IndexableNode: + def __init__(self, types): + self.types = types + + def __getitem__(self, item): + assert type(item) in self.types + return Node(item) @public_api(stability='stable') @dc.dataclass(kw_only=True) -class Model(_Predictor, Component): +class ObjectModel(_Predictor): """Model component which wraps a model to become serializable - {component_params} {_predictor_params} :param object: Model object, e.g. sklearn model, etc.. - :param model_to_device_method: The method to transfer the model to a device - :param metric_values: The metric values - :param predict_method: The method to use for prediction - :param model_update_kwargs: The kwargs to use for model update - :param serializer: Serializer to store model to artifact store - :param device: The device to use - :param preferred_devices: The preferred devices to use - :param training_configuration: The training configuration - :param train_X: The key of the input data to use for training - :param train_y: The key of the target data to use for training - :param train_select: The select to use for training + :param num_workers: Number of workers """ - __doc__ = __doc__.format( - component_params=Component.__doc__, - _predictor_params=_Predictor.__doc__, - ) + type_id: t.ClassVar[str] = 'model' + + __doc__ = __doc__.format(_predictor_params=_Predictor.__doc__) _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, 'DataType']]] = ( ('object', dill_serializer), ) object: t.Any - model_to_device_method: t.Optional[str] = None - metric_values: t.Optional[t.Dict] = dc.field(default_factory=dict) - predict_method: t.Optional[str] = None - model_update_kwargs: dict = dc.field(default_factory=dict) - device: str = "cpu" - preferred_devices: t.Union[None, t.Sequence[str]] = ("cuda", "mps", "cpu") - - training_configuration: t.Union[str, _TrainingConfiguration, None] = None - train_X: t.Optional[str] = None - train_y: t.Optional[str] = None - train_select: t.Optional[CompoundSelect] = None - - type_id: t.ClassVar[str] = 'model' - - def __post_init__(self, artifacts): - super().__post_init__(artifacts) - self._artifact_method = None - if self.model_to_device_method is not None: - self._artifact_method = getattr(self, self.model_to_device_method) - - def __repr__(self): - s = f'Identifier: {self.identifier}\nParams: {self.inputs.params}\nPreferred Devices: {self.preferred_devices}' - return border_msg(s, title='SuperDuper Model') - - def to_call(self, X, *args, **kwargs): - X = self.setup_required_inputs(X) - if self.predict_method is None: - return self.object(*args, **X, **kwargs) - out = getattr(self.object, self.predict_method)(*args, **X, **kwargs) - return out - - def post_create(self, db: Datalayer) -> None: - if isinstance(self.training_configuration, str): - self.training_configuration = db.load( - 'training_configuration', self.training_configuration - ) # type: ignore[assignment] - # TODO is this necessary - should be handled by `db.add` automatically - if isinstance(self.output_schema, Schema): - db.add(self.output_schema) - output_component = db.databackend.create_model_table_or_collection(self) - if output_component is not None: - db.add(output_component) + num_workers: int = 0 + signature: str = Signature.args_kwargs # type: ignore[misc] - # TODO - bring inside post_create - @override - def schedule_jobs( - self, - db: Datalayer, - dependencies: t.Sequence[Job] = (), - verbose: bool = False, - ) -> t.Sequence[t.Any]: - jobs = [] - if self.train_X is not None: - assert ( - isinstance(self.training_configuration, _TrainingConfiguration) - or self.training_configuration is None - ) - assert self.train_select is not None - jobs.append( - self.fit( - X=self.train_X, - y=self.train_y, - configuration=self.training_configuration, - select=self.train_select, - db=db, - dependencies=dependencies, - metrics=self.metrics, # type: ignore[arg-type] - validation_sets=self.validation_sets, - ) - ) - if self.predict_X is not None: - assert self.predict_select is not None - jobs.append( - self.predict( - X=self.predict_X, - select=self.predict_select, - max_chunk_size=self.predict_max_chunk_size, - db=db, - **(self.predict_kwargs or {}), - ) - ) - return jobs + @property + def outputs(self): + return IndexableNode([int]) - def on_load(self, db: Datalayer) -> None: - logging.debug(f'Calling on_load method of {self}') - if self._artifact_method and self.preferred_devices: - for i, device in enumerate(self.preferred_devices): - try: - self._artifact_method(device) - self.device = device - return - except Exception: - if i == len(self.preferred_devices) - 1: - raise + @property + def inputs(self): + kwargs = self.predict_kwargs if self.predict_kwargs else {} + return CallableInputs(self.object, kwargs) @property def training_keys(self) -> t.List: @@ -722,6 +721,12 @@ def append_metrics(self, d: t.Dict[str, float]) -> None: def validate( self, db, validation_set: t.Union[Dataset, str], metrics: t.Sequence[Metric] ): + """ + Validate model on `db` and validation set. + + :param db: `db` SuperDuperDB instance + :param validation_set: Dataset on which to validate. + """ db.add(self) out = self._validate(db, validation_set, metrics) if self.metric_values is None: @@ -735,144 +740,33 @@ def validate( value=self.metric_values, ) - def pre_create(self, db: Datalayer): - # TODO this kind of thing should come from an enum component_types.datatype - # that will make refactors etc. easier - if isinstance(self.datatype, str): - # ruff: noqa: E501 - self.datatype = db.load('datatype', self.datatype) # type: ignore[assignment] - - def _validate( - self, - db: Datalayer, - validation_set: t.Union[Dataset, str], - metrics: t.Sequence[Metric], - ): - if isinstance(validation_set, str): - from superduperdb.components.dataset import Dataset - - validation_set = t.cast(Dataset, db.load('dataset', validation_set)) - - mdicts = [MongoStyleDict(r.unpack()) for r in validation_set.data] - assert self.train_X is not None - prediction = self._predict([d[self.train_X] for d in mdicts]) - assert self.train_y is not None - target = [d[self.train_y] for d in mdicts] - assert isinstance(prediction, list) - assert isinstance(target, list) - results = {} - - for m in metrics: - out = m(prediction, target) - results[f'{validation_set.identifier}/{m.identifier}'] = out - return results - - def create_fit_job( - self, - X: t.Union[str, t.Sequence[str]], - select: t.Optional[Select] = None, - y: t.Optional[str] = None, - **kwargs, - ): - return ComponentJob( - component_identifier=self.identifier, - method_name='fit', - type_id='model', - args=[X], - kwargs={ - 'y': y, - 'select': select.dict().encode() if select else None, - **kwargs, - }, - ) - - def _fit( - self, - X: t.Any, - y: t.Optional[t.Any] = None, - configuration: t.Optional[_TrainingConfiguration] = None, - data_prefetch: bool = False, - db: t.Optional[Datalayer] = None, - metrics: t.Optional[t.Sequence[Metric]] = None, - select: t.Optional[Select] = None, - validation_sets: t.Optional[t.Sequence[t.Union[str, Dataset]]] = None, - ): - raise NotImplementedError - - def fit( - self, - X: t.Any, - y: t.Optional[t.Any] = None, - configuration: t.Optional[_TrainingConfiguration] = None, - data_prefetch: bool = False, - db: t.Optional[Datalayer] = None, - dependencies: t.Sequence[Job] = (), - metrics: t.Optional[t.Sequence[Metric]] = None, - select: t.Optional[Select] = None, - validation_sets: t.Optional[t.Sequence[t.Union[str, Dataset]]] = None, - **kwargs, - ) -> t.Optional[Pipeline]: - """ - Fit the model on the given data. - - :param X: The key of the input data to use for training - :param y: The key of the target data to use for training - :param configuration: The training configuration (optional) - :param data_prefetch: Whether to prefetch the data (optional) - :param db: The datalayer (optional) - :param dependencies: The dependencies (optional) - :param metrics: The metrics to evaluate on (optional) - :param select: The select to use for training (optional) - :param validation_sets: The validation ``Dataset`` instances to use (optional) - """ - if isinstance(select, dict): - # TODO replace with Document.decode(select) - select = Serializable.from_dict(select) - - if validation_sets: - from superduperdb.components.dataset import Dataset - - validation_sets = list(validation_sets) - for i, vs in enumerate(validation_sets): - if isinstance(vs, Dataset): - assert db is not None - db.add(vs) - validation_sets[i] = vs.identifier - - self.training_configuration = configuration or self.training_configuration + def _wrapper(self, data): + args, kwargs = self.handle_input_type(data, self.signature) + return self.object(*args, **kwargs) - if db is not None: - db.add(self) + def predict_one(self, *args, **kwargs): + return self.object(*args, **kwargs) - if db is not None and db.compute.type == 'distributed': - return self.create_fit_job( - X, - select=select, - y=y, - **kwargs, - )(db=db, dependencies=dependencies) + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + outputs = [] + if self.num_workers: + pool = multiprocessing.Pool(processes=self.num_workers) + for r in pool.map(self._wrapper, dataset): # type: ignore[arg-type] + outputs.append(r) + pool.close() + pool.join() else: - return self._fit( - X, - y=y, - configuration=configuration, - data_prefetch=data_prefetch, - db=db, - metrics=metrics, - select=select, - validation_sets=validation_sets, - **kwargs, - ) + for i in range(len(dataset)): + outputs.append(self._wrapper(dataset[i])) + return outputs -@wraps(_TrainingConfiguration) -def TrainingConfiguration(identifier: str, **kwargs): - return _TrainingConfiguration(identifier=identifier, kwargs=kwargs) +Model = ObjectModel @public_api(stability='beta') @dc.dataclass(kw_only=True) -class APIModel(Component, _Predictor): +class APIModel(_Predictor): '''{component_params} {predictor_params} :param model: The model to use, e.g. ``'text-embedding-ada-002'``''' @@ -900,31 +794,10 @@ def post_create(self, db: Datalayer) -> None: if output_component is not None: db.add(output_component) - @override - def schedule_jobs( - self, - db: Datalayer, - dependencies: t.Sequence[Job] = (), - verbose: bool = False, - ) -> t.Sequence[t.Any]: - jobs = [] - if self.predict_X is not None: - assert self.predict_select is not None - jobs.append( - self.predict( - X=self.predict_X, - select=self.predict_select, - max_chunk_size=self.predict_max_chunk_size, - db=db, - **(self.predict_kwargs or {}), - ) - ) - return jobs - @public_api(stability='stable') @dc.dataclass(kw_only=True) -class QueryModel(Component, _Predictor): +class QueryModel(_Predictor): """ Model which can be used to query data and return those results as pre-computed queries. @@ -932,78 +805,47 @@ class QueryModel(Component, _Predictor): :param select: query used to find data (can include `like`) """ + preprocess: t.Optional[t.Callable] = None + postprocess: t.Optional[t.Callable] = None select: CompoundSelect - def schedule_jobs( - self, - db: Datalayer, - dependencies: t.Sequence[Job] = (), - verbose: bool = False, - ) -> t.Sequence[t.Any]: - jobs = [] - if self.predict_X is not None: - assert self.predict_select is not None - jobs.append( - self.predict( - X=self.predict_X, - select=self.predict_select, - max_chunk_size=self.predict_max_chunk_size, - db=db, - **(self.predict_kwargs or {}), - ) - ) - return jobs - - def _predict_one(self, X: t.Any, **kwargs): - select = self.select.set_variables(db=self.db, X=X) + @property + def inputs(self) -> Inputs: + if self.preprocess is not None: + return CallableInputs(self.preprocess) + return Inputs([x.value for x in self.select.variables]) + + def predict_one(self, X: t.Dict): + if self.preprocess is not None: + X = self.preprocess(X) + select = self.select.set_variables(db=self.db, **X) out = self.db.execute(select) if self.postprocess is not None: return self.postprocess(out) return out - def _predict(self, X: t.Any, one: bool = False, **predict_kwargs): - if one: - return self._predict_one(X, **predict_kwargs) - return [self._predict_one(x, **predict_kwargs) for x in X] + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + return [self.predict_one(dataset[i]) for i in range(len(dataset))] @public_api(stability='stable') @dc.dataclass(kw_only=True) -class SequentialModel(Component, _Predictor): +class SequentialModel(_Predictor): """ Sequential model component which wraps a model to become serializable - {component_params} {_predictor_params} :param predictors: A list of predictors to use """ __doc__ = __doc__.format( - component_params=Component.__doc__, _predictor_params=_Predictor.__doc__, ) - predictors: t.List[t.Union[str, Model, APIModel]] + predictors: t.List[_Predictor] - @override - def schedule_jobs( - self, - db: Datalayer, - dependencies: t.Sequence[Job] = (), - verbose: bool = False, - ) -> t.Sequence[t.Any]: - jobs = [] - if self.predict_X is not None: - assert self.predict_select is not None - jobs.append( - self.predict( - X=self.predict_X, - select=self.predict_select, - max_chunk_size=self.predict_max_chunk_size, - db=db, - **(self.predict_kwargs or {}), - ) - ) - return jobs + @property + def inputs(self) -> Inputs: + return self.predictors[0].inputs def post_create(self, db: Datalayer): for p in self.predictors: @@ -1015,11 +857,16 @@ def post_create(self, db: Datalayer): def on_load(self, db: Datalayer): for i, p in enumerate(self.predictors): if isinstance(p, str): - self.predictors[i] = db.load('model', p) # type: ignore[call-overload] + self.predictors[i] = db.load('model', p) - def _predict(self, X: t.Any, one: bool = False, **predict_kwargs): - out = X - for p in self.predictors: + def predict_one(self, *args, **kwargs): + return self.predict([(args, kwargs)])[0] + + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + for i, p in enumerate(self.predictors): assert isinstance(p, _Predictor) - out = p._predict(out, one=one, **predict_kwargs) + if i == 0: + out = p.predict(dataset) + else: + out = p.predict(out) return out diff --git a/superduperdb/components/schema.py b/superduperdb/components/schema.py index 3e48a4f68a..5925ee86f4 100644 --- a/superduperdb/components/schema.py +++ b/superduperdb/components/schema.py @@ -2,6 +2,8 @@ import typing as t from functools import cached_property +from overrides import override + from superduperdb.base.configs import CFG from superduperdb.components.component import Component from superduperdb.components.datatype import DataType @@ -29,6 +31,7 @@ def __post_init__(self, artifacts): assert self.identifier is not None, 'Schema must have an identifier' assert self.fields is not None, 'Schema must have fields' + @override def pre_create(self, db) -> None: for v in self.fields.values(): if isinstance(v, DataType): diff --git a/superduperdb/components/vector_index.py b/superduperdb/components/vector_index.py index ba264fb087..f83f444a9b 100644 --- a/superduperdb/components/vector_index.py +++ b/superduperdb/components/vector_index.py @@ -9,6 +9,7 @@ from superduperdb.components.component import Component from superduperdb.components.datatype import DataType from superduperdb.components.listener import Listener +from superduperdb.components.model import Mapping, ModelInputType from superduperdb.ext.utils import str_shape from superduperdb.misc.annotations import public_api from superduperdb.misc.special_dicts import MongoStyleDict @@ -33,8 +34,8 @@ class VectorIndex(Component): type_id: t.ClassVar[str] = 'vector_index' - indexing_listener: t.Union[Listener, str] - compatible_listener: t.Union[None, Listener, str] = None + indexing_listener: Listener + compatible_listener: t.Optional[Listener] = None measure: VectorIndexMeasureType = VectorIndexMeasureType.cosine metric_values: t.Optional[t.Dict] = dc.field(default_factory=dict) @@ -93,18 +94,11 @@ def get_vector( f'VectorIndex keys: {keys}, with model: {models}' ) - if isinstance(key, str): - model_input = document[key] - elif isinstance(key, (tuple, list)): - model_input = [document[k] for k in list(key)] - elif isinstance(key, dict): - model_input = [document[k] for k in key.values()] - else: - model_input = document - model = db.models[model_name] + data = Mapping(key, model.signature)(document) + args, kwargs = model.handle_input_type(data, model.signature) return ( - model.predict(model_input, one=True), + model.predict_one(*args, **kwargs), model.identifier, key, ) @@ -150,7 +144,7 @@ def get_nearest( ) @property - def models_keys(self) -> t.Tuple[t.List[str], t.List[KeyType]]: + def models_keys(self) -> t.Tuple[t.List[str], t.List[ModelInputType]]: """ Return a list of model and keys for each listener """ @@ -162,7 +156,7 @@ def models_keys(self) -> t.Tuple[t.List[str], t.List[KeyType]]: else: listeners = [self.indexing_listener] - models = [w.model.identifier for w in listeners] # type: ignore[union-attr] + models = [w.model.identifier for w in listeners] keys = [w.key for w in listeners] return models, keys diff --git a/superduperdb/ext/anthropic/model.py b/superduperdb/ext/anthropic/model.py index 5a155de707..2c04def07f 100644 --- a/superduperdb/ext/anthropic/model.py +++ b/superduperdb/ext/anthropic/model.py @@ -6,6 +6,7 @@ from superduperdb.backends.ibis.data_backend import IbisDataBackend from superduperdb.backends.ibis.field_types import dtype +from superduperdb.backends.query_dataset import QueryDataset from superduperdb.base.datalayer import Datalayer from superduperdb.components.model import APIModel from superduperdb.ext.utils import format_prompt, get_key @@ -45,39 +46,14 @@ def pre_create(self, db: Datalayer) -> None: self.datatype = dtype('str') @retry - def _predict_one(self, X, context: t.Optional[t.List[str]] = None, **kwargs): + def predict_one( + self, prompt: str, context: t.Optional[t.List[str]] = None, **kwargs + ): if context is not None: - X = format_prompt(X, self.prompt, context=context) + prompt = format_prompt(prompt, self.prompt, context=context) client = anthropic.Anthropic(api_key=get_key(KEY_NAME), **self.client_kwargs) - resp = client.completions.create(prompt=X, model=self.identifier, **kwargs) - return resp.completion - - @retry - async def _apredict_one(self, X, context: t.Optional[t.List[str]] = None, **kwargs): - if context is not None: - X = format_prompt(X, self.prompt, context=context) - client = anthropic.AsyncAnthropic( - api_key=get_key(KEY_NAME), **self.client_kwargs - ) - resp = await client.completions.create( - prompt=X, model=self.identifier, **kwargs - ) + resp = client.completions.create(prompt=prompt, model=self.identifier, **kwargs) return resp.completion - def _predict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - if context: - assert one, 'context only works with ``one=True``' - if one: - return self._predict_one(X, context=context, **kwargs) - return [self._predict_one(msg) for msg in X] - - async def _apredict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - if context: - assert one, 'context only works with ``one=True``' - if one: - return await self._apredict_one(X, context=context, **kwargs) - return [await self._apredict_one(msg) for msg in X] + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + return [self.predict_one(dataset[i]) for i in range(len(dataset))] diff --git a/superduperdb/ext/cohere/model.py b/superduperdb/ext/cohere/model.py index 173bddc5a2..ef78b756fd 100644 --- a/superduperdb/ext/cohere/model.py +++ b/superduperdb/ext/cohere/model.py @@ -7,6 +7,7 @@ from superduperdb.backends.ibis.data_backend import IbisDataBackend from superduperdb.backends.ibis.field_types import dtype +from superduperdb.backends.query_dataset import QueryDataset from superduperdb.base.datalayer import Datalayer from superduperdb.components.model import APIModel from superduperdb.components.vector_index import sqlvector, vector @@ -36,8 +37,10 @@ class CohereEmbed(Cohere): :param shape: The shape as ``tuple`` of the embedding. """ + signature: t.ClassVar[str] = 'singleton' shapes: t.ClassVar[t.Dict] = {'embed-english-v2.0': (4096,)} shape: t.Optional[t.Sequence[int]] = None + batch_size: int = 100 def __post_init__(self, artifacts): super().__post_init__(artifacts) @@ -53,47 +56,25 @@ def pre_create(self, db): self.datatype = vector(self.shape) @retry - def _predict_one(self, X: str, **kwargs): + def predict_one(self, X: str): client = cohere.Client(get_key(KEY_NAME), **self.client_kwargs) - e = client.embed(texts=[X], model=self.identifier, **kwargs) + e = client.embed(texts=[X], model=self.identifier, **self.predict_kwargs) return e.embeddings[0] @retry - async def _apredict_one(self, X: str, **kwargs): - client = cohere.AsyncClient(get_key(KEY_NAME), **self.client_kwargs) - e = await client.embed(texts=[X], model=self.identifier, **kwargs) - await client.close() - return e.embeddings[0] - - @retry - def _predict_a_batch(self, texts: t.List[str], **kwargs): + def _predict_a_batch(self, texts: t.List[str]): client = cohere.Client(get_key(KEY_NAME), **self.client_kwargs) - out = client.embed(texts=texts, model=self.identifier, **kwargs) - return [r for r in out.embeddings] - - @retry - async def _apredict_a_batch(self, texts: t.List[str], **kwargs): - client = cohere.AsyncClient(get_key(KEY_NAME), **self.client_kwargs) - out = await client.embed(texts=texts, model=self.identifier, **kwargs) - await client.close() + out = client.embed(texts=texts, model=self.identifier, **self.predict_kwargs) return [r for r in out.embeddings] - def _predict(self, X, one=False, **kwargs): - if isinstance(X, str): - return self._predict_one(X) + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: out = [] - batch_size = kwargs.pop('batch_size', 100) - for i in tqdm.tqdm(range(0, len(X), batch_size)): - out.extend(self._predict_a_batch(X[i : i + batch_size], **kwargs)) - return out - - async def _apredict(self, X, one=False, **kwargs): - if isinstance(X, str): - return await self._apredict_one(X) - out = [] - batch_size = kwargs.pop('batch_size', 100) - for i in range(0, len(X), batch_size): - out.extend(await self._apredict_a_batch(X[i : i + batch_size], **kwargs)) + for i in tqdm.tqdm(range(0, len(dataset), self.batch_size)): + out.extend( + self._predict_a_batch( + dataset[i : i + self.batch_size], **self.predict_kwargs + ) + ) return out @@ -105,6 +86,7 @@ class CohereGenerate(Cohere): :param prompt: The prompt to use to seed the response. """ + signature: t.ClassVar[str] = '*args,**kwargs' takes_context: bool = True prompt: str = '' @@ -114,36 +96,15 @@ def pre_create(self, db: Datalayer) -> None: self.datatype = dtype('str') @retry - def _predict_one(self, X, context: t.Optional[t.List[str]] = None, **kwargs): + def predict_one(self, prompt: str, context: t.Optional[t.List[str]] = None): if context is not None: - X = format_prompt(X, self.prompt, context=context) + prompt = format_prompt(prompt, self.prompt, context=context) client = cohere.Client(get_key(KEY_NAME), **self.client_kwargs) - resp = client.generate(prompt=X, model=self.identifier, **kwargs) + resp = client.generate( + prompt=prompt, model=self.identifier, **self.predict_kwargs + ) return resp.generations[0].text @retry - async def _apredict_one(self, X, context: t.Optional[t.List[str]] = None, **kwargs): - if context is not None: - X = format_prompt(X, self.prompt, context=context) - client = cohere.AsyncClient(get_key(KEY_NAME), **self.client_kwargs) - resp = await client.generate(prompt=X, model=self.identifier, **kwargs) - await client.close() - return resp.generations[0].text - - def _predict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - if context: - assert one, 'context only works with ``one=True``' - if one: - return self._predict_one(X, context=context, **kwargs) - return [self._predict_one(msg) for msg in X] - - async def _apredict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - if context: - assert one, 'context only works with ``one=True``' - if one: - return await self._apredict_one(X, context=context, **kwargs) - return [await self._apredict_one(msg) for msg in X] + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + return [self.predict_one(dataset[i]) for i in range(len(dataset))] diff --git a/superduperdb/ext/jina/model.py b/superduperdb/ext/jina/model.py index 9e0e263fbb..65e6b73444 100644 --- a/superduperdb/ext/jina/model.py +++ b/superduperdb/ext/jina/model.py @@ -4,6 +4,7 @@ import tqdm from superduperdb.backends.ibis.data_backend import IbisDataBackend +from superduperdb.backends.query_dataset import QueryDataset from superduperdb.components.model import APIModel from superduperdb.components.vector_index import sqlvector, vector from superduperdb.ext.jina.client import JinaAPIClient @@ -29,6 +30,9 @@ class JinaEmbedding(Jina): If not provided, it will be obtained by sending a simple query to the API """ + batch_size: int = 100 + signature: t.ClassVar[str] = 'singleton' + shape: t.Optional[t.Sequence[int]] = None def __post_init__(self, artifacts): @@ -44,33 +48,17 @@ def pre_create(self, db): elif self.datatype is None: self.datatype = vector(self.shape) - def _predict_one(self, X: str, **kwargs): + def predict_one(self, X: str): return self.client.encode_batch([X])[0] - async def _apredict_one(self, X: str, **kwargs): - embeddings = await self.client.aencode_batch([X]) - return embeddings[0] - - def _predict_a_batch(self, texts: t.List[str], **kwargs): + def _predict_a_batch(self, texts: t.List[str]): return self.client.encode_batch(texts) - async def _apredict_a_batch(self, texts: t.List[str], **kwargs): - return await self.client.aencode_batch(texts) - - def _predict(self, X, one=False, **kwargs): - if isinstance(X, str): - return self._predict_one(X) - out = [] - batch_size = kwargs.pop('batch_size', 100) - for i in tqdm.tqdm(range(0, len(X), batch_size)): - out.extend(self._predict_a_batch(X[i : i + batch_size], **kwargs)) - return out - - async def _apredict(self, X, one=False, **kwargs): - if isinstance(X, str): - return await self._apredict_one(X) + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: out = [] - batch_size = kwargs.pop('batch_size', 100) - for i in range(0, len(X), batch_size): - out.extend(await self._apredict_a_batch(X[i : i + batch_size], **kwargs)) + for i in tqdm.tqdm(range(0, len(dataset), self.batch_size)): + batch = [ + dataset[i] for i in range(i, min(i + self.batch_size, len(dataset))) + ] + out.extend(self._predict_a_batch(batch)) return out diff --git a/superduperdb/ext/llamacpp/model.py b/superduperdb/ext/llamacpp/model.py index b746a4e802..bd1ed1265c 100644 --- a/superduperdb/ext/llamacpp/model.py +++ b/superduperdb/ext/llamacpp/model.py @@ -8,7 +8,14 @@ from superduperdb.ext.llm.base import _BaseLLM +# TODO use core downloader already implemented def download_uri(uri, save_path): + """ + Download file + + :param uri: URI to download + :param save_path: place to save + """ response = requests.get(uri) if response.status_code == 200: with open(save_path, 'wb') as file: @@ -19,8 +26,18 @@ def download_uri(uri, save_path): @dc.dataclass class LlamaCpp(_BaseLLM): + """ + Llama.cpp connector + + :param model_name_or_path: path or name of model + :param model_kwargs: dictionary of init-kwargs + :param download_dir: local caching directory + :param signature: s + """ + + signature: t.ClassVar[str] = 'singleton' + model_name_or_path: str = "facebook/opt-125m" - object: t.Optional[Llama] = None model_kwargs: t.Dict = dc.field(default_factory=dict) download_dir: str = '.llama_cpp' diff --git a/superduperdb/ext/llm/base.py b/superduperdb/ext/llm/base.py index 5ebedd542b..9d280a94dc 100644 --- a/superduperdb/ext/llm/base.py +++ b/superduperdb/ext/llm/base.py @@ -5,10 +5,10 @@ import typing from functools import reduce from logging import WARNING, getLogger -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Sequence, Union from superduperdb import logging -from superduperdb.components.component import Component +from superduperdb.backends.query_dataset import QueryDataset from superduperdb.components.model import _Predictor from superduperdb.ext.llm.utils import Prompter from superduperdb.ext.utils import ensure_initialized @@ -21,7 +21,7 @@ @dc.dataclass -class _BaseLLM(Component, _Predictor, metaclass=abc.ABCMeta): +class _BaseLLM(_Predictor, metaclass=abc.ABCMeta): """ :param prompt_template: The template to use for the prompt. :param prompt_func: The function to use for the prompt. @@ -69,31 +69,22 @@ def init(self): def _generate(self, prompt: str, **kwargs: Any) -> str: ... - def _batch_generate(self, prompts: List[str], **kwargs: Any) -> List[str]: + def _batch_generate(self, prompts: List[str]) -> List[str]: """ Base method to batch generate text from a list of prompts. If the model can run batch generation efficiently, pls override this method. """ - return [self._generate(prompt, **kwargs) for prompt in prompts] + return [self._generate(prompt, **self.predict_kwargs) for prompt in prompts] @ensure_initialized - def _predict( - self, - X: Union[str, List[str], List[dict[str, str]]], - one: bool = False, - **kwargs: Any, - ): - # support string and dialog format - one = isinstance(X, str) - if not one and isinstance(X, list): - one = isinstance(X[0], dict) - - if one: - x = self.prompter(X, **kwargs) - return self._generate(x, **kwargs) - else: - xs = [self.prompter(x, **kwargs) for x in X] - return self._batch_generate(xs, **kwargs) + def predict_one(self, X: Union[str, dict[str, str]], **kwargs): + x = self.prompter(X) + return self._generate(x, **kwargs) + + @ensure_initialized + def predict(self, dataset: Union[List, QueryDataset]) -> Sequence: + xs = [self.prompter(dataset[i]) for i in range(len(dataset))] + return self._batch_generate(xs) def get_kwargs(self, func, *kwargs_list): """ diff --git a/superduperdb/ext/llm/model.py b/superduperdb/ext/llm/model.py index 815f482853..2da334fe55 100644 --- a/superduperdb/ext/llm/model.py +++ b/superduperdb/ext/llm/model.py @@ -10,10 +10,11 @@ ) from superduperdb import logging -from superduperdb.backends.query_dataset import query_dataset_factory +from superduperdb.backends.query_dataset import QueryDataset, query_dataset_factory from superduperdb.components.dataset import Dataset as _Dataset from superduperdb.components.model import ( - Model, + _Fittable, + _Predictor, _TrainingConfiguration, ) from superduperdb.ext.llm import training @@ -35,7 +36,7 @@ def LLMTrainingConfiguration(identifier: str, **kwargs) -> _TrainingConfiguratio @dc.dataclass -class LLM(Model): +class LLM(_Predictor, _Fittable): """ LLM model based on `transformers` library. Parameters: @@ -65,6 +66,7 @@ class LLM(Model): tokenizer_kwargs: t.Dict = dc.field(default_factory=dict) prompt_template: str = "{input}" prompt_func: t.Optional[t.Callable] = None + signature: str = 'singleton' # type: ignore[misc] # Save models and tokenizers cache for sharing when using multiple models _model_cache: t.ClassVar[dict] = {} @@ -122,7 +124,7 @@ def init(self): self.prompter = Prompter(self.prompt_template, self.prompt_func) self.model, self.tokenizer = self.init_model_and_tokenizer() if self.adapter_id is not None: - self.add_adapter(self.adapter_id.artifact, self.adapter_id.artifact) + self.add_adapter(self.adapter_id, self.adapter_id) def _fit( self, @@ -185,21 +187,15 @@ def compute_metrics(eval_preds): return compute_metrics @ensure_initialized - def _predict( - self, - X: t.Union[str, t.List[str], t.List[dict[str, str]]], - one: bool = False, - **kwargs: t.Any, - ): - # support string and dialog format - one = isinstance(X, str) - if not one and isinstance(X, list): - one = isinstance(X[0], dict) + def predict_one(self, X): + X = self.prompter(X) + results = self._generate([X], **self.predict_kwargs) + return results[0] - xs = [X] if one else X - xs = [self.prompter(x, **kwargs) for x in xs] - results = self._generate(xs, **kwargs) - return results[0] if one else results + @ensure_initialized + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + X = [self.prompter(dataset[i]) for i in range(len(dataset))] + return self._generate(X, **self.predict_kwargs) def _generate(self, X: t.Any, adapter_name=None, **kwargs): """ @@ -262,6 +258,7 @@ def add_adapter(self, model_id, adapter_name: str): # Update cache model self._model_cache[hash(self.model_kwargs)] = self.model else: + # TODO where does this come from? self.model.load_adapter(model_id, adapter_name) def get_datasets( diff --git a/superduperdb/ext/openai/model.py b/superduperdb/ext/openai/model.py index a2c3a3cca4..a08be7f90c 100644 --- a/superduperdb/ext/openai/model.py +++ b/superduperdb/ext/openai/model.py @@ -1,27 +1,25 @@ -import asyncio import base64 import dataclasses as dc -import itertools import json import os import typing as t -import aiohttp import requests import tqdm from httpx import ResponseNotRead from openai import ( APITimeoutError, - AsyncOpenAI, InternalServerError, OpenAI as SyncOpenAI, RateLimitError, ) +from openai._types import NOT_GIVEN from superduperdb.backends.ibis.data_backend import IbisDataBackend from superduperdb.backends.ibis.field_types import dtype +from superduperdb.backends.query_dataset import QueryDataset from superduperdb.base.datalayer import Datalayer -from superduperdb.components.model import APIModel +from superduperdb.components.model import APIModel, Inputs from superduperdb.components.vector_index import sqlvector, vector from superduperdb.misc.compat import cache from superduperdb.misc.retry import Retry @@ -47,7 +45,6 @@ def _available_models(skwargs): class _OpenAI(APIModel): ''' :param client_kwargs: The kwargs to be passed to OpenAI - ''' client_kwargs: t.Optional[dict] = dc.field(default_factory=dict) @@ -66,7 +63,6 @@ def __post_init__(self, artifacts): raise ValueError(msg) self.syncClient = SyncOpenAI(**self.client_kwargs) - self.asyncClient = AsyncOpenAI(**self.client_kwargs) if 'OPENAI_API_KEY' not in os.environ and ( 'api_key' not in self.client_kwargs.keys() and self.client_kwargs @@ -76,6 +72,15 @@ def __post_init__(self, artifacts): 'nor in `client_kwargs`' ) + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + out = [] + for i in tqdm.tqdm(range(0, len(dataset), self.batch_size)): + batch = [ + dataset[i] for i in range(i, min(len(dataset), i + self.batch_size)) + ] + out.extend(self._predict_a_batch(batch)) + return out + @dc.dataclass(kw_only=True) class OpenAIEmbedding(_OpenAI): @@ -89,11 +94,17 @@ class OpenAIEmbedding(_OpenAI): shape: t.Optional[t.Sequence[int]] = None shapes: t.ClassVar[t.Dict] = {'text-embedding-ada-002': (1536,)} + signature: t.ClassVar[str] = 'singleton' + batch_size: int = 100 + + @property + def inputs(self): + return Inputs(['input']) def __post_init__(self, artifacts): super().__post_init__(artifacts) if self.shape is None: - self.shape = self.shapes[self.identifier] + self.shape = self.shapes[self.model] def pre_create(self, db): super().pre_create(db) @@ -104,48 +115,19 @@ def pre_create(self, db): self.datatype = vector(self.shape) @retry - def _predict_one(self, X: str, **kwargs): - e = self.syncClient.embeddings.create(input=X, model=self.model, **kwargs) - return e.data[0].embedding - - @retry - async def _apredict_one(self, X: str, **kwargs): - e = await self.asyncClient.embeddings.create( - input=X, model=self.model, **kwargs + def predict_one(self, X: str): + e = self.syncClient.embeddings.create( + input=X, model=self.model, **self.predict_kwargs ) return e.data[0].embedding @retry - def _predict_a_batch(self, texts: t.List[str], **kwargs): - out = self.syncClient.embeddings.create(input=texts, model=self.model, **kwargs) - return [r.embedding for r in out.data] - - @retry - async def _apredict_a_batch(self, texts: t.List[str], **kwargs): - out = await self.asyncClient.embeddings.create( - input=texts, model=self.model, **kwargs + def _predict_a_batch(self, texts: t.List[t.Dict]): + out = self.syncClient.embeddings.create( + input=texts, model=self.model, **self.predict_kwargs ) return [r.embedding for r in out.data] - def _predict(self, X, one: bool = False, **kwargs): - if isinstance(X, str): - return self._predict_one(X) - out = [] - batch_size = kwargs.pop('batch_size', 100) - for i in tqdm.tqdm(range(0, len(X), batch_size)): - out.extend(self._predict_a_batch(X[i : i + batch_size], **kwargs)) - return out - - async def _apredict(self, X, one: bool = False, **kwargs): - if isinstance(X, str): - return await self._apredict_one(X) - out = [] - batch_size = kwargs.pop('batch_size', 100) - # Note: we submit the async requests in serial to avoid rate-limiting - for i in range(0, len(X), batch_size): - out.extend(await self._apredict_a_batch(X[i : i + batch_size], **kwargs)) - return out - @dc.dataclass(kw_only=True) class OpenAIChatCompletion(_OpenAI): @@ -154,10 +136,16 @@ class OpenAIChatCompletion(_OpenAI): :param prompt: The prompt to use to seed the response. """ + signature: t.ClassVar[str] = 'singleton' __doc__ = __doc__.format(_openai_parameters=_OpenAI.__doc__) + batch_size: int = 1 prompt: str = '' + @property + def inputs(self): + return Inputs(['content', 'context']) + def __post_init__(self, artifacts): super().__post_init__(artifacts) self.takes_context = True @@ -172,52 +160,27 @@ def pre_create(self, db: Datalayer) -> None: self.datatype = dtype('str') @retry - def _predict_one(self, X, context: t.Optional[t.List[str]] = None, **kwargs): + def predict_one(self, X: str, context: t.Optional[str] = None): if context is not None: X = self._format_prompt(context, X) return ( self.syncClient.chat.completions.create( messages=[{'role': 'user', 'content': X}], model=self.model, - **kwargs, + **self.predict_kwargs, ) .choices[0] .message.content ) - @retry - async def _apredict_one(self, X, context: t.Optional[t.List[str]] = None, **kwargs): - if context is not None: - X = self._format_prompt(context, X) - return ( - ( - await self.asyncClient.chat.completions.create( - messages=[{'role': 'user', 'content': X}], - model=self.model, - **kwargs, - ) + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + out = [] + for i in range(len(dataset)): + args, kwargs = self.handle_input_type( + data=dataset[i], signature=self.signature ) - .choices[0] - .message.content - ) - - def _predict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - if context: - assert one, 'context only works with ``one=True``' - if one: - return self._predict_one(X, context=context, **kwargs) - return [self._predict_one(msg) for msg in X] - - async def _apredict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - if context: - assert one, 'context only works with ``one=True``' - if one: - return await self._apredict_one(X, context=context, **kwargs) - return [await self._apredict_one(msg) for msg in X] + out.append(self.predict_one(*args, **kwargs)) + return out @dc.dataclass(kw_only=True) @@ -228,10 +191,14 @@ class OpenAIImageCreation(_OpenAI): :param prompt: The prompt to use to seed the response. """ + signature: t.ClassVar[str] = 'singleton' + __doc__ = __doc__.format(_openai_parameters=_OpenAI.__doc__) takes_context: bool = True prompt: str = '' + n: int = 1 + response_format: str = 'b64_json' def pre_create(self, db: Datalayer) -> None: super().pre_create(db) @@ -243,89 +210,35 @@ def _format_prompt(self, context, X): return prompt + X @retry - def _predict_one( - self, - X, - n: int, - response_format: str, - context: t.Optional[t.List[str]] = None, - **kwargs, - ): - if context is not None: - X = self._format_prompt(context, X) - if response_format == 'b64_json': - b64_json = ( - self.syncClient.images.generate( - prompt=X, n=n, response_format='b64_json' - ) - .data[0] - .b64_json - ) - return base64.b64decode(b64_json) - else: - url = self.syncClient.images.generate(prompt=X, n=n, **kwargs).data[0].url - return requests.get(url).content - - @retry - async def _apredict_one( - self, - X, - n: int, - response_format: str, - context: t.Optional[t.List[str]] = None, - **kwargs, - ): - if context is not None: - X = self._format_prompt(context, X) - if response_format == 'b64_json': - b64_json = ( - ( - await self.asyncClient.images.generate( - prompt=X, n=n, response_format='b64_json' - ) - ) - .data[0] - .b64_json + def predict_one(self, X: str): + if self.response_format == 'b64_json': + resp = self.syncClient.images.generate( + prompt=X, + n=self.n, + response_format='b64_json', + **self.predict_kwargs, ) + b64_json = resp.data[0].b64_json + assert b64_json is not None return base64.b64decode(b64_json) else: url = ( - (await self.asyncClient.images.generate(prompt=X, n=n, **kwargs)) + self.syncClient.images.generate( + prompt=X, n=self.n, **self.predict_kwargs + ) .data[0] .url ) - async with aiohttp.ClientSession() as session: - async with session.get(url) as resp: - return await resp.read() - - def _predict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - response_format = kwargs.pop('response_format', 'b64_json') - if context: - assert one, 'context only works with ``one=True``' - if one: - return self._predict_one( - X, n=1, response_format=response_format, context=context, **kwargs - ) - return [ - self._predict_one(msg, n=1, response_format=response_format) for msg in X - ] + return requests.get(url).content - async def _apredict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - response_format = kwargs.pop('response_format', 'b64_json') - if context: - assert one, 'context only works with ``one=True``' - if one: - return await self._apredict_one( - X, context=context, n=1, response_format=response_format, **kwargs + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + out = [] + for i in range(len(dataset)): + args, kwargs = self.handle_input_type( + data=dataset[i], signature=self.signature ) - return [ - await self._apredict_one(msg, n=1, response_format=response_format) - for msg in X - ] + out.append(self.predict_one(*args, **kwargs)) + return out @dc.dataclass(kw_only=True) @@ -340,6 +253,8 @@ class OpenAIImageEdit(_OpenAI): takes_context: bool = True prompt: str = '' + response_format: str = 'b64_json' + n: int = 1 def _format_prompt(self, context): prompt = self.prompt.format(context='\n'.join(context)) @@ -351,126 +266,54 @@ def pre_create(self, db: Datalayer) -> None: self.datatype = dtype('bytes') @retry - def _predict_one( + def predict_one( self, image: t.BinaryIO, - n: int, - response_format: str, + mask: t.Optional[t.BinaryIO] = None, context: t.Optional[t.List[str]] = None, - mask_png_path: t.Optional[str] = None, - **kwargs, ): if context is not None: self.prompt = self._format_prompt(context) - if mask_png_path is not None: - with open(mask_png_path, 'rb') as f: - mask = f.read() - kwargs['mask'] = mask + maybe_mask = mask or NOT_GIVEN - if response_format == 'b64_json': + if self.response_format == 'b64_json': b64_json = ( self.syncClient.images.edit( image=image, + mask=maybe_mask, prompt=self.prompt, - n=n, + n=self.n, response_format='b64_json', - **kwargs, + **self.predict_kwargs, ) .data[0] .b64_json ) - return base64.b64decode(b64_json) + out = base64.b64decode(b64_json) else: url = ( self.syncClient.images.edit( - image=image, prompt=self.prompt, n=n, **kwargs - ) - .data[0] - .url - ) - return requests.get(url).content - - @retry - async def _apredict_one( - self, - image: t.BinaryIO, - n: int, - response_format: str, - context: t.Optional[t.List[str]] = None, - mask_png_path: t.Optional[str] = None, - **kwargs, - ): - if context is not None: - self.prompt = self._format_prompt(context) - - if mask_png_path is not None: - with open(mask_png_path, 'rb') as f: - mask = f.read() - kwargs['mask'] = mask - - if response_format == 'b64_json': - b64_json = ( - ( - await self.asyncClient.images.edit( - image=image, - prompt=self.prompt, - n=n, - response_format='b64_json', - **kwargs, - ) - ) - .data[0] - .b64_json - ) - return base64.b64decode(b64_json) - else: - url = ( - ( - await self.asyncClient.images.edit( - image=image, prompt=self.prompt, n=n, **kwargs - ) + image=image, + mask=maybe_mask, + prompt=self.prompt, + n=self.n, + **self.predict_kwargs, ) .data[0] .url ) - async with aiohttp.ClientSession() as session: - async with session.get(url) as resp: - return await resp.read() - - def _predict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - response_format = kwargs.pop('response_format', 'b64_json') - if context: - assert one, 'context only works with ``one=True``' - if one: - return self._predict_one( - image=X, n=1, response_format=response_format, context=context, **kwargs - ) - return [ - self._predict_one( - image=image, n=1, response_format=response_format, **kwargs - ) - for image in X - ] + out = requests.get(url).content + return out - async def _apredict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - response_format = kwargs.pop('response_format', 'b64_json') - if context: - assert one, 'context only works with ``one=True``' - if one: - return await self._apredict_one( - image=X, context=context, n=1, response_format=response_format, **kwargs - ) - return [ - await self._apredict_one( - image=image, n=1, response_format=response_format, **kwargs + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + out = [] + for i in range(len(dataset)): + args, kwargs = self.handle_input_type( + data=dataset[i], signature=self.signature ) - for image in X - ] + out.append(self.predict_one(*args, **kwargs)) + return out @dc.dataclass(kw_only=True) @@ -492,9 +335,7 @@ def pre_create(self, db: Datalayer) -> None: self.datatype = dtype('str') @retry - def _predict_one( - self, file: t.BinaryIO, context: t.Optional[t.List[str]] = None, **kwargs - ): + def predict_one(self, file: t.BinaryIO, context: t.Optional[t.List[str]] = None): "Converts a file-like Audio recording to text." if context is not None: self.prompt = self.prompt.format(context='\n'.join(context)) @@ -502,23 +343,7 @@ def _predict_one( file=file, model=self.model, prompt=self.prompt, - **kwargs, - ).text - - @retry - async def _apredict_one( - self, file: t.BinaryIO, context: t.Optional[t.List[str]] = None, **kwargs - ): - "Converts a file-like Audio recording to text." - if context is not None: - self.prompt = self.prompt.format(context='\n'.join(context)) - return ( - await self.asyncClient.audio.transcriptions.create( - file=file, - model=self.model, - prompt=self.prompt, - **kwargs, - ) + **self.predict_kwargs, ).text @retry @@ -526,54 +351,12 @@ def _predict_a_batch(self, files: t.List[t.BinaryIO], **kwargs): "Converts multiple file-like Audio recordings to text." resps = [ self.syncClient.audio.transcriptions.create( - file=file, model=self.model, **kwargs + file=file, model=self.model, **self.predict_kwargs ) for file in files ] return [resp.text for resp in resps] - @retry - async def _apredict_a_batch(self, files: t.List[t.BinaryIO], **kwargs): - "Converts multiple file-like Audio recordings to text." - resps = await asyncio.gather( - *[ - self.asyncClient.audio.transcriptions.create( - file=file, model=self.model, **kwargs - ) - for file in files - ] - ) - return [resp.text for resp in resps] - - def _predict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - if context: - assert one, 'context only works with ``one=True``' - if one: - return self._predict_one(X, context=context, **kwargs) - out = [] - batch_size = kwargs.pop('batch_size', 10) - for i in tqdm.tqdm(range(0, len(X), batch_size)): - out.extend(self._predict_a_batch(X[i : i + batch_size], **kwargs)) - return out - - async def _apredict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - if context: - assert one, 'context only works with ``one=True``' - if one: - return await self._apredict_one(X, context=context, **kwargs) - batch_size = kwargs.pop('batch_size', 10) - list_of_lists = await asyncio.gather( - *[ - self._apredict_a_batch(X[i : i + batch_size], **kwargs) - for i in range(0, len(X), batch_size) - ] - ) - return list(itertools.chain(*list_of_lists)) - @dc.dataclass(kw_only=True) class OpenAIAudioTranslation(_OpenAI): @@ -583,10 +366,13 @@ class OpenAIAudioTranslation(_OpenAI): :param prompt: The prompt to guide the model's style. Should contain ``{context}``. """ + signature: t.ClassVar[str] = 'singleton' + __doc__ = __doc__.format(_openai_parameters=_OpenAI.__doc__, context='{context}') takes_context: bool = True prompt: str = '' + batch_size: int = 1 def pre_create(self, db: Datalayer) -> None: super().pre_create(db) @@ -594,8 +380,10 @@ def pre_create(self, db: Datalayer) -> None: self.datatype = dtype('str') @retry - def _predict_one( - self, file: t.BinaryIO, context: t.Optional[t.List[str]] = None, **kwargs + def predict_one( + self, + file: t.BinaryIO, + context: t.Optional[t.List[str]] = None, ): "Translates a file-like Audio recording to English." if context is not None: @@ -605,75 +393,18 @@ def _predict_one( file=file, model=self.model, prompt=self.prompt, - **kwargs, + **self.predict_kwargs, ) ).text @retry - async def _apredict_one( - self, file: t.BinaryIO, context: t.Optional[t.List[str]] = None, **kwargs - ): - "Translates a file-like Audio recording to English." - if context is not None: - self.prompt = self.prompt.format(context='\n'.join(context)) - return ( - await self.asyncClient.audio.translations.create( - file=file, - model=self.model, - prompt=self.prompt, - **kwargs, - ) - ).text - - @retry - def _predict_a_batch(self, files: t.List[t.BinaryIO], **kwargs): + def _predict_a_batch(self, files: t.List[t.BinaryIO]): "Translates multiple file-like Audio recordings to English." + # TODO use async or threads resps = [ self.syncClient.audio.translations.create( - file=file, model=self.model, **kwargs + file=file, model=self.model, **self.predict_kwargs ) for file in files ] return [resp.text for resp in resps] - - @retry - async def _apredict_a_batch(self, files: t.List[t.BinaryIO], **kwargs): - "Translates multiple file-like Audio recordings to English." - resps = await asyncio.gather( - *[ - self.asyncClient.audio.translations.create( - file=file, model=self.model, **kwargs - ) - for file in files - ] - ) - return [resp.text for resp in resps] - - def _predict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - if context: - assert one, 'context only works with ``one=True``' - if one: - return self._predict_one(X, context=context, **kwargs) - out = [] - batch_size = kwargs.pop('batch_size', 10) - for i in tqdm.tqdm(range(0, len(X), batch_size)): - out.extend(self._predict_a_batch(X[i : i + batch_size], **kwargs)) - return out - - async def _apredict( - self, X, one: bool = True, context: t.Optional[t.List[str]] = None, **kwargs - ): - if context: - assert one, 'context only works with ``one=True``' - if one: - return await self._apredict_one(X, context=context, **kwargs) - batch_size = kwargs.pop('batch_size', 10) - list_of_lists = await asyncio.gather( - *[ - self._apredict_a_batch(X[i : i + batch_size], **kwargs) - for i in range(0, len(X), batch_size) - ] - ) - return list(itertools.chain(*list_of_lists)) diff --git a/superduperdb/ext/sentence_transformers/model.py b/superduperdb/ext/sentence_transformers/model.py index da508b4ec0..a349c2d054 100644 --- a/superduperdb/ext/sentence_transformers/model.py +++ b/superduperdb/ext/sentence_transformers/model.py @@ -1,32 +1,53 @@ import dataclasses as dc import typing as t -from superduperdb.components.model import Model +from overrides import override +from sentence_transformers import SentenceTransformer as _SentenceTransformer + +from superduperdb.backends.query_dataset import QueryDataset +from superduperdb.components.datatype import DataType, dill_serializer +from superduperdb.components.model import _DeviceManaged, _Predictor @dc.dataclass(kw_only=True) -class SentenceTransformer(Model): - _encodables: t.ClassVar[t.Sequence[str]] = ('object',) +class SentenceTransformer(_Predictor, _DeviceManaged): + _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, 'DataType']]] = ( + ('object', dill_serializer), + ) - object: t.Optional[t.Callable] = None + object: t.Optional[_SentenceTransformer] = None model: t.Optional[str] = None + device: str = 'cpu' + preprocess: t.Optional[t.Callable] = None + + def __post_init__(self, artifacts): + super().__post_init__(artifacts) - def __post_init__(self): - super().__post_init__() if self.model is None: self.model = self.identifier if self.object is None: - import sentence_transformers + self.object = _SentenceTransformer(self.identifier, device=self.device) - self.object = sentence_transformers.SentenceTransformer( - self.identifier, device=self.device - ) - self.object = self.object.to(self.device) - self.model_to_device_method = '_to' - self.predict_method = 'encode' - self.batch_predict = True + self.to(self.device) - def _to(self, device): + def to(self, device): self.object = self.object.to(device) self.object._target_device = device + + @override + def predict_one(self, X) -> int: + if self.preprocess is not None: + X = self.preprocess(X) + + assert self.object is not None + return self.object.encode(X, self.predict_kwargs) + + @override + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + if self.preprocess is not None: + dataset = list(map(self.preprocess, dataset)) # type: ignore[arg-type] + return self.object.encode( # type: ignore[union-attr] + dataset, + **self.predict_kwargs, + ) diff --git a/superduperdb/ext/sklearn/model.py b/superduperdb/ext/sklearn/model.py index 1aaa42783c..e8e39a5d33 100644 --- a/superduperdb/ext/sklearn/model.py +++ b/superduperdb/ext/sklearn/model.py @@ -2,13 +2,21 @@ import typing as t import numpy +from sklearn.base import BaseEstimator from tqdm import tqdm from superduperdb.backends.base.query import Select from superduperdb.backends.query_dataset import QueryDataset from superduperdb.base.datalayer import Datalayer +from superduperdb.components.datatype import DataType, pickle_serializer from superduperdb.components.metric import Metric -from superduperdb.components.model import Model, _TrainingConfiguration +from superduperdb.components.model import ( + Mapping, + _Fittable, + _Predictor, + _TrainingConfiguration, +) +from superduperdb.jobs.job import Job def _get_data_from_query( @@ -19,34 +27,40 @@ def _get_data_from_query( y_preprocess: t.Optional[t.Callable] = None, preprocess: t.Optional[t.Callable] = None, ): - def transform(r): - out = {} - if X == '_base': - out.update(**preprocess(r)) - else: - out[X] = preprocess(r[X]) - if y is not None: - out[y] = y_preprocess(r[y]) if y_preprocess else r[y] - return out + if y is None: + data = QueryDataset( + select=select, + mapping=Mapping([X], signature='singleton'), + transform=preprocess, + db=db, + ) + else: + y_preprocess = y_preprocess or (lambda x: x) + preprocess = preprocess or (lambda x: x) + data = QueryDataset( + select=select, + mapping=Mapping([X, y], signature='*args'), + transform=lambda x, y: (preprocess(x), y_preprocess(y)), + db=db, + ) - data = QueryDataset( - select=select, - keys=[X] if y is None else [X, y], - transform=transform, - db=db, - ) - documents = [] + rows = [] for i in tqdm(range(len(data))): - r = data[i] - documents.append(r) - X_arr = [r[X] for r in documents] + rows.append(data[i]) + if y is not None: + X_arr = [x[0] for x in rows] + y_arr = [x[1] for x in rows] + else: + X_arr = rows + if isinstance(X[0], numpy.ndarray): X_arr = numpy.stack(X_arr) if y is not None: - y_arr = [r[y] for r in documents] - if isinstance(y[0], numpy.ndarray): + y_arr = [r[1] for r in rows] + if isinstance(y_arr[0], numpy.ndarray): y_arr = numpy.stack(y_arr) - return X_arr, y_arr + return X_arr, y_arr + return X_arr, None @dc.dataclass @@ -56,13 +70,27 @@ class SklearnTrainingConfiguration(_TrainingConfiguration): y_preprocess: t.Optional[t.Callable] = None -@dc.dataclass -class Estimator(Model): - def __post_init__(self, artifacts): - if self.predict_method is not None: - assert self.predict_method == 'predict' - self.predict_method = 'predict' - super().__post_init__(artifacts) +@dc.dataclass(kw_only=True) +class Estimator(_Predictor, _Fittable): + _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, DataType]]] = ( + ('object', pickle_serializer), + ) + signature: t.ClassVar[str] = 'singleton' + object: BaseEstimator + preprocess: t.Optional[t.Callable] = None + postprocess: t.Optional[t.Callable] = None + + def schedule_jobs( + self, + db: Datalayer, + dependencies: t.Sequence[Job] = (), + verbose: bool = False, + ) -> t.Sequence[t.Any]: + jobs = _Fittable.schedule_jobs(self, db, dependencies=dependencies) + jobs.extend( + _Predictor.schedule_jobs(self, db, dependencies=[*dependencies, *jobs]) + ) + return jobs def __getattr__(self, item): if item in ['transform', 'predict_proba', 'score']: @@ -70,8 +98,29 @@ def __getattr__(self, item): else: return super().__getattribute__(item) - def _forward(self, X, **kwargs): - return self.object.predict(X, **kwargs) + def predict_one(self, X): + X = X[None, :] + if self.preprocess is not None: + X = self.preprocess(X) + X = self.object.predict(X, **self.predict_kwargs)[0] + if self.postprocess is not None: + X = self.postprocess(X) + return X + + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + if self.preprocess is not None: + inputs = [] + for i in range(len(dataset)): + args, kwargs = dataset[i] + inputs.append(self.preprocess(*args, **kwargs)) + dataset = inputs + else: + dataset = [dataset[i] for i in range(len(dataset))] + out = self.object.predict(dataset, **self.predict_kwargs) + if self.postprocess is not None: + out = list(out) + out = list(map(self.postprocess, out)) + return out def _fit( # type: ignore[override] self, @@ -100,14 +149,12 @@ def _fit( # type: ignore[override] if select is None or db is None: raise ValueError('Neither select nor db can be None') - preprocess = self.preprocess or (lambda x: x) - X, y = _get_data_from_query( select=select, db=db, X=X, y=y, - preprocess=preprocess, + preprocess=self.preprocess, y_preprocess=y_preprocess, ) if self.training_configuration is not None: @@ -134,6 +181,7 @@ def _fit( # type: ignore[override] ) ) self.append_metrics(results) + if db is not None: db.replace(self, upsert=True) return to_return diff --git a/superduperdb/ext/torch/model.py b/superduperdb/ext/torch/model.py index 970a26d64a..b420acb3d5 100644 --- a/superduperdb/ext/torch/model.py +++ b/superduperdb/ext/torch/model.py @@ -11,23 +11,29 @@ from torch.utils.data import DataLoader from tqdm import tqdm -import superduperdb as s from superduperdb import logging from superduperdb.backends.base.query import Select from superduperdb.backends.query_dataset import QueryDataset from superduperdb.base.datalayer import Datalayer -from superduperdb.base.document import Document from superduperdb.base.serializable import Serializable from superduperdb.components.dataset import Dataset from superduperdb.components.datatype import ( DataType, - Encodable, dill_serializer, torch_serializer, ) from superduperdb.components.metric import Metric -from superduperdb.components.model import Model, _TrainingConfiguration +from superduperdb.components.model import ( + CallableInputs, + Mapping, + Signature, + _DeviceManaged, + _Fittable, + _Predictor, + _TrainingConfiguration, +) from superduperdb.ext.torch.utils import device_of, eval, to_device +from superduperdb.jobs.job import Job class BasicDataset(data.Dataset): @@ -38,21 +44,23 @@ class BasicDataset(data.Dataset): :param transform: function """ - def __init__(self, documents, transform): + def __init__(self, items, transform, signature): super().__init__() - self.documents = documents + self.items = items self.transform = transform + self.signature = signature def __len__(self): - return len(self.documents) + return len(self.items) def __getitem__(self, item): - document = self.documents[item] - if isinstance(document, Document): - document = document.unpack() - elif isinstance(document, Encodable): - document = document.x - return self.transform(document) + out = self.items[item] + if self.transform is not None: + from superduperdb.components.model import Model + + args, kwargs = Model.handle_input_type(out, self.signature) + return self.transform(*args, **kwargs) + return out @dc.dataclass @@ -87,36 +95,53 @@ class TorchTrainerConfiguration(_TrainingConfiguration): optimizer_cls: t.Any = dc.field(default_factory=lambda: torch.optim.Adam) optimizer_kwargs: t.Dict = dc.field(default_factory=dict) target_preprocessors: t.Optional[t.Dict] = None + pass_kwargs: bool = False @dc.dataclass(kw_only=True) -class TorchModel(Model): +class TorchModel(_Predictor, _Fittable, _DeviceManaged): _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, DataType]]] = ( ('object', dill_serializer), ('optimizer_state', torch_serializer), ) - num_directions: int = 2 + object: torch.nn.Module + preprocess: t.Optional[t.Callable] = None + postprocess: t.Optional[t.Callable] = None + collate_fn: t.Optional[t.Callable] = None optimizer_state: t.Optional[t.Any] = None forward_method: str = '__call__' train_forward_method: str = '__call__' + loader_kwargs: t.Dict = dc.field(default_factory=lambda: {}) + signature: str = Signature.singleton # type: ignore[misc] + forward_signature: str = Signature.singleton + postprocess_signature: str = Signature.singleton def __post_init__(self, artifacts): super().__post_init__(artifacts=artifacts) - if self.model_to_device_method: - s.logging.debug( - f'{self.model_to_device_method} will be overriden with `to`' - ) - - self.model_to_device_method = 'to' - self.object.serializer = 'torch' - if self.optimizer_state is not None: self.optimizer.load_state_dict(self.optimizer_state) self._validation_set_cache = {} + @property + def inputs(self) -> CallableInputs: + return CallableInputs( + self.object.forward if not self.preprocess else self.preprocess, {} + ) + + def schedule_jobs( + self, + db: Datalayer, + dependencies: t.Sequence[Job] = (), + ) -> t.Sequence[t.Any]: + jobs = _Fittable.schedule_jobs(self, db, dependencies=dependencies) + jobs.extend( + _Predictor.schedule_jobs(self, db, dependencies=[*dependencies, *jobs]) + ) + return jobs + def to(self, device): self.object.to(device) @@ -181,36 +206,47 @@ def __setstate__(self, state): io.BytesIO(state.pop('object_bytes')) ) - def _predict_one(self, x): + def predict_one(self, *args, **kwargs): with torch.no_grad(), eval(self.object): if self.preprocess is not None: - x = self.preprocess(x) - x = to_device(x, device_of(self.object)) - singleton_batch = create_batch(x) + out = self.preprocess(*args, **kwargs) + args, kwargs = self.handle_input_type(out, self.forward_signature) + + args, kwargs = to_device((args, kwargs), self.device) + args, kwargs = create_batch((args, kwargs)) + method = getattr(self.object, self.forward_method) - output = method(singleton_batch) + output = method(*args, **kwargs) output = to_device(output, 'cpu') args = unpack_batch(output)[0] if self.postprocess is not None: args = self.postprocess(args) return args - def _predict(self, x, one: bool = False, **kwargs): # type: ignore[override] + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: with torch.no_grad(), eval(self.object): - if one: - return self._predict_one(x) - inputs = BasicDataset(x, self.preprocess or (lambda x: x)) - loader = torch.utils.data.DataLoader(inputs, **kwargs) + inputs = BasicDataset( + items=dataset, + transform=self.preprocess, + signature=self.signature, + ) + loader = torch.utils.data.DataLoader( + inputs, **self.loader_kwargs, collate_fn=self.collate_fn + ) out = [] for batch in tqdm(loader, total=len(loader)): batch = to_device(batch, device_of(self.object)) - + args, kwargs = self.handle_input_type(batch, self.forward_signature) method = getattr(self.object, self.forward_method) - tmp = method(batch) + tmp = method(*args, **kwargs, **self.predict_kwargs) tmp = to_device(tmp, 'cpu') tmp = unpack_batch(tmp) if self.postprocess: - tmp = [self.postprocess(t) for t in tmp] + tmp = [ + self.handle_input_type(x, self.postprocess_signature) + for x in tmp + ] + tmp = [self.postprocess(*x[0], **x[1]) for x in tmp] out.extend(tmp) return out @@ -266,7 +302,7 @@ def _fit( db: t.Optional[Datalayer] = None, metrics: t.Optional[t.Sequence[Metric]] = None, select: t.Optional[t.Union[Select, t.Dict]] = None, - validation_sets: t.Optional[t.Sequence[t.Union[str, Dataset]]] = None, + validation_sets: t.Optional[t.Sequence[Dataset]] = None, ): if configuration is not None: self.training_configuration = configuration @@ -311,20 +347,9 @@ def forward(self, X): return self.object(X) def extract_batch_key(self, batch, key: t.Union[t.List[str], str]): - if isinstance(key, str): - return batch[key] return [batch[k] for k in key] - def extract_batch(self, batch): - if self.train_y is not None: - return [ - self.extract_batch_key(batch, self.train_X), - self.extract_batch_key(batch, self.train_y), - ] - return [self.extract_batch_key(batch, self.train_X)] - def take_step(self, batch, optimizers): - batch = self.extract_batch(batch) outputs = self.train_forward(*batch) objective_value = self.training_configuration.objective(*outputs) for opt in optimizers: @@ -338,7 +363,6 @@ def compute_validation_objective(self, valid_dataloader): objective_values = [] with self.evaluating(), torch.no_grad(): for batch in valid_dataloader: - batch = self.extract_batch(batch) objective_values.append( self.training_configuration.objective( *self.train_forward(*batch) @@ -389,56 +413,40 @@ def _fit_with_dataloaders( return iteration += 1 - def train_preprocess(self): - preprocessors = {} - if isinstance(self.train_X, str): - preprocessors[self.train_X] = ( - self.preprocess if self.preprocess else lambda x: x - ) - else: - for model, X in zip(self.models, self.train_X): - preprocessors[X] = ( - model.preprocess if model.preprocess is not None else lambda x: x - ) - if self.train_y is not None: - if ( - isinstance(self.train_y, str) - and self.training_configuration.target_preprocessors - ): - preprocessors[ - self.train_y - ] = self.training_configuration.target_preprocessors.get( - self.train_y, lambda x: x - ) - elif isinstance(self.train_y, str): - preprocessors[self.train_y] = lambda x: x - elif ( - isinstance(self.train_y, list) - and self.training_configuration.target_preprocessors - ): - for y in self.train_y: - preprocessors[ - y - ] = self.training_configuration.target_preprocessors.get( - y, lambda x: x - ) - return lambda r: {k: preprocessors[k](r[k]) for k in preprocessors} - def _get_data(self, db: t.Optional[Datalayer]): if self.training_select is None: raise ValueError('self.training_select cannot be None') + preprocess = self.preprocess or (lambda x: x) train_data = QueryDataset( select=self.training_select, - keys=self.training_keys, + mapping=Mapping( + [self.train_X, self.train_y] # type: ignore[list-item] + if self.train_y + else self.train_X, # type: ignore[arg-type] + signature='*args', + ), fold='train', - transform=self.train_preprocess(), + transform=( + preprocess + if not self.train_y + else lambda x, y: (preprocess(x), y) # type: ignore[misc] + ), db=db, ) valid_data = QueryDataset( select=self.training_select, - keys=self.training_keys, + mapping=Mapping( + [self.train_X, self.train_y] # type: ignore[list-item] + if self.train_y + else self.train_X, # type: ignore[arg-type] + signature='*args', + ), fold='valid', - transform=self.train_preprocess(), + transform=( + preprocess + if not self.train_y + else lambda x, y: (preprocess(x), y) # type: ignore[misc] + ), db=db, ) return train_data, valid_data diff --git a/superduperdb/ext/transformers/model.py b/superduperdb/ext/transformers/model.py index 52d62cdd78..4f2eae7e49 100644 --- a/superduperdb/ext/transformers/model.py +++ b/superduperdb/ext/transformers/model.py @@ -14,10 +14,17 @@ from superduperdb import logging from superduperdb.backends.base.query import Select -from superduperdb.backends.query_dataset import query_dataset_factory +from superduperdb.backends.query_dataset import QueryDataset, query_dataset_factory from superduperdb.base.datalayer import Datalayer +from superduperdb.components.dataset import Dataset from superduperdb.components.metric import Metric -from superduperdb.components.model import Model, _TrainingConfiguration +from superduperdb.components.model import ( + Signature, + _DeviceManaged, + _Fittable, + _Predictor, + _TrainingConfiguration, +) from superduperdb.misc.special_dicts import MongoStyleDict _DEFAULT_PREFETCH_SIZE: int = 100 @@ -32,19 +39,28 @@ def TransformersTrainerConfiguration(identifier: str, *args, **kwargs): return _TrainingConfiguration(identifier=identifier, kwargs=cfg.to_dict()) +# TODO refactor @dc.dataclass -class Pipeline(Model): +class Pipeline(_Predictor, _Fittable, _DeviceManaged): """A wrapper for ``transformers.Pipeline`` + :param object: The object + :param postprocess: The postprocessor :param preprocess_type: The type of preprocessing to use {'tokenizer'} :param preprocess_kwargs: The type of preprocessing to use. Currently only + :param postprocessor: The postprocessing function :param postprocess_kwargs: The type of postprocessing to use. :param task: The task to use for the pipeline. """ + signature: t.ClassVar[str] = Signature.singleton + object: t.Optional[t.Callable] = None + preprocess: t.Optional[t.Callable] = None preprocess_type: str = 'tokenizer' preprocess_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict) + postprocess: t.Optional[t.Callable] = None postprocess_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict) + collate_fn: t.Optional[t.Callable] = None task: str = 'text-classification' def __post_init__(self, artifacts): @@ -57,6 +73,13 @@ def __post_init__(self, artifacts): raise NotImplementedError( 'Only tokenizer is supported for now in pipeline mode' ) + + if ( + self.collate_fn is None + and self.preprocess is not None + and self.preprocess_type == 'tokenizer' + ): + self.collate_fn = DataCollatorWithPadding(self.preprocess) self.object = self.object.model self.task = self.object.task if ( @@ -81,6 +104,7 @@ def pipeline(self): else: warnings.warn('Only tokenizer is supported for now in pipeline mode') + # TODO very confusing... def _predict_with_preprocess_object_post(self, X, **kwargs): X = self.preprocess(X, **self.preprocess_kwargs) X = getattr(self.object, self.predict_method)(**X, **kwargs) @@ -95,35 +119,32 @@ def training_arguments(self): def _get_data( self, db, + X: str, valid_data=None, data_prefetch: bool = False, prefetch_size: int = 100, - X: str = '', ): - def transform_function(r): - text = r[X] - r.update(**self.preprocess(text, **self.preprocess_kwargs)) + def preprocess(r): + if self.preprocess: + r.update(self.preprocess(r[X], **self.preprocess_kwargs)) return r train_data = query_dataset_factory( select=self.training_select, - keys=self.training_keys, fold='train', - transform=transform_function, + transform=preprocess, data_prefetch=data_prefetch, prefetch_size=prefetch_size, db=db, ) valid_data = query_dataset_factory( select=self.training_select, - keys=self.training_keys, fold='valid', - transform=transform_function, + transform=preprocess, data_prefetch=data_prefetch, prefetch_size=prefetch_size, db=db, ) - return train_data, valid_data def _fit( # type: ignore[override] @@ -136,7 +157,7 @@ def _fit( # type: ignore[override] metrics: t.Optional[t.Sequence[Metric]] = None, prefetch_size: int = _DEFAULT_PREFETCH_SIZE, select: t.Optional[Select] = None, - validation_sets: t.Optional[t.Sequence[str]] = None, + validation_sets: t.Optional[t.Sequence[Dataset]] = None, **kwargs, ) -> None: if configuration is not None: @@ -150,21 +171,20 @@ def _fit( # type: ignore[override] self.train_X = X self.train_y = y - if isinstance(X, str): - train_data, valid_data = self._get_data( - db, - valid_data=validation_sets, - data_prefetch=data_prefetch, - prefetch_size=prefetch_size, - X=X, - ) + train_data, valid_data = self._get_data( + db, + X, + valid_data=validation_sets, + data_prefetch=data_prefetch, + prefetch_size=prefetch_size, + ) def compute_metrics(eval_pred): output = {} for vs in validation_sets: vs = db.load('dataset', vs) unpacked = [MongoStyleDict(r.unpack()) for r in vs.data] - predictions = self._predict([r[X] for r in unpacked]) + predictions = self.predict([r[X] for r in unpacked]) targets = [r[y] for r in unpacked] for m in metrics: output[f'{vs.identifier}/{m.identifier}'] = m(predictions, targets) @@ -186,17 +206,21 @@ def compute_metrics(eval_pred): ) trainer.train() - def _predict(self, X, one: bool = False, **kwargs): + def predict_one(self, X: t.Any): + return self.predict([X])[0] + + def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: if self.pipeline is not None: - out = self.pipeline(X, **self.preprocess_kwargs, **kwargs) + X = [dataset[i] for i in range(len(dataset))] + out = self.pipeline(X, **self.preprocess_kwargs, **self.predict_kwargs) out = [r['label'] for r in out] for i, p in enumerate(out): if re.match(r'^LABEL_[0-9]+', p): out[i] = int(p[6:]) else: - out = self._predict_with_preprocess_object_post(X, **kwargs) - if one: - return out[0] + out = self._predict_with_preprocess_object_post( + dataset, **self.predict_kwargs + ) return out diff --git a/superduperdb/misc/special_dicts.py b/superduperdb/misc/special_dicts.py index d3faaa31f6..f5f34cd8ae 100644 --- a/superduperdb/misc/special_dicts.py +++ b/superduperdb/misc/special_dicts.py @@ -23,6 +23,8 @@ class MongoStyleDict(t.Dict[str, t.Any]): """ def __getitem__(self, key: str) -> t.Any: + if key == '_base': + return self if '.' not in key: return super().__getitem__(key) else: diff --git a/superduperdb/server/app.py b/superduperdb/server/app.py index fdc1c3ad35..250d9a5a17 100644 --- a/superduperdb/server/app.py +++ b/superduperdb/server/app.py @@ -106,7 +106,7 @@ def handshake(cfg: str): return JSONResponse( status_code=400, - content={'error': 'Config is not match'}, + content={'error': 'Config doesn\'t match'}, ) def print_routes(self): diff --git a/superduperdb/vector_search/atlas.py b/superduperdb/vector_search/atlas.py index 2ae78c0656..3d28332eda 100644 --- a/superduperdb/vector_search/atlas.py +++ b/superduperdb/vector_search/atlas.py @@ -52,7 +52,7 @@ def index(self): @classmethod def from_component(cls, vi: 'VectorIndex'): from superduperdb.components.listener import Listener - from superduperdb.components.model import Model + from superduperdb.components.model import ObjectModel assert isinstance(vi.indexing_listener, Listener) collection = vi.indexing_listener.select.table_or_collection.identifier @@ -63,7 +63,7 @@ def from_component(cls, vi: 'VectorIndex'): ), 'Only single key is support for atlas search' if indexing_key.startswith('_outputs'): indexing_key = indexing_key.split('.')[1] - assert isinstance(vi.indexing_listener.model, Model) or isinstance( + assert isinstance(vi.indexing_listener.model, ObjectModel) or isinstance( vi.indexing_listener.model, APIModel ) assert isinstance(collection, str), 'Collection is required to be a string' diff --git a/test/conftest.py b/test/conftest.py index 27930a54e2..3f4666cbc0 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -213,8 +213,8 @@ def add_datatypes(db: Datalayer): def add_models(db: Datalayer): # identifier, weight_shape, encoder params = [ - ['linear_a', (32, 16), 'torch.float32[16]'], - ['linear_b', (16, 8), 'torch.float32[8]'], + ['linear_a', (32, 16), db.load('datatype', 'torch.float32[16]')], + ['linear_b', (16, 8), db.load('datatype', 'torch.float32[8]')], ] for identifier, weight_shape, datatype in params: m = TorchModel( @@ -228,7 +228,7 @@ def add_models(db: Datalayer): def add_vector_index( db: Datalayer, collection_name='documents', identifier='test_vector_search' ): - # TODO: Support configurable key and model + # TODO: Support configurable key and mode is_mongodb_backend = isinstance(db.databackend, MongoDataBackend) if is_mongodb_backend: select_x = Collection(collection_name).find() @@ -238,25 +238,27 @@ def add_vector_index( select_x = table.select('id', 'x') select_z = table.select('id', 'z') - db.add( + model = db.load('model', 'linear_a') + + _, i_list = db.add( Listener( select=select_x, key='x', - model='linear_a', + model=model, ) ) - db.add( + _, c_list = db.add( Listener( select=select_z, key='z', - model='linear_a', + model=model, ) ) vi = VectorIndex( identifier=identifier, - indexing_listener='linear_a/x', - compatible_listener='linear_a/z', + indexing_listener=i_list, + compatible_listener=c_list, ) db.add(vi) diff --git a/test/integration/conftest.py b/test/integration/conftest.py index 4b13c5f171..dab78536ba 100644 --- a/test/integration/conftest.py +++ b/test/integration/conftest.py @@ -40,32 +40,32 @@ def add_models_encoders(test_db): test_db.add(tensor(torch.float, shape=(32,))) - test_db.add(tensor(torch.float, shape=(16,))) - test_db.add( + _, dt_16 = test_db.add(tensor(torch.float, shape=(16,))) + _, model = test_db.add( TorchModel( object=torch.nn.Linear(32, 16), identifier='model_linear_a', - datatype='torch.float32[16]', + datatype=dt_16, ) ) - test_db.add( + _, indexing_listener = test_db.add( Listener( select=Collection(identifier='documents').find(), key='x', - model='model_linear_a', + model=model, ) ) - test_db.add( + _, compatible_listener = test_db.add( Listener( select=Collection(identifier='documents').find(), key='z', - model='model_linear_a', + model=model, ) ) vi = VectorIndex( identifier='test_index', - indexing_listener='model_linear_a/x', - compatible_listener='model_linear_a/z', + indexing_listener=indexing_listener, + compatible_listener=compatible_listener, ) test_db.add(vi) return test_db diff --git a/test/integration/ext/anthropic/test_model_anthropic.py b/test/integration/ext/anthropic/test_model_anthropic.py index c0b04a4337..2ca81b1bdf 100644 --- a/test/integration/ext/anthropic/test_model_anthropic.py +++ b/test/integration/ext/anthropic/test_model_anthropic.py @@ -13,7 +13,7 @@ ) def test_completions(): e = AnthropicCompletions(model='claude-2', prompt='Hello, {context}') - resp = e.predict('', one=True, context=['world!']) + resp = e.predict_in_db('', one=True, context=['world!']) assert isinstance(resp, str) @@ -25,34 +25,7 @@ def test_completions(): ) def test_batch_completions(): e = AnthropicCompletions(model='claude-2') - resp = e.predict(['Hello, world!'], one=False) - - assert isinstance(resp, list) - assert isinstance(resp[0], str) - - -@pytest.mark.skip(reason="API is not publically available yet") -@vcr.use_cassette( - f'{CASSETTE_DIR}/test_completions_async.yaml', - filter_headers=['authorization'], -) -@pytest.mark.asyncio -async def test_completions_async(): - e = AnthropicCompletions(model='claude-2', prompt='Hello, {context}') - resp = await e.apredict('', one=True, context=['world!']) - - assert isinstance(resp, str) - - -@pytest.mark.skip(reason="API is not publically available yet") -@vcr.use_cassette( - f'{CASSETTE_DIR}/test_batch_completions_async.yaml', - filter_headers=['authorization'], -) -@pytest.mark.asyncio -async def test_batch_completions_async(): - e = AnthropicCompletions(model='claude-2') - resp = await e.apredict(['Hello, world!'], one=False) + resp = e.predict_in_db(['Hello, world!'], one=False) assert isinstance(resp, list) assert isinstance(resp[0], str) diff --git a/test/integration/ext/cohere/test_model_cohere.py b/test/integration/ext/cohere/test_model_cohere.py index 1d5a489766..26410af9a7 100644 --- a/test/integration/ext/cohere/test_model_cohere.py +++ b/test/integration/ext/cohere/test_model_cohere.py @@ -19,7 +19,7 @@ ) def test_embed_one(): embed = CohereEmbed(identifier='embed-english-v2.0') - resp = embed.predict('Hello world') + resp = embed.predict_one('Hello world') assert len(resp) == embed.shape[0] assert isinstance(resp, list) @@ -31,37 +31,8 @@ def test_embed_one(): filter_headers=['authorization'], ) def test_embed_batch(): - embed = CohereEmbed(identifier='embed-english-v2.0') - resp = embed.predict(['Hello', 'world'], batch_size=1) - - assert len(resp) == 2 - assert len(resp[0]) == embed.shape[0] - assert isinstance(resp[0], list) - assert all(isinstance(x, float) for x in resp[0]) - - -@pytest.mark.asyncio -@vcr.use_cassette( - f'{CASSETTE_DIR}/test_async_embed_one.yaml', - filter_headers=['authorization'], -) -async def test_async_embed_one(): - embed = CohereEmbed(identifier='embed-english-v2.0') - resp = await embed.apredict('Hello world') - - assert len(resp) == embed.shape[0] - assert isinstance(resp, list) - assert all(isinstance(x, float) for x in resp) - - -@pytest.mark.asyncio -@vcr.use_cassette( - f'{CASSETTE_DIR}/test_async_embed_batch.yaml', - filter_headers=['authorization'], -) -async def test_async_embed_batch(): - embed = CohereEmbed(identifier='embed-english-v2.0') - resp = await embed.apredict(['Hello', 'world'], batch_size=1) + embed = CohereEmbed(identifier='embed-english-v2.0', batch_size=1) + resp = embed.predict(['Hello', 'world']) assert len(resp) == 2 assert len(resp[0]) == embed.shape[0] @@ -75,7 +46,7 @@ async def test_async_embed_batch(): ) def test_generate(): e = CohereGenerate(identifier='base-light', prompt='Hello, {context}') - resp = e.predict('', one=True, context=['world!']) + resp = e.predict_one('', context=['world!']) assert isinstance(resp, str) @@ -86,32 +57,11 @@ def test_generate(): ) def test_batch_generate(): e = CohereGenerate(identifier='base-light') - resp = e.predict(['Hello, world!'], one=False) - - assert isinstance(resp, list) - assert isinstance(resp[0], str) - - -@vcr.use_cassette( - f'{CASSETTE_DIR}/test_generate_async.yaml', - filter_headers=['authorization'], -) -@pytest.mark.asyncio -async def test_chat_async(): - e = CohereGenerate(identifier='base-light', prompt='Hello, {context}') - resp = await e.apredict('', one=True, context=['world!']) - - assert isinstance(resp, str) - - -@vcr.use_cassette( - f'{CASSETTE_DIR}/test_batch_chat_async.yaml', - filter_headers=['authorization'], -) -@pytest.mark.asyncio -async def test_batch_chat_async(): - e = CohereGenerate(identifier='base-light') - resp = await e.apredict(['Hello, world!'], one=False) + resp = e.predict( + [ + (('Hello, world!',), {}), + ] + ) assert isinstance(resp, list) assert isinstance(resp[0], str) diff --git a/test/integration/ext/jina/test_model_jina.py b/test/integration/ext/jina/test_model_jina.py index e587940e6b..ca0d6d7a70 100644 --- a/test/integration/ext/jina/test_model_jina.py +++ b/test/integration/ext/jina/test_model_jina.py @@ -19,7 +19,7 @@ ) def test_embed_one(): embed = JinaEmbedding(identifier='jina-embeddings-v2-base-en') - resp = embed.predict('Hello world') + resp = embed.predict_one('Hello world') assert len(resp) == embed.shape[0] assert isinstance(resp, list) @@ -31,37 +31,8 @@ def test_embed_one(): filter_headers=['Authorization'], ) def test_embed_batch(): - embed = JinaEmbedding(identifier='jina-embeddings-v2-base-en') - resp = embed.predict(['Hello', 'world', 'I', 'am', 'here'], batch_size=3) - - assert len(resp) == 5 - assert len(resp[0]) == embed.shape[0] - assert isinstance(resp[0], list) - assert all(isinstance(x, float) for x in resp[0]) - - -@pytest.mark.asyncio -@vcr.use_cassette( - f'{CASSETTE_DIR}/test_async_embed_one.yaml', - filter_headers=['authorization'], -) -async def test_async_embed_one(): - embed = JinaEmbedding(identifier='jina-embeddings-v2-base-en') - resp = await embed.apredict('Hello world') - - assert len(resp) == embed.shape[0] - assert isinstance(resp, list) - assert all(isinstance(x, float) for x in resp) - - -@pytest.mark.asyncio -@vcr.use_cassette( - f'{CASSETTE_DIR}/test_async_embed_batch.yaml', - filter_headers=['authorization'], -) -async def test_async_embed_batch(): - embed = JinaEmbedding(identifier='jina-embeddings-v2-base-en') - resp = await embed.apredict(['Hello', 'world', 'I', 'am', 'here'], batch_size=3) + embed = JinaEmbedding(identifier='jina-embeddings-v2-base-en', batch_size=3) + resp = embed.predict(['Hello', 'world', 'I', 'am', 'here']) assert len(resp) == 5 assert len(resp[0]) == embed.shape[0] diff --git a/test/integration/ext/openai/test_model_openai.py b/test/integration/ext/openai/test_model_openai.py index 6c5d55df0c..6c68acaa23 100644 --- a/test/integration/ext/openai/test_model_openai.py +++ b/test/integration/ext/openai/test_model_openai.py @@ -91,7 +91,7 @@ def mock_lru_cache(): @vcr.use_cassette() def test_embed(): e = OpenAIEmbedding(identifier='text-embedding-ada-002') - resp = e.predict('Hello, world!') + resp = e.predict_one('Hello, world!') assert len(resp) == e.shape[0] assert all(isinstance(x, float) for x in resp) @@ -99,29 +99,8 @@ def test_embed(): @vcr.use_cassette() def test_batch_embed(): - e = OpenAIEmbedding(identifier='text-embedding-ada-002') - resp = e.predict(['Hello', 'world!'], batch_size=1) - - assert len(resp) == 2 - assert all(len(x) == e.shape[0] for x in resp) - assert all(isinstance(x, float) for y in resp for x in y) - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_embed_async(): - e = OpenAIEmbedding(identifier='text-embedding-ada-002') - resp = await e.apredict('Hello, world!') - - assert len(resp) == e.shape[0] - assert all(isinstance(x, float) for x in resp) - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_batch_embed_async(): - e = OpenAIEmbedding(identifier='text-embedding-ada-002') - resp = await e.apredict(['Hello', 'world!'], batch_size=1) + e = OpenAIEmbedding(identifier='text-embedding-ada-002', batch_size=1) + resp = e.predict(['Hello', 'world!']) assert len(resp) == 2 assert all(len(x) == e.shape[0] for x in resp) @@ -131,7 +110,7 @@ async def test_batch_embed_async(): @vcr.use_cassette() def test_chat(): e = OpenAIChatCompletion(identifier='gpt-3.5-turbo', prompt='Hello, {context}') - resp = e.predict('', one=True, context=['world!']) + resp = e.predict_one('', context=['world!']) assert isinstance(resp, str) @@ -139,26 +118,7 @@ def test_chat(): @vcr.use_cassette() def test_batch_chat(): e = OpenAIChatCompletion(identifier='gpt-3.5-turbo') - resp = e.predict(['Hello, world!'], one=False) - - assert isinstance(resp, list) - assert isinstance(resp[0], str) - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_chat_async(): - e = OpenAIChatCompletion(identifier='gpt-3.5-turbo', prompt='Hello, {context}') - resp = await e.apredict('', one=True, context=['world!']) - - assert isinstance(resp, str) - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_batch_chat_async(): - e = OpenAIChatCompletion(identifier='gpt-3.5-turbo') - resp = await e.apredict(['Hello, world!'], one=False) + resp = e.predict([(('Hello, world!',), {})]) assert isinstance(resp, list) assert isinstance(resp[0], str) @@ -169,8 +129,9 @@ def test_create_url(): e = OpenAIImageCreation( identifier='dall-e', prompt='a close up, studio photographic portrait of a {context}', + response_format='url', ) - resp = e.predict('', one=True, response_format='url', context=['cat']) + resp = e.predict_one('cat') # PNG 8-byte signature assert resp[0:16] == PNG_BYTE_SIGNATURE @@ -178,51 +139,12 @@ def test_create_url(): @vcr.use_cassette() def test_create_url_batch(): - e = OpenAIImageCreation( - identifier='dall-e', prompt='a close up, studio photographic portrait of a' - ) - resp = e.predict(['cat', 'dog'], response_format='url') - - for img in resp: - # PNG 8-byte signature - assert img[0:16] == PNG_BYTE_SIGNATURE - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_create_async(): e = OpenAIImageCreation( identifier='dall-e', - prompt='a close up, studio photographic portrait of a {context}', + prompt='a close up, studio photographic portrait of a', + response_format='url', ) - resp = await e.apredict('', one=True, context=['cat']) - - assert isinstance(resp, bytes) - - # PNG 8-byte signature - assert resp[0:16] == PNG_BYTE_SIGNATURE - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_create_url_async(): - e = OpenAIImageCreation( - identifier='dall-e', - prompt='a close up, studio photographic portrait of a {context}', - ) - resp = await e.apredict('', one=True, response_format='url', context=['cat']) - - # PNG 8-byte signature - assert resp[0:16] == PNG_BYTE_SIGNATURE - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_create_url_async_batch(): - e = OpenAIImageCreation( - identifier='dall-e', prompt='a close up, studio photographic portrait of a' - ) - resp = await e.apredict(['cat', 'dog'], response_format='url') + resp = e.predict(['cat', 'dog']) for img in resp: # PNG 8-byte signature @@ -232,11 +154,13 @@ async def test_create_url_async_batch(): @vcr.use_cassette() def test_edit_url(): e = OpenAIImageEdit( - identifier='dall-e', prompt='A celebration party at the launch of {context}' + identifier='dall-e', + prompt='A celebration party at the launch of {context}', + response_format='url', ) with open('test/material/data/rickroll.png', 'rb') as f: buffer = io.BytesIO(f.read()) - resp = e.predict(buffer, one=True, response_format='url', context=['superduperdb']) + resp = e.predict_one(buffer, context=['superduperdb']) buffer.close() # PNG 8-byte signature @@ -246,69 +170,21 @@ def test_edit_url(): @vcr.use_cassette() def test_edit_url_batch(): e = OpenAIImageEdit( - identifier='dall-e', prompt='A celebration party at the launch of superduperdb' + identifier='dall-e', + prompt='A celebration party at the launch of superduperdb', + response_format='url', ) with open('test/material/data/rickroll.png', 'rb') as f: buffer_one = io.BytesIO(f.read()) with open('test/material/data/rickroll.png', 'rb') as f: buffer_two = io.BytesIO(f.read()) - resp = e.predict([buffer_one, buffer_two], response_format='url') - - buffer_one.close() - buffer_two.close() - - for img in resp: - # PNG 8-byte signature - assert img[0:16] == PNG_BYTE_SIGNATURE - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_edit_async(): - e = OpenAIImageEdit( - identifier='dall-e', prompt='A celebration party at the launch of {context}' - ) - with open('test/material/data/rickroll.png', 'rb') as f: - buffer = io.BytesIO(f.read()) - resp = await e.apredict(buffer, one=True, context=['superduperdb']) - buffer.close() - - assert isinstance(resp, bytes) - - # PNG 8-byte signature - assert resp[0:16] == PNG_BYTE_SIGNATURE - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_edit_url_async(): - e = OpenAIImageEdit( - identifier='dall-e', prompt='A celebration party at the launch of {context}' - ) - with open('test/material/data/rickroll.png', 'rb') as f: - buffer = io.BytesIO(f.read()) - resp = await e.apredict( - buffer, one=True, response_format='url', context=['superduperdb'] - ) - buffer.close() - - # PNG 8-byte signature - assert resp[0:16] == PNG_BYTE_SIGNATURE - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_edit_url_async_batch(): - e = OpenAIImageEdit( - identifier='dall-e', prompt='A celebration party at the launch of superduperdb' + resp = e.predict( + [ + ((buffer_one,), {}), + ((buffer_two,), {}), + ] ) - with open('test/material/data/rickroll.png', 'rb') as f: - buffer_one = io.BytesIO(f.read()) - with open('test/material/data/rickroll.png', 'rb') as f: - buffer_two = io.BytesIO(f.read()) - - resp = await e.apredict([buffer_one, buffer_two], response_format='url') buffer_one.close() buffer_two.close() @@ -328,68 +204,12 @@ def test_transcribe(): 'only make an exception for the following words: {context}' ) e = OpenAIAudioTranscription(identifier='whisper-1', prompt=prompt) - resp = e.predict(buffer, one=True, context=['United States']) + resp = e.predict_one(buffer, context=['United States']) buffer.close() assert 'United States' in resp -@vcr.use_cassette() -def test_batch_transcribe(): - with open('test/material/data/test.wav', 'rb') as f: - buffer = io.BytesIO(f.read()) - buffer.name = 'test.wav' - - with open('test/material/data/test.wav', 'rb') as f: - buffer2 = io.BytesIO(f.read()) - buffer2.name = 'test.wav' - - e = OpenAIAudioTranscription(identifier='whisper-1') - resp = e.predict([buffer, buffer2], one=False, batch_size=1) - buffer.close() - - assert len(resp) == 2 - assert resp[0] == resp[1] - assert 'United States' in resp[0] - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_transcribe_async(): - with open('test/material/data/test.wav', 'rb') as f: - buffer = io.BytesIO(f.read()) - buffer.name = 'test.wav' - prompt = ( - 'i have some advice for you. write all text in lower-case.' - 'only make an exception for the following words: {context}' - ) - e = OpenAIAudioTranscription(identifier='whisper-1', prompt=prompt) - resp = await e.apredict(buffer, one=True, context=['United States']) - buffer.close() - - assert 'United States' in resp - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_batch_transcribe_async(): - with open('test/material/data/test.wav', 'rb') as f: - buffer = io.BytesIO(f.read()) - buffer.name = 'test1.wav' - - with open('test/material/data/test.wav', 'rb') as f: - buffer2 = io.BytesIO(f.read()) - buffer2.name = 'test1.wav' - e = OpenAIAudioTranscription(identifier='whisper-1') - - resp = await e.apredict([buffer, buffer2], one=False, batch_size=1) - buffer.close() - - assert len(resp) == 2 - assert resp[0] == resp[1] - assert 'United States' in resp[0] - - @vcr.use_cassette() def test_translate(): with open('test/material/data/german.wav', 'rb') as f: @@ -400,7 +220,7 @@ def test_translate(): 'only make an exception for the following words: {context}' ) e = OpenAIAudioTranslation(identifier='whisper-1', prompt=prompt) - resp = e.predict(buffer, one=True, context=['Emmerich']) + resp = e.predict_one(buffer, context=['Emmerich']) buffer.close() assert 'station' in resp @@ -416,45 +236,8 @@ def test_batch_translate(): buffer2 = io.BytesIO(f.read()) buffer2.name = 'test.wav' - e = OpenAIAudioTranslation(identifier='whisper-1') - resp = e.predict([buffer, buffer2], one=False, batch_size=1) - buffer.close() - - assert len(resp) == 2 - assert resp[0] == resp[1] - assert 'station' in resp[0] - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_translate_async(): - with open('test/material/data/german.wav', 'rb') as f: - buffer = io.BytesIO(f.read()) - buffer.name = 'test.wav' - prompt = ( - 'i have some advice for you. write all text in lower-case.' - 'only make an exception for the following words: {context}' - ) - e = OpenAIAudioTranslation(identifier='whisper-1', prompt=prompt) - resp = await e.apredict(buffer, one=True, context=['Emmerich']) - buffer.close() - - assert 'station' in resp - - -@vcr.use_cassette() -@pytest.mark.asyncio -async def test_batch_translate_async(): - with open('test/material/data/german.wav', 'rb') as f: - buffer = io.BytesIO(f.read()) - buffer.name = 'test1.wav' - - with open('test/material/data/german.wav', 'rb') as f: - buffer2 = io.BytesIO(f.read()) - buffer2.name = 'test1.wav' - e = OpenAIAudioTranslation(identifier='whisper-1') - - resp = await e.apredict([buffer, buffer2], one=False, batch_size=1) + e = OpenAIAudioTranslation(identifier='whisper-1', batch_size=1) + resp = e.predict([buffer, buffer2]) buffer.close() assert len(resp) == 2 diff --git a/test/integration/test_atlas.py b/test/integration/test_atlas.py index c6cf813589..16809e387e 100644 --- a/test/integration/test_atlas.py +++ b/test/integration/test_atlas.py @@ -10,7 +10,7 @@ from superduperdb.backends.mongodb.query import Collection from superduperdb.base.document import Document from superduperdb.components.listener import Listener -from superduperdb.components.model import Model +from superduperdb.components.model import ObjectModel from superduperdb.components.vector_index import VectorIndex, vector ATLAS_VECTOR_URI = os.environ.get('ATLAS_VECTOR_URI') @@ -30,7 +30,7 @@ def atlas_search_config(): @pytest.mark.skipif(ATLAS_VECTOR_URI is None, reason='Only atlas deployments relevant.') def test_setup_atlas_vector_search(atlas_search_config): - model = Model( + model = ObjectModel( identifier='test-model', object=random_vector_model, encoder=vector(shape=(16,)) ) client = pymongo.MongoClient(ATLAS_VECTOR_URI) diff --git a/test/integration/test_cdc.py b/test/integration/test_cdc.py index 434811e148..3b5c5bed6c 100644 --- a/test/integration/test_cdc.py +++ b/test/integration/test_cdc.py @@ -440,15 +440,17 @@ def state_check(): def add_and_cleanup_listeners(database, select): """Add listeners to the database and remove them after the test""" + m = database.load('model', 'model_linear_a') + listener_x = Listener( key='x', - model='model_linear_a', + model=m, select=select, ) listener_z = Listener( key='z', - model='model_linear_a', + model=m, select=select, ) diff --git a/test/integration/test_end2end.py b/test/integration/test_end2end.py index 8a8afc9b83..13759dc045 100644 --- a/test/integration/test_end2end.py +++ b/test/integration/test_end2end.py @@ -19,7 +19,7 @@ def fixture(self, *args, **kwargs): from superduperdb.backends.mongodb.query import Collection from superduperdb.base.document import Document from superduperdb.components.listener import Listener -from superduperdb.components.model import Model +from superduperdb.components.model import ObjectModel from superduperdb.components.schema import Schema from superduperdb.components.vector_index import VectorIndex from superduperdb.ext.pillow.encoder import pil_image @@ -48,14 +48,6 @@ def predict(self, x): } ] * random.randint(1, 2) - @staticmethod - def preprocess(x): - return x - - @staticmethod - def postprocess(x): - return x - class Model2: def predict(self, x): @@ -63,14 +55,6 @@ def predict(self, x): return np.asarray([int(x) * 100] * 10) return x - @staticmethod - def preprocess(x): - return x - - @staticmethod - def postprocess(x): - return x - def _wait_for_keys(db, collection='_outputs.int.model1', n=10, key=''): retry_left = 5 @@ -163,31 +147,26 @@ def test_advance_setup(distributed_db, image_url): from superduperdb.ext.numpy import array - model1 = Model( + model1 = ObjectModel( identifier='model1', - object=Model1(), - preprocess=Model1.preprocess, - postprocess=Model1.postprocess, + object=Model1().predict, flatten=True, model_update_kwargs={'document_embedded': False}, output_schema=Schema( identifier='myschema', fields={'image': pil_image, 'int': array('int64', (10,))}, ), - predict_method='predict', ) db.add(model1) e = array('int64', (10,)) - model2 = Model( + model2 = ObjectModel( identifier='model2', - object=Model2(), - preprocess=Model2.preprocess, - postprocess=Model2.postprocess, - predict_method='predict', + object=Model2().predict, datatype=e, ) + db.add(model2) listener1 = Listener( diff --git a/test/integration/test_ibis.py b/test/integration/test_ibis.py index 5c27c7bf09..fbad79ff12 100644 --- a/test/integration/test_ibis.py +++ b/test/integration/test_ibis.py @@ -133,7 +133,7 @@ def postprocess(x): ) # Apply the torchvision model - resnet.predict( + resnet.predict_in_db( X='image', db=db, select=t.select('id', 'image'), @@ -150,7 +150,7 @@ def postprocess(x): ) # apply to the table - vectorize.predict( + vectorize.predict_in_db( X='image', db=db, select=t.select('id', 'image'), diff --git a/test/integration/test_ray.py b/test/integration/test_ray.py index cf6425b464..e4f7a77dad 100644 --- a/test/integration/test_ray.py +++ b/test/integration/test_ray.py @@ -19,9 +19,10 @@ @contextmanager def add_and_cleanup_listener(database, collection_name): """Add listener to the database and remove it after the test""" + m = database.load('model', 'model_linear_a') listener_x = Listener( key='x', - model='model_linear_a', + model=m, select=Collection(identifier=collection_name).find(), ) @@ -133,11 +134,14 @@ def test_node_2(*args, **kwargs): def test_model_job_logs(distributed_db, fake_updates): # Set Collection Listener # ------------------------------ + collection = Collection(identifier=str(uuid.uuid4())) + m = distributed_db('model', 'model_linear_a') + listener_x = Listener( key='x', - model='model_linear_a', + model=m, select=collection.find(), ) jobs, _ = distributed_db.add(listener_x) diff --git a/test/unittest/backends/ibis/test_query.py b/test/unittest/backends/ibis/test_query.py index 0ec9a0179a..f1510c3939 100644 --- a/test/unittest/backends/ibis/test_query.py +++ b/test/unittest/backends/ibis/test_query.py @@ -9,7 +9,7 @@ from superduperdb.backends.ibis.field_types import dtype from superduperdb.backends.ibis.query import IbisQueryTable, Table from superduperdb.base.serializable import Serializable -from superduperdb.components.model import Model +from superduperdb.components.model import ObjectModel from superduperdb.components.schema import Schema from superduperdb.ext.numpy.encoder import array from superduperdb.ext.pillow.encoder import pil_image @@ -71,12 +71,12 @@ def duckdb(monkeypatch): ) ) - model = Model( + model = ObjectModel( object=lambda _: numpy.random.randn(32), identifier='test', datatype=array('float64', shape=(32,)), ) - model.predict('x', select=t, db=db) + model.predict_in_db('x', select=t, db=db) _, s = db.add( Table( diff --git a/test/unittest/backends/test_query_dataset.py b/test/unittest/backends/test_query_dataset.py index 00728b4562..375fc60163 100644 --- a/test/unittest/backends/test_query_dataset.py +++ b/test/unittest/backends/test_query_dataset.py @@ -2,6 +2,7 @@ from superduperdb.backends.mongodb.query import Collection from superduperdb.backends.query_dataset import QueryDataset +from superduperdb.components.model import Mapping try: import torch @@ -13,12 +14,14 @@ def test_query_dataset(db): train_data = QueryDataset( db=db, + mapping=Mapping('_base', signature='singleton'), select=Collection('documents').find( {}, {'_id': 0, 'x': 1, '_fold': 1, '_outputs': 1} ), fold='train', ) r = train_data[0] + assert '_id' not in r assert r['_fold'] == 'train' assert 'y' not in r @@ -28,7 +31,7 @@ def test_query_dataset(db): train_data = QueryDataset( db=db, select=Collection('documents').find(), - keys=['x', 'y'], + mapping=Mapping({'x': 'x', 'y': 'y'}, signature='**kwargs'), fold='train', ) @@ -36,22 +39,15 @@ def test_query_dataset(db): assert '_id' not in r assert set(r.keys()) == {'x', 'y'} - _ = QueryDataset( - db=db, - select=Collection('documents').find(), - fold='valid', - ) - @pytest.mark.skipif(not torch, reason='Torch not installed') def test_query_dataset_base(db): train_data = QueryDataset( db=db, select=Collection('documents').find({}, {'_id': 0}), - keys=['_base', 'y'], + mapping=Mapping(['_base', 'y'], signature='*args'), fold='train', ) r = train_data[0] - assert '_id' not in r - assert set(r.keys()) == {'_base', 'y'} - assert r['_base']['_fold'] == 'train' + assert isinstance(r, tuple) + assert len(r) == 2 diff --git a/test/unittest/base/test_datalayer.py b/test/unittest/base/test_datalayer.py index 16f1df1aee..8e4d37eeed 100644 --- a/test/unittest/base/test_datalayer.py +++ b/test/unittest/base/test_datalayer.py @@ -29,12 +29,11 @@ from superduperdb.components.dataset import Dataset from superduperdb.components.datatype import ( DataType, - Encodable, dill_serializer, pickle_serializer, ) from superduperdb.components.listener import Listener -from superduperdb.components.model import Model +from superduperdb.components.model import ObjectModel from superduperdb.components.schema import Schema n_data_points = 250 @@ -76,7 +75,7 @@ def child_components(self): def add_fake_model(db: Datalayer): - model = Model( + model = ObjectModel( object=lambda x: str(x), identifier='fake_model', datatype=DataType(identifier='base'), @@ -97,7 +96,7 @@ def add_fake_model(db: Datalayer): select = db.load('table', 'documents').to_query().select('id', 'x') db.add( Listener( - model='fake_model', + model=model, select=select, key='x', ), @@ -219,7 +218,7 @@ def test_add(db): @pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_add_with_artifact(db): - m = Model( + m = ObjectModel( identifier='test', object=lambda x: x + 2, datatype=dill_serializer, @@ -228,12 +227,8 @@ def test_add_with_artifact(db): db.add(m) m = db.load('model', m.identifier) - import pprint - - pprint.pprint(m) - - # assert m.object is not None - # assert callable(m.object) + assert m.object is not None + assert callable(m.object) @pytest.mark.parametrize("db", [DBConfig.sqldb_empty], indirect=True) @@ -424,72 +419,6 @@ def test_show(db): assert db.show('test-component', 'b', -1)['version'] == 2 -@pytest.mark.skipif(not torch, reason='Torch not installed') -@pytest.mark.parametrize( - "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True -) -def test_predict(db: Datalayer): - models = [ - TorchModel( - object=torch.nn.Linear(16, 2), - identifier='model1', - datatype=tensor(torch.float32, shape=(4, 2)), - ), - TorchModel( - object=torch.nn.Linear(16, 3), - identifier='model2', - datatype=tensor(torch.float32, shape=(4, 3)), - ), - TorchModel( - object=torch.nn.Linear(16, 3), - identifier='model3', - datatype=DataType( - identifier='test-datatype', - encoder=lambda x: torch.argmax(x, dim=1), - ), - ), - ] - db.add(models) - - # test model selection - x = torch.randn(4, 16) - assert db.predict('model1', x)[0]['_base'].x.shape == torch.Size([4, 2]) - assert db.predict('model2', x)[0]['_base'].x.shape == torch.Size([4, 3]) - - -@pytest.mark.skipif(not torch, reason='Torch not installed') -@pytest.mark.parametrize( - "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True -) -def test_predict_context(db: Datalayer): - db.add( - TorchModel( - object=torch.nn.Linear(16, 2), - identifier='model', - datatype=tensor(torch.float32, shape=(4, 2)), - ) - ) - - y, context_out = db.predict('model', torch.randn(4, 16)) - assert not context_out - - with patch.object(db, '_get_context') as mock_get_context: - mock_get_context.return_value = ( - [Document({'_base': 'test'}), Document({'_base': 'test'})], - [ - Document({'_base': Encodable(datatype=None, x=torch.randn(4, 2))}), - Document({'_base': Encodable(datatype=None, x=torch.randn(4, 3))}), - ], - ) - y, context_out = db.predict( - 'model', - torch.randn(4, 16), - context_select=Collection('context_collection').find({}), - ) - assert context_out[0]['_base'].x.shape == torch.Size([4, 2]) - assert context_out[1]['_base'].x.shape == torch.Size([4, 3]) - - @pytest.mark.parametrize( "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) @@ -498,7 +427,7 @@ def test_get_context(db): fake_contexts = [Document({'text': f'hello world {i}'}) for i in range(10)] - model = Model(object=lambda x: x, identifier='model', takes_context=True) + model = ObjectModel(object=lambda x, context: x, identifier='model') context_select = MagicMock(spec=Select) context_select.variables = [] context_select.execute.return_value = fake_contexts @@ -516,7 +445,7 @@ def test_get_context(db): ] # Testing models that cannot accept context - model = Model(object=lambda x: x, identifier='model', takes_context=False) + model = ObjectModel(object=lambda x: x, identifier='model') with pytest.raises(AssertionError): db._get_context(model, context_select, context_key=None) @@ -525,13 +454,13 @@ def test_get_context(db): "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_load(db): - m1 = Model(object=lambda x: x, identifier='m1', datatype=dtype('int32')) + m1 = ObjectModel(object=lambda x: x, identifier='m1', datatype=dtype('int32')) db.add( [ DataType(identifier='e1'), DataType(identifier='e2'), m1, - Model(object=lambda x: x, identifier='m1', datatype=dtype('int32')), + ObjectModel(object=lambda x: x, identifier='m1', datatype=dtype('int32')), m1, ] ) @@ -627,7 +556,7 @@ def test_delete(db): "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True ) def test_replace(db): - model = Model( + model = ObjectModel( object=lambda x: x + 1, identifier='m', datatype=DataType(identifier='base'), @@ -638,17 +567,21 @@ def test_replace(db): db.replace(model, upsert=True) - assert db.load('model', 'm').predict([1]) == [2] + assert db.load('model', 'm').predict_one(1) == 2 # replace the 0 version of the model - new_model = Model(object=lambda x: x + 2, identifier='m') + new_model = ObjectModel( + object=lambda x: x + 2, identifier='m', signature='singleton' + ) new_model.version = 0 db.replace(new_model) time.sleep(0.1) assert db.load('model', 'm').predict([1]) == [3] # replace the last version of the model - new_model = Model(object=lambda x: x + 3, identifier='m') + new_model = ObjectModel( + object=lambda x: x + 3, identifier='m', signature='singleton' + ) db.replace(new_model) time.sleep(0.1) assert db.load('model', 'm').predict([1]) == [4] diff --git a/test/unittest/base/test_serializable.py b/test/unittest/base/test_serializable.py index de93353e01..1497b0f04e 100644 --- a/test/unittest/base/test_serializable.py +++ b/test/unittest/base/test_serializable.py @@ -2,7 +2,7 @@ import typing as t from pprint import pprint -from superduperdb import Model +from superduperdb import ObjectModel from superduperdb.backends.mongodb.query import Collection from superduperdb.base.document import Document from superduperdb.base.serializable import Serializable, Variable @@ -26,8 +26,8 @@ class TestSubModel(Component): type_id: t.ClassVar[str] = 'test-sub-model' a: int b: t.Union[str, Variable] - c: Model - d: t.List[Model] + c: ObjectModel + d: t.List[ObjectModel] e: OtherSer f: t.Callable @@ -91,8 +91,8 @@ def test_component_with_document(): identifier='test-1', a=2, b='test', - c=Model('test-2', object=lambda x: x + 2), - d=[Model('test-3', object=lambda x: x + 2)], + c=ObjectModel('test-2', object=lambda x: x + 2), + d=[ObjectModel('test-3', object=lambda x: x + 2)], e=OtherSer(d='test'), f=lambda x: x, ) diff --git a/test/unittest/component/test_component.py b/test/unittest/component/test_component.py index 2463c7cb05..f2cf348465 100644 --- a/test/unittest/component/test_component.py +++ b/test/unittest/component/test_component.py @@ -1,11 +1,14 @@ +import dataclasses as dc import os import shutil +import typing as t import pytest -from superduperdb import Model +from superduperdb import ObjectModel +from superduperdb.base.document import Document from superduperdb.components.component import Component -from superduperdb.components.datatype import dill_serializer +from superduperdb.components.datatype import Encodable, dill_serializer @pytest.fixture @@ -16,10 +19,53 @@ def cleanup(): def test_compile_decompile(cleanup): - m = Model('test_export', object=lambda x: x, datatype=dill_serializer) + m = ObjectModel('test_export', object=lambda x: x, datatype=dill_serializer) m.version = 0 m.datatype.version = 0 m.export() assert os.path.exists('test_export.tar.gz') m_reload = Component.import_('test_export.tar.gz') - assert isinstance(m_reload, Model) + assert isinstance(m_reload, ObjectModel) + + +@dc.dataclass(kw_only=True) +class MyComponent(Component): + _lazy_fields: t.ClassVar[t.Sequence[str]] = ('my_dict',) + my_dict: t.Dict + nested_list: t.List + a: t.Callable + + +def test_init(monkeypatch): + from unittest.mock import MagicMock + + def unpack(self, db): + if '_base' in self.keys(): + return [lambda x: x + 1, lambda x: x + 2] + return {'a': lambda x: x + 1} + + monkeypatch.setattr(Document, 'unpack', unpack) + + e = Encodable(x=None, file_id='123', datatype=None) + a = Encodable(x=None, file_id='456', datatype=None) + + def side_effect(*args, **kwargs): + a.x = lambda x: x + 1 + + a.init = MagicMock() + a.init.side_effect = side_effect + + list_ = [e, a] + + c = MyComponent('test', my_dict={'a': a}, a=a, nested_list=list_) + + c.init() + + assert callable(c.my_dict['a']) + assert c.my_dict['a'](1) == 2 + + assert callable(c.a) + assert c.a(1) == 2 + + assert callable(c.nested_list[1]) + assert c.nested_list[1](1) == 3 diff --git a/test/unittest/component/test_listener.py b/test/unittest/component/test_listener.py index a5bee9aa95..267fab9c96 100644 --- a/test/unittest/component/test_listener.py +++ b/test/unittest/component/test_listener.py @@ -1,12 +1,12 @@ from superduperdb.backends.mongodb.query import Collection from superduperdb.components.listener import Listener -from superduperdb.components.model import Model +from superduperdb.components.model import ObjectModel def test_listener_serializes_properly(): q = Collection('test').find({}, {}) listener = Listener( - model=Model('test', object=lambda x: x), + model=ObjectModel('test', object=lambda x: x), select=q, key='test', ) diff --git a/test/unittest/component/test_model.py b/test/unittest/component/test_model.py index 2fc8aafad5..6261d8e990 100644 --- a/test/unittest/component/test_model.py +++ b/test/unittest/component/test_model.py @@ -1,4 +1,4 @@ -import inspect +import dataclasses as dc import random from test.db_config import DBConfig from unittest.mock import MagicMock, patch @@ -8,28 +8,32 @@ import pytest from sklearn.metrics import accuracy_score, f1_score -from superduperdb.backends.base.query import CompoundSelect, Select +from superduperdb.backends.base.query import Select from superduperdb.backends.local.compute import LocalComputeBackend from superduperdb.backends.mongodb.query import Collection from superduperdb.base.datalayer import Datalayer from superduperdb.base.document import Document from superduperdb.base.serializable import Variable -from superduperdb.components.component import Component from superduperdb.components.dataset import Dataset from superduperdb.components.datatype import DataType -from superduperdb.components.listener import Listener from superduperdb.components.metric import Metric from superduperdb.components.model import ( - Model, + ObjectModel, QueryModel, SequentialModel, + Signature, + _Fittable, _Predictor, _TrainingConfiguration, ) + # ------------------------------------------ # Test the _TrainingConfiguration class (tc) # ------------------------------------------ +@dc.dataclass +class Validator(_Fittable, ObjectModel): + ... def test_tc_type_id(): @@ -55,7 +59,7 @@ def test_tc_get_method(): # -------------------------------- -# Test the PredictMixin class (pm) +# Test the _Predictor class (pm) # -------------------------------- @@ -68,225 +72,89 @@ def return_self_multikey(x, y, z): def to_call(x): - if isinstance(x, list): - return [to_call(i) for i in x] return x * 5 -def to_call_multi(x): - if isinstance(x[0], list): - return [1] * len(x) - return 1 - - -def preprocess(x): - return x + 1 - - -def preprocess_multi(x, y): - return x + y - - -def postprocess(x): - return x + 0.1 - - -def mock_forward(self, x, **kwargs): - return to_call(x) - - -def mock_forward_multi(self, x, **kwargs): - return to_call_multi(x) - - -class TestModel(Component, _Predictor): - batch_predict: bool = False +def to_call_multi(x, y): + return x @pytest.fixture -def predict_mixin(request) -> _Predictor: - cls_ = getattr(request, 'param', _Predictor) - - if 'identifier' in inspect.signature(cls_).parameters: - predict_mixin = cls_(identifier='test') - else: - predict_mixin = cls_() - predict_mixin.identifier = 'test' - predict_mixin.to_call = to_call - predict_mixin.preprocess = preprocess - predict_mixin.postprocess = postprocess - predict_mixin.takes_context = False - predict_mixin.output_schema = None - predict_mixin.datatype = None - predict_mixin.model_update_kwargs = {} +def predict_mixin() -> _Predictor: + predict_mixin = ObjectModel('test', object=to_call) predict_mixin.version = 0 return predict_mixin @pytest.fixture -def predict_mixin_multikey(request) -> _Predictor: - cls_ = getattr(request, 'param', _Predictor) - - if 'identifier' in inspect.signature(cls_).parameters: - predict_mixin = cls_(identifier='test') - else: - predict_mixin = cls_() - predict_mixin.identifier = 'test' - predict_mixin.to_call = to_call_multi - predict_mixin.preprocess = preprocess_multi - predict_mixin.postprocess = postprocess - predict_mixin.takes_context = False - predict_mixin.output_schema = None - predict_mixin.datatype = None - predict_mixin.model_update_kwargs = {} +def predict_mixin_multikey() -> _Predictor: + predict_mixin = ObjectModel('test', object=to_call_multi) predict_mixin.version = 0 return predict_mixin def test_pm_predict_one(predict_mixin): X = np.random.randn(5) - - # preprocess -> to_call -> postprocess - expect = postprocess(to_call(preprocess(X))) - assert np.allclose(predict_mixin._predict_one(X), expect) - - # to_call -> postprocess - with patch.object(predict_mixin, 'preprocess', None): - expect = postprocess(to_call(X)) - assert np.allclose(predict_mixin._predict_one(X), expect) - - # preprocess -> to_call - with patch.object(predict_mixin, 'postprocess', None): - expect = to_call(preprocess(X)) - assert np.allclose(predict_mixin._predict_one(X), expect) - - -@pytest.mark.parametrize( - 'batch_predict, num_workers, expect_type', - [ - [True, 0, np.ndarray], - [False, 0, list], - [False, 1, list], - [False, 5, list], - ], -) -def test_pm_forward(batch_predict, num_workers, expect_type): - predict_mixin = _Predictor() - X = np.random.randn(4, 5) - - predict_mixin.to_call = to_call - predict_mixin.batch_predict = batch_predict - - output = predict_mixin._forward(X, num_workers=num_workers) - assert isinstance(output, expect_type) - assert np.allclose(output, to_call(X)) + expect = to_call(X) + assert np.allclose(predict_mixin.predict_one(X), expect) -@patch.object(_Predictor, '_forward', mock_forward_multi) def test_predict_core_multikey(predict_mixin_multikey): X = 1 Y = 2 - Z = 2 + expect = to_call_multi(X, Y) + output = predict_mixin_multikey.predict_one(X, Y) + assert output == expect - # Multi key with preprocess - # As list - predict_mixin_multikey.preprocess = None - expect = postprocess(to_call_multi([X, Y, Z])) - output = predict_mixin_multikey._predict([[X, Y, Z], [X, Y, Z]]) - assert isinstance(output, list) - assert np.allclose(output, expect) + output = predict_mixin_multikey.predict_one(x=X, y=Y) + assert output == expect + with pytest.raises(TypeError): + predict_mixin_multikey.predict(X, Y) -@patch.object(_Predictor, '_forward', mock_forward) -def test_predict_core_multikey_dict(predict_mixin_multikey): - X = 1 - Y = 2 - # As Dict - predict_mixin_multikey.preprocess = preprocess_multi - output = predict_mixin_multikey._predict([{'x': X, 'y': Y}]) + output = predict_mixin_multikey.predict([((X, Y), {}), ((X, Y), {})]) assert isinstance(output, list) - assert np.allclose(output, 15.1) - - -@patch.object(_Predictor, '_forward', mock_forward) -def test_predict_preprocess_multikey(predict_mixin_multikey): - X = 1 - Y = 2 - # Multi key with preprocess - predict_mixin_multikey.to_call = to_call - expect = postprocess(to_call(preprocess_multi(X, Y))) - output = predict_mixin_multikey._predict([[X, Y], [X, Y]]) + predict_mixin_multikey.num_workers = 2 + output = predict_mixin_multikey.predict([((X, Y), {}), ((X, Y), {})]) assert isinstance(output, list) - assert np.allclose(output, expect) -@patch.object(_Predictor, '_forward', mock_forward) def test_pm_core_predict(predict_mixin): - X = np.random.randn(4, 5) - # make sure _predict_one is called - with patch.object(predict_mixin, '_predict_one', return_self): - assert predict_mixin._predict(5, one=True) == return_self(5) - - expect = postprocess(to_call(preprocess(X))) - output = predict_mixin._predict(X) - assert isinstance(output, list) - assert np.allclose(output, expect) - - # to_call -> postprocess - with patch.object(predict_mixin, 'preprocess', None): - expect = postprocess(to_call(X)) - output = predict_mixin._predict(X) - assert isinstance(output, list) - assert np.allclose(output, expect) - - # preprocess -> to_call - with patch.object(predict_mixin, 'postprocess', None): - output = predict_mixin._predict(X) - expect = to_call(preprocess(X)) - assert isinstance(output, list) - assert np.allclose(output, expect) - - -def test_pm_create_predict_job(predict_mixin): - select = MagicMock(spec=Select) - X = 'x' - ids = [1, 2, 3] - max_chunk_size = 2 - job = predict_mixin.create_predict_job(X, select, ids, max_chunk_size) - assert job.component_identifier == predict_mixin.identifier - assert job.method_name == 'predict' - assert job.args == [X] - assert job.kwargs['max_chunk_size'] == max_chunk_size - assert job.kwargs['ids'] == ids - - -@patch.object(Datalayer, 'add') -@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) -def test_pm_predict_and_listen(mock_add, predict_mixin, db): - X = 'x' - select = MagicMock(CompoundSelect) - - in_memory = False - max_chunk_size = 2 - predict_mixin._predict_and_listen( - X, - select, - db=db, - max_chunk_size=max_chunk_size, - in_memory=in_memory, + with patch.object(predict_mixin, 'predict_one', return_self): + assert predict_mixin.predict_one(5) == return_self(5) + + +@patch('superduperdb.components.model.ComponentJob') +def test_pm_create_predict_job(mock_job, predict_mixin): + mock_db = MagicMock() + mock_select = MagicMock() + mock_select.dict().encode.return_value = b'encoded_select' + X = 'model_input' + ids = ['id1', 'id2'] + max_chunk_size = 100 + in_memory = True + overwrite = False + predict_mixin.predict_in_db_job( + X=X, db=mock_db, select=mock_select, ids=ids, max_chunk_size=max_chunk_size + ) + mock_job.assert_called_once_with( + component_identifier=predict_mixin.identifier, # Adjust according to your setup + method_name='predict_in_db', + type_id='model', + args=[X], + kwargs={ + 'select': b'encoded_select', + 'ids': ids, + 'max_chunk_size': max_chunk_size, + 'in_memory': in_memory, + 'overwrite': overwrite, + }, ) - listener = mock_add.call_args[0][0] - - # Check whether create a correct listener - assert isinstance(listener, Listener) - assert listener.model == predict_mixin - assert listener.predict_kwargs['in_memory'] == in_memory - assert listener.predict_kwargs['max_chunk_size'] == max_chunk_size -@pytest.mark.parametrize('predict_mixin', [TestModel], indirect=True) +# @pytest.mark.parametrize('predict_mixin', [TestModel], indirect=True) def test_pm_predict(predict_mixin): # Check the logic of predict method, the mock method will be tested below db = MagicMock(spec=Datalayer) @@ -295,77 +163,73 @@ def test_pm_predict(predict_mixin): select = MagicMock(spec=Select) select.table_or_collection = MagicMock() - with patch.object(predict_mixin, '_predict_and_listen') as predict_func: - predict_mixin.predict('x', db, select, listen=True) + with patch.object(predict_mixin, 'predict') as predict_func: + predict_mixin.predict_in_db('x', db=db, select=select) predict_func.assert_called_once() - with patch.object(predict_mixin, '_predict') as predict_func: - predict_mixin.predict('x') - predict_func.assert_called_once() +def test_pm_predict_with_select_ids(monkeypatch, predict_mixin): + xs = [np.random.randn(4) for _ in range(10)] -def test_pm_predict_with_select(predict_mixin): - # Check the logic about overwrite in _predict_with_select + docs = [Document({'x': x}) for x in xs] X = 'x' - all_ids = ['1', '2', '3'] - ids_of_missing_outputs = ['1', '2'] - select = MagicMock(spec=Select) - select.select_ids_of_missing_outputs.return_value = 'missing' - - def return_value(select_type): - ids = ids_of_missing_outputs if select_type == 'missing' else all_ids - query_result = [ - ( - { - 'id_field': id, - } - ) - for id in ids - ] - return query_result + ids = [i for i in range(10)] + select = MagicMock(spec=Select) db = MagicMock(spec=Datalayer) - db.execute.side_effect = return_value - db.databackend = MagicMock() - db.databackend.id_field = 'id_field' - - # overwrite = True - with patch.object(predict_mixin, '_predict_with_select_and_ids') as mock_predict: - predict_mixin._predict_with_select(X, select, db, overwrite=True) - _, kwargs = mock_predict.call_args - assert kwargs.get('ids') == all_ids - - # overwrite = False - with patch.object(predict_mixin, '_predict_with_select_and_ids') as mock_predict: - predict_mixin._predict_with_select( - X, select, db, overwrite=False, max_chunk_size=None, in_memory=True - ) - _, kwargs = mock_predict.call_args - assert kwargs.get('ids') == ids_of_missing_outputs - - -def test_model_on_create(): - db = MagicMock(spec=Datalayer) - db.databackend = MagicMock() + db.execute.return_value = docs - # Check the encoder is loaded if encoder is string - model = Model('test', object=object(), datatype='test_encoder') - with patch.object(db, 'load') as db_load: - model.pre_create(db) - db_load.assert_called_with('datatype', 'test_encoder') + with patch.object(predict_mixin, 'object') as my_object: + my_object.return_value = 2 + # Check the base predict function + predict_mixin.db = db + with patch.object(select, 'select_using_ids') as select_using_ids, patch.object( + select, 'model_update' + ) as model_update: + predict_mixin._predict_with_select_and_ids(X, db, select, ids) + select_using_ids.assert_called_once_with(ids) + _, kwargs = model_update.call_args + # make sure the outputs are set + assert kwargs.get('outputs') == [2] * 10 + + with ( + patch.object(predict_mixin, 'object') as my_object, + patch.object(select, 'select_using_ids') as select_using_id, + patch.object(select, 'model_update') as model_update, + ): + my_object.return_value = 2 + + monkeypatch.setattr(predict_mixin, 'datatype', DataType(identifier='test')) + predict_mixin._predict_with_select_and_ids(X, db, select, ids) + select_using_id.assert_called_once_with(ids) + _, kwargs = model_update.call_args + datatype = predict_mixin.datatype + assert kwargs.get('outputs') == [datatype(2).encode() for _ in range(10)] + + with patch.object(predict_mixin, 'object') as my_object: + my_object.return_value = {'out': 2} + # Check the base predict function with output_schema + from superduperdb.components.schema import Schema - # Check the output_component table is added by datalayer - model = Model('test', object=object(), datatype=DataType(identifier='test')) - output_component = MagicMock() - db.databackend.create_model_table_or_collection.return_value = output_component - with patch.object(db, 'add') as db_load: - model.post_create(db) - db_load.assert_called_with(output_component) + predict_mixin.datatype = None + predict_mixin.output_schema = schema = MagicMock(spec=Schema) + schema.side_effect = str + with patch.object(select, 'select_using_ids') as select_using_ids, patch.object( + select, 'model_update' + ) as model_update: + predict_mixin._predict_with_select_and_ids(X, db, select, ids) + select_using_ids.assert_called_once_with(ids) + _, kwargs = model_update.call_args + assert kwargs.get('outputs') == [str({'out': 2}) for _ in range(10)] def test_model_append_metrics(): - model = Model('test', object=object()) + @dc.dataclass + class _Tmp(ObjectModel, _Fittable): + ... + + model = _Tmp('test', object=object()) metric_values = {'acc': 0.5, 'loss': 0.5} @@ -380,11 +244,11 @@ def test_model_append_metrics(): assert model.metric_values.get('loss') == [0.5, 0.4] -@patch.object(Model, '_validate') +@patch.object(Validator, '_validate') def test_model_validate(mock_validate): # Check the metadadata recieves the correct values mock_validate.return_value = {'acc': 0.5, 'loss': 0.5} - model = Model('test', object=object()) + model = Validator('test', object=object()) db = MagicMock(spec=Datalayer) db.metadata = MagicMock() with patch.object(db, 'add') as db_add, patch.object( @@ -397,7 +261,7 @@ def test_model_validate(mock_validate): assert kwargs.get('value') == {'acc': 0.5, 'loss': 0.5} -@patch.object(Model, '_predict') +@patch.object(ObjectModel, 'predict') @pytest.mark.parametrize( "db", [ @@ -409,8 +273,10 @@ def test_model_validate(mock_validate): def test_model_core_validate(model_predict, valid_dataset, db): # Check the validation is done correctly db.add(valid_dataset) - model = Model('test', object=object(), train_X='x', train_y='y') - model_predict.side_effect = lambda x: [random.randint(0, 1) for _ in range(len(x))] + model = Validator('test', object=object(), train_X='x', train_y='y') + model_predict.side_effect = lambda dataset: [ + random.randint(0, 1) for _ in range(len(dataset)) + ] metrics = [ Metric('f1', object=f1_score), Metric('acc', object=accuracy_score), @@ -428,30 +294,31 @@ def test_model_core_validate(model_predict, valid_dataset, db): def test_model_create_fit_job(): # Check the fit job is created correctly - model = Model('test', object=object()) + model = Validator('test', object=object()) job = model.create_fit_job('x') assert job.component_identifier == model.identifier assert job.method_name == 'fit' assert job.args == ['x'] +@patch.object(Validator, '_fit') def test_model_fit(valid_dataset): # Check the logic of the fit method, the mock method was tested above - model = Model('test', object=object()) - with patch.object(model, '_fit') as model_fit: - model.fit('x') - model_fit.assert_called_once() - with patch.object(model, '_fit') as model_fit: - db = MagicMock(spec=Datalayer) - db.compute = MagicMock(spec=LocalComputeBackend) - model.fit( - valid_dataset, - db=db, - validation_sets=[valid_dataset], - ) - _, kwargs = model_fit.call_args - assert kwargs.get('validation_sets') == [valid_dataset.identifier] + Validator._fit.return_value = 'done' + model = Validator('test', object=object()) + model.fit('x') + model._fit.assert_called_once() + + db = MagicMock(spec=Datalayer) + db.compute = MagicMock(spec=LocalComputeBackend) + model.fit( + valid_dataset, + db=db, + validation_sets=[valid_dataset], + ) + _, kwargs = model._fit.call_args + assert kwargs.get('validation_sets')[0].identifier == valid_dataset.identifier @pytest.mark.parametrize( @@ -468,7 +335,8 @@ def test_query_model(db): .find_one({}, {'_id': 1}) ) - # check = q.set_variables(db, X='test') + check = q.set_variables(db, X='test') + assert not check.variables m = QueryModel( identifier='test-query-model', @@ -479,11 +347,11 @@ def test_query_model(db): import torch - out = m.predict(X=torch.randn(32), one=True) + out = m.predict_one({'X': torch.randn(32)}) assert isinstance(out, bson.ObjectId) - out = m.predict(X=torch.randn(4, 32)) + out = m.predict([{'X': torch.randn(32)} for _ in range(4)]) assert len(out) == 4 @@ -497,35 +365,31 @@ def test_sequential_model(): m = SequentialModel( identifier='test-sequential-model', predictors=[ - Model( + ObjectModel( identifier='test-predictor-1', object=lambda x: x + 2, ), - Model( + ObjectModel( identifier='test-predictor-2', object=lambda x: x + 1, + signature=Signature.singleton, ), ], ) - assert m.predict(X=1, one=True) == 4 - assert m.predict(X=[1, 1, 1, 1]) == [4, 4, 4, 4] + assert m.predict_one(x=1) == 4 + assert m.predict([((1,), {}) for _ in range(4)]) == [4, 4, 4, 4] + + +def test_pm_predict_with_select_ids_multikey(monkeypatch, predict_mixin_multikey): + xs = [np.random.randn(4) for _ in range(10)] + def func(x, y): + return 2 -@patch.object(_Predictor, '_predict') -def test_pm_predict_with_select_ids( - predict_mock, predict_mixin_multikey, predict_mixin -): - def _test(multi_key, predict_mixin_multikey): - xs = [np.random.randn(4) for _ in range(10)] - ys = [int(random.random() > 0.5) for i in range(10)] - if multi_key: - docs = [Document({'x': x, 'y': x, 'z': x}) for x in xs] - X = ['x', 'y', 'z'] - else: - docs = [Document({'x': x}) for x in xs] - X = 'x' + monkeypatch.setattr(predict_mixin_multikey, 'object', func) + def _test(X, docs): ids = [i for i in range(10)] select = MagicMock(spec=Select) @@ -533,7 +397,6 @@ def _test(multi_key, predict_mixin_multikey): db.execute.return_value = docs # Check the base predict function - predict_mock.return_value = ys predict_mixin_multikey.db = db with patch.object(select, 'select_using_ids') as select_using_ids, patch.object( select, 'model_update' @@ -542,60 +405,16 @@ def _test(multi_key, predict_mixin_multikey): select_using_ids.assert_called_once_with(ids) _, kwargs = model_update.call_args # make sure the outputs are set - assert kwargs.get('outputs') == ys + assert kwargs.get('outputs') == [2] * 10 - # Check the base predict function with encoder - from superduperdb.components.datatype import DataType + # TODO - I don't know how this works given that the `_outputs` field + # should break... + docs = [Document({'x': x, 'y': x}) for x in xs] + X = ('x', 'y') - predict_mixin_multikey.datatype = DataType(identifier='test') - with patch.object(select, 'model_update') as model_update: - predict_mixin_multikey._predict_with_select_and_ids(X, db, select, ids) - select_using_ids.assert_called_once_with(ids) - _, kwargs = model_update.call_args - # make sure encoder is used - datatype = predict_mixin_multikey.datatype - assert kwargs.get('outputs') == [datatype(y).encode() for y in ys] - - # Check the base predict function with output_schema - from superduperdb.components.schema import Schema - - predict_mixin_multikey.datatype = None - predict_mixin_multikey.output_schema = schema = MagicMock(spec=Schema) - schema.side_effect = str - predict_mock.return_value = [{'y': y} for y in ys] - with patch.object(select, 'model_update') as model_update: - predict_mixin_multikey._predict_with_select_and_ids(X, db, select, ids) - select_using_ids.assert_called_once_with(ids) - _, kwargs = model_update.call_args - assert kwargs.get('outputs') == [str({'y': y}) for y in ys] - - # Test multikey - _test(1, predict_mixin_multikey) - - # Test single key - _test(0, predict_mixin) + _test(X, docs) - -@pytest.mark.parametrize( - "db", - [ - (DBConfig.mongodb_empty, {}), - ], - indirect=True, -) -def test_predict_insert(db): - # Check that when `insert_to` is specified, then the input - # and output of the prediction are saved in the database - - m = Model( - identifier='test-predictor-1', - object=lambda x: x + 2, - ) - - db.add(m) - m.predict( - X=Document({'x': 1}), key='x', one=True, insert_to=Collection('documents') - ) - r = db.execute(Collection('documents').find_one()) - out = r['_outputs']['x'][m.identifier]['0'] - assert out == 3 + # TODO this should also work + # docs = [Document({'a': x, 'b': x}) for x in xs] + # X = {'a': 'x', 'b': 'y'} + # _test(X, docs) diff --git a/test/unittest/component/test_serialization.py b/test/unittest/component/test_serialization.py index 4db24d808c..360977bd95 100644 --- a/test/unittest/component/test_serialization.py +++ b/test/unittest/component/test_serialization.py @@ -10,13 +10,13 @@ from sklearn.svm import SVC -from superduperdb.components.model import Model +from superduperdb.components.model import ObjectModel from superduperdb.ext.sklearn.model import Estimator @pytest.mark.skipif(not torch, reason='Torch not installed') def test_model(): - m = Model( + m = ObjectModel( identifier='test', datatype=tensor(torch.float, shape=(32,)), object=torch.nn.Linear(13, 18), diff --git a/test/unittest/ext/llm/utils.py b/test/unittest/ext/llm/utils.py index 3534861095..3d52b947f9 100644 --- a/test/unittest/ext/llm/utils.py +++ b/test/unittest/ext/llm/utils.py @@ -17,11 +17,6 @@ def check_predict(db, llm): result = db.predict(llm.identifier, "1+1=")[0].unpack() assert isinstance(result, str) - results = db.predict(llm.identifier, ["1+1=", "2+2="])[0].unpack() - - assert isinstance(results, list) - assert len(results) == 2 - def check_llm_as_listener_model(db, llm): """Test whether the model can predict the data in the database normally""" @@ -44,13 +39,11 @@ def check_llm_as_listener_model(db, llm): select = table.select("id", "question") output_select = table.select("id", "question").outputs(question=llm.identifier) - db.add(llm) - db.add( Listener( select=select, key="question", - model=llm.identifier, + model=llm, ) ) diff --git a/test/unittest/ext/test_llama_cpp.py b/test/unittest/ext/test_llama_cpp.py index cd8323b9b0..02fedc67ee 100644 --- a/test/unittest/ext/test_llama_cpp.py +++ b/test/unittest/ext/test_llama_cpp.py @@ -23,7 +23,7 @@ def mocked_init(self): ) text = 'testing prompt' - output = llama.predict(text, one=True) + output = llama.predict_one(text) assert output == 'tested' @@ -41,5 +41,5 @@ def mocked_init(self): ) text = 'testing prompt' - output = llama.predict(text, one=True) + output = llama.predict_one(text) assert output == [1] diff --git a/test/unittest/ext/test_torch.py b/test/unittest/ext/test_torch.py index a1196d357e..0c1687ab43 100644 --- a/test/unittest/ext/test_torch.py +++ b/test/unittest/ext/test_torch.py @@ -51,19 +51,9 @@ def acc(x, y): return x == y -@pytest.mark.skipif(not torch, reason='Torch not installed') -@pytest.mark.parametrize( - 'db', - [ - (DBConfig.mongodb_data, {'n_data': 500}), - (DBConfig.sqldb_data, {'n_data': 500}), - ], - indirect=True, -) -def test_fit(db, valid_dataset): - db.add(valid_dataset) - - m = TorchModel( +@pytest.fixture +def model(): + return TorchModel( object=torch.nn.Linear(32, 1), identifier='test', training_configuration=TorchTrainerConfiguration( @@ -78,6 +68,21 @@ def test_fit(db, valid_dataset): datatype=DataType(identifier='base'), ) + +@pytest.mark.skipif(not torch, reason='Torch not installed') +@pytest.mark.parametrize( + 'db', + [ + (DBConfig.mongodb_data, {'n_data': 500}), + (DBConfig.sqldb_data, {'n_data': 500}), + ], + indirect=True, +) +def test_fit(db, valid_dataset, model): + db.add(valid_dataset) + + m = model + if isinstance(db.databackend, MongoDataBackend): select = Collection('documents').find() else: @@ -92,3 +97,8 @@ def test_fit(db, valid_dataset): metrics=[Metric(identifier='acc', object=acc)], validation_sets=['my_valid'], ) + + +def test_predict(): + # Check that the pre-process etc. has been called + ... diff --git a/test/unittest/ext/test_transformers.py b/test/unittest/ext/test_transformers.py index cceafe505b..2e1eef2c77 100644 --- a/test/unittest/ext/test_transformers.py +++ b/test/unittest/ext/test_transformers.py @@ -48,7 +48,7 @@ def transformers_model(db): @pytest.mark.skipif(not torch, reason='Torch not installed') def test_transformer_predict(transformers_model): - one_prediction = transformers_model.predict('this is a test', one=True) + one_prediction = transformers_model.predict_one('this is a test') assert isinstance(one_prediction, int) predictions = transformers_model.predict(['this is a test', 'this is another']) assert isinstance(predictions, list) @@ -88,5 +88,4 @@ def test_transformer_fit(transformers_model, db, td): select=Collection('train_documents').find({'_fold': 'valid'}), ) ], - data_prefetch=False, ) diff --git a/test/unittest/ext/test_vanilla.py b/test/unittest/ext/test_vanilla.py index cbddeffee8..ac20e9b480 100644 --- a/test/unittest/ext/test_vanilla.py +++ b/test/unittest/ext/test_vanilla.py @@ -4,7 +4,7 @@ from superduperdb.backends.mongodb.query import Collection from superduperdb.base.document import Document -from superduperdb.components.model import Model +from superduperdb.components.model import ObjectModel @pytest.fixture() @@ -20,24 +20,24 @@ def data_in_db(db): def test_function_predict_one(): - function = Model(object=lambda x: x, identifier='test') - assert function.predict(1, one=True) == 1 + function = ObjectModel(object=lambda x: x, identifier='test') + assert function.predict_one(1) == 1 def test_function_predict(): - function = Model(object=lambda x: x, identifier='test') + function = ObjectModel(object=lambda x: x, identifier='test', signature='singleton') assert function.predict([1, 1]) == [1, 1] # TODO: use table to test the sqldb @pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_function_predict_with_document_embedded(data_in_db): - function = Model( + function = ObjectModel( object=lambda x: x, identifier='test', model_update_kwargs={'document_embedded': False}, ) - function.predict( + function.predict_in_db( X='X', db=data_in_db, select=Collection(identifier='documents').find() ) out = data_in_db.execute(Collection(identifier='_outputs.X.test').find({})) @@ -46,8 +46,8 @@ def test_function_predict_with_document_embedded(data_in_db): @pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_function_predict_without_document_embedded(data_in_db): - function = Model(object=lambda x: x, identifier='test') - function.predict( + function = ObjectModel(object=lambda x: x, identifier='test') + function.predict_in_db( X='X', db=data_in_db, select=Collection(identifier='documents').find() ) out = data_in_db.execute(Collection(identifier='documents').find({})) @@ -56,13 +56,13 @@ def test_function_predict_without_document_embedded(data_in_db): @pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_function_predict_with_flatten_outputs(data_in_db): - function = Model( + function = ObjectModel( object=lambda x: [x, x, x] if x > 2 else [x, x], identifier='test', model_update_kwargs={'document_embedded': False}, flatten=True, ) - function.predict( + function.predict_in_db( X='X', db=data_in_db, select=Collection(identifier='documents').find() ) out = data_in_db.execute(Collection(identifier='_outputs.X.test').find({})) @@ -94,16 +94,15 @@ def test_function_predict_with_flatten_outputs(data_in_db): assert [o['_source'] for o in out] == source_ids -@pytest.mark.skip @pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) def test_function_predict_with_mix_flatten_outputs(data_in_db): - function = Model( + function = ObjectModel( object=lambda x: x if x < 2 else [x, x, x], identifier='test', flatten=True, model_update_kwargs={'document_embedded': False}, ) - function.predict( + function.predict_in_db( X='X', db=data_in_db, select=Collection(identifier='documents').find() ) out = data_in_db.execute(Collection(identifier='_outputs.X.test').find({})) diff --git a/test/unittest/test_quality.py b/test/unittest/test_quality.py index d1dae7d30d..747c548c23 100644 --- a/test/unittest/test_quality.py +++ b/test/unittest/test_quality.py @@ -17,9 +17,9 @@ # over time. If you have decreased the number of defects, change it here, # and take a bow! ALLOWABLE_DEFECTS = { - 'cast': 14, # Try to keep this down - 'noqa': 8, # This should never change - 'type_ignore': 35, # This should only ever increase in obscure edge cases + 'cast': 12, # Try to keep this down + 'noqa': 7, # This should never change + 'type_ignore': 40, # This should only ever increase in obscure edge cases }