Skip to content

Commit

Permalink
Merge pull request #2 from Haidra-Org/regen-day2
Browse files Browse the repository at this point in the history
refactor: more error checking, better flow+control; docs: more docstrings
  • Loading branch information
tazlin authored Oct 9, 2023
2 parents 52082a5 + e4e6fc3 commit e903e18
Show file tree
Hide file tree
Showing 6 changed files with 663 additions and 211 deletions.
70 changes: 56 additions & 14 deletions horde_worker_regen/process_management/horde_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,20 +58,25 @@ class HordeProcess(abc.ABC):
"""The time to sleep between each loop iteration."""

_end_process: bool = False
"""Whether the process should end soon."""

_memory_report_interval: float = 5.0
"""The time to wait between each memory report."""

_last_sent_process_state: HordeProcessState = HordeProcessState.PROCESS_STARTING
"""The last process state that was sent to the main process."""

_vram_total_bytes: int = 0
"""The total number of bytes of VRAM available on the GPU."""

def get_vram_usage_bytes(self) -> int:
"""Return the number of bytes of VRAM used by the GPU."""
from hordelib.comfy_horde import get_torch_free_vram_mb, get_torch_total_vram_mb

return get_torch_total_vram_mb() - get_torch_free_vram_mb()

def get_vram_total_bytes(self) -> int:
"""Return the total number of bytes of VRAM available on the GPU."""
from hordelib.comfy_horde import get_torch_total_vram_mb

return get_torch_total_vram_mb()
Expand All @@ -83,6 +88,16 @@ def __init__(
pipe_connection: Connection,
disk_lock: Lock,
) -> None:
"""Initialise the process.
Args:
process_id (int): The ID of the process. This is not the same as the process PID.
process_message_queue (ProcessQueue): The queue the main process uses to receive messages from all worker \
processes.
pipe_connection (Connection): Receives `HordeControlMessage`s from the main process.
disk_lock (Lock): A lock used to prevent multiple processes from accessing disk at the same time.
"""

self.process_id = process_id
self.process_message_queue = process_message_queue
self.pipe_connection = pipe_connection
Expand All @@ -109,6 +124,15 @@ def send_process_state_change_message(
info: str,
time_elapsed: float | None = None,
) -> None:
"""Send a process state change message to the main process.
Args:
process_state (HordeProcessState): The state of the process.
info (str): Information about the process.
time_elapsed (float | None, optional): The time elapsed during the last operation, if applicable. \
Defaults to None.
"""
message = HordeProcessStateChangeMessage(
process_state=process_state,
process_id=self.process_id,
Expand All @@ -122,7 +146,12 @@ def send_process_state_change_message(
_last_heartbeat_time: float = 0.0

def send_heartbeat_message(self) -> None:
"""Send a heartbeat message to the main process."""
"""Send a heartbeat message to the main process, indicating that the process is still alive
during an operation.
Note that this will only send a heartbeat message if the last heartbeat was sent more than
`_heartbeat_limit_interval_seconds` ago.
"""

if (time.time() - self._last_heartbeat_time) < self._heartbeat_limit_interval_seconds:
return
Expand All @@ -137,14 +166,18 @@ def send_heartbeat_message(self) -> None:
self._last_heartbeat_time = time.time()

@abstractmethod
def cleanup_and_exit(self) -> None:
def cleanup_for_exit(self) -> None:
"""Cleanup and exit the process."""

def send_memory_report_message(
self,
include_vram: bool = False,
) -> None:
"""Send a memory report message to the main process."""
"""Send a memory report message to the main process.
Args:
include_vram (bool, optional): Whether to include VRAM usage in the message. Defaults to False.
"""
message = HordeProcessMemoryMessage(
process_id=self.process_id,
info="Memory report",
Expand All @@ -160,7 +193,11 @@ def send_memory_report_message(

@abstractmethod
def _receive_and_handle_control_message(self, message: HordeControlMessage) -> None:
"""Receive and handle a control message from the main process."""
"""Receive and handle a control message from the main process.
Args:
message (HordeControlMessage): The message to handle.
"""

def receive_and_handle_control_messages(self) -> None:
"""Get and handle any control messages pending from the main process."""
Expand All @@ -179,42 +216,47 @@ def receive_and_handle_control_messages(self) -> None:
self._receive_and_handle_control_message(message)

def worker_cycle(self) -> None:
"""Called after messages have been received and handled. Override this to implement any additional logic."""
"""Called after messages have been received and handled. Override this to implement any process specific \
logic."""
return

def main_loop(self) -> None:
"""The main loop of the worker process."""
signal.signal(signal.SIGINT, signal_handler)

while not self._end_process:
try:
time.sleep(self._loop_interval)

self.receive_and_handle_control_messages()

self.worker_cycle()
except KeyboardInterrupt:
logger.info("Keyboard interrupt received")
time.sleep(self._loop_interval)
self.receive_and_handle_control_messages()
self.worker_cycle()

# We escaped the loop, so the process is ending
self.send_process_state_change_message(
process_state=HordeProcessState.PROCESS_ENDING,
info="Process ending",
)

self.cleanup_and_exit()
self.cleanup_for_exit()

logger.info("Process ended")
self.send_process_state_change_message(
process_state=HordeProcessState.PROCESS_ENDED,
info="Process ended",
)

# We are exiting, so send a final memory report
self.send_memory_report_message(include_vram=True)

# Exit the process (we expect to be a child process)
sys.exit(0)


_signals_caught = 0


def signal_handler(sig: int, frame: object) -> None:
"""Called when a signal is received. This will exit the process gracefully if the process has only received one \
signal, or exit immediately if the process has received two signals."""

global _signals_caught
if _signals_caught >= 1:
logger.warning("Received second signal, exiting immediately")
Expand Down
Loading

0 comments on commit e903e18

Please sign in to comment.