From 79284b5f3bba4c4d7f3111f2895e1f3f5bf68762 Mon Sep 17 00:00:00 2001 From: Christopher Childs Date: Sat, 11 Nov 2023 19:37:38 -0800 Subject: [PATCH] Replace processes when a model is unloaded On Linux, it seems like there is a tremendous amount of memory allocated by something outside Python when you allow one worker to process many jobs on many different models. In order to limit the damage from that behavior, we'll try to replace the processes when the model is scheduled to be unloaded. I realize this probably makes it a _little_ harder to decouple a process from a model, but it's a huge stability improvement. This also switches the model management strategy a tiny bit by allocating a model to every open worker before trying to unload a model. Previously, you would have N processes = `threads` + `queue`, but jobs were very likely to be scheduled on only the `threads` number of workers. --- .../process_management/process_manager.py | 222 ++++++++++++++---- 1 file changed, 182 insertions(+), 40 deletions(-) diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index 5520fed5..96651db5 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -1,11 +1,13 @@ import asyncio import base64 +import collections +import datetime import multiprocessing import os import random import sys import time -from asyncio import CancelledError +from asyncio import CancelledError, Task from asyncio import Lock as Lock_Asyncio from collections import deque from collections.abc import Mapping @@ -81,6 +83,8 @@ class HordeProcessInfo: """The type of this process.""" last_process_state: HordeProcessState """The last known state of this process.""" + last_timestamp: datetime.datetime + """Last time we updated the process info. If we're regularly working, then this value should change frequently.""" loaded_horde_model_name: str | None = None """The name of the horde model that is (supposedly) currently loaded in this process.""" @@ -115,6 +119,7 @@ def __init__( self.process_id = process_id self.process_type = process_type self.last_process_state = last_process_state + self.last_timestamp = datetime.datetime.now() def is_process_busy(self) -> bool: """Return true if the process is actively engaged in a task. @@ -186,6 +191,15 @@ def update_entry( if process_id is not None: self.root[horde_model_name].process_id = process_id + def expire_entry(self, horde_model_name: str) -> ModelInfo | None: + """ + Removes information about a horde model. + + :param horde_model_name: Name of model to remove + :return: model name if removed; 'none' string otherwise + """ + return self.root.pop(horde_model_name, None) + def is_model_loaded(self, horde_model_name: str) -> bool: """Return true if the given model is loaded in any process.""" if horde_model_name not in self.root: @@ -242,6 +256,8 @@ def update_entry( if total_vram_bytes is not None: self[process_id].total_vram_bytes = total_vram_bytes + self[process_id].last_timestamp = datetime.datetime.now() + def num_inference_processes(self) -> int: """Return the number of inference processes.""" count = 0 @@ -260,6 +276,14 @@ def num_available_inference_processes(self) -> int: def get_first_available_inference_process(self) -> HordeProcessInfo | None: """Return the first available inference process, or None if there are none available.""" + for p in self.values(): + if ( + p.process_type == HordeProcessType.INFERENCE + and p.last_process_state == HordeProcessState.WAITING_FOR_JOB + and p.loaded_horde_model_name is None + ): + return p + for p in self.values(): if p.process_type == HordeProcessType.INFERENCE and p.can_accept_job(): if p.last_process_state == HordeProcessState.PRELOADED_MODEL: @@ -336,7 +360,10 @@ def __repr__(self) -> str: base_string = "Processes: " for process_id, process_info in self.items(): if process_info.process_type == HordeProcessType.INFERENCE: - base_string += f"{process_id}: ({process_info.loaded_horde_model_name}) " + base_string += ( + f"{process_id}: ({process_info.loaded_horde_model_name} " + f"[last event: {process_info.last_timestamp}]) " + ) else: base_string += f"{process_id}: ({process_info.process_type.name}) " base_string += f"{process_info.last_process_state.name}; " @@ -380,6 +407,30 @@ def is_job_checked_for_safety(self) -> bool: return self.censored is not None +class LRUCache: + def __init__(self, capacity: int) -> None: + """ + Initializes the LRU cache. + :param capacity: Maximum number of elements in the cache. + """ + self.capacity = capacity + self.cache: "collections.OrderedDict[str, ModelInfo | None]" = collections.OrderedDict() + + def append(self, key: str) -> object: + """ + Adds an element to the LRU cache, and potentially bumps one from the cache. + :param key: the element to add + :return: the bumped element + """ + bumped = None + if key in self.cache: + self.cache.move_to_end(key) + elif len(self.cache) >= self.capacity: + bumped, _ = self.cache.popitem(last=False) + self.cache[key] = None + return bumped + + class HordeWorkerProcessManager: """Manages and controls processes to act as a horde worker.""" @@ -412,6 +463,9 @@ def max_concurrent_inference_processes(self) -> int: target_vram_overhead_bytes_map: Mapping[int, int] | None = None + process_timeout: datetime.timedelta + """Max amount of time a job can go without checking in with the main process manager""" + @property def max_queue_size(self) -> int: """The maximum number of jobs that can be queued.""" @@ -501,6 +555,8 @@ def num_total_processes(self) -> int: _shutting_down = False + _lru: LRUCache + def __init__( self, *, @@ -542,6 +598,8 @@ def __init__( self._inference_semaphore = Semaphore(self._max_concurrent_inference_processes, ctx=ctx) self.max_inference_processes = self.bridge_data.queue_size + self.bridge_data.max_threads + self._lru = LRUCache(self.max_inference_processes) + self.process_timeout = datetime.timedelta(minutes=5) # If there is only one model to load and only one inference process, then we can only run one job at a time # and there is no point in having more than one inference process @@ -743,32 +801,41 @@ def start_inference_processes(self) -> None: for _ in range(num_processes_to_start): # Create a two-way communication pipe for the parent and child processes pid = len(self._process_map) - pipe_connection, child_pipe_connection = multiprocessing.Pipe(duplex=True) - - # Create a new process that will run the start_inference_process function - process = multiprocessing.Process( - target=start_inference_process, - args=( - pid, - self._process_message_queue, - child_pipe_connection, - self._inference_semaphore, - self._disk_lock, - ), - ) + self._start_inference_process(pid) - process.start() + logger.info(f"Started inference process (id: {pid})") - # Add the process to the process map - self._process_map[pid] = HordeProcessInfo( - mp_process=process, - pipe_connection=pipe_connection, - process_id=pid, - process_type=HordeProcessType.INFERENCE, - last_process_state=HordeProcessState.PROCESS_STARTING, - ) + def _start_inference_process(self, pid: int) -> HordeProcessInfo: + """ + Starts an inference process. - logger.info(f"Started inference process (id: {pid})") + :param pid: process ID to assign to the process + :return: + """ + logger.info(f"Starting inference process on PID {pid}") + pipe_connection, child_pipe_connection = multiprocessing.Pipe(duplex=True) + # Create a new process that will run the start_inference_process function + process = multiprocessing.Process( + target=start_inference_process, + args=( + pid, + self._process_message_queue, + child_pipe_connection, + self._inference_semaphore, + self._disk_lock, + ), + ) + process.start() + # Add the process to the process map + process_info = HordeProcessInfo( + mp_process=process, + pipe_connection=pipe_connection, + process_id=pid, + process_type=HordeProcessType.INFERENCE, + last_process_state=HordeProcessState.PROCESS_STARTING, + ) + self._process_map[pid] = process_info + return process_info def end_inference_processes(self) -> None: """End any inference processes above the configured limit, or all of them if shutting down.""" @@ -778,19 +845,35 @@ def end_inference_processes(self) -> None: # Get the process to end process_info = self._process_map._get_first_inference_process_to_kill() - if process_info is None: - return + if process_info is not None: + self._end_inference_process(process_info) + + def _end_inference_process(self, process_info: HordeProcessInfo) -> None: + """ + Ends an inference process. + :param process_info: HordeProcessInfo for the process to end + :return: None + """ # Send the process a message to end process_info.pipe_connection.send(HordeControlMessage(control_flag=HordeControlFlag.END_PROCESS)) - # Update the process map self._process_map.update_entry(process_id=process_info.process_id) - logger.info(f"Ended inference process {process_info.process_id}") - # Join the process with a timeout of 1 second process_info.mp_process.join(timeout=1) + process_info.mp_process.kill() + + def _replace_inference_process(self, process_info: HordeProcessInfo) -> None: + """ + Replaces an inference process (for whatever reason). + + :param process_info: process to replace + :return: None + """ + logger.debug(f"Replacing {process_info}") + self._end_inference_process(process_info) + self._start_inference_process(process_info.process_id) total_num_completed_jobs: int = 0 @@ -1058,11 +1141,34 @@ def preload_models(self) -> bool: if model_is_loaded: continue - available_process = self._process_map.get_first_available_inference_process() + available_process = None + model_to_unload = self._lru.append(job.model) + if model_to_unload is not None: + for p in self._process_map.values(): + if p.loaded_horde_model_name == model_to_unload and ( + p.last_process_state == HordeProcessState.INFERENCE_COMPLETE + or p.last_process_state == HordeProcessState.WAITING_FOR_JOB + ): + available_process = p + if available_process is None: + available_process = self._process_map.get_first_available_inference_process() if available_process is None: return False + if ( + available_process.last_process_state != HordeProcessState.WAITING_FOR_JOB + and available_process.loaded_horde_model_name is not None + ): + # We're going to restart the process and then exit the loop, because + # available_process is very quickly _not_ going to be available. + # We also don't want to block waiting for the newly forked job to become + # available, so we'll wait for it to become ready before scheduling a model + # to be loaded on it. + self._replace_inference_process(available_process) + self._horde_model_map.expire_entry(available_process.loaded_horde_model_name) + return False + logger.debug(f"Preloading model {job.model} on process {available_process.process_id}") logger.debug(f"Available inference processes: {self._process_map}") logger.debug(f"Horde model map: {self._horde_model_map}") @@ -1145,11 +1251,7 @@ def start_inference(self) -> None: next_n_models = list(self.get_next_n_models(self.max_inference_processes)) # If the model would be used by another process soon, don't unload it - if ( - self.max_concurrent_inference_processes > 1 - and process_info.loaded_horde_model_name - in next_n_models[: self.max_concurrent_inference_processes - 1] - ): + if process_info.loaded_horde_model_name in next_n_models: continue process_info.pipe_connection.send( @@ -1770,6 +1872,8 @@ async def api_job_pop(self) -> None: async with self._job_deque_lock, self._job_pop_timestamps_lock: self.job_deque.append(job_pop_response) + jobs = [f"<{x.id_}: {x.model}>" for x in self.job_deque] + logger.info(f'Job queue: {", ".join(jobs)}') # self._testing_jobs_added += 1 self.job_pop_timestamps[str(job_pop_response.id_)] = time.time() @@ -1885,6 +1989,7 @@ async def _process_control_loop(self) -> None: self.start_inference_processes() while True: + logger.debug("_process_control_loop looped") try: if self.stable_diffusion_reference is None: return @@ -1908,6 +2013,8 @@ async def _process_control_loop(self) -> None: async with self._job_deque_lock, self._jobs_safety_check_lock, self._completed_jobs_lock: self.receive_and_handle_process_messages() + self.replace_hung_processes() + self.unload_models() if self._shutting_down: @@ -1921,6 +2028,8 @@ async def _process_control_loop(self) -> None: logger.info(f"{self._process_map}") logger.info(f"Threads being used: {self._max_concurrent_inference_processes}") logger.info(f"Number of jobs popped: {len(self.job_deque)}") + jobs = [f"<{x.id_}: {x.model}>" for x in self.job_deque] + logger.info(f'Job queue: {", ".join(jobs)}') logger.info(f"Number of jobs in progress: {len(self.jobs_in_progress)}") logger.info(f"Number of jobs pending safety check: {len(self.jobs_pending_safety_check)}") logger.info(f"Number of jobs being safety checked: {len(self.jobs_being_safety_checked)}") @@ -2002,13 +2111,34 @@ async def _bridge_data_loop(self) -> None: except CancelledError: self._shutting_down = True + @staticmethod + async def _handle_exception(task: Task) -> None: + """ + Logs exceptions from asyncio tasks. + + :param task: asyncio task to monitor + :return: None + """ + try: + await task + except Exception as e: + logger.error(f"Caught exception in task {task}: {e}") + async def _main_loop(self) -> None: # Run both loops concurrently + process_control_loop = asyncio.create_task(self._process_control_loop(), name="process_control_loop") + api_call_loop = asyncio.create_task(self._api_call_loop(), name="api_call_loop") + job_submit_loop = asyncio.create_task(self._job_submit_loop(), name="job_submit_loop") + bridge_data_loop = asyncio.create_task(self._bridge_data_loop(), name="bridge_data_loop") + process_control_loop.add_done_callback(self._handle_exception) + api_call_loop.add_done_callback(self._handle_exception) + job_submit_loop.add_done_callback(self._handle_exception) + bridge_data_loop.add_done_callback(self._handle_exception) await asyncio.gather( - asyncio.create_task(self._process_control_loop(), name="process_control_loop"), - asyncio.create_task(self._api_call_loop(), name="api_call_loop"), - asyncio.create_task(self._job_submit_loop(), name="job_submit_loop"), - asyncio.create_task(self._bridge_data_loop(), name="bridge_data_loop"), + process_control_loop, + api_call_loop, + job_submit_loop, + bridge_data_loop, ) _caught_sigints = 0 @@ -2046,3 +2176,15 @@ def shutdown() -> None: sys.exit(0) threading.Thread(target=shutdown).start() + + def replace_hung_processes(self) -> None: + """ + Replaces processes that haven't checked in since `self.process_timeout` + + :return: None + """ + now = datetime.datetime.now() + for process_info in self._process_map.values(): + if (now - process_info.last_timestamp) > self.process_timeout: + logger.error(f"{process_info} has exceeded its timeout and will be replaced") + self._replace_inference_process(process_info)