diff --git a/libs/colbert/ragstack_colbert/base_database.py b/libs/colbert/ragstack_colbert/base_database.py index b5bb7b2c5..2870139e3 100644 --- a/libs/colbert/ragstack_colbert/base_database.py +++ b/libs/colbert/ragstack_colbert/base_database.py @@ -5,10 +5,13 @@ models. """ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import List, Tuple +from typing import TYPE_CHECKING -from .objects import Chunk, Vector +if TYPE_CHECKING: + from .objects import Chunk, Vector class BaseDatabase(ABC): @@ -24,7 +27,7 @@ class BaseDatabase(ABC): """ @abstractmethod - def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: + def add_chunks(self, chunks: list[Chunk]) -> list[tuple[str, int]]: """Stores a list of embedded text chunks in the vector store. Args: @@ -35,7 +38,7 @@ def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: """ @abstractmethod - def delete_chunks(self, doc_ids: List[str]) -> bool: + def delete_chunks(self, doc_ids: list[str]) -> bool: """Deletes chunks from the vector store based on their document id. Args: @@ -48,8 +51,8 @@ def delete_chunks(self, doc_ids: List[str]) -> bool: @abstractmethod async def aadd_chunks( - self, chunks: List[Chunk], concurrent_inserts: int = 100 - ) -> List[Tuple[str, int]]: + self, chunks: list[Chunk], concurrent_inserts: int = 100 + ) -> list[tuple[str, int]]: """Stores a list of embedded text chunks in the vector store. Args: @@ -63,7 +66,7 @@ async def aadd_chunks( @abstractmethod async def adelete_chunks( - self, doc_ids: List[str], concurrent_deletes: int = 100 + self, doc_ids: list[str], concurrent_deletes: int = 100 ) -> bool: """Deletes chunks from the vector store based on their document id. @@ -78,7 +81,7 @@ async def adelete_chunks( """ @abstractmethod - async def search_relevant_chunks(self, vector: Vector, n: int) -> List[Chunk]: + async def search_relevant_chunks(self, vector: Vector, n: int) -> list[Chunk]: """Retrieves 'n' ANN results for an embedded token vector. Returns: diff --git a/libs/colbert/ragstack_colbert/base_embedding_model.py b/libs/colbert/ragstack_colbert/base_embedding_model.py index 0e58285d5..07004d964 100644 --- a/libs/colbert/ragstack_colbert/base_embedding_model.py +++ b/libs/colbert/ragstack_colbert/base_embedding_model.py @@ -4,10 +4,13 @@ embeddings for text. """ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import List, Optional +from typing import TYPE_CHECKING -from .objects import Embedding +if TYPE_CHECKING: + from .objects import Embedding class BaseEmbeddingModel(ABC): @@ -21,7 +24,7 @@ class BaseEmbeddingModel(ABC): """ @abstractmethod - def embed_texts(self, texts: List[str]) -> List[Embedding]: + def embed_texts(self, texts: list[str]) -> list[Embedding]: """Embeds a list of texts into their vector embedding representations. Args: @@ -36,7 +39,7 @@ def embed_query( self, query: str, full_length_search: bool = False, - query_maxlen: Optional[int] = None, + query_maxlen: int | None = None, ) -> Embedding: """Embeds a single query text into its vector representation. diff --git a/libs/colbert/ragstack_colbert/base_retriever.py b/libs/colbert/ragstack_colbert/base_retriever.py index a0486ffcf..e5894e19f 100644 --- a/libs/colbert/ragstack_colbert/base_retriever.py +++ b/libs/colbert/ragstack_colbert/base_retriever.py @@ -5,10 +5,13 @@ models. """ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Any, List, Optional, Tuple +from typing import TYPE_CHECKING, Any -from .objects import Chunk, Embedding +if TYPE_CHECKING: + from .objects import Chunk, Embedding class BaseRetriever(ABC): @@ -24,10 +27,10 @@ class BaseRetriever(ABC): def embedding_search( self, query_embedding: Embedding, - k: Optional[int] = None, + k: int | None = None, include_embedding: bool = False, **kwargs: Any, - ) -> List[Tuple[Chunk, float]]: + ) -> list[tuple[Chunk, float]]: """Search for relevant text chunks based on a query embedding. Retrieves a list of text chunks relevant to a given query from the vector @@ -53,10 +56,10 @@ def embedding_search( async def aembedding_search( self, query_embedding: Embedding, - k: Optional[int] = None, + k: int | None = None, include_embedding: bool = False, **kwargs: Any, - ) -> List[Tuple[Chunk, float]]: + ) -> list[tuple[Chunk, float]]: """Search for relevant text chunks based on a query embedding. Retrieves a list of text chunks relevant to a given query from the vector @@ -82,11 +85,11 @@ async def aembedding_search( def text_search( self, query_text: str, - k: Optional[int] = None, - query_maxlen: Optional[int] = None, + k: int | None = None, + query_maxlen: int | None = None, include_embedding: bool = False, **kwargs: Any, - ) -> List[Tuple[Chunk, float]]: + ) -> list[tuple[Chunk, float]]: """Search for relevant text chunks based on a query text. Retrieves a list of text chunks relevant to a given query from the vector @@ -113,11 +116,11 @@ def text_search( async def atext_search( self, query_text: str, - k: Optional[int] = None, - query_maxlen: Optional[int] = None, + k: int | None = None, + query_maxlen: int | None = None, include_embedding: bool = False, **kwargs: Any, - ) -> List[Tuple[Chunk, float]]: + ) -> list[tuple[Chunk, float]]: """Search for relevant text chunks based on a query text. Retrieves a list of text chunks relevant to a given query from the vector diff --git a/libs/colbert/ragstack_colbert/base_vector_store.py b/libs/colbert/ragstack_colbert/base_vector_store.py index 7f2a42c73..fc5251045 100644 --- a/libs/colbert/ragstack_colbert/base_vector_store.py +++ b/libs/colbert/ragstack_colbert/base_vector_store.py @@ -5,11 +5,14 @@ and can be used to create a LangChain or LlamaIndex ColBERT vector store. """ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING -from .base_retriever import BaseRetriever -from .objects import Chunk, Metadata +if TYPE_CHECKING: + from .base_retriever import BaseRetriever + from .objects import Chunk, Metadata # LlamaIndex Node (chunk) has ids, text, embedding, metadata # VectorStore.add(nodes: List[Node]) -> List[str](ids): embeds texts OUTside add # noqa: E501 @@ -37,7 +40,7 @@ class BaseVectorStore(ABC): # handles LlamaIndex add @abstractmethod - def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: + def add_chunks(self, chunks: list[Chunk]) -> list[tuple[str, int]]: """Stores a list of embedded text chunks in the vector store. Args: @@ -51,10 +54,10 @@ def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: @abstractmethod def add_texts( self, - texts: List[str], - metadatas: Optional[List[Metadata]], - doc_id: Optional[str] = None, - ) -> List[Tuple[str, int]]: + texts: list[str], + metadatas: list[Metadata] | None, + doc_id: str | None = None, + ) -> list[tuple[str, int]]: """Adds text chunks to the vector store. Embeds and stores a list of text chunks and optional metadata into the vector @@ -73,7 +76,7 @@ def add_texts( # handles LangChain and LlamaIndex delete @abstractmethod - def delete_chunks(self, doc_ids: List[str]) -> bool: + def delete_chunks(self, doc_ids: list[str]) -> bool: """Deletes chunks from the vector store based on their document id. Args: @@ -87,8 +90,8 @@ def delete_chunks(self, doc_ids: List[str]) -> bool: # handles LlamaIndex add @abstractmethod async def aadd_chunks( - self, chunks: List[Chunk], concurrent_inserts: int = 100 - ) -> List[Tuple[str, int]]: + self, chunks: list[Chunk], concurrent_inserts: int = 100 + ) -> list[tuple[str, int]]: """Stores a list of embedded text chunks in the vector store. Args: @@ -104,11 +107,11 @@ async def aadd_chunks( @abstractmethod async def aadd_texts( self, - texts: List[str], - metadatas: Optional[List[Metadata]], - doc_id: Optional[str] = None, + texts: list[str], + metadatas: list[Metadata] | None, + doc_id: str | None = None, concurrent_inserts: int = 100, - ) -> List[Tuple[str, int]]: + ) -> list[tuple[str, int]]: """Adds text chunks to the vector store. Embeds and stores a list of text chunks and optional metadata into the vector @@ -130,7 +133,7 @@ async def aadd_texts( # handles LangChain and LlamaIndex delete @abstractmethod async def adelete_chunks( - self, doc_ids: List[str], concurrent_deletes: int = 100 + self, doc_ids: list[str], concurrent_deletes: int = 100 ) -> bool: """Deletes chunks from the vector store based on their document id. diff --git a/libs/colbert/ragstack_colbert/cassandra_database.py b/libs/colbert/ragstack_colbert/cassandra_database.py index 2942cae47..67b53f7fc 100644 --- a/libs/colbert/ragstack_colbert/cassandra_database.py +++ b/libs/colbert/ragstack_colbert/cassandra_database.py @@ -6,13 +6,14 @@ facilitating scalable and high-relevancy retrieval operations. """ +from __future__ import annotations + import asyncio import logging from collections import defaultdict -from typing import Any, Awaitable, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Awaitable import cassio -from cassandra.cluster import Session from cassio.table.query import Predicate, PredicateOperator from cassio.table.tables import ClusteredMetadataVectorCassandraTable from typing_extensions import Self, override @@ -21,6 +22,9 @@ from .constant import DEFAULT_COLBERT_DIM from .objects import Chunk, Vector +if TYPE_CHECKING: + from cassandra.cluster import Session + class CassandraDatabaseError(Exception): """Exception raised for errors in the CassandraDatabase class.""" @@ -51,9 +55,9 @@ def from_astra( cls, database_id: str, astra_token: str, - keyspace: Optional[str] = "default_keyspace", + keyspace: str | None = "default_keyspace", table_name: str = "colbert", - timeout: Optional[int] = 300, + timeout: int | None = 300, ) -> Self: """Creates a CassandraVectorStore using AstraDB connection info.""" cassio.init(token=astra_token, database_id=database_id, keyspace=keyspace) @@ -68,7 +72,7 @@ def from_astra( def from_session( cls, session: Session, - keyspace: Optional[str] = "default_keyspace", + keyspace: str | None = "default_keyspace", table_name: str = "colbert", ) -> Self: """Creates a CassandraVectorStore using an existing session.""" @@ -79,7 +83,7 @@ def from_session( def _initialize( self, session: Session, - keyspace: Optional[str], + keyspace: str | None, table_name: str, ) -> None: """Initializes a new instance of the CassandraVectorStore. @@ -129,9 +133,9 @@ def _log_insert_error( ) @override - def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: - failed_chunks: List[Tuple[str, int]] = [] - success_chunks: List[Tuple[str, int]] = [] + def add_chunks(self, chunks: list[Chunk]) -> list[tuple[str, int]]: + failed_chunks: list[tuple[str, int]] = [] + success_chunks: list[tuple[str, int]] = [] for chunk in chunks: doc_id = chunk.doc_id @@ -182,10 +186,10 @@ async def _limited_put( doc_id: str, chunk_id: int, embedding_id: int = -1, - text: Optional[str] = None, - metadata: Optional[Dict[str, Any]] = None, - vector: Optional[Vector] = None, - ) -> Tuple[str, int, int, Optional[Exception]]: + text: str | None = None, + metadata: dict[str, Any] | None = None, + vector: Vector | None = None, + ) -> tuple[str, int, int, Exception | None]: row_id = (chunk_id, embedding_id) async with sem: try: @@ -206,11 +210,11 @@ async def _limited_put( @override async def aadd_chunks( - self, chunks: List[Chunk], concurrent_inserts: int = 100 - ) -> List[Tuple[str, int]]: + self, chunks: list[Chunk], concurrent_inserts: int = 100 + ) -> list[tuple[str, int]]: semaphore = asyncio.Semaphore(concurrent_inserts) - all_tasks: List[Awaitable[Tuple[str, int, int, Optional[Exception]]]] = [] - tasks_per_chunk: Dict[Tuple[str, int], int] = defaultdict(int) + all_tasks: list[Awaitable[tuple[str, int, int, Exception | None]]] = [] + tasks_per_chunk: dict[tuple[str, int], int] = defaultdict(int) for chunk in chunks: doc_id = chunk.doc_id @@ -259,8 +263,8 @@ async def aadd_chunks( exp=exp, ) - outputs: List[Tuple[str, int]] = [] - failed_chunks: List[Tuple[str, int]] = [] + outputs: list[tuple[str, int]] = [] + failed_chunks: list[tuple[str, int]] = [] for doc_id, chunk_id in tasks_per_chunk: if tasks_per_chunk[(doc_id, chunk_id)] == 0: @@ -277,8 +281,8 @@ async def aadd_chunks( return outputs @override - def delete_chunks(self, doc_ids: List[str]) -> bool: - failed_docs: List[str] = [] + def delete_chunks(self, doc_ids: list[str]) -> bool: + failed_docs: list[str] = [] for doc_id in doc_ids: try: @@ -299,7 +303,7 @@ async def _limited_delete( self, sem: asyncio.Semaphore, doc_id: str, - ) -> Tuple[str, Optional[Exception]]: + ) -> tuple[str, Exception | None]: async with sem: try: await self._table.adelete_partition(partition_id=doc_id) @@ -309,7 +313,7 @@ async def _limited_delete( @override async def adelete_chunks( - self, doc_ids: List[str], concurrent_deletes: int = 100 + self, doc_ids: list[str], concurrent_deletes: int = 100 ) -> bool: semaphore = asyncio.Semaphore(concurrent_deletes) all_tasks = [ @@ -323,7 +327,7 @@ async def adelete_chunks( results = await asyncio.gather(*all_tasks, return_exceptions=True) success = True - failed_docs: List[str] = [] + failed_docs: list[str] = [] for result in results: if isinstance(result, BaseException): @@ -344,8 +348,8 @@ async def adelete_chunks( return success @override - async def search_relevant_chunks(self, vector: Vector, n: int) -> List[Chunk]: - chunks: Set[Chunk] = set() + async def search_relevant_chunks(self, vector: Vector, n: int) -> list[Chunk]: + chunks: set[Chunk] = set() # TODO: only return partition_id and row_id after cassio supports this rows = await self._table.aann_search(vector=vector, n=n) diff --git a/libs/colbert/ragstack_colbert/colbert_embedding_model.py b/libs/colbert/ragstack_colbert/colbert_embedding_model.py index e76cb4274..87c9d2348 100644 --- a/libs/colbert/ragstack_colbert/colbert_embedding_model.py +++ b/libs/colbert/ragstack_colbert/colbert_embedding_model.py @@ -10,7 +10,7 @@ with support for both CPU and GPU computing environments. """ -from typing import List, Optional +from __future__ import annotations from colbert.infra import ColBERTConfig from typing_extensions import override @@ -43,7 +43,7 @@ def __init__( nbits: int = 2, kmeans_niters: int = 4, nranks: int = -1, - query_maxlen: Optional[int] = None, + query_maxlen: int | None = None, verbose: int = 3, # 3 is the default on ColBERT checkpoint chunk_batch_size: int = 640, ): @@ -84,7 +84,7 @@ def __init__( self._encoder = TextEncoder(config=colbert_config, verbose=verbose) @override - def embed_texts(self, texts: List[str]) -> List[Embedding]: + def embed_texts(self, texts: list[str]) -> list[Embedding]: chunks = [ Chunk(doc_id="dummy", chunk_id=i, text=t) for i, t in enumerate(texts) ] @@ -104,7 +104,7 @@ def embed_query( self, query: str, full_length_search: bool = False, - query_maxlen: Optional[int] = None, + query_maxlen: int | None = None, ) -> Embedding: if query_maxlen is None: query_maxlen = -1 diff --git a/libs/colbert/ragstack_colbert/colbert_retriever.py b/libs/colbert/ragstack_colbert/colbert_retriever.py index ee3d891ab..0008fc48c 100644 --- a/libs/colbert/ragstack_colbert/colbert_retriever.py +++ b/libs/colbert/ragstack_colbert/colbert_retriever.py @@ -11,18 +11,22 @@ hardware environments. """ +from __future__ import annotations + import asyncio import logging import math -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any import torch from typing_extensions import override -from .base_database import BaseDatabase -from .base_embedding_model import BaseEmbeddingModel from .base_retriever import BaseRetriever -from .objects import Chunk, Embedding, Vector + +if TYPE_CHECKING: + from .base_database import BaseDatabase + from .base_embedding_model import BaseEmbeddingModel + from .objects import Chunk, Embedding, Vector def all_gpus_support_fp16(is_cuda: bool = False) -> bool: @@ -149,9 +153,9 @@ def __init__( async def _query_relevant_chunks( self, query_embedding: Embedding, top_k: int - ) -> Set[Chunk]: + ) -> set[Chunk]: """Queries for the top_k most relevant chunks for each query token.""" - chunks: Set[Chunk] = set() + chunks: set[Chunk] = set() # Collect all tasks tasks = [ self._database.search_relevant_chunks(vector=v, n=top_k) @@ -171,7 +175,7 @@ async def _query_relevant_chunks( return chunks - async def _get_chunk_embeddings(self, chunks: Set[Chunk]) -> List[Chunk]: + async def _get_chunk_embeddings(self, chunks: set[Chunk]) -> list[Chunk]: """Retrieves Chunks with `doc_id`, `chunk_id`, and `embedding` set.""" # Collect all tasks tasks = [ @@ -194,8 +198,8 @@ async def _get_chunk_embeddings(self, chunks: Set[Chunk]) -> List[Chunk]: return chunk_embeddings def _score_chunks( - self, query_embedding: Embedding, chunk_embeddings: List[Chunk] - ) -> Dict[Chunk, float]: + self, query_embedding: Embedding, chunk_embeddings: list[Chunk] + ) -> dict[Chunk, float]: """Process the retrieved chunk data to calculate scores.""" chunk_scores = {} for chunk in chunk_embeddings: @@ -214,9 +218,9 @@ def _score_chunks( async def _get_chunk_data( self, - chunks: List[Chunk], + chunks: list[Chunk], include_embedding: bool = False, - ) -> List[Chunk]: + ) -> list[Chunk]: """Fetches text and metadata for each chunk. Returns: @@ -250,11 +254,11 @@ async def _get_chunk_data( async def atext_search( self, query_text: str, - k: Optional[int] = 5, - query_maxlen: Optional[int] = None, + k: int | None = 5, + query_maxlen: int | None = None, include_embedding: bool = False, **kwargs: Any, - ) -> List[Tuple[Chunk, float]]: + ) -> list[tuple[Chunk, float]]: query_embedding = self._embedding_model.embed_query( query=query_text, query_maxlen=query_maxlen ) @@ -270,10 +274,10 @@ async def atext_search( async def aembedding_search( self, query_embedding: Embedding, - k: Optional[int] = 5, + k: int | None = 5, include_embedding: bool = False, **kwargs: Any, - ) -> List[Tuple[Chunk, float]]: + ) -> list[tuple[Chunk, float]]: if k is None: k = 5 top_k = max(math.floor(len(query_embedding) / 2), 16) @@ -285,28 +289,28 @@ async def aembedding_search( ) # search for relevant chunks (only with `doc_id` and `chunk_id` set) - relevant_chunks: Set[Chunk] = await self._query_relevant_chunks( + relevant_chunks: set[Chunk] = await self._query_relevant_chunks( query_embedding=query_embedding, top_k=top_k ) # get the embedding for each chunk # (with `doc_id`, `chunk_id`, and `embedding` set) - chunk_embeddings: List[Chunk] = await self._get_chunk_embeddings( + chunk_embeddings: list[Chunk] = await self._get_chunk_embeddings( chunks=relevant_chunks ) # score the chunks using max_similarity - chunk_scores: Dict[Chunk, float] = self._score_chunks( + chunk_scores: dict[Chunk, float] = self._score_chunks( query_embedding=query_embedding, chunk_embeddings=chunk_embeddings, ) # only keep the top k sorted results - top_k_chunks: List[Chunk] = sorted( + top_k_chunks: list[Chunk] = sorted( chunk_scores, key=lambda c: chunk_scores.get(c, 0), reverse=True )[:k] - chunks: List[Chunk] = await self._get_chunk_data( + chunks: list[Chunk] = await self._get_chunk_data( chunks=top_k_chunks, include_embedding=include_embedding ) @@ -316,11 +320,11 @@ async def aembedding_search( def text_search( self, query_text: str, - k: Optional[int] = 5, - query_maxlen: Optional[int] = None, + k: int | None = 5, + query_maxlen: int | None = None, include_embedding: bool = False, **kwargs: Any, - ) -> List[Tuple[Chunk, float]]: + ) -> list[tuple[Chunk, float]]: return asyncio.run( self.atext_search( query_text=query_text, @@ -335,10 +339,10 @@ def text_search( def embedding_search( self, query_embedding: Embedding, - k: Optional[int] = 5, + k: int | None = 5, include_embedding: bool = False, **kwargs: Any, - ) -> List[Tuple[Chunk, float]]: + ) -> list[tuple[Chunk, float]]: return asyncio.run( self.aembedding_search( query_embedding=query_embedding, diff --git a/libs/colbert/ragstack_colbert/colbert_vector_store.py b/libs/colbert/ragstack_colbert/colbert_vector_store.py index ca00465fa..490098f16 100644 --- a/libs/colbert/ragstack_colbert/colbert_vector_store.py +++ b/libs/colbert/ragstack_colbert/colbert_vector_store.py @@ -7,18 +7,22 @@ operations. """ +from __future__ import annotations + import uuid -from typing import List, Optional, Tuple +from typing import TYPE_CHECKING from typing_extensions import override -from .base_database import BaseDatabase -from .base_embedding_model import BaseEmbeddingModel -from .base_retriever import BaseRetriever from .base_vector_store import BaseVectorStore from .colbert_retriever import ColbertRetriever from .objects import Chunk, Metadata +if TYPE_CHECKING: + from .base_database import BaseDatabase + from .base_embedding_model import BaseEmbeddingModel + from .base_retriever import BaseRetriever + class ColbertVectorStore(BaseVectorStore): """A vector store implementation for ColBERT. @@ -30,12 +34,12 @@ class ColbertVectorStore(BaseVectorStore): """ _database: BaseDatabase - _embedding_model: Optional[BaseEmbeddingModel] + _embedding_model: BaseEmbeddingModel | None def __init__( self, database: BaseDatabase, - embedding_model: Optional[BaseEmbeddingModel] = None, + embedding_model: BaseEmbeddingModel | None = None, ): self._database = database self._embedding_model = embedding_model @@ -49,10 +53,10 @@ def _validate_embedding_model(self) -> BaseEmbeddingModel: def _build_chunks( self, - texts: List[str], - metadatas: Optional[List[Metadata]] = None, - doc_id: Optional[str] = None, - ) -> List[Chunk]: + texts: list[str], + metadatas: list[Metadata] | None = None, + doc_id: str | None = None, + ) -> list[Chunk]: embedding_model = self._validate_embedding_model() if metadatas is not None and len(texts) != len(metadatas): @@ -63,7 +67,7 @@ def _build_chunks( embeddings = embedding_model.embed_texts(texts=texts) - chunks: List[Chunk] = [] + chunks: list[Chunk] = [] for i, text in enumerate(texts): chunks.append( Chunk( @@ -77,27 +81,27 @@ def _build_chunks( return chunks @override - def add_chunks(self, chunks: List[Chunk]) -> List[Tuple[str, int]]: + def add_chunks(self, chunks: list[Chunk]) -> list[tuple[str, int]]: return self._database.add_chunks(chunks=chunks) @override def add_texts( self, - texts: List[str], - metadatas: Optional[List[Metadata]] = None, - doc_id: Optional[str] = None, - ) -> List[Tuple[str, int]]: + texts: list[str], + metadatas: list[Metadata] | None = None, + doc_id: str | None = None, + ) -> list[tuple[str, int]]: chunks = self._build_chunks(texts=texts, metadatas=metadatas, doc_id=doc_id) return self._database.add_chunks(chunks=chunks) @override - def delete_chunks(self, doc_ids: List[str]) -> bool: + def delete_chunks(self, doc_ids: list[str]) -> bool: return self._database.delete_chunks(doc_ids=doc_ids) @override async def aadd_chunks( - self, chunks: List[Chunk], concurrent_inserts: int = 100 - ) -> List[Tuple[str, int]]: + self, chunks: list[Chunk], concurrent_inserts: int = 100 + ) -> list[tuple[str, int]]: return await self._database.aadd_chunks( chunks=chunks, concurrent_inserts=concurrent_inserts ) @@ -105,11 +109,11 @@ async def aadd_chunks( @override async def aadd_texts( self, - texts: List[str], - metadatas: Optional[List[Metadata]] = None, - doc_id: Optional[str] = None, + texts: list[str], + metadatas: list[Metadata] | None = None, + doc_id: str | None = None, concurrent_inserts: int = 100, - ) -> List[Tuple[str, int]]: + ) -> list[tuple[str, int]]: chunks = self._build_chunks(texts=texts, metadatas=metadatas, doc_id=doc_id) return await self._database.aadd_chunks( chunks=chunks, concurrent_inserts=concurrent_inserts @@ -117,7 +121,7 @@ async def aadd_texts( @override async def adelete_chunks( - self, doc_ids: List[str], concurrent_deletes: int = 100 + self, doc_ids: list[str], concurrent_deletes: int = 100 ) -> bool: return await self._database.adelete_chunks( doc_ids=doc_ids, concurrent_deletes=concurrent_deletes diff --git a/libs/colbert/ragstack_colbert/objects.py b/libs/colbert/ragstack_colbert/objects.py index baed6a1d2..e75839b55 100644 --- a/libs/colbert/ragstack_colbert/objects.py +++ b/libs/colbert/ragstack_colbert/objects.py @@ -4,6 +4,8 @@ stages of processing within the ColBERT retrieval system. """ +from __future__ import annotations + from typing import Any, Dict, List, Optional from pydantic import BaseModel, Field @@ -34,7 +36,7 @@ class Chunk(BaseModel): metadata: Metadata = Field( default_factory=dict, description="flat metadata of the chunk" ) - embedding: Optional[Embedding] = Field( + embedding: Optional[Embedding] = Field( # noqa: UP007 default=None, description="embedding of the chunk" ) diff --git a/libs/colbert/ragstack_colbert/text_encoder.py b/libs/colbert/ragstack_colbert/text_encoder.py index 7b520bad0..59efbcad5 100644 --- a/libs/colbert/ragstack_colbert/text_encoder.py +++ b/libs/colbert/ragstack_colbert/text_encoder.py @@ -8,17 +8,21 @@ chunks. """ +from __future__ import annotations + import logging -from typing import List, Optional, cast +from typing import TYPE_CHECKING, cast import torch -from colbert.infra import ColBERTConfig from colbert.modeling.checkpoint import Checkpoint from .objects import Chunk, Embedding +if TYPE_CHECKING: + from colbert.infra import ColBERTConfig + -def calculate_query_maxlen(tokens: List[List[str]]) -> int: +def calculate_query_maxlen(tokens: list[list[str]]) -> int: """Calculates maximum query length. Calculates an appropriate maximum query length for token embeddings, @@ -53,7 +57,7 @@ class TextEncoder: verbose (int): The level of logging to use """ - def __init__(self, config: ColBERTConfig, verbose: Optional[int] = 3) -> None: + def __init__(self, config: ColBERTConfig, verbose: int | None = 3) -> None: logging.info("Cuda enabled GPU available: %s", torch.cuda.is_available()) self._checkpoint = Checkpoint( @@ -61,7 +65,7 @@ def __init__(self, config: ColBERTConfig, verbose: Optional[int] = 3) -> None: ) self._use_cpu = config.total_visible_gpus == 0 - def encode_chunks(self, chunks: List[Chunk], batch_size: int = 640) -> List[Chunk]: + def encode_chunks(self, chunks: list[Chunk], batch_size: int = 640) -> list[Chunk]: """Encodes a list of chunks into embeddings. Encodes a list of chunks into embeddings, processing in batches to @@ -78,7 +82,7 @@ def encode_chunks(self, chunks: List[Chunk], batch_size: int = 640) -> List[Chun """ logging.debug("#> Encoding %s chunks..", len(chunks)) - embedded_chunks: List[Chunk] = [] + embedded_chunks: list[Chunk] = [] if len(chunks) == 0: return embedded_chunks diff --git a/libs/colbert/tests/integration_tests/test_embedding_retrieval.py b/libs/colbert/tests/integration_tests/test_embedding_retrieval.py index e7a47ef7e..f233d49ec 100644 --- a/libs/colbert/tests/integration_tests/test_embedding_retrieval.py +++ b/libs/colbert/tests/integration_tests/test_embedding_retrieval.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import logging -from typing import List +from typing import TYPE_CHECKING import pytest -from cassandra.cluster import Session from ragstack_colbert import ( CassandraDatabase, ColbertEmbeddingModel, @@ -10,6 +11,9 @@ ) from ragstack_tests_utils import TestData +if TYPE_CHECKING: + from cassandra.cluster import Session + @pytest.mark.parametrize("session", ["cassandra", "astra_db"], indirect=["session"]) def test_embedding_cassandra_retriever(session: Session) -> None: @@ -20,7 +24,7 @@ def test_embedding_cassandra_retriever(session: Session) -> None: overlap_size = 50 # Function to generate chunks with the specified size and overlap - def chunk_texts(text: str, chunk_size: int, overlap_size: int) -> List[str]: + def chunk_texts(text: str, chunk_size: int, overlap_size: int) -> list[str]: texts = [] start = 0 end = chunk_size diff --git a/libs/colbert/tests/unit_tests/test_colbert_baseline_embeddings.py b/libs/colbert/tests/unit_tests/test_colbert_baseline_embeddings.py index 027037574..80721d672 100644 --- a/libs/colbert/tests/unit_tests/test_colbert_baseline_embeddings.py +++ b/libs/colbert/tests/unit_tests/test_colbert_baseline_embeddings.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import logging -from typing import List import torch from colbert.indexing.collection_encoder import CollectionEncoder @@ -42,7 +43,7 @@ # a uility function to evaluate similarity of two embeddings at per token level -def are_they_similar(embedded_chunks: List[Embedding], tensors: List[Tensor]) -> None: +def are_they_similar(embedded_chunks: list[Embedding], tensors: list[Tensor]) -> None: n = 0 pdist = torch.nn.PairwiseDistance(p=2) for embedding in embedded_chunks: @@ -78,7 +79,7 @@ def test_embeddings_with_baseline() -> None: please add to the model and implementions resultsed euclidian and cosine threshold change 2024-04-08 default model - https://huggingface.co/colbert-ir/colbertv2.0 """ # noqa: E501 - embeddings: List[Embedding] = colbert.embed_texts(arctic_botany_chunks) + embeddings: list[Embedding] = colbert.embed_texts(arctic_botany_chunks) pdist = torch.nn.PairwiseDistance(p=2) embedded_tensors = [] diff --git a/libs/e2e-tests/e2e_tests/langchain/nemo_guardrails.py b/libs/e2e-tests/e2e_tests/langchain/nemo_guardrails.py index 2a700ef52..6ae4fb40d 100644 --- a/libs/e2e-tests/e2e_tests/langchain/nemo_guardrails.py +++ b/libs/e2e-tests/e2e_tests/langchain/nemo_guardrails.py @@ -1,8 +1,9 @@ -from langchain.llms.base import BaseLLM +from __future__ import annotations + +from typing import TYPE_CHECKING + from langchain.prompts import PromptTemplate from langchain.schema.output_parser import StrOutputParser -from langchain.schema.retriever import BaseRetriever -from langchain.schema.vectorstore import VectorStore from nemoguardrails import LLMRails, RailsConfig from nemoguardrails.actions.actions import ActionResult @@ -11,6 +12,11 @@ SAMPLE_DATA, ) +if TYPE_CHECKING: + from langchain.llms.base import BaseLLM + from langchain.schema.retriever import BaseRetriever + from langchain.schema.vectorstore import VectorStore + def _config(engine, model) -> str: return f""" diff --git a/libs/e2e-tests/e2e_tests/langchain/rag_application.py b/libs/e2e-tests/e2e_tests/langchain/rag_application.py index 09053baf9..ca11d703a 100644 --- a/libs/e2e-tests/e2e_tests/langchain/rag_application.py +++ b/libs/e2e-tests/e2e_tests/langchain/rag_application.py @@ -1,7 +1,9 @@ +from __future__ import annotations + import logging import time from operator import itemgetter -from typing import Callable, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, Callable, Sequence from langchain import callbacks from langchain.chains import ConversationalRetrievalChain @@ -9,24 +11,26 @@ ConversationSummaryMemory, ) from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder, PromptTemplate -from langchain.schema import Document -from langchain.schema.language_model import BaseLanguageModel from langchain.schema.messages import AIMessage, HumanMessage from langchain.schema.output_parser import StrOutputParser -from langchain.schema.retriever import BaseRetriever from langchain.schema.runnable import ( Runnable, RunnableBranch, RunnableLambda, RunnableMap, ) -from langchain.schema.vectorstore import VectorStore -from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.tracers import ConsoleCallbackHandler from pydantic import BaseModel from e2e_tests.test_utils.tracing import record_langsmith_sharelink +if TYPE_CHECKING: + from langchain.schema import Document + from langchain.schema.language_model import BaseLanguageModel + from langchain.schema.retriever import BaseRetriever + from langchain.schema.vectorstore import VectorStore + from langchain_core.chat_history import BaseChatMessageHistory + BASIC_QA_PROMPT = """ Answer the question based only on the supplied context. If you don't know the answer, say the following: "I don't know the answer". Context: {context} @@ -86,7 +90,7 @@ class ChatRequest(BaseModel): question: str - chat_history: Optional[List[Dict[str, str]]] + chat_history: list[dict[str, str]] | None def create_retriever_chain( diff --git a/libs/e2e-tests/e2e_tests/langchain/test_astra.py b/libs/e2e-tests/e2e_tests/langchain/test_astra.py index b42e3ec4b..348b65750 100644 --- a/libs/e2e-tests/e2e_tests/langchain/test_astra.py +++ b/libs/e2e-tests/e2e_tests/langchain/test_astra.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import json import logging -from typing import List +from typing import TYPE_CHECKING import pytest from astrapy.api import APIRequestError @@ -9,7 +11,6 @@ from langchain_astradb import AstraDBVectorStore from langchain_core.documents import Document from langchain_core.runnables import RunnableConfig -from langchain_core.vectorstores import VectorStore from e2e_tests.conftest import ( is_astra, @@ -18,6 +19,9 @@ from e2e_tests.test_utils.astradb_vector_store_handler import AstraDBVectorStoreHandler from e2e_tests.test_utils.vector_store_handler import VectorStoreImplementation +if TYPE_CHECKING: + from langchain_core.vectorstores import VectorStore + MINIMUM_ACCEPTABLE_SCORE = 0.1 @@ -427,11 +431,11 @@ def __init__(self): def mock_embedding(text: str): return [len(text) / 2, len(text) / 5, len(text) / 10] - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed_documents(self, texts: list[str]) -> list[list[float]]: self.embedded_documents = texts return [self.mock_embedding(text) for text in texts] - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: self.embedded_query = text return self.mock_embedding(text) diff --git a/libs/e2e-tests/e2e_tests/langchain/test_compatibility_rag.py b/libs/e2e-tests/e2e_tests/langchain/test_compatibility_rag.py index 11b95594d..21872aecc 100644 --- a/libs/e2e-tests/e2e_tests/langchain/test_compatibility_rag.py +++ b/libs/e2e-tests/e2e_tests/langchain/test_compatibility_rag.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import logging -from typing import List import pytest from langchain import callbacks @@ -417,11 +418,11 @@ def test_multimodal(vector_store, embedding, llm, request, record_property): class FakeEmbeddings(Embeddings): @override - def embed_documents(self, texts: List[str]) -> List[List[float]]: + def embed_documents(self, texts: list[str]) -> list[list[float]]: return [[0.0] * embedding_size] * len(texts) @override - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: return [0.0] * embedding_size enhanced_vector_store = request.getfixturevalue( diff --git a/libs/e2e-tests/e2e_tests/langchain/trulens.py b/libs/e2e-tests/e2e_tests/langchain/trulens.py index 6d8923b88..b096b9b19 100644 --- a/libs/e2e-tests/e2e_tests/langchain/trulens.py +++ b/libs/e2e-tests/e2e_tests/langchain/trulens.py @@ -1,11 +1,11 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import numpy as np from langchain.prompts import PromptTemplate -from langchain.schema.language_model import BaseLanguageModel from langchain.schema.output_parser import StrOutputParser -from langchain.schema.runnable import Runnable -from langchain.schema.vectorstore import VectorStore from langchain_core.runnables import RunnablePassthrough -from langchain_core.vectorstores import VectorStoreRetriever from trulens_eval import Feedback, Tru, TruChain from trulens_eval.app import App from trulens_eval.feedback.provider import Langchain @@ -16,6 +16,12 @@ format_docs, ) +if TYPE_CHECKING: + from langchain.schema.language_model import BaseLanguageModel + from langchain.schema.runnable import Runnable + from langchain.schema.vectorstore import VectorStore + from langchain_core.vectorstores import VectorStoreRetriever + def _feedback_functions(chain: Runnable, llm: BaseLanguageModel) -> list[Feedback]: provider = Langchain(chain=llm) diff --git a/libs/e2e-tests/e2e_tests/langchain_llamaindex/test_astra.py b/libs/e2e-tests/e2e_tests/langchain_llamaindex/test_astra.py index 23533c958..4a892b278 100644 --- a/libs/e2e-tests/e2e_tests/langchain_llamaindex/test_astra.py +++ b/libs/e2e-tests/e2e_tests/langchain_llamaindex/test_astra.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from uuid import uuid4 import langchain_core.documents diff --git a/libs/e2e-tests/e2e_tests/llama_index/test_astra.py b/libs/e2e-tests/e2e_tests/llama_index/test_astra.py index cdf34189e..c030accc3 100644 --- a/libs/e2e-tests/e2e_tests/llama_index/test_astra.py +++ b/libs/e2e-tests/e2e_tests/llama_index/test_astra.py @@ -1,11 +1,12 @@ +from __future__ import annotations + import logging -from typing import List +from typing import TYPE_CHECKING import pytest from httpx import ConnectError, HTTPStatusError from llama_index.core import ServiceContext, StorageContext, VectorStoreIndex from llama_index.core.embeddings import BaseEmbedding -from llama_index.core.llms import LLM from llama_index.core.node_parser import SimpleNodeParser from llama_index.core.schema import Document, NodeWithScore from llama_index.core.vector_stores import ( @@ -23,6 +24,9 @@ from e2e_tests.test_utils.astradb_vector_store_handler import AstraDBVectorStoreHandler from e2e_tests.test_utils.vector_store_handler import VectorStoreImplementation +if TYPE_CHECKING: + from llama_index.core.llms import LLM + class Environment: def __init__( @@ -216,13 +220,13 @@ def environment() -> Environment: class MockEmbeddings(BaseEmbedding): - def _get_query_embedding(self, query: str) -> List[float]: + def _get_query_embedding(self, query: str) -> list[float]: return self.mock_embedding(query) - async def _aget_query_embedding(self, query: str) -> List[float]: + async def _aget_query_embedding(self, query: str) -> list[float]: return self.mock_embedding(query) - def _get_text_embedding(self, text: str) -> List[float]: + def _get_text_embedding(self, text: str) -> list[float]: return self.mock_embedding(text) @staticmethod diff --git a/libs/e2e-tests/e2e_tests/test_utils/astradb_vector_store_handler.py b/libs/e2e-tests/e2e_tests/test_utils/astradb_vector_store_handler.py index de1dd41be..3cc973a23 100644 --- a/libs/e2e-tests/e2e_tests/test_utils/astradb_vector_store_handler.py +++ b/libs/e2e-tests/e2e_tests/test_utils/astradb_vector_store_handler.py @@ -1,15 +1,16 @@ +from __future__ import annotations + import concurrent import logging import os import threading import time from dataclasses import dataclass, field -from typing import Callable, List +from typing import TYPE_CHECKING, Callable import cassio from langchain_astradb import AstraDBVectorStore as LangChainVectorStore from langchain_community.chat_message_histories import AstraDBChatMessageHistory -from langchain_core.chat_history import BaseChatMessageHistory try: from llama_index.vector_stores import AstraDBVectorStore @@ -36,6 +37,9 @@ VectorStoreTestContext, ) +if TYPE_CHECKING: + from langchain_core.chat_history import BaseChatMessageHistory + @dataclass class AstraRef: @@ -94,7 +98,7 @@ class EnhancedAstraDBLangChainVectorStore( EnhancedLangChainVectorStore, LangChainVectorStore ): def put_document( - self, doc_id: str, document: str, metadata: dict, vector: List[float] + self, doc_id: str, document: str, metadata: dict, vector: list[float] ) -> None: self.collection.insert_one( { @@ -105,7 +109,7 @@ def put_document( } ) - def search_documents(self, vector: List[float], limit: int) -> List[str]: + def search_documents(self, vector: list[float], limit: int) -> list[str]: return [ result["document"] for result in self.collection.vector_find( @@ -119,7 +123,7 @@ class EnhancedAstraDBLlamaIndexVectorStore( AstraDBVectorStore, EnhancedLlamaIndexVectorStore ): def put_document( - self, doc_id: str, document: str, metadata: dict, vector: List[float] + self, doc_id: str, document: str, metadata: dict, vector: list[float] ) -> None: self.client.insert_one( { @@ -130,7 +134,7 @@ def put_document( } ) - def search_documents(self, vector: List[float], limit: int) -> List[str]: + def search_documents(self, vector: list[float], limit: int) -> list[str]: return [ result["document"] for result in self.client.vector_find( @@ -141,7 +145,7 @@ def search_documents(self, vector: List[float], limit: int) -> List[str]: class AstraDBVectorStoreTestContext(VectorStoreTestContext): - def __init__(self, handler: "AstraDBVectorStoreHandler"): + def __init__(self, handler: AstraDBVectorStoreHandler): super().__init__() self.handler = handler self.test_id = "test_id" + random_string() diff --git a/libs/e2e-tests/e2e_tests/test_utils/cassandra_vector_store_handler.py b/libs/e2e-tests/e2e_tests/test_utils/cassandra_vector_store_handler.py index 2ac76795b..26f0506b5 100644 --- a/libs/e2e-tests/e2e_tests/test_utils/cassandra_vector_store_handler.py +++ b/libs/e2e-tests/e2e_tests/test_utils/cassandra_vector_store_handler.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import logging import os -from typing import List +from typing import TYPE_CHECKING import cassio from cassandra.auth import PlainTextAuthProvider @@ -10,7 +12,6 @@ CassandraChatMessageHistory, ) from langchain_community.vectorstores.cassandra import Cassandra -from langchain_core.chat_history import BaseChatMessageHistory from llama_index.core.schema import TextNode from llama_index.core.vector_stores.types import ( VectorStoreQuery, @@ -30,6 +31,9 @@ VectorStoreTestContext, ) +if TYPE_CHECKING: + from langchain_core.chat_history import BaseChatMessageHistory + class CassandraVectorStoreHandler(VectorStoreHandler): cassandra_container = None @@ -76,7 +80,7 @@ def before_test(self) -> VectorStoreTestContext: class EnhancedCassandraLangChainVectorStore(EnhancedLangChainVectorStore, Cassandra): def put_document( - self, doc_id: str, document: str, metadata: dict, vector: List[float] + self, doc_id: str, document: str, metadata: dict, vector: list[float] ) -> None: if isinstance(self.table, MetadataVectorCassandraTable): self.table.put( @@ -93,7 +97,7 @@ def put_document( metadata=metadata or {}, ) - def search_documents(self, vector: List[float], limit: int) -> List[str]: + def search_documents(self, vector: list[float], limit: int) -> list[str]: if isinstance(self.table, MetadataVectorCassandraTable): return [ result["body_blob"] @@ -109,13 +113,13 @@ class EnhancedCassandraLlamaIndexVectorStore( EnhancedLlamaIndexVectorStore, CassandraVectorStore ): def put_document( - self, doc_id: str, document: str, metadata: dict, vector: List[float] + self, doc_id: str, document: str, metadata: dict, vector: list[float] ) -> None: self.add( [TextNode(text=document, metadata=metadata, id_=doc_id, embedding=vector)] ) - def search_documents(self, vector: List[float], limit: int) -> List[str]: + def search_documents(self, vector: list[float], limit: int) -> list[str]: return self.query( VectorStoreQuery(query_embedding=vector, similarity_top_k=limit) ).ids diff --git a/libs/e2e-tests/e2e_tests/test_utils/vector_store_handler.py b/libs/e2e-tests/e2e_tests/test_utils/vector_store_handler.py index e06eb4c84..ad8901d30 100644 --- a/libs/e2e-tests/e2e_tests/test_utils/vector_store_handler.py +++ b/libs/e2e-tests/e2e_tests/test_utils/vector_store_handler.py @@ -1,12 +1,16 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from enum import Enum -from typing import List +from typing import TYPE_CHECKING -from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.vectorstores import VectorStore as LangChainVectorStore from e2e_tests.test_utils import skip_test_due_to_implementation_not_supported +if TYPE_CHECKING: + from langchain_core.chat_history import BaseChatMessageHistory + class VectorStoreImplementation(Enum): ASTRADB = "astradb" @@ -16,12 +20,12 @@ class VectorStoreImplementation(Enum): class EnhancedVectorStore(ABC): @abstractmethod def put_document( - self, doc_id: str, document: str, metadata: dict, vector: List[float] + self, doc_id: str, document: str, metadata: dict, vector: list[float] ) -> None: """Put a document""" @abstractmethod - def search_documents(self, vector: List[float], limit: int) -> List[str]: + def search_documents(self, vector: list[float], limit: int) -> list[str]: """Search documents""" @@ -53,7 +57,7 @@ class VectorStoreHandler(ABC): def __init__( self, implementation: VectorStoreImplementation, - supported_implementations: List[VectorStoreImplementation], + supported_implementations: list[VectorStoreImplementation], ): self.implementation = implementation self.supported_implementations = supported_implementations diff --git a/libs/knowledge-store/ragstack_knowledge_store/_mmr_helper.py b/libs/knowledge-store/ragstack_knowledge_store/_mmr_helper.py index bce94d7f3..c6b72d58c 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/_mmr_helper.py +++ b/libs/knowledge-store/ragstack_knowledge_store/_mmr_helper.py @@ -1,13 +1,17 @@ +from __future__ import annotations + import dataclasses -from typing import Dict, Iterable, List, Optional +from typing import TYPE_CHECKING, Iterable import numpy as np -from numpy.typing import NDArray from ragstack_knowledge_store.math import cosine_similarity +if TYPE_CHECKING: + from numpy.typing import NDArray + -def _emb_to_ndarray(embedding: List[float]) -> NDArray[np.float32]: +def _emb_to_ndarray(embedding: list[float]) -> NDArray[np.float32]: emb_array = np.array(embedding, dtype=np.float32) if emb_array.ndim == 1: emb_array = np.expand_dims(emb_array, axis=0) @@ -63,14 +67,14 @@ class MmrHelper: score_threshold: float """Only documents with a score greater than or equal to this will be chosen.""" - selected_ids: List[str] + selected_ids: list[str] """List of selected IDs (in selection order).""" selected_embeddings: NDArray[np.float32] """(N, dim) ndarray with a row for each selected node.""" - candidate_id_to_index: Dict[str, int] + candidate_id_to_index: dict[str, int] """Dictionary of candidate IDs to indices in candidates and candidate_embeddings.""" - candidates: List[_Candidate] + candidates: list[_Candidate] """List containing information about candidates. Same order as rows in `candidate_embeddings`. @@ -79,12 +83,12 @@ class MmrHelper: """(N, dim) ndarray with a row for each candidate.""" best_score: float - best_id: Optional[str] + best_id: str | None def __init__( self, k: int, - query_embedding: List[float], + query_embedding: list[float], lambda_mult: float = 0.5, score_threshold: float = NEG_INF, ) -> None: @@ -154,7 +158,7 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]: return embedding - def pop_best(self) -> Optional[str]: + def pop_best(self) -> str | None: """Select and pop the best item being considered. Updates the consideration set based on it. @@ -191,7 +195,7 @@ def pop_best(self) -> Optional[str]: return selected_id - def add_candidates(self, candidates: Dict[str, List[float]]) -> None: + def add_candidates(self, candidates: dict[str, list[float]]) -> None: """Add candidates to the consideration set.""" # Determine the keys to actually include. # These are the candidates that aren't already selected diff --git a/libs/knowledge-store/ragstack_knowledge_store/concurrency.py b/libs/knowledge-store/ragstack_knowledge_store/concurrency.py index 45ec75f32..838c39a36 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/concurrency.py +++ b/libs/knowledge-store/ragstack_knowledge_store/concurrency.py @@ -1,21 +1,23 @@ +from __future__ import annotations + import contextlib import logging import threading -from types import TracebackType from typing import ( + TYPE_CHECKING, Any, Callable, Literal, NamedTuple, - Optional, Protocol, Sequence, - Tuple, - Type, ) -from cassandra.cluster import ResponseFuture, Session -from cassandra.query import PreparedStatement +if TYPE_CHECKING: + from types import TracebackType + + from cassandra.cluster import ResponseFuture, Session + from cassandra.query import PreparedStatement logger = logging.getLogger(__name__) @@ -31,13 +33,13 @@ def __init__(self, session: Session) -> None: self._session = session self._completion = threading.Condition() self._pending = 0 - self._error: Optional[BaseException] = None + self._error: BaseException | None = None def _handle_result( self, result: Sequence[NamedTuple], future: ResponseFuture, - callback: Optional[Callable[[Sequence[NamedTuple]], Any]], + callback: Callable[[Sequence[NamedTuple]], Any] | None, ) -> None: if callback is not None: callback(result) @@ -63,8 +65,8 @@ def _handle_error(self, error: BaseException, future: ResponseFuture) -> None: def execute( self, query: PreparedStatement, - parameters: Optional[Tuple[Any, ...]] = None, - callback: Optional[_Callback] = None, + parameters: tuple[Any, ...] | None = None, + callback: _Callback | None = None, ) -> None: """Execute a query concurrently. @@ -97,14 +99,11 @@ def execute( }, ) - def __enter__(self) -> "ConcurrentQueries": - return super().__enter__() - def __exit__( self, - _exc_type: Optional[Type[BaseException]], - _exc_inst: Optional[BaseException], - _exc_traceback: Optional[TracebackType], + _exc_type: type[BaseException] | None, + _exc_inst: BaseException | None, + _exc_traceback: TracebackType | None, ) -> Literal[False]: with self._completion: while self._error is None and self._pending > 0: diff --git a/libs/knowledge-store/ragstack_knowledge_store/embedding_model.py b/libs/knowledge-store/ragstack_knowledge_store/embedding_model.py index 549814976..88cf0866c 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/embedding_model.py +++ b/libs/knowledge-store/ragstack_knowledge_store/embedding_model.py @@ -1,22 +1,23 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import List class EmbeddingModel(ABC): """Embedding model.""" @abstractmethod - def embed_texts(self, texts: List[str]) -> List[List[float]]: + def embed_texts(self, texts: list[str]) -> list[list[float]]: """Embed texts.""" @abstractmethod - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: """Embed query text.""" @abstractmethod - async def aembed_texts(self, texts: List[str]) -> List[List[float]]: + async def aembed_texts(self, texts: list[str]) -> list[list[float]]: """Embed texts.""" @abstractmethod - async def aembed_query(self, text: str) -> List[float]: + async def aembed_query(self, text: str) -> list[float]: """Embed query text.""" diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index b29accbbf..c8ed3c540 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging import re @@ -5,11 +7,11 @@ from dataclasses import asdict, dataclass, field, is_dataclass from enum import Enum from typing import ( + TYPE_CHECKING, Any, Dict, Iterable, List, - Optional, Sequence, Set, Tuple, @@ -23,9 +25,11 @@ from ._mmr_helper import MmrHelper from .concurrency import ConcurrentQueries from .content import Kind -from .embedding_model import EmbeddingModel from .links import Link +if TYPE_CHECKING: + from .embedding_model import EmbeddingModel + logger = logging.getLogger(__name__) CONTENT_ID = "content_id" @@ -43,11 +47,11 @@ class Node: text: str """Text contained by the node.""" - id: Optional[str] = None + id: str | None = None """Unique ID for the node. Will be generated by the GraphStore if not set.""" - metadata: Dict[str, Any] = field(default_factory=dict) + metadata: dict[str, Any] = field(default_factory=dict) """Metadata for the node.""" - links: Set[Link] = field(default_factory=set) + links: set[Link] = field(default_factory=set) """Links for the node.""" @@ -79,14 +83,14 @@ def _is_metadata_field_indexed(field_name: str, policy: MetadataIndexingPolicy) raise ValueError(f"Unexpected metadata indexing mode {p_mode}") -def _serialize_metadata(md: Dict[str, Any]) -> str: +def _serialize_metadata(md: dict[str, Any]) -> str: if isinstance(md.get("links"), Set): md = md.copy() md["links"] = list(md["links"]) return json.dumps(md) -def _serialize_links(links: Set[Link]) -> str: +def _serialize_links(links: set[Link]) -> str: class SetAndLinkEncoder(json.JSONEncoder): def default(self, obj: Any) -> Any: if is_dataclass(obj) and not isinstance(obj, type): @@ -104,13 +108,13 @@ def default(self, obj: Any) -> Any: return json.dumps(list(links), cls=SetAndLinkEncoder) -def _deserialize_metadata(json_blob: Optional[str]) -> Dict[str, Any]: +def _deserialize_metadata(json_blob: str | None) -> dict[str, Any]: # We don't need to convert the links list back to a set -- it will be # converted when accessed, if needed. return cast(Dict[str, Any], json.loads(json_blob or "")) -def _deserialize_links(json_blob: Optional[str]) -> Set[Link]: +def _deserialize_links(json_blob: str | None) -> set[Link]: return { Link(kind=link["kind"], direction=link["direction"], tag=link["tag"]) for link in cast(List[Dict[str, Any]], json.loads(json_blob or "")) @@ -134,8 +138,8 @@ def _row_to_node(row: Any) -> Node: @dataclass class _Edge: target_content_id: str - target_text_embedding: List[float] - target_link_to_tags: Set[Tuple[str, str]] + target_text_embedding: list[float] + target_link_to_tags: set[tuple[str, str]] class GraphStore: @@ -156,8 +160,8 @@ def __init__( *, node_table: str = "graph_nodes", targets_table: str = "", - session: Optional[Session] = None, - keyspace: Optional[str] = None, + session: Session | None = None, + keyspace: str | None = None, setup_mode: SetupMode = SetupMode.SYNC, metadata_indexing: MetadataIndexingType = "all", ): @@ -180,7 +184,7 @@ def __init__( self._node_table = node_table self._session = session self._keyspace = keyspace - self._prepared_query_cache: Dict[str, PreparedStatement] = {} + self._prepared_query_cache: dict[str, PreparedStatement] = {} self._metadata_indexing_policy = self._normalize_metadata_indexing_policy( metadata_indexing=metadata_indexing, @@ -280,10 +284,10 @@ def add_nodes( nodes: Iterable[Node], ) -> Iterable[str]: """Add nodes to the graph store.""" - node_ids: List[str] = [] - texts: List[str] = [] - metadatas: List[Dict[str, Any]] = [] - nodes_links: List[Set[Link]] = [] + node_ids: list[str] = [] + texts: list[str] = [] + metadatas: list[dict[str, Any]] = [] + nodes_links: list[set[Link]] = [] for node in nodes: if not node.id: node_ids.append(secrets.token_hex(8)) @@ -336,8 +340,8 @@ def add_nodes( def _nodes_with_ids( self, ids: Iterable[str], - ) -> List[Node]: - results: Dict[str, Optional[Node]] = {} + ) -> list[Node]: + results: dict[str, Node | None] = {} with self._concurrent_queries() as cq: def node_callback(rows: Iterable[Any]) -> None: @@ -373,7 +377,7 @@ def mmr_traversal_search( adjacent_k: int = 10, lambda_mult: float = 0.5, score_threshold: float = float("-inf"), - metadata_filter: Dict[str, Any] = {}, # noqa: B006 + metadata_filter: dict[str, Any] = {}, # noqa: B006 ) -> Iterable[Node]: """Retrieve documents from this graph store using MMR-traversal. @@ -410,7 +414,7 @@ def mmr_traversal_search( ) # For each unvisited node, stores the outgoing tags. - outgoing_tags: Dict[str, Set[Tuple[str, str]]] = {} + outgoing_tags: dict[str, set[tuple[str, str]]] = {} # Fetch the initial candidates and add them to the helper and # outgoing_tags. @@ -450,7 +454,7 @@ def fetch_initial_candidates() -> None: # Select the best item, K times. depths = {candidate_id: 0 for candidate_id in helper.candidate_ids()} - visited_tags: Set[Tuple[str, str]] = set() + visited_tags: set[tuple[str, str]] = set() for _ in range(k): selected_id = helper.pop_best() @@ -518,7 +522,7 @@ def traversal_search( *, k: int = 4, depth: int = 1, - metadata_filter: Dict[str, Any] = {}, # noqa: B006 + metadata_filter: dict[str, Any] = {}, # noqa: B006 ) -> Iterable[Node]: """Retrieve documents from this knowledge store. @@ -562,11 +566,11 @@ def traversal_search( with self._concurrent_queries() as cq: # Map from visited ID to depth - visited_ids: Dict[str, int] = {} + visited_ids: dict[str, int] = {} # Map from visited tag `(kind, tag)` to depth. Allows skipping queries # for tags that we've already traversed. - visited_tags: Dict[Tuple[str, str], int] = {} + visited_tags: dict[tuple[str, str], int] = {} def visit_nodes(d: int, nodes: Sequence[Any]) -> None: nonlocal visited_ids @@ -646,9 +650,9 @@ def visit_targets(d: int, targets: Sequence[Any]) -> None: def similarity_search( self, - embedding: List[float], + embedding: list[float], k: int = 4, - metadata_filter: Dict[str, Any] = {}, # noqa: B006 + metadata_filter: dict[str, Any] = {}, # noqa: B006 ) -> Iterable[Node]: """Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501 query, params = self._get_search_cql_and_params( @@ -660,7 +664,7 @@ def similarity_search( def metadata_search( self, - metadata: Dict[str, Any] = {}, # noqa: B006 + metadata: dict[str, Any] = {}, # noqa: B006 n: int = 5, ) -> Iterable[Node]: """Retrieve nodes based on their metadata.""" @@ -676,7 +680,7 @@ def get_node(self, content_id: str) -> Node: def _get_outgoing_tags( self, source_ids: Iterable[str], - ) -> Set[Tuple[str, str]]: + ) -> set[tuple[str, str]]: """Return the set of outgoing tags for the given source ID(s). Args: @@ -698,11 +702,11 @@ def add_sources(rows: Iterable[Any]) -> None: def _get_adjacent( self, - tags: Set[Tuple[str, str]], + tags: set[tuple[str, str]], adjacent_query: PreparedStatement, - query_embedding: List[float], - k_per_tag: Optional[int] = None, - metadata_filter: Dict[str, Any] = {}, # noqa: B006 + query_embedding: list[float], + k_per_tag: int | None = None, + metadata_filter: dict[str, Any] = {}, # noqa: B006 ) -> Iterable[_Edge]: """Return the target nodes with incoming links from any of the given tags. @@ -716,7 +720,7 @@ def _get_adjacent( Returns: List of adjacent edges. """ - targets: Dict[str, _Edge] = {} + targets: dict[str, _Edge] = {} def add_targets(rows: Iterable[Any]) -> None: # TODO: Figure out how to use the "kind" on the edge. @@ -752,10 +756,10 @@ def add_targets(rows: Iterable[Any]) -> None: @staticmethod def _normalize_metadata_indexing_policy( - metadata_indexing: Union[Tuple[str, Iterable[str]], str], + metadata_indexing: tuple[str, Iterable[str]] | str, ) -> MetadataIndexingPolicy: mode: MetadataIndexingMode - fields: Set[str] + fields: set[str] # metadata indexing policy normalization: if isinstance(metadata_indexing, str): if metadata_indexing.lower() == "all": @@ -812,10 +816,10 @@ def _coerce_string(value: Any) -> str: def _extract_where_clause_cql( self, - metadata_keys: List[str] = [], # noqa: B006 + metadata_keys: list[str] = [], # noqa: B006 has_link_from_tags: bool = False, ) -> str: - wc_blocks: List[str] = [] + wc_blocks: list[str] = [] if has_link_from_tags: wc_blocks.append("link_from_tags CONTAINS (?, ?)") @@ -835,10 +839,10 @@ def _extract_where_clause_cql( def _extract_where_clause_params( self, - metadata: Dict[str, Any], - link_from_tags: Optional[Tuple[str, str]] = None, - ) -> List[Any]: - params: List[Any] = [] + metadata: dict[str, Any], + link_from_tags: tuple[str, str] | None = None, + ) -> list[Any]: + params: list[Any] = [] if link_from_tags is not None: params.append(link_from_tags[0]) @@ -857,8 +861,8 @@ def _extract_where_clause_params( def _get_search_cql( self, has_limit: bool = False, - columns: Optional[str] = CONTENT_COLUMNS, - metadata_keys: List[str] = [], # noqa: B006 + columns: str | None = CONTENT_COLUMNS, + metadata_keys: list[str] = [], # noqa: B006 has_embedding: bool = False, has_link_from_tags: bool = False, ) -> PreparedStatement: @@ -887,11 +891,11 @@ def _get_search_cql( def _get_search_params( self, - limit: Optional[int] = None, - metadata: Dict[str, Any] = {}, # noqa: B006 - embedding: Optional[List[float]] = None, - link_from_tags: Optional[Tuple[str, str]] = None, - ) -> Tuple[PreparedStatement, Tuple[Any, ...]]: + limit: int | None = None, + metadata: dict[str, Any] = {}, # noqa: B006 + embedding: list[float] | None = None, + link_from_tags: tuple[str, str] | None = None, + ) -> tuple[PreparedStatement, tuple[Any, ...]]: where_params = self._extract_where_clause_params( metadata=metadata, link_from_tags=link_from_tags ) @@ -903,12 +907,12 @@ def _get_search_params( def _get_search_cql_and_params( self, - limit: Optional[int] = None, - columns: Optional[str] = CONTENT_COLUMNS, - metadata: Dict[str, Any] = {}, # noqa: B006 - embedding: Optional[List[float]] = None, - link_from_tags: Optional[Tuple[str, str]] = None, - ) -> Tuple[PreparedStatement, Tuple[Any, ...]]: + limit: int | None = None, + columns: str | None = CONTENT_COLUMNS, + metadata: dict[str, Any] = {}, # noqa: B006 + embedding: list[float] | None = None, + link_from_tags: tuple[str, str] | None = None, + ) -> tuple[PreparedStatement, tuple[Any, ...]]: query = self._get_search_cql( has_limit=limit is not None, columns=columns, diff --git a/libs/knowledge-store/tests/integration_tests/test_graph_store.py b/libs/knowledge-store/tests/integration_tests/test_graph_store.py index b69a51c39..ae8a3efe3 100644 --- a/libs/knowledge-store/tests/integration_tests/test_graph_store.py +++ b/libs/knowledge-store/tests/integration_tests/test_graph_store.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import math import secrets -from typing import Callable, Iterable, Iterator, List +from typing import Callable, Iterable, Iterator import numpy as np import pytest @@ -17,23 +19,23 @@ vector_size = 52 -def text_to_embedding(text: str) -> List[float]: +def text_to_embedding(text: str) -> list[float]: """Embeds text using a simple ascii conversion algorithm""" embedding = np.zeros(vector_size) for i, char in enumerate(text): if i >= vector_size - 2: break embedding[i + 2] = ord(char) / 255 # Normalize ASCII value - vector: List[float] = embedding.tolist() + vector: list[float] = embedding.tolist() return vector -def angle_to_embedding(angle: float) -> List[float]: +def angle_to_embedding(angle: float) -> list[float]: """Embeds angles onto a circle""" embedding = np.zeros(vector_size) embedding[0] = math.cos(angle * math.pi) embedding[1] = math.sin(angle * math.pi) - vector: List[float] = embedding.tolist() + vector: list[float] = embedding.tolist() return vector @@ -43,13 +45,13 @@ class SimpleEmbeddingModel(EmbeddingModel): a circle, and other text into a simple 50-dimension vector. """ - def embed_texts(self, texts: List[str]) -> List[List[float]]: + def embed_texts(self, texts: list[str]) -> list[list[float]]: """ Make a list of texts into a list of embedding vectors. """ return [self.embed_query(text) for text in texts] - def embed_query(self, text: str) -> List[float]: + def embed_query(self, text: str) -> list[float]: """ Convert input text to a 'vector' (list of floats). If the text is a number, use it as the angle for the @@ -63,13 +65,13 @@ def embed_query(self, text: str) -> List[float]: # Assume: just test string return text_to_embedding(text) - async def aembed_texts(self, texts: List[str]) -> List[List[float]]: + async def aembed_texts(self, texts: list[str]) -> list[list[float]]: """ Make a list of texts into a list of embedding vectors. """ return self.embed_texts(texts=texts) - async def aembed_query(self, text: str) -> List[float]: + async def aembed_query(self, text: str) -> list[float]: """ Convert input text to a 'vector' (list of floats). If the text is a number, use it as the angle for the @@ -115,7 +117,7 @@ def _make_graph_store( session.shutdown() -def _result_ids(nodes: Iterable[Node]) -> List[str]: +def _result_ids(nodes: Iterable[Node]) -> list[str]: return [n.id for n in nodes if n.id is not None] diff --git a/libs/knowledge-store/tests/unit_tests/test_graph_store.py b/libs/knowledge-store/tests/unit_tests/test_graph_store.py index ef86d1b5c..9c102e989 100644 --- a/libs/knowledge-store/tests/unit_tests/test_graph_store.py +++ b/libs/knowledge-store/tests/unit_tests/test_graph_store.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, Set +from __future__ import annotations + +from typing import Any from ragstack_knowledge_store.graph_store import ( _deserialize_links, @@ -10,7 +12,7 @@ def test_metadata_serialization() -> None: - def assert_roundtrip(metadata: Dict[str, Any]) -> None: + def assert_roundtrip(metadata: dict[str, Any]) -> None: serialized = _serialize_metadata(metadata) deserialized = _deserialize_metadata(serialized) assert metadata == deserialized @@ -20,7 +22,7 @@ def assert_roundtrip(metadata: Dict[str, Any]) -> None: def test_links_serialization() -> None: - def assert_roundtrip(links: Set[Link]) -> None: + def assert_roundtrip(links: set[Link]) -> None: serialized = _serialize_links(links) deserialized = _deserialize_links(serialized) assert links == deserialized diff --git a/libs/knowledge-store/tests/unit_tests/test_mmr_helper.py b/libs/knowledge-store/tests/unit_tests/test_mmr_helper.py index d8a8c3320..3171a133e 100644 --- a/libs/knowledge-store/tests/unit_tests/test_mmr_helper.py +++ b/libs/knowledge-store/tests/unit_tests/test_mmr_helper.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import math -from typing import List from ragstack_knowledge_store._mmr_helper import MmrHelper @@ -29,7 +30,7 @@ def test_mmr_helper_pop_best() -> None: assert helper.pop_best() is None -def angular_embedding(angle: float) -> List[float]: +def angular_embedding(angle: float) -> list[float]: return [math.cos(angle * math.pi), math.sin(angle * math.pi)] diff --git a/libs/langchain/ragstack_langchain/colbert/colbert_retriever.py b/libs/langchain/ragstack_langchain/colbert/colbert_retriever.py index f2b3bc472..df8998287 100644 --- a/libs/langchain/ragstack_langchain/colbert/colbert_retriever.py +++ b/libs/langchain/ragstack_langchain/colbert/colbert_retriever.py @@ -1,16 +1,18 @@ +from __future__ import annotations + from typing import TYPE_CHECKING, List, Optional, Tuple -from langchain_core.callbacks.manager import ( - AsyncCallbackManagerForRetrieverRun, - CallbackManagerForRetrieverRun, -) from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever -from ragstack_colbert.base_retriever import BaseRetriever as ColbertBaseRetriever from typing_extensions import override if TYPE_CHECKING: + from langchain_core.callbacks.manager import ( + AsyncCallbackManagerForRetrieverRun, + CallbackManagerForRetrieverRun, + ) from ragstack_colbert import Chunk + from ragstack_colbert.base_retriever import BaseRetriever as ColbertBaseRetriever class ColbertRetriever(BaseRetriever): diff --git a/libs/langchain/ragstack_langchain/colbert/embedding.py b/libs/langchain/ragstack_langchain/colbert/embedding.py index c2a932872..776465058 100644 --- a/libs/langchain/ragstack_langchain/colbert/embedding.py +++ b/libs/langchain/ragstack_langchain/colbert/embedding.py @@ -1,10 +1,14 @@ -from typing import List, Optional +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional from langchain_core.embeddings import Embeddings from ragstack_colbert import DEFAULT_COLBERT_MODEL, ColbertEmbeddingModel -from ragstack_colbert.base_embedding_model import BaseEmbeddingModel from typing_extensions import Self, override +if TYPE_CHECKING: + from ragstack_colbert.base_embedding_model import BaseEmbeddingModel + class TokensEmbeddings(Embeddings): """Adapter for token-based embedding models and the LangChain Embeddings.""" diff --git a/libs/langchain/tests/integration_tests/test_colbert.py b/libs/langchain/tests/integration_tests/test_colbert.py index 38c244934..2e1109bd6 100644 --- a/libs/langchain/tests/integration_tests/test_colbert.py +++ b/libs/langchain/tests/integration_tests/test_colbert.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import logging -from typing import List +from typing import TYPE_CHECKING, List import pytest -from cassandra.cluster import Session from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_core.documents import Document from ragstack_colbert import CassandraDatabase @@ -11,6 +12,9 @@ from ragstack_tests_utils import TestData from transformers import BertTokenizer +if TYPE_CHECKING: + from cassandra.cluster import Session + logging.getLogger("cassandra").setLevel(logging.ERROR) test_data = {} diff --git a/libs/llamaindex/ragstack_llamaindex/colbert/colbert_retriever.py b/libs/llamaindex/ragstack_llamaindex/colbert/colbert_retriever.py index 5b8e2ef7c..09238a849 100644 --- a/libs/llamaindex/ragstack_llamaindex/colbert/colbert_retriever.py +++ b/libs/llamaindex/ragstack_llamaindex/colbert/colbert_retriever.py @@ -1,13 +1,15 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from __future__ import annotations + +from typing import TYPE_CHECKING, Any -from llama_index.core.callbacks.base import CallbackManager from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.core.retrievers import BaseRetriever from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode -from ragstack_colbert.base_retriever import BaseRetriever as ColbertBaseRetriever if TYPE_CHECKING: + from llama_index.core.callbacks.base import CallbackManager from ragstack_colbert import Chunk + from ragstack_colbert.base_retriever import BaseRetriever as ColbertBaseRetriever class ColbertRetriever(BaseRetriever): @@ -20,14 +22,14 @@ class ColbertRetriever(BaseRetriever): _retriever: ColbertBaseRetriever _k: int - _query_maxlen: Optional[int] + _query_maxlen: int | None def __init__( self, retriever: ColbertBaseRetriever, similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, - callback_manager: Optional[CallbackManager] = None, - object_map: Optional[Dict[str, Any]] = None, + callback_manager: CallbackManager | None = None, + object_map: dict[str, Any] | None = None, verbose: bool = False, query_maxlen: int = -1, ) -> None: @@ -44,8 +46,8 @@ def __init__( def _retrieve( self, query_bundle: QueryBundle, - ) -> List[NodeWithScore]: - chunk_scores: List[Tuple[Chunk, float]] = self._retriever.text_search( + ) -> list[NodeWithScore]: + chunk_scores: list[tuple[Chunk, float]] = self._retriever.text_search( query_text=query_bundle.query_str, k=self._k, query_maxlen=self._query_maxlen, diff --git a/libs/llamaindex/tests/integration_tests/test_colbert.py b/libs/llamaindex/tests/integration_tests/test_colbert.py index c23df3e1c..1cf06876d 100644 --- a/libs/llamaindex/tests/integration_tests/test_colbert.py +++ b/libs/llamaindex/tests/integration_tests/test_colbert.py @@ -1,8 +1,9 @@ +from __future__ import annotations + import logging -from typing import Dict, List, Tuple +from typing import TYPE_CHECKING import pytest -from cassandra.cluster import Session from llama_index.core import Settings, get_response_synthesizer from llama_index.core.ingestion import IngestionPipeline from llama_index.core.llms import MockLLM @@ -18,10 +19,13 @@ from ragstack_llamaindex.colbert import ColbertRetriever from ragstack_tests_utils import TestData +if TYPE_CHECKING: + from cassandra.cluster import Session + logging.getLogger("cassandra").setLevel(logging.ERROR) -def validate_retrieval(results: List[NodeWithScore], key_value: str) -> bool: +def validate_retrieval(results: list[NodeWithScore], key_value: str) -> bool: passed = False for result in results: if key_value in result.text: @@ -65,7 +69,7 @@ def test_sync(session: Session) -> None: nodes = pipeline.run(documents=docs) - docs2: Dict[str, Tuple[List[str], List[Metadata]]] = {} + docs2: dict[str, tuple[list[str], list[Metadata]]] = {} for node in nodes: doc_id = node.metadata["name"] diff --git a/libs/ragulate/colbert_chunk_size_and_k.py b/libs/ragulate/colbert_chunk_size_and_k.py index a58660415..74aa6f65e 100644 --- a/libs/ragulate/colbert_chunk_size_and_k.py +++ b/libs/ragulate/colbert_chunk_size_and_k.py @@ -1,9 +1,11 @@ # ruff: noqa: D103, INP001, T201 +from __future__ import annotations + import logging import os import time from pathlib import Path -from typing import Any, List +from typing import Any from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain_community.document_loaders import UnstructuredFileLoader @@ -114,7 +116,7 @@ async def ingest(file_path: str, chunk_size: int, **_: Any) -> None: await colbert_vector_store.adelete_chunks(doc_ids=[doc_id]) - chunks: List[Chunk] = [] + chunks: list[Chunk] = [] for i, doc in enumerate(chunked_docs): chunks.append( Chunk( diff --git a/libs/ragulate/ragstack_ragulate/analysis.py b/libs/ragulate/ragstack_ragulate/analysis.py index c62f90450..e2ea76cf0 100644 --- a/libs/ragulate/ragstack_ragulate/analysis.py +++ b/libs/ragulate/ragstack_ragulate/analysis.py @@ -1,4 +1,6 @@ -from typing import Any, Dict, List, Tuple +from __future__ import annotations + +from typing import Any import matplotlib.pyplot as plt import numpy as np @@ -13,11 +15,11 @@ class Analysis: """Analysis class.""" - def get_all_data(self, recipes: List[str]) -> Tuple[pd.DataFrame, List[str]]: + def get_all_data(self, recipes: list[str]) -> tuple[pd.DataFrame, list[str]]: """Get all data from the recipes.""" df_all = pd.DataFrame() - all_metrics: List[str] = [] + all_metrics: list[str] = [] for recipe in recipes: tru = get_tru(recipe_name=recipe) @@ -55,10 +57,10 @@ def get_all_data(self, recipes: List[str]) -> Tuple[pd.DataFrame, List[str]]: return reset_df, list(set(all_metrics)) def calculate_statistics( - self, df: pd.DataFrame, metrics: List[str] - ) -> Dict[str, Any]: + self, df: pd.DataFrame, metrics: list[str] + ) -> dict[str, Any]: """Calculate statistics.""" - stats: Dict[str, Any] = {} + stats: dict[str, Any] = {} for recipe in df["recipe"].unique(): stats[recipe] = {} for metric in metrics: @@ -77,7 +79,7 @@ def calculate_statistics( } return stats - def output_box_plots_by_dataset(self, df: pd.DataFrame, metrics: List[str]) -> None: + def output_box_plots_by_dataset(self, df: pd.DataFrame, metrics: list[str]) -> None: """Output box plots by dataset.""" stats = self.calculate_statistics(df, metrics) recipes = sorted(df["recipe"].unique(), key=lambda x: x.lower()) @@ -160,7 +162,7 @@ def output_box_plots_by_dataset(self, df: pd.DataFrame, metrics: List[str]) -> N write_image(fig, f"./{dataset}_box_plot.png") def output_histograms_by_dataset( - self, df: pd.DataFrame, metrics: List[str] + self, df: pd.DataFrame, metrics: list[str] ) -> None: """Output histograms by dataset.""" # Append "latency" to the metrics list @@ -186,7 +188,7 @@ def output_histograms_by_dataset( sns.set_theme(style="darkgrid") # Custom function to set bin ranges and filter invalid values - def custom_hist(data: Dict[str, Any], **kws: Any) -> None: + def custom_hist(data: dict[str, Any], **kws: Any) -> None: metric = data["metric"].iloc[0] data = data[ np.isfinite(data["value"]) @@ -253,7 +255,7 @@ def custom_hist(data: Dict[str, Any], **kws: Any) -> None: # Close the plot to avoid displaying it plt.close() - def compare(self, recipes: List[str], output: str = "box-plots") -> None: + def compare(self, recipes: list[str], output: str = "box-plots") -> None: """Compare results from 2 (or more) recipes.""" df, metrics = self.get_all_data(recipes=recipes) if output == "box-plots": diff --git a/libs/ragulate/ragstack_ragulate/cli_commands/compare.py b/libs/ragulate/ragstack_ragulate/cli_commands/compare.py index 7d1dcf9c3..15fb8175d 100644 --- a/libs/ragulate/ragstack_ragulate/cli_commands/compare.py +++ b/libs/ragulate/ragstack_ragulate/cli_commands/compare.py @@ -1,10 +1,14 @@ -from argparse import ArgumentParser, _SubParsersAction -from typing import Any, List +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from ragstack_ragulate.analysis import Analysis from .utils import remove_sqlite_extension +if TYPE_CHECKING: + from argparse import ArgumentParser, _SubParsersAction + def setup_compare(subparsers: _SubParsersAction[ArgumentParser]) -> None: """Setup the compare command.""" @@ -30,7 +34,7 @@ def setup_compare(subparsers: _SubParsersAction[ArgumentParser]) -> None: def call_compare( - recipe: List[str], + recipe: list[str], output: str = "box-plots", **_: Any, ) -> None: diff --git a/libs/ragulate/ragstack_ragulate/cli_commands/ingest.py b/libs/ragulate/ragstack_ragulate/cli_commands/ingest.py index e83ea0771..4103180e2 100644 --- a/libs/ragulate/ragstack_ragulate/cli_commands/ingest.py +++ b/libs/ragulate/ragstack_ragulate/cli_commands/ingest.py @@ -1,10 +1,14 @@ -from argparse import ArgumentParser, _SubParsersAction -from typing import Any, List +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from ragstack_ragulate.datasets import find_dataset from ragstack_ragulate.pipelines import IngestPipeline from ragstack_ragulate.utils import convert_vars_to_ingredients +if TYPE_CHECKING: + from argparse import ArgumentParser, _SubParsersAction + def setup_ingest(subparsers: _SubParsersAction[ArgumentParser]) -> None: """Setup the ingest command.""" @@ -56,9 +60,9 @@ def call_ingest( name: str, script_path: str, method_name: str, - var_name: List[str], - var_value: List[str], - dataset: List[str], + var_name: list[str], + var_value: list[str], + dataset: list[str], **_: Any, ) -> None: """Run an ingest pipeline.""" diff --git a/libs/ragulate/ragstack_ragulate/cli_commands/query.py b/libs/ragulate/ragstack_ragulate/cli_commands/query.py index 4965554e9..2edca9b63 100644 --- a/libs/ragulate/ragstack_ragulate/cli_commands/query.py +++ b/libs/ragulate/ragstack_ragulate/cli_commands/query.py @@ -1,10 +1,14 @@ -from argparse import ArgumentParser, _SubParsersAction -from typing import Any, List +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from ragstack_ragulate.datasets import find_dataset from ragstack_ragulate.pipelines import QueryPipeline from ragstack_ragulate.utils import convert_vars_to_ingredients +if TYPE_CHECKING: + from argparse import ArgumentParser, _SubParsersAction + def setup_query(subparsers: _SubParsersAction[ArgumentParser]) -> None: """Setup the query command.""" @@ -102,10 +106,10 @@ def call_query( name: str, script: str, method: str, - var_name: List[str], - var_value: List[str], - dataset: List[str], - subset: List[str], + var_name: list[str], + var_value: list[str], + dataset: list[str], + subset: list[str], sample: float, seed: int, restart: bool, diff --git a/libs/ragulate/ragstack_ragulate/cli_commands/run.py b/libs/ragulate/ragstack_ragulate/cli_commands/run.py index 2dc52ccb4..cb0fd0128 100644 --- a/libs/ragulate/ragstack_ragulate/cli_commands/run.py +++ b/libs/ragulate/ragstack_ragulate/cli_commands/run.py @@ -1,11 +1,15 @@ -from argparse import ArgumentParser, _SubParsersAction -from typing import Any, List +from __future__ import annotations + +from typing import TYPE_CHECKING, Any from ragstack_ragulate.analysis import Analysis from ragstack_ragulate.config import ConfigParser from ragstack_ragulate.logging_config import logger from ragstack_ragulate.pipelines import IngestPipeline, QueryPipeline +if TYPE_CHECKING: + from argparse import ArgumentParser, _SubParsersAction + def setup_run(subparsers: _SubParsersAction[ArgumentParser]) -> None: """Setup the run command.""" @@ -28,8 +32,8 @@ def call_run(config_file: str, **_: Any) -> None: config_parser = ConfigParser.from_file(file_path=config_file) config = config_parser.get_config() - ingest_pipelines: List[IngestPipeline] = [] - query_pipelines: List[QueryPipeline] = [] + ingest_pipelines: list[IngestPipeline] = [] + query_pipelines: list[QueryPipeline] = [] for dataset in config.datasets.values(): dataset.download_dataset() diff --git a/libs/ragulate/ragstack_ragulate/config/base_config_schema.py b/libs/ragulate/ragstack_ragulate/config/base_config_schema.py index 681e9bd2b..bb9f03cf1 100644 --- a/libs/ragulate/ragstack_ragulate/config/base_config_schema.py +++ b/libs/ragulate/ragstack_ragulate/config/base_config_schema.py @@ -1,7 +1,10 @@ +from __future__ import annotations + from abc import ABC, abstractmethod -from typing import Any, Dict +from typing import TYPE_CHECKING, Any -from .objects import Config +if TYPE_CHECKING: + from .objects import Config class BaseConfigSchema(ABC): @@ -12,9 +15,9 @@ def version(self) -> float: """Returns the config file version.""" @abstractmethod - def schema(self) -> Dict[str, Any]: + def schema(self) -> dict[str, Any]: """Returns the config file schema.""" @abstractmethod - def parse_document(self, document: Dict[str, Any]) -> Config: + def parse_document(self, document: dict[str, Any]) -> Config: """Parses a validated config file and returns a Config object.""" diff --git a/libs/ragulate/ragstack_ragulate/config/config_parser.py b/libs/ragulate/ragstack_ragulate/config/config_parser.py index 5f5826779..527fac8fe 100644 --- a/libs/ragulate/ragstack_ragulate/config/config_parser.py +++ b/libs/ragulate/ragstack_ragulate/config/config_parser.py @@ -1,12 +1,16 @@ -from typing import Any, Dict +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import yaml from cerberus import Validator -from .base_config_schema import BaseConfigSchema from .config_schema_0_1 import _VERSION_0_1, ConfigSchema0Dot1 from .objects import Config +if TYPE_CHECKING: + from .base_config_schema import BaseConfigSchema + class ConfigParser: """Config parser.""" @@ -16,7 +20,7 @@ class ConfigParser: _errors: Any _document: Any - def __init__(self, config_schema: BaseConfigSchema, config: Dict[str, Any]): + def __init__(self, config_schema: BaseConfigSchema, config: dict[str, Any]): self._config_schema = config_schema validator = Validator(config_schema.schema()) self.is_valid = validator.validate(config) @@ -30,7 +34,7 @@ def get_config(self) -> Config: return self._config_schema.parse_document(self._document) @classmethod - def from_file(cls, file_path: str) -> "ConfigParser": + def from_file(cls, file_path: str) -> ConfigParser: """Create a ConfigParser from a file.""" with open(file_path) as file: config = yaml.safe_load(file) diff --git a/libs/ragulate/ragstack_ragulate/config/config_schema_0_1.py b/libs/ragulate/ragstack_ragulate/config/config_schema_0_1.py index 36f75cc65..d07edd525 100644 --- a/libs/ragulate/ragstack_ragulate/config/config_schema_0_1.py +++ b/libs/ragulate/ragstack_ragulate/config/config_schema_0_1.py @@ -1,4 +1,6 @@ -from typing import Any, Dict +from __future__ import annotations + +from typing import Any from typing_extensions import override @@ -19,7 +21,7 @@ def version(self) -> float: return _VERSION_0_1 @override - def schema(self) -> Dict[str, Any]: + def schema(self) -> dict[str, Any]: step_list = { "type": "list", "schema": { @@ -142,10 +144,10 @@ def schema(self) -> Dict[str, Any]: } @override - def parse_document(self, document: Dict[str, Any]) -> Config: - ingest_steps: Dict[str, Step] = {} - query_steps: Dict[str, Step] = {} - cleanup_steps: Dict[str, Step] = {} + def parse_document(self, document: dict[str, Any]) -> Config: + ingest_steps: dict[str, Step] = {} + query_steps: dict[str, Step] = {} + cleanup_steps: dict[str, Step] = {} step_map = { "ingest": ingest_steps, @@ -168,12 +170,12 @@ def parse_document(self, document: Dict[str, Any]) -> Config: name=doc_name, script=doc_script, method=doc_method ) - recipes: Dict[str, Recipe] = {} + recipes: dict[str, Recipe] = {} doc_recipes = document.get("recipes", {}) for doc_recipe in doc_recipes: doc_ingredients = doc_recipe.get("ingredients", {}) - ingredients: Dict[str, Any] = {} + ingredients: dict[str, Any] = {} for doc_ingredient in doc_ingredients: for key, value in doc_ingredient.items(): @@ -195,7 +197,7 @@ def parse_document(self, document: Dict[str, Any]) -> Config: else: recipe_name = doc_name - recipe_steps: Dict[str, Step] = {} + recipe_steps: dict[str, Step] = {} for step_kind in step_map: doc_recipe_step = doc_recipe.get(step_kind, None) @@ -224,7 +226,7 @@ def parse_document(self, document: Dict[str, Any]) -> Config: ingredients=ingredients, ) - datasets: Dict[str, BaseDataset] = {} + datasets: dict[str, BaseDataset] = {} for doc_dataset in document.get("datasets", []): if isinstance(doc_dataset, str): diff --git a/libs/ragulate/ragstack_ragulate/config/objects.py b/libs/ragulate/ragstack_ragulate/config/objects.py index 8c44d7e39..dce4d660d 100644 --- a/libs/ragulate/ragstack_ragulate/config/objects.py +++ b/libs/ragulate/ragstack_ragulate/config/objects.py @@ -1,8 +1,10 @@ -from typing import Any, Dict +from __future__ import annotations + +from typing import Any from pydantic import BaseModel -from ragstack_ragulate.datasets import BaseDataset +from ragstack_ragulate.datasets import BaseDataset # noqa: TCH001 class Step(BaseModel): @@ -20,7 +22,7 @@ class Recipe(BaseModel): ingest: Step | None query: Step cleanup: Step | None - ingredients: Dict[str, Any] + ingredients: dict[str, Any] class Config(BaseModel): @@ -31,5 +33,5 @@ class Config: arbitrary_types_allowed = True - recipes: Dict[str, Recipe] = {} - datasets: Dict[str, BaseDataset] = {} + recipes: dict[str, Recipe] = {} + datasets: dict[str, BaseDataset] = {} diff --git a/libs/ragulate/ragstack_ragulate/config/utils.py b/libs/ragulate/ragstack_ragulate/config/utils.py index e8f2c4923..e71ee8d03 100644 --- a/libs/ragulate/ragstack_ragulate/config/utils.py +++ b/libs/ragulate/ragstack_ragulate/config/utils.py @@ -1,7 +1,9 @@ -from typing import Any, Dict +from __future__ import annotations +from typing import Any -def dict_to_string(d: Dict[str, Any]) -> str: + +def dict_to_string(d: dict[str, Any]) -> str: """Convert dictionary to string.""" parts = [] diff --git a/libs/ragulate/ragstack_ragulate/dashboard.py b/libs/ragulate/ragstack_ragulate/dashboard.py index 3511e7828..4eedb75be 100644 --- a/libs/ragulate/ragstack_ragulate/dashboard.py +++ b/libs/ragulate/ragstack_ragulate/dashboard.py @@ -1,9 +1,9 @@ -from typing import Optional +from __future__ import annotations from .utils import get_tru -def run_dashboard(recipe_name: str, port: Optional[int] = 8501) -> None: +def run_dashboard(recipe_name: str, port: int | None = 8501) -> None: """Runs the TruLens dashboard.""" tru = get_tru(recipe_name=recipe_name) tru.run_dashboard(port=port) diff --git a/libs/ragulate/ragstack_ragulate/pipelines/base_pipeline.py b/libs/ragulate/ragstack_ragulate/pipelines/base_pipeline.py index e65bee535..d93a4a6ed 100644 --- a/libs/ragulate/ragstack_ragulate/pipelines/base_pipeline.py +++ b/libs/ragulate/ragstack_ragulate/pipelines/base_pipeline.py @@ -1,12 +1,16 @@ +from __future__ import annotations + import importlib.util import inspect import logging import sys from abc import ABC, abstractmethod -from types import ModuleType -from typing import Any, Dict, List +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from types import ModuleType -from ragstack_ragulate.datasets import BaseDataset + from ragstack_ragulate.datasets import BaseDataset def load_module(file_path: str, name: str) -> ModuleType: @@ -27,17 +31,17 @@ def get_method(script_path: str, pipeline_type: str, method_name: str) -> Any: return getattr(module, method_name) -def get_method_params(method: Any) -> List[str]: +def get_method_params(method: Any) -> list[str]: """Return the parameters of a method.""" signature = inspect.signature(method) return list(signature.parameters.keys()) def get_ingredients( - method_params: List[str], - reserved_params: List[str], - passed_ingredients: Dict[str, Any], -) -> Dict[str, Any]: + method_params: list[str], + reserved_params: list[str], + passed_ingredients: dict[str, Any], +) -> dict[str, Any]: """Return ingredients for the given method params.""" ingredients = {} for method_param in method_params: @@ -59,10 +63,10 @@ class BasePipeline(ABC): script_path: str method_name: str _method: Any - _method_params: List[str] - _passed_ingredients: Dict[str, Any] - ingredients: Dict[str, Any] - datasets: List[BaseDataset] + _method_params: list[str] + _passed_ingredients: dict[str, Any] + ingredients: dict[str, Any] + datasets: list[BaseDataset] @property @abstractmethod @@ -71,7 +75,7 @@ def pipeline_type(self) -> str: @property @abstractmethod - def get_reserved_params(self) -> List[str]: + def get_reserved_params(self) -> list[str]: """Get the list of reserved parameter names for this pipeline type.""" def __init__( @@ -79,8 +83,8 @@ def __init__( recipe_name: str, script_path: str, method_name: str, - ingredients: Dict[str, Any], - datasets: List[BaseDataset], + ingredients: dict[str, Any], + datasets: list[BaseDataset], ): self.recipe_name = recipe_name self.script_path = script_path @@ -115,7 +119,7 @@ def get_method(self) -> Any: """Return the pipeline method.""" return self._method - def dataset_names(self) -> List[str]: + def dataset_names(self) -> list[str]: """Return the names of the datasets.""" return [d.name for d in self.datasets] diff --git a/libs/ragulate/ragstack_ragulate/pipelines/feedbacks.py b/libs/ragulate/ragstack_ragulate/pipelines/feedbacks.py index 5b21fb161..6313c3cc5 100644 --- a/libs/ragulate/ragstack_ragulate/pipelines/feedbacks.py +++ b/libs/ragulate/ragstack_ragulate/pipelines/feedbacks.py @@ -1,11 +1,15 @@ -from typing import Any, Dict, List +from __future__ import annotations + +from typing import TYPE_CHECKING, Any import numpy as np from trulens_eval import Feedback from trulens_eval.app import App from trulens_eval.feedback import GroundTruthAgreement -from trulens_eval.feedback.provider.base import LLMProvider -from trulens_eval.utils.serial import Lens + +if TYPE_CHECKING: + from trulens_eval.feedback.provider.base import LLMProvider + from trulens_eval.utils.serial import Lens class Feedbacks: @@ -47,7 +51,7 @@ def context_relevance(self) -> Feedback: .aggregate(np.mean) ) - def answer_correctness(self, golden_set: List[Dict[str, str]]) -> Feedback: + def answer_correctness(self, golden_set: list[dict[str, str]]) -> Feedback: """Return answer correctness feedback.""" # GroundTruth for comparing the Answer to the Ground-Truth Answer ground_truth_collection = GroundTruthAgreement( diff --git a/libs/ragulate/ragstack_ragulate/pipelines/ingest_pipeline.py b/libs/ragulate/ragstack_ragulate/pipelines/ingest_pipeline.py index e8e4fbb87..43dd07146 100644 --- a/libs/ragulate/ragstack_ragulate/pipelines/ingest_pipeline.py +++ b/libs/ragulate/ragstack_ragulate/pipelines/ingest_pipeline.py @@ -1,5 +1,6 @@ +from __future__ import annotations + import asyncio -from typing import List from tqdm import tqdm from typing_extensions import override @@ -19,7 +20,7 @@ def pipeline_type(self) -> str: @property @override - def get_reserved_params(self) -> List[str]: + def get_reserved_params(self) -> list[str]: return ["file_path"] def ingest(self) -> None: diff --git a/libs/ragulate/ragstack_ragulate/pipelines/query_pipeline.py b/libs/ragulate/ragstack_ragulate/pipelines/query_pipeline.py index c34fecc93..4360e0a68 100644 --- a/libs/ragulate/ragstack_ragulate/pipelines/query_pipeline.py +++ b/libs/ragulate/ragstack_ragulate/pipelines/query_pipeline.py @@ -1,8 +1,10 @@ # ruff: noqa: T201 +from __future__ import annotations + import random import signal import time -from typing import Any, Dict, List, Optional +from typing import TYPE_CHECKING, Any from tqdm import tqdm from trulens_eval import Tru, TruChain @@ -10,13 +12,15 @@ from trulens_eval.schema.feedback import FeedbackMode, FeedbackResultStatus from typing_extensions import override -from ragstack_ragulate.datasets import BaseDataset from ragstack_ragulate.logging_config import logger from ragstack_ragulate.utils import get_tru from .base_pipeline import BasePipeline from .feedbacks import Feedbacks +if TYPE_CHECKING: + from ragstack_ragulate.datasets import BaseDataset + class QueryPipeline(BasePipeline): """Query pipeline.""" @@ -26,8 +30,8 @@ class QueryPipeline(BasePipeline): _tru: Tru _name: str _progress: tqdm - _queries: Dict[str, List[str]] - _golden_sets: Dict[str, List[Dict[str, str]]] + _queries: dict[str, list[str]] + _golden_sets: dict[str, list[dict[str, str]]] _total_queries: int = 0 _total_feedbacks: int = 0 _finished_feedbacks: int = 0 @@ -41,7 +45,7 @@ def pipeline_type(self) -> str: @property @override - def get_reserved_params(self) -> List[str]: + def get_reserved_params(self) -> list[str]: return [] def __init__( @@ -49,13 +53,13 @@ def __init__( recipe_name: str, script_path: str, method_name: str, - ingredients: Dict[str, Any], - datasets: List[BaseDataset], + ingredients: dict[str, Any], + datasets: list[BaseDataset], sample_percent: float = 1.0, - random_seed: Optional[int] = None, + random_seed: int | None = None, restart_pipeline: bool = False, llm_provider: str = "OpenAI", - model_name: Optional[str] = None, + model_name: str | None = None, ): self._queries = {} self._golden_sets = {} diff --git a/libs/ragulate/ragstack_ragulate/utils.py b/libs/ragulate/ragstack_ragulate/utils.py index 15f780633..d62df66ed 100644 --- a/libs/ragulate/ragstack_ragulate/utils.py +++ b/libs/ragulate/ragstack_ragulate/utils.py @@ -1,5 +1,7 @@ +from __future__ import annotations + import re -from typing import Any, Dict, List, Union +from typing import Any from trulens_eval import Tru @@ -14,16 +16,16 @@ def get_tru(recipe_name: str) -> Tru: def convert_vars_to_ingredients( - var_names: List[str], var_values: List[str] -) -> Dict[str, Any]: + var_names: list[str], var_values: list[str] +) -> dict[str, Any]: """Convert variables to ingredients.""" - params: Dict[str, Any] = {} + params: dict[str, Any] = {} for i, name in enumerate(var_names): params[name] = _convert_string(var_values[i]) return params -def _convert_string(s: str) -> Union[str, int, float]: +def _convert_string(s: str) -> str | int | float: s = s.strip() if re.match(r"^\d+$", s): return int(s) diff --git a/libs/tests-utils/ragstack_tests_utils/test_store.py b/libs/tests-utils/ragstack_tests_utils/test_store.py index c7b193672..1abff93fa 100644 --- a/libs/tests-utils/ragstack_tests_utils/test_store.py +++ b/libs/tests-utils/ragstack_tests_utils/test_store.py @@ -1,7 +1,8 @@ +from __future__ import annotations + import logging import os from abc import ABC, abstractmethod -from typing import Optional import cassio from cassandra.cluster import Cluster, PlainTextAuthProvider, Session @@ -46,8 +47,8 @@ def create_cassandra_session(self) -> Session: class AstraDBTestStore(TestStore): - token: Optional[str] - database_id: Optional[str] + token: str | None + database_id: str | None env: str def __init__(self) -> None: diff --git a/pyproject.toml b/pyproject.toml index 717c360df..11efcdde8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -87,6 +87,7 @@ select = [ "E", "EXE", "F", + "FA", "FLY", "FURB", "G", diff --git a/scripts/generate-testspace-report.py b/scripts/generate-testspace-report.py index 74872fa11..9024675cd 100755 --- a/scripts/generate-testspace-report.py +++ b/scripts/generate-testspace-report.py @@ -1,11 +1,11 @@ #!/usr/bin/env python +from __future__ import annotations import json import os.path import sys import xml from dataclasses import dataclass -from typing import List from xml.etree.ElementTree import Element, ElementTree, SubElement, parse @@ -43,15 +43,15 @@ class TestCase: name: str passed: bool time: str - links: List[Link] - failures: List[Failure] + links: list[Link] + failures: list[Failure] @dataclass class TestSuite: name: str - test_cases: List[TestCase] - links: List[Link] + test_cases: list[TestCase] + links: list[Link] def unsafe_escape_data(text):