Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant arguments from components using kw_only=True #1495

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci_code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
fail-fast: false
matrix:
os: [ "ubuntu-latest" ] # TODO: add "windows-latest", "macos-latest" when Docker removed
python-version: ["3.8", "3.11"] # Due to cache limitations, check only the earliest and the latest.
python-version: ["3.10", "3.11"] # Due to cache limitations, check only the earliest and the latest.

steps:
- name: Check out repository
Expand Down
18 changes: 3 additions & 15 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Update connection uris in `sql_examples.ipynb` to include snippets for Embedded, Cloud, and Distributed databases.

## [0.1.0](https://github.com/SuperDuperDB/superduperdb/compare/0.0.20...0.1.0]) (2023-Dec-05)
#### Refactorings

#### Changed defaults / behaviours
- Added `kw_only` to most `@dataclass` decorators to simplify

- ...
## [0.1.0](https://github.com/SuperDuperDB/superduperdb/compare/0.0.20...0.1.0]) (2023-Dec-05)

#### New Features & Functionality

- Introduced Chinese version of README


#### Bug Fixes

- Updated paths for docker-compose.


## [0.0.20](https://github.com/SuperDuperDB/superduperdb/compare/0.0.10...0.0.20]) (2023-Dec-04)

#### Changed defaults / behaviours

- Chop down large files from the history to reduce the size of the repo.

#### New Features & Functionality

- ...


#### Bug Fixes

- ...


## [0.0.19](https://github.com/SuperDuperDB/superduperdb/compare/0.0.15...0.0.19]) (2023-Dec-04)

Expand All @@ -74,7 +63,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add a `Component.post_create` hook to enable logic to incorporate model versions
- Fix multiple issues with `ibis`/ SQL code


#### New Features & Functionality

- Add support for selecting whether logs will be redirected to the system output or directly to Loki
Expand Down
2,821 changes: 39 additions & 2,782 deletions examples/question_the_docs.ipynb

Large diffs are not rendered by default.

13 changes: 5 additions & 8 deletions superduperdb/backends/ibis/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,24 +527,21 @@ class QueryType(str, enum.Enum):
ATTR = 'attr'


@dc.dataclass(repr=False)
@dc.dataclass(repr=False, kw_only=True)
class Table(Component):
"""
This is a representation of an SQL table in ibis,
saving the important meta-data associated with the table
in the ``superduperdb`` meta-data store.

:param identifier: The name of the table
:param schema: The schema of the table
{component_params}:param schema: The schema of the table
:param primary_id: The primary id of the table
:param version: The version of the table
"""

identifier: str
type_id: t.ClassVar[str] = 'table'
__doc__ = __doc__.format(component_params=Component.__doc__)

schema: Schema
primary_id: str = 'id'
version: t.Optional[int] = None
type_id: t.ClassVar[str] = 'table'

def pre_create(self, db: 'Datalayer'):
assert self.schema is not None, "Schema must be set"
Expand Down
3 changes: 2 additions & 1 deletion superduperdb/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,8 @@ def _remove_component_version(

if hasattr(component, 'artifact_attributes'):
for a in component.artifact_attributes:
self.artifact_store.delete(info['dict'][a]['file_id'])
if info['dict'][a] is not None:
self.artifact_store.delete(info['dict'][a]['file_id'])
self.metadata.delete_component_version(type_id, identifier, version=version)

def _download_content( # TODO: duplicated function
Expand Down
10 changes: 9 additions & 1 deletion superduperdb/base/serializable.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ def _deserialize(r: t.Any, db: None = None) -> t.Any:
if 'db' in inspect.signature(component_cls.__init__).parameters:
kwargs.update(db=db)

return component_cls(**kwargs)
instance = component_cls(**{k: v for k, v in kwargs.items() if k != 'version'})

# special handling of Component.version
from superduperdb.components.component import Component

if issubclass(component_cls, Component):
instance.version = r.get('version', None)

return instance


def _serialize(item: t.Any) -> t.Dict[str, t.Any]:
Expand Down
4 changes: 2 additions & 2 deletions superduperdb/cli/stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@


@command(help='Apply the stack tarball to the database')
def apply(yaml_path: str):
def apply(yaml_path: str, identifier: str):
db = build_datalayer()
stack = Stack()
stack = Stack(identifier=identifier)
stack.load(yaml_path)
db.add(stack)
14 changes: 8 additions & 6 deletions superduperdb/components/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from __future__ import annotations

import dataclasses as dc
import typing as t

from superduperdb.backends.base.artifact import ArtifactStore
Expand All @@ -16,16 +17,18 @@
from superduperdb.components.dataset import Dataset


@dc.dataclass
class Component(Serializable):
"""
Base component which model, listeners, learning tasks etc. inherit from.
"""
:param identifier: A unique identifier for the component"""

type_id: t.ClassVar[str]
identifier: str

if t.TYPE_CHECKING:
identifier: t.Optional[str]
version: t.Optional[int]
def __post_init__(self) -> None:
# set version in `__post_init__` so that is
# cannot be set in `__init__` and is always set
self.version: t.Optional[int] = None
Comment on lines +28 to +31
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add self.type_id = self.__class__.__name__ in the post_init

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to discuss this -- can be done later. (The code is old, just moved around a bit, creating this diff.)


def pre_create(self, db: Datalayer) -> None:
"""Called the first time this component is created
Expand All @@ -41,7 +44,6 @@ def post_create(self, db: Datalayer) -> None:
:param db: the db that creates the component
"""
assert db
assert db

def on_load(self, db: Datalayer) -> None:
"""Called when this component is loaded from the data store
Expand Down
17 changes: 7 additions & 10 deletions superduperdb/components/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,27 @@


@public_api(stability='stable')
@dc.dataclass
@dc.dataclass(kw_only=True)
class Dataset(Component):
"""A dataset is an immutable collection of documents that used for training

:param identifier: A unique identifier for the dataset
"""A dataset is an immutable collection of documents.
{component_params}
:param select: A query to select the documents for the dataset
:param sample_size: The number of documents to sample from the query
:param random_seed: The random seed to use for sampling
:param creation_date: The date the dataset was created
:param raw_data: The raw data for the dataset
:param version: The version of the dataset
"""

identifier: str
__doc__ = __doc__.format(component_params=Component.__doc__)

type_id: t.ClassVar[str] = 'dataset'

select: t.Optional[Select] = None
sample_size: t.Optional[int] = None
random_seed: t.Optional[int] = None
creation_date: t.Optional[str] = None
raw_data: t.Optional[t.Union[Artifact, t.Any]] = None

# Don't set these manually
version: t.Optional[int] = None
type_id: t.ClassVar[str] = 'dataset'

@override
def pre_create(self, db: 'Datalayer') -> None:
if self.raw_data is None:
Expand Down
15 changes: 8 additions & 7 deletions superduperdb/components/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,27 @@ def _pickle_encoder(x: t.Any) -> bytes:


@public_api(stability='stable')
@dc.dataclass
@dc.dataclass(kw_only=True)
class Encoder(Component):
"""
Storeable ``Component`` allowing byte encoding of primary data,
i.e. data inserted using ``db.base.db.Datalayer.insert``

{component_parameters}
:param identifier: Unique identifier
:param decoder: callable converting a ``bytes`` string to a ``Encodable``
of this ``Encoder``
:param encoder: Callable converting an ``Encodable`` of this ``Encoder``
to ``bytes``
:param shape: Shape of the data
:param version: Version of the encoder (don't use this)
:param load_hybrid: Whether to load the data from the URI or return the URI in
`CFG.hybrid` mode
"""

__doc__ = __doc__.format(component_parameters=Component.__doc__)

type_id: t.ClassVar[str] = 'encoder'

artifact_artibutes: t.ClassVar[t.Sequence[str]] = ['decoder', 'encoder']
identifier: str
decoder: t.Union[t.Callable, Artifact] = dc.field(
default_factory=lambda: Artifact(artifact=_pickle_decoder)
)
Expand All @@ -50,13 +52,12 @@ class Encoder(Component):
shape: t.Optional[t.Sequence] = None
load_hybrid: bool = True

# Don't set this manually
version: t.Optional[int] = None
type_id: t.ClassVar[str] = 'encoder'
# TODO what's this for?
encoders: t.ClassVar[t.List] = []

def __post_init__(self):
super().__post_init__()

self.encoders.append(self.identifier)
if isinstance(self.decoder, t.Callable):
self.decoder = Artifact(artifact=self.decoder)
Expand Down
30 changes: 15 additions & 15 deletions superduperdb/components/listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,39 @@


@public_api(stability='stable')
@dc.dataclass
@dc.dataclass(kw_only=True)
class Listener(Component):
"""
Listener object which is used to process a column/ key of a collection or table,
and store the outputs.

{component_parameters}
:param key: Key to be bound to model
:param model: Model for processing data
:param select: Object for selecting which data is processed
:param active: Toggle to ``False`` to deactivate change data triggering
:param identifier: A string used to identify the model.
:param active: Toggle to ``False`` to deactivate change data triggering
:param predict_kwargs: Keyword arguments to self.model.predict
:param version: Version number of the model(?)
"""

__doc__ = __doc__.format(component_parameters=Component.__doc__)

key: str
model: t.Union[str, Model]
select: CompoundSelect
identifier: t.Optional[str] = None # type: ignore[assignment]
active: bool = True
identifier: t.Optional[str] = None
predict_kwargs: t.Optional[t.Dict] = dc.field(default_factory=dict)

# Don't set this manually
version: t.Optional[int] = None

type_id: t.ClassVar[str] = 'listener'

def __post_init__(self):
super().__post_init__()
if self.identifier is None and self.model is not None:
if isinstance(self.model, str):
self.identifier = f'{self.model}/{self.id_key}'
else:
self.identifier = f'{self.model.identifier}/{self.id_key}'

@property
def outputs(self):
return f'{_OUTPUTS_KEY}.{self.key}.{self.model.identifier}.{self.model.version}'
Expand All @@ -56,6 +62,7 @@ def child_components(self) -> t.Sequence[t.Tuple[str, str]]:
def pre_create(self, db: Datalayer) -> None:
if isinstance(self.model, str):
self.model = t.cast(Model, db.load('model', self.model))

if self.select is not None and self.select.variables:
self.select = t.cast(CompoundSelect, self.select.set_variables(db))

Expand Down Expand Up @@ -88,13 +95,6 @@ def id_key(self) -> str:
return self.key.split('.')[1]
return self.key

def __post_init__(self):
if self.identifier is None and self.model is not None:
if isinstance(self.model, str):
self.identifier = f'{self.model}/{self.id_key}'
else:
self.identifier = f'{self.model.identifier}/{self.id_key}'

@override
def schedule_jobs(
self,
Expand Down
13 changes: 6 additions & 7 deletions superduperdb/components/metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,24 @@


@public_api(stability='beta')
@dc.dataclass
@dc.dataclass(kw_only=True)
class Metric(Component):
"""
Metric base object with which to evaluate performance on a data-set.
These objects are ``callable`` and are applied row-wise to the data, and averaged.

:param identifier: unique identifier
{component_parameters}
:param object: callable or ``Artifact`` to be applied to the data
:param version: version of the ``Metric``
"""

identifier: str
object: t.Union[Artifact, t.Callable, None] = None
version: t.Optional[int] = None
__doc__ = __doc__.format(component_parameters=Component.__doc__)

artifacts: t.ClassVar[t.List[str]] = ['object']
type_id: t.ClassVar[str] = 'metric'

object: t.Union[Artifact, t.Callable]

def __post_init__(self) -> None:
super().__post_init__()
if self.object and not isinstance(self.object, Artifact):
self.object = Artifact(artifact=self.object)

Expand Down
Loading