diff --git a/horde_sdk/ai_horde_api/ai_horde_clients.py b/horde_sdk/ai_horde_api/ai_horde_clients.py index c10dc15..97e9aee 100644 --- a/horde_sdk/ai_horde_api/ai_horde_clients.py +++ b/horde_sdk/ai_horde_api/ai_horde_clients.py @@ -3,6 +3,7 @@ import asyncio import base64 +import contextlib import io import time import urllib.parse @@ -779,9 +780,19 @@ def alchemy_request( class AIHordeAPIAsyncSimpleClient(BaseAIHordeSimpleClient): """An asyncio based simple client for the AI-Horde API. Start with this class if you want asyncio capabilities..""" - def __init__(self, aiohttp_session: aiohttp.ClientSession) -> None: + _horde_client_session: AIHordeAPIAsyncClientSession | None + + def __init__( + self, + aiohttp_session: aiohttp.ClientSession | None, + horde_client_session: AIHordeAPIAsyncClientSession | None = None, + ) -> None: """Create a new instance of the AIHordeAPISimpleClient.""" + if aiohttp_session is not None and horde_client_session is not None: + raise ValueError("Only one of aiohttp_session or horde_client_session can be provided") + self._aiohttp_session = aiohttp_session + self._horde_client_session = horde_client_session async def download_image_from_generation(self, generation: ImageGeneration) -> tuple[PIL.Image.Image, JobID]: """Asynchronously convert from base64 or download an image from a response. @@ -876,8 +887,19 @@ async def _do_request_with_check( AIHordeRequestError: If the request failed. The error response is included in the exception. """ + context: contextlib.AbstractContextManager | AIHordeAPIAsyncClientSession + ai_horde_session: AIHordeAPIAsyncClientSession + + if self._horde_client_session is not None: + # Use a dummy context manager to keep the type checker happy + context = contextlib.nullcontext() + ai_horde_session = self._horde_client_session + elif self._aiohttp_session is not None: + ai_horde_session = AIHordeAPIAsyncClientSession(self._aiohttp_session) + context = ai_horde_session + # This session class will cleanup incomplete requests in the event of an exception - async with AIHordeAPIAsyncClientSession(aiohttp_session=self._aiohttp_session) as ai_horde_session: + async with context: # type: ignore # Submit the initial request logger.debug( f"Submitting request: {api_request.log_safe_model_dump()} with timeout {timeout}", diff --git a/horde_sdk/ai_horde_api/fields.py b/horde_sdk/ai_horde_api/fields.py index 789b2e4..1f58086 100644 --- a/horde_sdk/ai_horde_api/fields.py +++ b/horde_sdk/ai_horde_api/fields.py @@ -41,11 +41,14 @@ def __str__(self) -> str: @override def __eq__(self, other: Any) -> bool: + if isinstance(other, UUID_Identifier): + return self.root == other.root + if isinstance(other, str): return self.root.__str__() == other if isinstance(other, uuid.UUID): - return str(self.root) == str(other) + return self.root == other return False diff --git a/horde_sdk/generic_api/generic_clients.py b/horde_sdk/generic_api/generic_clients.py index d60720c..45f5985 100644 --- a/horde_sdk/generic_api/generic_clients.py +++ b/horde_sdk/generic_api/generic_clients.py @@ -639,7 +639,8 @@ async def submit_request( ) -> HordeResponseTypeVar | RequestErrorResponse: # Add the request to the list of awaiting requests. - self._awaiting_requests.append(api_request) + async with self._awaiting_requests_lock: + self._awaiting_requests.append(api_request) # Submit the request to the API and get the response. response = await super().submit_request(api_request, expected_response_type) @@ -783,6 +784,7 @@ async def _handle_exit_async( # Log the results of each cleanup request. for i, cleanup_response in enumerate(cleanup_responses): logger.info(f"Recovery request {i+1} submitted!") + logger.debug(f"Recovery request {i+1}: {cleanup_requests[i].log_safe_model_dump()}") logger.debug(f"Recovery response {i+1}: {cleanup_response}") # Return True to indicate that all requests were handled successfully. diff --git a/pyproject.toml b/pyproject.toml index a60cf64..c000f91 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "horde_sdk" -version = "0.7.7" +version = "0.7.9" description = "A python toolkit for interacting with the horde APIs, services, and ecosystem." authors = [ {name = "tazlin", email = "tazlin.on.github@gmail.com"}, diff --git a/tests/ai_horde_api/test_dynamically_validate_against_swagger.py b/tests/ai_horde_api/test_dynamically_validate_against_swagger.py index b4b6cd0..ab93391 100644 --- a/tests/ai_horde_api/test_dynamically_validate_against_swagger.py +++ b/tests/ai_horde_api/test_dynamically_validate_against_swagger.py @@ -109,9 +109,11 @@ def json_serializer(obj: object) -> object: with open("docs/api_to_sdk_payload_map.json", "w") as f: f.write(json.dumps(api_to_sdk_payload_model_map, indent=4, default=json_serializer)) + f.write("\n") with open("docs/api_to_sdk_response_map.json", "w") as f: f.write(json.dumps(api_to_sdk_response_model_map, indent=4, default=json_serializer)) + f.write("\n") def test_all_ai_horde_model_defs_in_swagger_from_prod_swagger() -> None: