Skip to content

Commit

Permalink
feat: concurrent inference; better logs and error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
tazlin committed Oct 3, 2023
1 parent 5488e17 commit e07f50b
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 28 deletions.
38 changes: 27 additions & 11 deletions horde_worker_regen/process_management/horde_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from horde_worker_regen.process_management.messages import (
HordeControlFlag,
HordeControlMessage,
HordeProcessHeartbeatMessage,
HordeProcessMemoryMessage,
HordeProcessState,
HordeProcessStateChangeMessage,
Expand Down Expand Up @@ -52,9 +53,6 @@ class HordeProcess(abc.ABC):
disk_lock: Lock
"""A lock used to prevent multiple processes from accessing disk at the same time."""

_last_message_time: float = 0.0
"""The last time a message was sent or received."""

_loop_interval: float = 0.1
"""The time to sleep between each loop iteration."""

Expand Down Expand Up @@ -99,24 +97,44 @@ def __init__(
verbosity_count=5, # FIXME
)

self._last_message_time = time.time()

self.send_process_state_change_message(
process_state=HordeProcessState.PROCESS_STARTING,
info="Process starting",
)

def send_process_state_change_message(self, process_state: HordeProcessState, info: str) -> None:
def send_process_state_change_message(
self,
process_state: HordeProcessState,
info: str,
time_elapsed: float | None = None,
) -> None:
message = HordeProcessStateChangeMessage(
process_state=process_state,
process_id=self.process_id,
info=info,
time_elapsed=self._last_message_time - time.time(),
time_elapsed=time_elapsed,
)
self.process_message_queue.put(message)
self._last_message_time = time.time()
self._last_sent_process_state = process_state

_heartbeat_limit_interval_seconds: float = 5.0
_last_heartbeat_time: float = 0.0

def send_heartbeat_message(self) -> None:
"""Send a heartbeat message to the main process."""

if (time.time() - self._last_heartbeat_time) < self._heartbeat_limit_interval_seconds:
return

message = HordeProcessHeartbeatMessage(
process_id=self.process_id,
info="Heartbeat",
time_elapsed=None,
)
self.process_message_queue.put(message)

self._last_heartbeat_time = time.time()

@abstractmethod
def cleanup_and_exit(self) -> None:
"""Cleanup and exit the process."""
Expand All @@ -129,7 +147,7 @@ def send_memory_report_message(
message = HordeProcessMemoryMessage(
process_id=self.process_id,
info="Memory report",
time_elapsed=self._last_message_time - time.time(),
time_elapsed=None,
ram_usage_bytes=psutil.Process().memory_info().rss,
)

Expand All @@ -138,7 +156,6 @@ def send_memory_report_message(
message.vram_total_bytes = self.get_vram_total_bytes()

self.process_message_queue.put(message)
self._last_message_time = time.time()

@abstractmethod
def _receive_and_handle_control_message(self, message: HordeControlMessage) -> None:
Expand All @@ -148,7 +165,6 @@ def receive_and_handle_control_messages(self) -> None:
"""Get and handle any control messages pending from the main process."""
while self.pipe_connection.poll():
message = self.pipe_connection.recv()
self._last_message_time = time.time()

if not isinstance(message, HordeControlMessage):
logger.critical(f"Received unexpected message type: {type(message).__name__}")
Expand Down
34 changes: 28 additions & 6 deletions horde_worker_regen/process_management/inference_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,20 +90,33 @@ def __init__(

from hordelib.nodes.node_model_loader import HordeCheckpointLoader

self._horde = HordeLib()
self._shared_model_manager = SharedModelManager()
self._checkpoint_loader = HordeCheckpointLoader()
try:
self._horde = HordeLib(comfyui_callback=self._comfyui_callback)
self._shared_model_manager = SharedModelManager()
except Exception as e:
logger.critical(f"Failed to initialise HordeLib: {type(e).__name__} {e}")

try:
self._checkpoint_loader = HordeCheckpointLoader()
except Exception as e:
logger.critical(f"Failed to initialise HordeCheckpointLoader: {type(e).__name__} {e}")

logger.info("HordeInferenceProcess initialised")

self.send_process_state_change_message(
process_state=HordeProcessState.WAITING_FOR_JOB,
info="Waiting for job",
)

def _comfyui_callback(self, label: str, data: dict, _id: str) -> None:
self.send_heartbeat_message()

def on_horde_model_state_change(
self,
horde_model_name: str,
process_state: HordeProcessState,
horde_model_state: ModelLoadState,
time_elapsed: float | None = None,
) -> None:
"""Update the main process with the current process state and model state."""
self.send_memory_report_message(include_vram=True)
Expand All @@ -114,7 +127,7 @@ def on_horde_model_state_change(
info=f"Model {horde_model_name} {horde_model_state.name}",
horde_model_name=horde_model_name,
horde_model_state=horde_model_state,
time_elapsed=self._last_message_time - time.time(),
time_elapsed=time_elapsed,
)
self.process_message_queue.put(model_update_message)

Expand All @@ -139,12 +152,15 @@ def download_model(self, horde_model_name: str) -> None:
if self._shared_model_manager.manager.is_model_available(horde_model_name):
logger.info(f"Model {horde_model_name} already downloaded")

time_start = time.time()

success = self._shared_model_manager.manager.download_model(horde_model_name, self.download_callback)

if success:
self.send_process_state_change_message(
process_state=HordeProcessState.DOWNLOAD_COMPLETE,
info=f"Downloaded model {horde_model_name}",
time_elapsed=time.time() - time_start,
)

self.on_horde_model_state_change(
Expand Down Expand Up @@ -178,6 +194,8 @@ def preload_model(
horde_model_state=ModelLoadState.LOADING,
)

time_start = time.time()

with contextlib.nullcontext(): # self.disk_lock:
self._checkpoint_loader.load_checkpoint(
will_load_loras=will_load_loras,
Expand All @@ -192,6 +210,7 @@ def preload_model(
process_state=HordeProcessState.PRELOADED_MODEL,
horde_model_name=horde_model_name,
horde_model_state=ModelLoadState.LOADED_IN_RAM,
time_elapsed=time.time() - time_start,
)

self.send_process_state_change_message(
Expand Down Expand Up @@ -260,6 +279,7 @@ def send_inference_result_message(
process_state: HordeProcessState,
job_info: ImageGenerateJobPopResponse,
images: list[Image] | None,
time_elapsed: float,
) -> None:
images_as_base64 = []

Expand All @@ -274,12 +294,11 @@ def send_inference_result_message(
process_id=self.process_id,
info="Inference result",
state=GENERATION_STATE.ok if images is not None and len(images) > 0 else GENERATION_STATE.faulted,
time_elapsed=self._last_message_time - time.time(),
time_elapsed=time_elapsed,
job_result_images_base64=images_as_base64,
job_info=job_info,
)
self.process_message_queue.put(message)
self._last_message_time = time.time()

if self._active_model_name is None:
logger.critical("No active model name, cannot update model state")
Expand Down Expand Up @@ -318,13 +337,16 @@ def _receive_and_handle_control_message(self, message: HordeControlMessage) -> N
info=f"Starting inference for {message.job_info.id_} with model {message.horde_model_name}",
)

time_start = time.time()

images = self.start_inference(message.job_info)
process_state = HordeProcessState.INFERENCE_COMPLETE if images else HordeProcessState.INFERENCE_FAILED
logger.debug(f"Finished inference with process state {process_state}")
self.send_inference_result_message(
process_state=process_state,
job_info=message.job_info,
images=images,
time_elapsed=time.time() - time_start,
)
else:
logger.critical(f"Received unexpected message: {message}")
Expand Down
6 changes: 5 additions & 1 deletion horde_worker_regen/process_management/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class HordeProcessMessage(BaseModel):

process_id: int
info: str
time_elapsed: float
time_elapsed: float | None = None


class HordeProcessMemoryMessage(HordeProcessMessage):
Expand All @@ -78,6 +78,10 @@ class HordeProcessMemoryMessage(HordeProcessMessage):
vram_total_bytes: int | None = None


class HordeProcessHeartbeatMessage(HordeProcessMessage):
pass


class HordeProcessStateChangeMessage(HordeProcessMessage):
process_state: HordeProcessState

Expand Down
Loading

0 comments on commit e07f50b

Please sign in to comment.