Skip to content

Commit

Permalink
diffusion backend should reuse pipeline components for img2img / prom…
Browse files Browse the repository at this point in the history
…pt_variation
  • Loading branch information
bghira committed May 26, 2024
1 parent e7c44a9 commit f63a2e5
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions discord_tron_client/classes/image_manipulation/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ class DiffusionPipelineManager:

def __init__(self):
hw_limits = hardware.get_hardware_limits()
self.torch_dtype = torch.float16
self.torch_dtype = torch.bfloat16
if torch.backends.mps.is_available():
self.torch_dtype = torch.float32
self.torch_dtype = torch.float16
self.is_memory_constrained = False
self.model_id = None
if (
Expand Down Expand Up @@ -137,6 +137,9 @@ def create_pipeline(self, model_id: str, pipe_type: str, use_safetensors: bool =
elif pipe_type in ["prompt_variation"]:
# Use the long prompt weighting pipeline.
logger.debug(f"Creating a LPW pipeline for {model_id}")
if model_id in self.pipelines:
# reuse the components from an already-defined model
extra_args = self.pipelines[model_id].components
pipeline = pipeline_class.from_pretrained(
model_id,
torch_dtype=self.torch_dtype,
Expand Down Expand Up @@ -208,7 +211,7 @@ def get_model_latest_hash(
return result
except Exception as e:
logger.error(f"Could not get model metadata: {e}")
return None
return False

def get_repo_last_modified(
self,
Expand All @@ -227,6 +230,9 @@ def is_model_latest(
if latest_hash is None:
logger.debug(f"is_model_latest could not retrieve metadata: {latest_hash}")
return None
if latest_hash is False:
logger.debug(f"is_model_latest could not retrieve metadata: {latest_hash}, but we are assuming it's fine.")
return True
current_hash = self.pipeline_versions.get(model_id, {}).get("latest_hash", "unknown")
last_modified = self.pipeline_versions.get(model_id, {}).get("last_modified", "unknown")
latest_modified = self.get_repo_last_modified(model_id)
Expand Down Expand Up @@ -269,6 +275,7 @@ def get_pipe(
if (
model_id in self.last_pipe_type
and self.last_pipe_type[model_id] != pipe_type
and pipe_type != "prompt_variation"
):
logger.warn(
f"Clearing out an incorrect pipeline type for the same model. Going from {self.last_pipe_type[model_id]} to {pipe_type}. Model: {model_id}"
Expand Down

0 comments on commit f63a2e5

Please sign in to comment.