Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
kerinin committed Aug 15, 2024
1 parent eb54a9f commit 92e4c74
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
29 changes: 22 additions & 7 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,7 @@ def mmr_traversal_search(
lambda_mult: float = 0.5,
score_threshold: float = float("-inf"),
metadata_filter: dict[str, Any] = {}, # noqa: B006
tag_filter: set[tuple[str, str]],
) -> Iterable[Node]:
"""Retrieve documents from this graph store using MMR-traversal.
Expand Down Expand Up @@ -398,6 +399,7 @@ def mmr_traversal_search(
score_threshold: Only documents with a score greater than or equal
this threshold will be chosen. Defaults to -infinity.
metadata_filter: Optional metadata to filter the results.
tag_filter: Optional tags to filter graph edges to be traversed.
"""
query_embedding = self._embedding.embed_query(query)
helper = MmrHelper(
Expand Down Expand Up @@ -444,9 +446,14 @@ def fetch_neighborhood(neighborhood: Sequence[str]) -> None:
new_candidates = {}
for adjacent in adjacents:
if adjacent.target_content_id not in outgoing_tags:
outgoing_tags[adjacent.target_content_id] = (
adjacent.target_link_to_tags
)
if tag_filter.len() == 0:
outgoing_tags[adjacent.target_content_id] = (
adjacent.target_link_to_tags
)
else:
outgoing_tags[adjacent.target_content_id] = (
tag_filter.intersection(adjacent.target_link_to_tags)
)

new_candidates[adjacent.target_content_id] = (
adjacent.target_text_embedding
Expand Down Expand Up @@ -474,7 +481,10 @@ def fetch_initial_candidates() -> None:
for row in fetched:
if row.content_id not in outgoing_tags:
candidates[row.content_id] = row.text_embedding
outgoing_tags[row.content_id] = set(row.link_to_tags or [])
if tag_filter.len() == 0:
outgoing_tags[row.content_id] = set(row.link_to_tags or [])
else:
outgoing_tags[row.content_id] = tag_filter.intersection(set(row.link_to_tags or []))
helper.add_candidates(candidates)

if initial_roots:
Expand Down Expand Up @@ -522,9 +532,14 @@ def fetch_initial_candidates() -> None:
new_candidates = {}
for adjacent in adjacents:
if adjacent.target_content_id not in outgoing_tags:
outgoing_tags[adjacent.target_content_id] = (
adjacent.target_link_to_tags
)
if tag_filter.len() == 0:
outgoing_tags[adjacent.target_content_id] = (
adjacent.target_link_to_tags
)
else:
outgoing_tags[adjacent.target_content_id] = (
tag_filter.intersection(adjacent.target_link_to_tags)
)
new_candidates[adjacent.target_content_id] = (
adjacent.target_text_embedding
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,12 @@ def test_mmr_traversal(
results = gs.mmr_traversal_search("0.0", fetch_k=2, k=4, initial_roots=["v0"])
assert _result_ids(results) == ["v1", "v3", "v2"]

results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, tag_filter=set(("explicit", "link")))
assert _result_ids(results) == ["v0", "v2"]

results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, tag_filter=set(("no", "match")))
assert _result_ids(results) == []


def test_write_retrieve_keywords(
graph_store_factory: Callable[[MetadataIndexingType], GraphStore],
Expand Down

0 comments on commit 92e4c74

Please sign in to comment.