From 547c6d0e114eb364447cd78ba1c4f78126b244fc Mon Sep 17 00:00:00 2001 From: Nicholas Chittle Date: Mon, 12 Aug 2024 21:57:43 -0400 Subject: [PATCH] Fix async vector search from a collection --- google/cloud/firestore_v1/async_collection.py | 9 +++ tests/system/test_system.py | 46 ++++++++++++ tests/system/test_system_async.py | 45 ++++++++++++ tests/system/util/bootstrap_vector_index.py | 42 +++++++++-- tests/unit/v1/test_async_vector_query.py | 70 ++++++++++++++++++- 5 files changed, 206 insertions(+), 6 deletions(-) diff --git a/google/cloud/firestore_v1/async_collection.py b/google/cloud/firestore_v1/async_collection.py index 7032b1bdc..77761f2ad 100644 --- a/google/cloud/firestore_v1/async_collection.py +++ b/google/cloud/firestore_v1/async_collection.py @@ -23,6 +23,7 @@ async_aggregation, async_document, async_query, + async_vector_query, transaction, ) from google.cloud.firestore_v1.base_collection import ( @@ -81,6 +82,14 @@ def _aggregation_query(self) -> async_aggregation.AsyncAggregationQuery: """ return async_aggregation.AsyncAggregationQuery(self._query()) + def _vector_query(self) -> async_vector_query.AsyncVectorQuery: + """AsyncVectorQuery factory. + + Returns: + :class:`~google.cloud.firestore_v1.async_vector_query.AsyncVectorQuery` + """ + return async_vector_query.AsyncVectorQuery(self._query()) + async def _chunkify(self, chunk_size: int): async for page in self._query()._chunkify(chunk_size): yield page diff --git a/tests/system/test_system.py b/tests/system/test_system.py index 87cd89d3e..67fab710c 100644 --- a/tests/system/test_system.py +++ b/tests/system/test_system.py @@ -177,6 +177,29 @@ def on_snapshot(docs, changes, read_time): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_vector_search_collection(client, database): + # Documents and Indexs are a manual step from util/boostrap_vector_index.py + collection_id = "vector_search" + collection = client.collection(collection_id) + + vector_query = collection.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=1, + ) + returned = vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_vector_search_collection_with_filter(client, database): + # Documents and Indexs are a manual step from util/boostrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) @@ -198,6 +221,29 @@ def test_vector_search_collection(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) def test_vector_search_collection_group(client, database): + # Documents and Indexs are a manual step from util/boostrap_vector_index.py + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=1, + ) + returned = vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +def test_vector_search_collection_group_with_filter(client, database): + # Documents and Indexs are a manual step from util/boostrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) diff --git a/tests/system/test_system_async.py b/tests/system/test_system_async.py index 696f5a6f7..4f021a1b4 100644 --- a/tests/system/test_system_async.py +++ b/tests/system/test_system_async.py @@ -342,6 +342,28 @@ async def test_document_update_w_int_field(client, cleanup, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_vector_search_collection(client, database): + # Documents and Indexs are a manual step from util/boostrap_vector_index.py + collection_id = "vector_search" + collection = client.collection(collection_id) + vector_query = collection.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + limit=1, + distance_measure=DistanceMeasure.EUCLIDEAN, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection_with_filter(client, database): + # Documents and Indexs are a manual step from util/boostrap_vector_index.py collection_id = "vector_search" collection = client.collection(collection_id) vector_query = collection.where("color", "==", "red").find_nearest( @@ -362,6 +384,29 @@ async def test_vector_search_collection(client, database): @pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") @pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) async def test_vector_search_collection_group(client, database): + # Documents and Indexs are a manual step from util/boostrap_vector_index.py + collection_id = "vector_search" + collection_group = client.collection_group(collection_id) + + vector_query = collection_group.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=DistanceMeasure.EUCLIDEAN, + limit=1, + ) + returned = await vector_query.get() + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == { + "embedding": Vector([1.0, 2.0, 3.0]), + "color": "red", + } + + +@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data") +@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True) +async def test_vector_search_collection_group_with_filter(client, database): + # Documents and Indexs are a manual step from util/boostrap_vector_index.py collection_id = "vector_search" collection_group = client.collection_group(collection_id) diff --git a/tests/system/util/bootstrap_vector_index.py b/tests/system/util/bootstrap_vector_index.py index 1e88202b5..b5542534d 100644 --- a/tests/system/util/bootstrap_vector_index.py +++ b/tests/system/util/bootstrap_vector_index.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """A script to bootstrap vector data and vector index for system tests.""" +from google.api_core.client_options import ClientOptions from google.cloud.client import ClientWithProject # type: ignore from google.cloud.firestore import Client @@ -60,6 +61,21 @@ def _init_admin_api(self): return firestore_admin_client.FirestoreAdminClient(transport=self._transport) def create_vector_index(self, parent): + self._firestore_admin_api.create_index( + parent=parent, + index=Index( + query_scope=Index.QueryScope.COLLECTION, + fields=[ + Index.IndexField( + field_path="embedding", + vector_config=Index.IndexField.VectorConfig( + dimension=3, flat=Index.IndexField.VectorConfig.FlatIndex() + ), + ), + ], + ), + ) + self._firestore_admin_api.create_index( parent=parent, index=Index( @@ -79,6 +95,21 @@ def create_vector_index(self, parent): ), ) + self._firestore_admin_api.create_index( + parent=parent, + index=Index( + query_scope=Index.QueryScope.COLLECTION_GROUP, + fields=[ + Index.IndexField( + field_path="embedding", + vector_config=Index.IndexField.VectorConfig( + dimension=3, flat=Index.IndexField.VectorConfig.FlatIndex() + ), + ), + ], + ), + ) + self._firestore_admin_api.create_index( parent=parent, index=Index( @@ -103,13 +134,16 @@ def create_vector_documents(client, collection_id): document1 = client.document(collection_id, "doc1") document2 = client.document(collection_id, "doc2") document3 = client.document(collection_id, "doc3") - document1.create({"embedding": Vector([1.0, 2.0, 3.0]), "color": "red"}) - document2.create({"embedding": Vector([2.0, 2.0, 3.0]), "color": "red"}) - document3.create({"embedding": Vector([3.0, 4.0, 5.0]), "color": "yellow"}) + document1.set({"embedding": Vector([1.0, 2.0, 3.0]), "color": "red"}) + document2.set({"embedding": Vector([2.0, 2.0, 3.0]), "color": "red"}) + document3.set({"embedding": Vector([3.0, 4.0, 5.0]), "color": "yellow"}) def main(): - client = Client(project=PROJECT_ID, database=DATABASE_ID) + client_options = ClientOptions(api_endpoint=TARGET_HOSTNAME) + client = Client( + project=PROJECT_ID, database=DATABASE_ID, client_options=client_options + ) create_vector_documents(client=client, collection_id=COLLECTION_ID) admin_client = FirestoreAdminClient(project=PROJECT_ID) admin_client.create_vector_index( diff --git a/tests/unit/v1/test_async_vector_query.py b/tests/unit/v1/test_async_vector_query.py index 69e855b53..8b2a95a26 100644 --- a/tests/unit/v1/test_async_vector_query.py +++ b/tests/unit/v1/test_async_vector_query.py @@ -45,6 +45,72 @@ def _expected_pb(parent, vector_field, vector, distance_type, limit): return expected_pb +@pytest.mark.parametrize( + "distance_measure, expected_distance", + [ + ( + DistanceMeasure.EUCLIDEAN, + StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN, + ), + (DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE), + ( + DistanceMeasure.DOT_PRODUCT, + StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT, + ), + ], +) +@pytest.mark.asyncio +async def test_async_vector_query(distance_measure, expected_distance): + # Create a minimal fake GAPIC. + firestore_api = AsyncMock(spec=["run_query"]) + client = make_async_client() + client._firestore_api_internal = firestore_api + + # Make a **real** collection reference as parent. + parent = client.collection("dee") + parent_path, expected_prefix = parent._parent_info() + + data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])} + response_pb1 = _make_query_response( + name="{}/test_doc".format(expected_prefix), data=data + ) + + kwargs = make_retry_timeout_kwargs(retry=None, timeout=None) + + # Execute the vector query and check the response. + firestore_api.run_query.return_value = AsyncIter([response_pb1]) + + vector_async_query = parent.find_nearest( + vector_field="embedding", + query_vector=Vector([1.0, 2.0, 3.0]), + distance_measure=distance_measure, + limit=5, + ) + + returned = await vector_async_query.get(transaction=_transaction(client), **kwargs) + assert isinstance(returned, list) + assert len(returned) == 1 + assert returned[0].to_dict() == data + + expected_pb = _expected_pb( + parent=parent, + vector_field="embedding", + vector=Vector([1.0, 2.0, 3.0]), + distance_type=expected_distance, + limit=5, + ) + + firestore_api.run_query.assert_called_once_with( + request={ + "parent": parent_path, + "structured_query": expected_pb, + "transaction": _TXN_ID, + }, + metadata=client._rpc_metadata, + **kwargs, + ) + + @pytest.mark.parametrize( "distance_measure, expected_distance", [ @@ -84,14 +150,14 @@ async def test_async_vector_query_with_filter(distance_measure, expected_distanc # Execute the vector query and check the response. firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2]) - vector_async__query = query.where("snooze", "==", 10).find_nearest( + vector_async_query = query.where("snooze", "==", 10).find_nearest( vector_field="embedding", query_vector=Vector([1.0, 2.0, 3.0]), distance_measure=distance_measure, limit=5, ) - returned = await vector_async__query.get(transaction=_transaction(client), **kwargs) + returned = await vector_async_query.get(transaction=_transaction(client), **kwargs) assert isinstance(returned, list) assert len(returned) == 2 assert returned[0].to_dict() == data