diff --git a/requirements.dev.txt b/requirements.dev.txt index 7577853..c9a2fdd 100644 --- a/requirements.dev.txt +++ b/requirements.dev.txt @@ -5,3 +5,4 @@ ruff==0.0.275 types-requests==2.31.0.1 tox==4.6.3 pre-commit==3.3.3 +pytest-cov diff --git a/src/horde_sdk/ai_horde_api/__init__.py b/src/horde_sdk/ai_horde_api/__init__.py index 42ea779..29db661 100644 --- a/src/horde_sdk/ai_horde_api/__init__.py +++ b/src/horde_sdk/ai_horde_api/__init__.py @@ -1,3 +1,4 @@ +from horde_sdk.ai_horde_api.apimodels._stats import StatsImageModels, StatsModelsResponse from horde_sdk.ai_horde_api.apimodels.generate._async import ImageGenerateAsyncRequest, ImageGenerateAsyncResponse from horde_sdk.ai_horde_api.apimodels.generate._check import ImageGenerateCheckRequest, ImageGenerateCheckResponse from horde_sdk.ai_horde_api.apimodels.generate._pop import ImageGenerateJobPopRequest, ImageGenerateJobResponse @@ -6,7 +7,10 @@ ImageGenerateStatusRequest, ImageGenerateStatusResponse, ) -from horde_sdk.ai_horde_api.apimodels.stats import StatsImageModels, StatsModelsResponse +from horde_sdk.ai_horde_api.apimodels.generate._submit import ( + ImageGenerationJobSubmitRequest, + ImageGenerationJobSubmitResponse, +) __all__ = [ "ImageGenerateAsyncRequest", @@ -20,4 +24,6 @@ "CancelImageGenerateRequest", "StatsImageModels", "StatsModelsResponse", + "ImageGenerationJobSubmitRequest", + "ImageGenerationJobSubmitResponse", ] diff --git a/src/horde_sdk/ai_horde_api/apimodels/_base.py b/src/horde_sdk/ai_horde_api/apimodels/_base.py index f4c0e85..61027dc 100644 --- a/src/horde_sdk/ai_horde_api/apimodels/_base.py +++ b/src/horde_sdk/ai_horde_api/apimodels/_base.py @@ -1,12 +1,96 @@ +from pydantic import BaseModel, Field, field_validator +from typing_extensions import override + +from horde_sdk.ai_horde_api.consts import KNOWN_SAMPLERS, KNOWN_SOURCE_PROCESSING +from horde_sdk.ai_horde_api.endpoints import AI_HORDE_BASE_URL from horde_sdk.ai_horde_api.fields import GenerationID, WorkerID -from horde_sdk.generic_api.apimodels import BaseRequestAuthenticated +from horde_sdk.generic_api.apimodels import BaseRequest +from horde_sdk.utils import seed_to_int + + +class BaseAIHordeRequest(BaseRequest): + @override + @classmethod + def get_api_url(cls) -> str: + return AI_HORDE_BASE_URL -class BaseImageGenerateJobRequest(BaseRequestAuthenticated): +class BaseImageGenerateJobRequest(BaseModel): + """Mix-in class for data relating to image generation jobs.""" + id: str | GenerationID # noqa: A003 """The UUID for this job.""" -class BaseWorkerRequest(BaseRequestAuthenticated): +class BaseWorkerRequest(BaseModel): + """Mix-in class for data relating to worker requests.""" + worker_id: str | WorkerID """The UUID of the worker in question for this request.""" + + +class LorasPayloadEntry(BaseModel): + """Represents a single lora parameter. + + v2 API Model: `ModelPayloadLorasStable` + """ + + name: str = Field(min_length=1, max_length=255) + """The name of the LoRa model to use.""" + model: float = Field(default=1, ge=0, le=5) + """The strength of the LoRa against the stable diffusion model.""" + clip: float = Field(default=1, ge=0, le=5) + """The strength of the LoRa against the clip model.""" + inject_trigger: str | None = Field(default=None, min_length=1, max_length=30) + """Any trigger required to activate the LoRa model.""" + + +class BaseImageGenerateParam(BaseModel): + """Represents some of the data included in a request to the `/v2/generate/async` endpoint. + + Also is the corresponding information returned on a job pop to the `/v2/generate/pop` endpoint. + v2 API Model: `ModelPayloadStable` + """ + + sampler_name: KNOWN_SAMPLERS = KNOWN_SAMPLERS.k_lms + cfg_scale: float = 7.5 + denoising_strength: float | None = Field(default=1, ge=0, le=1) + seed: str | None = None + height: int = 512 + width: int = 512 + seed_variation: int | None = None + post_processing: list[str] = Field(default_factory=list) + karras: bool = True + tiling: bool = False + hires_fix: bool = False + clip_skip: int = 1 + control_type: str | None = None + image_is_control: bool | None = None + return_control_map: bool | None = None + facefixer_strength: float | None = Field(default=None, ge=0, le=1) + loras: list[LorasPayloadEntry] = Field(default_factory=list) + special: dict = Field(default_factory=dict) + steps: int = Field(default=25, ge=1) + + n_iter: int = Field(default=1, ge=1) + use_nsfw_censor: bool = False + + @field_validator("sampler_name") + def sampler_name_must_be_known(cls, v): + """Ensure that the sampler name is in this list of supported samplers.""" + if v not in KNOWN_SAMPLERS.__members__: + raise ValueError(f"Unknown sampler name {v}") + return v + + @field_validator("seed") + def seed_to_int_if_str(cls, v): + """Ensure that the seed is an integer. If it is a string, convert it to an integer.""" + return str(seed_to_int(v)) + + +class BaseImageGenerateImg2Img(BaseModel): + """Mix-in class for data relating to img2img generation.""" + + source_image: str | None = None + source_processing: KNOWN_SOURCE_PROCESSING = KNOWN_SOURCE_PROCESSING.txt2img + source_mask: str | None = None diff --git a/src/horde_sdk/ai_horde_api/apimodels/_shared.py b/src/horde_sdk/ai_horde_api/apimodels/_shared.py deleted file mode 100644 index 275116b..0000000 --- a/src/horde_sdk/ai_horde_api/apimodels/_shared.py +++ /dev/null @@ -1,82 +0,0 @@ -import pydantic -from pydantic import Field, field_validator -from typing_extensions import override - -from horde_sdk.ai_horde_api.consts import KNOWN_SAMPLERS, KNOWN_SOURCE_PROCESSING -from horde_sdk.ai_horde_api.endpoints import AI_HORDE_BASE_URL -from horde_sdk.generic_api.apimodels import BaseRequest, BaseRequestWorkerDriven -from horde_sdk.utils import seed_to_int - - -class BaseAIHordeRequest(BaseRequest): - @override - @classmethod - def get_api_url(cls) -> str: - return AI_HORDE_BASE_URL - - -class LorasPayloadEntry(pydantic.BaseModel): - """Represents a single lora parameter. - - v2 API Model: `ModelPayloadLorasStable` - """ - - name: str = Field(min_length=1, max_length=255) - """The name of the LoRa model to use.""" - model: float = Field(default=1, ge=0, le=5) - """The strength of the LoRa against the stable diffusion model.""" - clip: float = Field(default=1, ge=0, le=5) - """The strength of the LoRa against the clip model.""" - inject_trigger: str | None = Field(default=None, min_length=1, max_length=30) - """Any trigger required to activate the LoRa model.""" - - -class BaseImageGenerateParam(pydantic.BaseModel): - """Represents some of the data included in a request to the `/v2/generate/async` endpoint. - - Also is the corresponding information returned on a job pop to the `/v2/generate/pop` endpoint. - v2 API Model: `ModelPayloadStable` - """ - - sampler_name: KNOWN_SAMPLERS = KNOWN_SAMPLERS.k_lms - cfg_scale: float = 7.5 - denoising_strength: float | None = Field(default=1, ge=0, le=1) - seed: str | None = None - height: int = 512 - width: int = 512 - seed_variation: int | None = None - post_processing: list[str] = Field(default_factory=list) - karras: bool = True - tiling: bool = False - hires_fix: bool = False - clip_skip: int = 1 - control_type: str | None = None - image_is_control: bool | None = None - return_control_map: bool | None = None - facefixer_strength: float | None = Field(default=None, ge=0, le=1) - loras: list[LorasPayloadEntry] = Field(default_factory=list) - special: dict = Field(default_factory=dict) - steps: int = Field(default=25, ge=1) - - n_iter: int = Field(default=1, ge=1) - use_nsfw_censor: bool = False - - @field_validator("sampler_name") - def sampler_name_must_be_known(cls, v): - """Ensure that the sampler name is in this list of supported samplers.""" - if v not in KNOWN_SAMPLERS.__members__: - raise ValueError(f"Unknown sampler name {v}") - return v - - @field_validator("seed") - def seed_to_int_if_str(cls, v): - """Ensure that the seed is an integer. If it is a string, convert it to an integer.""" - return str(seed_to_int(v)) - - -class ImageGenerateImg2ImgData(BaseRequestWorkerDriven): - """Mix-in class for data relating to img2img generation.""" - - source_image: str | None = None - source_processing: KNOWN_SOURCE_PROCESSING = KNOWN_SOURCE_PROCESSING.txt2img - source_mask: str | None = None diff --git a/src/horde_sdk/ai_horde_api/apimodels/stats.py b/src/horde_sdk/ai_horde_api/apimodels/_stats.py similarity index 77% rename from src/horde_sdk/ai_horde_api/apimodels/stats.py rename to src/horde_sdk/ai_horde_api/apimodels/_stats.py index 6f58e26..b7da8f7 100644 --- a/src/horde_sdk/ai_horde_api/apimodels/stats.py +++ b/src/horde_sdk/ai_horde_api/apimodels/_stats.py @@ -1,12 +1,17 @@ from typing_extensions import override -from horde_sdk.ai_horde_api.apimodels._shared import BaseAIHordeRequest +from horde_sdk.ai_horde_api.apimodels._base import BaseAIHordeRequest from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals from horde_sdk.consts import HTTPMethod from horde_sdk.generic_api.apimodels import BaseResponse class StatsModelsResponse(BaseResponse): + """Represents the data returned from the `/v2/stats/img/models` endpoint. + + v2 API Model: `ImgModelStats` + """ + model_config = {"frozen": True} day: dict[str, int] @@ -20,6 +25,8 @@ def get_api_model_name(cls) -> str | None: class StatsImageModels(BaseAIHordeRequest): + """Represents the data needed to make a request to the `/v2/stats/img/models` endpoint.""" + @override @classmethod def get_api_model_name(cls) -> str | None: diff --git a/src/horde_sdk/ai_horde_api/apimodels/generate/_async.py b/src/horde_sdk/ai_horde_api/apimodels/generate/_async.py index bcfd764..06920dd 100644 --- a/src/horde_sdk/ai_horde_api/apimodels/generate/_async.py +++ b/src/horde_sdk/ai_horde_api/apimodels/generate/_async.py @@ -1,10 +1,10 @@ from pydantic import Field, model_validator from typing_extensions import override -from horde_sdk.ai_horde_api.apimodels._shared import ( +from horde_sdk.ai_horde_api.apimodels._base import ( BaseAIHordeRequest, + BaseImageGenerateImg2Img, BaseImageGenerateParam, - ImageGenerateImg2ImgData, ) from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals from horde_sdk.ai_horde_api.fields import GenerationID @@ -13,7 +13,10 @@ class ImageGenerateAsyncResponse(BaseResponse): - """Represents the data returned from the `/v2/generate/async` endpoint.""" + """Represents the data returned from the `/v2/generate/async` endpoint. + + v2 API Model: `RequestAsync` + """ id: str | GenerationID # noqa: A003 """The UUID for this image generation.""" @@ -30,7 +33,7 @@ class ImageGenerationInputPayload(BaseImageGenerateParam): n: int = Field(default=1, ge=1) -class ImageGenerateAsyncRequest(BaseAIHordeRequest, ImageGenerateImg2ImgData, BaseRequestWorkerDriven): +class ImageGenerateAsyncRequest(BaseAIHordeRequest, BaseImageGenerateImg2Img, BaseRequestWorkerDriven): """Represents the data needed to make a request to the `/v2/generate/async` endpoint. v2 API Model: `GenerationInputStable` diff --git a/src/horde_sdk/ai_horde_api/apimodels/generate/_check.py b/src/horde_sdk/ai_horde_api/apimodels/generate/_check.py index 3acae33..5be5755 100644 --- a/src/horde_sdk/ai_horde_api/apimodels/generate/_check.py +++ b/src/horde_sdk/ai_horde_api/apimodels/generate/_check.py @@ -1,7 +1,6 @@ from typing_extensions import override -from horde_sdk.ai_horde_api.apimodels._base import BaseImageGenerateJobRequest -from horde_sdk.ai_horde_api.apimodels._shared import BaseAIHordeRequest +from horde_sdk.ai_horde_api.apimodels._base import BaseAIHordeRequest, BaseImageGenerateJobRequest from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals from horde_sdk.consts import HTTPMethod from horde_sdk.generic_api.apimodels import BaseResponse diff --git a/src/horde_sdk/ai_horde_api/apimodels/generate/_pop.py b/src/horde_sdk/ai_horde_api/apimodels/generate/_pop.py index 4dea489..b5ae4c0 100644 --- a/src/horde_sdk/ai_horde_api/apimodels/generate/_pop.py +++ b/src/horde_sdk/ai_horde_api/apimodels/generate/_pop.py @@ -2,7 +2,7 @@ from pydantic import Field, field_validator from typing_extensions import override -from horde_sdk.ai_horde_api.apimodels._shared import BaseAIHordeRequest, BaseImageGenerateParam +from horde_sdk.ai_horde_api.apimodels._base import BaseAIHordeRequest, BaseImageGenerateParam from horde_sdk.ai_horde_api.consts import KNOWN_SOURCE_PROCESSING from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals from horde_sdk.ai_horde_api.fields import GenerationID diff --git a/src/horde_sdk/ai_horde_api/apimodels/generate/_status.py b/src/horde_sdk/ai_horde_api/apimodels/generate/_status.py index 1dcb6da..03be05b 100644 --- a/src/horde_sdk/ai_horde_api/apimodels/generate/_status.py +++ b/src/horde_sdk/ai_horde_api/apimodels/generate/_status.py @@ -1,8 +1,7 @@ from pydantic import BaseModel, Field from typing_extensions import override -from horde_sdk.ai_horde_api.apimodels._base import BaseImageGenerateJobRequest -from horde_sdk.ai_horde_api.apimodels._shared import BaseAIHordeRequest +from horde_sdk.ai_horde_api.apimodels._base import BaseAIHordeRequest, BaseImageGenerateJobRequest from horde_sdk.ai_horde_api.apimodels.generate._check import ImageGenerateCheckResponse from horde_sdk.ai_horde_api.consts import GENERATION_STATE from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals diff --git a/src/horde_sdk/ai_horde_api/apimodels/generate/_submit.py b/src/horde_sdk/ai_horde_api/apimodels/generate/_submit.py index 56facdf..a8d9952 100644 --- a/src/horde_sdk/ai_horde_api/apimodels/generate/_submit.py +++ b/src/horde_sdk/ai_horde_api/apimodels/generate/_submit.py @@ -1,6 +1,6 @@ from typing_extensions import override -from horde_sdk.ai_horde_api.apimodels._shared import BaseAIHordeRequest +from horde_sdk.ai_horde_api.apimodels._base import BaseAIHordeRequest from horde_sdk.ai_horde_api.endpoints import AI_HORDE_API_URL_Literals from horde_sdk.ai_horde_api.fields import GenerationID from horde_sdk.consts import HTTPMethod @@ -11,6 +11,11 @@ class ImageGenerationJobSubmitResponse(BaseResponse): reward: float """The amount of kudos gained for submitting this request.""" + @override + @classmethod + def get_api_model_name(cls) -> str | None: + return "GenerationSubmitted" + class ImageGenerationJobSubmitRequest(BaseAIHordeRequest, BaseRequestAuthenticated): """Represents the data needed to make a job submit 'request' from a worker to the /v2/generate/submit endpoint. diff --git a/src/horde_sdk/ai_horde_api/endpoints.py b/src/horde_sdk/ai_horde_api/endpoints.py index 5646866..8798a1c 100644 --- a/src/horde_sdk/ai_horde_api/endpoints.py +++ b/src/horde_sdk/ai_horde_api/endpoints.py @@ -2,6 +2,8 @@ from strenum import StrEnum +from horde_sdk.generic_api.endpoints import url_with_path + AI_HORDE_BASE_URL = "https://aihorde.net/api/" if os.environ.get("HORDE_URL", None): @@ -15,6 +17,8 @@ class AI_HORDE_API_URL_Literals(StrEnum): """The URL actions 'paths' to the endpoints. Includes the find/replace strings in brackets for path (non-query) variables.""" + swagger = "/swagger.json" + # Note that the leading slash is included for consistency with the swagger docs, # but it is dropped when the URL is actually constructed (see `url_with_path` in `horde_sdk.generic_api.endpoints`) v2_stats_img_models = "/v2/stats/img/models" @@ -65,3 +69,10 @@ class AI_HORDE_API_URL_Literals(StrEnum): v2_workers_all = "/v2/workers" v2_workers = "/v2/workers/{worker_id}" + + +def get_ai_horde_swagger_url() -> str: + return url_with_path( + base_url=AI_HORDE_BASE_URL, + path=AI_HORDE_API_URL_Literals.swagger, + ) diff --git a/src/horde_sdk/ai_horde_api/utils/swagger.py b/src/horde_sdk/ai_horde_api/utils/swagger.py deleted file mode 100644 index cc7b9ad..0000000 --- a/src/horde_sdk/ai_horde_api/utils/swagger.py +++ /dev/null @@ -1,261 +0,0 @@ -from __future__ import annotations - -import requests -from horde_sdk.ai_horde_api.endpoints import AI_HORDE_BASE_URL -from horde_sdk.consts import HTTPMethod -from pydantic import BaseModel, Field, model_validator - -SWAGGER_DOC_URL = f"{AI_HORDE_BASE_URL}/swagger.json" - - -class SwaggerModelDefinitionAdditionalProperty(BaseModel): - # TODO: Is this actually a recursive SwaggerModelDefinitionProperty? - model_config = {"extra": "forbid"} - - type_: str | None = Field(None, alias="type") - description: str | None = None - - -class SwaggerModelDefinitionProperty(BaseModel): - """A property of a model (data structure) used in the API. - - This might also be referred to as a "field" or an "attribute" of the model. - - See https://swagger.io/docs/specification/data-models/data-types/#objects - """ - - model_config = {"extra": "forbid"} - - type_: str | None = Field(None, alias="type") - description: str | None = None - title: str | None = None - default: object | None = None - example: object | None = None - format_: str | None = Field(None, alias="format") - enum: list[str] | None = None - ref: str | None = Field(None, alias="$ref") - minimum: float | None = None - maximum: float | None = None - minLength: int | None = None - maxLength: int | None = None - multipleOf: float | None = None - uniqueItems: bool | None = None - additionalProperties: SwaggerModelDefinitionAdditionalProperty | None = None - items: SwaggerModelDefinitionProperty | list[SwaggerModelDefinitionProperty] | None = None - - -class SwaggerModelDefinitionRef(BaseModel): - """A reference to a model (data structure) used in the API.""" - - model_config = {"extra": "forbid"} - - ref: str | None = Field(None, alias="$ref") - - -class SwaggerModelDefinition(BaseModel): - """A definition of a model (data structure) used in the API.""" - - model_config = {"extra": "forbid"} - - type_: str | None = Field(None, alias="type") - properties: dict[str, SwaggerModelDefinitionProperty] | None = None - required: list[str] | None = None - - -class SwaggerModelDefinitionSchemaValidationMethods(BaseModel): - """When used instead of a SwaggerModelDefinition, it means that the model is validated against one or more schemas. - - See the `allOf`, `oneOf`, and `anyOf` properties of the Swagger spec: - https://swagger.io/docs/specification/data-models/oneof-anyof-allof-not/ - """ - - model_config = {"extra": "forbid"} - - allOf: list[SwaggerModelDefinition | SwaggerModelDefinitionRef] | None = None - """The model must match all of the schemas in this list.""" - oneOf: list[SwaggerModelDefinition | SwaggerModelDefinitionRef] | None = None - """The model must match exactly one of the schemas in this list.""" - anyOf: list[SwaggerModelDefinition | SwaggerModelDefinitionRef] | None = None - """The model must match at least one of the schemas in this list.""" - - @model_validator(mode="before") - def one_method_specified(cls, v): - """Ensure at least one of the validation methods is specified.""" - if not any([v.get("allOf"), v.get("oneOf"), v.get("anyOf")]): - raise ValueError("At least one of allOf, oneOf, or anyOf must be specified.") - - return v - - -class SwaggerDocTagsItem(BaseModel): - model_config = {"extra": "forbid"} - - name: str | None = None - description: str | None = None - - -class SwaggerDocResponseItem(BaseModel): - model_config = {"extra": "forbid"} - - description: str | None = None - - -class SwaggerDocInfo(BaseModel): - model_config = {"extra": "forbid"} - - title: str | None = None - version: str | None = None - description: str | None = None - - -class SwaggerEndpointMethodParameterSchemaItem(BaseModel): - model_config = {"extra": "forbid"} - ref: str | None = Field(None, alias="$ref") - - -class SwaggerEndpointMethodParameterSchemaProperty(BaseModel): - model_config = {"extra": "forbid"} - - type_: str | None = Field(None, alias="type") - format_: str | None = Field(None, alias="format") - - -class SwaggerEndpointMethodParameterSchema(BaseModel): - model_config = {"extra": "forbid"} - - type_: str | None = Field(None, alias="type") - properties: dict[str, SwaggerEndpointMethodParameterSchemaProperty] | None = None - - -class SwaggerEndpointMethodParameter(BaseModel): - model_config = {"extra": "forbid"} - - name: str | None = None - in_: str = Field("", alias="in") - description: str | None = None - required: bool | None = None - schema_: SwaggerEndpointMethodParameterSchema | SwaggerEndpointMethodParameterSchemaItem | None = Field( - None, alias="schema" - ) - default: object | None = None - type_: str | None = Field(None, alias="type") - format_: str | None = Field(None, alias="format") - - -class SwaggerEndpointResponseSchemaItem(BaseModel): - # model_config = {"extra": "forbid"} - - ref: str | None = Field(None, alias="$ref") - - -class SwaggerEndpointResponseSchema(BaseModel): - # model_config = {"extra": "forbid"} - - ref: str | None = Field(None, alias="$ref") - type_: str | None = Field(None, alias="type") - items: SwaggerEndpointResponseSchemaItem | None = None - - -class SwaggerEndpointResponse(BaseModel): - model_config = {"extra": "forbid"} - - description: str - schema_: SwaggerEndpointResponseSchema | None = Field(None, alias="schema") - - -class SwaggerEndpointMethod(BaseModel): - model_config = {"extra": "forbid"} - - summary: str | None = None - description: str | None = None - operation_id: str | None = Field(None, alias="operationId") - parameters: list[SwaggerEndpointMethodParameter] | None = None - responses: dict[str, SwaggerEndpointResponse] | None = None - tags: list[str] | None = None - - -class SwaggerEndpointParameter(BaseModel): - model_config = {"extra": "forbid"} - - name: str | None = None - in_: str = Field("", alias="in") - description: str | None = None - required: bool | None = None - type_: str | None = Field(None, alias="type") - - -class SwaggerEndpoint(BaseModel): - model_config = {"extra": "forbid"} - - parameters: list[SwaggerEndpointParameter] | None = None - - get: SwaggerEndpointMethod | None = None - post: SwaggerEndpointMethod | None = None - put: SwaggerEndpointMethod | None = None - delete: SwaggerEndpointMethod | None = None - patch: SwaggerEndpointMethod | None = None - options: SwaggerEndpointMethod | None = None - head: SwaggerEndpointMethod | None = None - - def get_endpoint_method_from_http_method(self, http_method: HTTPMethod | str) -> SwaggerEndpointMethod | None: - """Get the endpoint method for the given HTTP method.""" - if isinstance(http_method, str): - http_method = HTTPMethod(http_method) - return getattr(self, http_method.value.lower(), None) - - @model_validator(mode="before") - def at_least_one_method_specified(cls, v): - """Ensure at least one method is specified.""" - if not any( - [ - v.get("get"), - v.get("post"), - v.get("put"), - v.get("delete"), - v.get("patch"), - v.get("options"), - v.get("head"), - ] - ): - raise ValueError("At least one method must be specified.") - - return v - - -class SwaggerDoc(BaseModel): - model_config = {"extra": "forbid"} - - swagger: str - """The swagger version of the document""" - basePath: str - """The base path (after the top level domain) of the API. IE, `/api`.""" - paths: dict[str, SwaggerEndpoint] - """The endpoints of the API. IE, `/api/v2/generate/async`.""" - info: SwaggerDocInfo - """The info section of the document""" - produces: list[str] - """The content types that the API can produce in responses.""" - consumes: list[str] - """The content types that the API can consume in payloads.""" - tags: list[SwaggerDocTagsItem] | None = None - """Metadata about the document.""" - responses: dict[str, SwaggerDocResponseItem] | None = None - """Unknown""" - definitions: dict[str, SwaggerModelDefinition | SwaggerModelDefinitionSchemaValidationMethods] - """The definitions of the models (data structures) used in the API.""" - - -class SwaggerParser: - _swagger_json: dict - - def __init__(self) -> None: - # Try to get the swagger.json from the server - try: - response = requests.get(SWAGGER_DOC_URL) - response.raise_for_status() - self._swagger_json = response.json() - except requests.exceptions.HTTPError as e: - raise RuntimeError(f"Failed to get swagger.json from server: {e.response.text}") from e - - def get_swagger_doc(self) -> SwaggerDoc: - return SwaggerDoc.model_validate(self._swagger_json) diff --git a/src/horde_sdk/consts.py b/src/horde_sdk/consts.py index d955271..9ccf66f 100644 --- a/src/horde_sdk/consts.py +++ b/src/horde_sdk/consts.py @@ -20,6 +20,9 @@ class HTTPMethod(StrEnum): CONNECT = "CONNECT" +PAYLOAD_HTTP_METHODS = {HTTPMethod.POST, HTTPMethod.PUT, HTTPMethod.PATCH} + + class HTTPStatusCode(Enum): """An enum representing all HTTP status codes.""" diff --git a/src/horde_sdk/generic_api/__init__.py b/src/horde_sdk/generic_api/__init__.py index 600a14b..ba02948 100644 --- a/src/horde_sdk/generic_api/__init__.py +++ b/src/horde_sdk/generic_api/__init__.py @@ -1,7 +1,5 @@ """Tools for making or interacting with any horde APIs.""" -from horde_sdk.generic_api._error import RequestErrorResponse -from horde_sdk.generic_api.apimodels import BaseRequest, BaseRequestAuthenticated, BaseRequestUserSpecific -from horde_sdk.generic_api.generic_client import GenericHordeAPIClient +# isort:skip_file from horde_sdk.generic_api.metadata import ( GenericAcceptTypes, GenericHeaderFields, @@ -9,6 +7,12 @@ GenericQueryFields, ) + +from horde_sdk.generic_api._error import RequestErrorResponse +from horde_sdk.generic_api.apimodels import BaseRequest, BaseRequestAuthenticated, BaseRequestUserSpecific +from horde_sdk.generic_api.generic_client import GenericHordeAPIClient + + __all__ = [ "RequestErrorResponse", "BaseRequest", diff --git a/src/horde_sdk/generic_api/apimodels.py b/src/horde_sdk/generic_api/apimodels.py index 507d02a..eb3d13d 100644 --- a/src/horde_sdk/generic_api/apimodels.py +++ b/src/horde_sdk/generic_api/apimodels.py @@ -6,8 +6,8 @@ from pydantic import BaseModel, Field, field_validator from horde_sdk.consts import HTTPMethod, HTTPStatusCode +from horde_sdk.generic_api import GenericAcceptTypes from horde_sdk.generic_api.endpoints import url_with_path -from horde_sdk.generic_api.metadata import GenericAcceptTypes class HordeAPIMessage(BaseModel, abc.ABC): @@ -70,15 +70,15 @@ def get_expected_response_type(cls) -> type[BaseResponse]: """Return the `type` of the response expected.""" -class BaseRequestAuthenticated(BaseRequest): - """Represents abstractly a authenticated request, IE, using an API key.""" +class BaseRequestAuthenticated(BaseModel): + """Mix-in class to describe an endpoint which requires authentication.""" apikey: str # TODO validator """A horde API key.""" -class BaseRequestUserSpecific(BaseRequestAuthenticated): - """Represents the minimum for any request specifying a specific user to the API.""" +class BaseRequestUserSpecific(BaseModel): + """Mix-in class to describe an endpoint for which you can specify a user.""" "" user_id: str """The user's ID, as a `str`, but only containing numeric values.""" @@ -91,8 +91,8 @@ def user_id_is_numeric(cls, value: str) -> str: return value -class BaseRequestWorkerDriven(BaseRequestAuthenticated): - """Represents the minimum for any request which is ultimately backed by a worker (Such as on AI-Horde).""" +class BaseRequestWorkerDriven(BaseModel): + """ "Mix-in class to describe an endpoint for which you can specify workers.""" trusted_workers: bool = False slow_workers: bool = False diff --git a/src/horde_sdk/generic_api/generic_client.py b/src/horde_sdk/generic_api/generic_client.py index d7578bd..cc644a4 100644 --- a/src/horde_sdk/generic_api/generic_client.py +++ b/src/horde_sdk/generic_api/generic_client.py @@ -3,14 +3,14 @@ import requests from pydantic import BaseModel -from horde_sdk.generic_api._error import RequestErrorResponse -from horde_sdk.generic_api.apimodels import BaseRequest, BaseResponse -from horde_sdk.generic_api.metadata import ( +from horde_sdk.generic_api import ( GenericAcceptTypes, GenericHeaderFields, GenericPathFields, GenericQueryFields, ) +from horde_sdk.generic_api._error import RequestErrorResponse +from horde_sdk.generic_api.apimodels import BaseRequest, BaseResponse class _ParsedRequest(BaseModel): diff --git a/src/horde_sdk/generic_api/utils/swagger.py b/src/horde_sdk/generic_api/utils/swagger.py new file mode 100644 index 0000000..d6e8544 --- /dev/null +++ b/src/horde_sdk/generic_api/utils/swagger.py @@ -0,0 +1,526 @@ +from __future__ import annotations + +import json +import re +from abc import ABC, abstractmethod +from pathlib import Path + +import requests +from horde_sdk.consts import PAYLOAD_HTTP_METHODS, HTTPMethod, HTTPStatusCode +from loguru import logger +from pydantic import BaseModel, Field, model_validator +from typing_extensions import override + + +class SwaggerModelDefinitionAdditionalProperty(BaseModel): + # TODO: Is this actually a recursive SwaggerModelDefinitionProperty? + model_config = {"extra": "forbid"} + + type_: str | None = Field(None, alias="type") + description: str | None = None + + +class SwaggerModelDefinitionProperty(BaseModel): + """A property of a model (data structure) used in the API. + + This might also be referred to as a "field" or an "attribute" of the model. + + See https://swagger.io/docs/specification/data-models/data-types/#objects + """ + + model_config = {"extra": "forbid"} + + type_: str | None = Field(None, alias="type") + description: str | None = None + title: str | None = None + default: object | None = None + example: object | None = None + format_: str | None = Field(None, alias="format") + enum: list[str] | None = None + ref: str | None = Field(None, alias="$ref") + minimum: float | None = None + maximum: float | None = None + minLength: int | None = None + maxLength: int | None = None + multipleOf: float | None = None + uniqueItems: bool | None = None + additionalProperties: SwaggerModelDefinitionAdditionalProperty | None = None + items: SwaggerModelDefinitionProperty | list[SwaggerModelDefinitionProperty] | None = None + + +class SwaggerModelDefinitionRef(BaseModel): + """A reference to a model (data structure) used in the API.""" + + model_config = {"extra": "forbid"} + + ref: str | None = Field(None, alias="$ref") + + +class SwaggerModelDefinitionEntry(BaseModel, ABC): + """An entry in the definitions section of the swagger doc. This could be a model definition, or a schema validation + method object. See `SwaggerModelDefinitionSchemaValidationMethods` for more info.""" + + @abstractmethod + def get_all_definitions(self) -> list[SwaggerModelDefinition | SwaggerModelDefinitionRef]: + """Get all definitions from all validation methods.""" + raise NotImplementedError + + +class SwaggerModelDefinition(SwaggerModelDefinitionEntry): + """A definition of a model (data structure) used in the API.""" + + model_config = {"extra": "forbid"} + + type_: str | None = Field(None, alias="type") + properties: dict[str, SwaggerModelDefinitionProperty] | None = None + required: list[str] | None = None + + @override + def get_all_definitions(self) -> list[SwaggerModelDefinition | SwaggerModelDefinitionRef]: + """Get all definitions from all validation methods.""" + return [self] # This looks odd, but `SwaggerModelDefinitionSchemaValidationMethods` is the wrinkle here. + + +class SwaggerModelDefinitionSchemaValidationMethods(SwaggerModelDefinitionEntry): + """When used instead of a SwaggerModelDefinition, it means that the model is validated against one or more schemas. + + See the `allOf`, `oneOf`, and `anyOf` properties of the Swagger spec: + https://swagger.io/docs/specification/data-models/oneof-anyof-allof-not/ + """ + + model_config = {"extra": "forbid"} + + allOf: list[SwaggerModelDefinition | SwaggerModelDefinitionRef] | None = None + """The model must match all of the schemas in this list.""" + oneOf: list[SwaggerModelDefinition | SwaggerModelDefinitionRef] | None = None + """The model must match exactly one of the schemas in this list.""" + anyOf: list[SwaggerModelDefinition | SwaggerModelDefinitionRef] | None = None + """The model must match at least one of the schemas in this list.""" + + @model_validator(mode="before") + def one_method_specified(cls, v): + """Ensure at least one of the validation methods is specified.""" + if not any([v.get("allOf"), v.get("oneOf"), v.get("anyOf")]): + raise ValueError("At least one of allOf, oneOf, or anyOf must be specified.") + + return v + + @override + def get_all_definitions(self) -> list[SwaggerModelDefinition | SwaggerModelDefinitionRef]: + """Get all definitions from all validation methods.""" + return_list = [] + if self.allOf: + return_list.extend(self.allOf) + if self.oneOf: + return_list.extend(self.oneOf) + if self.anyOf: + return_list.extend(self.anyOf) + + return return_list + + +class SwaggerDocTagsItem(BaseModel): + model_config = {"extra": "forbid"} + + name: str | None = None + description: str | None = None + + +class SwaggerDocResponseItem(BaseModel): + model_config = {"extra": "forbid"} + + description: str | None = None + + +class SwaggerDocInfo(BaseModel): + model_config = {"extra": "forbid"} + + title: str | None = None + version: str | None = None + description: str | None = None + + +class SwaggerEndpointMethodParameterSchemaRef(BaseModel): + model_config = {"extra": "forbid"} + ref: str | None = Field(alias="$ref") + + +class SwaggerEndpointMethodParameterSchemaProperty(BaseModel): + model_config = {"extra": "forbid"} + + type_: str | None = Field(None, alias="type") + format_: str | None = Field(None, alias="format") + + +class SwaggerEndpointMethodParameterSchema(BaseModel): + model_config = {"extra": "forbid"} + + type_: str | None = Field(None, alias="type") + properties: dict[str, SwaggerEndpointMethodParameterSchemaProperty] | None = None + + +class SwaggerEndpointMethodParameter(BaseModel): + model_config = {"extra": "forbid"} + + name: str | None = None + in_: str = Field("", alias="in") + description: str | None = None + required: bool | None = None + schema_: SwaggerEndpointMethodParameterSchema | SwaggerEndpointMethodParameterSchemaRef | None = Field( + None, alias="schema" + ) + default: object | None = None + type_: str | None = Field(None, alias="type") + format_: str | None = Field(None, alias="format") + + +class SwaggerEndpointResponseSchemaItem(BaseModel): + # model_config = {"extra": "forbid"} + + ref: str | None = Field(None, alias="$ref") + + +class SwaggerEndpointResponseSchema(BaseModel): + # model_config = {"extra": "forbid"} + + ref: str | None = Field(None, alias="$ref") + type_: str | None = Field(None, alias="type") + items: SwaggerEndpointResponseSchemaItem | None = None + + +class SwaggerEndpointResponse(BaseModel): + model_config = {"extra": "forbid"} + + description: str + schema_: SwaggerEndpointResponseSchema | None = Field(None, alias="schema") + + +class SwaggerEndpointMethod(BaseModel): + model_config = {"extra": "forbid"} + + summary: str | None = None + description: str | None = None + operation_id: str | None = Field(None, alias="operationId") + parameters: list[SwaggerEndpointMethodParameter] | None = None + responses: dict[str, SwaggerEndpointResponse] | None = None + tags: list[str] | None = None + + +class SwaggerEndpointParameter(BaseModel): + model_config = {"extra": "forbid"} + + name: str | None = None + in_: str = Field("", alias="in") + description: str | None = None + required: bool | None = None + type_: str | None = Field(None, alias="type") + + +class SwaggerEndpoint(BaseModel): + model_config = {"extra": "forbid"} + + parameters: list[SwaggerEndpointParameter] | None = None + + get: SwaggerEndpointMethod | None = None + post: SwaggerEndpointMethod | None = None + put: SwaggerEndpointMethod | None = None + delete: SwaggerEndpointMethod | None = None + patch: SwaggerEndpointMethod | None = None + options: SwaggerEndpointMethod | None = None + head: SwaggerEndpointMethod | None = None + + def get_endpoint_method_from_http_method(self, http_method: HTTPMethod | str) -> SwaggerEndpointMethod | None: + """Get the endpoint method for the given HTTP method.""" + if isinstance(http_method, str): + http_method = HTTPMethod(http_method) + return getattr(self, http_method.value.lower(), None) + + def get_defined_endpoints(self) -> dict[str, SwaggerEndpointMethod]: + """Get the endpoints that are specified in the swagger doc.""" + return_dict = {} + for http_method, endpoint_method in self.__dict__.items(): + if isinstance(endpoint_method, SwaggerEndpointMethod): + return_dict[http_method] = endpoint_method + return return_dict + + @model_validator(mode="before") + def at_least_one_method_specified(cls, v): + """Ensure at least one method is specified.""" + if not any( + [ + v.get("get"), + v.get("post"), + v.get("put"), + v.get("delete"), + v.get("patch"), + v.get("options"), + v.get("head"), + ] + ): + raise ValueError("At least one method must be specified.") + + return v + + +class SwaggerDoc(BaseModel): + model_config = {"extra": "forbid"} + + swagger: str + """The swagger version of the document""" + basePath: str + """The base path (after the top level domain) of the API. IE, `/api`.""" + paths: dict[str, SwaggerEndpoint] + """The endpoints of the API. IE, `/api/v2/generate/async`.""" + info: SwaggerDocInfo + """The info section of the document""" + produces: list[str] + """The content types that the API can produce in responses.""" + consumes: list[str] + """The content types that the API can consume in payloads.""" + tags: list[SwaggerDocTagsItem] | None = None + """Metadata about the document.""" + responses: dict[str, SwaggerDocResponseItem] | None = None + """Unknown""" + definitions: dict[str, SwaggerModelDefinition | SwaggerModelDefinitionSchemaValidationMethods] + """The definitions of the models (data structures) used in the API.""" + + def extract_all_response_examples(self) -> dict[str, dict[HTTPMethod, dict[HTTPStatusCode, object]]]: + """Extract all response examples from the swagger doc. + + This in the form of: + `dict[endpoint_path, dict[http_method, dict[http_status_code, example_response]]]` + """ + every_endpoint_example: dict[str, dict[HTTPMethod, dict[HTTPStatusCode, object]]] = {} + + # Iterate through each endpoint in the Swagger documentation + for endpoint_path, endpoint in self.paths.items(): + endpoint_examples: dict[HTTPMethod, dict[HTTPStatusCode, object]] = {} + + # Iterate through each HTTP method used by the endpoint + for http_method_name, endpoint_method_definition in endpoint.get_defined_endpoints().items(): + endpoint_method_examples: dict[HTTPStatusCode, object] = {} + + if not endpoint_method_definition.responses: + continue + + logger.debug(f"Found {endpoint_path} {http_method_name.upper()} with response") + + # Iterate through each defined response for the HTTP method + for http_status_code_str, response_definition in endpoint_method_definition.responses.items(): + http_status_code_object = HTTPStatusCode(int(http_status_code_str)) + + if not response_definition.schema_: + continue + + # If the response schema is a reference to an API model, resolve the reference and include + # its defaults too + if isinstance(response_definition.schema_, SwaggerEndpointResponseSchema): + logger.debug(f"Resolving {response_definition.schema_.ref}") + example_response = self._resolve_model_ref_defaults(response_definition.schema_.ref) + endpoint_method_examples.update( + {http_status_code_object: example_response}, + ) + # If there is an explicit, endpoint specific schema, use that too + elif isinstance(response_definition.schema_, SwaggerEndpointResponseSchemaItem): + if not response_definition.schema_.ref: + continue + logger.debug(f"Resolving {response_definition.schema_.ref}") + example_response = self._resolve_model_ref_defaults(response_definition.schema_.ref) + endpoint_method_examples.update( + {http_status_code_object: example_response}, + ) + + if endpoint_method_examples: + endpoint_examples.update( + {HTTPMethod(http_method_name.upper()): endpoint_method_examples}, + ) + + if endpoint_examples: + every_endpoint_example[endpoint_path] = endpoint_examples + + return every_endpoint_example + + def extract_all_payload_examples(self) -> dict[str, dict[HTTPMethod, dict[str, object]]]: + """Extract all examples from the swagger doc. + + This in the form of: + `dict[endpoint_path, dict[param_name, param_example_value]]]` + """ + every_endpoint_example: dict[str, dict[HTTPMethod, dict[str, object]]] = {} + for endpoint_path, endpoint in self.paths.items(): + endpoint_examples: dict[HTTPMethod, dict[str, object]] = {} + + for http_method_name, endpoint_method_definition in endpoint.get_defined_endpoints().items(): + if http_method_name.upper() not in PAYLOAD_HTTP_METHODS: + continue + + if not endpoint_method_definition.parameters: + continue + + logger.debug(f"Found {endpoint_path} {http_method_name.upper()} with payload") + _payload_definitions = [d for d in endpoint_method_definition.parameters if d.name == "payload"] + + if len(_payload_definitions) != 1: + raise RuntimeError( + f"Expected to find exactly one payload definition for {endpoint_path} {http_method_name}" + ) + + payload_definition = _payload_definitions[0] + + if not payload_definition.schema_: + raise RuntimeError( + f"Expected to find a schema for {endpoint_path} {http_method_name} payload definition" + ) + + if isinstance(payload_definition.schema_, SwaggerEndpointMethodParameterSchemaRef): + logger.debug(f"Resolving {payload_definition.schema_.ref}") + example_payload = self._resolve_model_ref_defaults(payload_definition.schema_.ref) + endpoint_examples.update( + {HTTPMethod(http_method_name.upper()): example_payload}, + ) + + elif isinstance(payload_definition.schema_, SwaggerEndpointMethodParameterSchema): + if not payload_definition.schema_.properties: + continue + if not payload_definition.required: + continue + endpoint_specific_schema = {} + for prop_name, prop in payload_definition.schema_.properties.items(): + if not prop.type_: + raise RuntimeError( + f"Expected to find a type for {endpoint_path} {http_method_name} payload definition" + ) + endpoint_specific_schema[prop_name] = default_swagger_value_from_type_name(prop.type_) + + endpoint_examples.update( + {HTTPMethod(http_method_name.upper()): endpoint_specific_schema}, + ) + + if endpoint_examples: + every_endpoint_example[endpoint_path] = endpoint_examples + + return every_endpoint_example + + @staticmethod + def filename_from_endpoint_path(endpoint_path: str, http_method: HTTPMethod) -> str: + """Get the filename for the given endpoint path.""" + endpoint_path = re.sub(r"\W+", "_", endpoint_path) + endpoint_path = endpoint_path + "_" + http_method.value.lower() + return re.sub(r"__+", "_", endpoint_path) + + def write_all_payload_examples_to_file(self, directory: str | Path) -> bool: + directory = Path(directory) + all_examples = self.extract_all_payload_examples() + for endpoint_path, endpoint_examples_info in all_examples.items(): + for http_method, example_payload in endpoint_examples_info.items(): + filename = self.filename_from_endpoint_path(endpoint_path, http_method) + filepath = directory / f"{filename}.json" + with open(filepath, "w") as f: + json.dump(example_payload, f, indent=4) + return True + + def _resolve_model_ref_defaults( + self, + ref: str | None, + ) -> dict[str, object]: + """For ref entries, recursively resolve the default values for all properties. + + Note that this will combine all properties from all definitions that are referenced. + This function is more useful when trying to determine the expected payload for a request + or the expected response for a particular endpoint, rather than the schema for a particular model. + """ + # example ref: "#/definitions/RequestError" + + if not ref: + return {} + + return_dict = {} + + if ref.startswith("#/definitions/"): + ref = ref[len("#/definitions/") :] + + if ref not in self.definitions: + raise RuntimeError(f"Failed to find definition for {ref}") + + found_def_parent = self.definitions[ref] + + logger.debug(f"Found definition for {ref}") + + all_defs = found_def_parent.get_all_definitions() + + if not all_defs: + raise RuntimeError(f"Failed to find any definitions for {ref}") + for definition in all_defs: + if not isinstance(definition, SwaggerModelDefinitionRef): + continue + + logger.debug(f"Recursing ref: {definition.ref}") + return_dict.update(self._resolve_model_ref_defaults(definition.ref)) + + if len(all_defs) > 1: + logger.debug(f"Found {len(all_defs)} definitions for {ref}") + + for definition in all_defs: + if not isinstance(definition, SwaggerModelDefinition): + if definition.ref: + continue + raise RuntimeError(f"Unexpected definition type: {type(definition)}") + + if not definition.properties: + continue + + for prop_name, prop in definition.properties.items(): + if prop_name == "models": + pass + if prop.ref: + continue + if prop.example is not None: + return_dict[prop_name] = prop.example + continue + if prop.default is not None: + return_dict[prop_name] = prop.default + continue + + if not prop.type_: + raise RuntimeError(f"Failed to find type for {prop_name} in {ref}") + return_dict[prop_name] = default_swagger_value_from_type_name(prop.type_) + + return return_dict + + +class SwaggerParser: + _swagger_json: dict + + def __init__(self, swagger_doc_url: str) -> None: + # Try to get the swagger.json from the server + try: + response = requests.get(swagger_doc_url) + response.raise_for_status() + self._swagger_json = response.json() + except requests.exceptions.HTTPError as e: + raise RuntimeError(f"Failed to get swagger.json from server: {e.response.text}") from e + + def get_swagger_doc(self) -> SwaggerDoc: + return SwaggerDoc.model_validate(self._swagger_json) + + def get_all_examples(self) -> dict[str, dict[str, object]]: + return {} + + +_SWAGGER_TYPE_TO_PYTHON_TYPE = { + "integer": int, + "number": float, + "string": str, + "boolean": bool, + "array": list, + "object": dict, +} + + +def resolve_swagger_type_name(type_name: str) -> type: + """Resolve a swagger type name to a python type.""" + return _SWAGGER_TYPE_TO_PYTHON_TYPE[type_name] + + +def default_swagger_value_from_type_name(type_name: str) -> object: + return resolve_swagger_type_name(type_name)() diff --git a/src/horde_sdk/models.py b/src/horde_sdk/models.py index 0b15658..103d5c5 100644 --- a/src/horde_sdk/models.py +++ b/src/horde_sdk/models.py @@ -1,16 +1,22 @@ import json -models_json = None -with open("src/horde_sdk/models.json") as f: - models = f.read() - models_json = json.loads(models) -# sort by value -month_models_json = models_json["month"] -month_models_json = dict(sorted(month_models_json.items(), key=lambda item: item[1], reverse=True)) +def main(): + models_json = None + with open("src/horde_sdk/models.json") as f: + models = f.read() + models_json = json.loads(models) -# print top 5 models -for i, (model_name, model) in enumerate(month_models_json.items()): - if i == 10: - break - print(f"{model_name}: {model}") + # sort by value + month_models_json = models_json["month"] + month_models_json = dict(sorted(month_models_json.items(), key=lambda item: item[1], reverse=True)) + + # print top 5 models + for i, (model_name, model) in enumerate(month_models_json.items()): + if i == 10: + break + print(f"{model_name}: {model}") + + +if __name__ == "__main__": + main() diff --git a/src/horde_sdk/ratings_api/apimodels.py b/src/horde_sdk/ratings_api/apimodels.py index d5e58e4..2e49b9f 100644 --- a/src/horde_sdk/ratings_api/apimodels.py +++ b/src/horde_sdk/ratings_api/apimodels.py @@ -136,7 +136,7 @@ def get_api_model_name(cls) -> str | None: # region Requests -class BaseRequestImageSpecific(BaseRatingsAPIRequest, BaseRequestAuthenticated): +class BaseRequestImageSpecific(BaseModel): """Represents the minimum for any request specifying a specific user to the API.""" image_id: uuid.UUID @@ -227,7 +227,7 @@ def get_expected_response_type() -> type[UserValidateResponse]: return UserValidateResponse -class UserCheckRequest(BaseRatingsAPIRequest, BaseRequestUserSpecific): +class UserCheckRequest(BaseRatingsAPIRequest, BaseRequestUserSpecific, BaseRequestAuthenticated): """Represents the data needed to make a request to the `/v1/user/check/` endpoint.""" minutes: int = Field(ge=1) diff --git a/src/horde_sdk/scripts/write_all_payload_examples.py b/src/horde_sdk/scripts/write_all_payload_examples.py new file mode 100644 index 0000000..b2e6c29 --- /dev/null +++ b/src/horde_sdk/scripts/write_all_payload_examples.py @@ -0,0 +1,17 @@ +from pathlib import Path + +from horde_sdk.ai_horde_api.endpoints import get_ai_horde_swagger_url +from horde_sdk.generic_api.utils.swagger import SwaggerParser + + +def main(*, test_data_path: Path | None = None): + ai_horde_swagger_doc = SwaggerParser(get_ai_horde_swagger_url()).get_swagger_doc() + if not test_data_path: + test_data_path = ( + Path(__file__).parent.parent.parent.parent / "tests" / "test_data" / "ai_horde_api" / "example_payloads" + ) + ai_horde_swagger_doc.write_all_payload_examples_to_file(test_data_path) + + +if __name__ == "__main__": + main() diff --git a/tests/ai_horde_api/test_dynamically_validate_against_swagger.py b/tests/ai_horde_api/test_dynamically_validate_against_swagger.py index 258bca3..a4d6581 100644 --- a/tests/ai_horde_api/test_dynamically_validate_against_swagger.py +++ b/tests/ai_horde_api/test_dynamically_validate_against_swagger.py @@ -1,7 +1,11 @@ import horde_sdk.ai_horde_api as ai_horde_api -from horde_sdk.ai_horde_api.utils.swagger import SwaggerEndpoint, SwaggerParser +from horde_sdk.ai_horde_api.endpoints import get_ai_horde_swagger_url from horde_sdk.consts import HTTPMethod from horde_sdk.generic_api._reflection import get_all_request_types +from horde_sdk.generic_api.utils.swagger import ( + SwaggerEndpoint, + SwaggerParser, +) def test_all_ai_horde_model_defs_in_swagger() -> None: @@ -15,12 +19,15 @@ def test_all_ai_horde_model_defs_in_swagger() -> None: # Retrieve the swagger doc swagger_doc = None try: - swagger_doc = SwaggerParser().get_swagger_doc() + swagger_doc = SwaggerParser(get_ai_horde_swagger_url()).get_swagger_doc() except RuntimeError as e: raise RuntimeError(f"Failed to get swagger doc: {e}") from e assert swagger_doc, "Failed to get SwaggerDoc" - all_swagger_defined_models = swagger_doc.definitions.keys() + swagger_defined_models = swagger_doc.definitions.keys() + swagger_defined_examples: dict[ + str, dict[HTTPMethod, dict[str, object]] + ] = swagger_doc.extract_all_payload_examples() for request_type in all_request_types: endpoint_subpath = request_type.get_endpoint_subpath() @@ -49,7 +56,9 @@ def test_all_ai_horde_model_defs_in_swagger() -> None: # a payload else: assert ( - request_type.get_api_model_name() in all_swagger_defined_models + request_type.get_api_model_name() in swagger_defined_models ), f"Model is defined in horde_sdk, but not in swagger: {request_type.get_api_model_name()}" assert endpoint_subpath in swagger_doc.paths, f"Missing {request_type.__name__} in swagger" + + assert endpoint_subpath in swagger_defined_examples, f"Missing {request_type.__name__} in swagger examples" diff --git a/tests/ai_horde_api/test_swagger.py b/tests/ai_horde_api/test_swagger.py index 67b222b..64356dc 100644 --- a/tests/ai_horde_api/test_swagger.py +++ b/tests/ai_horde_api/test_swagger.py @@ -1,12 +1,13 @@ -from horde_sdk.ai_horde_api.utils.swagger import SwaggerDoc, SwaggerParser +from horde_sdk.ai_horde_api.endpoints import get_ai_horde_swagger_url +from horde_sdk.generic_api.utils.swagger import SwaggerDoc, SwaggerParser def test_swagger_parser_init(): - SwaggerParser() + SwaggerParser(get_ai_horde_swagger_url()) def test_get_swagger_doc(): - parser = SwaggerParser() + parser = SwaggerParser(get_ai_horde_swagger_url()) doc = parser.get_swagger_doc() assert isinstance(doc, SwaggerDoc) assert doc.swagger == "2.0" @@ -28,3 +29,17 @@ def test_get_swagger_doc(): assert doc.responses assert len(doc.responses) > 0 + + +def test_extract_all_payload_examples() -> None: + swagger_doc = SwaggerParser(get_ai_horde_swagger_url()).get_swagger_doc() + + all_request_examples = swagger_doc.extract_all_payload_examples() + assert len(all_request_examples) > 0, "Failed to extract any examples from the swagger doc" + + +def test_extract_all_response_examples() -> None: + swagger_doc = SwaggerParser(get_ai_horde_swagger_url()).get_swagger_doc() + + all_response_examples = swagger_doc.extract_all_response_examples() + assert len(all_response_examples) > 0, "Failed to extract any examples from the swagger doc" diff --git a/tests/ratings_api/test_init.py b/tests/ratings_api/test_init.py new file mode 100644 index 0000000..cf91dad --- /dev/null +++ b/tests/ratings_api/test_init.py @@ -0,0 +1,6 @@ +def test_init(): + pass + + +def test_init_meta(): + pass diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_filters_filter_id_patch.json b/tests/test_data/ai_horde_api/example_payloads/_v2_filters_filter_id_patch.json new file mode 100644 index 0000000..cfe2b90 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_filters_filter_id_patch.json @@ -0,0 +1,6 @@ +{ + "regex": "ac.*", + "filter_type": 10, + "description": "", + "replacement": "" +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_filters_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_filters_post.json new file mode 100644 index 0000000..b446979 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_filters_post.json @@ -0,0 +1,4 @@ +{ + "prompt": "", + "filter_type": 0 +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_filters_put.json b/tests/test_data/ai_horde_api/example_payloads/_v2_filters_put.json new file mode 100644 index 0000000..cfe2b90 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_filters_put.json @@ -0,0 +1,6 @@ +{ + "regex": "ac.*", + "filter_type": 10, + "description": "", + "replacement": "" +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_generate_async_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_async_post.json new file mode 100644 index 0000000..57ffabd --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_async_post.json @@ -0,0 +1,17 @@ +{ + "prompt": "", + "nsfw": false, + "trusted_workers": false, + "slow_workers": true, + "censor_nsfw": false, + "workers": [], + "worker_blacklist": false, + "models": [], + "source_image": "", + "source_processing": "img2img", + "source_mask": "", + "r2": true, + "shared": false, + "replacement_filter": true, + "dry_run": false +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_generate_pop_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_pop_post.json new file mode 100644 index 0000000..8dd0639 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_pop_post.json @@ -0,0 +1,18 @@ +{ + "name": "", + "priority_usernames": [], + "nsfw": false, + "models": [], + "bridge_version": 1, + "bridge_agent": "AI Horde Worker:11:https://github.com/db0/AI-Horde-Worker", + "threads": 1, + "require_upfront_kudos": false, + "max_pixels": 262144, + "blacklist": [], + "allow_img2img": true, + "allow_painting": true, + "allow_unsafe_ipaddr": true, + "allow_post_processing": true, + "allow_controlnet": true, + "allow_lora": true +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_generate_rate_id_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_rate_id_post.json new file mode 100644 index 0000000..88959f8 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_rate_id_post.json @@ -0,0 +1,4 @@ +{ + "best": "6038971e-f0b0-4fdd-a3bb-148f561f815e", + "ratings": [] +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_generate_submit_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_submit_post.json new file mode 100644 index 0000000..02f4b2d --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_submit_post.json @@ -0,0 +1,7 @@ +{ + "id": "00000000-0000-0000-0000-000000000000", + "generation": "R2", + "state": "ok", + "seed": 0, + "censored": false +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_generate_text_async_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_text_async_post.json new file mode 100644 index 0000000..c0709e2 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_text_async_post.json @@ -0,0 +1,10 @@ +{ + "prompt": "", + "softprompt": "", + "trusted_workers": false, + "slow_workers": true, + "workers": [], + "worker_blacklist": false, + "models": [], + "dry_run": false +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_generate_text_pop_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_text_pop_post.json new file mode 100644 index 0000000..e714919 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_text_pop_post.json @@ -0,0 +1,13 @@ +{ + "name": "", + "priority_usernames": [], + "nsfw": false, + "models": [], + "bridge_version": 1, + "bridge_agent": "AI Horde Worker:11:https://github.com/db0/AI-Horde-Worker", + "threads": 1, + "require_upfront_kudos": false, + "max_length": 512, + "max_context_length": 2048, + "softprompts": [] +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_generate_text_submit_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_text_submit_post.json new file mode 100644 index 0000000..3450bf4 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_generate_text_submit_post.json @@ -0,0 +1,5 @@ +{ + "id": "00000000-0000-0000-0000-000000000000", + "generation": "R2", + "state": "ok" +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_interrogate_async_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_interrogate_async_post.json new file mode 100644 index 0000000..d6bf6d9 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_interrogate_async_post.json @@ -0,0 +1,5 @@ +{ + "forms": [], + "source_image": "", + "slow_workers": true +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_interrogate_pop_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_interrogate_pop_post.json new file mode 100644 index 0000000..ff6880b --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_interrogate_pop_post.json @@ -0,0 +1,10 @@ +{ + "name": "", + "priority_usernames": [], + "forms": [], + "amount": 1, + "bridge_version": 1, + "bridge_agent": "AI Horde Worker:11:https://github.com/db0/AI-Horde-Worker", + "threads": 1, + "max_tiles": 16 +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_interrogate_submit_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_interrogate_submit_post.json new file mode 100644 index 0000000..d35d0a8 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_interrogate_submit_post.json @@ -0,0 +1,5 @@ +{ + "id": "", + "result": "", + "state": "" +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_kudos_award_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_kudos_award_post.json new file mode 100644 index 0000000..5d6abf7 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_kudos_award_post.json @@ -0,0 +1,4 @@ +{ + "username": "", + "amount": 0 +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_kudos_kai_user_id_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_kudos_kai_user_id_post.json new file mode 100644 index 0000000..1f17a1f --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_kudos_kai_user_id_post.json @@ -0,0 +1,5 @@ +{ + "kai_id": 0, + "kudos_amount": 0, + "trusted": false +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_kudos_transfer_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_kudos_transfer_post.json new file mode 100644 index 0000000..5d6abf7 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_kudos_transfer_post.json @@ -0,0 +1,4 @@ +{ + "username": "", + "amount": 0 +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_sharedkeys_put.json b/tests/test_data/ai_horde_api/example_payloads/_v2_sharedkeys_put.json new file mode 100644 index 0000000..c19c61e --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_sharedkeys_put.json @@ -0,0 +1,8 @@ +{ + "kudos": 5000, + "expiry": 30, + "name": "Mutual Aid", + "max_image_pixels": -1, + "max_image_steps": -1, + "max_text_tokens": -1 +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_sharedkeys_sharedkey_id_patch.json b/tests/test_data/ai_horde_api/example_payloads/_v2_sharedkeys_sharedkey_id_patch.json new file mode 100644 index 0000000..c19c61e --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_sharedkeys_sharedkey_id_patch.json @@ -0,0 +1,8 @@ +{ + "kudos": 5000, + "expiry": 30, + "name": "Mutual Aid", + "max_image_pixels": -1, + "max_image_steps": -1, + "max_text_tokens": -1 +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_status_modes_put.json b/tests/test_data/ai_horde_api/example_payloads/_v2_status_modes_put.json new file mode 100644 index 0000000..1016d45 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_status_modes_put.json @@ -0,0 +1,5 @@ +{ + "maintenance": false, + "invite_only": false, + "raid": false +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_teams_post.json b/tests/test_data/ai_horde_api/example_payloads/_v2_teams_post.json new file mode 100644 index 0000000..db68aa5 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_teams_post.json @@ -0,0 +1,4 @@ +{ + "name": "", + "info": "Anarchy is emergent order." +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_teams_team_id_patch.json b/tests/test_data/ai_horde_api/example_payloads/_v2_teams_team_id_patch.json new file mode 100644 index 0000000..db68aa5 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_teams_team_id_patch.json @@ -0,0 +1,4 @@ +{ + "name": "", + "info": "Anarchy is emergent order." +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_users_user_id_put.json b/tests/test_data/ai_horde_api/example_payloads/_v2_users_user_id_put.json new file mode 100644 index 0000000..6ae894b --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_users_user_id_put.json @@ -0,0 +1,17 @@ +{ + "kudos": 0.0, + "concurrency": 0, + "usage_multiplier": 0.0, + "worker_invited": 0, + "moderator": false, + "public_workers": false, + "monthly_kudos": 0, + "username": "", + "trusted": false, + "flagged": false, + "customizer": false, + "vpn": false, + "special": false, + "reset_suspicion": false, + "contact": "email@example.com" +} diff --git a/tests/test_data/ai_horde_api/example_payloads/_v2_workers_worker_id_put.json b/tests/test_data/ai_horde_api/example_payloads/_v2_workers_worker_id_put.json new file mode 100644 index 0000000..805ca2b --- /dev/null +++ b/tests/test_data/ai_horde_api/example_payloads/_v2_workers_worker_id_put.json @@ -0,0 +1,8 @@ +{ + "maintenance": false, + "maintenance_msg": "", + "paused": false, + "info": "", + "name": "", + "team": "0bed257b-e57c-4327-ac64-40cdfb1ac5e6" +} diff --git a/tests/test_data/ai_horde_api/example_responses/ImageGenerationJobSubmitRequest.json b/tests/test_data/ai_horde_api/example_responses/ImageGenerationJobSubmitRequest.json new file mode 100644 index 0000000..02f4b2d --- /dev/null +++ b/tests/test_data/ai_horde_api/example_responses/ImageGenerationJobSubmitRequest.json @@ -0,0 +1,7 @@ +{ + "id": "00000000-0000-0000-0000-000000000000", + "generation": "R2", + "state": "ok", + "seed": 0, + "censored": false +} diff --git a/tests/test_data/ai_horde_api/example_responses/ImageGenerationJobSubmitResponse.json b/tests/test_data/ai_horde_api/example_responses/ImageGenerationJobSubmitResponse.json new file mode 100644 index 0000000..df1be44 --- /dev/null +++ b/tests/test_data/ai_horde_api/example_responses/ImageGenerationJobSubmitResponse.json @@ -0,0 +1,3 @@ +{ + "reward": 10 +} diff --git a/tests/test_data/swagger.json b/tests/test_data/ai_horde_api/swagger.json similarity index 100% rename from tests/test_data/swagger.json rename to tests/test_data/ai_horde_api/swagger.json diff --git a/tests/test_dynamically_check_apimodels.py b/tests/test_dynamically_check_apimodels.py index cf49a7b..76f7ac6 100644 --- a/tests/test_dynamically_check_apimodels.py +++ b/tests/test_dynamically_check_apimodels.py @@ -6,8 +6,15 @@ import horde_sdk.ai_horde_api as ai_horde_api import horde_sdk.generic_api as generic_api import horde_sdk.ratings_api as ratings_api +from horde_sdk.consts import HTTPMethod from horde_sdk.generic_api._reflection import get_all_request_types from horde_sdk.generic_api.apimodels import BaseResponse +from horde_sdk.generic_api.utils.swagger import SwaggerDoc + +EXAMPLE_PAYLOADS: dict[ModuleType, Path] = { + ai_horde_api: Path("tests/test_data/ai_horde_api/example_payloads"), + ratings_api: Path("tests/test_data/ratings_api/example_payloads"), +} EXAMPLE_RESPONSES: dict[ModuleType, Path] = { ai_horde_api: Path("tests/test_data/ai_horde_api/example_responses"), @@ -35,7 +42,7 @@ def test_get_all_request_types(self) -> None: # noqa: D102 ), f"Response type is not a subclass of `BaseResponse`: {request_type}" @staticmethod - def dynamic_json_load(module_name: str, sample_data_folder: str | Path) -> None: + def dynamic_json_load(module: ModuleType) -> None: """Attempts to create instances of all non-abstract children of `RequestBase`.""" # This test does a lot of heavy lifting. If you're looking to make additions/changes. # This is probably the first test that will fail if you break something. @@ -43,6 +50,10 @@ def dynamic_json_load(module_name: str, sample_data_folder: str | Path) -> None: # If you're here because it failed and you're not sure why, # check the implementation of `BaseRequestUserSpecific` and `UserRatingsRequest` + module_name = module.__name__ + example_payload_folder = EXAMPLE_PAYLOADS[module] + example_response_folder = EXAMPLE_RESPONSES[module] + all_request_types: list[type[generic_api.BaseRequest]] = get_all_request_types(module_name) for request_type in all_request_types: @@ -56,15 +67,24 @@ def dynamic_json_load(module_name: str, sample_data_folder: str | Path) -> None: response_type, BaseResponse ), f"Response type is not a subclass of `BaseResponse`: {response_type}" - target_file = f"{sample_data_folder}/{response_type.__name__}.json" - assert os.path.exists(target_file), f"Missing sample data file: {target_file}" + if request_type.get_http_method() not in [HTTPMethod.GET, HTTPMethod.DELETE]: + example_payload_filename = SwaggerDoc.filename_from_endpoint_path( + request_type.get_endpoint_subpath(), + request_type.get_http_method(), + ) + + target_payload_file_path = f"{example_payload_folder}/{example_payload_filename}.json" + assert os.path.exists( + target_payload_file_path + ), f"Missing example payload file: {target_payload_file_path}" - with open(target_file) as sample_file_handle: + target_response_file_path = f"{example_response_folder}/{response_type.__name__}.json" + with open(target_response_file_path) as sample_file_handle: sample_data_json = json.loads(sample_file_handle.read()) response_type(**sample_data_json) def test_horde_api(self) -> None: - self.dynamic_json_load(ai_horde_api.__name__, EXAMPLE_RESPONSES[ai_horde_api]) + self.dynamic_json_load(ai_horde_api) def test_ratings_api(self) -> None: - self.dynamic_json_load(ratings_api.__name__, EXAMPLE_RESPONSES[ratings_api]) + self.dynamic_json_load(ratings_api) diff --git a/tox.ini b/tox.ini index b5e57f1..68d4207 100644 --- a/tox.ini +++ b/tox.ini @@ -5,7 +5,7 @@ env_list = [coverage:paths] source = - horde_model_reference/ + horde_sdk/ ignore_errors = True skip_empty = True