Skip to content

Commit

Permalink
fix: sort r2_uploads so it matches ids sort
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Sep 26, 2024
1 parent 8efc611 commit 58ab404
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
12 changes: 9 additions & 3 deletions horde_sdk/ai_horde_api/apimodels/generate/_pop.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,14 @@ def ids_present(self) -> bool:
"""Whether or not the IDs are present."""
return self._ids_present

def _sort_ids(self) -> None:
"""Sort the IDs in place and sort so r2_uploads is changed so the same index changes occur."""
if len(self.ids) > 1:
logger.debug("Sorting IDs")
self.ids.sort()
if self.r2_uploads is not None:
self.r2_uploads.sort()

@model_validator(mode="after")
def validate_ids_present(self) -> ImageGenerateJobPopResponse:
"""Ensure that either id_ or ids is present."""
Expand All @@ -273,9 +281,7 @@ def validate_ids_present(self) -> ImageGenerateJobPopResponse:
if self.id_ is None and len(self.ids) == 0:
raise ValueError("Neither id_ nor ids were present in the response.")

if len(self.ids) > 1:
logger.debug("Sorting IDs")
self.ids.sort()
self._sort_ids()

self._ids_present = True

Expand Down
30 changes: 30 additions & 0 deletions tests/ai_horde_api/test_ai_horde_api_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,6 +460,36 @@ def test_ImageGenerateJobPopResponse() -> None:

assert all(post_processor in KNOWN_UPSCALERS for post_processor in test_response.payload.post_processing)

test_response = ImageGenerateJobPopResponse(
ids=[
JobID(root=UUID("00000000-0000-0000-0000-000000000001")),
JobID(root=UUID("00000000-0000-0000-0000-000000000002")),
JobID(root=UUID("00000000-0000-0000-0000-000000000000")),
],
payload=ImageGenerateJobPopPayload(
prompt="A cat in a hat",
),
model="Deliberate",
r2_uploads=[
"https://abbaabbaabbaabbaabbaabbaabbaabba.r2.cloudflarestorage.com/horde-transient/00000000-0000-0000-0000-000000000001.webp?AWSAccessKeyId=deadbeefdeadbeefdeadbeefdeadbeef&Signature=zxcbvfakesignature%3D&Expires=1727390285",
"https://abbaabbaabbaabbaabbaabbaabbaabba.r2.cloudflarestorage.com/horde-transient/00000000-0000-0000-0000-000000000000.webp?AWSAccessKeyId=deadbeefdeadbeefdeadbeefdeadbeef&Signature=345567dfakes2ignature%3D&Expires=1727390285",
"https://abbaabbaabbaabbaabbaabbaabbaabba.r2.cloudflarestorage.com/horde-transient/00000000-0000-0000-0000-000000000002.webp?AWSAccessKeyId=deadbeefdeadbeefdeadbeefdeadbeef&Signature=asdfg32fakesignature%3D&Expires=1727390285",
],
skipped=ImageGenerateJobPopSkippedStatus(),
)

assert test_response.ids_present
assert test_response.ids == [
JobID(root=UUID("00000000-0000-0000-0000-000000000000")),
JobID(root=UUID("00000000-0000-0000-0000-000000000001")),
JobID(root=UUID("00000000-0000-0000-0000-000000000002")),
]
assert test_response.r2_uploads == [
"https://abbaabbaabbaabbaabbaabbaabbaabba.r2.cloudflarestorage.com/horde-transient/00000000-0000-0000-0000-000000000000.webp?AWSAccessKeyId=deadbeefdeadbeefdeadbeefdeadbeef&Signature=345567dfakes2ignature%3D&Expires=1727390285",
"https://abbaabbaabbaabbaabbaabbaabbaabba.r2.cloudflarestorage.com/horde-transient/00000000-0000-0000-0000-000000000001.webp?AWSAccessKeyId=deadbeefdeadbeefdeadbeefdeadbeef&Signature=zxcbvfakesignature%3D&Expires=1727390285",
"https://abbaabbaabbaabbaabbaabbaabbaabba.r2.cloudflarestorage.com/horde-transient/00000000-0000-0000-0000-000000000002.webp?AWSAccessKeyId=deadbeefdeadbeefdeadbeefdeadbeef&Signature=asdfg32fakesignature%3D&Expires=1727390285",
]


def test_ImageGenerateJobPopResponse_hashability() -> None:
test_response_ids = ImageGenerateJobPopResponse(
Expand Down

0 comments on commit 58ab404

Please sign in to comment.