Skip to content

Commit

Permalink
🐛 Fix transformer depth issue
Browse files Browse the repository at this point in the history
  • Loading branch information
huchenlei committed Aug 21, 2023
1 parent 839ffae commit d6eb229
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 6 deletions.
5 changes: 3 additions & 2 deletions models/control-lora-canny-rank128.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ model:
model_channels: 320
num_res_blocks: 2
attention_resolutions: [2, 4]
transformer_depth: 10
transformer_depth: [0, 2, 10]
transformer_depth_middle: 10
channel_mult: [1, 2, 4]
use_linear_in_transformer: True
context_dim: [2048, 2048,2048,2048,2048,2048,2048,2048,2048,2048]
context_dim: [2048,2048,2048,2048,2048,2048,2048,2048,2048,2048]
num_heads: -1
num_head_channels: 64
hint_channels: 3
20 changes: 16 additions & 4 deletions scripts/cldm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@ def use_controlnet_lora_operations():


def set_attr(obj, attr, value):
print(f"setting {attr}")

attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
Expand Down Expand Up @@ -182,6 +180,7 @@ def __init__(
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
transformer_depth_middle=None,
):
use_fp16 = getattr(devices, 'dtype_unet', devices.dtype) == th.float16 and not getattr(shared.cmd_opts, "no_half_controlnet", False)

Expand All @@ -204,6 +203,13 @@ def __init__(
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'

if isinstance(transformer_depth, int):
transformer_depth = len(channel_mult) * [transformer_depth]
if transformer_depth_middle is None:
transformer_depth_middle = transformer_depth[-1]

self.max_transformer_depth = max([*transformer_depth, transformer_depth_middle])

self.dims = dims
self.image_size = image_size
self.in_channels = in_channels
Expand Down Expand Up @@ -313,7 +319,7 @@ def __init__(
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth[level], context_dim=context_dim,
disable_self_attn=disabled_sa, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
)
Expand Down Expand Up @@ -373,7 +379,7 @@ def __init__(
use_new_attention_order=use_new_attention_order,
# always uses a self-attn
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
ch, num_heads, dim_head, depth=transformer_depth_middle, context_dim=context_dim,
disable_self_attn=disable_middle_self_attn, use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint
),
Expand Down Expand Up @@ -409,6 +415,12 @@ def forward(self, x, hint, timesteps, context, **kwargs):
guided_hint = self.align(guided_hint, h1, w1)

h = x.type(self.dtype)

# `context` is only used in SpatialTransformer.
if not isinstance(context, list):
context = [context] * self.max_transformer_depth
assert len(context) >= self.max_transformer_depth

for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
h = module(h, emb, context)
Expand Down

0 comments on commit d6eb229

Please sign in to comment.