diff --git a/codegen/codegen_regex_fixes.py b/codegen/codegen_regex_fixes.py index 5a78cb4..0c44e2b 100644 --- a/codegen/codegen_regex_fixes.py +++ b/codegen/codegen_regex_fixes.py @@ -4,7 +4,7 @@ import sys -def main(path): +def main(path: str) -> None: print(f"Processing {path}") with open(path) as f: contents = f.read() diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..82aa7eb --- /dev/null +++ b/mypy.ini @@ -0,0 +1,2 @@ +[mypy] +disallow_untyped_defs = True diff --git a/src/horde_sdk/ai_horde_api/apimodels/_base.py b/src/horde_sdk/ai_horde_api/apimodels/_base.py index 61027dc..8a0c8ef 100644 --- a/src/horde_sdk/ai_horde_api/apimodels/_base.py +++ b/src/horde_sdk/ai_horde_api/apimodels/_base.py @@ -76,14 +76,14 @@ class BaseImageGenerateParam(BaseModel): use_nsfw_censor: bool = False @field_validator("sampler_name") - def sampler_name_must_be_known(cls, v): + def sampler_name_must_be_known(cls, v: str | KNOWN_SAMPLERS) -> str | KNOWN_SAMPLERS: """Ensure that the sampler name is in this list of supported samplers.""" if v not in KNOWN_SAMPLERS.__members__: raise ValueError(f"Unknown sampler name {v}") return v @field_validator("seed") - def seed_to_int_if_str(cls, v): + def seed_to_int_if_str(cls, v: str | int) -> str | int: """Ensure that the seed is an integer. If it is a string, convert it to an integer.""" return str(seed_to_int(v)) diff --git a/src/horde_sdk/ai_horde_api/apimodels/generate/_async.py b/src/horde_sdk/ai_horde_api/apimodels/generate/_async.py index bc73a2a..e79541a 100644 --- a/src/horde_sdk/ai_horde_api/apimodels/generate/_async.py +++ b/src/horde_sdk/ai_horde_api/apimodels/generate/_async.py @@ -57,7 +57,7 @@ class ImageGenerateAsyncRequest( replacement_filter: bool = True @model_validator(mode="before") - def validate_censor_nsfw(cls, values): + def validate_censor_nsfw(cls, values: dict) -> dict: if values.get("censor_nsfw", None) and values.get("nsfw", None): raise ValueError("censor_nsfw is only valid when nsfw is False") return values diff --git a/src/horde_sdk/ai_horde_api/apimodels/generate/_pop.py b/src/horde_sdk/ai_horde_api/apimodels/generate/_pop.py index a3ff4a5..5bbbe9c 100644 --- a/src/horde_sdk/ai_horde_api/apimodels/generate/_pop.py +++ b/src/horde_sdk/ai_horde_api/apimodels/generate/_pop.py @@ -79,7 +79,7 @@ class ImageGenerateJobResponse(BaseResponse): """The r2 upload link to use to upload this image.""" @field_validator("source_processing") - def source_processing_must_be_known(cls, v): + def source_processing_must_be_known(cls, v: str | KNOWN_SOURCE_PROCESSING) -> str | KNOWN_SOURCE_PROCESSING: """Ensure that the source processing is in this list of supported source processing.""" if v not in KNOWN_SOURCE_PROCESSING.__members__: raise ValueError(f"Unknown source processing {v}") diff --git a/src/horde_sdk/ai_horde_api/fields.py b/src/horde_sdk/ai_horde_api/fields.py index 40324b8..105c159 100644 --- a/src/horde_sdk/ai_horde_api/fields.py +++ b/src/horde_sdk/ai_horde_api/fields.py @@ -17,8 +17,11 @@ def id_as_uuid(self) -> uuid.UUID: return uuid.UUID(str(self.id), version=4) @field_validator("id") - def id_must_be_uuid(cls, v): + def id_must_be_uuid(cls, v: str | uuid.UUID) -> str | uuid.UUID: """Ensure that the ID is a valid UUID.""" + if isinstance(v, uuid.UUID): + return v + try: uuid.UUID(v, version=4) except ValueError as e: diff --git a/src/horde_sdk/generic_api/apimodels.py b/src/horde_sdk/generic_api/apimodels.py index c67e20e..834df93 100644 --- a/src/horde_sdk/generic_api/apimodels.py +++ b/src/horde_sdk/generic_api/apimodels.py @@ -21,7 +21,7 @@ def get_api_model_name(cls) -> str | None: such as for a GET request. """ - def to_json_horde_sdk_safe(self): + def to_json_horde_sdk_safe(self) -> str: """Return the model as a JSON string, taking into account the paradigms of the horde_sdk. If you use the default json dumping behavior, you will find some rough edges, such as alias @@ -85,7 +85,7 @@ def from_dict_or_array(cls, dict_or_array: dict | list) -> Self: return cls(**dict_or_array) @override - def to_json_horde_sdk_safe(self): + def to_json_horde_sdk_safe(self) -> str: # TODO: Is there a more pydantic way to do this? if self.is_array_response(): self_array = self.get_array() diff --git a/src/horde_sdk/generic_api/utils/swagger.py b/src/horde_sdk/generic_api/utils/swagger.py index 289ed92..28ec44d 100644 --- a/src/horde_sdk/generic_api/utils/swagger.py +++ b/src/horde_sdk/generic_api/utils/swagger.py @@ -99,7 +99,7 @@ class SwaggerModelDefinitionSchemaValidation(SwaggerModelEntry): """The model must match at least one of the schemas in this list.""" @model_validator(mode="before") - def one_method_specified(cls, v): + def one_method_specified(cls, v: dict) -> dict: """Ensure at least one of the validation methods is specified.""" if not any([v.get("allOf"), v.get("oneOf"), v.get("anyOf")]): raise ValueError("At least one of allOf, oneOf, or anyOf must be specified.") @@ -257,7 +257,7 @@ def get_defined_endpoints(self) -> dict[str, SwaggerEndpointMethod]: return return_dict @model_validator(mode="before") - def at_least_one_method_specified(cls, v): + def at_least_one_method_specified(cls, v: dict) -> dict: """Ensure at least one method is specified.""" if not any( [ diff --git a/src/horde_sdk/scripts/write_all_payload_examples.py b/src/horde_sdk/scripts/write_all_payload_examples.py index ac5777a..9586002 100644 --- a/src/horde_sdk/scripts/write_all_payload_examples.py +++ b/src/horde_sdk/scripts/write_all_payload_examples.py @@ -4,7 +4,7 @@ from horde_sdk.generic_api.utils.swagger import SwaggerParser -def main(*, test_data_path: Path | None = None): +def main(*, test_data_path: Path | None = None) -> None: ai_horde_swagger_doc = SwaggerParser( swagger_doc_url=get_ai_horde_swagger_url(), ).get_swagger_doc() diff --git a/src/horde_sdk/scripts/write_all_response_examples.py b/src/horde_sdk/scripts/write_all_response_examples.py index 9c3c970..1728b2f 100644 --- a/src/horde_sdk/scripts/write_all_response_examples.py +++ b/src/horde_sdk/scripts/write_all_response_examples.py @@ -4,7 +4,7 @@ from horde_sdk.generic_api.utils.swagger import SwaggerParser -def main(*, test_data_path: Path | None = None): +def main(*, test_data_path: Path | None = None) -> None: ai_horde_swagger_doc = SwaggerParser( swagger_doc_url=get_ai_horde_swagger_url(), ).get_swagger_doc() diff --git a/src/horde_sdk/utils.py b/src/horde_sdk/utils.py index fb83c2c..0d06ffd 100644 --- a/src/horde_sdk/utils.py +++ b/src/horde_sdk/utils.py @@ -1,8 +1,8 @@ import random -def seed_to_int(self, s=None): - if type(s) is int: +def seed_to_int(s: int | str | None = None) -> int: + if isinstance(s, int): return s if s is None or s == "": # return a random int diff --git a/tests/ai_horde_api/test_ai_horde_api_models.py b/tests/ai_horde_api/test_ai_horde_api_models.py index cca9008..8cdb020 100644 --- a/tests/ai_horde_api/test_ai_horde_api_models.py +++ b/tests/ai_horde_api/test_ai_horde_api_models.py @@ -6,13 +6,13 @@ from horde_sdk.ai_horde_api.consts import KNOWN_SAMPLERS, KNOWN_SOURCE_PROCESSING -def test_api_endpoint(): +def test_api_endpoint() -> None: ImageGenerateAsyncRequest.get_api_url() ImageGenerateAsyncRequest.get_endpoint_subpath() ImageGenerateAsyncRequest.get_endpoint_url() -def test_ImageGenerateAsyncRequest(): +def test_ImageGenerateAsyncRequest() -> None: test_async_request = ImageGenerateAsyncRequest( apikey="000000000", models=["Deliberate"], diff --git a/tests/ai_horde_api/test_api_calls.py b/tests/ai_horde_api/test_api_calls.py index 664c4c6..efc5054 100644 --- a/tests/ai_horde_api/test_api_calls.py +++ b/tests/ai_horde_api/test_api_calls.py @@ -26,10 +26,10 @@ def default_image_gen_request(self) -> ImageGenerateAsyncRequest: models=["Deliberate"], ) - def test_AIHordeAPIClient_init(self): + def test_AIHordeAPIClient_init(self) -> None: AIHordeAPIClient() - def test_generate_async(self, default_image_gen_request: ImageGenerateAsyncRequest): + def test_generate_async(self, default_image_gen_request: ImageGenerateAsyncRequest) -> None: client = AIHordeAPIClient() image_async_response: ImageGenerateAsyncResponse | RequestErrorResponse = client.generate_image_async( @@ -54,7 +54,7 @@ def test_generate_async(self, default_image_gen_request: ImageGenerateAsyncReque assert isinstance(cancel_response, CancelImageGenerateRequest.get_success_response_type()) - def test_workers_all(self): + def test_workers_all(self) -> None: client = AIHordeAPIClient() api_request = AllWorkersDetailsRequest(type=WORKER_TYPE.image) diff --git a/tests/ai_horde_api/test_swagger.py b/tests/ai_horde_api/test_swagger.py index d5a8ee3..93e8548 100644 --- a/tests/ai_horde_api/test_swagger.py +++ b/tests/ai_horde_api/test_swagger.py @@ -2,11 +2,11 @@ from horde_sdk.generic_api.utils.swagger import SwaggerDoc, SwaggerParser -def test_swagger_parser_init(): +def test_swagger_parser_init() -> None: SwaggerParser(swagger_doc_url=get_ai_horde_swagger_url()) -def test_get_swagger_doc(): +def test_get_swagger_doc() -> None: parser = SwaggerParser(swagger_doc_url=get_ai_horde_swagger_url()) doc = parser.get_swagger_doc() assert isinstance(doc, SwaggerDoc) diff --git a/tests/conftest.py b/tests/conftest.py index 82d1b29..65d3c84 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -def pytest_collection_modifyitems(items): +def pytest_collection_modifyitems(items: list) -> None: """Modifies test items in place to ensure test modules run in a given order.""" MODULE_ORDER = ["tests_generic", "test_utils", "test_dynamically_check_apimodels"] # `test.scripts` must run first because it downloads the legacy database diff --git a/tests/test_generic.py b/tests/test_generic.py index 85586ee..27743ac 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -3,7 +3,7 @@ from horde_sdk.generic_api import RequestErrorResponse -def test_error_response(): +def test_error_response() -> None: with open("tests/test_data/RequestErrorResponse.json") as error_response_file: json_error_response = json.load(error_response_file) RequestErrorResponse(**json_error_response) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4cc4ee7..0c55162 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,7 +1,7 @@ from horde_sdk.generic_api.endpoints import url_with_path -def test_url_with_path(): +def test_url_with_path() -> None: example_url = "https://example.com/api/" example_path = "/example/path" example_path_no_leading_slash = "example/path"