diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 00000000..55c13bdd --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,36 @@ +name: Run Pytest + +on: + push: + branches: + - master + pull_request: + branches: + - master + +jobs: + test: + env: + ASTRA_DB_ID: ${{ secrets.ASTRA_DB_ID }} + ASTRA_DB_REGION: ${{ secrets.ASTRA_DB_REGION }} + ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }} + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.11 + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pytest + pip install -r requirements.txt + + - name: Run pytest + run: | + PYTHONPATH=. pytest -s diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml deleted file mode 100644 index a3d8ee3c..00000000 --- a/.github/workflows/tests.yml +++ /dev/null @@ -1,62 +0,0 @@ -# Copyright DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -name: Tests -on: - pull_request: - branches: - - master - schedule: - - cron: "0 0 * * *" -jobs: - tests: - env: - ASTRA_DB_KEYSPACE: ${{ secrets.ASTRA_DB_KEYSPACE }} - ASTRA_DB_ID: ${{ secrets.ASTRA_DB_ID }} - ASTRA_DB_REGION: ${{ secrets.ASTRA_DB_REGION }} - ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.ASTRA_DB_APPLICATION_TOKEN }} - STARGATE_BASE_URL: ${{ secrets.STARGATE_BASE_URL }} - STARGATE_AUTH_URL: ${{ secrets.STARGATE_AUTH_URL }} - STARGATE_USERNAME: ${{ secrets.STARGATE_USERNAME }} - STARGATE_PASSWORD: ${{ secrets.STARGATE_PASSWORD }} - runs-on: ubuntu-latest - services: - stargate: - image: stargateio/stargate-dse-68:v1.0.52 - env: - CLUSTER_NAME: stargate - CLUSTER_VERSION: 6.8 - DEVELOPER_MODE: true - DSE: 1 - ports: - - 8080:8080 - - 8081:8081 - - 8082:8082 - - 9042:9042 - steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.7 - uses: actions/setup-python@v2 - with: - python-version: 3.9 - - name: Install Dependencies - run: | - python -m pip install --upgrade pip - pip install -r requirements.txt - - name: wait for stargate - run: | - sleep 1m - - name: Test with PyTest - run: | - pytest tests/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..107e5ea3 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,7 @@ +repos: +- repo: https://github.com/psf/black + rev: main # Replace with your desired version or use 'stable' + hooks: + - id: black + args: ['--quiet'] + diff --git a/README.md b/README.md index 65ceccc1..1114e24c 100644 --- a/README.md +++ b/README.md @@ -1,144 +1,127 @@ ## AstraPy -[![Actions Status](https://github.com/datastax/astrapy/workflows/Tests/badge.svg)](https://github.com/datastax/astrapy/actions) +[![Actions Status](https://github.com/datastax/astrapy/workflows/Tests/badge.svg)](https://github.com/datastax/astrapy/actions) AstraPy is a Pythonic SDK for [DataStax Astra](https://astra.datastax.com) and [Stargate](https://stargate.io/) ### Resources + - [DataStax Astra](https://astra.datastax.com) - [Stargate](https://stargate.io/) - ### Getting Started + Install AstraPy + ```shell pip install astrapy ``` Setup your Astra client -```python -from astrapy.client import create_astra_client - -astra_client = create_astra_client(astra_database_id=ASTRA_DB_ID, - astra_database_region=ASTRA_DB_REGION, - astra_application_token=ASTRA_DB_APPLICATION_TOKEN) -``` -Take a look at the [client tests](https://github.com/datastax/astrapy/blob/master/tests/astrapy/test_client.py) and the [collection tests](https://github.com/datastax/astrapy/blob/master/tests/astrapy/test_collections.py) for specific endpoint examples. +Create a .env file with the appropriate values, or use the 'astra' cli to do the same. -#### Using the Ops Client -You can use the Ops client to work the with Astra DevOps API. [API Reference](https://docs.datastax.com/en/astra/docs/_attachments/devopsv2.html) -```python -# astra_client created above -# create a keyspace using the Ops API -astra_client.ops.create_keyspace(database=ASTRA_DB_ID, keyspace=KEYSPACE_NAME) +```bash +ASTRA_DB_KEYSPACE="" +ASTRA_DB_APPLICATION_TOKEN="" +ASTRA_DB_REGION="" +ASTRA_DB_ID= ``` -#### Using the REST Client -You can use the REST client to work with the Astra REST API. [API Reference](https://docs.datastax.com/en/astra/docs/_attachments/restv2.html#tag/Data) -```python -# astra_client created above -# search a table -res = astra_client.rest.search_table(keyspace=ASTRA_DB_KEYSPACE, - table=TABLE_NAME, - query={"firstname": {"$eq": "Cliff"}}) -print(res["count"]) # number of results -print(res["data"]) # list of rows -``` +Load the variables in and then create the client. This collections client can make non-vector and vector calls, depending on the call configuration. -#### Using the Schemas Client -You can use the Schemas client to work with the Astra Schemas API. [API Reference](https://docs.datastax.com/en/astra/docs/_attachments/restv2.html#tag/Schemas) ```python -# astra_client created above -# create a table -astra_client.schemas.create_table(keyspace=ASTRA_DB_KEYSPACE, table_definition={ - "name": "my_table", - "columnDefinitions": [ - { - "name": "firstname", - "typeDefinition": "text" - }, - { - "name": "lastname", - "typeDefinition": "text" - }, - { - "name": "favorite_color", - "typeDefinition": "text", - } - ], - "primaryKey": { - "partitionKey": [ - "firstname" - ], - "clusteringKey": [ - "lastname" - ] +import os +import sys + +from astrapy.db import AstraDB, AstraDBCollection +from astrapy.ops import AstraDBOps + +# First, we work with devops +api_key = os.getenv("ASTRA_DB_APPLICATION_TOKEN") +astra_ops = AstraDBOps(api_key) + +# Define a database to create +database_definition = { + "name": "vector_test", + "tier": "serverless", + "cloudProvider": "GCP", + "keyspace": os.getenv("ASTRA_DB_KEYSPACE", "default_keyspace"), + "region": os.getenv("ASTRA_DB_REGION", None), + "capacityUnits": 1, + "user": "example", + "password": api_key, + "dbType": "vector", +} + +# Create the database +create_result = astra_ops.create_database(database_definition=database_definition) + +# Grab the new information from the database +database_id = create_result["id"] +database_region = astra_ops.get_database()[0]["info"]["region"] +database_base_url = "apps.astra.datastax.com" + +# Build the endpoint URL: +api_endpoint = f"https://{database_id}-{database_region}.{database_base_url}" + +# Initialize our vector db +astra_db = AstraDB(api_key=api_key, api_endpoint=api_endpoint) + +# Possible Operations +astra_db.create_collection(collection_name="collection_test_delete", size=5) +astra_db.delete_collection(collection_name="collection_test_delete") +astra_db.create_collection(collection_name="collection_test", size=5) + +# Collections +astra_db_collection = AstraDBCollection( + collection_name="collection_test", astra_db=astra_db +) +# Or... +astra_db_collection = AstraDBCollection( + collection_name="collection_test", api_key=api_key, api_endpoint=api_endpoint +) + +astra_db_collection.insert_one( + { + "_id": "5", + "name": "Coded Cleats Copy", + "description": "ChatGPT integrated sneakers that talk to you", + "$vector": [0.25, 0.25, 0.25, 0.25, 0.25], } -}) +) + +astra_db_collection.find_one({"name": "potato"}) +astra_db_collection.find_one({"name": "Coded Cleats Copy"}) ``` +#### More Information -#### Using the Collections Client -You can use the Collections client to work with the Astra Document API. [API Reference](https://docs.datastax.com/en/astra/docs/_attachments/docv2.html) -```python -# astra_client created above -# create multiple documents using the collections API -my_collection = astra_client.collections.namespace(ASTRA_DB_KEYSPACE).collection(COLLECTION_NAME) -my_collection.batch(documents=[ - { - "documentId": "1", - "first_name": "Dang", - "last_name": "Son", - }, { - "documentId": "2", - "first_name": "Yep", - "last_name": "Boss", - }]) -``` +Check out the [notebook](https://colab.research.google.com/github/synedra/astra_vector_examples/blob/main/notebook/vector.ipynb#scrollTo=f04a1806) which has examples for finding and inserting information into the database, including vector commands. -#### Using the GraphQL Client -You can use the GraphQL client to work with the Astra GraphQL API. [API Reference](https://docs.datastax.com/en/astra/docs/using-the-astra-graphql-api.html) -```python -# astra_client created above -# create multiple documents using the GraphQL API -astra_client.gql.execute(keyspace=ASTRA_DB_KEYSPACE, query=""" - mutation insert2Books { - moby: insertbook(value: {title:"Moby Dick", author:"Herman Melville"}) { - value { - title - } - } - catch22: insertbook(value: {title:"Catch-22", author:"Joseph Heller"}) { - value { - title - } - } - } - """) -``` +Take a look at the [vector tests](https://github.com/datastax/astrapy/blob/master/tests/astrapy/test_collections.py) and the [collection tests](https://github.com/datastax/astrapy/blob/master/tests/astrapy/test_collections.py) for specific endpoint examples. + +#### Using the Ops Client + +You can use the Ops client to work with the Astra DevOps API. Check the [devops tests](https://github.com/datastax/astrapy/blob/master/tests/astrapy/test_devops.py) + +### For Developers + +#### Testing + +Ensure you provide all required environment variables: -#### Using the HTTP Client -You can use the HTTP client to work with any Astra/Stargate endpoint directly. [API Reference](https://docs.datastax.com/en/astra/docs/api.html) -```python -# astra_client created above -# create a document on Astra using the Document API -astra_client._rest_client.request( - method="PUT", - path=f"/api/rest/v2/namespaces/my_namespace/collections/my_collection/user_1", - json_data={ - "first_name": "Cliff", - "last_name": "Wicklow", - "emails": ["cliff.wicklow@example.com"], - }) +``` +export ASTRA_DB_ID="..." +export ASTRA_DB_REGION="..." +export ASTRA_DB_APPLICATION_TOKEN="..." +export ASTRA_DB_KEYSPACE="..." +export ASTRA_CLIENT_ID="..." +export ASTRA_CLIENT_SECRET="..." ``` -#### Connecting to a local Stargate Instance -```python -from astrapy.client import create_astra_client +then you can run: -stargate_client = create_astra_client(base_url=http://localhost:8082, - auth_base_url=http://localhost:8081/v1/auth, - username=****, - password=****) +``` +PYTHONPATH=. pytest ``` diff --git a/astrapy/__init__.py b/astrapy/__init__.py index 33284e0a..2c9ca172 100644 --- a/astrapy/__init__.py +++ b/astrapy/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/astrapy/client.py b/astrapy/client.py deleted file mode 100644 index 541878ed..00000000 --- a/astrapy/client.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from astrapy.collections import create_client -from astrapy.endpoints.rest import AstraRest -from astrapy.endpoints.schemas import AstraSchemas -from astrapy.endpoints.ops import AstraOps -from astrapy.endpoints.graphql import AstraGraphQL -import logging - -logger = logging.getLogger(__name__) - - -class AstraClient(): - def __init__(self, astra_collections_client=None): - self.collections = astra_collections_client - self._rest_client = astra_collections_client.astra_client - self.rest = AstraRest(self._rest_client) - self.ops = AstraOps(self._rest_client) - self.schemas = AstraSchemas(self._rest_client) - self.gql = AstraGraphQL(self._rest_client) - - -def create_astra_client(astra_database_id=None, - astra_database_region=None, - astra_application_token=None, - base_url=None, - auth_base_url=None, - username=None, - password=None, - debug=False): - astra_collections_client = create_client(astra_database_id=astra_database_id, - astra_database_region=astra_database_region, - astra_application_token=astra_application_token, - base_url=base_url, - auth_base_url=auth_base_url, - username=username, - password=password, - debug=debug) - return AstraClient(astra_collections_client=astra_collections_client) diff --git a/astrapy/collections.py b/astrapy/collections.py deleted file mode 100644 index b95de552..00000000 --- a/astrapy/collections.py +++ /dev/null @@ -1,186 +0,0 @@ -# Copyright DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from astrapy.rest import http_methods -from astrapy.rest import create_client as create_astra_client -import logging -import json - -logger = logging.getLogger(__name__) - -DEFAULT_PAGE_SIZE = 20 -DEFAULT_BASE_PATH = "/api/rest/v2/namespaces" - - -class AstraCollection(): - def __init__(self, astra_client=None, namespace_name=None, collection_name=None): - self.astra_client = astra_client - self.namespace_name = namespace_name - self.collection_name = collection_name - self.base_path = f"{DEFAULT_BASE_PATH}/{namespace_name}/collections/{collection_name}" - if astra_client.auth_base_url is not None: - self.base_path = f"/v2/namespaces/{namespace_name}/collections/{collection_name}" - - def _get(self, path=None, options=None): - full_path = f"{self.base_path}/{path}" if path else self.base_path - response = self.astra_client.request(method=http_methods.GET, - path=full_path, - url_params=options) - if isinstance(response, dict): - return response["data"] - return None - - def _put(self, path=None, document=None): - return self.astra_client.request(method=http_methods.PUT, - path=f"{self.base_path}/{path}", - json_data=document) - - def upgrade(self): - return self.astra_client.request(method=http_methods.POST, - path=f"{self.base_path}/upgrade") - - def get_schema(self): - return self.astra_client.request(method=http_methods.GET, - path=f"{self.base_path}/json-schema") - - def create_schema(self, schema=None): - return self.astra_client.request(method=http_methods.PUT, - path=f"{self.base_path}/json-schema", - json_data=schema) - - def update_schema(self, schema=None): - return self.astra_client.request(method=http_methods.PUT, - path=f"{self.base_path}/json-schema", - json_data=schema) - - def get(self, path=None): - return self._get(path=path) - - def find(self, query=None, options=None): - options = {} if options is None else options - request_params = {"where": json.dumps( - query), "page-size": DEFAULT_PAGE_SIZE} - request_params.update(options) - response = self.astra_client.request(method=http_methods.GET, - path=self.base_path, - url_params=request_params) - if isinstance(response, dict): - return response - return None - - def find_one(self, query=None, options=None): - options = {} if options is None else options - request_params = {"where": json.dumps(query), "page-size": 1} - request_params.update(options) - response = self._get(path=None, options=request_params) - if response is not None: - keys = list(response.keys()) - if(len(keys) == 0): - return None - return response[keys[0]] - return None - - def create(self, path=None, document=None): - if path is not None: - return self._put(path=path, document=document) - return self.astra_client.request(method=http_methods.POST, - path=self.base_path, - json_data=document) - - def update(self, path, document): - return self.astra_client.request(method=http_methods.PATCH, - path=f"{self.base_path}/{path}", - json_data=document) - - def replace(self, path, document): - return self._put(path=path, document=document) - - def delete(self, path): - return self.astra_client.request(method=http_methods.DELETE, - path=f"{self.base_path}/{path}") - - def batch(self, documents=None, id_path=""): - if id_path == "": - id_path = "documentId" - return self.astra_client.request(method=http_methods.POST, - path=f"{self.base_path}/batch", - json_data=documents, - url_params={"id-path": id_path}) - - def push(self, path=None, value=None): - json_data = {"operation": "$push", "value": value} - res = self.astra_client.request(method=http_methods.POST, - path=f"{self.base_path}/{path}/function", - json_data=json_data) - return res.get("data") - - def pop(self, path=None): - json_data = {"operation": "$pop"} - res = self.astra_client.request(method=http_methods.POST, - path=f"{self.base_path}/{path}/function", - json_data=json_data) - return res.get("data") - - -class AstraNamespace(): - def __init__(self, astra_client=None, namespace_name=None): - self.astra_client = astra_client - self.namespace_name = namespace_name - self.base_path = f"{DEFAULT_BASE_PATH}/{namespace_name}" - - def collection(self, collection_name): - return AstraCollection(astra_client=self.astra_client, - namespace_name=self.namespace_name, - collection_name=collection_name) - - def get_collections(self): - res = self.astra_client.request(method=http_methods.GET, - path=f"{self.base_path}/collections") - return res.get("data") - - def create_collection(self, name=""): - return self.astra_client.request(method=http_methods.POST, - path=f"{self.base_path}/collections", - json_data={"name": name}) - - def delete_collection(self, name=""): - return self.astra_client.request(method=http_methods.DELETE, - path=f"{self.base_path}/collections/{name}") - - -class AstraDocumentClient(): - def __init__(self, astra_client=None): - self.astra_client = astra_client - - def namespace(self, namespace_name): - return AstraNamespace(astra_client=self.astra_client, namespace_name=namespace_name) - - -def create_client(astra_database_id=None, - astra_database_region=None, - astra_application_token=None, - base_url=None, - auth_base_url=None, - username=None, - password=None, - debug=False): - astra_client = create_astra_client(astra_database_id=astra_database_id, - astra_database_region=astra_database_region, - astra_application_token=astra_application_token, - base_url=base_url, - auth_base_url=auth_base_url, - username=username, - password=password, - debug=debug) - return AstraDocumentClient(astra_client=astra_client) diff --git a/astrapy/db.py b/astrapy/db.py new file mode 100644 index 00000000..db24c56e --- /dev/null +++ b/astrapy/db.py @@ -0,0 +1,375 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from astrapy.defaults import ( + DEFAULT_AUTH_HEADER, + DEFAULT_KEYSPACE_NAME, + DEFAULT_BASE_PATH, +) +from astrapy.utils import make_payload, make_request, http_methods, parse_endpoint_url + +import logging +import json + +logger = logging.getLogger(__name__) + + +class AstraDBCollection: + def __init__( + self, + collection_name, + astra_db=None, + token=None, + api_endpoint=None, + namespace=None, + ): + if astra_db is None: + if token is None or api_endpoint is None: + raise AssertionError("Must provide token and api_endpoint") + + astra_db = AstraDB( + token=token, api_endpoint=api_endpoint, namespace=namespace + ) + + self.astra_db = astra_db + self.collection_name = collection_name + self.base_path = f"{self.astra_db.base_path}/{collection_name}" + + def _request(self, *args, skip_error_check=False, **kwargs): + response = make_request( + *args, + **kwargs, + base_url=self.astra_db.base_url, + auth_header=DEFAULT_AUTH_HEADER, + token=self.astra_db.token, + ) + responsebody = response.json() + + if not skip_error_check and "errors" in responsebody: + raise ValueError(json.dumps(responsebody["errors"])) + else: + return responsebody + + def _get(self, path=None, options=None): + full_path = f"{self.base_path}/{path}" if path else self.base_path + response = self._request( + method=http_methods.GET, path=full_path, url_params=options + ) + if isinstance(response, dict): + return response + return None + + def _post(self, path=None, document=None): + response = self._request( + method=http_methods.POST, path=f"{self.base_path}", json_data=document + ) + return response + + def get(self, path=None): + return self._get(path=path) + + def find(self, filter=None, projection=None, sort={}, options=None): + json_query = make_payload( + top_level="find", + filter=filter, + projection=projection, + options=options, + sort=sort, + ) + + response = self._post( + document=json_query, + ) + + return response + + @staticmethod + def paginate(*, method, options, **kwargs): + response0 = method(options=options, **kwargs) + next_page_state = response0["data"]["nextPageState"] + options0 = options + for document in response0["data"]["documents"]: + yield document + while next_page_state is not None: + options1 = {**options0, **{"pagingState": next_page_state}} + response1 = method(options=options1, **kwargs) + for document in response1["data"]["documents"]: + yield document + next_page_state = response1["data"]["nextPageState"] + + def paginated_find(self, filter=None, projection=None, sort=None, options=None): + return self.paginate( + method=self.find, + filter=filter, + projection=projection, + sort=sort, + options=options, + ) + + def pop(self, filter, update, options): + json_query = make_payload( + top_level="findOneAndUpdate", filter=filter, update=update, options=options + ) + + response = self._request( + method=http_methods.POST, + path=self.base_path, + json_data=json_query, + ) + + return response + + def push(self, filter, update, options): + json_query = make_payload( + top_level="findOneAndUpdate", filter=filter, update=update, options=options + ) + + response = self._request( + method=http_methods.POST, + path=self.base_path, + json_data=json_query, + ) + + return response + + def find_one_and_replace( + self, sort={}, filter=None, replacement=None, options=None + ): + json_query = make_payload( + top_level="findOneAndReplace", + filter=filter, + replacement=replacement, + options=options, + sort=sort, + ) + + response = self._request( + method=http_methods.POST, path=f"{self.base_path}", json_data=json_query + ) + + return response + + def find_one_and_update(self, sort={}, update=None, filter=None, options=None): + json_query = make_payload( + top_level="findOneAndUpdate", + filter=filter, + update=update, + options=options, + sort=sort, + ) + + response = self._request( + method=http_methods.POST, + path=f"{self.base_path}", + json_data=json_query, + ) + + return response + + def find_one(self, filter={}, projection={}, sort={}, options={}): + json_query = make_payload( + top_level="findOne", + filter=filter, + projection=projection, + options=options, + sort=sort, + ) + + response = self._post( + document=json_query, + ) + + return response + + def insert_one(self, document): + json_query = make_payload(top_level="insertOne", document=document) + + response = self._request( + method=http_methods.POST, path=self.base_path, json_data=json_query + ) + + return response + + def insert_many(self, documents, options=None, partial_failures_allowed=False): + json_query = make_payload( + top_level="insertMany", documents=documents, options=options + ) + + response = self._request( + method=http_methods.POST, + path=f"{self.base_path}", + json_data=json_query, + skip_error_check=partial_failures_allowed, + ) + + return response + + def update_one(self, filter, update): + json_query = make_payload(top_level="updateOne", filter=filter, update=update) + + response = self._request( + method=http_methods.POST, + path=f"{self.base_path}", + json_data=json_query, + ) + + return response + + def replace(self, path, document): + return self._put(path=path, document=document) + + def delete(self, id): + json_query = { + "deleteOne": { + "filter": {"_id": id}, + } + } + + response = self._request( + method=http_methods.POST, path=f"{self.base_path}", json_data=json_query + ) + + return response + + def delete_subdocument(self, id, subdoc): + json_query = { + "findOneAndUpdate": { + "filter": {"_id": id}, + "update": {"$unset": {subdoc: ""}}, + } + } + + response = self._request( + method=http_methods.POST, path=f"{self.base_path}", json_data=json_query + ) + + return response + + def upsert(self, document): + """ + Emulate an upsert operation for a single document, + whereby a document is inserted if its _id is new, or completely + replaces and existing one if that _id is already saved in the collection. + Returns: the _id of the inserted document. + """ + # Attempt to insert the given document + result = self.insert_one(document) + + # Check if we hit an error + if ( + "errors" in result + and "errorCode" in result["errors"][0] + and result["errors"][0]["errorCode"] == "DOCUMENT_ALREADY_EXISTS" + ): + # Now we attempt to update + result = self.find_one_and_replace( + filter={"_id": document["_id"]}, + replacement=document, + ) + upserted_id = result["data"]["document"]["_id"] + else: + upserted_id = result["status"]["insertedIds"][0] + + return upserted_id + + +class AstraDB: + def __init__( + self, + token=None, + api_endpoint=None, + namespace=None, + ): + if token is None or api_endpoint is None: + raise AssertionError("Must provide token and api_endpoint") + + if namespace is None: + logger.info( + f"ASTRA_DB_KEYSPACE is not set. Defaulting to '{DEFAULT_KEYSPACE_NAME}'" + ) + namespace = DEFAULT_KEYSPACE_NAME + + # Store the initial parameters + self.token = token + ( + self.database_id, + self.database_region, + self.database_domain, + ) = parse_endpoint_url(api_endpoint) + + # Set the Base URL for the API calls + self.base_url = ( + f"https://{self.database_id}-{self.database_region}.{self.database_domain}" + ) + self.base_path = f"{DEFAULT_BASE_PATH}/{namespace}" + + # Set the namespace parameter + self.namespace = namespace + + def _request(self, *args, skip_error_check=False, **kwargs): + response = make_request( + *args, + **kwargs, + base_url=self.base_url, + auth_header=DEFAULT_AUTH_HEADER, + token=self.token, + ) + + responsebody = response.json() + + if not skip_error_check and "errors" in responsebody: + raise ValueError(json.dumps(responsebody["errors"])) + else: + return responsebody + + return result + + def collection(self, collection_name): + return AstraDBCollection(collection_name=collection_name, astra_db=self) + + def get_collections(self): + response = self._request( + method=http_methods.POST, + path=self.base_path, + json_data={"findCollections": {}}, + ) + + return response + + def create_collection(self, size=None, options={}, function="", collection_name=""): + if size and not options: + options = {"vector": {"size": size}} + if function: + options["vector"]["function"] = function + if options: + jsondata = {"name": collection_name, "options": options} + else: + jsondata = {"name": collection_name} + + response = self._request( + method=http_methods.POST, + path=f"{self.base_path}", + json_data={"createCollection": jsondata}, + ) + + return response + + def delete_collection(self, collection_name=""): + response = self._request( + method=http_methods.POST, + path=f"{self.base_path}", + json_data={"deleteCollection": {"name": collection_name}}, + ) + + return response diff --git a/astrapy/defaults.py b/astrapy/defaults.py new file mode 100644 index 00000000..4ed56aca --- /dev/null +++ b/astrapy/defaults.py @@ -0,0 +1,10 @@ +DEFAULT_AUTH_PATH = "/api/rest/v1/auth" +DEFAULT_BASE_PATH = "/api/json/v1" + +DEFAULT_TIMEOUT = 30000 +DEFAULT_AUTH_HEADER = "X-Cassandra-Token" +DEFAULT_KEYSPACE_NAME = "default_keyspace" +DEFAULT_REGION = "us-east1" + +DEFAULT_DEV_OPS_PATH_PREFIX = "/v2" +DEFAULT_DEV_OPS_URL = "api.astra.datastax.com" diff --git a/astrapy/endpoints/__init__.py b/astrapy/endpoints/__init__.py index 33284e0a..2c9ca172 100644 --- a/astrapy/endpoints/__init__.py +++ b/astrapy/endpoints/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/astrapy/endpoints/graphql.py b/astrapy/endpoints/graphql.py deleted file mode 100644 index a90f87fa..00000000 --- a/astrapy/endpoints/graphql.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from gql import Client, gql -from gql.transport.requests import RequestsHTTPTransport -from gql.transport.requests import log as requests_logger - -requests_logger.setLevel(logging.WARNING) - - -class AstraGraphQL(): - - def __init__(self, client=None): - self.client = client - self.gql = gql - self.schema_gql_client = self.create_gql_client( - url=f"{client.base_url}/api/graphql-schema") - self.keyspace_clients = {} - - def create_gql_client(self, url=""): - headers = {'X-Cassandra-Token': self.client.astra_application_token} - transport = RequestsHTTPTransport(url=url, - verify=True, - retries=1, - headers=headers) - return Client(transport=transport, fetch_schema_from_transport=True) - - def get_keyspace_client(self, keyspace=""): - return self.create_gql_client(url=f"{self.client.base_url}/api/graphql/{keyspace}") - - def execute(self, query="", variables=None, keyspace=""): - gql_client = self.schema_gql_client - if(keyspace != ""): - gql_client = self.keyspace_clients.get(keyspace) - if(gql_client is None): - self.keyspace_clients[keyspace] = self.get_keyspace_client( - keyspace) - gql_client = self.keyspace_clients[keyspace] - return gql_client.execute(gql(query), variable_values=variables) diff --git a/astrapy/endpoints/ops.py b/astrapy/endpoints/ops.py index b41840ce..58487db5 100644 --- a/astrapy/endpoints/ops.py +++ b/astrapy/endpoints/ops.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from astrapy.rest import http_methods, DEFAULT_TIMEOUT +from astrapyjson.client import http_methods, DEFAULT_TIMEOUT import copy import requests @@ -20,239 +20,320 @@ PATH_PREFIX = "/v2" -class AstraOps(): - +class AstraOps: def __init__(self, client=None): self.client = copy.deepcopy(client) self.client.base_url = DEFAULT_HOST self.client.auth_header = "Authorization" - self.client.astra_application_token = f"Bearer {self.client.astra_application_token}" + self.client.astra_application_token = ( + f"Bearer {self.client.astra_application_token}" + ) def get_databases(self, options=None): options = {} if options is None else options - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/databases", - url_params=options) + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/databases", url_params=options + ) def create_database(self, database_definition=None): - r = requests.request(method=http_methods.POST, - url=f"{self.client.base_url}{PATH_PREFIX}/databases", - json=database_definition, - timeout=DEFAULT_TIMEOUT, - headers={self.client.auth_header: self.client.astra_application_token}) - if(r.status_code == 201): + r = requests.request( + method=http_methods.POST, + url=f"{self.client.base_url}{PATH_PREFIX}/databases", + json=database_definition, + timeout=DEFAULT_TIMEOUT, + headers={self.client.auth_header: self.client.astra_application_token}, + ) + if r.status_code == 201: return {"id": r.headers["Location"]} return None def terminate_database(self, database=""): - r = requests.request(method=http_methods.POST, - url=f"{self.client.base_url}{PATH_PREFIX}/databases/{database}/terminate", - timeout=DEFAULT_TIMEOUT, - headers={self.client.auth_header: self.client.astra_application_token}) - if(r.status_code == 202): + r = requests.request( + method=http_methods.POST, + url=f"{self.client.base_url}{PATH_PREFIX}/databases/{database}/terminate", + timeout=DEFAULT_TIMEOUT, + headers={self.client.auth_header: self.client.astra_application_token}, + ) + if r.status_code == 202: return database return None def get_database(self, database=""): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/databases/{database}") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/databases/{database}" + ) def create_keyspace(self, database="", keyspace=""): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/databases/{database}/keyspaces/{keyspace}") + return self.client.request( + method=http_methods.POST, + path=f"{PATH_PREFIX}/databases/{database}/keyspaces/{keyspace}", + ) def park_database(self, database=""): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/databases/{database}/park") + return self.client.request( + method=http_methods.POST, path=f"{PATH_PREFIX}/databases/{database}/park" + ) def unpark_database(self, database=""): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/databases/{database}/unpark") + return self.client.request( + method=http_methods.POST, path=f"{PATH_PREFIX}/databases/{database}/unpark" + ) def resize_database(self, database="", options=None): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/databases/{database}/resize", - json_data=options) + return self.client.request( + method=http_methods.POST, + path=f"{PATH_PREFIX}/databases/{database}/resize", + json_data=options, + ) def reset_database_password(self, database="", options=None): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/databases/{database}/resetPassword", - json_data=options) + return self.client.request( + method=http_methods.POST, + path=f"{PATH_PREFIX}/databases/{database}/resetPassword", + json_data=options, + ) def get_secure_bundle(self, database=""): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/databases/{database}/secureBundleURL") + return self.client.request( + method=http_methods.POST, + path=f"{PATH_PREFIX}/databases/{database}/secureBundleURL", + ) def get_datacenters(self, database=""): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/databases/{database}/datacenters") + return self.client.request( + method=http_methods.GET, + path=f"{PATH_PREFIX}/databases/{database}/datacenters", + ) def create_datacenter(self, database="", options=None): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/databases/{database}/datacenters", - json_data=options) + return self.client.request( + method=http_methods.POST, + path=f"{PATH_PREFIX}/databases/{database}/datacenters", + json_data=options, + ) def terminate_datacenter(self, database="", datacenter=""): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/databases/{database}/datacenters/{datacenter}/terminate") - - def get_access_list(self, database=""): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/databases/{database}/access-list") - - def replace_access_list(self, database="", access_list=None): - return self.client.request(method=http_methods.PUT, - path=f"{PATH_PREFIX}/databases/{database}/access-list", - json_data=access_list) - - def update_access_list(self, database="", access_list=None): - return self.client.request(method=http_methods.PATCH, - path=f"{PATH_PREFIX}/databases/{database}/access-list", - json_data=access_list) - - def add_access_list_address(self, database="", address=None): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/databases/{database}/access-list", - json_data=address) - - def delete_access_list(self, database=""): - return self.client.request(method=http_methods.DELETE, - path=f"{PATH_PREFIX}/databases/{database}/access-list") - - def get_private_link(self, database=""): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/organizations/clusters/{database}/private-link") - - def get_datacenter_private_link(self, database="", datacenter=""): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/organizations/clusters/{database}/datacenters/{datacenter}/private-link") - - def create_datacenter_private_link(self, database="", datacenter="", private_link=None): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/organizations/clusters/{database}/datacenters/{datacenter}/private-link", - json_data=private_link) - - def create_datacenter_endpoint(self, database="", datacenter="", endpoint=None): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/organizations/clusters/{database}/datacenters/{datacenter}/endpoint", - json_data=endpoint) - - def update_datacenter_endpoint(self, database="", datacenter="", endpoint=None): - return self.client.request(method=http_methods.PUT, - path=f"{PATH_PREFIX}/organizations/clusters/{database}/datacenters/{datacenter}/endpoints/{endpoint['id']}", - json_data=endpoint) - - def get_datacenter_endpoint(self, database="", datacenter="", endpoint=""): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/organizations/clusters/{database}/datacenters/{datacenter}/endpoints/{endpoint}") - - def delete_datacenter_endpoint(self, database="", datacenter="", endpoint=""): - return self.client.request(method=http_methods.DELETE, - path=f"{PATH_PREFIX}/organizations/clusters/{database}/datacenters/{datacenter}/endpoints/{endpoint}") + return self.client.request( + method=http_methods.POST, + path=f"{PATH_PREFIX}/databases/{database}/datacenters/{datacenter}/terminate", + ) + + def get_access_list(self, database=""): + return self.client.request( + method=http_methods.GET, + path=f"{PATH_PREFIX}/databases/{database}/access-list", + ) + + def replace_access_list(self, database="", access_list=None): + return self.client.request( + method=http_methods.PUT, + path=f"{PATH_PREFIX}/databases/{database}/access-list", + json_data=access_list, + ) + + def update_access_list(self, database="", access_list=None): + return self.client.request( + method=http_methods.PATCH, + path=f"{PATH_PREFIX}/databases/{database}/access-list", + json_data=access_list, + ) + + def add_access_list_address(self, database="", address=None): + return self.client.request( + method=http_methods.POST, + path=f"{PATH_PREFIX}/databases/{database}/access-list", + json_data=address, + ) + + def delete_access_list(self, database=""): + return self.client.request( + method=http_methods.DELETE, + path=f"{PATH_PREFIX}/databases/{database}/access-list", + ) + + def get_private_link(self, database=""): + return self.client.request( + method=http_methods.GET, + path=f"{PATH_PREFIX}/organizations/clusters/{database}/private-link", + ) + + def get_datacenter_private_link(self, database="", datacenter=""): + return self.client.request( + method=http_methods.GET, + path=f"{PATH_PREFIX}/organizations/clusters/{database}/datacenters/{datacenter}/private-link", + ) + + def create_datacenter_private_link( + self, database="", datacenter="", private_link=None + ): + return self.client.request( + method=http_methods.POST, + path=f"{PATH_PREFIX}/organizations/clusters/{database}/datacenters/{datacenter}/private-link", + json_data=private_link, + ) + + def create_datacenter_endpoint(self, database="", datacenter="", endpoint=None): + return self.client.request( + method=http_methods.POST, + path=f"{PATH_PREFIX}/organizations/clusters/{database}/datacenters/{datacenter}/endpoint", + json_data=endpoint, + ) + + def update_datacenter_endpoint(self, database="", datacenter="", endpoint=None): + return self.client.request( + method=http_methods.PUT, + path=f"{PATH_PREFIX}/organizations/clusters/{database}/datacenters/{datacenter}/endpoints/{endpoint['id']}", + json_data=endpoint, + ) + + def get_datacenter_endpoint(self, database="", datacenter="", endpoint=""): + return self.client.request( + method=http_methods.GET, + path=f"{PATH_PREFIX}/organizations/clusters/{database}/datacenters/{datacenter}/endpoints/{endpoint}", + ) + + def delete_datacenter_endpoint(self, database="", datacenter="", endpoint=""): + return self.client.request( + method=http_methods.DELETE, + path=f"{PATH_PREFIX}/organizations/clusters/{database}/datacenters/{datacenter}/endpoints/{endpoint}", + ) def get_available_classic_regions(self): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/availableRegions") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/availableRegions" + ) def get_available_regions(self): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/regions/serverless") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/regions/serverless" + ) def get_roles(self): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/organizations/roles") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/organizations/roles" + ) def create_role(self, role_definition=None): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/organizations/roles", - json_data=role_definition) + return self.client.request( + method=http_methods.POST, + path=f"{PATH_PREFIX}/organizations/roles", + json_data=role_definition, + ) def get_role(self, role=""): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/organizations/roles/{role}") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/organizations/roles/{role}" + ) def update_role(self, role="", role_definition=None): - return self.client.request(method=http_methods.PUT, - path=f"{PATH_PREFIX}/organizations/roles/{role}", - json_data=role_definition) + return self.client.request( + method=http_methods.PUT, + path=f"{PATH_PREFIX}/organizations/roles/{role}", + json_data=role_definition, + ) def delete_role(self, role=""): - return self.client.request(method=http_methods.DELETE, - path=f"{PATH_PREFIX}/organizations/roles/{role}") + return self.client.request( + method=http_methods.DELETE, path=f"{PATH_PREFIX}/organizations/roles/{role}" + ) def invite_user(self, user_definition=None): - return self.client.request(method=http_methods.PUT, - path=f"{PATH_PREFIX}/organizations/users", - json_data=user_definition) + return self.client.request( + method=http_methods.PUT, + path=f"{PATH_PREFIX}/organizations/users", + json_data=user_definition, + ) def get_users(self): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/organizations/users") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/organizations/users" + ) def get_user(self, user=""): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/organizations/users/{user}") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/organizations/users/{user}" + ) def remove_user(self, user=""): - return self.client.request(method=http_methods.DELETE, - path=f"{PATH_PREFIX}/organizations/users/{user}") + return self.client.request( + method=http_methods.DELETE, path=f"{PATH_PREFIX}/organizations/users/{user}" + ) def update_user_roles(self, user="", roles=None): - return self.client.request(method=http_methods.PUT, - path=f"{PATH_PREFIX}/organizations/users/{user}/roles", - json_data=roles) + return self.client.request( + method=http_methods.PUT, + path=f"{PATH_PREFIX}/organizations/users/{user}/roles", + json_data=roles, + ) def get_clients(self): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/clientIdSecrets") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/clientIdSecrets" + ) def create_token(self, roles=None): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/clientIdSecrets", - json_data=roles) + return self.client.request( + method=http_methods.POST, + path=f"{PATH_PREFIX}/clientIdSecrets", + json_data=roles, + ) def delete_token(self, token=""): - return self.client.request(method=http_methods.DELETE, - path=f"{PATH_PREFIX}/clientIdSecret/{token}") + return self.client.request( + method=http_methods.DELETE, path=f"{PATH_PREFIX}/clientIdSecret/{token}" + ) def get_organization(self): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/currentOrg") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/currentOrg" + ) def get_access_lists(self): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/access-lists") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/access-lists" + ) def get_access_list_template(self): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/access-list/template") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/access-list/template" + ) def validate_access_list(self): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/access-list/validate") + return self.client.request( + method=http_methods.POST, path=f"{PATH_PREFIX}/access-list/validate" + ) def get_private_links(self): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/organizations/private-link") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/organizations/private-link" + ) def get_streaming_providers(self): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/streaming/providers") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/streaming/providers" + ) def get_streaming_tenants(self): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/streaming/tenants") + return self.client.request( + method=http_methods.GET, path=f"{PATH_PREFIX}/streaming/tenants" + ) def create_streaming_tenant(self, tenant=None): - return self.client.request(method=http_methods.POST, - path=f"{PATH_PREFIX}/streaming/tenants", - json_data=tenant) + return self.client.request( + method=http_methods.POST, + path=f"{PATH_PREFIX}/streaming/tenants", + json_data=tenant, + ) def delete_streaming_tenant(self, tenant="", cluster=""): - return self.client.request(method=http_methods.DELETE, - path=f"{PATH_PREFIX}/streaming/tenants/{tenant}/clusters/{cluster}", - json_data=tenant) + return self.client.request( + method=http_methods.DELETE, + path=f"{PATH_PREFIX}/streaming/tenants/{tenant}/clusters/{cluster}", + json_data=tenant, + ) def get_streaming_tenant(self, tenant=""): - return self.client.request(method=http_methods.GET, - path=f"{PATH_PREFIX}/streaming/tenants/{tenant}/limits") + return self.client.request( + method=http_methods.GET, + path=f"{PATH_PREFIX}/streaming/tenants/{tenant}/limits", + ) diff --git a/astrapy/endpoints/rest.py b/astrapy/endpoints/rest.py index 6b5bd765..b5e31981 100644 --- a/astrapy/endpoints/rest.py +++ b/astrapy/endpoints/rest.py @@ -13,49 +13,59 @@ # limitations under the License. import json -from astrapy.rest import http_methods +from astrapyjson.config.rest import http_methods DEFAULT_PAGE_SIZE = 20 PATH_PREFIX = "/api/rest/v2/keyspaces" -class AstraRest(): - +class AstraRest: def __init__(self, client=None): self.client = client self.path_prefix = PATH_PREFIX if client.auth_base_url is not None: - self.path_prefix = '/v2/keyspaces' + self.path_prefix = "/v2/keyspaces" def search_table(self, keyspace="", table="", query=None, options=None): options = {} if options is None else options - request_params = {"where": json.dumps( - query), "page-size": DEFAULT_PAGE_SIZE} + request_params = {"where": json.dumps(query), "page-size": DEFAULT_PAGE_SIZE} request_params.update(options) - return self.client.request(method=http_methods.GET, - path=f"{self.path_prefix}/{keyspace}/{table}", - url_params=request_params) + return self.client.request( + method=http_methods.GET, + path=f"{self.path_prefix}/{keyspace}/{table}", + url_params=request_params, + ) def add_row(self, keyspace="", table="", row=None): - return self.client.request(method=http_methods.POST, - path=f"{self.path_prefix}/{keyspace}/{table}", - json_data=row) + return self.client.request( + method=http_methods.POST, + path=f"{self.path_prefix}/{keyspace}/{table}", + json_data=row, + ) def get_rows(self, keyspace="", table="", key_path="", options=None): - return self.client.request(method=http_methods.GET, - path=f"{self.path_prefix}/{keyspace}/{table}/{key_path}", - json_data=options) + return self.client.request( + method=http_methods.GET, + path=f"{self.path_prefix}/{keyspace}/{table}/{key_path}", + json_data=options, + ) def replace_rows(self, keyspace="", table="", key_path="", row=""): - return self.client.request(method=http_methods.PUT, - path=f"{self.path_prefix}/{keyspace}/{table}/{key_path}", - json_data=row) + return self.client.request( + method=http_methods.PUT, + path=f"{self.path_prefix}/{keyspace}/{table}/{key_path}", + json_data=row, + ) def update_rows(self, keyspace="", table="", key_path="", row=""): - return self.client.request(method=http_methods.PATCH, - path=f"{self.path_prefix}/{keyspace}/{table}/{key_path}", - json_data=row) + return self.client.request( + method=http_methods.PATCH, + path=f"{self.path_prefix}/{keyspace}/{table}/{key_path}", + json_data=row, + ) def delete_rows(self, keyspace="", table="", key_path=""): - return self.client.request(method=http_methods.DELETE, - path=f"{self.path_prefix}/{keyspace}/{table}/{key_path}") + return self.client.request( + method=http_methods.DELETE, + path=f"{self.path_prefix}/{keyspace}/{table}/{key_path}", + ) diff --git a/astrapy/endpoints/schemas.py b/astrapy/endpoints/schemas.py deleted file mode 100644 index 88d1660e..00000000 --- a/astrapy/endpoints/schemas.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from astrapy.rest import http_methods - -PATH_PREFIX = "/api/rest/v2/schemas" - - -class AstraSchemas(): - - def __init__(self, client=None): - self.client = client - self.path_prefix = PATH_PREFIX - if client.auth_base_url is not None: - self.path_prefix = '/v2/schemas' - - def get_keyspaces(self): - res = self.client.request(method=http_methods.GET, - path=f"{self.path_prefix}/keyspaces") - return res.get("data", []) - - def get_keyspace(self, keyspace=""): - res = self.client.request(method=http_methods.GET, - path=f"{self.path_prefix}/keyspaces/{keyspace}") - return res.get("data") - - def create_table(self, keyspace="", table_definition=None): - return self.client.request(method=http_methods.POST, - path=f"{self.path_prefix}/keyspaces/{keyspace}/tables", - json_data=table_definition) - - def get_tables(self, keyspace=""): - res = self.client.request(method=http_methods.GET, - path=f"{self.path_prefix}/keyspaces/{keyspace}/tables") - return res.get("data") - - def get_table(self, keyspace="", table=""): - res = self.client.request(method=http_methods.GET, - path=f"{self.path_prefix}/keyspaces/{keyspace}/tables/{table}") - return res.get("data") - - def update_table(self, keyspace="", table_definition=None): - return self.client.request(method=http_methods.PUT, - path=f"{self.path_prefix}/keyspaces/{keyspace}/tables/{table_definition['name']}", - json_data=table_definition) - - def delete_table(self, keyspace="", table=""): - return self.client.request(method=http_methods.DELETE, - path=f"{self.path_prefix}/keyspaces/{keyspace}/tables/{table}") - - def create_column(self, keyspace="", table="", column_definition=None): - return self.client.request(method=http_methods.POST, - path=f"{self.path_prefix}/keyspaces/{keyspace}/tables/{table}/columns", - json_data=column_definition) - - def get_columns(self, keyspace="", table=""): - res = self.client.request(method=http_methods.GET, - path=f"{self.path_prefix}/keyspaces/{keyspace}/tables/{table}/columns") - return res.get("data") - - def get_column(self, keyspace="", table="", column=""): - res = self.client.request(method=http_methods.GET, - path=f"{self.path_prefix}/keyspaces/{keyspace}/tables/{table}/columns/{column}") - return res.get("data") - - def update_column(self, keyspace="", table="", column="", column_definition=None): - return self.client.request(method=http_methods.PUT, - path=f"{self.path_prefix}/keyspaces/{keyspace}/tables/{table}/columns/{column}", - json_data=column_definition) - - def delete_column(self, keyspace="", table="", column=""): - return self.client.request(method=http_methods.DELETE, - path=f"{self.path_prefix}/keyspaces/{keyspace}/tables/{table}/columns/{column}") - - def get_indexes(self, keyspace="", table=""): - return self.client.request(method=http_methods.GET, - path=f"{self.path_prefix}/keyspaces/{keyspace}/tables/{table}/indexes") - - def create_index(self, keyspace="", table="", index_definition=None): - return self.client.request(method=http_methods.POST, - path=f"{self.path_prefix}/keyspaces/{keyspace}/tables/{table}/indexes", - json_data=index_definition) - - def delete_index(self, keyspace="", table="", index=""): - return self.client.request(method=http_methods.DELETE, - path=f"{self.path_prefix}/keyspaces/{keyspace}/tables/{table}/indexes/{index}") - - def get_types(self, keyspace=""): - res = self.client.request(method=http_methods.GET, - path=f"{self.path_prefix}/keyspaces/{keyspace}/types") - return res.get("data") - - def get_type(self, keyspace="", udt=""): - res = self.client.request(method=http_methods.GET, - path=f"{self.path_prefix}/keyspaces/{keyspace}/types/{udt}") - return res.get("data") - - def create_type(self, keyspace="", udt_definition=None): - return self.client.request(method=http_methods.POST, - path=f"{self.path_prefix}/keyspaces/{keyspace}/types", - json_data=udt_definition) - - def update_type(self, keyspace="", udt_definition=None): - return self.client.request(method=http_methods.PUT, - path=f"{self.path_prefix}/keyspaces/{keyspace}/types", - json_data=udt_definition) - - def delete_type(self, keyspace="", udt=""): - return self.client.request(method=http_methods.DELETE, - path=f"{self.path_prefix}/keyspaces/{keyspace}/types/{udt}") diff --git a/astrapy/endpoints/test.py b/astrapy/endpoints/test.py new file mode 100644 index 00000000..e69de29b diff --git a/astrapy/ops.py b/astrapy/ops.py new file mode 100644 index 00000000..533dbb72 --- /dev/null +++ b/astrapy/ops.py @@ -0,0 +1,346 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from astrapy.utils import make_request, http_methods +from astrapy.defaults import DEFAULT_DEV_OPS_PATH_PREFIX, DEFAULT_DEV_OPS_URL + +import logging + +logger = logging.getLogger(__name__) + + +class AstraDBOps: + def __init__(self, api_key, dev_ops_url=None): + dev_ops_url = dev_ops_url or DEFAULT_DEV_OPS_URL + + self.api_key = "Bearer " + api_key + self.base_url = f"https://{dev_ops_url}{DEFAULT_DEV_OPS_PATH_PREFIX}" + + def _ops_request(self, method, path, options=None, json_data=None): + options = {} if options is None else options + + return make_request( + base_url=self.base_url, + method=method, + auth_header="Authorization", + api_key=self.api_key, + json_data=json_data, + url_params=options, + path=path, + ) + + def get_databases(self, options=None): + response = self._ops_request( + method=http_methods.GET, path="/databases", options=options + ).json() + + return response + + def create_database(self, database_definition=None): + r = self._ops_request( + method=http_methods.POST, path="/databases", json_data=database_definition + ) + + if r.status_code == 201: + return {"id": r.headers["Location"]} + + return None + + def terminate_database(self, database=""): + r = self._ops_request( + method=http_methods.POST, path=f"/databases/{database}/terminate" + ) + + if r.status_code == 202: + return database + + return None + + def get_database(self, database="", options=None): + return self._ops_request( + method=http_methods.GET, + path=f"/databases/{database}", + options=options, + ).json() + + def create_keyspace(self, database="", keyspace=""): + return self._ops_request( + method=http_methods.POST, + path=f"/databases/{database}/keyspaces/{keyspace}", + ) + + def park_database(self, database=""): + return self._ops_request( + method=http_methods.POST, path=f"/databases/{database}/park" + ).json() + + def unpark_database(self, database=""): + return self._ops_request( + method=http_methods.POST, path=f"/databases/{database}/unpark" + ).json() + + def resize_database(self, database="", options=None): + return self._ops_request( + method=http_methods.POST, + path=f"/databases/{database}/resize", + json_data=options, + ).json() + + def reset_database_password(self, database="", options=None): + return self._ops_request( + method=http_methods.POST, + path=f"/databases/{database}/resetPassword", + json_data=options, + ).json() + + def get_secure_bundle(self, database=""): + return self._ops_request( + method=http_methods.POST, + path=f"/databases/{database}/secureBundleURL", + ).json() + + def get_datacenters(self, database=""): + return self._ops_request( + method=http_methods.GET, + path=f"/databases/{database}/datacenters", + ).json() + + def create_datacenter(self, database="", options=None): + return self._ops_request( + method=http_methods.POST, + path=f"/databases/{database}/datacenters", + json_data=options, + ).json() + + def terminate_datacenter(self, database="", datacenter=""): + return self._ops_request( + method=http_methods.POST, + path=f"/databases/{database}/datacenters/{datacenter}/terminate", + ).json() + + def get_access_list(self, database=""): + return self._ops_request( + method=http_methods.GET, + path=f"/databases/{database}/access-list", + ).json() + + def replace_access_list(self, database="", access_list=None): + return self._ops_request( + method=http_methods.PUT, + path=f"/databases/{database}/access-list", + json_data=access_list, + ).json() + + def update_access_list(self, database="", access_list=None): + return self._ops_request( + method=http_methods.PATCH, + path=f"/databases/{database}/access-list", + json_data=access_list, + ).json() + + def add_access_list_address(self, database="", address=None): + return self._ops_request( + method=http_methods.POST, + path=f"/databases/{database}/access-list", + json_data=address, + ).json() + + def delete_access_list(self, database=""): + return self._ops_request( + method=http_methods.DELETE, + path=f"/databases/{database}/access-list", + ).json() + + def get_private_link(self, database=""): + return self._ops_request( + method=http_methods.GET, + path=f"/organizations/clusters/{database}/private-link", + ).json() + + def get_datacenter_private_link(self, database="", datacenter=""): + return self._ops_request( + method=http_methods.GET, + path=f"/organizations/clusters/{database}/datacenters/{datacenter}/private-link", + ).json() + + def create_datacenter_private_link( + self, database="", datacenter="", private_link=None + ): + return self._ops_request( + method=http_methods.POST, + path=f"/organizations/clusters/{database}/datacenters/{datacenter}/private-link", + json_data=private_link, + ).json() + + def create_datacenter_endpoint(self, database="", datacenter="", endpoint=None): + return self._ops_request( + method=http_methods.POST, + path=f"/organizations/clusters/{database}/datacenters/{datacenter}/endpoint", + json_data=endpoint, + ).json() + + def update_datacenter_endpoint(self, database="", datacenter="", endpoint=None): + return self._ops_request( + method=http_methods.PUT, + path=f"/organizations/clusters/{database}/datacenters/{datacenter}/endpoints/{endpoint['id']}", + json_data=endpoint, + ).json() + + def get_datacenter_endpoint(self, database="", datacenter="", endpoint=""): + return self._ops_request( + method=http_methods.GET, + path=f"/organizations/clusters/{database}/datacenters/{datacenter}/endpoints/{endpoint}", + ).json() + + def delete_datacenter_endpoint(self, database="", datacenter="", endpoint=""): + return self._ops_request( + method=http_methods.DELETE, + path=f"/organizations/clusters/{database}/datacenters/{datacenter}/endpoints/{endpoint}", + ).json() + + def get_available_classic_regions(self): + return self._ops_request( + method=http_methods.GET, path=f"/availableRegions" + ).json() + + def get_available_regions(self): + return self._ops_request( + method=http_methods.GET, path=f"/regions/serverless" + ).json() + + def get_roles(self): + return self._ops_request( + method=http_methods.GET, path=f"/organizations/roles" + ).json() + + def create_role(self, role_definition=None): + return self._ops_request( + method=http_methods.POST, + path=f"/organizations/roles", + json_data=role_definition, + ).json() + + def get_role(self, role=""): + return self._ops_request( + method=http_methods.GET, path=f"/organizations/roles/{role}" + ).json() + + def update_role(self, role="", role_definition=None): + return self._ops_request( + method=http_methods.PUT, + path=f"/organizations/roles/{role}", + json_data=role_definition, + ).json() + + def delete_role(self, role=""): + return self._ops_request( + method=http_methods.DELETE, path=f"/organizations/roles/{role}" + ).json() + + def invite_user(self, user_definition=None): + return self._ops_request( + method=http_methods.PUT, + path=f"/organizations/users", + json_data=user_definition, + ).json() + + def get_users(self): + return self._ops_request( + method=http_methods.GET, path=f"/organizations/users" + ).json() + + def get_user(self, user=""): + return self._ops_request( + method=http_methods.GET, path=f"/organizations/users/{user}" + ).json() + + def remove_user(self, user=""): + return self._ops_request( + method=http_methods.DELETE, path=f"/organizations/users/{user}" + ).json() + + def update_user_roles(self, user="", roles=None): + return self._ops_request( + method=http_methods.PUT, + path=f"/organizations/users/{user}/roles", + json_data=roles, + ).json() + + def get_clients(self): + return self._ops_request( + method=http_methods.GET, path=f"/clientIdSecrets" + ).json() + + def create_token(self, roles=None): + return self._ops_request( + method=http_methods.POST, + path=f"/clientIdSecrets", + json_data=roles, + ).json() + + def delete_token(self, token=""): + return self._ops_request( + method=http_methods.DELETE, path=f"/clientIdSecret/{token}" + ).json() + + def get_organization(self): + return self._ops_request(method=http_methods.GET, path=f"/currentOrg").json() + + def get_access_lists(self): + return self._ops_request(method=http_methods.GET, path=f"/access-lists").json() + + def get_access_list_template(self): + return self._ops_request( + method=http_methods.GET, path=f"/access-list/template" + ).json() + + def validate_access_list(self): + return self._ops_request( + method=http_methods.POST, path=f"/access-list/validate" + ).json() + + def get_private_links(self): + return self._ops_request( + method=http_methods.GET, path=f"/organizations/private-link" + ).json() + + def get_streaming_providers(self): + return self._ops_request( + method=http_methods.GET, path=f"/streaming/providers" + ).json() + + def get_streaming_tenants(self): + return self._ops_request( + method=http_methods.GET, path=f"/streaming/tenants" + ).json() + + def create_streaming_tenant(self, tenant=None): + return self._ops_request( + method=http_methods.POST, + path=f"/streaming/tenants", + json_data=tenant, + ).json() + + def delete_streaming_tenant(self, tenant="", cluster=""): + return self._ops_request( + method=http_methods.DELETE, + path=f"/streaming/tenants/{tenant}/clusters/{cluster}", + json_data=tenant, + ).json() + + def get_streaming_tenant(self, tenant=""): + return self._ops_request( + method=http_methods.GET, + path=f"/streaming/tenants/{tenant}/limits", + ).json() diff --git a/astrapy/rest.py b/astrapy/rest.py deleted file mode 100644 index 4e007beb..00000000 --- a/astrapy/rest.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import requests - -logger = logging.getLogger(__name__) - - -REQUESTED_WITH = "AstraPy" -DEFAULT_AUTH_PATH = "/api/rest/v1/auth" -DEFAULT_TIMEOUT = 30000 -DEFAULT_AUTH_HEADER = "X-Cassandra-Token" - - -class http_methods(): - GET = "GET" - POST = "POST" - PUT = "PUT" - PATCH = "PATCH" - DELETE = "DELETE" - - -class AstraClient(): - def __init__(self, astra_database_id=None, - astra_database_region=None, - astra_application_token=None, - base_url=None, - auth_base_url=None, - username=None, - password=None, - auth_token=None, - debug=False, - auth_header=None): - self.astra_database_id = astra_database_id - self.astra_database_region = astra_database_region - self.astra_application_token = astra_application_token - self.base_url = base_url - self.auth_base_url = auth_base_url - self.username = username - self.password = password - self.auth_token = auth_token - self.auth_header = auth_header - self.debug = debug - self.auth_header = DEFAULT_AUTH_HEADER - - def request(self, method=http_methods.GET, path=None, json_data=None, url_params=None): - if self.auth_token: - auth_token = self.auth_token - else: - auth_token = self.astra_application_token - r = requests.request(method=method, url=f"{self.base_url}{path}", - params=url_params, json=json_data, timeout=DEFAULT_TIMEOUT, - headers={self.auth_header: auth_token}) - try: - return r.json() - except: - return None - - -def get_token(auth_base_url=None, username=None, password=None): - try: - r = requests.request(method=http_methods.POST, - url=f"{auth_base_url}", - json={"username": username, "password": password}, - timeout=DEFAULT_TIMEOUT) - token_response = r.json() - return token_response["authToken"] - except: - return None - - -def create_client(astra_database_id=None, - astra_database_region=None, - astra_application_token=None, - base_url=None, - auth_base_url=None, - username=None, - password=None, - debug=False): - if base_url is None: - base_url = f"https://{astra_database_id}-{astra_database_region}.apps.astra.datastax.com" - auth_token = None - if auth_base_url: - auth_token = get_token(auth_base_url=auth_base_url, - username=username, password=password) - if auth_token == None: - raise Exception('A valid token is required') - return AstraClient(astra_database_id=astra_database_id, - astra_database_region=astra_database_region, - astra_application_token=astra_application_token, - base_url=base_url, - auth_base_url=auth_base_url, - username=username, - password=password, - auth_token=auth_token, - debug=debug) diff --git a/astrapy/utils.py b/astrapy/utils.py new file mode 100644 index 00000000..ea792734 --- /dev/null +++ b/astrapy/utils.py @@ -0,0 +1,71 @@ +import requests +import logging +import re + +from astrapy.defaults import DEFAULT_TIMEOUT + +logger = logging.getLogger(__name__) + + +class http_methods: + GET = "GET" + POST = "POST" + PUT = "PUT" + PATCH = "PATCH" + DELETE = "DELETE" + + +def make_request( + base_url, + auth_header, + token, + method=http_methods.POST, + path=None, + json_data=None, + url_params=None, +): + try: + r = requests.request( + method=method, + url=f"{base_url}{path}", + params=url_params, + json=json_data, + timeout=DEFAULT_TIMEOUT, + headers={auth_header: token}, + ) + + return r + except Exception as e: + logger.warning(e) + + return {"error": "An unknown error occurred", "details": str(e)} + + +def make_payload(top_level, **kwargs): + params = {} + for key, value in kwargs.items(): + params[key] = value + + json_query = {top_level: {}} + + # Adding keys only if they're provided + for key, value in params.items(): + if value is not None: + json_query[top_level][key] = value + + return json_query + + +def parse_endpoint_url(url): + # Regular expression pattern to match the given URL format + pattern = r"https://(?P[a-fA-F0-9\-]{36})-(?P[a-zA-Z0-9\-]+)\.(?P[a-zA-Z0-9\-\.]+\.com)" + + match = re.match(pattern, url) + if match: + return ( + match.group("db_id"), + match.group("db_region"), + match.group("db_hostname"), + ) + else: + raise ValueError("Invalid URL format") diff --git a/requirements.txt b/requirements.txt index 768c35a7..c1a63efa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,8 @@ -faker==13.0.0 -gql==3.0.0 -pytest==7.0.1 -pytest-cov==3.0.0 -pytest-testdox==3.0.0 -requests==2.27.1 -requests-toolbelt==0.9.1 +faker~=19.11.0 +pytest~=7.4.2 +pytest-cov~=4.1.0 +pytest-testdox~=3.1.0 +requests~=2.31.0 +requests-toolbelt~=1.0.0 +python-dotenv~=1.0.0 +pre-commit~=3.5.0 diff --git a/scripts/astrapy_latest_interface.py b/scripts/astrapy_latest_interface.py new file mode 100644 index 00000000..a692fe18 --- /dev/null +++ b/scripts/astrapy_latest_interface.py @@ -0,0 +1,68 @@ +import os +import sys + +from dotenv import load_dotenv + +from astrapy.db import AstraDB, AstraDBCollection +from astrapy.ops import AstraDBOps + +sys.path.append("../") + +load_dotenv() + +# First, we work with devops +token = os.getenv("ASTRA_DB_APPLICATION_TOKEN") +astra_ops = AstraDBOps(token) + +# Define a database to create +database_definition = { + "name": "vector_test", + "tier": "serverless", + "cloudProvider": "GCP", + "keyspace": os.getenv("ASTRA_DB_KEYSPACE", "default_keyspace"), + "region": os.getenv("ASTRA_DB_REGION", None), + "capacityUnits": 1, + "user": "example", + "password": token, + "dbType": "vector", +} + +# Create the database +create_result = astra_ops.create_database(database_definition=database_definition) + +# Grab the new information from the database +database_id = create_result["id"] +database_region = astra_ops.get_database()[0]["info"]["region"] +database_base_url = "apps.astra.datastax.com" + +# Build the endpoint URL: +api_endpoint = f"https://{database_id}-{database_region}.{database_base_url}" + +# Initialize our vector db +astra_db = AstraDB(token=token, api_endpoint=api_endpoint) + +# Possible Operations +astra_db.create_collection(collection_name="collection_test_delete", size=5) +astra_db.delete_collection(collection_name="collection_test_delete") +astra_db.create_collection(collection_name="collection_test", size=5) + +# Collections +astra_db_collection = AstraDBCollection( + collection_name="collection_test", astra_db=astra_db +) +# Or... +astra_db_collection = AstraDBCollection( + collection_name="collection_test", token=token, api_endpoint=api_endpoint +) + +astra_db_collection.insert_one( + { + "_id": "5", + "name": "Coded Cleats Copy", + "description": "ChatGPT integrated sneakers that talk to you", + "$vector": [0.25, 0.25, 0.25, 0.25, 0.25], + } +) + +astra_db_collection.find_one({"name": "potato"}) +astra_db_collection.find_one({"name": "Coded Cleats Copy"}) diff --git a/setup.py b/setup.py index caccf932..aa5fb9a4 100644 --- a/setup.py +++ b/setup.py @@ -14,42 +14,46 @@ from setuptools import setup from os import path + this_directory = path.abspath(path.dirname(__file__)) -with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: +with open(path.join(this_directory, "README.md"), encoding="utf-8") as f: long_description = f.read() setup( - name='astrapy', + name="astrapy", packages=[ - 'astrapy', - 'astrapy/endpoints', + "astrapy", + "astrapy/endpoints", ], - version='0.3.3', - license='Apache license 2.0', - description='AstraPy is a Pythonic SDK for DataStax Astra', + version="0.5.0", + license="Apache license 2.0", + description="AstraPy is a Pythonic SDK for DataStax Astra", long_description=long_description, - long_description_content_type='text/markdown', - author='DataStax', - author_email='oss@datastax.com', - url='https://github.com/datastax/astrapy', - download_url='https://github.com/datastax/astrapy/archive/refs/tags/v0.3.3.tar.gz', - keywords=['DataStax Astra', 'Stargate'], + long_description_content_type="text/markdown", + author="Kirsten Hunter", + author_email="kirsten.hunter@datastax.com", + url="https://github.com/datastax/astrapy", + keywords=["DataStax Astra", "Stargate"], install_requires=[ - "requests>=2.27,<3", - "requests_toolbelt>=0.9.1,<1", - 'gql>=3.0.0', + "faker~=19.11.0", + "pytest~=7.4.2", + "pytest-cov~=4.1.0", + "pytest-testdox~=3.1.0", + "requests~=2.31.0", + "requests-toolbelt~=1.0.0", + "python-dotenv~=1.0.0", ], classifiers=[ - 'Development Status :: 3 - Alpha', - 'Intended Audience :: Developers', - 'Topic :: Software Development :: Build Tools', - 'License :: OSI Approved :: Apache Software License', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.4', - 'Programming Language :: Python :: 3.5', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', - 'Programming Language :: Python :: 3.9', + "Development Status :: 3 - Alpha", + "Intended Audience :: Developers", + "Topic :: Software Development :: Build Tools", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.4", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", ], ) diff --git a/tests/__init__.py b/tests/__init__.py index 33284e0a..2c9ca172 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/tests/astrapy/__init__.py b/tests/astrapy/__init__.py index 33284e0a..2c9ca172 100644 --- a/tests/astrapy/__init__.py +++ b/tests/astrapy/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/tests/astrapy/test_client.py b/tests/astrapy/test_client.py deleted file mode 100644 index c85c1f95..00000000 --- a/tests/astrapy/test_client.py +++ /dev/null @@ -1,473 +0,0 @@ -# Copyright DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import List -from astrapy.client import create_astra_client, AstraClient -import pytest -import logging -import os -import uuid -import time -from faker import Faker - - -logger = logging.getLogger(__name__) -fake = Faker() - -ASTRA_DB_ID = os.environ.get('ASTRA_DB_ID') -ASTRA_DB_REGION = os.environ.get('ASTRA_DB_REGION') -ASTRA_DB_APPLICATION_TOKEN = os.environ.get('ASTRA_DB_APPLICATION_TOKEN') -ASTRA_DB_KEYSPACE = os.environ.get('ASTRA_DB_KEYSPACE') -TABLE_NAME = fake.bothify(text="users_????") - -STARGATE_BASE_URL = os.environ.get('STARGATE_BASE_URL') -STARGATE_AUTH_URL = os.environ.get('STARGATE_AUTH_URL') -STARGATE_USERNAME = os.environ.get('STARGATE_USERNAME') -STARGATE_PASSWORD = os.environ.get('STARGATE_PASSWORD') - - -@pytest.fixture -def astra_client(): - return create_astra_client(astra_database_id=ASTRA_DB_ID, - astra_database_region=ASTRA_DB_REGION, - astra_application_token=ASTRA_DB_APPLICATION_TOKEN) - - -@pytest.fixture -def stargate_client(): - return create_astra_client(base_url=STARGATE_BASE_URL, - auth_base_url=STARGATE_AUTH_URL, - username=STARGATE_USERNAME, - password=STARGATE_PASSWORD) - - -@pytest.fixture -def table_definition(): - return { - "name": TABLE_NAME, - "columnDefinitions": [ - { - "name": "firstname", - "typeDefinition": "text" - }, - { - "name": "lastname", - "typeDefinition": "text" - }, - { - "name": "favorite_color", - "typeDefinition": "text", - } - ], - "primaryKey": { - "partitionKey": [ - "firstname" - ], - "clusteringKey": [ - "lastname" - ] - } - } - - -@pytest.fixture -def column_definition(): - return { - "name": "favorite_food", - "typeDefinition": "text" - } - - -@pytest.fixture -def index_definition(): - return { - "column": "favorite_color", - "name": "favorite_color_idx", - "ifNotExists": True - } - - -@pytest.fixture -def udt_definition(): - return { - "name": "custom", - "ifNotExists": True, - "fields": [ - { - "name": "title", - "typeDefinition": "text" - } - ] - } - - -@pytest.mark.it('should initialize an AstraDB REST Client') -def test_connect(astra_client): - assert type(astra_client) is AstraClient - - -@pytest.mark.it('should initialize a Stargate REST Client') -def test_stargate_connect(stargate_client): - assert type(stargate_client) is AstraClient - - -@pytest.mark.it('should get databases') -def test_get_databases(astra_client): - databases = astra_client.ops.get_databases() - assert type(databases) is list - - -@pytest.mark.it('should get a database') -def test_get_databases(astra_client): - databases = astra_client.ops.get_databases(options={"include": "active"}) - database = astra_client.ops.get_database(databases[0]["id"]) - assert databases[0]["id"] == database["id"] - -# # TODO: when deleting a keyspace is available, round out the test -# @pytest.mark.it('should create a keyspace') -# def test_create_keyspace(astra_client): -# res = astra_client.ops.create_keyspace( -# database=ASTRA_DB_ID, keyspace="new_keyspace") -# print(res) -# assert res is not None - -# # TODO: find a way to terminate pending databases -# @pytest.mark.it('should create and delete a database') -# def test_create_database(astra_client): -# create_res = astra_client.ops.create_database({ -# "name": "astrapy-test-db", -# "keyspace": "astrapy_test", -# "cloudProvider": "AWS", -# "tier": "serverless", -# "capacityUnits": 1, -# "region": "us-west-2" -# }) -# assert type(create_res["id"]) is str -# time.sleep(10) -# delete_res = astra_client.ops.terminate_database(create_res["id"]) -# assert delete_res is not None - - -@pytest.mark.it('should get a secure bundle') -def test_get_secure_bundle(astra_client): - bundle = astra_client.ops.get_secure_bundle(database=ASTRA_DB_ID) - assert bundle["downloadURL"] is not None - - -@pytest.mark.it('should get datacenters') -def test_get_datacenters(astra_client): - datacenters = astra_client.ops.get_datacenters(database=ASTRA_DB_ID) - assert type(datacenters) is list - - -@pytest.mark.it('should get a private link') -def test_get_private_link(astra_client): - private_link = astra_client.ops.get_private_link(database=ASTRA_DB_ID) - assert private_link["clusterID"] == ASTRA_DB_ID - - -@pytest.mark.it('should get available classic regions') -def test_get_available_classic_regions(astra_client): - regions = astra_client.ops.get_available_classic_regions() - assert type(regions) is list - - -@pytest.mark.it('should get available regions') -def test_get_available_regions(astra_client): - regions = astra_client.ops.get_available_regions() - assert type(regions) is list - - -@pytest.mark.it('should get roles') -def test_get_roles(astra_client): - roles = astra_client.ops.get_roles() - assert type(roles) is list - - -@pytest.mark.it('should get users') -def test_get_users(astra_client): - users = astra_client.ops.get_users() - assert users["OrgID"] is not None - - -@pytest.mark.it('should get clients') -def test_get_clients(astra_client): - clients = astra_client.ops.get_clients() - assert clients["clients"] is not None - - -@pytest.mark.it('should get an organization') -def test_get_organization(astra_client): - organization = astra_client.ops.get_organization() - assert organization["id"] is not None - - -@pytest.mark.it('should get an access list template') -def test_get_access_list_template(astra_client): - access_list_template = astra_client.ops.get_access_list_template() - assert access_list_template["addresses"] is not None - - -@pytest.mark.it('should get all private links') -def test_get_private_links(astra_client): - private_links = astra_client.ops.get_private_links() - assert type(private_links) is list - - -@pytest.mark.it('should get all streaming providers') -def test_get_streaming_providers(astra_client): - streaming_providers = astra_client.ops.get_streaming_providers() - assert streaming_providers["aws"] is not None - - -@pytest.mark.it('should get all streaming tenants') -def test_get_streaming_tenants(astra_client): - streaming_tenants = astra_client.ops.get_streaming_tenants() - assert type(streaming_tenants) is list - - -@pytest.mark.it('should get all keyspaces') -def test_get_keyspaces(astra_client): - keyspaces = astra_client.schemas.get_keyspaces() - assert type(keyspaces) is list - - -@pytest.mark.it('should get a keyspace') -def test_get_keyspace(astra_client): - keyspaces = astra_client.schemas.get_keyspaces() - keyspace = astra_client.schemas.get_keyspace(keyspace=keyspaces[0]["name"]) - assert keyspace["name"] == keyspaces[0]["name"] - - -@pytest.mark.it('should create a table') -def test_create_table(astra_client, table_definition): - table = astra_client.schemas.create_table(keyspace=ASTRA_DB_KEYSPACE, - table_definition=table_definition) - assert table["name"] == table_definition["name"] - - -@pytest.mark.it('should get all tables') -def test_get_tables(astra_client): - tables = astra_client.schemas.get_tables(keyspace=ASTRA_DB_KEYSPACE) - assert type(tables) is list - - -@pytest.mark.it('should get a table') -def test_get_table(astra_client, table_definition): - table = astra_client.schemas.get_table(keyspace=ASTRA_DB_KEYSPACE, - table=table_definition["name"]) - assert table["name"] == table_definition["name"] - - -@pytest.mark.it('should update a table') -def test_update_table(astra_client, table_definition): - table_definition["tableOptions"] = {"defaultTimeToLive": 0} - table = astra_client.schemas.update_table(keyspace=ASTRA_DB_KEYSPACE, - table_definition=table_definition) - assert table["name"] == table_definition["name"] - - -@pytest.mark.it('should create a column') -def test_create_column(astra_client, table_definition, column_definition): - column = astra_client.schemas.create_column(keyspace=ASTRA_DB_KEYSPACE, - table=table_definition["name"], - column_definition=column_definition) - assert column["name"] == column_definition["name"] - - -@pytest.mark.it('should get columns') -def test_get_columns(astra_client, table_definition): - columns = astra_client.schemas.get_columns(keyspace=ASTRA_DB_KEYSPACE, - table=table_definition["name"]) - assert type(columns) is list - - -@pytest.mark.it('should get a column') -def test_get_column(astra_client, table_definition, column_definition): - column = astra_client.schemas.get_column(keyspace=ASTRA_DB_KEYSPACE, - table=table_definition["name"], - column=column_definition["name"]) - assert column["name"] == column_definition["name"] - - -@pytest.mark.it('should delete a column') -def test_delete_column(astra_client, table_definition, column_definition): - res = astra_client.schemas.delete_column(keyspace=ASTRA_DB_KEYSPACE, - table=table_definition["name"], - column=column_definition["name"]) - assert res is None - - -@pytest.mark.it('should create an index') -def test_create_index(astra_client, table_definition, index_definition): - res = astra_client.schemas.create_index(keyspace=ASTRA_DB_KEYSPACE, - table=table_definition["name"], - index_definition=index_definition) - assert res["success"] == True - - -@pytest.mark.it('should get all indexes') -def test_get_indexes(astra_client, table_definition): - indexes = astra_client.schemas.get_indexes(keyspace=ASTRA_DB_KEYSPACE, - table=table_definition["name"]) - assert type(indexes) is list - - -@pytest.mark.it('should delete an index') -def test_delete_index(astra_client, table_definition, index_definition): - res = astra_client.schemas.delete_index(keyspace=ASTRA_DB_KEYSPACE, - table=table_definition["name"], - index=index_definition["name"]) - assert res is None - - -@pytest.mark.it('should create a type') -def test_create_type(astra_client, udt_definition): - udt = astra_client.schemas.create_type(keyspace=ASTRA_DB_KEYSPACE, - udt_definition=udt_definition) - assert udt["name"] == udt_definition["name"] - - -@pytest.mark.it('should get all types') -def test_get_types(astra_client): - udts = astra_client.schemas.get_types(keyspace=ASTRA_DB_KEYSPACE) - assert type(udts) is list - - -@pytest.mark.it('should get a type') -def test_get_type(astra_client, udt_definition): - udt = astra_client.schemas.get_type( - keyspace=ASTRA_DB_KEYSPACE, udt=udt_definition["name"]) - assert udt["name"] == udt_definition["name"] - - -@pytest.mark.it('should update a type') -def test_update_type(astra_client): - udt_definition = {"name": "custom", - "addFields": [{ - "name": "description", - "typeDefinition": "text", - }]} - res = astra_client.schemas.update_type(keyspace=ASTRA_DB_KEYSPACE, - udt_definition=udt_definition) - assert res is None - - -@pytest.mark.it('should delete a type') -def test_delete_type(astra_client, udt_definition): - res = astra_client.schemas.delete_type(keyspace=ASTRA_DB_KEYSPACE, - udt=udt_definition["name"]) - assert res is None - - -@pytest.mark.it('should add rows') -def test_add_row(astra_client, table_definition): - row_definition = {"firstname": "Cliff", "lastname": "Wicklow"} - row = astra_client.rest.add_row(keyspace=ASTRA_DB_KEYSPACE, - table=table_definition["name"], - row=row_definition) - assert row["firstname"] == row_definition["firstname"] - - -@pytest.mark.it('should get rows') -def test_get_rows(astra_client, table_definition): - rows = astra_client.rest.get_rows(keyspace=ASTRA_DB_KEYSPACE, - table=table_definition["name"], - key_path="Cliff/Wicklow") - assert rows["count"] is not None - assert rows["data"][0]["firstname"] == "Cliff" - - -@pytest.mark.it('should search a table') -def test_search_table(astra_client, table_definition): - query = {"firstname": {"$eq": "Cliff"}} - res = astra_client.rest.search_table(keyspace=ASTRA_DB_KEYSPACE, - table=table_definition["name"], - query=query) - assert res["count"] is not None - assert res["data"][0]["firstname"] == "Cliff" - - -@pytest.mark.it('should query the gql schema') -def test_gql_schema(astra_client): - query = """{ - keyspaces { - name - } - }""" - res = astra_client.gql.execute(query=query) - assert res["keyspaces"] is not None - - -@pytest.mark.it('should use gql to create a table') -def test_gql_create_table(astra_client): - query = """ - mutation createTable ($keyspaceName: String!) { - book: createTable( - keyspaceName: $keyspaceName, - tableName: "book", - partitionKeys: [ - { name: "title", type: { basic: TEXT } } - ] - clusteringKeys: [ - { name: "author", type: { basic: TEXT } } - ] - ) - } - """ - res = astra_client.gql.execute(query=query, - variables={"keyspaceName": ASTRA_DB_KEYSPACE}) - assert res["book"] is True - - -@pytest.mark.it('should use gql to insert into a table') -def test_gql_insert_table(astra_client): - query = """ - mutation insert2Books { - moby: insertbook(value: {title:"Moby Dick", author:"Herman Melville"}) { - value { - title - } - } - catch22: insertbook(value: {title:"Catch-22", author:"Joseph Heller"}) { - value { - title - } - } - } - """ - res = astra_client.gql.execute(query=query, keyspace=ASTRA_DB_KEYSPACE) - assert res["moby"] is not None - - -@pytest.mark.it('should use gql to delete a table') -def test_gql_delete_table(astra_client): - query = """ - mutation dropTable ($keyspaceName: String!) { - book: dropTable( - keyspaceName: $keyspaceName, - tableName: "book" - ) - } - """ - res = astra_client.gql.execute(query=query, - variables={"keyspaceName": ASTRA_DB_KEYSPACE}) - assert res["book"] is True - - -@pytest.mark.it('should delete a table') -def test_delete_table(astra_client, table_definition): - res = astra_client.schemas.delete_table(keyspace=ASTRA_DB_KEYSPACE, - table=table_definition["name"]) - assert res == None diff --git a/tests/astrapy/test_collections.py b/tests/astrapy/test_collections.py index 61fd34e3..001ea27d 100644 --- a/tests/astrapy/test_collections.py +++ b/tests/astrapy/test_collections.py @@ -12,268 +12,456 @@ # See the License for the specific language governing permissions and # limitations under the License. -from astrapy.collections import create_client, AstraCollection +from astrapy.db import AstraDBCollection, AstraDB +from astrapy.defaults import DEFAULT_KEYSPACE_NAME, DEFAULT_REGION + import uuid import pytest import logging import os from faker import Faker +import json logger = logging.getLogger(__name__) fake = Faker() -ASTRA_DB_ID = os.environ.get('ASTRA_DB_ID') -ASTRA_DB_REGION = os.environ.get('ASTRA_DB_REGION') -ASTRA_DB_APPLICATION_TOKEN = os.environ.get('ASTRA_DB_APPLICATION_TOKEN') -ASTRA_DB_KEYSPACE = os.environ.get('ASTRA_DB_KEYSPACE') -TEST_COLLECTION_NAME = "test" +from dotenv import load_dotenv +load_dotenv() -@pytest.fixture -def test_collection(): - astra_client = create_client(astra_database_id=ASTRA_DB_ID, - astra_database_region=ASTRA_DB_REGION, - astra_application_token=ASTRA_DB_APPLICATION_TOKEN) - return astra_client.namespace(ASTRA_DB_KEYSPACE).collection(TEST_COLLECTION_NAME) + +ASTRA_DB_ID = os.environ.get("ASTRA_DB_ID") +ASTRA_DB_REGION = os.environ.get("ASTRA_DB_REGION", DEFAULT_REGION) +ASTRA_DB_APPLICATION_TOKEN = os.environ.get("ASTRA_DB_APPLICATION_TOKEN") +ASTRA_DB_KEYSPACE = os.environ.get("ASTRA_DB_KEYSPACE", DEFAULT_KEYSPACE_NAME) +ASTRA_DB_BASE_URL = os.environ.get("ASTRA_DB_BASE_URL", "apps.astra.datastax.com") + +TEST_COLLECTION_NAME = "test_collection" +cliffu = str(uuid.uuid4()) @pytest.fixture def cliff_uuid(): - return str(uuid.uuid4()) + return cliffu @pytest.fixture -def test_namespace(): - astra_client = create_client(astra_database_id=ASTRA_DB_ID, - astra_database_region=ASTRA_DB_REGION, - astra_application_token=ASTRA_DB_APPLICATION_TOKEN) - return astra_client.namespace(ASTRA_DB_KEYSPACE) - - -@pytest.mark.it('should initialize an AstraDB Collections Client') -def test_connect(test_collection): - assert type(test_collection) is AstraCollection - - -@pytest.mark.it('should create a collection') -def test_create_collection(test_namespace): - res = test_namespace.create_collection(name="pytest_collection") - assert res is None - res2 = test_namespace.create_collection(name="test_schema") - assert res2 is None - - -@pytest.mark.it('should get all collections') -def test_get_collections(test_namespace): - res = test_namespace.get_collections() - assert type(res) is list - - -@pytest.mark.it('should create a collection with schema') -def test_create_schema(test_namespace): - schema = { - "$id": "https://example.com/person.schema.json", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "title": "Person", - "type": "object", - "properties": { - "firstName": { - "type": "string", - "description": "The persons first name." - }, - "lastName": { - "type": "string", - "description": "The persons last name." - }, - "age": { - "description": "Age in years which must be equal to or greater than zero.", - "type": "integer", - "minimum": 0 - } - } - } - test_collection = test_namespace.collection("test_schema") - res = test_collection.create_schema(schema=schema) - assert res["schema"] is not None - - -@pytest.mark.it('should update a collection with schema') -def test_update_schema(test_namespace): - schema = { - "$id": "https://example.com/person.schema.json", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "title": "Person", - "type": "object", - "properties": { - "firstName": { - "type": "string", - "description": "The persons first name." - } - } - } - test_collection = test_namespace.collection("test_schema") - res = test_collection.update_schema(schema=schema) - assert res["schema"] is not None +def test_collection(): + astra_db_collection = AstraDBCollection( + collection_name=TEST_COLLECTION_NAME, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=f"https://{ASTRA_DB_ID}-{ASTRA_DB_REGION}.{ASTRA_DB_BASE_URL}", + namespace=ASTRA_DB_KEYSPACE, + ) + + return astra_db_collection + + +@pytest.fixture +def test_db(): + astra_db = AstraDB( + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=f"https://{ASTRA_DB_ID}-{ASTRA_DB_REGION}.{ASTRA_DB_BASE_URL}", + namespace=ASTRA_DB_KEYSPACE, + ) + + return astra_db + +@pytest.mark.describe("should create a vector collection") +def test_create_collection(test_db): + res = test_db.create_collection(collection_name=TEST_COLLECTION_NAME, size=5) + print("CREATE", res) + assert res is not None -@pytest.mark.it('should delete a collection') -def test_delete_collection(test_namespace): - res = test_namespace.delete_collection(name="pytest_collection") - assert res is None - res2 = test_namespace.delete_collection(name="test_schema") - assert res2 is None +@pytest.mark.describe("should get all collections") +def test_get_collections(test_db): + res = test_db.get_collections() + assert res["status"]["collections"] is not None -@pytest.mark.it('should create a document') -def test_create_document(test_collection, cliff_uuid): - test_collection.create(path=cliff_uuid, document={ + +@pytest.mark.describe("should create a document") +def test_create_document_cliff(test_collection, cliff_uuid): + json_query = { + "_id": cliff_uuid, "first_name": "Cliff", "last_name": "Wicklow", - }) - document = test_collection.get(path=cliff_uuid) - assert document["first_name"] == "Cliff" + } + + test_collection.insert_one(document=json_query) + + document = test_collection.find_one(filter={"_id": cliff_uuid}) + + assert document is not None + + +@pytest.mark.describe("should create a vector document") +def test_create_document(test_collection): + json_query = { + "_id": "4", + "name": "Coded Cleats Copy", + "description": "ChatGPT integrated sneakers that talk to you", + "$vector": [0.25, 0.25, 0.25, 0.25, 0.25], + } + + res = test_collection.insert_one(document=json_query) + assert res is not None -@pytest.mark.it('should create multiple documents') -def test_batch(test_collection): + +@pytest.mark.describe("Find one document") +def test_find_document(test_collection): + document = test_collection.find_one(filter={"_id": "4"}) + print("DOC", document) + assert document is not None + + +@pytest.mark.describe("should create multiple documents: nonvector") +def test_insert_many(test_collection): id_1 = fake.bothify(text="????????") id_2 = fake.bothify(text="????????") - documents = [{ - "_id": id_1, - "first_name": "Dang", - "last_name": "Son", - }, { - "_id": id_2, - "first_name": "Yep", - "last_name": "Boss", - }] - res = test_collection.batch(documents=documents, id_path="_id") - assert res["documentIds"] is not None - - document_1 = test_collection.get(path=id_1) - assert document_1["first_name"] == "Dang" - - document_2 = test_collection.get(path=id_2) - assert document_2["first_name"] == "Yep" - - -@pytest.mark.it('should create a subdocument') + id_3 = fake.bothify(text="????????") + documents = [ + { + "_id": id_1, + "first_name": "Dang", + "last_name": "Son", + }, + { + "_id": id_2, + "first_name": "Yep", + "last_name": "Boss", + }, + ] + res = test_collection.insert_many(documents=documents) + assert res is not None + + documents2 = [ + { + "_id": id_2, + "first_name": "Yep", + "last_name": "Boss", + }, + { + "_id": id_3, + "first_name": "Miv", + "last_name": "Fuff", + }, + ] + res = test_collection.insert_many( + documents=documents2, + partial_failures_allowed=True, + ) + print(res) + assert set(res["status"]["insertedIds"]) == set() + + res = test_collection.insert_many( + documents=documents2, + options={"ordered": False}, + partial_failures_allowed=True, + ) + print(res) + assert set(res["status"]["insertedIds"]) == {id_3} + + document = test_collection.find(filter={"first_name": "Yep"}) + assert document is not None + + +@pytest.mark.describe("create many vector documents") +def test_create_documents(test_collection): + json_query = [ + { + "_id": "1", + "name": "Coded Cleats", + "description": "ChatGPT integrated sneakers that talk to you", + "$vector": [0.1, 0.15, 0.3, 0.12, 0.05], + }, + { + "_id": "2", + "name": "Logic Layers", + "description": "An AI quilt to help you sleep forever", + "$vector": [0.45, 0.09, 0.01, 0.2, 0.11], + }, + { + "_id": "3", + "name": "Vision Vector Frame", + "description": "Vision Vector Frame - A deep learning display that controls your mood", + "$vector": [0.1, 0.05, 0.08, 0.3, 0.6], + }, + ] + + res = test_collection.insert_many(documents=json_query) + assert res is not None + + +@pytest.mark.describe("should create a subdocument") def test_create_subdocument(test_collection, cliff_uuid): - test_collection.create(path=f"{cliff_uuid}/addresses", document={ - "home": { - "city": "New York", - "state": "NY", - } - }) - document = test_collection.get(path=f"{cliff_uuid}/addresses") - assert document["home"]["state"] == "NY" + document = test_collection.update_one( + filter={"_id": cliff_uuid}, + update={"$set": {"addresses.city": "New York", "addresses.state": "NY"}}, + ) + print("SUBSUB", document) + + document = test_collection.find_one(filter={"_id": cliff_uuid}) + print("SUBDOC", document) + assert document["data"]["document"]["addresses"] is not None -@pytest.mark.it('should create a document without an ID') +@pytest.mark.describe("should create a document without an ID") def test_create_document_without_id(test_collection): - response = test_collection.create(document={ - "first_name": "New", - "last_name": "Guy", - }) - document = test_collection.get(path=response["documentId"]) - assert document["first_name"] == "New" + response = test_collection.insert_one( + document={ + "first_name": "New", + "last_name": "Guy", + } + ) + assert response is not None + document = test_collection.find_one(filter={"first_name": "New"}) + assert document["data"]["document"]["last_name"] == "Guy" -@pytest.mark.it('should udpate a document') +@pytest.mark.describe("should update a document") def test_update_document(test_collection, cliff_uuid): - test_collection.update(path=cliff_uuid, document={ - "first_name": "Dang", - }) - document = test_collection.get(path=cliff_uuid) - assert document["first_name"] == "Dang" - - -@pytest.mark.it('replace a subdocument') -def test_replace_subdocument(test_collection, cliff_uuid): - test_collection.replace(path=f"{cliff_uuid}/addresses", document={ - "work": { - "city": "New York", - "state": "NY", - } - }) - document = test_collection.get(path=f"{cliff_uuid}/addresses/work") - assert document["state"] == "NY" - document_2 = test_collection.get(path=f"{cliff_uuid}/addresses/home") - assert document_2 is None + test_collection.update_one( + filter={"_id": cliff_uuid}, + update={"$set": {"first_name": "Dang"}}, + ) + document = test_collection.find_one(filter={"_id": cliff_uuid}) + assert document["data"]["document"]["_id"] == cliff_uuid + + +@pytest.mark.describe("replace a non-vector document") +def test_replace_document(test_collection, cliff_uuid): + test_collection.find_one_and_replace( + filter={"_id": cliff_uuid}, + replacement={ + "_id": cliff_uuid, + "addresses": { + "work": { + "city": "New York", + "state": "NY", + } + }, + }, + ) + document = test_collection.find_one(filter={"_id": cliff_uuid}) + print(document) + + assert document is not None + document_2 = test_collection.find_one( + filter={"_id": cliff_uuid}, projection={"addresses.work.city": 1} + ) + print("HOME", json.dumps(document_2, indent=4)) -@pytest.mark.it('should delete a subdocument') + +@pytest.mark.describe("should delete a subdocument") def test_delete_subdocument(test_collection, cliff_uuid): - test_collection.delete(path=f"{cliff_uuid}/addresses") - document = test_collection.get(path=f"{cliff_uuid}/addresses") - assert document is None + response = test_collection.delete_subdocument(id=cliff_uuid, subdoc="addresses") + document = test_collection.find(filter={"_id": cliff_uuid}) + assert response is not None -@pytest.mark.it('should delete a document') +@pytest.mark.describe("should delete a document") def test_delete_document(test_collection, cliff_uuid): - test_collection.delete(path=cliff_uuid) - document = test_collection.get(path=cliff_uuid) - assert document is None + response = test_collection.delete(id=cliff_uuid) + + assert response is not None + + +@pytest.mark.describe("Find documents using vector search") +def test_find_documents_vector(test_collection): + sort = {"$vector": [0.15, 0.1, 0.1, 0.35, 0.55]} + options = {"limit": 100} + + document = test_collection.find(sort=sort, options=options) + assert document is not None + + +@pytest.mark.describe("Find documents using vector search with error") +def test_find_documents_vector_error(test_collection): + sort = ({"$vector": [0.15, 0.1, 0.1, 0.35, 0.55]},) + options = {"limit": 100} + + try: + test_collection.find(sort=sort, options=options) + except ValueError as e: + assert e is not None + +@pytest.mark.describe("Find documents using vector search and projection") +def test_find_documents_vector_proj(test_collection): + sort = {"$vector": [0.15, 0.1, 0.1, 0.35, 0.55]} + options = {"limit": 100} + projection = {"$vector": 1, "$similarity": 1} -@pytest.mark.it('should find documents') + document = test_collection.find(sort=sort, options=options, projection=projection) + assert document is not None + + +@pytest.mark.describe("Find a document using vector search and projection") +def test_find_documents_vector_proj(test_collection): + sort = {"$vector": [0.15, 0.1, 0.1, 0.35, 0.55]} + projection = {"$vector": 1} + + document = test_collection.find(sort=sort, options={}, projection=projection) + assert document is not None + + +@pytest.mark.describe("Find one and update with vector search") +def test_find_one_and_update_vector(test_collection): + sort = {"$vector": [0.15, 0.1, 0.1, 0.35, 0.55]} + update = {"$set": {"status": "active"}} + options = {"returnDocument": "after"} + + result = test_collection.find_one_and_update( + sort=sort, update=update, options=options + ) + print(result) + document = test_collection.find_one(filter={"status": "active"}) + print(document) + assert document["data"]["document"] is not None + + +@pytest.mark.describe("Find one and replace with vector search") +def test_find_one_and_replace_vector(test_collection): + sort = {"$vector": [0.15, 0.1, 0.1, 0.35, 0.55]} + replacement = { + "_id": "3", + "name": "Vision Vector Frame", + "description": "Vision Vector Frame - A deep learning display that controls your mood", + "$vector": [0.1, 0.05, 0.08, 0.3, 0.6], + "status": "inactive", + } + options = {"returnDocument": "after"} + + test_collection.find_one_and_replace( + sort=sort, replacement=replacement, options=options + ) + document = test_collection.find_one(filter={"name": "Vision Vector Frame"}) + assert document["data"]["document"] is not None + + +@pytest.mark.describe("should find documents, non-vector") def test_find_documents(test_collection): user_id = str(uuid.uuid4()) - test_collection.create(path=user_id, document={ - "first_name": f"Cliff-{user_id}", - "last_name": "Wicklow", - }) + test_collection.insert_one( + document={ + "_id": user_id, + "first_name": f"Cliff-{user_id}", + "last_name": "Wicklow", + }, + ) user_id_2 = str(uuid.uuid4()) - test_collection.create(path=user_id_2, document={ - "first_name": f"Cliff-{user_id}", - "last_name": "Danger", - }) - documents = test_collection.find(query={ - "first_name": {"$eq": f"Cliff-{user_id}"}, - }) - assert len(documents["data"].keys()) == 2 - assert documents["data"][user_id]["last_name"] == "Wicklow" - assert documents["data"][user_id_2]["last_name"] == "Danger" - - -@pytest.mark.it('should find a single document') + test_collection.insert_one( + document={ + "_id": user_id_2, + "first_name": f"Cliff-{user_id}", + "last_name": "Danger", + }, + ) + document = test_collection.find(filter={"first_name": f"Cliff-{user_id}"}) + assert document is not None + + +@pytest.mark.describe("should find a single document, non-vector") def test_find_one_document(test_collection): user_id = str(uuid.uuid4()) - test_collection.create(path=user_id, document={ - "first_name": f"Cliff-{user_id}", - "last_name": "Wicklow", - }) + test_collection.insert_one( + document={ + "_id": user_id, + "first_name": f"Cliff-{user_id}", + "last_name": "Wicklow", + }, + ) user_id_2 = str(uuid.uuid4()) - test_collection.create(path=user_id_2, document={ - "first_name": f"Cliff-{user_id}", - "last_name": "Danger", - }) - document = test_collection.find_one(query={ - "first_name": {"$eq": f"Cliff-{user_id}"}, - }) - assert document["first_name"] == f"Cliff-{user_id}" - document = test_collection.find_one(query={ - "first_name": {"$eq": f"Cliff-Not-There"}, - }) - assert document is None - - -@pytest.mark.it('should use document functions') -def test_functions(test_collection): - user_id = str(uuid.uuid4()) - test_collection.create(path=user_id, document={ - "first_name": f"Cliff-{user_id}", - "last_name": "Wicklow", - "roles": ["admin", "user"] - }) + test_collection.insert_one( + document={ + "_id": user_id_2, + "first_name": f"Cliff-{user_id}", + "last_name": "Danger", + }, + ) + document = test_collection.find_one(filter={"first_name": f"Cliff-{user_id}"}) + print("DOCUMENT", document) + + assert document["data"]["document"] is not None + + document = test_collection.find_one(filter={"first_name": f"Cliff-Not-There"}) + assert document["data"]["document"] == None + + +@pytest.mark.describe("upsert a document") +def test_upsert_document(test_collection, cliff_uuid): + new_uuid = str(uuid.uuid4()) + + test_collection.upsert( + { + "_id": new_uuid, + "addresses": { + "work": { + "city": "Seattle", + "state": "WA", + } + }, + } + ) + + document = test_collection.find_one(filter={"_id": new_uuid}) + + # Check the document exists and that the city field is Seattle + assert document is not None + assert document["data"]["document"]["addresses"]["work"]["city"] == "Seattle" + assert "country" not in document["data"]["document"]["addresses"]["work"] + + test_collection.upsert( + { + "_id": cliff_uuid, + "addresses": {"work": {"city": "Everett", "state": "WA", "country": "USA"}}, + } + ) + + document = test_collection.find_one(filter={"_id": cliff_uuid}) - pop_res = test_collection.pop(path=f"{user_id}/roles") - assert pop_res == "user" + assert document is not None + assert document["data"]["document"]["addresses"]["work"]["city"] == "Everett" + assert "country" in document["data"]["document"]["addresses"]["work"] - doc_1 = test_collection.get(path=user_id) - assert len(doc_1["roles"]) == 1 - test_collection.push(path=f"{user_id}/roles", value="users") - doc_2 = test_collection.get(path=user_id) - assert len(doc_2["roles"]) == 2 +@pytest.mark.describe("should use document functions") +def test_functions(test_collection): + user_id = str(uuid.uuid4()) + test_collection.insert_one( + document={ + "_id": user_id, + "first_name": f"Cliff-{user_id}", + "last_name": "Wicklow", + "roles": ["admin", "user"], + }, + ) + update = {"$pop": {"roles": 1}} + options = {"returnDocument": "after"} + + pop_res = test_collection.pop( + filter={"_id": user_id}, update=update, options=options + ) + + doc_1 = test_collection.find_one(filter={"_id": user_id}) + assert doc_1["data"]["document"]["_id"] == user_id + + update = {"$push": {"roles": "users"}} + options = {"returnDocument": "after"} + + test_collection.push(filter={"_id": user_id}, update=update, options=options) + doc_2 = test_collection.find_one(filter={"_id": user_id}) + assert doc_2["data"]["document"]["_id"] == user_id + + +@pytest.mark.describe("should delete a collection") +def test_delete_collection(test_db): + res = test_db.delete_collection(collection_name="test_collection") + assert res is not None + res2 = test_db.delete_collection(collection_name="test_collection") + assert res2 is not None diff --git a/tests/astrapy/test_devops.py b/tests/astrapy/test_devops.py new file mode 100644 index 00000000..01bb1f66 --- /dev/null +++ b/tests/astrapy/test_devops.py @@ -0,0 +1,86 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from astrapy.ops import AstraDBOps +from astrapy.defaults import DEFAULT_KEYSPACE_NAME, DEFAULT_REGION + +import pytest +import logging +import os +from faker import Faker + +logger = logging.getLogger(__name__) +fake = Faker() + + +from dotenv import load_dotenv + +load_dotenv() + + +# Parameter for the ops testing +ASTRA_DB_ID = os.environ.get("ASTRA_DB_ID") +ASTRA_DB_REGION = os.environ.get("ASTRA_DB_REGION", DEFAULT_REGION) +ASTRA_DB_APPLICATION_TOKEN = os.environ.get("ASTRA_DB_APPLICATION_TOKEN") +ASTRA_DB_KEYSPACE = os.environ.get("ASTRA_DB_KEYSPACE", DEFAULT_KEYSPACE_NAME) +ASTRA_DB_BASE_URL = os.environ.get("ASTRA_DB_BASE_URL", "apps.astra.datastax.com") + + +@pytest.fixture +def devops_client(): + return AstraDBOps(token=ASTRA_DB_APPLICATION_TOKEN) + + +@pytest.mark.describe("should initialize an AstraDB Ops Client") +def test_client_type(devops_client): + assert type(devops_client) is AstraDBOps + + +@pytest.mark.describe("should get all databases") +def test_get_databases(devops_client): + response = devops_client.get_databases() + assert type(response) is list + + +@pytest.mark.describe("should create a database") +def test_create_database(devops_client): + database_definition = { + "name": "vector_test_create", + "tier": "serverless", + "cloudProvider": "GCP", + "keyspace": ASTRA_DB_KEYSPACE, + "region": ASTRA_DB_REGION, + "capacityUnits": 1, + "user": "token", + "password": ASTRA_DB_APPLICATION_TOKEN, + "dbType": "vector", + } + response = devops_client.create_database(database_definition=database_definition) + assert response["id"] is not None + ASTRA_TEMP_DB = response["id"] + + check_db = devops_client.get_database(database=ASTRA_TEMP_DB) + assert check_db is not None + + response = devops_client.terminate_database(database=ASTRA_TEMP_DB) + assert response is None + + +@pytest.mark.describe("should create a keyspace") +def test_create_keyspace(devops_client): + response = devops_client.create_keyspace( + keyspace="test_namespace", database=os.environ["ASTRA_DB_ID"] + ) + + assert response is not None diff --git a/tests/astrapy/test_pagination.py b/tests/astrapy/test_pagination.py new file mode 100644 index 00000000..79cfa842 --- /dev/null +++ b/tests/astrapy/test_pagination.py @@ -0,0 +1,103 @@ +# Copyright DataStax, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import os +import logging +from typing import Iterable, TypeVar + +from astrapy.db import AstraDBCollection, AstraDB +from astrapy.defaults import DEFAULT_KEYSPACE_NAME, DEFAULT_REGION + +from dotenv import load_dotenv +import pytest + +logger = logging.getLogger(__name__) + + +load_dotenv() + + +ASTRA_DB_ID = os.environ.get("ASTRA_DB_ID") +ASTRA_DB_REGION = os.environ.get("ASTRA_DB_REGION", DEFAULT_REGION) +ASTRA_DB_APPLICATION_TOKEN = os.environ.get("ASTRA_DB_APPLICATION_TOKEN") +ASTRA_DB_KEYSPACE = os.environ.get("ASTRA_DB_KEYSPACE", DEFAULT_KEYSPACE_NAME) +ASTRA_DB_BASE_URL = os.environ.get("ASTRA_DB_BASE_URL", "apps.astra.datastax.com") + + +TEST_COLLECTION_NAME = "test_collection" +INSERT_BATCH_SIZE = 20 # max 20, fixed by API constraints +N = 200 # must be EVEN +FIND_LIMIT = 183 # Keep this > 20 and <= N to actually put pagination to test + +T = TypeVar("T") + + +def mk_vector(i, N): + angle = 2 * math.pi * i / N + return [math.cos(angle), math.sin(angle)] + + +def _batch_iterable(iterable: Iterable[T], batch_size: int) -> Iterable[Iterable[T]]: + this_batch = [] + for entry in iterable: + this_batch.append(entry) + if len(this_batch) == batch_size: + yield this_batch + this_batch = [] + if this_batch: + yield this_batch + + +@pytest.fixture(scope="module") +def test_collection(): + astra_db = AstraDB( + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=f"https://{ASTRA_DB_ID}-{ASTRA_DB_REGION}.{ASTRA_DB_BASE_URL}", + namespace=ASTRA_DB_KEYSPACE, + ) + res = astra_db.create_collection(collection_name=TEST_COLLECTION_NAME, size=2) + astra_db_collection = AstraDBCollection( + collection_name=TEST_COLLECTION_NAME, + token=ASTRA_DB_APPLICATION_TOKEN, + api_endpoint=f"https://{ASTRA_DB_ID}-{ASTRA_DB_REGION}.{ASTRA_DB_BASE_URL}", + namespace=ASTRA_DB_KEYSPACE, + ) + if int(os.getenv("TEST_PAGINATION_SKIP_INSERTION", "0")) == 0: + inserted_ids = set() + for i_batch in _batch_iterable(range(N), INSERT_BATCH_SIZE): + batch_ids = astra_db_collection.insert_many( + documents=[{"_id": str(i), "$vector": mk_vector(i, N)} for i in i_batch] + )["status"]["insertedIds"] + inserted_ids = inserted_ids | set(batch_ids) + assert inserted_ids == {str(i) for i in range(N)} + yield astra_db_collection + if int(os.getenv("TEST_PAGINATION_SKIP_DELETE_COLLECTION", "0")) == 0: + res = astra_db.delete_collection(collection_name=TEST_COLLECTION_NAME) + + +@pytest.mark.describe( + "should retrieve the required amount of documents, all different, through pagination" +) +def test_find_paginated(test_collection): + options = {"limit": FIND_LIMIT} + projection = {"$vector": 0} + + paginated_documents = test_collection.paginated_find( + projection=projection, + options=options, + ) + paginated_ids = [doc["_id"] for doc in paginated_documents] + assert len(paginated_ids) == FIND_LIMIT + assert len(paginated_ids) == len(set(paginated_ids)) diff --git a/tests/astrapy/test_rest.py b/tests/astrapy/test_rest.py deleted file mode 100644 index e11b76a7..00000000 --- a/tests/astrapy/test_rest.py +++ /dev/null @@ -1,97 +0,0 @@ -# Copyright DataStax, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from astrapy.rest import create_client, AstraClient, http_methods -import pytest -import logging -import os -import uuid - -logger = logging.getLogger(__name__) - -ASTRA_DB_ID = os.environ.get('ASTRA_DB_ID') -ASTRA_DB_REGION = os.environ.get('ASTRA_DB_REGION') -ASTRA_DB_APPLICATION_TOKEN = os.environ.get('ASTRA_DB_APPLICATION_TOKEN') -ASTRA_DB_KEYSPACE = os.environ.get('ASTRA_DB_KEYSPACE') -ASTRA_DB_COLLECTION = "rest" - -STARGATE_BASE_URL = os.environ.get('STARGATE_BASE_URL') -STARGATE_AUTH_URL = os.environ.get('STARGATE_AUTH_URL') -STARGATE_USERNAME = os.environ.get('STARGATE_USERNAME') -STARGATE_PASSWORD = os.environ.get('STARGATE_PASSWORD') - - -@pytest.fixture -def astra_rest_client(): - return create_client(astra_database_id=ASTRA_DB_ID, - astra_database_region=ASTRA_DB_REGION, - astra_application_token=ASTRA_DB_APPLICATION_TOKEN) - - -@pytest.fixture -def stargate_rest_client(): - return create_client(base_url=STARGATE_BASE_URL, - auth_base_url=STARGATE_AUTH_URL, - username=STARGATE_USERNAME, - password=STARGATE_PASSWORD) - - -@pytest.mark.it('should initialize an AstraDB REST Client') -def test_connect(astra_rest_client): - assert type(astra_rest_client) is AstraClient - - -@pytest.mark.it('should initialize a Stargate REST Client') -def test_stargate_connect(stargate_rest_client): - assert type(stargate_rest_client) is AstraClient - - -@pytest.mark.it('should create a document') -def test_creating_document(astra_rest_client): - doc_uuid = uuid.uuid4() - r = astra_rest_client.request( - method=http_methods.PUT, - path=f"/api/rest/v2/namespaces/{ASTRA_DB_KEYSPACE}/collections/{ASTRA_DB_COLLECTION}/{doc_uuid}", - json_data={ - "name": "Cliff", - "last_name": "Wicklow", - "emails": ["cliff.wicklow@example.com"], - }) - assert r["documentId"] == str(doc_uuid) - - -@pytest.mark.it('should create a stargate keyspace') -def test_creating_stargate_keyspace(stargate_rest_client): - doc_uuid = uuid.uuid4() - r = stargate_rest_client.request( - method=http_methods.POST, - path=f"/v2/schemas/namespaces", - json_data={ - "name": ASTRA_DB_KEYSPACE - }) - assert r["name"] == ASTRA_DB_KEYSPACE - - -@pytest.mark.it('should create a stargate document') -def test_creating_stargate_document(stargate_rest_client): - doc_uuid = uuid.uuid4() - r = stargate_rest_client.request( - method=http_methods.PUT, - path=f"/v2/namespaces/{ASTRA_DB_KEYSPACE}/collections/{ASTRA_DB_COLLECTION}/{doc_uuid}", - json_data={ - "first_name": "Cliff", - "last_name": "Wicklow", - "emails": ["cliff.wicklow@example.com"], - }) - assert r["documentId"] == str(doc_uuid)