Skip to content

Commit

Permalink
feat: proper asyncio support, horde context manager
Browse files Browse the repository at this point in the history
- e.g. `AIHordeAPISimpleClient`, `AIHordeAPISession`, and `AIHordeAPIClient`
- Requests which require some sort of follow up, such as image generation requests, now can have the appropriate cancel request sent automatically with
- See `AIHordeAPISession` (new context handler) or its usage in `test_ai_horde_api_calls.py` for an idea of what was accomplished.
  • Loading branch information
tazlin committed Jul 13, 2023
1 parent 0f78f1c commit cafe3c5
Show file tree
Hide file tree
Showing 17 changed files with 520 additions and 98 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,5 @@ dmypy.json
out.json
.vscode/launch.json
.vscode/settings.json

examples/requested_images/*.*
6 changes: 3 additions & 3 deletions examples/ai_horde_api_client.example.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from horde_sdk.ai_horde_api import AIHordeAPIClient
from horde_sdk.ai_horde_api import AIHordeAPISimpleClient
from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest
from horde_sdk.generic_api.apimodels import RequestErrorResponse


def do_generate_check(ai_horde_api_client: AIHordeAPIClient) -> None:
def do_generate_check(ai_horde_api_client: AIHordeAPISimpleClient) -> None:
pass


def main() -> None:
"""Just a proof of concept - but several other pieces of functionality exist."""

ai_horde_api_client = AIHordeAPIClient()
ai_horde_api_client = AIHordeAPISimpleClient()

image_generate_async_request = ImageGenerateAsyncRequest(
apikey="0000000000",
Expand Down
64 changes: 59 additions & 5 deletions examples/async_ai_horde_api_client_example.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
from __future__ import annotations

import asyncio
import time
from pathlib import Path

from horde_sdk.ai_horde_api import AIHordeAPIClient
from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest
import aiohttp

from horde_sdk.ai_horde_api import AIHordeAPISimpleClient
from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGenerateStatusRequest
from horde_sdk.generic_api.apimodels import RequestErrorResponse


async def main() -> None:
print("Starting...")
ai_horde_api_client = AIHordeAPIClient()
ai_horde_api_client = AIHordeAPISimpleClient()

image_generate_async_request = ImageGenerateAsyncRequest(
apikey="0000000000",
prompt="A cat in a hat",
models=["Deliberate"],
)

print("Submitting image generation request...")
response = await ai_horde_api_client.async_submit_request(
image_generate_async_request,
image_generate_async_request.get_success_response_type(),
Expand All @@ -26,8 +30,17 @@ async def main() -> None:
print(f"Error: {response.message}")
return

print("Image generation request submitted!")
image_done = False
start_time = time.time()
check_counter = 0
# Keep making ImageGenerateCheckRequests until the job is done.
while True:
while not image_done:
if time.time() - start_time > 20 or check_counter == 0:
print(f"{time.time() - start_time} seconds elapsed ({check_counter} checks made)...")
start_time = time.time()

check_counter += 1
check_response = await ai_horde_api_client.async_get_generate_check(
apikey="0000000000",
generation_id=response.id_,
Expand All @@ -38,10 +51,51 @@ async def main() -> None:
return

if check_response.done:
print("Image is done!")
image_done = True
break

await asyncio.sleep(5)

# Get the image with a ImageGenerateStatusRequest.
image_generate_status_request = ImageGenerateStatusRequest(
id=response.id_,
)

status_response = await ai_horde_api_client.async_submit_request(
image_generate_status_request,
image_generate_status_request.get_success_response_type(),
)

if isinstance(status_response, RequestErrorResponse):
print(f"Error: {status_response.message}")
return

for image_gen in status_response.generations:
print("Image generation:")
print(f"ID: {image_gen.id}")
print(f"URL: {image_gen.img}")
# debug(image_gen)
print("Downloading image...")

image_bytes = None
# image_gen.img is a url, download it using aiohttp.
async with aiohttp.ClientSession() as session, session.get(image_gen.img) as resp:
image_bytes = await resp.read()

if image_bytes is None:
print("Error: Could not download image.")
return

# Open a file in write mode and write the image bytes to it.
dir_to_write_to = Path("examples/requested_images/")
dir_to_write_to.mkdir(parents=True, exist_ok=True)
filepath_to_write_to = dir_to_write_to / f"{image_gen.id}.webp"
with open(filepath_to_write_to, "wb") as image_file:
image_file.write(image_bytes)

print(f"Image downloaded to {filepath_to_write_to}!")


if __name__ == "__main__":
asyncio.run(main())
4 changes: 2 additions & 2 deletions horde_sdk/ai_horde_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from horde_sdk.ai_horde_api.ai_horde_client import (
AIHordeAPIClient,
AIHordeAPISimpleClient,
)
from horde_sdk.ai_horde_api.consts import (
ALCHEMY_FORMS,
Expand All @@ -14,7 +14,7 @@
)

__all__ = [
"AIHordeAPIClient",
"AIHordeAPISimpleClient",
"AI_HORDE_BASE_URL",
"AI_HORDE_API_URL_Literals",
"ALCHEMY_FORMS",
Expand Down
62 changes: 58 additions & 4 deletions horde_sdk/ai_horde_api/ai_horde_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Definitions to help interact with the AI-Horde API."""
from __future__ import annotations

import urllib.parse

from loguru import logger

from horde_sdk.ai_horde_api.apimodels import (
DeleteImageGenerateRequest,
ImageGenerateAsyncRequest,
ImageGenerateAsyncResponse,
ImageGenerateCheckRequest,
ImageGenerateCheckResponse,
ImageGenerateStatusRequest,
Expand All @@ -14,10 +18,13 @@
from horde_sdk.ai_horde_api.fields import GenerationID
from horde_sdk.ai_horde_api.metadata import AIHordePathData
from horde_sdk.generic_api.apimodels import RequestErrorResponse
from horde_sdk.generic_api.generic_client import GenericHordeAPIClient
from horde_sdk.generic_api.generic_client import (
GenericHordeAPISession,
GenericHordeAPISimpleClient,
)


class AIHordeAPIClient(GenericHordeAPIClient):
class AIHordeAPISimpleClient(GenericHordeAPISimpleClient):
"""Represent an API client specifically configured for the AI-Horde API."""

def __init__(self) -> None:
Expand Down Expand Up @@ -156,7 +163,7 @@ def delete_pending_image(
Args:
generation_id (GenerationID): The ID of the request to delete.
"""
api_request = DeleteImageGenerateRequest(id=generation_id, apikey=apikey)
api_request = DeleteImageGenerateRequest(id=generation_id)

api_response = self.submit_request(api_request, api_request.get_success_response_type())
if isinstance(api_response, RequestErrorResponse):
Expand All @@ -170,11 +177,58 @@ async def async_delete_pending_image(
apikey: str,
generation_id: GenerationID | str,
) -> ImageGenerateStatusResponse | RequestErrorResponse:
api_request = DeleteImageGenerateRequest(id=generation_id, apikey=apikey)
api_request = DeleteImageGenerateRequest(id=generation_id)

api_response = await self.async_submit_request(api_request, api_request.get_success_response_type())
if isinstance(api_response, RequestErrorResponse):
self._handle_api_error(api_response, api_request.get_endpoint_url())
return api_response

return api_response


class AIHordeAPISession(AIHordeAPISimpleClient, GenericHordeAPISession):
"""Represent an API session specifically configured for the AI-Horde API.
If you make a request which requires follow up (such as a request to generate an image), this will delete the
generation in progress when the context manager exits. If you do not want this to happen, use `AIHordeAPIClient`.
"""

def __enter__(self) -> AIHordeAPISession:
_self = super().__enter__()
if not isinstance(_self, AIHordeAPISession):
raise TypeError("Unexpected type returned from super().__enter__()")

return _self

async def __aenter__(self) -> AIHordeAPISession:
_self = await super().__aenter__()
if not isinstance(_self, AIHordeAPISession):
raise TypeError("Unexpected type returned from super().__aenter__()")

return _self


class AIHordeAPIClient:
async def image_generate_request(self, image_gen_request: ImageGenerateAsyncRequest) -> ImageGenerateAsyncResponse:
async with AIHordeAPISession() as image_gen_client:
response = await image_gen_client.async_submit_request(
api_request=image_gen_request,
expected_response_type=image_gen_request.get_success_response_type(),
)

if isinstance(response, RequestErrorResponse):
raise RuntimeError(f"Error response received: {response.message}")

check_request_type = response.get_follow_up_default_request()
follow_up_data = response.get_follow_up_data()
check_request = check_request_type.model_validate(follow_up_data)
async with AIHordeAPISession() as check_client:
while True:
check_response = await check_client.async_submit_request(
api_request=check_request,
expected_response_type=check_request.get_success_response_type(),
)

if isinstance(check_response, RequestErrorResponse):
raise RuntimeError(f"Error response received: {check_response.message}")
31 changes: 28 additions & 3 deletions horde_sdk/ai_horde_api/apimodels/generate/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
BaseAIHordeRequest,
BaseImageGenerateParam,
)
from horde_sdk.ai_horde_api.apimodels.generate._status import DeleteImageGenerateRequest
from horde_sdk.ai_horde_api.apimodels.generate._check import ImageGenerateCheckRequest
from horde_sdk.ai_horde_api.apimodels.generate._status import DeleteImageGenerateRequest, ImageGenerateStatusRequest
from horde_sdk.ai_horde_api.consts import KNOWN_SOURCE_PROCESSING
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals
from horde_sdk.ai_horde_api.fields import GenerationID
from horde_sdk.consts import HTTPMethod, HTTPStatusCode
from horde_sdk.generic_api.apimodels import BaseRequestAuthenticated, BaseRequestWorkerDriven, BaseResponse
from horde_sdk.generic_api.apimodels import (
BaseRequestAuthenticated,
BaseRequestWorkerDriven,
BaseResponse,
)


class ImageGenerateAsyncResponse(BaseResponse):
Expand All @@ -19,11 +24,31 @@ class ImageGenerateAsyncResponse(BaseResponse):
v2 API Model: `RequestAsync`
"""

id_: str | GenerationID = Field(alias="id")
id_: str | GenerationID = Field(alias="id") # TODO: Remove `str`?
"""The UUID for this image generation."""
kudos: float
message: str | None = None

@override
@classmethod
def is_requiring_follow_up(cls) -> bool:
return True

@override
def get_follow_up_data(self) -> dict[str, object]:
return {"id": self.id_}

@classmethod
def get_follow_up_default_request(cls) -> type[ImageGenerateCheckRequest]:
return ImageGenerateCheckRequest

@override
@classmethod
def get_follow_up_request_types(
cls,
) -> list[type[ImageGenerateCheckRequest | ImageGenerateStatusRequest]]:
return [ImageGenerateCheckRequest, ImageGenerateStatusRequest]

@override
@classmethod
def get_api_model_name(cls) -> str | None:
Expand Down
2 changes: 0 additions & 2 deletions horde_sdk/ai_horde_api/apimodels/generate/_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals
from horde_sdk.ai_horde_api.fields import ImageID, WorkerID
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import BaseRequestAuthenticated


class ImageGeneration(BaseModel):
Expand Down Expand Up @@ -53,7 +52,6 @@ def get_api_model_name(cls) -> str | None:

class DeleteImageGenerateRequest(
BaseAIHordeRequest,
BaseRequestAuthenticated,
BaseImageJobRequest,
):
"""Represents a DELETE request to the `/v2/generate/status/{id}` endpoint."""
Expand Down
50 changes: 49 additions & 1 deletion horde_sdk/generic_api/apimodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,48 @@ class HordeAPIMessage(HordeAPIModel):
class BaseResponse(HordeAPIMessage):
"""Represents any response from any Horde API."""

@classmethod
def is_requiring_follow_up(cls) -> bool:
"""Return whether this response requires a follow up request of some kind."""
return False

def get_follow_up_data(self) -> dict[str, object]:
"""Return the information required from this response to submit a follow up request.
Note that this dict uses the alias field names (as seen on the API), not the python field names.
GenerationIDs will be returned as `{"id": "00000000-0000-0000-0000-000000000000"}` instead of
`{"id_": "00000000-0000-0000-0000-000000000000"}`.
This means it is suitable for passing directly
to a constructor, such as `ImageGenerateStatusRequest(**response.get_follow_up_required_info())`.
"""
raise NotImplementedError("This response does not require a follow up request")

@classmethod
def get_follow_up_default_request(cls) -> type[BaseRequest]:
raise NotImplementedError("This response does not require a follow up request")

@classmethod
def get_follow_up_request_types(cls) -> list[type]: # TODO type hint???
"""Return a list of all the possible follow up request types for this response"""
return [cls.get_follow_up_default_request()]

_follow_up_handled: bool = False

def set_follow_up_handled(self) -> None:
"""Set this response as having had its follow up request handled.
This is used for context management.
"""
self._follow_up_handled = True

def is_follow_up_handled(self) -> bool:
"""Return whether this response has had its follow up request handled.
This is used for context management.
"""
return self._follow_up_handled

@classmethod
def is_array_response(cls) -> bool:
"""Return whether this response is an array of an internal type."""
Expand Down Expand Up @@ -128,6 +170,11 @@ def get_http_method(cls) -> HTTPMethod:
"""The 'accept' header field."""
# X_Fields # TODO

client_agent: str = Field(
default="horde_sdk:0.2.0:https://githib.com/haidra-org/horde-sdk",
alias="Client-Agent",
)

@classmethod
def get_endpoint_url(cls) -> str:
"""Return the endpoint URL, including the path to the specific API action defined by this object"""
Expand Down Expand Up @@ -168,7 +215,8 @@ def get_header_fields(cls) -> list[str]:

@classmethod
def is_recovery_enabled(cls) -> bool:
"""Return whether this request should attempt to recover from during a client failure.
"""Return whether this request should attempt to recover from a client failure by submitting a request
specified by `get_recovery_request_type`.
This is used in for context management.
"""
Expand Down
Loading

0 comments on commit cafe3c5

Please sign in to comment.