Skip to content

Commit

Permalink
Replace processes when a model is unloaded
Browse files Browse the repository at this point in the history
On Linux, it seems like there is a tremendous amount of memory
allocated by something outside Python when you allow one
worker to process many jobs on many different models. In order
to limit the damage from that behavior, we'll try to replace
the processes when the model is scheduled to be unloaded.

I realize this probably makes it a _little_ harder to decouple
a process from a model, but it's a huge stability improvement.

This also switches the model management strategy a tiny bit
by allocating a model to every open worker before trying
to unload a model. Previously, you would have
N processes = `threads` + `queue`, but jobs were very likely
to be scheduled on only the `threads` number of workers.
  • Loading branch information
zten committed Nov 27, 2023
1 parent c5da46a commit 5170812
Showing 1 changed file with 129 additions and 40 deletions.
169 changes: 129 additions & 40 deletions horde_worker_regen/process_management/process_manager.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import base64
import collections
import datetime
import multiprocessing
import os
import random
Expand Down Expand Up @@ -81,6 +83,8 @@ class HordeProcessInfo:
"""The type of this process."""
last_process_state: HordeProcessState
"""The last known state of this process."""
last_timestamp: datetime.datetime
"""Last time we updated the process info. If we're regularly working, then this value should change frequently."""
loaded_horde_model_name: str | None = None
"""The name of the horde model that is (supposedly) currently loaded in this process."""

Expand Down Expand Up @@ -115,6 +119,7 @@ def __init__(
self.process_id = process_id
self.process_type = process_type
self.last_process_state = last_process_state
self.last_timestamp = datetime.datetime.now()

def is_process_busy(self) -> bool:
"""Return true if the process is actively engaged in a task.
Expand Down Expand Up @@ -186,6 +191,9 @@ def update_entry(
if process_id is not None:
self.root[horde_model_name].process_id = process_id

def expire_entry(self, horde_model_name: str):
self.root.pop(horde_model_name, 'None')

def is_model_loaded(self, horde_model_name: str) -> bool:
"""Return true if the given model is loaded in any process."""
if horde_model_name not in self.root:
Expand Down Expand Up @@ -242,6 +250,8 @@ def update_entry(
if total_vram_bytes is not None:
self[process_id].total_vram_bytes = total_vram_bytes

self[process_id].last_timestamp = datetime.datetime.now()

def num_inference_processes(self) -> int:
"""Return the number of inference processes."""
count = 0
Expand All @@ -260,6 +270,12 @@ def num_available_inference_processes(self) -> int:

def get_first_available_inference_process(self) -> HordeProcessInfo | None:
"""Return the first available inference process, or None if there are none available."""
for p in self.values():
if p.process_type == HordeProcessType.INFERENCE \
and p.last_process_state == HordeProcessState.WAITING_FOR_JOB \
and p.loaded_horde_model_name is None:
return p

for p in self.values():
if p.process_type == HordeProcessType.INFERENCE and p.can_accept_job():
if p.last_process_state == HordeProcessState.PRELOADED_MODEL:
Expand Down Expand Up @@ -336,7 +352,8 @@ def __repr__(self) -> str:
base_string = "Processes: "
for process_id, process_info in self.items():
if process_info.process_type == HordeProcessType.INFERENCE:
base_string += f"{process_id}: ({process_info.loaded_horde_model_name}) "
base_string += (f"{process_id}: ({process_info.loaded_horde_model_name} "
f"[last event: {process_info.last_timestamp}]) ")
else:
base_string += f"{process_id}: ({process_info.process_type.name}) "
base_string += f"{process_info.last_process_state.name}; "
Expand Down Expand Up @@ -380,6 +397,21 @@ def is_job_checked_for_safety(self) -> bool:
return self.censored is not None


class LRUCache:
def __init__(self, capacity):
self.capacity = capacity
self.cache = collections.OrderedDict()

def append(self, key):
bumped = None
if key in self.cache:
self.cache.move_to_end(key)
elif len(self.cache) >= self.capacity:
bumped, _ = self.cache.popitem(last=False)
self.cache[key] = None
return bumped


class HordeWorkerProcessManager:
"""Manages and controls processes to act as a horde worker."""

Expand Down Expand Up @@ -412,6 +444,9 @@ def max_concurrent_inference_processes(self) -> int:

target_vram_overhead_bytes_map: Mapping[int, int] | None = None

process_timeout: datetime.timedelta
"""Max amount of time a job can go without checking in with the main process manager"""

@property
def max_queue_size(self) -> int:
"""The maximum number of jobs that can be queued."""
Expand Down Expand Up @@ -501,6 +536,8 @@ def num_total_processes(self) -> int:

_shutting_down = False

_lru: LRUCache

def __init__(
self,
*,
Expand Down Expand Up @@ -542,6 +579,8 @@ def __init__(
self._inference_semaphore = Semaphore(self._max_concurrent_inference_processes, ctx=ctx)

self.max_inference_processes = self.bridge_data.queue_size + self.bridge_data.max_threads
self._lru = LRUCache(self.max_inference_processes)
self.process_timeout = datetime.timedelta(minutes=5)

# If there is only one model to load and only one inference process, then we can only run one job at a time
# and there is no point in having more than one inference process
Expand Down Expand Up @@ -743,33 +782,34 @@ def start_inference_processes(self) -> None:
for _ in range(num_processes_to_start):
# Create a two-way communication pipe for the parent and child processes
pid = len(self._process_map)
pipe_connection, child_pipe_connection = multiprocessing.Pipe(duplex=True)

# Create a new process that will run the start_inference_process function
process = multiprocessing.Process(
target=start_inference_process,
args=(
pid,
self._process_message_queue,
child_pipe_connection,
self._inference_semaphore,
self._disk_lock,
),
)

process.start()

# Add the process to the process map
self._process_map[pid] = HordeProcessInfo(
mp_process=process,
pipe_connection=pipe_connection,
process_id=pid,
process_type=HordeProcessType.INFERENCE,
last_process_state=HordeProcessState.PROCESS_STARTING,
)
self._start_inference_process(pid)

logger.info(f"Started inference process (id: {pid})")

def _start_inference_process(self, pid):
logger.info(f"Starting inference process on PID {pid}")
pipe_connection, child_pipe_connection = multiprocessing.Pipe(duplex=True)
# Create a new process that will run the start_inference_process function
process = multiprocessing.Process(
target=start_inference_process,
args=(
pid,
self._process_message_queue,
child_pipe_connection,
self._inference_semaphore,
self._disk_lock,
),
)
process.start()
# Add the process to the process map
self._process_map[pid] = HordeProcessInfo(
mp_process=process,
pipe_connection=pipe_connection,
process_id=pid,
process_type=HordeProcessType.INFERENCE,
last_process_state=HordeProcessState.PROCESS_STARTING,
)

def end_inference_processes(self) -> None:
"""End any inference processes above the configured limit, or all of them if shutting down."""
if len(self.job_deque) > 0 and len(self.job_deque) != len(self.jobs_in_progress):
Expand All @@ -778,19 +818,23 @@ def end_inference_processes(self) -> None:
# Get the process to end
process_info = self._process_map._get_first_inference_process_to_kill()

if process_info is None:
return
if process_info is not None:
self._end_inference_process(process_info)

def _end_inference_process(self, process_info):
# Send the process a message to end
process_info.pipe_connection.send(HordeControlMessage(control_flag=HordeControlFlag.END_PROCESS))

# Update the process map
self._process_map.update_entry(process_id=process_info.process_id)

logger.info(f"Ended inference process {process_info.process_id}")

# Join the process with a timeout of 1 second
process_info.mp_process.join(timeout=1)
process_info.mp_process.kill()

def _replace_inference_process(self, process_info):
logger.debug(f"Replacing {process_info}")
self._end_inference_process(process_info)
self._start_inference_process(process_info.process_id)

total_num_completed_jobs: int = 0

Expand Down Expand Up @@ -1058,11 +1102,31 @@ def preload_models(self) -> bool:
if model_is_loaded:
continue

available_process = self._process_map.get_first_available_inference_process()
available_process = None
model_to_unload = self._lru.append(job.model)
if model_to_unload is not None:
for p in self._process_map.values():
if p.loaded_horde_model_name == model_to_unload and \
(p.last_process_state == HordeProcessState.INFERENCE_COMPLETE or \
p.last_process_state == HordeProcessState.WAITING_FOR_JOB):
available_process = p
if available_process is None:
available_process = self._process_map.get_first_available_inference_process()

if available_process is None:
return False

if available_process.last_process_state != HordeProcessState.WAITING_FOR_JOB \
and available_process.loaded_horde_model_name is not None:
# We're going to restart the process and then exit the loop, because
# available_process is very quickly _not_ going to be available.
# We also don't want to block waiting for the newly forked job to become
# available, so we'll wait for it to become ready before scheduling a model
# to be loaded on it.
self._replace_inference_process(available_process)
self._horde_model_map.expire_entry(available_process.loaded_horde_model_name)
return False

logger.debug(f"Preloading model {job.model} on process {available_process.process_id}")
logger.debug(f"Available inference processes: {self._process_map}")
logger.debug(f"Horde model map: {self._horde_model_map}")
Expand Down Expand Up @@ -1145,11 +1209,7 @@ def start_inference(self) -> None:
next_n_models = list(self.get_next_n_models(self.max_inference_processes))

# If the model would be used by another process soon, don't unload it
if (
self.max_concurrent_inference_processes > 1
and process_info.loaded_horde_model_name
in next_n_models[: self.max_concurrent_inference_processes - 1]
):
if process_info.loaded_horde_model_name in next_n_models:
continue

process_info.pipe_connection.send(
Expand Down Expand Up @@ -1770,6 +1830,8 @@ async def api_job_pop(self) -> None:

async with self._job_deque_lock, self._job_pop_timestamps_lock:
self.job_deque.append(job_pop_response)
jobs = list(map(lambda x: f'<{x.id_}: {x.model}>', self.job_deque))
logger.info(f'Job queue: {", ".join(jobs)}')
# self._testing_jobs_added += 1
self.job_pop_timestamps[str(job_pop_response.id_)] = time.time()

Expand Down Expand Up @@ -1885,6 +1947,7 @@ async def _process_control_loop(self) -> None:
self.start_inference_processes()

while True:
logger.debug("_process_control_loop looped")
try:
if self.stable_diffusion_reference is None:
return
Expand All @@ -1908,6 +1971,8 @@ async def _process_control_loop(self) -> None:
async with self._job_deque_lock, self._jobs_safety_check_lock, self._completed_jobs_lock:
self.receive_and_handle_process_messages()

self.replace_hung_processes()

self.unload_models()

if self._shutting_down:
Expand All @@ -1921,6 +1986,8 @@ async def _process_control_loop(self) -> None:
logger.info(f"{self._process_map}")
logger.info(f"Threads being used: {self._max_concurrent_inference_processes}")
logger.info(f"Number of jobs popped: {len(self.job_deque)}")
jobs = list(map(lambda x: f'<{x.id_}: {x.model}>', self.job_deque))
logger.info(f'Job queue: {", ".join(jobs)}')
logger.info(f"Number of jobs in progress: {len(self.jobs_in_progress)}")
logger.info(f"Number of jobs pending safety check: {len(self.jobs_pending_safety_check)}")
logger.info(f"Number of jobs being safety checked: {len(self.jobs_being_safety_checked)}")
Expand Down Expand Up @@ -2002,13 +2069,28 @@ async def _bridge_data_loop(self) -> None:
except CancelledError:
self._shutting_down = True

@staticmethod
async def _handle_exception(task):
try:
await task
except Exception as e:
logger.error(f"Caught exception in task {task}: {e}")

async def _main_loop(self) -> None:
# Run both loops concurrently
process_control_loop = asyncio.create_task(self._process_control_loop(), name="process_control_loop")
api_call_loop = asyncio.create_task(self._api_call_loop(), name="api_call_loop")
job_submit_loop = asyncio.create_task(self._job_submit_loop(), name="job_submit_loop")
bridge_data_loop = asyncio.create_task(self._bridge_data_loop(), name="bridge_data_loop")
process_control_loop.add_done_callback(self._handle_exception)
api_call_loop.add_done_callback(self._handle_exception)
job_submit_loop.add_done_callback(self._handle_exception)
bridge_data_loop.add_done_callback(self._handle_exception)
await asyncio.gather(
asyncio.create_task(self._process_control_loop(), name="process_control_loop"),
asyncio.create_task(self._api_call_loop(), name="api_call_loop"),
asyncio.create_task(self._job_submit_loop(), name="job_submit_loop"),
asyncio.create_task(self._bridge_data_loop(), name="bridge_data_loop"),
process_control_loop,
api_call_loop,
job_submit_loop,
bridge_data_loop,
)

_caught_sigints = 0
Expand Down Expand Up @@ -2046,3 +2128,10 @@ def shutdown() -> None:
sys.exit(0)

threading.Thread(target=shutdown).start()

def replace_hung_processes(self):
now = datetime.datetime.now()
for pid, process_info in self._process_map.items():
if (now - process_info.last_timestamp) > self.process_timeout:
logger.error(f"{process_info} has exceeded its timeout and will be replaced")
self._replace_inference_process(process_info)

0 comments on commit 5170812

Please sign in to comment.