Skip to content

Commit

Permalink
fix: better mulitprocess lora handling
Browse files Browse the repository at this point in the history
This should hopefully help with the problem of the different inference processes falling out of sync with each other, especially as it pertains to overflows for the lora disk cache (IE, there are more models on disk than the worker operator specified in their config).
  • Loading branch information
tazlin committed Mar 3, 2024
1 parent 8f39f68 commit 1f588ed
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 31 deletions.
1 change: 1 addition & 0 deletions horde_worker_regen/download_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def download_all_models(purge_unused_loras: bool = False) -> None:
if SharedModelManager.manager.lora is None:
logger.error("Failed to load LORA model manager")
exit(1)
SharedModelManager.manager.lora.reset_adhoc_loras()
SharedModelManager.manager.lora.download_default_loras(bridge_data.nsfw)
SharedModelManager.manager.lora.wait_for_downloads(600)
SharedModelManager.manager.lora.wait_for_adhoc_reset(15)
Expand Down
76 changes: 45 additions & 31 deletions horde_worker_regen/process_management/inference_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,15 @@ class HordeInferenceProcess(HordeProcess):

_active_model_name: str | None = None
"""The name of the currently active model. Note that other models may be loaded in RAM or VRAM."""
_aux_model_lock: Lock

def __init__(
self,
process_id: int,
process_message_queue: ProcessQueue,
pipe_connection: Connection,
inference_semaphore: Semaphore,
aux_model_lock: Lock,
disk_lock: Lock,
) -> None:
"""Initialise the HordeInferenceProcess.
Expand All @@ -95,6 +97,8 @@ def __init__(
disk_lock=disk_lock,
)

self._aux_model_lock = aux_model_lock

# We import these here to guard against potentially importing them in the main process
# which would create shared objects, potentially causing issues
try:
Expand Down Expand Up @@ -247,46 +251,56 @@ def download_aux_models(self, job_info: ImageGenerateJobPopResponse) -> float |
float | None: The time elapsed during downloading, or None if no models were downloaded.
"""

time_start = time.time()
with self._aux_model_lock:
time_start = time.time()

lora_manager = self._shared_model_manager.manager.lora
if lora_manager is None:
raise RuntimeError("Failed to load LORA model manager")
lora_manager = self._shared_model_manager.manager.lora
if lora_manager is None:
raise RuntimeError("Failed to load LORA model manager")

ti_manager = self._shared_model_manager.manager.ti
if ti_manager is None:
raise RuntimeError("Failed to load TI model manager")
ti_manager = self._shared_model_manager.manager.ti
if ti_manager is None:
raise RuntimeError("Failed to load TI model manager")

performed_a_download = False
performed_a_download = False

loras = job_info.payload.loras
tis = job_info.payload.tis
loras = job_info.payload.loras
tis = job_info.payload.tis

for lora_entry in loras:
if not lora_manager.is_model_available(lora_entry.name):
if not performed_a_download:
self.send_aux_model_message(
process_state=HordeProcessState.DOWNLOADING_AUX_MODEL,
info="Downloading auxiliary models",
time_elapsed=0.0,
job_info=job_info,
)
try:
lora_manager.load_model_database()
lora_manager.reset_adhoc_loras()
except Exception as e:
logger.error(f"Failed to reset adhoc loras: {type(e).__name__} {e}")

for lora_entry in loras:
if not lora_manager.is_model_available(lora_entry.name):
if not performed_a_download:
self.send_aux_model_message(
process_state=HordeProcessState.DOWNLOADING_AUX_MODEL,
info="Downloading auxiliary models",
time_elapsed=0.0,
job_info=job_info,
)
performed_a_download = True
lora_manager.fetch_adhoc_lora(lora_entry.name, timeout=45, is_version=lora_entry.is_version)
lora_manager.wait_for_downloads(45)

for ti_entry in tis:
if not ti_manager.is_model_available(ti_entry.name):
performed_a_download = True
lora_manager.fetch_adhoc_lora(lora_entry.name, timeout=45, is_version=lora_entry.is_version)
lora_manager.wait_for_downloads(45)
ti_manager.fetch_adhoc_ti(ti_entry.name, timeout=45)
ti_manager.wait_for_downloads(45)

time_elapsed = round(time.time() - time_start, 2)

for ti_entry in tis:
if not ti_manager.is_model_available(ti_entry.name):
performed_a_download = True
ti_manager.fetch_adhoc_ti(ti_entry.name, timeout=45)
ti_manager.wait_for_downloads(45)
lora_manager.save_cached_reference_to_disk()

time_elapsed = round(time.time() - time_start, 2)
if performed_a_download:
logger.info(f"Downloaded auxiliary models in {time_elapsed} seconds")
return time_elapsed
if performed_a_download:
logger.info(f"Downloaded auxiliary models in {time_elapsed} seconds")
return time_elapsed

return None
return None

def preload_model(
self,
Expand Down
5 changes: 5 additions & 0 deletions horde_worker_regen/process_management/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,8 @@ def num_total_processes(self) -> int:
"""A semaphore that limits the number of inference processes that can run at once."""
_disk_lock: Lock_MultiProcessing

_aux_model_lock: Lock_MultiProcessing

_shutting_down = False

_lru: LRUCache
Expand Down Expand Up @@ -776,6 +778,8 @@ def __init__(
self._max_concurrent_inference_processes = bridge_data.max_threads
self._inference_semaphore = Semaphore(self._max_concurrent_inference_processes, ctx=ctx)

self._aux_model_lock = Lock_MultiProcessing(ctx=ctx)

self.max_inference_processes = self.bridge_data.queue_size + self.bridge_data.max_threads
self._lru = LRUCache(self.max_inference_processes)

Expand Down Expand Up @@ -1031,6 +1035,7 @@ def _start_inference_process(self, pid: int) -> HordeProcessInfo:
child_pipe_connection,
self._inference_semaphore,
self._disk_lock,
self._aux_model_lock,
),
kwargs={"high_memory_mode": self.bridge_data.high_memory_mode},
)
Expand Down
2 changes: 2 additions & 0 deletions horde_worker_regen/process_management/worker_entry_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def start_inference_process(
pipe_connection: Connection,
inference_semaphore: Semaphore,
disk_lock: Lock,
aux_model_lock: Lock,
*,
high_memory_mode: bool = False,
) -> None:
Expand Down Expand Up @@ -70,6 +71,7 @@ def start_inference_process(
pipe_connection=pipe_connection,
inference_semaphore=inference_semaphore,
disk_lock=disk_lock,
aux_model_lock=aux_model_lock,
)

worker_process.main_loop()
Expand Down

0 comments on commit 1f588ed

Please sign in to comment.