Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable specifying schedulers for text-to-image and image-to-image #132

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 26 additions & 15 deletions runner/app/pipelines/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,19 @@
import os
from enum import Enum
from typing import List, Optional, Tuple
from copy import deepcopy

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
)
from diffusers import (
AutoPipelineForImage2Image,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from app.pipelines.utils import (SafetyChecker, get_model_dir,
get_torch_device, is_lightning_model,
is_turbo_model, load_scheduler_presets,
create_scheduler)
from diffusers import (AutoPipelineForImage2Image,
EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
StableDiffusionInstructPix2PixPipeline,
StableDiffusionXLPipeline, UNet2DConditionModel)
from huggingface_hub import file_download, hf_hub_download
from PIL import ImageFile
from safetensors.torch import load_file
Expand Down Expand Up @@ -125,6 +119,12 @@ def __init__(self, model_id: str):
model_id, **kwargs
).to(torch_device)

#save the default scheduler
self.default_scheduler = deepcopy(self.ldm.scheduler)
#load the scheduler presets
self.scheduler_presets = load_scheduler_presets(self.__class__.__name__)
logger.info(f"loaded scheduler presets for {self.__class__.__name__}")

sfast_enabled = os.getenv("SFAST", "").strip().lower() == "true"
deepcache_enabled = os.getenv("DEEPCACHE", "").strip().lower() == "true"
if sfast_enabled and deepcache_enabled:
Expand Down Expand Up @@ -222,6 +222,17 @@ def __call__(
# Default to 8step
kwargs["num_inference_steps"] = 8

set_scheduler = kwargs.pop("scheduler", None)
logger.info(f"setting pipeline scheduler to: {set_scheduler}")
if set_scheduler:
new_scheduler, args, error = create_scheduler(set_scheduler, self.scheduler_presets)
if new_scheduler:
self.ldm.scheduler = new_scheduler.from_config(self.default_scheduler.config, **args)
else:
raise ValueError(f"scheduler could not be created: {error}")
else:
self.ldm.scheduler = self.default_scheduler

output = self.ldm(prompt, image=image, **kwargs)

if safety_check:
Expand Down
41 changes: 25 additions & 16 deletions runner/app/pipelines/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,18 @@
import os
from enum import Enum
from typing import List, Optional, Tuple
from copy import deepcopy

import PIL
import torch
from app.pipelines.base import Pipeline
from app.pipelines.utils import (
SafetyChecker,
get_model_dir,
get_torch_device,
is_lightning_model,
is_turbo_model,
split_prompt,
)
from diffusers import (
AutoPipelineForText2Image,
EulerDiscreteScheduler,
FluxPipeline,
StableDiffusion3Pipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from app.pipelines.utils import (SafetyChecker, get_model_dir,
get_torch_device, is_lightning_model,
is_turbo_model, split_prompt,
load_scheduler_presets, create_scheduler)
from diffusers import (AutoPipelineForText2Image, EulerDiscreteScheduler,
StableDiffusion3Pipeline, StableDiffusionXLPipeline,
UNet2DConditionModel, FluxPipeline)
from diffusers.models import AutoencoderKL
from huggingface_hub import file_download, hf_hub_download
from safetensors.torch import load_file
Expand Down Expand Up @@ -139,6 +131,12 @@ def __init__(self, model_id: str):
torch_device
)

#save the default scheduler
self.default_scheduler = deepcopy(self.ldm.scheduler)
#load the scheduler presets
self.scheduler_presets = load_scheduler_presets(self.__class__.__name__)
logger.info(f"loaded scheduler presets for {self.__class__.__name__}")

if os.environ.get("TORCH_COMPILE"):
torch._inductor.config.conv_1x1_as_mm = True
torch._inductor.config.coordinate_descent_tuning = True
Expand Down Expand Up @@ -263,6 +261,17 @@ def __call__(
)
kwargs.update(neg_prompts)

set_scheduler = kwargs.pop("scheduler", None)
logger.info(f"setting pipeline scheduler to: {set_scheduler}")
if set_scheduler:
new_scheduler, args, error = create_scheduler(set_scheduler, self.scheduler_presets)
if new_scheduler:
self.ldm.scheduler = new_scheduler.from_config(self.default_scheduler.config, **args)
else:
raise ValueError(f"scheduler could not be created: {error}")
else:
self.ldm.scheduler = self.default_scheduler

output = self.ldm(prompt=prompt, **kwargs)

if safety_check:
Expand Down
5 changes: 5 additions & 0 deletions runner/app/pipelines/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,8 @@
split_prompt,
validate_torch_device,
)

from app.pipelines.utils.schedulers import (
load_scheduler_presets,
create_scheduler
)
62 changes: 62 additions & 0 deletions runner/app/pipelines/utils/schedulers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
import importlib, json, logging
from typing import Tuple, get_type_hints

logger = logging.getLogger(__name__)

text_to_image_presets = {
"DPM++ 2M": { "name": "DPMSolverMultistepScheduler", "args": {} },
"DPM++ 2M Karras": { "name": "DPMSolverMultistepScheduler", "args": { "use_karras_sigmas": True } },
"DPM++ 2M SDE": { "name": "DPMSolverMultistepScheduler", "args": { "algorithm_type": "sde-dpmsolver++" } },
"DPM++ 2M SDE Karras": { "name": "DPMSolverMultistepScheduler", "args": { "use_karras_sigmas": True, "algorithm_type": "sde-dpmsolver++" } },
"DPM++ 2S a": { "name": "DPMSolverSinglestepScheduler", "args": {} },
"DPM++ 2S a Karras": { "name": "DPMSolverSinglestepScheduler", "args": { "use_karras_sigmas": True } },
"DPM++ SDE": { "name": "DPMSolverSinglestepScheduler", "args": {} },
"DPM++ SDE Karras": { "name": "DPMSolverSinglestepScheduler", "args": { "use_karras_sigmas": True } },
"DPM2": { "name": "KDPM2DiscreteScheduler", "args": {} },
"DPM2 Karras": { "name": "KDPM2DiscreteScheduler", "args": { "use_karras_sigmas": True } },
"DPM2 a": { "name": "KDPM2AncestralDiscreteScheduler", "args": {} },
"DPM2 a Karras": { "name": "KDPM2AncestralDiscreteScheduler", "args": { "use_karras_sigmas": True } },
"Euler": { "name": "EulerDiscreteScheduler", "args": {} },
"Euler a": { "name": "EulerAncestralDiscreteScheduler", "args": {} },
"Euler flow match": { "name": "FlowMatchEulerDiscreteScheduler", "args": {} },
"Huen": { "name": "HeunDiscreteScheduler", "args": {} },
"LMS": { "name": "LMSDiscreteScheduler", "args": {} },
"LMS Karras": { "name": "LMSDiscreteScheduler", "args": { "use_karras_sigmas": True } }
}


def load_scheduler_presets(pipeline: str) -> any:
if pipeline == "TextToImagePipeline":
return text_to_image_presets
elif pipeline == "ImageToImagePipeline":
return text_to_image_presets
else:
return {}

def create_scheduler(scheduler: str, presets: dict) -> Tuple[object, dict, str]:
"""
creates scheduler from provided settings (name/args). Presets available for convenience setting main schedulers
"""
set_sch = json.loads(scheduler)
set_sch_name = set_sch.get("name", None)
if set_sch_name in presets:
set_sch["name"] = presets[set_sch_name]["name"]
set_sch["args"] = presets[set_sch_name]["args"] | set_sch["args"]

try:
sch_cls = getattr(importlib.import_module("diffusers"), set_sch["name"])
#convert params to correct type. The from_config method passes the **kwargs provided directly
#to the __init__ method of the scheduler (diffusers/src/diffusers/configuration_utils.py)
type_hints = get_type_hints(sch_cls.__init__)
for arg in set_sch["args"]:
try:
set_sch["args"][arg] = type_hints[arg](set_sch["args"][arg])
except:
return None, None, f"params not in correct format: {arg}={set_sch['args'][arg]}"

return sch_cls, set_sch["args"], ""

except BaseException as e:
return None, None, f"scheduler not available: {e}"


5 changes: 5 additions & 0 deletions runner/app/routes/image_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,10 @@ async def image_to_image(
int,
Form(description="Number of images to generate per prompt."),
] = 1,
scheduler: Annotated[
str,
Form(description="Set scheduler for pipeline to use per documentation or presets available")
] = "",
pipeline: Pipeline = Depends(get_pipeline),
token: HTTPAuthorizationCredentials = Depends(HTTPBearer(auto_error=False)),
):
Expand Down Expand Up @@ -153,6 +157,7 @@ async def image_to_image(
seed=seed,
num_images_per_prompt=1,
num_inference_steps=num_inference_steps,
scheduler=scheduler
)
images.extend(imgs)
has_nsfw_concept.extend(nsfw_checks)
Expand Down
7 changes: 7 additions & 0 deletions runner/app/routes/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,13 @@ class TextToImageParams(BaseModel):
int,
Field(default=1, description="Number of images to generate per prompt."),
]
scheduler: Annotated[
str,
Field(
default="",
description="Set scheduler for pipeline to use per documentation or presets available"
),
]


RESPONSES = {
Expand Down
12 changes: 12 additions & 0 deletions runner/gateway.openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,12 @@ components:
title: Num Images Per Prompt
description: Number of images to generate per prompt.
default: 1
scheduler:
type: string
title: Scheduler
description: Set scheduler for pipeline to use per documentation or presets
available
default: ''
type: object
required:
- prompt
Expand Down Expand Up @@ -677,6 +683,12 @@ components:
title: Num Images Per Prompt
description: Number of images to generate per prompt.
default: 1
scheduler:
type: string
title: Scheduler
description: Set scheduler for pipeline to use per documentation or presets
available
default: ''
type: object
required:
- prompt
Expand Down
12 changes: 12 additions & 0 deletions runner/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,12 @@ components:
title: Num Images Per Prompt
description: Number of images to generate per prompt.
default: 1
scheduler:
type: string
title: Scheduler
description: Set scheduler for pipeline to use per documentation or presets
available
default: ''
type: object
required:
- prompt
Expand Down Expand Up @@ -691,6 +697,12 @@ components:
title: Num Images Per Prompt
description: Number of images to generate per prompt.
default: 1
scheduler:
type: string
title: Scheduler
description: Set scheduler for pipeline to use per documentation or presets
available
default: ''
type: object
required:
- prompt
Expand Down
Loading