diff --git a/internal_controlnet/args.py b/internal_controlnet/args.py index 7e6dcec6e..413794f84 100644 --- a/internal_controlnet/args.py +++ b/internal_controlnet/args.py @@ -15,6 +15,7 @@ ControlMode, HiResFixOption, PuLIDMode, + ControlNetUnionControlType, ) from annotator.util import HWC3 @@ -202,6 +203,11 @@ def parse_effective_region_mask(cls, value) -> np.ndarray: # https://github.com/ToTheBeginning/PuLID pulid_mode: PuLIDMode = PuLIDMode.FIDELITY + # ControlNet control type for ControlNet union model. + # https://github.com/xinsir6/ControlNetPlus/tree/main + # The value of this field is only used when the model is ControlNetUnion. + union_control_type: ControlNetUnionControlType = ControlNetUnionControlType.UNKNOWN + # ------- API only fields ------- # The tensor input for ipadapter. When this field is set in the API, # the base64string will be interpret by torch.load to reconstruct ipadapter diff --git a/scripts/cldm.py b/scripts/cldm.py index be7612e98..c404fd956 100644 --- a/scripts/cldm.py +++ b/scripts/cldm.py @@ -1,8 +1,9 @@ +from typing import List import torch import torch.nn as nn from modules import devices - +from scripts.controlnet_core.controlnet_union import ControlAddEmbedding, ResBlockUnionControlnet try: from sgm.modules.diffusionmodules.openaimodel import conv_nd, linear, zero_module, timestep_embedding, \ @@ -26,7 +27,7 @@ def __init__(self, config, state_dict=None): def reset(self): pass - + def forward(self, *args, **kwargs): return self.control_model(*args, **kwargs) @@ -57,7 +58,7 @@ def send_me_to_gpu(module, _): def fullvram(self): self.to(devices.get_device_for("controlnet")) return - + class ControlNet(nn.Module): def __init__( @@ -90,6 +91,7 @@ def __init__( use_linear_in_transformer=False, adm_in_channels=None, transformer_depth_middle=None, + union_controlnet=False, device=None, global_average_pooling=False, ): @@ -280,10 +282,74 @@ def __init__( self.middle_block_out = self.make_zero_conv(ch) self._feature_size += ch + if union_controlnet: + self.num_control_type = 6 + num_trans_channel = 320 + num_trans_head = 8 + num_trans_layer = 1 + num_proj_channel = 320 + self.task_embedding = nn.Parameter(torch.empty( + self.num_control_type, num_trans_channel, dtype=self.dtype, device=device + )) + + self.transformer_layes = nn.Sequential(*[ + ResBlockUnionControlnet( + num_trans_channel, num_trans_head, dtype=self.dtype, device=device + ) + for _ in range(num_trans_layer) + ]) + self.spatial_ch_projs = nn.Linear( + num_trans_channel, num_proj_channel, dtype=self.dtype, device=device + ) + + control_add_embed_dim = 256 + self.control_add_embedding = ControlAddEmbedding( + control_add_embed_dim, time_embed_dim, self.num_control_type, + dtype=self.dtype, device=device + ) + else: + self.task_embedding = None + self.control_add_embedding = None + + def union_controlnet_merge( + self, + hint: torch.Tensor, + control_type: List[int], + emb: torch.Tensor, + context: torch.Tensor + ): + """ Note: control_type is a list of enum values. The length of the list + is the number of control images.""" + # Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main + inputs = [] + condition_list = [] + + for idx in range(min(1, len(control_type))): + controlnet_cond = self.input_hint_block(hint[idx], emb, context) + feat_seq = torch.mean(controlnet_cond, dim=(2, 3)) + if idx < len(control_type): + feat_seq += self.task_embedding[control_type[idx]] + + inputs.append(feat_seq.unsqueeze(1)) + condition_list.append(controlnet_cond) + + x = torch.cat(inputs, dim=1) + x = self.transformer_layes(x) + controlnet_cond_fuser = None + for idx in range(len(control_type)): + alpha = self.spatial_ch_projs(x[:, idx]) + alpha = alpha.unsqueeze(-1).unsqueeze(-1) + o = condition_list[idx] + alpha + if controlnet_cond_fuser is None: + controlnet_cond_fuser = o + else: + controlnet_cond_fuser += o + return controlnet_cond_fuser + def make_zero_conv(self, channels): return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0))) - def forward(self, x, hint, timesteps, context, y=None, **kwargs): + def forward(self, x, hint, timesteps, context, y=None, control_type: List[int] = None, **kwargs): original_type = x.dtype x = x.to(self.dtype) @@ -297,7 +363,19 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs): t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype) emb = self.time_embed(t_emb) - guided_hint = self.input_hint_block(hint, emb, context) + guided_hint = None + if self.control_add_embedding is not None: + assert control_type is not None + + emb += self.control_add_embedding(control_type, emb.dtype, emb.device) + if len(control_type) > 0: + if len(hint.shape) < 5: + hint = hint.unsqueeze(dim=0) + guided_hint = self.union_controlnet_merge(hint, control_type, emb, context) + + if guided_hint is None: + guided_hint = self.input_hint_block(hint, emb, context) + outs = [] if self.num_classes is not None: diff --git a/scripts/controlnet.py b/scripts/controlnet.py index 6fdee8c10..d2a122b16 100644 --- a/scripts/controlnet.py +++ b/scripts/controlnet.py @@ -1027,6 +1027,10 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe control_model_type.is_controlnet and model_net.control_model.global_average_pooling ) + + if control_model_type == ControlModelType.ControlNetUnion: + logger.info(f"ControlNetUnion control type: {unit.union_control_type}") + forward_param = ControlParams( control_model=model_net, preprocessor=preprocessor_dict, @@ -1047,6 +1051,8 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe if unit.effective_region_mask is not None else None ), + # TODO: Implement merge of units with the same union model. + union_control_types=[unit.union_control_type], ) forward_params.append(forward_param) diff --git a/scripts/controlnet_core/controlnet_union.py b/scripts/controlnet_core/controlnet_union.py new file mode 100644 index 000000000..153354b55 --- /dev/null +++ b/scripts/controlnet_core/controlnet_union.py @@ -0,0 +1,118 @@ +from collections import OrderedDict +import torch +import torch.nn as nn + +try: + from sgm.modules.diffusionmodules.openaimodel import ( + timestep_embedding, + ) + + using_sgm = True +except ImportError: + from ldm.modules.diffusionmodules.openaimodel import ( + timestep_embedding, + ) + + using_sgm = False + + +def attention_pytorch( + q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False +): + if skip_reshape: + b, _, _, dim_head = q.shape + else: + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), + (q, k, v), + ) + + out = torch.nn.functional.scaled_dot_product_attention( + q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False + ) + out = out.transpose(1, 2).reshape(b, -1, heads * dim_head) + return out + + +class ControlAddEmbedding(nn.Module): + def __init__( + self, + in_dim, + out_dim, + num_control_type, + dtype=None, + device=None, + ): + super().__init__() + self.num_control_type = num_control_type + self.in_dim = in_dim + self.linear_1 = nn.Linear( + in_dim * num_control_type, out_dim, dtype=dtype, device=device + ) + self.linear_2 = nn.Linear(out_dim, out_dim, dtype=dtype, device=device) + + def forward(self, control_type, dtype, device): + c_type = torch.zeros((self.num_control_type,), device=device) + c_type[control_type] = 1.0 + c_type = ( + timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False) + .to(dtype) + .reshape((-1, self.num_control_type * self.in_dim)) + ) + return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type))) + + +class OptimizedAttention(nn.Module): + def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None): + super().__init__() + self.heads = nhead + self.c = c + + self.in_proj = nn.Linear(c, c * 3, bias=True, dtype=dtype, device=device) + self.out_proj = nn.Linear(c, c, bias=True, dtype=dtype, device=device) + + def forward(self, x): + x = self.in_proj(x) + q, k, v = x.split(self.c, dim=2) + out = attention_pytorch(q, k, v, self.heads) + return self.out_proj(out) + + +class QuickGELU(nn.Module): + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResBlockUnionControlnet(nn.Module): + def __init__(self, dim, nhead, dtype=None, device=None, operations=None): + super().__init__() + self.attn = OptimizedAttention( + dim, nhead, dtype=dtype, device=device, operations=operations + ) + self.ln_1 = nn.LayerNorm(dim, dtype=dtype, device=device) + self.mlp = nn.Sequential( + OrderedDict( + [ + ( + "c_fc", + nn.Linear(dim, dim * 4, dtype=dtype, device=device), + ), + ("gelu", QuickGELU()), + ( + "c_proj", + nn.Linear(dim * 4, dim, dtype=dtype, device=device), + ), + ] + ) + ) + self.ln_2 = nn.LayerNorm(dim, dtype=dtype, device=device) + + def attention(self, x: torch.Tensor): + return self.attn(x) + + def forward(self, x: torch.Tensor): + x = x + self.attention(self.ln_1(x)) + x = x + self.mlp(self.ln_2(x)) + return x diff --git a/scripts/controlnet_model_guess.py b/scripts/controlnet_model_guess.py index d735a3285..030f43deb 100644 --- a/scripts/controlnet_model_guess.py +++ b/scripts/controlnet_model_guess.py @@ -221,14 +221,25 @@ def build_model_by_guess(state_dict, unet, model_path: str) -> ControlModel: final_state_dict[key] = p_new state_dict = final_state_dict - config['use_fp16'] = devices.dtype_unet == torch.float16 + if "control_add_embedding.linear_1.bias" in state_dict: # Controlnet Union + config["union_controlnet"] = True + final_state_dict = {} + for k in list(state_dict.keys()): + new_k = k.replace('.attn.in_proj_', '.attn.in_proj.') + final_state_dict[new_k] = state_dict.pop(k) + state_dict = final_state_dict - network = PlugableControlModel(config, state_dict) - network.to(devices.dtype_unet) - if "instant_id" in model_path.lower(): + control_model_type = ControlModelType.ControlNetUnion + elif "instant_id" in model_path.lower(): control_model_type = ControlModelType.InstantID else: control_model_type = ControlModelType.ControlNet + + config['use_fp16'] = devices.dtype_unet == torch.float16 + + network = PlugableControlModel(config, state_dict) + network.to(devices.dtype_unet) + return ControlModel(network, control_model_type) if 'conv_in.weight' in state_dict: diff --git a/scripts/controlnet_ui/controlnet_ui_group.py b/scripts/controlnet_ui/controlnet_ui_group.py index c37407d64..21d3d85e0 100644 --- a/scripts/controlnet_ui/controlnet_ui_group.py +++ b/scripts/controlnet_ui/controlnet_ui_group.py @@ -24,6 +24,7 @@ PuLIDMode, ControlMode, ResizeMode, + ControlNetUnionControlType, ) from modules import shared from modules.ui_components import FormRow, FormHTML, ToolButton @@ -265,6 +266,7 @@ def __init__( self.output_dir_state = None self.advanced_weighting = gr.State(None) self.pulid_mode = None + self.union_control_type = None # API-only fields self.ipadapter_input = gr.State(None) @@ -487,6 +489,13 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: visible=False, ) + with gr.Row(): + self.union_control_type = gr.Textbox( + label="Union Control Type", + value=ControlNetUnionControlType.UNKNOWN.value, + visible=False, + ) + with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]): self.type_filter = ( gr.Dropdown @@ -664,6 +673,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None: self.advanced_weighting, self.effective_region_mask, self.pulid_mode, + self.union_control_type, ) unit = gr.State(ControlNetUnit()) @@ -841,6 +851,19 @@ def filter_selected(k: str): show_progress=False, ) + def register_union_control_type(self): + def filter_selected(k: str): + control_type = ControlNetUnionControlType.from_str(k) + logger.debug(f"Switch to union control type {control_type}") + return gr.update(value=control_type.value) + + self.type_filter.change( + fn=filter_selected, + inputs=[self.type_filter], + outputs=[self.union_control_type], + show_progress=False, + ) + def register_sd_version_changed(self): def sd_version_changed(type_filter: str, current_model: str): """When SD version changes, update model dropdown choices.""" @@ -1227,6 +1250,7 @@ def register_core_callbacks(self): self.register_webcam_mirror_toggle() self.register_refresh_all_models() self.register_build_sliders() + self.register_union_control_type() self.register_shift_preview() self.register_shift_upload_mask() self.register_shift_pulid_mode() diff --git a/scripts/enums.py b/scripts/enums.py index 477f38ff4..e4f7b2d65 100644 --- a/scripts/enums.py +++ b/scripts/enums.py @@ -1,3 +1,4 @@ +from __future__ import annotations from enum import Enum from typing import List, NamedTuple from functools import lru_cache @@ -173,6 +174,7 @@ class ControlModelType(Enum): Controlllite = "Controlllite, Kohya" InstantID = "InstantID, Qixun Wang" SparseCtrl = "SparseCtrl, Yuwei Guo" + ControlNetUnion = "ControlNetUnion, xinsir6" @property def is_controlnet(self) -> bool: @@ -181,6 +183,7 @@ def is_controlnet(self) -> bool: ControlModelType.ControlNet, ControlModelType.ControlLoRA, ControlModelType.InstantID, + ControlModelType.ControlNetUnion, ) @property @@ -273,3 +276,61 @@ def int_value(self): elif self == ResizeMode.OUTER_FIT: return 2 assert False, "NOTREACHED" + + +class ControlNetUnionControlType(Enum): + """ + ControlNet control type for ControlNet union model. + https://github.com/xinsir6/ControlNetPlus/tree/main + """ + + OPENPOSE = "OpenPose" + DEPTH = "Depth" + # hed/pidi/scribble/ted + SOFT_EDGE = "Soft Edge" + # canny/lineart/anime_lineart/mlsd + HARD_EDGE = "Hard Edge" + NORMAL_MAP = "Normal Map" + SEGMENTATION = "Segmentation" + + UNKNOWN = "Unknown" + + @staticmethod + def all_tags() -> List[str]: + """ Tags can be handled by union ControlNet """ + return [ + "openpose", + "depth", + "softedge", + "scribble", + "canny", + "lineart", + "mlsd", + "normalmap", + "segmentation", + ] + + @staticmethod + def from_str(s: str) -> ControlNetUnionControlType: + s = s.lower() + + if s == "openpose": + return ControlNetUnionControlType.OPENPOSE + elif s == "depth": + return ControlNetUnionControlType.DEPTH + elif s in ["scribble", "softedge"]: + return ControlNetUnionControlType.SOFT_EDGE + elif s in ["canny", "lineart", "mlsd"]: + return ControlNetUnionControlType.HARD_EDGE + elif s == "normalmap": + return ControlNetUnionControlType.NORMAL_MAP + elif s == "segmentation": + return ControlNetUnionControlType.SEGMENTATION + + return ControlNetUnionControlType.UNKNOWN + + def int_value(self) -> int: + if self == ControlNetUnionControlType.UNKNOWN: + raise ValueError("Unknown control type cannot be encoded.") + + return list(ControlNetUnionControlType).index(self) diff --git a/scripts/hook.py b/scripts/hook.py index 593826b8a..4e485e375 100644 --- a/scripts/hook.py +++ b/scripts/hook.py @@ -6,7 +6,12 @@ from typing import Optional, Any, List from scripts.logging import logger -from scripts.enums import ControlModelType, AutoMachine, HiResFixOption +from scripts.enums import ( + ControlModelType, + AutoMachine, + HiResFixOption, + ControlNetUnionControlType, +) from scripts.ipadapter.ipadapter_model import ImageEmbed from scripts.controlnet_sparsectrl import SparseCtrl from modules import devices, lowvram, shared, scripts @@ -173,6 +178,7 @@ def __init__( hr_option: HiResFixOption = HiResFixOption.BOTH, control_context_override: Optional[Any] = None, effective_region_mask: Optional[torch.Tensor] = None, + union_control_types: List[ControlNetUnionControlType] = None, **kwargs # To avoid errors ): self.control_model = control_model @@ -189,6 +195,7 @@ def __init__( self.hr_option = hr_option self.control_context_override = control_context_override self.effective_region_mask = effective_region_mask + self.union_control_types = union_control_types or [] self.used_hint_cond = None self.used_hint_cond_latent = None self.used_hint_inpaint_hijack = None @@ -608,7 +615,15 @@ def forward(self, x, timesteps=None, context=None, y=None, **kwargs): hint=hint, timesteps=timesteps, context=controlnet_context, - y=y + y=y, + control_type=( + [ + t.int_value() + for t in param.union_control_types + ] + if param.control_model_type == ControlModelType.ControlNetUnion + else None + ), ) if is_sdxl: diff --git a/scripts/supported_preprocessor.py b/scripts/supported_preprocessor.py index 775299f83..22d02a427 100644 --- a/scripts/supported_preprocessor.py +++ b/scripts/supported_preprocessor.py @@ -5,6 +5,7 @@ import torch from modules import shared, devices +from scripts.enums import ControlNetUnionControlType from scripts.logging import logger from scripts.utils import ndarray_lru_cache @@ -173,7 +174,8 @@ def tag_to_filters(cls, tag: str) -> Set[str]: } tag = tag.lower() - return set([tag] + filters_aliases.get(tag, [])) + union_tags = ["union"] if tag in ControlNetUnionControlType.all_tags() else [] + return set([tag] + filters_aliases.get(tag, []) + union_tags) @classmethod def unload_unused(cls, active_processors: Set["Preprocessor"]):