From 840ce5422c66e3f9446b4531e76a6a768e5c32de Mon Sep 17 00:00:00 2001 From: JieguangZhou Date: Sat, 11 May 2024 16:18:25 +0800 Subject: [PATCH] Added docstrings - Updated docstrings for superduperdb.components - Updated docstrings for superduperdb.base - Updated docstrings for superduperdb.jobs - Updated docstrings for superduperdb.server - Updated docstrings for superduperdb.rest - Updated docstrings for superduperdb.vector_search - Updated docstrings for superduperdb.misc - Updated docstrings for superduperdb.cli - Updated docstrings for superduperdb.cdc - Updated docstrings for superduperdb.backends - Updated docstrings for superduperdb.ext --- CHANGELOG.md | 2 +- pyproject.toml | 16 +- superduperdb/__main__.py | 6 +- superduperdb/backends/base/artifacts.py | 43 ++- superduperdb/backends/base/compute.py | 34 +-- superduperdb/backends/base/data_backend.py | 55 ++-- superduperdb/backends/base/metadata.py | 37 ++- superduperdb/backends/base/query.py | 193 +++++++++--- superduperdb/backends/ibis/cdc/base.py | 7 +- superduperdb/backends/ibis/cdc/listener.py | 70 ++++- superduperdb/backends/ibis/cursor.py | 8 +- superduperdb/backends/ibis/data_backend.py | 53 +++- superduperdb/backends/ibis/db_helper.py | 60 +++- superduperdb/backends/ibis/field_types.py | 17 +- superduperdb/backends/ibis/query.py | 195 +++++++++++-- superduperdb/backends/ibis/utils.py | 7 +- superduperdb/backends/local/artifacts.py | 25 +- superduperdb/backends/local/compute.py | 27 +- superduperdb/backends/mongodb/artifacts.py | 39 ++- superduperdb/backends/mongodb/cdc/base.py | 29 +- superduperdb/backends/mongodb/cdc/listener.py | 65 +++-- superduperdb/backends/mongodb/data_backend.py | 52 +++- superduperdb/backends/mongodb/metadata.py | 127 +++++++- superduperdb/backends/mongodb/query.py | 274 +++++++++++++++--- superduperdb/backends/mongodb/utils.py | 4 +- superduperdb/backends/query_dataset.py | 37 ++- superduperdb/backends/ray/compute.py | 30 +- superduperdb/backends/ray/serve.py | 15 +- superduperdb/backends/sqlalchemy/db_helper.py | 22 ++ superduperdb/backends/sqlalchemy/metadata.py | 135 ++++++++- superduperdb/base/build.py | 7 +- superduperdb/base/code.py | 15 + superduperdb/base/config.py | 103 +++++-- superduperdb/base/config_dicts.py | 16 +- superduperdb/base/configs.py | 11 +- superduperdb/base/cursor.py | 13 +- superduperdb/base/datalayer.py | 51 ++-- superduperdb/base/decorators.py | 4 + superduperdb/base/document.py | 45 ++- superduperdb/base/enums.py | 4 +- superduperdb/base/exceptions.py | 40 +-- superduperdb/base/leaf.py | 39 ++- superduperdb/base/logger.py | 29 ++ superduperdb/base/serializable.py | 40 ++- superduperdb/base/superduper.py | 47 ++- superduperdb/cdc/app.py | 16 +- superduperdb/cdc/cdc.py | 109 +++++-- superduperdb/cli/config.py | 4 + superduperdb/cli/info.py | 5 + superduperdb/cli/serve.py | 15 + superduperdb/cli/stack.py | 5 + superduperdb/components/__init__.py | 4 +- superduperdb/components/component.py | 90 ++---- superduperdb/components/dataset.py | 21 +- superduperdb/components/datatype.py | 230 ++++++--------- superduperdb/components/graph.py | 70 ++--- superduperdb/components/listener.py | 47 +-- superduperdb/components/metric.py | 9 +- superduperdb/components/model.py | 180 +++++------- superduperdb/components/schema.py | 28 +- superduperdb/components/stack.py | 13 +- superduperdb/components/vector_index.py | 52 ++-- superduperdb/ext/anthropic/model.py | 24 +- superduperdb/ext/cohere/model.py | 41 ++- superduperdb/ext/jina/client.py | 30 +- superduperdb/ext/jina/model.py | 23 +- superduperdb/ext/llamacpp/model.py | 25 +- superduperdb/ext/llm/model.py | 59 ++-- superduperdb/ext/llm/prompter.py | 27 +- superduperdb/ext/numpy/encoder.py | 35 ++- superduperdb/ext/openai/model.py | 227 +++++++-------- superduperdb/ext/pillow/encoder.py | 19 +- .../ext/sentence_transformers/model.py | 26 ++ superduperdb/ext/sklearn/model.py | 37 +++ superduperdb/ext/torch/encoder.py | 39 ++- superduperdb/ext/torch/model.py | 89 +++++- superduperdb/ext/torch/training.py | 37 +++ superduperdb/ext/torch/utils.py | 2 +- superduperdb/ext/transformers/model.py | 105 ++++++- superduperdb/ext/transformers/training.py | 166 ++++++++--- superduperdb/ext/unstructured/encoder.py | 20 +- superduperdb/ext/utils.py | 26 +- superduperdb/ext/vllm/model.py | 20 +- superduperdb/jobs/job.py | 53 ++-- superduperdb/jobs/task_workflow.py | 18 +- superduperdb/jobs/tasks.py | 24 ++ superduperdb/misc/__init__.py | 9 +- superduperdb/misc/annotations.py | 32 +- superduperdb/misc/anonymize.py | 7 +- superduperdb/misc/archives.py | 8 +- superduperdb/misc/auto_schema.py | 39 ++- superduperdb/misc/colors.py | 2 + superduperdb/misc/compat.py | 12 +- superduperdb/misc/data.py | 4 +- superduperdb/misc/download.py | 71 +++-- superduperdb/misc/files.py | 3 + superduperdb/misc/hash.py | 14 +- superduperdb/misc/retry.py | 7 +- superduperdb/misc/run.py | 1 + superduperdb/misc/runnable/collection.py | 19 +- superduperdb/misc/runnable/queue_chunker.py | 9 +- superduperdb/misc/runnable/runnable.py | 15 +- superduperdb/misc/runnable/thread.py | 36 ++- superduperdb/misc/serialization.py | 6 +- superduperdb/misc/server.py | 8 + superduperdb/misc/special_dicts.py | 9 +- superduperdb/rest/app.py | 5 + superduperdb/rest/utils.py | 10 + superduperdb/server/app.py | 83 ++++-- superduperdb/server/cluster.py | 17 +- superduperdb/vector_search/atlas.py | 35 ++- superduperdb/vector_search/base.py | 99 +++++-- superduperdb/vector_search/in_memory.py | 23 ++ superduperdb/vector_search/interface.py | 13 +- superduperdb/vector_search/lance.py | 21 ++ superduperdb/vector_search/server/app.py | 45 +++ superduperdb/vector_search/server/service.py | 74 +++-- superduperdb/vector_search/update_tasks.py | 22 +- .../ext/openai/test_model_openai.py | 4 +- test/unittest/ext/llm/test_openai.py | 48 --- test/unittest/test_docstrings.py | 32 +- 121 files changed, 3600 insertions(+), 1460 deletions(-) delete mode 100644 test/unittest/ext/llm/test_openai.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 26d46330c5..3ea087e677 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 #### Changed defaults / behaviours -- Add docstrings in component classes and methods. - Run Tests from within the container - Add model dict output indexing in graph - Make lance upsert for added vectors @@ -20,6 +19,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - At the end of the test, drop the collection instead of the database - Force load vector indices during backfill - Fix pandas database (in-memory) +- Add docstrings in component classes and methods. #### New Features & Functionality - Add nightly image for pre-release testing in the cloud environment diff --git a/pyproject.toml b/pyproject.toml index 43a93930b1..d7c9eb7708 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -110,9 +110,21 @@ extend-select = [ #"W", # PyCode Warning "E", # PyCode Error #"N", # pep8-naming - #"D", # pydocstyle + "D", # pydocstyle +] +ignore = [ + "D100", # Missing docstring in public module + "D104", # Missing docstring in public package + "D107", # Missing docstring in __init__ + "D105", # Missing docstring in magic method + "D212", # Multi-line docstring summary should start at the first line + "D213", # Multi-line docstring summary should start at the second line + "D401", + "E402", ] -ignore = ["E402"] [tool.ruff.isort] combine-as-imports = true + +[tool.ruff.per-file-ignores] +"test/**" = ["D"] diff --git a/superduperdb/__main__.py b/superduperdb/__main__.py index 9c0f99b8d3..35e13dab47 100644 --- a/superduperdb/__main__.py +++ b/superduperdb/__main__.py @@ -9,9 +9,9 @@ def run(): - """ - Entrypoint for the CLI. This is the function that is called when the - user runs `python -m superduperdb`. + """Entrypoint for the CLI. + + This is the function that is called when the user runs `python -m superduperdb`. """ try: app(standalone_mode=False) diff --git a/superduperdb/backends/base/artifacts.py b/superduperdb/backends/base/artifacts.py index db5a17cefd..92a68543ec 100644 --- a/superduperdb/backends/base/artifacts.py +++ b/superduperdb/backends/base/artifacts.py @@ -29,28 +29,36 @@ def __init__( @property def serializers(self): + """Return the serializers.""" assert self._serializers is not None, 'Serializers not initialized!' return self._serializers @serializers.setter def serializers(self, value): + """Set the serializers. + + :param value: The serializers. + """ self._serializers = value @abstractmethod def url(self): - """ - Artifact store connection url - """ + """Artifact store connection url.""" pass @abstractmethod def _delete_artifact(self, file_id: str): - """ - Delete artifact from artifact store - :param file_id: File id uses to identify artifact in store + """Delete artifact from artifact store. + + :param file_id: File id uses to identify artifact in store. """ def delete(self, r: t.Dict): + """Delete artifact from artifact store. + + :param r: dictionary with mandatory fields + {'file_id'} + """ if '_content' in r and 'file_id' in r['_content']: return self._delete_artifact(r['_content']['file_id']) for v in r.values(): @@ -76,6 +84,12 @@ def exists( datatype: t.Optional[str] = None, uri: t.Optional[str] = None, ): + """Check if artifact exists in artifact store. + + :param file_id: file id of artifact in the store + :param datatype: Datatype of the artifact + :param uri: URI of the artifact + """ if file_id is None: assert uri is not None, "if file_id is None, uri can\'t be None" file_id = _construct_file_id_from_uri(uri) @@ -91,7 +105,7 @@ def _save_bytes(self, serialized: bytes, file_id: str): @abstractmethod def _save_file(self, file_path: str, file_id: str) -> str: - """Save file in artifact store and return file_id""" + """Save file in artifact store and return file_id.""" pass def save_artifact(self, r: t.Dict): @@ -147,7 +161,7 @@ def _load_bytes(self, file_id: str) -> bytes: @abstractmethod def _load_file(self, file_id: str) -> str: """ - Load file from artifact store and return path + Load file from artifact store and return path. :param file_id: Identifier of artifact in the store """ @@ -159,7 +173,6 @@ def load_artifact(self, r): :param r: Mandatory fields {'file_id', 'datatype'} """ - datatype = self.serializers[r['datatype']] file_id = r.get('file_id') if r.get('encodable') == 'file': @@ -174,9 +187,9 @@ def load_artifact(self, r): return datatype.decode_data(x) def save(self, r: t.Dict) -> t.Dict: - """ - Save list of artifacts and replace the artifacts with file reference - :param r: `dict` of artifacts + """Save list of artifacts and replace the artifacts with file reference. + + :param r: `dict` of artifacts. """ if isinstance(r, dict): if '_content' in r and r['_content']['leaf_type'] in { @@ -196,11 +209,11 @@ def save(self, r: t.Dict) -> t.Dict: @abstractmethod def disconnect(self): - """ - Disconnect the client - """ + """Disconnect the client.""" pass class ArtifactSavingError(Exception): + """Error when saving artifact in artifact store fails.""" + pass diff --git a/superduperdb/backends/base/compute.py b/superduperdb/backends/base/compute.py index 122f71ec52..07de8fe96c 100644 --- a/superduperdb/backends/base/compute.py +++ b/superduperdb/backends/base/compute.py @@ -3,26 +3,20 @@ class ComputeBackend(ABC): - """ - Abstraction for sending jobs to a distributed compute platform. - """ + """Abstraction for sending jobs to a distributed compute platform.""" @abstractproperty def type(self) -> str: - """ - Return the type of compute engine - """ + """Return the type of compute engine.""" pass @abstractproperty def name(self) -> str: - """ - Return the name of current compute engine - """ + """Return the name of current compute engine.""" pass def get_local_client(self): - '''Returns a local version of self''' + """Returns a local version of self.""" pass @abstractmethod @@ -37,22 +31,18 @@ def submit(self, function: t.Callable, **kwargs) -> t.Any: @abstractproperty def tasks(self) -> t.Any: - """ - List for all tasks - """ + """List for all tasks.""" pass @abstractmethod def wait_all(self) -> None: - """ - Waits for all pending tasks to complete. - """ + """Waits for all pending tasks to complete.""" pass @abstractmethod def result(self, identifier: str) -> t.Any: - """ - Retrieves the result of a previously submitted task. + """Retrieves the result of a previously submitted task. + Note: This will block until the future is completed. :param identifier: The identifier of the submitted task. @@ -61,14 +51,10 @@ def result(self, identifier: str) -> t.Any: @abstractmethod def disconnect(self) -> None: - """ - Disconnect the client. - """ + """Disconnect the client.""" pass @abstractmethod def shutdown(self) -> None: - """ - Shuts down the compute cluster. - """ + """Shuts down the compute cluster.""" pass diff --git a/superduperdb/backends/base/data_backend.py b/superduperdb/backends/base/data_backend.py index ef2a1fb860..4a90f1dc45 100644 --- a/superduperdb/backends/base/data_backend.py +++ b/superduperdb/backends/base/data_backend.py @@ -6,6 +6,12 @@ class BaseDataBackend(ABC): + """Base data backend for the database. + + :param conn: The connection to the databackend database. + :param name: The name of the databackend. + """ + db_type = None def __init__(self, conn: t.Any, name: str): @@ -16,27 +22,22 @@ def __init__(self, conn: t.Any, name: str): @property def db(self): + """Return the datalayer.""" raise NotImplementedError @abstractmethod def url(self): - """ - Databackend connection url - """ + """Databackend connection url.""" pass @abstractmethod def build_metadata(self): - """ - Build a default metadata store based on current connection. - """ + """Build a default metadata store based on current connection.""" pass @abstractmethod def build_artifact_store(self): - """ - Build a default artifact store based on current connection. - """ + """Build a default artifact store based on current connection.""" pass @abstractmethod @@ -46,41 +47,57 @@ def create_output_dest( datatype: t.Union[None, DataType, FieldType], flatten: bool = False, ): + """Create an output destination for the database. + + :param predict_id: The predict id of the output destination. + :param datatype: The datatype of the output destination. + :param flatten: Whether to flatten the output destination. + """ pass @abstractmethod def check_output_dest(self, predict_id) -> bool: + """Check if the output destination exists. + + :param predict_id: The identifier of the output destination. + """ pass @abstractmethod def get_table_or_collection(self, identifier): + """Get a table or collection from the database. + + :param identifier: The identifier of the table or collection. + """ pass def set_content_bytes(self, r, key, bytes_): + """Set content bytes. + + :param r: The row. + :param key: The key. + :param bytes_: The bytes. + """ raise NotImplementedError @abstractmethod def drop(self, force: bool = False): - """ - Drop the databackend. + """Drop the databackend. + + :param force: If ``True``, don't ask for confirmation. """ @abstractmethod def disconnect(self): - """ - Disconnect the client - """ + """Disconnect the client.""" @abstractmethod def list_tables_or_collections(self): - """ - List all tables or collections in the database. - """ + """List all tables or collections in the database.""" @staticmethod def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None): - """ - Infer a schema from a given data object + """Infer a schema from a given data object. :param data: The data object :param identifier: The identifier for the schema, if None, it will be generated diff --git a/superduperdb/backends/base/metadata.py b/superduperdb/backends/base/metadata.py index 15e9c2b02c..e0a15da8e9 100644 --- a/superduperdb/backends/base/metadata.py +++ b/superduperdb/backends/base/metadata.py @@ -7,6 +7,8 @@ class NonExistentMetadataError(Exception): + """NonExistentMetadataError.""" + ... @@ -28,9 +30,7 @@ def __init__( @abstractmethod def url(self): - """ - Metadata store connection url - """ + """Metadata store connection url.""" pass @abstractmethod @@ -44,8 +44,8 @@ def create_component(self, info: t.Dict): @abstractmethod def create_job(self, info: t.Dict): - """ - Create a job in the metadata store. + """Create a job in the metadata store. + :param info: dictionary containing information about the job. """ pass @@ -126,8 +126,10 @@ def show_jobs( component_identifier: t.Optional[str], type_id: t.Optional[str], ): - """ - Show all jobs in the metadata store. + """Show all jobs in the metadata store. + + :param component_identifier: identifier of component + :param type_id: type of component """ pass @@ -154,6 +156,11 @@ def show_component_versions(self, type_id: str, identifier: str): def get_indexing_listener_of_vector_index( self, identifier: str, version: t.Optional[int] = None ): + """Get the indexing listener of a vector index. + + :param identifier: identifier of vector index + :param version: version of vector index + """ info = self.get_component( 'vector_index', identifier=identifier, version=version ) @@ -355,19 +362,19 @@ def update_metadata(self, key, value): pass def add_query(self, query: 'Select', model: str): - """ - Add query id to query table + """Add query id to query table. + + :param query: query object + :param model: model identifier """ raise NotImplementedError def get_queries(self, model: str): - """ - Get all queries from query table corresponding - to the model. + """Get all queries from query table corresponding to the model. + + :param model: model identifier """ @abstractmethod def disconnect(self): - """ - Disconnect the client - """ + """Disconnect the client.""" diff --git a/superduperdb/backends/base/query.py b/superduperdb/backends/base/query.py index 78fcaab837..4d49be4c53 100644 --- a/superduperdb/backends/base/query.py +++ b/superduperdb/backends/base/query.py @@ -33,28 +33,49 @@ def _check_illegal_attribute(name): raise AttributeError(f"Attempt to access illegal attribute '{name}'") +# TODO: Remove unused code @dc.dataclass(repr=False) class model(Serializable): + """Model. + + :param identifier: The identifier of the model. + """ + identifier: str def predict_one(self, *args, **kwargs): + """Predict one.""" return PredictOne(model=self.identifier, args=args, kwargs=kwargs) def predict(self, *args, **kwargs): + """Predict.""" raise NotImplementedError class Predict: + """Base class for all prediction queries.""" + ... @dc.dataclass(repr=False) class PredictOne(Predict, Serializable, ABC): + """A query to predict a single document. + + :param model: The model to use + :param args: The arguments to pass to the model + :param kwargs: The keyword arguments to pass to the model + """ + model: str args: t.Sequence = dc.field(default_factory=list) kwargs: t.Dict = dc.field(default_factory=dict) def execute(self, db): + """Execute the query. + + :param db: The datalayer instance + """ m = db.models[self.model] out = m.predict_one(*self.args, **self.kwargs) if isinstance(m.datatype, DataType): @@ -68,16 +89,16 @@ def execute(self, db): @dc.dataclass(repr=False) class Select(Serializable, ABC): - """ - Base class for all select queries. - """ + """Base class for all select queries.""" @abstractproperty def id_field(self): + """Return the primary id of the table.""" pass @property def query_components(self): + """Return the query components of the query.""" return self.table_or_collection.query_components def model_update( @@ -93,8 +114,7 @@ def model_update( :param db: The DB instance to use :param ids: The ids to update - :param key: The key to update - :param model: The model to update + :param predict_id: The predict_id of the outputs :param outputs: The outputs to update """ return self.table_or_collection.model_update( @@ -107,32 +127,51 @@ def model_update( @abstractproperty def select_table(self): + """Return a select query for the table.""" pass @abstractmethod def add_fold(self, fold: str) -> 'Select': + """Add a fold to the query. + + :param fold: The fold to add + """ pass @abstractmethod def select_using_ids(self, ids: t.Sequence[str]) -> 'Select': - pass + """Return a query that selects only the given ids. + + :param ids: The ids to select + """ @abstractproperty def select_ids(self) -> 'Select': + """Return a query that selects only the ids.""" pass @abstractmethod def select_ids_of_missing_outputs(self, predict_id: str) -> 'Select': + """Return a query that selects ids where outputs are missing. + + :param predict_id: The predict_id of the outputs + """ pass @abstractmethod def select_single_id(self, id: str) -> 'Select': + """Return a query that selects a single id. + + :param id: The id to select + """ pass @abstractmethod def execute(self, db, reference: bool = True): - """ - Execute the query on the DB instance. + """Execute the query on the DB instance. + + :param db: The datalayer instance + :param reference: Whether to return a reference to the data """ pass @@ -140,7 +179,7 @@ def execute(self, db, reference: bool = True): @dc.dataclass(repr=False) class Delete(Serializable, ABC): """ - Base class for all deletion queries + Base class for all deletion queries. :param table_or_collection: The table or collection that this query is linked to """ @@ -151,13 +190,17 @@ class Delete(Serializable, ABC): @abstractmethod def execute(self, db): + """Execute the query. + + :param db: The datalayer instance + """ pass @dc.dataclass(repr=False) class Update(Serializable, ABC): """ - Base class for all update queries + Base class for all update queries. :param table_or_collection: The table or collection that this query is linked to """ @@ -166,17 +209,22 @@ class Update(Serializable, ABC): @abstractmethod def select_table(self): + """Return a select query for the table.""" pass @abstractmethod def execute(self, db): + """Execute the query. + + :param db: The datalayer instance + """ pass @dc.dataclass(repr=False) class Write(Serializable, ABC): """ - Base class for all bulk write queries + Base class for all bulk write queries. :param table_or_collection: The table or collection that this query is linked to """ @@ -185,10 +233,15 @@ class Write(Serializable, ABC): @abstractmethod def select_table(self): + """Return a select query for the table.""" pass @abstractmethod def execute(self, db): + """Execute the query on the DB instance. + + :param db: The datalaer instance + """ pass @@ -205,7 +258,6 @@ class CompoundSelect(_ReprMixin, Select, ABC): (e.g. ``table.filter(...)....like(...)``) :param query_linker: The query linker that is responsible for linking the query chain. E.g. ``table.filter(...).select(...)``. - :param i: The index of the query in the query chain """ table_or_collection: 'TableOrCollection' @@ -215,17 +267,24 @@ class CompoundSelect(_ReprMixin, Select, ABC): @abstractproperty def output_fields(self): + """Return the output fields of the query.""" pass @property def id_field(self): + """Return the primary id of the table.""" return self.primary_id @property def primary_id(self): + """Return the primary id of the table.""" return self.table_or_collection.primary_id def add_fold(self, fold: str): + """Add a fold to the query. + + :param fold: The fold to add + """ assert self.pre_like is None assert self.post_like is None assert self.query_linker is not None @@ -236,10 +295,7 @@ def add_fold(self, fold: str): @property def select_ids(self): - """ - Query which selects the same documents/ rows but only ids. - """ - + """Query which selects the same documents/ rows but only ids.""" assert self.pre_like is None assert self.post_like is None @@ -249,10 +305,10 @@ def select_ids(self): ) def select_ids_of_missing_outputs(self, predict_id: str): - """ - Query which selects ids where outputs are missing. - """ + """Query which selects ids where outputs are missing. + :param predict_id: The predict_id of the outputs + """ assert self.pre_like is None assert self.post_like is None assert self.query_linker is not None @@ -285,7 +341,6 @@ def select_using_ids(self, ids): :param ids: The ids to subset to. """ - assert self.pre_like is None assert self.post_like is None @@ -295,10 +350,7 @@ def select_using_ids(self, ids): ) def repr_(self): - """ - String representation of the query. - """ - + """String representation of the query.""" components = [] components.append(str(self.table_or_collection.identifier)) if self.pre_like: @@ -357,6 +409,7 @@ def __getattr__(self, name): ) def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Add a query to the query chain.""" assert self.post_like is None assert self.query_linker is not None return self._query_from_parts( @@ -371,9 +424,16 @@ def execute(self, db, load_hybrid: bool = True): Execute the compound query on the DB instance. :param db: The DB instance to use + :param load_hybrid: Whether to load hybrid fields """ def like(self, r: Document, vector_index: str, n: int = 10): + """Return a query that performs a vector search. + + :param r: The document to search for + :param vector_index: The vector index to use + :param n: The number of results to return + """ assert self.query_linker is not None assert self.pre_like is None return self._query_from_parts( @@ -391,10 +451,8 @@ class Insert(_ReprMixin, Serializable, ABC): :param table_or_collection: The table or collection that this query is linked to :param documents: The documents to insert - :param refresh: Whether to refresh the task-graph after inserting :param verbose: Whether to print the progress of the insert :param kwargs: Any additional keyword arguments to pass to the insert method - :param encoders: The encoders to use to encode the documents """ table_or_collection: 'TableOrCollection' @@ -403,6 +461,7 @@ class Insert(_ReprMixin, Serializable, ABC): kwargs: t.Dict = dc.field(default_factory=dict) def repr_(self): + """String representation of the query.""" documents_str = ( str(self.documents)[:25] + '...' if len(self.documents) > 25 @@ -412,6 +471,7 @@ def repr_(self): @abstractmethod def select_table(self): + """Return a select query for the inserted documents.""" pass @abstractmethod @@ -424,15 +484,17 @@ def execute(self, parent: t.Any): pass def to_select(self, ids=None): + """Return a select query for the inserted documents. + + :param ids: The ids to select + """ if ids is None: ids = [r['_id'] for r in self.documents] return self.table.find({'_id': ids}) class QueryType(str, enum.Enum): - """ - The type of a query. Either `query` or `attr`. - """ + """The type of a query. Either `query` or `attr`.""" QUERY = 'query' ATTR = 'attr' @@ -487,8 +549,8 @@ def _deep_flat_encode_impl(self, cache): @dc.dataclass(repr=False) class QueryComponent(Serializable): - """ - This is a representation of a single query object in ibis query chain. + """QueryComponent is a representation of a query object in ibis query chain. + This is used to build a query chain that can be executed on a database. Query will be executed in the order they are added to the chain. @@ -502,6 +564,7 @@ class QueryComponent(Serializable): :param type: The type of the query, either `query` or `attr` :param args: The arguments to pass to the query :param kwargs: The keyword arguments to pass to the query + :param _deep_flat_encode: The method to encode the query """ name: str @@ -512,6 +575,7 @@ class QueryComponent(Serializable): _deep_flat_encode = _deep_flat_encode_impl def repr_(self) -> str: + """String representation of the query.""" if self.type == QueryType.ATTR: return self.name @@ -533,6 +597,7 @@ def to_str(x): return f'{self.name}({joined})' def __call__(self, *args: Any, **kwargs: Any) -> Any: + """Add a query to the query chain.""" try: assert ( self.type == QueryType.ATTR @@ -548,6 +613,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: ) def execute(self, parent: t.Any): + """Execute the query on the parent object. + + :param parent: The parent object to execute the query on. + """ if self.type == QueryType.ATTR: return getattr(parent, self.name) assert self.type == QueryType.QUERY @@ -557,7 +626,8 @@ def execute(self, parent: t.Any): @dc.dataclass(repr=False) class QueryLinker(_ReprMixin, Serializable, ABC): - """ + """QueryLinker is a representation of a query chain. + This class is responsible for linking together a query using `getattr` and `__call__`. @@ -584,9 +654,11 @@ class QueryLinker(_ReprMixin, Serializable, ABC): @property def query_components(self): + """Return the query components of the query chain.""" return self.table_or_collection.query_components def repr_(self) -> str: + """String representation of the query.""" return ( f'{self.table_or_collection.identifier}' + '.' @@ -615,22 +687,39 @@ def __getattr__(self, k): @property @abstractmethod def select_ids(self): + """Return a query that selects only the ids. + + This is used to select only the ids of the documents. + """ pass @abstractmethod def select_single_id(self, id): + """Return a query that selects a single id. + + :param id: The id to select + """ pass @abstractmethod def select_using_ids(self, ids): + """Return a query that selects only the given ids. + + :param ids: The ids to select + """ pass def __call__(self, *args, **kwargs): + """Add a query to the query chain.""" members = [*self.members[:-1], self.members[-1](*args, **kwargs)] return type(self)(table_or_collection=self.table_or_collection, members=members) @abstractmethod def execute(self, db): + """Execute the query. + + :param db: The datalayer instance + """ pass @@ -642,6 +731,7 @@ class Like(Serializable): :param r: The item to be converted to a vector, to search with. :param vector_index: The vector index to use :param n: The number of results to return + :param _deep_flat_encode: The method to encode the query """ r: t.Union[t.Dict, Document] @@ -651,6 +741,11 @@ class Like(Serializable): _deep_flat_encode = _deep_flat_encode_impl def execute(self, db, ids: t.Optional[t.Sequence[str]] = None): + """Execute the query. + + :param db: The datalayer instance + :param ids: The ids to search for + """ return db.select_nearest( like=self.r, vector_index=self.vector_index, @@ -661,10 +756,11 @@ def execute(self, db, ids: t.Optional[t.Sequence[str]] = None): @dc.dataclass class TableOrCollection(Serializable, ABC): - """ - This is a representation of an SQL table in ibis. + """A base class for all tables and collections. - :param identifier: The name of the table + Defines the interface for all tables and collections. + + :param identifier: The identifier of the table or collection. """ query_components: t.ClassVar[t.Dict] = {} @@ -690,10 +786,22 @@ def model_update( flatten: bool = False, **kwargs, ): + """Update model outputs for a set of ids. + + :param db: The datalayer instance + :param ids: The ids to update + :param predict_id: The predict_id of outputs + :param outputs: The outputs to update + :param flatten: Whether to flatten the output + """ pass @abstractmethod def insert(self, documents: t.Sequence[Document], **kwargs) -> Insert: + """Return Insert query. + + :param documents: The documents to insert + """ pass @abstractmethod @@ -721,7 +829,8 @@ def like( vector_index: str, n: int = 10, ): - """ + """Return a query that performs a vector search. + This method appends a query to the query chain where the query is repsonsible for performing a vector search on the parent query chain inputs. @@ -752,10 +861,16 @@ def _insert( @dc.dataclass class RawQuery: + """A raw query object. + + :param query: The raw query to execute. + """ + query: t.Any @abstractmethod def execute(self, db): - ''' - A raw query method which executes the query and returns the result - ''' + """A raw query method which executes the query and returns the result. + + :param db: The datalayer instance + """ diff --git a/superduperdb/backends/ibis/cdc/base.py b/superduperdb/backends/ibis/cdc/base.py index deb38f15c9..7fd9c34e5f 100644 --- a/superduperdb/backends/ibis/cdc/base.py +++ b/superduperdb/backends/ibis/cdc/base.py @@ -9,8 +9,11 @@ @dc.dataclass class IbisDBPacket(Packet): - """ - A base packet to represent message in task queue. + """A base packet to represent message in task queue. + + :param ids: The ids of the rows. + :param query: The query to be executed. + :param event_type: The event type. """ ids: t.List[str] diff --git a/superduperdb/backends/ibis/cdc/listener.py b/superduperdb/backends/ibis/cdc/listener.py index a37aefce2a..25639a212b 100644 --- a/superduperdb/backends/ibis/cdc/listener.py +++ b/superduperdb/backends/ibis/cdc/listener.py @@ -15,6 +15,16 @@ class PollingStrategyIbis: + """PollingStrategyIbis. + + This is a base class for polling strategies for ibis backend. + + :param db: The datalayer instance. + :param table: The table on which the polling strategy is applied. + :param strategy: The strategy to use for polling. + :param primary_id: The primary id of the table. + """ + def __init__( self, db: 'Datalayer', table: 'Table', strategy, primary_id: str = 'id' ): @@ -28,12 +38,15 @@ def __init__( self._last_processed_id = -1 def fetch_ids(self): + """fetch_ids.""" raise NotImplementedError def post_handling(self): + """post_handling.""" time.sleep(self.frequency) def get_strategy(self): + """get_strategy.""" if self.increment_field: return PollingStrategyIbisByIncrement( self.db, self.table, self.strategy, primary_id=self.primary_id @@ -45,9 +58,16 @@ def get_strategy(self): class PollingStrategyIbisByIncrement(PollingStrategyIbis): + """PollingStrategyIbisByIncrement. + + This is a polling strategy for ibis backend which polls the table + based on the increment field. + """ + def fetch_ids( self, ): + """fetch_ids.""" assert self.increment_field _filter = self.table.__getattr__(self.increment_field) > self._last_processed_id query = self.table.select(self.primary_id).filter(_filter) @@ -58,19 +78,31 @@ def fetch_ids( class PollingStrategyIbisByID(PollingStrategyIbis): + """PollingStrategyIbisByID. + + This is a polling strategy for ibis backend which polls the table + based on the primary id. + """ + ... class IbisDatabaseListener(cdc.BaseDatabaseListener): """ - It is a class which helps capture data from ibis database and handle it - accordingly. + It is a class which helps capture data from ibis database and handle it accordingly. This class accepts options and db instance from user and starts a scheduler which could schedule a listening service to listen change stream. This class builds a workflow graph on each change observed. + :param db: It is a datalayer instance. + :param on: It is used to define a Collection on which CDC would be performed. + :param stop_event: A threading event flag to notify for stoppage. + :param identifier: A identifier to represent the listener service. + :param timeout: A timeout to stop the listener service. + :param strategy: Used to select strategy used for listening changes, + Options: [PollingStrategy, LogBasedStrategy] """ DEFAULT_ID: str = 'id' @@ -100,7 +132,6 @@ def __init__( superduperdb.cdc.cdc.PollingStrategy) LogBasedStrategy (Not implemented yet) """ - if not strategy: assert CFG.cluster.cdc self.strategy = CFG.cluster.cdc.strategy @@ -117,20 +148,33 @@ def __init__( ) def on_update(self, ids: t.Sequence, db: 'Datalayer', table: query.Table) -> None: + """on_update. + + :param ids: Changed row ids. + :param db: a datalayer instance. + :param table: The table on which change was observed. + """ raise NotImplementedError def on_delete(self, ids: t.Sequence, db: 'Datalayer', table: query.Table) -> None: + """on_delete. + + :param ids: Changed row ids. + :param db: a datalayer instance. + :param table: The table on which change was observed. + """ raise NotImplementedError def on_create(self, ids: t.Sequence, db: 'Datalayer', table: query.Table) -> None: """on_create. + A helper on create event handler which handles inserted document in the change stream. It basically extracts the change document and build the taskflow graph to execute. :param ids: Changed row ids. - :param db: a superduperdb instance. + :param db: a datalayer instance. :param table: The table on which change was observed. """ logging.debug('Triggered `on_create` handler.') @@ -139,9 +183,7 @@ def on_create(self, ids: t.Sequence, db: 'Datalayer', table: query.Table) -> Non ) def setup_cdc(self): - """ - Setup cdc change stream from user provided - """ + """Setup cdc change stream from user provided.""" if isinstance(self.strategy, PollingStrategy): self.stream = PollingStrategyIbis( self.db, @@ -156,8 +198,9 @@ def setup_cdc(self): return self.stream def next_cdc(self, stream) -> None: - """ - Get the next stream of change observed on the given `Collection`. + """Get the next stream of change observed on the given `Collection`. + + :param stream: The stream to get the next change. """ ids = stream.fetch_ids() if ids: @@ -168,6 +211,10 @@ def next_cdc(self, stream) -> None: def listen( self, ) -> None: + """Start listening cdc changes. + + This starts the corresponding scheduler as well. + """ try: self._stop_event.clear() if self._scheduler: @@ -191,8 +238,8 @@ def listen( raise def stop(self) -> None: - """ - Stop listening cdc changes. + """Stop listening cdc changes. + This stops the corresponding services as well. """ self._stop_event.set() @@ -200,4 +247,5 @@ def stop(self) -> None: self._scheduler.join() def running(self) -> bool: + """Check if the listener is running.""" return not self._stop_event.is_set() diff --git a/superduperdb/backends/ibis/cursor.py b/superduperdb/backends/ibis/cursor.py index ff987155d2..28c405feb0 100644 --- a/superduperdb/backends/ibis/cursor.py +++ b/superduperdb/backends/ibis/cursor.py @@ -8,12 +8,14 @@ @dc.dataclass class SuperDuperIbisResult(SuperDuperCursor): - ''' + """SuperDuperIbisResult class for ibis query results. + SuperDuperIbisResult represents ibis query results with options - to unroll results as i.e pandas - ''' + to unroll results as i.e pandas. + """ def as_pandas(self): + """Unroll the result as a pandas DataFrame.""" return pandas.DataFrame([Document(r).unpack() for r in self.raw_cursor]) def __getitem__(self, item): diff --git a/superduperdb/backends/ibis/data_backend.py b/superduperdb/backends/ibis/data_backend.py index 28806ae371..f234592293 100644 --- a/superduperdb/backends/ibis/data_backend.py +++ b/superduperdb/backends/ibis/data_backend.py @@ -23,6 +23,13 @@ class IbisDataBackend(BaseDataBackend): + """Ibis data backend for the database. + + :param conn: The connection to the database. + :param name: The name of the database. + :param in_memory: Whether to store the data in memory. + """ + db_type = DBType.SQL def __init__(self, conn: BaseBackend, name: str, in_memory: bool = False): @@ -32,18 +39,31 @@ def __init__(self, conn: BaseBackend, name: str, in_memory: bool = False): self.db_helper = get_db_helper(self.dialect) def url(self): + """Get the URL of the database.""" return self.conn.con.url + self.name def build_artifact_store(self): + """Build artifact store for the database.""" return FileSystemArtifactStore(conn='.superduperdb/artifacts/', name='ibis') def build_metadata(self): + """Build metadata for the database.""" return SQLAlchemyMetadata(conn=self.conn.con, name='ibis') def create_ibis_table(self, identifier: str, schema: Schema): + """Create a table in the database. + + :param identifier: The identifier of the table. + :param schema: The schema of the table. + """ self.conn.create_table(identifier, schema=schema) def insert(self, table_name, raw_documents): + """Insert data into the database. + + :param table_name: The name of the table. + :param raw_documents: The data to insert. + """ for doc in raw_documents: for k, v in doc.items(): doc[k] = self.db_helper.convert_data_format(v) @@ -71,6 +91,12 @@ def create_output_dest( datatype: t.Union[FieldType, DataType], flatten: bool = False, ): + """Create a table for the output of the model. + + :param predict_id: The identifier of the prediction. + :param datatype: The data type of the output. + :param flatten: Whether to flatten the output. + """ msg = ( "Model must have an encoder to create with the" f" {type(self).__name__} backend." @@ -103,6 +129,10 @@ def create_output_dest( ) def check_output_dest(self, predict_id) -> bool: + """Check if the output destination exists. + + :param predict_id: The identifier of the prediction. + """ try: self.conn.table(f'_outputs.{predict_id}') return True @@ -110,10 +140,11 @@ def check_output_dest(self, predict_id) -> bool: return False def create_table_and_schema(self, identifier: str, mapping: dict): - """ - Create a schema in the data-backend. - """ + """Create a schema in the data-backend. + :param identifier: The identifier of the table. + :param mapping: The mapping of the schema. + """ try: mapping = self.db_helper.process_schema_types(mapping) t = self.conn.create_table(identifier, schema=ibis.schema(mapping)) @@ -126,27 +157,33 @@ def create_table_and_schema(self, identifier: str, mapping: dict): return t def drop(self, force: bool = False): + """Drop tables or collections in the database. + + :param force: Whether to force the drop. + """ raise NotImplementedError( "Dropping tables needs to be done in each DB natively" ) def get_table_or_collection(self, identifier): + """Get a table or collection from the database. + + :param identifier: The identifier of the table or collection. + """ return self.conn.table(identifier) def disconnect(self): - """ - Disconnect the client - """ + """Disconnect the client.""" # TODO: implement me def list_tables_or_collections(self): + """List all tables or collections in the database.""" return self.conn.list_tables() @staticmethod def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None): - """ - Infer a schema from a given data object + """Infer a schema from a given data object. :param data: The data object :param identifier: The identifier for the schema, if None, it will be generated diff --git a/superduperdb/backends/ibis/db_helper.py b/superduperdb/backends/ibis/db_helper.py index da2a28b3c8..6609f4ddf0 100644 --- a/superduperdb/backends/ibis/db_helper.py +++ b/superduperdb/backends/ibis/db_helper.py @@ -6,22 +6,38 @@ class Base64Mixin: + """ + Mixin class for converting byte data to base64 format for storage in the database. + + This class is used to convert byte data to base64 format for storage in the + database. + """ + def convert_data_format(self, data): - """Convert byte data to base64 format for storage in the database.""" + """Convert byte data to base64 format for storage in the database. + + :param data: The data to convert. + """ if isinstance(data, bytes): return BASE64_PREFIX + base64.b64encode(data).decode('utf-8') else: return data def recover_data_format(self, data): - """Recover byte data from base64 format stored in the database.""" + """Recover byte data from base64 format stored in the database. + + :param data: The data to recover. + """ if isinstance(data, str) and data.startswith(BASE64_PREFIX): return base64.b64decode(data[len(BASE64_PREFIX) :]) else: return data def process_schema_types(self, schema_mapping): - """Convert bytes to string in the schema.""" + """Convert bytes to string in the schema. + + :param schema_mapping: The schema mapping to convert. + """ for key, value in schema_mapping.items(): if value == 'Bytes': schema_mapping[key] = 'String' @@ -29,33 +45,69 @@ def process_schema_types(self, schema_mapping): class DBHelper: + """Generic helper class for database. + + :param dialect: The dialect of the database. + """ + match_dialect = 'base' def __init__(self, dialect): self.dialect = dialect def process_before_insert(self, table_name, datas): + """Convert byte data to base64 format for storage in the database. + + :param table_name: The name of the table. + :param datas: The data to insert. + """ return table_name, pd.DataFrame(datas) def process_schema_types(self, schema_mapping): + """Convert bytes to string in the schema. + + :param schema_mapping: The schema mapping to convert. + """ return schema_mapping def convert_data_format(self, data): + """Convert data to the format for storage in the database. + + :param data: The data to convert. + """ return data def recover_data_format(self, data): + """Recover data from the format stored in the database. + + :param data: The data to recover. + """ return data class ClickHouseHelper(Base64Mixin, DBHelper): + """Helper class for ClickHouse database. + + This class is used to convert byte data to base64 format for storage in the + database. + """ + match_dialect = 'clickhouse' def process_before_insert(self, table_name, datas): + """Convert byte data to base64 format for storage in the database. + + :param table_name: The name of the table. + :param datas: The data to insert. + """ return f'`{table_name}`', pd.DataFrame(datas) def get_db_helper(dialect) -> DBHelper: - """Get the insert processor for the given dialect.""" + """Get the insert processor for the given dialect. + + :param dialect: The dialect of the database. + """ for helper in DBHelper.__subclasses__(): if helper.match_dialect == dialect: return helper(dialect) diff --git a/superduperdb/backends/ibis/field_types.py b/superduperdb/backends/ibis/field_types.py index 0884c6221e..0648f02688 100644 --- a/superduperdb/backends/ibis/field_types.py +++ b/superduperdb/backends/ibis/field_types.py @@ -8,6 +8,14 @@ @dc.dataclass class FieldType(Serializable): + """Field type to represent the type of a field in a table. + + This is a wrapper around ibis.expr.datatypes.DataType to make it + serializable. + + :param identifier: The name of the data type. + """ + identifier: t.Union[str, DataType] def __post_init__(self): @@ -16,8 +24,9 @@ def __post_init__(self): def dtype(x): - ''' - Ibis dtype to represent basic data types in ibis - e.g int, str, etc - ''' + """Ibis dtype to represent basic data types in ibis. + + :param x: The data type + e.g int, str, etc. + """ return FieldType(_dtype(x)) diff --git a/superduperdb/backends/ibis/query.py b/superduperdb/backends/ibis/query.py index db085e461d..ea4af0e25c 100644 --- a/superduperdb/backends/ibis/query.py +++ b/superduperdb/backends/ibis/query.py @@ -102,7 +102,8 @@ def _model_update_impl( class IbisBackendError(DatabackendException): - """ + """Ibis backend error. + This error represents ibis query related errors i.e when there is an error while executing an ibis query, use this exception to represent the error. @@ -111,9 +112,7 @@ class IbisBackendError(DatabackendException): @dc.dataclass(repr=False) class IbisCompoundSelect(CompoundSelect): - """ - A query incorporating vector-search and a standard ``ibis`` query - """ + """A query incorporating vector-search and a standard ``ibis`` query.""" DB_TYPE: t.ClassVar[str] = 'SQL' @@ -121,6 +120,7 @@ class IbisCompoundSelect(CompoundSelect): @property def primary_id(self): + """Return the primary id of the query.""" if self.query_linker is None: return self.table_or_collection.primary_id return self.query_linker.primary_id @@ -167,6 +167,7 @@ def _get_query_linker( @property def output_fields(self): + """Return the output fields.""" return self.query_linker.output_fields def _get_query_component( @@ -183,13 +184,9 @@ def _get_query_component( return IbisQueryComponent(name, type=type, args=args, kwargs=kwargs) def outputs(self, *predict_ids): - """ - This method returns a query which joins a query with the outputs - for a table. + """Returns a query which joins a query with the outputs for a table. - :param key: The key on which the model was evaluated - :param model: The model identifier for which to get the outputs - :param version: The version of the model for which to get the outputs (optional) + :param *predict_ids: The predict ids of the outputs >>> q = t.filter(t.age > 25).outputs('txt', 'model_name') """ @@ -224,6 +221,7 @@ def compile(self, db: 'Datalayer', tables: t.Optional[t.Dict] = None): return self.query_linker.compile(db, tables=tables) def get_all_tables(self): + """Get all tables in the query.""" tables = [self.table_or_collection.identifier] if self.query_linker is not None: tables.extend(self.query_linker.get_all_tables()) @@ -253,6 +251,7 @@ def _get_all_fields(self, db): @property def select_table(self): + """Return the select table.""" return self.table_or_collection def _execute_with_pre_like(self, db): @@ -298,11 +297,17 @@ def _execute(self, db): @property def renamings(self): + """Return the renamings.""" if self.query_linker is not None: return self.query_linker.renamings return {} def execute(self, db, reference: bool = False): + """Execute the query. + + :param db: The Datalayer instance + :param reference: Whether to return a reference to the query + """ # TODO handle load_hybrid for `ibis` output, scores = self._execute(db) fields = self._get_all_fields(db) @@ -329,10 +334,10 @@ def execute(self, db, reference: bool = False): ) def select_ids_of_missing_outputs(self, predict_id: str): - """ - Query which selects ids where outputs are missing. - """ + """Query which selects ids where outputs are missing. + :param predict_id: The identifier of the model + """ assert self.pre_like is None assert self.post_like is None assert self.query_linker is not None @@ -354,6 +359,15 @@ def model_update( # type: ignore[override] flatten: bool = False, document_embedded: t.Optional[bool] = None, ): + """Update the model outputs in the output table. + + :param db: The Datalayer instance + :param ids: The ids of the outputs + :param predict_id: The identifier of the model + :param outputs: The outputs of the model + :param flatten: Whether to flatten the outputs + :param document_embedded: Whether the outputs are document embedded + """ if document_embedded is True: logging.warn( "Ibis backend does not support document embedded parameter.", @@ -365,6 +379,10 @@ def model_update( # type: ignore[override] ) def add_fold(self, fold: str) -> Select: + """Add a fold to the query. + + :param fold: The fold to add + """ if self.query_linker is not None: # make sure we have a fold column in the query query_members = [ @@ -380,11 +398,12 @@ def add_fold(self, fold: str) -> Select: class _LogicalExprMixin: - ''' + """_LogicalExpr. + Mixin class which holds '__eq__', '__or__', '__gt__', etc arithmetic operators These methods are overloaded for ibis logical expression dynamic wrapping with superduperdb. - ''' + """ def _logical_expr(self, members, collection, k, other: t.Optional[t.Any] = None): if other is not None: @@ -427,6 +446,13 @@ def getitem(self, other, members, collection): @dc.dataclass(repr=False) class IbisQueryLinker(QueryLinker, _LogicalExprMixin): + """A query linker for ibis queries. + + This class is used to link multiple queries together in a chain. + + :param primary_id: The primary id of the table + """ + primary_id: t.Union[str, t.List[str], None] = None def __post_init__(self): @@ -436,12 +462,14 @@ def __post_init__(self): @property def renamings(self): + """Return the renamings.""" out = {} for m in self.members: out.update(m.renamings) return out def repr_(self) -> str: + """Return the representation of the query.""" out = super().repr_() out = re.sub('\. ', ' ', out) out = re.sub('\.\[', '[', out) @@ -449,10 +477,15 @@ def repr_(self) -> str: @property def output_fields(self): + """Return the output fields.""" return self._output_fields @output_fields.setter def output_fields(self, value): + """Set the output fields. + + :param value: The output fields + """ self._output_fields = value def __eq__(self, other): @@ -496,15 +529,24 @@ def _get_query_component(self, k): @property def select_ids(self): + """Return a query which selects ids.""" return self.select(self.table_or_collection.primary_id) def select_single_id(self, id): + """Return a query which selects a single id. + + :param id: The id to select + """ return self.filter( self.table_or_collection.__getattr__(self.table_or_collection.primary_id) == id ) def select_using_ids(self, ids): + """Return a query which selects using the given ids. + + :param ids: The ids to select + """ return self.filter( self.__getattr__(self.table_or_collection.primary_id).isin(ids) ) @@ -522,6 +564,7 @@ def _select_ids_of_missing_outputs(self, predict_id: str): return out def get_all_tables(self): + """Get all tables in the query.""" out = [] for member in self.members: out.extend(member.get_all_tables()) @@ -541,6 +584,7 @@ def _outputs(self, *identifiers): return other_query def __call__(self, *args, **kwargs): + """Execute the query.""" primary_id = ( [self.primary_id] if isinstance(self.primary_id, str) @@ -583,6 +627,11 @@ def my_filter(item): ) def compile(self, db: 'Datalayer', tables: t.Optional[t.Dict] = None): + """Compile the query. + + :param db: The Datalayer instance + :param tables: The tables to use for the query + """ table_id = self.table_or_collection.identifier if tables is None: tables = {} @@ -594,6 +643,10 @@ def compile(self, db: 'Datalayer', tables: t.Optional[t.Dict] = None): return result, tables def execute(self, db): + """Execute the query. + + :param db: The Datalayer instance + """ native_query, _ = self.compile(db) try: result = native_query.execute() @@ -609,11 +662,12 @@ def execute(self, db): class QueryType(str, enum.Enum): - ''' + """Query type enum. + This class holds type of query query: This means Query and can be called attr: This means Attribute and cannot be called - ''' + """ QUERY = 'query' ATTR = 'attr' @@ -621,7 +675,8 @@ class QueryType(str, enum.Enum): @dc.dataclass(repr=False, kw_only=True) class Table(Component): - """ + """Table component. + This is a representation of an SQL table in ibis, saving the important meta-data associated with the table in the ``superduperdb`` meta-data store. @@ -646,6 +701,10 @@ def __post_init__(self, artifacts): assert self.primary_id != '_input_id', '"_input_id" is a reserved value' def pre_create(self, db: 'Datalayer'): + """Pre-create the table. + + :param db: The Datalayer instance + """ assert self.schema is not None, "Schema must be set" # TODO why? This is done already for e in self.schema.encoders: @@ -670,24 +729,44 @@ def pre_create(self, db: 'Datalayer'): @property def table_or_collection(self): + """Return the table or collection.""" return IbisQueryTable(self.identifier, primary_id=self.primary_id) def compile(self, db: 'Datalayer', tables: t.Optional[t.Dict] = None): + """Compile the query. + + :param db: The Datalayer instance + :param tables: The tables to use for the query + """ return IbisQueryTable(self.identifier, primary_id=self.primary_id).compile( db, tables=tables ) def insert(self, documents, **kwargs): + """Return a query which inserts documents into the table. + + :param documents: The documents to insert + """ return IbisQueryTable( identifier=self.identifier, primary_id=self.primary_id ).insert(documents, **kwargs) def like(self, r: 'Document', vector_index: str, n: int = 10): + """Return a query which finds similar documents to the given document. + + :param r: The document to find similar documents to + :param vector_index: The vector index to use for the search + :param n: The number of similar documents to find + """ return IbisQueryTable( identifier=self.identifier, primary_id=self.primary_id ).like(r=r, vector_index=vector_index, n=n) def outputs(self, *predict_ids): + """Returns a query which joins a query with the model outputs. + + :param *predict_ids: The predict ids of the outputs + """ return IbisQueryTable( identifier=self.identifier, primary_id=self.primary_id ).outputs(*predict_ids) @@ -703,6 +782,7 @@ def __getitem__(self, item): ).__getitem__(item) def to_query(self): + """Return the query representation of the table.""" return IbisCompoundSelect( table_or_collection=IbisQueryTable( self.identifier, primary_id=self.primary_id @@ -717,9 +797,7 @@ def to_query(self): @dc.dataclass(repr=False) class IbisQueryTable(_ReprMixin, TableOrCollection, Select): - """ - This is a symbolic representation of a table - for building ``IbisCompoundSelect`` queries. + """A symbolic representation of a table for building ``IbisCompoundSelect`` queries. :param primary_id: The primary id of the table """ @@ -727,6 +805,11 @@ class IbisQueryTable(_ReprMixin, TableOrCollection, Select): primary_id: str = 'id' def compile(self, db: 'Datalayer', tables: t.Optional[t.Dict] = None): + """Compile the query. + + :param db: The Datalayer instance + :param tables: The tables to use for the query + """ if tables is None: tables = {} if self.identifier not in tables: @@ -734,16 +817,20 @@ def compile(self, db: 'Datalayer', tables: t.Optional[t.Dict] = None): return tables[self.identifier], tables def repr_(self): + """Return the representation of the table.""" return self.identifier def add_fold(self, fold: str) -> Select: + """Add a fold to the query. + + :param fold: The fold to add + """ return self.filter(self.fold == fold) def outputs(self, *predict_ids): - """ - This method returns a query which joins a query with the model outputs. + """Returns a query which joins a query with the model outputs. - :param model: The model identifier for which to get the outputs + :param *predict_ids: The predict ids of the outputs >>> q = t.filter(t.age > 25).outputs('model_name', db) @@ -756,20 +843,31 @@ def outputs(self, *predict_ids): @property def id_field(self): + """Return the primary id of the table.""" return self.primary_id @property def select_table(self) -> Select: + """Return the select table.""" return self @property def select_ids(self) -> Select: + """Select the ids of the table.""" return self.select(self.primary_id) def select_using_ids(self, ids: t.Sequence[t.Any]) -> Select: + """Select using ids. + + :param ids: The ids to select + """ return self.filter(self[self.primary_id].isin(ids)) def select_ids_of_missing_outputs(self, predict_id: str) -> Select: + """Select ids where outputs are missing. + + :param predict_id: The predict id of the outputs + """ output_table = IbisQueryTable( identifier=f'_outputs.{predict_id}', primary_id='output_id', @@ -779,6 +877,10 @@ def select_ids_of_missing_outputs(self, predict_id: str) -> Select: ) def select_single_id(self, id): + """Select a single id from the table. + + :param id: The id to select + """ return self.filter(getattr(self, self.primary_id) == id) def __getitem__(self, item): @@ -821,12 +923,17 @@ def insert( *args, **kwargs, ): + """Insert data into the table.""" return self._insert(*args, **kwargs) def _delete(self, *args, **kwargs): return super()._delete(*args, **kwargs) def execute(self, db): + """Execute the query. + + :param db: The Datalayer instance + """ return db.databackend.conn.table(self.identifier).execute() def model_update( @@ -838,6 +945,14 @@ def model_update( flatten: bool = False, **kwargs, ): + """Update the model outputs in the output table. + + :param db: The Datalayer isinstance + :param ids: The ids of the input data + :param predict_id: The identifier of the model + :param outputs: The outputs of the model + :param flatten: Whether to flatten the outputs + """ return _model_update_impl( db, ids=ids, predict_id=predict_id, outputs=outputs, flatten=flatten ) @@ -877,7 +992,8 @@ def _get_all_tables(item): @dc.dataclass class IbisQueryComponent(QueryComponent): - """ + """Ibis query component. + This class represents a component of an ``ibis`` query. For example ``filter`` in ``t.filter(t.age > 25)``. """ @@ -886,6 +1002,7 @@ class IbisQueryComponent(QueryComponent): @property def primary_id(self): + """Return the primary id of the query component.""" assert self.type == QueryType.QUERY, 'can\'t get primary id of an attribute' primary_id = [] for a in self.args: @@ -906,6 +1023,7 @@ def primary_id(self): @property def renamings(self): + """Return the renamings of the query component.""" if self.name == 'rename': return self.args[0] elif self.name == 'relabel': @@ -927,7 +1045,8 @@ def renamings(self): return out def repr_(self) -> str: - """ + """Return the string representation of the query component. + >>> IbisQueryComponent('__eq__(2)', type=QueryType.QUERY, args=[1, 2]).repr_() """ out = super().repr_() @@ -946,6 +1065,12 @@ def repr_(self) -> str: def compile( self, parent: t.Any, db: 'Datalayer', tables: t.Optional[t.Dict] = None ): + """Compile the query component. + + :param parent: The parent query + :param db: The Datalayer instance + :param tables: The tables to use for the query + """ if self.type == QueryType.ATTR: return getattr(parent, self.name), tables args, tables = _compile_item(self.args, db, tables=tables) @@ -953,6 +1078,7 @@ def compile( return getattr(parent, self.name)(*args, **kwargs), tables def get_all_tables(self): + """Get all tables in the query.""" out = [] out.extend(_get_all_tables(self.args)) out.extend(_get_all_tables(self.kwargs)) @@ -961,6 +1087,8 @@ def get_all_tables(self): @dc.dataclass class IbisInsert(Insert): + """Insert query for ibis.""" + def __post_init__(self): if isinstance(self.documents, pandas.DataFrame): self.documents = [ @@ -971,6 +1099,10 @@ def _encode_documents(self, table: Table) -> t.List[t.Dict]: return [r.encode(table.schema) for r in self.documents] def execute(self, db): + """Execute the query. + + :param db: The Datalayer instance + """ table = db.load( 'table', self.table_or_collection.identifier, @@ -985,6 +1117,7 @@ def execute(self, db): @property def select_table(self): + """Return the table or collection to select from.""" return self.table_or_collection @@ -1004,10 +1137,20 @@ def __iter__(self): @dc.dataclass class RawSQL(RawQuery): + """Raw SQL query. + + :param query: The raw SQL query + :param id_field: The field to use as the primary id + """ + query: str id_field: str = 'id' def execute(self, db): + """Run the query. + + :param db: The DataLayer instance + """ cursor = db.databackend.conn.raw_sql(self.query) try: cursor = cursor.mappings().all() diff --git a/superduperdb/backends/ibis/utils.py b/superduperdb/backends/ibis/utils.py index 92efb13e0c..c9f81827a0 100644 --- a/superduperdb/backends/ibis/utils.py +++ b/superduperdb/backends/ibis/utils.py @@ -1,4 +1,9 @@ +# TODO: Remove the unused function def get_output_table_name(model_identifier, version): - """Get the output table name for the given model.""" + """Get the output table name for the given model. + + :param model_identifier: The identifier of the model. + :param version: The version of the model. + """ # use `_` to connect the model_identifier and version return f'_outputs_{model_identifier}_{version}' diff --git a/superduperdb/backends/local/artifacts.py b/superduperdb/backends/local/artifacts.py index 0ae8f9c383..bb2a0bd95f 100644 --- a/superduperdb/backends/local/artifacts.py +++ b/superduperdb/backends/local/artifacts.py @@ -34,11 +34,12 @@ def _exists(self, file_id: str): return os.path.exists(path) def url(self): + """Return the URL of the artifact store.""" return self.conn def _delete_artifact(self, file_id: str): - """ - Delete artifact from artifact store + """Delete artifact from artifact store. + :param file_id: File id uses to identify artifact in store """ path = os.path.join(self.conn, file_id) @@ -48,8 +49,11 @@ def _delete_artifact(self, file_id: str): os.remove(path) def drop(self, force: bool = False): - """ - Drop the artifact store. + """Drop the artifact store. + + Please use with caution as this will delete all data in the artifact store. + + :param force: Whether to force the drop. """ if not force: if not click.confirm( @@ -78,9 +82,12 @@ def _load_bytes(self, file_id: str) -> bytes: return f.read() def _save_file(self, file_path: str, file_id: str): - """ - Save file in artifact store and return the relative path + """Save file in artifact store and return the relative path. + return the relative path {file_id}/{name} + + :param file_path: The path to the file to be saved. + :param file_id: The id of the file. """ path = Path(file_path) name = path.name @@ -96,13 +103,11 @@ def _save_file(self, file_path: str, file_id: str): return os.path.join(file_id, name) def _load_file(self, file_id: str) -> str: - """Return the path to the file in the artifact store""" + """Return the path to the file in the artifact store.""" logging.info(f"Loading file {file_id} from {self.conn}") return os.path.join(self.conn, file_id) def disconnect(self): - """ - Disconnect the client - """ + """Disconnect the client.""" # Not necessary since just local filesystem pass diff --git a/superduperdb/backends/local/compute.py b/superduperdb/backends/local/compute.py index e6ddaed699..63251481c8 100644 --- a/superduperdb/backends/local/compute.py +++ b/superduperdb/backends/local/compute.py @@ -6,9 +6,7 @@ class LocalComputeBackend(ComputeBackend): - """ - A mockup backend for running jobs locally. - """ + """A mockup backend for running jobs locally.""" def __init__( self, @@ -17,10 +15,12 @@ def __init__( @property def type(self) -> str: + """The type of the backend.""" return "local" @property def name(self) -> str: + """The name of the backend.""" return "local" def submit( @@ -30,6 +30,7 @@ def submit( Submits a function for local execution. :param function: The function to be executed. + :param compute_kwargs: Do not use this parameter. """ logging.info(f"Submitting job. function:{function}") future = function(*args, **kwargs) @@ -44,20 +45,16 @@ def submit( @property def tasks(self) -> t.Dict[str, t.Any]: - """ - List for all pending tasks - """ + """List for all pending tasks.""" return self.__outputs def wait_all(self) -> None: - """ - Waits for all pending tasks to complete. - """ + """Waits for all pending tasks to complete.""" pass def result(self, identifier: str) -> t.Any: - """ - Retrieves the result of a previously submitted task. + """Retrieves the result of a previously submitted task. + Note: This will block until the future is completed. :param identifier: The identifier of the submitted task. @@ -65,13 +62,9 @@ def result(self, identifier: str) -> t.Any: return self.__outputs[identifier] def disconnect(self) -> None: - """ - Disconnect the local client. - """ + """Disconnect the local client.""" pass def shutdown(self) -> None: - """ - Shuts down the local cluster. - """ + """Shuts down the local cluster.""" pass diff --git a/superduperdb/backends/mongodb/artifacts.py b/superduperdb/backends/mongodb/artifacts.py index b32c912dae..2450b7229f 100644 --- a/superduperdb/backends/mongodb/artifacts.py +++ b/superduperdb/backends/mongodb/artifacts.py @@ -25,9 +25,16 @@ def __init__(self, conn, name: str): self.filesystem = gridfs.GridFS(self.db) def url(self): + """Return the URL of the database.""" return self.conn.HOST + ':' + str(self.conn.PORT) + '/' + self.name def drop(self, force: bool = False): + """Drop the database. + + Please use with caution as this will delete all artifacts. + + :param force: If True, will not prompt for confirmation + """ if not force: if not click.confirm( f'{Colors.RED}[!!!WARNING USE WITH CAUTION AS YOU ' @@ -56,7 +63,7 @@ def _load_bytes(self, file_id: str): return cur.read() def _save_file(self, file_path: str, file_id: str): - """Save file to GridFS""" + """Save file to GridFS.""" path = Path(file_path) if path.is_dir(): upload_folder(file_path, file_id, self.filesystem) @@ -65,8 +72,8 @@ def _save_file(self, file_path: str, file_id: str): return file_id def _load_file(self, file_id: str) -> str: - """ - Download file from GridFS and return the path + """Download file from GridFS and return the path. + The path is a temporary directory, {tmp_prefix}/{file_id}/{filename or folder} """ return download(file_id, self.filesystem) @@ -80,15 +87,18 @@ def _save_bytes(self, serialized: bytes, file_id: str): ) def disconnect(self): - """ - Disconnect the client - """ + """Disconnect the client.""" # TODO: implement me def upload_file(path, file_id, fs): - """Upload file to GridFS""" + """Upload file to GridFS. + + :param path: The path to the file to upload + :param file_id: The file_id of the file + :param fs: The GridFS object + """ logging.info(f"Uploading file {path} to GridFS with file_id {file_id}") path = Path(path) with open(path, 'rb') as file_to_upload: @@ -100,7 +110,13 @@ def upload_file(path, file_id, fs): def upload_folder(path, file_id, fs, parent_path=""): - """Upload folder to GridFS""" + """Upload folder to GridFS. + + :param path: The path to the folder to upload + :param file_id: The file_id of the folder + :param fs: The GridFS object + :param parent_path: The parent path of the folder + """ path = Path(path) if not parent_path: logging.info(f"Uploading folder {path} to GridFS with file_id {file_id}") @@ -128,8 +144,13 @@ def upload_folder(path, file_id, fs, parent_path=""): def download(file_id, fs): - """Download file or folder from GridFS and return the path""" + """Download file or folder from GridFS and return the path. + + The path is a temporary directory, {tmp_prefix}/{file_id}/{filename or folder} + :param file_id: The file_id of the file or folder to download + :param fs: The GridFS object + """ download_folder = CFG.downloads.folder if not download_folder: diff --git a/superduperdb/backends/mongodb/cdc/base.py b/superduperdb/backends/mongodb/cdc/base.py index 4ff83b470f..845011ddfe 100644 --- a/superduperdb/backends/mongodb/cdc/base.py +++ b/superduperdb/backends/mongodb/cdc/base.py @@ -12,6 +12,11 @@ class CachedTokens: + """A class to cache the CDC tokens in a file. + + This class is used to cache the CDC tokens in a file. + """ + token_path = os.path.join('.superduperdb', '.cdc.tokens') separate = '\n' @@ -21,12 +26,17 @@ def __init__(self): os.makedirs('.superduperdb', exist_ok=True) def append(self, token: TokenType) -> None: + """Append the token to the file. + + :param token: The token to be appended. + """ with open(CachedTokens.token_path, 'a') as fp: stoken = json.dumps(token) stoken = stoken + self.separate fp.write(stoken) def load(self) -> t.Sequence[TokenType]: + """Load the tokens from the file.""" with open(CachedTokens.token_path) as fp: tokens = fp.read().split(self.separate)[:-1] self._current_tokens = [TokenType(json.loads(t)) for t in tokens] @@ -34,12 +44,22 @@ def load(self) -> t.Sequence[TokenType]: class ObjectId(objectid.ObjectId): + """A class to represent the ObjectId. + + This class is a subclass of the `bson.objectid.ObjectId` class. + Use this class to validate the ObjectId. + """ + @classmethod def __get_validators__(cls): yield cls.validate @classmethod def validate(cls, v): + """Validate the ObjectId. + + :param v: The value to be validated. + """ if not isinstance(v, objectid.ObjectId): raise TypeError('Id is required.') return str(v) @@ -47,8 +67,13 @@ def validate(cls, v): @dc.dataclass class MongoDBPacket(Packet): - """ - A base packet to represent message in task queue. + """A base packet to represent message in task queue. + + This class is a subclass of the `Packet` class. + + :param ids: The ids of the rows. + :param query: The query to be executed. + :param event_type: The event type. """ ids: t.List[t.Union[ObjectId, str]] diff --git a/superduperdb/backends/mongodb/cdc/listener.py b/superduperdb/backends/mongodb/cdc/listener.py index 936512471b..e4e2807888 100644 --- a/superduperdb/backends/mongodb/cdc/listener.py +++ b/superduperdb/backends/mongodb/cdc/listener.py @@ -21,9 +21,7 @@ class CDCKeys(str, Enum): - """ - A enum to represent mongo change document keys. - """ + """A enum to represent mongo change document keys.""" operation_type = 'operationType' document_key = 'documentKey' @@ -42,20 +40,21 @@ class CDCKeys(str, Enum): @dc.dataclass class MongoChangePipeline: - """`MongoChangePipeline` is a class to represent listen pipeline - in mongodb watch api. + """MongoChangePipeline. + + `MongoChangePipeline` is a class to represent listen pipeline in mongodb watch api. + + :param matching_operations: A list of operations to match. """ matching_operations: t.Sequence[str] = dc.field(default_factory=list) def validate(self): + """Validate.""" raise NotImplementedError def build_matching(self) -> t.Sequence[t.Dict]: - """A helper fxn to build a listen pipeline for mongo watch api. - - :param matching_operations: A list of operations to watch. - """ + """A helper fxn to build a listen pipeline for mongo watch api.""" if bad := [op for op in self.matching_operations if op not in cdc.DBEvent]: raise ValueError(f'Unknown operations: {bad}') @@ -63,7 +62,8 @@ def build_matching(self) -> t.Sequence[t.Dict]: class MongoDatabaseListener(cdc.BaseDatabaseListener): - """ + """A class handling change stream in mongodb. + It is a class which helps capture data from mongodb database and handle it accordingly. @@ -72,6 +72,12 @@ class MongoDatabaseListener(cdc.BaseDatabaseListener): This class builds a workflow graph on each change observed. + :param db: It is a datalayer instance. + :param on: It is used to define a Collection on which CDC would be performed. + :param stop_event: A threading event flag to notify for stoppage. + :param identifier: A identifier to represent the listener service. + :param timeout: A timeout to stop the listener service. + :param resume_token: A resume token is a token used to resume """ DEFAULT_ID: str = '_id' @@ -99,7 +105,6 @@ def __init__( :param resume_token: A resume token is a token used to resume the change stream in mongo. """ - self.tokens = CachedTokens() self.resume_token = None @@ -121,6 +126,7 @@ def on_create( self, ids: t.Sequence, db: 'Datalayer', collection: query.Collection ) -> None: """on_create. + A helper on create event handler which handles inserted document in the change stream. It basically extracts the change document and build the taskflow graph to @@ -199,17 +205,18 @@ def _get_reference_id(self, document: t.Dict) -> t.Optional[str]: def dump_token(self, change: t.Dict) -> None: """dump_token. + A helper utility to dump resume token from the changed document. - :param change: + :param change: A change document. """ token = change[self.DEFAULT_ID] self.tokens.append(token) def check_if_taskgraph_change(self, change: t.Dict) -> bool: - """ - A helper method to check if the cdc change is done - by taskgraph nodes. + """A helper method to check if the cdc change is done by taskgraph nodes. + + :param change: A change document. """ if change[CDCKeys.operation_type] == cdc.DBEvent.update: updates = change[CDCKeys.update_descriptions_key] @@ -220,9 +227,7 @@ def check_if_taskgraph_change(self, change: t.Dict) -> bool: return False def setup_cdc(self) -> CollectionChangeStream: - """ - Setup cdc change stream from user provided - """ + """Setup cdc change stream from user provided.""" if isinstance(self._change_pipeline, str): pipeline = self._get_stream_pipeline(self._change_pipeline) @@ -245,10 +250,10 @@ def setup_cdc(self) -> CollectionChangeStream: return stream_iterator def next_cdc(self, stream: CollectionChangeStream) -> None: - """ - Get the next stream of change observed on the given `Collection`. - """ + """Get the next stream of change observed on the given `Collection`. + :param stream: A change stream object. + """ change = stream.try_next() if change is not None: logging.debug(f'Database change encountered at {datetime.datetime.now()}') @@ -268,8 +273,9 @@ def next_cdc(self, stream: CollectionChangeStream) -> None: def set_change_pipeline( self, change_pipeline: t.Optional[t.Union[str, t.Sequence[t.Dict]]] ) -> None: - """ - Set the change pipeline for the listener. + """Set the change pipeline for the listener. + + :param change_pipeline: Change pipeline to listen to. """ if change_pipeline is None: change_pipeline = MongoChangePipelines.get('generic') @@ -280,7 +286,9 @@ def listen( self, change_pipeline: t.Optional[t.Union[str, t.Sequence[t.Dict]]] = None, ) -> None: - """Primary fxn to initiate listening of a database on the collection + """Listen to the database changes. + + Primary fxn to initiate listening of a database on the collection with defined `change_pipeline` by the user. :param change_pipeline: A mongo listen pipeline defined by the user @@ -310,14 +318,12 @@ def listen( raise def resume_tokens(self) -> t.Sequence[TokenType]: - """ - Get the resume tokens from the change stream. - """ + """Get the resume tokens from the change stream.""" return self.tokens.load() def stop(self) -> None: - """ - Stop listening cdc changes. + """Stop listening cdc changes. + This stops the corresponding services as well. """ self._stop_event.set() @@ -325,4 +331,5 @@ def stop(self) -> None: self._scheduler.join() def running(self) -> bool: + """Check if the listener is running or not.""" return not self._stop_event.is_set() diff --git a/superduperdb/backends/mongodb/data_backend.py b/superduperdb/backends/mongodb/data_backend.py index 1152a485bd..e6e122fe79 100644 --- a/superduperdb/backends/mongodb/data_backend.py +++ b/superduperdb/backends/mongodb/data_backend.py @@ -33,16 +33,20 @@ def __init__(self, conn: pymongo.MongoClient, name: str): self._db = self.conn[self.name] def url(self): + """Return the data backend connection url.""" return self.conn.HOST + ':' + str(self.conn.PORT) + '/' + self.name @property def db(self): + """Return the datalayer instance.""" return self._db def build_metadata(self): + """Build the metadata store for the data backend.""" return MongoMetaDataStore(self.conn, self.name) def build_artifact_store(self): + """Build the artifact store for the data backend.""" from mongomock import MongoClient as MockClient if isinstance(self.conn, MockClient): @@ -55,6 +59,12 @@ def build_artifact_store(self): return MongoArtifactStore(self.conn, f'_filesystem:{self.name}') def drop(self, force: bool = False): + """Drop the data backend. + + Please use with caution as you will lose all data. + :param force: Force the drop, default is False. + If False, a confirmation prompt will be displayed. + """ if not force: if not click.confirm( f'{Colors.RED}[!!!WARNING USE WITH CAUTION AS YOU ' @@ -66,15 +76,31 @@ def drop(self, force: bool = False): return self.db.client.drop_database(self.db.name) def get_table_or_collection(self, identifier): + """Get a table or collection from the data backend. + + :param identifier: table or collection identifier + """ return self._db[identifier] def set_content_bytes(self, r, key, bytes_): + """Set the content bytes in the data backend. + + :param r: dictionary containing information about the content + :param key: key to set + :param bytes_: content bytes + """ if not isinstance(r, MongoStyleDict): r = MongoStyleDict(r) r[f'{key}._content.bytes'] = bytes_ return r def exists(self, table_or_collection, id, key): + """Check if a document exists in the data backend. + + :param table_or_collection: table or collection identifier + :param id: document identifier + :param key: key to check + """ return ( self.db[table_or_collection].find_one( {'_id': id, f'{key}._content.bytes': {'$exists': 1}} @@ -82,7 +108,12 @@ def exists(self, table_or_collection, id, key): is not None ) + # TODO: Remove the unused function def unset_outputs(self, info: t.Dict): + """Unset the output field in the data backend. + + :param info: dictionary containing information about the output field + """ select = Serializable.from_dict(info['select']) logging.info(f'unsetting output field _outputs.{info["key"]}.{info["model"]}') doc = {'$unset': {f'_outputs.{info["key"]}.{info["model"]}': 1}} @@ -90,6 +121,7 @@ def unset_outputs(self, info: t.Dict): return self.db[select.collection].update_many(update.filter, update.update) def list_vector_indexes(self): + """List all vector indexes in the data backend.""" indexes = [] for coll in self.db.list_collection_names(): i = self.db.command({'listSearchIndexes': coll}) @@ -102,6 +134,7 @@ def list_vector_indexes(self): return indexes def list_tables_or_collections(self): + """List all tables or collections in the data backend.""" return self.db.list_collection_names() def delete_vector_index(self, vector_index): @@ -124,9 +157,7 @@ def delete_vector_index(self, vector_index): ) def disconnect(self): - """ - Disconnect the client - """ + """Disconnect the client.""" # TODO: implement me @@ -136,15 +167,26 @@ def create_output_dest( datatype: t.Union[None, DataType, FieldType], flatten: bool = False, ): + """Create an output collection for a component. + + That will do nothing for MongoDB. + + :param predict_id: The predict id of the output destination + :param datatype: datatype of component + :param flatten: flatten the output + """ pass def check_output_dest(self, predict_id) -> bool: + """Check if the output destination exists. + + :param predict_id: identifier of the prediction + """ return True @staticmethod def infer_schema(data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None): - """ - Infer a schema from a given data object + """Infer a schema from a given data object. :param data: The data object :param identifier: The identifier for the schema, if None, it will be generated diff --git a/superduperdb/backends/mongodb/metadata.py b/superduperdb/backends/mongodb/metadata.py index b1b94a955d..6d838c9ad6 100644 --- a/superduperdb/backends/mongodb/metadata.py +++ b/superduperdb/backends/mongodb/metadata.py @@ -33,9 +33,15 @@ def __init__( self.conn = conn def url(self): + """Metadata store connection url.""" return self.conn.HOST + ':' + str(self.conn.PORT) + '/' + self.name def drop(self, force: bool = False): + """Drop all meta-data from the metadata store. + + Please always use with caution. This will drop all the meta-data collections. + :param force: whether to force the drop, defaults to False + """ if not force: if not click.confirm( f'{Colors.RED}[!!!WARNING USE WITH CAUTION AS YOU ' @@ -50,6 +56,11 @@ def drop(self, force: bool = False): self.db.drop_collection(self.parent_child_mappings.name) def create_parent_child(self, parent: str, child: str) -> None: + """Create a parent-child relationship between two components. + + :param parent: parent component + :param child: child component + """ self.parent_child_mappings.insert_one( { 'parent': parent, @@ -58,35 +69,72 @@ def create_parent_child(self, parent: str, child: str) -> None: ) def create_component(self, info: t.Dict) -> InsertOneResult: + """Create a component in the metadata store. + + :param info: dictionary containing information about the component. + """ if 'hidden' not in info: info['hidden'] = False return self.component_collection.insert_one(info) def create_job(self, info: t.Dict) -> InsertOneResult: + """Create a job in the metadata store. + + :param info: dictionary containing information about the job. + """ return self.job_collection.insert_one(info) def get_parent_child_relations(self): + """Get parent-child relations from the metadata store.""" c = self.parent_child_mappings.find() return [(r['parent'], r['child']) for r in c] def get_component_version_children(self, unique_id: str): + """Get the children of a component version. + + :param unique_id: unique identifier of component + """ return self.parent_child_mappings.distinct('child', {'parent': unique_id}) def get_job(self, identifier: str): + """Get a job from the metadata store. + + :param identifier: identifier of job + """ return self.job_collection.find_one({'identifier': identifier}) def create_metadata(self, key: str, value: str): + """Create metadata in the metadata store. + + :param key: key to be created + :param value: value to be created + """ return self.meta_collection.insert_one({'key': key, 'value': value}) def get_metadata(self, key: str): + """Get metadata from the metadata store. + + :param key: key to be retrieved + """ return self.meta_collection.find_one({'key': key})['value'] def update_metadata(self, key: str, value: str): + """Update metadata in the metadata store. + + :param key: key to be updated + :param value: value to be updated + """ return self.meta_collection.update_one({'key': key}, {'$set': {'value': value}}) def get_latest_version( self, type_id: str, identifier: str, allow_hidden: bool = False ) -> int: + """Get the latest version of a component. + + :param type_id: type of component + :param identifier: identifier of component + :param allow_hidden: whether to allow hidden components + """ try: if allow_hidden: return sorted( @@ -120,11 +168,21 @@ def f(): raise FileNotFoundError(f'Can\'t find {type_id}: {identifier} in metadata') def update_job(self, identifier: str, key: str, value: t.Any) -> UpdateResult: + """Update a job in the metadata store. + + :param identifier: identifier of job + :param key: key to be updated + :param value: value to be updated + """ return self.job_collection.update_one( {'identifier': identifier}, {'$set': {key: value}} ) def show_components(self, type_id: t.Optional[str] = None): + """Show components in the metadata store. + + :param type_id: type of component + """ # TODO: Should this be sorted? if type_id is not None: return self.component_collection.distinct( @@ -142,17 +200,30 @@ def show_components(self, type_id: t.Optional[str] = None): def show_component_versions( self, type_id: str, identifier: str ) -> t.List[t.Union[t.Any, int]]: + """Show component versions in the metadata store. + + :param type_id: type of component + :param identifier: identifier of component + """ return self.component_collection.distinct( 'version', {'type_id': type_id, 'identifier': identifier} ) def list_components_in_scope(self, scope: str): + """List components in a scope. + + :param scope: scope of components + """ out = [] for r in self.component_collection.find({'parent': scope}): out.append((r['type_id'], r['identifier'])) return out def show_job(self, job_id: str): + """Show a job in the metadata store. + + :param job_id: identifier of job + """ return self.job_collection.find_one({'identifier': job_id}) def show_jobs( @@ -160,6 +231,11 @@ def show_jobs( component_identifier: t.Optional[str] = None, type_id: t.Optional[str] = None, ): + """Show jobs in the metadata store. + + :param component_identifier: identifier of component + :param type_id: type of component + """ filter_ = {} if component_identifier is not None: filter_['component_identifier'] = component_identifier @@ -170,6 +246,12 @@ def show_jobs( def _component_used( self, type_id: str, identifier: str, version: t.Optional[int] = None ) -> bool: + """Check if a component is used in other components. + + :param type_id: type of component + :param identifier: identifier of component + :param version: version of component + """ if version is None: members: t.Union[t.Dict, str] = {'$regex': f'^{identifier}/{type_id}'} else: @@ -178,18 +260,35 @@ def _component_used( return bool(self.component_collection.count_documents({'members': members})) def component_has_parents(self, type_id: str, identifier: str) -> int: + """Check if a component has parents. + + :param type_id: type of component + :param identifier: identifier of component + """ doc = {'child': {'$regex': f'^{type_id}/{identifier}/'}} return self.parent_child_mappings.count_documents(doc) def component_version_has_parents( self, type_id: str, identifier: str, version: int ) -> int: + """Check if a component version has parents. + + :param type_id: type of component + :param identifier: identifier of component + :param version: version of component + """ doc = {'child': Component.make_unique_id(type_id, identifier, version)} return self.parent_child_mappings.count_documents(doc) def delete_component_version( self, type_id: str, identifier: str, version: int ) -> DeleteResult: + """Delete a component version from the metadata store. + + :param type_id: type of component + :param identifier: identifier of component + :param version: version of component + """ if self._component_used(type_id, identifier, version=version): raise Exception('Component version already in use in other components!') @@ -242,6 +341,10 @@ def _get_component( return r def get_component_version_parents(self, unique_id: str) -> t.List[str]: + """Get the parents of a component version. + + :param unique_id: unique identifier of component + """ return [ r['parent'] for r in self.parent_child_mappings.find({'child': unique_id}) ] @@ -266,12 +369,26 @@ def _update_object( value: t.Any, version: int, ): + """Update a component in the metadata store. + + :param identifier: identifier of component + :param type_id: type of component + :param key: key to be updated + :param value: value to be updated + :param version: version of component + """ return self.component_collection.update_one( {'identifier': identifier, 'type_id': type_id, 'version': version}, {'$set': {key: value}}, ) def write_output_to_job(self, identifier, msg, stream): + """Write output to a job in the metadata store. + + :param identifier: identifier of job + :param msg: message to be written + :param stream: stream to be written to + """ if stream not in ('stdout', 'stderr'): raise ValueError(f'stream is "{stream}", should be stdout or stderr') self.job_collection.update_one( @@ -281,14 +398,18 @@ def write_output_to_job(self, identifier, msg, stream): def hide_component_version( self, type_id: str, identifier: str, version: int ) -> None: + """Hide a component version. + + :param type_id: type of component + :param identifier: identifier of component + :param version: version of component + """ self.component_collection.update_one( {'type_id': type_id, 'identifier': identifier, 'version': version}, {'$set': {'hidden': True}}, ) def disconnect(self): - """ - Disconnect the client - """ + """Disconnect the client.""" # TODO: implement me diff --git a/superduperdb/backends/mongodb/query.py b/superduperdb/backends/mongodb/query.py index 2e945f256f..d5034b1c34 100644 --- a/superduperdb/backends/mongodb/query.py +++ b/superduperdb/backends/mongodb/query.py @@ -29,14 +29,17 @@ class FindOne(QueryComponent): - """ - Wrapper around ``pymongo.Collection.find_one`` + """Wrapper around ``pymongo.Collection.find_one``. :param args: Positional arguments to ``pymongo.Collection.find_one`` :param kwargs: Named arguments to ``pymongo.Collection.find_one`` """ def select_using_ids(self, ids): + """Select documents using ids. + + :param ids: The ids to select + """ ids = [ObjectId(id) for id in ids] args = list(self.args)[:] if not args: @@ -50,12 +53,10 @@ def select_using_ids(self, ids): ) def add_fold(self, fold: str): - """ - Modify the query to add a fold to filter {'_fold': fold} + """Modify the query to add a fold to filter {'_fold': fold}. :param fold: The fold to add """ - args = self.args or [{}] args[0]['_fold'] = fold return FindOne( @@ -68,11 +69,9 @@ def add_fold(self, fold: str): @dc.dataclass class Find(QueryComponent): - """ - Wrapper around ``pymongo.Collection.find`` + """Wrapper around ``pymongo.Collection.find``. - :param args: Positional arguments to ``pymongo.Collection.find`` - :param kwargs: Named arguments to ``pymongo.Collection.find`` + :param output_fields: The output fields to project to """ output_fields: t.Optional[t.Dict[str, str]] = None @@ -93,6 +92,7 @@ def __post_init__(self): @property def select_ids(self): + """Select ids.""" args = list(self.args)[:] if not args: args = [{}] @@ -113,7 +113,7 @@ def outputs(self, *predict_ids): """ Join the query with the outputs for a table. - :param **kwargs: key=model/version or key=model pairs + :param *predict_ids: The ids to predict """ args = copy.deepcopy(list(self.args[:])) if not args: @@ -132,6 +132,10 @@ def outputs(self, *predict_ids): ) def select_using_ids(self, ids): + """Select documents using ids. + + :param ids: The ids to select + """ ids = [ObjectId(id) for id in ids] args = list(self.args)[:] if not args: @@ -145,6 +149,10 @@ def select_using_ids(self, ids): ) def select_ids_of_missing_outputs(self, predict_id: str): + """Select ids of missing outputs. + + :param predict_id: The predict id to select + """ assert self.type == QueryType.QUERY if self.args: args = [ @@ -172,6 +180,10 @@ def select_ids_of_missing_outputs(self, predict_id: str): ) def select_single_id(self, id): + """Select a single document by id. + + :param id: The id of the document to select + """ assert self.type == QueryType.QUERY args = list(self.args)[:] if not args: @@ -185,6 +197,10 @@ def select_single_id(self, id): ) def add_fold(self, fold: str): + """Add a fold to the query. + + :param fold: The fold to add + """ args = self.args if not self.args: args = [{}] @@ -199,8 +215,7 @@ def add_fold(self, fold: str): @dc.dataclass class Aggregate(Select): - """ - Wrapper around ``pymongo.Collection.aggregate`` + """Wrapper around ``pymongo.Collection.aggregate``. :param table_or_collection: The table or collection to perform the query on :param vector_index: The vector index to use @@ -215,26 +230,41 @@ class Aggregate(Select): @property def id_field(self): + """Return the id field.""" return self.table_or_collection.primary_id @property def select_table(self): + """Select the table to perform the query on.""" raise NotImplementedError def add_fold(self): + """Add a fold to the query.""" raise NotImplementedError def select_single_id(self, id: str): + """Select a single document by id. + + :param id: The id of the document to select + """ raise NotImplementedError @property def select_ids(self): + """Select ids.""" raise NotImplementedError def select_using_ids(self): + """Select documents using ids.""" raise NotImplementedError def select_ids_of_missing_outputs(self, key: str, model: str, version: int): + """Select ids of missing outputs. + + :param key: The key to select + :param model: The model to select + :param version: The version to select + """ raise NotImplementedError @staticmethod @@ -283,6 +313,11 @@ def _prepare_pipeline(pipeline, db, vector_index): return pipeline def execute(self, db, reference=False): + """Execute the query. + + :param db: The datalayer instance + :param reference: Not used + """ collection = db.databackend.get_table_or_collection( self.table_or_collection.identifier ) @@ -304,6 +339,8 @@ def execute(self, db, reference=False): @dc.dataclass(repr=False) class MongoCompoundSelect(CompoundSelect): + """CompoundSelect class to perform compound queries on a collection.""" + DB_TYPE: t.ClassVar[str] = 'MONGODB' def _deep_flat_encode(self, cache): @@ -339,14 +376,13 @@ def _get_query_linker(self, table_or_collection, members) -> 'QueryLinker': @property def output_fields(self): + """Return the output fields.""" return self.query_linker.output_fields def outputs(self, *predict_ids): - """ - This method returns a query which joins a query with the outputs - for a table. + """Returns a query which joins a query with the outputs for a table. - :param model: The model identifier for which to get the outputs + :param *predict_ids: The ids to predict >>> q = Collection(...).find(...).outputs('key', 'model_name') @@ -360,6 +396,7 @@ def outputs(self, *predict_ids): ) def change_stream(self, *args, **kwargs): + """Change stream for the query.""" return self.table_or_collection.change_stream(*args, **kwargs) def _execute(self, db): @@ -386,6 +423,11 @@ def _execute(self, db): return post_query_linker.execute(db), similar_scores def execute(self, db, reference=False): + """Execute the query. + + :param db: The datalayer instance + :param reference: If True, load the references + """ output, scores = self._execute(db) decode_function = _get_decode_function(db) if isinstance(output, (pymongo.cursor.Cursor, mongomock.collection.Cursor)): @@ -409,7 +451,7 @@ def download_update(self, db, id: str, key: str, bytes: bytearray) -> None: :param db: The db to query :param id: The id to filter on - :param key: + :param key: The key to update :param bytes: The bytes to update """ if self.collection is None: @@ -421,21 +463,30 @@ def download_update(self, db, id: str, key: str, bytes: bytearray) -> None: return collection.update_one({'_id': id}, update) def check_exists(self, db): + """Check if the query exists in the database. + + :param db: The datalayer instance + """ ... @property def select_table(self): + """Select the table to perform the query on.""" return self.table_or_collection.find() @dc.dataclass(repr=False) class MongoQueryLinker(QueryLinker): + """QueryLinker class to link queries together.""" + @property def query_components(self): + """Return the query components.""" return self.table_or_collection.query_components @property def output_fields(self): + """Return the output fields.""" out = {} for member in self.members: if hasattr(member, 'output_fields'): @@ -443,6 +494,10 @@ def output_fields(self): return out def add_fold(self, fold): + """Add a fold to the query. + + :param fold: The fold to add + """ new_members = [] for member in self.members: if hasattr(member, 'add_fold'): @@ -455,6 +510,7 @@ def add_fold(self, fold): ) def outputs(self, *predict_ids): + """Join the query with the outputs for a table.""" new_members = [] for member in self.members: if hasattr(member, 'outputs'): @@ -469,6 +525,7 @@ def outputs(self, *predict_ids): @property def select_ids(self): + """Select ids.""" new_members = [] for member in self.members: if hasattr(member, 'select_ids'): @@ -482,6 +539,10 @@ def select_ids(self): ) def select_using_ids(self, ids): + """Select documents using ids. + + :param ids: The ids to select + """ new_members = [] for member in self.members: if hasattr(member, 'select_using_ids'): @@ -507,6 +568,10 @@ def _select_ids_of_missing_outputs(self, predict_id: str): ) def select_single_id(self, id): + """Select a single document by id. + + :param id: The id of the document to select + """ assert ( len(self.members) == 1 and self.members[0].type == QueryType.QUERY @@ -518,6 +583,10 @@ def select_single_id(self, id): ) def execute(self, db): + """Execute the query. + + :param db: The datalayer instance + """ parent = db.databackend.get_table_or_collection( self.table_or_collection.identifier ) @@ -528,12 +597,18 @@ def execute(self, db): @dc.dataclass(repr=False) class MongoInsert(Insert): + """Insert class to insert a single document in the database. + + :param one: If True, only one document will be inserted + """ + one: bool = False def raw_query(self, db): - ''' - Returns a raw mongodb query for mongodb operation. - ''' + """Returns a raw mongodb query for mongodb operation. + + :param db: The datalayer instance + """ schema = self.kwargs.pop('schema', None) schema = get_schema(db, schema) if schema else None documents = [r.encode(schema) for r in self.documents] @@ -543,6 +618,10 @@ def raw_query(self, db): return documents def execute(self, db): + """Execute the query. + + :param db: The datalayer instance + """ collection = db.databackend.get_table_or_collection( self.table_or_collection.identifier ) @@ -552,26 +631,38 @@ def execute(self, db): @property def select_table(self): + """Select collection to be inserted.""" return self.table_or_collection.find() @dc.dataclass(repr=False) class MongoDelete(Delete): + """Delete class to delete a single document in the database. + + :param one: If True, only one document will be deleted + """ + one: bool = False @property def collection(self): + """Return the collection from the database.""" return self.table_or_collection def to_operation(self, collection): - ''' - Returns a mongodb operation i.e `pymongo.InsertOne` - ''' + """Returns a mongodb operation i.e `pymongo.InsertOne`. + + :param collection: The collection to perform the operation on + """ if self.one: return pymongo.DeleteOne(*self.args, **self.kwargs) return pymongo.DeleteMany(*self.args, **self.kwargs) def arg_ids(self, collection): + """Returns the ids of the documents to be deleted. + + :param collection: The collection to be deleted from + """ ids = [] if '_id' in self.kwargs: @@ -586,6 +677,10 @@ def arg_ids(self, collection): return ids def execute(self, db): + """Execute the query. + + :param db: The datalayer instance + """ collection = db.databackend.get_table_or_collection( self.table_or_collection.identifier ) @@ -613,6 +708,15 @@ def execute(self, db): @dc.dataclass(repr=False) class MongoUpdate(Update): + """Update class to update a single document in the database. + + :param update: The update document + :param filter: The filter to apply + :param one: If True, only one document will be updated + :param args: Positional arguments to ``pymongo.Collection.update_one`` + :param kwargs: Named arguments to ``pymongo.Collection.update_one`` + """ + update: Document filter: t.Dict one: bool = False @@ -621,18 +725,24 @@ class MongoUpdate(Update): @property def select_table(self): + """Select collection to be updated.""" return self.table_or_collection.find() def to_operation(self, collection): - ''' - Returns a mongodb operation i.e `pymongo.InsertOne` - ''' + """Returns a mongodb operation i.e `pymongo.InsertOne`. + + :param collection: The collection to perform the operation on + """ filter, update = self.raw_query(collection) if self.one: return pymongo.UpdateOne(filter, update) return pymongo.UpdateMany(filter, update) def arg_ids(self, collection): + """Returns the ids of the documents to be updated. + + :param collection: The collection to be updated from + """ filter, _ = self.raw_query(collection) if self.one is True: return [filter['_id']] @@ -640,9 +750,10 @@ def arg_ids(self, collection): return filter['_id']['$in'] def raw_query(self, collection): - ''' - Returns a raw mongodb query for mongodb operation. - ''' + """Returns a raw mongodb query for mongodb operation. + + :param collection: The collection to perform the operation on + """ update = self.update if isinstance(self.update, Document): update = update.encode() @@ -655,6 +766,10 @@ def raw_query(self, collection): return {'_id': {'$in': ids}}, update def execute(self, db): + """Execute the query. + + :param db: The datalayer instance + """ collection = db.databackend.get_table_or_collection( self.table_or_collection.identifier ) @@ -670,13 +785,17 @@ def execute(self, db): @dc.dataclass(repr=False) class MongoBulkWrite(Write): - ''' - MongoBulkWrite will help write multiple mongodb operations - to database at once. + """MongoBulkWrite will help write multiple mongodb operations to database at once. - example: + Example: + ------- MongoBulkWrite(operations= [MongoUpdate(...), MongoDelete(...)]) - ''' + + :param operations: List of operations to be performed + :param args: Positional arguments to ``pymongo.Collection.bulk_write`` + :param kwargs: Named arguments to ``pymongo.Collection.bulk_write`` + + """ operations: t.List[t.Union[MongoUpdate, MongoDelete]] args: t.Sequence = dc.field(default_factory=list) @@ -693,12 +812,14 @@ def __post_init__(self): @property def select_table(self): - ''' - Select collection to be bulk written - ''' + """Select collection to be bulk written.""" return self.table_or_collection.find() def execute(self, db): + """Execute the query. + + :param db: The datalayer instance + """ collection = db.databackend.get_table_or_collection( self.table_or_collection.identifier ) @@ -738,6 +859,14 @@ def execute(self, db): @dc.dataclass(repr=False) class MongoReplaceOne(Update): + """Replace class to replace a single document in the database. + + :param replacement: The replacement document + :param filter: The filter to apply + :param args: Positional arguments to ``pymongo.Collection.replace_one`` + :param kwargs: Named arguments to ``pymongo.Collection.replace_one`` + """ + replacement: Document filter: t.Dict args: t.Sequence = dc.field(default_factory=list) @@ -745,13 +874,19 @@ class MongoReplaceOne(Update): @property def collection(self): + """Return the collection from the database.""" return self.table_or_collection @property def select_table(self): + """Return the table from the database.""" return self.table_or_collection.find() def execute(self, db): + """Execute the query. + + :param db: The datalayer instance + """ collection = db.databackend.get_table_or_collection( self.table_or_collection.identifier ) @@ -767,7 +902,7 @@ def execute(self, db): @dc.dataclass class ChangeStream: - """Request a stream of changes from a db + """Change stream class to watch for changes in specified collection. :param collection: The collection to perform the query on :param args: Positional query arguments to ``pymongo.Collection.watch`` @@ -779,22 +914,33 @@ class ChangeStream: kwargs: t.Dict = dc.field(default_factory=dict) def __call__(self, db): + """Watch for changes in the database in specified collection. + + :param db: The datalayer instance + """ collection = db.databackend.get_table_or_collection(self.collection) return collection.watch(**self.kwargs) @dc.dataclass(repr=False) class Collection(TableOrCollection): + """Collection class to perform queries on a collection.""" + query_components: t.ClassVar[t.Dict] = {'find': Find, 'find_one': FindOne} type_id: t.ClassVar[str] = 'collection' primary_id: t.ClassVar[str] = '_id' def get_table(self, db): + """Return the table from the database. + + :param db: The datalayer instance + """ collection = db.databackend.get_table_or_collection(self.collection.identifier) return collection def change_stream(self, *args, **kwargs): + """Request a stream of changes from the collection.""" return ChangeStream( collection=self.identifier, args=args, @@ -845,6 +991,10 @@ def _update(self, filter, update, *args, one: bool = False, **kwargs): def aggregate( self, *args, vector_index: t.Optional[str] = None, **kwargs ) -> Aggregate: + """Perform an aggregation on the collection. + + :param vector_index: The vector index to use + """ return Aggregate( args=args, kwargs=kwargs, @@ -853,12 +1003,19 @@ def aggregate( ) def delete_one(self, *args, **kwargs): + """Delete a single document in the database.""" return self._delete(*args, one=True, **kwargs) def delete_many(self, *args, **kwargs): + """Delete multiple documents in the database.""" return self._delete(*args, one=False, **kwargs) def replace_one(self, filter, replacement, *args, **kwargs): + """Replace a single document in the database. + + :param filter: The filter to apply + :param replacement: The replacement to apply + """ return MongoReplaceOne( filter=filter, replacement=replacement, @@ -868,21 +1025,47 @@ def replace_one(self, filter, replacement, *args, **kwargs): ) def update_one(self, filter, update, *args, **kwargs): + """Update a single document in the database. + + :param filter: The filter to apply + :param update: The update to apply + """ return self._update(filter, update, *args, one=True, **kwargs) def update_many(self, filter, update, *args, **kwargs): + """Update multiple documents in the database. + + :param filter: The filter to apply + :param update: The update to apply + """ return self._update(filter, update, *args, one=False, **kwargs) def bulk_write(self, operations, *args, **kwargs): + """Bulk write multiple operations into the database. + + :param operations: The operations to perform + """ return self._bulk_write(operations, *args, **kwargs) def insert(self, *args, **kwargs): + """Insert multiple documents into the database. + + :param args: The documents to insert + """ return self.insert_many(*args, **kwargs) def insert_many(self, *args, **kwargs): + """Insert multiple documents into the database. + + :param args: The documents to insert + """ return self._insert(*args, **kwargs) def insert_one(self, document, *args, **kwargs): + """Insert a single document into the database. + + :param document: The document to insert + """ return self._insert([document], *args, **kwargs) def model_update( @@ -894,6 +1077,15 @@ def model_update( flatten: bool = False, **kwargs, ): + """Update the outputs of a model in the database. + + :param db: The datalaer instance + :param ids: The ids of the documents to update + :param predict_id: The predict_id of outputs to store + :param outputs: The outputs to store + :param flatten: Whether to flatten the outputs + :param kwargs: Additional arguments + """ document_embedded = kwargs.get('document_embedded', True) if not len(outputs): @@ -972,7 +1164,11 @@ def decode(output): def get_schema(db, schema: t.Union[Schema, str]) -> Schema: - """Handle schema caching and loading.""" + """Handle schema caching and loading. + + :param db: the Datalayer instance + :param schema: the schema to be loaded + """ if isinstance(schema, Schema): # If the schema is not in the db, it is added to the db. if schema.identifier not in db.show(Schema.type_id): diff --git a/superduperdb/backends/mongodb/utils.py b/superduperdb/backends/mongodb/utils.py index 200b9512b9..6a94aef2f0 100644 --- a/superduperdb/backends/mongodb/utils.py +++ b/superduperdb/backends/mongodb/utils.py @@ -6,8 +6,8 @@ def get_avaliable_conn(uri: str, **kwargs): - """ - Get an available connection to the database. + """Get an available connection to the database. + This can avoid some issues with database permission verification. 1. Try to connect to the database with the given URI. 2. Try to connect to the database with the base URI without database name. diff --git a/superduperdb/backends/query_dataset.py b/superduperdb/backends/query_dataset.py index 7a486b3c7f..1082f82dcf 100644 --- a/superduperdb/backends/query_dataset.py +++ b/superduperdb/backends/query_dataset.py @@ -11,6 +11,11 @@ class ExpiryCache(list): + """Expiry Cache for storing documents. + + The document will be removed from the cache after fetching it from the cache. + """ + def __getitem__(self, index): item = super().__getitem__(index) del self[index] @@ -18,19 +23,16 @@ def __getitem__(self, index): class QueryDataset: - """ - A dataset class which can be used to define a torch dataset class. + """Query Dataset for fetching documents from database. :param select: A select query object which defines the query to be executed. - :param keys: A list of keys to be returned from the dataset. + :param mapping: A mapping object to be used for the dataset. + :param ids: A list of ids to be used for the dataset. :param fold: The fold to be used for the dataset. - :param suppress: A list of keys to be suppressed from the dataset. :param transform: A callable which can be used to transform the dataset. - :param db: A ``DB`` object to be used for the dataset. - :param ids: A list of ids to be used for the dataset. + :param db: A datalayer instance to be used for the dataset. :param in_memory: A boolean flag to indicate if the dataset should be loaded in memory. - :param extract: A key to be extracted from the dataset. """ def __init__( @@ -73,6 +75,7 @@ def __init__( @property def db(self): + """Return the datalayer instance.""" if self._db is None: from superduperdb.base.build import build_datalayer @@ -114,12 +117,22 @@ def __getitem__(self, item): class CachedQueryDataset(QueryDataset): - """ + """Cached Query Dataset for fetching documents from database. + This class which fetch the document corresponding to the given ``index``. This class prefetches documents from database and stores in the memory. This can drastically reduce database read operations and hence reduce the overall load on the database. + + :param select: A select query object which defines the query to be executed. + :param mapping: A mapping object to be used for the dataset. + :param ids: A list of ids to be used for the dataset. + :param fold: The fold to be used for the dataset. + :param transform: A callable which can be used to transform the dataset. + :param db: A datalayer instance to be used for the dataset. + :param in_memory: A boolean flag to indicate if the dataset should be loaded + :param prefetch_size: The number of documents to prefetch from the database. """ _BACKFILL_INDEX = 0.2 @@ -154,6 +167,7 @@ def _fetch_cache(self): @property def database(self): + """Return the database object.""" if self._database is None: from superduperdb.base.build import build_datalayer @@ -178,6 +192,13 @@ def __getitem__(self, index): def query_dataset_factory(**kwargs): + """Create a query dataset object. + + If ``data_prefetch`` is set to ``True``, then a ``CachedQueryDataset`` object is + created, otherwise a ``QueryDataset`` object is created. + + :param kwargs: Keyword arguments to be passed to the query dataset object. + """ if kwargs.get('data_prefetch', False): return CachedQueryDataset(**kwargs) kwargs = { diff --git a/superduperdb/backends/ray/compute.py b/superduperdb/backends/ray/compute.py index de39450660..1d6ad338f2 100644 --- a/superduperdb/backends/ray/compute.py +++ b/superduperdb/backends/ray/compute.py @@ -7,8 +7,7 @@ class RayComputeBackend(ComputeBackend): - """ - A client for interacting with a ray cluster. Initialize the ray client. + """A client for interacting with a ray cluster. Initialize the ray client. :param address: The address of the ray cluster. :param local: Set to True to create a local Dask cluster. (optional) @@ -30,10 +29,12 @@ def __init__( @property def type(self) -> str: + """The type of the compute backend.""" return "distributed" @property def name(self) -> str: + """The name of the compute backend.""" return f"ray://{self.address}" def submit( @@ -43,6 +44,7 @@ def submit( Submits a function to the ray server for execution. :param function: The function to be executed. + :param compute_kwargs: Additional keyword arguments to be passed to ray API. """ def _dependable_remote_job(function, *args, **kwargs): @@ -71,30 +73,26 @@ def _dependable_remote_job(function, *args, **kwargs): @property def tasks(self) -> t.Dict[str, ray.ObjectRef]: - """ - List all pending tasks - """ + """List all pending tasks.""" return self._futures_collection def wait(self, identifier: str) -> None: - """ - Waits for task corresponding to identifier to complete. + """Waits for task corresponding to identifier to complete. + :param identifier: Future task id to wait """ ray.wait([self._futures_collection[identifier]]) def wait_all(self) -> None: - """ - Waits for all tasks to complete. - """ + """Waits for all tasks to complete.""" ray.wait( list(self._futures_collection.values()), num_returns=len(self._futures_collection), ) def result(self, identifier: str) -> t.Any: - """ - Retrieves the result of a previously submitted task. + """Retrieves the result of a previously submitted task. + Note: This will block until the future is completed. :param identifier: The identifier of the submitted task. @@ -103,13 +101,9 @@ def result(self, identifier: str) -> t.Any: return ray.get(future) def disconnect(self) -> None: - """ - Disconnect the ray client. - """ + """Disconnect the ray client.""" ray.shutdown() def shutdown(self) -> None: - """ - Shuts down the ray cluster. - """ + """Shuts down the ray cluster.""" ray.shutdown() diff --git a/superduperdb/backends/ray/serve.py b/superduperdb/backends/ray/serve.py index 630f871e6a..434d982150 100644 --- a/superduperdb/backends/ray/serve.py +++ b/superduperdb/backends/ray/serve.py @@ -17,16 +17,19 @@ def run( ray_actor_options: t.Dict = {}, route_prefix: str = '/', ): - ''' - Serve a superduperdb model on ray cluster - ''' + """Serve a superduperdb model on ray cluster. + + :param model: model identifier + :param version: model version + :param num_replicas: number of replicas + :param ray_actor_options: ray actor options + :param route_prefix: route prefix + """ @serve.deployment(ray_actor_options=ray_actor_options, num_replicas=num_replicas) @serve.ingress(app.app) class SuperDuperRayServe: - ''' - A ray deployment which serves a superduperdb model with default ingress - ''' + """A ray deployment which serves a superduperdb model with default ingress.""" def __init__(self, model_identifier: str, version: t.Optional[int]): from superduperdb.base.build import build_datalayer diff --git a/superduperdb/backends/sqlalchemy/db_helper.py b/superduperdb/backends/sqlalchemy/db_helper.py index 197675fe78..37f826d0f8 100644 --- a/superduperdb/backends/sqlalchemy/db_helper.py +++ b/superduperdb/backends/sqlalchemy/db_helper.py @@ -15,30 +15,47 @@ class JsonMixin: """Mixin for JSON type columns. + Converts dict to JSON strings before saving to database and converts JSON strings to dict when loading from database. """ def process_bind_param(self, value, dialect): + """Convert dict to JSON string. + + :param value: The dict to convert. + :param dialect: The dialect of the database. + """ if value is not None: value = json.dumps(value) return value def process_result_value(self, value, dialect): + """Convert JSON string to dict. + + :param value: The JSON string to convert. + :param dialect: The dialect of the database. + """ if value is not None: value = json.loads(value) return value class JsonAsString(JsonMixin, TypeDecorator): + """JSON type column for short JSON strings.""" + impl = String(DEFAULT_LENGTH) class JsonAsText(JsonMixin, TypeDecorator): + """JSON type column for long JSON strings.""" + impl = Text class DefaultConfig: + """Default configuration for database types.""" + type_string = String(DEFAULT_LENGTH) type_json_as_string = JsonAsString type_json_as_text = JsonAsText @@ -54,6 +71,7 @@ class DefaultConfig: def create_clickhouse_config(): + """Create configuration for ClickHouse database.""" # lazy import try: from clickhouse_sqlalchemy import engines, types @@ -89,6 +107,10 @@ class JsonAsText(JsonMixin, TypeDecorator): def get_db_config(dialect): + """Get the configuration class for the specified dialect. + + :param dialect: The dialect of the database. + """ if dialect == 'clickhouse': return create_clickhouse_config() else: diff --git a/superduperdb/backends/sqlalchemy/metadata.py b/superduperdb/backends/sqlalchemy/metadata.py index 8dea606205..e081cc4d09 100644 --- a/superduperdb/backends/sqlalchemy/metadata.py +++ b/superduperdb/backends/sqlalchemy/metadata.py @@ -119,11 +119,13 @@ def _init_tables(self): metadata.create_all(self.conn) def url(self): + """Return the URL of the metadata store.""" return self.conn.url + self.name def drop(self, force: bool = False): - """ - Drop the metadata store. + """Drop the metadata store. + + :param force: whether to force the drop (without confirmation) """ if not force: if not click.confirm( @@ -141,6 +143,7 @@ def drop(self, force: bool = False): @contextmanager def session_context(self): + """Provide a transactional scope around a series of operations.""" sm = sessionmaker(bind=self.conn) session = sm() try: @@ -157,6 +160,12 @@ def session_context(self): def component_version_has_parents( self, type_id: str, identifier: str, version: int ): + """Check if a component version has parents. + + :param type_id: the type of the component + :param identifier: the identifier of the component + :param version: the version of the component + """ unique_id = _Component.make_unique_id(type_id, identifier, version) with self.session_context() as session: stmt = ( @@ -170,6 +179,10 @@ def component_version_has_parents( return len(res) > 0 def create_component(self, info: t.Dict): + """Create a component in the metadata store. + + :param info: the information to create + """ if 'hidden' not in info: info['hidden'] = False info['id'] = f'{info["type_id"]}/{info["identifier"]}/{info["version"]}' @@ -178,6 +191,11 @@ def create_component(self, info: t.Dict): session.execute(stmt) def create_parent_child(self, parent_id: str, child_id: str): + """Create a parent-child relationship between two components. + + :param parent_id: the parent component + :param child_id: the child component + """ with self.session_context() as session: stmt = insert(self.parent_child_association_table).values( parent_id=parent_id, child_id=child_id @@ -185,6 +203,12 @@ def create_parent_child(self, parent_id: str, child_id: str): session.execute(stmt) def delete_component_version(self, type_id: str, identifier: str, version: int): + """Delete a component from the metadata store. + + :param type_id: the type of the component + :param identifier: the identifier of the component + :param version: the version of the component + """ with self.session_context() as session: stmt = ( self.component_table.select() @@ -210,6 +234,13 @@ def _get_component( version: int, allow_hidden: bool = False, ): + """Get a component from the metadata store. + + :param type_id: the type of the component + :param identifier: the identifier of the component + :param version: the version of the component + :param allow_hidden: whether to allow hidden components + """ with self.session_context() as session: stmt = select(self.component_table).where( self.component_table.c.type_id == type_id, @@ -223,6 +254,10 @@ def _get_component( return res[0] if res else None def get_component_version_parents(self, unique_id: str): + """Get the parents of a component version. + + :param unique_id: the unique identifier of the component version + """ with self.session_context() as session: stmt = select(self.parent_child_association_table).where( self.parent_child_association_table.c.child_id == unique_id, @@ -234,6 +269,12 @@ def get_component_version_parents(self, unique_id: str): def get_latest_version( self, type_id: str, identifier: str, allow_hidden: bool = False ): + """Get the latest version of a component. + + :param type_id: the type of the component + :param identifier: the identifier of the component + :param allow_hidden: whether to allow hidden components + """ with self.session_context() as session: stmt = ( select(self.component_table) @@ -255,6 +296,12 @@ def get_latest_version( return versions[0] def hide_component_version(self, type_id: str, identifier: str, version: int): + """Hide a component in the metadata store. + + :param type_id: the type of the component + :param identifier: the identifier of the component + :param version: the version of the component + """ with self.session_context() as session: stmt = ( self.component_table.update() @@ -287,6 +334,13 @@ def replace_component( type_id: str, version: t.Optional[int] = None, ) -> None: + """Replace a component in the metadata store. + + :param info: the information to replace + :param identifier: the identifier of the component + :param type_id: the type of the component + :param version: the version of the component + """ if version is not None: version = self.get_latest_version(type_id, identifier) return self._replace_object( @@ -297,6 +351,10 @@ def replace_component( ) def show_components(self, type_id: t.Optional[str] = None): + """Show all components in the database. + + :param type_id: the type of the component + """ with self.session_context() as session: stmt = select(self.component_table) if type_id is not None: @@ -313,6 +371,11 @@ def show_components(self, type_id: t.Optional[str] = None): ] def show_component_versions(self, type_id: str, identifier: str): + """Show all versions of a component in the database. + + :param type_id: the type of the component + :param identifier: the identifier of the component + """ with self.session_context() as session: stmt = select(self.component_table).where( self.component_table.c.type_id == type_id, @@ -346,11 +409,19 @@ def _update_object( # --------------- JOBS ----------------- def create_job(self, info: t.Dict): + """Create a job with the given info. + + :param info: The information used to create the job + """ with self.session_context() as session: stmt = insert(self.job_table).values(**info) session.execute(stmt) def get_job(self, job_id: str): + """Get the job with the given job_id. + + :param job_id: The identifier of the job + """ with self.session_context() as session: stmt = ( select(self.job_table) @@ -361,6 +432,10 @@ def get_job(self, job_id: str): return res[0] if res else None def listen_job(self, identifier: str): + """Listen a job. + + :param identifier: the identifier of the job + """ # Not supported currently raise NotImplementedError @@ -369,6 +444,11 @@ def show_jobs( component_identifier: t.Optional[str] = None, type_id: t.Optional[str] = None, ): + """Show all jobs in the database. + + :param component_identifier: the identifier of the component + :param type_id: the type of the component + """ with self.session_context() as session: # Start building the select statement stmt = select(self.job_table) @@ -389,6 +469,12 @@ def show_jobs( return [r['identifier'] for r in res] def update_job(self, job_id: str, key: str, value: t.Any): + """Update the job with the given key and value. + + :param job_id: The identifier of the job + :param key: The key to update + :param value: The value to update + """ with self.session_context() as session: stmt = ( self.job_table.update() @@ -398,17 +484,32 @@ def update_job(self, job_id: str, key: str, value: t.Any): session.execute(stmt) def write_output_to_job(self, identifier, msg, stream): + """Write output to the job. + + :param identifier: the identifier of the job + :param msg: the message to write + :param stream: the stream to write to + """ # Not supported currently raise NotImplementedError # --------------- METADATA ----------------- def create_metadata(self, key, value): + """Create metadata with the given key and value. + + :param key: The key to create + :param value: The value to create + """ with self.session_context() as session: stmt = insert(self.meta_table).values(key=key, value=value) session.execute(stmt) def get_metadata(self, key): + """Get the metadata with the given key. + + :param key: The key to retrieve + """ with self.session_context() as session: stmt = select(self.meta_table).where(self.meta_table.c.key == key).limit(1) res = self.query_results(self.meta_table, stmt, session) @@ -416,6 +517,11 @@ def get_metadata(self, key): return value def update_metadata(self, key, value): + """Update the metadata with the given key. + + :param key: The key to update + :param value: The updated value + """ with self.session_context() as session: stmt = ( self.meta_table.update() @@ -426,6 +532,11 @@ def update_metadata(self, key, value): # --------------- Query ID ----------------- def add_query(self, query: 'Select', model: str): + """Add a query to the query table. + + :param query: The query to add to the table. + :param model: The model to associate with the query. + """ query_hash = str(hash(query)) with self.session_context() as session: @@ -440,8 +551,9 @@ def add_query(self, query: 'Select', model: str): session.execute(stmt) def get_query(self, query_hash: str): - """ - Get the query from the query table corresponding to the query hash + """Get the query from the query table corresponding to the query hash. + + :param query_hash: The hash of the query to retrieve. """ try: with self.session_context() as session: @@ -463,8 +575,9 @@ def get_query(self, query_hash: str): raise NonExistentMetadataError(f'Query hash {query_hash} does not exist') def get_model_queries(self, model: str): - """ - Get queries related to the given model. + """Get queries related to the given model. + + :param model: The name of the model to retrieve queries for. """ with self.session_context() as session: stmt = select(self.query_id_table).where( @@ -483,13 +596,17 @@ def get_model_queries(self, model: str): return unpacked_queries def disconnect(self): - """ - Disconnect the client - """ + """Disconnect the client.""" # TODO: implement me def query_results(self, table, statment, session): + """Query the database and return the results as a list of row datas. + + :param table: The table object to query, used to derive column names. + :param statment: The SQL statement to execute. + :param session: The database session within which the query is executed. + """ # Some databases don't support defining statment outside of session result = session.execute(statment) columns = [col.name for col in table.columns] diff --git a/superduperdb/base/build.py b/superduperdb/base/build.py index 2be45c4be9..43de3821f1 100644 --- a/superduperdb/base/build.py +++ b/superduperdb/base/build.py @@ -169,7 +169,6 @@ def build_datalayer(cfg=None, databackend=None, **kwargs) -> Datalayer: If None, use ``superduperdb.CFG.data_backend``. :pararm kwargs: keyword arguments to be adopted by the `CFG` """ - # Configuration # ------------------------------ # Use the provided configuration or fall back to the default configuration. @@ -197,11 +196,11 @@ def build_datalayer(cfg=None, databackend=None, **kwargs) -> Datalayer: def show_configuration(cfg): - """ - Show the configuration. + """Show the configuration. + Only show the important configuration values and anonymize the URLs. - : param cfg: The configuration object. + :param cfg: The configuration object. """ table = PrettyTable() table.field_names = ["Configuration", "Value"] diff --git a/superduperdb/base/code.py b/superduperdb/base/code.py index 6161c35f54..64c2e2771d 100644 --- a/superduperdb/base/code.py +++ b/superduperdb/base/code.py @@ -14,11 +14,22 @@ @dc.dataclass class Code(Serializable): + """A class to store remote code. + + This class stores remote code that can be executed on a remote server. + + :param code: The code to store. + """ + code: str default: t.ClassVar[str] = default @staticmethod def from_object(obj): + """Create a Code object from a callable object. + + :param obj: The object to create the Code object from. + """ code = inspect.getsource(obj) mini_module = template.format( @@ -39,4 +50,8 @@ def __post_init__(self): self.object = remote_code def unpack(self, db=None): + """Unpack the code object. + + :param db: Do not use this parameter, should be None. + """ return self.object diff --git a/superduperdb/base/config.py b/superduperdb/base/config.py index 23d1b821a6..cf9b0047d7 100644 --- a/superduperdb/base/config.py +++ b/superduperdb/base/config.py @@ -1,4 +1,5 @@ -""" +"""Configuration variables for SuperDuperDB. + The classes in this file define the configuration variables for SuperDuperDB, which means that this file gets imported before alost anything else, and canot contain any other imports from this project. @@ -28,7 +29,14 @@ def _dataclass_from_dict(data_class: t.Any, data: dict): @dc.dataclass class BaseConfig: + """A base class for configuration dataclasses. + + This class allows for easy updating of configuration dataclasses + with a dictionary of parameters. + """ + def __call__(self, **kwargs): + """Update the configuration with the given parameters.""" parameters = self.dict() for k, v in kwargs.items(): if '__' in k: @@ -41,13 +49,13 @@ def __call__(self, **kwargs): return _dataclass_from_dict(type(self), parameters) def dict(self): + """Return the configuration as a dictionary.""" return dc.asdict(self) @dc.dataclass class Retry(BaseConfig): - """ - Describes how to retry using the `tenacity` library + """Describes how to retry using the `tenacity` library. :param stop_after_attempt: The number of attempts to make :param wait_max: The maximum time to wait between attempts @@ -63,13 +71,23 @@ class Retry(BaseConfig): @dc.dataclass class CDCStrategy: - '''Base CDC strategy dataclass''' + """Base CDC strategy dataclass. + + :param type: The type of CDC strategy + """ type: str @dc.dataclass class PollingStrategy(CDCStrategy): + """Describes a polling strategy for change data capture. + + :param auto_increment_field: The field to use for auto-incrementing + :param frequency: The frequency to poll for changes + :param type: The type of CDC strategy + """ + auto_increment_field: t.Optional[str] = None frequency: float = 3600 type: 'str' = 'incremental' @@ -77,18 +95,37 @@ class PollingStrategy(CDCStrategy): @dc.dataclass class LogBasedStrategy(CDCStrategy): + """Describes a log-based strategy for change data capture. + + :param resume_token: The resume token to use for log-based CDC + :param type: The type of CDC strategy + """ + resume_token: t.Optional[t.Dict[str, str]] = None type: str = 'logbased' @dc.dataclass class CDCConfig(BaseConfig): + """Describes the configuration for change data capture. + + :param uri: The URI for the CDC service + :param strategy: The strategy to use for CDC + """ + uri: t.Optional[str] = None # None implies local mode strategy: t.Optional[t.Union[PollingStrategy, LogBasedStrategy]] = None @dc.dataclass class VectorSearch(BaseConfig): + """Describes the configuration for vector search. + + :param uri: The URI for the vector search service + :param type: The type of vector search service + :param backfill_batch_size: The size of the backfill batch + """ + uri: t.Optional[str] = None # None implies local mode type: str = 'in_memory' # in_memory|lance backfill_batch_size: int = 100 @@ -96,19 +133,29 @@ class VectorSearch(BaseConfig): @dc.dataclass class Rest(BaseConfig): + """Describes the configuration for the REST service. + + :param uri: The URI for the REST service + """ + uri: t.Optional[str] = None @dc.dataclass class Compute(BaseConfig): + """Describes the configuration for distributed computing. + + :param uri: The URI for the compute service + :param compute_kwargs: The keyword arguments to pass to the compute service + """ + uri: t.Optional[str] = None # None implies local mode compute_kwargs: t.Dict = dc.field(default_factory=dict) @dc.dataclass class Cluster(BaseConfig): - """ - Describes a connection to distributed work via Dask + """Describes a connection to distributed work via Ray. :param compute: The URI for compute - None: run all jobs in local mode i.e. simple function call @@ -116,6 +163,7 @@ class Cluster(BaseConfig): :param vector_search: The URI for the vector search service None: Run vector search on local "http://:": Connect a remote vector search service + :param rest: The URI for the REST service :param cdc: The URI for the change data capture service (if "None" then no cdc assumed) None: Run cdc on local as a thread. @@ -129,9 +177,7 @@ class Cluster(BaseConfig): class LogLevel(str, Enum): - """ - Enumerate log severity level - """ + """Enumerate log severity level.""" DEBUG = 'DEBUG' INFO = 'INFO' @@ -141,9 +187,7 @@ class LogLevel(str, Enum): class LogType(str, Enum): - """ - Enumerate the standard logs - """ + """Enumerate the standard logs.""" # SYSTEM uses the systems STDOUT and STDERR for printing the logs. # DEBUG, INFO, and WARN go to STDOUT. @@ -155,12 +199,22 @@ class LogType(str, Enum): class BytesEncoding(str, Enum): + """Enumerate the encoding of bytes in the data backend.""" + BYTES = 'Bytes' BASE64 = 'Str' @dc.dataclass class Downloads(BaseConfig): + """Describes the configuration for downloading files. + + :param folder: The folder to download files to + :param n_workers: The number of workers to use for downloading + :param headers: The headers to use for downloading + :param timeout: The timeout for downloading + """ + folder: t.Optional[str] = None n_workers: int = 0 headers: t.Dict = dc.field(default_factory=lambda: {'User-Agent': 'me'}) @@ -169,11 +223,12 @@ class Downloads(BaseConfig): @dc.dataclass class Config(BaseConfig): - """ - The data class containing all configurable superduperdb values + """The data class containing all configurable superduperdb values. + :param envs: The envs datas :param data_backend: The URI for the data backend - :param vector_search: The configuration for the vector search {'in_memory', 'lance'} + :param lance_home: The home directory for the Lance vector indices, + Default: .superduperdb/vector_indices :param artifact_store: The URI for the artifact store :param metadata_store: The URI for the metadata store :param cluster: Settings distributed computing and change data capture @@ -212,20 +267,20 @@ def __post_init__(self, envs): @property def hybrid_storage(self): + """Whether to use hybrid storage.""" return self.downloads.folder is not None @property def comparables(self): - """ - A dict of `self` excluding some defined attributes. - """ + """A dict of `self` excluding some defined attributes.""" _dict = dc.asdict(self) list(map(_dict.pop, ('cluster', 'retries', 'downloads'))) return _dict def match(self, cfg: t.Dict): - """ - Match the target cfg dict with `self` comparables dict. + """Match the target cfg dict with `self` comparables dict. + + :param cfg: The target configuration dictionary. """ self_cfg = self.comparables self_hash = hash(json.dumps(self_cfg, sort_keys=True)) @@ -233,9 +288,14 @@ def match(self, cfg: t.Dict): return self_hash == cfg_hash def diff(self, cfg: t.Dict): + """Return the difference between `self` and the target cfg dict. + + :param cfg: The target configuration dictionary. + """ return _diff(self.dict(), cfg) def to_yaml(self): + """Return the configuration as a YAML string.""" import yaml def enum_representer(dumper, data): @@ -249,7 +309,8 @@ def enum_representer(dumper, data): def _diff(r1, r2): - """ + """Return the difference between two dictionaries. + >>> _diff({'a': 1, 'b': 2}, {'a': 2, 'b': 2}) {'a': (1, 2)} >>> _diff({'a': {'c': 3}, 'b': 2}, {'a': 2, 'b': 2}) diff --git a/superduperdb/base/config_dicts.py b/superduperdb/base/config_dicts.py index 3b42aca324..141ffc7dba 100644 --- a/superduperdb/base/config_dicts.py +++ b/superduperdb/base/config_dicts.py @@ -1,4 +1,5 @@ -""" +"""Utility functions for combining and converting dictionaries. + Operations on dictionaries used to fill and combine config files and environment variables """ @@ -15,6 +16,10 @@ def combine_configs(dicts: t.Sequence[Dict]) -> Dict: + """Combine a sequence of dictionaries into a single dictionary. + + :param dicts: The dictionaries to combine. + """ result: Dict = {} for d in dicts: _combine_one(result, d) @@ -28,6 +33,15 @@ def environ_to_config_dict( err: t.Optional[t.TextIO] = sys.stderr, fail: bool = False, ): + """Convert environment variables to a configuration dictionary. + + :param prefix: The prefix to use for environment variables. + :param parent: The parent dictionary to use as a basis. + :param environ: The environment variables to read from. + :param err: The file to write errors to. + :param fail: Whether to raise an exception on error. + :return: The configuration dictionary. + """ env_dict = _environ_dict(prefix, environ) good, bad = _env_dict_to_config_dict(env_dict, parent) bad = {k: v for k, v in bad.items() if k != 'SUPERDUPERDB_CONFIG'} diff --git a/superduperdb/base/configs.py b/superduperdb/base/configs.py index 37e7a5acbd..c14ff35be5 100644 --- a/superduperdb/base/configs.py +++ b/superduperdb/base/configs.py @@ -24,13 +24,11 @@ class ConfigError(Exception): @dataclass(frozen=True) class ConfigSettings: - """ - A class that reads a dataclass class from a configuration file and - environment variables. + """Helper class to read a configuration from a dataclass. + + Reads a dataclass class from a configuration file and environment variables. :param cls: The Pydantic class to read. - :param default_files: The default config files to read. - :param prefix: The prefix to use for environment variables. :param environ: The environment variables to read from. """ @@ -39,8 +37,7 @@ class ConfigSettings: @cached_property def config(self) -> t.Any: - """Read a configuration using defaults as basis""" - + """Read a configuration using defaults as basis.""" parent = self.cls().dict() env = dict(os.environ if self.environ is None else self.environ) env = config_dicts.environ_to_config_dict(PREFIX, parent, env) diff --git a/superduperdb/base/cursor.py b/superduperdb/base/cursor.py index f003ce27c7..c225da736a 100644 --- a/superduperdb/base/cursor.py +++ b/superduperdb/base/cursor.py @@ -11,14 +11,18 @@ @dc.dataclass class SuperDuperCursor: - """ + """A wrapper around a raw cursor that adds some extra functionality. + A cursor that wraps a cursor and returns ``Document`` wrapping a dict including ``Encodable`` objects. :param raw_cursor: the cursor to wrap :param id_field: the field to use as the document id - :param encoders: a dict of encoders to use to decode the documents + :param db: the datalayer to use to decode the documents :param scores: a dict of scores to add to the documents + :param decode_function: a function to use to decode the documents + :param _it: an iterator to keep track of the current position in the cursor, + Default is 0. """ raw_cursor: t.Any @@ -30,9 +34,7 @@ class SuperDuperCursor: _it: int = 0 def limit(self, *args, **kwargs) -> 'SuperDuperCursor': - """ - Limit the number of results returned by the cursor. - """ + """Limit the number of results returned by the cursor.""" return SuperDuperCursor( raw_cursor=self.raw_cursor.limit(*args, **kwargs), id_field=self.id_field, @@ -42,6 +44,7 @@ def limit(self, *args, **kwargs) -> 'SuperDuperCursor': ) def cursor_next(self): + """Get the next document from the cursor.""" if isinstance(self.raw_cursor, list): if self._it >= len(self.raw_cursor): raise StopIteration diff --git a/superduperdb/base/datalayer.py b/superduperdb/base/datalayer.py index 2d392f5278..58757cae1b 100644 --- a/superduperdb/base/datalayer.py +++ b/superduperdb/base/datalayer.py @@ -117,9 +117,7 @@ def __init__( @property def server_mode(self): - """ - Property for server mode. - """ + """Property for server mode.""" return self._server_mode @server_mode.setter @@ -207,7 +205,9 @@ def backfill_vector_search(self, vi, searcher): def set_compute(self, new: ComputeBackend): """ - Set a new compute engine at runtime. Use it only if you know what you are doing. + Set a new compute engine at runtime. + + Use it only if you know what you are doing. The standard procedure is to set the compute engine during initialization. :param new: New compute backend. @@ -225,9 +225,7 @@ def set_compute(self, new: ComputeBackend): self.compute = new def get_compute(self): - """ - Get compute. - """ + """Get compute.""" return self.compute def drop(self, force: bool = False): @@ -256,6 +254,7 @@ def show( ): """ Show available functionality which has been added using ``self.add``. + If the version is specified, then print full metadata. :param type_id: Type_id of component to show ['datatype', 'model', 'listener', @@ -315,7 +314,6 @@ def execute(self, query: ExecuteQuery, *args, **kwargs) -> ExecuteResult: :param query: The SQL query to execute, such as select, insert, delete, or update. """ - if isinstance(query, Delete): return self._delete(query, *args, **kwargs) if isinstance(query, Insert): @@ -349,7 +347,6 @@ def _delete(self, delete: Delete, refresh: bool = True) -> DeleteResult: :param delete: The delete query object specifying the data to be deleted. """ - result = delete.execute(self) if refresh and not self.cdc.running: return result, self.refresh_after_delete(delete, ids=result) @@ -365,7 +362,6 @@ def _insert( :param refresh: Boolean indicating whether to refresh the task group on insert. :param datatypes: List of datatypes in the insert documents. """ - for e in datatypes: self.add(e) @@ -400,7 +396,6 @@ def _select(self, select: Select, reference: bool = True) -> SelectResult: :param select: The select query object specifying the data to be retrieved. """ - if select.variables: select = select.set_variables(self) # type: ignore[assignment] return select.execute(self, reference=reference) @@ -419,7 +414,6 @@ def refresh_after_delete( :param ids: IDs that further reduce the scope of computations. :param verbose: Set to ``True`` to enable more detailed output. """ - task_workflow: TaskWorkflow = self._build_delete_task_workflow( query, ids=ids, @@ -444,7 +438,6 @@ def refresh_after_update_or_insert( :param verbose: Set to ``True`` to enable more detailed output. :param overwrite: If True, cascade the value to the 'predict_in_db' job. """ - task_workflow: TaskWorkflow = self._build_task_workflow( query.select_table, # TODO can be replaced by select_using_ids ids=ids, @@ -461,7 +454,6 @@ def _write(self, write: Write, refresh: bool = True) -> UpdateResult: :param write: The update query object specifying the data to be written. :param refresh: Boolean indicating whether to refresh the task group on write. """ - write_result, updated_ids, deleted_ids = write.execute(self) cdc_status = self.cdc.running or s.CFG.cluster.cdc.uri is not None @@ -533,8 +525,10 @@ def apply( dependencies: t.Sequence[Job] = (), ): """ - Add functionality in the form of components. Components are stored in the - configured artifact store and linked to the primary database through metadata. + Add functionality in the form of components. + + Components are stored in the configured artifact store + and linked to the primary database through metadata. :param object: Object to be stored. :param dependencies: List of jobs which should execute before component @@ -572,7 +566,6 @@ def remove( :param version: [Optional] Numerical version to remove. :param force: Force skip confirmation (use with caution). """ - # TODO: versions = [version] if version is not None else ... if version is not None: return self._remove_component_version( @@ -636,7 +629,6 @@ def load( of deprecated components. :param info_only: Toggle to ``True`` to return metadata only. """ - if type_id == 'encoder': logging.warn( '"encoder" has moved to "datatype" this functionality will not work' @@ -724,10 +716,7 @@ def _build_task_workflow( verbose: bool = True, overwrite: bool = False, ) -> TaskWorkflow: - """ - A helper function to build a task workflow for a query with dependencies. - """ - + """A helper function to build a task workflow for a query with dependencies.""" logging.debug(f"Building task workflow graph. Query:{query}") job_ids: t.Dict[str, t.Any] = defaultdict(lambda: []) @@ -1012,14 +1001,14 @@ def replace( upsert: bool = False, ): """ - (Use with caution!) Replace a model in the artifact store with - an updated object. + Replace a model in the artifact store with an updated object. + + (Use with caution!) :param object: The object to replace. :param upsert: Toggle to ``True`` to enable replacement even if the object doesn't exist yet. """ - try: info = self.metadata.get_component( object.type_id, object.identifier, version=object.version @@ -1059,7 +1048,6 @@ def select_nearest( :param outputs: (Optional) Seed outputs dictionary. :param n: Get top k results from vector search. """ - # TODO - make this un-ambiguous if not isinstance(like, Document): assert isinstance(like, dict) @@ -1076,9 +1064,7 @@ def select_nearest( return vi.get_nearest(like, db=self, ids=ids, n=n, outputs=outs) def close(self): - """ - Gracefully shutdown the Datalayer. - """ + """Gracefully shutdown the Datalayer.""" logging.info("Disconnect from Data Store") self.databackend.disconnect() @@ -1097,6 +1083,7 @@ def close(self): def _add_component_to_cache(self, component: Component): """ Add component to cache when it is added to the db. + Avoiding the need to load it from the db again. """ type_id = component.type_id @@ -1107,8 +1094,7 @@ def _add_component_to_cache(self, component: Component): def infer_schema( self, data: t.Mapping[str, t.Any], identifier: t.Optional[str] = None ) -> Schema: - """ - Infer a schema from a given data object + """Infer a schema from a given data object. :param data: The data object :param identifier: The identifier for the schema, if None, it will be generated @@ -1120,8 +1106,7 @@ def infer_schema( @dc.dataclass class LoadDict(dict): """ - Helper class to load component identifiers with on-demand - loading from the database. + Helper class to load component identifiers with on-demand loading from the database. :param database: Instance of Datalayer. :param field: (optional) Component type identifier. diff --git a/superduperdb/base/decorators.py b/superduperdb/base/decorators.py index 56c5cbc544..dbf35e694d 100644 --- a/superduperdb/base/decorators.py +++ b/superduperdb/base/decorators.py @@ -1,3 +1,7 @@ def code(my_callable): + """Decorator to mark a function as remote code. + + :param my_callable: The callable to mark as remote code. + """ my_callable.is_remote_code = True return my_callable diff --git a/superduperdb/base/document.py b/superduperdb/base/document.py index 3b8f3b8556..b1742b89d1 100644 --- a/superduperdb/base/document.py +++ b/superduperdb/base/document.py @@ -35,11 +35,10 @@ class Document(MongoStyleDict): - """ - A wrapper around an instance of dict or a Encodable which may be used to dump - that resource to a mix of json-able content, ids and `bytes` + """A wrapper around an instance of dict or a Encodable. + + The document data is used to dump that resource to a mix of json-able content, ids and `bytes` - :param content: The content to wrap """ _DEFAULT_ID_KEY: str = '_id' @@ -49,22 +48,35 @@ def encode( schema: t.Optional['Schema'] = None, leaf_types_to_keep: t.Sequence[t.Type] = (), ) -> t.Dict: - """Make a copy of the content with all the Leaves encoded""" + """Make a copy of the content with all the Leaves encoded. + + :param schema: The schema to encode with. + :param leaf_types_to_keep: The types of leaves to keep. + """ if schema is not None: return _encode_with_schema(dict(self), schema) return _encode(dict(self), leaf_types_to_keep) def get_leaves(self, *leaf_types: str): + """Get all the leaves in the document. + + :param *leaf_types: The types of leaves to get. + """ keys, leaves = _find_leaves(self, *leaf_types) return dict(zip(keys, leaves)) @property def variables(self) -> t.List[str]: + """Return a list of variables in the object.""" from superduperdb.base.serializable import _find_variables return _find_variables(self) - def set_variables(self, db, **kwargs) -> 'Document': + def set_variables(self, db: 'Datalayer', **kwargs) -> 'Document': + """Set free variables of self. + + :param db: The datalayer to use. + """ from superduperdb.base.serializable import _replace_variables content = _replace_variables( @@ -73,10 +85,12 @@ def set_variables(self, db, **kwargs) -> 'Document': return Document(**content) @staticmethod - def decode( - r: t.Dict, - db: t.Optional['Datalayer'] = None, - ) -> t.Any: + def decode(r: t.Dict, db: t.Optional['Datalayer'] = None) -> t.Any: + """Decode the object from a encoded data. + + :param r: Encoded data. + :param db: Datalayer instance. + """ cache = {} if '_leaves' in r: r['_leaves'] = _build_leaves(r['_leaves'], db=db) @@ -90,7 +104,11 @@ def __repr__(self) -> str: return f'Document({repr(dict(self))})' def unpack(self, db=None, leaves_to_keep: t.Sequence = ()) -> t.Any: - """Returns the content, but with any encodables replaced by their contents""" + """Returns the content, but with any encodables replaced by their contents. + + :param db: The datalayer to use. + :param leaves_to_keep: The types of leaves to keep. + """ out = _unpack(self, db=db, leaves_to_keep=leaves_to_keep) if '_base' in out: out = out['_base'] @@ -219,6 +237,11 @@ def _unpack(item: t.Any, db=None, leaves_to_keep: t.Sequence = ()) -> t.Any: class NotBuiltError(Exception): + """Exception for when a leaf is not built. + + :param key: The key that was not built. + """ + def __init__(self, *args, key, **kwargs): super().__init__(*args, **kwargs) self.key = key diff --git a/superduperdb/base/enums.py b/superduperdb/base/enums.py index 8b874244fb..6eb7ccd728 100644 --- a/superduperdb/base/enums.py +++ b/superduperdb/base/enums.py @@ -2,9 +2,7 @@ class DBType(str, Enum): - """ - DBType is an enumeration of the supported database types. - """ + """DBType is an enumeration of the supported database types.""" SQL = "SQL" MONGODB = "MONGODB" diff --git a/superduperdb/base/exceptions.py b/superduperdb/base/exceptions.py index 08747243b9..b6488a26b5 100644 --- a/superduperdb/base/exceptions.py +++ b/superduperdb/base/exceptions.py @@ -2,18 +2,22 @@ class ComponentInUseError(Exception): + """Exception raised when a component is already in use.""" + pass class ComponentInUseWarning(Warning): + """Warning raised when a component is already in use.""" + pass class BaseException(Exception): - ''' - BaseException which logs a message after - exception - ''' + """BaseException which logs a message after exception. + + :param msg: The message to log. + """ def __init__(self, msg): self.msg = msg @@ -24,42 +28,28 @@ def __str__(self): class RequiredPackageVersionsNotFound(ImportError): - ''' - Exception raised when one or more required packages are not found. - ''' + """Exception raised when one or more required packages are not found.""" class RequiredPackageVersionsWarning(ImportWarning): - ''' - Exception raised when one or more required packages are not found. - ''' + """Exception raised when one or more required packages are not found.""" class ServiceRequestException(BaseException): - ''' - ServiceRequestException - ''' + """ServiceRequestException.""" class QueryException(BaseException): - ''' - QueryException - ''' + """QueryException.""" class DatabackendException(BaseException): - ''' - DatabackendException - ''' + """DatabackendException.""" class MetadataException(BaseException): - ''' - MetadataException - ''' + """MetadataException.""" class ComponentException(BaseException): - ''' - ComponentException - ''' + """ComponentException.""" diff --git a/superduperdb/base/leaf.py b/superduperdb/base/leaf.py index 9705abf77d..4483304c2a 100644 --- a/superduperdb/base/leaf.py +++ b/superduperdb/base/leaf.py @@ -6,28 +6,37 @@ class Leaf(ABC): + """Base class for all leaf classes.""" + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) cls._register_class() @classmethod def handle_integration(cls, r): + """Method to handle integration. + + :param r: Encoded data. + """ return r @classmethod def _register_class(cls): - """ - Register class in the class registry and set the full import path - """ + """Register class in the class registry and set the full import path.""" full_import_path = f"{cls.__module__}.{cls.__name__}" cls.full_import_path = full_import_path _CLASS_REGISTRY[full_import_path] = cls @abstractproperty def unique_id(self): + """Unique identifier for the object.""" pass def unpack(self, db=None): + """Unpack object. + + :param db: Datalayer instance. + """ return self @abstractmethod @@ -35,17 +44,28 @@ def encode( self, leaf_types_to_keep: t.Sequence = (), ): - """Convert object to a saveable form""" + """Convert object to a saveable form. + + :param leaf_types_to_keep: Leaf types to keep. + """ pass @classmethod @abstractmethod def decode(cls, r, db=None): - """Decode object from a saveable form""" + """Decode object from a saveable form. + + :param r: Encoded data. + :param db: Datalayer instance. + """ pass @classmethod def build(cls, r): + """Build object from an encoded data. + + :param r: Encoded data. + """ modified = { k: v for k, v in r.items() @@ -54,9 +74,16 @@ def build(cls, r): return cls(**modified) def init(self, db=None): + """Initialize object. + + :param db: Datalayer instance. + """ pass def find_leaf_cls(full_import_path) -> t.Type[Leaf]: - """Find leaf class by class full import path""" + """Find leaf class by class full import path. + + :param full_import_path: Full import path of the class. + """ return _CLASS_REGISTRY[full_import_path] diff --git a/superduperdb/base/logger.py b/superduperdb/base/logger.py index 734e402eef..8fc3ed4d52 100644 --- a/superduperdb/base/logger.py +++ b/superduperdb/base/logger.py @@ -14,6 +14,8 @@ class Logging: + """Logging class to handle logging for the SuperDuperDB.""" + if CFG.logging_type == LogType.LOKI: # Send logs to Loki custom_handler = LokiLoggerHandler( url=os.environ["LOKI_URI"], @@ -75,26 +77,53 @@ class Logging: # Example: logging.info("param 1", "param 2", ..) @staticmethod def multikey_debug(msg: str, *args): + """Log a message with the DEBUG level. + + :param msg: The message to log. + """ logger.opt(depth=1).debug(" ".join(map(str, (msg, *args)))) @staticmethod def multikey_info(msg: str, *args): + """Log a message with the INFO level. + + :param msg: The message to log. + """ logger.opt(depth=1).info(" ".join(map(str, (msg, *args)))) @staticmethod def multikey_success(msg: str, *args): + """Log a message with the SUCCESS level. + + :param msg: The message to log. + """ logger.opt(depth=1).success(" ".join(map(str, (msg, *args)))) @staticmethod def multikey_warn(msg: str, *args): + """Log a message with the WARNING level. + + :param msg: The message to log. + """ logger.opt(depth=1).warning(" ".join(map(str, (msg, *args)))) @staticmethod def multikey_error(msg: str, *args): + """Log a message with the ERROR level. + + :param msg: The message to log. + """ logger.opt(depth=1).error(" ".join(map(str, (msg, *args)))) @staticmethod def multikey_exception(msg: str, *args, e=None): + """Log a message with the ERROR level. + + e.g. logger.exception("An error occurred", e) + + :param msg: The message to log. + :param e: The exception to log. + """ logger.opt(depth=1, exception=e).error(" ".join(map(str, (msg, *args)))) debug = multikey_debug diff --git a/superduperdb/base/serializable.py b/superduperdb/base/serializable.py index f2d098b91f..6a9fb081e5 100644 --- a/superduperdb/base/serializable.py +++ b/superduperdb/base/serializable.py @@ -37,6 +37,8 @@ def _from_dict(r: t.Any, db: None = None) -> t.Any: class VariableError(Exception): + """Variable error.""" + ... @@ -91,8 +93,9 @@ def _replace_variables(x, db, **kwargs): @dc.dataclass class Serializable(Leaf): - """ - Base class for serializable objects. This class is used to serialize and + """Base class for serializable objects. + + This class is used to serialize and deserialize objects to and from JSON + Artifact instances. """ @@ -107,13 +110,19 @@ def _deep_flat_encode(self, cache): @property def unique_id(self): + """Return a unique id for the object.""" return hash(str(self.dict().encode())) @property def variables(self) -> t.List['Variable']: + """Return a list of variables in the object.""" return sorted(list(set(self.dict().variables)), key=lambda x: x.value) def getattr_with_path(self, path): + """Get attribute with path. + + :param path: Path to the attribute. + """ assert path item = self for x in path: @@ -125,6 +134,11 @@ def getattr_with_path(self, path): return item def setattr_with_path(self, path, value): + """Set attribute with path. + + :param path: Path to the attribute. + :param value: Value to set. + """ if len(path) == 1: return setattr(self, path[0], value) else: @@ -133,10 +147,9 @@ def setattr_with_path(self, path, value): return def set_variables(self, db, **kwargs) -> 'Serializable': - """ - Set free variables of self. + """Set free variables of self. - :param db: + :param db: Datalayer instance. """ r = self.encode(leaf_types_to_keep=(Variable,)) r = _replace_variables(r, db, **kwargs) @@ -146,27 +159,38 @@ def encode( self, leaf_types_to_keep: t.Sequence = (), ): + """Encode the object to a dictionary. + + :param leaf_types_to_keep: Leaf types to keep. + """ r = dict(self.dict().encode(leaf_types_to_keep=leaf_types_to_keep)) r['leaf_type'] = 'serializable' return {'_content': r} @classmethod def decode(cls, r, db: t.Optional[t.Any] = None): + """Decode the object from a encoded data. + + :param r: Encoded data. + :param db: Datalayer instance. + """ return _from_dict(r, db=db) def dict(self): + """Return a dictionary representation of the object.""" from superduperdb import Document return Document(asdict(self)) def copy(self): + """Return a deep copy of the object.""" return deepcopy(self) @dc.dataclass class Variable(Serializable): - """ - Mechanism for allowing "free variables" in a serializable object. + """Mechanism for allowing "free variables" in a serializable object. + The idea is to allow a variable to be set at runtime, rather than at object creation time. @@ -193,7 +217,7 @@ def set(self, db, **kwargs): Get the intended value from the values of the global variables. :param db: The datalayer instance. - :param kwargs: Variables to be used in the setter_callback + :param **kwargs: Variables to be used in the setter_callback or as formatting variables. >>> Variable('number').set(db, number=1.5, other='test') diff --git a/superduperdb/base/superduper.py b/superduperdb/base/superduper.py index ed7d689e1c..6ffdaf080d 100644 --- a/superduperdb/base/superduper.py +++ b/superduperdb/base/superduper.py @@ -7,13 +7,13 @@ def superduper(item: t.Optional[t.Any] = None, **kwargs) -> t.Any: - """ + """Superduper API to automatically wrap an object to a db or a component. + Attempts to automatically wrap an item in a superduperdb component by using duck typing to recognize it. :param item: A database or model """ - if item is None: item = CFG.data_backend @@ -65,9 +65,10 @@ def run(item: t.Any, **kwargs) -> t.Any: raise ValueError(f'{item} matched more than one type: {dts}') + # TODO: Does this item match the DuckType? @classmethod def accept(cls, item: t.Any) -> bool: - """Does this item match the DuckType? + """Check if an item matches the DuckType. The default implementation returns True if the number of attrs that the item has is exactly equal to self.count. @@ -77,7 +78,11 @@ def accept(cls, item: t.Any) -> bool: @classmethod def create(cls, item: t.Any, **kwargs) -> t.Any: - """Create a superduperdb component for an item that has already been accepted""" + """Create a component from the item. + + This method should be implemented by subclasses. + :param item: The item to create the component from. + """ raise NotImplementedError _DUCK_TYPES: t.List[t.Type] = [] @@ -88,15 +93,29 @@ def __init_subclass__(cls, **kwargs): class MongoDbTyper(_DuckTyper): + """A DuckTyper for MongoDB databases. + + This DuckTyper is used to automatically wrap a MongoDB database in a + Datalayer. + """ + attrs = ('list_collection_names',) count = len(attrs) @classmethod def accept(cls, item: t.Any) -> bool: + """Check if an item is a MongoDB database. + + :param item: The item to check. + """ return super().accept(item) and item.__class__.__name__ == 'Database' @classmethod def create(cls, item: t.Any, **kwargs) -> t.Any: + """Create a Datalayer from a MongoDB database. + + :param item: A MongoDB database. + """ from mongomock.database import Database as MockDatabase from pymongo.database import Database @@ -120,11 +139,21 @@ def create(cls, item: t.Any, **kwargs) -> t.Any: class SklearnTyper(_DuckTyper): + """A DuckTyper for scikit-learn estimators. + + This DuckTyper is used to automatically wrap a scikit-learn estimator in + an Estimator. + """ + attrs = '_predict', 'fit', 'score', 'transform' count = 2 @classmethod def create(cls, item: t.Any, **kwargs) -> t.Any: + """Create an Estimator from a scikit-learn estimator. + + :param item: A scikit-learn estimator. + """ from sklearn.base import BaseEstimator from superduperdb.ext.sklearn.model import Estimator @@ -137,11 +166,21 @@ def create(cls, item: t.Any, **kwargs) -> t.Any: class TorchTyper(_DuckTyper): + """A DuckTyper for torch.nn.Module and torch.jit.ScriptModule. + + This DuckTyper is used to automatically wrap a torch.nn.Module or + torch.jit.ScriptModule in a TorchModel. + """ + attrs = 'forward', 'parameters', 'state_dict', '_load_from_state_dict' count = len(attrs) @classmethod def create(cls, item: t.Any, **kwargs) -> t.Any: + """Create a TorchModel from a torch.nn.Module or torch.jit.ScriptModule. + + :param item: A torch.nn.Module or torch.jit.ScriptModule. + """ from torch import jit, nn from superduperdb.ext.torch.model import TorchModel diff --git a/superduperdb/cdc/app.py b/superduperdb/cdc/app.py index 3658552dd4..11c7ab5086 100644 --- a/superduperdb/cdc/app.py +++ b/superduperdb/cdc/app.py @@ -12,13 +12,19 @@ @app.startup def cdc_startup(db: Datalayer): + """Start the cdc server. + + :param db: Datalayer instance. + """ db.cdc.start() @app.add('/listener/add', method='get') def add_listener(name: str, db: Datalayer = superduperapp.DatalayerDependency()): - """ - Endpoint for adding a listener to cdc + """Endpoint for adding a listener to cdc. + + :param name: Listener identifier. + :param db: Datalayer instance. """ listener = db.load('listener', name) assert isinstance(listener, Listener) @@ -27,8 +33,10 @@ def add_listener(name: str, db: Datalayer = superduperapp.DatalayerDependency()) @app.add('/listener/delete', method='get') def remove_listener(name: str, db: Datalayer = superduperapp.DatalayerDependency()): - """ - Endpoint for removing a listener from cdc + """Endpoint for removing a listener from cdc. + + :param name: Listener identifier. + :param db: Datalayer instance. """ listener = db.load('listener', name) assert isinstance(listener, Listener) diff --git a/superduperdb/cdc/cdc.py b/superduperdb/cdc/cdc.py index 7cae6c5d5c..a561784a8e 100644 --- a/superduperdb/cdc/cdc.py +++ b/superduperdb/cdc/cdc.py @@ -1,4 +1,5 @@ -""" +"""CDC module for superduperdb. + Change Data Capture (CDC) is a mechanism used in database systems to track and capture changes made to a table or collection in real-time. It allows applications to stay up-to-date with the latest changes in the database @@ -58,18 +59,27 @@ class DBEvent(str, Enum): @dc.dataclass class Packet: + """Packet to hold the cdc event data. + + :param ids: Document ids. + :param query: Query to fetch the document. + :param event_type: CDC event type. + """ + ids: t.Any query: t.Optional['Serializable'] event_type: DBEvent = DBEvent.insert @property def is_delete(self) -> bool: + """Check if the event is delete.""" return self.event_type == DBEvent.delete @staticmethod def collate(packets: t.Sequence['Packet']) -> 'Packet': - """ - Collate a batch of packets into one + """Collate a batch of packets into one. + + :param packets: A list of packets. """ assert packets ids = [] @@ -87,8 +97,16 @@ def collate(packets: t.Sequence['Packet']) -> 'Packet': class BaseDatabaseListener(ABC): - """ - A Base class which defines basic functions to implement. + """A Base class which defines basic functions to implement. + + This class is responsible for defining the basic functions + that needs to be implemented by the database listener. + + :param db: A superduperdb instance. + :param on: A table or collection on which the listener is invoked. + :param stop_event: A threading event flag to notify for stoppage. + :param identifier: A identity given to the listener service. + :param timeout: A timeout for the listener. """ IDENTITY_SEP: str = '/' @@ -115,6 +133,7 @@ def __init__( @property def identity(self) -> str: + """Get the database listener identity.""" return self._identifier @classmethod @@ -127,9 +146,7 @@ def _build_identifier(cls, identifiers) -> str: return cls.IDENTITY_SEP.join(identifiers) def info(self) -> t.Dict: - """ - Get info on the current state of listener. - """ + """Get info on the current state of listener.""" info = {} info.update( { @@ -143,30 +160,40 @@ def info(self) -> t.Dict: @abstractmethod def listen(self): + """Start the database listener.""" raise NotImplementedError @abstractmethod def stop(self): + """Stop the database listener.""" raise NotImplementedError @abstractmethod def setup_cdc(self) -> CollectionChangeStream: + """setup_cdc.""" raise NotImplementedError @abstractmethod def on_create(self, *args, **kwargs): + """Handle the create event.""" raise NotImplementedError @abstractmethod def on_update(self, *args, **kwargs): + """Handle the update event.""" raise NotImplementedError @abstractmethod def on_delete(self, *args, **kwargs): + """Handle the delete event.""" raise NotImplementedError @abstractmethod def next_cdc(self, stream: CollectionChangeStream) -> None: + """next_cdc. + + :param stream: CollectionChangeStream + """ raise NotImplementedError def create_event( @@ -176,10 +203,11 @@ def create_event( table_or_collection: t.Union['Table', 'TableOrCollection'], event: DBEvent, ): - """ + """Create an event. + A helper to create packet based on the event type and put it on the cdc queue - :param change: The changed document. + :param ids: Document ids :param db: a superduperdb instance. :param table_or_collection: The collection on which change was observed. :param event: CDC event type @@ -192,7 +220,8 @@ def create_event( db.cdc.CDC_QUEUE.put_nowait(self.packet(ids, cdc_query, event)) def event_handler(self, ids: t.Sequence, event: DBEvent) -> None: - """event_handler. + """Handle the incoming change stream event. + A helper fxn to handle incoming changes from change stream on a collection. :param ids: Changed document ids @@ -212,6 +241,16 @@ def event_handler(self, ids: t.Sequence, event: DBEvent) -> None: class DatabaseListenerThreadScheduler(threading.Thread): + """DatabaseListenerThreadScheduler to listen to the cdc changes. + + This class is responsible for listening to the cdc changes and + executing the following job. + + :param listener: A BaseDatabaseListener instance. + :param stop_event: A threading event flag to notify for stoppage. + :param start_event: A threading event flag to notify for start. + """ + def __init__( self, listener: BaseDatabaseListener, @@ -224,6 +263,7 @@ def __init__( self.listener = listener def run(self) -> None: + """Start to listen to the cdc changes.""" try: cdc_stream = self.listener.setup_cdc() self.start_event.set() @@ -235,10 +275,15 @@ def run(self) -> None: class CDCHandler(threading.Thread): - """ + """CDCHandler for handling CDC changes. + This class is responsible for handling the change by executing the taskflow. This class also extends the task graph by adding funcation job node which does post model executiong jobs, i.e `copy_vectors`. + + :param db: A superduperdb instance. + :param stop_event: A threading event flag to notify for stoppage. + :param queue: A queue to hold the cdc packets. """ def __init__(self, db: 'Datalayer', stop_event: Event, queue): @@ -247,7 +292,6 @@ def __init__(self, db: 'Datalayer', stop_event: Event, queue): :param db: a superduperdb instance. :param stop_event: A threading event flag to notify for stoppage. """ - self.db = db self._stop_event = stop_event self._is_running = False @@ -256,9 +300,11 @@ def __init__(self, db: 'Datalayer', stop_event: Event, queue): @property def is_running(self): + """Check if the cdc handler is running.""" return self._is_running def run(self): + """Run the cdc handler.""" self._is_running = True try: for c in queue_chunker(self.cdc_queue, self._stop_event): @@ -282,8 +328,12 @@ def run(self): class DatabaseListenerFactory(t.Generic[DBListenerType]): - """A Factory class to create instance of DatabaseListener corresponding to the - `db_type`. + """DatabaseListenerFactory to create listeners for different databases. + + This class is responsible for creating a DatabaseListener instance + based on the database type. + + :param db_type: Database type. """ SUPPORTED_LISTENERS: t.List[str] = ['mongodb', 'ibis'] @@ -294,6 +344,7 @@ def __init__(self, db_type: str = 'mongodb'): self.db_type = db_type def create(self, *args, **kwargs) -> DBListenerType: + """Create a DatabaseListener instance.""" stop_event = Event() kwargs['stop_event'] = stop_event if self.db_type == 'mongodb': @@ -315,7 +366,8 @@ def create(self, *args, **kwargs) -> DBListenerType: class DatabaseChangeDataCapture: - """ + """DatabaseChangeDataCapture (CDC). + DatabaseChangeDataCapture is a Python class that provides a flexible and extensible framework for capturing and managing data changes in a database. @@ -324,6 +376,8 @@ class DatabaseChangeDataCapture: This class is designed to simplify the process of tracking changes to database records,allowing you to monitor and respond to data modifications efficiently. + + :param db: A superduperdb datalayer instance. """ def __init__(self, db: 'Datalayer'): @@ -348,12 +402,11 @@ def __init__(self, db: 'Datalayer'): @property def running(self) -> bool: + """Check if the cdc service is running.""" return self._running or CFG.cluster.cdc.uri is not None def start(self): - """ - This method starts the cdc process on the database. - """ + """Start the cdc service.""" self._running = True # Listen to existing collection without cdc enabled @@ -367,11 +420,10 @@ def listen( *args, **kwargs, ): - """ - Starts cdc service on the provided collection + """Starts cdc service on the provided collection. + Not to be confused with ``superduperdb.container.listener.Listener``. - :param db: A superduperdb instance. :param on: Which collection/table listener service this be invoked on? :param identifier: A identity given to the listener service. """ @@ -401,8 +453,8 @@ def listen( return listener def stop(self, name: str = ''): - """ - Stop all registered listeners + """Stop all registered listeners. + :param name: Listener name """ try: @@ -422,17 +474,16 @@ def stop(self, name: str = ''): self.stop_handler() def stop_handler(self): - """ - Stop the cdc handler thread - """ + """Stop the cdc handler thread.""" self._cdc_stop_event.set() if self.cdc_change_handler: self.cdc_change_handler.join() self.cdc_change_handler = None def add(self, listener: 'Listener'): - """ - This method registered the given collection for cdc + """Register a listener to the cdc service. + + :param listener: A listener instance. """ collection = listener.select.table_or_collection if self.running and collection.identifier not in self._CDC_LISTENERS: diff --git a/superduperdb/cli/config.py b/superduperdb/cli/config.py index 7ab70cf395..ada2a672e8 100644 --- a/superduperdb/cli/config.py +++ b/superduperdb/cli/config.py @@ -13,6 +13,10 @@ def config( False, '--schema', '-s', help='If set, print the JSON schema for the model' ), ): + """Print all the SuperDuperDB configs as JSON. + + :param schema: If set, print the JSON schema for the model. + """ d = CFG.to_yaml() if schema else CFG.dict() if schema: print(CFG.to_yaml()) diff --git a/superduperdb/cli/info.py b/superduperdb/cli/info.py index e2bef83cdf..a624799202 100644 --- a/superduperdb/cli/info.py +++ b/superduperdb/cli/info.py @@ -17,6 +17,7 @@ @command(help='Print information about the current machine and installation') def info(): + """Print information about the current machine and installation.""" print('```') print(json.dumps(_get_info(), default=str, indent=2)) print('```') @@ -24,6 +25,10 @@ def info(): @command(help='Print information about the current machine and installation') def requirements(ext: t.List[str]): + """Print information about the current machine and installation. + + :param ext: Extensions to check. + """ out = [] for e in ext: try: diff --git a/superduperdb/cli/serve.py b/superduperdb/cli/serve.py index 0b579c4cdb..93b7a87046 100644 --- a/superduperdb/cli/serve.py +++ b/superduperdb/cli/serve.py @@ -6,6 +6,11 @@ @command(help='Start local cluster: server, ray and change data capture') def local_cluster(action: str, notebook_token: t.Optional[str] = None): + """Start local cluster: server, ray and change data capture. + + :param action: Action to perform (up, down, attach). + :param notebook_token: Notebook token. + """ from superduperdb.server.cluster import attach_cluster, down_cluster, up_cluster action = action.lower() @@ -20,6 +25,7 @@ def local_cluster(action: str, notebook_token: t.Optional[str] = None): @command(help='Start vector search server') def vector_search(): + """Start vector search server.""" from superduperdb.vector_search.server.app import app app.start() @@ -27,6 +33,7 @@ def vector_search(): @command(help='Start standalone change data capture') def cdc(): + """Start standalone change data capture.""" from superduperdb.cdc.app import app app.start() @@ -39,6 +46,13 @@ def ray_serve( ray_actor_options: str = '', num_replicas: int = 1, ): + """Serve a model on ray. + + :param model: Model name. + :param version: Model version. + :param ray_actor_options: Ray actor options. + :param num_replicas: Number of replicas. + """ from superduperdb.backends.ray.serve import run run( @@ -51,6 +65,7 @@ def ray_serve( @command(help='Start FastAPI REST server') def rest(): + """Start FastAPI REST server.""" from superduperdb.rest.app import app app.start() diff --git a/superduperdb/cli/stack.py b/superduperdb/cli/stack.py index 853c74c7a9..28c48c7fe2 100644 --- a/superduperdb/cli/stack.py +++ b/superduperdb/cli/stack.py @@ -3,4 +3,9 @@ @command(help='Apply the stack tarball to the database') def apply(yaml_path: str, identifier: str): + """Apply the stack tarball to the database. + + :param yaml_path: Path to the stack tarball. + :param identifier: Stack identifier. + """ raise NotImplementedError diff --git a/superduperdb/components/__init__.py b/superduperdb/components/__init__.py index 9df1a80162..dca51df700 100644 --- a/superduperdb/components/__init__.py +++ b/superduperdb/components/__init__.py @@ -1,5 +1,5 @@ -""" -The core package provides the core functionality of SuperDuperDB. +"""The core package provides the core functionality of SuperDuperDB. + This includes the main wrappers and classes for communicating with the database and for defining AI functionality. """ diff --git a/superduperdb/components/component.py b/superduperdb/components/component.py index 62a0594fe1..b0a87bc97c 100644 --- a/superduperdb/components/component.py +++ b/superduperdb/components/component.py @@ -1,6 +1,4 @@ -""" -The component module provides the base class for all components in SuperDuperDB. -""" +"""The component module provides the base class for all components in SuperDuperDB.""" from __future__ import annotations @@ -26,8 +24,7 @@ def import_(r=None, path=None, db=None): - """ - Helper function for importing component JSONs, YAMLs, etc. + """Helper function for importing component JSONs, YAMLs, etc. :param r: Object to be imported. :param path: Components directory. @@ -52,8 +49,7 @@ def import_(r=None, path=None, db=None): def getdeepattr(obj, attr): - """ - Get nested attribute with dot notation. + """Get nested attribute with dot notation. :param obj: Object. :param attr: Attribute. @@ -68,9 +64,10 @@ def getdeepattr(obj, attr): @dc.dataclass class Component(Serializable, Leaf): - """ - Class to represent SuperDuperDB serializable entities that can be saved - into a database. + """Base class for all components in SuperDuperDB. + + Class to represent SuperDuperDB serializable entities + that can be saved into a database. :param identifier: A unique identifier for the component. :param artifacts: List of artifacts which represent entities that are @@ -88,8 +85,7 @@ class Component(Serializable, Leaf): @classmethod def handle_integration(cls, kwargs): - """ - Abstract method for handling integration. + """Abstract method for handling integration. :param kwargs: Integration kwargs. """ @@ -97,16 +93,12 @@ def handle_integration(cls, kwargs): @property def id(self): - """ - Returns the component identifier. - """ + """Returns the component identifier.""" return f'_component/{self.type_id}/{self.identifier}' @property def id_tuple(self): - """ - Returns an object as `ComponentTuple`. - """ + """Returns an object as `ComponentTuple`.""" return ComponentTuple(self.type_id, self.identifier, self.version) def __post_init__(self, artifacts): @@ -118,9 +110,7 @@ def __post_init__(self, artifacts): @classmethod def get_ui_schema(cls): - """ - Helper method to get the UI schema. - """ + """Helper method to get the UI schema.""" out = {} ancestors = cls.mro()[::-1] for a in ancestors: @@ -129,12 +119,10 @@ def get_ui_schema(cls): return list(out.values()) def set_variables(self, db, **kwargs): - """ - Set free variables of self. + """Set free variables of self. :param db: Datalayer instance. """ - r = self.dict() variables = _find_variables_with_path(r['dict']) for r in variables: @@ -144,15 +132,11 @@ def set_variables(self, db, **kwargs): @property def dependencies(self): - """ - Get dependencies on the component. - """ + """Get dependencies on the component.""" return () def init(self): - """ - Method to help initiate component field dependencies. - """ + """Method to help initiate component field dependencies.""" def _init(item): if isinstance(item, Component): @@ -178,9 +162,7 @@ def _init(item): @property def artifact_schema(self): - """ - Returns `Schema` representation for the serializers in the component. - """ + """Returns `Schema` representation for the serializers in the component.""" from superduperdb import Schema from superduperdb.components.datatype import dill_serializer @@ -202,15 +184,12 @@ def artifact_schema(self): @property def db(self) -> Datalayer: - """ - Datalayer instance. - """ + """Datalayer instance.""" return self._db @db.setter def db(self, value: Datalayer): - """ - Datalayer instance property setter. + """Datalayer instance property setter. :param value: Datalayer instance to set. @@ -226,6 +205,7 @@ def pre_create(self, db: Datalayer) -> None: def post_create(self, db: Datalayer) -> None: """Called after the first time this component is created. + Generally used if ``self.version`` is important in this logic. :param db: the db that creates the component. @@ -249,9 +229,7 @@ def _deep_flat_encode(self, cache): return self.id def deep_flat_encode(self): - """ - Encode cache with deep flattened structure. - """ + """Encode cache with deep flattened structure.""" cache = {} id = self._deep_flat_encode(cache) return {'_leaves': list(cache.values()), '_base': id} @@ -266,8 +244,8 @@ def _to_dict_and_bytes(self): return r, bytes def export(self, format=None): - """ - Method to export the component in the provided format. + """Method to export the component in the provided format. + If format is None, the method exports the component in a dictionary. :param format: `json` and `yaml`. @@ -301,9 +279,7 @@ def export(self, format=None): raise NotImplementedError(f'Format {format} not supported') def dict(self) -> 'Document': - """ - A dictionary representation of the component. - """ + """A dictionary representation of the component.""" from superduperdb import Document from superduperdb.components.datatype import Artifact, File @@ -326,8 +302,7 @@ def encode( self, leaf_types_to_keep: t.Sequence = (), ): - """ - Method to encode the component into a dictionary. + """Method to encode the component into a dictionary. :param leaf_types_to_keep: Leaf types to be excluded from encoding. """ @@ -339,8 +314,7 @@ def encode( @classmethod def decode(cls, r, db: t.Optional[t.Any] = None, reference: bool = False): - """ - Decodes a dictionary component into a `Component` instance. + """Decodes a dictionary component into a `Component` instance. :param r: Object to be decoded. :param db: Datalayer instance. @@ -353,9 +327,7 @@ def decode(cls, r, db: t.Optional[t.Any] = None, reference: bool = False): @property def unique_id(self) -> str: - """ - Method to get a unique identifier for the component. - """ + """Method to get a unique identifier for the component.""" if getattr(self, 'version', None) is None: raise Exception('Version not yet set for component uniqueness') return f'{self.type_id}/{self.identifier}/{self.version}' @@ -365,8 +337,7 @@ def create_validation_job( validation_set: t.Union[str, Dataset], metrics: t.Sequence[str], ) -> ComponentJob: - """ - Method to create a validation job with `validation_set` and `metrics`. + """Method to create a validation job with `validation_set` and `metrics`. :param validation_set: Kwargs for the `predict` method of `Model`. :param metrics: Kwargs for the `predict` method of `Model` to set @@ -388,8 +359,7 @@ def schedule_jobs( db: Datalayer, dependencies: t.Sequence[Job] = (), ) -> t.Sequence[t.Any]: - """ - Run the job for this listener. + """Run the job for this listener. :param db: The db to process. :param dependencies: A sequence of dependencies. @@ -398,8 +368,7 @@ def schedule_jobs( @classmethod def make_unique_id(cls, type_id: str, identifier: str, version: int) -> str: - """ - Class method to create a unique identifier. + """Class method to create a unique identifier. :param type_id: Component type id. :param identifier: Unique identifier. @@ -414,8 +383,7 @@ def __setattr__(self, k, v): def ensure_initialized(func): - """ - Decorator to ensure that the model is initialized before calling the function. + """Decorator to ensure that the model is initialized before calling the function. :param func: Decorator function. """ diff --git a/superduperdb/components/dataset.py b/superduperdb/components/dataset.py index f491a8a831..bf300d95a0 100644 --- a/superduperdb/components/dataset.py +++ b/superduperdb/components/dataset.py @@ -50,8 +50,7 @@ class Dataset(Component): def __post_init__(self, artifacts): """Post-initialization method. - Args: - artifacts: Optional additional artifacts for initialization. + :param artifacts: Optional additional artifacts for initialization. """ self._data = None return super().__post_init__(artifacts) @@ -59,25 +58,19 @@ def __post_init__(self, artifacts): @property @ensure_initialized def data(self): - """ - Property representing the dataset's data. - """ + """Property representing the dataset's data.""" return self._data def init(self): - """ - Initialization method. - """ + """Initialization method.""" super().init() self._data = [Document.decode(r, self.db) for r in pickle_decode(self.raw_data)] @override def pre_create(self, db: 'Datalayer') -> None: - """ - Pre-create hook for database operations. + """Pre-create hook for database operations. - Args: - db: The Datalayer instance. + :param db: The database to use for the operation. """ if self.raw_data is None: if self.select is None: @@ -90,7 +83,5 @@ def pre_create(self, db: 'Datalayer') -> None: @cached_property def random(self): - """ - Cached property representing the random number generator. - """ + """Cached property representing the random number generator.""" return numpy.random.default_rng(seed=self.random_seed) diff --git a/superduperdb/components/datatype.py b/superduperdb/components/datatype.py index 1f375435ac..987429a566 100644 --- a/superduperdb/components/datatype.py +++ b/superduperdb/components/datatype.py @@ -29,17 +29,14 @@ class IntermidiaType: - """ - Intermidia data type - """ + """Intermidia data type.""" BYTES = 'bytes' STRING = 'string' def json_encode(object: t.Any, info: t.Optional[t.Dict] = None) -> str: - """ - Encode the dict to a JSON string + """Encode the dict to a JSON string. :param object: The object to encode :param info: Optional information @@ -48,8 +45,7 @@ def json_encode(object: t.Any, info: t.Optional[t.Dict] = None) -> str: def json_decode(b: str, info: t.Optional[t.Dict] = None) -> t.Any: - """ - Decode the JSON string to an dict + """Decode the JSON string to an dict. :param b: The JSON string to decode :param info: Optional information @@ -58,8 +54,7 @@ def json_decode(b: str, info: t.Optional[t.Dict] = None) -> t.Any: def pickle_encode(object: t.Any, info: t.Optional[t.Dict] = None) -> bytes: - """ - Encodes an object using pickle. + """Encodes an object using pickle. :param object: The object to encode. :param info: Optional information. @@ -68,8 +63,7 @@ def pickle_encode(object: t.Any, info: t.Optional[t.Dict] = None) -> bytes: def pickle_decode(b: bytes, info: t.Optional[t.Dict] = None) -> t.Any: - """ - Decodes bytes using pickle. + """Decodes bytes using pickle. :param b: The bytes to decode. :param info: Optional information. @@ -78,8 +72,7 @@ def pickle_decode(b: bytes, info: t.Optional[t.Dict] = None) -> t.Any: def dill_encode(object: t.Any, info: t.Optional[t.Dict] = None) -> bytes: - """ - Encodes an object using dill. + """Encodes an object using dill. :param object: The object to encode. :param info: Optional information. @@ -88,8 +81,7 @@ def dill_encode(object: t.Any, info: t.Optional[t.Dict] = None) -> bytes: def dill_decode(b: bytes, info: t.Optional[t.Dict] = None) -> t.Any: - """ - Decodes bytes using dill. + """Decodes bytes using dill. :param b: The bytes to decode. :param info: Optional information. @@ -98,8 +90,7 @@ def dill_decode(b: bytes, info: t.Optional[t.Dict] = None) -> t.Any: def file_check(path: t.Any, info: t.Optional[t.Dict] = None) -> str: - """ - Checks if a file path exists. + """Checks if a file path exists. :param path: The file path to check. :param info: Optional information. @@ -111,8 +102,7 @@ def file_check(path: t.Any, info: t.Optional[t.Dict] = None) -> str: def torch_encode(object: t.Any, info: t.Optional[t.Dict] = None) -> bytes: - """ - Saves an object in torch format. + """Saves an object in torch format. :param object: The object to encode. :param info: Optional information. @@ -135,8 +125,7 @@ def torch_encode(object: t.Any, info: t.Optional[t.Dict] = None) -> bytes: def torch_decode(b: bytes, info: t.Optional[t.Dict] = None) -> t.Any: - """ - Decodes bytes to a torch model. + """Decodes bytes to a torch model. :param b: The bytes to decode. :param info: Optional information. @@ -147,8 +136,7 @@ def torch_decode(b: bytes, info: t.Optional[t.Dict] = None) -> t.Any: def bytes_to_base64(bytes): - """ - Converts bytes to base64. + """Converts bytes to base64. :param bytes: The bytes to convert. """ @@ -156,8 +144,7 @@ def bytes_to_base64(bytes): def base64_to_bytes(encoded): - """ - Decodes a base64 encoded string. + """Decodes a base64 encoded string. :param encoded: The base64 encoded string. """ @@ -165,14 +152,12 @@ def base64_to_bytes(encoded): class DataTypeFactory: - """ - Abstract class for creating a DataType - """ + """Abstract class for creating a DataType.""" @abstractstaticmethod def check(data: t.Any) -> bool: - """ - Check if the data can be encoded by the DataType + """Check if the data can be encoded by the DataType. + If the data can be encoded, return True, otherwise False :param data: The data to check @@ -181,8 +166,7 @@ def check(data: t.Any) -> bool: @abstractstaticmethod def create(data: t.Any) -> "DataType": - """ - Create a DataType for the data + """Create a DataType for the data. :param data: The data to create the DataType for """ @@ -192,8 +176,7 @@ def create(data: t.Any) -> "DataType": @public_api(stability='stable') @dc.dataclass(kw_only=True) class DataType(Component): - """ - A data type component that defines how data is encoded and decoded. + """A data type component that defines how data is encoded and decoded. {component_parameters} @@ -252,8 +235,7 @@ class DataType(Component): registered_types: t.ClassVar[t.Dict[str, "DataType"]] = {} def __post_init__(self, artifacts): - """ - Post-initialization hook. + """Post-initialization hook. :param artifacts: The artifacts. """ @@ -264,9 +246,7 @@ def __post_init__(self, artifacts): self.register_datatype(self) def dict(self): - """ - Get the dictionary representation of the object. - """ + """Get the dictionary representation of the object.""" r = super().dict() if hasattr(self.bytes_encoding, 'value'): r['dict']['bytes_encoding'] = str(self.bytes_encoding.value) @@ -275,8 +255,7 @@ def dict(self): def __call__( self, x: t.Optional[t.Any] = None, uri: t.Optional[str] = None ) -> '_BaseEncodable': - """ - Create an instance of the encodable class. + """Create an instance of the encodable class. :param x: The optional content. :param uri: The optional URI. @@ -288,8 +267,7 @@ def __call__( @ensure_initialized def encode_data(self, item, info: t.Optional[t.Dict] = None): - """ - Encode the item into bytes. + """Encode the item into bytes. :param item: The item to encode. :param info: The optional information dictionary. @@ -301,8 +279,7 @@ def encode_data(self, item, info: t.Optional[t.Dict] = None): @ensure_initialized def decode_data(self, item, info: t.Optional[t.Dict] = None): - """ - Decode the item from bytes. + """Decode the item from bytes. :param item: The item to decode. :param info: The optional information dictionary. @@ -312,8 +289,8 @@ def decode_data(self, item, info: t.Optional[t.Dict] = None): return self.decoder(item, info=info) def bytes_encoding_after_encode(self, data): - """ - Encode the data to base64, + """Encode the data to base64. + if the bytes_encoding is BASE64 and the intermidia_type is BYTES :param data: Encoded data @@ -326,8 +303,8 @@ def bytes_encoding_after_encode(self, data): return data def bytes_encoding_before_decode(self, data): - """ - Encode the data to base64, + """Encode the data to base64. + if the bytes_encoding is BASE64 and the intermidia_type is BYTES :param data: Decoded data @@ -341,8 +318,7 @@ def bytes_encoding_before_decode(self, data): @classmethod def register_datatype(cls, instance): - """ - Register a datatype. + """Register a datatype. :param instance: The datatype instance to register. """ @@ -350,8 +326,7 @@ def register_datatype(cls, instance): def encode_torch_state_dict(module, info): - """ - Encode torch state dictionary. + """Encode torch state dictionary. :param module: Module. :param info: Information. @@ -365,8 +340,7 @@ def encode_torch_state_dict(module, info): class DecodeTorchStateDict: - """ - Torch state dictionary decoder. + """Torch state dictionary decoder. :param cls: Torch state cls """ @@ -375,6 +349,11 @@ def __init__(self, cls): self.cls = cls def __call__(self, b: bytes, info: t.Dict): + """Decode the torch state dictionary. + + :param b: Bytes. + :param info: Information. + """ import torch buffer = io.BytesIO(b) @@ -385,8 +364,7 @@ def __call__(self, b: bytes, info: t.Dict): # TODO: Remove this because this function is only used in test cases. def build_torch_state_serializer(module, info): - """ - Datatype for serializing torch state dict. + """Datatype for serializing torch state dict. :param module: Module. :param info: Information. @@ -400,8 +378,7 @@ def build_torch_state_serializer(module, info): def _find_descendants(cls): - """ - Find descendants of the given class. + """Find descendants of the given class. :param cls: The class to find descendants for. """ @@ -413,8 +390,9 @@ def _find_descendants(cls): @dc.dataclass(kw_only=True) class _BaseEncodable(Leaf): - """ - Data variable wrapping encode-able item. Encoding is controlled by the referred + """Data variable wrapping encode-able item. + + Encoding is controlled by the referred to ``Encoder`` instance. :param encoder: Instance of ``Encoder`` controlling encoding. @@ -428,8 +406,7 @@ class _BaseEncodable(Leaf): sha1: t.Optional[str] = None def _deep_flat_encode(self, cache): - """ - Deep flat encode the encodable item. + """Deep flat encode the encodable item. :param cache: Cache to store encoded items. """ @@ -450,16 +427,12 @@ def _deep_flat_encode(self, cache): @property def id(self): - """ - Get the ID of the encodable item. - """ + """Get the ID of the encodable item.""" assert self.file_id is not None return f'_{self.leaf_type}/{self.file_id}' def __post_init__(self): - """ - Post-initialization hook. - """ + """Post-initialization hook.""" if self.uri is not None and self.file_id is None: self.file_id = _construct_file_id_from_uri(self.uri) @@ -468,23 +441,18 @@ def __post_init__(self): @property def unique_id(self): - """ - Get the unique ID of the encodable item. - """ + """Get the unique ID of the encodable item.""" if self.file_id is not None: return self.file_id return str(id(self.x)) @property def reference(self): - """ - Get the reference to the datatype. - """ + """Get the reference to the datatype.""" return self.datatype.reference def unpack(self, db): - """ - Unpack the content of the `Encodable`. + """Unpack the content of the `Encodable`. :param db: Datalayer instance. """ @@ -492,8 +460,8 @@ def unpack(self, db): @classmethod def get_encodable_cls(cls, name, default=None): - """ - Get the subclass of the _BaseEncodable with the given name. + """Get the subclass of the _BaseEncodable with the given name. + All the registered subclasses must be subclasses of the _BaseEncodable. :param name: Name of the subclass. @@ -517,8 +485,7 @@ def get_encodable_cls(cls, name, default=None): @classmethod @abstractmethod def _get_object(cls, db, r): - """ - Get object from the given representation. + """Get object from the given representation. :param db: Datalayer instance. :param r: Representation of the object. @@ -528,8 +495,7 @@ def _get_object(cls, db, r): @classmethod @abstractmethod def decode(cls, r, db=None) -> '_BaseEncodable': - """ - Decode the representation to an instance of _BaseEncodable. + """Decode the representation to an instance of _BaseEncodable. :param r: Representation to decode. :param db: Datalayer instance. @@ -537,8 +503,7 @@ def decode(cls, r, db=None) -> '_BaseEncodable': pass def get_hash(self, data): - """ - Get the hash of the given data. + """Get the hash of the given data. :param data: Data to hash. """ @@ -553,21 +518,16 @@ def get_hash(self, data): class Empty: - """ - Sentinel class. - """ + """Sentinel class.""" def __repr__(self): - """ - Get the string representation of the Empty object. - """ + """Get the string representation of the Empty object.""" return '' @dc.dataclass class Encodable(_BaseEncodable): - """ - Class for encoding non-Python datatypes to the database. + """Class for encoding non-Python datatypes to the database. :param x: The encodable object. """ @@ -590,7 +550,8 @@ def _get_object(cls, db, r): @override def encode(self, leaf_types_to_keep: t.Sequence = ()): - """ + """Encode itself to a specific format. + Encode `self.x` to dictionary format which could be serialized to a database. @@ -614,16 +575,14 @@ def encode(self, leaf_types_to_keep: t.Sequence = ()): @classmethod def build(cls, r): - """ - Build an `Encodable` instance with the given parameters `r`. + """Build an `Encodable` instance with the given parameters `r`. :param r: Parameters for building the `Encodable` instance. """ return cls(**r) def init(self, db): - """ - Initialization method. + """Initialization method. :param db: The Datalayer instance. """ @@ -631,8 +590,7 @@ def init(self, db): @classmethod def decode(cls, r, db=None) -> 'Encodable': - """ - Decode the dictionary `r` to an `Encodable` instance. + """Decode the dictionary `r` to an `Encodable` instance. :param r: The dictionary to decode. :param db: The Datalayer instance. @@ -647,8 +605,7 @@ def decode(cls, r, db=None) -> 'Encodable': @classmethod def get_datatype(cls, db, r): - """ - Get the datatype of the object + """Get the datatype of the object. :param db: `Datalayer` instance to assist with :param r: The object to get the datatype from @@ -670,8 +627,7 @@ def get_datatype(cls, db, r): @dc.dataclass class Native(_BaseEncodable): - """ - Class for representing native data supported by the underlying database. + """Class for representing native data supported by the underlying database. :param x: The encodable object. """ @@ -685,8 +641,7 @@ def _get_object(cls, db, r): @override def encode(self, leaf_types_to_keep: t.Sequence = ()): - """ - Encode the object. + """Encode itself to a specific format. :param leaf_types_to_keep: Leaf nodes to keep from encoding. """ @@ -694,8 +649,7 @@ def encode(self, leaf_types_to_keep: t.Sequence = ()): @classmethod def decode(cls, r, db=None): - """ - Decode the object `r` to a `Native` instance. + """Decode the object `r` to a `Native` instance. :param r: The object to decode. :param db: The Datalayer instance. @@ -712,8 +666,7 @@ def save(self, artifact_store): @dc.dataclass class Artifact(_BaseEncodable, _ArtifactSaveMixin): - """ - Class for representing data to be saved on disk or in the artifact-store. + """Class for representing data to be saved on disk or in the artifact-store. :param x: The artifact object. :param artifact: Whether the object is an artifact. @@ -729,9 +682,7 @@ def _encode(self): return bytes_, sha1 def init(self, db): - """ - Initialization method to seed `x` with the actual object from - the artifact store. + """Initialize the x attribute with the actual value from the artifact store. :param db: The Datalayer instance. """ @@ -745,7 +696,8 @@ def init(self, db): @override def encode(self, leaf_types_to_keep: t.Sequence = ()): - """ + """Encode itself to a specific format. + Encode `self.x` to dictionary format which is later saved in an artifact store. @@ -784,8 +736,7 @@ def _get_object(cls, db, file_id, datatype, uri): return obj def unpack(self, db): - """ - Unpack the content of the `Encodable`. + """Unpack the content of the `Encodable`. :param db: The `Datalayer` instance to assist with unpacking. """ @@ -793,8 +744,7 @@ def unpack(self, db): return self.x def save(self, artifact_store): - """ - Save the encoded data into an artifact store. + """Save the encoded data into an artifact store. :param artifact_store: The artifact store for storing the encoded object. """ @@ -804,8 +754,7 @@ def save(self, artifact_store): @classmethod def decode(cls, r, db=None) -> 'Artifact': - """ - Decode the dictionary `r` into an instance of `Artifact`. + """Decode the dictionary `r` into an instance of `Artifact`. :param r: The dictionary to decode. :param db: The Datalayer instance. @@ -825,9 +774,7 @@ def decode(cls, r, db=None) -> 'Artifact': @dc.dataclass class LazyArtifact(Artifact): - """ - Data to be saved on disk or in the artifact store - and loaded only when needed. + """A class for loading an artifact only when needed. :param artifact: If the object is an artifact """ @@ -841,8 +788,7 @@ def __post_init__(self): @override def encode(self, leaf_types_to_keep: t.Sequence = ()): - """ - Encode `x` in dictionary format for the artifact store. + """Encode `x` in dictionary format for the artifact store. :param leaf_types_to_keep: Leaf nodes to exclude from encoding. @@ -850,8 +796,7 @@ def encode(self, leaf_types_to_keep: t.Sequence = ()): return super().encode(leaf_types_to_keep) def unpack(self, db): - """ - Unpack the content of the `Encodable`. + """Unpack the content of the `Encodable`. :param db: `Datalayer` instance to assist with """ @@ -859,8 +804,7 @@ def unpack(self, db): return self.x def save(self, artifact_store): - """ - Save the encoded data into the artifact store. + """Save the encoded data into the artifact store. :param artifact_store: Artifact store for saving the encoded object. @@ -871,8 +815,7 @@ def save(self, artifact_store): @classmethod def decode(cls, r, db=None) -> 'LazyArtifact': - """ - Decode data into a `LazyArtifact` instance. + """Decode data into a `LazyArtifact` instance. :param r: Object to decode :param db: Datalayer instance @@ -887,9 +830,7 @@ def decode(cls, r, db=None) -> 'LazyArtifact': @dc.dataclass class File(_BaseEncodable, _ArtifactSaveMixin): - """ - Data to be saved on disk and passed - as a file reference. + """Data to be saved on disk and passed as a file reference. :param x: File object """ @@ -898,9 +839,7 @@ class File(_BaseEncodable, _ArtifactSaveMixin): x: t.Any = Empty() def init(self, db): - """ - Initialize to load `x` with the actual file from - the artifact store. + """Initialize to load `x` with the actual file from the artifact store. :param db: A Datalayer instance """ @@ -909,8 +848,7 @@ def init(self, db): self.x = file def unpack(self, db): - """ - Unpack and get the original data. + """Unpack and get the original data. :param db: Datalayer instance. """ @@ -923,9 +861,7 @@ def _get_object(cls, db, r): @override def encode(self, leaf_types_to_keep: t.Sequence = ()): - """ - Encode `x` to a dictionary which is saved to - the artifact store later. + """Encode `x` to a dictionary which is saved to the artifact store later. :param leaf_types_to_keep: Leaf nodes to exclude from encoding @@ -945,8 +881,7 @@ def encode(self, leaf_types_to_keep: t.Sequence = ()): @classmethod def decode(cls, r, db=None) -> 'File': - """ - Decode data to a `File` instance. + """Decode data to a `File` instance. :param r: Object to decode :param db: Datalayer instance @@ -961,16 +896,13 @@ def decode(cls, r, db=None) -> 'File': class LazyFile(File): - """ - Class is used to load a file only when needed. - """ + """Class is used to load a file only when needed.""" leaf_type: t.ClassVar[str] = 'lazy_file' @classmethod def decode(cls, r, db=None) -> 'LazyFile': - """ - Decode a dictionary to a `LazyFile` instance. + """Decode a dictionary to a `LazyFile` instance. :param r: Object to decode :param db: Datalayer instance diff --git a/superduperdb/components/graph.py b/superduperdb/components/graph.py index 0e9133bf96..ace4339804 100644 --- a/superduperdb/components/graph.py +++ b/superduperdb/components/graph.py @@ -10,9 +10,7 @@ def input_node(*args): - """ - Create an IndexableNode for input. - """ + """Create an IndexableNode for input.""" return IndexableNode( model=Input(spec=args if len(args) > 1 else args[0]), parent_graph=nx.DiGraph(), @@ -21,16 +19,15 @@ def input_node(*args): def document_node(*args): - """ - Create an IndexableNode for document input. - """ + """Create an IndexableNode for document input.""" return IndexableNode( model=DocumentInput(spec=args), parent_graph=nx.DiGraph(), parent_models={} ) class IndexableNode: - """ + """IndexableNode class to index the model. + Create a model IndexableNode, which can be used to index the model while creating graph links. @@ -51,8 +48,7 @@ def __init__( self.identifier = identifier def __getitem__(self, item): - """ - Method for indexing the model. + """Method for indexing the model. :param item: Index """ @@ -65,8 +61,7 @@ def __getitem__(self, item): ) def to_graph(self, identifier: str): - """ - Helper method to get the graph form. + """Helper method to get the graph form. :param identifier: Unique identifier """ @@ -92,8 +87,7 @@ def _get_node(self, u): return self.parent_models[u] def to_listeners(self, select: CompoundSelect, identifier: str): - """ - Create listeners from the parent graph and models. + """Create listeners from the parent graph and models. :param select: CompoundSelect query :param identifier: Unique identifier @@ -136,8 +130,7 @@ def to_listeners(self, select: CompoundSelect, identifier: str): class OutputWrapper: - """ - OutputWrapper class for wrapping model outputs. + """OutputWrapper class for wrapping model outputs. :param r: Output :param keys: Output keys @@ -158,8 +151,7 @@ def __getitem__(self, item): @dc.dataclass(kw_only=True) class Input(Model): - """ - Root model of a graph. + """Root model of a graph. :param spec: Model specifications from `inspect` :param identifier: Unique identifier @@ -176,9 +168,7 @@ def __post_init__(self, artifacts): self.signature = 'singleton' def predict_one(self, *args): - """ - Single prediction. - """ + """Single prediction.""" if self.signature == 'singleton': return args[0] return OutputWrapper( @@ -186,8 +176,7 @@ def predict_one(self, *args): ) def predict(self, dataset): - """ - Predict on the dataset. + """Predict on the dataset. :param dataset: Series of datapoints """ @@ -196,8 +185,7 @@ def predict(self, dataset): @dc.dataclass(kw_only=True) class DocumentInput(Model): - """ - Document Input node of the graph. + """Document Input node of the graph. :param spec: Model specifications from `inspect` :param identifier: Unique identifier @@ -211,16 +199,14 @@ def __post_init__(self, artifacts): super().__post_init__(artifacts) def predict_one(self, r): - """ - Single prediction. + """Single prediction. :param r: Model input """ return {k: r[k] for k in self.spec} def predict(self, dataset): - """ - Predict on the dataset. + """Predict on the dataset. :param dataset: Series of datapoints """ @@ -235,8 +221,7 @@ def predict(self, dataset): @dc.dataclass(kw_only=True) class Graph(Model): - """ - Represents a directed acyclic graph composed of interconnected model nodes. + """Represents a directed acyclic graph composed of interconnected model nodes. This class enables the creation of complex predictive models by defining a computational graph structure where each node @@ -253,11 +238,13 @@ class Graph(Model): :param signature: Graph signature. Example: + ------- >> g = Graph( >> identifier='simple-graph', input=model1, outputs=[model2], signature='*args' >> ) >> g.connect(model1, model2) >> assert g.predict_one(1) == [(4, 2)] + """ ui_schema: t.ClassVar[t.List[t.Dict]] = [ @@ -314,7 +301,8 @@ def connect( on: t.Optional[t.Tuple[t.Union[int, str], str]] = None, update_edge: t.Optional[bool] = True, ): - """ + """Connect the relationship between two models. + Connects two nodes `u` and `v` on an edge, where the edge is a tuple with the first element describing output index (int or None) and the second describing input argument (str). @@ -326,7 +314,9 @@ def connect( :param update_edge: Bool to update edge. Note: + ---- Output index: None means all outputs of node u are connected to node v. + """ assert isinstance(u, Model) assert isinstance(v, Model) @@ -361,8 +351,7 @@ def connect( def fetch_output( self, output, index: t.Optional[t.Union[int, str]] = None, one: bool = False ): - """ - Get corresponding output from model outputs with respect to link weight. + """Get corresponding output from model outputs with respect to link weight. :param output: model output. :param index: index to select output. @@ -388,8 +377,7 @@ def fetch_output( return output def validate(self, node): - """ - Validates the graph for any disconnection. + """Validates the graph for any disconnection. :param node: Graph node. """ @@ -494,7 +482,8 @@ def _predict_on_node(self, *args, node=None, cache={}, one=True, **kwargs): return cache[node] def predict_one(self, *args, **kwargs): - """ + """Predict on single data point. + Single data point prediction passes the args and kwargs to the defined node flow in the graph. """ @@ -515,7 +504,8 @@ def predict_one(self, *args, **kwargs): ) def patch_dataset_to_args(self, dataset): - """ + """Get the dataset and patch it with args. + Patch the dataset with args type as default, since all corresponding nodes take args as input type. @@ -541,8 +531,7 @@ def mapping(x, signature): return args_dataset def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: - """ - Predict on dataset i.e. series of datapoints. + """Predict on dataset i.e. series of datapoints. :param dataset: Series of datapoints. """ @@ -569,8 +558,7 @@ def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: return outputs def encode_outputs(self, outputs): - """ - Encode outputs for serialization in the database. + """Encode outputs for serialization in the database. :param outputs: model outputs. """ diff --git a/superduperdb/components/listener.py b/superduperdb/components/listener.py index 1fe3a6a988..102fe11cb2 100644 --- a/superduperdb/components/listener.py +++ b/superduperdb/components/listener.py @@ -26,7 +26,8 @@ @public_api(stability='stable') @dc.dataclass(kw_only=True) class Listener(Component): - """ + """Listener component. + Listener object which is used to process a column/key of a collection or table, and store the outputs. @@ -60,8 +61,7 @@ class Listener(Component): @classmethod def handle_integration(cls, kwargs): - """ - Method to handle integration. + """Method to handle integration. :param kwargs: Integration keyword arguments. """ @@ -89,16 +89,12 @@ def __post_init__(self, artifacts): @property def mapping(self): - """ - Mapping property. - """ + """Mapping property.""" return Mapping(self.key, signature=self.model.signature) @property def outputs(self): - """ - Get reference to outputs of listener model. - """ + """Get reference to outputs of listener model.""" if self.model.version is not None: return f'{_OUTPUTS_KEY}.{self.identifier}::{self.version}' else: @@ -113,9 +109,7 @@ def _callback(db, value, kwargs): @property def outputs_select(self): - """ - Get query reference to model outputs. - """ + """Get query reference to model outputs.""" if self.select.DB_TYPE == "SQL": return self.select.table_or_collection.outputs(self.predict_id) @@ -132,9 +126,7 @@ def outputs_select(self): @property def outputs_key(self): - """ - Model outputs key. - """ + """Model outputs key.""" if self.select.DB_TYPE == "SQL": return 'output' else: @@ -142,8 +134,7 @@ def outputs_key(self): @override def pre_create(self, db: "Datalayer") -> None: - """ - Pre-create hook. + """Pre-create hook. :param db: Data layer instance. """ @@ -168,8 +159,7 @@ def _set_key(db, key, **kwargs): @override def post_create(self, db: "Datalayer") -> None: - """ - Post-create hook. + """Post-create hook. :param db: Data layer instance. """ @@ -206,9 +196,7 @@ def create_output_dest(cls, db: "Datalayer", predict_id, model: Model): @property def dependencies(self) -> t.List[ComponentTuple]: - """ - Listener model dependencies. - """ + """Listener model dependencies.""" args, kwargs = self.mapping.mapping all_ = list(args) + list(kwargs.values()) out = [] @@ -221,9 +209,7 @@ def dependencies(self) -> t.List[ComponentTuple]: @property def predict_id(self): - """ - Get predict ID. - """ + """Get predict ID.""" return f'{self.identifier}::{self.version}' @classmethod @@ -234,15 +220,12 @@ def from_predict_id(cls, db: "Datalayer", predict_id) -> 'Listener': :param db: Data layer instance. :param predict_id: Predict ID. """ - identifier, version = predict_id.rsplit('::', 1) return t.cast(Listener, db.load('listener', identifier, version=int(version))) @property def id_key(self) -> str: - """ - Get identifier key. - """ + """Get identifier key.""" def _id_key(key) -> str: if isinstance(key, str): @@ -260,8 +243,7 @@ def _id_key(key) -> str: return _id_key(self.key) def depends_on(self, other: Component): - """ - Check if the listener depends on another component. + """Check if the listener depends on another component. :param other: Another component. """ @@ -280,8 +262,7 @@ def schedule_jobs( dependencies: t.Sequence[Job] = (), overwrite: bool = False, ) -> t.Sequence[t.Any]: - """ - Schedule jobs for the listener. + """Schedule jobs for the listener. :param db: Data layer instance to process. :param dependencies: A list of dependencies. diff --git a/superduperdb/components/metric.py b/superduperdb/components/metric.py index f0acd80002..4c26d17ecb 100644 --- a/superduperdb/components/metric.py +++ b/superduperdb/components/metric.py @@ -8,8 +8,8 @@ @public_api(stability='beta') @dc.dataclass(kw_only=True) class Metric(Component): - """ - Metric base object used to evaluate performance on a dataset. + """Metric base object used to evaluate performance on a dataset. + These objects are callable and are applied row-wise to the data, and averaged. {component_parameters} @@ -24,4 +24,9 @@ class Metric(Component): object: t.Callable def __call__(self, x: t.Sequence[int], y: t.Sequence[int]) -> bool: + """Call the metric object on the x and y data. + + :param x: First sequence of data. + :param y: Second sequence of data. + """ return self.object(x, y) diff --git a/superduperdb/components/model.py b/superduperdb/components/model.py index 15e33c7b33..33d0fb2baa 100644 --- a/superduperdb/components/model.py +++ b/superduperdb/components/model.py @@ -49,7 +49,8 @@ def objectmodel( flatten: bool = False, output_schema: t.Optional[Schema] = None, ): - """ + """Decorator to wrap a function with `ObjectModel`. + When a function is wrapped with this decorator, the function comes out as an `ObjectModel`. @@ -116,7 +117,8 @@ def codemodel( flatten: bool = False, output_schema: t.Optional[Schema] = None, ): - """ + """Decorator to wrap a function with `CodeModel`. + When a function is wrapped with this decorator, the function comes out as a `CodeModel`. @@ -151,8 +153,7 @@ def decorated_function(item): class Inputs: - """ - Base class to represent the model args and kwargs. + """Base class to represent the model args and kwargs. :param params: List of parameters of the Model object """ @@ -167,8 +168,7 @@ def __getattr__(self, attr): return self.params[attr] def get_kwargs(self, args): - """ - Get keyword arguments from positional arguments. + """Get keyword arguments from positional arguments. :param args: Parameters to be converted """ @@ -178,13 +178,13 @@ def get_kwargs(self, args): return kwargs def __call__(self, *args, **kwargs): + """Get the model args and kwargs.""" tmp = self.get_kwargs(args) return {**tmp, **kwargs} class CallableInputs(Inputs): - """ - Class represents the model callable args and kwargs. + """Class represents the model callable args and kwargs. :param fn: Callable function :param predict_kwargs: (optional) predict_kwargs if provided in Model @@ -209,7 +209,8 @@ def __init__(self, fn, predict_kwargs: t.Dict = {}): @dc.dataclass(kw_only=True) class Trainer(Component): - """ + """Trainer component to train a model. + Training configuration object, containing all settings necessary for a particular learning task use-case to be serialized and initiated. The object is ``callable`` and returns a class which may be invoked to apply training. @@ -246,8 +247,7 @@ def fit( train_dataset: QueryDataset, valid_dataset: QueryDataset, ): - """ - Fit on the model on training dataset with `valid_dataset` for validation. + """Fit on the model on training dataset with `valid_dataset` for validation. :param model: Model to be fit :param db: The datalayer @@ -259,8 +259,7 @@ def fit( @dc.dataclass(kw_only=True) class Validation(Component): - """ - component which represents Validation definition. + """component which represents Validation definition. :param metrics: List of metrics for validation :param key: Model input type key @@ -278,8 +277,7 @@ class _Fittable: trainer: t.Optional[Trainer] = None def schedule_jobs(self, db, dependencies=()): - """ - Database hook for scheduling jobs. + """Database hook for scheduling jobs. :param db: Datalayer instance. :param dependencies: List of dependencies. @@ -301,8 +299,7 @@ def fit_in_db_job( db: Datalayer, dependencies: t.Sequence[Job] = (), ): - """ - Model fit job in database. + """Model fit job in database. :param db: Datalayer instance. :param dependencies: List of dependent jobs @@ -362,8 +359,7 @@ def fit( valid_dataset: QueryDataset, db: Datalayer, ): - """ - Fit the model on the training dataset with `valid_dataset` for validation. + """Fit the model on the training dataset with `valid_dataset` for validation. :param train_dataset: The training ``Dataset`` instances to use. :param valid_dataset: The validation ``Dataset`` instances to use. @@ -382,8 +378,7 @@ def fit( ) def fit_in_db(self, db: Datalayer): - """ - Fit the model on the given data. + """Fit the model on the given data. :param db: The datalayer """ @@ -407,8 +402,7 @@ def append_metrics(self, d: t.Dict[str, float]) -> None: class Mapping: - """ - Class to represent model inputs for mapping database collections or tables. + """Class to represent model inputs for mapping database collections or tables. :param mapping: Mapping that represents a collection or table map. :param signature: Signature for the model. @@ -420,9 +414,7 @@ def __init__(self, mapping: ModelInputType, signature: Signature): @property def id_key(self): - """ - Extract the output key for model outputs. - """ + """Extract the output key for model outputs.""" outputs = [] for arg in self.mapping[0]: outputs.append(arg) @@ -446,7 +438,8 @@ def _map_args_kwargs(mapping): return mapping def __call__(self, r): - """ + """Get the model inputs from the mapping. + >>> r = {'a': 1, 'b': 2} >>> self.mapping = [('a', 'b'), {}] >>> _Predictor._data_from_input_type(docs) @@ -483,8 +476,7 @@ def __call__(self, r): @dc.dataclass(kw_only=True) class Model(Component): - """ - Base class for components which can predict. + """Base class for components which can predict. :param signature: Model signature. :param datatype: DataType instance. @@ -523,14 +515,13 @@ def __post_init__(self, artifacts): @property def inputs(self) -> Inputs: - """ - Instance of `Inputs` to represent model params. - """ + """Instance of `Inputs` to represent model params.""" return Inputs(list(inspect.signature(self.predict_one).parameters.keys())) @abstractmethod def predict_one(self, *args, **kwargs) -> int: - """ + """Predict on a single data point. + Execute a single prediction on a data point given by positional and keyword arguments. """ @@ -538,8 +529,7 @@ def predict_one(self, *args, **kwargs) -> int: @abstractmethod def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: - """ - Execute on a series of data points defined in the dataset. + """Execute on a series of data points defined in the dataset. :param dataset: Series of data points to predict on. """ @@ -575,8 +565,9 @@ def predict_in_db_job( in_memory: bool = True, overwrite: bool = False, ): - """ - Execute a single prediction on a data point + """Run a prediction job in the database. + + Execute a single prediction on the data points given by positional and keyword arguments as a job. :param X: combination of input keys to be mapped to the model @@ -658,7 +649,8 @@ def predict_in_db( in_memory: bool = True, overwrite: bool = False, ) -> t.Any: - """ + """Predict on the data points in the database. + Execute a single prediction on a data point given by positional and keyword arguments as a job. @@ -738,8 +730,7 @@ def _prepare_inputs_from_select( @staticmethod def handle_input_type(data, signature): - """ - Method to transform data with respect to signature. + """Method to transform data with respect to signature. :param data: Data to be transformed :param signature: Data signature for transforming @@ -811,8 +802,7 @@ def _predict_with_select_and_ids( ) def encode_outputs(self, outputs): - """ - Method that encodes outputs of a model for saving in the database. + """Method that encodes outputs of a model for saving in the database. :param outputs: outputs to encode. """ @@ -867,8 +857,7 @@ def _infer_auto_schema(self, outputs, predict_id): self.db.replace(self) def encode_with_schema(self, outputs): - """ - Encode model outputs corresponding to the provided `output_schema`. + """Encode model outputs corresponding to the provided `output_schema`. :param outputs: Encode the outputs with the given schema. """ @@ -883,6 +872,12 @@ def encode_with_schema(self, outputs): return outputs def __call__(self, *args, outputs: t.Optional[str] = None, **kwargs): + """Connect the models to build a graph. + + :param args: Arguments to be passed to the model. + :param outputs: Identifier for the model outputs. + :param kwargs: Keyword arguments to be passed to the model. + """ from superduperdb.components.graph import IndexableNode if args: @@ -918,8 +913,7 @@ def to_listener( predict_kwargs: t.Optional[dict] = None, **kwargs, ): - """ - Convert the model to a listener. + """Convert the model to a listener. :param key: Key to be bound to the model :param select: Object for selecting which data is processed @@ -939,8 +933,7 @@ def to_listener( return listener def validate(self, X, dataset: Dataset, metrics: t.Sequence[Metric]): - """ - Validate `dataset` on metrics. + """Validate `dataset` on metrics. :param X: Define input map :param dataset: Dataset to run validation on. @@ -955,8 +948,7 @@ def validate(self, X, dataset: Dataset, metrics: t.Sequence[Metric]): return results def validate_in_db_job(self, db, dependencies: t.Sequence[Job] = ()): - """ - Perform a validation job. + """Perform a validation job. :param db: DataLayer instance :param dependencies: dependencies on the job @@ -971,8 +963,7 @@ def validate_in_db_job(self, db, dependencies: t.Sequence[Job] = ()): return job def validate_in_db(self, db): - """ - Validation job in database. + """Validation job in database. :param db: DataLayer instance. """ @@ -1018,9 +1009,7 @@ def __init__(self, position): @dc.dataclass class IndexableNode: - """ - Base indexable node for `ObjectModel`. - """ + """Base indexable node for `ObjectModel`.""" def __init__(self, types): self.types = types @@ -1045,24 +1034,18 @@ class _ObjectModel(Model, ABC): @property def outputs(self): - """ - Get an instance of ``IndexableNode`` to index outputs. - """ + """Get an instance of ``IndexableNode`` to index outputs.""" return IndexableNode([int]) @property def inputs(self): - """ - A method to get Model callable inputs. - """ + """A method to get Model callable inputs.""" kwargs = self.predict_kwargs if self.predict_kwargs else {} return CallableInputs(self.object, kwargs) @property def training_keys(self) -> t.List: - """ - Retrieve training keys. - """ + """Retrieve training keys.""" if isinstance(self.train_X, list): out = list(self.train_X) elif self.train_X is not None: @@ -1080,7 +1063,8 @@ def _wrapper(self, data): @ensure_initialized def predict_one(self, *args, **kwargs): - """ + """Predict on a single data point. + Method to execute ``Object`` on args and kwargs. This method is also used for debugging the Model. """ @@ -1088,8 +1072,7 @@ def predict_one(self, *args, **kwargs): @ensure_initialized def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: - """ - Run the predict on series of Model inputs (dataset). + """Run the predict on series of Model inputs (dataset). :param dataset: series of data points. """ @@ -1109,8 +1092,7 @@ def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: @public_api(stability='stable') @dc.dataclass(kw_only=True) class ObjectModel(_ObjectModel): - """ - Model component which wraps a Model to become serializable. + """Model component which wraps a Model to become serializable. {_object_model_params} """ @@ -1142,8 +1124,7 @@ class CodeModel(_ObjectModel): @classmethod def handle_integration(cls, kwargs): - """ - Handler integration from ui + """Handler integration from ui. :param kwargs: integration kwargs """ @@ -1157,7 +1138,9 @@ def handle_integration(cls, kwargs): @public_api(stability='beta') @dc.dataclass(kw_only=True) class APIBaseModel(Model): - """{component_params} + """APIBaseModel component which is used to make the type of API request. + + {component_params} {predictor_params} :param model: The Model to use, e.g. ``'text-embedding-ada-002'`` :param max_batch_size: Maximum batch size. @@ -1182,9 +1165,7 @@ def __post_init__(self, artifacts): def _multi_predict( self, dataset: t.Union[t.List, QueryDataset], *args, **kwargs ) -> t.List: - """ - Base method to batch generate text from a list of prompts using multithreading. - Handles exceptions in _generate method. + """Use multi-threading to predict on a series of data points. :param dataset: Series of data points. """ @@ -1202,11 +1183,14 @@ def _multi_predict( @dc.dataclass(kw_only=True) class APIModel(APIBaseModel): - """{component_params} + """APIModel component which is used to make the type of API request. + + {component_params} {predictor_params} {api_base_model_params} :param url: The url to use for the API request - :param postprocess: Postprocess function to use on the output of the API request""" + :param postprocess: Postprocess function to use on the output of the API request + """ __doc__ = __doc__.format( component_params=Component.__doc__, @@ -1219,9 +1203,7 @@ class APIModel(APIBaseModel): @property def inputs(self): - """ - Method to get ``Inputs`` instance for model inputs. - """ + """Method to get ``Inputs`` instance for model inputs.""" return Inputs(self.runtime_params) def __post_init__(self, artifacts): @@ -1234,15 +1216,15 @@ def __post_init__(self, artifacts): self.runtime_params = runtime_variables def build_url(self, params): - """ - Get url for the ``APIModel``. + """Get url for the ``APIModel``. :param params: url params. """ return self.url.format(**params, **{k: os.environ[k] for k in self.envs}) def predict_one(self, *args, **kwargs): - """ + """Predict on a single data point. + Method to requests to `url` on args and kwargs. This method is also used for debugging the model. """ @@ -1270,7 +1252,8 @@ def predict_one(self, *args, **kwargs): @public_api(stability='stable') @dc.dataclass(kw_only=True) class QueryModel(Model): - """ + """QueryModel component. + Model which can be used to query data and return those precomputed queries as Results. @@ -1302,8 +1285,7 @@ def _replace_variables(r): @classmethod def handle_integration(cls, kwargs): - """ - Handle integration from UI. + """Handle integration from UI. :param kwargs: Integration kwargs. """ @@ -1322,16 +1304,15 @@ def handle_integration(cls, kwargs): @property def inputs(self) -> Inputs: - """ - Instance of `Inputs` to represent model params. - """ + """Instance of `Inputs` to represent model params.""" if self.preprocess is not None: return CallableInputs(self.preprocess) return Inputs([x.value for x in self.select.variables]) @ensure_initialized def predict_one(self, *args, **kwargs): - """ + """Predict on a single data point. + Method to perform a single prediction on args and kwargs. This method is also used for debugging the model. """ @@ -1345,8 +1326,7 @@ def predict_one(self, *args, **kwargs): return out def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: - """ - Execute on a series of data points defined in the dataset. + """Execute on a series of data points defined in the dataset. :param dataset: Series of data points to predict on. """ @@ -1364,8 +1344,7 @@ def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: @public_api(stability='stable') @dc.dataclass(kw_only=True) class SequentialModel(Model): - """ - Sequential model component which wraps a model to become serializable. + """Sequential model component which wraps a model to become serializable. {_model_params} :param models: A list of models to use @@ -1389,14 +1368,11 @@ def __post_init__(self, artifacts): @property def inputs(self) -> Inputs: - """ - Instance of `Inputs` to represent model params. - """ + """Instance of `Inputs` to represent model params.""" return self.models[0].inputs def post_create(self, db: Datalayer): - """ - Post create hook. + """Post create hook. :param db: Datalayer instance. """ @@ -1407,15 +1383,15 @@ def post_create(self, db: Datalayer): self.on_load(db) def predict_one(self, *args, **kwargs): - """ + """Predict on a single data point. + Method to do single prediction on args and kwargs. This method is also used for debugging the model. """ return self.predict([(args, kwargs)])[0] def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: - """ - Execute on series of data point defined in dataset. + """Execute on series of data point defined in dataset. :param dataset: Series of data point to predict on. """ diff --git a/superduperdb/components/schema.py b/superduperdb/components/schema.py index 2d731595ee..118788d7f6 100644 --- a/superduperdb/components/schema.py +++ b/superduperdb/components/schema.py @@ -12,8 +12,7 @@ @public_api(stability='beta') @dc.dataclass(kw_only=True) class Schema(Component): - """ - A component containing information about the types or encoders of a table. + """A component containing information about the types or encoders of a table. {component_parameters} :param fields: A mapping of field names to types or encoders. @@ -31,8 +30,7 @@ def __post_init__(self, artifacts): @override def pre_create(self, db) -> None: - """ - Database pre-create hook to add datatype to the database. + """Database pre-create hook to add datatype to the database. :param db: Datalayer instance. """ @@ -43,7 +41,8 @@ def pre_create(self, db) -> None: @property def raw(self): - """ + """Return the raw fields. + Get a dictionary of fields as keys and datatypes as values. This is used to create ibis tables. """ @@ -54,34 +53,26 @@ def raw(self): @cached_property def encoded_types(self): - """ - List of fields of type DataType. - """ + """List of fields of type DataType.""" return [k for k, v in self.fields.items() if isinstance(v, DataType)] @cached_property def trivial(self): - """ - Determine if the schema contains only trivial fields. - """ + """Determine if the schema contains only trivial fields.""" return not any([isinstance(v, DataType) for v in self.fields.values()]) @property def encoders(self): - """ - An iterable to list DataType fields. - """ + """An iterable to list DataType fields.""" for v in self.fields.values(): if isinstance(v, DataType): yield v def decode_data(self, data: dict[str, t.Any]) -> dict[str, t.Any]: - """ - Decode data using the schema's encoders. + """Decode data using the schema's encoders. :param data: Data to decode. """ - if self.trivial: return data @@ -94,8 +85,7 @@ def decode_data(self, data: dict[str, t.Any]) -> dict[str, t.Any]: return decoded def __call__(self, data: dict[str, t.Any]) -> dict[str, t.Any]: - """ - Encode data using the schema's encoders. + """Encode data using the schema's encoders. :param data: Data to encode. """ diff --git a/superduperdb/components/stack.py b/superduperdb/components/stack.py index 18420a556e..c2f68ba088 100644 --- a/superduperdb/components/stack.py +++ b/superduperdb/components/stack.py @@ -13,7 +13,8 @@ @public_api(stability='alpha') @dc.dataclass(kw_only=True) class Stack(Component): - """ + """Component to hold a list of components under a namespace and package. + A placeholder to hold a list of components under a namespace and package them as a tarball. This tarball can be retrieved back to a Stack instance with the @@ -31,15 +32,12 @@ class Stack(Component): @property def db(self): - """ - Datalayer property. - """ + """Datalayer property.""" return self._db @db.setter def db(self, value): - """ - Datalayer setter. + """Datalayer setter. :param value: Item to set the property. """ @@ -49,8 +47,7 @@ def db(self, value): @staticmethod def from_list(identifier, content, db: t.Optional['Datalayer'] = None): - """ - Helper method to create a Stack from a list `content`. + """Helper method to create a Stack from a list `content`. :param identifier: Unique identifier. :param content: Content to create a stack. diff --git a/superduperdb/components/vector_index.py b/superduperdb/components/vector_index.py index 24ced289d8..701b55192b 100644 --- a/superduperdb/components/vector_index.py +++ b/superduperdb/components/vector_index.py @@ -53,8 +53,9 @@ class VectorIndex(Component): @override def on_load(self, db: Datalayer) -> None: """ - On load hook to perform indexing and compatible listener - loading on loading of vector index from database. + On load hook to perform indexing and compatible listenernd compatible listener. + + Automatically loads the listeners if they are not already loaded. :param db: A DataLayer instance """ @@ -76,9 +77,10 @@ def get_vector( db: t.Any = None, outputs: t.Optional[t.Dict] = None, ): - """ + """Peform vector search. + Perform vector search with query `like` from outputs in db - on `self.identifier` vectori index. + on `self.identifier` vector index. :param like: The document to compare against :param models: List of models to retrieve outputs @@ -138,7 +140,8 @@ def get_nearest( ids: t.Optional[t.Sequence[str]] = None, n: int = 100, ) -> t.Tuple[t.List[str], t.List[float]]: - """ + """Get nearest results in this vector index. + Given a document, find the nearest results in this vector index, returned as two parallel lists of result IDs and scores. @@ -149,7 +152,6 @@ def get_nearest( :param ids: A list of ids to match :param n: Number of items to return """ - models, keys = self.models_keys if len(models) != len(keys): raise ValueError(f'len(model={models}) != len(keys={keys})') @@ -173,9 +175,7 @@ def get_nearest( @property def models_keys(self) -> t.Tuple[t.List[str], t.List[ModelInputType]]: - """ - Return a list of model and keys for each listener. - """ + """Return a list of model and keys for each listener.""" assert not isinstance(self.indexing_listener, str) assert not isinstance(self.compatible_listener, str) @@ -190,9 +190,9 @@ def models_keys(self) -> t.Tuple[t.List[str], t.List[ModelInputType]]: @property def dimensions(self) -> int: - """ - Get dimension for vector database. This dimension will be used to prepare - vectors in the vector database. + """Get dimension for vector database. + + This dimension will be used to prepare vectors in the vector database. """ assert not isinstance(self.indexing_listener, str) assert not isinstance(self.indexing_listener.model, str) @@ -206,8 +206,7 @@ def schedule_jobs( db: Datalayer, dependencies: t.Sequence['Job'] = (), ) -> t.Sequence[t.Any]: - """ - Schedule jobs for the listener. + """Schedule jobs for the listener. :param db: The DB instance to process :param dependencies: A list of dependencies @@ -228,8 +227,7 @@ def schedule_jobs( class EncodeArray: - """ - Class to encode an array. + """Class to encode an array. :param dtype: Datatype of array """ @@ -238,6 +236,11 @@ def __init__(self, dtype): self.dtype = dtype def __call__(self, x, info: t.Optional[t.Dict] = None): + """Encode an array. + + :param x: The array to encode + :param info: Optional info + """ x = np.asarray(x) if x.dtype != self.dtype: raise TypeError(f'dtype was {x.dtype}, expected {self.dtype}') @@ -245,8 +248,7 @@ def __call__(self, x, info: t.Optional[t.Dict] = None): class DecodeArray: - """ - Class to decode an array. + """Class to decode an array. :param dtype: Datatype of array """ @@ -255,6 +257,11 @@ def __init__(self, dtype): self.dtype = dtype def __call__(self, bytes, info: t.Optional[t.Dict] = None): + """Decode an array. + + :param bytes: The bytes to decode + :param info: Optional info + """ return np.frombuffer(bytes, dtype=self.dtype).tolist() @@ -263,8 +270,7 @@ def __call__(self, bytes, info: t.Optional[t.Dict] = None): {'name': 'identifier', 'type': 'str'}, ) def vector(shape, identifier: t.Optional[str] = None): - """ - Create an encoder for a vector (list of ints/ floats) of a given shape. + """Create an encoder for a vector (list of ints/ floats) of a given shape. :param shape: The shape of the vector :param identifier: The identifier of the vector @@ -283,9 +289,9 @@ def vector(shape, identifier: t.Optional[str] = None): def sqlvector(shape): - """ - Create an encoder for a vector (list of ints/ floats) of a given shape - compatible with sql databases. + """Create an encoder for a vector (list of ints/ floats) of a given shape. + + This is used for compatibility with SQL databases, as the default vector :param shape: The shape of the vector """ diff --git a/superduperdb/ext/anthropic/model.py b/superduperdb/ext/anthropic/model.py index 5d9cb7512d..051abd8140 100644 --- a/superduperdb/ext/anthropic/model.py +++ b/superduperdb/ext/anthropic/model.py @@ -21,7 +21,10 @@ @dc.dataclass(kw_only=True) class Anthropic(APIBaseModel): - """Anthropic predictor.""" + """Anthropic predictor. + + :param client_kwargs: The keyword arguments to pass to the client. + """ client_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict) @@ -37,13 +40,19 @@ def __post_init__(self, artifacts): class AnthropicCompletions(Anthropic): """Cohere completions (chat) predictor. - :param takes_context: Whether the model takes context into account. :param prompt: The prompt to use to seed the response. """ prompt: str = '' def pre_create(self, db: Datalayer) -> None: + """Pre create method for the model. + + If the datalayer is Ibis, the datatype will be set to the appropriate + SQL datatype. + + :param db: The datalayer to use for the model. + """ super().pre_create(db) if isinstance(db.databackend, IbisDataBackend) and self.datatype is None: self.datatype = dtype('str') @@ -55,6 +64,13 @@ def predict_one( context: t.Optional[t.List[str]] = None, **kwargs, ): + """Generate text from a single input. + + :param X: The input to generate text from. + :param context: The context to use for the prompt. + :param kwargs: The keyword arguments to pass to the prompt function and + the llm model. + """ if isinstance(X, str): if context is not None: X = format_prompt(X, self.prompt, context=context) @@ -75,4 +91,8 @@ def predict_one( return message.content[0].text def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + """Predict the embeddings of a dataset. + + :param dataset: The dataset to predict the embeddings of. + """ return [self.predict_one(dataset[i]) for i in range(len(dataset))] diff --git a/superduperdb/ext/cohere/model.py b/superduperdb/ext/cohere/model.py index 10c7147fff..9882fbc7c5 100644 --- a/superduperdb/ext/cohere/model.py +++ b/superduperdb/ext/cohere/model.py @@ -21,7 +21,10 @@ @dc.dataclass(kw_only=True) class Cohere(APIBaseModel): - """Cohere predictor""" + """Cohere predictor. + + :param client_kwargs: The keyword arguments to pass to the client. + """ client_kwargs: t.Dict[str, t.Any] = dc.field(default_factory=dict) @@ -32,9 +35,10 @@ def __post_init__(self, artifacts): @dc.dataclass(kw_only=True) class CohereEmbed(Cohere): - """Cohere embedding predictor + """Cohere embedding predictor. :param shape: The shape as ``tuple`` of the embedding. + :param batch_size: The batch size to use for the predictor. """ signature: t.ClassVar[str] = 'singleton' @@ -48,6 +52,13 @@ def __post_init__(self, artifacts): self.shape = self.shapes[self.identifier] def pre_create(self, db): + """Pre create method for the model. + + If the datalayer is Ibis, the datatype will be set to the appropriate + SQL datatype. + + :param db: The datalayer to use for the model. + """ super().pre_create(db) if isinstance(db.databackend, IbisDataBackend): if self.datatype is None: @@ -57,6 +68,10 @@ def pre_create(self, db): @retry def predict_one(self, X: str): + """Predict the embedding of a single text. + + :param X: The text to predict the embedding of. + """ client = cohere.Client(get_key(KEY_NAME), **self.client_kwargs) e = client.embed(texts=[X], model=self.identifier, **self.predict_kwargs) return e.embeddings[0] @@ -68,6 +83,10 @@ def _predict_a_batch(self, texts: t.List[str]): return [r for r in out.embeddings] def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + """Predict the embeddings of a dataset. + + :param dataset: The dataset to predict the embeddings of. + """ out = [] for i in tqdm.tqdm(range(0, len(dataset), self.batch_size)): out.extend( @@ -80,7 +99,7 @@ def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: @dc.dataclass(kw_only=True) class CohereGenerate(Cohere): - """Cohere realistic text generator (chat predictor) + """Cohere realistic text generator (chat predictor). :param takes_context: Whether the model takes context into account. :param prompt: The prompt to use to seed the response. @@ -91,12 +110,24 @@ class CohereGenerate(Cohere): prompt: str = '' def pre_create(self, db: Datalayer) -> None: + """Pre create method for the model. + + If the datalayer is Ibis, the datatype will be set to the appropriate + SQL datatype. + + :param db: The datalayer to use for the model. + """ super().pre_create(db) if isinstance(db.databackend, IbisDataBackend) and self.datatype is None: self.datatype = dtype('str') @retry def predict_one(self, prompt: str, context: t.Optional[t.List[str]] = None): + """Predict the generation of a single prompt. + + :param prompt: The prompt to generate text from. + :param context: The context to use for the prompt. + """ if context is not None: prompt = format_prompt(prompt, self.prompt, context=context) client = cohere.Client(get_key(KEY_NAME), **self.client_kwargs) @@ -107,4 +138,8 @@ def predict_one(self, prompt: str, context: t.Optional[t.List[str]] = None): @retry def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + """Predict the generations of a dataset. + + :param dataset: The dataset to predict the generations of. + """ return [self.predict_one(dataset[i]) for i in range(len(dataset))] diff --git a/superduperdb/ext/jina/client.py b/superduperdb/ext/jina/client.py index 4a5909f9ea..317e84e75b 100644 --- a/superduperdb/ext/jina/client.py +++ b/superduperdb/ext/jina/client.py @@ -15,21 +15,23 @@ class JinaAPIClient: + """A client for the Jina Embedding platform. + + Create a JinaAPIClient to provide an interface to encode using + Jina Embedding platform sync and async. + + :param api_key: The Jina API key. + It can be explicitly provided or automatically read + from the environment variable JINA_API_KEY (recommended). + :param model_name: The name of the Jina model to use. + Check the list of available models on `https://jina.ai/embeddings/` + """ + def __init__( self, api_key: Optional[str] = None, model_name: str = 'jina-embeddings-v2-base-en', ): - """ - Create a JinaAPIClient to provide an interface to encode using - Jina Embedding platform sync and async. - - :param api_key: The Jina API key. - It can be explicitly provided or automatically read - from the environment variable JINA_API_KEY (recommended). - :param model_name: The name of the Jina model to use. - Check the list of available models on `https://jina.ai/embeddings/` - """ # if the user does not provide the API key, # check if it is set in the environment variable if api_key is None: @@ -46,6 +48,10 @@ def __init__( @retry def encode_batch(self, texts: List[str]) -> List[List[float]]: + """Encode a batch of texts synchronously. + + :param texts: The list of texts to encode. + """ response = self._session.post( JINA_API_URL, json={"input": texts, "model": self.model_name} ).json() @@ -59,6 +65,10 @@ def encode_batch(self, texts: List[str]) -> List[List[float]]: @retry async def aencode_batch(self, texts: List[str]) -> List[List[float]]: + """Encode a batch of texts asynchronously. + + :param texts: The list of texts to encode. + """ async with aiohttp.ClientSession() as session: payload = { 'model': self.model_name, diff --git a/superduperdb/ext/jina/model.py b/superduperdb/ext/jina/model.py index df2016b50e..8a4747e5bb 100644 --- a/superduperdb/ext/jina/model.py +++ b/superduperdb/ext/jina/model.py @@ -12,7 +12,10 @@ @dc.dataclass(kw_only=True) class Jina(APIBaseModel): - """Cohere predictor""" + """Cohere predictor. + + :param api_key: The API key to use for the predicto + """ api_key: t.Optional[str] = None @@ -24,8 +27,9 @@ def __post_init__(self, artifacts): @dc.dataclass(kw_only=True) class JinaEmbedding(Jina): - """Jina embedding predictor + """Jina embedding predictor. + :param batch_size: The batch size to use for the predictor. :param shape: The shape of the embedding as ``tuple``. If not provided, it will be obtained by sending a simple query to the API """ @@ -41,6 +45,13 @@ def __post_init__(self, artifacts): self.shape = (len(self.client.encode_batch(['shape'])[0]),) def pre_create(self, db): + """Pre create method for the model. + + If the datalayer is Ibis, the datatype will be set to the appropriate + SQL datatype. + + :param db: The datalayer to use for the model. + """ super().pre_create(db) if isinstance(db.databackend, IbisDataBackend): if self.datatype is None: @@ -49,12 +60,20 @@ def pre_create(self, db): self.datatype = vector(self.shape) def predict_one(self, X: str): + """Predict the embedding of a single text. + + :param X: The text to predict the embedding of. + """ return self.client.encode_batch([X])[0] def _predict_a_batch(self, texts: t.List[str]): return self.client.encode_batch(texts) def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + """Predict the embeddings of a dataset. + + :param dataset: The dataset to predict the embeddings of. + """ out = [] for i in tqdm.tqdm(range(0, len(dataset), self.batch_size)): batch = [ diff --git a/superduperdb/ext/llamacpp/model.py b/superduperdb/ext/llamacpp/model.py index bec59289ed..313fe7ecc2 100644 --- a/superduperdb/ext/llamacpp/model.py +++ b/superduperdb/ext/llamacpp/model.py @@ -10,8 +10,7 @@ # TODO use core downloader already implemented def download_uri(uri, save_path): - """ - Download file + """Download file. :param uri: URI to download :param save_path: place to save @@ -26,13 +25,11 @@ def download_uri(uri, save_path): @dc.dataclass(kw_only=True) class LlamaCpp(BaseLLM): - """ - Llama.cpp connector + """Llama.cpp connector. :param model_name_or_path: path or name of model :param model_kwargs: dictionary of init-kwargs :param download_dir: local caching directory - :param signature: s """ signature: t.ClassVar[str] = 'singleton' @@ -42,6 +39,10 @@ class LlamaCpp(BaseLLM): download_dir: str = '.llama_cpp' def init(self): + """Initialize the model. + + If the model_name_or_path is a uri, download it to the download_dir. + """ if self.model_name_or_path.startswith('http'): # Download the uri os.makedirs(self.download_dir, exist_ok=True) @@ -56,8 +57,10 @@ def init(self): self._model = Llama(self.model_name_or_path, **self.model_kwargs) def _generate(self, prompt: str, **kwargs) -> str: - """ - Generate text from a prompt. + """Generate text from a prompt. + + :param prompt: The prompt to generate text from. + :param kwargs: The keyword arguments to pass to the llm model. """ out = self._model.create_completion(prompt, **self.predict_kwargs, **kwargs) return out['choices'][0]['text'] @@ -65,9 +68,13 @@ def _generate(self, prompt: str, **kwargs) -> str: @dc.dataclass class LlamaCppEmbedding(LlamaCpp): + """Llama.cpp connector for embeddings.""" + def _generate(self, prompt: str, **kwargs) -> str: - """ - Generate embedding from a prompt. + """Generate embedding from a prompt. + + :param prompt: The prompt to generate the embedding from. + :param kwargs: The keyword arguments to pass to the llm model. """ return self._model.create_embedding( prompt, embedding=True, **self.predict_kwargs, **kwargs diff --git a/superduperdb/ext/llm/model.py b/superduperdb/ext/llm/model.py index aa0bc2d917..b30c2de46b 100644 --- a/superduperdb/ext/llm/model.py +++ b/superduperdb/ext/llm/model.py @@ -21,11 +21,11 @@ @dc.dataclass(kw_only=True) class BaseLLM(Model, metaclass=abc.ABCMeta): - """ - :param prompt_template: The template to use for the prompt. + """Base class for LLM models. + + :param prompt: The template to use for the prompt. :param prompt_func: The function to use for the prompt. :param max_batch_size: The maximum batch size to use for batch generation. - :param predict_kwargs: Parameters used during inference. """ prompt: str = "{input}" @@ -38,10 +38,11 @@ def __post_init__(self, artifacts): self.takes_context = True self.identifier = self.identifier.replace("/", "-") - def to_call(self, X, *args, **kwargs): - raise NotImplementedError - def post_create(self, db: "Datalayer") -> None: + """Post create method for the model. + + :param db: The datalayer to use for the model. + """ # TODO: Do not make sense to add this logic here, # Need a auto DataType to handle this from superduperdb.backends.ibis.data_backend import IbisDataBackend @@ -54,36 +55,54 @@ def post_create(self, db: "Datalayer") -> None: @abc.abstractmethod def init(self): + """Initialize the model.""" ... def _generate(self, prompt: str, **kwargs: t.Any): raise NotImplementedError def _batch_generate(self, prompts: t.List[str], **kwargs) -> t.List[str]: - """ - Base method to batch generate text from a list of prompts. + """Base method to batch generate text from a list of prompts. + If the model can run batch generation efficiently, pls override this method. + + :param prompts: The list of prompts to generate text from. + :param kwargs: The keyword arguments to pass to the prompt function and + the llm model. """ return [self._generate(prompt, **self.predict_kwargs) for prompt in prompts] @ensure_initialized def predict_one(self, X: t.Union[str, dict[str, str]], context=None, **kwargs): + """Generate text from a single input. + + :param X: The input to generate text from. + :param context: The context to use for the prompt. + :param kwargs: The keyword arguments to pass to the prompt function and + the llm model. + """ x = self.prompter(X, context=context, **kwargs) return self._generate(x, **kwargs) @ensure_initialized def predict(self, dataset: t.Union[t.List, QueryDataset], **kwargs) -> t.Sequence: + """Generate text from a dataset. + + :param dataset: The dataset to generate text from. + :param kwargs: The keyword arguments to pass to the prompt function and + the llm model. + + """ xs = [self.prompter(dataset[i], **kwargs) for i in range(len(dataset))] kwargs.pop("context", None) return self._batch_generate(xs, **kwargs) - def get_kwargs(self, func, *kwargs_list): - """ - Get kwargs and object attributes that are in the function signature - :param func (Callable): function to get kwargs for - :param kwargs (list of dict): kwargs to filter - """ + def get_kwargs(self, func: t.Callable, *kwargs_list): + """Get kwargs and object attributes that are in the function signature. + :param func: function to get kwargs for + :param *kwargs_list: kwargs to filter + """ total_kwargs = reduce(lambda x, y: {**y, **x}, [self.dict(), *kwargs_list]) sig = inspect.signature(func) new_kwargs = {} @@ -94,12 +113,14 @@ def get_kwargs(self, func, *kwargs_list): @property def prompter(self): + """Return a prompter for the model.""" return Prompter(self.prompt, self.prompt_func) @dc.dataclass class BaseLLMAPI(BaseLLM): - """ + """Base class for LLM models that use an API. + :param api_url: The URL for the API. {parent_doc} """ @@ -109,12 +130,11 @@ class BaseLLMAPI(BaseLLM): api_url: str = dc.field(default="") def init(self): + """Initialize the model.""" pass def _generate_wrapper(self, prompt: str, **kwargs: t.Any) -> str: - """ - Wrapper for the _generate method to handle exceptions. - """ + """Wrapper for the _generate method to handle exceptions.""" try: return self._generate(prompt, **kwargs) except Exception as e: @@ -124,7 +144,10 @@ def _generate_wrapper(self, prompt: str, **kwargs: t.Any) -> str: def _batch_generate(self, prompts: t.List[str], **kwargs: t.Any) -> t.List[str]: """ Base method to batch generate text from a list of prompts using multi-threading. + Handles exceptions in _generate method. + + :param prompts: The list of prompts to generate text from. """ with concurrent.futures.ThreadPoolExecutor( max_workers=self.max_batch_size diff --git a/superduperdb/ext/llm/prompter.py b/superduperdb/ext/llm/prompter.py index e80749b8be..317dec0794 100644 --- a/superduperdb/ext/llm/prompter.py +++ b/superduperdb/ext/llm/prompter.py @@ -8,10 +8,25 @@ @dc.dataclass class Prompter: + """Prompt the user for input. + + This function prompts the user for input based on a + template string and a function which formats the + prompt. + + :param prompt_template: The template string for the prompt. + :param prompt_func: The function which formats the prompt. + """ + prompt_template: str = "{input}" prompt_func: t.Optional[t.Callable] = dc.field(default=None) def __call__(self, x: t.Any, **kwargs): + """Format the prompt. + + :param x: The input to format the prompt. + :param kwargs: The keyword arguments to pass to the prompt function. + """ if self.prompt_func is not None: sig = inspect.signature(self.prompt_func) new_kwargs = {} @@ -38,10 +53,15 @@ def __call__(self, x: t.Any, **kwargs): @dc.dataclass(kw_only=True) class RetrievalPrompt(QueryModel): - """ + """Retrieve a prompt based on data recalled from the database. + This function creates a prompt based on data recalled from the database and a pre-specified question: + + :param prompt_explanation: The explanation of the prompt. + :param prompt_introduction: The introduction of the prompt. + :param join: The string to join the facts. """ prompt_explanation: str = PROMPT_EXPLANATION @@ -60,9 +80,14 @@ def __post_init__(self, artifacts): @property def inputs(self): + """The inputs of the model.""" return super().inputs def predict_one(self, prompt): + """Predict the answer to the question based on the prompt. + + :param prompt: The prompt to answer the question. + """ out = super().predict_one(prompt=prompt) prompt = ( self.prompt_explanation diff --git a/superduperdb/ext/numpy/encoder.py b/superduperdb/ext/numpy/encoder.py index 38ef6dc82e..a1304ab8a8 100644 --- a/superduperdb/ext/numpy/encoder.py +++ b/superduperdb/ext/numpy/encoder.py @@ -7,21 +7,42 @@ class EncodeArray: + """Encode a numpy array to bytes. + + :param dtype: The dtype of the array. + """ + def __init__(self, dtype): self.dtype = dtype def __call__(self, x, info: t.Optional[t.Dict] = None): + """Encode the numpy array to bytes. + + :param x: The numpy array. + :param info: The info of the encoding. + """ if x.dtype != self.dtype: raise TypeError(f'dtype was {x.dtype}, expected {self.dtype}') return memoryview(x).tobytes() class DecodeArray: + """Decode a numpy array from bytes. + + :param dtype: The dtype of the array. + :param shape: The shape of the array. + """ + def __init__(self, dtype, shape): self.dtype = dtype self.shape = shape def __call__(self, bytes, info: t.Optional[t.Dict] = None): + """Decode the numpy array from bytes. + + :param bytes: The bytes to decode. + :param info: The info of the encoding. + """ return numpy.frombuffer(bytes, dtype=self.dtype).reshape(self.shape) @@ -36,6 +57,8 @@ def array( :param dtype: The dtype of the array. :param shape: The shape of the array. + :param bytes_encoding: The bytes encoding to use. + :param encodable: The encodable to use. """ return DataType( identifier=f'numpy.{dtype}[{str_shape(shape)}]', @@ -48,18 +71,22 @@ def array( class NumpyDataTypeFactory(DataTypeFactory): + """A factory for numpy arrays.""" + @staticmethod def check(data: t.Any) -> bool: - """ - Check if the data is a numpy array. + """Check if the data is a numpy array. + It's used for registering the auto schema. + :param data: The data to check. """ return isinstance(data, numpy.ndarray) @staticmethod def create(data: t.Any) -> DataType: - """ - Create a numpy array datatype. + """Create a numpy array datatype. + It's used for registering the auto schema. + :param data: The numpy array. """ return array(data.dtype, data.shape) diff --git a/superduperdb/ext/openai/model.py b/superduperdb/ext/openai/model.py index 8a75434a01..281537dab0 100644 --- a/superduperdb/ext/openai/model.py +++ b/superduperdb/ext/openai/model.py @@ -3,7 +3,6 @@ import json import os import typing as t -from typing import Any import requests import tqdm @@ -22,7 +21,6 @@ from superduperdb.base.datalayer import Datalayer from superduperdb.components.model import APIBaseModel, Inputs from superduperdb.components.vector_index import sqlvector, vector -from superduperdb.ext.llm.model import BaseLLMAPI from superduperdb.misc.compat import cache from superduperdb.misc.retry import Retry @@ -45,9 +43,10 @@ def _available_models(skwargs): @dc.dataclass(kw_only=True) class _OpenAI(APIBaseModel): - ''' + """Base class for OpenAI models. + :param client_kwargs: The kwargs to be passed to OpenAI - ''' + """ openai_api_key: t.Optional[str] = None openai_api_base: t.Optional[str] = None @@ -99,10 +98,11 @@ def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: @dc.dataclass(kw_only=True) class OpenAIEmbedding(_OpenAI): - """ - OpenAI embedding predictor + """OpenAI embedding predictor. + {_openai_parameters} :param shape: The shape as ``tuple`` of the embedding. + :param batch_size: The batch size to use. """ __doc__ = __doc__.format(_openai_parameters=_OpenAI.__doc__) @@ -114,6 +114,7 @@ class OpenAIEmbedding(_OpenAI): @property def inputs(self): + """The inputs of the model.""" return Inputs(['input']) def __post_init__(self, artifacts): @@ -122,6 +123,13 @@ def __post_init__(self, artifacts): self.shape = self.shapes[self.model] def pre_create(self, db): + """Pre creates the model. + + If the datatype is not set and the datalayer is an IbisDataBackend, + the datatype is set to ``sqlvector`` or ``vector``. + + :param db: The datalayer instance. + """ super().pre_create(db) if isinstance(db.databackend, IbisDataBackend): if self.datatype is None: @@ -131,6 +139,10 @@ def pre_create(self, db): @retry def predict_one(self, X: str): + """Generates embeddings from text. + + :param X: The text to generate embeddings for. + """ e = self.syncClient.embeddings.create( input=X, model=self.model, **self.predict_kwargs ) @@ -147,7 +159,9 @@ def _predict_a_batch(self, texts: t.List[t.Dict]): @dc.dataclass(kw_only=True) class OpenAIChatCompletion(_OpenAI): """OpenAI chat completion predictor. + {_openai_parameters} + :param batch_size: The batch size to use. :param prompt: The prompt to use to seed the response. """ @@ -166,12 +180,24 @@ def _format_prompt(self, context, X): return prompt + X def pre_create(self, db: Datalayer) -> None: + """Pre creates the model. + + If the datatype is not set and the datalayer is an IbisDataBackend, + the datatype is set to ``dtype('str')``. + + :param db: The datalayer instance. + """ super().pre_create(db) if isinstance(db.databackend, IbisDataBackend) and self.datatype is None: self.datatype = dtype('str') @retry def predict_one(self, X: str, context: t.Optional[str] = None, **kwargs): + """Generates text completions from prompts. + + :param X: The prompt. + :param context: The context to use for the prompt. + """ if context is not None: X = self._format_prompt(context, X) return ( @@ -185,6 +211,10 @@ def predict_one(self, X: str, context: t.Optional[str] = None, **kwargs): ) def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + """Generates text completions from prompts. + + :param dataset: The dataset of prompts. + """ out = [] for i in range(len(dataset)): args, kwargs = self.handle_input_type( @@ -197,9 +227,12 @@ def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: @dc.dataclass(kw_only=True) class OpenAIImageCreation(_OpenAI): """OpenAI image creation predictor. + {_openai_parameters} :param takes_context: Whether the model takes context into account. :param prompt: The prompt to use to seed the response. + :param n: The number of images to generate. + :param response_format: The response format to use. """ signature: t.ClassVar[str] = 'singleton' @@ -212,6 +245,13 @@ class OpenAIImageCreation(_OpenAI): response_format: str = 'b64_json' def pre_create(self, db: Datalayer) -> None: + """Pre creates the model. + + If the datatype is not set and the datalayer is an IbisDataBackend, + the datatype is set to ``dtype('bytes')``. + + :param db: The datalayer instance. + """ super().pre_create(db) if isinstance(db.databackend, IbisDataBackend) and self.datatype is None: self.datatype = dtype('bytes') @@ -222,6 +262,10 @@ def _format_prompt(self, context, X): @retry def predict_one(self, X: str): + """Generates images from text prompts. + + :param X: The text prompt. + """ if self.response_format == 'b64_json': resp = self.syncClient.images.generate( prompt=X, @@ -243,6 +287,10 @@ def predict_one(self, X: str): return requests.get(url).content def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + """Generates images from text prompts. + + :param dataset: The dataset of text prompts. + """ out = [] for i in range(len(dataset)): args, kwargs = self.handle_input_type( @@ -255,9 +303,12 @@ def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: @dc.dataclass(kw_only=True) class OpenAIImageEdit(_OpenAI): """OpenAI image edit predictor. + {_openai_parameters} :param takes_context: Whether the model takes context into account. :param prompt: The prompt to use to seed the response. + :param response_format: The response format to use. + :param n: The number of images to generate. """ __doc__ = __doc__.format(_openai_parameters=_OpenAI.__doc__) @@ -272,6 +323,13 @@ def _format_prompt(self, context): return prompt def pre_create(self, db: Datalayer) -> None: + """Pre creates the model. + + If the datatype is not set and the datalayer is an IbisDataBackend, + the datatype is set to ``dtype('bytes')``. + + :param db: The datalayer instance. + """ super().pre_create(db) if isinstance(db.databackend, IbisDataBackend) and self.datatype is None: self.datatype = dtype('bytes') @@ -283,6 +341,12 @@ def predict_one( mask: t.Optional[t.BinaryIO] = None, context: t.Optional[t.List[str]] = None, ): + """Edits an image. + + :param image: The image to edit. + :param mask: The mask to apply to the image. + :param context: The context to use for the prompt. + """ if context is not None: self.prompt = self._format_prompt(context) @@ -318,6 +382,10 @@ def predict_one( return out def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + """Predicts the output for a dataset of images. + + :param dataset: The dataset of images. + """ out = [] for i in range(len(dataset)): args, kwargs = self.handle_input_type( @@ -330,6 +398,7 @@ def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: @dc.dataclass(kw_only=True) class OpenAIAudioTranscription(_OpenAI): """OpenAI audio transcription predictor. + {_openai_parameters} :param takes_context: Whether the model takes context into account. :param prompt: The prompt to guide the model's style. Should contain ``{context}``. @@ -341,13 +410,24 @@ class OpenAIAudioTranscription(_OpenAI): prompt: str = '' def pre_create(self, db: Datalayer) -> None: + """Pre creates the model. + + If the datatype is not set and the datalayer is an IbisDataBackend, + the datatype is set to ``dtype('str')``. + + :param db: The datalayer instance. + """ super().pre_create(db) if isinstance(db.databackend, IbisDataBackend) and self.datatype is None: self.datatype = dtype('str') @retry def predict_one(self, file: t.BinaryIO, context: t.Optional[t.List[str]] = None): - "Converts a file-like Audio recording to text." + """Converts a file-like Audio recording to text. + + :param file: The file-like Audio recording to transcribe. + :param context: The context to use for the prompt. + """ if context is not None: self.prompt = self.prompt.format(context='\n'.join(context)) return self.syncClient.audio.transcriptions.create( @@ -359,7 +439,7 @@ def predict_one(self, file: t.BinaryIO, context: t.Optional[t.List[str]] = None) @retry def _predict_a_batch(self, files: t.List[t.BinaryIO], **kwargs): - "Converts multiple file-like Audio recordings to text." + """Converts multiple file-like Audio recordings to text.""" resps = [ self.syncClient.audio.transcriptions.create( file=file, model=self.model, **self.predict_kwargs @@ -372,9 +452,11 @@ def _predict_a_batch(self, files: t.List[t.BinaryIO], **kwargs): @dc.dataclass(kw_only=True) class OpenAIAudioTranslation(_OpenAI): """OpenAI audio translation predictor. + {_openai_parameters} :param takes_context: Whether the model takes context into account. :param prompt: The prompt to guide the model's style. Should contain ``{context}``. + :param batch_size: The batch size to use. """ signature: t.ClassVar[str] = 'singleton' @@ -386,6 +468,10 @@ class OpenAIAudioTranslation(_OpenAI): batch_size: int = 1 def pre_create(self, db: Datalayer) -> None: + """Translates a file-like Audio recording to English. + + :param db: The datalayer to use for the model. + """ super().pre_create(db) if isinstance(db.databackend, IbisDataBackend) and self.datatype is None: self.datatype = dtype('str') @@ -396,7 +482,11 @@ def predict_one( file: t.BinaryIO, context: t.Optional[t.List[str]] = None, ): - "Translates a file-like Audio recording to English." + """Translates a file-like Audio recording to English. + + :param file: The file-like Audio recording to translate. + :param context: The context to use for the prompt. + """ if context is not None: self.prompt = self.prompt.format(context='\n'.join(context)) return ( @@ -410,7 +500,7 @@ def predict_one( @retry def _predict_a_batch(self, files: t.List[t.BinaryIO]): - "Translates multiple file-like Audio recordings to English." + """Translates multiple file-like Audio recordings to English.""" # TODO use async or threads resps = [ self.syncClient.audio.translations.create( @@ -419,120 +509,3 @@ def _predict_a_batch(self, files: t.List[t.BinaryIO]): for file in files ] return [resp.text for resp in resps] - - -@dc.dataclass -class BaseOpenAILLM(BaseLLMAPI): - """ - :param openai_api_base: The base URL for the OpenAI API. - :param openai_api_key: The API key to use for the OpenAI API. - :param model_name: The name of the model to use. - :param chat: Whether to use the chat API or the completion API. Defaults to False. - :param system_prompt: The prompt to use for the system. - :param user_role: The role to use for the user. - :param system_role: The role to use for the system. - {parent_doc} - """ - - __doc__ = __doc__.format(parent_doc=BaseLLMAPI.__doc__) - - identifier: str = dc.field(default="") - openai_api_base: str = "https://api.openai.com/v1" - openai_api_key: t.Optional[str] = None - model_name: str = "gpt-3.5-turbo" - chat: bool = True - system_prompt: t.Optional[str] = None - user_role: str = "user" - system_role: str = "system" - - def __post_init__(self, artifacts): - self.api_url = self.openai_api_base - self.identifier = self.identifier or self.model_name - super().__post_init__(artifacts) - - def init(self): - try: - from openai import OpenAI - except ImportError: - raise Exception("You must install openai with command 'pip install openai'") - - params = { - "api_key": self.openai_api_key, - "base_url": self.openai_api_base, - } - - self.client = OpenAI(**params) - model_set = self.get_model_set() - assert ( - self.model_name in model_set - ), f"model_name {self.model_name} is not in model_set {model_set}" - - def get_model_set(self): - model_list = self.client.models.list() - return sorted({model.id for model in model_list.data}) - - def _generate(self, prompt: str, **kwargs: Any) -> str: - if self.chat: - return self._chat_generate(prompt, **kwargs) - else: - return self._prompt_generate(prompt, **kwargs) - - def _prompt_generate(self, prompt: str, **kwargs: Any) -> str: - """ - Generate a completion for a given prompt with prompt format. - """ - completion = self.client.completions.create( - model=self.model_name, - prompt=prompt, - **self.get_kwargs( - self.client.completions.create, kwargs, self.predict_kwargs - ), - ) - return completion.choices[0].text - - def _chat_generate(self, content: str, **kwargs: Any) -> str: - """ - Generate a completion for a given prompt with chat format. - :param prompt: The prompt to generate a completion for. - :param kwargs: Any additional arguments to pass to the API. - """ - messages = kwargs.get("messages", []) - - if self.system_prompt: - messages = [ - {"role": self.system_role, "content": self.system_prompt} - ] + messages - - messages.append({"role": self.user_role, "content": content}) - completion = self.client.chat.completions.create( - messages=messages, - model=self.model_name, - **self.get_kwargs( - self.client.chat.completions.create, kwargs, self.predict_kwargs - ), - ) - return completion.choices[0].message.content - - -@dc.dataclass -class OpenAILLM(BaseOpenAILLM): - """ - OpenAI chat completion predictor. - {parent_doc} - """ - - __doc__ = __doc__.format(parent_doc=BaseOpenAILLM.__doc__) - - def __post_init__(self, artifacts): - """Set model name.""" - # only support chat mode - self.chat = True - super().__post_init__(artifacts) - - @retry - def get_model_set(self): - return super().get_model_set() - - @retry - def _generate(self, *args, **kwargs) -> str: - return super()._generate(*args, **kwargs) diff --git a/superduperdb/ext/pillow/encoder.py b/superduperdb/ext/pillow/encoder.py index 0faa3f6c15..b7af9da6aa 100644 --- a/superduperdb/ext/pillow/encoder.py +++ b/superduperdb/ext/pillow/encoder.py @@ -12,14 +12,18 @@ def encode_pil_image(x, info: t.Optional[t.Dict] = None): + """Encode a `PIL.Image` to bytes. + + :param x: The image to encode. + :param info: Additional information. + """ buffer = io.BytesIO() x.save(buffer, 'png') return buffer.getvalue() class DecoderPILImage: - """ - Decoder to convert `bytes` back into a `PIL.Image` class + """Decoder to convert `bytes` back into a `PIL.Image` class. :param handle_exceptions: return a blank image if failure """ @@ -28,6 +32,11 @@ def __init__(self, handle_exceptions: bool = True): self.handle_exceptions = handle_exceptions def __call__(self, bytes, info: t.Optional[t.Dict] = None): + """Decode a `PIL.Image` from bytes. + + :param bytes: The bytes to decode. + :param info: Additional information. + """ try: return PIL.Image.open(io.BytesIO(bytes)) except Exception as e: @@ -77,6 +86,12 @@ def __call__(self, bytes, info: t.Optional[t.Dict] = None): def image_type( identifier: str, encodable: str = 'lazy_artifact', media_type: str = 'image/png' ): + """Create a `DataType` for an image. + + :param identifier: The identifier for the data type. + :param encodable: The encodable type. + :param media_type: The media type. + """ return DataType( identifier, encoder=encode_pil_image, diff --git a/superduperdb/ext/sentence_transformers/model.py b/superduperdb/ext/sentence_transformers/model.py index 922b3c4d75..68686e7a40 100644 --- a/superduperdb/ext/sentence_transformers/model.py +++ b/superduperdb/ext/sentence_transformers/model.py @@ -16,6 +16,16 @@ @dc.dataclass(kw_only=True) class SentenceTransformer(Model, _DeviceManaged): + """A model for sentence embeddings using `sentence-transformers`. + + :param object: The SentenceTransformer object to use. + :param model: The model name, e.g. 'all-MiniLM-L6-v2'. + :param device: The device to use, e.g. 'cpu' or 'cuda'. + :param preprocess: The preprocessing function to apply to the input. + :param postprocess: The postprocessing function to apply to the output. + :param signature: The signature of the model. + """ + _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, 'DataType']]] = ( ('object', dill_lazy), ) @@ -36,6 +46,7 @@ class SentenceTransformer(Model, _DeviceManaged): @classmethod def handle_integration(cls, kwargs): + """Handle integration of the model.""" if isinstance(kwargs.get('preprocess'), str): kwargs['preprocess'] = Code(kwargs['preprocess']) if isinstance(kwargs.get('postprocess'), str): @@ -52,10 +63,15 @@ def __post_init__(self, artifacts): self.object = _SentenceTransformer(self.model, device=self.device) def init(self): + """Initialize the model.""" super().init() self.to(self.device) def to(self, device): + """Move the model to a device. + + :param device: The device to move to, e.g. 'cpu' or 'cuda'. + """ self.object = self.object.to(device) self.object._target_device = device @@ -73,6 +89,12 @@ def _deep_flat_encode(self, cache): @ensure_initialized def predict_one(self, X, *args, **kwargs): + """Predict on a single input. + + :param X: The input to predict on. + :param args: Additional positional arguments, which are passed to the model. + :param kwargs: Additional keyword arguments, which are passed to the model. + """ if self.preprocess is not None: X = self.preprocess(X) @@ -84,6 +106,10 @@ def predict_one(self, X, *args, **kwargs): @ensure_initialized def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + """Predict on a dataset. + + :param dataset: The dataset to predict on. + """ if self.preprocess is not None: dataset = list(map(self.preprocess, dataset)) # type: ignore[arg-type] assert self.object is not None diff --git a/superduperdb/ext/sklearn/model.py b/superduperdb/ext/sklearn/model.py index 468e36cd1a..ba74f1e258 100644 --- a/superduperdb/ext/sklearn/model.py +++ b/superduperdb/ext/sklearn/model.py @@ -21,6 +21,13 @@ @dc.dataclass(kw_only=True) class SklearnTrainer(Trainer): + """A trainer for `sklearn` models. + + :param fit_params: The parameters to pass to `fit`. + :param predict_params: The parameters to pass to `predict + :param y_preprocess: The preprocessing function to use for the target. + """ + fit_params: t.Dict = dc.field(default_factory=dict) predict_params: t.Dict = dc.field(default_factory=dict) y_preprocess: t.Optional[t.Callable] = None @@ -64,6 +71,13 @@ def fit( train_dataset: QueryDataset, valid_dataset: QueryDataset, ): + """Fit the model. + + :param model: Model + :param db: Datalayer + :param train_dataset: Training dataset + :param valid_dataset: Validation dataset + """ train_X, train_y = self._get_data_from_dataset( dataset=train_dataset, X=self.key ) @@ -76,6 +90,16 @@ def fit( @dc.dataclass(kw_only=True) class Estimator(Model, _Fittable): + """Estimator model. + + This is a model that can be trained and used for prediction. + + :param object: The estimator object from `sklearn`. + :param trainer: The trainer to use. + :param preprocess: The preprocessing function to use. + :param postprocess: The postprocessing function to use. + """ + _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, DataType]]] = ( ('object', pickle_serializer), ) @@ -96,6 +120,11 @@ def schedule_jobs( db: Datalayer, dependencies: t.Sequence[Job] = (), ) -> t.Sequence[t.Any]: + """Schedule jobs for the model. + + :param db: The datalayer to use. + :param dependencies: The dependencies to wait for. + """ jobs = _Fittable.schedule_jobs(self, db, dependencies=dependencies) if self.validation is not None: jobs = self.validation.schedule_jobs( @@ -104,6 +133,10 @@ def schedule_jobs( return jobs def predict_one(self, X): + """Predict on a single input. + + :param X: The input to predict on. + """ X = X[None, :] if self.preprocess is not None: X = self.preprocess(X) @@ -113,6 +146,10 @@ def predict_one(self, X): return X def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + """Predict on a dataset. + + :param dataset: The dataset to predict on. + """ if self.preprocess is not None: inputs = [] for i in range(len(dataset)): diff --git a/superduperdb/ext/torch/encoder.py b/superduperdb/ext/torch/encoder.py index 0f853395bf..e18b07623c 100644 --- a/superduperdb/ext/torch/encoder.py +++ b/superduperdb/ext/torch/encoder.py @@ -8,21 +8,42 @@ class EncodeTensor: + """Encode a tensor to bytes. + + :param dtype: The dtype of the tensor, eg. torch.float32 + """ + def __init__(self, dtype): self.dtype = dtype def __call__(self, x, info: t.Optional[t.Dict] = None): + """Encode a tensor to bytes. + + :param x: The tensor to encode. + :param info: Additional information. + """ if x.dtype != self.dtype: raise TypeError(f"dtype was {x.dtype}, expected {self.dtype}") return memoryview(x.numpy()).tobytes() class DecodeTensor: + """Decode a tensor from bytes. + + :param dtype: The dtype of the tensor, eg. torch.float32 + :param shape: The shape of the tensor, eg. (3, 4) + """ + def __init__(self, dtype, shape): self.dtype = torch.randn(1).type(dtype).numpy().dtype self.shape = shape def __call__(self, bytes, info: t.Optional[t.Dict] = None): + """Decode a tensor from bytes. + + :param bytes: The bytes to decode. + :param info: Additional information. + """ array = numpy.frombuffer(bytes, dtype=self.dtype).reshape(self.shape) return torch.from_numpy(array) @@ -33,6 +54,7 @@ def tensor(dtype, shape: t.Sequence, bytes_encoding: t.Optional[str] = None): :param dtype: The dtype of the tensor. :param shape: The shape of the tensor. + :param bytes_encoding: The bytes encoding to use. """ return DataType( identifier=f"{str(dtype)}[{str_shape(shape)}]", @@ -44,18 +66,27 @@ def tensor(dtype, shape: t.Sequence, bytes_encoding: t.Optional[str] = None): class TorchDataTypeFactory(DataTypeFactory): + """Factory for torch datatypes. + + It's used for registering the auto schema. + """ + @staticmethod def check(data: t.Any) -> bool: - """ - Check if the data is a torch tensor. + """Check if the data is a torch tensor. + It's used for registering the auto schema. + + :param data: Data to check """ return isinstance(data, torch.Tensor) @staticmethod def create(data: t.Any) -> DataType: - """ - Create a torch tensor datatype. + """Create a torch tensor datatype. + It's used for registering the auto schema. + + :param data: Data to create the datatype from """ return tensor(data.dtype, data.shape) diff --git a/superduperdb/ext/torch/model.py b/superduperdb/ext/torch/model.py index caa96c7700..9a53910f0b 100644 --- a/superduperdb/ext/torch/model.py +++ b/superduperdb/ext/torch/model.py @@ -29,12 +29,13 @@ from superduperdb.jobs.job import Job -def torchmodel(cls): - """ +def torchmodel(class_obj): + """A decorator to convert a `torch.nn.Module` into a `TorchModel`. + Decorate a `torch.nn.Module` so that when it is invoked, the result is a `TorchModel`. - :param cls: Class to decorate + :param class_obj: Class to decorate """ def factory( @@ -54,7 +55,7 @@ def factory( ): return TorchModel( identifier=identifier, - object=cls(*args, **kwargs), + object=class_obj(*args, **kwargs), preprocess=preprocess, postprocess=postprocess, collate_fn=collate_fn, @@ -72,10 +73,11 @@ def factory( class BasicDataset(data.Dataset): """ - Basic database iterating over a list of documents and applying a transformation + Basic database iterating over a list of documents and applying a transformation. - :param documents: documents - :param transform: function + :param items: items, typically documents + :param transform: function, typically a preprocess function + :param signature: signature of the transform function """ def __init__(self, items, transform, signature): @@ -99,6 +101,27 @@ def __getitem__(self, item): @dc.dataclass(kw_only=True) class TorchModel(Model, _Fittable, _DeviceManaged): + """Torch model. + + This class is a wrapper around a PyTorch model. + + :param object: Torch model, e.g. `torch.nn.Module` + :param preprocess: Preprocess function, the function to apply to the input + :param preprocess_signature: The signature of the preprocess function + :param postprocess: The postprocess function, the function to apply to the output + :param postprocess_signature: The signature of the postprocess function + :param forward_method: The forward method, the method to call on the model + :param forward_signature: The signature of the forward method + :param train_forward_method: Train forward method, the method to call on the model + :param train_forward_signature: The signature of the train forward method + :param train_preprocess: Train preprocess function, + the function to apply to the input + :param train_preprocess_signature: The signature of the train preprocess function + :param collate_fn: The collate function for the dataloader + :param optimizer_state: The optimizer state + :param loader_kwargs: The kwargs for the dataloader + """ + _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, DataType]]] = ( ('object', dill_serializer), ) @@ -128,12 +151,17 @@ def __post_init__(self, artifacts): @property def signature(self): + """Get the signature of the model.""" if self.preprocess: return self.preprocess_signature return self.forward_signature @signature.setter def signature(self, signature): + """Set the signature of the model. + + :param signature: Signature + """ if self.preprocess: self.preprocess_signature = signature else: @@ -144,40 +172,66 @@ def schedule_jobs( db: Datalayer, dependencies: t.Sequence['Job'] = (), ) -> t.Sequence[t.Any]: + """Schedule jobs for the model. + + :param db: Datalayer + :param dependencies: Dependencies + """ jobs = _Fittable.schedule_jobs(self, db, dependencies=dependencies) return jobs @property def inputs(self) -> CallableInputs: + """Get the inputs callable for the model.""" return CallableInputs( self.object.forward if not self.preprocess else self.preprocess, {} ) def to(self, device): + """Move the model to a device. + + :param device: Device + """ self.object.to(device) def save(self, db: Datalayer): + """Save the model to the database. + + :param db: Datalayer + """ with self.saving(): db.replace(object=self, upsert=True) @contextmanager def evaluating(self): + """Context manager for evaluating the model. + + This context manager ensures that the model is in evaluation mode + """ yield eval(self) def train(self): + """Set the model to training mode.""" return self.object.train() def eval(self): + """Set the model to evaluation mode.""" return self.object.eval() def parameters(self): + """Get the model parameters.""" return self.object.parameters() def state_dict(self): + """Get the model state dict.""" return self.object.state_dict() @contextmanager def saving(self): + """Context manager for saving the model. + + This context manager ensures that the model is in evaluation mode + """ was_training = self.object.training try: self.object.eval() @@ -208,6 +262,11 @@ def __setstate__(self, state): @ensure_initialized def predict_one(self, *args, **kwargs): + """Predict on a single input. + + :param args: Input arguments + :param kwargs: Input keyword arguments + """ with torch.no_grad(), eval(self.object): if self.preprocess is not None: out = self.preprocess(*args, **kwargs) @@ -226,6 +285,10 @@ def predict_one(self, *args, **kwargs): @ensure_initialized def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + """Predict on a dataset. + + :param dataset: Dataset + """ with torch.no_grad(), eval(self.object): inputs = BasicDataset( items=dataset, @@ -253,6 +316,11 @@ def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: return out def train_forward(self, X, y=None): + """The forward method for training. + + :param X: Input + :param y: Target + """ X = X.to(self.device) if y is not None: y = y.to(self.device) @@ -271,8 +339,7 @@ def train_forward(self, X, y=None): def unpack_batch(args): - """ - Unpack a batch into lines of tensor output. + """Unpack a batch into lines of tensor output. :param args: a batch of model outputs @@ -296,7 +363,6 @@ def unpack_batch(args): >>> out[1]['a']['b'].shape torch.Size([10]) """ - if isinstance(args, torch.Tensor): return [args[i] for i in range(args.shape[0])] @@ -314,8 +380,7 @@ def unpack_batch(args): def create_batch(args): - """ - Create a singleton batch in a manner similar to the PyTorch dataloader + """Create a singleton batch in a manner similar to the PyTorch dataloader. :param args: single data point for batching diff --git a/superduperdb/ext/torch/training.py b/superduperdb/ext/torch/training.py index c5b4ee5d0f..ef15709df2 100644 --- a/superduperdb/ext/torch/training.py +++ b/superduperdb/ext/torch/training.py @@ -28,6 +28,8 @@ class TorchTrainer(Trainer): :param optimizer_cls: Optimizer class :param optimizer_kwargs: Kwargs for the optimizer :param optimizer_state: Latest state of the optimizer for contined training + :param collate_fn: Collate function for the dataloader + :param metric_values: Metric values """ objective: t.Callable @@ -44,6 +46,10 @@ class TorchTrainer(Trainer): metric_values: t.Dict = dc.field(default_factory=dict) def get_optimizers(self, model): + """Get the optimizers for the model. + + :param model: Model + """ cls_ = getattr(torch.optim, self.optimizer_cls) optimizer = cls_(model.parameters(), **self.optimizer_kwargs) if self.optimizer_state is not None: @@ -65,6 +71,13 @@ def fit( train_dataset: QueryDataset, valid_dataset: QueryDataset, ): + """Fit the model. + + :param model: Model + :param db: Datalayer + :param train_dataset: Training dataset + :param valid_dataset: Validation dataset + """ train_dataloader = self._create_loader(train_dataset) valid_dataloader = self._create_loader(valid_dataset) return self._fit_with_dataloaders( @@ -75,6 +88,12 @@ def fit( ) def take_step(self, model, batch, optimizers): + """Take a step in the optimization. + + :param model: Model + :param batch: Batch of data + :param optimizers: Optimizers + """ if self.signature == '*args': outputs = model.train_forward(*batch) elif self.signature == 'singleton': @@ -92,6 +111,11 @@ def take_step(self, model, batch, optimizers): return objective_value def compute_validation_objective(self, model, valid_dataloader): + """Compute the validation objective. + + :param model: Model + :param valid_dataloader: Validation dataloader to use + """ objective_values = [] with model.evaluating(), torch.no_grad(): for batch in valid_dataloader: @@ -142,11 +166,19 @@ def _fit_with_dataloaders( iteration += 1 def append_metrics(self, d: t.Dict[str, float]) -> None: + """Append metrics to the metric_values dict. + + :param d: Metrics to append + """ if self.metric_values is not None: for k, v in d.items(): self.metric_values.setdefault(k, []).append(v) def stopping_criterion(self, iteration): + """Check if the training should stop. + + :param iteration: Current iteration + """ max_iterations = self.max_iterations no_improve_then_stop = self.no_improve_then_stop if isinstance(max_iterations, int) and iteration >= max_iterations: @@ -162,6 +194,7 @@ def stopping_criterion(self, iteration): return False def saving_criterion(self): + """Check if the model should be saved.""" if self.listen == 'objective': to_listen = [-x for x in self.metric_values['objective']] else: @@ -171,6 +204,10 @@ def saving_criterion(self): return False def log(self, **kwargs): + """Log the training progress. + + :param kwargs: Key-value pairs to log + """ out = '' for k, v in kwargs.items(): if isinstance(v, dict): diff --git a/superduperdb/ext/torch/utils.py b/superduperdb/ext/torch/utils.py index ea3854783a..50dd6fe17f 100644 --- a/superduperdb/ext/torch/utils.py +++ b/superduperdb/ext/torch/utils.py @@ -12,7 +12,7 @@ def device_of(module: Module) -> t.Union[_device, str]: """ Get device of a model. - :param model: PyTorch model + :param module: PyTorch model """ try: return next(iter(module.state_dict().values())).device diff --git a/superduperdb/ext/transformers/model.py b/superduperdb/ext/transformers/model.py index d23699761c..eea3c53199 100644 --- a/superduperdb/ext/transformers/model.py +++ b/superduperdb/ext/transformers/model.py @@ -51,6 +51,18 @@ def _save_checkpoint(self, model, trial, metrics=None): @dc.dataclass(kw_only=True) class TransformersTrainer(TrainingArguments, Trainer): + """Trainer for transformers models. + + It's used to train the transformers models. + + :param signature: signature, default is '**kwargs' + :param output_dir: output directory + :param data_collator: data collator + :param callbacks: callbacks for training + :param optimizers: optimizers for training + :param preprocess_logits_for_metrics: preprocess logits for metrics + """ + _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, DataType]]] = ( ('data_collator', dill_serializer), ('callbacks', dill_serializer), @@ -78,6 +90,7 @@ def __post_init__(self, artifacts): @property def native_arguments(self): + """Get native arguments of TrainingArguments.""" _TRAINING_DEFAULTS = { k: v for k, v in TrainingArguments('_tmp').to_dict().items() @@ -136,6 +149,13 @@ def fit( train_dataset: QueryDataset, valid_dataset: QueryDataset, ): + """Fit the model. + + :param model: model + :param db: Datalayer instance + :param train_dataset: training dataset + :param valid_dataset: validation dataset + """ trainer = self._build_trainer( model=model, db=db, @@ -147,12 +167,21 @@ def fit( @dc.dataclass(kw_only=True) class TextClassificationPipeline(Model, _Fittable, _DeviceManaged): - """ - A wrapper for ``transformers.Pipeline`` + """A wrapper for ``transformers.Pipeline``. + + :param tokenizer_name: tokenizer name + :param tokenizer_cls: tokenizer class, e.g. ``transformers.AutoTokenizer`` + :param tokenizer_kwargs: tokenizer kwargs, will pass to ``tokenizer_cls`` + :param model_name: model name, will pass to ``model_cls`` + :param model_cls: model class, e.g. ``AutoModelForSequenceClassification`` + :param model_kwargs: model kwargs, will pass to ``model_cls`` + :param pipeline: pipeline instance, default is None, will build when None + :param task: task of the pipeline + + Example: + ------- + >>> model = TextClassificationPipeline(...) - >>> model = TextClassificationPipeline(...) # 123456 - >>> ,sd.s,d.s,ds - >>> """ _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, 'DataType']]] = ( @@ -185,9 +214,17 @@ def __post_init__(self, artifacts): super().__post_init__(artifacts) def predict_one(self, text: str): + """Predict the class of a single text. + + :param text: a text + """ return self.pipeline(text)[0] def predict(self, dataset: t.Union[t.List, QueryDataset]) -> t.List: + """Predict the class of a list of text. + + :param dataset: a list of text + """ text = [dataset[i] for i in range(len(dataset))] return self.pipeline(text) @@ -199,14 +236,13 @@ class LLM(BaseLLM, _Fittable): :param identifier: model identifier :param model_name_or_path: model name or path - :param bits: quantization bits, [4, 8], default is None :param adapter_id: adapter id, default is None Add a adapter to the base model for inference. When model_name_or_path, bits, model_kwargs, tokenizer_kwargs are the same, will share the same base model and tokenizer cache. :param model_kwargs: model kwargs, all the kwargs will pass to `transformers.AutoModelForCausalLM.from_pretrained` - :param tokenizer_kwagrs: tokenizer kwargs, + :param tokenizer_kwargs: tokenizer kwargs, all the kwargs will pass to `transformers.AutoTokenizer.from_pretrained` :param prompt_template: prompt template, default is "{input}" :param prompt_func: prompt function, default is None @@ -215,12 +251,11 @@ class LLM(BaseLLM, _Fittable): identifier: str = "" model_name_or_path: t.Optional[str] = None adapter_id: t.Optional[t.Union[str, Checkpoint]] = None - object: t.Optional[transformers.Trainer] = None model_kwargs: t.Dict = dc.field(default_factory=dict) tokenizer_kwargs: t.Dict = dc.field(default_factory=dict) prompt_template: str = "{input}" prompt_func: t.Optional[t.Callable] = None - signature: str = 'singleton' + signature: t.ClassVar[Signature] = 'singleton' # Save models and tokenizers cache for sharing when using multiple models _model_cache: t.ClassVar[dict] = {} @@ -252,10 +287,17 @@ def from_pretrained( predict_kwargs=None, **kwargs, ): - """ - A new function to create a LLM model from from_pretrained function. + """A new function to create a LLM model from from_pretrained function. + Allow the user to directly replace: AutoModelForCausalLM.from_pretrained -> LLM.from_pretrained + + :param model_name_or_path: model name or path + :param identifier: model identifier + :param prompt_template: prompt template, default is "{input}" + :param prompt_func: prompt function, default is None + :param predict_kwargs: predict kwargs, default is None + :param kwargs: additional keyword arguments, all the kwargs will pass to `LLM` """ model_kwargs = kwargs.copy() tokenizer_kwargs = {} @@ -273,6 +315,11 @@ def from_pretrained( def init_pipeline( self, adapter_id: t.Optional[str] = None, load_adapter_directly: bool = False ): + """Initialize pipeline. + + :param adapter_id: adapter id + :param load_adapter_directly: load adapter directly + """ # Do not update model state here model_kwargs = self.model_kwargs.copy() tokenizer_kwargs = self.tokenizer_kwargs.copy() @@ -334,6 +381,10 @@ def init_pipeline( return pipeline("text-generation", model=model, tokenizer=tokenizer) def init(self): + """Initialize the model. + + If adapter_id is provided, will load the adapter to the model. + """ db = self.db real_adapter_id = None @@ -364,6 +415,10 @@ def init(self): self.pipeline = self.init_pipeline(real_adapter_id) def handle_chekpoint(self, db): + """Handle checkpoint identifier. + + :param db: Datalayer instance + """ if isinstance(self.adapter_id, str): # match checkpoint:/// if Checkpoint.check_uri(self.adapter_id): @@ -376,6 +431,11 @@ def handle_chekpoint(self, db): @ensure_initialized def predict_one(self, X, **kwargs): + """Generate text from a single prompt. + + :param X: a prompt + :param kwargs: additional keyword arguments + """ X = self._process_inputs(X, **kwargs) kwargs.pop("context", None) results = self._batch_generate([X], **kwargs) @@ -383,6 +443,11 @@ def predict_one(self, X, **kwargs): @ensure_initialized def predict(self, dataset: t.Union[t.List, QueryDataset], **kwargs) -> t.List: + """Generate text from a list of prompts. + + :param dataset: a list of prompts + :param kwargs: additional keyword arguments + """ dataset = [ self._process_inputs(dataset[i], **kwargs) for i in range(len(dataset)) ] @@ -395,8 +460,8 @@ def _process_inputs(self, X: t.Any, **kwargs) -> str: return X def _batch_generate(self, prompts: t.List[str], **kwargs) -> t.List[str]: - """ - Generate text. + """Generate text. + Can overwrite this method to support more inference methods. """ kwargs = {**self.predict_kwargs, **kwargs.copy()} @@ -418,6 +483,11 @@ def schedule_jobs( db: Datalayer, dependencies: t.Sequence[Job] = (), ) -> t.Sequence[t.Any]: + """Schedule jobs for LLM model. + + :param db: Datalayer instance + :param dependencies: dependencies + """ jobs = _Fittable.schedule_jobs(self, db, dependencies=dependencies) if self.validation is not None: jobs = self.validation.schedule_jobs( @@ -426,6 +496,11 @@ def schedule_jobs( return jobs def add_adapter(self, model_id, adapter_name: str): + """Add adapter to the model. + + :param model_id: model id + :param adapter_name: adapter name + """ # TODO: Support lora checkpoint from s3 try: from peft import PeftModel @@ -448,6 +523,10 @@ def add_adapter(self, model_id, adapter_name: str): self.model.load_adapter(model_id, adapter_name) def post_create(self, db: "Datalayer") -> None: + """Post create hook for LLM model. + + :param db: Datalayer instance + """ # TODO: Do not make sense to add this logic here, # Need a auto DataType to handle this from superduperdb.backends.ibis.data_backend import IbisDataBackend diff --git a/superduperdb/ext/transformers/training.py b/superduperdb/ext/transformers/training.py index f225cc5e00..3ff136a447 100644 --- a/superduperdb/ext/transformers/training.py +++ b/superduperdb/ext/transformers/training.py @@ -34,6 +34,12 @@ @dc.dataclass(kw_only=True) class Checkpoint(Component): + """Checkpoint component for saving the model checkpoint. + + :param path: The path to the checkpoint. + :param step: The step of the checkpoint. + """ + path: t.Optional[str] step: int _artifacts: t.ClassVar[t.Sequence[t.Tuple[str, DataType]]] = (("path", file_lazy),) @@ -45,14 +51,23 @@ def __post_init__(self, artifacts): @property def uri(self): + """Get the uri of the checkpoint.""" return f"checkpoint://{self.identifier}/{self.step}" @staticmethod def check_uri(uri): + """Check if the uri is a valid checkpoint uri. + + :param uri: The uri to check. + """ return re.match(r"^checkpoint://.*?/\d+$", uri) is not None @staticmethod def parse_uri(uri): + """Parse the uri to get the identifier and step. + + :param uri: The uri to parse. + """ if not Checkpoint.check_uri(uri): raise ValueError(f"Invalid uri: {uri}") *_, identifier, step = uri.split("/") @@ -60,6 +75,18 @@ def parse_uri(uri): class LLMCallback(TrainerCallback): + """LLM Callback for logging training process to db. + + This callback will save the checkpoint to db after each epoch. + If the save_total_limit is set, will remove the oldest checkpoint. + + :param cfg: The configuration to use. + :param identifier: The identifier to use. + :param db: The datalayer to use. + :param llm: The LLM model to use. + :param experiment_id: The experiment id to use. + """ + def __init__( self, cfg: t.Optional["Config"] = None, @@ -87,7 +114,13 @@ def __init__( ) def on_save(self, args, state, control, **kwargs): - """Event called after a checkpoint save.""" + """Event called after a checkpoint save. + + :param args: The training arguments from transformers. + :param state: The training state from transformers. + :param control: The training control from transformers. + :param kwargs: Other keyword arguments from transformers. + """ if not state.is_world_process_zero: return @@ -116,7 +149,13 @@ def on_save(self, args, state, control, **kwargs): self.db.remove("checkpoint", self.experiment_id, version, force=True) def on_evaluate(self, args, state, control, **kwargs): - """Event called after an evaluation.""" + """Event called after an evaluation. + + :param args: The training arguments from transformers. + :param state: The training state from transformers. + :param control: The training control from transformers. + :param kwargs: Other keyword arguments from transformers. + """ if not state.is_world_process_zero: return @@ -124,6 +163,13 @@ def on_evaluate(self, args, state, control, **kwargs): self.llm.append_metrics(state.log_history[-1]) def on_train_end(self, args, state, control, **kwargs): + """Event called after training ends. + + :param args: The training arguments from transformers. + :param state: The training state from transformers. + :param control: The training control from transformers. + :param kwargs: Other keyword arguments from transformers. + """ self.check_init() # update the llm to db after training, will save the adapter_id and metrics @@ -139,6 +185,7 @@ def on_train_end(self, args, state, control, **kwargs): self.db.replace(self.llm) def check_init(self): + """Check the initialization of the callback.""" # Only check this in the world_rank 0 process # Rebuild datalayer for the new process if self.db is None: @@ -150,35 +197,25 @@ def check_init(self): @dc.dataclass(kw_only=True) class LLMTrainer(TrainingArguments, SuperDuperTrainer): - """ - LLM Training Arguments. + """LLM Training Arguments. + Inherits from :class:`transformers.TrainingArguments`. {training_arguments_doc} - use_lora (`bool`, *optional*, defaults to True): - Whether to use LoRA training. - lora_r (`int`, *optional*, defaults to 8): - Lora R dimension. - - lora_alpha (`int`, *optional*, defaults to 16): - Lora alpha. - - lora_dropout (`float`, *optional*, defaults to 0.05): - Lora dropout. - - lora_target_modules (`List[str]`, *optional*, defaults to None): - Lora target modules. If None, will be automatically inferred. - - lora_bias (`str`, *optional*, defaults to "none"): - Lora bias. - - max_seq_length (`int`, *optional*, defaults to 512): - Maximum source sequence length during training. - log_to_db (`bool`, *optional*, defaults to True): - Log training to db. - If True, will log checkpoint to superduperdb, - but need ray cluster can access to db. - If can't access to db, please set it to False. + :param output_dir: The output directory to use. + :param use_lora: Whether to use LoRA training. + :param lora_r: Lora R dimension. + :param lora_alpha: Lora alpha. + :param lora_dropout: Lora dropout. + :param lora_target_modules: Lora target modules. + :param lora_bias: Lora bias. + :param bits: The bits to use. + :param max_seq_length: Maximum source sequence length during training. + :param setup_chat_format: Whether to setup chat format. + :param log_to_db: Whether to log training to db. + :param training_kwargs: The training kwargs to use, will be passed to Trainer. + :param num_gpus: The number of GPUs to use, if None, will use all GPUs. + :param ray_configs: The ray configs to use. """ __doc__ = __doc__.format(training_arguments_doc=TrainingArguments.__doc__) @@ -205,9 +242,11 @@ def __post_init__(self, artifacts): return SuperDuperTrainer.__post_init__(self, artifacts) def build(self): + """Build the training arguments.""" super().__post_init__() def build_training_args(self): + """Build the training arguments.""" _TRAINING_DEFAULTS = { k: v for k, v in TrainingArguments('_tmp').to_dict().items() @@ -218,6 +257,13 @@ def build_training_args(self): @staticmethod def get_compute_metrics(metrics): + """Get the compute metrics function. + + :param metrics: List of callable metric functions. + Each function should take logits and labels as input + and return a metric value. + + """ if not metrics: return None @@ -231,6 +277,11 @@ def compute_metrics(eval_preds): return compute_metrics def prepare_dataset(self, model, dataset: QueryDataset): + """Prepare the dataset for training. + + :param model: The model to use. + :param dataset: The dataset to prepare. + """ if isinstance(self.key, str): dataset.transform = lambda x: {self.key: x} @@ -241,6 +292,13 @@ def fit( train_dataset: t.Union[QueryDataset, NativeDataset], valid_dataset: t.Union[QueryDataset, NativeDataset], ): + """Fit the model on the training dataset. + + :param model: The model to fit. + :param db: The datalayer to use. + :param train_dataset: The training dataset to use. + :param valid_dataset: The validation dataset to use. + """ if isinstance(train_dataset, QueryDataset): self.prepare_dataset(model, train_dataset) train_dataset = NativeDataset.from_list( @@ -293,11 +351,18 @@ def fit( @property def experiment_id(self): + """Get the experiment id.""" return getattr(self, "_experiment_id", None) def tokenize(tokenizer, example, X, y): - """Function to tokenize the example.""" + """Function to tokenize the example. + + :param tokenizer: The tokenizer to use. + :param example: The example to tokenize. + :param X: The input key. + :param y: The output key. + """ prompt = example[X] prompt = prompt + tokenizer.eos_token @@ -322,8 +387,8 @@ def train( ray_configs: t.Optional[dict] = None, **kwargs, ): - """ - Train LLM model on specified dataset. + """Train LLM model on specified dataset. + The training process can be run on these following modes: - Local node without ray, but only support single GPU - Local node with ray, support multi-nodes and multi-GPUs @@ -337,21 +402,17 @@ def train( Will rebuild the db and llm for the new process that can access to db. The ray cluster must can access to db. - Parameters: - :param training_config: training config for LLMTrainingArguments + :param training_args: training Arguments, see LLMTrainingArguments :param train_dataset: training dataset :param eval_datasets: evaluation dataset, can be a dict of datasets :param model_kwargs: model kwargs for AutoModelForCausalLM :param tokenizer_kwargs: tokenizer kwargs for AutoTokenizer - :param X: column name for input - :param y: column name for output :param db: datalayer, used for creating LLMCallback :param llm: llm model, used for creating LLMCallback - :param on_ray: whether to use ray, if True, will use ray_train - :param ray_address: ray address, if not None, will run on ray cluster :param ray_configs: ray configs, must provide if using ray - """ + :param **kwargs: other kwargs for Trainer + """ on_ray = bool(ray_configs) # Auto detect multi-GPUs and use ray to run data parallel training @@ -419,9 +480,13 @@ def train( def handle_ray_results(db, llm, results): - """ - Handle the ray results. + """Handle the ray results. + Will save the checkpoint to db if db and llm provided. + + :param db: datalayer, used for saving the checkpoint + :param llm: llm model, used for saving the checkpoint + :param results: the ray training results, contains the checkpoint """ checkpoint = results.checkpoint if checkpoint is None: @@ -455,8 +520,8 @@ def train_func( callbacks=None, **kwargs, ): - """ - Base training function for LLM model. + """Base training function for LLM model. + :param training_args: training Arguments, see LLMTrainingArguments :param train_dataset: training dataset, can be huggingface datasets.Dataset or ray.data.Dataset @@ -550,8 +615,8 @@ def ray_train( ray_configs: t.Optional[t.Dict[str, t.Any]] = None, **kwargs, ): - """ - Ray training function for LLM model. + """Ray training function for LLM model. + The ray train function will handle the following logic: - Prepare the datasets for ray - Build the training_loop_func for ray @@ -563,7 +628,6 @@ def ray_train( can be huggingface datasets.Dataset :param eval_datasets: evaluation dataset, Must be a Huggingface datasets.Dataset - :param ray_address: ray address, if not None, will run on ray cluster :param ray_configs: ray configs, must provide if using ray_configs :param **kwargs: other kwargs for Trainer """ @@ -647,9 +711,12 @@ def ray_train_func(train_loop_config): def prepare_lora_training(model, config: LLMTrainer): - """ - Prepare LoRA training for the model. + """Prepare LoRA training for the model. + Get the LoRA target modules and convert the model to peft model. + + :param model: The model to prepare for LoRA training. + :param config: The configuration to use. """ try: from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training @@ -687,7 +754,10 @@ def prepare_lora_training(model, config: LLMTrainer): def create_quantization_config(config: LLMTrainer): - """Create quantization config for LLM training.""" + """Create quantization config for LLM training. + + :param config: The configuration to use. + """ compute_dtype = ( torch.float16 if config.fp16 diff --git a/superduperdb/ext/unstructured/encoder.py b/superduperdb/ext/unstructured/encoder.py index 075b830da6..0b57533e88 100644 --- a/superduperdb/ext/unstructured/encoder.py +++ b/superduperdb/ext/unstructured/encoder.py @@ -9,11 +9,11 @@ def link2elements(link, unstructure_kwargs): - """ - Convert a link to a list of elements + """Convert a link to a list of elements. + Use unstructured to parse the link - param link: str, file path or url - param unstructure_kwargs: kwargs for unstructured + :param link: str, file path or url + :param unstructure_kwargs: kwargs for unstructured """ if link.startswith("file://"): link = link[7:] @@ -27,6 +27,11 @@ def link2elements(link, unstructure_kwargs): def create_encoder(unstructure_kwargs): + """Create an encoder for unstructured data. + + :param unstructure_kwargs: kwargs for unstructured + """ + def encoder(x: t.Union[str, t.List[Element]], info: t.Optional[t.Dict] = None): if isinstance(x, str): elements = link2elements(x, unstructure_kwargs) @@ -40,6 +45,8 @@ def encoder(x: t.Union[str, t.List[Element]], info: t.Optional[t.Dict] = None): def create_decoder(): + """Create a decoder for unstructured data.""" + def decoder(b: bytes, info: t.Optional[t.Dict] = None): try: return pickle.loads(b) @@ -62,6 +69,11 @@ def decoder(b: bytes, info: t.Optional[t.Dict] = None): def create_unstructured_encoder(identifier, **unstructure_kwargs): + """Create an unstructured encoder with the given identifier and unstructure kwargs. + + :param identifier: The identifier to use. + :param *unstructure_kwargs: The unstructure kwargs to use. + """ assert ( isinstance(identifier, str) and identifier != "unstructured" ), 'identifier must be a string and not "unstructured"' diff --git a/superduperdb/ext/utils.py b/superduperdb/ext/utils.py index 91e0edbb6c..a34c4a0e92 100644 --- a/superduperdb/ext/utils.py +++ b/superduperdb/ext/utils.py @@ -5,16 +5,25 @@ import numpy as np if t.TYPE_CHECKING: + from superduperdb.base.datalayer import LoadDict from superduperdb.components.datatype import DataType def str_shape(shape: t.Sequence[int]) -> str: + """Convert a shape to a string. + + :param shape: The shape to convert. + """ if not shape: raise ValueError('Shape was empty') return 'x'.join(str(x) for x in shape) def get_key(key_name: str) -> str: + """Get an environment variable. + + :param key_name: The name of the environment variable to get. + """ try: return os.environ[key_name] except KeyError: @@ -22,6 +31,12 @@ def get_key(key_name: str) -> str: def format_prompt(X: str, prompt: str, context: t.Optional[t.List[str]] = None) -> str: + """Format a prompt with the given input and context. + + :param X: The input to format the prompt with. + :param prompt: The prompt to format. + :param context: The context to format the prompt with. + """ format_params = {} if '{input}' in prompt: format_params['input'] = X @@ -41,6 +56,10 @@ def format_prompt(X: str, prompt: str, context: t.Optional[t.List[str]] = None) def superduperencode(object): + """Encode an object using superduper. + + :param object: The object to encode. + """ if isinstance(object, np.ndarray): from superduperdb.ext.numpy import array @@ -51,7 +70,12 @@ def superduperencode(object): return object -def superduperdecode(r: t.Any, encoders: t.List['DataType']): +def superduperdecode(r: t.Any, encoders: t.Union[t.Dict[str, 'DataType'], 'LoadDict']): + """Decode a superduper encoded object. + + :param r: The object to decode. + :param encoders: The encoders to use. + """ if isinstance(r, dict): encoder = encoders[r['_content']['datatype']] b = base64.b64decode(r['_content']['bytes']) diff --git a/superduperdb/ext/vllm/model.py b/superduperdb/ext/vllm/model.py index eea0638d65..87dc4a5130 100644 --- a/superduperdb/ext/vllm/model.py +++ b/superduperdb/ext/vllm/model.py @@ -38,8 +38,8 @@ @public_api(stability='beta') @dc.dataclass class VllmAPI(BaseLLMAPI): - """ - Wrapper for requesting the vLLM API service + """Wrapper for requesting the vLLM API service. + (API Server format, started by vllm.entrypoints.api_server) {parent_doc} """ @@ -47,9 +47,7 @@ class VllmAPI(BaseLLMAPI): __doc__ = __doc__.format(parent_doc=BaseLLMAPI.__doc__) def _generate(self, prompt: str, **kwargs) -> t.Union[str, t.List[str]]: - """ - Batch generate text from a prompt. - """ + """Batch generate text from a prompt.""" post_data = self.build_post_data(prompt, **kwargs) response = requests.post(self.api_url, json=post_data) results = [] @@ -61,6 +59,11 @@ def _generate(self, prompt: str, **kwargs) -> t.Union[str, t.List[str]]: def build_post_data( self, prompt: str, **kwargs: dict[str, t.Any] ) -> dict[str, t.Any]: + """Build the post data for the API request. + + :param prompt: The prompt to use. + :param kwargs: The keyword arguments to use. + """ total_kwargs = {} for key, value in {**self.predict_kwargs, **kwargs}.items(): if key in VLLM_INFERENCE_PARAMETERS_LIST: @@ -100,8 +103,12 @@ class VllmModel(BaseLLM): Load a large language model from VLLM. :param model_name: The name of the model to use. + :param tensor_parallel_size: The number of tensor parallelism. :param trust_remote_code: Whether to trust remote code. - :param dtype: The data type to use. + :param vllm_kwargs: Additional arguments to pass to the VLLM + :param on_ray: Whether to use Ray for parallelism. + :param ray_address: The address of the Ray cluster. + :param ray_config: The configuration for Ray. {parent_doc} """ @@ -129,6 +136,7 @@ def __post_init__(self, artifacts): super().__post_init__(artifacts) def init(self): + """Initialize the model.""" if self.on_ray: import ray diff --git a/superduperdb/jobs/job.py b/superduperdb/jobs/job.py index 224079e9c9..b5560a82e1 100644 --- a/superduperdb/jobs/job.py +++ b/superduperdb/jobs/job.py @@ -7,8 +7,17 @@ from superduperdb import CFG from superduperdb.jobs.tasks import callable_job, method_job +if t.TYPE_CHECKING: + from superduperdb.base.datalayer import Datalayer + def job(f): + """ + Decorator to create a job from a function. + + :param f: function to be decorated + """ + def wrapper( *args, db: t.Any = None, @@ -27,10 +36,6 @@ class Job: :param args: positional arguments to be passed to the function or method :param kwargs: keyword arguments to be passed to the function or method - :param identifier: unique identifier - :param callable: function or method to be called - :param db: DB instance to be used - :param future: future object returned by dask :param compute_kwargs: Arguments to use for model predict computation """ @@ -54,15 +59,12 @@ def __init__( self.compute_kwargs = compute_kwargs or CFG.cluster.compute.compute_kwargs def watch(self): - """ - Watch the stdout of the job. - """ + """Watch the stdout of the job.""" return self.db.metadata.watch_job(identifier=self.identifier) @abstractmethod def submit(self, compute, dependencies=()): - """ - Submit job for execution + """Submit job for execution. :param compute: compute engine :param dependencies: list of dependencies @@ -70,6 +72,7 @@ def submit(self, compute, dependencies=()): raise NotImplementedError def dict(self): + """Return a dictionary representation of the job.""" return { 'identifier': self.identifier, 'time': self.time, @@ -92,9 +95,7 @@ def __call__(self, db: t.Any = None, dependencies=()): class FunctionJob(Job): - """ - Job for running a function. - on a dask cluster. + """Job for running a function. :param callable: function to be called :param args: positional arguments to be passed to the function @@ -113,13 +114,13 @@ def __init__( self.callable = callable def dict(self): + """Return a dictionary representation of the job.""" d = super().dict() d['cls'] = 'FunctionJob' return d def submit(self, dependencies=()): - """ - Submit job for execution + """Submit job for execution. :param dependencies: list of dependencies """ @@ -136,7 +137,12 @@ def submit(self, dependencies=()): return - def __call__(self, db: t.Any = None, dependencies=()): + def __call__(self, db: t.Union['Datalayer', None], dependencies=()): + """Run the job. + + :param db: Datalayer instance to use + :param dependencies: list of dependencies + """ if db is None: from superduperdb.base.build import build_datalayer @@ -179,16 +185,21 @@ def __init__( @property def component(self): + """Get the component.""" return self._component @component.setter def component(self, value): + """Set the component. + + :param value: component to set + """ self._component = value self.callable = getattr(self._component, self.method_name) def submit(self, dependencies=()): - """ - Submit job for execution + """Submit job for execution. + :param dependencies: list of dependencies """ self.future, self.job_id = self.db.compute.submit( @@ -206,7 +217,12 @@ def submit(self, dependencies=()): ) return self - def __call__(self, db: t.Any = None, dependencies=()): + def __call__(self, db: t.Union['Datalayer', None] = None, dependencies=()): + """Run the job. + + :param db: Datalayer instance to use + :param dependencies: list of dependencies + """ if db is None: from superduperdb.base.build import build_datalayer @@ -221,6 +237,7 @@ def __call__(self, db: t.Any = None, dependencies=()): return self def dict(self): + """Return a dictionary representation of the job.""" d = super().dict() d.update( { diff --git a/superduperdb/jobs/task_workflow.py b/superduperdb/jobs/task_workflow.py index 3aa49a980e..96b5971ca6 100644 --- a/superduperdb/jobs/task_workflow.py +++ b/superduperdb/jobs/task_workflow.py @@ -14,7 +14,8 @@ @dc.dataclass class TaskWorkflow: - """ + """Task workflow class. + Keep a graph of jobs that need to be performed and their dependencies, and perform them when called. @@ -26,24 +27,35 @@ class TaskWorkflow: G: DiGraph = dc.field(default_factory=DiGraph) def add_edge(self, node1: str, node2: str) -> None: + """Add an edge to the graph. + + :param node1: name of the first node + :param node2: name of the second node + """ self.G.add_edge(node1, node2) @property def nodes(self): + """Return the nodes of the graph.""" return self.G.nodes() def add_node(self, node: str, job: t.Union[FunctionJob, ComponentJob]) -> None: + """Add a node to the graph. + + :param node: name of the node + :param job: job to be performed + """ self.G.add_node(node, job=job) def watch(self) -> None: - """Watch the stdout of each job in this workflow in topological order""" + """Watch the stdout of each job in this workflow in topological order.""" for node in list(networkx.topological_sort(self.G)): self.G.nodes[node]['job'].watch() def run_jobs( self, ): - """Run all the jobs in this workflow""" + """Run all the jobs in this workflow.""" pred = self.G.predecessors current_group = [n for n in self.G.nodes if not ancestors(self.G, n)] done = set() diff --git a/superduperdb/jobs/tasks.py b/superduperdb/jobs/tasks.py index 8c906192d7..2260a79247 100644 --- a/superduperdb/jobs/tasks.py +++ b/superduperdb/jobs/tasks.py @@ -27,6 +27,7 @@ def method_job( :param kwargs: keyword arguments to pass to the method :param job_id: unique identifier for this job :param dependencies: other jobs that this job depends on + :param db: datalayer to use """ import sys @@ -57,18 +58,31 @@ def method_job( db.metadata.update_job(job_id, 'status', 'success') +# TODO: Is this class used? class Logger: + """Logger class for writing to the database. + + :param database: database to write to + :param id_: job id + :param stream: stream to write to + """ + def __init__(self, database, id_, stream='stdout'): self.database = database self.id_ = id_ self.stream = stream def write(self, message): + """Write a message to the database. + + :param message: message to write + """ self.database.metadata.write_output_to_job( self.id_, message, stream=self.stream ) def flush(self): + """Flush something.""" pass @@ -81,6 +95,16 @@ def callable_job( dependencies=(), db: t.Optional['Datalayer'] = None, ): + """Run a function in the database. + + :param cfg: configuration + :param function_to_call: function to call + :param args: positional arguments to pass to the function + :param kwargs: keyword arguments to pass to the function + :param job_id: unique identifier for this job + :param dependencies: other jobs that this job depends on + :param db: datalayer to use + """ import sys from superduperdb import CFG diff --git a/superduperdb/misc/__init__.py b/superduperdb/misc/__init__.py index 09952c2201..1b6d4e399e 100644 --- a/superduperdb/misc/__init__.py +++ b/superduperdb/misc/__init__.py @@ -1,6 +1,13 @@ # https://stackoverflow.com/questions/39969064/how-to-print-a-message-box-in-python +# TODO: Remove the unused functions def border_msg(msg, indent=1, width=None, title=None): - """Print message-box with optional title.""" + """Print message-box with optional title. + + :param msg: Message to print + :param indent: Indentation of the box + :param width: Width of the box + :param title: Title of the box + """ lines = msg.split('\n') space = " " * indent if not width: diff --git a/superduperdb/misc/annotations.py b/superduperdb/misc/annotations.py index 520b3d539b..271a63d787 100644 --- a/superduperdb/misc/annotations.py +++ b/superduperdb/misc/annotations.py @@ -59,12 +59,13 @@ def _compare_versions(package, lower_bound, upper_bound, install_name): def requires_packages(*packages, warn=False): - """ - Require the packages to be installed - :param packages: list of tuples of packages + """Require the packages to be installed. + + :param *packages: list of tuples of packages each tuple of the form (import_name, lower_bound/None, upper_bound/None, install_name/None) + :param warn: if True, warn instead of raising an exception E.g. ('sklearn', '0.1.0', '0.2.0', 'scikit-learn') """ @@ -86,10 +87,11 @@ def requires_packages(*packages, warn=False): def _requires_packages( import_module, lower_bound=None, upper_bound=None, install_module=None ): - ''' + """Compare the versions of the required packages. + A utility function to check that a required package for a module in superduperdb.ext is installed. - ''' + """ import_module, lower_bound, upper_bound, install_module = _normalize_module( import_module, lower_bound, @@ -100,6 +102,13 @@ def _requires_packages( def deprecated(f): + """Decorator to mark a function as deprecated. + + This will result in a warning being emitted when the function is used. + + :param f: function to deprecate + """ + @functools.wraps(f) def decorated(*args, **kwargs): logging.warn( @@ -110,7 +119,7 @@ def decorated(*args, **kwargs): return decorated -# TODO add deprecated also +# TODO: add deprecated also def public_api(stability: str = 'stable'): """Annotation for documenting public APIs. @@ -122,8 +131,9 @@ def public_api(stability: str = 'stable'): If ``stability="stable"``, the APIs will remain backwards compatible across minor releases. - """ + :param stability: stability of the API + """ assert stability in ["stable", "beta", "alpha"] def wrap(obj): @@ -139,7 +149,7 @@ def wrap(obj): class SuperDuperDBDeprecationWarning(DeprecationWarning): - """Specialized Deprecation Warning for fine grained filtering control""" + """Specialized Deprecation Warning for fine grained filtering control.""" pass @@ -180,6 +190,12 @@ def _get_indent(docstring: str) -> int: def ui(*schema: t.Dict, handle_integration: t.Callable = lambda x: x): + """Annotation for documenting UI schemas. + + :param *schema: list of dictionaries representing the UI schema + :param handle_integration: function to handle the integration of the UI schema + """ + def decorated(f): f.get_ui_schema = lambda: schema f.build = lambda r: f(**r) diff --git a/superduperdb/misc/anonymize.py b/superduperdb/misc/anonymize.py index 62e2db8938..48b18c8ab7 100644 --- a/superduperdb/misc/anonymize.py +++ b/superduperdb/misc/anonymize.py @@ -6,11 +6,12 @@ def anonymize_url(url): - """ - Anonymize a URL by replacing the username and password with a mask token. + """Anonymize a URL by replacing the username and password with a mask token. + Change the username and password to *** keeping one character before and after each. - """ + :param url: Database URL + """ if not url: return url diff --git a/superduperdb/misc/archives.py b/superduperdb/misc/archives.py index 48834a2850..ad3e4b7408 100644 --- a/superduperdb/misc/archives.py +++ b/superduperdb/misc/archives.py @@ -2,11 +2,12 @@ import tarfile +# TODO: Remove the unused functions def to_tarball(folder_path: str, output_path: str): - """ - Create a tarball (compressed archive) from a folder. + """Create a tarball (compressed archive) from a folder. :param folder_path: Path to the folder to be archived. + :param output_path: Path to the output tarball file. """ try: with tarfile.open(output_path + '.tar.gz', "w:gz") as tar: @@ -18,8 +19,7 @@ def to_tarball(folder_path: str, output_path: str): def from_tarball(tarball_path: str): - """ - Extract the contents of stack tarball + """Extract the contents of stack tarball. :param tarball_path: Path to the tarball file. """ diff --git a/superduperdb/misc/auto_schema.py b/superduperdb/misc/auto_schema.py index 93e5343644..be840d0f05 100644 --- a/superduperdb/misc/auto_schema.py +++ b/superduperdb/misc/auto_schema.py @@ -13,10 +13,10 @@ def register_module(module_name): - """ - Register a module for datatype inference. - Only modules with a check and create function will be registered + """Register a module for datatype inference. + Only modules with a check and create function will be registered + :param module_name: The module name, e.g. "superduperdb.ext.numpy.encoder" """ try: importlib.import_module(module_name) @@ -46,14 +46,13 @@ def register_module(module_name): def infer_datatype(data: t.Any) -> t.Optional[t.Union[DataType, type]]: - """ - Infer the datatype of a given data object + """Infer the datatype of a given data object. + If the data object is a base type, return None, Otherwise, return the inferred datatype :param data: The data object """ - datatype = None if isinstance(data, BASE_TYPES): @@ -80,14 +79,14 @@ def infer_schema( identifier: t.Optional[str] = None, ibis=False, ) -> Schema: - """ - Infer a schema from a given data object + """Infer a schema from a given data object. :param data: The data object :param identifier: The identifier for the schema, if None, it will be generated + :param ibis: If True, the schema will be updated for the Ibis backend, + otherwise for MongoDB :return: The inferred schema """ - assert isinstance(data, dict), "Data must be a dictionary" schema_data = {} @@ -122,6 +121,12 @@ def infer_schema( def updated_schema_data_for_ibis( schema_data, ) -> t.Dict[str, DataType]: + """Update the schema data for Ibis backend. + + Convert the basic data types to Ibis data types. + + :param schema_data: The schema data + """ from superduperdb.backends.ibis.field_types import dtype for k, v in schema_data.items(): @@ -132,6 +137,12 @@ def updated_schema_data_for_ibis( def updated_schema_data_for_mongodb(schema_data) -> t.Dict[str, DataType]: + """Update the schema data for MongoDB backend. + + Only keep the data types that can be stored directly in MongoDB. + + :param schema_data: The schema data + """ schema_data = {k: v for k, v in schema_data.items() if isinstance(v, DataType)} # MongoDB can store dict directly, so we don't need to serialize it. @@ -141,8 +152,14 @@ def updated_schema_data_for_mongodb(schema_data) -> t.Dict[str, DataType]: class JsonDataTypeFactory(DataTypeFactory): + """A factory for JSON datatypes.""" + @staticmethod def check(data: t.Any) -> bool: + """Check if the data is able to be encoded by the JSON serializer. + + :param data: The data object + """ try: json_serializer.encode_data(data) return True @@ -151,4 +168,8 @@ def check(data: t.Any) -> bool: @staticmethod def create(data: t.Any) -> DataType: + """Create a JSON datatype. + + :param data: The data object + """ return json_serializer diff --git a/superduperdb/misc/colors.py b/superduperdb/misc/colors.py index 3d414961a8..6c0b3f4ebd 100644 --- a/superduperdb/misc/colors.py +++ b/superduperdb/misc/colors.py @@ -1,4 +1,6 @@ class Colors: + """Colors list for terminal output.""" + BLACK = '\033[30m' RED = '\033[31m' GREEN = '\033[32m' diff --git a/superduperdb/misc/compat.py b/superduperdb/misc/compat.py index 90809c8e93..516aebcccf 100644 --- a/superduperdb/misc/compat.py +++ b/superduperdb/misc/compat.py @@ -1,6 +1,4 @@ -""" -Functions from later standard libraries not available in Python 3.8 -""" +"""Functions from later standard libraries not available in Python 3.8.""" from functools import lru_cache @@ -9,4 +7,12 @@ # Implements functools.cache from Python 3.9 def cache(user_function, /): + """Simple cache decorator. + + This is a simple cache decorator that can be used to cache the results of + a function call. It does not have any of the advanced features of + functools.lru_cache. + + :param user_function: Function to cache + """ return lru_cache(maxsize=None)(user_function) diff --git a/superduperdb/misc/data.py b/superduperdb/misc/data.py index c8427c7e8f..567b8fbf75 100644 --- a/superduperdb/misc/data.py +++ b/superduperdb/misc/data.py @@ -5,13 +5,11 @@ def ibatch(iterable: t.Iterable[T], batch_size: int) -> t.Iterator[t.List[T]]: - """ - Batch an iterable into chunks of size `batch_size` + """Batch an iterable into chunks of size `batch_size`. :param iterable: the iterable to batch :param batch_size: the number of groups to write """ - iterator = iter(iterable) while True: batch = list(itertools.islice(iterator, batch_size)) diff --git a/superduperdb/misc/download.py b/superduperdb/misc/download.py index 57c1081666..10182275a9 100644 --- a/superduperdb/misc/download.py +++ b/superduperdb/misc/download.py @@ -20,15 +20,26 @@ class TimeoutException(Exception): + """Timeout exception.""" + ... def timeout_handler(signum, frame): + """Timeout handler to raise an TimeoutException. + + :param signum: signal number + :param frame: frame + """ raise TimeoutException() @contextmanager def timeout(seconds): + """Context manager to set a timeout. + + :param seconds: seconds until timeout + """ old_handler = signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(seconds) try: @@ -39,8 +50,7 @@ def timeout(seconds): class Fetcher: - """ - Fetches data from a URI + """Fetches data from a URI. :param headers: headers to be used for download :param n_workers: number of download workers @@ -76,8 +86,7 @@ def _download_from_uri(self, uri): return self.request_session.get(uri, headers=self.headers).content def __call__(self, uri: str): - """ - Download data from a URI + """Download data from a URI. :param uri: uri to download from """ @@ -92,8 +101,7 @@ def __call__(self, uri: str): class BaseDownloader: - """ - Base class for downloading files + """Base class for downloading files. :param uris: list of uris/ file names to fetch :param n_workers: number of multiprocessing workers @@ -119,8 +127,8 @@ def __init__( self.results: t.Dict = {} def go(self): - """ - Download all files + """Download all files. + Uses a :py:class:`multiprocessing.pool.ThreadPool` to parallelize connections. """ @@ -185,11 +193,24 @@ def _sequential_go(self, f): class Updater: + """Updater class to update the artifact. + + :param db: Datalayer instance + :param query: query to be executed + """ + def __init__(self, db, query): self.db = db self.query = query def exists(self, uri, key, id, datatype): + """Check if the artifact exists. + + :param uri: uri to download from + :param key: key in the document + :param id: id of the document + :param datatype: datatype of the document + """ if self.db.datatypes[datatype].encodable == 'artifact': out = self.db.artifact_store.exists(uri=uri, datatype=datatype) else: @@ -206,6 +227,14 @@ def __call__( datatype, bytes_, ): + """Run the updater. + + :param uri: uri to download from + :param key: key in the document + :param id: id of the document + :param datatype: datatype of the document + :param bytes_: bytes to insert + """ if self.db.datatypes[datatype].encodable == 'artifact': self.db.artifact_store.save_artifact( { @@ -228,6 +257,7 @@ class Downloader(BaseDownloader): :param update_one: function to call to insert data into table :param ids: list of ids of rows/ documents to update :param keys: list of keys in rows/ documents to insert to + :param datatypes: list of datatypes of rows/ documents to insert to :param n_workers: number of multiprocessing workers :param headers: dictionary of request headers passed to``requests`` package :param skip_existing: if ``True`` then don't bother getting already present data @@ -286,8 +316,7 @@ def _download(self, i): def gather_uris( documents: t.Sequence[Document], gather_ids: bool = True ) -> t.Tuple[t.List[str], t.List[str], t.List[t.Any], t.List[str]]: - """ - Get the uris out of all documents as denoted by ``{"_content": ...}`` + """Get the uris out of all documents as denoted by ``{"_content": ...}``. :param documents: list of dictionaries :param gather_ids: if ``True`` then gather ids of documents @@ -309,7 +338,8 @@ def gather_uris( def _gather_uris_for_document(r: Document, id_field: str = '_id'): - ''' + """Get the uris out of a single document as denoted by ``{"_content": ...}``. + >>> _gather_uris_for_document({'a': {'_content': {'uri': 'test'}}}) (['test'], ['a']) >>> d = {'b': {'a': {'_content': {'uri': 'test'}}}} @@ -318,7 +348,7 @@ def _gather_uris_for_document(r: Document, id_field: str = '_id'): >>> d = {'b': {'a': {'_content': {'uri': 'test', 'bytes': b'abc'}}}} >>> _gather_uris_for_document(d) ([], []) - ''' + """ uris = [] keys = [] datatypes = [] @@ -341,8 +371,9 @@ def download_content( raises: bool = True, n_workers: t.Optional[int] = None, ) -> t.Optional[t.Sequence[Document]]: - """ - Download content contained in uploaded data. Items to be downloaded are identifier + """Download content contained in uploaded data. + + Items to be downloaded are identifier via the subdocuments in the form exemplified below. By default items are downloaded to the database, unless a ``download_update`` function is provided. @@ -350,12 +381,8 @@ def download_content( :param query: query to be executed :param ids: ids to be downloaded :param documents: documents to be downloaded - :param timeout: timeout for download :param raises: whether to raise errors - :param n_download_workers: number of download workers - :param headers: headers to be used for download - :param download_update: function to be used for updating the database - :param **kwargs: additional keyword arguments + :param n_workers: number of download workers >>> d = {"_content": {"uri": "", "encoder": ""}} >>> def update(key, id, bytes): @@ -408,6 +435,12 @@ def download_content( def download_from_one(r: Document): + """Download content from a single document. + + This function will find all URIs in the document and download them. + + :param r: document to download from + """ uris, keys, _, _ = gather_uris([r]) if not uris: return diff --git a/superduperdb/misc/files.py b/superduperdb/misc/files.py index 94ace62f27..4f0df3e96a 100644 --- a/superduperdb/misc/files.py +++ b/superduperdb/misc/files.py @@ -13,6 +13,8 @@ def get_file_from_uri(uri): 'test.txt' >>> _get_file('http://test.txt') '414388bd5644669b8a92e45a96318890f6e8de54' + + :param uri: The uri to get the file from """ if uri.startswith('file://'): file = uri[7:] @@ -38,6 +40,7 @@ def load_uris( Load ``"bytes"`` into ``"_content"`` from ``"uri"`` inside ``r``. :param r: The dict to load the bytes into + :param datatypes: The datatypes to use for encoding :param root: The root directory to load the bytes from :param raises: Whether to raise an error if the file is not found diff --git a/superduperdb/misc/hash.py b/superduperdb/misc/hash.py index a5ccb0ee37..32c8016a26 100644 --- a/superduperdb/misc/hash.py +++ b/superduperdb/misc/hash.py @@ -3,10 +3,19 @@ def hash_string(string: str): + """Hash a string. + + :param string: string to hash + """ return hashlib.sha256(string.encode()).hexdigest() def hash_dict(data: dict): + """Hash a dictionary. + + :param data: dictionary to hash + """ + def process(d): if isinstance(d, dict): return sorted((k, process(v)) for k, v in d.items()) @@ -20,10 +29,7 @@ def process(d): def random_sha1(): - """ - Generate random sha1 values - Can be used to generate file_id and other values - """ + """Generate random sha1 values.""" random_data = os.urandom(256) sha1 = hashlib.sha1() sha1.update(random_data) diff --git a/superduperdb/misc/retry.py b/superduperdb/misc/retry.py index b73840dea0..0b155a730b 100644 --- a/superduperdb/misc/retry.py +++ b/superduperdb/misc/retry.py @@ -15,13 +15,18 @@ class Retry: This is a thin wrapper around the tenacity retry library, using our configs. :param exception_types: The exception types to retry on. - :param cfg: The retry config. + :param cfg: The retry config. If None, uses the default config. """ exception_types: ExceptionTypes cfg: t.Optional[s.config.Retry] = None def __call__(self, f: t.Callable) -> t.Any: + """Decorate a function to retry on exceptions. + + Uses the exception types and config provided to the constructor. + :param f: The function to decorate. + """ cfg = self.cfg or s.CFG.retries retry = tenacity.retry_if_exception_type(self.exception_types) stop = tenacity.stop_after_attempt(cfg.stop_after_attempt) diff --git a/superduperdb/misc/run.py b/superduperdb/misc/run.py index ef7c36427a..1335591ac6 100644 --- a/superduperdb/misc/run.py +++ b/superduperdb/misc/run.py @@ -23,6 +23,7 @@ def run( :param args: The command to run. :param text: Whether to use text mode. :param check: Whether to raise an error if the command fails. + :param verbose: Whether to print the command. :param **kwargs: Additional arguments to pass to ``subprocess.run``. """ if verbose: diff --git a/superduperdb/misc/runnable/collection.py b/superduperdb/misc/runnable/collection.py index 0c604a3727..0c3efb57ba 100644 --- a/superduperdb/misc/runnable/collection.py +++ b/superduperdb/misc/runnable/collection.py @@ -10,26 +10,33 @@ class HasRunnables(Runnable): - """Collect zero or more Runnable into one""" + """Collect zero or more Runnable into one.""" runnables: t.Sequence[Runnable] def start(self): + """Start all runnables.""" for r in self.runnables: r.running.on_set.append(self._on_start) r.stopped.on_set.append(self._on_stop) r.start() def stop(self): + """Stop all runnables.""" self.running.clear() for r in self.runnables: r.stop() def finish(self): + """Finish all runnables.""" for r in self.runnables: r.finish() def join(self, timeout: t.Optional[float] = None): + """Join all runnables. + + :param timeout: Timeout in seconds + """ for r in self.runnables: r.join(timeout) @@ -49,6 +56,13 @@ class ThreadQueue(HasRunnables): There is a special `finish_message` value, which when received shuts down that consumer. ThreadQueue.finish() puts one `self.finish_message` onto the queue for each consumer. + + :param callback: The callback to run for each item in the queue. + :param error: The error callback. + :param maxsize: The maximum size of the queue. + :param name: The name of the queue. + :param thread_count: The number of threads to run. + :param timeout: The timeout for getting an item from the queue. """ callback: t.Callable[[t.Any], None] @@ -64,10 +78,11 @@ def __post_init__(self): @cached_property def queue(self) -> Queue: + """Return a new queue.""" return Queue(self.maxsize) def finish(self) -> None: - """Put an empty message into the queue for each listener""" + """Put an empty message into the queue for each listener.""" for _ in self.runnables: self.queue.put(_SENTINEL_MESSAGE) diff --git a/superduperdb/misc/runnable/queue_chunker.py b/superduperdb/misc/runnable/queue_chunker.py index 4bedee9eab..4fafc97e7a 100644 --- a/superduperdb/misc/runnable/queue_chunker.py +++ b/superduperdb/misc/runnable/queue_chunker.py @@ -8,7 +8,8 @@ @dc.dataclass class QueueChunker: - """Chunk a queue into lists of length at most `chunk_size` within time `timeout` + """Chunk a queue into lists of length at most `chunk_size` within time `timeout`. + :param chunk_size: Maximum number of entries in a chunk :param timeout: Maximum amount of time to block :param accumulate_timeouts: If accumulate timeouts is True, then `timeout` is @@ -21,6 +22,12 @@ class QueueChunker: accumulate_timeouts: bool = False def __call__(self, queue: Queue, stop_event: Event) -> t.Iterator[t.List]: + """Chunk the queue. + + :param queue: Queue to chunk + :param stop_event: Event to stop the chunking + """ + def chunk(): start = self.accumulate_timeouts and time.time() diff --git a/superduperdb/misc/runnable/runnable.py b/superduperdb/misc/runnable/runnable.py index be9f6c65bf..b3a3db1289 100644 --- a/superduperdb/misc/runnable/runnable.py +++ b/superduperdb/misc/runnable/runnable.py @@ -5,12 +5,14 @@ class Event(threading.Event): - """ + """An Event that calls a list of callbacks when set or cleared. + A threading.Event that also calls back to zero or more functions when its state is set or reset, and has a __bool__ method. Note that the callback might happen on some completely different thread, - so these functions cannot block""" + so these functions cannot block + """ on_set: t.List[Callback] @@ -19,10 +21,12 @@ def __init__(self, *on_set: Callback): super().__init__() def set(self): + """Set the flag to True and call all the callbacks.""" super().set() [c() for c in self.on_set] def clear(self): + """Clear the flag to False and call all the callbacks.""" super().clear() [c() for c in self.on_set] @@ -31,7 +35,7 @@ def __bool__(self): class Runnable: - """A base class for things that start, run, finish, stop and join + """A base class for things that start, run, finish, stop and join. Stopping is requesting immediate termination: finishing is saying that there is no more work to be done, finish what you are doing. @@ -95,7 +99,10 @@ def finish(self): self.stop() def join(self, timeout: t.Optional[float] = None): - """Join this thread or process. Might block indefinitely, might do nothing""" + """Join this thread or process. Might block indefinitely, might do nothing. + + :param timeout: Timeout in seconds + """ def __enter__(self): self.start() diff --git a/superduperdb/misc/runnable/thread.py b/superduperdb/misc/runnable/thread.py index ac0bd95a64..86d2d026c3 100644 --- a/superduperdb/misc/runnable/thread.py +++ b/superduperdb/misc/runnable/thread.py @@ -10,6 +10,10 @@ def none(x): + """No-op function. + + :param x: Any + """ pass @@ -51,10 +55,12 @@ def __str__(self): @_debug def pre_run(self): + """Pre-run the thread.""" pass @_debug(after=True) def run(self): + """Run the thread.""" self.pre_run() self.running.set() @@ -76,15 +82,17 @@ def run(self): @_debug def stop(self): + """Stop the thread.""" self.running.clear() @_debug def finish(self): + """Finish the thread.""" pass class IsThread(ThreadBase, Thread): - """This ThreadBase inherits from threading.Thread. + """IsThread is a thread that inherits from threading.Thread. To use IsThread, derive from it and override either or both of self.callback() and self.pre_run() @@ -95,23 +103,40 @@ def __init__(self, *args, **kwargs): Thread.__init__(self, daemon=self.daemon) def callback(self): + """The callback to run in the thread.""" pass def error(self, item: Exception) -> None: + """Handle an error. + + :param item: Exception + """ pass @_debug(after=True) def join(self, timeout: t.Optional[float] = None): + """Join the thread. + + :param timeout: Timeout in seconds + """ Thread.join(self, timeout) @_debug def start(self): + """Start the thread.""" Thread.start(self) @dc.dataclass class HasThread(ThreadBase): - """This ThreadBase contains a thread, and is constructed with a callback""" + """HasThread contains a thread, and is constructed with a callback. + + :param callback: The callback to run in the thread. + :param daemon: Whether the thread is a daemon. + :param error: The error callback. + :param looping: Whether the thread should loop. + :param name: The name of the thread. + """ callback: Callback = print daemon: bool = False @@ -124,15 +149,22 @@ def __post_init__(self): @_debug(after=True) def join(self, timeout: t.Optional[float] = None): + """Join the thread. + + :param timeout: Timeout in seconds + """ self.thread.join(timeout) @_debug def start(self): + """Start the thread.""" self.thread.start() def new_thread(self) -> Thread: + """Return a new thread.""" return Thread(target=self.run, daemon=self.daemon) @cached_property def thread(self) -> Thread: + """Return the thread.""" return self.new_thread() diff --git a/superduperdb/misc/serialization.py b/superduperdb/misc/serialization.py index 09d987a877..171971ee4a 100644 --- a/superduperdb/misc/serialization.py +++ b/superduperdb/misc/serialization.py @@ -28,9 +28,13 @@ def asdict(obj, *, copy_method=copy.copy) -> t.Dict[str, t.Any]: - """ + """Convert the dataclass instance to a dict. + Custom ``asdict`` function which exports a dataclass object into a dict, with a option to choose for nested non atomic objects copy strategy. + + :param obj: The dataclass instance to + :param copy_method: The copy method to use for non atomic objects """ if not dc.is_dataclass(obj): raise TypeError("asdict() should be called on dataclass instances") diff --git a/superduperdb/misc/server.py b/superduperdb/misc/server.py index 3458834b02..73ecd6adff 100644 --- a/superduperdb/misc/server.py +++ b/superduperdb/misc/server.py @@ -58,6 +58,14 @@ def _request_server( def request_server( service: str = 'vector_search', data=None, endpoint='add', args={}, type='post' ): + """Request server with data. + + :param service: Service name + :param data: Data to send + :param endpoint: Endpoint to hit + :param args: Arguments to pass + :param type: Type of request + """ _handshake(service) return _request_server( service=service, data=data, endpoint=endpoint, args=args, type=type diff --git a/superduperdb/misc/special_dicts.py b/superduperdb/misc/special_dicts.py index 05f95c4362..71e251594f 100644 --- a/superduperdb/misc/special_dicts.py +++ b/superduperdb/misc/special_dicts.py @@ -51,7 +51,10 @@ def __setitem__(self, key: str, value: t.Any) -> None: self[parent] = parent_item +# TODO: Is this an unused class? class ArgumentDefaultDict(defaultdict): + """ArgumentDefaultDict.""" + def __getitem__(self, item): if item not in self: self[item] = self.default_factory(item) @@ -59,11 +62,15 @@ def __getitem__(self, item): def diff(r1, r2): - """ + """Get the difference between two dictionaries. + >>> _diff({'a': 1, 'b': 2}, {'a': 2, 'b': 2}) {'a': (1, 2)} >>> _diff({'a': {'c': 3}, 'b': 2}, {'a': 2, 'b': 2}) {'a': ({'c': 3}, 2)} + + :param r1: Dict + :param r2: Dict """ d = _diff_impl(r1, r2) out = {} diff --git a/superduperdb/rest/app.py b/superduperdb/rest/app.py index 9a91101b84..d8b2c93fde 100644 --- a/superduperdb/rest/app.py +++ b/superduperdb/rest/app.py @@ -37,6 +37,11 @@ @dc.dataclass(kw_only=True) class MyBoolean(Component): + """A simple boolean component. + + :param my_bool: a boolean value + """ + type_id: t.ClassVar[str] = 'bool' my_bool: bool diff --git a/superduperdb/rest/utils.py b/superduperdb/rest/utils.py index 4ad3bf483f..7cb691fad5 100644 --- a/superduperdb/rest/utils.py +++ b/superduperdb/rest/utils.py @@ -44,6 +44,12 @@ def _parse_query_part(part, documents, query, db: t.Optional[t.Any] = None): def parse_query(query, documents, db: t.Optional[t.Any] = None): + """Parse a query string into a query object. + + :param query: query string to parse + :param documents: documents to use in the query + :param db: datalayer instance + """ if isinstance(query, str): query = [x.strip() for x in query.split('\n') if x.strip()] for i, q in enumerate(query): @@ -52,6 +58,10 @@ def parse_query(query, documents, db: t.Optional[t.Any] = None): def strip_artifacts(r: t.Any): + """Strip artifacts for the data. + + :param r: the data to strip artifacts from + """ if isinstance(r, dict): if '_content' in r: return f'_artifact/{r["_content"]["file_id"]}', [r["_content"]["file_id"]] diff --git a/superduperdb/server/app.py b/superduperdb/server/app.py index f2af3f7b07..bd916f421a 100644 --- a/superduperdb/server/app.py +++ b/superduperdb/server/app.py @@ -2,6 +2,7 @@ import sys import threading import time +import typing as t from functools import cached_property from traceback import format_exc @@ -15,13 +16,21 @@ from superduperdb import logging from superduperdb.base.build import build_datalayer +from superduperdb.base.config import Config from superduperdb.base.datalayer import Datalayer # --------------- Create exception handler middleware----------------- class ExceptionHandlerMiddleware(BaseHTTPMiddleware): + """Middleware to handle exceptions and log them.""" + async def dispatch(self, request: Request, call_next): + """Dispatch the request and handle exceptions. + + :param request: request to dispatch + :param call_next: next call to make + """ try: return await call_next(request) except Exception as e: @@ -50,9 +59,14 @@ async def dispatch(self, request: Request, call_next): class SuperDuperApp: - """ - This is a wrapper class that prepares helper functions used to create a - fastapi application in the realm of superduperdb. + """A wrapper class for creating a fastapi application. + + The class provides a simple interface for creating a fastapi application + with custom endpoints. + + :param service: name of the service + :param port: port to run the service on + :param db: datalayer instance """ def __init__(self, service='vector_search', port=8000, db: Datalayer = None): @@ -83,19 +97,27 @@ def __init__(self, service='vector_search', port=8000, db: Datalayer = None): @cached_property def app(self): + """Return the application instance.""" self._app.include_router(self.router) return self._app def raise_error(self, msg: str, code: int): + """Raise an error with the given message and code. + + :param msg: message to raise + :param code: code to raise + """ raise HTTPException(code, detail=msg) @cached_property def db(self): + """Return the database instance from the app state.""" return self._app.state.pool def add(self, *args, method='post', **kwargs): - """ - Register an endpoint with this method. + """Register an endpoint with this method. + + :param method: method to use """ def decorator(function): @@ -106,8 +128,10 @@ def decorator(function): return decorator def add_default_endpoints(self): - """ - Add a list of default endpoints which comes out of the box with `SuperDuperApp` + """Add default endpoints to the application. + + - /health: Health check endpoint + - /handshake/config: Handshake endpoint """ @self.router.get('/health') @@ -131,6 +155,7 @@ def handshake(cfg: str): ) def print_routes(self): + """Print the routes of the application.""" table = PrettyTable() # Define the table headers @@ -142,7 +167,11 @@ def print_routes(self): logging.info(f"Routes for '{self.service}' app: \n{table}") - def pre_start(self, cfg=None): + def pre_start(self, cfg: t.Union[Config, None] = None): + """Pre-start the application. + + :param cfg: Configurations to use + """ self.add_default_endpoints() if not self._user_startup: @@ -152,6 +181,7 @@ def pre_start(self, cfg=None): assert self.app def run(self): + """Run the application.""" uvicorn.run( self._app, host=self.app_host, @@ -159,16 +189,20 @@ def run(self): ) def start(self): - """ - This method is used to start the application server - """ + """Start the application.""" self.pre_start() self.print_routes() self.run() - def startup(self, function=None, cfg=None): - """ - This method is used to register a startup function + def startup( + self, + function: t.Union[t.Callable, None] = None, + cfg: t.Union[Config, None] = None, + ): + """Startup the application. + + :param function: function to run on startup + :param cfg: Configurations to use """ self._user_startup = True @@ -186,9 +220,10 @@ def startup_db_client(): return - def shutdown(self, function=None): - """ - This method is used to register a shutdown function + def shutdown(self, function: t.Union[t.Callable, None] = None): + """Shutdown the application. + + :param function: function to run on shutdown """ self._user_shutdown = True @@ -203,22 +238,27 @@ def shutdown_db_client(): def database(request: Request) -> Datalayer: + """Return the database instance from the app state. + + :param request: request object + """ return request.app.state.pool def DatalayerDependency(): - """ - A helper method to be used for injecting datalayer instance - into endpoint implementation - """ + """Dependency for injecting datalayer instance into endpoint implementation.""" return Depends(database) class Server(uvicorn.Server): + """Custom server class.""" + def install_signal_handlers(self): + """Install signal handlers.""" pass def run_in_thread(self): + """Run the server in a separate thread.""" self._thread = threading.Thread(target=self.run) self._thread.start() @@ -226,5 +266,6 @@ def run_in_thread(self): time.sleep(1e-3) def stop(self): + """Stop the server.""" self.should_exit = True self._thread.join() diff --git a/superduperdb/server/cluster.py b/superduperdb/server/cluster.py index d89d59ff8c..8d25e90b7c 100644 --- a/superduperdb/server/cluster.py +++ b/superduperdb/server/cluster.py @@ -8,10 +8,11 @@ def create_tmux_session(session_name, commands): - ''' - Create a tmux local cluster - ''' + """Create a tmux local cluster. + :param session_name: name of the tmux session + :param commands: list of commands to run + """ for ix, cmd in enumerate(commands, start=0): window_name = f'{session_name}:{ix}' run_tmux_command(['send-keys', '-t', window_name, cmd, 'C-m']) @@ -22,11 +23,19 @@ def create_tmux_session(session_name, commands): def run_tmux_command(command): + """Run a tmux command. + + :param command: command to run + """ print('tmux ' + ' '.join(command)) subprocess.run(["tmux"] + command, check=True) def up_cluster(notebook_token: t.Optional[str] = None): + """Start the local cluster. + + :param notebook_token: token to use for the jupyter notebook + """ print('Starting the local cluster...') CFG = os.environ.get('SUPERDUPERDB_CONFIG') @@ -85,10 +94,12 @@ def up_cluster(notebook_token: t.Optional[str] = None): def down_cluster(): + """Stop the local cluster.""" print('Stopping the local cluster...') run_tmux_command(['kill-session', '-t', SESSION_NAME]) print('local cluster stopped') def attach_cluster(): + """Attach to the tmux session.""" run_tmux_command(['attach-session', '-t', SESSION_NAME]) diff --git a/superduperdb/vector_search/atlas.py b/superduperdb/vector_search/atlas.py index 30790f5898..0aee591687 100644 --- a/superduperdb/vector_search/atlas.py +++ b/superduperdb/vector_search/atlas.py @@ -15,10 +15,13 @@ class MongoAtlasVectorSearcher(BaseVectorSearcher): - """ - Implementation of atlas vector search + """Vector searcher implementation of atlas vector search. :param identifier: Unique string identifier of index + :param collection: Collection name + :param dimensions: Dimension of the vector embeddings + :param measure: measure to assess similarity + :param output_path: Path to the output """ def __init__( @@ -49,10 +52,15 @@ def __len__(self): @cached_property def index(self): + """Return the index collection.""" return self.database[self.collection] @classmethod def from_component(cls, vi: 'VectorIndex'): + """Create a vector searcher from a vector index. + + :param vi: VectorIndex instance + """ from superduperdb.components.listener import Listener from superduperdb.components.model import ObjectModel @@ -129,18 +137,38 @@ def _find(self, h, n=100): return ids, scores def find_nearest_from_id(self, id: str, n=100, within_ids=None): + """Find the nearest vectors to the given ID. + + :param id: ID of the vector + :param n: number of nearest vectors to return + :param within_ids: list of IDs to search within + """ h = self.index.find_one({'id': id}) return self.find_nearest_from_array(h, n=n, within_ids=within_ids) def find_nearest_from_array(self, h, n=100, within_ids=None): + """Find the nearest vectors to the given vector. + + :param h: vector + :param n: number of nearest vectors to return + :param within_ids: list of IDs to search within + """ return self._find(h, n=n) def add(self, items): + """Add vectors to the index. + + :param items: List of vectors to add + """ items = list(map(lambda x: x.to_dict(), items)) if not CFG.cluster.vector_search == CFG.data_backend: self.index.insert_many(items) def delete(self, items): + """Delete vectors from the index. + + :param items: List of vectors to delete + """ ids = list(map(lambda x: x.id, items)) if not CFG.cluster.vector_search == CFG.data_backend: self.index.delete_many({'id': {'$in': ids}}) @@ -149,7 +177,8 @@ def _create_index(self, collection: str, output_path: str): """ Create a vector index in the data backend if an Atlas deployment. - :param vector_index: vector index to create + :param collection: Collection name + :param output_path: Path to the output """ _, key, model, version = output_path.split('.') if re.match('^_outputs\.[A-Za-z0-9_]+\.[A-Za-z0-9_]+', key): diff --git a/superduperdb/vector_search/base.py b/superduperdb/vector_search/base.py index b8c1bfdc8c..79bc1e0f24 100644 --- a/superduperdb/vector_search/base.py +++ b/superduperdb/vector_search/base.py @@ -13,11 +13,14 @@ class BaseVectorSearcher(ABC): - @classmethod - def from_component(cls, vi: 'VectorIndex'): - return cls( - identifier=vi.identifier, dimensions=vi.dimensions, measure=vi.measure - ) + """Base class for vector searchers. + + :param identifier: Unique string identifier of index + :param dimensions: Dimension of the vector embeddings + :param h: Seed vectors ``numpy.ndarray`` + :param index: list of IDs + :param measure: measure to assess similarity + """ @abstractmethod def __init__( @@ -30,12 +33,26 @@ def __init__( ): pass + @classmethod + def from_component(cls, vi: 'VectorIndex'): + """Create a vector searcher from a vector index. + + :param vi: VectorIndex instance + """ + return cls( + identifier=vi.identifier, dimensions=vi.dimensions, measure=vi.measure + ) + @abstractmethod def __len__(self): pass @staticmethod def to_numpy(h): + """Converts a vector to a numpy array. + + :param h: vector, numpy.ndarray, or list + """ if isinstance(h, numpy.ndarray): return h if hasattr(h, 'numpy'): @@ -46,6 +63,10 @@ def to_numpy(h): @staticmethod def to_list(h): + """Converts a vector to a list. + + :param h: vector + """ if hasattr(h, 'tolist'): return h.tolist() if isinstance(h, list): @@ -62,8 +83,7 @@ def add(self, items: t.Sequence[VectorItem]) -> None: @abstractmethod def delete(self, ids: t.Sequence[str]) -> None: - """ - Remove items from the index + """Remove items from the index. :param ids: t.Sequence of ids of vectors. """ @@ -80,6 +100,7 @@ def find_nearest_from_id( :param _id: id of the vector :param n: number of nearest vectors to return + :param within_ids: list of ids to search within """ @abstractmethod @@ -94,16 +115,20 @@ def find_nearest_from_array( :param h: vector :param n: number of nearest vectors to return + :param within_ids: list of ids to search within """ def post_create(self): - """ + """Post create method. + This method is used for searchers which requires to perform a task after all vectors have been added """ class VectorIndexMeasureType(str, enum.Enum): + """Enum for vector index measure types.""" + cosine = 'cosine' css = 'css' dot = 'dot' @@ -112,10 +137,13 @@ class VectorIndexMeasureType(str, enum.Enum): @dataclass(frozen=True) class VectorSearchConfig: - ''' - Represents search config which helps initiate a vector - searcher class. - ''' + """Represents search config which helps initiate a vector searcher class. + + :param id: Identifier for the vector index + :param dimensions: Dimensions of the vector + :param measure: Measure for vector search + :param parameters: Additional parameters for vector search + """ id: str dimensions: int @@ -125,11 +153,11 @@ class VectorSearchConfig: @dataclass(frozen=True) class VectorItem: - ''' - Class for representing a vector in vector search with - id and vector. + """Class for representing a vector in vector search with id and vector. - ''' + :param id: ID of the vector + :param vector: Vector of the item + """ id: str vector: numpy.ndarray @@ -141,41 +169,54 @@ def create( id: str, vector: numpy.typing.ArrayLike, ) -> VectorItem: + """Creates a vector item from id and vector. + + :param id: ID of the vector + :param vector: Vector of the item + """ return VectorItem(id=id, vector=BaseVectorSearcher.to_numpy(vector)) def to_dict(self) -> t.Dict: + """Converts the vector item to a dictionary.""" return {'id': self.id, 'vector': self.vector} @dataclass(frozen=True) class VectorSearchResult: - ''' - Dataclass for representing vector search results with - `id` and `score`. - ''' + """Dataclass for representing vector search results with `id` and `score`. + + :param id: ID of the vector + :param score: Similarity score of the vector + """ id: str score: float def l2(x, y): - ''' - L2 function for vector similarity search - ''' + """L2 function for vector similarity search. + + :param x: numpy.ndarray + :param y: numpy.ndarray + """ return numpy.array([-numpy.linalg.norm(x - y, axis=1)]) def dot(x, y): - ''' - Dot function for vector similarity search - ''' + """Dot function for vector similarity search. + + :param x: numpy.ndarray + :param y: numpy.ndarray + """ return numpy.dot(x, y.T) def cosine(x, y): - ''' - Cosine similarity function for vector search - ''' + """Cosine similarity function for vector search. + + :param x: numpy.ndarray + :param y: numpy.ndarray, y should be normalized! + """ x = x / numpy.linalg.norm(x, axis=1)[:, None] # y which implies all vectors in vectordatabase # has normalized vectors. diff --git a/superduperdb/vector_search/in_memory.py b/superduperdb/vector_search/in_memory.py index 43b38fbce8..cc356fc781 100644 --- a/superduperdb/vector_search/in_memory.py +++ b/superduperdb/vector_search/in_memory.py @@ -11,6 +11,7 @@ class InMemoryVectorSearcher(BaseVectorSearcher): Simple hash-set for looking up with vector similarity. :param identifier: Unique string identifier of index + :param dimensions: Dimension of the vector embeddings :param h: array/ tensor of vectors :param index: list of IDs :param measure: measure to assess similarity @@ -63,10 +64,21 @@ def _setup(self, h, index): self.lookup = dict(zip(index, range(len(index)))) def find_nearest_from_id(self, _id, n=100): + """Find the nearest vectors to the given ID. + + :param _id: ID of the vector + :param n: number of nearest vectors to return + """ self.post_create() return self.find_nearest_from_array(self.h[self.lookup[_id]], n=n) def find_nearest_from_array(self, h, n=100, within_ids=None): + """Find the nearest vectors to the given vector. + + :param h: vector + :param n: number of nearest vectors to return + :param within_ids: list of IDs to search within + """ self.post_create() if self.h is None: @@ -98,6 +110,12 @@ def find_nearest_from_array(self, h, n=100, within_ids=None): return _ids, scores def add(self, items: t.Sequence[VectorItem]) -> None: + """Add vectors to the index. + + Only adds to cache if cache is not full. + + :param items: List of vectors to add + """ if len(self._cache) < self._CACHE_SIZE: for item in items: self._cache.append(item) @@ -106,6 +124,7 @@ def add(self, items: t.Sequence[VectorItem]) -> None: self._cache = [] def post_create(self): + """Post create method to incorporate remaining vectors to be added in cache.""" if self._cache: self._add(self._cache) self._cache = [] @@ -123,6 +142,10 @@ def _add(self, items: t.Sequence[VectorItem]) -> None: return self._setup(h, index) def delete(self, ids): + """Delete vectors from the index. + + :param ids: List of IDs to delete + """ self.post_create() ix = list(map(self.lookup.__getitem__, ids)) h = numpy.delete(self.h, ix, axis=0) diff --git a/superduperdb/vector_search/interface.py b/superduperdb/vector_search/interface.py index 3b9e6c54db..cfce0c86e7 100644 --- a/superduperdb/vector_search/interface.py +++ b/superduperdb/vector_search/interface.py @@ -11,6 +11,13 @@ class FastVectorSearcher(BaseVectorSearcher): + """Fast vector searcher implementation using the server. + + :param db: Datalayer instance + :param vector_searcher: Vector searcher instance + :param vector_index: Vector index name + """ + def __init__(self, db: 'Datalayer', vector_searcher, vector_index: str): self.searcher = vector_searcher self.vector_index = vector_index @@ -50,8 +57,7 @@ def add(self, items: t.Sequence[VectorItem]) -> None: return self.searcher.add(items) def delete(self, ids: t.Sequence[str]) -> None: - """ - Remove items from the index + """Remove items from the index. :param ids: t.Sequence of ids of vectors. """ @@ -79,6 +85,7 @@ def find_nearest_from_id( :param _id: id of the vector :param n: number of nearest vectors to return + :param within_ids: list of ids to search within """ if CFG.cluster.vector_search.uri is not None: response = request_server( @@ -101,6 +108,7 @@ def find_nearest_from_array( :param h: vector :param n: number of nearest vectors to return + :param within_ids: list of ids to search within """ if CFG.cluster.vector_search.uri is not None: response = request_server( @@ -114,6 +122,7 @@ def find_nearest_from_array( return self.searcher.find_nearest_from_array(h=h, n=n, within_ids=within_ids) def post_create(self): + """Post create method for vector searcher.""" if CFG.cluster.is_remote_vector_search: request_server( service='vector_search', diff --git a/superduperdb/vector_search/lance.py b/superduperdb/vector_search/lance.py index 8a27c4cc01..91c2ccd109 100644 --- a/superduperdb/vector_search/lance.py +++ b/superduperdb/vector_search/lance.py @@ -44,6 +44,7 @@ def __init__( @property def dataset(self): + """Return the Lance dataset.""" if not os.path.exists(self.dataset_path): self._create_or_append_to_dataset([], [], mode='create') return lance.dataset(self.dataset_path) @@ -69,11 +70,19 @@ def _create_or_append_to_dataset(self, vectors, ids, mode: str = 'upsert'): lance.write_dataset(_table, self.dataset_path, mode=mode) def add(self, items: t.Sequence[VectorItem]) -> None: + """Add vectors to the index. + + :param items: List of vectors to add + """ ids = [item.id for item in items] vectors = [item.vector for item in items] self._create_or_append_to_dataset(vectors, ids, mode='append') def delete(self, ids: t.Sequence[str]) -> None: + """Delete vectors from the index. + + :param ids: List of IDs to delete + """ to_remove = ", ".join(f"'{str(id)}'" for id in ids) self.dataset.delete(f"id IN ({to_remove})") @@ -83,6 +92,12 @@ def find_nearest_from_id( n: int = 100, within_ids: t.Sequence[str] = (), ) -> t.Tuple[t.List[str], t.List[float]]: + """Find the nearest vectors to a given ID. + + :param _id: ID to search + :param n: Number of results to return + :param within_ids: List of IDs to search within + """ # The ``lance`` file format has been specifically designed for fast # random access. The logic to take advantage of this is implemented # by the ``.take`` method. @@ -101,6 +116,12 @@ def find_nearest_from_array( n: int = 100, within_ids: t.Sequence[str] = (), ) -> t.Tuple[t.List[str], t.List[float]]: + """Find the nearest vectors to a given vector. + + :param h: Vector to search + :param n: Number of results to return + :param within_ids: List of IDs to search within + """ # NOTE: filter is currently applied AFTER vector-search # See https://lancedb.github.io/lance/api/python/lance.html#lance.dataset.LanceDataset.scanner if within_ids: diff --git a/superduperdb/vector_search/server/app.py b/superduperdb/vector_search/server/app.py index 07d679b951..bbab7882c3 100644 --- a/superduperdb/vector_search/server/app.py +++ b/superduperdb/vector_search/server/app.py @@ -17,18 +17,33 @@ class VectorItem(BaseModel): + """A vector item model.""" + id: str vector: service.ListVectorType @app.add("/create/search", status_code=200, method='get') def create_search(vector_index: str, db: Datalayer = DatalayerDependency()): + """Create a vector index. + + :param vector_index: Vector index to create + :param db: Datalayer instance + """ service.create_search(vector_index=vector_index, db=db) return {'message': 'Vector index created successfully'} @app.add("/create/post_create", status_code=200, method='get') def post_create(vector_index: str, db: Datalayer = DatalayerDependency()): + """Post create method for vector searcher. + + Performs post create method of vector searcher to incorporate remaining vectors + to be added in cache. + + :param vector_index: Vector index to post create + :param db: Datalayer instance + """ service.post_create(vector_index=vector_index, db=db) return {'message': 'Post create executed successfully'} @@ -37,6 +52,13 @@ def post_create(vector_index: str, db: Datalayer = DatalayerDependency()): def query_search_by_id( id: str, vector_index: str, n: int = 100, db: Datalayer = DatalayerDependency() ): + """Query the vector index with an id. + + :param id: Id to query + :param vector_index: Vector index to query + :param n: Number of results to return + :param db: Datalayer instance + """ ids, scores = service.query_search_from_id( id, vector_index=vector_index, n=n, db=db ) @@ -56,6 +78,13 @@ def query_search_by_array( n: int = 100, db: Datalayer = DatalayerDependency(), ): + """Query the vector index with a vector. + + :param vector: Vector to query + :param vector_index: Vector index to query + :param n: Number of results to return + :param db: Datalayer instance + """ ids, scores = service.query_search_from_array( vector, vector_index=vector_index, n=n, db=db ) @@ -75,6 +104,12 @@ def add_search( vector_index: str, db: Datalayer = DatalayerDependency(), ): + """Add vectors to the vector index. + + :param vectors: List of vectors to add + :param vector_index: Vector index to add to + :param db: Datalayer instance + """ logging.info(f'Adding {len(vectors)} to search') service.add_search(vectors, vector_index=vector_index, db=db) return {'message': 'Added vectors successfully'} @@ -84,10 +119,20 @@ def add_search( def delete_search( ids: t.List[str], vector_index: str, db: Datalayer = DatalayerDependency() ): + """Delete vectors from the vector index. + + :param ids: List of ids to delete + :param vector_index: Vector index to delete from + :param db: Datalayer instance + """ service.delete_search(ids, vector_index=vector_index, db=db) return {'message': 'Ids deleted successfully'} @app.add("/list/search") def list_search(db: Datalayer = DatalayerDependency()): + """List all the vector indices in the database. + + :param db: Datalayer instance + """ return service.list_search(db) diff --git a/superduperdb/vector_search/server/service.py b/superduperdb/vector_search/server/service.py index e9909d72ec..742f16675f 100644 --- a/superduperdb/vector_search/server/service.py +++ b/superduperdb/vector_search/server/service.py @@ -16,8 +16,8 @@ def _vector_search( x: t.Union[str, ListVectorType], n: int, vector_index: str, + db: Datalayer, by_array: bool = True, - db=None, ) -> VectorSearchResultType: vi = db.fast_vector_searchers[vector_index] if by_array: @@ -30,79 +30,93 @@ def _vector_search( def database(request: Request) -> Datalayer: + """Helper function to get the database instance from the app state. + + :param request: request object + """ return request.app.state.pool def list_search(db: Datalayer): - ''' - Helper functon for listing all vector search indices. + """Helper functon for listing all vector search indices. + :param db: Datalayer instance - ''' + """ return db.show('vector_index') -def post_create(vector_index: str, db=None): - ''' - Performs post create method of vector searcher to - incorporate remaining vectors to be added in cache. +def post_create(vector_index: str, db: Datalayer): + """Post create method for vector searcher. + + Performs post create method of vector searcher to incorporate remaining vectors + to be added in cache. + :param vector_index: Vector class to initiate :param db: Datalayer instance - ''' + """ vi = db.fast_vector_searchers[vector_index] vi.post_create() -def create_search(vector_index: str, db=None): - ''' - Initiates a vector search class corresponding to `vector_index` +def create_search(vector_index: str, db: Datalayer): + """Initiates a vector search class corresponding to `vector_index`. + :param vector_index: Vector class to initiate :param db: Datalayer instance - ''' + """ db.fast_vector_searchers.update( {vector_index: db.initialize_vector_searcher(vector_index)} ) def query_search_from_array( - array: ListVectorType, vector_index: str, n: int = 100, db=None + array: ListVectorType, + vector_index: str, + db: Datalayer, + n: int = 100, ) -> VectorSearchResultType: - ''' - Perform a vector search with an array + """Perform a vector search with an array. + :param array: Array to perform vector search on index. :param vector_index: Vector search index :param db: Datalayer instance - ''' + :param n: Number of nearest neighbors to be returned + """ return _vector_search(array, n=n, vector_index=vector_index, db=db) def query_search_from_id( - id: str, vector_index: str, n: int = 100, db=None + id: str, vector_index: str, db: Datalayer, n: int = 100 ) -> VectorSearchResultType: - ''' - Perform a vector search with an id + """Perform a vector search with an id. + :param id: Identifier for vector :param vector_index: Vector search index :param db: Datalayer instance - ''' + :param n: Number of nearest neighbors to be returned + """ return _vector_search(id, n=n, vector_index=vector_index, db=db, by_array=False) -def add_search(vector, vector_index: str, db=None): - ''' - Adds a vector in vector index `vector_index` +def add_search(vector, vector_index: str, db: Datalayer): + """Adds a vector in vector index `vector_index`. + :param vector: Vector to be added. :param vector_index: Vector index where vector needs to be added. - - ''' + :param db: Datalayer instance + """ vector = [VectorItem(id=v.id, vector=v.vector) for v in vector] vi = db.fast_vector_searchers[vector_index] vi.searcher.add(vector) -def delete_search(ids: t.List[str], vector_index: str, db=None): - ''' - Deletes a vector corresponding to `id` - ''' +def delete_search(ids: t.List[str], vector_index: str, db: Datalayer): + """Deletes a vector corresponding to `id`. + + :param ids: List of ids to be deleted. + :param vector_index: Vector index where vector needs to be deleted. + :param db: Datalayer instance + """ vi = db.fast_vector_searchers[vector_index] vi.searcher.delete(ids) diff --git a/superduperdb/vector_search/update_tasks.py b/superduperdb/vector_search/update_tasks.py index 31876dab32..f830216f8a 100644 --- a/superduperdb/vector_search/update_tasks.py +++ b/superduperdb/vector_search/update_tasks.py @@ -7,19 +7,20 @@ from superduperdb.misc.special_dicts import MongoStyleDict from superduperdb.vector_search.base import VectorItem +if t.TYPE_CHECKING: + from superduperdb.base.datalayer import Datalayer + def delete_vectors( vector_index: str, ids: t.Sequence[str], - db=None, + db=t.Optional['Datalayer'], ): - """ - A helper fxn to delete vectors of a ``VectorIndex`` component - in the fast_vector_search backend. + """Delete vectors of a ``VectorIndex`` component in the fast_vector_search backend. :param vector_index: A identifier of vector-index. :param ids: List of ids which were observed as deleted documents. - :param db: A ``DB`` instance. + :param db: Datalayer instance. """ return db.fast_vector_searchers[vector_index].delete(ids) @@ -28,18 +29,15 @@ def copy_vectors( vector_index: str, query: t.Union[t.Dict, CompoundSelect], ids: t.Sequence[str], - db=None, + db=t.Optional['Datalayer'], ): - """ - A helper fxn to copy vectors of a ``VectorIndex`` component - from the databackend to the fast_vector_search backend. + """Copy vectors of a ``VectorIndex`` component from the databackend to the fast_vector_search backend. - :param vector-index: A identifier of the vector-index. + :param vector_index: A identifier of the vector-index. :param query: A query which was used by `db._build_task_workflow` method :param ids: List of ids which were observed as added/updated documents. - :param db: A ``DB`` instance. + :param db: Datalayer instance. """ - vi = db.vector_indices[vector_index] if isinstance(query, dict): # ruff: noqa: E501 diff --git a/test/integration/ext/openai/test_model_openai.py b/test/integration/ext/openai/test_model_openai.py index 6c68acaa23..072cf45054 100644 --- a/test/integration/ext/openai/test_model_openai.py +++ b/test/integration/ext/openai/test_model_openai.py @@ -43,11 +43,11 @@ def _make_vcr_request(httpx_request, **kwargs): def before_record_response(response): - ''' + """ VCR filter function to only record the PNG signature in the response. This is necessary because the response is a PNG which can be quite large. - ''' + """ if 'body' not in response: return response if PNG_BYTE_SIGNATURE in response['body']['string']: diff --git a/test/unittest/ext/llm/test_openai.py b/test/unittest/ext/llm/test_openai.py deleted file mode 100644 index c02feb72eb..0000000000 --- a/test/unittest/ext/llm/test_openai.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -from test.db_config import DBConfig -from test.unittest.ext.llm.utils import check_llm_as_listener_model, check_predict - -import pytest -import vcr - -from superduperdb.ext.openai.model import OpenAILLM - -CASSETTE_DIR = "test/unittest/ext/cassettes/llm/openai" - - -@pytest.fixture -def openai_mock(monkeypatch): - if os.getenv("OPENAI_API_KEY") is None: - monkeypatch.setenv("OPENAI_API_KEY", "sk-TopSecret") - - -@pytest.mark.skip("Skip until openai connection failures fixed") -@vcr.use_cassette( - f"{CASSETTE_DIR}/test_predict.yaml", - filter_headers=["authorization"], - record_on_exception=False, - ignore_localhost=True, -) -@pytest.mark.parametrize( - "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True -) -def test_predict(db, openai_mock): - """Test chat.""" - check_predict(db, OpenAILLM(model_name="gpt-3.5-turbo")) - check_predict( - db, OpenAILLM(identifier="chat-llm", model_name="gpt-3.5-turbo", chat=True) - ) - - -@pytest.mark.skip("Skip until openai connection failures fixed") -@vcr.use_cassette( - f"{CASSETTE_DIR}/test_llm_as_listener_model.yaml", - filter_headers=["authorization"], - record_on_exception=False, - ignore_localhost=True, -) -@pytest.mark.parametrize( - "db", [DBConfig.mongodb_empty, DBConfig.sqldb_empty], indirect=True -) -def test_llm_as_listener_model(db, openai_mock): - check_llm_as_listener_model(db, OpenAILLM(model_name="gpt-3.5-turbo")) diff --git a/test/unittest/test_docstrings.py b/test/unittest/test_docstrings.py index 101eb87afa..3c6d329ac8 100644 --- a/test/unittest/test_docstrings.py +++ b/test/unittest/test_docstrings.py @@ -60,8 +60,9 @@ def get_function_params(node): params = [] args = node.args.args args += node.args.kwonlyargs or [] + args += node.args.posonlyargs or [] for arg in args: - if arg.arg not in ['self', 'cls']: + if arg.arg not in ['self', 'cls', 'args', 'kwargs']: params.append(arg.arg) return params @@ -69,7 +70,7 @@ def get_function_params(node): def get_method_params(node): params = [] for arg in node.args.args: - if arg.arg != 'self': + if arg.arg not in ['self', 'cls', 'args' 'kwargs']: params.append(arg.arg) return params @@ -87,6 +88,9 @@ def get_dataclass_init_params(node): if isinstance(target, ast.Name): field_name = target.id init_params.append(field_name) + elif isinstance(item, (ast.FunctionDef,)): + # Ignore the assign after function + break init_params = [p for p in init_params if p != '__doc__'] return init_params @@ -94,6 +98,12 @@ def get_dataclass_init_params(node): def get_doc_string_params(doc_string): param_pattern = r":param (\w+): (.+)" params = re.findall(param_pattern, doc_string) + params = [ + param + for param in params + if not param[0].startswith('*') + if param[0] not in ['args', 'kwargs'] + ] return {param[0]: param[1] for param in params} @@ -107,13 +117,17 @@ def check_class_docstring(file_path, node, dataclass=False): params = get_dataclass_init_params(node) else: params = get_class_init_params(node) + params = [p for p in params if p not in ['args', 'kwargs']] doc_params = get_doc_string_params(doc_string) if len(doc_params) != len(params): raise MismatchingDocParameters( file_path=file_path, node=node, - msg=f'Got {len(params)} parameters but doc-string has {len(doc_params)}.', + msg=( + f'Got {len(params)} parameters but doc-string has {len(doc_params)}. ' + f'Diffs: {set(params) ^ set(doc_params.keys())}' + ), ) for i, (p, (dp, expl)) in enumerate(zip(params, doc_params.items())): @@ -142,7 +156,10 @@ def check_method_docstring(file_path, parent_class, node): raise MismatchingDocParameters( file_path=file_path, node=node, - msg=f'Got {len(params)} parameters but doc-string has {len(doc_params)}.', + msg=( + f'Got {len(params)} parameters but doc-string has {len(doc_params)}. ' + f'Diffs: {set(params) ^ set(doc_params.keys())}' + ), parent=parent_class, ) @@ -176,7 +193,10 @@ def check_function_doc_string(file_path, node): raise MismatchingDocParameters( file_path=file_path, node=node, - msg=f'Got {len(params)} parameters but doc-string has {len(doc_params)}.', + msg=( + f'Got {len(params)} parameters but doc-string has {len(doc_params)}. ' + f'Diffs: {set(params) ^ set(doc_params.keys())}' + ), ) for i, (p, (dp, expl)) in enumerate(zip(params, doc_params.items())): @@ -260,7 +280,7 @@ def _extract(file_path): CLASS_TEST_CASES, METHOD_TEST_CASES, FUNCTION_TEST_CASES = extract_docstrings( - './superduperdb/components/', './superduperdb/base/datalayer.py' + './superduperdb', )