diff --git a/.github/workflows/maintests.yml b/.github/workflows/maintests.yml index fd458cb..51edb84 100644 --- a/.github/workflows/maintests.yml +++ b/.github/workflows/maintests.yml @@ -19,8 +19,9 @@ jobs: build: env: AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache - HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1 TESTS_ONGOING: 1 + HORDE_SDK_TESTING: 1 + HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1 runs-on: ubuntu-latest strategy: matrix: diff --git a/.github/workflows/prtests.yml b/.github/workflows/prtests.yml index 38990ce..144c384 100644 --- a/.github/workflows/prtests.yml +++ b/.github/workflows/prtests.yml @@ -23,6 +23,7 @@ jobs: env: AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache TESTS_ONGOING: 1 + HORDE_SDK_TESTING: 1 HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1 runs-on: ubuntu-latest strategy: diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 463a2a7..51ad9fc 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,12 +21,13 @@ repos: pass_filenames: false additional_dependencies: [ pytest, - pydantic, + pydantic==2.7.4, types-Pillow, types-requests, types-pytz, types-setuptools, types-urllib3, types-aiofiles, - StrEnum + StrEnum, + horde_model_reference==0.8.1, ] diff --git a/horde_sdk/ai_horde_api/apimodels/base.py b/horde_sdk/ai_horde_api/apimodels/base.py index 0d410ea..12f5943 100644 --- a/horde_sdk/ai_horde_api/apimodels/base.py +++ b/horde_sdk/ai_horde_api/apimodels/base.py @@ -219,7 +219,7 @@ class ImageGenerateParamMixin(HordeAPIDataObject): karras: bool = True """Set to True if you want to use the Karras scheduling.""" tiling: bool = False - """Deprecated.""" + """Set to True if you want to use seamless tiling.""" hires_fix: bool = False """Set to True if you want to use the hires fix.""" hires_fix_denoising_strength: float | None = Field(default=None, ge=0, le=1) @@ -234,17 +234,17 @@ class ImageGenerateParamMixin(HordeAPIDataObject): """Set to True if you want the ControlNet map returned instead of a generated image.""" facefixer_strength: float | None = Field(default=None, ge=0, le=1) """The strength of the facefixer model.""" - loras: list[LorasPayloadEntry] = Field(default_factory=list) + loras: list[LorasPayloadEntry] | None = None """A list of lora parameters to use.""" - tis: list[TIPayloadEntry] = Field(default_factory=list) + tis: list[TIPayloadEntry] | None = None """A list of textual inversion (embedding) parameters to use.""" - extra_texts: list[ExtraTextEntry] = Field(default_factory=list) + extra_texts: list[ExtraTextEntry] | None = None """A list of extra texts and prompts to use in the comfyUI workflow.""" workflow: str | KNOWN_WORKFLOWS | None = None """The specific comfyUI workflow to use.""" transparent: bool | None = None """When true, will generate an image with a transparent background""" - special: dict[Any, Any] = Field(default_factory=dict) + special: dict[Any, Any] | None = None """Reserved for future use.""" use_nsfw_censor: bool = False """If the request is SFW, and the worker accidentally generates NSFW, it will send back a censored image.""" diff --git a/horde_sdk/ai_horde_api/apimodels/generate/_pop.py b/horde_sdk/ai_horde_api/apimodels/generate/_pop.py index 03c5514..7de8309 100644 --- a/horde_sdk/ai_horde_api/apimodels/generate/_pop.py +++ b/horde_sdk/ai_horde_api/apimodels/generate/_pop.py @@ -253,8 +253,15 @@ def validate_id(cls, v: str | JobID) -> JobID | str: return v + _ids_present: bool = False + + @property + def ids_present(self) -> bool: + """Whether or not the IDs are present.""" + return self._ids_present + @model_validator(mode="after") - def ids_present(self) -> ImageGenerateJobPopResponse: + def validate_ids_present(self) -> ImageGenerateJobPopResponse: """Ensure that either id_ or ids is present.""" if self.model is None: if self.skipped.is_empty(): @@ -270,6 +277,8 @@ def ids_present(self) -> ImageGenerateJobPopResponse: logger.debug("Sorting IDs") self.ids.sort() + self._ids_present = True + return self @override @@ -418,11 +427,9 @@ class PopInput(HordeAPIObject): max_length=1000, ) """The worker name, version and website.""" - models: list[str] | None = None + models: list[str] """The models this worker can generate.""" - name: str | None = Field( - None, - ) + name: str """The Name of the Worker.""" nsfw: bool | None = Field( False, diff --git a/horde_sdk/generic_api/apimodels.py b/horde_sdk/generic_api/apimodels.py index 79b5243..377f03a 100644 --- a/horde_sdk/generic_api/apimodels.py +++ b/horde_sdk/generic_api/apimodels.py @@ -79,9 +79,14 @@ def get_extra_fields_to_exclude_from_log(self) -> set[str]: """Return an additional set of fields to exclude from the log_safe_model_dump method.""" return set() - def log_safe_model_dump(self) -> dict[Any, Any]: + def log_safe_model_dump(self, extra_exclude: set[str] | None = None) -> dict[Any, Any]: """Return a dict of the model's fields, with any sensitive fields redacted.""" - return self.model_dump(exclude=self.get_sensitive_fields() | self.get_extra_fields_to_exclude_from_log()) + if extra_exclude is None: + extra_exclude = set() + + return self.model_dump( + exclude=self.get_sensitive_fields() | self.get_extra_fields_to_exclude_from_log() | extra_exclude, + ) class HordeResponse(HordeAPIMessage): diff --git a/pyproject.toml b/pyproject.toml index a9d3523..74a6393 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,7 +128,9 @@ exclude = [ concurrency = ["gevent"] [tool.pytest.ini_options] +# You can use `and`, `or`, `not` and parentheses to filter with the `-m` option markers = [ - # "slow: marks tests as slow (deselect with '-m \"not slow\"')", + "slow: marks tests as slow (deselect with '-m \"not slow\"')", "object_verify: marks tests that verify the API object structure and layout", + "api_side_ci: indicates that the test is intended to run during CI for the API", ] diff --git a/requirements.txt b/requirements.txt index d6d29b6..1e9eff4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ -horde_model_reference~=0.7.0 +horde_model_reference~=0.8.1 -pydantic +pydantic==2.7.4 requests StrEnum loguru diff --git a/tests/ai_horde_api/test_ai_worker_roundtrip_api_calls.py b/tests/ai_horde_api/test_ai_worker_roundtrip_api_calls.py new file mode 100644 index 0000000..2a6fb8d --- /dev/null +++ b/tests/ai_horde_api/test_ai_worker_roundtrip_api_calls.py @@ -0,0 +1,174 @@ +import asyncio + +import aiohttp +import PIL.Image +import pytest +import yarl +from loguru import logger + +from horde_sdk.ai_horde_api.ai_horde_clients import ( + AIHordeAPIAsyncClientSession, + AIHordeAPIAsyncSimpleClient, +) +from horde_sdk.ai_horde_api.apimodels import ( + ImageGenerateAsyncRequest, + ImageGenerateJobPopRequest, + ImageGenerateJobPopResponse, + ImageGenerateStatusResponse, + ImageGenerationJobSubmitRequest, + JobSubmitResponse, +) +from horde_sdk.ai_horde_api.consts import ( + GENERATION_STATE, +) +from horde_sdk.ai_horde_api.fields import JobID + + +class TestImageWorkerRoundtrip: + async def fake_worker_checkin( + self, + aiohttp_session: aiohttp.ClientSession, + horde_client_session: AIHordeAPIAsyncClientSession, + image_gen_request: ImageGenerateAsyncRequest, + ) -> None: + assert image_gen_request.params is not None + + effective_resolution = (image_gen_request.params.width * image_gen_request.params.height) * 2 + + job_pop_request = ImageGenerateJobPopRequest( + name="fake CI worker", + bridge_agent="AI Horde Worker reGen:8.0.1-citests:https://github.com/Haidra-Org/horde-worker-reGen", + max_pixels=effective_resolution, + models=image_gen_request.models, + ) + + job_pop_response = await horde_client_session.submit_request( + job_pop_request, + job_pop_request.get_default_success_response_type(), + ) + + assert isinstance(job_pop_response, ImageGenerateJobPopResponse) + logger.info(f"{job_pop_response.log_safe_model_dump({'skipped'})}") + + assert not job_pop_response.ids_present + assert job_pop_response.skipped is not None + + logger.info(f"Checked in as fake worker ({effective_resolution}): {job_pop_response.skipped}") + + async def fake_worker( + self, + aiohttp_session: aiohttp.ClientSession, + horde_client_session: AIHordeAPIAsyncClientSession, + image_gen_request: ImageGenerateAsyncRequest, + ) -> None: + assert image_gen_request.params is not None + + effective_resolution = (image_gen_request.params.width * image_gen_request.params.height) * 2 + + job_pop_request = ImageGenerateJobPopRequest( + name="fake CI worker", + bridge_agent="AI Horde Worker reGen:8.0.1-citests:https://github.com/Haidra-Org/horde-worker-reGen", + max_pixels=effective_resolution, + models=image_gen_request.models, + ) + + max_tries = 5 + try_wait = 5 + current_try = 0 + + while True: + job_pop_response = await horde_client_session.submit_request( + job_pop_request, + job_pop_request.get_default_success_response_type(), + ) + + assert isinstance(job_pop_response, ImageGenerateJobPopResponse) + logger.info(f"{job_pop_response.log_safe_model_dump({'skipped'})}") + logger.info(f"Checked in as fake worker ({effective_resolution}): {job_pop_response.skipped}") + + if not job_pop_response.ids_present: + if current_try >= max_tries: + raise RuntimeError("Max tries exceeded") + + logger.info(f"Waiting {try_wait} seconds before retrying") + await asyncio.sleep(try_wait) + current_try += 1 + continue + + # We're going to send a blank image base64 encoded + fake_image = PIL.Image.new( + "RGB", + (image_gen_request.params.width, image_gen_request.params.height), + (255, 255, 255), + ) + + fake_image_bytes = fake_image.tobytes() + + r2_url = job_pop_response.r2_upload + + assert r2_url is not None + + async with aiohttp_session.put( + yarl.URL(r2_url, encoded=True), + data=fake_image_bytes, + skip_auto_headers=["content-type"], + timeout=aiohttp.ClientTimeout(total=10), + ) as response: + assert response.status == 200 + + assert job_pop_response.ids is not None + assert len(job_pop_response.ids) == 1 + + job_submit_request = ImageGenerationJobSubmitRequest( + id=job_pop_response.ids[0], + state=GENERATION_STATE.ok, + generation="R2", + seed="1312", + ) + + job_submit_response = await horde_client_session.submit_request( + job_submit_request, + job_submit_request.get_default_success_response_type(), + ) + + assert isinstance(job_submit_response, JobSubmitResponse) + assert job_submit_response.reward is not None and job_submit_response.reward > 0 + + break + + @pytest.mark.api_side_ci + @pytest.mark.asyncio + async def test_basic_image_roundtrip(self, simple_image_gen_request: ImageGenerateAsyncRequest) -> None: + aiohttp_session = aiohttp.ClientSession() + horde_client_session = AIHordeAPIAsyncClientSession(aiohttp_session) + + async with aiohttp_session, horde_client_session: + simple_client = AIHordeAPIAsyncSimpleClient(horde_client_session=horde_client_session) + + await self.fake_worker_checkin(aiohttp_session, horde_client_session, simple_image_gen_request) + + image_gen_task = asyncio.create_task(simple_client.image_generate_request(simple_image_gen_request)) + + fake_worker_task = asyncio.create_task( + self.fake_worker( + aiohttp_session, + horde_client_session, + simple_image_gen_request, + ), + ) + + await asyncio.gather(image_gen_task, fake_worker_task) + + image_gen_response, job_id = image_gen_task.result() + + assert isinstance(image_gen_response, ImageGenerateStatusResponse) + assert isinstance(job_id, JobID) + + assert len(image_gen_response.generations) == 1 + + generation = image_gen_response.generations[0] + assert generation.seed == "1312" + assert generation.img is not None + assert not generation.gen_metadata + + assert generation.censored is False diff --git a/tests/conftest.py b/tests/conftest.py index d29c349..bdec2f1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ import pathlib import pytest +from loguru import logger os.environ["TESTS_ONGOING"] = "1" @@ -15,6 +16,21 @@ def check_tests_ongoing_env_var() -> None: """Checks that the TESTS_ONGOING environment variable is set.""" assert os.getenv("TESTS_ONGOING", None) is not None, "TESTS_ONGOING environment variable not set" + AI_HORDE_TESTING = os.getenv("AI_HORDE_TESTING", None) + HORDE_SDK_TESTING = os.getenv("HORDE_SDK_TESTING", None) + if AI_HORDE_TESTING is None and HORDE_SDK_TESTING is None: + logger.warning( + "Neither AI_HORDE_TESTING nor HORDE_SDK_TESTING environment variables are set. " + "Is this a local development test run? If so, set AI_HORDE_TESTING=1 or HORDE_SDK_TESTING=1 to suppress " + "this warning", + ) + + if AI_HORDE_TESTING is not None: + logger.info("AI_HORDE_TESTING environment variable set.") + + if HORDE_SDK_TESTING is not None: + logger.info("HORDE_SDK_TESTING environment variable set.") + @pytest.fixture(scope="session") def ai_horde_api_key() -> str: @@ -30,7 +46,7 @@ def simple_image_gen_request(ai_horde_api_key: str) -> ImageGenerateAsyncRequest prompt="a cat in a hat", models=["Deliberate"], params=ImageGenerationInputPayload( - steps=1, + steps=5, n=1, ), ) diff --git a/tox.ini b/tox.ini index 0d0db14..ac5c805 100644 --- a/tox.ini +++ b/tox.ini @@ -14,6 +14,8 @@ skip_empty = True description = base evironment passenv = AIWORKER_CACHE_HOME + HORDE_SDK_TESTING + AI_HORDE_TESTING TESTS_ONGOING [testenv:pre-commit] @@ -33,7 +35,7 @@ deps = requests -r requirements.txt commands = - pytest tests {posargs} --cov + pytest tests {posargs} --cov -m "not api_side_ci" [testenv:tests-no-api-calls] @@ -48,4 +50,4 @@ deps = requests -r requirements.txt commands = - pytest tests {posargs} --ignore-glob=*api_calls.py --cov + pytest tests {posargs} --ignore-glob=*api_calls.py -m "not api_side_ci"