diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index 02110a62e..068d7ee0f 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -367,6 +367,7 @@ def mmr_traversal_search( lambda_mult: float = 0.5, score_threshold: float = float("-inf"), metadata_filter: dict[str, Any] = {}, # noqa: B006 + tag_filter: set[tuple[str, str]], ) -> Iterable[Node]: """Retrieve documents from this graph store using MMR-traversal. @@ -398,6 +399,7 @@ def mmr_traversal_search( score_threshold: Only documents with a score greater than or equal this threshold will be chosen. Defaults to -infinity. metadata_filter: Optional metadata to filter the results. + tag_filter: Optional tags to filter graph edges to be traversed. """ query_embedding = self._embedding.embed_query(query) helper = MmrHelper( @@ -444,9 +446,14 @@ def fetch_neighborhood(neighborhood: Sequence[str]) -> None: new_candidates = {} for adjacent in adjacents: if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags - ) + if tag_filter.len() == 0: + outgoing_tags[adjacent.target_content_id] = ( + adjacent.target_link_to_tags + ) + else: + outgoing_tags[adjacent.target_content_id] = ( + tag_filter.intersection(adjacent.target_link_to_tags) + ) new_candidates[adjacent.target_content_id] = ( adjacent.target_text_embedding @@ -474,7 +481,10 @@ def fetch_initial_candidates() -> None: for row in fetched: if row.content_id not in outgoing_tags: candidates[row.content_id] = row.text_embedding - outgoing_tags[row.content_id] = set(row.link_to_tags or []) + if tag_filter.len() == 0: + outgoing_tags[row.content_id] = set(row.link_to_tags or []) + else: + outgoing_tags[row.content_id] = tag_filter.intersection(set(row.link_to_tags or [])) helper.add_candidates(candidates) if initial_roots: @@ -522,9 +532,14 @@ def fetch_initial_candidates() -> None: new_candidates = {} for adjacent in adjacents: if adjacent.target_content_id not in outgoing_tags: - outgoing_tags[adjacent.target_content_id] = ( - adjacent.target_link_to_tags - ) + if tag_filter.len() == 0: + outgoing_tags[adjacent.target_content_id] = ( + adjacent.target_link_to_tags + ) + else: + outgoing_tags[adjacent.target_content_id] = ( + tag_filter.intersection(adjacent.target_link_to_tags) + ) new_candidates[adjacent.target_content_id] = ( adjacent.target_text_embedding ) diff --git a/libs/knowledge-store/tests/integration_tests/test_graph_store.py b/libs/knowledge-store/tests/integration_tests/test_graph_store.py index 17d5a7a77..84b66abce 100644 --- a/libs/knowledge-store/tests/integration_tests/test_graph_store.py +++ b/libs/knowledge-store/tests/integration_tests/test_graph_store.py @@ -211,6 +211,12 @@ def test_mmr_traversal( results = gs.mmr_traversal_search("0.0", fetch_k=2, k=4, initial_roots=["v0"]) assert _result_ids(results) == ["v1", "v3", "v2"] + results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, tag_filter=set(("explicit", "link"))) + assert _result_ids(results) == ["v0", "v2"] + + results = gs.mmr_traversal_search("0.0", k=2, fetch_k=2, tag_filter=set(("no", "match"))) + assert _result_ids(results) == [] + def test_write_retrieve_keywords( graph_store_factory: Callable[[MetadataIndexingType], GraphStore],