Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pipeline at dev branch #98

Open
SmileTAT opened this issue Sep 26, 2024 · 4 comments
Open

pipeline at dev branch #98

SmileTAT opened this issue Sep 26, 2024 · 4 comments

Comments

@SmileTAT
Copy link

at dev brach, init pipeline as following code, but the output image is covered with a red layer
`# brushnet-based version
unet = UNet2DConditionModel.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
subfolder="unet",
revision=None,
torch_dtype=weight_dtype,
)
text_encoder = CLIPTextModel.from_pretrained(
"stable-diffusion-v1-5/stable-diffusion-v1-5",
subfolder="text_encoder",
revision=None,
torch_dtype=weight_dtype,
)
brushnet = BrushNetModel.from_unet(unet)

checkpoint_dir = "cache/huggingface/hub/models--JunhaoZhuang--PowerPaint-v2-1/snapshots/5ae2be3ac38b162df209b7ad5de036d339081e33"
base_model_path = os.path.join(checkpoint_dir, "realisticVisionV60B1_v51VAE")
        
pipe = StableDiffusionPowerPaintBrushNetPipeline.from_pretrained(
    base_model_path,
    brushnet=brushnet,
    text_encoder=text_encoder,
    torch_dtype=weight_dtype,
    low_cpu_mem_usage=False,
    safety_checker=None,
)
pipe.unet = UNet2DConditionModel.from_pretrained(
    base_model_path,
    subfolder="unet",
    revision=None,
    torch_dtype=weight_dtype,
)
load_model(
    pipe.brushnet,
    os.path.join(checkpoint_dir, "PowerPaint_Brushnet/diffusion_pytorch_model.safetensors"),
)

# IMPORTANT: add learnable tokens for task prompts into tokenizer
placeholder_tokens = ["P_ctxt", "P_shape", "P_obj"] # [v.placeholder_tokens for k, v in args.task_prompt.items()]
initializer_token = ["a", "a", "a"] # [v.initializer_token for k, v in args.task_prompt.items()]
num_vectors_per_token = [10, 10, 10] # [v.num_vectors_per_token for k, v in args.task_prompt.items()]
placeholder_token_ids = pipe.add_tokens(
    placeholder_tokens, initializer_token, num_vectors_per_token, initialize_parameters=True
)

text_state_dict = torch.load(os.path.join(checkpoint_dir, "PowerPaint_Brushnet/pytorch_model.bin"))

P_obj_weight = text_state_dict['text_model.embeddings.token_embedding.trainable_embeddings.P_obj']
P_ctxt_weight = text_state_dict['text_model.embeddings.token_embedding.trainable_embeddings.P_ctxt']
P_shape_weight = text_state_dict['text_model.embeddings.token_embedding.trainable_embeddings.P_shape']
wraped_weight = text_state_dict['text_model.embeddings.token_embedding.wrapped.weight']

text_state_dict.pop('text_model.embeddings.token_embedding.trainable_embeddings.P_obj')
text_state_dict.pop('text_model.embeddings.token_embedding.trainable_embeddings.P_ctxt')
text_state_dict.pop('text_model.embeddings.token_embedding.trainable_embeddings.P_shape')
text_state_dict.pop('text_model.embeddings.token_embedding.wrapped.weight')

text_state_dict['text_model.embeddings.token_embedding.weight'] = torch.cat([
    wraped_weight, P_ctxt_weight, P_shape_weight, P_obj_weight])

msg = pipe.text_encoder.load_state_dict(
    text_state_dict, strict=False)
print(f'text load sd: {msg}')


pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)`
@zengyh1900
Copy link
Collaborator

hi @SmileTAT what do you mean by "red layers"? Would you mind sharing some screenshots?

@SmileTAT
Copy link
Author

SmileTAT commented Sep 27, 2024

IMG_0426
input and output images @zengyh1900

@SmileTAT
Copy link
Author

IMG_0426 input and output images @zengyh1900

differences between dev and main branch
1.
dev: conditioning_latents = torch.concat([mask, conditioning_latents], 1)
main: conditioning_latents = torch.concat([conditioning_latents, mask], 1)
2.
dev: original_mask = (original_mask.sum(1)[:, None, :, :] > 0).to(image.dtype)
main: original_mask = (original_mask.sum(1)[:, None, :, :] < 0).to(image.dtype)

@zengyh1900
Copy link
Collaborator

oh I see. I had refactored the dev branch. If you are running our pretrained weights on dev branch, then it probably has some problems. Please run app in dev branch using your own trained weights

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants