diff --git a/horde_worker_regen/process_management/process_manager.py b/horde_worker_regen/process_management/process_manager.py index cb279de1..2ba6ce84 100644 --- a/horde_worker_regen/process_management/process_manager.py +++ b/horde_worker_regen/process_management/process_manager.py @@ -7,7 +7,6 @@ from collections import deque from collections.abc import Mapping from io import BytesIO -from multiprocessing.connection import PipeConnection from multiprocessing.context import BaseContext from multiprocessing.synchronize import Lock as Lock_MultiProcessing from multiprocessing.synchronize import Semaphore @@ -58,10 +57,15 @@ ) from horde_worker_regen.process_management.worker_entry_points import start_inference_process, start_safety_process +try: + from multiprocessing.connection import PipeConnection as Connection +except ImportError: + from multiprocessing.connection import Connection # type: ignore + class HordeProcessInfo: mp_process: multiprocessing.Process - pipe_connection: PipeConnection + pipe_connection: Connection process_id: int process_type: HordeProcessKind last_process_state: HordeProcessState @@ -76,7 +80,7 @@ class HordeProcessInfo: def __init__( self, mp_process: multiprocessing.Process, - pipe_connection: PipeConnection, + pipe_connection: Connection, process_id: int, process_type: HordeProcessKind, last_process_state: HordeProcessState, @@ -1062,8 +1066,8 @@ async def api_job_pop(self) -> None: job_pop_request = ImageGenerateJobPopRequest( apikey=self.bridge_data.api_key, name=self.bridge_data.dreamer_worker_name, - bridge_agent="AI Horde Worker:23:tazlin reGen testing", - bridge_version=23, # TODO TIs broken + bridge_agent="AI Horde Worker reGen:1:https://github.com/Haidra-Org/", + bridge_version=1, # TODO TIs broken models=self.bridge_data.image_models_to_load, nsfw=self.bridge_data.nsfw, threads=self.max_concurrent_inference_processes, @@ -1093,6 +1097,11 @@ async def api_job_pop(self) -> None: self._job_pop_frequency = self._default_job_pop_frequency + info_string = "No job available. " + if len(self.job_deque) > 0: + info_string += f"Current job deque length: {len(self.job_deque)}. " + info_string += f"(Skipped reasons: {job_pop_response.skipped.model_dump(exclude_defaults=True)})" + if job_pop_response.id_ is None: logger.info( f"No job available. (Skipped reasons: {job_pop_response.skipped.model_dump(exclude_defaults=True)})", @@ -1107,6 +1116,48 @@ async def api_job_pop(self) -> None: new_response_dict["payload"]["seed"] = random.randint(0, (2**32) - 1) job_pop_response = ImageGenerateJobPopResponse(**new_response_dict) + if job_pop_response.source_image is not None and "https://" in job_pop_response.source_image: + # Download and convert the source image to base64 + fail_count = 0 + while True: + try: + if fail_count >= 10: + logger.error(f"Failed to download source image after {fail_count} attempts") + break + source_image_response = requests.get(job_pop_response.source_image) + source_image_response.raise_for_status() + new_response_dict = job_pop_response.model_dump(by_alias=True) + + new_response_dict["source_image"] = base64.b64encode(source_image_response.content).decode("utf-8") + job_pop_response = ImageGenerateJobPopResponse(**new_response_dict) + logger.debug(f"Downloaded source image for job {job_pop_response.id_}") + break + except Exception as e: + logger.error(f"Failed to download source image: {e}") + fail_count += 1 + time.sleep(0.5) + + if job_pop_response.source_mask is not None and "https://" in job_pop_response.source_mask: + # Download and convert the source image to base64 + fail_count = 0 + while True: + try: + if fail_count >= 10: + logger.error(f"Failed to download source image after {fail_count} attempts") + break + source_mask_response = requests.get(job_pop_response.source_mask) + source_mask_response.raise_for_status() + new_response_dict = job_pop_response.model_dump(by_alias=True) + + new_response_dict["source_mask"] = base64.b64encode(source_mask_response.content).decode("utf-8") + job_pop_response = ImageGenerateJobPopResponse(**new_response_dict) + logger.debug(f"Downloaded source image for job {job_pop_response.id_}") + break + except Exception as e: + logger.error(f"Failed to download source_mask: {e}") + fail_count += 1 + time.sleep(0.5) + async with self._job_deque_lock: self.job_deque.append(job_pop_response) self._testing_jobs_added += 1 diff --git a/requirements.txt b/requirements.txt index 3908cd4c..d80ffa9a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torch -horde_sdk +horde_sdk>=7.10.0 horde_model_reference horde_safety hordelib diff --git a/run_worker.py b/run_worker.py index 8883f49a..633a1d40 100644 --- a/run_worker.py +++ b/run_worker.py @@ -34,13 +34,15 @@ def main(ctx: BaseContext) -> None: imlr = ImageModelLoadResolver(horde_model_reference_manager) + resolved_models = None if bridge_data.meta_load_instructions is not None: resolved_models = imlr.resolve_meta_instructions( list(bridge_data.meta_load_instructions), AIHordeAPIManualClient(), ) - bridge_data.image_models_to_load = list(resolved_models) + if resolved_models is not None: + bridge_data.image_models_to_load = list(set(bridge_data.image_models_to_load + list(resolved_models))) start_working(ctx=ctx, bridge_data=bridge_data)