diff --git a/.github/workflows/ci_code.yml b/.github/workflows/ci_code.yml index 157c7e1007..647f2dfa2c 100644 --- a/.github/workflows/ci_code.yml +++ b/.github/workflows/ci_code.yml @@ -66,10 +66,6 @@ jobs: # Once the local file database is complete, we may need to update this section. python -m pip install plugins/mongodb - - name: Install DevKit (docs, testing, etc) - run: | - make install_devkit - - name: Lint and type-check run: | make lint-and-type-check diff --git a/.github/workflows/ci_plugins.yaml b/.github/workflows/ci_plugins.yaml index 96fdd65e82..a9cd535228 100644 --- a/.github/workflows/ci_plugins.yaml +++ b/.github/workflows/ci_plugins.yaml @@ -6,7 +6,7 @@ on: - main jobs: - prepare_matrix: + plugin_update_check: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 @@ -38,11 +38,11 @@ jobs: matrix: ${{ steps.set-matrix.outputs.matrix }} test_plugin: - needs: prepare_matrix + needs: plugin_update_check runs-on: ubuntu-latest strategy: fail-fast: false - matrix: ${{fromJson(needs.prepare_matrix.outputs.matrix)}} + matrix: ${{fromJson(needs.plugin_update_check.outputs.matrix)}} steps: - name: Checkout repository uses: actions/checkout@v4 @@ -66,27 +66,22 @@ jobs: # Install core and testsuite dependencies on the cached python environment. python -m pip install '.[test]' - - name: Install DevKit (docs, testing, etc) - run: | - make install_devkit - - name: Lint and type-check run: | make lint-and-type-check DIRECTORIES="plugins/${{ matrix.plugin }}" - name: Install Plugin run: | - python -m pip install 'plugins/${{ matrix.plugin }}[test]' + echo "Installing local plugin dependencies..." + grep -o '#\s*:CI:\s*plugins/.*' plugins/${{ matrix.plugin }}/pyproject.toml | while read line; do + dep_path=${line##*# :CI: } + dep_path=${dep_path%%[[:space:]]*} + echo "Installing $dep_path for testing..." + python -m pip install "$dep_path[test]" + done + echo "Installing plugin..." + python -m pip install "plugins/${{ matrix.plugin }}[test]" - - name: Optionally run custom CI script - run: | - if [ -f "plugins/${{ matrix.plugin }}/.ci_extend.sh" ]; then - echo "Running custom CI script..." - bash ./plugins/${{ matrix.plugin }}/.ci_extend.sh - else - echo "No custom CI script found, skipping..." - fi - - name: Plugin Testing run: | export PYTHONPATH=./ diff --git a/Makefile b/Makefile index b40b7cfc4a..bcf1e9214c 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,7 @@ DIRECTORIES ?= superduper test SUPERDUPER_CONFIG ?= test/configs/default.yaml PYTEST_ARGUMENTS ?= +PLUGIN_NAME ?= # Export directories for data and artifacts export SUPERDUPER_DATA_DIR ?= ~/.cache/superduper/test_data @@ -47,49 +48,6 @@ new_release: ## Release a new version of superduper.io # Push the specific tag git push origin $(RELEASE_VERSION) - -install_devkit: ## Add essential development tools - # Add pre-commit hooks to ensure that no strange stuff are being committed. - # https://stackoverflow.com/questions/3462955/putting-git-hooks-into-a-repository - python -m pip install pre-commit - - @echo "Download Code Quality dependencies" - python -m pip install black==23.3 ruff==0.4.4 mypy types-PyYAML types-requests interrogate - - @echo "Download Code Testing dependencies" - python -m pip install pytest pytest-cov "nbval>=0.10.0" - -install_plugin: - @if [ "$(filter-out $@,$(MAKECMDGOALS))" = "all" ]; then \ - $(MAKE) install_all_plugins; \ - else \ - echo "Error: Cannot find plugin '$(filter-out $@,$(MAKECMDGOALS))'"; \ - fi - -install_named_plugin: - @if [ -f "plugins/$(PLUGIN)/pyproject.toml" ]; then \ - python -m pip install -e "plugins/$(PLUGIN)"; \ - else \ - echo "Error: Plugin '$(PLUGIN)' not found."; \ - fi - -install_all_plugins: - @plugins=""; \ - for plugin in $$(ls plugins); do \ - if [ "$$plugin" != "template" -a -d "plugins/$$plugin" -a -f "plugins/$$plugin/pyproject.toml" ]; then \ - plugins="$$plugins $$plugin"; \ - fi \ - done; \ - echo "Found plugins:$$plugins"; \ - for plugin in $$plugins; do \ - echo "Installing $$plugin..."; \ - python -m pip install -e "plugins/$$plugin"; \ - done -%: - @: - - - ##@ Code Quality gen_docs: ## Generate Docs and API diff --git a/plugins/ibis/pyproject.toml b/plugins/ibis/pyproject.toml index fe3ec2a8a8..e24feef1d8 100644 --- a/plugins/ibis/pyproject.toml +++ b/plugins/ibis/pyproject.toml @@ -25,6 +25,8 @@ dependencies = [ [project.optional-dependencies] test = [ "ibis-framework[sqlite]>=5.1.0", + # Annotation plugin dependencies will be installed in CI + # :CI: plugins/sqlalchemy ] [project.urls] diff --git a/plugins/pillow/.ci_extend.sh b/plugins/pillow/.ci_extend.sh deleted file mode 100644 index e917db1b00..0000000000 --- a/plugins/pillow/.ci_extend.sh +++ /dev/null @@ -1 +0,0 @@ -python -m pip install plugins/mongodb diff --git a/plugins/pillow/pyproject.toml b/plugins/pillow/pyproject.toml index c03ad53887..8f439bbff2 100644 --- a/plugins/pillow/pyproject.toml +++ b/plugins/pillow/pyproject.toml @@ -22,6 +22,13 @@ dependencies = [ "pillow>=10.2.0", ] +[project.optional-dependencies] +test = [ + "ibis-framework[sqlite]>=5.1.0", + # Annotation plugin dependencies will be installed in CI + # :CI: plugins/sqlalchemy +] + [project.urls] homepage = "https://www.superduper.com/" documentation = "https://docs.superduper.com" diff --git a/plugins/sklearn/.ci_extend.sh b/plugins/sklearn/.ci_extend.sh deleted file mode 100644 index e917db1b00..0000000000 --- a/plugins/sklearn/.ci_extend.sh +++ /dev/null @@ -1 +0,0 @@ -python -m pip install plugins/mongodb diff --git a/plugins/sklearn/pyproject.toml b/plugins/sklearn/pyproject.toml index 6abbf7703b..5d6b0fe9f5 100644 --- a/plugins/sklearn/pyproject.toml +++ b/plugins/sklearn/pyproject.toml @@ -21,6 +21,12 @@ dynamic = ["version"] dependencies = [ ] +[project.optional-dependencies] +test = [ + # Annotation plugin dependencies will be installed in CI + # :CI: plugins/mongodb +] + [project.urls] homepage = "https://www.superduper.com/" documentation = "https://docs.superduper.com" diff --git a/plugins/torch/.ci_extend.sh b/plugins/torch/.ci_extend.sh deleted file mode 100644 index e917db1b00..0000000000 --- a/plugins/torch/.ci_extend.sh +++ /dev/null @@ -1 +0,0 @@ -python -m pip install plugins/mongodb diff --git a/plugins/torch/pyproject.toml b/plugins/torch/pyproject.toml index 9bf8c2628d..95d7ae6f2f 100644 --- a/plugins/torch/pyproject.toml +++ b/plugins/torch/pyproject.toml @@ -23,6 +23,12 @@ dependencies = [ "torchvision>=0.17.1", ] +[project.optional-dependencies] +test = [ + # Annotation plugin dependencies will be installed in CI + # :CI: plugins/mongodb +] + [project.urls] homepage = "https://www.superduper.com/" documentation = "https://docs.superduper.com" diff --git a/plugins/transformers/.ci_extend.sh b/plugins/transformers/.ci_extend.sh deleted file mode 100644 index e917db1b00..0000000000 --- a/plugins/transformers/.ci_extend.sh +++ /dev/null @@ -1 +0,0 @@ -python -m pip install plugins/mongodb diff --git a/plugins/transformers/pyproject.toml b/plugins/transformers/pyproject.toml index 266fb63149..1e986297e4 100644 --- a/plugins/transformers/pyproject.toml +++ b/plugins/transformers/pyproject.toml @@ -27,6 +27,8 @@ dependencies = [ test = [ "peft>=0.10.0", "trl>=0.8.0", + # Annotation plugin dependencies will be installed in CI + # :CI: plugins/mongodb ] [project.urls] diff --git a/plugins/vllm/.ci_extend.sh b/plugins/vllm/.ci_extend.sh deleted file mode 100644 index e917db1b00..0000000000 --- a/plugins/vllm/.ci_extend.sh +++ /dev/null @@ -1 +0,0 @@ -python -m pip install plugins/mongodb diff --git a/plugins/vllm/pyproject.toml b/plugins/vllm/pyproject.toml index c5addcb3e1..e36d1b5092 100644 --- a/plugins/vllm/pyproject.toml +++ b/plugins/vllm/pyproject.toml @@ -24,6 +24,8 @@ dependencies = [ [project.optional-dependencies] test = [ "vcrpy>=5.1.0", + # Annotation plugin dependencies will be installed in CI + # :CI: plugins/mongodb ] [project.urls] diff --git a/pyproject.toml b/pyproject.toml index df6a5d8702..04c1b76cbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,6 +61,16 @@ dependencies = [ test = [ "scikit-learn>=1.1.3", "pandas", + "pre-commit", + "black==23.3", + "ruff==0.4.4", + "mypy", + "types-PyYAML", + "types-requests", + "interrogate", + "pytest", + "pytest-cov", + "nbval>=0.10.0", ] [project.urls] diff --git a/superduper/backends/ibis/__init__.py b/superduper/backends/ibis/__init__.py index 5e49b79324..e86d221215 100644 --- a/superduper/backends/ibis/__init__.py +++ b/superduper/backends/ibis/__init__.py @@ -1 +1,5 @@ from superduper_ibis import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('ibis') diff --git a/superduper/backends/mongodb/__init__.py b/superduper/backends/mongodb/__init__.py index 0473f09a34..ff1f16152e 100644 --- a/superduper/backends/mongodb/__init__.py +++ b/superduper/backends/mongodb/__init__.py @@ -1 +1,5 @@ from superduper_mongodb import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('mongodb') diff --git a/superduper/backends/sqlalchemy/__init__.py b/superduper/backends/sqlalchemy/__init__.py new file mode 100644 index 0000000000..c613a3956f --- /dev/null +++ b/superduper/backends/sqlalchemy/__init__.py @@ -0,0 +1,5 @@ +from superduper_sqlalchemy import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('sqlalchemy') diff --git a/superduper/backends/sqlalchemy/db_helper.py b/superduper/backends/sqlalchemy/db_helper.py index f48fec45f2..156845c5ff 100644 --- a/superduper/backends/sqlalchemy/db_helper.py +++ b/superduper/backends/sqlalchemy/db_helper.py @@ -1,119 +1 @@ -import json -from typing import Tuple - -from sqlalchemy import ( - Boolean, - DateTime, - Integer, - String, - Text, - TypeDecorator, -) - -DEFAULT_LENGTH = 255 - - -class JsonMixin: - """Mixin for JSON type columns. - - Converts dict to JSON strings before saving to database - and converts JSON strings to dict when loading from database. - - # noqa - """ - - def process_bind_param(self, value, dialect): - """Convert dict to JSON string. - - :param value: The dict to convert. - :param dialect: The dialect of the database. - """ - if value is not None: - value = json.dumps(value) - return value - - def process_result_value(self, value, dialect): - """Convert JSON string to dict. - - :param value: The JSON string to convert. - :param dialect: The dialect of the database. - """ - if value is not None: - value = json.loads(value) - return value - - -class JsonAsString(JsonMixin, TypeDecorator): - """JSON type column for short JSON strings # noqa.""" - - impl = String(DEFAULT_LENGTH) - - -class JsonAsText(JsonMixin, TypeDecorator): - """JSON type column for long JSON strings # noqa.""" - - impl = Text - - -class DefaultConfig: - """Default configuration for database types # noqa.""" - - type_string = String(DEFAULT_LENGTH) - type_json_as_string = JsonAsString - type_json_as_text = JsonAsText - type_integer = Integer - type_datetime = DateTime - type_boolean = Boolean - - query_id_table_args: Tuple = tuple() - job_table_args: Tuple = tuple() - parent_child_association_table_args: Tuple = tuple() - component_table_args: Tuple = tuple() - meta_table_args: Tuple = tuple() - - -def create_clickhouse_config(): - """Create configuration for ClickHouse database.""" - # lazy import - try: - from clickhouse_sqlalchemy import engines, types - except ImportError: - raise ImportError( - 'The clickhouse_sqlalchemy package is required to use the ' - 'clickhouse dialect. Please install it with pip install ' - 'clickhouse-sqlalchemy' - ) - - class ClickHouseConfig: - class JsonAsString(JsonMixin, TypeDecorator): - impl = types.String - - class JsonAsText(JsonMixin, TypeDecorator): - impl = types.String - - type_string = types.String - type_json_as_string = JsonAsString - type_json_as_text = JsonAsText - type_integer = types.Int32 - type_datetime = types.DateTime - type_boolean = types.Boolean - - # clickhouse need engine args to create table - query_id_table_args = (engines.MergeTree(order_by='query_id'),) - job_table_args = (engines.MergeTree(order_by='identifier'),) - parent_child_association_table_args = (engines.MergeTree(order_by='parent_id'),) - component_table_args = (engines.MergeTree(order_by='id'),) - meta_table_args = (engines.MergeTree(order_by='key'),) - - return ClickHouseConfig - - -def get_db_config(dialect): - """Get the configuration class for the specified dialect. - - :param dialect: The dialect of the database. - """ - if dialect == 'clickhouse': - return create_clickhouse_config() - else: - return DefaultConfig +from superduper_sqlalchemy.db_helper import * # noqa diff --git a/superduper/backends/sqlalchemy/metadata.py b/superduper/backends/sqlalchemy/metadata.py index d4b2939fdd..0e86ba058d 100644 --- a/superduper/backends/sqlalchemy/metadata.py +++ b/superduper/backends/sqlalchemy/metadata.py @@ -1,574 +1 @@ -import threading -import typing as t -from contextlib import contextmanager - -import click -from sqlalchemy import ( - Column, - MetaData, - Table, - and_, - create_engine, - delete, - insert, - select, -) -from sqlalchemy.exc import ProgrammingError -from sqlalchemy.orm import sessionmaker - -from superduper import logging -from superduper.backends.base.metadata import MetaDataStore, NonExistentMetadataError -from superduper.backends.sqlalchemy.db_helper import get_db_config -from superduper.misc.colors import Colors - - -class SQLAlchemyMetadata(MetaDataStore): - """ - Abstraction for storing meta-data separately from primary data. - - :param uri: URI to the databackend database. - :param flavour: Flavour of the databackend. - :param callback: Optional callback to create connection. - """ - - def __init__( - self, - uri: t.Optional[str] = None, - flavour: t.Optional[str] = None, - callback: t.Optional[t.Callable] = None, - ): - super().__init__(uri=uri, flavour=flavour) - - if callback: - self.connection_callback = callback - else: - assert isinstance(uri, str) - name = uri.split('//')[0] - self.connection_callback = lambda: (create_engine(uri), name) - - sql_conn, name = self.connection_callback() - - self.name = name - self.conn = sql_conn - self.dialect = sql_conn.dialect.name - self._init_tables() - - self._lock = threading.Lock() - - def reconnect(self): - """Reconnect to sqlalchmey metadatastore.""" - sql_conn = create_engine(self.uri) - self.conn = sql_conn - - # TODO: is it required to init after - # a reconnect. - self._init_tables() - - def _init_tables(self): - # Get the DB config for the given dialect - DBConfig = get_db_config(self.dialect) - - type_string = DBConfig.type_string - type_json_as_string = DBConfig.type_json_as_string - type_json_as_text = DBConfig.type_json_as_text - type_integer = DBConfig.type_integer - type_datetime = DBConfig.type_datetime - type_boolean = DBConfig.type_boolean - - job_table_args = DBConfig.job_table_args - parent_child_association_table_args = ( - DBConfig.parent_child_association_table_args - ) - component_table_args = DBConfig.component_table_args - - metadata = MetaData() - - self.job_table = Table( - 'JOB', - metadata, - Column('identifier', type_string, primary_key=True), - Column('component_identifier', type_string), - Column('type_id', type_string), - Column('info', type_json_as_string), - Column('time', type_datetime), - Column('status', type_string), - Column('args', type_json_as_string), - Column('kwargs', type_json_as_text), - Column('method_name', type_string), - Column('stdout', type_json_as_string), - Column('stderr', type_json_as_string), - Column('_path', type_string), - Column('job_id', type_string), - *job_table_args, - ) - - self.parent_child_association_table = Table( - 'PARENT_CHILD_ASSOCIATION', - metadata, - Column('parent_id', type_string, primary_key=True), - Column('child_id', type_string, primary_key=True), - *parent_child_association_table_args, - ) - - self.component_table = Table( - 'COMPONENT', - metadata, - Column('id', type_string, primary_key=True), - Column('identifier', type_string), - Column('version', type_integer), - Column('hidden', type_boolean), - Column('type_id', type_string), - Column('_path', type_string), - Column('dict', type_json_as_text), - *component_table_args, - ) - - metadata.create_all(self.conn) - - def url(self): - """Return the URL of the metadata store.""" - return self.conn.url + self.name - - def drop(self, force: bool = False): - """Drop the metadata store. - - :param force: whether to force the drop (without confirmation) - """ - if not force: - if not click.confirm( - f'{Colors.RED}[!!!WARNING USE WITH CAUTION AS YOU ' - f'WILL LOSE ALL DATA!!!]{Colors.RESET} ' - 'Are you sure you want to drop all meta-data? ', - default=False, - ): - logging.warn('Aborting...') - try: - self.job_table.drop(self.conn) - except ProgrammingError as e: - logging.warn(f'Error dropping job table: {e}') - - try: - self.parent_child_association_table.drop(self.conn) - except ProgrammingError as e: - logging.warn(f'Error dropping parent-child association table: {e}') - - try: - self.component_table.drop(self.conn) - except ProgrammingError as e: - logging.warn(f'Error dropping component table {e}') - - @contextmanager - def session_context(self): - """Provide a transactional scope around a series of operations.""" - sm = sessionmaker(bind=self.conn) - session = sm() - try: - yield session - session.commit() - except Exception: - session.rollback() - raise - finally: - session.close() - - # --------------- COMPONENTS ----------------- - - def _get_component_uuid(self, type_id: str, identifier: str, version: int): - with self.session_context() as session: - stmt = ( - select(self.component_table) - .where( - self.component_table.c.type_id == type_id, - self.component_table.c.identifier == identifier, - self.component_table.c.version == version, - ) - .limit(1) - ) - res = self.query_results(self.component_table, stmt, session) - return res[0]['id'] if res else None - - def _get_all_component_info(self): - with self.session_context() as session: - res = self.query_results( - self.component_table, - select(self.component_table), - session=session, - ) - return list(res) - - def component_version_has_parents( - self, type_id: str, identifier: str, version: int - ): - """Check if a component version has parents. - - :param type_id: the type of the component - :param identifier: the identifier of the component - :param version: the version of the component - """ - uuid = self._get_component_uuid(type_id, identifier, version) - with self.session_context() as session: - stmt = ( - select(self.parent_child_association_table) - .where( - self.parent_child_association_table.c.child_id == uuid, - ) - .limit(1) - ) - res = self.query_results(self.parent_child_association_table, stmt, session) - return len(res) > 0 - - def create_component(self, info: t.Dict): - """Create a component in the metadata store. - - :param info: the information to create the component - """ - new_info = self._refactor_component_info(info) - with self.session_context() as session: - stmt = insert(self.component_table).values(**new_info) - session.execute(stmt) - - def delete_parent_child(self, parent_id: str, child_id: str): - """ - Delete parent-child relationships between two components. - - :param parent: parent component uuid - :param child: child component uuid - """ - with self.session_context() as session: - stmt = delete(self.parent_child_association_table).where( - self.parent_child_association_table.c.parent_id == parent_id, - self.parent_child_association_table.c.child_id == child_id, - ) - session.execute(stmt) - - def create_parent_child(self, parent_id: str, child_id: str): - """Create a parent-child relationship between two components. - - :param parent_id: the parent component - :param child_id: the child component - """ - with self.session_context() as session: - stmt = insert(self.parent_child_association_table).values( - parent_id=parent_id, child_id=child_id - ) - session.execute(stmt) - - def delete_component_version(self, type_id: str, identifier: str, version: int): - """Delete a component from the metadata store. - - :param type_id: the type of the component - :param identifier: the identifier of the component - :param version: the version of the component - """ - with self.session_context() as session: - stmt = ( - self.component_table.select() - .where( - self.component_table.c.type_id == type_id, - self.component_table.c.identifier == identifier, - self.component_table.c.version == version, - ) - .limit(1) - ) - res = self.query_results(self.component_table, stmt, session) - cv = res[0] if res else None - if cv: - stmt_delete = delete(self.component_table).where( - self.component_table.c.id == cv['id'] - ) - session.execute(stmt_delete) - - def get_component_by_uuid(self, uuid: str, allow_hidden: bool = False): - """Get a component by UUID. - - :param uuid: UUID of component - :param allow_hidden: whether to load hidden components - """ - with self.session_context() as session: - stmt = ( - select(self.component_table) - .where( - self.component_table.c.id == uuid, - ) - .limit(1) - ) - res = self.query_results(self.component_table, stmt, session) - try: - r = res[0] - except IndexError: - raise NonExistentMetadataError( - f'Table with uuid: {uuid} does not exist' - ) - - return self._get_component( - type_id=r['type_id'], - identifier=r['identifier'], - version=r['version'], - allow_hidden=allow_hidden, - ) - - def _get_component( - self, - type_id: str, - identifier: str, - version: int, - allow_hidden: bool = False, - ): - """Get a component from the metadata store. - - :param type_id: the type of the component - :param identifier: the identifier of the component - :param version: the version of the component - :param allow_hidden: whether to allow hidden components - """ - with self.session_context() as session: - stmt = select(self.component_table).where( - self.component_table.c.type_id == type_id, - self.component_table.c.identifier == identifier, - self.component_table.c.version == version, - ) - if not allow_hidden: - stmt = stmt.where(self.component_table.c.hidden == allow_hidden) - - res = self.query_results(self.component_table, stmt, session) - if res: - res = res[0] - dict_ = res['dict'] - del res['dict'] - res = {**res, **dict_} - return res - - def get_component_version_parents(self, uuid: str): - """Get the parents of a component version. - - :param uuid: the unique identifier of the component version - """ - with self.session_context() as session: - stmt = select(self.parent_child_association_table).where( - self.parent_child_association_table.c.child_id == uuid, - ) - res = self.query_results(self.parent_child_association_table, stmt, session) - parents = [r['parent_id'] for r in res] - return parents - - @classmethod - def _refactor_component_info(cls, info): - if 'hidden' not in info: - info['hidden'] = False - component_fields = ['identifier', 'version', 'hidden', 'type_id', '_path'] - new_info = {k: info[k] for k in component_fields} - new_info['dict'] = {k: info[k] for k in info if k not in component_fields} - new_info['id'] = new_info['dict']['uuid'] - return new_info - - def get_latest_version( - self, type_id: str, identifier: str, allow_hidden: bool = False - ): - """Get the latest version of a component. - - :param type_id: the type of the component - :param identifier: the identifier of the component - :param allow_hidden: whether to allow hidden components - """ - with self.session_context() as session: - stmt = ( - select(self.component_table) - .where( - self.component_table.c.type_id == type_id, - self.component_table.c.identifier == identifier, - self.component_table.c.hidden == allow_hidden, - ) - .order_by(self.component_table.c.version.desc()) - .limit(1) - ) - res = session.execute(stmt) - res = self.query_results(self.component_table, stmt, session) - versions = [r['version'] for r in res] - if len(versions) == 0: - raise FileNotFoundError( - f'Can\'t find {type_id}: {identifier} in metadata' - ) - return versions[0] - - def hide_component_version(self, type_id: str, identifier: str, version: int): - """Hide a component in the metadata store. - - :param type_id: the type of the component - :param identifier: the identifier of the component - :param version: the version of the component - """ - with self.session_context() as session: - stmt = ( - self.component_table.update() - .where( - self.component_table.c.type_id == type_id, - self.component_table.c.identifier == identifier, - self.component_table.c.version == version, - ) - .values(hidden=True) - ) - session.execute(stmt) - - def _replace_object(self, info, identifier, type_id, version): - info = self._refactor_component_info(info) - with self.session_context() as session: - stmt = ( - self.component_table.update() - .where( - self.component_table.c.type_id == type_id, - self.component_table.c.identifier == identifier, - self.component_table.c.version == version, - ) - .values(**info) - ) - session.execute(stmt) - - def _show_components(self, type_id: t.Optional[str] = None): - """Show all components in the database. - - :param type_id: the type of the component - """ - with self.session_context() as session: - stmt = select(self.component_table) - if type_id is not None: - stmt = stmt.where(self.component_table.c.type_id == type_id) - res = self.query_results(self.component_table, stmt, session) - return res - - def show_component_versions(self, type_id: str, identifier: str): - """Show all versions of a component in the database. - - :param type_id: the type of the component - :param identifier: the identifier of the component - """ - with self.session_context() as session: - stmt = select(self.component_table).where( - self.component_table.c.type_id == type_id, - self.component_table.c.identifier == identifier, - ) - res = self.query_results(self.component_table, stmt, session) - versions = [data['version'] for data in res] - versions = sorted(set(versions), key=lambda x: versions.index(x)) - return versions - - def _update_object( - self, - identifier: str, - type_id: str, - key: str, - value: t.Any, - version: int, - ): - with self.session_context() as session: - stmt = ( - self.component_table.update() - .where( - self.component_table.c.type_id == type_id, - self.component_table.c.identifier == identifier, - self.component_table.c.version == version, - ) - .values({key: value}) - ) - session.execute(stmt) - - # --------------- JOBS ----------------- - - def create_job(self, info: t.Dict): - """Create a job with the given info. - - :param info: The information used to create the job - """ - with self.session_context() as session: - stmt = insert(self.job_table).values(**info) - session.execute(stmt) - - def get_job(self, job_id: str): - """Get the job with the given job_id. - - :param job_id: The identifier of the job - """ - with self.session_context() as session: - stmt = ( - select(self.job_table) - .where(self.job_table.c.identifier == job_id) - .limit(1) - ) - res = self.query_results(self.job_table, stmt, session) - return res[0] if res else None - - def show_jobs( - self, - component_identifier: t.Optional[str] = None, - type_id: t.Optional[str] = None, - ): - """Show all jobs in the database. - - :param component_identifier: the identifier of the component - :param type_id: the type of the component - """ - with self.session_context() as session: - # Start building the select statement - stmt = select(self.job_table) - - # If a component_identifier is provided, add a where clause to filter by it - if component_identifier is not None: - stmt = stmt.where( - and_( - self.job_table.c.component_identifier == component_identifier, - self.job_table.c.type_id == type_id, - ) - ) - - # Execute the query and collect results - res = self.query_results(self.job_table, stmt, session) - - return res - - def update_job(self, job_id: str, key: str, value: t.Any): - """Update the job with the given key and value. - - :param job_id: The identifier of the job - :param key: The key to update - :param value: The value to update - """ - with self.session_context() as session: - stmt = ( - self.job_table.update() - .where(self.job_table.c.identifier == job_id) - .values({key: value}) - ) - session.execute(stmt) - - # --------------- Query ID ----------------- - - def disconnect(self): - """Disconnect the client.""" - - # TODO: implement me - - def query_results(self, table, statment, session): - """Query the database and return the results as a list of row datas. - - :param table: The table object to query, used to derive column names. - :param statment: The SQL statement to execute. - :param session: The database session within which the query is executed. - """ - # Some databases don't support defining statment outside of session - try: - result = session.execute(statment) - columns = [col.name for col in table.columns] - results = [] - for row in result: - if len(row) != len(columns): - raise ValueError( - f'Number of columns in result ({row}) does not match ' - f'number of columns in table ({columns})' - ) - results.append(dict(zip(columns, row))) - except ProgrammingError: - # Known ProgrammingErrors: - # - EmptyResults: Duckdb don't support return empty results - # - NotExist: SnowFlake returns error if a component does not exist - return [] - - return results +from superduper_sqlalchemy.metadata import * # noqa diff --git a/superduper/ext/__init__.py b/superduper/ext/__init__.py index e69de29bb2..5dd483f5c2 100644 --- a/superduper/ext/__init__.py +++ b/superduper/ext/__init__.py @@ -0,0 +1,11 @@ +from superduper import logging + + +def _warn_plugin_deprecated(name): + message = ( + f'superduper.ext.{name} is deprecated ' + 'and will be removed in a future release.' + f'Please insteall superduper_{name} and use' + f'from superduper_{name} import * instead.' + ) + logging.warn(message) diff --git a/superduper/ext/anthropic/__init__.py b/superduper/ext/anthropic/__init__.py index 22382a7af5..b9e00bd57b 100644 --- a/superduper/ext/anthropic/__init__.py +++ b/superduper/ext/anthropic/__init__.py @@ -1 +1,5 @@ from superduper_anthropic import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('anthropic') diff --git a/superduper/ext/cohere/__init__.py b/superduper/ext/cohere/__init__.py index b9ef9727bc..2a3396d233 100644 --- a/superduper/ext/cohere/__init__.py +++ b/superduper/ext/cohere/__init__.py @@ -1 +1,5 @@ from superduper_cohere import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('cohere') diff --git a/superduper/ext/jina/__init__.py b/superduper/ext/jina/__init__.py index 67d167f2cf..d03dcea798 100644 --- a/superduper/ext/jina/__init__.py +++ b/superduper/ext/jina/__init__.py @@ -1 +1,5 @@ from superduper_jina import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('jina') diff --git a/superduper/ext/llamacpp/__init__.py b/superduper/ext/llamacpp/__init__.py index cc158dcbe4..ca342fc4c8 100644 --- a/superduper/ext/llamacpp/__init__.py +++ b/superduper/ext/llamacpp/__init__.py @@ -1 +1,5 @@ from superduper_llamacpp import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('llamacpp') diff --git a/superduper/ext/openai/__init__.py b/superduper/ext/openai/__init__.py index 1229db4a69..db8321d413 100644 --- a/superduper/ext/openai/__init__.py +++ b/superduper/ext/openai/__init__.py @@ -1 +1,5 @@ from superduper_openai import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('openai') diff --git a/superduper/ext/pillow/__init__.py b/superduper/ext/pillow/__init__.py index f7e9f7f6b5..da40b78666 100644 --- a/superduper/ext/pillow/__init__.py +++ b/superduper/ext/pillow/__init__.py @@ -1 +1,5 @@ from superduper_pillow import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('pillow') diff --git a/superduper/ext/sentence_transformers/__init__.py b/superduper/ext/sentence_transformers/__init__.py index 22d1ee1171..8906883fd0 100644 --- a/superduper/ext/sentence_transformers/__init__.py +++ b/superduper/ext/sentence_transformers/__init__.py @@ -1 +1,5 @@ from superduper_sentence_transformers import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('sentence_transformers') diff --git a/superduper/ext/sklearn/__init__.py b/superduper/ext/sklearn/__init__.py index 4e564cb63a..002c39bb20 100644 --- a/superduper/ext/sklearn/__init__.py +++ b/superduper/ext/sklearn/__init__.py @@ -1 +1,5 @@ from superduper_sklearn import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('sklearn') diff --git a/superduper/ext/torch/__init__.py b/superduper/ext/torch/__init__.py index d5533d82fd..ca8e5c7e0c 100644 --- a/superduper/ext/torch/__init__.py +++ b/superduper/ext/torch/__init__.py @@ -1 +1,5 @@ from superduper_torch import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('torch') diff --git a/superduper/ext/transformers/__init__.py b/superduper/ext/transformers/__init__.py index c1978972c3..99854b03fc 100644 --- a/superduper/ext/transformers/__init__.py +++ b/superduper/ext/transformers/__init__.py @@ -1 +1,5 @@ from superduper_transformers import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('transformers') diff --git a/superduper/ext/vllm/__init__.py b/superduper/ext/vllm/__init__.py index 5d16a1be9d..6fb800cce9 100644 --- a/superduper/ext/vllm/__init__.py +++ b/superduper/ext/vllm/__init__.py @@ -1 +1,5 @@ from superduper_vllm import * # noqa + +from superduper.misc.annotations import warn_plugin_deprecated + +warn_plugin_deprecated('vllm') diff --git a/superduper/misc/annotations.py b/superduper/misc/annotations.py index f564280387..e23438ad3e 100644 --- a/superduper/misc/annotations.py +++ b/superduper/misc/annotations.py @@ -321,5 +321,19 @@ def replace_parameters(doc, placeholder: str = '!!!'): return '\n'.join(lines) +def warn_plugin_deprecated(name): + """Warn that a plugin is deprecated. + + :param name: name of the plugin + """ + message = ( + f'`superduper.ext.{name}` is deprecated ' + 'and will be removed in a future release. ' + f'Please insteall `superduper_{name}` and use ' + f'`from superduper_{name} import *` instead.' + ) + logging.warn(message) + + if __name__ == '__main__': print(replace_parameters(extract_parameters.__doc__))