Skip to content

Commit

Permalink
Add UT for llm api: OpenAI, vLLM
Browse files Browse the repository at this point in the history
  • Loading branch information
jieguangzhou committed Dec 29, 2023
1 parent de6ed6c commit 2a90a24
Show file tree
Hide file tree
Showing 14 changed files with 3,634 additions and 4 deletions.
22 changes: 22 additions & 0 deletions superduperdb/ext/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import asyncio
import dataclasses as dc
import inspect
import typing
from functools import wraps
from logging import WARNING, getLogger
from typing import Any, Callable, List, Optional, Union
Expand All @@ -11,6 +12,9 @@
from superduperdb.components.model import _Predictor
from superduperdb.ext.utils import format_prompt

if typing.TYPE_CHECKING:
from superduperdb.base.datalayer import Datalayer

# Disable httpx info level logging
getLogger("httpx").setLevel(WARNING)

Expand Down Expand Up @@ -51,8 +55,26 @@ class _BaseLLM(Component, _Predictor, metaclass=abc.ABCMeta):
def __post_init__(self):
super().__post_init__()
self.takes_context = True
self.identifier = self.identifier.replace("/", "-")
assert "{input}" in self.prompt_template, "Template must contain {input}"

def to_call(self, X, *args, **kwargs):
raise NotImplementedError

def post_create(self, db: "Datalayer") -> None:
# TODO: Do not make sense to add this logic here,
# Need a auto DataType to handle this
from superduperdb.backends.ibis.data_backend import IbisDataBackend
from superduperdb.backends.ibis.field_types import dtype

if isinstance(db.databackend, IbisDataBackend) and self.encoder is None:
self.encoder = dtype('str')

# since then the `.add` clause is not necessary
output_component = db.databackend.create_model_table_or_collection(self) # type: ignore[arg-type]
if output_component is not None:
db.add(output_component)

@abc.abstractmethod
def init(self):
...
Expand Down
4 changes: 2 additions & 2 deletions superduperdb/ext/llm/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def _generate(self, prompt: str, **kwargs) -> str:
"""
post_data = self.build_post_data(prompt, **kwargs)
response = requests.post(self.api_url, json=post_data)
return response.json()["text"]
return response.json()["text"][0]

def _batch_generate(self, prompts: List[str], **kwargs: Any) -> List[str]:
"""
Expand All @@ -53,7 +53,7 @@ async def _async_generate(self, session, semaphore, prompt: str, **kwargs) -> st
try:
async with session.post(self.api_url, json=post_data) as response:
response_json = await response.json()
return response_json["text"]
return response_json["text"][0]
except aiohttp.ClientError as e:
logging.error(f"HTTP request failed: {e}. Prompt: {prompt}")
return ""
Expand Down
1,203 changes: 1,203 additions & 0 deletions test/unittest/ext/cassettes/llm/openai/test_llm_as_listener_model.yaml

Large diffs are not rendered by default.

984 changes: 984 additions & 0 deletions test/unittest/ext/cassettes/llm/openai/test_predict.yaml

Large diffs are not rendered by default.

Loading

0 comments on commit 2a90a24

Please sign in to comment.