Skip to content

Commit

Permalink
fix: less flux slowdown
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Sep 21, 2024
1 parent 0985fa6 commit 797f9c8
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 41 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ repos:
- id: mypy
args: []
additional_dependencies:
- pydantic==2.7.4
- pydantic==2.9.2
- types-requests
- types-pytz
- types-setuptools
Expand All @@ -40,7 +40,7 @@ repos:
- horde_safety==0.2.3
- torch==2.3.1
- ruamel.yaml
- horde_engine==2.15.0
- horde_sdk==0.14.3
- horde_engine==2.15.1
- horde_sdk==0.14.7
- horde_model_reference==0.9.0
- semver
2 changes: 1 addition & 1 deletion horde-bridge.cmd
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ cd /d %~dp0
call runtime python -s -m pip -V

call python -s -m pip uninstall hordelib
call python -s -m pip install horde_sdk~=0.14.3 horde_model_reference~=0.9.0 horde_engine~=2.15.0 horde_safety~=0.2.3 -U
call python -s -m pip install horde_sdk~=0.14.7 horde_model_reference~=0.9.0 horde_engine~=2.15.1 horde_safety~=0.2.3 -U

if %ERRORLEVEL% NEQ 0 (
echo "Please run update-runtime.cmd."
Expand Down
57 changes: 40 additions & 17 deletions horde_worker_regen/process_management/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@
from horde_sdk.ai_horde_api.fields import JobID
from loguru import logger
from pydantic import BaseModel, ConfigDict, RootModel, ValidationError
from typing_extensions import override
from typing import Literal, Union
from typing_extensions import override, TypeAlias

import horde_worker_regen
from horde_worker_regen.bridge_data.data_model import reGenBridgeData
Expand Down Expand Up @@ -103,18 +104,15 @@
_async_client_exceptions = (asyncio.exceptions.TimeoutError, aiohttp.client_exceptions.ClientError, OSError)

_excludes_for_job_dump = {
"job_image_results": ...,
"job_image_results": True,
"sdk_api_job_info": {
"payload": {
"prompt",
"special",
},
"skipped": ...,
"source_image": ...,
"source_mask": ...,
"extra_source_images": ...,
"r2_upload": ...,
"r2_uploads": ...,
"payload": {"prompt": True, "special": True},
"skipped": True,
"source_image": True,
"source_mask": True,
"extra_source_images": True,
"r2_upload": True,
"r2_uploads": True,
},
}

Expand Down Expand Up @@ -1740,7 +1738,7 @@ def receive_and_handle_process_messages(self) -> None:
)

logger.debug(
f"Job data: {message.sdk_api_job_info.model_dump(exclude=_excludes_for_job_dump)}",
f"Job data: {message.sdk_api_job_info.model_dump(exclude=_excludes_for_job_dump)}", # type: ignore
)

self.completed_jobs.append(job_info)
Expand Down Expand Up @@ -2789,7 +2787,7 @@ async def api_submit_job(self) -> None:
):

model_dump = hji.model_dump(
exclude=_excludes_for_job_dump,
exclude=_excludes_for_job_dump, # type: ignore
)
if (
self.stable_diffusion_reference is not None
Expand Down Expand Up @@ -3158,13 +3156,28 @@ async def _get_source_images(self, job_pop_response: ImageGenerateJobPopResponse
return job_pop_response

_last_pop_no_jobs_available: bool = False
_too_many_consecutive_failed_jobs: bool = False
_too_many_consecutive_failed_jobs_time: float = 0.0
_too_many_consecutive_failed_jobs_wait_time = 180

@logger.catch(reraise=True)
async def api_job_pop(self) -> None:
"""If the job deque is not full, add any jobs that are available to the job deque."""
if self._shutting_down:
return

cur_time = time.time()

if self._too_many_consecutive_failed_jobs:
if (
cur_time - self._too_many_consecutive_failed_jobs_time
> self._too_many_consecutive_failed_jobs_wait_time
):
self._too_many_consecutive_failed_jobs = False
self._too_many_consecutive_failed_jobs_time = 0
logger.debug("Resuming job pops after too many consecutive failed jobs")
return

if self._consecutive_failed_jobs >= 3:
logger.error(
"Too many consecutive failed jobs, pausing job pops. "
Expand All @@ -3174,9 +3187,8 @@ async def api_job_pop(self) -> None:
if self.bridge_data.exit_on_unhandled_faults:
logger.error("Exiting due to exit_on_unhandled_faults being enabled")
self._abort()
await asyncio.sleep(180)
self._consecutive_failed_jobs = 0
logger.info("Resuming job pops")
self._too_many_consecutive_failed_jobs = True
self._too_many_consecutive_failed_jobs_time = cur_time
return

max_jobs_in_queue = self.bridge_data.queue_size + 1
Expand Down Expand Up @@ -3937,6 +3949,17 @@ def print_status_method(self) -> None:
"mode. Consider disabling `extra_slow_worker` in your config.",
)

if self._too_many_consecutive_failed_jobs:
time_since_failure = time.time() - self._too_many_consecutive_failed_jobs_time
logger.error(
"Too many consecutive failed jobs. This may be due to a misconfiguration or other issue. "
"Please check your logs and configuration.",
)
logger.error(
f"Time since last job failure: {time_since_failure:.2f}s). "
f"{self._too_many_consecutive_failed_jobs_wait_time} seconds must pass before resuming.",
)

self._last_status_message_time = time.time()

_bridge_data_loop_interval = 1.0
Expand Down
20 changes: 12 additions & 8 deletions horde_worker_regen/process_management/worker_entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,21 +64,25 @@ def start_inference_process(
if amd_gpu:
extra_comfyui_args.append("--use-pytorch-cross-attention")

models_not_to_force_load = []
models_not_to_force_load = ["flux"]

if very_high_memory_mode:
extra_comfyui_args.append("--gpu-only")
elif high_memory_mode:
extra_comfyui_args.append("--normalvram")
models_not_to_force_load = [
"cascade",
]
models_not_to_force_load.extend(
[
"cascade",
],
)
elif low_memory_mode:
extra_comfyui_args.append("--novram")
models_not_to_force_load = [
"sdxl",
"cascade",
]
models_not_to_force_load.extend(
[
"sdxl",
"cascade",
],
)

with logger.catch(reraise=True):
hordelib.initialise(
Expand Down
12 changes: 6 additions & 6 deletions requirements.dev.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
pytest==8.3.1
mypy==1.11.0
black==24.4.2
ruff==0.5.4
tox~=4.16.0
pre-commit~=3.7.1
pytest==8.3.3
mypy==1.11.2
black==24.8.0
ruff==0.6.5
tox~=4.18.1
pre-commit~=3.8.0
build>=0.10.0
coverage>=7.2.7

Expand Down
6 changes: 3 additions & 3 deletions requirements.rocm.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
numpy==1.26.4
torch==2.3.1+rocm6.0

horde_sdk~=0.14.3
horde_sdk~=0.14.7
horde_safety~=0.2.3
horde_engine~=2.15.0
horde_engine~=2.15.1
horde_model_reference~=0.9.0

python-dotenv
Expand All @@ -13,7 +13,7 @@ wheel

python-Levenshtein

pydantic>=2.7.4
pydantic>=2.9.2
typing_extensions
requests
StrEnum
Expand Down
6 changes: 3 additions & 3 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
numpy==1.26.4
torch==2.3.1

horde_sdk~=0.14.3
horde_sdk~=0.14.7
horde_safety~=0.2.3
horde_engine~=2.15.0
horde_engine~=2.15.1
horde_model_reference>=0.9.0

python-dotenv
Expand All @@ -12,7 +12,7 @@ semver

python-Levenshtein

pydantic>=2.7.4
pydantic>=2.9.2
typing_extensions
requests
StrEnum
Expand Down

0 comments on commit 797f9c8

Please sign in to comment.