From 525a047625fab4a94cb9c95f1a80c65f411f5b64 Mon Sep 17 00:00:00 2001 From: tazlin Date: Mon, 2 Oct 2023 22:37:25 -0400 Subject: [PATCH] feat: SIGINT handling; pause job pop for big jobs; better logs --- bridgeData_template.yaml | 1 + .../process_management/horde_process.py | 28 +- .../process_management/messages.py | 7 - .../process_management/process_manager.py | 290 +++++++++++++----- requirements.txt | 3 +- run_worker.py | 10 +- 6 files changed, 256 insertions(+), 83 deletions(-) diff --git a/bridgeData_template.yaml b/bridgeData_template.yaml index 7403b68e..3d103d3b 100644 --- a/bridgeData_template.yaml +++ b/bridgeData_template.yaml @@ -135,6 +135,7 @@ models_to_load: # This is to avoid loading models which you do not want either due to VRAM constraints, or due to NSFW content models_to_skip: - "pix2pix" + - "SDXL_beta::stability.ai#6901" # Do not remove this, as this model would never work #- "stable_diffusion_inpainting" # Inpainting is generally quite heavy along with other models for smaller GPUs. #- "stable_diffusion_2.1", # Stable diffusion 2.1 has bigger memory requirements than 1.5, so if your card cannot lift, it, disable it #- "stable_diffusion_2.0", # Same as Stable diffusion 2.1 diff --git a/horde_worker_regen/process_management/horde_process.py b/horde_worker_regen/process_management/horde_process.py index c515054d..7386bd0e 100644 --- a/horde_worker_regen/process_management/horde_process.py +++ b/horde_worker_regen/process_management/horde_process.py @@ -2,6 +2,7 @@ import abc import enum +import signal import time from abc import abstractmethod from enum import auto @@ -155,6 +156,7 @@ def receive_and_handle_control_messages(self) -> None: if message.control_flag == HordeControlFlag.END_PROCESS: self._end_process = True + logger.info("Received end process message") return self._receive_and_handle_control_message(message) @@ -165,11 +167,31 @@ def worker_cycle(self) -> None: def main_loop(self) -> None: """The main loop of the worker process.""" + signal.signal(signal.SIGINT, signal_handler) + while not self._end_process: - time.sleep(self._loop_interval) + try: + time.sleep(self._loop_interval) - self.receive_and_handle_control_messages() + self.receive_and_handle_control_messages() - self.worker_cycle() + self.worker_cycle() + except KeyboardInterrupt: + logger.info("Keyboard interrupt received") + + self.send_process_state_change_message( + process_state=HordeProcessState.PROCESS_ENDING, + info="Process ending", + ) self.cleanup_and_exit() + + logger.info("Process ended") + self.send_process_state_change_message( + process_state=HordeProcessState.PROCESS_ENDED, + info="Process ended", + ) + + +def signal_handler(sig: int, frame: object) -> None: + print("You pressed Ctrl+C!") diff --git a/horde_worker_regen/process_management/messages.py b/horde_worker_regen/process_management/messages.py index 687b668a..91ebff4b 100644 --- a/horde_worker_regen/process_management/messages.py +++ b/horde_worker_regen/process_management/messages.py @@ -62,13 +62,6 @@ class HordeProcessState(enum.Enum): EVALUATING_SAFETY = auto() - def can_accept_job(self) -> bool: - return ( - self == HordeProcessState.WAITING_FOR_JOB - or self == HordeProcessState.INFERENCE_COMPLETE - or self == HordeProcessState.INFERENCE_FAILED - ) - class HordeProcessMessage(BaseModel): """Process messages are sent from the child processes to the main process.""" diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index 8c31d92e..13261695 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -3,6 +3,7 @@ import multiprocessing import random import time +from asyncio import CancelledError from asyncio import Lock as Lock_Asyncio from collections import deque from collections.abc import Mapping @@ -92,7 +93,15 @@ def __init__( self.last_process_state = last_process_state def is_process_busy(self) -> bool: - return not self.last_process_state.can_accept_job() + return ( + self.last_process_state == HordeProcessState.INFERENCE_STARTING + or self.last_process_state == HordeProcessState.ALCHEMY_STARTING + or self.last_process_state == HordeProcessState.DOWNLOADING_MODEL + or self.last_process_state == HordeProcessState.PRELOADING_MODEL + or self.last_process_state == HordeProcessState.JOB_RECEIVED + or self.last_process_state == HordeProcessState.EVALUATING_SAFETY + or self.last_process_state == HordeProcessState.PROCESS_STARTING + ) def __repr__(self) -> str: return str( @@ -101,7 +110,11 @@ def __repr__(self) -> str: ) def can_accept_job(self) -> bool: - return self.last_process_state.can_accept_job() + return ( + self.last_process_state == HordeProcessState.WAITING_FOR_JOB + or self.last_process_state == HordeProcessState.INFERENCE_COMPLETE + or self.last_process_state == HordeProcessState.ALCHEMY_COMPLETE + ) class HordeModelMap(RootModel[dict[str, ModelInfo]]): @@ -191,6 +204,26 @@ def get_first_available_inference_process(self) -> HordeProcessInfo | None: return p return None + def get_first_inference_process_to_kill(self) -> HordeProcessInfo | None: + for p in self.values(): + if p.process_type != HordeProcessKind.INFERENCE: + continue + + if ( + p.last_process_state == HordeProcessState.WAITING_FOR_JOB + or p.last_process_state == HordeProcessState.PROCESS_STARTING + or p.last_process_state == HordeProcessState.DOWNLOADING_MODEL + or p.last_process_state == HordeProcessState.INFERENCE_COMPLETE + ): + return p + + if p.is_process_busy(): + continue + + if (): + return p + return None + def get_safety_process(self) -> HordeProcessInfo | None: for p in self.values(): if p.process_type == HordeProcessKind.SAFETY: @@ -206,7 +239,7 @@ def num_safety_processes(self) -> int: def get_first_available_safety_process(self) -> HordeProcessInfo | None: for p in self.values(): - if p.process_type == HordeProcessKind.SAFETY and p.can_accept_job(): + if p.process_type == HordeProcessKind.SAFETY and p.last_process_state == HordeProcessState.WAITING_FOR_JOB: return p return None @@ -292,6 +325,9 @@ def get_process_total_ram_usage(self) -> int: _completed_jobs_lock: Lock_Asyncio + kudos_generated_this_session: float = 0 + session_start_time: float = 0 + _aiohttp_session: aiohttp.ClientSession stable_diffusion_reference: StableDiffusion_ModelReference | None @@ -317,7 +353,7 @@ def get_process_total_ram_usage(self) -> int: _api_call_loop_interval = 0.1 """The number of seconds to wait between each loop of the main API call loop.""" - _api_get_user_info_interval = 5 + _api_get_user_info_interval = 10 """The number of seconds to wait between each fetch of the user info.""" _last_get_user_info_time: float = 0 @@ -338,6 +374,8 @@ def num_total_processes(self) -> int: """A semaphore that limits the number of inference processes that can run at once.""" _disk_lock: Lock_MultiProcessing + _shutting_down = False + def __init__( self, *, @@ -350,6 +388,8 @@ def __init__( max_download_processes: int = 1, max_concurrent_inference_processes: int = 1, ) -> None: + self.session_start_time = time.time() + self.bridge_data = bridge_data self._process_map = ProcessMap({}) @@ -426,6 +466,12 @@ def __init__( time.sleep(5) def is_time_for_shutdown(self) -> bool: + if len(self.completed_jobs) > 0: + return False + + if len(self.jobs_being_safety_checked) > 0 or len(self.jobs_pending_safety_check) > 0: + return False + if len(self.jobs_in_progress) > 0: return False @@ -435,8 +481,8 @@ def is_time_for_shutdown(self) -> bool: any_process_alive = False for process_info in self._process_map.values(): - if process_info.is_process_busy(): - return False + if process_info.process_type != HordeProcessKind.INFERENCE: + continue if process_info.last_process_state != HordeProcessState.PROCESS_ENDED: any_process_alive = True @@ -557,48 +603,50 @@ def start_inference_processes(self) -> None: def end_inference_processes(self) -> None: """End any inference processes above the configured limit, or all of them if shutting down.""" - if self.is_time_for_shutdown(): - num_processes_to_end = self._process_map.num_inference_processes() - else: - num_processes_to_end = self._process_map.num_inference_processes() - self.max_inference_processes - - # If the number of processes to end is less than 0, log a critical error and raise a ValueError - if num_processes_to_end < 0: - logger.critical( - f"There are already {self._process_map.num_inference_processes()} inference processes running, but " - f"max_inference_processes is set to {self.max_inference_processes}", - ) - raise ValueError("num_processes_to_end cannot be less than 0") + if len(self.job_deque) > 0 and len(self.job_deque) != len(self.jobs_in_progress): + return - # End the required number of processes - for _ in range(num_processes_to_end): - # Get the process to end - process_info = self._process_map.get_first_available_inference_process() + # Get the process to end + process_info = self._process_map.get_first_inference_process_to_kill() - if process_info is None: - logger.critical( - f"Expected to find {num_processes_to_end} inference processes to end, but found none", - ) - raise ValueError("Expected to find a process to end, but found none") + if process_info is None: + return - # Send the process a message to end - process_info.pipe_connection.send(HordeControlMessage(control_flag=HordeControlFlag.END_PROCESS)) + # Send the process a message to end + process_info.pipe_connection.send(HordeControlMessage(control_flag=HordeControlFlag.END_PROCESS)) - # Update the process map - self._process_map.update_entry(process_id=process_info.process_id) + # Update the process map + self._process_map.update_entry(process_id=process_info.process_id) - logger.info(f"Ended inference process {process_info.process_id}") + logger.info(f"Ended inference process {process_info.process_id}") total_num_completed_jobs: int = 0 + def end_safety_processes(self) -> None: + """End any safety processes above the configured limit, or all of them if shutting down.""" + + # Get the process to end + process_info = self._process_map.get_first_available_safety_process() + + if process_info is None: + return + + # Send the process a message to end + process_info.pipe_connection.send(HordeControlMessage(control_flag=HordeControlFlag.END_PROCESS)) + + # Update the process map + self._process_map.update_entry(process_id=process_info.process_id) + + logger.info(f"Ended safety process {process_info.process_id}") + def receive_and_handle_process_messages(self) -> None: """Receive and handle any messages from the child processes.""" while not self._process_message_queue.empty(): message: HordeProcessMessage = self._process_message_queue.get() logger.debug( - f"Received {type(message).__name__}: " - f"{message.model_dump(exclude={'job_result_images_base64', 'replacement_image_base64'})}", + f"Received {type(message).__name__} from process {message.process_id}", + # f"{message.model_dump(exclude={'job_result_images_base64', 'replacement_image_base64'})}", ) if not isinstance(message, HordeProcessMessage): @@ -632,7 +680,7 @@ def receive_and_handle_process_messages(self) -> None: message.process_id in self._process_map and message.horde_model_state != self._process_map[message.process_id].loaded_horde_model_name ): - logger.info(f"Process {message.process_id} loaded model {message.horde_model_name}") + logger.info(f"Process {message.process_id} has model {message.horde_model_name} loaded.") self._process_map.update_entry( process_id=message.process_id, @@ -706,7 +754,7 @@ def receive_and_handle_process_messages(self) -> None: if message.safety_evaluations[i].is_csam: num_images_csam += 1 - logger.info(f"Job {message.job_id} had {num_images_censored} images censored") + logger.debug(f"Job {message.job_id} had {num_images_censored} images censored") if num_images_censored > 0: completed_job_info.censored = True @@ -824,6 +872,33 @@ def start_inference(self) -> None: ) time.sleep(0.1) + logger.info(f"Starting inference for job {next_job.id_} on process {process_with_model.process_id}") + logger.info(f"Model: {next_job.model}, Using: {next_job.source_processing}") + extra_info = "" + if next_job.payload.control_type is not None: + extra_info += f"Control type: {next_job.payload.control_type}" + if next_job.payload.loras: + if extra_info: + extra_info += ", " + extra_info += f"{len(next_job.payload.loras)} LoRAs" + if next_job.payload.tis: + if extra_info: + extra_info += ", " + extra_info += f"{len(next_job.payload.tis)} TIs" + if next_job.payload.post_processing is not None and len(next_job.payload.post_processing) > 0: + if extra_info: + extra_info += ", " + extra_info += f"Post processing: {next_job.payload.post_processing}" + if next_job.payload.hires_fix: + if extra_info: + extra_info += ", " + extra_info += "HiRes fix" + + if extra_info: + logger.info(extra_info) + + logger.info(f"{next_job.payload.width}x{next_job.payload.height} for {next_job.payload.ddim_steps} steps") + self.jobs_in_progress.append(next_job) process_with_model.pipe_connection.send( HordeInferenceControlMessage( @@ -996,6 +1071,7 @@ async def api_submit_job( if self._consecutive_failed_results >= self._consecutive_failed_results_max: async with self._completed_jobs_lock: self.completed_jobs.remove(completed_job_info) + self._consecutive_failed_results = 0 return image_in_buffer = self.base64_image_to_stream_buffer(completed_job_info.job_result_images_base64[0]) @@ -1036,9 +1112,10 @@ async def api_submit_job( self._consecutive_failed_results += 1 return - logger.info( + logger.success( f"Submitted job {job_info.id_} (model: {job_info.model}) for {job_submit_response.reward} kudos.", ) + self.kudos_generated_this_session += job_submit_response.reward async with self._completed_jobs_lock: self.completed_jobs.remove(completed_job_info) self._consecutive_failed_results = 0 @@ -1056,14 +1133,53 @@ async def api_submit_job( _job_pop_frequency = 1.0 _last_job_pop_time = 0.0 + _max_pending_megapixelsteps = 150 + _triggered_max_pending_megapixelsteps_time = 0.0 + _triggered_max_pending_megapixelsteps = False + + def get_pending_megapixelsteps(self) -> int: + """Get the number of megapixelsteps that are pending in the job deque.""" + job_deque_mps = sum(job.payload.width * job.payload.height * job.payload.ddim_steps for job in self.job_deque) + in_progress_mps = sum( + job.payload.width * job.payload.height * job.payload.ddim_steps for job in self.jobs_in_progress + ) + + return int((job_deque_mps + in_progress_mps) / 1_000_000) + + def should_wait_for_pending_megapixelsteps(self) -> bool: + """Check if the number of megapixelsteps in the job deque is above the limit.""" + return self.get_pending_megapixelsteps() > self._max_pending_megapixelsteps + async def api_job_pop(self) -> None: """If the job deque is not full, add any jobs that are available to the job deque.""" + if self._shutting_down: + return + if len(self.job_deque) >= self.bridge_data.queue_size + 1: # FIXME? return if self._testing_jobs_added >= self._testing_max_jobs: return + if self.should_wait_for_pending_megapixelsteps(): + if self._triggered_max_pending_megapixelsteps is False: + self._triggered_max_pending_megapixelsteps = True + self._triggered_max_pending_megapixelsteps_time = time.time() + logger.info( + f"Paused job pops for pending megapixelsteps to decrease below {self._max_pending_megapixelsteps}", + ) + return + + if (time.time() - self._triggered_max_pending_megapixelsteps_time) > 0.0: + return + + self._triggered_max_pending_megapixelsteps = False + logger.info( + f"Pending megapixelsteps decreased below {self._max_pending_megapixelsteps}, continuing with job pops", + ) + + self._triggered_max_pending_megapixelsteps = False + if time.time() - self._last_job_pop_time < self._job_pop_frequency: return @@ -1080,8 +1196,8 @@ async def api_job_pop(self) -> None: job_pop_request = ImageGenerateJobPopRequest( apikey=self.bridge_data.api_key, name=self.bridge_data.dreamer_worker_name, - bridge_agent="AI Horde Worker reGen:1:https://github.com/Haidra-Org/", - bridge_version=1, # TODO TIs broken + bridge_agent="AI Horde Worker reGen:2:https://github.com/Haidra-Org/", + bridge_version=2, models=self.bridge_data.image_models_to_load, nsfw=self.bridge_data.nsfw, threads=self.max_concurrent_inference_processes, @@ -1092,7 +1208,7 @@ async def api_job_pop(self) -> None: allow_unsafe_ipaddr=self.bridge_data.allow_unsafe_ip, allow_post_processing=self.bridge_data.allow_post_processing, allow_controlnet=self.bridge_data.allow_controlnet, - allow_lora=False, # TODO loras broken + allow_lora=self.bridge_data.allow_lora, # TODO loras broken ) job_pop_response = await self.horde_client_session.submit_request( @@ -1101,7 +1217,10 @@ async def api_job_pop(self) -> None: ) if isinstance(job_pop_response, RequestErrorResponse): - logger.error(f"Failed to pop job (API Error): {job_pop_response}") + if "maintenance mode" in job_pop_response.message: + logger.warning(f"Failed to pop job (Maintenance Mode): {job_pop_response}") + else: + logger.error(f"Failed to pop job (API Error): {job_pop_response}") self._job_pop_frequency = self._error_job_pop_frequency return except Exception as e: @@ -1189,6 +1308,9 @@ async def api_job_pop(self) -> None: _user_info_failed_reason: str | None = None async def api_get_user_info(self) -> None: + if self._shutting_down: + return + request = FindUserRequest(apikey=self.bridge_data.api_key) try: response = await self.horde_client_session.submit_request(request, FindUserResponse) @@ -1204,7 +1326,12 @@ async def api_get_user_info(self) -> None: self._user_info_failed_reason = None if self.user_info.kudos_details is not None: - logger.debug(f"Kudos Accumulated: {self.user_info.kudos_details.accumulated }") + # print kudos this session and kudos per hour based on self.session_start_time + kudos_per_hour = self.kudos_generated_this_session / (time.time() - self.session_start_time) * 3600 + logger.info( + f"Kudos this session: {self.kudos_generated_this_session} ~{kudos_per_hour:.2f} kudos/hour)", + ) + logger.info(f"Worker Kudos Accumulated: {self.user_info.kudos_details.accumulated }") except ClientError as e: self._user_info_failed = True @@ -1232,24 +1359,31 @@ async def _api_call_loop(self) -> None: async with self.horde_client_session: while True: with logger.catch(): - if self._user_info_failed: - await asyncio.sleep(5) + try: + if self._user_info_failed: + await asyncio.sleep(5) - tasks = [self.api_job_pop(), self.api_submit_job()] + tasks = [self.api_job_pop(), self.api_submit_job()] - if self._last_get_user_info_time + self._api_get_user_info_interval < time.time(): - self._last_get_user_info_time = time.time() - # tasks.append(self.api_get_user_info()) + if self._last_get_user_info_time + self._api_get_user_info_interval < time.time(): + self._last_get_user_info_time = time.time() + tasks.append(self.api_get_user_info()) - if len(tasks) > 0: - await asyncio.gather(*tasks, return_exceptions=True) + if len(tasks) > 0: + results = await asyncio.gather(*tasks, return_exceptions=True) - for task in tasks: - if isinstance(task, Exception): - logger.exception(f"Task failed: {task}") + # Print all exceptions + for result in results: + if isinstance(result, Exception): + logger.exception(f"Exception in api call loop: {result}") - if self._user_info_failed: - logger.error("The server failed to respond. Is the horde or your internet down?") + if self._user_info_failed: + logger.error("The server failed to respond. Is the horde or your internet down?") + + if self.is_time_for_shutdown(): + break + except CancelledError: + self._shutting_down = True await asyncio.sleep(self._api_call_loop_interval) @@ -1258,28 +1392,35 @@ async def _process_control_loop(self) -> None: self.start_inference_processes() while True: - if self.stable_diffusion_reference is None: - return + try: + if self.stable_diffusion_reference is None: + return - # We don't want to pop jobs from the deque while we are adding jobs to it - # TODO: Is this necessary? - async with self._job_deque_lock, self._jobs_safety_check_lock, self._completed_jobs_lock: - self.receive_and_handle_process_messages() + # We don't want to pop jobs from the deque while we are adding jobs to it + # TODO: Is this necessary? + async with self._job_deque_lock, self._jobs_safety_check_lock, self._completed_jobs_lock: + self.receive_and_handle_process_messages() - if len(self.jobs_pending_safety_check) > 0: - async with self._jobs_safety_check_lock: - self.start_evaluate_safety() + if len(self.jobs_pending_safety_check) > 0: + async with self._jobs_safety_check_lock: + self.start_evaluate_safety() - if self.is_time_for_shutdown(): - break + if self.is_free_inference_process_available() and len(self.job_deque) > 0: + self.preload_models() + self.start_inference() + self.unload_models() - if self.is_free_inference_process_available() and len(self.job_deque) > 0: - self.preload_models() - self.start_inference() - self.unload_models() + if self._shutting_down: + self.end_inference_processes() - await asyncio.sleep(self._loop_interval) + if self.is_time_for_shutdown(): + break + + await asyncio.sleep(self._loop_interval) + except CancelledError: + self._shutting_down = True + self.end_safety_processes() logger.info("Shutting down process manager") async def _main_loop(self) -> None: @@ -1291,4 +1432,11 @@ async def _main_loop(self) -> None: def start(self) -> None: """Start the process manager.""" + import signal + + signal.signal(signal.SIGINT, self.signal_handler) asyncio.run(self._main_loop()) + + def signal_handler(self, sig: int, frame: object) -> None: + logger.warning("Shutting down after current jobs are finished...") + self._shutting_down = True diff --git a/requirements.txt b/requirements.txt index d80ffa9a..f8bcc1a9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch -horde_sdk>=7.10.0 +horde_sdk>=0.7.11 horde_model_reference horde_safety hordelib @@ -8,6 +8,7 @@ hordelib python-dotenv pyyaml +python-Levenshtein pydantic typing_extensions diff --git a/run_worker.py b/run_worker.py index 252df672..8264b38d 100644 --- a/run_worker.py +++ b/run_worker.py @@ -1,3 +1,4 @@ +import argparse import contextlib import multiprocessing import time @@ -77,13 +78,20 @@ def ensure_model_db_downloaded() -> ModelReferenceManager: print(f"Multiprocessing start method: {multiprocessing.get_start_method()}") + # Create args for -v, allowing -vvv + parser = argparse.ArgumentParser() + parser.add_argument("-v", action="count", default=3, help="Increase verbosity of output") + + args = parser.parse_args() + + logger.remove() from hordelib.utils.logger import HordeLog # Initialise logging with loguru HordeLog.initialise( setup_logging=True, process_id=None, - verbosity_count=5, # FIXME + verbosity_count=args.v, # FIXME ) # We only need to download the legacy DBs once, so we do it here instead of in the worker processes