Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Weaviate memory adapter #95

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
8 changes: 8 additions & 0 deletions llama_stack/providers/adapters/memory/weaviate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from .config import WeaviateConfig

async def get_adapter_impl(config: WeaviateConfig, _deps):
from .weaviate import WeaviateMemoryAdapter

impl = WeaviateMemoryAdapter(config)
await impl.initialize()
return impl
18 changes: 18 additions & 0 deletions llama_stack/providers/adapters/memory/weaviate/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field

class WeaviateRequestProviderData(BaseModel):
# if there _is_ provider data, it must specify the API KEY
# if you want it to be optional, use Optional[str]
weaviate_api_key: str
weaviate_cluster_url: str

@json_schema_type
class WeaviateConfig(BaseModel):
collection: str = Field(default="MemoryBank")
192 changes: 192 additions & 0 deletions llama_stack/providers/adapters/memory/weaviate/weaviate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
import json
import uuid
from typing import List, Optional, Dict, Any
from numpy.typing import NDArray

import weaviate
import weaviate.classes as wvc
from weaviate.classes.init import Auth

from llama_stack.apis.memory import *
from llama_stack.distribution.request_headers import get_request_provider_data
from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex,
EmbeddingIndex,
)

from .config import WeaviateConfig, WeaviateRequestProviderData

class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection: str):
self.client = client
self.collection = collection

async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"

data_objects = []
for i, chunk in enumerate(chunks):

data_objects.append(wvc.data.DataObject(
properties={
"chunk_content": chunk,
},
vector = embeddings[i].tolist()
))

# Inserting chunks into a prespecified Weaviate collection
assert self.collection is not None, "Collection name must be specified"
my_collection = self.client.collections.get(self.collection)

await my_collection.data.insert_many(data_objects)


async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
assert self.collection is not None, "Collection name must be specified"

my_collection = self.client.collections.get(self.collection)

results = my_collection.query.near_vector(
near_vector = embedding.tolist(),
limit = k,
return_meta_data = wvc.query.MetadataQuery(distance=True)
)

chunks = []
scores = []
for doc in results.objects:
try:
chunk = doc.properties["chunk_content"]
chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance)

except Exception as e:
import traceback
traceback.print_exc()
print(f"Failed to parse document: {e}")

return QueryDocumentsResponse(chunks=chunks, scores=scores)


class WeaviateMemoryAdapter(Memory):
def __init__(self, config: WeaviateConfig) -> None:
self.config = config
self.client = None
self.cache = {}

def _get_client(self) -> weaviate.Client:
request_provider_data = get_request_provider_data()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

provider data isn't present at this point -- it is only provided at the time of the request. you should initialize the client on every client call and if we need a cache of clients then, we'd need to build that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

modified the method to return an initialized client. following this I initialize the client on every client call using the method.


if request_provider_data is not None:
assert isinstance(request_provider_data, WeaviateRequestProviderData)

# Connect to Weaviate Cloud
return weaviate.connect_to_weaviate_cloud(
cluster_url = request_provider_data.weaviate_cluster_url,
auth_credentials = Auth.api_key(request_provider_data.weaviate_api_key),
)

async def initialize(self) -> None:
try:
self.client = self._get_client()

# Create collection if it doesn't exist
if not self.client.collections.exists(self.config.collection):
self.client.collections.create(
name = self.config.collection,
vectorizer_config = wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
]
)

except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to Weaviate server") from e

async def shutdown(self) -> None:
self.client = self._get_client()

if self.client:
self.client.close()

async def create_memory_bank(
self,
name: str,
config: MemoryBankConfig,
url: Optional[URL] = None,
) -> MemoryBank:
bank_id = str(uuid.uuid4())
bank = MemoryBank(
bank_id=bank_id,
name=name,
config=config,
url=url,
)
self.client = self._get_client()

# Store the bank as a new collection in Weaviate
self.client.collections.create(
name=bank_id
)

index = BankWithIndex(
bank=bank,
index=WeaviateIndex(cleint = self.client, collection = bank_id),
)
self.cache[bank_id] = index
return bank

async def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]:
bank_index = await self._get_and_cache_bank_index(bank_id)
if bank_index is None:
return None
return bank_index.bank

async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:

self.client = self._get_client()

if bank_id in self.cache:
return self.cache[bank_id]

collections = await self.client.collections.list_all().keys()

for collection in collections:
if collection == bank_id:
bank = MemoryBank(**json.loads(collection.metadata["bank"]))
index = BankWithIndex(
bank=bank,
index=WeaviateIndex(self.client, collection),
)
self.cache[bank_id] = index
return index

return None

async def insert_documents(
self,
bank_id: str,
documents: List[MemoryBankDocument],
) -> None:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")

await index.insert_documents(documents)

async def query_documents(
self,
bank_id: str,
query: InterleavedTextMedia,
params: Optional[Dict[str, Any]] = None,
) -> QueryDocumentsResponse:
index = await self._get_and_cache_bank_index(bank_id)
if not index:
raise ValueError(f"Bank {bank_id} not found")

return await index.query_documents(query, params)
9 changes: 9 additions & 0 deletions llama_stack/providers/registry/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.adapters.memory.pgvector.PGVectorConfig",
),
),
remote_provider_spec(
Api.memory,
AdapterSpec(
adapter_id="weaviate",
pip_packages=EMBEDDING_DEPS + ["weaviate-client"],
module="llama_stack.providers.adapters.memory.weaviate",
provider_data_validator="llama_stack.providers.adapters.memory.weaviate.WeaviateRequestProviderData",
),
),
remote_provider_spec(
api=Api.memory,
adapter=AdapterSpec(
Expand Down