Skip to content

Commit

Permalink
feat: better swagger validation, continued refactors
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Jul 8, 2023
1 parent a2292b1 commit 4e89015
Show file tree
Hide file tree
Showing 51 changed files with 964 additions and 399 deletions.
1 change: 1 addition & 0 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ ruff==0.0.275
types-requests==2.31.0.1
tox==4.6.3
pre-commit==3.3.3
pytest-cov
8 changes: 7 additions & 1 deletion src/horde_sdk/ai_horde_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from horde_sdk.ai_horde_api.apimodels._stats import StatsImageModels, StatsModelsResponse
from horde_sdk.ai_horde_api.apimodels.generate._async import ImageGenerateAsyncRequest, ImageGenerateAsyncResponse
from horde_sdk.ai_horde_api.apimodels.generate._check import ImageGenerateCheckRequest, ImageGenerateCheckResponse
from horde_sdk.ai_horde_api.apimodels.generate._pop import ImageGenerateJobPopRequest, ImageGenerateJobResponse
Expand All @@ -6,7 +7,10 @@
ImageGenerateStatusRequest,
ImageGenerateStatusResponse,
)
from horde_sdk.ai_horde_api.apimodels.stats import StatsImageModels, StatsModelsResponse
from horde_sdk.ai_horde_api.apimodels.generate._submit import (
ImageGenerationJobSubmitRequest,
ImageGenerationJobSubmitResponse,
)

__all__ = [
"ImageGenerateAsyncRequest",
Expand All @@ -20,4 +24,6 @@
"CancelImageGenerateRequest",
"StatsImageModels",
"StatsModelsResponse",
"ImageGenerationJobSubmitRequest",
"ImageGenerationJobSubmitResponse",
]
90 changes: 87 additions & 3 deletions src/horde_sdk/ai_horde_api/apimodels/_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,96 @@
from pydantic import BaseModel, Field, field_validator
from typing_extensions import override

from horde_sdk.ai_horde_api.consts import KNOWN_SAMPLERS, KNOWN_SOURCE_PROCESSING
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_BASE_URL
from horde_sdk.ai_horde_api.fields import GenerationID, WorkerID
from horde_sdk.generic_api.apimodels import BaseRequestAuthenticated
from horde_sdk.generic_api.apimodels import BaseRequest
from horde_sdk.utils import seed_to_int


class BaseAIHordeRequest(BaseRequest):
@override
@classmethod
def get_api_url(cls) -> str:
return AI_HORDE_BASE_URL


class BaseImageGenerateJobRequest(BaseRequestAuthenticated):
class BaseImageGenerateJobRequest(BaseModel):
"""Mix-in class for data relating to image generation jobs."""

id: str | GenerationID # noqa: A003
"""The UUID for this job."""


class BaseWorkerRequest(BaseRequestAuthenticated):
class BaseWorkerRequest(BaseModel):
"""Mix-in class for data relating to worker requests."""

worker_id: str | WorkerID
"""The UUID of the worker in question for this request."""


class LorasPayloadEntry(BaseModel):
"""Represents a single lora parameter.
v2 API Model: `ModelPayloadLorasStable`
"""

name: str = Field(min_length=1, max_length=255)
"""The name of the LoRa model to use."""
model: float = Field(default=1, ge=0, le=5)
"""The strength of the LoRa against the stable diffusion model."""
clip: float = Field(default=1, ge=0, le=5)
"""The strength of the LoRa against the clip model."""
inject_trigger: str | None = Field(default=None, min_length=1, max_length=30)
"""Any trigger required to activate the LoRa model."""


class BaseImageGenerateParam(BaseModel):
"""Represents some of the data included in a request to the `/v2/generate/async` endpoint.
Also is the corresponding information returned on a job pop to the `/v2/generate/pop` endpoint.
v2 API Model: `ModelPayloadStable`
"""

sampler_name: KNOWN_SAMPLERS = KNOWN_SAMPLERS.k_lms
cfg_scale: float = 7.5
denoising_strength: float | None = Field(default=1, ge=0, le=1)
seed: str | None = None
height: int = 512
width: int = 512
seed_variation: int | None = None
post_processing: list[str] = Field(default_factory=list)
karras: bool = True
tiling: bool = False
hires_fix: bool = False
clip_skip: int = 1
control_type: str | None = None
image_is_control: bool | None = None
return_control_map: bool | None = None
facefixer_strength: float | None = Field(default=None, ge=0, le=1)
loras: list[LorasPayloadEntry] = Field(default_factory=list)
special: dict = Field(default_factory=dict)
steps: int = Field(default=25, ge=1)

n_iter: int = Field(default=1, ge=1)
use_nsfw_censor: bool = False

@field_validator("sampler_name")
def sampler_name_must_be_known(cls, v):
"""Ensure that the sampler name is in this list of supported samplers."""
if v not in KNOWN_SAMPLERS.__members__:
raise ValueError(f"Unknown sampler name {v}")
return v

@field_validator("seed")
def seed_to_int_if_str(cls, v):
"""Ensure that the seed is an integer. If it is a string, convert it to an integer."""
return str(seed_to_int(v))


class BaseImageGenerateImg2Img(BaseModel):
"""Mix-in class for data relating to img2img generation."""

source_image: str | None = None
source_processing: KNOWN_SOURCE_PROCESSING = KNOWN_SOURCE_PROCESSING.txt2img
source_mask: str | None = None
82 changes: 0 additions & 82 deletions src/horde_sdk/ai_horde_api/apimodels/_shared.py

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels._shared import BaseAIHordeRequest
from horde_sdk.ai_horde_api.apimodels._base import BaseAIHordeRequest
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import BaseResponse


class StatsModelsResponse(BaseResponse):
"""Represents the data returned from the `/v2/stats/img/models` endpoint.
v2 API Model: `ImgModelStats`
"""

model_config = {"frozen": True}

day: dict[str, int]
Expand All @@ -20,6 +25,8 @@ def get_api_model_name(cls) -> str | None:


class StatsImageModels(BaseAIHordeRequest):
"""Represents the data needed to make a request to the `/v2/stats/img/models` endpoint."""

@override
@classmethod
def get_api_model_name(cls) -> str | None:
Expand Down
11 changes: 7 additions & 4 deletions src/horde_sdk/ai_horde_api/apimodels/generate/_async.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from pydantic import Field, model_validator
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels._shared import (
from horde_sdk.ai_horde_api.apimodels._base import (
BaseAIHordeRequest,
BaseImageGenerateImg2Img,
BaseImageGenerateParam,
ImageGenerateImg2ImgData,
)
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals
from horde_sdk.ai_horde_api.fields import GenerationID
Expand All @@ -13,7 +13,10 @@


class ImageGenerateAsyncResponse(BaseResponse):
"""Represents the data returned from the `/v2/generate/async` endpoint."""
"""Represents the data returned from the `/v2/generate/async` endpoint.
v2 API Model: `RequestAsync`
"""

id: str | GenerationID # noqa: A003
"""The UUID for this image generation."""
Expand All @@ -30,7 +33,7 @@ class ImageGenerationInputPayload(BaseImageGenerateParam):
n: int = Field(default=1, ge=1)


class ImageGenerateAsyncRequest(BaseAIHordeRequest, ImageGenerateImg2ImgData, BaseRequestWorkerDriven):
class ImageGenerateAsyncRequest(BaseAIHordeRequest, BaseImageGenerateImg2Img, BaseRequestWorkerDriven):
"""Represents the data needed to make a request to the `/v2/generate/async` endpoint.
v2 API Model: `GenerationInputStable`
Expand Down
3 changes: 1 addition & 2 deletions src/horde_sdk/ai_horde_api/apimodels/generate/_check.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels._base import BaseImageGenerateJobRequest
from horde_sdk.ai_horde_api.apimodels._shared import BaseAIHordeRequest
from horde_sdk.ai_horde_api.apimodels._base import BaseAIHordeRequest, BaseImageGenerateJobRequest
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals
from horde_sdk.consts import HTTPMethod
from horde_sdk.generic_api.apimodels import BaseResponse
Expand Down
2 changes: 1 addition & 1 deletion src/horde_sdk/ai_horde_api/apimodels/generate/_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pydantic import Field, field_validator
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels._shared import BaseAIHordeRequest, BaseImageGenerateParam
from horde_sdk.ai_horde_api.apimodels._base import BaseAIHordeRequest, BaseImageGenerateParam
from horde_sdk.ai_horde_api.consts import KNOWN_SOURCE_PROCESSING
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals
from horde_sdk.ai_horde_api.fields import GenerationID
Expand Down
3 changes: 1 addition & 2 deletions src/horde_sdk/ai_horde_api/apimodels/generate/_status.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from pydantic import BaseModel, Field
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels._base import BaseImageGenerateJobRequest
from horde_sdk.ai_horde_api.apimodels._shared import BaseAIHordeRequest
from horde_sdk.ai_horde_api.apimodels._base import BaseAIHordeRequest, BaseImageGenerateJobRequest
from horde_sdk.ai_horde_api.apimodels.generate._check import ImageGenerateCheckResponse
from horde_sdk.ai_horde_api.consts import GENERATION_STATE
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals
Expand Down
7 changes: 6 additions & 1 deletion src/horde_sdk/ai_horde_api/apimodels/generate/_submit.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing_extensions import override

from horde_sdk.ai_horde_api.apimodels._shared import BaseAIHordeRequest
from horde_sdk.ai_horde_api.apimodels._base import BaseAIHordeRequest
from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals
from horde_sdk.ai_horde_api.fields import GenerationID
from horde_sdk.consts import HTTPMethod
Expand All @@ -11,6 +11,11 @@ class ImageGenerationJobSubmitResponse(BaseResponse):
reward: float
"""The amount of kudos gained for submitting this request."""

@override
@classmethod
def get_api_model_name(cls) -> str | None:
return "GenerationSubmitted"


class ImageGenerationJobSubmitRequest(BaseAIHordeRequest, BaseRequestAuthenticated):
"""Represents the data needed to make a job submit 'request' from a worker to the /v2/generate/submit endpoint.
Expand Down
11 changes: 11 additions & 0 deletions src/horde_sdk/ai_horde_api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from strenum import StrEnum

from horde_sdk.generic_api.endpoints import url_with_path

AI_HORDE_BASE_URL = "https://aihorde.net/api/"

if os.environ.get("HORDE_URL", None):
Expand All @@ -15,6 +17,8 @@ class AI_HORDE_API_URL_Literals(StrEnum):
"""The URL actions 'paths' to the endpoints. Includes the find/replace strings in brackets for path (non-query)
variables."""

swagger = "/swagger.json"

# Note that the leading slash is included for consistency with the swagger docs,
# but it is dropped when the URL is actually constructed (see `url_with_path` in `horde_sdk.generic_api.endpoints`)
v2_stats_img_models = "/v2/stats/img/models"
Expand Down Expand Up @@ -65,3 +69,10 @@ class AI_HORDE_API_URL_Literals(StrEnum):

v2_workers_all = "/v2/workers"
v2_workers = "/v2/workers/{worker_id}"


def get_ai_horde_swagger_url() -> str:
return url_with_path(
base_url=AI_HORDE_BASE_URL,
path=AI_HORDE_API_URL_Literals.swagger,
)
Loading

0 comments on commit 4e89015

Please sign in to comment.