Skip to content

Commit

Permalink
fix: logic issue with KNOWN_SAMPLERS check
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Jan 25, 2024
1 parent d4c7bde commit f8d9d6b
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
6 changes: 4 additions & 2 deletions horde_sdk/ai_horde_api/apimodels/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,10 @@ def width_divisible_by_64(cls, value: int) -> int:
@field_validator("sampler_name")
def sampler_name_must_be_known(cls, v: str | KNOWN_SAMPLERS) -> str | KNOWN_SAMPLERS:
"""Ensure that the sampler name is in this list of supported samplers."""
if (isinstance(v, str) and v not in KNOWN_SAMPLERS.__members__) or (not isinstance(v, KNOWN_SAMPLERS)):
logger.warning(f"Unknown sampler name {v}. Is your SDK out of date or did the API change?")
if (isinstance(v, str) and v in KNOWN_SAMPLERS.__members__) or (isinstance(v, KNOWN_SAMPLERS)):
return v

logger.warning(f"Unknown sampler name {v}. Is your SDK out of date or did the API change?")

return v

Expand Down
5 changes: 3 additions & 2 deletions tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def test_ImageGenerateAsyncRequest(ai_horde_api_key: str) -> None:
models=["Deliberate"],
prompt="test prompt",
params=ImageGenerationInputPayload(
sampler_name=KNOWN_SAMPLERS.k_lms,
# sampler_name="DDIM",
sampler_name=KNOWN_SAMPLERS.DDIM,
cfg_scale=7.5,
denoising_strength=1,
seed="123456789",
Expand Down Expand Up @@ -86,7 +87,7 @@ def test_ImageGenerateAsyncRequest(ai_horde_api_key: str) -> None:
assert test_async_request.models == ["Deliberate"]
assert test_async_request.prompt == "test prompt"
assert test_async_request.params is not None
assert test_async_request.params.sampler_name == "k_lms"
assert test_async_request.params.sampler_name == "DDIM"
assert test_async_request.params.cfg_scale == 7.5
assert test_async_request.params.denoising_strength == 1
assert test_async_request.params.seed is not None
Expand Down

0 comments on commit f8d9d6b

Please sign in to comment.