Skip to content

Commit

Permalink
Document features (#1631)
Browse files Browse the repository at this point in the history
  • Loading branch information
fazlulkarimweb authored Jan 2, 2024
1 parent 68ddb39 commit 90cdb4f
Show file tree
Hide file tree
Showing 3 changed files with 85 additions and 21 deletions.
48 changes: 33 additions & 15 deletions superduperdb/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from bson.objectid import ObjectId

from superduperdb import CFG
from superduperdb.base.config import BytesEncoding
from superduperdb.components.encoder import Encodable, Encoder
from superduperdb.components.schema import Schema
from superduperdb.misc.files import get_file_from_uri
Expand Down Expand Up @@ -34,11 +35,16 @@ def dump_bson(self) -> bytes:
"""Dump this document into BSON and encode as bytes"""
return bson.encode(self.encode())

def encode(self, schema: t.Optional[Schema] = None) -> t.Any:
def encode(
self,
schema: t.Optional[Schema] = None,
bytes_encoding: t.Optional[BytesEncoding] = None,
) -> t.Any:
"""Make a copy of the content with all the Encodables encoded"""
bytes_encoding = bytes_encoding or CFG.bytes_encoding
if schema is not None:
return _encode_with_schema(self.content, schema)
return _encode(self.content)
return _encode_with_schema(self.content, schema, bytes_encoding)
return _encode(self.content, bytes_encoding)

@property
def variables(self) -> t.List[str]:
Expand Down Expand Up @@ -71,11 +77,15 @@ def outputs(self, key: str, model: str, version: t.Optional[int] = None) -> t.An
return document

@staticmethod
def decode(r: t.Dict, encoders: t.Dict) -> t.Any:
def decode(
r: t.Dict, encoders: t.Dict, bytes_encoding: t.Optional[BytesEncoding] = None
) -> t.Any:
bytes_encoding = bytes_encoding or CFG.bytes_encoding

if isinstance(r, Document):
return Document(_decode(r, encoders))
return Document(_decode(r, encoders, bytes_encoding))
elif isinstance(r, dict):
return _decode(r, encoders)
return _decode(r, encoders, bytes_encoding)
raise NotImplementedError(f'type {type(r)} is not supported')

def __repr__(self) -> str:
Expand Down Expand Up @@ -123,7 +133,10 @@ def load_bsons(content: t.ByteString, encoders: t.Dict) -> t.List[Document]:
return [Document(Document.decode(r, encoders=encoders)) for r in documents]


def _decode(r: t.Dict, encoders: t.Dict) -> t.Any:
def _decode(
r: t.Dict, encoders: t.Dict, bytes_encoding: t.Optional[BytesEncoding] = None
) -> t.Any:
bytes_encoding = bytes_encoding or CFG.bytes_encoding
if isinstance(r, dict) and '_content' in r:
encoder = encoders[r['_content']['encoder']]
try:
Expand All @@ -137,31 +150,36 @@ def _decode(r: t.Dict, encoders: t.Dict) -> t.Any:
elif isinstance(r, dict):
for k in r:
if k in encoders:
r[k] = encoders[k].decode(r[k]).x
r[k] = encoders[k].decode(r[k], bytes_encoding).x
else:
r[k] = _decode(r[k], encoders)
r[k] = _decode(r[k], encoders, bytes_encoding)
return r


def _encode(r: t.Any) -> t.Any:
def _encode(r: t.Any, bytes_encoding: t.Optional[BytesEncoding] = None) -> t.Any:
bytes_encoding = bytes_encoding or CFG.bytes_encoding

if isinstance(r, dict):
return {k: _encode(v) for k, v in r.items()}
return {k: _encode(v, bytes_encoding) for k, v in r.items()}
if isinstance(r, Encodable):
return r.encode()
return r.encode(bytes_encoding=bytes_encoding)
return r


def _encode_with_schema(r: t.Any, schema: Schema) -> t.Any:
def _encode_with_schema(
r: t.Any, schema: Schema, bytes_encoding: t.Optional[BytesEncoding] = None
) -> t.Any:
bytes_encoding = bytes_encoding or CFG.bytes_encoding
if isinstance(r, dict):
out = {
k: schema.fields[k].encode(v, wrap=False) # type: ignore[call-arg]
if isinstance(schema.fields[k], Encoder)
else _encode_with_schema(v, schema)
else _encode_with_schema(v, schema, bytes_encoding)
for k, v in r.items()
}
return out
if isinstance(r, Encodable):
return r.encode()
return r.encode(bytes_encoding=bytes_encoding)
return r


Expand Down
26 changes: 21 additions & 5 deletions superduperdb/components/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,16 @@ def __call__(
) -> 'Encodable':
return Encodable(self, x=x, uri=uri)

def decode(self, b: t.Union[bytes, str]) -> t.Any:
def decode(
self, b: t.Union[bytes, str], bytes_encoding: t.Optional[BytesEncoding] = None
) -> t.Any:
assert isinstance(self.decoder, Artifact)
if CFG.bytes_encoding == BytesEncoding.BASE64:
bytes_encoding = bytes_encoding or CFG.bytes_encoding

if (
CFG.bytes_encoding == BytesEncoding.BASE64
or bytes_encoding == BytesEncoding.BASE64
):
b = self.from_base64(b)
return self(self.decoder.artifact(b))

Expand All @@ -94,12 +101,16 @@ def encode(
x: t.Optional[t.Any] = None,
uri: t.Optional[str] = None,
wrap: bool = True,
bytes_encoding: t.Optional[BytesEncoding] = None,
) -> t.Union[t.Optional[str], t.Dict[str, t.Any]]:
# TODO clarify what is going on here

def _encode(x):
bytes_ = self.encoder.artifact(x)
if CFG.bytes_encoding == BytesEncoding.BASE64:
if (
CFG.bytes_encoding == BytesEncoding.BASE64
or bytes_encoding == BytesEncoding.BASE64
):
bytes_ = self.to_base64(bytes_)
return bytes_

Expand Down Expand Up @@ -145,9 +156,14 @@ class Encodable:
x: t.Optional[t.Any] = None
uri: t.Optional[str] = None

def encode(self) -> t.Union[t.Optional[str], t.Dict[str, t.Any]]:
def encode(
self, bytes_encoding: t.Optional[BytesEncoding] = None
) -> t.Union[t.Optional[str], t.Dict[str, t.Any]]:
bytes_encoding = bytes_encoding or CFG.bytes_encoding
assert hasattr(self.encoder, 'encode')
return self.encoder.encode(x=self.x, uri=self.uri)
return self.encoder.encode(
x=self.x, uri=self.uri, bytes_encoding=bytes_encoding
)


default_encoder = Encoder(identifier='_default')
32 changes: 31 additions & 1 deletion test/unittest/backends/mongodb/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def test_replace(db):


@pytest.mark.skipif(not torch, reason='Torch not installed')
@pytest.mark.parametrize("db", [DBConfig.mongodb_empty], indirect=True)
@pytest.mark.parametrize('db', [DBConfig.mongodb_empty], indirect=True)
def test_insert_from_uris(db, image_url):
import PIL

Expand Down Expand Up @@ -100,6 +100,36 @@ def test_insert_from_uris(db, image_url):
assert isinstance(r['other']['item'].x, PIL.Image.Image)


@pytest.mark.skipif(not torch, reason='Torch not installed')
@pytest.mark.parametrize('db', [DBConfig.mongodb_empty], indirect=True)
def test_insert_from_uris_bytes_encoding(db, image_url):
import PIL

from superduperdb.base.config import BytesEncoding
from superduperdb.ext.pillow.encoder import pil_image

db.add(pil_image)

if image_url.startswith('file://'):
image_url = image_url[7:]

collection = Collection('documents')
to_insert = [
Document(
{
'img': pil_image(PIL.Image.open(image_url)).encode(
bytes_encoding=BytesEncoding.BASE64
)
}
)
]

db.execute(collection.insert_many(to_insert))

r = db.execute(collection.find_one())
assert isinstance(r['img'].x, PIL.Image.Image)


@pytest.mark.skipif(not torch, reason='Torch not installed')
def test_update_many(db):
collection = Collection('documents')
Expand Down

0 comments on commit 90cdb4f

Please sign in to comment.