diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a7cad5333..c93153c03e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 #### New Features & Functionality +- Add Llama cpp model in extensions. - Simplify the testing of SQL databases using containerized databases - Integrate Monitoring(cadvisor/Prometheus) and Logging (promtail/Loki) with Grafana, in the `testenv` - Add `QueryModel` and `SequentialModel` to make chaining searches and models easier. diff --git a/deploy/images/superduperdb/Dockerfile b/deploy/images/superduperdb/Dockerfile index 60b24e3312..32311ca976 100644 --- a/deploy/images/superduperdb/Dockerfile +++ b/deploy/images/superduperdb/Dockerfile @@ -28,6 +28,7 @@ RUN apt-get update \ libglib2.0-0 libgl1-mesa-glx \ # Required for PostgreSQL \ libpq-dev \ + build-essential \ # Purge apt cache && apt-get clean \ && rm -rf /var/lib/apt/lists/* diff --git a/pyproject.toml b/pyproject.toml index d36789e8a5..a89a422666 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -67,7 +67,8 @@ dependencies = [ "PyYAML>=6.0.0", "prettytable", "python-dotenv", - "ray[default]>=2.8.1" + "ray[default]>=2.8.1", + "llama_cpp_python>=0.2.39" ] # --------------------------------- diff --git a/superduperdb/ext/llama_cpp/model.py b/superduperdb/ext/llama_cpp/model.py deleted file mode 100644 index 499dad444e..0000000000 --- a/superduperdb/ext/llama_cpp/model.py +++ /dev/null @@ -1,80 +0,0 @@ -import dataclasses as dc -import functools -import os -import typing as t - -import requests -from llama_cpp import Llama - -from superduperdb.components.model import Model - - -def download_uri(uri, save_path): - response = requests.get(uri) - if response.status_code == 200: - with open(save_path, 'wb') as file: - file.write(response.content) - else: - raise Exception(f"Error while downloading uri {uri}") - - -@dc.dataclass -class LlamaCpp(Model): - model_name_or_path: str = "facebook/opt-125m" - object: t.Optional[Llama] = None - model_kwargs: t.Dict = dc.field(default_factory=dict) - download_dir: str = '.llama_cpp' - - def __post_init__(self): - if self.model_name_or_path.startswith('http'): - # Download the uri - os.makedirs(self.download_dir, exist_ok=True) - saved_path = os.path.join(self.download_dir, f'{self.identifier}.gguf') - - download_uri(self.model_name_or_path, saved_path) - self.model_name_or_path = saved_path - - if self.predict_kwargs is None: - self.predict_kwargs = {} - - self._model = Llama(self.model_name_or_path, **self.model_kwargs) - super().__post_init__() - - def _predict( - self, - X: t.Union[str, t.List[str], t.List[dict[str, str]]], - one: bool = False, - **kwargs: t.Any, - ): - one = isinstance(X, str) - - assert isinstance(self.predict_kwargs, dict) - to_call = functools.partial( - self._model.create_completion, **self.predict_kwargs - ) - if one: - return to_call(X) - else: - return list(map(to_call, X)) - - -@dc.dataclass -class LlamaCppEmbedding(LlamaCpp): - def __post_init__(self): - self.model_kwargs['embedding'] = True - super().__post_init__() - - def _predict( - self, - X: t.Union[str, t.List[str], t.List[dict[str, str]]], - one: bool = False, - **kwargs: t.Any, - ): - one = isinstance(X, str) - assert isinstance(self.predict_kwargs, dict) - - to_call = functools.partial(self._model.create_embedding, **self.predict_kwargs) - if one: - return to_call(X) - else: - return list(map(to_call, X)) diff --git a/superduperdb/ext/llama_cpp/__init__.py b/superduperdb/ext/llamacpp/__init__.py similarity index 100% rename from superduperdb/ext/llama_cpp/__init__.py rename to superduperdb/ext/llamacpp/__init__.py diff --git a/superduperdb/ext/llamacpp/model.py b/superduperdb/ext/llamacpp/model.py new file mode 100644 index 0000000000..b746a4e802 --- /dev/null +++ b/superduperdb/ext/llamacpp/model.py @@ -0,0 +1,56 @@ +import dataclasses as dc +import os +import typing as t + +import requests +from llama_cpp import Llama + +from superduperdb.ext.llm.base import _BaseLLM + + +def download_uri(uri, save_path): + response = requests.get(uri) + if response.status_code == 200: + with open(save_path, 'wb') as file: + file.write(response.content) + else: + raise Exception(f"Error while downloading uri {uri}") + + +@dc.dataclass +class LlamaCpp(_BaseLLM): + model_name_or_path: str = "facebook/opt-125m" + object: t.Optional[Llama] = None + model_kwargs: t.Dict = dc.field(default_factory=dict) + download_dir: str = '.llama_cpp' + + def init(self): + if self.model_name_or_path.startswith('http'): + # Download the uri + os.makedirs(self.download_dir, exist_ok=True) + saved_path = os.path.join(self.download_dir, f'{self.identifier}.gguf') + + download_uri(self.model_name_or_path, saved_path) + self.model_name_or_path = saved_path + + if self.predict_kwargs is None: + self.predict_kwargs = {} + + self._model = Llama(self.model_name_or_path, **self.model_kwargs) + + def _generate(self, prompt: str, **kwargs) -> str: + """ + Generate text from a prompt. + """ + return self._model.create_completion(prompt, **self.predict_kwargs, **kwargs) + + +@dc.dataclass +class LlamaCppEmbedding(LlamaCpp): + def _generate(self, prompt: str, **kwargs) -> str: + """ + Generate embedding from a prompt. + """ + return self._model.create_embedding( + prompt, embedding=True, **self.predict_kwargs, **kwargs + ) diff --git a/test/unittest/ext/test_llama_cpp.py b/test/unittest/ext/test_llama_cpp.py new file mode 100644 index 0000000000..cd8323b9b0 --- /dev/null +++ b/test/unittest/ext/test_llama_cpp.py @@ -0,0 +1,45 @@ +from superduperdb.ext.llamacpp.model import LlamaCpp, LlamaCppEmbedding + + +class _MockedLlama: + def create_completion(self, *args, **kwargs): + return 'tested' + + def create_embedding(self, *args, **kwargs): + return [1] + + +def test_llama(): + def mocked_init(self): + self._model = _MockedLlama() + self.predict_kwargs = {} + + LlamaCpp.init = mocked_init + + llama = LlamaCpp( + identifier='myllama', + model_name_or_path='some_model', + model_kwargs={'vocab_only': True}, + ) + + text = 'testing prompt' + output = llama.predict(text, one=True) + assert output == 'tested' + + +def test_llama_embedding(): + def mocked_init(self): + self._model = _MockedLlama() + self.predict_kwargs = {} + + LlamaCppEmbedding.init = mocked_init + + llama = LlamaCppEmbedding( + identifier='myllama', + model_name_or_path='some_model', + model_kwargs={'vocab_only': True}, + ) + + text = 'testing prompt' + output = llama.predict(text, one=True) + assert output == [1]