Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support union ControlNet #2988

Merged
merged 4 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions internal_controlnet/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ControlMode,
HiResFixOption,
PuLIDMode,
ControlNetUnionControlType,
)
from annotator.util import HWC3

Expand Down Expand Up @@ -202,6 +203,11 @@ def parse_effective_region_mask(cls, value) -> np.ndarray:
# https://github.com/ToTheBeginning/PuLID
pulid_mode: PuLIDMode = PuLIDMode.FIDELITY

# ControlNet control type for ControlNet union model.
# https://github.com/xinsir6/ControlNetPlus/tree/main
# The value of this field is only used when the model is ControlNetUnion.
union_control_type: ControlNetUnionControlType = ControlNetUnionControlType.UNKNOWN

# ------- API only fields -------
# The tensor input for ipadapter. When this field is set in the API,
# the base64string will be interpret by torch.load to reconstruct ipadapter
Expand Down
88 changes: 83 additions & 5 deletions scripts/cldm.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import List
import torch
import torch.nn as nn

from modules import devices

from scripts.controlnet_core.controlnet_union import ControlAddEmbedding, ResBlockUnionControlnet

try:
from sgm.modules.diffusionmodules.openaimodel import conv_nd, linear, zero_module, timestep_embedding, \
Expand All @@ -26,7 +27,7 @@ def __init__(self, config, state_dict=None):

def reset(self):
pass

def forward(self, *args, **kwargs):
return self.control_model(*args, **kwargs)

Expand Down Expand Up @@ -57,7 +58,7 @@ def send_me_to_gpu(module, _):
def fullvram(self):
self.to(devices.get_device_for("controlnet"))
return


class ControlNet(nn.Module):
def __init__(
Expand Down Expand Up @@ -90,6 +91,7 @@ def __init__(
use_linear_in_transformer=False,
adm_in_channels=None,
transformer_depth_middle=None,
union_controlnet=False,
device=None,
global_average_pooling=False,
):
Expand Down Expand Up @@ -280,10 +282,74 @@ def __init__(
self.middle_block_out = self.make_zero_conv(ch)
self._feature_size += ch

if union_controlnet:
self.num_control_type = 6
num_trans_channel = 320
num_trans_head = 8
num_trans_layer = 1
num_proj_channel = 320
self.task_embedding = nn.Parameter(torch.empty(
self.num_control_type, num_trans_channel, dtype=self.dtype, device=device
))

self.transformer_layes = nn.Sequential(*[
ResBlockUnionControlnet(
num_trans_channel, num_trans_head, dtype=self.dtype, device=device
)
for _ in range(num_trans_layer)
])
self.spatial_ch_projs = nn.Linear(
num_trans_channel, num_proj_channel, dtype=self.dtype, device=device
)

control_add_embed_dim = 256
self.control_add_embedding = ControlAddEmbedding(
control_add_embed_dim, time_embed_dim, self.num_control_type,
dtype=self.dtype, device=device
)
else:
self.task_embedding = None
self.control_add_embedding = None

def union_controlnet_merge(
self,
hint: torch.Tensor,
control_type: List[int],
emb: torch.Tensor,
context: torch.Tensor
):
""" Note: control_type is a list of enum values. The length of the list
is the number of control images."""
# Equivalent to: https://github.com/xinsir6/ControlNetPlus/tree/main
inputs = []
condition_list = []

for idx in range(min(1, len(control_type))):
controlnet_cond = self.input_hint_block(hint[idx], emb, context)
feat_seq = torch.mean(controlnet_cond, dim=(2, 3))
if idx < len(control_type):
feat_seq += self.task_embedding[control_type[idx]]

inputs.append(feat_seq.unsqueeze(1))
condition_list.append(controlnet_cond)

x = torch.cat(inputs, dim=1)
x = self.transformer_layes(x)
controlnet_cond_fuser = None
for idx in range(len(control_type)):
alpha = self.spatial_ch_projs(x[:, idx])
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
o = condition_list[idx] + alpha
if controlnet_cond_fuser is None:
controlnet_cond_fuser = o
else:
controlnet_cond_fuser += o
return controlnet_cond_fuser

def make_zero_conv(self, channels):
return TimestepEmbedSequential(zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))

def forward(self, x, hint, timesteps, context, y=None, **kwargs):
def forward(self, x, hint, timesteps, context, y=None, control_type: List[int] = None, **kwargs):
original_type = x.dtype

x = x.to(self.dtype)
Expand All @@ -297,7 +363,19 @@ def forward(self, x, hint, timesteps, context, y=None, **kwargs):
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(self.dtype)
emb = self.time_embed(t_emb)

guided_hint = self.input_hint_block(hint, emb, context)
guided_hint = None
if self.control_add_embedding is not None:
assert control_type is not None

emb += self.control_add_embedding(control_type, emb.dtype, emb.device)
if len(control_type) > 0:
if len(hint.shape) < 5:
hint = hint.unsqueeze(dim=0)
guided_hint = self.union_controlnet_merge(hint, control_type, emb, context)

if guided_hint is None:
guided_hint = self.input_hint_block(hint, emb, context)

outs = []

if self.num_classes is not None:
Expand Down
6 changes: 6 additions & 0 deletions scripts/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,6 +1027,10 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe
control_model_type.is_controlnet and
model_net.control_model.global_average_pooling
)

if control_model_type == ControlModelType.ControlNetUnion:
logger.info(f"ControlNetUnion control type: {unit.union_control_type}")

forward_param = ControlParams(
control_model=model_net,
preprocessor=preprocessor_dict,
Expand All @@ -1047,6 +1051,8 @@ def ad_process_control(cc: List[torch.Tensor], cn_ad_keyframe_idx=cn_ad_keyframe
if unit.effective_region_mask is not None
else None
),
# TODO: Implement merge of units with the same union model.
union_control_types=[unit.union_control_type],
)
forward_params.append(forward_param)

Expand Down
118 changes: 118 additions & 0 deletions scripts/controlnet_core/controlnet_union.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from collections import OrderedDict
import torch
import torch.nn as nn

try:
from sgm.modules.diffusionmodules.openaimodel import (
timestep_embedding,
)

using_sgm = True
except ImportError:
from ldm.modules.diffusionmodules.openaimodel import (
timestep_embedding,
)

using_sgm = False


def attention_pytorch(
q, k, v, heads, mask=None, attn_precision=None, skip_reshape=False
):
if skip_reshape:
b, _, _, dim_head = q.shape
else:
b, _, dim_head = q.shape
dim_head //= heads
q, k, v = map(
lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2),
(q, k, v),
)

out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False
)
out = out.transpose(1, 2).reshape(b, -1, heads * dim_head)
return out


class ControlAddEmbedding(nn.Module):
def __init__(
self,
in_dim,
out_dim,
num_control_type,
dtype=None,
device=None,
):
super().__init__()
self.num_control_type = num_control_type
self.in_dim = in_dim
self.linear_1 = nn.Linear(
in_dim * num_control_type, out_dim, dtype=dtype, device=device
)
self.linear_2 = nn.Linear(out_dim, out_dim, dtype=dtype, device=device)

def forward(self, control_type, dtype, device):
c_type = torch.zeros((self.num_control_type,), device=device)
c_type[control_type] = 1.0
c_type = (
timestep_embedding(c_type.flatten(), self.in_dim, repeat_only=False)
.to(dtype)
.reshape((-1, self.num_control_type * self.in_dim))
)
return self.linear_2(torch.nn.functional.silu(self.linear_1(c_type)))


class OptimizedAttention(nn.Module):
def __init__(self, c, nhead, dropout=0.0, dtype=None, device=None, operations=None):
super().__init__()
self.heads = nhead
self.c = c

self.in_proj = nn.Linear(c, c * 3, bias=True, dtype=dtype, device=device)
self.out_proj = nn.Linear(c, c, bias=True, dtype=dtype, device=device)

def forward(self, x):
x = self.in_proj(x)
q, k, v = x.split(self.c, dim=2)
out = attention_pytorch(q, k, v, self.heads)
return self.out_proj(out)


class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)


class ResBlockUnionControlnet(nn.Module):
def __init__(self, dim, nhead, dtype=None, device=None, operations=None):
super().__init__()
self.attn = OptimizedAttention(
dim, nhead, dtype=dtype, device=device, operations=operations
)
self.ln_1 = nn.LayerNorm(dim, dtype=dtype, device=device)
self.mlp = nn.Sequential(
OrderedDict(
[
(
"c_fc",
nn.Linear(dim, dim * 4, dtype=dtype, device=device),
),
("gelu", QuickGELU()),
(
"c_proj",
nn.Linear(dim * 4, dim, dtype=dtype, device=device),
),
]
)
)
self.ln_2 = nn.LayerNorm(dim, dtype=dtype, device=device)

def attention(self, x: torch.Tensor):
return self.attn(x)

def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
19 changes: 15 additions & 4 deletions scripts/controlnet_model_guess.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,14 +221,25 @@ def build_model_by_guess(state_dict, unet, model_path: str) -> ControlModel:
final_state_dict[key] = p_new
state_dict = final_state_dict

config['use_fp16'] = devices.dtype_unet == torch.float16
if "control_add_embedding.linear_1.bias" in state_dict: # Controlnet Union
config["union_controlnet"] = True
final_state_dict = {}
for k in list(state_dict.keys()):
new_k = k.replace('.attn.in_proj_', '.attn.in_proj.')
final_state_dict[new_k] = state_dict.pop(k)
state_dict = final_state_dict

network = PlugableControlModel(config, state_dict)
network.to(devices.dtype_unet)
if "instant_id" in model_path.lower():
control_model_type = ControlModelType.ControlNetUnion
elif "instant_id" in model_path.lower():
control_model_type = ControlModelType.InstantID
else:
control_model_type = ControlModelType.ControlNet

config['use_fp16'] = devices.dtype_unet == torch.float16

network = PlugableControlModel(config, state_dict)
network.to(devices.dtype_unet)

return ControlModel(network, control_model_type)

if 'conv_in.weight' in state_dict:
Expand Down
24 changes: 24 additions & 0 deletions scripts/controlnet_ui/controlnet_ui_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
PuLIDMode,
ControlMode,
ResizeMode,
ControlNetUnionControlType,
)
from modules import shared
from modules.ui_components import FormRow, FormHTML, ToolButton
Expand Down Expand Up @@ -265,6 +266,7 @@ def __init__(
self.output_dir_state = None
self.advanced_weighting = gr.State(None)
self.pulid_mode = None
self.union_control_type = None

# API-only fields
self.ipadapter_input = gr.State(None)
Expand Down Expand Up @@ -487,6 +489,13 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
visible=False,
)

with gr.Row():
self.union_control_type = gr.Textbox(
label="Union Control Type",
value=ControlNetUnionControlType.UNKNOWN.value,
visible=False,
)

with gr.Row(elem_classes=["controlnet_control_type", "controlnet_row"]):
self.type_filter = (
gr.Dropdown
Expand Down Expand Up @@ -664,6 +673,7 @@ def render(self, tabname: str, elem_id_tabname: str) -> None:
self.advanced_weighting,
self.effective_region_mask,
self.pulid_mode,
self.union_control_type,
)

unit = gr.State(ControlNetUnit())
Expand Down Expand Up @@ -841,6 +851,19 @@ def filter_selected(k: str):
show_progress=False,
)

def register_union_control_type(self):
def filter_selected(k: str):
control_type = ControlNetUnionControlType.from_str(k)
logger.debug(f"Switch to union control type {control_type}")
return gr.update(value=control_type.value)

self.type_filter.change(
fn=filter_selected,
inputs=[self.type_filter],
outputs=[self.union_control_type],
show_progress=False,
)

def register_sd_version_changed(self):
def sd_version_changed(type_filter: str, current_model: str):
"""When SD version changes, update model dropdown choices."""
Expand Down Expand Up @@ -1227,6 +1250,7 @@ def register_core_callbacks(self):
self.register_webcam_mirror_toggle()
self.register_refresh_all_models()
self.register_build_sliders()
self.register_union_control_type()
self.register_shift_preview()
self.register_shift_upload_mask()
self.register_shift_pulid_mode()
Expand Down
Loading
Loading