Skip to content

Commit

Permalink
Add encoders in openai model
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Dec 2, 2023
1 parent 78f8255 commit 2feb387
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
33 changes: 22 additions & 11 deletions superduperdb/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import superduperdb as s
from superduperdb import logging
from superduperdb.backends.base.backends import data_backends, metadata_stores
from superduperdb.backends.base.data_backend import BaseDataBackend
from superduperdb.backends.dask.compute import DaskComputeBackend
from superduperdb.backends.local.artifacts import FileSystemArtifactStore
from superduperdb.backends.local.compute import LocalComputeBackend
Expand Down Expand Up @@ -51,7 +52,7 @@ def build(uri, mapping, type: str = 'data_backend'):

if re.match('^mongodb:\/\/', uri) is not None:
name = uri.split('/')[-1]
conn = pymongo.MongoClient(
conn: pymongo.MongoClient = pymongo.MongoClient(
uri,
serverSelectionTimeoutMS=5000,
)
Expand All @@ -74,21 +75,27 @@ def build(uri, mapping, type: str = 'data_backend'):
raise ValueError('Cannot build metadata from a CSV file.')

import glob
csv_files = glob.glob(uri)
tables = {
re.match('^.*/(.*)\.csv$', csv_file).groups()[0]: pandas.read_csv(csv_file)
for csv_file in csv_files
}
conn = ibis.pandas.connect(tables)
return mapping['ibis'](conn, uri.split('/')[0])

tables = {}

for csv_file in glob.glob(uri):
match = re.match('^.*/(.*)\.csv$', csv_file)
if match:
table = match.groups()[0]
df = pandas.read_csv(csv_file)
tables.update({table: df})

ibis_conn = ibis.pandas.connect(tables)
return mapping['ibis'](ibis_conn, uri.split('/')[0])
else:
name = uri.split('//')[0]
if type == 'data_backend':
conn = ibis.connect(uri)
return mapping['ibis'](conn, name)
ibis_conn = ibis.connect(uri)
return mapping['ibis'](ibis_conn, name)
else:
assert type == 'metadata'
from sqlalchemy import create_engine

conn = create_engine(uri)
return mapping['sqlalchemy'](conn, name)

Expand All @@ -107,7 +114,9 @@ def build_compute(compute):
return LocalComputeBackend()


def build_datalayer(cfg=None, databackend=None, **kwargs) -> Datalayer:
def build_datalayer(
cfg=None, databackend: t.Optional[BaseDataBackend] = None, **kwargs
) -> Datalayer:
"""
Build a Datalayer object as per ``db = superduper(db)`` from configuration.
Expand All @@ -130,6 +139,8 @@ def build_datalayer(cfg=None, databackend=None, **kwargs) -> Datalayer:
try:
if not databackend:
databackend = build(cfg.data_backend, data_backends)

assert isinstance(databackend, BaseDataBackend)
logging.info("Data Client is ready.", databackend.conn)
except Exception as e:
# Exit quickly if a connection fails.
Expand Down
23 changes: 22 additions & 1 deletion superduperdb/ext/openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,13 @@
RateLimitError,
)

from superduperdb.backends.ibis.field_types import FieldType, dtype
from superduperdb.base.datalayer import Datalayer
from superduperdb.components.component import Component
from superduperdb.components.encoder import Encoder
from superduperdb.components.model import Predictor
from superduperdb.components.vector_index import vector
from superduperdb.ext.pillow import pil_image
from superduperdb.misc.compat import cache
from superduperdb.misc.retry import Retry

Expand All @@ -47,7 +49,7 @@ class OpenAI(Component, Predictor):
identifier: str = ''
version: t.Optional[int] = None
takes_context: bool = False
encoder: t.Union[Encoder, str, None] = None
encoder: t.Union[FieldType, Encoder, str, None] = None
model_update_kwargs: dict = dc.field(default_factory=dict)

@property
Expand Down Expand Up @@ -151,6 +153,10 @@ class OpenAIChatCompletion(OpenAI):
takes_context: bool = True
prompt: str = ''

def __post_init__(self):
if self.encoder is None:
self.encoder = dtype('str')

def _format_prompt(self, context, X):
prompt = self.prompt.format(context='\n'.join(context))
return prompt + X
Expand Down Expand Up @@ -215,6 +221,9 @@ class OpenAIImageCreation(OpenAI):
takes_context: bool = True
prompt: str = ''

def __post_init__(self):
self.encoder = pil_image

def _format_prompt(self, context, X):
prompt = self.prompt.format(context='\n'.join(context))
return prompt + X
Expand Down Expand Up @@ -316,6 +325,10 @@ class OpenAIImageEdit(OpenAI):
takes_context: bool = True
prompt: str = ''

def __post_init__(self):
if self.encoder is None:
self.encoder = pil_image

def _format_prompt(self, context):
prompt = self.prompt.format(context='\n'.join(context))
return prompt
Expand Down Expand Up @@ -454,6 +467,10 @@ class OpenAIAudioTranscription(OpenAI):
takes_context: bool = True
prompt: str = ''

def __post_init__(self):
if self.encoder is None:
self.encoder = dtype('str')

@retry
def _predict_one(
self, file: t.BinaryIO, context: t.Optional[t.List[str]] = None, **kwargs
Expand Down Expand Up @@ -549,6 +566,10 @@ class OpenAIAudioTranslation(OpenAI):
takes_context: bool = True
prompt: str = ''

def __post_init__(self):
if self.encoder is None:
self.encoder = dtype('str')

@retry
def _predict_one(
self, file: t.BinaryIO, context: t.Optional[t.List[str]] = None, **kwargs
Expand Down

0 comments on commit 2feb387

Please sign in to comment.