Skip to content

Commit

Permalink
Fix async vector search from a collection
Browse files Browse the repository at this point in the history
  • Loading branch information
NickChittle committed Aug 13, 2024
1 parent d4956f4 commit 0cc2429
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 6 deletions.
9 changes: 9 additions & 0 deletions google/cloud/firestore_v1/async_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
async_aggregation,
async_document,
async_query,
async_vector_query,
transaction,
)
from google.cloud.firestore_v1.base_collection import (
Expand Down Expand Up @@ -81,6 +82,14 @@ def _aggregation_query(self) -> async_aggregation.AsyncAggregationQuery:
"""
return async_aggregation.AsyncAggregationQuery(self._query())

def _vector_query(self) -> async_vector_query.AsyncVectorQuery:
"""AsyncVectorQuery factory.
Returns:
:class:`~google.cloud.firestore_v1.async_vector_query.AsyncVectorQuery`
"""
return async_vector_query.AsyncVectorQuery(self._query())

async def _chunkify(self, chunk_size: int):
async for page in self._query()._chunkify(chunk_size):
yield page
Expand Down
46 changes: 46 additions & 0 deletions tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,29 @@ def on_snapshot(docs, changes, read_time):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)

vector_query = collection.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=1,
)
returned = vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_with_filter(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)

Expand All @@ -198,6 +221,29 @@ def test_vector_search_collection(client, database):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_group(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

vector_query = collection_group.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=1,
)
returned = vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_group_with_filter(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

Expand Down
45 changes: 45 additions & 0 deletions tests/system/test_system_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,28 @@ async def test_document_update_w_int_field(client, cleanup, database):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)
vector_query = collection.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
limit=1,
distance_measure=DistanceMeasure.EUCLIDEAN,
)
returned = await vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection_with_filter(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)
vector_query = collection.where("color", "==", "red").find_nearest(
Expand All @@ -362,6 +384,29 @@ async def test_vector_search_collection(client, database):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection_group(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

vector_query = collection_group.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=1,
)
returned = await vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection_group_with_filter(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

Expand Down
42 changes: 38 additions & 4 deletions tests/system/util/bootstrap_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script to bootstrap vector data and vector index for system tests."""
from google.api_core.client_options import ClientOptions
from google.cloud.client import ClientWithProject # type: ignore

from google.cloud.firestore import Client
Expand Down Expand Up @@ -60,6 +61,21 @@ def _init_admin_api(self):
return firestore_admin_client.FirestoreAdminClient(transport=self._transport)

def create_vector_index(self, parent):
self._firestore_admin_api.create_index(
parent=parent,
index=Index(
query_scope=Index.QueryScope.COLLECTION,
fields=[
Index.IndexField(
field_path="embedding",
vector_config=Index.IndexField.VectorConfig(
dimension=3, flat=Index.IndexField.VectorConfig.FlatIndex()
),
),
],
),
)

self._firestore_admin_api.create_index(
parent=parent,
index=Index(
Expand All @@ -79,6 +95,21 @@ def create_vector_index(self, parent):
),
)

self._firestore_admin_api.create_index(
parent=parent,
index=Index(
query_scope=Index.QueryScope.COLLECTION_GROUP,
fields=[
Index.IndexField(
field_path="embedding",
vector_config=Index.IndexField.VectorConfig(
dimension=3, flat=Index.IndexField.VectorConfig.FlatIndex()
),
),
],
),
)

self._firestore_admin_api.create_index(
parent=parent,
index=Index(
Expand All @@ -103,13 +134,16 @@ def create_vector_documents(client, collection_id):
document1 = client.document(collection_id, "doc1")
document2 = client.document(collection_id, "doc2")
document3 = client.document(collection_id, "doc3")
document1.create({"embedding": Vector([1.0, 2.0, 3.0]), "color": "red"})
document2.create({"embedding": Vector([2.0, 2.0, 3.0]), "color": "red"})
document3.create({"embedding": Vector([3.0, 4.0, 5.0]), "color": "yellow"})
document1.set({"embedding": Vector([1.0, 2.0, 3.0]), "color": "red"})
document2.set({"embedding": Vector([2.0, 2.0, 3.0]), "color": "red"})
document3.set({"embedding": Vector([3.0, 4.0, 5.0]), "color": "yellow"})


def main():
client = Client(project=PROJECT_ID, database=DATABASE_ID)
client_options = ClientOptions(api_endpoint=TARGET_HOSTNAME)
client = Client(
project=PROJECT_ID, database=DATABASE_ID, client_options=client_options
)
create_vector_documents(client=client, collection_id=COLLECTION_ID)
admin_client = FirestoreAdminClient(project=PROJECT_ID)
admin_client.create_vector_index(
Expand Down
70 changes: 68 additions & 2 deletions tests/unit/v1/test_async_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,72 @@ def _expected_pb(parent, vector_field, vector, distance_type, limit):
return expected_pb


@pytest.mark.parametrize(
"distance_measure, expected_distance",
[
(
DistanceMeasure.EUCLIDEAN,
StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN,
),
(DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE),
(
DistanceMeasure.DOT_PRODUCT,
StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT,
),
],
)
@pytest.mark.asyncio
async def test_async_vector_query(distance_measure, expected_distance):
# Create a minimal fake GAPIC.
firestore_api = AsyncMock(spec=["run_query"])
client = make_async_client()
client._firestore_api_internal = firestore_api

# Make a **real** collection reference as parent.
parent = client.collection("dee")
parent_path, expected_prefix = parent._parent_info()

data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])}
response_pb1 = _make_query_response(
name="{}/test_doc".format(expected_prefix), data=data
)

kwargs = make_retry_timeout_kwargs(retry=None, timeout=None)

# Execute the vector query and check the response.
firestore_api.run_query.return_value = AsyncIter([response_pb1])

vector_async_query = parent.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=distance_measure,
limit=5,
)

returned = await vector_async_query.get(transaction=_transaction(client), **kwargs)
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == data

expected_pb = _expected_pb(
parent=parent,
vector_field="embedding",
vector=Vector([1.0, 2.0, 3.0]),
distance_type=expected_distance,
limit=5,
)

firestore_api.run_query.assert_called_once_with(
request={
"parent": parent_path,
"structured_query": expected_pb,
"transaction": _TXN_ID,
},
metadata=client._rpc_metadata,
**kwargs,
)


@pytest.mark.parametrize(
"distance_measure, expected_distance",
[
Expand Down Expand Up @@ -84,14 +150,14 @@ async def test_async_vector_query_with_filter(distance_measure, expected_distanc
# Execute the vector query and check the response.
firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2])

vector_async__query = query.where("snooze", "==", 10).find_nearest(
vector_async_query = query.where("snooze", "==", 10).find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=distance_measure,
limit=5,
)

returned = await vector_async__query.get(transaction=_transaction(client), **kwargs)
returned = await vector_async_query.get(transaction=_transaction(client), **kwargs)
assert isinstance(returned, list)
assert len(returned) == 2
assert returned[0].to_dict() == data
Expand Down

0 comments on commit 0cc2429

Please sign in to comment.