Skip to content

Commit

Permalink
fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
epinzur committed Jul 17, 2024
1 parent 7f80794 commit e7ef77f
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 65 deletions.
105 changes: 70 additions & 35 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# ruff: noqa: B006

import json
import re
import secrets
Expand All @@ -16,7 +18,7 @@
cast,
)

from cassandra.cluster import ConsistencyLevel, Session, ResponseFuture
from cassandra.cluster import ConsistencyLevel, Session
from cassio.config import check_resolve_keyspace, check_resolve_session

from ._mmr_helper import MmrHelper
Expand All @@ -27,9 +29,14 @@

CONTENT_ID = "content_id"

CONTENT_COLUMNS = "content_id, kind, text_content, attributes_blob, metadata_s, links_blob"
CONTENT_COLUMNS = (
"content_id, kind, text_content, attributes_blob, metadata_s, links_blob"
)

SELECT_CQL_TEMPLATE = (
"SELECT {columns} FROM {table_name} {where_clause} {order_clause} {limit_clause};"
)

SELECT_CQL_TEMPLATE = "SELECT {columns} FROM {table_name} {where_clause} {order_clause} {limit_clause};"

@dataclass
class Node:
Expand All @@ -52,20 +59,25 @@ class SetupMode(Enum):
ASYNC = 2
OFF = 3


class MetadataIndexingMode(Enum):
"""Mode used to index metadata."""

DEFAULT_TO_UNSEARCHABLE = 1
DEFAULT_TO_SEARCHABLE = 2


MetadataIndexingPolicy = Tuple[MetadataIndexingMode, Set[str]]


def _is_metadata_field_indexed(field_name: str, policy: MetadataIndexingPolicy) -> bool:
p_mode, p_fields = policy
if p_mode == MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE:
return field_name in p_fields
elif p_mode == MetadataIndexingMode.DEFAULT_TO_SEARCHABLE:
if p_mode == MetadataIndexingMode.DEFAULT_TO_SEARCHABLE:
return field_name not in p_fields
else:
raise ValueError(f"Unexpected metadata indexing mode {p_mode}")
raise ValueError(f"Unexpected metadata indexing mode {p_mode}")


def _serialize_metadata(md: Dict[str, Any]) -> str:
if isinstance(md.get("links"), Set):
Expand Down Expand Up @@ -112,7 +124,9 @@ def _row_to_node(row: Any) -> Node:
if metadata_s is None:
metadata_s = {}
attributes_blob = row.attributes_blob
attributes_dict = _deserialize_metadata(attributes_blob) if attributes_blob is not None else {}
attributes_dict = (
_deserialize_metadata(attributes_blob) if attributes_blob is not None else {}
)
links = _deserialize_links(row.links_blob)
return Node(
id=row.content_id,
Expand Down Expand Up @@ -237,6 +251,7 @@ def __init__(
)

def table_name(self) -> str:
"""Returns the fully qualified table name."""
return f"{self._keyspace}.{self._node_table}"

def _apply_schema(self) -> None:
Expand Down Expand Up @@ -281,7 +296,9 @@ def _apply_schema(self) -> None:
def _concurrent_queries(self) -> ConcurrentQueries:
return ConcurrentQueries(self._session)

def _parse_metadata(self, metadata: Dict[str, Any], is_query: bool) -> Tuple[str, Dict[str,str]]:
def _parse_metadata(
self, metadata: Dict[str, Any], is_query: bool
) -> Tuple[str, Dict[str, str]]:
attributes_dict = {
k: self._coerce_string(v)
for k, v in metadata.items()
Expand All @@ -296,10 +313,11 @@ def _parse_metadata(self, metadata: Dict[str, Any], is_query: bool) -> Tuple[str
for k, v in metadata.items()
if _is_metadata_field_indexed(k, self._metadata_indexing_policy)
}
metadata_s = {k: self._coerce_string(v) for k, v in metadata_indexed_dict.items()}
metadata_s = {
k: self._coerce_string(v) for k, v in metadata_indexed_dict.items()
}
return (attributes_blob, metadata_s)


# TODO: Async (aadd_nodes)
def add_nodes(
self,
Expand Down Expand Up @@ -335,7 +353,9 @@ def add_nodes(
if tag.direction in {"out", "bidir"}:
link_to_tags.add((tag.kind, tag.tag))

attributes_blob, metadata_s = self._parse_metadata(metadata=metadata, is_query=False)
attributes_blob, metadata_s = self._parse_metadata(
metadata=metadata, is_query=False
)

links_blob = _serialize_links(links)
cq.execute(
Expand Down Expand Up @@ -440,7 +460,7 @@ def fetch_initial_candidates() -> None:
limit=fetch_k,
columns="content_id, text_embedding, link_to_tags",
metadata=metadata,
embedding=query_embedding
embedding=query_embedding,
)

fetched = self._session.execute(query=query, parameters=params)
Expand Down Expand Up @@ -515,7 +535,12 @@ def fetch_initial_candidates() -> None:
return self._nodes_with_ids(helper.selected_ids)

def traversal_search(
self, query: str, *, k: int = 4, depth: int = 1, metadata: Optional[Dict[str, Any]] = [],
self,
query: str,
*,
k: int = 4,
depth: int = 1,
metadata: Optional[Dict[str, Any]] = [],
) -> Iterable[Node]:
"""Retrieve documents from this knowledge store.
Expand Down Expand Up @@ -634,21 +659,26 @@ def similarity_search(
k: int = 4,
metadata: Optional[Dict[str, Any]] = [],
) -> Iterable[Node]:
"""Retrieve nodes similar to the given embedding, optionally filtered by metadata"""
query, params = self._get_search_cql(embedding=embedding, limit=k, metadata=metadata)
"""Retrieve nodes similar to the given embedding, optionally filtered by metadata.""" # noqa: E501
query, params = self._get_search_cql(
embedding=embedding, limit=k, metadata=metadata
)

for row in self._session.execute(query, params):
yield _row_to_node(row)

def metadata_search(self, metadata: Dict[str, Any] = {}, n: Optional[int] = 5)-> Iterable[Node]:
def metadata_search(
self, metadata: Dict[str, Any] = {}, n: Optional[int] = 5
) -> Iterable[Node]:
"""Retrieve nodes based on their metadata."""
query, params = self._get_search_cql(metadata=metadata, limit=n)

for row in self._session.execute(query, params):
yield _row_to_node(row)

def get_node(self, id: str) -> Node:
return self._nodes_with_ids(ids=[id])[0]

def get_node(self, content_id: str) -> Node:
"""Get a node by its id."""
return self._nodes_with_ids(ids=[content_id])[0]

def _get_outgoing_tags(
self,
Expand Down Expand Up @@ -723,7 +753,7 @@ def add_targets(rows: Iterable[Any]) -> None:

@staticmethod
def _normalize_metadata_indexing_policy(
metadata_indexing: Union[Tuple[str, Iterable[str]], str]
metadata_indexing: Union[Tuple[str, Iterable[str]], str],
) -> MetadataIndexingPolicy:
mode: MetadataIndexingMode
fields: Set[str]
Expand All @@ -738,7 +768,10 @@ def _normalize_metadata_indexing_policy(
f"Unsupported metadata_indexing value '{metadata_indexing}'"
)
else:
assert len(metadata_indexing) == 2
if len(metadata_indexing) != 2: # noqa: PLR2004
raise ValueError(
f"Unsupported metadata_indexing value '{metadata_indexing}'."
)
# it's a 2-tuple (mode, fields) still to normalize
_mode, _field_spec = metadata_indexing
fields = {_field_spec} if isinstance(_field_spec, str) else set(_field_spec)
Expand Down Expand Up @@ -766,25 +799,21 @@ def _normalize_metadata_indexing_policy(
def _coerce_string(value: Any) -> str:
if isinstance(value, str):
return value
elif isinstance(value, bool):
if isinstance(value, bool):
# bool MUST come before int in this chain of ifs!
return json.dumps(value)
elif isinstance(value, int):
if isinstance(value, int):
# we don't want to store '1' and '1.0' differently
# for the sake of metadata-filtered retrieval:
return json.dumps(float(value))
elif isinstance(value, float):
if isinstance(value, float) or value is None:
return json.dumps(value)
elif value is None:
return json.dumps(value)
else:
# when all else fails ...
return str(value)
# when all else fails ...
return str(value)

def _extract_where_clause_blocks(
self, metadata: Dict[str, Any]
) -> Tuple[str, List[Any]]:

_, metadata_s = self._parse_metadata(metadata=metadata, is_query=True)

if len(metadata_s) == 0:
Expand All @@ -800,13 +829,20 @@ def _extract_where_clause_blocks(
where_clause = "WHERE " + " AND ".join(wc_blocks)
return where_clause, vals_list


def _get_search_cql(self, limit: int, columns: Optional[str] = CONTENT_COLUMNS, metadata: Optional[Dict[str, Any]] = {}, embedding: Optional[List[float]] = None) -> Tuple[str, Tuple[Any, ...]]:
where_clause, get_cql_vals = self._extract_where_clause_blocks(metadata=metadata)
def _get_search_cql(
self,
limit: int,
columns: Optional[str] = CONTENT_COLUMNS,
metadata: Optional[Dict[str, Any]] = {},
embedding: Optional[List[float]] = None,
) -> Tuple[str, Tuple[Any, ...]]:
where_clause, get_cql_vals = self._extract_where_clause_blocks(
metadata=metadata
)
limit_clause = "LIMIT ?"
limit_cql_vals = [limit]

order_clause=""
order_clause = ""
order_cql_vals = []
if embedding is not None:
order_clause = "ORDER BY text_embedding ANN OF ?"
Expand All @@ -819,7 +855,6 @@ def _get_search_cql(self, limit: int, columns: Optional[str] = CONTENT_COLUMNS,
where_clause=where_clause,
order_clause=order_clause,
limit_clause=limit_clause,

)
prepared_query = self._session.prepare(select_cql)
prepared_query.consistency_level = ConsistencyLevel.ONE
Expand Down
40 changes: 23 additions & 17 deletions libs/knowledge-store/tests/integration_tests/test_graph_store.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# ruff: noqa: PT011, RUF015

import secrets
from typing import Callable, Iterator, List, Optional

Expand Down Expand Up @@ -71,27 +73,28 @@ def test_graph_store_creation(graph_store_factory: Callable[[str], GraphStore])
"""
graph_store_factory()


def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore]) -> None:
gs = graph_store_factory()

gs.add_nodes([Node(text="bb1", id="row1")])
gotten1 = gs.get_node(id="row1")
gotten1 = gs.get_node(content_id="row1")
assert gotten1 == Node(text="bb1", id="row1", metadata={})

gs.add_nodes([Node(text=None, id="row2", metadata={})])
gotten2 = gs.get_node(id="row2")
gotten2 = gs.get_node(content_id="row2")
assert gotten2 == Node(text=None, id="row2", metadata={})

md3 = {"a": 1, "b": "Bee", "c": True}
md3_string = {"a": "1.0", "b": "Bee", "c": "true"}
gs.add_nodes([Node(text=None, id="row3", metadata=md3)])
gotten3 = gs.get_node(id="row3")
gotten3 = gs.get_node(content_id="row3")
assert gotten3 == Node(text=None, id="row3", metadata=md3_string)

md4 = {"c1": True, "c2": True, "c3": True}
md4_string = {"c1": "true", "c2": "true", "c3": "true"}
gs.add_nodes([Node(text=None, id="row4", metadata=md4)])
gotten4 = gs.get_node(id="row4")
gotten4 = gs.get_node(content_id="row4")
assert gotten4 == Node(text=None, id="row4", metadata=md4_string)

# metadata searches:
Expand All @@ -103,13 +106,15 @@ def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore])
assert md_gotten4 == gotten4

# 'search' proper
gs.add_nodes([
Node(text=None, id="twin_a", metadata={"twin": True, "index": 0}),
Node(text=None, id="twin_b", metadata={"twin": True, "index": 1})
])
gs.add_nodes(
[
Node(text=None, id="twin_a", metadata={"twin": True, "index": 0}),
Node(text=None, id="twin_b", metadata={"twin": True, "index": 1}),
]
)
md_twins_gotten = sorted(
list(gs.metadata_search(metadata={"twin": True})),
key=lambda res: int(float(res.metadata["index"]))
gs.metadata_search(metadata={"twin": True}),
key=lambda res: int(float(res.metadata["index"])),
)
expected = [
Node(text=None, id="twin_a", metadata={"twin": "true", "index": "0.0"}),
Expand All @@ -118,24 +123,25 @@ def test_graph_store_metadata(graph_store_factory: Callable[[str], GraphStore])
assert md_twins_gotten == expected
assert list(gs.metadata_search(metadata={"fake": True})) == []

def test_graph_store_metadata_routing(graph_store_factory: Callable[[str], GraphStore]) -> None:

def test_graph_store_metadata_routing(
graph_store_factory: Callable[[str], GraphStore],
) -> None:
test_md = {"mds": "string", "mdn": 255, "mdb": True}
test_md_string = {"mds": "string", "mdn": "255.0", "mdb": "true"}

gs_all = graph_store_factory(metadata_indexing="all")
gs_all.add_nodes([Node(id="row1", text="bb1", metadata=test_md)])
gotten_all = list(gs_all.metadata_search(metadata={"mds": "string"}))[0]
assert gotten_all.metadata == test_md_string
#
gs_none = graph_store_factory(metadata_indexing="none")
gs_none.add_nodes([Node(id="row1", text="bb1", metadata=test_md)])
with pytest.raises(ValueError):
# querying on non-indexed metadata fields:
list(gs_none.metadata_search(metadata={"mds": "string"}))
gotten_none = gs_none.get_node(id="row1")
gotten_none = gs_none.get_node(content_id="row1")
assert gotten_none is not None
assert gotten_none.metadata == test_md_string
#
test_md_allowdeny = {
"mdas": "MDAS",
"mdds": "MDDS",
Expand All @@ -152,14 +158,14 @@ def test_graph_store_metadata_routing(graph_store_factory: Callable[[str], Graph
"mdab": "true",
"mddb": "true",
}
#
gs_allow = graph_store_factory(metadata_indexing=("allow", {"mdas", "mdan", "mdab"}))
gs_allow = graph_store_factory(
metadata_indexing=("allow", {"mdas", "mdan", "mdab"})
)
gs_allow.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)])
with pytest.raises(ValueError):
list(gs_allow.metadata_search(metadata={"mdds": "MDDS"}))
gotten_allow = list(gs_allow.metadata_search(metadata={"mdas": "MDAS"}))[0]
assert gotten_allow.metadata == test_md_allowdeny_string
#
gs_deny = graph_store_factory(metadata_indexing=("deny", {"mdds", "mddn", "mddb"}))
gs_deny.add_nodes([Node(id="row1", text="bb1", metadata=test_md_allowdeny)])
with pytest.raises(ValueError):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
# ruff: noqa: SLF001
"""
Normalization of metadata policy specification options
"""

from ragstack_knowledge_store.graph_store import MetadataIndexingMode, GraphStore
from ragstack_knowledge_store.graph_store import GraphStore, MetadataIndexingMode


class TestNormalizeMetadataPolicy:
def test_normalize_metadata_policy(self) -> None:
#
mdp1 = GraphStore._normalize_metadata_indexing_policy("all")
assert mdp1 == (MetadataIndexingMode.DEFAULT_TO_SEARCHABLE, set())
#
mdp2 = GraphStore._normalize_metadata_indexing_policy("none")
assert mdp2 == (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, set())
#
mdp3 = GraphStore._normalize_metadata_indexing_policy(
("default_to_Unsearchable", ["x", "y"]),
)
assert mdp3 == (MetadataIndexingMode.DEFAULT_TO_UNSEARCHABLE, {"x", "y"})
#
mdp4 = GraphStore._normalize_metadata_indexing_policy(
("DenyList", ["z"]),
)
Expand Down
Loading

0 comments on commit e7ef77f

Please sign in to comment.