Skip to content

Commit

Permalink
refactor: better leverage new 'Generic' behavior
Browse files Browse the repository at this point in the history
Not the generic API, but *typing* generics.
  • Loading branch information
tazlin committed Jul 11, 2023
1 parent 48397ac commit 87d6779
Showing 1 changed file with 51 additions and 9 deletions.
60 changes: 51 additions & 9 deletions src/horde_sdk/ai_horde_api/ai_horde_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
"""Definitions to help interact with the AI-Horde API."""
import urllib.parse

from loguru import logger

from horde_sdk.ai_horde_api.apimodels.generate import ImageGenerateAsyncRequest, ImageGenerateAsyncResponse
from horde_sdk.ai_horde_api.apimodels import (
CancelImageGenerateRequest,
ImageGenerateAsyncRequest,
ImageGenerateAsyncResponse,
ImageGenerateStatusResponse,
)
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_BASE_URL
from horde_sdk.ai_horde_api.fields import GenerationID
from horde_sdk.ai_horde_api.metadata import AIHordePathData
from horde_sdk.generic_api import GenericHordeAPIClient, RequestErrorResponse
Expand All @@ -14,12 +22,39 @@ def __init__(self) -> None:
"""Create a new instance of the RatingsAPIClient."""
super().__init__(path_fields=AIHordePathData)

_base_url: str = AI_HORDE_BASE_URL

@property
def base_url(self) -> str:
"""Get the base URL for the AI-Horde API."""
return self.base_url

@base_url.setter
def base_url(self, value: str) -> None:
"""Set the base URL for the AI-Horde API."""
if urllib.parse.urlparse(value).scheme not in ["http", "https"]:
raise ValueError(f"Invalid scheme in URL: {value}")

self.base_url = value

def _handle_api_error(self, error_response: RequestErrorResponse, endpoint_url: str) -> None:
"""Handle an error response from the API.
Args:
error_response (RequestErrorResponse): The error response to handle.
"""
logger.error("Error response received from the AI-Horde API.")
logger.error(f"Endpoint: {endpoint_url}")
logger.error(f"Message: {error_response.message}")

def generate_image_async(
self,
api_request: ImageGenerateAsyncRequest,
) -> ImageGenerateAsyncResponse | RequestErrorResponse:
"""Submit a request to the AI-Horde API to generate an image asynchronously.
This is a call to the `/v2/generate/async` endpoint.
Args:
api_request (ImageGenerateAsyncRequest): The request to submit.
Expand All @@ -29,27 +64,34 @@ def generate_image_async(
Returns:
ImageGenerateAsyncResponse | RequestErrorResponse: The response from the API.
"""
api_response = self.submit_request(api_request)
api_response = self.submit_request(api_request, api_request.get_success_response_type())
if isinstance(api_response, RequestErrorResponse):
logger.error("Error response received from the AI-Horde API.")
logger.error(f"Endpoint: {api_request.get_endpoint_url()}")
logger.error(f"Message: {api_response.message}")
self._handle_api_error(api_response, api_request.get_endpoint_url())
return api_response
if not isinstance(api_response, ImageGenerateAsyncResponse):
logger.error("Failed to generate an image asynchronously.")
logger.error(f"Unexpected response type: {type(api_response)}")
raise RuntimeError(
f"Unexpected response type. Expected ImageGenerateAsyncResponse, got {type(api_response)}"
f"Unexpected response type. Expected ImageGenerateAsyncResponse, got {type(api_response)}",
)

return api_response

def delete_pending_image_request(
def delete_pending_image(
self,
generation_id: GenerationID,
):
apikey: str,
generation_id: GenerationID | str,
) -> ImageGenerateStatusResponse | RequestErrorResponse:
"""Delete a pending image request from the AI-Horde API.
Args:
generation_id (GenerationID): The ID of the request to delete.
"""
api_request = CancelImageGenerateRequest(id=generation_id, apikey=apikey)

api_response = self.submit_request(api_request, api_request.get_success_response_type())
if isinstance(api_response, RequestErrorResponse):
self._handle_api_error(api_response, api_request.get_endpoint_url())
return api_response

return api_response

0 comments on commit 87d6779

Please sign in to comment.