diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index 384f587c4..ade610d93 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -67,6 +67,7 @@ class MetadataIndexingMode(Enum): DEFAULT_TO_SEARCHABLE = 2 +MetadataIndexingType = Union[Tuple[str, Iterable[str]], str] MetadataIndexingPolicy = Tuple[MetadataIndexingMode, Set[str]] @@ -166,7 +167,7 @@ def __init__( session: Optional[Session] = None, keyspace: Optional[str] = None, setup_mode: SetupMode = SetupMode.SYNC, - metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all", + metadata_indexing: MetadataIndexingType = "all", ): session = check_resolve_session(session) keyspace = check_resolve_keyspace(keyspace) @@ -414,7 +415,7 @@ def mmr_traversal_search( adjacent_k: int = 10, lambda_mult: float = 0.5, score_threshold: float = float("-inf"), - metadata: Optional[Dict[str, Any]] = [], + metadata: Dict[str, Any] = {}, ) -> Iterable[Node]: """Retrieve documents from this graph store using MMR-traversal. @@ -540,7 +541,7 @@ def traversal_search( *, k: int = 4, depth: int = 1, - metadata: Optional[Dict[str, Any]] = [], + metadata: Dict[str, Any] = {}, ) -> Iterable[Node]: """Retrieve documents from this knowledge store. @@ -657,7 +658,7 @@ def similarity_search( self, embedding: List[float], k: int = 4, - metadata: Optional[Dict[str, Any]] = [], + metadata: Dict[str, Any] = {}, ) -> Iterable[Node]: """Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501 query, params = self._get_search_cql( @@ -668,7 +669,7 @@ def similarity_search( yield _row_to_node(row) def metadata_search( - self, metadata: Dict[str, Any] = {}, n: Optional[int] = 5 + self, metadata: Dict[str, Any] = {}, n: int = 5 ) -> Iterable[Node]: """Retrieve nodes based on their metadata.""" query, params = self._get_search_cql(metadata=metadata, limit=n) @@ -833,7 +834,7 @@ def _get_search_cql( self, limit: int, columns: Optional[str] = CONTENT_COLUMNS, - metadata: Optional[Dict[str, Any]] = {}, + metadata: Dict[str, Any] = {}, embedding: Optional[List[float]] = None, ) -> Tuple[str, Tuple[Any, ...]]: where_clause, get_cql_vals = self._extract_where_clause_blocks( diff --git a/libs/knowledge-store/tests/integration_tests/test_graph_store.py b/libs/knowledge-store/tests/integration_tests/test_graph_store.py index 43f6f82f9..2b4dd9502 100644 --- a/libs/knowledge-store/tests/integration_tests/test_graph_store.py +++ b/libs/knowledge-store/tests/integration_tests/test_graph_store.py @@ -1,12 +1,12 @@ # ruff: noqa: PT011, RUF015 import secrets -from typing import Callable, Iterator, List, Optional +from typing import Callable, Iterator, List import pytest from dotenv import load_dotenv from ragstack_knowledge_store import EmbeddingModel -from ragstack_knowledge_store.graph_store import GraphStore, Node +from ragstack_knowledge_store.graph_store import GraphStore, MetadataIndexingType, Node from ragstack_tests_utils import LocalCassandraTestStore load_dotenv() @@ -49,7 +49,7 @@ def graph_store_factory( embedding = DummyEmbeddingModel() - def _make_graph_store(metadata_indexing: Optional[str] = "all") -> GraphStore: + def _make_graph_store(metadata_indexing: str = "all") -> GraphStore: name = secrets.token_hex(8) node_table = f"nodes_{name}" @@ -66,36 +66,40 @@ def _make_graph_store(metadata_indexing: Optional[str] = "all") -> GraphStore: session.shutdown() -def test_graph_store_creation(graph_store_factory: Callable[[str], GraphStore]) -> None: +def test_graph_store_creation( + graph_store_factory: Callable[[MetadataIndexingType], GraphStore], +) -> None: """Test that a graph store can be created. This verifies the schema can be applied and the queries prepared. """ - graph_store_factory() + graph_store_factory("all") -def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore]) -> None: - gs = graph_store_factory() +def test_graph_store_metadata( + graph_store_factory: Callable[[MetadataIndexingType], GraphStore], +) -> None: + gs = graph_store_factory("all") gs.add_nodes([Node(text="bb1", id="row1")]) gotten1 = gs.get_node(content_id="row1") assert gotten1 == Node(text="bb1", id="row1", metadata={}) - gs.add_nodes([Node(text=None, id="row2", metadata={})]) + gs.add_nodes([Node(text="", id="row2", metadata={})]) gotten2 = gs.get_node(content_id="row2") - assert gotten2 == Node(text=None, id="row2", metadata={}) + assert gotten2 == Node(text="", id="row2", metadata={}) md3 = {"a": 1, "b": "Bee", "c": True} md3_string = {"a": "1.0", "b": "Bee", "c": "true"} - gs.add_nodes([Node(text=None, id="row3", metadata=md3)]) + gs.add_nodes([Node(text="", id="row3", metadata=md3)]) gotten3 = gs.get_node(content_id="row3") - assert gotten3 == Node(text=None, id="row3", metadata=md3_string) + assert gotten3 == Node(text="", id="row3", metadata=md3_string) md4 = {"c1": True, "c2": True, "c3": True} md4_string = {"c1": "true", "c2": "true", "c3": "true"} - gs.add_nodes([Node(text=None, id="row4", metadata=md4)]) + gs.add_nodes([Node(text="", id="row4", metadata=md4)]) gotten4 = gs.get_node(content_id="row4") - assert gotten4 == Node(text=None, id="row4", metadata=md4_string) + assert gotten4 == Node(text="", id="row4", metadata=md4_string) # metadata searches: md_gotten3a = list(gs.metadata_search(metadata={"a": 1}))[0] @@ -108,8 +112,8 @@ def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore]) # 'search' proper gs.add_nodes( [ - Node(text=None, id="twin_a", metadata={"twin": True, "index": 0}), - Node(text=None, id="twin_b", metadata={"twin": True, "index": 1}), + Node(text="", id="twin_a", metadata={"twin": True, "index": 0}), + Node(text="", id="twin_b", metadata={"twin": True, "index": 1}), ] ) md_twins_gotten = sorted( @@ -117,24 +121,24 @@ def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore]) key=lambda res: int(float(res.metadata["index"])), ) expected = [ - Node(text=None, id="twin_a", metadata={"twin": "true", "index": "0.0"}), - Node(text=None, id="twin_b", metadata={"twin": "true", "index": "1.0"}), + Node(text="", id="twin_a", metadata={"twin": "true", "index": "0.0"}), + Node(text="", id="twin_b", metadata={"twin": "true", "index": "1.0"}), ] assert md_twins_gotten == expected assert list(gs.metadata_search(metadata={"fake": True})) == [] def test_graph_store_metadata_routing( - graph_store_factory: Callable[[str], GraphStore], + graph_store_factory: Callable[[MetadataIndexingType], GraphStore], ) -> None: test_md = {"mds": "string", "mdn": 255, "mdb": True} test_md_string = {"mds": "string", "mdn": "255.0", "mdb": "true"} - gs_all = graph_store_factory(metadata_indexing="all") + gs_all = graph_store_factory("all") gs_all.add_nodes([Node(id="row1", text="bb1", metadata=test_md)]) gotten_all = list(gs_all.metadata_search(metadata={"mds": "string"}))[0] assert gotten_all.metadata == test_md_string - gs_none = graph_store_factory(metadata_indexing="none") + gs_none = graph_store_factory("none") gs_none.add_nodes([Node(id="row1", text="bb1", metadata=test_md)]) with pytest.raises(ValueError): # querying on non-indexed metadata fields: @@ -158,15 +162,13 @@ def test_graph_store_metadata_routing( "mdab": "true", "mddb": "true", } - gs_allow = graph_store_factory( - metadata_indexing=("allow", {"mdas", "mdan", "mdab"}) - ) + gs_allow = graph_store_factory(("allow", {"mdas", "mdan", "mdab"})) gs_allow.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)]) with pytest.raises(ValueError): list(gs_allow.metadata_search(metadata={"mdds": "MDDS"})) gotten_allow = list(gs_allow.metadata_search(metadata={"mdas": "MDAS"}))[0] assert gotten_allow.metadata == test_md_allowdeny_string - gs_deny = graph_store_factory(metadata_indexing=("deny", {"mdds", "mddn", "mddb"})) + gs_deny = graph_store_factory(("deny", {"mdds", "mddn", "mddb"})) gs_deny.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)]) with pytest.raises(ValueError): list(gs_deny.metadata_search(metadata={"mdds": "MDDS"}))