Skip to content

Commit

Permalink
fix lora load path
Browse files Browse the repository at this point in the history
  • Loading branch information
bghira committed Aug 29, 2024
1 parent c3db52a commit 4eb8c20
Showing 1 changed file with 7 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,18 @@ def download_adapter(self, adapter_type: str, adapter_path: str):

adapter_filename = "pytorch_lora_weights.safetensors"
cache_dir = config.get_huggingface_model_path()
path_to_adapter = f"{cache_dir}/{adapter_path}"
path_to_adapter = f"{cache_dir}/{self.clean_adapter_name(adapter_path)}"
hf_hub_download(
repo_id=adapter_path, filename=adapter_filename, local_dir=cache_dir
)

return path_to_adapter

def clean_adapter_name(self, adapter_path: str) -> str:
return (
adapter_path.replace("/", "_").replace("\\", "_").replace(":", "_")
)

def load_adapter(
self,
adapter_type: str,
Expand All @@ -117,9 +122,7 @@ def load_adapter(
):
"""load the adapter from the path"""
# remove / and other chars from the adapter name
clean_adapter_name = (
adapter_path.replace("/", "_").replace("\\", "_").replace(":", "_")
)
clean_adapter_name = self.clean_adapter_name(adapter_path)
lycoris_wrapper = None
if adapter_type == "lora":
self.pipeline.load_lora_weights(
Expand Down

0 comments on commit 4eb8c20

Please sign in to comment.