Skip to content

Commit

Permalink
chore: mypy disallow_untyped_defs = True
Browse files Browse the repository at this point in the history
This requires function signatures to be typed (no 'dynamic' functions), effectively making this library entirely typed.
  • Loading branch information
tazlin committed Jul 11, 2023
1 parent 62c69fa commit 9adf702
Show file tree
Hide file tree
Showing 17 changed files with 29 additions and 24 deletions.
2 changes: 1 addition & 1 deletion codegen/codegen_regex_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys


def main(path):
def main(path: str) -> None:
print(f"Processing {path}")
with open(path) as f:
contents = f.read()
Expand Down
2 changes: 2 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[mypy]
disallow_untyped_defs = True
4 changes: 2 additions & 2 deletions src/horde_sdk/ai_horde_api/apimodels/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ class BaseImageGenerateParam(BaseModel):
use_nsfw_censor: bool = False

@field_validator("sampler_name")
def sampler_name_must_be_known(cls, v):
def sampler_name_must_be_known(cls, v: str | KNOWN_SAMPLERS) -> str | KNOWN_SAMPLERS:
"""Ensure that the sampler name is in this list of supported samplers."""
if v not in KNOWN_SAMPLERS.__members__:
raise ValueError(f"Unknown sampler name {v}")
return v

@field_validator("seed")
def seed_to_int_if_str(cls, v):
def seed_to_int_if_str(cls, v: str | int) -> str | int:
"""Ensure that the seed is an integer. If it is a string, convert it to an integer."""
return str(seed_to_int(v))

Expand Down
2 changes: 1 addition & 1 deletion src/horde_sdk/ai_horde_api/apimodels/generate/_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class ImageGenerateAsyncRequest(
replacement_filter: bool = True

@model_validator(mode="before")
def validate_censor_nsfw(cls, values):
def validate_censor_nsfw(cls, values: dict) -> dict:
if values.get("censor_nsfw", None) and values.get("nsfw", None):
raise ValueError("censor_nsfw is only valid when nsfw is False")
return values
Expand Down
2 changes: 1 addition & 1 deletion src/horde_sdk/ai_horde_api/apimodels/generate/_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ class ImageGenerateJobResponse(BaseResponse):
"""The r2 upload link to use to upload this image."""

@field_validator("source_processing")
def source_processing_must_be_known(cls, v):
def source_processing_must_be_known(cls, v: str | KNOWN_SOURCE_PROCESSING) -> str | KNOWN_SOURCE_PROCESSING:
"""Ensure that the source processing is in this list of supported source processing."""
if v not in KNOWN_SOURCE_PROCESSING.__members__:
raise ValueError(f"Unknown source processing {v}")
Expand Down
5 changes: 4 additions & 1 deletion src/horde_sdk/ai_horde_api/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ def id_as_uuid(self) -> uuid.UUID:
return uuid.UUID(str(self.id), version=4)

@field_validator("id")
def id_must_be_uuid(cls, v):
def id_must_be_uuid(cls, v: str | uuid.UUID) -> str | uuid.UUID:
"""Ensure that the ID is a valid UUID."""
if isinstance(v, uuid.UUID):
return v

try:
uuid.UUID(v, version=4)
except ValueError as e:
Expand Down
4 changes: 2 additions & 2 deletions src/horde_sdk/generic_api/apimodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_api_model_name(cls) -> str | None:
such as for a GET request.
"""

def to_json_horde_sdk_safe(self):
def to_json_horde_sdk_safe(self) -> str:
"""Return the model as a JSON string, taking into account the paradigms of the horde_sdk.
If you use the default json dumping behavior, you will find some rough edges, such as alias
Expand Down Expand Up @@ -85,7 +85,7 @@ def from_dict_or_array(cls, dict_or_array: dict | list) -> Self:
return cls(**dict_or_array)

@override
def to_json_horde_sdk_safe(self):
def to_json_horde_sdk_safe(self) -> str:
# TODO: Is there a more pydantic way to do this?
if self.is_array_response():
self_array = self.get_array()
Expand Down
4 changes: 2 additions & 2 deletions src/horde_sdk/generic_api/utils/swagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class SwaggerModelDefinitionSchemaValidation(SwaggerModelEntry):
"""The model must match at least one of the schemas in this list."""

@model_validator(mode="before")
def one_method_specified(cls, v):
def one_method_specified(cls, v: dict) -> dict:
"""Ensure at least one of the validation methods is specified."""
if not any([v.get("allOf"), v.get("oneOf"), v.get("anyOf")]):
raise ValueError("At least one of allOf, oneOf, or anyOf must be specified.")
Expand Down Expand Up @@ -257,7 +257,7 @@ def get_defined_endpoints(self) -> dict[str, SwaggerEndpointMethod]:
return return_dict

@model_validator(mode="before")
def at_least_one_method_specified(cls, v):
def at_least_one_method_specified(cls, v: dict) -> dict:
"""Ensure at least one method is specified."""
if not any(
[
Expand Down
2 changes: 1 addition & 1 deletion src/horde_sdk/scripts/write_all_payload_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from horde_sdk.generic_api.utils.swagger import SwaggerParser


def main(*, test_data_path: Path | None = None):
def main(*, test_data_path: Path | None = None) -> None:
ai_horde_swagger_doc = SwaggerParser(
swagger_doc_url=get_ai_horde_swagger_url(),
).get_swagger_doc()
Expand Down
2 changes: 1 addition & 1 deletion src/horde_sdk/scripts/write_all_response_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from horde_sdk.generic_api.utils.swagger import SwaggerParser


def main(*, test_data_path: Path | None = None):
def main(*, test_data_path: Path | None = None) -> None:
ai_horde_swagger_doc = SwaggerParser(
swagger_doc_url=get_ai_horde_swagger_url(),
).get_swagger_doc()
Expand Down
4 changes: 2 additions & 2 deletions src/horde_sdk/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import random


def seed_to_int(self, s=None):
if type(s) is int:
def seed_to_int(s: int | str | None = None) -> int:
if isinstance(s, int):
return s
if s is None or s == "":
# return a random int
Expand Down
4 changes: 2 additions & 2 deletions tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from horde_sdk.ai_horde_api.consts import KNOWN_SAMPLERS, KNOWN_SOURCE_PROCESSING


def test_api_endpoint():
def test_api_endpoint() -> None:
ImageGenerateAsyncRequest.get_api_url()
ImageGenerateAsyncRequest.get_endpoint_subpath()
ImageGenerateAsyncRequest.get_endpoint_url()


def test_ImageGenerateAsyncRequest():
def test_ImageGenerateAsyncRequest() -> None:
test_async_request = ImageGenerateAsyncRequest(
apikey="000000000",
models=["Deliberate"],
Expand Down
6 changes: 3 additions & 3 deletions tests/ai_horde_api/test_api_calls.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def default_image_gen_request(self) -> ImageGenerateAsyncRequest:
models=["Deliberate"],
)

def test_AIHordeAPIClient_init(self):
def test_AIHordeAPIClient_init(self) -> None:
AIHordeAPIClient()

def test_generate_async(self, default_image_gen_request: ImageGenerateAsyncRequest):
def test_generate_async(self, default_image_gen_request: ImageGenerateAsyncRequest) -> None:
client = AIHordeAPIClient()

image_async_response: ImageGenerateAsyncResponse | RequestErrorResponse = client.generate_image_async(
Expand All @@ -54,7 +54,7 @@ def test_generate_async(self, default_image_gen_request: ImageGenerateAsyncReque

assert isinstance(cancel_response, CancelImageGenerateRequest.get_success_response_type())

def test_workers_all(self):
def test_workers_all(self) -> None:
client = AIHordeAPIClient()

api_request = AllWorkersDetailsRequest(type=WORKER_TYPE.image)
Expand Down
4 changes: 2 additions & 2 deletions tests/ai_horde_api/test_swagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from horde_sdk.generic_api.utils.swagger import SwaggerDoc, SwaggerParser


def test_swagger_parser_init():
def test_swagger_parser_init() -> None:
SwaggerParser(swagger_doc_url=get_ai_horde_swagger_url())


def test_get_swagger_doc():
def test_get_swagger_doc() -> None:
parser = SwaggerParser(swagger_doc_url=get_ai_horde_swagger_url())
doc = parser.get_swagger_doc()
assert isinstance(doc, SwaggerDoc)
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
def pytest_collection_modifyitems(items):
def pytest_collection_modifyitems(items: list) -> None:
"""Modifies test items in place to ensure test modules run in a given order."""
MODULE_ORDER = ["tests_generic", "test_utils", "test_dynamically_check_apimodels"]
# `test.scripts` must run first because it downloads the legacy database
Expand Down
2 changes: 1 addition & 1 deletion tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from horde_sdk.generic_api import RequestErrorResponse


def test_error_response():
def test_error_response() -> None:
with open("tests/test_data/RequestErrorResponse.json") as error_response_file:
json_error_response = json.load(error_response_file)
RequestErrorResponse(**json_error_response)
2 changes: 1 addition & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from horde_sdk.generic_api.endpoints import url_with_path


def test_url_with_path():
def test_url_with_path() -> None:
example_url = "https://example.com/api/"
example_path = "/example/path"
example_path_no_leading_slash = "example/path"
Expand Down

0 comments on commit 9adf702

Please sign in to comment.