From 90cdb4fdfcbf61d9e4ece56696ee6906ac111956 Mon Sep 17 00:00:00 2001 From: Md Fazlul Karim <26186000+fazlulkarimweb@users.noreply.github.com> Date: Tue, 2 Jan 2024 22:32:55 +0600 Subject: [PATCH] Document features (#1631) --- superduperdb/base/document.py | 48 +++++++++++++------ superduperdb/components/encoder.py | 26 ++++++++-- .../unittest/backends/mongodb/test_queries.py | 32 ++++++++++++- 3 files changed, 85 insertions(+), 21 deletions(-) diff --git a/superduperdb/base/document.py b/superduperdb/base/document.py index e8c8469066..caee2cd5a2 100644 --- a/superduperdb/base/document.py +++ b/superduperdb/base/document.py @@ -4,6 +4,7 @@ from bson.objectid import ObjectId from superduperdb import CFG +from superduperdb.base.config import BytesEncoding from superduperdb.components.encoder import Encodable, Encoder from superduperdb.components.schema import Schema from superduperdb.misc.files import get_file_from_uri @@ -34,11 +35,16 @@ def dump_bson(self) -> bytes: """Dump this document into BSON and encode as bytes""" return bson.encode(self.encode()) - def encode(self, schema: t.Optional[Schema] = None) -> t.Any: + def encode( + self, + schema: t.Optional[Schema] = None, + bytes_encoding: t.Optional[BytesEncoding] = None, + ) -> t.Any: """Make a copy of the content with all the Encodables encoded""" + bytes_encoding = bytes_encoding or CFG.bytes_encoding if schema is not None: - return _encode_with_schema(self.content, schema) - return _encode(self.content) + return _encode_with_schema(self.content, schema, bytes_encoding) + return _encode(self.content, bytes_encoding) @property def variables(self) -> t.List[str]: @@ -71,11 +77,15 @@ def outputs(self, key: str, model: str, version: t.Optional[int] = None) -> t.An return document @staticmethod - def decode(r: t.Dict, encoders: t.Dict) -> t.Any: + def decode( + r: t.Dict, encoders: t.Dict, bytes_encoding: t.Optional[BytesEncoding] = None + ) -> t.Any: + bytes_encoding = bytes_encoding or CFG.bytes_encoding + if isinstance(r, Document): - return Document(_decode(r, encoders)) + return Document(_decode(r, encoders, bytes_encoding)) elif isinstance(r, dict): - return _decode(r, encoders) + return _decode(r, encoders, bytes_encoding) raise NotImplementedError(f'type {type(r)} is not supported') def __repr__(self) -> str: @@ -123,7 +133,10 @@ def load_bsons(content: t.ByteString, encoders: t.Dict) -> t.List[Document]: return [Document(Document.decode(r, encoders=encoders)) for r in documents] -def _decode(r: t.Dict, encoders: t.Dict) -> t.Any: +def _decode( + r: t.Dict, encoders: t.Dict, bytes_encoding: t.Optional[BytesEncoding] = None +) -> t.Any: + bytes_encoding = bytes_encoding or CFG.bytes_encoding if isinstance(r, dict) and '_content' in r: encoder = encoders[r['_content']['encoder']] try: @@ -137,31 +150,36 @@ def _decode(r: t.Dict, encoders: t.Dict) -> t.Any: elif isinstance(r, dict): for k in r: if k in encoders: - r[k] = encoders[k].decode(r[k]).x + r[k] = encoders[k].decode(r[k], bytes_encoding).x else: - r[k] = _decode(r[k], encoders) + r[k] = _decode(r[k], encoders, bytes_encoding) return r -def _encode(r: t.Any) -> t.Any: +def _encode(r: t.Any, bytes_encoding: t.Optional[BytesEncoding] = None) -> t.Any: + bytes_encoding = bytes_encoding or CFG.bytes_encoding + if isinstance(r, dict): - return {k: _encode(v) for k, v in r.items()} + return {k: _encode(v, bytes_encoding) for k, v in r.items()} if isinstance(r, Encodable): - return r.encode() + return r.encode(bytes_encoding=bytes_encoding) return r -def _encode_with_schema(r: t.Any, schema: Schema) -> t.Any: +def _encode_with_schema( + r: t.Any, schema: Schema, bytes_encoding: t.Optional[BytesEncoding] = None +) -> t.Any: + bytes_encoding = bytes_encoding or CFG.bytes_encoding if isinstance(r, dict): out = { k: schema.fields[k].encode(v, wrap=False) # type: ignore[call-arg] if isinstance(schema.fields[k], Encoder) - else _encode_with_schema(v, schema) + else _encode_with_schema(v, schema, bytes_encoding) for k, v in r.items() } return out if isinstance(r, Encodable): - return r.encode() + return r.encode(bytes_encoding=bytes_encoding) return r diff --git a/superduperdb/components/encoder.py b/superduperdb/components/encoder.py index 24d1cc264a..d923c32c6c 100644 --- a/superduperdb/components/encoder.py +++ b/superduperdb/components/encoder.py @@ -72,9 +72,16 @@ def __call__( ) -> 'Encodable': return Encodable(self, x=x, uri=uri) - def decode(self, b: t.Union[bytes, str]) -> t.Any: + def decode( + self, b: t.Union[bytes, str], bytes_encoding: t.Optional[BytesEncoding] = None + ) -> t.Any: assert isinstance(self.decoder, Artifact) - if CFG.bytes_encoding == BytesEncoding.BASE64: + bytes_encoding = bytes_encoding or CFG.bytes_encoding + + if ( + CFG.bytes_encoding == BytesEncoding.BASE64 + or bytes_encoding == BytesEncoding.BASE64 + ): b = self.from_base64(b) return self(self.decoder.artifact(b)) @@ -94,12 +101,16 @@ def encode( x: t.Optional[t.Any] = None, uri: t.Optional[str] = None, wrap: bool = True, + bytes_encoding: t.Optional[BytesEncoding] = None, ) -> t.Union[t.Optional[str], t.Dict[str, t.Any]]: # TODO clarify what is going on here def _encode(x): bytes_ = self.encoder.artifact(x) - if CFG.bytes_encoding == BytesEncoding.BASE64: + if ( + CFG.bytes_encoding == BytesEncoding.BASE64 + or bytes_encoding == BytesEncoding.BASE64 + ): bytes_ = self.to_base64(bytes_) return bytes_ @@ -145,9 +156,14 @@ class Encodable: x: t.Optional[t.Any] = None uri: t.Optional[str] = None - def encode(self) -> t.Union[t.Optional[str], t.Dict[str, t.Any]]: + def encode( + self, bytes_encoding: t.Optional[BytesEncoding] = None + ) -> t.Union[t.Optional[str], t.Dict[str, t.Any]]: + bytes_encoding = bytes_encoding or CFG.bytes_encoding assert hasattr(self.encoder, 'encode') - return self.encoder.encode(x=self.x, uri=self.uri) + return self.encoder.encode( + x=self.x, uri=self.uri, bytes_encoding=bytes_encoding + ) default_encoder = Encoder(identifier='_default') diff --git a/test/unittest/backends/mongodb/test_queries.py b/test/unittest/backends/mongodb/test_queries.py index 1161bf8c12..0ee28a1334 100644 --- a/test/unittest/backends/mongodb/test_queries.py +++ b/test/unittest/backends/mongodb/test_queries.py @@ -64,7 +64,7 @@ def test_replace(db): @pytest.mark.skipif(not torch, reason='Torch not installed') -@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True) +@pytest.mark.parametrize('db', [DBConfig.mongodb_empty], indirect=True) def test_insert_from_uris(db, image_url): import PIL @@ -100,6 +100,36 @@ def test_insert_from_uris(db, image_url): assert isinstance(r['other']['item'].x, PIL.Image.Image) +@pytest.mark.skipif(not torch, reason='Torch not installed') +@pytest.mark.parametrize('db', [DBConfig.mongodb_empty], indirect=True) +def test_insert_from_uris_bytes_encoding(db, image_url): + import PIL + + from superduperdb.base.config import BytesEncoding + from superduperdb.ext.pillow.encoder import pil_image + + db.add(pil_image) + + if image_url.startswith('file://'): + image_url = image_url[7:] + + collection = Collection('documents') + to_insert = [ + Document( + { + 'img': pil_image(PIL.Image.open(image_url)).encode( + bytes_encoding=BytesEncoding.BASE64 + ) + } + ) + ] + + db.execute(collection.insert_many(to_insert)) + + r = db.execute(collection.find_one()) + assert isinstance(r['img'].x, PIL.Image.Image) + + @pytest.mark.skipif(not torch, reason='Torch not installed') def test_update_many(db): collection = Collection('documents')