diff --git a/internal_controlnet/external_code.py b/internal_controlnet/external_code.py new file mode 100644 index 000000000..f2a3c7bb1 --- /dev/null +++ b/internal_controlnet/external_code.py @@ -0,0 +1,458 @@ +from enum import Enum +from typing import List, Any, Optional, Union, Tuple, Dict +import numpy as np +from modules import scripts, processing, shared +from scripts import global_state +from scripts.processor import preprocessor_sliders_config, model_free_preprocessors +from scripts.logging import logger + +from modules.api import api + + +def get_api_version() -> int: + return 2 + + +class ControlMode(Enum): + """ + The improved guess mode. + """ + + BALANCED = "Balanced" + PROMPT = "My prompt is more important" + CONTROL = "ControlNet is more important" + + +class ResizeMode(Enum): + """ + Resize modes for ControlNet input images. + """ + + RESIZE = "Just Resize" + INNER_FIT = "Crop and Resize" + OUTER_FIT = "Resize and Fill" + + def int_value(self): + if self == ResizeMode.RESIZE: + return 0 + elif self == ResizeMode.INNER_FIT: + return 1 + elif self == ResizeMode.OUTER_FIT: + return 2 + assert False, "NOTREACHED" + + +resize_mode_aliases = { + 'Inner Fit (Scale to Fit)': 'Crop and Resize', + 'Outer Fit (Shrink to Fit)': 'Resize and Fill', + 'Scale to Fit (Inner Fit)': 'Crop and Resize', + 'Envelope (Outer Fit)': 'Resize and Fill', +} + + +def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode: + if isinstance(value, str): + return ResizeMode(resize_mode_aliases.get(value, value)) + elif isinstance(value, int): + assert value >= 0 + if value == 3: # 'Just Resize (Latent upscale)' + return ResizeMode.RESIZE + + if value >= len(ResizeMode): + logger.warning(f'Unrecognized ResizeMode int value {value}. Fall back to RESIZE.') + return ResizeMode.RESIZE + + return [e for e in ResizeMode][value] + else: + return value + + +def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode: + if isinstance(value, str): + return ControlMode(value) + elif isinstance(value, int): + return [e for e in ControlMode][value] + else: + return value + + +def visualize_inpaint_mask(img): + if img.ndim == 3 and img.shape[2] == 4: + result = img.copy() + mask = result[:, :, 3] + mask = 255 - mask // 2 + result[:, :, 3] = mask + return np.ascontiguousarray(result.copy()) + return img + + +def pixel_perfect_resolution( + image: np.ndarray, + target_H: int, + target_W: int, + resize_mode: ResizeMode, +) -> int: + """ + Calculate the estimated resolution for resizing an image while preserving aspect ratio. + + The function first calculates scaling factors for height and width of the image based on the target + height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger + scaling factor to estimate the new resolution. + + If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image + fits within the target dimensions, potentially leaving some empty space. + + If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target + dimensions are fully filled, potentially cropping the image. + + After calculating the estimated resolution, the function prints some debugging information. + + Args: + image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels]. + target_H (int): The target height for the image. + target_W (int): The target width for the image. + resize_mode (ResizeMode): The mode for resizing. + + Returns: + int: The estimated resolution after resizing. + """ + raw_H, raw_W, _ = image.shape + + k0 = float(target_H) / float(raw_H) + k1 = float(target_W) / float(raw_W) + + if resize_mode == ResizeMode.OUTER_FIT: + estimation = min(k0, k1) * float(min(raw_H, raw_W)) + else: + estimation = max(k0, k1) * float(min(raw_H, raw_W)) + + logger.debug(f"Pixel Perfect Computation:") + logger.debug(f"resize_mode = {resize_mode}") + logger.debug(f"raw_H = {raw_H}") + logger.debug(f"raw_W = {raw_W}") + logger.debug(f"target_H = {target_H}") + logger.debug(f"target_W = {target_W}") + logger.debug(f"estimation = {estimation}") + + return int(np.round(estimation)) + + +InputImage = Union[np.ndarray, str] +InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage] + + +class ControlNetUnit: + """ + Represents an entire ControlNet processing unit. + """ + + def __init__( + self, + enabled: bool = True, + module: Optional[str] = None, + model: Optional[str] = None, + weight: float = 1.0, + image: Optional[InputImage] = None, + resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT, + low_vram: bool = False, + processor_res: int = -1, + threshold_a: float = -1, + threshold_b: float = -1, + guidance_start: float = 0.0, + guidance_end: float = 1.0, + pixel_perfect: bool = False, + control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED, + **_kwargs, + ): + self.enabled = enabled + self.module = module + self.model = model + self.weight = weight + self.image = image + self.resize_mode = resize_mode + self.low_vram = low_vram + self.processor_res = processor_res + self.threshold_a = threshold_a + self.threshold_b = threshold_b + self.guidance_start = guidance_start + self.guidance_end = guidance_end + self.pixel_perfect = pixel_perfect + self.control_mode = control_mode + + def __eq__(self, other): + if not isinstance(other, ControlNetUnit): + return False + + return vars(self) == vars(other) + + +def to_base64_nparray(encoding: str): + """ + Convert a base64 image into the image type the extension uses + """ + + return np.array(api.decode_base64_to_image(encoding)).astype('uint8') + + +def get_all_units_in_processing(p: processing.StableDiffusionProcessing) -> List[ControlNetUnit]: + """ + Fetch ControlNet processing units from a StableDiffusionProcessing. + """ + + return get_all_units(p.scripts, p.script_args) + + +def get_all_units(script_runner: scripts.ScriptRunner, script_args: List[Any]) -> List[ControlNetUnit]: + """ + Fetch ControlNet processing units from an existing script runner. + Use this function to fetch units from the list of all scripts arguments. + """ + + cn_script = find_cn_script(script_runner) + if cn_script: + return get_all_units_from(script_args[cn_script.args_from:cn_script.args_to]) + + return [] + + +def get_all_units_from(script_args: List[Any]) -> List[ControlNetUnit]: + """ + Fetch ControlNet processing units from ControlNet script arguments. + Use `external_code.get_all_units` to fetch units from the list of all scripts arguments. + """ + + def is_stale_unit(script_arg: Any) -> bool: + """ Returns whether the script_arg is potentially an stale version of + ControlNetUnit created before module reload.""" + return ( + 'ControlNetUnit' in type(script_arg).__name__ and + not isinstance(script_arg, ControlNetUnit) + ) + + def is_controlnet_unit(script_arg: Any) -> bool: + """ Returns whether the script_arg is ControlNetUnit or anything that + can be treated like ControlNetUnit. """ + return ( + isinstance(script_arg, (ControlNetUnit, dict)) or + ( + hasattr(script_arg, '__dict__') and + set(vars(ControlNetUnit()).keys()).issubset( + set(vars(script_arg).keys())) + ) + ) + + all_units = [ + to_processing_unit(script_arg) + for script_arg in script_args + if is_controlnet_unit(script_arg) + ] + if not all_units: + logger.warning( + "No ControlNetUnit detected in args. It is very likely that you are having an extension conflict." + f"Here are args received by ControlNet: {script_args}.") + if any(is_stale_unit(script_arg) for script_arg in script_args): + logger.debug( + "Stale version of ControlNetUnit detected. The ControlNetUnit received" + "by ControlNet is created before the newest load of ControlNet extension." + "They will still be used by ControlNet as long as they provide same fields" + "defined in the newest version of ControlNetUnit." + ) + + return all_units + + +def get_single_unit_from(script_args: List[Any], index: int = 0) -> Optional[ControlNetUnit]: + """ + Fetch a single ControlNet processing unit from ControlNet script arguments. + The list must not contain script positional arguments. It must only contain processing units. + """ + + i = 0 + while i < len(script_args) and index >= 0: + if index == 0 and script_args[i] is not None: + return to_processing_unit(script_args[i]) + i += 1 + + index -= 1 + + return None + + +def get_max_models_num(): + """ + Fetch the maximum number of allowed ControlNet models. + """ + + max_models_num = shared.opts.data.get("control_net_max_models_num", 1) + return max_models_num + + +def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit: + """ + Convert different types to processing unit. + If `unit` is a dict, alternative keys are supported. See `ext_compat_keys` in implementation for details. + """ + + ext_compat_keys = { + 'guessmode': 'guess_mode', + 'guidance': 'guidance_end', + 'lowvram': 'low_vram', + 'input_image': 'image' + } + + if isinstance(unit, dict): + unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()} + + mask = None + if 'mask' in unit: + mask = unit['mask'] + del unit['mask'] + + if 'image' in unit and not isinstance(unit['image'], dict): + unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit[ + 'image'] else None + + if 'guess_mode' in unit: + logger.warning('Guess Mode is removed since 1.1.136. Please use Control Mode instead.') + + unit = ControlNetUnit(**unit) + + # temporary, check #602 + # assert isinstance(unit, ControlNetUnit), f'bad argument to controlnet extension: {unit}\nexpected Union[dict[str, Any], ControlNetUnit]' + return unit + + +def update_cn_script_in_processing( + p: processing.StableDiffusionProcessing, + cn_units: List[ControlNetUnit], + **_kwargs, # for backwards compatibility +): + """ + Update the arguments of the ControlNet script in `p.script_args` in place, reading from `cn_units`. + `cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want. + + Does not update `p.script_args` if any of the folling is true: + - ControlNet is not present in `p.scripts` + - `p.script_args` is not filled with script arguments for scripts that are processed before ControlNet + """ + + cn_units_type = type(cn_units) if type(cn_units) in (list, tuple) else list + script_args = list(p.script_args) + update_cn_script_in_place(p.scripts, script_args, cn_units) + p.script_args = cn_units_type(script_args) + + +def update_cn_script_in_place( + script_runner: scripts.ScriptRunner, + script_args: List[Any], + cn_units: List[ControlNetUnit], + **_kwargs, # for backwards compatibility +): + """ + Update the arguments of the ControlNet script in `script_args` in place, reading from `cn_units`. + `cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want. + + Does not update `script_args` if any of the folling is true: + - ControlNet is not present in `script_runner` + - `script_args` is not filled with script arguments for scripts that are processed before ControlNet + """ + + cn_script = find_cn_script(script_runner) + if cn_script is None or len(script_args) < cn_script.args_from: + return + + # fill in remaining parameters to satisfy max models, just in case script needs it. + max_models = shared.opts.data.get("control_net_max_models_num", 1) + cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0) + + cn_script_args_diff = 0 + for script in script_runner.alwayson_scripts: + if script is cn_script: + cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from) + script_args[script.args_from:script.args_to] = cn_units + script.args_to = script.args_from + len(cn_units) + else: + script.args_from += cn_script_args_diff + script.args_to += cn_script_args_diff + + +def get_models(update: bool = False) -> List[str]: + """ + Fetch the list of available models. + Each value is a valid candidate of `ControlNetUnit.model`. + + Keyword arguments: + update -- Whether to refresh the list from disk. (default False) + """ + + if update: + global_state.update_cn_models() + + return list(global_state.cn_models_names.values()) + + +def get_modules(alias_names: bool = False) -> List[str]: + """ + Fetch the list of available preprocessors. + Each value is a valid candidate of `ControlNetUnit.module`. + + Keyword arguments: + alias_names -- Whether to get the ui alias names instead of internal keys + """ + + modules = list(global_state.cn_preprocessor_modules.keys()) + + if alias_names: + modules = [global_state.preprocessor_aliases.get(module, module) for module in modules] + + return modules + + +def get_modules_detail(alias_names: bool = False) -> Dict[str, Any]: + """ + get the detail of all preprocessors including + sliders: the slider config in Auto1111 webUI + + Keyword arguments: + alias_names -- Whether to get the module detail with alias names instead of internal keys + """ + + _module_detail = {} + _module_list = get_modules(False) + _module_list_alias = get_modules(True) + + _output_list = _module_list if not alias_names else _module_list_alias + for index, module in enumerate(_output_list): + if _module_list[index] in preprocessor_sliders_config: + _module_detail[module] = { + "model_free": module in model_free_preprocessors, + "sliders": preprocessor_sliders_config[_module_list[index]] + } + else: + _module_detail[module] = { + "model_free": False, + "sliders": [] + } + + return _module_detail + + +def find_cn_script(script_runner: scripts.ScriptRunner) -> Optional[scripts.Script]: + """ + Find the ControlNet script in `script_runner`. Returns `None` if `script_runner` does not contain a ControlNet script. + """ + + if script_runner is None: + return None + + for script in script_runner.alwayson_scripts: + if is_cn_script(script): + return script + + +def is_cn_script(script: scripts.Script) -> bool: + """ + Determine whether `script` is a ControlNet script. + """ + + return script.title().lower() == 'controlnet' diff --git a/scripts/controlnet.py b/scripts/controlnet.py index 7e0d15079..aec70caab 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -4,7 +4,6 @@ from collections import OrderedDict from copy import copy from typing import Dict, Optional, Tuple -import importlib import modules.scripts as scripts from modules import shared, devices, script_callbacks, processing, masking, images import gradio as gr @@ -13,16 +12,6 @@ from einops import rearrange from scripts import global_state, hook, external_code, processor, batch_hijack, controlnet_version, utils from scripts.controlnet_ui import controlnet_ui_group -importlib.reload(processor) -importlib.reload(utils) -importlib.reload(global_state) -importlib.reload(hook) -importlib.reload(external_code) -# Reload ui group as `ControlNetUnit` is redefined in `external_code`. If `controlnet_ui_group` -# is not reloaded, `UiControlNetUnit` will inherit from a stale version of `ControlNetUnit`, -# which can cause typecheck to fail. -importlib.reload(controlnet_ui_group) -importlib.reload(batch_hijack) from scripts.cldm import PlugableControlModel from scripts.processor import * from scripts.adapter import PlugableAdapter diff --git a/scripts/external_code.py b/scripts/external_code.py index c085c32a9..08cefb7af 100644 --- a/scripts/external_code.py +++ b/scripts/external_code.py @@ -1,453 +1 @@ -from enum import Enum -from typing import List, Any, Optional, Union, Tuple, Dict -import numpy as np -from modules import scripts, processing, shared -from scripts import global_state -from scripts.processor import preprocessor_sliders_config, model_free_preprocessors -from scripts.logging import logger - -from modules.api import api - - -def get_api_version() -> int: - return 2 - - -class ControlMode(Enum): - """ - The improved guess mode. - """ - - BALANCED = "Balanced" - PROMPT = "My prompt is more important" - CONTROL = "ControlNet is more important" - - -class ResizeMode(Enum): - """ - Resize modes for ControlNet input images. - """ - - RESIZE = "Just Resize" - INNER_FIT = "Crop and Resize" - OUTER_FIT = "Resize and Fill" - - def int_value(self): - if self == ResizeMode.RESIZE: - return 0 - elif self == ResizeMode.INNER_FIT: - return 1 - elif self == ResizeMode.OUTER_FIT: - return 2 - assert False, "NOTREACHED" - - -resize_mode_aliases = { - 'Inner Fit (Scale to Fit)': 'Crop and Resize', - 'Outer Fit (Shrink to Fit)': 'Resize and Fill', - 'Scale to Fit (Inner Fit)': 'Crop and Resize', - 'Envelope (Outer Fit)': 'Resize and Fill', -} - - -def resize_mode_from_value(value: Union[str, int, ResizeMode]) -> ResizeMode: - if isinstance(value, str): - return ResizeMode(resize_mode_aliases.get(value, value)) - elif isinstance(value, int): - assert value >= 0 - if value == 3: # 'Just Resize (Latent upscale)' - return ResizeMode.RESIZE - - if value >= len(ResizeMode): - logger.warning(f'Unrecognized ResizeMode int value {value}. Fall back to RESIZE.') - return ResizeMode.RESIZE - - return [e for e in ResizeMode][value] - else: - return value - - -def control_mode_from_value(value: Union[str, int, ControlMode]) -> ControlMode: - if isinstance(value, str): - return ControlMode(value) - elif isinstance(value, int): - return [e for e in ControlMode][value] - else: - return value - - -def visualize_inpaint_mask(img): - if img.ndim == 3 and img.shape[2] == 4: - result = img.copy() - mask = result[:, :, 3] - mask = 255 - mask // 2 - result[:, :, 3] = mask - return np.ascontiguousarray(result.copy()) - return img - - -def pixel_perfect_resolution( - image: np.ndarray, - target_H: int, - target_W: int, - resize_mode: ResizeMode, -) -> int: - """ - Calculate the estimated resolution for resizing an image while preserving aspect ratio. - - The function first calculates scaling factors for height and width of the image based on the target - height and width. Then, based on the chosen resize mode, it either takes the smaller or the larger - scaling factor to estimate the new resolution. - - If the resize mode is OUTER_FIT, the function uses the smaller scaling factor, ensuring the whole image - fits within the target dimensions, potentially leaving some empty space. - - If the resize mode is not OUTER_FIT, the function uses the larger scaling factor, ensuring the target - dimensions are fully filled, potentially cropping the image. - - After calculating the estimated resolution, the function prints some debugging information. - - Args: - image (np.ndarray): A 3D numpy array representing an image. The dimensions represent [height, width, channels]. - target_H (int): The target height for the image. - target_W (int): The target width for the image. - resize_mode (ResizeMode): The mode for resizing. - - Returns: - int: The estimated resolution after resizing. - """ - raw_H, raw_W, _ = image.shape - - k0 = float(target_H) / float(raw_H) - k1 = float(target_W) / float(raw_W) - - if resize_mode == ResizeMode.OUTER_FIT: - estimation = min(k0, k1) * float(min(raw_H, raw_W)) - else: - estimation = max(k0, k1) * float(min(raw_H, raw_W)) - - logger.debug(f"Pixel Perfect Computation:") - logger.debug(f"resize_mode = {resize_mode}") - logger.debug(f"raw_H = {raw_H}") - logger.debug(f"raw_W = {raw_W}") - logger.debug(f"target_H = {target_H}") - logger.debug(f"target_W = {target_W}") - logger.debug(f"estimation = {estimation}") - - return int(np.round(estimation)) - - -InputImage = Union[np.ndarray, str] -InputImage = Union[Dict[str, InputImage], Tuple[InputImage, InputImage], InputImage] - - -class ControlNetUnit: - """ - Represents an entire ControlNet processing unit. - """ - - def __init__( - self, - enabled: bool=True, - module: Optional[str]=None, - model: Optional[str]=None, - weight: float=1.0, - image: Optional[InputImage]=None, - resize_mode: Union[ResizeMode, int, str] = ResizeMode.INNER_FIT, - low_vram: bool=False, - processor_res: int=-1, - threshold_a: float=-1, - threshold_b: float=-1, - guidance_start: float=0.0, - guidance_end: float=1.0, - pixel_perfect: bool=False, - control_mode: Union[ControlMode, int, str] = ControlMode.BALANCED, - **_kwargs, - ): - self.enabled = enabled - self.module = module - self.model = model - self.weight = weight - self.image = image - self.resize_mode = resize_mode - self.low_vram = low_vram - self.processor_res = processor_res - self.threshold_a = threshold_a - self.threshold_b = threshold_b - self.guidance_start = guidance_start - self.guidance_end = guidance_end - self.pixel_perfect = pixel_perfect - self.control_mode = control_mode - - def __eq__(self, other): - if not isinstance(other, ControlNetUnit): - return False - - return vars(self) == vars(other) - - -def to_base64_nparray(encoding: str): - """ - Convert a base64 image into the image type the extension uses - """ - - return np.array(api.decode_base64_to_image(encoding)).astype('uint8') - - -def get_all_units_in_processing(p: processing.StableDiffusionProcessing) -> List[ControlNetUnit]: - """ - Fetch ControlNet processing units from a StableDiffusionProcessing. - """ - - return get_all_units(p.scripts, p.script_args) - - -def get_all_units(script_runner: scripts.ScriptRunner, script_args: List[Any]) -> List[ControlNetUnit]: - """ - Fetch ControlNet processing units from an existing script runner. - Use this function to fetch units from the list of all scripts arguments. - """ - - cn_script = find_cn_script(script_runner) - if cn_script: - return get_all_units_from(script_args[cn_script.args_from:cn_script.args_to]) - - return [] - - -def get_all_units_from(script_args: List[Any]) -> List[ControlNetUnit]: - """ - Fetch ControlNet processing units from ControlNet script arguments. - Use `external_code.get_all_units` to fetch units from the list of all scripts arguments. - """ - def is_stale_unit(script_arg: Any) -> bool: - """ Returns whether the script_arg is potentially an stale version of - ControlNetUnit created before module reload.""" - return ( - 'ControlNetUnit' in type(script_arg).__name__ and - not isinstance(script_arg, ControlNetUnit) - ) - - def is_controlnet_unit(script_arg: Any) -> bool: - """ Returns whether the script_arg is ControlNetUnit or anything that - can be treated like ControlNetUnit. """ - return ( - isinstance(script_arg, (ControlNetUnit, dict)) or - ( - hasattr(script_arg, '__dict__') and - set(vars(ControlNetUnit()).keys()).issubset( - set(vars(script_arg).keys())) - ) - ) - - all_units = [ - to_processing_unit(script_arg) - for script_arg in script_args - if is_controlnet_unit(script_arg) - ] - if not all_units: - logger.warning("No ControlNetUnit detected in args. It is very likely that you are having an extension conflict." - f"Here are args received by ControlNet: {script_args}.") - if any(is_stale_unit(script_arg) for script_arg in script_args): - logger.debug( - "Stale version of ControlNetUnit detected. The ControlNetUnit received" - "by ControlNet is created before the newest load of ControlNet extension." - "They will still be used by ControlNet as long as they provide same fields" - "defined in the newest version of ControlNetUnit." - ) - - return all_units - - -def get_single_unit_from(script_args: List[Any], index: int=0) -> Optional[ControlNetUnit]: - """ - Fetch a single ControlNet processing unit from ControlNet script arguments. - The list must not contain script positional arguments. It must only contain processing units. - """ - - i = 0 - while i < len(script_args) and index >= 0: - if index == 0 and script_args[i] is not None: - return to_processing_unit(script_args[i]) - i += 1 - - index -= 1 - - return None - -def get_max_models_num(): - """ - Fetch the maximum number of allowed ControlNet models. - """ - - max_models_num = shared.opts.data.get("control_net_max_models_num", 1) - return max_models_num - -def to_processing_unit(unit: Union[Dict[str, Any], ControlNetUnit]) -> ControlNetUnit: - """ - Convert different types to processing unit. - If `unit` is a dict, alternative keys are supported. See `ext_compat_keys` in implementation for details. - """ - - ext_compat_keys = { - 'guessmode': 'guess_mode', - 'guidance': 'guidance_end', - 'lowvram': 'low_vram', - 'input_image': 'image' - } - - if isinstance(unit, dict): - unit = {ext_compat_keys.get(k, k): v for k, v in unit.items()} - - mask = None - if 'mask' in unit: - mask = unit['mask'] - del unit['mask'] - - if 'image' in unit and not isinstance(unit['image'], dict): - unit['image'] = {'image': unit['image'], 'mask': mask} if mask is not None else unit['image'] if unit['image'] else None - - if 'guess_mode' in unit: - logger.warning('Guess Mode is removed since 1.1.136. Please use Control Mode instead.') - - unit = ControlNetUnit(**unit) - - # temporary, check #602 - #assert isinstance(unit, ControlNetUnit), f'bad argument to controlnet extension: {unit}\nexpected Union[dict[str, Any], ControlNetUnit]' - return unit - - -def update_cn_script_in_processing( - p: processing.StableDiffusionProcessing, - cn_units: List[ControlNetUnit], - **_kwargs, # for backwards compatibility -): - """ - Update the arguments of the ControlNet script in `p.script_args` in place, reading from `cn_units`. - `cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want. - - Does not update `p.script_args` if any of the folling is true: - - ControlNet is not present in `p.scripts` - - `p.script_args` is not filled with script arguments for scripts that are processed before ControlNet - """ - - cn_units_type = type(cn_units) if type(cn_units) in (list, tuple) else list - script_args = list(p.script_args) - update_cn_script_in_place(p.scripts, script_args, cn_units) - p.script_args = cn_units_type(script_args) - - -def update_cn_script_in_place( - script_runner: scripts.ScriptRunner, - script_args: List[Any], - cn_units: List[ControlNetUnit], - **_kwargs, # for backwards compatibility -): - """ - Update the arguments of the ControlNet script in `script_args` in place, reading from `cn_units`. - `cn_units` and its elements are not modified. You can call this function repeatedly, as many times as you want. - - Does not update `script_args` if any of the folling is true: - - ControlNet is not present in `script_runner` - - `script_args` is not filled with script arguments for scripts that are processed before ControlNet - """ - - cn_script = find_cn_script(script_runner) - if cn_script is None or len(script_args) < cn_script.args_from: - return - - # fill in remaining parameters to satisfy max models, just in case script needs it. - max_models = shared.opts.data.get("control_net_max_models_num", 1) - cn_units = cn_units + [ControlNetUnit(enabled=False)] * max(max_models - len(cn_units), 0) - - cn_script_args_diff = 0 - for script in script_runner.alwayson_scripts: - if script is cn_script: - cn_script_args_diff = len(cn_units) - (cn_script.args_to - cn_script.args_from) - script_args[script.args_from:script.args_to] = cn_units - script.args_to = script.args_from + len(cn_units) - else: - script.args_from += cn_script_args_diff - script.args_to += cn_script_args_diff - - -def get_models(update: bool=False) -> List[str]: - """ - Fetch the list of available models. - Each value is a valid candidate of `ControlNetUnit.model`. - - Keyword arguments: - update -- Whether to refresh the list from disk. (default False) - """ - - if update: - global_state.update_cn_models() - - return list(global_state.cn_models_names.values()) - - -def get_modules(alias_names: bool = False) -> List[str]: - """ - Fetch the list of available preprocessors. - Each value is a valid candidate of `ControlNetUnit.module`. - - Keyword arguments: - alias_names -- Whether to get the ui alias names instead of internal keys - """ - - modules = list(global_state.cn_preprocessor_modules.keys()) - - if alias_names: - modules = [global_state.preprocessor_aliases.get(module, module) for module in modules] - - return modules - - -def get_modules_detail(alias_names: bool = False) -> Dict[str, Any]: - """ - get the detail of all preprocessors including - sliders: the slider config in Auto1111 webUI - - Keyword arguments: - alias_names -- Whether to get the module detail with alias names instead of internal keys - """ - - _module_detail = {} - _module_list = get_modules(False) - _module_list_alias = get_modules(True) - - _output_list = _module_list if not alias_names else _module_list_alias - for index, module in enumerate(_output_list): - if _module_list[index] in preprocessor_sliders_config: - _module_detail[module] = { - "model_free": module in model_free_preprocessors, - "sliders": preprocessor_sliders_config[_module_list[index]] - } - else: - _module_detail[module] = { - "model_free": False, - "sliders": [] - } - - return _module_detail - - -def find_cn_script(script_runner: scripts.ScriptRunner) -> Optional[scripts.Script]: - """ - Find the ControlNet script in `script_runner`. Returns `None` if `script_runner` does not contain a ControlNet script. - """ - - if script_runner is None: - return None - - for script in script_runner.alwayson_scripts: - if is_cn_script(script): - return script - - -def is_cn_script(script: scripts.Script) -> bool: - """ - Determine whether `script` is a ControlNet script. - """ - - return script.title().lower() == 'controlnet' +from internal_controlnet.external_code import * diff --git a/tests/external_code_api/importlib_reload_test.py b/tests/external_code_api/importlib_reload_test.py new file mode 100644 index 000000000..3c81a99fb --- /dev/null +++ b/tests/external_code_api/importlib_reload_test.py @@ -0,0 +1,24 @@ +import unittest +import importlib +utils = importlib.import_module('extensions.sd-webui-controlnet.tests.utils', 'utils') +utils.setup_test_env() + +from scripts import external_code + + +class TestImportlibReload(unittest.TestCase): + def setUp(self): + self.ControlNetUnit = external_code.ControlNetUnit + + def test_reload_does_not_redefine(self): + importlib.reload(external_code) + NewControlNetUnit = external_code.ControlNetUnit + self.assertEqual(self.ControlNetUnit, NewControlNetUnit) + + def test_force_import_does_not_redefine(self): + external_code_copy = importlib.import_module('extensions.sd-webui-controlnet.scripts.external_code', 'external_code') + self.assertEqual(self.ControlNetUnit, external_code_copy.ControlNetUnit) + + +if __name__ == '__main__': + unittest.main()