From cafe3c53c3d3d299900f133661f22615eddade6b Mon Sep 17 00:00:00 2001 From: tazlin Date: Wed, 12 Jul 2023 20:53:55 -0400 Subject: [PATCH] feat: proper asyncio support, horde context manager - e.g. `AIHordeAPISimpleClient`, `AIHordeAPISession`, and `AIHordeAPIClient` - Requests which require some sort of follow up, such as image generation requests, now can have the appropriate cancel request sent automatically with - See `AIHordeAPISession` (new context handler) or its usage in `test_ai_horde_api_calls.py` for an idea of what was accomplished. --- .gitignore | 2 + examples/ai_horde_api_client.example.py | 6 +- examples/async_ai_horde_api_client_example.py | 64 +++- horde_sdk/ai_horde_api/__init__.py | 4 +- horde_sdk/ai_horde_api/ai_horde_client.py | 62 +++- .../ai_horde_api/apimodels/generate/_async.py | 31 +- .../apimodels/generate/_status.py | 2 - horde_sdk/generic_api/apimodels.py | 50 ++- horde_sdk/generic_api/generic_client.py | 315 ++++++++++++++---- horde_sdk/generic_api/metadata.py | 1 + horde_sdk/ratings_api/apimodels.py | 1 - horde_sdk/ratings_api/ratings_client.py | 4 +- pyproject.toml | 1 + requirements.dev.txt | 2 + requirements.txt | 1 + ...pi_calls.py => test_ai_horde_api_calls.py} | 56 +++- tests/conftest.py | 16 + 17 files changed, 520 insertions(+), 98 deletions(-) rename tests/ai_horde_api/{test_api_calls.py => test_ai_horde_api_calls.py} (60%) diff --git a/.gitignore b/.gitignore index e2447dc..ca9ee48 100644 --- a/.gitignore +++ b/.gitignore @@ -138,3 +138,5 @@ dmypy.json out.json .vscode/launch.json .vscode/settings.json + +examples/requested_images/*.* diff --git a/examples/ai_horde_api_client.example.py b/examples/ai_horde_api_client.example.py index 88ed9b5..d01d5cc 100644 --- a/examples/ai_horde_api_client.example.py +++ b/examples/ai_horde_api_client.example.py @@ -1,16 +1,16 @@ -from horde_sdk.ai_horde_api import AIHordeAPIClient +from horde_sdk.ai_horde_api import AIHordeAPISimpleClient from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest from horde_sdk.generic_api.apimodels import RequestErrorResponse -def do_generate_check(ai_horde_api_client: AIHordeAPIClient) -> None: +def do_generate_check(ai_horde_api_client: AIHordeAPISimpleClient) -> None: pass def main() -> None: """Just a proof of concept - but several other pieces of functionality exist.""" - ai_horde_api_client = AIHordeAPIClient() + ai_horde_api_client = AIHordeAPISimpleClient() image_generate_async_request = ImageGenerateAsyncRequest( apikey="0000000000", diff --git a/examples/async_ai_horde_api_client_example.py b/examples/async_ai_horde_api_client_example.py index e9de72f..a56ef50 100644 --- a/examples/async_ai_horde_api_client_example.py +++ b/examples/async_ai_horde_api_client_example.py @@ -1,22 +1,26 @@ from __future__ import annotations import asyncio +import time +from pathlib import Path -from horde_sdk.ai_horde_api import AIHordeAPIClient -from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest +import aiohttp + +from horde_sdk.ai_horde_api import AIHordeAPISimpleClient +from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGenerateStatusRequest from horde_sdk.generic_api.apimodels import RequestErrorResponse async def main() -> None: print("Starting...") - ai_horde_api_client = AIHordeAPIClient() + ai_horde_api_client = AIHordeAPISimpleClient() image_generate_async_request = ImageGenerateAsyncRequest( apikey="0000000000", prompt="A cat in a hat", models=["Deliberate"], ) - + print("Submitting image generation request...") response = await ai_horde_api_client.async_submit_request( image_generate_async_request, image_generate_async_request.get_success_response_type(), @@ -26,8 +30,17 @@ async def main() -> None: print(f"Error: {response.message}") return + print("Image generation request submitted!") + image_done = False + start_time = time.time() + check_counter = 0 # Keep making ImageGenerateCheckRequests until the job is done. - while True: + while not image_done: + if time.time() - start_time > 20 or check_counter == 0: + print(f"{time.time() - start_time} seconds elapsed ({check_counter} checks made)...") + start_time = time.time() + + check_counter += 1 check_response = await ai_horde_api_client.async_get_generate_check( apikey="0000000000", generation_id=response.id_, @@ -38,10 +51,51 @@ async def main() -> None: return if check_response.done: + print("Image is done!") + image_done = True break await asyncio.sleep(5) + # Get the image with a ImageGenerateStatusRequest. + image_generate_status_request = ImageGenerateStatusRequest( + id=response.id_, + ) + + status_response = await ai_horde_api_client.async_submit_request( + image_generate_status_request, + image_generate_status_request.get_success_response_type(), + ) + + if isinstance(status_response, RequestErrorResponse): + print(f"Error: {status_response.message}") + return + + for image_gen in status_response.generations: + print("Image generation:") + print(f"ID: {image_gen.id}") + print(f"URL: {image_gen.img}") + # debug(image_gen) + print("Downloading image...") + + image_bytes = None + # image_gen.img is a url, download it using aiohttp. + async with aiohttp.ClientSession() as session, session.get(image_gen.img) as resp: + image_bytes = await resp.read() + + if image_bytes is None: + print("Error: Could not download image.") + return + + # Open a file in write mode and write the image bytes to it. + dir_to_write_to = Path("examples/requested_images/") + dir_to_write_to.mkdir(parents=True, exist_ok=True) + filepath_to_write_to = dir_to_write_to / f"{image_gen.id}.webp" + with open(filepath_to_write_to, "wb") as image_file: + image_file.write(image_bytes) + + print(f"Image downloaded to {filepath_to_write_to}!") + if __name__ == "__main__": asyncio.run(main()) diff --git a/horde_sdk/ai_horde_api/__init__.py b/horde_sdk/ai_horde_api/__init__.py index 03c42df..99fc834 100644 --- a/horde_sdk/ai_horde_api/__init__.py +++ b/horde_sdk/ai_horde_api/__init__.py @@ -1,5 +1,5 @@ from horde_sdk.ai_horde_api.ai_horde_client import ( - AIHordeAPIClient, + AIHordeAPISimpleClient, ) from horde_sdk.ai_horde_api.consts import ( ALCHEMY_FORMS, @@ -14,7 +14,7 @@ ) __all__ = [ - "AIHordeAPIClient", + "AIHordeAPISimpleClient", "AI_HORDE_BASE_URL", "AI_HORDE_API_URL_Literals", "ALCHEMY_FORMS", diff --git a/horde_sdk/ai_horde_api/ai_horde_client.py b/horde_sdk/ai_horde_api/ai_horde_client.py index b1f3be2..0e9879a 100644 --- a/horde_sdk/ai_horde_api/ai_horde_client.py +++ b/horde_sdk/ai_horde_api/ai_horde_client.py @@ -1,10 +1,14 @@ """Definitions to help interact with the AI-Horde API.""" +from __future__ import annotations + import urllib.parse from loguru import logger from horde_sdk.ai_horde_api.apimodels import ( DeleteImageGenerateRequest, + ImageGenerateAsyncRequest, + ImageGenerateAsyncResponse, ImageGenerateCheckRequest, ImageGenerateCheckResponse, ImageGenerateStatusRequest, @@ -14,10 +18,13 @@ from horde_sdk.ai_horde_api.fields import GenerationID from horde_sdk.ai_horde_api.metadata import AIHordePathData from horde_sdk.generic_api.apimodels import RequestErrorResponse -from horde_sdk.generic_api.generic_client import GenericHordeAPIClient +from horde_sdk.generic_api.generic_client import ( + GenericHordeAPISession, + GenericHordeAPISimpleClient, +) -class AIHordeAPIClient(GenericHordeAPIClient): +class AIHordeAPISimpleClient(GenericHordeAPISimpleClient): """Represent an API client specifically configured for the AI-Horde API.""" def __init__(self) -> None: @@ -156,7 +163,7 @@ def delete_pending_image( Args: generation_id (GenerationID): The ID of the request to delete. """ - api_request = DeleteImageGenerateRequest(id=generation_id, apikey=apikey) + api_request = DeleteImageGenerateRequest(id=generation_id) api_response = self.submit_request(api_request, api_request.get_success_response_type()) if isinstance(api_response, RequestErrorResponse): @@ -170,7 +177,7 @@ async def async_delete_pending_image( apikey: str, generation_id: GenerationID | str, ) -> ImageGenerateStatusResponse | RequestErrorResponse: - api_request = DeleteImageGenerateRequest(id=generation_id, apikey=apikey) + api_request = DeleteImageGenerateRequest(id=generation_id) api_response = await self.async_submit_request(api_request, api_request.get_success_response_type()) if isinstance(api_response, RequestErrorResponse): @@ -178,3 +185,50 @@ async def async_delete_pending_image( return api_response return api_response + + +class AIHordeAPISession(AIHordeAPISimpleClient, GenericHordeAPISession): + """Represent an API session specifically configured for the AI-Horde API. + + If you make a request which requires follow up (such as a request to generate an image), this will delete the + generation in progress when the context manager exits. If you do not want this to happen, use `AIHordeAPIClient`. + """ + + def __enter__(self) -> AIHordeAPISession: + _self = super().__enter__() + if not isinstance(_self, AIHordeAPISession): + raise TypeError("Unexpected type returned from super().__enter__()") + + return _self + + async def __aenter__(self) -> AIHordeAPISession: + _self = await super().__aenter__() + if not isinstance(_self, AIHordeAPISession): + raise TypeError("Unexpected type returned from super().__aenter__()") + + return _self + + +class AIHordeAPIClient: + async def image_generate_request(self, image_gen_request: ImageGenerateAsyncRequest) -> ImageGenerateAsyncResponse: + async with AIHordeAPISession() as image_gen_client: + response = await image_gen_client.async_submit_request( + api_request=image_gen_request, + expected_response_type=image_gen_request.get_success_response_type(), + ) + + if isinstance(response, RequestErrorResponse): + raise RuntimeError(f"Error response received: {response.message}") + + check_request_type = response.get_follow_up_default_request() + follow_up_data = response.get_follow_up_data() + check_request = check_request_type.model_validate(follow_up_data) + async with AIHordeAPISession() as check_client: + while True: + check_response = await check_client.async_submit_request( + api_request=check_request, + expected_response_type=check_request.get_success_response_type(), + ) + + if isinstance(check_response, RequestErrorResponse): + raise RuntimeError(f"Error response received: {check_response.message}") diff --git a/horde_sdk/ai_horde_api/apimodels/generate/_async.py b/horde_sdk/ai_horde_api/apimodels/generate/_async.py index f114f33..4dacf1d 100644 --- a/horde_sdk/ai_horde_api/apimodels/generate/_async.py +++ b/horde_sdk/ai_horde_api/apimodels/generate/_async.py @@ -5,12 +5,17 @@ BaseAIHordeRequest, BaseImageGenerateParam, ) -from horde_sdk.ai_horde_api.apimodels.generate._status import DeleteImageGenerateRequest +from horde_sdk.ai_horde_api.apimodels.generate._check import ImageGenerateCheckRequest +from horde_sdk.ai_horde_api.apimodels.generate._status import DeleteImageGenerateRequest, ImageGenerateStatusRequest from horde_sdk.ai_horde_api.consts import KNOWN_SOURCE_PROCESSING from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals from horde_sdk.ai_horde_api.fields import GenerationID from horde_sdk.consts import HTTPMethod, HTTPStatusCode -from horde_sdk.generic_api.apimodels import BaseRequestAuthenticated, BaseRequestWorkerDriven, BaseResponse +from horde_sdk.generic_api.apimodels import ( + BaseRequestAuthenticated, + BaseRequestWorkerDriven, + BaseResponse, +) class ImageGenerateAsyncResponse(BaseResponse): @@ -19,11 +24,31 @@ class ImageGenerateAsyncResponse(BaseResponse): v2 API Model: `RequestAsync` """ - id_: str | GenerationID = Field(alias="id") + id_: str | GenerationID = Field(alias="id") # TODO: Remove `str`? """The UUID for this image generation.""" kudos: float message: str | None = None + @override + @classmethod + def is_requiring_follow_up(cls) -> bool: + return True + + @override + def get_follow_up_data(self) -> dict[str, object]: + return {"id": self.id_} + + @classmethod + def get_follow_up_default_request(cls) -> type[ImageGenerateCheckRequest]: + return ImageGenerateCheckRequest + + @override + @classmethod + def get_follow_up_request_types( + cls, + ) -> list[type[ImageGenerateCheckRequest | ImageGenerateStatusRequest]]: + return [ImageGenerateCheckRequest, ImageGenerateStatusRequest] + @override @classmethod def get_api_model_name(cls) -> str | None: diff --git a/horde_sdk/ai_horde_api/apimodels/generate/_status.py b/horde_sdk/ai_horde_api/apimodels/generate/_status.py index bfd4514..3fda525 100644 --- a/horde_sdk/ai_horde_api/apimodels/generate/_status.py +++ b/horde_sdk/ai_horde_api/apimodels/generate/_status.py @@ -7,7 +7,6 @@ from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals from horde_sdk.ai_horde_api.fields import ImageID, WorkerID from horde_sdk.consts import HTTPMethod -from horde_sdk.generic_api.apimodels import BaseRequestAuthenticated class ImageGeneration(BaseModel): @@ -53,7 +52,6 @@ def get_api_model_name(cls) -> str | None: class DeleteImageGenerateRequest( BaseAIHordeRequest, - BaseRequestAuthenticated, BaseImageJobRequest, ): """Represents a DELETE request to the `/v2/generate/status/{id}` endpoint.""" diff --git a/horde_sdk/generic_api/apimodels.py b/horde_sdk/generic_api/apimodels.py index 4ecd769..17cc0be 100644 --- a/horde_sdk/generic_api/apimodels.py +++ b/horde_sdk/generic_api/apimodels.py @@ -41,6 +41,48 @@ class HordeAPIMessage(HordeAPIModel): class BaseResponse(HordeAPIMessage): """Represents any response from any Horde API.""" + @classmethod + def is_requiring_follow_up(cls) -> bool: + """Return whether this response requires a follow up request of some kind.""" + return False + + def get_follow_up_data(self) -> dict[str, object]: + """Return the information required from this response to submit a follow up request. + + Note that this dict uses the alias field names (as seen on the API), not the python field names. + GenerationIDs will be returned as `{"id": "00000000-0000-0000-0000-000000000000"}` instead of + `{"id_": "00000000-0000-0000-0000-000000000000"}`. + + This means it is suitable for passing directly + to a constructor, such as `ImageGenerateStatusRequest(**response.get_follow_up_required_info())`. + """ + raise NotImplementedError("This response does not require a follow up request") + + @classmethod + def get_follow_up_default_request(cls) -> type[BaseRequest]: + raise NotImplementedError("This response does not require a follow up request") + + @classmethod + def get_follow_up_request_types(cls) -> list[type]: # TODO type hint??? + """Return a list of all the possible follow up request types for this response""" + return [cls.get_follow_up_default_request()] + + _follow_up_handled: bool = False + + def set_follow_up_handled(self) -> None: + """Set this response as having had its follow up request handled. + + This is used for context management. + """ + self._follow_up_handled = True + + def is_follow_up_handled(self) -> bool: + """Return whether this response has had its follow up request handled. + + This is used for context management. + """ + return self._follow_up_handled + @classmethod def is_array_response(cls) -> bool: """Return whether this response is an array of an internal type.""" @@ -128,6 +170,11 @@ def get_http_method(cls) -> HTTPMethod: """The 'accept' header field.""" # X_Fields # TODO + client_agent: str = Field( + default="horde_sdk:0.2.0:https://githib.com/haidra-org/horde-sdk", + alias="Client-Agent", + ) + @classmethod def get_endpoint_url(cls) -> str: """Return the endpoint URL, including the path to the specific API action defined by this object""" @@ -168,7 +215,8 @@ def get_header_fields(cls) -> list[str]: @classmethod def is_recovery_enabled(cls) -> bool: - """Return whether this request should attempt to recover from during a client failure. + """Return whether this request should attempt to recover from a client failure by submitting a request + specified by `get_recovery_request_type`. This is used in for context management. """ diff --git a/horde_sdk/generic_api/generic_client.py b/horde_sdk/generic_api/generic_client.py index 7f21333..8f56f2b 100644 --- a/horde_sdk/generic_api/generic_client.py +++ b/horde_sdk/generic_api/generic_client.py @@ -1,17 +1,19 @@ """The API client which can perform arbitrary horde API requests.""" -from typing import Generic, TypeVar +from __future__ import annotations + +import asyncio +from typing import TypeVar import aiohttp import requests +from loguru import logger from pydantic import BaseModel, ValidationError from strenum import StrEnum +from typing_extensions import override from horde_sdk.generic_api.apimodels import ( BaseRequest, - BaseRequestAuthenticated, - BaseRequestUserSpecific, - BaseRequestWorkerDriven, BaseResponse, RequestErrorResponse, ) @@ -22,11 +24,6 @@ GenericQueryFields, ) -HordeRequest = TypeVar("HordeRequest", bound=BaseRequest) -"""TypeVar for the request type.""" -HordeResponse = TypeVar("HordeResponse", bound=BaseResponse) -"""TypeVar for the response type.""" - class _ParsedRequest(BaseModel): endpoint_no_query: str @@ -41,10 +38,16 @@ class _ParsedRequest(BaseModel): """The body to be sent with the request, or `None` if no body should be sent.""" -class GenericHordeAPIClient: - """Interfaces with any flask API the horde provides, intended to be fairly dynamic/flexible. +HordeRequest = TypeVar("HordeRequest", bound=BaseRequest) +"""TypeVar for the request type.""" +HordeResponse = TypeVar("HordeResponse", bound=BaseResponse) +"""TypeVar for the response type.""" + + +class GenericHordeAPISimpleClient: + """Interfaces with any flask API the horde provides, but provides little error handling. - You can use the friendly, typed functions, or if you prefer more control, you can use the `submit_request` method. + This is the no-frills, simple version of the client if you want to have more control over the request process. """ _header_data_keys: type[GenericHeaderFields] = GenericHeaderFields @@ -174,16 +177,17 @@ def get_specified_data_keys(data_keys: type[StrEnum], api_request: BaseRequest) def _after_request_handling( self, + *, api_request: BaseRequest, - raw_response: requests.Response, + raw_response_json: dict, + returned_status_code: int, expected_response: type[HordeResponse], ) -> HordeResponse | RequestErrorResponse: expected_response_type = api_request.get_success_response_type() - raw_response_json = raw_response.json() # If requests response is a failure code, see if a `message` key exists in the response. # If so, return a RequestErrorResponse - if raw_response.status_code >= 400: + if returned_status_code >= 400: if len(raw_response_json) == 1 and "message" in raw_response_json: return RequestErrorResponse(**raw_response_json) @@ -288,7 +292,12 @@ def get( allow_redirects=True, ) - return self._after_request_handling(api_request, raw_response, expected_response) + return self._after_request_handling( + api_request=api_request, + raw_response_json=raw_response.json(), + returned_status_code=raw_response.status_code, + expected_response=expected_response, + ) async def async_get( self, @@ -305,9 +314,13 @@ async def async_get( HordeResponse | RequestErrorResponse: The response from the API. """ parsed_request = self._validate_and_prepare_request(api_request) + if parsed_request.request_body is not None: raise RuntimeError("GET requests cannot have a body!") + raw_response_json: dict = {} + response_status: int = 599 + async with ( aiohttp.ClientSession() as session, session.get( @@ -317,9 +330,15 @@ async def async_get( allow_redirects=True, ) as response, ): - raw_response = await response.json() - - return self._after_request_handling(api_request, raw_response, expected_response) + raw_response_json = await response.json() + response_status = response.status + + return self._after_request_handling( + api_request=api_request, + raw_response_json=raw_response_json, + returned_status_code=response_status, + expected_response=expected_response, + ) def post( self, @@ -344,7 +363,12 @@ def post( allow_redirects=True, ) - return self._after_request_handling(api_request, raw_response, expected_response) + return self._after_request_handling( + api_request=api_request, + raw_response_json=raw_response.json(), + returned_status_code=raw_response.status_code, + expected_response=expected_response, + ) async def async_post( self, @@ -361,7 +385,8 @@ async def async_post( HordeResponse | RequestErrorResponse: The response from the API. """ parsed_request = self._validate_and_prepare_request(api_request) - + raw_response_json: dict = {} + response_status: int = 599 async with ( aiohttp.ClientSession() as session, session.post( @@ -372,9 +397,15 @@ async def async_post( allow_redirects=True, ) as response, ): - raw_response = await response.json() - - return self._after_request_handling(api_request, raw_response, expected_response) + raw_response_json = await response.json() + response_status = response.status + + return self._after_request_handling( + api_request=api_request, + raw_response_json=raw_response_json, + returned_status_code=response_status, + expected_response=expected_response, + ) def put( self, @@ -399,7 +430,12 @@ def put( allow_redirects=True, ) - return self._after_request_handling(api_request, raw_response, expected_response) + return self._after_request_handling( + api_request=api_request, + raw_response_json=raw_response.json(), + returned_status_code=raw_response.status_code, + expected_response=expected_response, + ) async def async_put( self, @@ -416,7 +452,8 @@ async def async_put( HordeResponse | RequestErrorResponse: The response from the API. """ parsed_request = self._validate_and_prepare_request(api_request) - + raw_response_json: dict = {} + response_status: int = 599 async with ( aiohttp.ClientSession() as session, session.put( @@ -427,9 +464,15 @@ async def async_put( allow_redirects=True, ) as response, ): - raw_response = await response.json() - - return self._after_request_handling(api_request, raw_response, expected_response) + raw_response_json = await response.json() + response_status = response.status + + return self._after_request_handling( + api_request=api_request, + raw_response_json=raw_response_json, + returned_status_code=response_status, + expected_response=expected_response, + ) def patch( self, @@ -454,7 +497,12 @@ def patch( allow_redirects=True, ) - return self._after_request_handling(api_request, raw_response, expected_response) + return self._after_request_handling( + api_request=api_request, + raw_response_json=raw_response.json(), + returned_status_code=raw_response.status_code, + expected_response=expected_response, + ) async def async_patch( self, @@ -471,7 +519,8 @@ async def async_patch( HordeResponse | RequestErrorResponse: The response from the API. """ parsed_request = self._validate_and_prepare_request(api_request) - + raw_response_json: dict = {} + response_status: int = 599 async with ( aiohttp.ClientSession() as session, session.patch( @@ -482,9 +531,15 @@ async def async_patch( allow_redirects=True, ) as response, ): - raw_response = await response.json() - - return self._after_request_handling(api_request, raw_response, expected_response) + raw_response_json = await response.json() + response_status = response.status + + return self._after_request_handling( + api_request=api_request, + raw_response_json=raw_response_json, + returned_status_code=response_status, + expected_response=expected_response, + ) def delete( self, @@ -509,7 +564,12 @@ def delete( allow_redirects=True, ) - return self._after_request_handling(api_request, raw_response, expected_response) + return self._after_request_handling( + api_request=api_request, + raw_response_json=raw_response.json(), + returned_status_code=raw_response.status_code, + expected_response=expected_response, + ) async def async_delete( self, @@ -526,7 +586,8 @@ async def async_delete( HordeResponse | RequestErrorResponse: The response from the API. """ parsed_request = self._validate_and_prepare_request(api_request) - + raw_response_json: dict = {} + response_status: int = 599 async with ( aiohttp.ClientSession() as session, session.delete( @@ -537,47 +598,159 @@ async def async_delete( allow_redirects=True, ) as response, ): - raw_response = await response.json() + raw_response_json = await response.json() + response_status = response.status + + return self._after_request_handling( + api_request=api_request, + raw_response_json=raw_response_json, + returned_status_code=response_status, + expected_response=expected_response, + ) - return self._after_request_handling(api_request, raw_response, expected_response) +class GenericHordeAPISession(GenericHordeAPISimpleClient): + """A client which can perform arbitrary horde API requests, but also keeps track of requests' responses which + need follow up. Use `submit_request` for synchronous requests, and `async_submit_request` for asynchronous + requests. -class HordeRequestHandler(Generic[HordeRequest, HordeResponse]): - request: HordeRequest - """The request to be handled.""" + This typically is the better class if you do not want to worry about handling any outstanding requests + if your program crashes. This would be the case with most non-atomic requests, such as generation requests + or anything labeled as `async` on the API. + """ - response: HordeResponse | RequestErrorResponse - """The response from the API.""" + _awaiting_requests: list[BaseRequest] + _awaiting_requests_lock: asyncio.Lock = asyncio.Lock() - def __init__(self, request: HordeRequest) -> None: - self.request = request + _pending_follow_ups: list[tuple[BaseRequest, BaseResponse]] + _pending_follow_ups_lock: asyncio.Lock = asyncio.Lock() - def __enter__(self) -> HordeRequest: - return self.request + def __init__( + self, + *, + header_fields: type[GenericHeaderFields] = GenericHeaderFields, + path_fields: type[GenericPathFields] = GenericPathFields, + query_fields: type[GenericQueryFields] = GenericQueryFields, + accept_types: type[GenericAcceptTypes] = GenericAcceptTypes, + ) -> None: + super().__init__( + header_fields=header_fields, + path_fields=path_fields, + query_fields=query_fields, + accept_types=accept_types, + ) + self._awaiting_requests = [] + self._pending_follow_ups = [] + + @override + def submit_request( + self, + api_request: BaseRequest, + expected_response_type: type[HordeResponse], + ) -> HordeResponse | RequestErrorResponse: + response = super().submit_request(api_request, expected_response_type) + self._pending_follow_ups.append((api_request, response)) + return response + + @override + async def async_submit_request( + self, + api_request: BaseRequest, + expected_response_type: type[HordeResponse], + ) -> HordeResponse | RequestErrorResponse: + async with self._awaiting_requests_lock: + self._awaiting_requests.append(api_request) + + response = await super().async_submit_request(api_request, expected_response_type) + + async with self._awaiting_requests_lock, self._pending_follow_ups_lock: + self._awaiting_requests.remove(api_request) + self._pending_follow_ups.append((api_request, response)) + + return response + + def __enter__(self) -> GenericHordeAPISession: + return self def __exit__(self, exc_type: type[Exception], exc_val: Exception, exc_tb: object) -> None: if exc_type is not None: - print(f"Error: {exc_val}, Type: {exc_type}, Traceback: {exc_tb}") - if not self.request.is_recovery_enabled(): - return - - recovery_request_type = self.request.get_recovery_request_type() - - request_params = {} - - mappable_base_types: list[type[BaseModel]] = [ - BaseRequestAuthenticated, - BaseRequestUserSpecific, - BaseRequestWorkerDriven, - ] - - # If it any of the base types are a subclass of the recovery request type, then we can map the request - # parameters to the recovery request. - # - # For example, if the recovery request type is `DeleteImageGenerateRequest`, and the request is - # `ImageGenerateAsyncRequest`, then we can map the `id` parameter from the request to the `id` parameter - # of the recovery request. - for base_type in mappable_base_types: - if issubclass(recovery_request_type, base_type): - for key in base_type.model_fields: - request_params[key] = getattr(self.request, key) + logger.error(f"Error: {exc_val}, Type: {exc_type}, Traceback: {exc_tb}") + + if not self._pending_follow_ups: + return + + for request_to_follow_up, response_to_follow_up in self._pending_follow_ups: + self._handle_exit(request_to_follow_up, response_to_follow_up) + + def _handle_exit(self, request_to_follow_up: BaseRequest, response_to_follow_up: BaseResponse) -> None: + recovery_request_type: type[BaseRequest] = request_to_follow_up.get_recovery_request_type() + + request_params: dict[str, object] = response_to_follow_up.get_follow_up_data() + + message = ( + "An exception occurred while trying to create a recovery request! " + "This is a bug in the SDK, please report it!" + ) + try: + recovery_request = recovery_request_type.model_validate(request_params) + + recovery_response = super().submit_request( + recovery_request, + recovery_request.get_success_response_type(), + ) + logger.info("Recovery request submitted!") + logger.debug(f"Recovery response: {recovery_response}") + + except Exception: + logger.critical(message) + logger.critical(f"{request_to_follow_up}") + + async def __aenter__(self) -> GenericHordeAPISession: + return self + + async def __aexit__(self, exc_type: type[Exception], exc_val: Exception, exc_tb: object) -> None: + if exc_type is not None: + logger.error(f"Error: {exc_val}, Type: {exc_type}, Traceback: {exc_tb}") + + if self._awaiting_requests: + logger.warning( + ( + "This session was used to submit asynchronous requests, but the context manager was exited " + "before all requests were returned! This may result in requests not being handled properly." + ), + ) + for request in self._awaiting_requests: + logger.warning(f"Request Unhandled: {request}") + + if not self._pending_follow_ups: + return + + await asyncio.gather( + *[ + self._handle_exit_async(request_to_follow_up, response_to_follow_up) + for request_to_follow_up, response_to_follow_up in self._pending_follow_ups + ], + ) + + async def _handle_exit_async(self, request_to_follow_up: BaseRequest, response_to_follow_up: BaseResponse) -> None: + recovery_request_type: type[BaseRequest] = request_to_follow_up.get_recovery_request_type() + + request_params: dict[str, object] = response_to_follow_up.get_follow_up_data() + + message = ( + "An exception occurred while trying to create a recovery request! " + "This is a bug in the SDK, please report it!" + ) + try: + recovery_request = recovery_request_type.model_validate(request_params) + + recovery_response = await super().async_submit_request( + recovery_request, + recovery_request.get_success_response_type(), + ) + logger.info("Recovery request submitted!") + logger.debug(f"Recovery response: {recovery_response}") + + except Exception: + logger.critical(message) + logger.critical(f"{request_to_follow_up}") diff --git a/horde_sdk/generic_api/metadata.py b/horde_sdk/generic_api/metadata.py index 4a1bf4c..d6c2127 100644 --- a/horde_sdk/generic_api/metadata.py +++ b/horde_sdk/generic_api/metadata.py @@ -11,6 +11,7 @@ class GenericHeaderFields(StrEnum): apikey = auto() accept = auto() # X_Fields = "X-Fields" # TODO? + client_agent = auto() class GenericAcceptTypes(StrEnum): diff --git a/horde_sdk/ratings_api/apimodels.py b/horde_sdk/ratings_api/apimodels.py index 1992614..1afb364 100644 --- a/horde_sdk/ratings_api/apimodels.py +++ b/horde_sdk/ratings_api/apimodels.py @@ -270,7 +270,6 @@ class UserRatingsRequest( limit: int offset: int = 0 diverge: int | None - client_agent: str | None @override @classmethod diff --git a/horde_sdk/ratings_api/ratings_client.py b/horde_sdk/ratings_api/ratings_client.py index 3b467cb..386f438 100644 --- a/horde_sdk/ratings_api/ratings_client.py +++ b/horde_sdk/ratings_api/ratings_client.py @@ -1,9 +1,9 @@ """Definitions to help interact with the Ratings API.""" -from horde_sdk.generic_api.generic_client import GenericHordeAPIClient +from horde_sdk.generic_api.generic_client import GenericHordeAPISimpleClient from horde_sdk.ratings_api.metadata import RatingsAPIPathFields, RatingsAPIQueryFields -class RatingsAPIClient(GenericHordeAPIClient): +class RatingsAPIClient(GenericHordeAPISimpleClient): """Represent a client specifically configured for the Ratings APi.""" def __init__(self) -> None: diff --git a/pyproject.toml b/pyproject.toml index ac36772..8542c7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,6 +20,7 @@ dependencies = [ "babel", "aiohttp", "aiodns", + "pillow", ] license = {file = "LICENSE"} classifiers = [ diff --git a/requirements.dev.txt b/requirements.dev.txt index f84644b..25a17a8 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -1,8 +1,10 @@ pytest==7.4.0 +pytest-asyncio mypy==1.4.1 black==23.3.0 ruff==0.0.275 types-requests==2.31.0.1 +types-Pillow types-pytz types-setuptools types-urllib3 diff --git a/requirements.txt b/requirements.txt index 1e6c696..3f2f158 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,3 +4,4 @@ StrEnum loguru aiohttp aiodns +pillow diff --git a/tests/ai_horde_api/test_api_calls.py b/tests/ai_horde_api/test_ai_horde_api_calls.py similarity index 60% rename from tests/ai_horde_api/test_api_calls.py rename to tests/ai_horde_api/test_ai_horde_api_calls.py index 1546e1b..1486fed 100644 --- a/tests/ai_horde_api/test_api_calls.py +++ b/tests/ai_horde_api/test_ai_horde_api_calls.py @@ -1,8 +1,9 @@ +import asyncio from pathlib import Path import pytest -from horde_sdk.ai_horde_api.ai_horde_client import AIHordeAPIClient +from horde_sdk.ai_horde_api.ai_horde_client import AIHordeAPISession, AIHordeAPISimpleClient from horde_sdk.ai_horde_api.apimodels import ( AllWorkersDetailsRequest, AllWorkersDetailsResponse, @@ -28,10 +29,10 @@ def default_image_gen_request(self) -> ImageGenerateAsyncRequest: ) def test_AIHordeAPIClient_init(self) -> None: - AIHordeAPIClient() + AIHordeAPISimpleClient() def test_generate_async(self, default_image_gen_request: ImageGenerateAsyncRequest) -> None: - client = AIHordeAPIClient() + client = AIHordeAPISimpleClient() image_async_response: ImageGenerateAsyncResponse | RequestErrorResponse = client.submit_request( api_request=default_image_gen_request, @@ -59,7 +60,7 @@ def test_generate_async(self, default_image_gen_request: ImageGenerateAsyncReque assert isinstance(cancel_response, DeleteImageGenerateRequest.get_success_response_type()) def test_workers_all(self) -> None: - client = AIHordeAPIClient() + client = AIHordeAPISimpleClient() api_request = AllWorkersDetailsRequest(type=WORKER_TYPE.image) @@ -91,3 +92,50 @@ def test_workers_all(self) -> None: _PRODUCTION_RESPONSES_FOLDER.mkdir(parents=True, exist_ok=True) with open(_PRODUCTION_RESPONSES_FOLDER / filename, "w") as f: f.write(api_response.to_json_horde_sdk_safe()) + + +class HordeTestException(Exception): + pass + + +def test_HordeRequestSession(simple_image_gen_request: ImageGenerateAsyncRequest) -> None: + with pytest.raises(HordeTestException), AIHordeAPISession() as horde_session: + api_response = horde_session.submit_request( # noqa: F841 + simple_image_gen_request, + simple_image_gen_request.get_success_response_type(), + ) + raise HordeTestException("This tests the context manager, not the request/response.") + + +@pytest.mark.asyncio +async def test_HordeRequestSession_async(simple_image_gen_request: ImageGenerateAsyncRequest) -> None: + with AIHordeAPISession() as horde_session: + api_response = await horde_session.async_submit_request( # noqa: F841 + simple_image_gen_request, + simple_image_gen_request.get_success_response_type(), + ) + + +@pytest.mark.asyncio +async def test_HordeRequestSession_async_exception_raised(simple_image_gen_request: ImageGenerateAsyncRequest) -> None: + with pytest.raises(HordeTestException), AIHordeAPISession() as horde_session: + api_response = await horde_session.async_submit_request( # noqa: F841 + simple_image_gen_request, + simple_image_gen_request.get_success_response_type(), + ) + raise HordeTestException("This tests the context manager, not the request/response.") + + +@pytest.mark.asyncio +async def test_multiple_concurrent_async_requests(simple_image_gen_request: ImageGenerateAsyncRequest) -> None: + async def submit_request() -> None: + async with AIHordeAPISession() as horde_session: + api_response: ImageGenerateAsyncResponse | RequestErrorResponse = ( # noqa: F841 + await horde_session.async_submit_request( + simple_image_gen_request, + simple_image_gen_request.get_success_response_type(), + ) + ) + + # Run 5 concurrent requests using asyncio + await asyncio.gather(*[asyncio.create_task(submit_request()) for _ in range(5)]) diff --git a/tests/conftest.py b/tests/conftest.py index 65d3c84..b05bc86 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,19 @@ +import pytest + +from horde_sdk.ai_horde_api.apimodels import ( + ImageGenerateAsyncRequest, +) + + +@pytest.fixture(scope="function") +def simple_image_gen_request() -> ImageGenerateAsyncRequest: + return ImageGenerateAsyncRequest( + apikey="0000000000", + prompt="a cat in a hat", + models=["Deliberate"], + ) + + def pytest_collection_modifyitems(items: list) -> None: """Modifies test items in place to ensure test modules run in a given order.""" MODULE_ORDER = ["tests_generic", "test_utils", "test_dynamically_check_apimodels"]