diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index 32e48bc2..07c45bfe 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -487,7 +487,7 @@ def get_process_total_ram_usage(self) -> int: total += process_info.ram_usage_bytes return total - jobs_in_progress: list[ImageGenerateJobPopResponse] + jobs_in_progress: list[tuple[ImageGenerateJobPopResponse, int]] """A list of jobs that are currently in progress.""" jobs_pending_safety_check: list[HordeJobInfo] @@ -866,13 +866,14 @@ def _end_inference_process(self, process_info: HordeProcessInfo) -> None: def _replace_inference_process(self, process_info: HordeProcessInfo) -> None: """ - Replaces an inference process (for whatever reason). + Replaces an inference process (for whatever reason; probably because it crashed). :param process_info: process to replace :return: None """ logger.debug(f"Replacing {process_info}") self._end_inference_process(process_info) + self.jobs_in_progress = [job for job in self.jobs_in_progress if job[1] != process_info.process_id] self._start_inference_process(process_info.process_id) total_num_completed_jobs: int = 0 @@ -1008,7 +1009,7 @@ def receive_and_handle_process_messages(self) -> None: # Remove the job from the jobs in progress by matching the job ID (.id_) self.jobs_in_progress = [ - job for job in self.jobs_in_progress if job.id_ != message.sdk_api_job_info.id_ + job for job in self.jobs_in_progress if job[0].id_ != message.sdk_api_job_info.id_ ] if len(self.jobs_in_progress) != _num_jobs_in_progress - 1: @@ -1207,8 +1208,10 @@ def start_inference(self) -> None: # Get the first job in the deque that is not already in progress next_job: ImageGenerateJobPopResponse | None = None + # ImageGenerateJobPopResponse itself is unhashable, so we'll check the ids instead. + jobs_in_progress = {job[0].id_ for job in self.jobs_in_progress} for job in self.job_deque: - if job in self.jobs_in_progress: + if job.id_ in jobs_in_progress: continue next_job = job break @@ -1296,7 +1299,7 @@ def start_inference(self) -> None: ) # endregion - self.jobs_in_progress.append(next_job) + self.jobs_in_progress.append((next_job, process_with_model.process_id)) process_with_model.pipe_connection.send( HordeInferenceControlMessage( control_flag=HordeControlFlag.START_INFERENCE,