From b8ef014a856cec32198c5979b7e2d86d7b11e5bc Mon Sep 17 00:00:00 2001 From: tazlin Date: Fri, 6 Oct 2023 12:38:07 -0400 Subject: [PATCH 1/2] fix: be less likely to fail when PIL data is corrupt or wrong --- .../process_management/inference_process.py | 22 ++++++++- .../process_management/messages.py | 2 + .../process_management/process_manager.py | 49 +++++++++++++++---- .../process_management/safety_process.py | 15 +++++- 4 files changed, 76 insertions(+), 12 deletions(-) diff --git a/horde_worker_regen/process_management/inference_process.py b/horde_worker_regen/process_management/inference_process.py index 4616d346..9778d45f 100644 --- a/horde_worker_regen/process_management/inference_process.py +++ b/horde_worker_regen/process_management/inference_process.py @@ -155,6 +155,7 @@ def on_horde_model_state_change( process_state=process_state, info=f"Model {horde_model_name} {horde_model_state.name}", ) + self.send_memory_report_message(include_vram=True) def download_callback(self, downloaded_bytes: int, total_bytes: int) -> None: if downloaded_bytes % (total_bytes / 20) == 0: @@ -246,7 +247,12 @@ def preload_model( def start_inference(self, job_info: ImageGenerateJobPopResponse) -> list[Image] | None: with self._inference_semaphore: self._is_busy = True - results = self._horde.basic_inference(job_info) + try: + results = self._horde.basic_inference(job_info) + except Exception as e: + logger.critical(f"Inference failed: {type(e).__name__} {e}") + return None + self._is_busy = False return results @@ -369,6 +375,20 @@ def _receive_and_handle_control_message(self, message: HordeControlMessage) -> N time_start = time.time() images = self.start_inference(message.job_info) + + if images is None: + self.send_memory_report_message(include_vram=True) + self.send_process_state_change_message( + process_state=HordeProcessState.INFERENCE_FAILED, + info=f"Inference failed for job {message.job_info.id_}", + ) + + logger.debug("Unloading models from RAM") + self.unload_models_from_ram() + logger.debug("Unloaded models from RAM") + self.send_memory_report_message(include_vram=True) + return + process_state = HordeProcessState.INFERENCE_COMPLETE if images else HordeProcessState.INFERENCE_FAILED logger.debug(f"Finished inference with process state {process_state}") self.send_inference_result_message( diff --git a/horde_worker_regen/process_management/messages.py b/horde_worker_regen/process_management/messages.py index a1e0eb14..b9171ff4 100644 --- a/horde_worker_regen/process_management/messages.py +++ b/horde_worker_regen/process_management/messages.py @@ -61,6 +61,7 @@ class HordeProcessState(enum.Enum): ALCHEMY_FAILED = auto() EVALUATING_SAFETY = auto() + SAFETY_FAILED = auto() class HordeProcessMessage(BaseModel): @@ -114,6 +115,7 @@ class HordeSafetyEvaluation(BaseModel): is_nsfw: bool is_csam: bool replacement_image_base64: str | None + failed: bool = False class HordeSafetyResultMessage(HordeProcessMessage): diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index 9174c508..bbb6b8da 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -848,6 +848,13 @@ def receive_and_handle_process_messages(self) -> None: for i in range(len(completed_job_info.job_result_images_base64)): replacement_image = message.safety_evaluations[i].replacement_image_base64 + + if message.safety_evaluations[i].failed: + logger.error( + f"Job {message.job_id} image #{i} wasn't safety checked ", + ) + continue + if replacement_image is not None: completed_job_info.job_result_images_base64[i] = replacement_image num_images_censored += 1 @@ -1173,18 +1180,22 @@ def start_evaluate_safety(self) -> None: ), ) - def base64_image_to_stream_buffer(self, image_base64: str) -> BytesIO: + def base64_image_to_stream_buffer(self, image_base64: str) -> BytesIO | None: """Convert a base64 image to a BytesIO stream buffer.""" - image_as_pil = PIL.Image.open(BytesIO(base64.b64decode(image_base64))) - image_buffer = BytesIO() - image_as_pil.save( - image_buffer, - format="WebP", - quality=95, # FIXME # TODO - method=6, - ) + try: + image_as_pil = PIL.Image.open(BytesIO(base64.b64decode(image_base64))) + image_buffer = BytesIO() + image_as_pil.save( + image_buffer, + format="WebP", + quality=95, # FIXME # TODO + method=6, + ) - return image_buffer + return image_buffer + except Exception as e: + logger.error(f"Failed to convert base64 image to stream buffer: {e}") + return None _consecutive_failed_results = 0 _consecutive_failed_results_max = 10 @@ -1229,6 +1240,24 @@ async def api_submit_job( image_in_buffer = self.base64_image_to_stream_buffer(completed_job_info.job_result_images_base64[0]) + if image_in_buffer is None: + logger.critical( + f"There is an invalid image in the job results for {job_info.id_}, removing from completed jobs", + ) + async with self._completed_jobs_lock: + self.completed_jobs.remove(completed_job_info) + self._consecutive_failed_results = 0 + + for follow_up_request in completed_job_info.job_info.get_follow_up_failure_cleanup_request(): + follow_up_response = self.horde_client_session.submit_request( + follow_up_request, + JobSubmitResponse, + ) + + if isinstance(follow_up_response, RequestErrorResponse): + logger.error(f"Failed to submit followup request: {follow_up_response}") + return + # TODO: This would be better (?) if we could use aiohttp instead of requests # except for the fact that it causes S3 to return a 403 Forbidden error diff --git a/horde_worker_regen/process_management/safety_process.py b/horde_worker_regen/process_management/safety_process.py index 9dcbbdf6..0f103bc8 100644 --- a/horde_worker_regen/process_management/safety_process.py +++ b/horde_worker_regen/process_management/safety_process.py @@ -126,7 +126,20 @@ def _receive_and_handle_control_message(self, message: HordeControlMessage) -> N for image_base64 in message.images_base64: # Decode the image from base64 image_bytes = BytesIO(base64.b64decode(image_base64)) - image_as_pil = PIL.Image.open(image_bytes) + try: + image_as_pil = PIL.Image.open(image_bytes) + except Exception as e: + logger.error(f"Failed to open image: {type(e).__name__} {e}") + safety_evaluations.append( + HordeSafetyEvaluation( + is_nsfw=True, + is_csam=True, + replacement_image_base64=None, + failed=True, + ), + ) + + continue nsfw_result: NSFWResult | None = self._nsfw_checker.check_for_nsfw( image=image_as_pil, From e4e6fc3dd47a0e4db1fa9734a4f768c3da2c45a9 Mon Sep 17 00:00:00 2001 From: tazlin Date: Mon, 9 Oct 2023 12:16:37 -0400 Subject: [PATCH 2/2] refactor: more error checking, better flow+control; docs: more docstrings - Improved chances that messages line up with expectation in `HordeWorkerProcessManager`'s main loop - Have processes report failures more gracefully and in more situations - Resume popping jobs after a certain number of seconds pass when pausing for queued megapixel steps] - Added many inline comments and missing docstrings --- .../process_management/horde_process.py | 70 ++- .../process_management/inference_process.py | 102 +++- .../process_management/messages.py | 90 ++- .../process_management/process_manager.py | 558 ++++++++++++------ .../process_management/safety_process.py | 4 +- run_worker.py | 2 +- 6 files changed, 607 insertions(+), 219 deletions(-) diff --git a/horde_worker_regen/process_management/horde_process.py b/horde_worker_regen/process_management/horde_process.py index e242f448..9d911eb1 100644 --- a/horde_worker_regen/process_management/horde_process.py +++ b/horde_worker_regen/process_management/horde_process.py @@ -58,20 +58,25 @@ class HordeProcess(abc.ABC): """The time to sleep between each loop iteration.""" _end_process: bool = False + """Whether the process should end soon.""" _memory_report_interval: float = 5.0 """The time to wait between each memory report.""" _last_sent_process_state: HordeProcessState = HordeProcessState.PROCESS_STARTING + """The last process state that was sent to the main process.""" _vram_total_bytes: int = 0 + """The total number of bytes of VRAM available on the GPU.""" def get_vram_usage_bytes(self) -> int: + """Return the number of bytes of VRAM used by the GPU.""" from hordelib.comfy_horde import get_torch_free_vram_mb, get_torch_total_vram_mb return get_torch_total_vram_mb() - get_torch_free_vram_mb() def get_vram_total_bytes(self) -> int: + """Return the total number of bytes of VRAM available on the GPU.""" from hordelib.comfy_horde import get_torch_total_vram_mb return get_torch_total_vram_mb() @@ -83,6 +88,16 @@ def __init__( pipe_connection: Connection, disk_lock: Lock, ) -> None: + """Initialise the process. + + Args: + process_id (int): The ID of the process. This is not the same as the process PID. + process_message_queue (ProcessQueue): The queue the main process uses to receive messages from all worker \ + processes. + pipe_connection (Connection): Receives `HordeControlMessage`s from the main process. + disk_lock (Lock): A lock used to prevent multiple processes from accessing disk at the same time. + """ + self.process_id = process_id self.process_message_queue = process_message_queue self.pipe_connection = pipe_connection @@ -109,6 +124,15 @@ def send_process_state_change_message( info: str, time_elapsed: float | None = None, ) -> None: + """Send a process state change message to the main process. + + Args: + process_state (HordeProcessState): The state of the process. + info (str): Information about the process. + time_elapsed (float | None, optional): The time elapsed during the last operation, if applicable. \ + Defaults to None. + + """ message = HordeProcessStateChangeMessage( process_state=process_state, process_id=self.process_id, @@ -122,7 +146,12 @@ def send_process_state_change_message( _last_heartbeat_time: float = 0.0 def send_heartbeat_message(self) -> None: - """Send a heartbeat message to the main process.""" + """Send a heartbeat message to the main process, indicating that the process is still alive + during an operation. + + Note that this will only send a heartbeat message if the last heartbeat was sent more than + `_heartbeat_limit_interval_seconds` ago. + """ if (time.time() - self._last_heartbeat_time) < self._heartbeat_limit_interval_seconds: return @@ -137,14 +166,18 @@ def send_heartbeat_message(self) -> None: self._last_heartbeat_time = time.time() @abstractmethod - def cleanup_and_exit(self) -> None: + def cleanup_for_exit(self) -> None: """Cleanup and exit the process.""" def send_memory_report_message( self, include_vram: bool = False, ) -> None: - """Send a memory report message to the main process.""" + """Send a memory report message to the main process. + + Args: + include_vram (bool, optional): Whether to include VRAM usage in the message. Defaults to False. + """ message = HordeProcessMemoryMessage( process_id=self.process_id, info="Memory report", @@ -160,7 +193,11 @@ def send_memory_report_message( @abstractmethod def _receive_and_handle_control_message(self, message: HordeControlMessage) -> None: - """Receive and handle a control message from the main process.""" + """Receive and handle a control message from the main process. + + Args: + message (HordeControlMessage): The message to handle. + """ def receive_and_handle_control_messages(self) -> None: """Get and handle any control messages pending from the main process.""" @@ -179,7 +216,8 @@ def receive_and_handle_control_messages(self) -> None: self._receive_and_handle_control_message(message) def worker_cycle(self) -> None: - """Called after messages have been received and handled. Override this to implement any additional logic.""" + """Called after messages have been received and handled. Override this to implement any process specific \ + logic.""" return def main_loop(self) -> None: @@ -187,27 +225,28 @@ def main_loop(self) -> None: signal.signal(signal.SIGINT, signal_handler) while not self._end_process: - try: - time.sleep(self._loop_interval) - - self.receive_and_handle_control_messages() - - self.worker_cycle() - except KeyboardInterrupt: - logger.info("Keyboard interrupt received") + time.sleep(self._loop_interval) + self.receive_and_handle_control_messages() + self.worker_cycle() + # We escaped the loop, so the process is ending self.send_process_state_change_message( process_state=HordeProcessState.PROCESS_ENDING, info="Process ending", ) - self.cleanup_and_exit() + self.cleanup_for_exit() logger.info("Process ended") self.send_process_state_change_message( process_state=HordeProcessState.PROCESS_ENDED, info="Process ended", ) + + # We are exiting, so send a final memory report + self.send_memory_report_message(include_vram=True) + + # Exit the process (we expect to be a child process) sys.exit(0) @@ -215,6 +254,9 @@ def main_loop(self) -> None: def signal_handler(sig: int, frame: object) -> None: + """Called when a signal is received. This will exit the process gracefully if the process has only received one \ + signal, or exit immediately if the process has received two signals.""" + global _signals_caught if _signals_caught >= 1: logger.warning("Received second signal, exiting immediately") diff --git a/horde_worker_regen/process_management/inference_process.py b/horde_worker_regen/process_management/inference_process.py index 9778d45f..52da8a2b 100644 --- a/horde_worker_regen/process_management/inference_process.py +++ b/horde_worker_regen/process_management/inference_process.py @@ -2,11 +2,9 @@ import base64 import contextlib -import enum import io import sys import time -from enum import auto try: from multiprocessing.connection import PipeConnection as Connection # type: ignore @@ -42,7 +40,7 @@ from hordelib.nodes.node_model_loader import HordeCheckpointLoader from hordelib.shared_model_manager import SharedModelManager else: - # Create a dummy class to prevent type errors + # Create a dummy class to prevent type errors at runtime class HordeCheckpointLoader: pass @@ -53,11 +51,6 @@ class SharedModelManager: pass -class HordeProcessKind(enum.Enum): - INFERENCE = auto() - SAFETY = auto() - - class HordeInferenceProcess(HordeProcess): _inference_semaphore: Semaphore """A semaphore used to limit the number of concurrent inference jobs.""" @@ -67,8 +60,11 @@ class HordeInferenceProcess(HordeProcess): _shared_model_manager: SharedModelManager """The SharedModelManager instance used by this process. It is not shared between processes (despite the name).""" _checkpoint_loader: HordeCheckpointLoader + """The HordeCheckpointLoader instance used by this process. This is horde hordelib signals comfyui \ + to load a model. It is not shared between processes.""" _active_model_name: str | None = None + """The name of the currently active model. Note that other models may be loaded in RAM or VRAM.""" def __init__( self, @@ -78,6 +74,16 @@ def __init__( inference_semaphore: Semaphore, disk_lock: Lock, ) -> None: + """Initialise the HordeInferenceProcess. + + Args: + process_id (int): The ID of the process. This is not the same as the process PID. + process_message_queue (ProcessQueue): The queue the main process uses to receive messages from all worker \ + processes. + pipe_connection (Connection): Receives `HordeControlMessage`s from the main process. + inference_semaphore (Semaphore): A semaphore used to limit the number of concurrent inference jobs. + disk_lock (Lock): A lock used to prevent multiple processes from accessing disk at the same time. + """ super().__init__( process_id=process_id, process_message_queue=process_message_queue, @@ -85,8 +91,14 @@ def __init__( disk_lock=disk_lock, ) - from hordelib.horde import HordeLib - from hordelib.shared_model_manager import SharedModelManager + # We import these here to guard against potentially importing them in the main process + # which would create shared objects, potentially causing issues + try: + from hordelib.horde import HordeLib + from hordelib.shared_model_manager import SharedModelManager + except Exception as e: + logger.critical(f"Failed to import HordeLib or SharedModelManager: {type(e).__name__} {e}") + sys.exit(1) self._inference_semaphore = inference_semaphore @@ -129,6 +141,7 @@ def __init__( ) def _comfyui_callback(self, label: str, data: dict, _id: str) -> None: + # TODO self.send_heartbeat_message() def on_horde_model_state_change( @@ -138,7 +151,16 @@ def on_horde_model_state_change( horde_model_state: ModelLoadState, time_elapsed: float | None = None, ) -> None: - """Update the main process with the current process state and model state.""" + """Update the main process with the current process state and model state. + + Args: + horde_model_name (str): The name of the model. + process_state (HordeProcessState): The state of the process. + horde_model_state (ModelLoadState): The state of the model. + time_elapsed (float | None, optional): The time elapsed during the last operation, if applicable. \ + Defaults to None. + """ + self.send_memory_report_message(include_vram=True) model_update_message = HordeModelStateChangeMessage( @@ -158,6 +180,7 @@ def on_horde_model_state_change( self.send_memory_report_message(include_vram=True) def download_callback(self, downloaded_bytes: int, total_bytes: int) -> None: + # TODO if downloaded_bytes % (total_bytes / 20) == 0: self.send_process_state_change_message( process_state=HordeProcessState.DOWNLOADING_MODEL, @@ -165,6 +188,7 @@ def download_callback(self, downloaded_bytes: int, total_bytes: int) -> None: ) def download_model(self, horde_model_name: str) -> None: + # TODO self.send_process_state_change_message( process_state=HordeProcessState.DOWNLOADING_MODEL, info=f"Downloading model {horde_model_name}", @@ -196,6 +220,13 @@ def preload_model( will_load_loras: bool, seamless_tiling_enabled: bool, ) -> None: + """Preload a model into RAM. + + Args: + horde_model_name (str): The name of the model to preload. + will_load_loras (bool): Whether or not the model will be loaded into VRAM. + seamless_tiling_enabled (bool): Whether or not seamless tiling is enabled. + """ if self._active_model_name == horde_model_name: return @@ -245,18 +276,28 @@ def preload_model( _is_busy: bool = False def start_inference(self, job_info: ImageGenerateJobPopResponse) -> list[Image] | None: + """Start an inference job in the HordeLib instance. + + Args: + job_info (ImageGenerateJobPopResponse): The job to start inference on. + + Returns: + list[Image] | None: The generated images, or None if inference failed. + """ with self._inference_semaphore: self._is_busy = True try: results = self._horde.basic_inference(job_info) except Exception as e: logger.critical(f"Inference failed: {type(e).__name__} {e}") + self._is_busy = False return None self._is_busy = False return results def unload_models_from_vram(self) -> None: + """Unload all models from VRAM.""" from hordelib.comfy_horde import unload_all_models_vram unload_all_models_vram() @@ -278,6 +319,7 @@ def unload_models_from_vram(self) -> None: ) def unload_models_from_ram(self) -> None: + """Unload all models from RAM.""" from hordelib.comfy_horde import unload_all_models_ram unload_all_models_ram() @@ -301,7 +343,8 @@ def unload_models_from_ram(self) -> None: logger.info("Unloaded all models from RAM") self._active_model_name = None - def cleanup_and_exit(self) -> None: + def cleanup_for_exit(self) -> None: + """Cleanup the process pending a shutdown.""" self.unload_models_from_ram() self.send_process_state_change_message( process_state=HordeProcessState.PROCESS_ENDED, @@ -315,6 +358,14 @@ def send_inference_result_message( images: list[Image] | None, time_elapsed: float, ) -> None: + """Send an inference result message to the main process. + + Args: + process_state (HordeProcessState): The state of the process. + job_info (ImageGenerateJobPopResponse): The job that was inferred. + images (list[Image] | None): The generated images, or None if inference failed. + time_elapsed (float): The time elapsed during the last operation. + """ images_as_base64 = [] if images is not None: @@ -346,6 +397,11 @@ def send_inference_result_message( @override def _receive_and_handle_control_message(self, message: HordeControlMessage) -> None: + """Receive and handle a control message from the main process. + + Args: + message (HordeControlMessage): The message to handle. + """ logger.debug(f"Received ({type(message)}): {message.control_flag}") if isinstance(message, HordePreloadInferenceModelMessage): @@ -378,15 +434,33 @@ def _receive_and_handle_control_message(self, message: HordeControlMessage) -> N if images is None: self.send_memory_report_message(include_vram=True) - self.send_process_state_change_message( + self.send_inference_result_message( process_state=HordeProcessState.INFERENCE_FAILED, - info=f"Inference failed for job {message.job_info.id_}", + job_info=message.job_info, + images=None, + time_elapsed=time.time() - time_start, ) + active_model_name = self._active_model_name logger.debug("Unloading models from RAM") self.unload_models_from_ram() logger.debug("Unloaded models from RAM") self.send_memory_report_message(include_vram=True) + + if active_model_name is None: + logger.critical("No active model name, cannot update model state") + + else: + self.preload_model( + active_model_name, + will_load_loras=True, + seamless_tiling_enabled=False, + ) + + self.send_process_state_change_message( + process_state=HordeProcessState.WAITING_FOR_JOB, + info="Waiting for job", + ) return process_state = HordeProcessState.INFERENCE_COMPLETE if images else HordeProcessState.INFERENCE_FAILED diff --git a/horde_worker_regen/process_management/messages.py b/horde_worker_regen/process_management/messages.py index b9171ff4..2b9d9d9e 100644 --- a/horde_worker_regen/process_management/messages.py +++ b/horde_worker_regen/process_management/messages.py @@ -13,6 +13,11 @@ class ModelLoadState(enum.Enum): + """The state of a model. + + e.g., if a model is `IN_USE` or `LOADED_IN_VRAM` + """ + DOWNLOADING = auto() ON_DISK = auto() LOADING = auto() @@ -29,36 +34,60 @@ def is_loaded(self) -> bool: class ModelInfo(BaseModel): + """Information about a model loaded or used by a process.""" + horde_model_name: str horde_model_load_state: ModelLoadState process_id: int class HordeProcessState(enum.Enum): + """The state of a process. + + e.g., if a process is `INFERENCE_STARTING` or `WAITING_FOR_JOB` + """ + PROCESS_STARTING = auto() + """The process is starting.""" PROCESS_ENDING = auto() + """The process is ending.""" PROCESS_ENDED = auto() + """The process has ended.""" WAITING_FOR_JOB = auto() + """The process is waiting for a job.""" JOB_RECEIVED = auto() + """The process has received a job.""" DOWNLOADING_MODEL = auto() + """The process is downloading a model.""" DOWNLOAD_COMPLETE = auto() + """The process has finished downloading a model.""" PRELOADING_MODEL = auto() + """The process is preloading a model.""" PRELOADED_MODEL = auto() + """The process has finished preloading a model.""" UNLOADED_MODEL_FROM_VRAM = auto() + """The process has unloaded a model from VRAM.""" UNLOADED_MODEL_FROM_RAM = auto() + """The process has unloaded a model from RAM.""" INFERENCE_STARTING = auto() + """The process is starting inference.""" INFERENCE_COMPLETE = auto() + """The process has finished inference.""" INFERENCE_FAILED = auto() + """The process has failed inference.""" ALCHEMY_STARTING = auto() + """The process is starting performing alchemy jobs.""" ALCHEMY_COMPLETE = auto() + """The process has finished performing alchemy jobs.""" ALCHEMY_FAILED = auto() + """The process has failed performing alchemy jobs.""" EVALUATING_SAFETY = auto() SAFETY_FAILED = auto() @@ -68,33 +97,55 @@ class HordeProcessMessage(BaseModel): """Process messages are sent from the child processes to the main process.""" process_id: int + """The ID of the process that sent the message.""" info: str + """Information about the process.""" time_elapsed: float | None = None + """The time elapsed since the process started.""" class HordeProcessMemoryMessage(HordeProcessMessage): + """Memory messages that are sent from the child processes to the main process.""" + ram_usage_bytes: int + """The number of bytes of RAM used by the process.""" vram_usage_bytes: int | None = None + """The number of bytes of VRAM used by the GPU.""" vram_total_bytes: int | None = None + """The total number of bytes of VRAM available on the GPU.""" class HordeProcessHeartbeatMessage(HordeProcessMessage): - pass + """Heartbeat messages that are sent from the child processes to the main process.""" class HordeProcessStateChangeMessage(HordeProcessMessage): + """State change messages that are sent from the child processes to the main process.""" + process_state: HordeProcessState + """The state of the process.""" class HordeModelStateChangeMessage(HordeProcessStateChangeMessage): + """Model state change messages that are sent from the child processes to the main process. + + See also `ModelLoadState`. + """ + horde_model_name: str + """The name of the model as defined in the horde model reference.""" horde_model_state: ModelLoadState + """The state of the model.""" class HordeDownloadProgressMessage(HordeModelStateChangeMessage): + """Download progress messages that are sent from the child processes to the main process.""" + total_downloaded_bytes: int + """The total number of bytes downloaded so far.""" total_bytes: int + """The total number of bytes that will be downloaded.""" @property def progress_percent(self) -> float: @@ -102,63 +153,98 @@ def progress_percent(self) -> float: class HordeDownloadCompleteMessage(HordeModelStateChangeMessage): - pass + """Download complete messages that are sent from the child processes to the main process.""" class HordeInferenceResultMessage(HordeProcessMessage): + """Inference result messages that are sent from the child processes to the main process.""" + job_result_images_base64: list[str] | None = None + """The base64 strings of the images generated by the job.""" state: GENERATION_STATE + """The state of the job to be sent to the API.""" job_info: ImageGenerateJobPopResponse + """The job as sent by the API.""" class HordeSafetyEvaluation(BaseModel): + """The result of a safety evaluation.""" + is_nsfw: bool + """If the image is NSFW.""" is_csam: bool + """If the image is CSAM.""" replacement_image_base64: str | None + """The base64 string of the replacement image if it was censored.""" failed: bool = False + """If the safety evaluation failed.""" class HordeSafetyResultMessage(HordeProcessMessage): + """Safety result messages that are sent from the child processes to the main process.""" + job_id: JobID + """The ID of the job that was evaluated.""" safety_evaluations: list[HordeSafetyEvaluation] + """A list of safety evaluations for each image in the job.""" class HordeControlFlag(enum.Enum): + """Control flags are sent from the main process to the child processes.""" + DOWNLOAD_MODEL = auto() + """Signal the child process to download a model.""" PRELOAD_MODEL = auto() + """Signal the child process to preload a model.""" START_INFERENCE = auto() + """Signal the child process to start inference.""" EVALUATE_SAFETY = auto() + """Signal the child process to evaluate safety of images from inference.""" UNLOAD_MODELS_FROM_VRAM = auto() + """Signal the child process to unload models from VRAM.""" UNLOAD_MODELS_FROM_RAM = auto() + """Signal the child process to unload models from RAM.""" END_PROCESS = auto() + """Signal the child process to end.""" class HordeControlMessage(BaseModel): """Control messages are sent from the main process to the child processes.""" control_flag: HordeControlFlag + """The control flag signaling the child process to perform an action.""" class HordeControlModelMessage(HordeControlMessage): horde_model_name: str + """The name of the model as defined in the horde model reference.""" class HordePreloadInferenceModelMessage(HordeControlModelMessage): will_load_loras: bool + """If the model will be patched with LoRa(s).""" seamless_tiling_enabled: bool + """If seamless tiling will be enabled.""" class HordeInferenceControlMessage(HordeControlModelMessage): job_info: ImageGenerateJobPopResponse + """The job as sent by the API.""" class HordeSafetyControlMessage(HordeControlMessage): job_id: JobID + """The ID of the job that was evaluated.""" prompt: str + """The prompt used to generate the images.""" censor_nsfw: bool + """If NSFW images should be censored.""" sfw_worker: bool + """If the worker is SFW.""" images_base64: list[str] + """The base64 strings of the images generated by the job.""" horde_model_info: dict + """The model info as defined in the horde model reference.""" @model_validator(mode="after") def validate_censor_flags_logical(self) -> HordeSafetyControlMessage: diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index bbb6b8da..56157ca7 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -42,7 +42,7 @@ from horde_worker_regen.bridge_data.load_config import BridgeDataLoader from horde_worker_regen.consts import BRIDGE_CONFIG_FILENAME from horde_worker_regen.process_management._aliased_types import ProcessQueue -from horde_worker_regen.process_management.inference_process import HordeProcessKind +from horde_worker_regen.process_management.horde_process import HordeProcessType from horde_worker_regen.process_management.messages import ( HordeControlFlag, HordeControlMessage, @@ -69,16 +69,27 @@ class HordeProcessInfo: + """Contains information about a horde child process.""" + mp_process: multiprocessing.Process + """The multiprocessing.Process object for this process.""" pipe_connection: Connection + """The connection through which messages can be sent to this process.""" process_id: int - process_type: HordeProcessKind + """The ID of this process. This is not an OS process ID.""" + process_type: HordeProcessType + """The type of this process.""" last_process_state: HordeProcessState + """The last known state of this process.""" loaded_horde_model_name: str | None = None + """The name of the horde model that is (supposedly) currently loaded in this process.""" ram_usage_bytes: int = 0 + """The amount of RAM used by this process.""" vram_usage_bytes: int = 0 + """The amount of VRAM used by this process.""" total_vram_bytes: int = 0 + """The total amount of VRAM available to this process.""" # TODO: VRAM usage @@ -87,9 +98,18 @@ def __init__( mp_process: multiprocessing.Process, pipe_connection: Connection, process_id: int, - process_type: HordeProcessKind, + process_type: HordeProcessType, last_process_state: HordeProcessState, ) -> None: + """Initializes a new HordeProcessInfo object. + + Args: + mp_process (multiprocessing.Process): The multiprocessing.Process object for this process. + pipe_connection (Connection): The connection through which messages can be sent to this process. + process_id (int): The ID of this process. This is not an OS process ID. + process_type (HordeProcessType): The type of this process. + last_process_state (HordeProcessState): The last known state of this process. + """ self.mp_process = mp_process self.pipe_connection = pipe_connection self.process_id = process_id @@ -97,6 +117,9 @@ def __init__( self.last_process_state = last_process_state def is_process_busy(self) -> bool: + """Return true if the process is actively engaged in a task. + This does not include the process starting up or shutting down.""" + return ( self.last_process_state == HordeProcessState.INFERENCE_STARTING or self.last_process_state == HordeProcessState.ALCHEMY_STARTING @@ -114,6 +137,7 @@ def __repr__(self) -> str: ) def can_accept_job(self) -> bool: + """Return true if the process can accept a job.""" return ( self.last_process_state == HordeProcessState.WAITING_FOR_JOB or self.last_process_state == HordeProcessState.INFERENCE_COMPLETE @@ -129,6 +153,18 @@ def update_entry( load_state: ModelLoadState | None = None, process_id: int | None = None, ) -> None: + """Update the entry for the given model name. If the model does not exist, it will be created. + + Args: + horde_model_name (str): The (horde) name of the model to update. + load_state (ModelLoadState | None, optional): The load state of the model. Defaults to None. + process_id (int | None, optional): The process ID of the process that has this model loaded. \ + Defaults to None. + + Raises: + ValueError: If the process_id is None and the model does not exist. + ValueError: If the load_state is None and the model does not exist. + """ if horde_model_name not in self.root: if process_id is None: raise ValueError("process_id must be provided when adding a new model to the map") @@ -148,11 +184,13 @@ def update_entry( self.root[horde_model_name].process_id = process_id def is_model_loaded(self, horde_model_name: str) -> bool: + """Return true if the given model is loaded in any process.""" if horde_model_name not in self.root: return False return self.root[horde_model_name].horde_model_load_state.is_loaded() def is_model_loading(self, horde_model_name: str) -> bool: + """Return true if the given model is currently being loaded in any process.""" if horde_model_name not in self.root: return False return self.root[horde_model_name].horde_model_load_state == ModelLoadState.LOADING @@ -173,6 +211,19 @@ def update_entry( vram_usage_bytes: int | None = None, total_vram_bytes: int | None = None, ) -> None: + """Update the entry for the given process ID. If the process does not exist, it will be created. + + Args: + process_id (int): The ID of the process to update. + last_process_state (HordeProcessState | None, optional): The last process state of the process. \ + Defaults to None. + loaded_horde_model_name (str | None, optional): The name of the horde model that is (supposedly) \ + currently loaded in this process. Defaults to None. + ram_usage_bytes (int | None, optional): The amount of RAM used by this process. Defaults to None. + vram_usage_bytes (int | None, optional): The amount of VRAM used by this process. Defaults to None. + total_vram_bytes (int | None, optional): The total amount of VRAM available to this process. \ + Defaults to None. + """ if last_process_state is not None: self[process_id].last_process_state = last_process_state @@ -189,31 +240,36 @@ def update_entry( self[process_id].total_vram_bytes = total_vram_bytes def num_inference_processes(self) -> int: + """Return the number of inference processes.""" count = 0 for p in self.values(): - if p.process_type == HordeProcessKind.INFERENCE: + if p.process_type == HordeProcessType.INFERENCE: count += 1 return count def num_available_inference_processes(self) -> int: + """Return the number of inference processes that are available to accept jobs.""" count = 0 for p in self.values(): - if p.process_type == HordeProcessKind.INFERENCE and not p.is_process_busy(): + if p.process_type == HordeProcessType.INFERENCE and not p.is_process_busy(): count += 1 return count def get_first_available_inference_process(self) -> HordeProcessInfo | None: + """Return the first available inference process, or None if there are none available.""" for p in self.values(): - if p.process_type == HordeProcessKind.INFERENCE and p.can_accept_job(): + if p.process_type == HordeProcessType.INFERENCE and p.can_accept_job(): if p.last_process_state == HordeProcessState.PRELOADED_MODEL: continue return p return None - def get_first_inference_process_to_kill(self) -> HordeProcessInfo | None: + def _get_first_inference_process_to_kill(self) -> HordeProcessInfo | None: + """Return the first inference process eligible to be killed, or None if there are none. + Used during shutdown.""" for p in self.values(): - if p.process_type != HordeProcessKind.INFERENCE: + if p.process_type != HordeProcessType.INFERENCE: continue if ( @@ -232,31 +288,37 @@ def get_first_inference_process_to_kill(self) -> HordeProcessInfo | None: return None def get_safety_process(self) -> HordeProcessInfo | None: + """Return the safety process.""" for p in self.values(): - if p.process_type == HordeProcessKind.SAFETY: + if p.process_type == HordeProcessType.SAFETY: return p return None def num_safety_processes(self) -> int: + """Return the number of safety processes.""" count = 0 for p in self.values(): - if p.process_type == HordeProcessKind.SAFETY: + if p.process_type == HordeProcessType.SAFETY: count += 1 return count def get_first_available_safety_process(self) -> HordeProcessInfo | None: + """Return the first available safety process, or None if there are none available.""" for p in self.values(): - if p.process_type == HordeProcessKind.SAFETY and p.last_process_state == HordeProcessState.WAITING_FOR_JOB: + if p.process_type == HordeProcessType.SAFETY and p.last_process_state == HordeProcessState.WAITING_FOR_JOB: return p return None def get_process_by_horde_model_name(self, horde_model_name: str) -> HordeProcessInfo | None: + """Return the process that has the given horde model loaded, or None if there is none.""" for p in self.values(): if p.loaded_horde_model_name == horde_model_name: return p return None def num_busy_processes(self) -> int: + """Return the number of processes that are actively engaged in a task. This does not include processes which + are starting up or shutting down, or in a faulted state.""" count = 0 for p in self.values(): if p.is_process_busy(): @@ -266,7 +328,7 @@ def num_busy_processes(self) -> int: def __repr__(self) -> str: base_string = "Processes: " for process_id, process_info in self.items(): - if process_info.process_type == HordeProcessKind.INFERENCE: + if process_info.process_type == HordeProcessType.INFERENCE: base_string += f"{process_id}: ({process_info.loaded_horde_model_name}) " else: base_string += f"{process_id}: ({process_info.process_type.name}) " @@ -281,16 +343,24 @@ class TorchDeviceInfo(BaseModel): total_memory: int -class TorchDeviceMap(RootModel[dict[int, TorchDeviceInfo]]): +class TorchDeviceMap(RootModel[dict[int, TorchDeviceInfo]]): # TODO pass class CompletedJobInfo(BaseModel): + """Contains information about a job that has been generated. It is used to track the state of the job + as it goes through the safety process and then when it is returned to the requesting user.""" + job_info: ImageGenerateJobPopResponse + """The API response which has all of the information about the job.""" job_result_images_base64: list[str] | None = None + """A list of base64 encoded images that are the result of the job.""" state: GENERATION_STATE + """The state of the job to send to the API.""" censored: bool | None = None + """Whether or not the job was censored. This is set by the safety process.""" time_to_generate: float + """The time it took to generate the job. This is set by the inference process.""" @property def is_job_checked_for_safety(self) -> bool: @@ -441,6 +511,8 @@ def __init__( self.max_inference_processes = self.bridge_data.queue_size + self.bridge_data.max_threads + # If there is only one model to load and only one inference process, then we can only run one job at a time + # and there is no point in having more than one inference process if len(self.bridge_data.image_models_to_load) == 1 and self.max_concurrent_inference_processes == 1: self.max_inference_processes = 1 @@ -454,7 +526,7 @@ def __init__( self._jobs_safety_check_lock = Lock_Asyncio() - self.target_vram_overhead_bytes_map = target_vram_overhead_bytes_map + self.target_vram_overhead_bytes_map = target_vram_overhead_bytes_map # TODO self.total_ram_bytes = psutil.virtual_memory().total @@ -492,8 +564,6 @@ def __init__( self._process_message_queue = multiprocessing.Queue() - # The parent process already downloaded and converted the model references - self.stable_diffusion_reference = None while self.stable_diffusion_reference is None: @@ -521,22 +591,23 @@ def is_time_for_shutdown(self) -> bool: ): return True + # If any job hasn't been submitted to the API yet, then we can't shut down if len(self.completed_jobs) > 0: return False + # If there are any jobs in progress, then we can't shut down if len(self.jobs_being_safety_checked) > 0 or len(self.jobs_pending_safety_check) > 0: return False - if len(self.jobs_in_progress) > 0: return False - if len(self.job_deque) > 0: return False any_process_alive = False for process_info in self._process_map.values(): - if process_info.process_type != HordeProcessKind.INFERENCE: + # The safety process gets shut down last and is part of cleanup + if process_info.process_type != HordeProcessType.INFERENCE: continue if ( @@ -546,12 +617,14 @@ def is_time_for_shutdown(self) -> bool: any_process_alive = True continue + # If there are any inference processes still alive, then we can't shut down return not any_process_alive def is_free_inference_process_available(self) -> bool: + """Return true if there is an inference process available which can accept a job.""" return self._process_map.num_available_inference_processes() > 0 - def get_expected_ram_usage(self, horde_model_name: str) -> int: + def get_expected_ram_usage(self, horde_model_name: str) -> int: # TODO: Use or rework this if self.stable_diffusion_reference is None: raise ValueError("stable_diffusion_reference is None") @@ -610,7 +683,7 @@ def start_safety_processes(self) -> None: mp_process=process, pipe_connection=pipe_connection, process_id=pid, - process_type=HordeProcessKind.SAFETY, + process_type=HordeProcessType.SAFETY, last_process_state=HordeProcessState.PROCESS_STARTING, ) @@ -655,7 +728,7 @@ def start_inference_processes(self) -> None: mp_process=process, pipe_connection=pipe_connection, process_id=pid, - process_type=HordeProcessKind.INFERENCE, + process_type=HordeProcessType.INFERENCE, last_process_state=HordeProcessState.PROCESS_STARTING, ) @@ -668,7 +741,7 @@ def end_inference_processes(self) -> None: return # Get the process to end - process_info = self._process_map.get_first_inference_process_to_kill() + process_info = self._process_map._get_first_inference_process_to_kill() if process_info is None: return @@ -703,7 +776,19 @@ def end_safety_processes(self) -> None: logger.info(f"Ended safety process {process_info.process_id}") def receive_and_handle_process_messages(self) -> None: - """Receive and handle any messages from the child processes.""" + """Receive and handle any messages from the child processes. This is the backbone of the \ + inter-process communication system and is the main way that the parent process knows what is going on \ + in the child processes. + + **Note** also that this is a synchronous function and any interaction with objects that are shared between \ + coroutines should be done with care. Critically, this function should be called with locks already \ + acquired on any shared objects. + + See also `._process_map` and `._horde_model_map`, which are updated by this function, and `HordeProcessState` \ + and `ModelLoadState` for the possible states that the processes and models can be in. + """ + + # We want to completely flush the queue, to maximize the chances we get the most up to date information while not self._process_message_queue.empty(): message: HordeProcessMessage = self._process_message_queue.get() @@ -712,12 +797,26 @@ def receive_and_handle_process_messages(self) -> None: # f"{message.model_dump(exclude={'job_result_images_base64', 'replacement_image_base64'})}", ) + # These events happening are program-breaking conditions that (hopefully) should never happen in production + # and are mainly to make debugging easier when making changes to the code, but serve as a guard against + # truly catastrophic failures if not isinstance(message, HordeProcessMessage): raise ValueError(f"Received a message that is not a HordeProcessMessage: {message}") - if message.process_id not in self._process_map: raise ValueError(f"Received a message from an unknown process: {message}") + # If the process is updating us on its memory usage, update the process map for those values only + # and then continue to the next message + if isinstance(message, HordeProcessMemoryMessage): + self._process_map.update_entry( + process_id=message.process_id, + ram_usage_bytes=message.ram_usage_bytes, + vram_usage_bytes=message.vram_usage_bytes, + total_vram_bytes=message.vram_total_bytes, + ) + continue + + # If the process state has changed, update the process map if isinstance(message, HordeProcessStateChangeMessage): self._process_map.update_entry( process_id=message.process_id, @@ -726,20 +825,20 @@ def receive_and_handle_process_messages(self) -> None: logger.debug(f"Process {message.process_id} changed state to {message.process_state}") if message.process_state == HordeProcessState.INFERENCE_STARTING: - logger.info(f"Process {message.process_id} is starting inference on model {message.info}") + # logger.info(f"Process {message.process_id} is starting inference on model {message.info}") loaded_model_name = self._process_map[message.process_id].loaded_horde_model_name if loaded_model_name is None: raise ValueError( f"Process {message.process_id} has no model loaded, but is starting inference", ) - self._horde_model_map.update_entry( horde_model_name=loaded_model_name, load_state=ModelLoadState.IN_USE, process_id=message.process_id, ) + # If The model state has changed, update the model map if isinstance(message, HordeModelStateChangeMessage): self._horde_model_map.update_entry( horde_model_name=message.horde_model_name, @@ -754,6 +853,7 @@ def receive_and_handle_process_messages(self) -> None: loaded_horde_model_name=message.horde_model_name, ) + # If the model was just loaded, so update the process map and log a message with the time it took if ( message.horde_model_state == ModelLoadState.LOADED_IN_VRAM or message.horde_model_state == ModelLoadState.LOADED_IN_RAM @@ -783,22 +883,13 @@ def receive_and_handle_process_messages(self) -> None: # FIXME this message is wrong for download processes logger.info(f"Process {message.process_id} unloaded model {message.horde_model_name}") - if isinstance(message, HordeProcessMemoryMessage): - self._process_map.update_entry( - process_id=message.process_id, - ram_usage_bytes=message.ram_usage_bytes, - vram_usage_bytes=message.vram_usage_bytes, - total_vram_bytes=message.vram_total_bytes, - ) - + # If the process is sending us an inference job result: + # - if its a faulted job, log an error and add it to the list of completed jobs to be sent to the API + # - if its a completed job, add it to the list of jobs pending safety checks if isinstance(message, HordeInferenceResultMessage): - if message.job_result_images_base64 is None: - logger.error(f"Received an inference result message with a None job_result: {message}") - continue - _num_jobs_in_progress = len(self.jobs_in_progress) - # Remove the job from the jobs in progress by matching the job ID (.id_) + # Remove the job from the jobs in progress by matching the job ID (.id_) self.jobs_in_progress = [job for job in self.jobs_in_progress if job.id_ != message.job_info.id_] if len(self.jobs_in_progress) != _num_jobs_in_progress - 1: @@ -823,14 +914,33 @@ def receive_and_handle_process_messages(self) -> None: logger.info(f"Inference finished for job {message.job_info.id_}") logger.debug(f"Job didn't include time_elapsed: {message.job_info}") - self.jobs_pending_safety_check.append( - CompletedJobInfo( - job_info=message.job_info, - job_result_images_base64=message.job_result_images_base64, - state=message.state, - time_to_generate=message.time_elapsed if message.time_elapsed is not None else 0, - ), - ) + if message.state != GENERATION_STATE.faulted: + self.jobs_pending_safety_check.append( + CompletedJobInfo( + job_info=message.job_info, + job_result_images_base64=message.job_result_images_base64, + state=message.state, + time_to_generate=message.time_elapsed if message.time_elapsed is not None else 0, + ), + ) + else: + logger.error( + f"Job {message.job_info.id_} faulted on process {message.process_id}: {message.info}", + ) + + self.completed_jobs.append( + CompletedJobInfo( + job_info=message.job_info, + job_result_images_base64=None, + state=message.state, + time_to_generate=message.time_elapsed if message.time_elapsed is not None else 0, + ), + ) + + # If the process is sending us a safety job result: + # - if an unexpected error occurred, log an error a + # - if the job was censored, replace the images with the replacement images + # - add the job to the list of completed jobs to be sent to the API elif isinstance(message, HordeSafetyResultMessage): completed_job_info: CompletedJobInfo | None = None for i, job_being_safety_checked in enumerate(self.jobs_being_safety_checked): @@ -846,13 +956,17 @@ def receive_and_handle_process_messages(self) -> None: num_images_censored = 0 num_images_csam = 0 + any_safety_failed = False + for i in range(len(completed_job_info.job_result_images_base64)): replacement_image = message.safety_evaluations[i].replacement_image_base64 if message.safety_evaluations[i].failed: logger.error( - f"Job {message.job_id} image #{i} wasn't safety checked ", + f"Job {message.job_id} image #{i} faulted during safety checks. " + "Check the safety process logs for more information.", ) + any_safety_failed = True continue if replacement_image is not None: @@ -866,7 +980,9 @@ def receive_and_handle_process_messages(self) -> None: f"{message.time_elapsed:.2f} seconds to check safety", ) - if num_images_censored > 0: + if any_safety_failed: + completed_job_info.state = GENERATION_STATE.faulted + elif num_images_censored > 0: completed_job_info.censored = True if num_images_csam > 0: completed_job_info.state = GENERATION_STATE.csam @@ -877,12 +993,15 @@ def receive_and_handle_process_messages(self) -> None: self.completed_jobs.append(completed_job_info) - def preload_models(self) -> None: - """Preload models that are likely to be used soon.""" + def preload_models(self) -> bool: + """Preload models that are likely to be used soon. + + Returns: + True if a model was preloaded, False otherwise. + """ # Starting from the left of the deque, preload models that are not yet loaded up to the - # number of inference processes - # that are available + # number of inference processes that are available for job in self.job_deque: model_is_loaded = False @@ -907,7 +1026,7 @@ def preload_models(self) -> None: available_process = self._process_map.get_first_available_inference_process() if available_process is None: - return + return False logger.debug(f"Preloading model {job.model} on process {available_process.process_id}") logger.debug(f"Available inference processes: {self._process_map}") @@ -936,7 +1055,9 @@ def preload_models(self) -> None: loaded_horde_model_name=job.model, ) - break + return True + + return False def start_inference(self) -> None: """Start inference for the next job in the deque, if possible.""" @@ -1005,6 +1126,7 @@ def start_inference(self) -> None: ) logger.info(f"Starting inference for job {next_job.id_} on process {process_with_model.process_id}") + # region Log job info logger.info(f"Model: {next_job.model}") if next_job.source_image is not None: logger.info(f"Using {next_job.source_processing}") @@ -1036,6 +1158,7 @@ def start_inference(self) -> None: f"{next_job.payload.width}x{next_job.payload.height} for {next_job.payload.ddim_steps} steps " f"with sampler {next_job.payload.sampler_name}", ) + # endregion self.jobs_in_progress.append(next_job) process_with_model.pipe_connection.send( @@ -1047,7 +1170,11 @@ def start_inference(self) -> None: ) def unload_from_ram(self, process_id: int) -> None: - """Unload models from a process, either from VRAM or both VRAM and system RAM.""" + """Unload models from a process, either from VRAM or both VRAM and system RAM. + + Args: + process_id: The process to unload models from. + """ if process_id not in self._process_map: raise ValueError(f"process_id {process_id} is not in the process map") @@ -1079,7 +1206,14 @@ def unload_from_ram(self, process_id: int) -> None: ) def get_next_n_models(self, n: int) -> set[str]: - """Get the next n models that will be used in the job deque.""" + """Get the next n models that will be used in the job deque. + + Args: + n: The number of models to get. + + Returns: + A set of the next n models that will be used in the job deque. + """ next_n_models: set[str] = set() jobs_traversed = 0 @@ -1136,6 +1270,7 @@ def unload_models(self) -> None: self.unload_from_ram(process_info.process_id) def start_evaluate_safety(self) -> None: + """Start evaluating the safety of the next job pending a safety check, if any.""" if len(self.jobs_pending_safety_check) == 0: return @@ -1181,7 +1316,14 @@ def start_evaluate_safety(self) -> None: ) def base64_image_to_stream_buffer(self, image_base64: str) -> BytesIO | None: - """Convert a base64 image to a BytesIO stream buffer.""" + """Convert a base64 image to a BytesIO stream buffer. + + Args: + image_base64: The base64 image to convert. + + Returns: + A BytesIO stream buffer containing the image, or None if the conversion failed. + """ try: image_as_pil = PIL.Image.open(BytesIO(base64.b64decode(image_base64))) image_buffer = BytesIO() @@ -1197,12 +1339,13 @@ def base64_image_to_stream_buffer(self, image_base64: str) -> BytesIO | None: logger.error(f"Failed to convert base64 image to stream buffer: {e}") return None - _consecutive_failed_results = 0 - _consecutive_failed_results_max = 10 + _consecutive_failed_job_submits = 0 + """The number of consecutive failed attempts to submit a job result to the API.""" + _max_consecutive_failed_job_submits = 10 + """The maximum number of consecutive failed attempts to submit a job result to the API.""" - async def api_submit_job( - self, - ) -> None: + async def api_submit_job(self) -> None: + """Submit a job result to the API, if any are completed (safety checked too) and ready to be submitted.""" if len(self.completed_jobs) == 0: return @@ -1211,70 +1354,74 @@ async def api_submit_job( submit_job_request_type = job_info.get_follow_up_default_request_type() - if completed_job_info.job_result_images_base64 is None: - raise ValueError("completed_job_info.job_result_images_base64 is None") - - if len(completed_job_info.job_result_images_base64) > 1: - raise NotImplementedError("Only single image jobs are supported right now") + if completed_job_info.job_result_images_base64 is not None: + if len(completed_job_info.job_result_images_base64) > 1: + raise NotImplementedError("Only single image jobs are supported right now") + if completed_job_info.censored is None: + raise ValueError("completed_job_info.censored is None") if job_info.id_ is None: raise ValueError("job_info.id_ is None") if job_info.payload.seed is None: raise ValueError("job_info.payload.seed is None") - if job_info.r2_upload is None: + if job_info.r2_upload is None: # TODO: r2_upload should be being set somewhere raise ValueError("job_info.r2_upload is None") - if completed_job_info.censored is None: - raise ValueError("completed_job_info.censored is None") - # TODO: n_iter support try: - if self._consecutive_failed_results >= self._consecutive_failed_results_max: + if self._consecutive_failed_job_submits >= self._max_consecutive_failed_job_submits: async with self._completed_jobs_lock: self.completed_jobs.remove(completed_job_info) - self._consecutive_failed_results = 0 + self._consecutive_failed_job_submits = 0 return - image_in_buffer = self.base64_image_to_stream_buffer(completed_job_info.job_result_images_base64[0]) - - if image_in_buffer is None: - logger.critical( - f"There is an invalid image in the job results for {job_info.id_}, removing from completed jobs", - ) - async with self._completed_jobs_lock: - self.completed_jobs.remove(completed_job_info) - self._consecutive_failed_results = 0 - - for follow_up_request in completed_job_info.job_info.get_follow_up_failure_cleanup_request(): - follow_up_response = self.horde_client_session.submit_request( - follow_up_request, - JobSubmitResponse, - ) + if completed_job_info.job_result_images_base64 is not None: + image_in_buffer = self.base64_image_to_stream_buffer(completed_job_info.job_result_images_base64[0]) - if isinstance(follow_up_response, RequestErrorResponse): - logger.error(f"Failed to submit followup request: {follow_up_response}") + if image_in_buffer is None: + logger.critical( + f"There is an invalid image in the job results for {job_info.id_}, " + "removing from completed jobs", + ) + async with self._completed_jobs_lock: + self.completed_jobs.remove(completed_job_info) + self._consecutive_failed_job_submits = 0 + + for follow_up_request in completed_job_info.job_info.get_follow_up_failure_cleanup_request(): + follow_up_response = self.horde_client_session.submit_request( + follow_up_request, + JobSubmitResponse, + ) + + if isinstance(follow_up_response, RequestErrorResponse): + logger.error(f"Failed to submit followup request: {follow_up_response}") + return + + # TODO: This would be better (?) if we could use aiohttp instead of requests + # except for the fact that it causes S3 to return a 403 Forbidden error + + # async with self._aiohttp_session.put( + # yarl.URL(job_info.r2_upload, encoded=True), + # data=image_in_buffer.getvalue(), + # ) as response: + # if response.status != 200: + # logger.error(f"Failed to upload image to R2: {response}") + # return + + response = requests.put(job_info.r2_upload, data=image_in_buffer.getvalue()) + + if response.status_code != 200: + logger.error(f"Failed to upload image to R2: {response}") + self._consecutive_failed_job_submits += 1 return - # TODO: This would be better (?) if we could use aiohttp instead of requests - # except for the fact that it causes S3 to return a 403 Forbidden error - - # async with self._aiohttp_session.put( - # yarl.URL(job_info.r2_upload, encoded=True), - # data=image_in_buffer.getvalue(), - # ) as response: - # if response.status != 200: - # logger.error(f"Failed to upload image to R2: {response}") - # return - - response = requests.put(job_info.r2_upload, data=image_in_buffer.getvalue()) - - if response.status_code != 200: - logger.error(f"Failed to upload image to R2: {response}") - self._consecutive_failed_results += 1 - return + if completed_job_info.state == GENERATION_STATE.faulted: + logger.error( + f"Job {job_info.id_} faulted, removing from completed jobs", + ) submit_job_request = submit_job_request_type( apikey=self.bridge_data.api_key, @@ -1282,11 +1429,12 @@ async def api_submit_job( seed=int(job_info.payload.seed), generation="R2", # TODO # FIXME state=completed_job_info.state, - censored=completed_job_info.censored, + censored=bool(completed_job_info.censored), # TODO: is this cast problematic? ) job_submit_response = await self.horde_client_session.submit_request(submit_job_request, JobSubmitResponse) + # If the job submit response is an error, log it and increment the number of consecutive failed job submits if isinstance(job_submit_response, RequestErrorResponse): if ( "Processing Job with ID" in job_submit_response.message @@ -1295,7 +1443,7 @@ async def api_submit_job( logger.warning(f"Job {job_info.id_} does not exist, removing from completed jobs") async with self._completed_jobs_lock: self.completed_jobs.remove(completed_job_info) - self._consecutive_failed_results = 0 + self._consecutive_failed_job_submits = 0 return @@ -1303,62 +1451,86 @@ async def api_submit_job( logger.debug(f"Job {job_info.id_} has already been submitted, removing from completed jobs") async with self._completed_jobs_lock: self.completed_jobs.remove(completed_job_info) - self._consecutive_failed_results = 0 + self._consecutive_failed_job_submits = 0 return error_string = "Failed to submit job (API Error)" - error_string += f"{self._consecutive_failed_results}/{self._consecutive_failed_results_max} " + error_string += f"{self._consecutive_failed_job_submits}/{self._max_consecutive_failed_job_submits} " error_string += f": {job_submit_response}" logger.error(error_string) - self._consecutive_failed_results += 1 + self._consecutive_failed_job_submits += 1 return + # Get the time the job was popped from the job deque async with self._job_pop_timestamps_lock: time_popped = self.job_pop_timestamps.pop(str(completed_job_info.job_info.id_)) time_taken = round(time.time() - time_popped, 2) - logger.success( - f"Submitted job {job_info.id_} (model: {job_info.model}) for {job_submit_response.reward:.2f} kudos. " - f"Job popped {time_taken} seconds ago and took {completed_job_info.time_to_generate:.2f} to generate.", - ) + # If the job was not faulted, log the job submission as a success + if completed_job_info.state != GENERATION_STATE.faulted: + logger.success( + f"Submitted job {job_info.id_} (model: {job_info.model}) for {job_submit_response.reward:.2f} " + f"kudos. Job popped {time_taken} seconds ago and took {completed_job_info.time_to_generate:.2f} " + "to generate.", + ) + # If the job was faulted, log an error + else: + logger.error( + f"{job_info.id_} faulted, not submitting for kudos. Job popped {time_taken} seconds ago and took " + f"{completed_job_info.time_to_generate:.2f} to generate.", + ) + # If the job took a long time to generate, log a warning (unless speed warnings are suppressed) if not self.bridge_data.suppress_speed_warnings: - if (job_submit_response.reward / time_taken) < 0.1: + if job_submit_response.reward > 0 and (job_submit_response.reward / time_taken) < 0.1: logger.warning( f"This job ({job_info.id_}) may have been in the queue for a long time. ", ) - if (job_submit_response.reward / completed_job_info.time_to_generate) < 0.4: + if ( + job_submit_response.reward > 0 + and (job_submit_response.reward / completed_job_info.time_to_generate) < 0.4 + ): logger.warning( f"This job ({job_info.id_}) took longer than is ideal; if this persists consider " "lowering your max_power, using less threads, disabling post processing and/or controlnets.", ) self.kudos_generated_this_session += job_submit_response.reward + + # Finally, remove the job from the completed jobs list and reset the number of consecutive failed job async with self._completed_jobs_lock: self.completed_jobs.remove(completed_job_info) - self._consecutive_failed_results = 0 + self._consecutive_failed_job_submits = 0 + except Exception as e: logger.error(f"Failed to submit job (Unexpected Error): {e}") - self._consecutive_failed_results += 1 + self._consecutive_failed_job_submits += 1 return await asyncio.sleep(self._api_call_loop_interval) - _testing_max_jobs = 10000 - _testing_jobs_added = 0 - _testing_job_queue_length = 1 + # _testing_max_jobs = 10000 + # _testing_jobs_added = 0 + # _testing_job_queue_length = 1 _default_job_pop_frequency = 1.0 + """The default frequency at which to pop jobs from the API.""" _error_job_pop_frequency = 5.0 + """The frequency at which to pop jobs from the API when an error occurs.""" _job_pop_frequency = 1.0 + """The frequency at which to pop jobs from the API. Can be altered if an error occurs.""" _last_job_pop_time = 0.0 + """The time at which the last job was popped from the API.""" _max_pending_megapixelsteps = 45 + """The maximum number of megapixelsteps that can be pending in the job deque before job pops are paused.""" _triggered_max_pending_megapixelsteps_time = 0.0 + """The time at which the number of megapixelsteps in the job deque exceeded the limit.""" _triggered_max_pending_megapixelsteps = False + """Whether the number of megapixelsteps in the job deque exceeded the limit.""" def get_pending_megapixelsteps(self) -> int: """Get the number of megapixelsteps that are pending in the job deque.""" @@ -1371,8 +1543,38 @@ def get_pending_megapixelsteps(self) -> int: def should_wait_for_pending_megapixelsteps(self) -> bool: """Check if the number of megapixelsteps in the job deque is above the limit.""" + # TODO: Option to increase the limit for higher end GPUs + return self.get_pending_megapixelsteps() > self._max_pending_megapixelsteps + async def _get_source_images(self, job_pop_response: ImageGenerateJobPopResponse) -> ImageGenerateJobPopResponse: + # TODO: Move this into horde_sdk + for field_name in ["source_image", "source_mask"]: + field_value = getattr(job_pop_response, field_name) + if field_value is not None and "https://" in field_value: + fail_count = 0 + while True: + try: + if fail_count >= 10: + logger.error(f"Failed to download {field_name} after {fail_count} attempts") + break + response = await self._aiohttp_session.get(field_value) + response.raise_for_status() + new_response_dict = job_pop_response.model_dump(by_alias=True) + + content = await response.content.read() + + new_response_dict[field_name] = base64.b64encode(content).decode("utf-8") + job_pop_response = ImageGenerateJobPopResponse(**new_response_dict) + logger.debug(f"Downloaded {field_name} for job {job_pop_response.id_}") + break + except Exception as e: + logger.warning(f"Failed to download {field_name}: {e}") + fail_count += 1 + time.sleep(0.5) + + return job_pop_response + async def api_job_pop(self) -> None: """If the job deque is not full, add any jobs that are available to the job deque.""" if self._shutting_down: @@ -1381,18 +1583,23 @@ async def api_job_pop(self) -> None: if len(self.job_deque) >= self.bridge_data.queue_size + 1: # FIXME? return - if len(self.job_deque) == 1 and self.completed_jobs == 0: + # We let the first job run through to make sure things are working + # (if we're doomed to fail with 1 job, we're doomed to fail with 2 jobs) + if len(self.job_deque) != 0 and self.completed_jobs == 0: return # if self._testing_jobs_added >= self._testing_max_jobs: # return + # Don't start jobs if we can't evaluate safety (NSFW/CSAM) if self._process_map.get_first_available_safety_process() is None: return + # Don't start jobs if we can't run inference if self._process_map.get_first_available_inference_process() is None: return + # If there are long running jobs, don't start any more even if there is space in the deque if self.should_wait_for_pending_megapixelsteps(): if self._triggered_max_pending_megapixelsteps is False: self._triggered_max_pending_megapixelsteps = True @@ -1402,7 +1609,12 @@ async def api_job_pop(self) -> None: ) return - if (time.time() - self._triggered_max_pending_megapixelsteps_time) > 0.0: + # Assuming a megapixelstep takes 0.75 seconds, if 2/3 of the time has passed since the limit was triggered, + # we can assume that the pending megapixelsteps will be below the limit soon. Otherwise we continue to wait + + if not (time.time() - self._triggered_max_pending_megapixelsteps_time) > ( + (self._max_pending_megapixelsteps * 0.75) * (2 / 3) + ): return self._triggered_max_pending_megapixelsteps = False @@ -1412,6 +1624,7 @@ async def api_job_pop(self) -> None: self._triggered_max_pending_megapixelsteps = False + # We don't want to pop jobs too frequently, so we wait a bit between each pop if time.time() - self._last_job_pop_time < self._job_pop_frequency: return @@ -1448,6 +1661,7 @@ async def api_job_pop(self) -> None: ImageGenerateJobPopResponse, ) + # TODO: horde_sdk should handle this and return a field with a enum(?) of the reason if isinstance(job_pop_response, RequestErrorResponse): if "maintenance mode" in job_pop_response.message: logger.warning(f"Failed to pop job (Maintenance Mode): {job_pop_response}") @@ -1455,8 +1669,13 @@ async def api_job_pop(self) -> None: logger.error(f"Failed to pop job (API Error): {job_pop_response}") self._job_pop_frequency = self._error_job_pop_frequency return + except Exception as e: - logger.error(f"Failed to pop job (Unexpected Error): {e}") + if self._job_pop_frequency == self._error_job_pop_frequency: + logger.error(f"Failed to pop job (Unexpected Error): {e}") + else: + logger.warning(f"Failed to pop job (Unexpected Error): {e}") + self._job_pop_frequency = self._error_job_pop_frequency return @@ -1464,70 +1683,35 @@ async def api_job_pop(self) -> None: info_string = "No job available. " if len(self.job_deque) > 0: - info_string += f"Current job deque length: {len(self.job_deque)}. " + info_string += f"Current number of popped jobs: {len(self.job_deque)}. " + info_string += f"(Skipped reasons: {job_pop_response.skipped.model_dump(exclude_defaults=True)})" if job_pop_response.id_ is None: - logger.info( - f"No job available. (Skipped reasons: {job_pop_response.skipped.model_dump(exclude_defaults=True)})", - ) + logger.info(info_string) return logger.info(f"Popped job {job_pop_response.id_} (model: {job_pop_response.model})") - if job_pop_response.payload.seed is None: + # region TODO: move to horde_sdk + if job_pop_response.payload.seed is None: # TODO # FIXME logger.warning(f"Job {job_pop_response.id_} has no seed!") new_response_dict = job_pop_response.model_dump(by_alias=True) new_response_dict["payload"]["seed"] = random.randint(0, (2**32) - 1) - job_pop_response = ImageGenerateJobPopResponse(**new_response_dict) if job_pop_response.payload.denoising_strength is not None and job_pop_response.source_image is None: logger.debug(f"Job {job_pop_response.id_} has denoising_strength but no source image!") new_response_dict = job_pop_response.model_dump(by_alias=True) new_response_dict["payload"]["denoising_strength"] = None + + if job_pop_response.payload.seed is None or ( + job_pop_response.payload.denoising_strength is not None and job_pop_response.source_image is None + ): job_pop_response = ImageGenerateJobPopResponse(**new_response_dict) - if job_pop_response.source_image is not None and "https://" in job_pop_response.source_image: - # Download and convert the source image to base64 - fail_count = 0 - while True: - try: - if fail_count >= 10: - logger.error(f"Failed to download source image after {fail_count} attempts") - break - source_image_response = requests.get(job_pop_response.source_image) - source_image_response.raise_for_status() - new_response_dict = job_pop_response.model_dump(by_alias=True) + job_pop_response = await self._get_source_images(job_pop_response) - new_response_dict["source_image"] = base64.b64encode(source_image_response.content).decode("utf-8") - job_pop_response = ImageGenerateJobPopResponse(**new_response_dict) - logger.debug(f"Downloaded source image for job {job_pop_response.id_}") - break - except Exception as e: - logger.error(f"Failed to download source image: {e}") - fail_count += 1 - time.sleep(0.5) - - if job_pop_response.source_mask is not None and "https://" in job_pop_response.source_mask: - # Download and convert the source image to base64 - fail_count = 0 - while True: - try: - if fail_count >= 10: - logger.error(f"Failed to download source image after {fail_count} attempts") - break - source_mask_response = requests.get(job_pop_response.source_mask) - source_mask_response.raise_for_status() - new_response_dict = job_pop_response.model_dump(by_alias=True) - - new_response_dict["source_mask"] = base64.b64encode(source_mask_response.content).decode("utf-8") - job_pop_response = ImageGenerateJobPopResponse(**new_response_dict) - logger.debug(f"Downloaded source image for job {job_pop_response.id_}") - break - except Exception as e: - logger.error(f"Failed to download source_mask: {e}") - fail_count += 1 - time.sleep(0.5) + # endregion if job_pop_response.id_ is None: logger.error("Job has no id!") @@ -1535,7 +1719,7 @@ async def api_job_pop(self) -> None: async with self._job_deque_lock, self._job_pop_timestamps_lock: self.job_deque.append(job_pop_response) - self._testing_jobs_added += 1 + # self._testing_jobs_added += 1 self.job_pop_timestamps[str(job_pop_response.id_)] = time.time() _user_info_failed = False @@ -1562,9 +1746,12 @@ async def api_get_user_info(self) -> None: if self.user_info.kudos_details is not None: # print kudos this session and kudos per hour based on self.session_start_time kudos_per_hour = self.kudos_generated_this_session / (time.time() - self.session_start_time) * 3600 - logger.success( - f"Kudos this session: {self.kudos_generated_this_session:.2f} (~{kudos_per_hour:.2f} kudos/hour)", - ) + + if self.kudos_generated_this_session > 0: + logger.success( + f"Kudos this session: {self.kudos_generated_this_session:.2f} " + f"(~{kudos_per_hour:.2f} kudos/hour)", + ) logger.info(f"Worker Kudos Accumulated: {self.user_info.kudos_details.accumulated }") except ClientError as e: @@ -1649,8 +1836,6 @@ async def _process_control_loop(self) -> None: if self.stable_diffusion_reference is None: return - # We don't want to pop jobs from the deque while we are adding jobs to it - # TODO: Is this necessary? async with self._job_deque_lock, self._jobs_safety_check_lock, self._completed_jobs_lock: self.receive_and_handle_process_messages() @@ -1660,12 +1845,13 @@ async def _process_control_loop(self) -> None: if self.is_free_inference_process_available() and len(self.job_deque) > 0: async with self._job_deque_lock, self._jobs_safety_check_lock, self._completed_jobs_lock: - self.preload_models() - self.start_inference() + # So long as we didn't preload a model this cycle, we can start inference + # We want to get any messages next cycle from preloading processes to make sure + # the state of everything is up to date + if not self.preload_models(): + self.start_inference() await asyncio.sleep(self._loop_interval / 2) - # We don't want to pop jobs from the deque while we are adding jobs to it - # TODO: Is this necessary? async with self._job_deque_lock, self._jobs_safety_check_lock, self._completed_jobs_lock: self.receive_and_handle_process_messages() @@ -1685,6 +1871,7 @@ async def _process_control_loop(self) -> None: logger.info(f"Number of jobs pending safety check: {len(self.jobs_pending_safety_check)}") logger.info(f"Number of jobs being safety checked: {len(self.jobs_being_safety_checked)}") logger.info(f"Number of jobs completed: {len(self.completed_jobs)}") + # TODO: Faulted logger.info(f"Number of jobs submitted: {self.total_num_completed_jobs}") self._last_status_message_time = time.time() @@ -1788,14 +1975,13 @@ def signal_handler(self, sig: int, frame: object) -> None: self._caught_sigints += 1 logger.warning("Shutting down after current jobs are finished...") self._shutting_down = True - self._start_timed_shutdown() def _start_timed_shutdown(self) -> None: import threading def shutdown() -> None: # Just in case the process manager gets stuck on shutdown - time.sleep((self.get_pending_megapixelsteps() * 3.5) + 10) + time.sleep((len(self.jobs_pending_safety_check) * 4) + 2) for process in self._process_map.values(): process.mp_process.kill() diff --git a/horde_worker_regen/process_management/safety_process.py b/horde_worker_regen/process_management/safety_process.py index 0f103bc8..6d952211 100644 --- a/horde_worker_regen/process_management/safety_process.py +++ b/horde_worker_regen/process_management/safety_process.py @@ -188,5 +188,5 @@ def _receive_and_handle_control_message(self, message: HordeControlMessage) -> N self.send_process_state_change_message(HordeProcessState.WAITING_FOR_JOB, "Waiting for job") @override - def cleanup_and_exit(self) -> None: - return super().cleanup_and_exit() + def cleanup_for_exit(self) -> None: + return super().cleanup_for_exit() diff --git a/run_worker.py b/run_worker.py index 962c6209..3bb24230 100644 --- a/run_worker.py +++ b/run_worker.py @@ -49,7 +49,7 @@ def ensure_model_db_downloaded() -> ModelReferenceManager: horde_model_reference_manager=horde_model_reference_manager, ) except Exception as e: - logger.debug(e) + logger.exception(e) if "No such file or directory" in str(e): logger.error(f"Could not find {BRIDGE_CONFIG_FILENAME}. Please create it and try again.")