Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: support async vector search from a collection #949

Merged
merged 3 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions google/cloud/firestore_v1/async_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
async_aggregation,
async_document,
async_query,
async_vector_query,
transaction,
)
from google.cloud.firestore_v1.base_collection import (
Expand Down Expand Up @@ -81,6 +82,14 @@ def _aggregation_query(self) -> async_aggregation.AsyncAggregationQuery:
"""
return async_aggregation.AsyncAggregationQuery(self._query())

def _vector_query(self) -> async_vector_query.AsyncVectorQuery:
"""AsyncVectorQuery factory.

Returns:
:class:`~google.cloud.firestore_v1.async_vector_query.AsyncVectorQuery`
"""
return async_vector_query.AsyncVectorQuery(self._query())

async def _chunkify(self, chunk_size: int):
async for page in self._query()._chunkify(chunk_size):
yield page
Expand Down
46 changes: 46 additions & 0 deletions tests/system/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,29 @@ def on_snapshot(docs, changes, read_time):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)

vector_query = collection.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=1,
)
returned = vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_with_filter(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)

Expand All @@ -198,6 +221,29 @@ def test_vector_search_collection(client, database):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_group(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

vector_query = collection_group.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=1,
)
returned = vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
def test_vector_search_collection_group_with_filter(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

Expand Down
45 changes: 45 additions & 0 deletions tests/system/test_system_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,28 @@ async def test_document_update_w_int_field(client, cleanup, database):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)
vector_query = collection.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
limit=1,
distance_measure=DistanceMeasure.EUCLIDEAN,
)
returned = await vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection_with_filter(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection = client.collection(collection_id)
vector_query = collection.where("color", "==", "red").find_nearest(
Expand All @@ -362,6 +384,29 @@ async def test_vector_search_collection(client, database):
@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection_group(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

vector_query = collection_group.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=DistanceMeasure.EUCLIDEAN,
limit=1,
)
returned = await vector_query.get()
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == {
"embedding": Vector([1.0, 2.0, 3.0]),
"color": "red",
}


@pytest.mark.skipif(FIRESTORE_EMULATOR, reason="Require index and seed data")
@pytest.mark.parametrize("database", [None, FIRESTORE_OTHER_DB], indirect=True)
async def test_vector_search_collection_group_with_filter(client, database):
# Documents and Indexs are a manual step from util/boostrap_vector_index.py
collection_id = "vector_search"
collection_group = client.collection_group(collection_id)

Expand Down
42 changes: 38 additions & 4 deletions tests/system/util/bootstrap_vector_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""A script to bootstrap vector data and vector index for system tests."""
from google.api_core.client_options import ClientOptions
from google.cloud.client import ClientWithProject # type: ignore

from google.cloud.firestore import Client
Expand Down Expand Up @@ -60,6 +61,21 @@ def _init_admin_api(self):
return firestore_admin_client.FirestoreAdminClient(transport=self._transport)

def create_vector_index(self, parent):
self._firestore_admin_api.create_index(
parent=parent,
index=Index(
query_scope=Index.QueryScope.COLLECTION,
fields=[
Index.IndexField(
field_path="embedding",
vector_config=Index.IndexField.VectorConfig(
dimension=3, flat=Index.IndexField.VectorConfig.FlatIndex()
),
),
],
),
)

self._firestore_admin_api.create_index(
parent=parent,
index=Index(
Expand All @@ -79,6 +95,21 @@ def create_vector_index(self, parent):
),
)

self._firestore_admin_api.create_index(
parent=parent,
index=Index(
query_scope=Index.QueryScope.COLLECTION_GROUP,
fields=[
Index.IndexField(
field_path="embedding",
vector_config=Index.IndexField.VectorConfig(
dimension=3, flat=Index.IndexField.VectorConfig.FlatIndex()
),
),
],
),
)

self._firestore_admin_api.create_index(
parent=parent,
index=Index(
Expand All @@ -103,13 +134,16 @@ def create_vector_documents(client, collection_id):
document1 = client.document(collection_id, "doc1")
document2 = client.document(collection_id, "doc2")
document3 = client.document(collection_id, "doc3")
document1.create({"embedding": Vector([1.0, 2.0, 3.0]), "color": "red"})
document2.create({"embedding": Vector([2.0, 2.0, 3.0]), "color": "red"})
document3.create({"embedding": Vector([3.0, 4.0, 5.0]), "color": "yellow"})
document1.set({"embedding": Vector([1.0, 2.0, 3.0]), "color": "red"})
document2.set({"embedding": Vector([2.0, 2.0, 3.0]), "color": "red"})
document3.set({"embedding": Vector([3.0, 4.0, 5.0]), "color": "yellow"})


def main():
client = Client(project=PROJECT_ID, database=DATABASE_ID)
client_options = ClientOptions(api_endpoint=TARGET_HOSTNAME)
client = Client(
project=PROJECT_ID, database=DATABASE_ID, client_options=client_options
)
create_vector_documents(client=client, collection_id=COLLECTION_ID)
admin_client = FirestoreAdminClient(project=PROJECT_ID)
admin_client.create_vector_index(
Expand Down
70 changes: 68 additions & 2 deletions tests/unit/v1/test_async_vector_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,72 @@ def _expected_pb(parent, vector_field, vector, distance_type, limit):
return expected_pb


@pytest.mark.parametrize(
"distance_measure, expected_distance",
[
(
DistanceMeasure.EUCLIDEAN,
StructuredQuery.FindNearest.DistanceMeasure.EUCLIDEAN,
),
(DistanceMeasure.COSINE, StructuredQuery.FindNearest.DistanceMeasure.COSINE),
(
DistanceMeasure.DOT_PRODUCT,
StructuredQuery.FindNearest.DistanceMeasure.DOT_PRODUCT,
),
],
)
@pytest.mark.asyncio
async def test_async_vector_query(distance_measure, expected_distance):
# Create a minimal fake GAPIC.
firestore_api = AsyncMock(spec=["run_query"])
client = make_async_client()
client._firestore_api_internal = firestore_api

# Make a **real** collection reference as parent.
parent = client.collection("dee")
parent_path, expected_prefix = parent._parent_info()

data = {"snooze": 10, "embedding": Vector([1.0, 2.0, 3.0])}
response_pb1 = _make_query_response(
name="{}/test_doc".format(expected_prefix), data=data
)

kwargs = make_retry_timeout_kwargs(retry=None, timeout=None)

# Execute the vector query and check the response.
firestore_api.run_query.return_value = AsyncIter([response_pb1])

vector_async_query = parent.find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=distance_measure,
limit=5,
)

returned = await vector_async_query.get(transaction=_transaction(client), **kwargs)
assert isinstance(returned, list)
assert len(returned) == 1
assert returned[0].to_dict() == data

expected_pb = _expected_pb(
parent=parent,
vector_field="embedding",
vector=Vector([1.0, 2.0, 3.0]),
distance_type=expected_distance,
limit=5,
)

firestore_api.run_query.assert_called_once_with(
request={
"parent": parent_path,
"structured_query": expected_pb,
"transaction": _TXN_ID,
},
metadata=client._rpc_metadata,
**kwargs,
)


@pytest.mark.parametrize(
"distance_measure, expected_distance",
[
Expand Down Expand Up @@ -84,14 +150,14 @@ async def test_async_vector_query_with_filter(distance_measure, expected_distanc
# Execute the vector query and check the response.
firestore_api.run_query.return_value = AsyncIter([response_pb1, response_pb2])

vector_async__query = query.where("snooze", "==", 10).find_nearest(
vector_async_query = query.where("snooze", "==", 10).find_nearest(
vector_field="embedding",
query_vector=Vector([1.0, 2.0, 3.0]),
distance_measure=distance_measure,
limit=5,
)

returned = await vector_async__query.get(transaction=_transaction(client), **kwargs)
returned = await vector_async_query.get(transaction=_transaction(client), **kwargs)
assert isinstance(returned, list)
assert len(returned) == 2
assert returned[0].to_dict() == data
Expand Down
Loading