diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index 359480ea..a2d58e71 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -201,7 +201,10 @@ def num_available_inference_processes(self) -> int: def get_first_available_inference_process(self) -> HordeProcessInfo | None: for p in self.values(): if p.process_type == HordeProcessKind.INFERENCE and p.can_accept_job(): + if p.last_process_state == HordeProcessState.PRELOADED_MODEL: + continue return p + return None def get_first_inference_process_to_kill(self) -> HordeProcessInfo | None: @@ -257,9 +260,12 @@ def num_busy_processes(self) -> int: return count def __repr__(self) -> str: - base_string = "" + base_string = "Processes: " for process_id, process_info in self.items(): - base_string += f"{process_id}: ({process_info.loaded_horde_model_name}) " + if process_info.process_type == HordeProcessKind.INFERENCE: + base_string += f"{process_id}: ({process_info.loaded_horde_model_name}) " + else: + base_string += f"{process_id}: ({process_info.process_type.name}) " base_string += f"{process_info.last_process_state.name}; " return base_string @@ -635,6 +641,9 @@ def end_inference_processes(self) -> None: logger.info(f"Ended inference process {process_info.process_id}") + # Join the process with a timeout of 1 second + process_info.mp_process.join(timeout=1) + total_num_completed_jobs: int = 0 def end_safety_processes(self) -> None: @@ -679,6 +688,11 @@ def receive_and_handle_process_messages(self) -> None: logger.debug(f"Process {message.process_id} changed state to {message.process_state}") if message.process_state == HordeProcessState.INFERENCE_STARTING: logger.info(f"Process {message.process_id} is starting inference on model {message.info}") + self._horde_model_map.update_entry( + horde_model_name=message.info, + load_state=ModelLoadState.IN_USE, + process_id=message.process_id, + ) if isinstance(message, HordeModelStateChangeMessage): self._horde_model_map.update_entry( @@ -687,6 +701,13 @@ def receive_and_handle_process_messages(self) -> None: process_id=message.process_id, ) + if message.horde_model_state == ModelLoadState.LOADING: + logger.debug(f"Process {message.process_id} is loading model {message.horde_model_name}") + self._process_map.update_entry( + process_id=message.process_id, + loaded_horde_model_name=message.horde_model_name, + ) + if ( message.horde_model_state == ModelLoadState.LOADED_IN_VRAM or message.horde_model_state == ModelLoadState.LOADED_IN_RAM @@ -698,7 +719,8 @@ def receive_and_handle_process_messages(self) -> None: loaded_message = f"Process {message.process_id} has model {message.horde_model_name} loaded. " if message.time_elapsed is not None: - loaded_message += f"Loading took {message.time_elapsed} seconds" + # round to 2 decimal places + loaded_message += f"Loading took {message.time_elapsed:.2f} seconds" logger.info(loaded_message) @@ -740,7 +762,11 @@ def receive_and_handle_process_messages(self) -> None: ) logger.debug(f"Jobs in progress: {self.jobs_in_progress}") - self.job_deque.popleft() + for job in self.job_deque: + if job.id_ == message.job_info.id_: + self.job_deque.remove(job) + break + self.total_num_completed_jobs += 1 if message.time_elapsed is not None: logger.info( @@ -760,8 +786,8 @@ def receive_and_handle_process_messages(self) -> None: ) elif isinstance(message, HordeSafetyResultMessage): completed_job_info: CompletedJobInfo | None = None - for i, job in enumerate(self.jobs_being_safety_checked): - if job.job_info.id_ == message.job_id: + for i, job_being_safety_checked in enumerate(self.jobs_being_safety_checked): + if job_being_safety_checked.job_info.id_ == message.job_id: completed_job_info = self.jobs_being_safety_checked.pop(i) break @@ -783,7 +809,7 @@ def receive_and_handle_process_messages(self) -> None: logger.debug( f"Job {message.job_id} had {num_images_censored} images censored and took " - f"{message.time_elapsed} seconds to check safety", + f"{message.time_elapsed:.2f} seconds to check safety", ) if num_images_censored > 0: @@ -803,17 +829,19 @@ def preload_models(self) -> None: # Starting from the left of the deque, preload models that are not yet loaded up to the # number of inference processes # that are available - num_already_loaded_model = 0 for job in self.job_deque: + model_is_loaded = False + if job.model is None: raise ValueError(f"job.model is None ({job})") - if self._horde_model_map.is_model_loaded(job.model) or self._horde_model_map.is_model_loading(job.model): - num_already_loaded_model += 1 - continue + for process in self._process_map.values(): + if process.loaded_horde_model_name == job.model: + model_is_loaded = True + break - if num_already_loaded_model >= self._process_map.num_inference_processes(): - break + if model_is_loaded: + continue available_process = self._process_map.get_first_available_inference_process() @@ -1166,7 +1194,7 @@ async def api_submit_job( return logger.success( - f"Submitted job {job_info.id_} (model: {job_info.model}) for {job_submit_response.reward} kudos.", + f"Submitted job {job_info.id_} (model: {job_info.model}) for {job_submit_response.reward:.2f} kudos.", ) self.kudos_generated_this_session += job_submit_response.reward async with self._completed_jobs_lock: @@ -1468,9 +1496,10 @@ async def _process_control_loop(self) -> None: self.start_evaluate_safety() if self.is_free_inference_process_available() and len(self.job_deque) > 0: - self.preload_models() - self.start_inference() - self.unload_models() + async with self._job_deque_lock, self._jobs_safety_check_lock, self._completed_jobs_lock: + self.preload_models() + self.start_inference() + self.unload_models() if self._shutting_down: self.end_inference_processes() @@ -1482,9 +1511,9 @@ async def _process_control_loop(self) -> None: logger.info(f"{self._process_map}") logger.info(f"Number of jobs popped: {len(self.job_deque)}") logger.info(f"Number of jobs in progress: {len(self.jobs_in_progress)}") - logger.info(f"Number of jobs completed: {len(self.completed_jobs)}") logger.info(f"Number of jobs pending safety check: {len(self.jobs_pending_safety_check)}") logger.info(f"Number of jobs being safety checked: {len(self.jobs_being_safety_checked)}") + logger.info(f"Number of jobs completed: {len(self.completed_jobs)}") logger.info(f"Number of jobs submitted: {self.total_num_completed_jobs}") self._last_status_message_time = time.time()