diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index 5520fed5..3dd3bb32 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -1,5 +1,7 @@ import asyncio import base64 +import collections +import datetime import multiprocessing import os import random @@ -81,6 +83,8 @@ class HordeProcessInfo: """The type of this process.""" last_process_state: HordeProcessState """The last known state of this process.""" + last_timestamp: datetime.datetime + """Last time we updated the process info. If we're regularly working, then this value should change frequently.""" loaded_horde_model_name: str | None = None """The name of the horde model that is (supposedly) currently loaded in this process.""" @@ -115,6 +119,7 @@ def __init__( self.process_id = process_id self.process_type = process_type self.last_process_state = last_process_state + self.last_timestamp = datetime.datetime.now() def is_process_busy(self) -> bool: """Return true if the process is actively engaged in a task. @@ -186,6 +191,9 @@ def update_entry( if process_id is not None: self.root[horde_model_name].process_id = process_id + def expire_entry(self, horde_model_name: str): + self.root.pop(horde_model_name, 'None') + def is_model_loaded(self, horde_model_name: str) -> bool: """Return true if the given model is loaded in any process.""" if horde_model_name not in self.root: @@ -242,6 +250,8 @@ def update_entry( if total_vram_bytes is not None: self[process_id].total_vram_bytes = total_vram_bytes + self[process_id].last_timestamp = datetime.datetime.now() + def num_inference_processes(self) -> int: """Return the number of inference processes.""" count = 0 @@ -260,6 +270,12 @@ def num_available_inference_processes(self) -> int: def get_first_available_inference_process(self) -> HordeProcessInfo | None: """Return the first available inference process, or None if there are none available.""" + for p in self.values(): + if p.process_type == HordeProcessType.INFERENCE \ + and p.last_process_state == HordeProcessState.WAITING_FOR_JOB \ + and p.loaded_horde_model_name is None: + return p + for p in self.values(): if p.process_type == HordeProcessType.INFERENCE and p.can_accept_job(): if p.last_process_state == HordeProcessState.PRELOADED_MODEL: @@ -336,7 +352,8 @@ def __repr__(self) -> str: base_string = "Processes: " for process_id, process_info in self.items(): if process_info.process_type == HordeProcessType.INFERENCE: - base_string += f"{process_id}: ({process_info.loaded_horde_model_name}) " + base_string += (f"{process_id}: ({process_info.loaded_horde_model_name} " + f"[last event: {process_info.last_timestamp}]) ") else: base_string += f"{process_id}: ({process_info.process_type.name}) " base_string += f"{process_info.last_process_state.name}; " @@ -380,6 +397,21 @@ def is_job_checked_for_safety(self) -> bool: return self.censored is not None +class LRUCache: + def __init__(self, capacity): + self.capacity = capacity + self.cache = collections.OrderedDict() + + def append(self, key): + bumped = None + if key in self.cache: + self.cache.move_to_end(key) + elif len(self.cache) >= self.capacity: + bumped, _ = self.cache.popitem(last=False) + self.cache[key] = None + return bumped + + class HordeWorkerProcessManager: """Manages and controls processes to act as a horde worker.""" @@ -412,6 +444,9 @@ def max_concurrent_inference_processes(self) -> int: target_vram_overhead_bytes_map: Mapping[int, int] | None = None + process_timeout: datetime.timedelta + """Max amount of time a job can go without checking in with the main process manager""" + @property def max_queue_size(self) -> int: """The maximum number of jobs that can be queued.""" @@ -501,6 +536,8 @@ def num_total_processes(self) -> int: _shutting_down = False + _lru: LRUCache + def __init__( self, *, @@ -542,6 +579,8 @@ def __init__( self._inference_semaphore = Semaphore(self._max_concurrent_inference_processes, ctx=ctx) self.max_inference_processes = self.bridge_data.queue_size + self.bridge_data.max_threads + self._lru = LRUCache(self.max_inference_processes) + self.process_timeout = datetime.timedelta(minutes=5) # If there is only one model to load and only one inference process, then we can only run one job at a time # and there is no point in having more than one inference process @@ -743,33 +782,34 @@ def start_inference_processes(self) -> None: for _ in range(num_processes_to_start): # Create a two-way communication pipe for the parent and child processes pid = len(self._process_map) - pipe_connection, child_pipe_connection = multiprocessing.Pipe(duplex=True) - - # Create a new process that will run the start_inference_process function - process = multiprocessing.Process( - target=start_inference_process, - args=( - pid, - self._process_message_queue, - child_pipe_connection, - self._inference_semaphore, - self._disk_lock, - ), - ) - - process.start() - - # Add the process to the process map - self._process_map[pid] = HordeProcessInfo( - mp_process=process, - pipe_connection=pipe_connection, - process_id=pid, - process_type=HordeProcessType.INFERENCE, - last_process_state=HordeProcessState.PROCESS_STARTING, - ) + self._start_inference_process(pid) logger.info(f"Started inference process (id: {pid})") + def _start_inference_process(self, pid): + logger.info(f"Starting inference process on PID {pid}") + pipe_connection, child_pipe_connection = multiprocessing.Pipe(duplex=True) + # Create a new process that will run the start_inference_process function + process = multiprocessing.Process( + target=start_inference_process, + args=( + pid, + self._process_message_queue, + child_pipe_connection, + self._inference_semaphore, + self._disk_lock, + ), + ) + process.start() + # Add the process to the process map + self._process_map[pid] = HordeProcessInfo( + mp_process=process, + pipe_connection=pipe_connection, + process_id=pid, + process_type=HordeProcessType.INFERENCE, + last_process_state=HordeProcessState.PROCESS_STARTING, + ) + def end_inference_processes(self) -> None: """End any inference processes above the configured limit, or all of them if shutting down.""" if len(self.job_deque) > 0 and len(self.job_deque) != len(self.jobs_in_progress): @@ -778,19 +818,23 @@ def end_inference_processes(self) -> None: # Get the process to end process_info = self._process_map._get_first_inference_process_to_kill() - if process_info is None: - return + if process_info is not None: + self._end_inference_process(process_info) + def _end_inference_process(self, process_info): # Send the process a message to end process_info.pipe_connection.send(HordeControlMessage(control_flag=HordeControlFlag.END_PROCESS)) - # Update the process map self._process_map.update_entry(process_id=process_info.process_id) - logger.info(f"Ended inference process {process_info.process_id}") - # Join the process with a timeout of 1 second process_info.mp_process.join(timeout=1) + process_info.mp_process.kill() + + def _replace_inference_process(self, process_info): + logger.debug(f"Replacing {process_info}") + self._end_inference_process(process_info) + self._start_inference_process(process_info.process_id) total_num_completed_jobs: int = 0 @@ -1058,11 +1102,31 @@ def preload_models(self) -> bool: if model_is_loaded: continue - available_process = self._process_map.get_first_available_inference_process() + available_process = None + model_to_unload = self._lru.append(job.model) + if model_to_unload is not None: + for p in self._process_map.values(): + if p.loaded_horde_model_name == model_to_unload and \ + (p.last_process_state == HordeProcessState.INFERENCE_COMPLETE or \ + p.last_process_state == HordeProcessState.WAITING_FOR_JOB): + available_process = p + if available_process is None: + available_process = self._process_map.get_first_available_inference_process() if available_process is None: return False + if available_process.last_process_state != HordeProcessState.WAITING_FOR_JOB \ + and available_process.loaded_horde_model_name is not None: + # We're going to restart the process and then exit the loop, because + # available_process is very quickly _not_ going to be available. + # We also don't want to block waiting for the newly forked job to become + # available, so we'll wait for it to become ready before scheduling a model + # to be loaded on it. + self._replace_inference_process(available_process) + self._horde_model_map.expire_entry(available_process.loaded_horde_model_name) + return False + logger.debug(f"Preloading model {job.model} on process {available_process.process_id}") logger.debug(f"Available inference processes: {self._process_map}") logger.debug(f"Horde model map: {self._horde_model_map}") @@ -1145,11 +1209,7 @@ def start_inference(self) -> None: next_n_models = list(self.get_next_n_models(self.max_inference_processes)) # If the model would be used by another process soon, don't unload it - if ( - self.max_concurrent_inference_processes > 1 - and process_info.loaded_horde_model_name - in next_n_models[: self.max_concurrent_inference_processes - 1] - ): + if process_info.loaded_horde_model_name in next_n_models: continue process_info.pipe_connection.send( @@ -1770,6 +1830,8 @@ async def api_job_pop(self) -> None: async with self._job_deque_lock, self._job_pop_timestamps_lock: self.job_deque.append(job_pop_response) + jobs = list(map(lambda x: f'<{x.id_}: {x.model}>', self.job_deque)) + logger.info(f'Job queue: {", ".join(jobs)}') # self._testing_jobs_added += 1 self.job_pop_timestamps[str(job_pop_response.id_)] = time.time() @@ -1885,6 +1947,7 @@ async def _process_control_loop(self) -> None: self.start_inference_processes() while True: + logger.debug("_process_control_loop looped") try: if self.stable_diffusion_reference is None: return @@ -1908,6 +1971,8 @@ async def _process_control_loop(self) -> None: async with self._job_deque_lock, self._jobs_safety_check_lock, self._completed_jobs_lock: self.receive_and_handle_process_messages() + self.replace_hung_processes() + self.unload_models() if self._shutting_down: @@ -1921,6 +1986,8 @@ async def _process_control_loop(self) -> None: logger.info(f"{self._process_map}") logger.info(f"Threads being used: {self._max_concurrent_inference_processes}") logger.info(f"Number of jobs popped: {len(self.job_deque)}") + jobs = list(map(lambda x: f'<{x.id_}: {x.model}>', self.job_deque)) + logger.info(f'Job queue: {", ".join(jobs)}') logger.info(f"Number of jobs in progress: {len(self.jobs_in_progress)}") logger.info(f"Number of jobs pending safety check: {len(self.jobs_pending_safety_check)}") logger.info(f"Number of jobs being safety checked: {len(self.jobs_being_safety_checked)}") @@ -2002,13 +2069,28 @@ async def _bridge_data_loop(self) -> None: except CancelledError: self._shutting_down = True + @staticmethod + async def _handle_exception(task): + try: + await task + except Exception as e: + logger.error(f"Caught exception in task {task}: {e}") + async def _main_loop(self) -> None: # Run both loops concurrently + process_control_loop = asyncio.create_task(self._process_control_loop(), name="process_control_loop") + api_call_loop = asyncio.create_task(self._api_call_loop(), name="api_call_loop") + job_submit_loop = asyncio.create_task(self._job_submit_loop(), name="job_submit_loop") + bridge_data_loop = asyncio.create_task(self._bridge_data_loop(), name="bridge_data_loop") + process_control_loop.add_done_callback(self._handle_exception) + api_call_loop.add_done_callback(self._handle_exception) + job_submit_loop.add_done_callback(self._handle_exception) + bridge_data_loop.add_done_callback(self._handle_exception) await asyncio.gather( - asyncio.create_task(self._process_control_loop(), name="process_control_loop"), - asyncio.create_task(self._api_call_loop(), name="api_call_loop"), - asyncio.create_task(self._job_submit_loop(), name="job_submit_loop"), - asyncio.create_task(self._bridge_data_loop(), name="bridge_data_loop"), + process_control_loop, + api_call_loop, + job_submit_loop, + bridge_data_loop, ) _caught_sigints = 0 @@ -2046,3 +2128,10 @@ def shutdown() -> None: sys.exit(0) threading.Thread(target=shutdown).start() + + def replace_hung_processes(self): + now = datetime.datetime.now() + for pid, process_info in self._process_map.items(): + if (now - process_info.last_timestamp) > self.process_timeout: + logger.error(f"{process_info} has exceeded its timeout and will be replaced") + self._replace_inference_process(process_info)