Skip to content

Commit

Permalink
fix: don't drop job with concurrent working process
Browse files Browse the repository at this point in the history
additionally fixes a bug where models would compete preloading models
  • Loading branch information
tazlin committed Oct 3, 2023
1 parent e07f50b commit a8b3985
Showing 1 changed file with 47 additions and 18 deletions.
65 changes: 47 additions & 18 deletions horde_worker_regen/process_management/process_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,10 @@ def num_available_inference_processes(self) -> int:
def get_first_available_inference_process(self) -> HordeProcessInfo | None:
for p in self.values():
if p.process_type == HordeProcessKind.INFERENCE and p.can_accept_job():
if p.last_process_state == HordeProcessState.PRELOADED_MODEL:
continue
return p

return None

def get_first_inference_process_to_kill(self) -> HordeProcessInfo | None:
Expand Down Expand Up @@ -257,9 +260,12 @@ def num_busy_processes(self) -> int:
return count

def __repr__(self) -> str:
base_string = ""
base_string = "Processes: "
for process_id, process_info in self.items():
base_string += f"{process_id}: ({process_info.loaded_horde_model_name}) "
if process_info.process_type == HordeProcessKind.INFERENCE:
base_string += f"{process_id}: ({process_info.loaded_horde_model_name}) "
else:
base_string += f"{process_id}: ({process_info.process_type.name}) "
base_string += f"{process_info.last_process_state.name}; "

return base_string
Expand Down Expand Up @@ -635,6 +641,9 @@ def end_inference_processes(self) -> None:

logger.info(f"Ended inference process {process_info.process_id}")

# Join the process with a timeout of 1 second
process_info.mp_process.join(timeout=1)

total_num_completed_jobs: int = 0

def end_safety_processes(self) -> None:
Expand Down Expand Up @@ -679,6 +688,11 @@ def receive_and_handle_process_messages(self) -> None:
logger.debug(f"Process {message.process_id} changed state to {message.process_state}")
if message.process_state == HordeProcessState.INFERENCE_STARTING:
logger.info(f"Process {message.process_id} is starting inference on model {message.info}")
self._horde_model_map.update_entry(
horde_model_name=message.info,
load_state=ModelLoadState.IN_USE,
process_id=message.process_id,
)

if isinstance(message, HordeModelStateChangeMessage):
self._horde_model_map.update_entry(
Expand All @@ -687,6 +701,13 @@ def receive_and_handle_process_messages(self) -> None:
process_id=message.process_id,
)

if message.horde_model_state == ModelLoadState.LOADING:
logger.debug(f"Process {message.process_id} is loading model {message.horde_model_name}")
self._process_map.update_entry(
process_id=message.process_id,
loaded_horde_model_name=message.horde_model_name,
)

if (
message.horde_model_state == ModelLoadState.LOADED_IN_VRAM
or message.horde_model_state == ModelLoadState.LOADED_IN_RAM
Expand All @@ -698,7 +719,8 @@ def receive_and_handle_process_messages(self) -> None:
loaded_message = f"Process {message.process_id} has model {message.horde_model_name} loaded. "

if message.time_elapsed is not None:
loaded_message += f"Loading took {message.time_elapsed} seconds"
# round to 2 decimal places
loaded_message += f"Loading took {message.time_elapsed:.2f} seconds"

logger.info(loaded_message)

Expand Down Expand Up @@ -740,7 +762,11 @@ def receive_and_handle_process_messages(self) -> None:
)
logger.debug(f"Jobs in progress: {self.jobs_in_progress}")

self.job_deque.popleft()
for job in self.job_deque:
if job.id_ == message.job_info.id_:
self.job_deque.remove(job)
break

self.total_num_completed_jobs += 1
if message.time_elapsed is not None:
logger.info(
Expand All @@ -760,8 +786,8 @@ def receive_and_handle_process_messages(self) -> None:
)
elif isinstance(message, HordeSafetyResultMessage):
completed_job_info: CompletedJobInfo | None = None
for i, job in enumerate(self.jobs_being_safety_checked):
if job.job_info.id_ == message.job_id:
for i, job_being_safety_checked in enumerate(self.jobs_being_safety_checked):
if job_being_safety_checked.job_info.id_ == message.job_id:
completed_job_info = self.jobs_being_safety_checked.pop(i)
break

Expand All @@ -783,7 +809,7 @@ def receive_and_handle_process_messages(self) -> None:

logger.debug(
f"Job {message.job_id} had {num_images_censored} images censored and took "
f"{message.time_elapsed} seconds to check safety",
f"{message.time_elapsed:.2f} seconds to check safety",
)

if num_images_censored > 0:
Expand All @@ -803,17 +829,19 @@ def preload_models(self) -> None:
# Starting from the left of the deque, preload models that are not yet loaded up to the
# number of inference processes
# that are available
num_already_loaded_model = 0
for job in self.job_deque:
model_is_loaded = False

if job.model is None:
raise ValueError(f"job.model is None ({job})")

if self._horde_model_map.is_model_loaded(job.model) or self._horde_model_map.is_model_loading(job.model):
num_already_loaded_model += 1
continue
for process in self._process_map.values():
if process.loaded_horde_model_name == job.model:
model_is_loaded = True
break

if num_already_loaded_model >= self._process_map.num_inference_processes():
break
if model_is_loaded:
continue

available_process = self._process_map.get_first_available_inference_process()

Expand Down Expand Up @@ -1166,7 +1194,7 @@ async def api_submit_job(
return

logger.success(
f"Submitted job {job_info.id_} (model: {job_info.model}) for {job_submit_response.reward} kudos.",
f"Submitted job {job_info.id_} (model: {job_info.model}) for {job_submit_response.reward:.2f} kudos.",
)
self.kudos_generated_this_session += job_submit_response.reward
async with self._completed_jobs_lock:
Expand Down Expand Up @@ -1468,9 +1496,10 @@ async def _process_control_loop(self) -> None:
self.start_evaluate_safety()

if self.is_free_inference_process_available() and len(self.job_deque) > 0:
self.preload_models()
self.start_inference()
self.unload_models()
async with self._job_deque_lock, self._jobs_safety_check_lock, self._completed_jobs_lock:
self.preload_models()
self.start_inference()
self.unload_models()

if self._shutting_down:
self.end_inference_processes()
Expand All @@ -1482,9 +1511,9 @@ async def _process_control_loop(self) -> None:
logger.info(f"{self._process_map}")
logger.info(f"Number of jobs popped: {len(self.job_deque)}")
logger.info(f"Number of jobs in progress: {len(self.jobs_in_progress)}")
logger.info(f"Number of jobs completed: {len(self.completed_jobs)}")
logger.info(f"Number of jobs pending safety check: {len(self.jobs_pending_safety_check)}")
logger.info(f"Number of jobs being safety checked: {len(self.jobs_being_safety_checked)}")
logger.info(f"Number of jobs completed: {len(self.completed_jobs)}")
logger.info(f"Number of jobs submitted: {self.total_num_completed_jobs}")

self._last_status_message_time = time.time()
Expand Down

0 comments on commit a8b3985

Please sign in to comment.