Skip to content

Commit

Permalink
add load and save huggingface ckpt in baichuan2, qwen2 and gemma
Browse files Browse the repository at this point in the history
  • Loading branch information
sallyjunjun committed Sep 5, 2024
1 parent 1e5d0d1 commit dfc3d54
Show file tree
Hide file tree
Showing 10 changed files with 704 additions and 47 deletions.
13 changes: 9 additions & 4 deletions configs/7B_baichuan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,26 @@
NUM_LAYER = 32


MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
MODEL_ONLY_FOLDER = "local:llm_ckpts_baichuan2/xxxx"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
SAVE_CKPT_FOLDER = "local:llm_ckpts_baichuan2"

# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path,
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
# load function such as "llama"
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
# with an automatic restart mechanism upon training reboot.
Expand Down
21 changes: 13 additions & 8 deletions configs/7B_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,26 @@
NUM_LAYER = 28


MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
MODEL_ONLY_FOLDER = "local:llm_ckpts_gemma/2_hf"#"/mnt/petrelfs/geruijun/hf-gemma-7b-ckpt/"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
SAVE_CKPT_FOLDER = "local:llm_ckpts_gemma"

# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
enable_save_ckpt=True, # enable ckpt save.
enable_internevo2hf_ckpt=True, # enable ckpt save for huggingface format.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path,
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
# load function such as "llama"
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
# with an automatic restart mechanism upon training reboot.
Expand Down Expand Up @@ -54,7 +59,7 @@
# defaults to 0, means disable evaluate
valid_every=0,
pack_sample_into_one=False,
total_steps=20,
total_steps=4,
skip_batches="",
# rampup_batch_size (str): A string with three space-separated integers representing the
# starting batch size, the increment, and the number of steps between
Expand Down Expand Up @@ -185,9 +190,9 @@
"""
parallel = dict(
zero1=dict(size=-1),
tensor=dict(size=1, mode="mtp"),
tensor=dict(size=2, mode="isp"),
pipeline=dict(size=1, interleaved_overlap=True),
weight=dict(size=1, overlap=True, memory_pool=True),
weight=dict(size=2, overlap=True, memory_pool=True),
)

cudnn_deterministic = False
Expand Down
17 changes: 11 additions & 6 deletions configs/7B_qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,30 @@
HIDDEN_SIZE = 3584
NUM_ATTENTION_HEAD = 28
NUM_KV_ATTENTION_HEAD = 4
MLP_RATIO = 2.6875
MLP_RATIO = 5.25
NUM_LAYER = 28


MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
MODEL_ONLY_FOLDER = "local:llm_ckpts_qwen2/xxxx/"
# Ckpt folder format:
# fs: 'local:/mnt/nfs/XXX'
SAVE_CKPT_FOLDER = "local:llm_ckpts"
LOAD_CKPT_FOLDER = "local:llm_ckpts/49"
SAVE_CKPT_FOLDER = "local:llm_ckpts_qwen2"

# boto3 Ckpt folder format:
# import os
# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint
# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm"
# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/"
CHECKPOINT_EVERY = 50
ckpt = dict(
enable_save_ckpt=False, # enable ckpt save.
enable_save_ckpt=True, # enable ckpt save.
enable_internevo2hf_ckpt=True, # enable ckpt save for huggingface format.
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
# 'load_ckpt_info' setting guide:
# 1. the 'path' indicate ckpt path,
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined
# load function such as "llama"
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
# with an automatic restart mechanism upon training reboot.
Expand Down
2 changes: 1 addition & 1 deletion internlm/checkpoint/checkpoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ def try_save_checkpoint(self, train_state, force=False):
logger.info(
f"Finish to convert internevo2hf checkpoint from {save_ckpt_folder} to {save_hf_ckpt_folder}."
)
torch.distributed.barrier()
torch.distributed.barrier()

return now_break

Expand Down
193 changes: 190 additions & 3 deletions internlm/model/modeling_baichuan2.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def __init__(
device=device,
dtype=dtype,
qk_interleaved=qk_interleaved,
enable_qkv_fusion=False,
enable_qkv_fusion=True,
)

self.dropout1 = nn.Dropout(drop_rate)
Expand Down Expand Up @@ -444,8 +444,195 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs):

@staticmethod
def load_hf_weights(folder: str, model: nn.Module) -> None:
raise NotImplementedError
assert folder is not None, "Please specify the folder of the pretrained model"
if gpc.is_rank_for_log():
logger.info(f"Loading pretrained model from {folder}")

fns = get_fns(folder)
model_fns = [
os.path.join(folder, fn)
for fn in fns
if (fn.endswith(".bin") and fn.startswith("pytorch_model"))
or (fn.endswith(".safetensors") and fn.startswith("model"))
]
model_fns.sort()

state_dict = {}
for model_fn in model_fns:
state_dict.update(llm_load(model_fn, map_location="cpu"))

tp_size = gpc.get_world_size(ParallelMode.TENSOR)
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
wp_size = gpc.get_world_size(ParallelMode.WEIGHT)
wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT)
tp_mode = gpc.config.parallel.tensor["mode"]
split_size = wp_size if tp_mode == "isp" else tp_size
local_rank = wp_rank if tp_mode == "isp" else tp_rank
row_dim = 0 if tp_mode == "isp" else 1
if gpc.config.model.get("embed_split_hidden", True):
embed_concat_dim = 1
else:
embed_concat_dim = 0

new_state_dict = {}

# embedding
if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)):
new_state_dict["tok_embeddings.weight"] = torch.chunk(
state_dict.pop("model.embed_tokens.weight"),
split_size,
dim=embed_concat_dim,
)[local_rank]

for idx, i in enumerate(range(model.first_layer, model.last_layer)):
layer_ids = i

# attn
state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.self_attn.W_pack.weight"),
split_size,
dim=0,
)[local_rank]
state_dict[f"layers.{i}.attention.out_proj.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"),
split_size,
dim=row_dim,
)[local_rank]

# ffn
state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"),
split_size,
dim=0,
)[local_rank]
state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"),
split_size,
dim=0,
)[local_rank]
state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk(
state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"),
split_size,
dim=row_dim,
)[local_rank]

# attn norm
state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop(
f"model.layers.{layer_ids}.input_layernorm.weight"
)
# ffn norm
state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop(
f"model.layers.{layer_ids}.post_attention_layernorm.weight"
)

# replace value within decoder layer
for name in list(state_dict.keys()):
if name.startswith(f"layers.{i}"):
new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name)

# output
if gpc.is_last_rank(ParallelMode.PIPELINE):
new_state_dict["output.weight"] = torch.chunk(
state_dict.pop("lm_head.weight"),
split_size,
dim=0,
)[local_rank]
new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight")

missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False)
if len(state_dict) > 0:
logger.warning(f"Be cautious, checkpoint state_dict keys={state_dict.keys()} have not beed loaded.")

if gpc.get_local_rank(ParallelMode.DATA) == 0:
pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE)
logger.info(
f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in "
f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}"
)

internlm_accelerator.empty_cache()

@staticmethod
def convert_internevo2hf_weights(src: str, tgt: str) -> None:
raise NotImplementedError
def permute(qkv, num_heads, num_kv_heads, head_dim, adapt_hf=True):
if adapt_hf:
return qkv
q_per_kv = num_heads // num_kv_heads
qkv = rearrange(qkv.T, "o (g n i) -> o g n i", n=q_per_kv + 2, i=head_dim)
q, k, v = qkv[..., :q_per_kv, :], qkv[..., -2:-1, :], qkv[..., -1:, :]
q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1)
k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1)
qkv = torch.cat((q, k, v), dim=2)
qkv = rearrange(qkv, "o g n i -> o (g n i)").T
return qkv

model_config = gpc.config.model
tp_mode = gpc.config.parallel.tensor["mode"]
row_dim = 0 if tp_mode == "isp" else 1
if model_config["embed_split_hidden"]:
embed_concat_dim = 1
else:
embed_concat_dim = 0

# load states
states, num_shards = Baichuan2.load_sharded_states(src)

# convert state_dict
state_dict = {}
embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None]
for layer_i in tqdm(range(model_config["num_layers"])):
# attn norm, ffn norm
state_dict.update(
{
f"model.layers.{layer_i}.input_layernorm.weight": states[0][
f"layers.{layer_i}.attention_norm.weight"
].clone(),
f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][
f"layers.{layer_i}.ffn_norm.weight"
].clone(),
}
)
# attn
state_dict[f"model.layers.{layer_i}.self_attn.W_pack.weight"] = permute(
torch.cat([states[i][f"layers.{layer_i}.attention.wqkv.weight"] for i in range(num_shards)], dim=0),
num_heads=model_config["num_attention_heads"],
# num_kv_attention_heads equals to num_attention_heads in MHA
num_kv_heads=model_config["num_attention_heads"],
head_dim=model_config["hidden_size"] // model_config["num_attention_heads"],
adapt_hf=model_config.get("adapt_hf", True),
)
state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat(
[states[i][f"layers.{layer_i}.attention.out_proj.weight"] for i in range(num_shards)], dim=row_dim
)
# ffn
state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat(
[states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0
)
state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat(
[states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim
)
state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat(
[states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0
)
# embedding, output
for embedding_key in embedding_key_list:
if embedding_key in states[0]:
break
if embedding_key is None:
raise KeyError("Cannot find embedding key!")
state_dict.update(
{
"model.norm.weight": states[0]["norm.weight"],
"model.embed_tokens.weight": torch.cat(
[states[i][embedding_key] for i in range(num_shards)], dim=embed_concat_dim
),
"lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0),
},
)

# save state_dict to hf format
shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME)
for shard_file, shard in shards.items():
llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"})
if index is not None:
llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index)
Loading

0 comments on commit dfc3d54

Please sign in to comment.