Skip to content

Commit

Permalink
fix: more accurate swagger evaluation, better Base* classes
Browse files Browse the repository at this point in the history
Too many changes to list. See diff.
  • Loading branch information
tazlin committed Jul 9, 2023
1 parent 4e89015 commit 45fadf5
Show file tree
Hide file tree
Showing 92 changed files with 11,478 additions and 863 deletions.
3 changes: 3 additions & 0 deletions src/horde_sdk/ai_horde_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
ImageGenerationJobSubmitRequest,
ImageGenerationJobSubmitResponse,
)
from horde_sdk.ai_horde_api.apimodels.workers._workers_all import AllWorkersDetailsRequest, AllWorkersDetailsResponse

__all__ = [
"ImageGenerateAsyncRequest",
Expand All @@ -26,4 +27,6 @@
"StatsModelsResponse",
"ImageGenerationJobSubmitRequest",
"ImageGenerationJobSubmitResponse",
"AllWorkersDetailsRequest",
"AllWorkersDetailsResponse",
]
2 changes: 1 addition & 1 deletion src/horde_sdk/ai_horde_api/apimodels/_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,5 @@ def get_endpoint_subpath(cls) -> str:

@override
@classmethod
def get_expected_response_type(cls) -> type[StatsModelsResponse]:
def get_success_response_type(cls) -> type[StatsModelsResponse]:
return StatsModelsResponse
26 changes: 19 additions & 7 deletions src/horde_sdk/ai_horde_api/apimodels/generate/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
)
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals
from horde_sdk.ai_horde_api.fields import GenerationID
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import BaseRequestWorkerDriven, BaseResponse
from horde_sdk.consts import HTTPMethod, HTTPStatusCode
from horde_sdk.generic_api.apimodels import BaseRequestAuthenticated, BaseRequestWorkerDriven, BaseResponse


class ImageGenerateAsyncResponse(BaseResponse):
Expand All @@ -33,7 +33,12 @@ class ImageGenerationInputPayload(BaseImageGenerateParam):
n: int = Field(default=1, ge=1)


class ImageGenerateAsyncRequest(BaseAIHordeRequest, BaseImageGenerateImg2Img, BaseRequestWorkerDriven):
class ImageGenerateAsyncRequest(
BaseAIHordeRequest,
BaseRequestAuthenticated,
BaseImageGenerateImg2Img,
BaseRequestWorkerDriven,
):
"""Represents the data needed to make a request to the `/v2/generate/async` endpoint.
v2 API Model: `GenerationInputStable`
Expand Down Expand Up @@ -68,11 +73,18 @@ def get_http_method(cls) -> HTTPMethod:
return HTTPMethod.POST

@override
@staticmethod
def get_endpoint_subpath() -> str:
@classmethod
def get_endpoint_subpath(cls) -> str:
return AI_HORDE_API_URL_Literals.v2_generate_async

@override
@staticmethod
def get_expected_response_type() -> type[ImageGenerateAsyncResponse]:
@classmethod
def get_success_response_type(cls) -> type[ImageGenerateAsyncResponse]:
return ImageGenerateAsyncResponse

@override
@classmethod
def get_success_status_response_pairs(cls) -> dict[HTTPStatusCode, type[BaseResponse]]:
return {
HTTPStatusCode.ACCEPTED: cls.get_success_response_type(),
}
2 changes: 1 addition & 1 deletion src/horde_sdk/ai_horde_api/apimodels/generate/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,5 +59,5 @@ def get_endpoint_subpath(cls) -> str:

@override
@classmethod
def get_expected_response_type(cls) -> type[ImageGenerateCheckResponse]:
def get_success_response_type(cls) -> type[ImageGenerateCheckResponse]:
return ImageGenerateCheckResponse
2 changes: 1 addition & 1 deletion src/horde_sdk/ai_horde_api/apimodels/generate/_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def get_endpoint_subpath() -> str:

@override
@staticmethod
def get_expected_response_type() -> type[ImageGenerateJobResponse]:
def get_success_response_type() -> type[ImageGenerateJobResponse]:
return ImageGenerateJobResponse


Expand Down
9 changes: 7 additions & 2 deletions src/horde_sdk/ai_horde_api/apimodels/generate/_status.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals
from horde_sdk.ai_horde_api.fields import ImageID, WorkerID
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import BaseRequestAuthenticated


class ImageGenerateStatusRequest(BaseAIHordeRequest, BaseImageGenerateJobRequest):
Expand Down Expand Up @@ -54,7 +55,11 @@ def get_api_model_name(cls) -> str | None:
return "RequestStatusStable"


class CancelImageGenerateRequest(BaseAIHordeRequest, BaseImageGenerateJobRequest):
class CancelImageGenerateRequest(
BaseAIHordeRequest,
BaseRequestAuthenticated,
BaseImageGenerateJobRequest,
):
"""Represents a DELETE request to the `/v2/generate/status/{id}` endpoint."""

@override
Expand All @@ -74,5 +79,5 @@ def get_endpoint_subpath() -> str:

@override
@staticmethod
def get_expected_response_type() -> type[ImageGenerateStatusResponse]:
def get_success_response_type() -> type[ImageGenerateStatusResponse]:
return ImageGenerateStatusResponse
2 changes: 1 addition & 1 deletion src/horde_sdk/ai_horde_api/apimodels/generate/_submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,5 +51,5 @@ def get_endpoint_subpath() -> str:

@override
@staticmethod
def get_expected_response_type() -> type[ImageGenerationJobSubmitResponse]:
def get_success_response_type() -> type[ImageGenerationJobSubmitResponse]:
return ImageGenerationJobSubmitResponse
140 changes: 140 additions & 0 deletions src/horde_sdk/ai_horde_api/apimodels/workers/_workers_all.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from horde_sdk.ai_horde_api.apimodels._base import BaseAIHordeRequest
from horde_sdk.ai_horde_api.consts import WORKER_TYPE
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals
from horde_sdk.ai_horde_api.fields import TeamID, WorkerID
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import BaseRequestAuthenticated, BaseResponse, HordeAPIModel
from pydantic import BaseModel, Field
from typing_extensions import override


class TeamDetailsLite(HordeAPIModel):
name: str | None = None
"""The Name given to this team."""
id_: str | TeamID | None = Field(None, alias="id")
"""The UUID of this team."""

@override
@classmethod
def get_api_model_name(cls) -> str | None:
return "TeamDetailsLite"


class WorkerKudosDetails(HordeAPIModel):
generated: float | None = None
"""How much Kudos this worker has received for generating images."""
uptime: int | None = None
"""How much Kudos this worker has received for staying online longer."""

@override
@classmethod
def get_api_model_name(cls) -> str | None:
return "WorkerKudosDetails"


class WorkerDetailItem(HordeAPIModel):
type_: WORKER_TYPE = Field(alias="type")
name: str
id_: str | WorkerID = Field(alias="id")
online: bool | None = None
requests_fulfilled: int | None = None
kudos_rewards: float | None = None
kudos_details: WorkerKudosDetails | None = None
performance: str | None = None
threads: int | None = None
uptime: int | None = None
maintenance_mode: bool
paused: bool | None = None
info: str | None = None
nsfw: bool | None = None
owner: str | None = None
trusted: bool | None = None
flagged: bool | None = None
suspicious: int | None = None
uncompleted_jobs: int | None = None
models: list[str] | None = None
forms: list[str] | None = None
team: TeamDetailsLite | None = None
contact: str | None = Field(None, min_length=4, max_length=500)
bridge_agent: str = Field(max_length=1000)
max_pixels: int | None = None
megapixelsteps_generated: int | None = None
img2img: bool | None = None
painting: bool | None = None
post_processing: bool | None = None
lora: bool | None = None
max_length: int | None = None
max_context_length: int | None = None
tokens_generated: int | None = None

@override
@classmethod
def get_api_model_name(cls) -> str | None:
return "WorkerDetailItem"


class AllWorkersDetailsResponse(BaseResponse):
_workers: list[WorkerDetailItem]

@property
def workers(self) -> list[WorkerDetailItem]:
return self._workers

@override
@classmethod
def get_api_model_name(cls) -> str | None:
return "WorkerDetails"

@override
@classmethod
def is_array_response(cls) -> bool:
return True

@override
@classmethod
def get_array_item_type(cls) -> type[BaseModel]:
return WorkerDetailItem

@override
def set_array(self, list_: list) -> None:
if not isinstance(list_, list):
raise ValueError("list_ must be a list")

parsed_list = []
for item in list_:
parsed_list.append(WorkerDetailItem(**item))

self._workers = parsed_list

@override
def get_array(self) -> list:
return self._workers.copy()


class AllWorkersDetailsRequest(BaseAIHordeRequest, BaseRequestAuthenticated):
type_: WORKER_TYPE = Field(alias="type")

@override
@classmethod
def get_api_model_name(cls) -> str | None:
return None

@override
@classmethod
def get_endpoint_subpath(cls) -> str:
return AI_HORDE_API_URL_Literals.v2_workers_all

@override
@classmethod
def get_http_method(cls) -> HTTPMethod:
return HTTPMethod.GET

@override
@classmethod
def get_success_response_type(cls) -> type[BaseResponse]:
return AllWorkersDetailsResponse

@override
@classmethod
def get_header_fields(cls) -> list[str]:
return ["type_"]
6 changes: 6 additions & 0 deletions src/horde_sdk/ai_horde_api/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@ class KNOWN_SOURCE_PROCESSING(StrEnum):
class GENERATION_STATE(StrEnum):
ok = auto()
censored = auto()


class WORKER_TYPE(StrEnum):
image = auto()
text = auto()
interrogation = auto()
4 changes: 4 additions & 0 deletions src/horde_sdk/ai_horde_api/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,7 @@ class WorkerID(_UUID_Identifier):

class ImageID(_UUID_Identifier):
"""Represents the ID of an image. Instances of this class can be compared with a `str` or a UUID object."""


class TeamID(_UUID_Identifier):
"""Represents the ID of a team. Instances of this class can be compared with a `str` or a UUID object."""
24 changes: 24 additions & 0 deletions src/horde_sdk/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,27 @@ class HTTPStatusCode(Enum):
NOT_IMPLEMENTED = 501
SERVICE_UNAVAILABLE = 503
GATEWAY_TIMEOUT = 504


def get_all_success_status_codes() -> list[HTTPStatusCode]:
"""Return a list of all success status codes."""
return [status_code for status_code in HTTPStatusCode if is_success_status_code(status_code)]


def get_all_error_status_codes() -> list[HTTPStatusCode]:
"""Return a list of all error status codes."""
return [status_code for status_code in HTTPStatusCode if is_error_status_code(status_code)]


def is_success_status_code(status_code: HTTPStatusCode | int) -> bool:
"""Return True if the status code is a success code, False otherwise."""
if isinstance(status_code, HTTPStatusCode):
status_code = status_code.value
return 200 <= status_code < 300


def is_error_status_code(status_code: HTTPStatusCode | int) -> bool:
"""Return True if the status code is an error code, False otherwise."""
if isinstance(status_code, HTTPStatusCode):
status_code = status_code.value
return 400 <= status_code < 600
Loading

0 comments on commit 45fadf5

Please sign in to comment.