Skip to content

Commit

Permalink
merge xl path, fix vae
Browse files Browse the repository at this point in the history
  • Loading branch information
lllyasviel committed Aug 23, 2023
1 parent 4ed5b77 commit 7003d5f
Showing 1 changed file with 12 additions and 28 deletions.
40 changes: 12 additions & 28 deletions scripts/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,15 @@ def call_vae_using_process(p, x, batch_size=None, mask=None):
with devices.autocast():
vae_output = p.sd_model.encode_first_stage(x)
vae_output = p.sd_model.get_first_stage_encoding(vae_output)
if torch.all(torch.isnan(vae_output)).item():
logger.info(f'ControlNet find Nans in the VAE encoding. \n '
f'Now ControlNet will automatically retry.\n '
f'To always start with 32-bit VAE, use --no-half-vae commandline flag.')
devices.dtype_vae = torch.float32
x = x.to(devices.dtype_vae)
p.sd_model.first_stage_model.to(devices.dtype_vae)
vae_output = p.sd_model.encode_first_stage(x)
vae_output = p.sd_model.get_first_stage_encoding(vae_output)
vae_cache.set(x, vae_output)
logger.info(f'ControlNet used {str(devices.dtype_vae)} VAE to encode {vae_output.shape}.')
latent = vae_output
Expand Down Expand Up @@ -366,7 +375,8 @@ def process_sample(*args, **kwargs):
mark_prompt_context(getattr(process, 'hr_uc', []), positive=False)
return process.sample_before_CN_hack(*args, **kwargs)

def forward_sd15(self, x, timesteps=None, context=None, **kwargs):
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
is_sdxl = y is not None and model_is_sdxl
total_controlnet_embedding = [0.0] * 13
total_t2i_adapter_embedding = [0.0] * 4
require_inpaint_hijack = False
Expand Down Expand Up @@ -680,40 +690,14 @@ def forward_sd15(self, x, timesteps=None, context=None, **kwargs):

return h

def forward_sdxl(self, x, timesteps=None, context=None, y=None, **kwargs):
# Handle cond-uncond marker
cond_mark, outer.current_uc_indices, context = unmark_prompt_context(context)

hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)

if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)

h = x
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)

return self.out(h)

def forward_webui(*args, **kwargs):
# webui will handle other compoments
try:
if shared.cmd_opts.lowvram:
lowvram.send_everything_to_cpu()

if model_is_sdxl:
return forward_sdxl(*args, **kwargs)
else:
return forward_sd15(*args, **kwargs)
return forward(*args, **kwargs)
finally:
if self.lowvram:
for param in self.control_params:
Expand Down

0 comments on commit 7003d5f

Please sign in to comment.