Skip to content

Commit

Permalink
Merge pull request #221 from Haidra-Org/main
Browse files Browse the repository at this point in the history
tests: api side worker test; fix: misc fixes
  • Loading branch information
tazlin authored Jul 20, 2024
2 parents e528ff1 + 6845419 commit 9ea5d79
Show file tree
Hide file tree
Showing 11 changed files with 230 additions and 21 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/maintests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ jobs:
build:
env:
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1
TESTS_ONGOING: 1
HORDE_SDK_TESTING: 1
HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1
runs-on: ubuntu-latest
strategy:
matrix:
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/prtests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
env:
AIWORKER_CACHE_HOME: ${{ github.workspace }}/.cache
TESTS_ONGOING: 1
HORDE_SDK_TESTING: 1
HORDE_MODEL_REFERENCE_MAKE_FOLDERS: 1
runs-on: ubuntu-latest
strategy:
Expand Down
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,13 @@ repos:
pass_filenames: false
additional_dependencies: [
pytest,
pydantic,
pydantic==2.7.4,
types-Pillow,
types-requests,
types-pytz,
types-setuptools,
types-urllib3,
types-aiofiles,
StrEnum
StrEnum,
horde_model_reference==0.8.1,
]
10 changes: 5 additions & 5 deletions horde_sdk/ai_horde_api/apimodels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ class ImageGenerateParamMixin(HordeAPIDataObject):
karras: bool = True
"""Set to True if you want to use the Karras scheduling."""
tiling: bool = False
"""Deprecated."""
"""Set to True if you want to use seamless tiling."""
hires_fix: bool = False
"""Set to True if you want to use the hires fix."""
hires_fix_denoising_strength: float | None = Field(default=None, ge=0, le=1)
Expand All @@ -234,17 +234,17 @@ class ImageGenerateParamMixin(HordeAPIDataObject):
"""Set to True if you want the ControlNet map returned instead of a generated image."""
facefixer_strength: float | None = Field(default=None, ge=0, le=1)
"""The strength of the facefixer model."""
loras: list[LorasPayloadEntry] = Field(default_factory=list)
loras: list[LorasPayloadEntry] | None = None
"""A list of lora parameters to use."""
tis: list[TIPayloadEntry] = Field(default_factory=list)
tis: list[TIPayloadEntry] | None = None
"""A list of textual inversion (embedding) parameters to use."""
extra_texts: list[ExtraTextEntry] = Field(default_factory=list)
extra_texts: list[ExtraTextEntry] | None = None
"""A list of extra texts and prompts to use in the comfyUI workflow."""
workflow: str | KNOWN_WORKFLOWS | None = None
"""The specific comfyUI workflow to use."""
transparent: bool | None = None
"""When true, will generate an image with a transparent background"""
special: dict[Any, Any] = Field(default_factory=dict)
special: dict[Any, Any] | None = None
"""Reserved for future use."""
use_nsfw_censor: bool = False
"""If the request is SFW, and the worker accidentally generates NSFW, it will send back a censored image."""
Expand Down
17 changes: 12 additions & 5 deletions horde_sdk/ai_horde_api/apimodels/generate/_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,15 @@ def validate_id(cls, v: str | JobID) -> JobID | str:

return v

_ids_present: bool = False

@property
def ids_present(self) -> bool:
"""Whether or not the IDs are present."""
return self._ids_present

@model_validator(mode="after")
def ids_present(self) -> ImageGenerateJobPopResponse:
def validate_ids_present(self) -> ImageGenerateJobPopResponse:
"""Ensure that either id_ or ids is present."""
if self.model is None:
if self.skipped.is_empty():
Expand All @@ -270,6 +277,8 @@ def ids_present(self) -> ImageGenerateJobPopResponse:
logger.debug("Sorting IDs")
self.ids.sort()

self._ids_present = True

return self

@override
Expand Down Expand Up @@ -418,11 +427,9 @@ class PopInput(HordeAPIObject):
max_length=1000,
)
"""The worker name, version and website."""
models: list[str] | None = None
models: list[str]
"""The models this worker can generate."""
name: str | None = Field(
None,
)
name: str
"""The Name of the Worker."""
nsfw: bool | None = Field(
False,
Expand Down
9 changes: 7 additions & 2 deletions horde_sdk/generic_api/apimodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,14 @@ def get_extra_fields_to_exclude_from_log(self) -> set[str]:
"""Return an additional set of fields to exclude from the log_safe_model_dump method."""
return set()

def log_safe_model_dump(self) -> dict[Any, Any]:
def log_safe_model_dump(self, extra_exclude: set[str] | None = None) -> dict[Any, Any]:
"""Return a dict of the model's fields, with any sensitive fields redacted."""
return self.model_dump(exclude=self.get_sensitive_fields() | self.get_extra_fields_to_exclude_from_log())
if extra_exclude is None:
extra_exclude = set()

return self.model_dump(
exclude=self.get_sensitive_fields() | self.get_extra_fields_to_exclude_from_log() | extra_exclude,
)


class HordeResponse(HordeAPIMessage):
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ exclude = [
concurrency = ["gevent"]

[tool.pytest.ini_options]
# You can use `and`, `or`, `not` and parentheses to filter with the `-m` option
markers = [
# "slow: marks tests as slow (deselect with '-m \"not slow\"')",
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
"object_verify: marks tests that verify the API object structure and layout",
"api_side_ci: indicates that the test is intended to run during CI for the API",
]
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
horde_model_reference~=0.7.0
horde_model_reference~=0.8.1

pydantic
pydantic==2.7.4
requests
StrEnum
loguru
Expand Down
174 changes: 174 additions & 0 deletions tests/ai_horde_api/test_ai_worker_roundtrip_api_calls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import asyncio

import aiohttp
import PIL.Image
import pytest
import yarl
from loguru import logger

from horde_sdk.ai_horde_api.ai_horde_clients import (
AIHordeAPIAsyncClientSession,
AIHordeAPIAsyncSimpleClient,
)
from horde_sdk.ai_horde_api.apimodels import (
ImageGenerateAsyncRequest,
ImageGenerateJobPopRequest,
ImageGenerateJobPopResponse,
ImageGenerateStatusResponse,
ImageGenerationJobSubmitRequest,
JobSubmitResponse,
)
from horde_sdk.ai_horde_api.consts import (
GENERATION_STATE,
)
from horde_sdk.ai_horde_api.fields import JobID


class TestImageWorkerRoundtrip:
async def fake_worker_checkin(
self,
aiohttp_session: aiohttp.ClientSession,
horde_client_session: AIHordeAPIAsyncClientSession,
image_gen_request: ImageGenerateAsyncRequest,
) -> None:
assert image_gen_request.params is not None

effective_resolution = (image_gen_request.params.width * image_gen_request.params.height) * 2

job_pop_request = ImageGenerateJobPopRequest(
name="fake CI worker",
bridge_agent="AI Horde Worker reGen:8.0.1-citests:https://github.com/Haidra-Org/horde-worker-reGen",
max_pixels=effective_resolution,
models=image_gen_request.models,
)

job_pop_response = await horde_client_session.submit_request(
job_pop_request,
job_pop_request.get_default_success_response_type(),
)

assert isinstance(job_pop_response, ImageGenerateJobPopResponse)
logger.info(f"{job_pop_response.log_safe_model_dump({'skipped'})}")

assert not job_pop_response.ids_present
assert job_pop_response.skipped is not None

logger.info(f"Checked in as fake worker ({effective_resolution}): {job_pop_response.skipped}")

async def fake_worker(
self,
aiohttp_session: aiohttp.ClientSession,
horde_client_session: AIHordeAPIAsyncClientSession,
image_gen_request: ImageGenerateAsyncRequest,
) -> None:
assert image_gen_request.params is not None

effective_resolution = (image_gen_request.params.width * image_gen_request.params.height) * 2

job_pop_request = ImageGenerateJobPopRequest(
name="fake CI worker",
bridge_agent="AI Horde Worker reGen:8.0.1-citests:https://github.com/Haidra-Org/horde-worker-reGen",
max_pixels=effective_resolution,
models=image_gen_request.models,
)

max_tries = 5
try_wait = 5
current_try = 0

while True:
job_pop_response = await horde_client_session.submit_request(
job_pop_request,
job_pop_request.get_default_success_response_type(),
)

assert isinstance(job_pop_response, ImageGenerateJobPopResponse)
logger.info(f"{job_pop_response.log_safe_model_dump({'skipped'})}")
logger.info(f"Checked in as fake worker ({effective_resolution}): {job_pop_response.skipped}")

if not job_pop_response.ids_present:
if current_try >= max_tries:
raise RuntimeError("Max tries exceeded")

logger.info(f"Waiting {try_wait} seconds before retrying")
await asyncio.sleep(try_wait)
current_try += 1
continue

# We're going to send a blank image base64 encoded
fake_image = PIL.Image.new(
"RGB",
(image_gen_request.params.width, image_gen_request.params.height),
(255, 255, 255),
)

fake_image_bytes = fake_image.tobytes()

r2_url = job_pop_response.r2_upload

assert r2_url is not None

async with aiohttp_session.put(
yarl.URL(r2_url, encoded=True),
data=fake_image_bytes,
skip_auto_headers=["content-type"],
timeout=aiohttp.ClientTimeout(total=10),
) as response:
assert response.status == 200

assert job_pop_response.ids is not None
assert len(job_pop_response.ids) == 1

job_submit_request = ImageGenerationJobSubmitRequest(
id=job_pop_response.ids[0],
state=GENERATION_STATE.ok,
generation="R2",
seed="1312",
)

job_submit_response = await horde_client_session.submit_request(
job_submit_request,
job_submit_request.get_default_success_response_type(),
)

assert isinstance(job_submit_response, JobSubmitResponse)
assert job_submit_response.reward is not None and job_submit_response.reward > 0

break

@pytest.mark.api_side_ci
@pytest.mark.asyncio
async def test_basic_image_roundtrip(self, simple_image_gen_request: ImageGenerateAsyncRequest) -> None:
aiohttp_session = aiohttp.ClientSession()
horde_client_session = AIHordeAPIAsyncClientSession(aiohttp_session)

async with aiohttp_session, horde_client_session:
simple_client = AIHordeAPIAsyncSimpleClient(horde_client_session=horde_client_session)

await self.fake_worker_checkin(aiohttp_session, horde_client_session, simple_image_gen_request)

image_gen_task = asyncio.create_task(simple_client.image_generate_request(simple_image_gen_request))

fake_worker_task = asyncio.create_task(
self.fake_worker(
aiohttp_session,
horde_client_session,
simple_image_gen_request,
),
)

await asyncio.gather(image_gen_task, fake_worker_task)

image_gen_response, job_id = image_gen_task.result()

assert isinstance(image_gen_response, ImageGenerateStatusResponse)
assert isinstance(job_id, JobID)

assert len(image_gen_response.generations) == 1

generation = image_gen_response.generations[0]
assert generation.seed == "1312"
assert generation.img is not None
assert not generation.gen_metadata

assert generation.censored is False
18 changes: 17 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pathlib

import pytest
from loguru import logger

os.environ["TESTS_ONGOING"] = "1"

Expand All @@ -15,6 +16,21 @@ def check_tests_ongoing_env_var() -> None:
"""Checks that the TESTS_ONGOING environment variable is set."""
assert os.getenv("TESTS_ONGOING", None) is not None, "TESTS_ONGOING environment variable not set"

AI_HORDE_TESTING = os.getenv("AI_HORDE_TESTING", None)
HORDE_SDK_TESTING = os.getenv("HORDE_SDK_TESTING", None)
if AI_HORDE_TESTING is None and HORDE_SDK_TESTING is None:
logger.warning(
"Neither AI_HORDE_TESTING nor HORDE_SDK_TESTING environment variables are set. "
"Is this a local development test run? If so, set AI_HORDE_TESTING=1 or HORDE_SDK_TESTING=1 to suppress "
"this warning",
)

if AI_HORDE_TESTING is not None:
logger.info("AI_HORDE_TESTING environment variable set.")

if HORDE_SDK_TESTING is not None:
logger.info("HORDE_SDK_TESTING environment variable set.")


@pytest.fixture(scope="session")
def ai_horde_api_key() -> str:
Expand All @@ -30,7 +46,7 @@ def simple_image_gen_request(ai_horde_api_key: str) -> ImageGenerateAsyncRequest
prompt="a cat in a hat",
models=["Deliberate"],
params=ImageGenerationInputPayload(
steps=1,
steps=5,
n=1,
),
)
Expand Down
6 changes: 4 additions & 2 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ skip_empty = True
description = base evironment
passenv =
AIWORKER_CACHE_HOME
HORDE_SDK_TESTING
AI_HORDE_TESTING
TESTS_ONGOING

[testenv:pre-commit]
Expand All @@ -33,7 +35,7 @@ deps =
requests
-r requirements.txt
commands =
pytest tests {posargs} --cov
pytest tests {posargs} --cov -m "not api_side_ci"


[testenv:tests-no-api-calls]
Expand All @@ -48,4 +50,4 @@ deps =
requests
-r requirements.txt
commands =
pytest tests {posargs} --ignore-glob=*api_calls.py --cov
pytest tests {posargs} --ignore-glob=*api_calls.py -m "not api_side_ci"

0 comments on commit 9ea5d79

Please sign in to comment.