Skip to content

Commit

Permalink
fix type check
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Jul 19, 2024
1 parent c5eafb9 commit 9deb78e
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 30 deletions.
13 changes: 7 additions & 6 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class MetadataIndexingMode(Enum):
DEFAULT_TO_SEARCHABLE = 2


MetadataIndexingType = Union[Tuple[str, Iterable[str]], str]
MetadataIndexingPolicy = Tuple[MetadataIndexingMode, Set[str]]


Expand Down Expand Up @@ -166,7 +167,7 @@ def __init__(
session: Optional[Session] = None,
keyspace: Optional[str] = None,
setup_mode: SetupMode = SetupMode.SYNC,
metadata_indexing: Union[Tuple[str, Iterable[str]], str] = "all",
metadata_indexing: MetadataIndexingType = "all",
):
session = check_resolve_session(session)
keyspace = check_resolve_keyspace(keyspace)
Expand Down Expand Up @@ -414,7 +415,7 @@ def mmr_traversal_search(
adjacent_k: int = 10,
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
metadata: Optional[Dict[str, Any]] = [],
metadata: Dict[str, Any] = {},
) -> Iterable[Node]:
"""Retrieve documents from this graph store using MMR-traversal.
Expand Down Expand Up @@ -540,7 +541,7 @@ def traversal_search(
*,
k: int = 4,
depth: int = 1,
metadata: Optional[Dict[str, Any]] = [],
metadata: Dict[str, Any] = {},
) -> Iterable[Node]:
"""Retrieve documents from this knowledge store.
Expand Down Expand Up @@ -657,7 +658,7 @@ def similarity_search(
self,
embedding: List[float],
k: int = 4,
metadata: Optional[Dict[str, Any]] = [],
metadata: Dict[str, Any] = {},
) -> Iterable[Node]:
"""Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501
query, params = self._get_search_cql(
Expand All @@ -668,7 +669,7 @@ def similarity_search(
yield _row_to_node(row)

def metadata_search(
self, metadata: Dict[str, Any] = {}, n: Optional[int] = 5
self, metadata: Dict[str, Any] = {}, n: int = 5
) -> Iterable[Node]:
"""Retrieve nodes based on their metadata."""
query, params = self._get_search_cql(metadata=metadata, limit=n)
Expand Down Expand Up @@ -833,7 +834,7 @@ def _get_search_cql(
self,
limit: int,
columns: Optional[str] = CONTENT_COLUMNS,
metadata: Optional[Dict[str, Any]] = {},
metadata: Dict[str, Any] = {},
embedding: Optional[List[float]] = None,
) -> Tuple[str, Tuple[Any, ...]]:
where_clause, get_cql_vals = self._extract_where_clause_blocks(
Expand Down
50 changes: 26 additions & 24 deletions libs/knowledge-store/tests/integration_tests/test_graph_store.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# ruff: noqa: PT011, RUF015

import secrets
from typing import Callable, Iterator, List, Optional
from typing import Callable, Iterator, List

import pytest
from dotenv import load_dotenv
from ragstack_knowledge_store import EmbeddingModel
from ragstack_knowledge_store.graph_store import GraphStore, Node
from ragstack_knowledge_store.graph_store import GraphStore, MetadataIndexingType, Node
from ragstack_tests_utils import LocalCassandraTestStore

load_dotenv()
Expand Down Expand Up @@ -49,7 +49,7 @@ def graph_store_factory(

embedding = DummyEmbeddingModel()

def _make_graph_store(metadata_indexing: Optional[str] = "all") -> GraphStore:
def _make_graph_store(metadata_indexing: str = "all") -> GraphStore:
name = secrets.token_hex(8)

node_table = f"nodes_{name}"
Expand All @@ -66,36 +66,40 @@ def _make_graph_store(metadata_indexing: Optional[str] = "all") -> GraphStore:
session.shutdown()


def test_graph_store_creation(graph_store_factory: Callable[[str], GraphStore]) -> None:
def test_graph_store_creation(
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],
) -> None:
"""Test that a graph store can be created.
This verifies the schema can be applied and the queries prepared.
"""
graph_store_factory()
graph_store_factory("all")


def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore]) -> None:
gs = graph_store_factory()
def test_graph_store_metadata(
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],
) -> None:
gs = graph_store_factory("all")

gs.add_nodes([Node(text="bb1", id="row1")])
gotten1 = gs.get_node(content_id="row1")
assert gotten1 == Node(text="bb1", id="row1", metadata={})

gs.add_nodes([Node(text=None, id="row2", metadata={})])
gs.add_nodes([Node(text="", id="row2", metadata={})])
gotten2 = gs.get_node(content_id="row2")
assert gotten2 == Node(text=None, id="row2", metadata={})
assert gotten2 == Node(text="", id="row2", metadata={})

md3 = {"a": 1, "b": "Bee", "c": True}
md3_string = {"a": "1.0", "b": "Bee", "c": "true"}
gs.add_nodes([Node(text=None, id="row3", metadata=md3)])
gs.add_nodes([Node(text="", id="row3", metadata=md3)])
gotten3 = gs.get_node(content_id="row3")
assert gotten3 == Node(text=None, id="row3", metadata=md3_string)
assert gotten3 == Node(text="", id="row3", metadata=md3_string)

md4 = {"c1": True, "c2": True, "c3": True}
md4_string = {"c1": "true", "c2": "true", "c3": "true"}
gs.add_nodes([Node(text=None, id="row4", metadata=md4)])
gs.add_nodes([Node(text="", id="row4", metadata=md4)])
gotten4 = gs.get_node(content_id="row4")
assert gotten4 == Node(text=None, id="row4", metadata=md4_string)
assert gotten4 == Node(text="", id="row4", metadata=md4_string)

# metadata searches:
md_gotten3a = list(gs.metadata_search(metadata={"a": 1}))[0]
Expand All @@ -108,33 +112,33 @@ def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore])
# 'search' proper
gs.add_nodes(
[
Node(text=None, id="twin_a", metadata={"twin": True, "index": 0}),
Node(text=None, id="twin_b", metadata={"twin": True, "index": 1}),
Node(text="", id="twin_a", metadata={"twin": True, "index": 0}),
Node(text="", id="twin_b", metadata={"twin": True, "index": 1}),
]
)
md_twins_gotten = sorted(
gs.metadata_search(metadata={"twin": True}),
key=lambda res: int(float(res.metadata["index"])),
)
expected = [
Node(text=None, id="twin_a", metadata={"twin": "true", "index": "0.0"}),
Node(text=None, id="twin_b", metadata={"twin": "true", "index": "1.0"}),
Node(text="", id="twin_a", metadata={"twin": "true", "index": "0.0"}),
Node(text="", id="twin_b", metadata={"twin": "true", "index": "1.0"}),
]
assert md_twins_gotten == expected
assert list(gs.metadata_search(metadata={"fake": True})) == []


def test_graph_store_metadata_routing(
graph_store_factory: Callable[[str], GraphStore],
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],
) -> None:
test_md = {"mds": "string", "mdn": 255, "mdb": True}
test_md_string = {"mds": "string", "mdn": "255.0", "mdb": "true"}

gs_all = graph_store_factory(metadata_indexing="all")
gs_all = graph_store_factory("all")
gs_all.add_nodes([Node(id="row1", text="bb1", metadata=test_md)])
gotten_all = list(gs_all.metadata_search(metadata={"mds": "string"}))[0]
assert gotten_all.metadata == test_md_string
gs_none = graph_store_factory(metadata_indexing="none")
gs_none = graph_store_factory("none")
gs_none.add_nodes([Node(id="row1", text="bb1", metadata=test_md)])
with pytest.raises(ValueError):
# querying on non-indexed metadata fields:
Expand All @@ -158,15 +162,13 @@ def test_graph_store_metadata_routing(
"mdab": "true",
"mddb": "true",
}
gs_allow = graph_store_factory(
metadata_indexing=("allow", {"mdas", "mdan", "mdab"})
)
gs_allow = graph_store_factory(("allow", {"mdas", "mdan", "mdab"}))
gs_allow.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)])
with pytest.raises(ValueError):
list(gs_allow.metadata_search(metadata={"mdds": "MDDS"}))
gotten_allow = list(gs_allow.metadata_search(metadata={"mdas": "MDAS"}))[0]
assert gotten_allow.metadata == test_md_allowdeny_string
gs_deny = graph_store_factory(metadata_indexing=("deny", {"mdds", "mddn", "mddb"}))
gs_deny = graph_store_factory(("deny", {"mdds", "mddn", "mddb"}))
gs_deny.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)])
with pytest.raises(ValueError):
list(gs_deny.metadata_search(metadata={"mdds": "MDDS"}))
Expand Down

0 comments on commit 9deb78e

Please sign in to comment.