Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: proper asyncio support, horde context manager #8

Merged
merged 1 commit into from
Jul 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,5 @@ dmypy.json
out.json
.vscode/launch.json
.vscode/settings.json

examples/requested_images/*.*
6 changes: 3 additions & 3 deletions examples/ai_horde_api_client.example.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
from horde_sdk.ai_horde_api import AIHordeAPIClient
from horde_sdk.ai_horde_api import AIHordeAPISimpleClient
from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest
from horde_sdk.generic_api.apimodels import RequestErrorResponse


def do_generate_check(ai_horde_api_client: AIHordeAPIClient) -> None:
def do_generate_check(ai_horde_api_client: AIHordeAPISimpleClient) -> None:
pass


def main() -> None:
"""Just a proof of concept - but several other pieces of functionality exist."""

ai_horde_api_client = AIHordeAPIClient()
ai_horde_api_client = AIHordeAPISimpleClient()

image_generate_async_request = ImageGenerateAsyncRequest(
apikey="0000000000",
Expand Down
64 changes: 59 additions & 5 deletions examples/async_ai_horde_api_client_example.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,26 @@
from __future__ import annotations

import asyncio
import time
from pathlib import Path

from horde_sdk.ai_horde_api import AIHordeAPIClient
from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest
import aiohttp

from horde_sdk.ai_horde_api import AIHordeAPISimpleClient
from horde_sdk.ai_horde_api.apimodels import ImageGenerateAsyncRequest, ImageGenerateStatusRequest
from horde_sdk.generic_api.apimodels import RequestErrorResponse


async def main() -> None:
print("Starting...")
ai_horde_api_client = AIHordeAPIClient()
ai_horde_api_client = AIHordeAPISimpleClient()

image_generate_async_request = ImageGenerateAsyncRequest(
apikey="0000000000",
prompt="A cat in a hat",
models=["Deliberate"],
)

print("Submitting image generation request...")
response = await ai_horde_api_client.async_submit_request(
image_generate_async_request,
image_generate_async_request.get_success_response_type(),
Expand All @@ -26,8 +30,17 @@ async def main() -> None:
print(f"Error: {response.message}")
return

print("Image generation request submitted!")
image_done = False
start_time = time.time()
check_counter = 0
# Keep making ImageGenerateCheckRequests until the job is done.
while True:
while not image_done:
if time.time() - start_time > 20 or check_counter == 0:
print(f"{time.time() - start_time} seconds elapsed ({check_counter} checks made)...")
start_time = time.time()

check_counter += 1
check_response = await ai_horde_api_client.async_get_generate_check(
apikey="0000000000",
generation_id=response.id_,
Expand All @@ -38,10 +51,51 @@ async def main() -> None:
return

if check_response.done:
print("Image is done!")
image_done = True
break

await asyncio.sleep(5)

# Get the image with a ImageGenerateStatusRequest.
image_generate_status_request = ImageGenerateStatusRequest(
id=response.id_,
)

status_response = await ai_horde_api_client.async_submit_request(
image_generate_status_request,
image_generate_status_request.get_success_response_type(),
)

if isinstance(status_response, RequestErrorResponse):
print(f"Error: {status_response.message}")
return

for image_gen in status_response.generations:
print("Image generation:")
print(f"ID: {image_gen.id}")
print(f"URL: {image_gen.img}")
# debug(image_gen)
print("Downloading image...")

image_bytes = None
# image_gen.img is a url, download it using aiohttp.
async with aiohttp.ClientSession() as session, session.get(image_gen.img) as resp:
image_bytes = await resp.read()

if image_bytes is None:
print("Error: Could not download image.")
return

# Open a file in write mode and write the image bytes to it.
dir_to_write_to = Path("examples/requested_images/")
dir_to_write_to.mkdir(parents=True, exist_ok=True)
filepath_to_write_to = dir_to_write_to / f"{image_gen.id}.webp"
with open(filepath_to_write_to, "wb") as image_file:
image_file.write(image_bytes)

print(f"Image downloaded to {filepath_to_write_to}!")


if __name__ == "__main__":
asyncio.run(main())
4 changes: 2 additions & 2 deletions horde_sdk/ai_horde_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from horde_sdk.ai_horde_api.ai_horde_client import (
AIHordeAPIClient,
AIHordeAPISimpleClient,
)
from horde_sdk.ai_horde_api.consts import (
ALCHEMY_FORMS,
Expand All @@ -14,7 +14,7 @@
)

__all__ = [
"AIHordeAPIClient",
"AIHordeAPISimpleClient",
"AI_HORDE_BASE_URL",
"AI_HORDE_API_URL_Literals",
"ALCHEMY_FORMS",
Expand Down
62 changes: 58 additions & 4 deletions horde_sdk/ai_horde_api/ai_horde_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
"""Definitions to help interact with the AI-Horde API."""
from __future__ import annotations

import urllib.parse

from loguru import logger

from horde_sdk.ai_horde_api.apimodels import (
DeleteImageGenerateRequest,
ImageGenerateAsyncRequest,
ImageGenerateAsyncResponse,
ImageGenerateCheckRequest,
ImageGenerateCheckResponse,
ImageGenerateStatusRequest,
Expand All @@ -14,10 +18,13 @@
from horde_sdk.ai_horde_api.fields import GenerationID
from horde_sdk.ai_horde_api.metadata import AIHordePathData
from horde_sdk.generic_api.apimodels import RequestErrorResponse
from horde_sdk.generic_api.generic_client import GenericHordeAPIClient
from horde_sdk.generic_api.generic_client import (
GenericHordeAPISession,
GenericHordeAPISimpleClient,
)


class AIHordeAPIClient(GenericHordeAPIClient):
class AIHordeAPISimpleClient(GenericHordeAPISimpleClient):
"""Represent an API client specifically configured for the AI-Horde API."""

def __init__(self) -> None:
Expand Down Expand Up @@ -156,7 +163,7 @@ def delete_pending_image(
Args:
generation_id (GenerationID): The ID of the request to delete.
"""
api_request = DeleteImageGenerateRequest(id=generation_id, apikey=apikey)
api_request = DeleteImageGenerateRequest(id=generation_id)

api_response = self.submit_request(api_request, api_request.get_success_response_type())
if isinstance(api_response, RequestErrorResponse):
Expand All @@ -170,11 +177,58 @@ async def async_delete_pending_image(
apikey: str,
generation_id: GenerationID | str,
) -> ImageGenerateStatusResponse | RequestErrorResponse:
api_request = DeleteImageGenerateRequest(id=generation_id, apikey=apikey)
api_request = DeleteImageGenerateRequest(id=generation_id)

api_response = await self.async_submit_request(api_request, api_request.get_success_response_type())
if isinstance(api_response, RequestErrorResponse):
self._handle_api_error(api_response, api_request.get_endpoint_url())
return api_response

return api_response


class AIHordeAPISession(AIHordeAPISimpleClient, GenericHordeAPISession):
"""Represent an API session specifically configured for the AI-Horde API.

If you make a request which requires follow up (such as a request to generate an image), this will delete the
generation in progress when the context manager exits. If you do not want this to happen, use `AIHordeAPIClient`.
"""

def __enter__(self) -> AIHordeAPISession:
_self = super().__enter__()
if not isinstance(_self, AIHordeAPISession):
raise TypeError("Unexpected type returned from super().__enter__()")

return _self

async def __aenter__(self) -> AIHordeAPISession:
_self = await super().__aenter__()
if not isinstance(_self, AIHordeAPISession):
raise TypeError("Unexpected type returned from super().__aenter__()")

return _self


class AIHordeAPIClient:
async def image_generate_request(self, image_gen_request: ImageGenerateAsyncRequest) -> ImageGenerateAsyncResponse:
async with AIHordeAPISession() as image_gen_client:
response = await image_gen_client.async_submit_request(
api_request=image_gen_request,
expected_response_type=image_gen_request.get_success_response_type(),
)

if isinstance(response, RequestErrorResponse):
raise RuntimeError(f"Error response received: {response.message}")

check_request_type = response.get_follow_up_default_request()
follow_up_data = response.get_follow_up_data()
check_request = check_request_type.model_validate(follow_up_data)
async with AIHordeAPISession() as check_client:
while True:
check_response = await check_client.async_submit_request(
api_request=check_request,
expected_response_type=check_request.get_success_response_type(),
)

if isinstance(check_response, RequestErrorResponse):
raise RuntimeError(f"Error response received: {check_response.message}")
31 changes: 28 additions & 3 deletions horde_sdk/ai_horde_api/apimodels/generate/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@
BaseAIHordeRequest,
BaseImageGenerateParam,
)
from horde_sdk.ai_horde_api.apimodels.generate._status import DeleteImageGenerateRequest
from horde_sdk.ai_horde_api.apimodels.generate._check import ImageGenerateCheckRequest
from horde_sdk.ai_horde_api.apimodels.generate._status import DeleteImageGenerateRequest, ImageGenerateStatusRequest
from horde_sdk.ai_horde_api.consts import KNOWN_SOURCE_PROCESSING
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals
from horde_sdk.ai_horde_api.fields import GenerationID
from horde_sdk.consts import HTTPMethod, HTTPStatusCode
from horde_sdk.generic_api.apimodels import BaseRequestAuthenticated, BaseRequestWorkerDriven, BaseResponse
from horde_sdk.generic_api.apimodels import (
BaseRequestAuthenticated,
BaseRequestWorkerDriven,
BaseResponse,
)


class ImageGenerateAsyncResponse(BaseResponse):
Expand All @@ -19,11 +24,31 @@ class ImageGenerateAsyncResponse(BaseResponse):
v2 API Model: `RequestAsync`
"""

id_: str | GenerationID = Field(alias="id")
id_: str | GenerationID = Field(alias="id") # TODO: Remove `str`?
"""The UUID for this image generation."""
kudos: float
message: str | None = None

@override
@classmethod
def is_requiring_follow_up(cls) -> bool:
return True

@override
def get_follow_up_data(self) -> dict[str, object]:
return {"id": self.id_}

@classmethod
def get_follow_up_default_request(cls) -> type[ImageGenerateCheckRequest]:
return ImageGenerateCheckRequest

@override
@classmethod
def get_follow_up_request_types(
cls,
) -> list[type[ImageGenerateCheckRequest | ImageGenerateStatusRequest]]:
return [ImageGenerateCheckRequest, ImageGenerateStatusRequest]

@override
@classmethod
def get_api_model_name(cls) -> str | None:
Expand Down
2 changes: 0 additions & 2 deletions horde_sdk/ai_horde_api/apimodels/generate/_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals
from horde_sdk.ai_horde_api.fields import ImageID, WorkerID
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import BaseRequestAuthenticated


class ImageGeneration(BaseModel):
Expand Down Expand Up @@ -53,7 +52,6 @@ def get_api_model_name(cls) -> str | None:

class DeleteImageGenerateRequest(
BaseAIHordeRequest,
BaseRequestAuthenticated,
BaseImageJobRequest,
):
"""Represents a DELETE request to the `/v2/generate/status/{id}` endpoint."""
Expand Down
50 changes: 49 additions & 1 deletion horde_sdk/generic_api/apimodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,48 @@ class HordeAPIMessage(HordeAPIModel):
class BaseResponse(HordeAPIMessage):
"""Represents any response from any Horde API."""

@classmethod
def is_requiring_follow_up(cls) -> bool:
"""Return whether this response requires a follow up request of some kind."""
return False

def get_follow_up_data(self) -> dict[str, object]:
"""Return the information required from this response to submit a follow up request.

Note that this dict uses the alias field names (as seen on the API), not the python field names.
GenerationIDs will be returned as `{"id": "00000000-0000-0000-0000-000000000000"}` instead of
`{"id_": "00000000-0000-0000-0000-000000000000"}`.

This means it is suitable for passing directly
to a constructor, such as `ImageGenerateStatusRequest(**response.get_follow_up_required_info())`.
"""
raise NotImplementedError("This response does not require a follow up request")

@classmethod
def get_follow_up_default_request(cls) -> type[BaseRequest]:
raise NotImplementedError("This response does not require a follow up request")

@classmethod
def get_follow_up_request_types(cls) -> list[type]: # TODO type hint???
"""Return a list of all the possible follow up request types for this response"""
return [cls.get_follow_up_default_request()]

_follow_up_handled: bool = False

def set_follow_up_handled(self) -> None:
"""Set this response as having had its follow up request handled.

This is used for context management.
"""
self._follow_up_handled = True

def is_follow_up_handled(self) -> bool:
"""Return whether this response has had its follow up request handled.

This is used for context management.
"""
return self._follow_up_handled

@classmethod
def is_array_response(cls) -> bool:
"""Return whether this response is an array of an internal type."""
Expand Down Expand Up @@ -128,6 +170,11 @@ def get_http_method(cls) -> HTTPMethod:
"""The 'accept' header field."""
# X_Fields # TODO

client_agent: str = Field(
default="horde_sdk:0.2.0:https://githib.com/haidra-org/horde-sdk",
alias="Client-Agent",
)

@classmethod
def get_endpoint_url(cls) -> str:
"""Return the endpoint URL, including the path to the specific API action defined by this object"""
Expand Down Expand Up @@ -168,7 +215,8 @@ def get_header_fields(cls) -> list[str]:

@classmethod
def is_recovery_enabled(cls) -> bool:
"""Return whether this request should attempt to recover from during a client failure.
"""Return whether this request should attempt to recover from a client failure by submitting a request
specified by `get_recovery_request_type`.

This is used in for context management.
"""
Expand Down
Loading