Skip to content

Commit

Permalink
Add ruff rule for Error Messages (EM)
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Aug 23, 2024
1 parent 9e25b78 commit 3e056e5
Show file tree
Hide file tree
Showing 27 changed files with 126 additions and 101 deletions.
21 changes: 10 additions & 11 deletions examples/evaluation/tru_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def get_recorder(
feedbacks=feedbacks,
feedback_mode=feedback_mode,
)
raise ValueError(f"Unknown framework: {framework} specified for get_recorder()")
msg = f"Unknown framework: {framework} specified for get_recorder()"
raise ValueError(msg)


def get_azure_chat_model(
Expand All @@ -151,7 +152,8 @@ def get_azure_chat_model(
model_version=model_version,
temperature=temperature,
)
raise ValueError(f"Unknown framework: {framework} specified for getChatModel()")
msg = f"Unknown framework: {framework} specified for getChatModel()"
raise ValueError(msg)


def get_azure_embeddings_model(framework: Framework):
Expand All @@ -167,9 +169,8 @@ def get_azure_embeddings_model(framework: Framework):
api_version="2023-05-15",
temperature=temperature,
)
raise ValueError(
f"Unknown framework: {framework} specified for getEmbeddingsModel()"
)
msg = f"Unknown framework: {framework} specified for getEmbeddingsModel()"
raise ValueError(msg)


def get_astra_vector_store(framework: Framework, collection_name: str):
Expand All @@ -187,9 +188,8 @@ def get_astra_vector_store(framework: Framework, collection_name: str):
token=os.getenv("ASTRA_DB_APPLICATION_TOKEN"),
embedding_dimension=1536,
)
raise ValueError(
f"Unknown framework: {framework} specified for get_astra_vector_store()"
)
msg = f"Unknown framework: {framework} specified for get_astra_vector_store()"
raise ValueError(msg)


def execute_query(framework: Framework, pipeline, query) -> None:
Expand All @@ -198,9 +198,8 @@ def execute_query(framework: Framework, pipeline, query) -> None:
elif framework == Framework.LLAMA_INDEX:
pipeline.query(query)
else:
raise ValueError(
f"Unknown framework: {framework} specified for execute_query()"
)
msg = f"Unknown framework: {framework} specified for execute_query()"
raise ValueError(msg)


# runs the pipeline across all queries in all known datasets
Expand Down
3 changes: 2 additions & 1 deletion examples/notebooks/advancedRAG.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@
"if uploaded:\n",
" SAMPLEDATA = uploaded\n",
"else:\n",
" raise ValueError(\"Cannot proceed without Sample Data. Please re-run the cell.\")\n",
" msg = \"Cannot proceed without Sample Data. Please re-run the cell.\"\n",
" raise ValueError(msg)\n",
"\n",
"print(\"Please make sure to change your queries to match the contents of your file!\")"
]
Expand Down
6 changes: 4 additions & 2 deletions examples/notebooks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

def get_required_env(name) -> str:
if name not in os.environ:
raise ValueError(f"Missing required environment variable: {name}")
msg = f"Missing required environment variable: {name}"
raise ValueError(msg)
value = os.environ[name]
if not value:
raise ValueError(f"Empty required environment variable: {name}")
msg = f"Empty required environment variable: {name}"
raise ValueError(msg)
return value


Expand Down
3 changes: 2 additions & 1 deletion examples/notebooks/langchain_evaluation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,8 @@
"if uploaded:\n",
" SAMPLEDATA = uploaded\n",
"else:\n",
" raise ValueError(\"Cannot proceed without Sample Data. Please re-run the cell.\")\n",
" msg = \"Cannot proceed without Sample Data. Please re-run the cell.\"\n",
" raise ValueError(msg)\n",
"\n",
"print(\"Please make sure to change your queries to match the contents of your file!\")"
]
Expand Down
20 changes: 12 additions & 8 deletions libs/colbert/ragstack_colbert/cassandra_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,11 @@ class CassandraDatabase(BaseDatabase):
_table: ClusteredMetadataVectorCassandraTable

def __new__(cls) -> Self: # noqa: D102
raise ValueError(
msg = (
"This class cannot be instantiated directly. "
"Please use the `from_astra()` or `from_session()` class methods."
)
raise ValueError(msg)

@classmethod
def from_astra(
Expand Down Expand Up @@ -173,10 +174,11 @@ def add_chunks(self, chunks: list[Chunk]) -> list[tuple[str, int]]:
success_chunks.append((doc_id, chunk_id))

if len(failed_chunks) > 0:
raise CassandraDatabaseError(
msg = (
f"add failed for these chunks: {failed_chunks}. "
f"See error logs for more info."
)
raise CassandraDatabaseError(msg)

return success_chunks

Expand Down Expand Up @@ -273,10 +275,11 @@ async def aadd_chunks(
failed_chunks.append((doc_id, chunk_id))

if len(failed_chunks) > 0:
raise CassandraDatabaseError(
msg = (
f"add failed for these chunks: {failed_chunks}. "
f"See error logs for more info."
)
raise CassandraDatabaseError(msg)

return outputs

Expand All @@ -292,8 +295,9 @@ def delete_chunks(self, doc_ids: list[str]) -> bool:
failed_docs.append(doc_id)

if len(failed_docs) > 0:
msg = "delete failed for these docs: %s. See error logs for more info."
raise CassandraDatabaseError(
"delete failed for these docs: %s. See error logs for more info.",
msg,
failed_docs,
)

Expand Down Expand Up @@ -340,10 +344,11 @@ async def adelete_chunks(
failed_docs.append(doc_id)

if len(failed_docs) > 0:
raise CassandraDatabaseError(
msg = (
f"delete failed for these docs: {failed_docs}. "
f"See error logs for more info."
)
raise CassandraDatabaseError(msg)

return success

Expand Down Expand Up @@ -379,9 +384,8 @@ async def get_chunk_data(
row = await self._table.aget(partition_id=doc_id, row_id=row_id)

if row is None:
raise CassandraDatabaseError(
f"no chunk found for doc_id: {doc_id} chunk_id: {chunk_id}"
)
msg = f"no chunk found for doc_id: {doc_id} chunk_id: {chunk_id}"
raise CassandraDatabaseError(msg)

if include_embedding is True:
embedded_chunk = await self.get_chunk_embedding(
Expand Down
8 changes: 4 additions & 4 deletions libs/colbert/ragstack_colbert/colbert_vector_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,8 @@ def __init__(

def _validate_embedding_model(self) -> BaseEmbeddingModel:
if self._embedding_model is None:
raise AttributeError(
"To use this method, `embedding_model` must be set on class creation."
)
msg = "To use this method, `embedding_model` must be set on class creation."
raise AttributeError(msg)
return self._embedding_model

def _build_chunks(
Expand All @@ -60,7 +59,8 @@ def _build_chunks(
embedding_model = self._validate_embedding_model()

if metadatas is not None and len(texts) != len(metadatas):
raise ValueError("Length of texts and metadatas must match.")
msg = "Length of texts and metadatas must match."
raise ValueError(msg)

if doc_id is None:
doc_id = str(uuid.uuid4())
Expand Down
6 changes: 4 additions & 2 deletions libs/e2e-tests/e2e_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def get_required_env(name) -> str:

vector_database_type = os.environ.get("VECTOR_DATABASE_TYPE", "astradb")
if vector_database_type not in ["astradb", "local-cassandra"]:
raise ValueError(f"Invalid VECTOR_DATABASE_TYPE: {vector_database_type}")
msg = f"Invalid VECTOR_DATABASE_TYPE: {vector_database_type}"
raise ValueError(msg)

is_astra = vector_database_type == "astradb"

Expand All @@ -67,7 +68,8 @@ def get_vector_store_handler(
return AstraDBVectorStoreHandler(implementation)
if vector_database_type == "local-cassandra":
return CassandraVectorStoreHandler(implementation)
raise ValueError("Invalid vector store implementation")
msg = "Invalid vector store implementation"
raise ValueError(msg)


failed_report_lines = []
Expand Down
3 changes: 2 additions & 1 deletion libs/e2e-tests/e2e_tests/langchain/test_compatibility_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,8 @@ def _run_test(
vector_store=vector_store, config=resolved_llm["nemo_config"]
)
else:
raise ValueError(f"Unknown test case: {test_case}")
msg = f"Unknown test case: {test_case}"
raise ValueError(msg)


@pytest.fixture()
Expand Down
5 changes: 3 additions & 2 deletions libs/knowledge-store/ragstack_knowledge_store/_mmr_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,11 +131,12 @@ def _pop_candidate(self, candidate_id: str) -> NDArray[np.float32]:
"""
# Get the embedding for the id.
index = self.candidate_id_to_index.pop(candidate_id)
if not self.candidates[index].id == candidate_id:
raise ValueError(
if self.candidates[index].id != candidate_id:
msg = (
"ID in self.candidate_id_to_index doesn't match the ID of the "
"corresponding index in self.candidates"
)
raise ValueError(msg)
embedding: NDArray[np.float32] = self.candidate_embeddings[index].copy()

# Swap that index with the last index in the candidates and
Expand Down
3 changes: 2 additions & 1 deletion libs/knowledge-store/ragstack_knowledge_store/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# This is equivalent to `itertools.batched`, but that is only available in 3.12
def batched(iterable: Iterable[T], n: int) -> Iterator[tuple[T, ...]]:
if n < 1:
raise ValueError("n must be at least one")
msg = "n must be at least one"
raise ValueError(msg)
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch
Expand Down
41 changes: 20 additions & 21 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from cassandra.cluster import ConsistencyLevel, PreparedStatement, Session
from cassio.config import check_resolve_keyspace, check_resolve_session
from typing_extensions import assert_never

from ._mmr_helper import MmrHelper
from .concurrency import ConcurrentQueries
Expand Down Expand Up @@ -76,7 +77,7 @@ def _is_metadata_field_indexed(field_name: str, policy: MetadataIndexingPolicy)
return field_name in p_fields
if p_mode == MetadataIndexingMode.DEFAULT_TO_SEARCHABLE:
return field_name not in p_fields
raise ValueError(f"Unexpected metadata indexing mode {p_mode}")
assert_never(p_mode)


def _serialize_metadata(md: dict[str, Any]) -> str:
Expand Down Expand Up @@ -170,10 +171,12 @@ def __init__(
keyspace = check_resolve_keyspace(keyspace)

if not _CQL_IDENTIFIER_PATTERN.fullmatch(keyspace):
raise ValueError(f"Invalid keyspace: {keyspace}")
msg = f"Invalid keyspace: {keyspace}"
raise ValueError(msg)

if not _CQL_IDENTIFIER_PATTERN.fullmatch(node_table):
raise ValueError(f"Invalid node table name: {node_table}")
msg = f"Invalid node table name: {node_table}"
raise ValueError(msg)

self._embedding = embedding
self._node_table = node_table
Expand All @@ -188,10 +191,11 @@ def __init__(
if setup_mode == SetupMode.SYNC:
self._apply_schema()
elif setup_mode != SetupMode.OFF:
raise ValueError(
msg = (
f"Invalid setup mode {setup_mode.name}. "
"Only SYNC and OFF are supported at the moment"
)
raise ValueError(msg)

# TODO: Parent ID / source ID / etc.
self._insert_passage = session.prepare(
Expand Down Expand Up @@ -350,7 +354,8 @@ def node_callback(rows: Iterable[Any]) -> None:

def get_result(node_id: str) -> Node:
if (result := results[node_id]) is None:
raise ValueError(f"No node with ID '{node_id}'")
msg = f"No node with ID '{node_id}'"
raise ValueError(msg)
return result

return [get_result(node_id) for node_id in ids]
Expand Down Expand Up @@ -800,14 +805,11 @@ def _normalize_metadata_indexing_policy(
elif metadata_indexing.lower() == "none":
mode, fields = (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, set())
else:
raise ValueError(
f"Unsupported metadata_indexing value '{metadata_indexing}'"
)
msg = f"Unsupported metadata_indexing value '{metadata_indexing}'"
raise ValueError(msg)
else:
if len(metadata_indexing) != 2: # noqa: PLR2004
raise ValueError(
f"Unsupported metadata_indexing value '{metadata_indexing}'."
)
assert_never(metadata_indexing)
# it's a 2-tuple (mode, fields) still to normalize
_mode, _field_spec = metadata_indexing
fields = {_field_spec} if isinstance(_field_spec, str) else set(_field_spec)
Expand All @@ -826,10 +828,9 @@ def _normalize_metadata_indexing_policy(
}:
mode = MetadataIndexingMode.DEFAULT_TO_SEARCHABLE
else:
raise ValueError(
f"Unsupported metadata indexing mode specification '{_mode}'"
)
return (mode, fields)
msg = f"Unsupported metadata indexing mode specification '{_mode}'"
raise ValueError(msg)
return mode, fields

@staticmethod
def _coerce_string(value: Any) -> str:
Expand Down Expand Up @@ -865,9 +866,8 @@ def _extract_where_clause_cql(
if _is_metadata_field_indexed(key, self._metadata_indexing_policy):
wc_blocks.append(f"metadata_s['{key}'] = ?")
else:
raise ValueError(
"Non-indexed metadata fields cannot be used in queries."
)
msg = "Non-indexed metadata fields cannot be used in queries."
raise ValueError(msg)

if len(wc_blocks) == 0:
return ""
Expand All @@ -889,9 +889,8 @@ def _extract_where_clause_params(
if _is_metadata_field_indexed(key, self._metadata_indexing_policy):
params.append(self._coerce_string(value=value))
else:
raise ValueError(
"Non-indexed metadata fields cannot be used in queries."
)
msg = "Non-indexed metadata fields cannot be used in queries."
raise ValueError(msg)

return params

Expand Down
3 changes: 2 additions & 1 deletion libs/knowledge-store/ragstack_knowledge_store/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ def cosine_similarity(x: Matrix, y: Matrix) -> NDArray[np.float32]:
x = np.array(x)
y = np.array(y)
if x.shape[1] != y.shape[1]:
raise ValueError(
msg = (
f"Number of columns in X and Y must be the same. X has shape {x.shape} "
f"and Y has shape {y.shape}."
)
raise ValueError(msg)
try:
import simsimd as simd
except ImportError:
Expand Down
5 changes: 3 additions & 2 deletions libs/langchain/ragstack_langchain/colbert/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
try:
from ragstack_colbert.base_retriever import BaseRetriever # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
raise ImportError(
msg = (
"Could not import ragstack-ai-colbert. "
"Please install it with `pip install ragstack-ai-langchain[colbert]`."
) from e
)
raise ImportError(msg) from e

from .colbert_retriever import ColbertRetriever
from .colbert_vector_store import ColbertVectorStore
Expand Down
Loading

0 comments on commit 3e056e5

Please sign in to comment.