Skip to content

Commit

Permalink
Fix download of .safetensors file in text-generation example (#294)
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss authored Jul 10, 2023
1 parent dbefacd commit a42056d
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions examples/text-generation/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from pathlib import Path

import torch
from huggingface_hub import snapshot_download
from huggingface_hub import list_repo_files, snapshot_download
from transformers.utils import is_offline_mode, is_safetensors_available


Expand All @@ -20,13 +20,21 @@ def get_repo_root(model_name_or_path, local_rank=-1):
if local_rank == 0:
print("Offline mode: forcing local_files_only=True")

# Only download PyTorch weights by default
allow_patterns = ["*.bin"]
# If the model repo contains any .safetensors file and
# safetensors is installed, only download safetensors weights
if is_safetensors_available():
if any(".safetensors" in filename for filename in list_repo_files(model_name_or_path)):
allow_patterns = ["*.safetensors"]

# Download only on first process
if local_rank in [-1, 0]:
cache_dir = snapshot_download(
model_name_or_path,
local_files_only=is_offline_mode(),
cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
allow_patterns=["*.safetensors"] if is_safetensors_available() else ["*.bin"],
allow_patterns=allow_patterns,
max_workers=16,
)
if local_rank == -1:
Expand All @@ -40,7 +48,7 @@ def get_repo_root(model_name_or_path, local_rank=-1):
model_name_or_path,
local_files_only=is_offline_mode(),
cache_dir=os.getenv("TRANSFORMERS_CACHE", None),
allow_patterns=["*.safetensors"] if is_safetensors_available() else ["*.bin"],
allow_patterns=allow_patterns,
)


Expand Down

0 comments on commit a42056d

Please sign in to comment.