Skip to content

Commit

Permalink
Fix SQL multimodal example
Browse files Browse the repository at this point in the history
  • Loading branch information
blythed committed Nov 29, 2023
1 parent d951858 commit 9025aa4
Show file tree
Hide file tree
Showing 10 changed files with 112 additions and 4,255 deletions.
4,230 changes: 16 additions & 4,214 deletions examples/multimodal_image_search_clip.ipynb

Large diffs are not rendered by default.

24 changes: 22 additions & 2 deletions examples/sql-example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@
"The initial step in any `superduperdb` workflow is to connect to your datastore. To connect to a different datastore, simply add a different `URI`, for example, `postgres://...`."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "587bffd3-7202-4e12-a6de-435d4b16892d",
"metadata": {},
"outputs": [],
"source": [
"!rm .superduperdb/test.ddb"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -78,6 +88,16 @@
"db = superduper('duckdb://.superduperdb/test.ddb')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0a5252a3-91a7-4789-85b7-0b8279c228a3",
"metadata": {},
"outputs": [],
"source": [
"db.show('vector_index')"
]
},
{
"cell_type": "markdown",
"id": "b8794451",
Expand Down Expand Up @@ -108,7 +128,7 @@
"!mkdir -p data/coco\n",
"\n",
"# Move the 'images_small' directory to 'data/coco/images'\n",
"!mv images_small data/coco/images"
"!mv images_tiny data/coco/images"
]
},
{
Expand Down Expand Up @@ -224,7 +244,7 @@
"source": [
"## Build SuperDuperDB `Model` Instances\n",
"\n",
"This use-case utilizes the `superduperdb.ext.torch` extension. Both models used output `torch` tensors, which are encoded with `tensor`:"
"This use-case utilizes the `superduperdb.ext.torch` extension. Both models use `torch` tensors in their output, which are encoded with `tensor`:"
]
},
{
Expand Down
57 changes: 36 additions & 21 deletions examples/transfer_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
"cell_type": "markdown",
"id": "fe6fd0ab0e1ad844",
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"# Transfer Learning with Sentence Transformers and Scikit-Learn"
Expand All @@ -14,7 +17,10 @@
"cell_type": "markdown",
"id": "8dcde44d942793ff",
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"## Introduction\n",
Expand All @@ -26,7 +32,10 @@
"cell_type": "markdown",
"id": "1809feca8a8dca5a",
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"## Prerequisites\n",
Expand All @@ -41,7 +50,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install superduperdb\n",
"# !pip install superduperdb\n",
"!pip install ipython numpy datasets sentence-transformers"
]
},
Expand All @@ -57,7 +66,10 @@
"cell_type": "markdown",
"id": "5379007991707d17",
"metadata": {
"collapsed": false
"collapsed": false,
"jupyter": {
"outputs_hidden": false
}
},
"source": [
"First, we need to establish a connection to a MongoDB datastore via SuperDuperDB. You can configure the `MongoDB_URI` based on your specific setup. \n",
Expand All @@ -84,7 +96,7 @@
"\n",
"# SuperDuperDB, now handles your MongoDB database\n",
"# It just super dupers your database\n",
"db = superduper(mongodb_uri)\n",
"db = superduper(mongodb_uri, artifact_store='filesystem://./data/')\n",
"\n",
"# Reference a collection called transfer\n",
"collection = Collection('transfer')"
Expand Down Expand Up @@ -169,10 +181,21 @@
" X='text',\n",
" db=db,\n",
" select=collection.find(),\n",
" listen=True\n",
" listen=True,\n",
" show_progress_bar=True,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "aed95196-4454-4d88-beb8-dc6dddfcbe33",
"metadata": {},
"outputs": [],
"source": [
"db.execute(collection.find_one())"
]
},
{
"cell_type": "markdown",
"id": "68fefc17",
Expand Down Expand Up @@ -201,10 +224,10 @@
"\n",
"# Train the model on 'text' data with corresponding labels\n",
"model.fit(\n",
" X='text',\n",
" X='_outputs.text.all-MiniLM-L6-v2.0',\n",
" y='label',\n",
" db=db,\n",
" select=collection.find().featurize({'text': 'all-MiniLM-L6-v2'}),\n",
" select=collection.find(),\n",
")\n"
]
},
Expand All @@ -227,9 +250,9 @@
"source": [
"# Make predictions on 'text' data with the trained SuperDuperDB model\n",
"model.predict(\n",
" X='text',\n",
" X='_outputs.text.all-MiniLM-L6-v2.0',\n",
" db=db,\n",
" select=collection.find().featurize({'text': 'all-MiniLM-L6-v2'}),\n",
" select=collection.find(),\n",
" listen=True,\n",
")"
]
Expand Down Expand Up @@ -258,16 +281,8 @@
"print(r['text'][:100])\n",
"\n",
"# Print the prediction made by the SVC model stored in '_outputs'\n",
"print(r['_outputs']['text']['svc'])"
"print(r['_outputs']['text']['svc']['0'])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2174a0a3-c32e-4481-a301-370b569ba30c",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -286,7 +301,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
16 changes: 8 additions & 8 deletions examples/voice_memos.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,15 @@
"db.execute(voice_collection.insert_many([\n",
" # Create a SuperDuperDB Document for each audio sample\n",
" Document({'audio': enc(r['audio']['array'])}) for r in data\n",
"]))\n"
"]))"
]
},
{
"cell_type": "markdown",
"id": "721f31f4626881e0",
"metadata": {},
"source": [
"## Install Pre-Trained Model (LibriSpeech) into Database\n",
"## Install Pre-Trained Model (LibriSpeech) with Database\n",
"\n",
"Apply a pre-trained `transformers` model to the data:"
]
Expand Down Expand Up @@ -225,7 +225,7 @@
" db=db, # Provide the SuperDuperDB instance\n",
" select=voice_collection.find(), # Specify the collection of audio data to transcribe\n",
" max_chunk_size=10 # Set the maximum chunk size for processing audio data\n",
")\n"
")"
]
},
{
Expand Down Expand Up @@ -256,7 +256,7 @@
" identifier='my-index', # Set a unique identifier for the VectorIndex\n",
" indexing_listener=Listener(\n",
" model=OpenAIEmbedding(model='text-embedding-ada-002'), # Use OpenAIEmbedding for audio transcriptions\n",
" key='_outputs.audio.transcription', # Specify the key for indexing the transcriptions in the output\n",
" key='_outputs.audio.transcription.0', # Specify the key for indexing the transcriptions in the output\n",
" select=voice_collection.find(), # Select the collection of audio data to index\n",
" ),\n",
" )\n",
Expand Down Expand Up @@ -290,7 +290,7 @@
"search_results = list(\n",
" db.execute(\n",
" voice_collection.like(\n",
" {'_outputs.audio.transcription': search_term},\n",
" {'_outputs.audio.transcription.0': search_term},\n",
" n=num_results,\n",
" vector_index='my-index', # Use the 'my-index' VectorIndex for similarity search\n",
" ).find({}, {'_outputs.audio.transcription': 1}) # Retrieve only the 'transcription' field in the results\n",
Expand Down Expand Up @@ -370,11 +370,11 @@
" \n",
" # Select relevant context for the model from the SuperDuperDB collection of audio transcriptions\n",
" context_select=voice_collection.like(\n",
" Document({'_outputs.audio.transcription': question}), vector_index='my-index'\n",
" Document({'_outputs.audio.transcription.0': question}), vector_index='my-index'\n",
" ).find(),\n",
" \n",
" # Specify the key in the context used by the model\n",
" context_key='_outputs.audio.transcription',\n",
" context_key='_outputs.audio.transcription.0',\n",
")[0].content\n",
"\n",
"# Print the response obtained from the chat completion model\n",
Expand All @@ -398,7 +398,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
7 changes: 5 additions & 2 deletions superduperdb/base/datalayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,11 @@ def backfill_vector_search(self, vi, searcher):

id = record[self.databackend.id_field]
assert not isinstance(vi.indexing_listener.model, str)
h = record.outputs(key, vi.indexing_listener.model.identifier)
h = record.outputs(
key,
vi.indexing_listener.model.identifier,
version=vi.indexing_listener.model.version,
)
if isinstance(h, Encodable):
h = h.x
items.append(VectorItem.create(id=str(id), vector=h))
Expand Down Expand Up @@ -915,7 +919,6 @@ def _add(
object.on_load(self)
return object.schedule_jobs(self, dependencies=dependencies), object
except Exception as e:

raise exceptions.DatalayerException(
f'Error while adding object with id: {object.identifier}'
) from e
Expand Down
6 changes: 4 additions & 2 deletions superduperdb/base/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from superduperdb.components.encoder import Encodable, Encoder
from superduperdb.components.schema import Schema
from superduperdb.misc.files import get_file_from_uri
from superduperdb.misc.special_dicts import MongoStyleDict

ContentType = t.Union[t.Dict, Encodable]
ItemType = t.Union[t.Dict[str, t.Any], Encodable, ObjectId]
Expand Down Expand Up @@ -46,10 +47,11 @@ def outputs(self, key: str, model: str, version: t.Optional[int] = None) -> t.An
:param key: Document key to get outputs from.
:param model: Model name to get outputs from.
"""
r = MongoStyleDict(self.unpack())
if version is not None:
document = self.unpack()[_OUTPUTS_KEY][key][model][version]
document = r[f'{_OUTPUTS_KEY}.{key}.{model}.{version}']
else:
tmp = self.unpack()[_OUTPUTS_KEY][key][model]
tmp = r[f'{_OUTPUTS_KEY}.{key}.{model}']
version = max(list(tmp.keys()))
return tmp[version]
return document
Expand Down
20 changes: 16 additions & 4 deletions superduperdb/components/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@
ObjectsArg = t.Sequence[t.Union[t.Any, Artifact]]


class _to_call:
def __init__(self, callable, **kwargs):
self.callable = callable
self.kwargs = kwargs

def __call__(self, X):
return self.callable(X, **self.kwargs)


@dc.dataclass
class _TrainingConfiguration(Component):
"""
Expand Down Expand Up @@ -133,20 +142,23 @@ def _predict_one(self, X: t.Any, **kwargs) -> int:

return output

def _forward(self, X: t.Sequence[int], num_workers: int = 0) -> t.Sequence[int]:
def _forward(
self, X: t.Sequence[int], num_workers: int = 0, **kwargs
) -> t.Sequence[int]:
if self.batch_predict:
return self.to_call(X)
return self.to_call(X, **kwargs)

outputs = []
if num_workers:
to_call = _to_call(self.to_call, **kwargs)
pool = multiprocessing.Pool(processes=num_workers)
for r in pool.map(self.to_call, X):
for r in pool.map(to_call, X):
outputs.append(r)
pool.close()
pool.join()
else:
for r in X:
outputs.append(self.to_call(r))
outputs.append(self.to_call(r, **kwargs))
return outputs

def _predict(
Expand Down
1 change: 1 addition & 0 deletions superduperdb/misc/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ def download_content(
pass

if CFG.hybrid_storage:
assert isinstance(CFG.downloads_folder, str)
_download_update = SaveFile(CFG.downloads_folder)
else:

Expand Down
5 changes: 4 additions & 1 deletion superduperdb/misc/special_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def __setitem__(self, key: str, value: t.Any) -> None:
else:
parent = key.split('.')[0]
child = '.'.join(key.split('.')[1:])
parent_item = MongoStyleDict(self[parent])
try:
parent_item = MongoStyleDict(self[parent])
except KeyError:
parent_item = MongoStyleDict({})
parent_item[child] = value
self[parent] = parent_item
1 change: 0 additions & 1 deletion test/unittest/misc/test_downloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def test_s3_and_web():

@pytest.fixture
def patch_cfg_downloads(monkeypatch):
monkeypatch.setattr(CFG, 'hybrid_storage', True)
td = str(uuid.uuid4())
with tempfile.TemporaryDirectory() as td:
monkeypatch.setattr(CFG, 'downloads_folder', td)
Expand Down

0 comments on commit 9025aa4

Please sign in to comment.