Skip to content

Commit

Permalink
feat: enable multiple prompts for T2I (#125)
Browse files Browse the repository at this point in the history
* enable multiple prompts for T2I

* refactor(runner): make prompt splitting more general

This commit ensures that the prompt splitting logic implemented in the
previous commits can be reused in multiple pipelines.

---------

Co-authored-by: Rick Staa <[email protected]>
  • Loading branch information
ad-astra-video and rickstaa authored Jul 17, 2024
1 parent 68ca7df commit 68b8d85
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 10 deletions.
31 changes: 21 additions & 10 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,15 @@

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
split_prompt,
)
from diffusers import (
AutoPipelineForText2Image,
EulerDiscreteScheduler,
Expand All @@ -15,15 +24,6 @@
from huggingface_hub import file_download, hf_hub_download
from safetensors.torch import load_file

from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -222,7 +222,18 @@ def __call__(
# Default to 8step
kwargs["num_inference_steps"] = 8

output = self.ldm(prompt, **kwargs)
# Allow users to specify multiple (negative) prompts using the '|' separator.
prompts = split_prompt(prompt, max_splits=3)
prompt = prompts.pop("prompt")
kwargs.update(prompts)
neg_prompts = split_prompt(
kwargs.pop("negative_prompt", ""),
key_prefix="negative_prompt",
max_splits=3,
)
kwargs.update(neg_prompts)

output = self.ldm(prompt=prompt, **kwargs)

if safety_check:
_, has_nsfw_concept = self._safety_checker.check_nsfw_images(output.images)
Expand Down
1 change: 1 addition & 0 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
get_torch_device,
is_lightning_model,
is_turbo_model,
split_prompt,
validate_torch_device,
)
33 changes: 33 additions & 0 deletions runner/app/pipelines/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,39 @@ def is_turbo_model(model_id: str) -> bool:
return re.search(r"[-_]turbo", model_id, re.IGNORECASE) is not None


def split_prompt(
input_prompt: str,
separator: str = "|",
key_prefix: str = "prompt",
max_splits: int = -1,
) -> dict[str, str]:
"""Splits an input prompt into prompts, including the main prompt, with customizable
key naming.
Args:
input_prompt (str): The input prompt string to be split.
separator (str): The character used to split the input prompt. Defaults to '|'.
key_prefix (str): Prefix for keys in the returned dictionary for all prompts,
including the main prompt. Defaults to 'prompt'.
max_splits (int): Maximum number of splits to perform. Defaults to -1 (no limit).
Returns:
Dict[str, str]: A dictionary of all prompts, including the main prompt.
"""
prompts = input_prompt.split(separator, max_splits - 1)
start_index = 1 if max_splits < 0 else max(1, len(prompts) - max_splits)

prompt_dict = {f"{key_prefix}": prompts[0].strip()}
prompt_dict.update(
{
f"{key_prefix}_{i+1}": prompt.strip()
for i, prompt in enumerate(prompts[1:], start=start_index)
}
)

return prompt_dict


class SafetyChecker:
"""Checks images for unsafe or inappropriate content using a pretrained model.
Expand Down

0 comments on commit 68b8d85

Please sign in to comment.