Skip to content

Commit

Permalink
Add Llama cpp test
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Feb 22, 2024
1 parent 7874183 commit fb6a9f6
Show file tree
Hide file tree
Showing 7 changed files with 105 additions and 81 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

#### New Features & Functionality

- Add Llama cpp model in extensions.
- Simplify the testing of SQL databases using containerized databases
- Integrate Monitoring(cadvisor/Prometheus) and Logging (promtail/Loki) with Grafana, in the `testenv`
- Add `QueryModel` and `SequentialModel` to make chaining searches and models easier.
Expand Down
1 change: 1 addition & 0 deletions deploy/images/superduperdb/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ RUN apt-get update \
libglib2.0-0 libgl1-mesa-glx \
# Required for PostgreSQL \
libpq-dev \
build-essential \
# Purge apt cache
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ dependencies = [
"PyYAML>=6.0.0",
"prettytable",
"python-dotenv",
"ray[default]>=2.8.1"
"ray[default]>=2.8.1",
"llama_cpp_python>=0.2.39"
]

# ---------------------------------
Expand Down
80 changes: 0 additions & 80 deletions superduperdb/ext/llama_cpp/model.py

This file was deleted.

File renamed without changes.
56 changes: 56 additions & 0 deletions superduperdb/ext/llamacpp/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import dataclasses as dc
import os
import typing as t

import requests
from llama_cpp import Llama

from superduperdb.ext.llm.base import _BaseLLM


def download_uri(uri, save_path):
response = requests.get(uri)
if response.status_code == 200:
with open(save_path, 'wb') as file:
file.write(response.content)
else:
raise Exception(f"Error while downloading uri {uri}")


@dc.dataclass
class LlamaCpp(_BaseLLM):
model_name_or_path: str = "facebook/opt-125m"
object: t.Optional[Llama] = None
model_kwargs: t.Dict = dc.field(default_factory=dict)
download_dir: str = '.llama_cpp'

def init(self):
if self.model_name_or_path.startswith('http'):
# Download the uri
os.makedirs(self.download_dir, exist_ok=True)
saved_path = os.path.join(self.download_dir, f'{self.identifier}.gguf')

download_uri(self.model_name_or_path, saved_path)
self.model_name_or_path = saved_path

if self.predict_kwargs is None:
self.predict_kwargs = {}

self._model = Llama(self.model_name_or_path, **self.model_kwargs)

def _generate(self, prompt: str, **kwargs) -> str:
"""
Generate text from a prompt.
"""
return self._model.create_completion(prompt, **self.predict_kwargs, **kwargs)


@dc.dataclass
class LlamaCppEmbedding(LlamaCpp):
def _generate(self, prompt: str, **kwargs) -> str:
"""
Generate embedding from a prompt.
"""
return self._model.create_embedding(
prompt, embedding=True, **self.predict_kwargs, **kwargs
)
45 changes: 45 additions & 0 deletions test/unittest/ext/test_llama_cpp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from superduperdb.ext.llamacpp.model import LlamaCpp, LlamaCppEmbedding


class _MockedLlama:
def create_completion(self, *args, **kwargs):
return 'tested'

def create_embedding(self, *args, **kwargs):
return [1]


def test_llama():
def mocked_init(self):
self._model = _MockedLlama()
self.predict_kwargs = {}

LlamaCpp.init = mocked_init

llama = LlamaCpp(
identifier='myllama',
model_name_or_path='some_model',
model_kwargs={'vocab_only': True},
)

text = 'testing prompt'
output = llama.predict(text, one=True)
assert output == 'tested'


def test_llama_embedding():
def mocked_init(self):
self._model = _MockedLlama()
self.predict_kwargs = {}

LlamaCppEmbedding.init = mocked_init

llama = LlamaCppEmbedding(
identifier='myllama',
model_name_or_path='some_model',
model_kwargs={'vocab_only': True},
)

text = 'testing prompt'
output = llama.predict(text, one=True)
assert output == [1]

0 comments on commit fb6a9f6

Please sign in to comment.