From dfc3d54de128ee51750c8adf24534717609931eb Mon Sep 17 00:00:00 2001 From: sallyjunjun Date: Thu, 5 Sep 2024 13:39:29 +0800 Subject: [PATCH] add load and save huggingface ckpt in baichuan2, qwen2 and gemma --- configs/7B_baichuan2.py | 13 +- configs/7B_gemma.py | 21 +- configs/7B_qwen2.py | 17 +- internlm/checkpoint/checkpoint_manager.py | 2 +- internlm/model/modeling_baichuan2.py | 193 ++++++++++++++++- internlm/model/modeling_gemma.py | 251 ++++++++++++++++++++-- internlm/model/modeling_qwen2.py | 244 ++++++++++++++++++++- internlm/model/modules/mlp.py | 7 +- internlm/model/modules/norm.py | 2 +- internlm/utils/utils.py | 1 + 10 files changed, 704 insertions(+), 47 deletions(-) diff --git a/configs/7B_baichuan2.py b/configs/7B_baichuan2.py index 77b77e4a..fdc1b0ab 100644 --- a/configs/7B_baichuan2.py +++ b/configs/7B_baichuan2.py @@ -10,21 +10,26 @@ NUM_LAYER = 32 -MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" +MODEL_ONLY_FOLDER = "local:llm_ckpts_baichuan2/xxxx" # Ckpt folder format: # fs: 'local:/mnt/nfs/XXX' -SAVE_CKPT_FOLDER = "local:llm_ckpts" -LOAD_CKPT_FOLDER = "local:llm_ckpts/49" +SAVE_CKPT_FOLDER = "local:llm_ckpts_baichuan2" # boto3 Ckpt folder format: # import os # BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint # SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" -# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" CHECKPOINT_EVERY = 50 ckpt = dict( enable_save_ckpt=False, # enable ckpt save. + enable_internevo2hf_ckpt=False, # enable ckpt save for huggingface format. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # 'load_ckpt_info' setting guide: + # 1. the 'path' indicate ckpt path, + # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined + # load function such as "llama" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"), # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) # with an automatic restart mechanism upon training reboot. diff --git a/configs/7B_gemma.py b/configs/7B_gemma.py index 84ec119f..16d35b53 100644 --- a/configs/7B_gemma.py +++ b/configs/7B_gemma.py @@ -12,21 +12,26 @@ NUM_LAYER = 28 -MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" +MODEL_ONLY_FOLDER = "local:llm_ckpts_gemma/2_hf"#"/mnt/petrelfs/geruijun/hf-gemma-7b-ckpt/" # Ckpt folder format: # fs: 'local:/mnt/nfs/XXX' -SAVE_CKPT_FOLDER = "local:llm_ckpts" -LOAD_CKPT_FOLDER = "local:llm_ckpts/49" +SAVE_CKPT_FOLDER = "local:llm_ckpts_gemma" # boto3 Ckpt folder format: # import os # BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint # SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" -# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" CHECKPOINT_EVERY = 50 ckpt = dict( - enable_save_ckpt=False, # enable ckpt save. + enable_save_ckpt=True, # enable ckpt save. + enable_internevo2hf_ckpt=True, # enable ckpt save for huggingface format. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # 'load_ckpt_info' setting guide: + # 1. the 'path' indicate ckpt path, + # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined + # load function such as "llama" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"), # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) # with an automatic restart mechanism upon training reboot. @@ -54,7 +59,7 @@ # defaults to 0, means disable evaluate valid_every=0, pack_sample_into_one=False, - total_steps=20, + total_steps=4, skip_batches="", # rampup_batch_size (str): A string with three space-separated integers representing the # starting batch size, the increment, and the number of steps between @@ -185,9 +190,9 @@ """ parallel = dict( zero1=dict(size=-1), - tensor=dict(size=1, mode="mtp"), + tensor=dict(size=2, mode="isp"), pipeline=dict(size=1, interleaved_overlap=True), - weight=dict(size=1, overlap=True, memory_pool=True), + weight=dict(size=2, overlap=True, memory_pool=True), ) cudnn_deterministic = False diff --git a/configs/7B_qwen2.py b/configs/7B_qwen2.py index f2f9623b..a40cb108 100644 --- a/configs/7B_qwen2.py +++ b/configs/7B_qwen2.py @@ -7,25 +7,30 @@ HIDDEN_SIZE = 3584 NUM_ATTENTION_HEAD = 28 NUM_KV_ATTENTION_HEAD = 4 -MLP_RATIO = 2.6875 +MLP_RATIO = 5.25 NUM_LAYER = 28 -MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" +MODEL_ONLY_FOLDER = "local:llm_ckpts_qwen2/xxxx/" # Ckpt folder format: # fs: 'local:/mnt/nfs/XXX' -SAVE_CKPT_FOLDER = "local:llm_ckpts" -LOAD_CKPT_FOLDER = "local:llm_ckpts/49" +SAVE_CKPT_FOLDER = "local:llm_ckpts_qwen2" # boto3 Ckpt folder format: # import os # BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint # SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" -# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" CHECKPOINT_EVERY = 50 ckpt = dict( - enable_save_ckpt=False, # enable ckpt save. + enable_save_ckpt=True, # enable ckpt save. + enable_internevo2hf_ckpt=True, # enable ckpt save for huggingface format. save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # 'load_ckpt_info' setting guide: + # 1. the 'path' indicate ckpt path, + # 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all" + # 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, support: "internevo", "hf", or other custom-defined + # load function such as "llama" + load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="hf"), # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) # with an automatic restart mechanism upon training reboot. diff --git a/internlm/checkpoint/checkpoint_manager.py b/internlm/checkpoint/checkpoint_manager.py index cfec1612..2f7f5d4e 100644 --- a/internlm/checkpoint/checkpoint_manager.py +++ b/internlm/checkpoint/checkpoint_manager.py @@ -459,7 +459,7 @@ def try_save_checkpoint(self, train_state, force=False): logger.info( f"Finish to convert internevo2hf checkpoint from {save_ckpt_folder} to {save_hf_ckpt_folder}." ) - torch.distributed.barrier() + torch.distributed.barrier() return now_break diff --git a/internlm/model/modeling_baichuan2.py b/internlm/model/modeling_baichuan2.py index b7daf55b..5a018848 100644 --- a/internlm/model/modeling_baichuan2.py +++ b/internlm/model/modeling_baichuan2.py @@ -130,7 +130,7 @@ def __init__( device=device, dtype=dtype, qk_interleaved=qk_interleaved, - enable_qkv_fusion=False, + enable_qkv_fusion=True, ) self.dropout1 = nn.Dropout(drop_rate) @@ -444,8 +444,195 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs): @staticmethod def load_hf_weights(folder: str, model: nn.Module) -> None: - raise NotImplementedError + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + fns = get_fns(folder) + model_fns = [ + os.path.join(folder, fn) + for fn in fns + if (fn.endswith(".bin") and fn.startswith("pytorch_model")) + or (fn.endswith(".safetensors") and fn.startswith("model")) + ] + model_fns.sort() + + state_dict = {} + for model_fn in model_fns: + state_dict.update(llm_load(model_fn, map_location="cpu")) + + tp_size = gpc.get_world_size(ParallelMode.TENSOR) + tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + wp_size = gpc.get_world_size(ParallelMode.WEIGHT) + wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) + tp_mode = gpc.config.parallel.tensor["mode"] + split_size = wp_size if tp_mode == "isp" else tp_size + local_rank = wp_rank if tp_mode == "isp" else tp_rank + row_dim = 0 if tp_mode == "isp" else 1 + if gpc.config.model.get("embed_split_hidden", True): + embed_concat_dim = 1 + else: + embed_concat_dim = 0 + + new_state_dict = {} + + # embedding + if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)): + new_state_dict["tok_embeddings.weight"] = torch.chunk( + state_dict.pop("model.embed_tokens.weight"), + split_size, + dim=embed_concat_dim, + )[local_rank] + + for idx, i in enumerate(range(model.first_layer, model.last_layer)): + layer_ids = i + + # attn + state_dict[f"layers.{i}.attention.wqkv.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.W_pack.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.out_proj.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # ffn + state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # attn norm + state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.input_layernorm.weight" + ) + # ffn norm + state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.post_attention_layernorm.weight" + ) + + # replace value within decoder layer + for name in list(state_dict.keys()): + if name.startswith(f"layers.{i}"): + new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) + + # output + if gpc.is_last_rank(ParallelMode.PIPELINE): + new_state_dict["output.weight"] = torch.chunk( + state_dict.pop("lm_head.weight"), + split_size, + dim=0, + )[local_rank] + new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight") + + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + if len(state_dict) > 0: + logger.warning(f"Be cautious, checkpoint state_dict keys={state_dict.keys()} have not beed loaded.") + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + + internlm_accelerator.empty_cache() @staticmethod def convert_internevo2hf_weights(src: str, tgt: str) -> None: - raise NotImplementedError \ No newline at end of file + def permute(qkv, num_heads, num_kv_heads, head_dim, adapt_hf=True): + if adapt_hf: + return qkv + q_per_kv = num_heads // num_kv_heads + qkv = rearrange(qkv.T, "o (g n i) -> o g n i", n=q_per_kv + 2, i=head_dim) + q, k, v = qkv[..., :q_per_kv, :], qkv[..., -2:-1, :], qkv[..., -1:, :] + q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) + k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) + qkv = torch.cat((q, k, v), dim=2) + qkv = rearrange(qkv, "o g n i -> o (g n i)").T + return qkv + + model_config = gpc.config.model + tp_mode = gpc.config.parallel.tensor["mode"] + row_dim = 0 if tp_mode == "isp" else 1 + if model_config["embed_split_hidden"]: + embed_concat_dim = 1 + else: + embed_concat_dim = 0 + + # load states + states, num_shards = Baichuan2.load_sharded_states(src) + + # convert state_dict + state_dict = {} + embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None] + for layer_i in tqdm(range(model_config["num_layers"])): + # attn norm, ffn norm + state_dict.update( + { + f"model.layers.{layer_i}.input_layernorm.weight": states[0][ + f"layers.{layer_i}.attention_norm.weight" + ].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][ + f"layers.{layer_i}.ffn_norm.weight" + ].clone(), + } + ) + # attn + state_dict[f"model.layers.{layer_i}.self_attn.W_pack.weight"] = permute( + torch.cat([states[i][f"layers.{layer_i}.attention.wqkv.weight"] for i in range(num_shards)], dim=0), + num_heads=model_config["num_attention_heads"], + # num_kv_attention_heads equals to num_attention_heads in MHA + num_kv_heads=model_config["num_attention_heads"], + head_dim=model_config["hidden_size"] // model_config["num_attention_heads"], + adapt_hf=model_config.get("adapt_hf", True), + ) + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.out_proj.weight"] for i in range(num_shards)], dim=row_dim + ) + # ffn + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 + ) + # embedding, output + for embedding_key in embedding_key_list: + if embedding_key in states[0]: + break + if embedding_key is None: + raise KeyError("Cannot find embedding key!") + state_dict.update( + { + "model.norm.weight": states[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat( + [states[i][embedding_key] for i in range(num_shards)], dim=embed_concat_dim + ), + "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0), + }, + ) + + # save state_dict to hf format + shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME) + for shard_file, shard in shards.items(): + llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"}) + if index is not None: + llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index) diff --git a/internlm/model/modeling_gemma.py b/internlm/model/modeling_gemma.py index 10eb27cd..4813a6f1 100644 --- a/internlm/model/modeling_gemma.py +++ b/internlm/model/modeling_gemma.py @@ -1,9 +1,11 @@ # Copyright (c) InternLM. All rights reserved. import math +import os from typing import Optional import torch from torch import nn +from tqdm import tqdm from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode @@ -26,16 +28,20 @@ ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger +from internlm.utils.storage_manager import get_fns, llm_load, llm_save +from transformers.modeling_utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + shard_checkpoint, +) try: from flash_attn.modules.mlp import ParallelFusedMLP except ImportError: pass -MODEL_TYPE = "GEMMA" - -logger = get_logger(__file__) internlm_accelerator = get_accelerator() +logger = get_logger(__file__) class GemmaDecoder(nn.Module): @@ -146,7 +152,9 @@ def __init__( self.dropout1 = nn.Dropout(drop_rate) self.dropout2 = nn.Dropout(drop_rate) - self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset) + self.attention_norm = new_layer_norm( + norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset + ) self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset) sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) @@ -180,7 +188,6 @@ def __init__( dtype=dtype, ) - self.use_glu = use_glu self.use_swiglu = use_swiglu self.use_scaled_init = use_scaled_init @@ -224,8 +231,7 @@ def reset_parameters(self): param.data ) - def forward( - self, hidden_states, residual=None, **kwargs): + def forward(self, hidden_states, residual=None, **kwargs): if self.checkpoint and self.training: args = convert_attn_kwargs_to_args(kwargs) return activation_checkpoint(self._forward, False, hidden_states, residual, *args) @@ -400,8 +406,8 @@ def __init__( self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") if first: - self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) - for _, param in self.tok_embeddings.named_parameters(): + self.embed_tokens = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + for _, param in self.embed_tokens.named_parameters(): if init_type == "normal": normal_(std=embedding_init_std)(param) else: @@ -451,7 +457,9 @@ def __init__( if last: if not apply_post_layer_norm: - self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset) + self.norm = new_layer_norm( + norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset + ) self.output = new_linear( name="output", @@ -497,8 +505,8 @@ def __init__( def forward(self, hidden_states=None, input_ids=None, **kwargs): # attention_mask: compute attention on the places where the value is 1 - if hasattr(self, "tok_embeddings"): - hidden_states = self.tok_embeddings(input_ids) + if hasattr(self, "embed_tokens"): + hidden_states = self.embed_tokens(input_ids) if self.embed_grad_scale != 1: hidden_states = ( self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() @@ -521,11 +529,224 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs): return (hidden_states, extra_hidden_states_list) return hidden_states - + @staticmethod def load_hf_weights(folder: str, model: nn.Module) -> None: - raise NotImplementedError + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + fns = get_fns(folder) + model_fns = [ + os.path.join(folder, fn) + for fn in fns + if (fn.endswith(".bin") and fn.startswith("pytorch_model")) + or (fn.endswith(".safetensors") and fn.startswith("model")) + ] + model_fns.sort() + + state_dict = {} + for model_fn in model_fns: + state_dict.update(llm_load(model_fn, map_location="cpu")) + + tp_size = gpc.get_world_size(ParallelMode.TENSOR) + tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + wp_size = gpc.get_world_size(ParallelMode.WEIGHT) + wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) + tp_mode = gpc.config.parallel.tensor["mode"] + split_size = wp_size if tp_mode == "isp" else tp_size + local_rank = wp_rank if tp_mode == "isp" else tp_rank + row_dim = 0 if tp_mode == "isp" else 1 + if gpc.config.model.get("embed_split_hidden", True): + embed_concat_dim = 1 + else: + embed_concat_dim = 0 + + new_state_dict = {} + + # embedding + if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)): + new_state_dict["embed_tokens.weight"] = torch.chunk( + state_dict.get("model.embed_tokens.weight"), + split_size, + dim=embed_concat_dim, + )[local_rank] + + for idx, i in enumerate(range(model.first_layer, model.last_layer)): + layer_ids = i + + # attn + state_dict[f"layers.{i}.attention.wq.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wk.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wv.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # ffn + state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # attn norm + state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.input_layernorm.weight" + ) + # ffn norm + state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.post_attention_layernorm.weight" + ) + + # replace value within decoder layer + for name in list(state_dict.keys()): + if name.startswith(f"layers.{i}"): + new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) + + # output + if gpc.is_last_rank(ParallelMode.PIPELINE): + if "model.lm_head.weight" in state_dict: + new_state_dict["output.weight"] = torch.chunk( + state_dict.pop("model.lm_head.weight"), # we do not tie lm head with embedding + split_size, + dim=0, + )[local_rank] + state_dict.pop("model.embed_tokens.weight") + else: + new_state_dict["output.weight"] = torch.chunk( + # gemma model ties lm head with embedding in transformers implementation + state_dict.pop("model.embed_tokens.weight"), + split_size, + dim=0, + )[local_rank] + new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight") + + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + if len(state_dict) > 0: + logger.warning(f"Be cautious, checkpoint state_dict keys={state_dict.keys()} have not beed loaded.") + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + + internlm_accelerator.empty_cache() @staticmethod def convert_internevo2hf_weights(src: str, tgt: str) -> None: - raise NotImplementedError \ No newline at end of file + model_config = gpc.config.model + tp_mode = gpc.config.parallel.tensor["mode"] + row_dim = 0 if tp_mode == "isp" else 1 + + # load states + states, num_shards = Gemma.load_sharded_states(src) + + # convert state_dict + state_dict = {} + embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None] + for layer_i in tqdm(range(model_config["num_layers"])): + # attn norm, mlp norm + state_dict.update( + { + f"model.layers.{layer_i}.input_layernorm.weight": states[0][ + f"layers.{layer_i}.attention_norm.weight" + ].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][ + f"layers.{layer_i}.ffn_norm.weight" + ].clone(), + } + ) + # attn wqkv weight and bias + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)], + dim=0, + ) + # attn wo weight + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim + ) + + # mlp + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 + ) + + # embedding, head + for embedding_key in embedding_key_list: + if embedding_key in states[0]: + break + if embedding_key is None: + raise KeyError("Cannot find embedding key!") + if model_config["embed_split_hidden"]: + embed_concat_dim = 1 + tok_emb_list = [states[i][embedding_key] for i in range(num_shards)] + else: + embed_concat_dim = 0 + _, size_1 = states[0][embedding_key].shape + embdim_pertp = size_1 // num_shards + tok_emb_list = [ + torch.concat( + [ + states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)] + for tp in range(num_shards) + ], + dim=0, + ) + for local_rank in range(num_shards) + ] + state_dict.update( + { + "model.norm.weight": states[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim), + "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0), + }, + ) + + # save state_dict to hf format + shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME) + for shard_file, shard in shards.items(): + llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"}) + if index is not None: + # Save the index as well + llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index) diff --git a/internlm/model/modeling_qwen2.py b/internlm/model/modeling_qwen2.py index 6ac8127a..d3700baa 100644 --- a/internlm/model/modeling_qwen2.py +++ b/internlm/model/modeling_qwen2.py @@ -4,7 +4,6 @@ from typing import Optional import torch -from einops import rearrange from torch import nn from tqdm import tqdm @@ -383,8 +382,8 @@ def __init__( checkpoint_layer_num = int(num_layers * checkpoint) if first: - self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) - for _, param in self.tok_embeddings.named_parameters(): + self.embed_tokens = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + for _, param in self.embed_tokens.named_parameters(): if init_type == "normal": normal_(std=embedding_init_std)(param) else: @@ -485,8 +484,8 @@ def __init__( def forward(self, hidden_states=None, input_ids=None, **kwargs): # attention_mask: compute attention on the places where the value is 1 - if hasattr(self, "tok_embeddings"): - hidden_states = self.tok_embeddings(input_ids) + if hasattr(self, "embed_tokens"): + hidden_states = self.embed_tokens(input_ids) if self.embed_grad_scale != 1: hidden_states = ( self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() @@ -515,8 +514,239 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs): @staticmethod def load_hf_weights(folder: str, model: nn.Module) -> None: - raise NotImplementedError + assert folder is not None, "Please specify the folder of the pretrained model" + if gpc.is_rank_for_log(): + logger.info(f"Loading pretrained model from {folder}") + + fns = get_fns(folder) + model_fns = [ + os.path.join(folder, fn) + for fn in fns + if (fn.endswith(".bin") and fn.startswith("pytorch_model")) + or (fn.endswith(".safetensors") and fn.startswith("model")) + ] + model_fns.sort() + + state_dict = {} + for model_fn in model_fns: + state_dict.update(llm_load(model_fn, map_location="cpu")) + + tp_size = gpc.get_world_size(ParallelMode.TENSOR) + tp_rank = gpc.get_local_rank(ParallelMode.TENSOR) + wp_size = gpc.get_world_size(ParallelMode.WEIGHT) + wp_rank = gpc.get_local_rank(ParallelMode.WEIGHT) + tp_mode = gpc.config.parallel.tensor["mode"] + split_size = wp_size if tp_mode == "isp" else tp_size + local_rank = wp_rank if tp_mode == "isp" else tp_rank + row_dim = 0 if tp_mode == "isp" else 1 + if gpc.config.model.get("embed_split_hidden", True): + embed_concat_dim = 1 + else: + embed_concat_dim = 0 + + new_state_dict = {} + + # embedding + if (gpc.get_local_rank(ParallelMode.PIPELINE) == 0) or (not gpc.is_using_parallel_mode(ParallelMode.PIPELINE)): + new_state_dict["embed_tokens.weight"] = torch.chunk( + state_dict.pop("model.embed_tokens.weight"), + split_size, + dim=embed_concat_dim, + )[local_rank] + + for idx, i in enumerate(range(model.first_layer, model.last_layer)): + layer_ids = i + + # attn + state_dict[f"layers.{i}.attention.wq.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wq.bias"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.q_proj.bias"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wk.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wk.bias"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.k_proj.bias"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wv.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wv.bias"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.v_proj.bias"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.attention.wo.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.self_attn.o_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # ffn + state_dict[f"layers.{i}.feed_forward.w1.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.gate_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w3.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.up_proj.weight"), + split_size, + dim=0, + )[local_rank] + state_dict[f"layers.{i}.feed_forward.w2.weight"] = torch.chunk( + state_dict.pop(f"model.layers.{layer_ids}.mlp.down_proj.weight"), + split_size, + dim=row_dim, + )[local_rank] + + # attn norm + state_dict[f"layers.{i}.attention_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.input_layernorm.weight" + ) + # ffn norm + state_dict[f"layers.{i}.ffn_norm.weight"] = state_dict.pop( + f"model.layers.{layer_ids}.post_attention_layernorm.weight" + ) + + # replace value within decoder layer + for name in list(state_dict.keys()): + if name.startswith(f"layers.{i}"): + new_state_dict[name.replace(f".{i}.", f".{idx}.")] = state_dict.pop(name) + + # output + if gpc.is_last_rank(ParallelMode.PIPELINE): + new_state_dict["output.weight"] = torch.chunk( + state_dict.pop("lm_head.weight"), + split_size, + dim=0, + )[local_rank] + new_state_dict["norm.weight"] = state_dict.pop("model.norm.weight") + + missing_keys, unexpected_keys = model.load_state_dict(new_state_dict, strict=False) + if len(state_dict) > 0: + logger.warning(f"Be cautious, checkpoint state_dict keys={state_dict.keys()} have not beed loaded.") + + if gpc.get_local_rank(ParallelMode.DATA) == 0: + pp_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(ParallelMode.PIPELINE) + logger.info( + f"Missing keys:{missing_keys}, unexpected keys:{unexpected_keys} in " + f"tp:{gpc.get_local_rank(ParallelMode.TENSOR)}, pp:{pp_rank}" + ) + + internlm_accelerator.empty_cache() @staticmethod def convert_internevo2hf_weights(src: str, tgt: str) -> None: - raise NotImplementedError \ No newline at end of file + model_config = gpc.config.model + tp_mode = gpc.config.parallel.tensor["mode"] + row_dim = 0 if tp_mode == "isp" else 1 + + # load states + states, num_shards = Qwen2.load_sharded_states(src) + + # convert state_dict + state_dict = {} + embedding_key_list = ["tok_embeddings.weight", "embed_tokens.weight", None] + for layer_i in tqdm(range(model_config["num_layers"])): + # attn norm, mlp norm + state_dict.update( + { + f"model.layers.{layer_i}.input_layernorm.weight": states[0][ + f"layers.{layer_i}.attention_norm.weight" + ].clone(), + f"model.layers.{layer_i}.post_attention_layernorm.weight": states[0][ + f"layers.{layer_i}.ffn_norm.weight" + ].clone(), + } + ) + # attn wqkv weight and bias + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wq.weight"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.q_proj.bias"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wq.bias"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wk.weight"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.k_proj.bias"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wk.bias"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wv.weight"] for i in range(num_shards)], + dim=0, + ) + state_dict[f"model.layers.{layer_i}.self_attn.v_proj.bias"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wv.bias"] for i in range(num_shards)], + dim=0, + ) + # attn wo weight + state_dict[f"model.layers.{layer_i}.self_attn.o_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.attention.wo.weight"] for i in range(num_shards)], dim=row_dim + ) + + # mlp + state_dict[f"model.layers.{layer_i}.mlp.gate_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w1.weight"] for i in range(num_shards)], dim=0 + ) + state_dict[f"model.layers.{layer_i}.mlp.down_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w2.weight"] for i in range(num_shards)], dim=row_dim + ) + state_dict[f"model.layers.{layer_i}.mlp.up_proj.weight"] = torch.cat( + [states[i][f"layers.{layer_i}.feed_forward.w3.weight"] for i in range(num_shards)], dim=0 + ) + + # embedding, head + for embedding_key in embedding_key_list: + if embedding_key in states[0]: + break + if embedding_key is None: + raise KeyError("Cannot find embedding key!") + if model_config["embed_split_hidden"]: + embed_concat_dim = 1 + tok_emb_list = [states[i][embedding_key] for i in range(num_shards)] + else: + embed_concat_dim = 0 + _, size_1 = states[0][embedding_key].shape + embdim_pertp = size_1 // num_shards + tok_emb_list = [ + torch.concat( + [ + states[tp][embedding_key][:, embdim_pertp * local_rank : embdim_pertp * (local_rank + 1)] + for tp in range(num_shards) + ], + dim=0, + ) + for local_rank in range(num_shards) + ] + state_dict.update( + { + "model.norm.weight": states[0]["norm.weight"], + "model.embed_tokens.weight": torch.cat(tok_emb_list, dim=embed_concat_dim), + "lm_head.weight": torch.cat([states[i]["output.weight"] for i in range(num_shards)], dim=0), + }, + ) + + # save state_dict to hf format + shards, index = shard_checkpoint(state_dict, weights_name=SAFE_WEIGHTS_NAME) + for shard_file, shard in shards.items(): + llm_save(save_path=os.path.join(tgt, shard_file), saved_obj=shard, metadata={"format": "pt"}) + if index is not None: + # Save the index as well + llm_save(save_path=os.path.join(tgt, SAFE_WEIGHTS_INDEX_NAME), saved_obj=index) diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index dfdbc5bf..b836ff3d 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -3,13 +3,13 @@ from typing import Dict, Optional -from internlm.utils.utils import ActivationType import torch from torch import nn from internlm.model.modules.linear import new_linear from internlm.model.modules.utils import Gelu, Silu from internlm.utils.logger import get_logger +from internlm.utils.utils import ActivationType logger = get_logger(__file__) @@ -72,7 +72,10 @@ def __init__( ): super().__init__() - assert activation_type in (ActivationType.swiglu.name, ActivationType.gelu.name), f"Unsupported activation type: {activation_type}" + assert activation_type in ( + ActivationType.swiglu.name, + ActivationType.gelu.name, + ), f"Unsupported activation type: {activation_type}" self.mlp_layer_fusion = mlp_layer_fusion self.activation_type = activation_type diff --git a/internlm/model/modules/norm.py b/internlm/model/modules/norm.py index 70717ff9..2a9700f8 100644 --- a/internlm/model/modules/norm.py +++ b/internlm/model/modules/norm.py @@ -16,7 +16,7 @@ def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5, add_unit_offset=False): if norm_type == "rmsnorm": rmsnorm_params = inspect.signature(RMSNorm).parameters - if 'add_unit_offset' in rmsnorm_params: + if "add_unit_offset" in rmsnorm_params: return RMSNorm(normalized_shape, eps, add_unit_offset) else: return RMSNorm(normalized_shape, eps) diff --git a/internlm/utils/utils.py b/internlm/utils/utils.py index ac2b62ae..ca6b3215 100644 --- a/internlm/utils/utils.py +++ b/internlm/utils/utils.py @@ -68,6 +68,7 @@ class ActivationType(Enum): swiglu = 1 gelu = 2 + def check_attention_argument(*args, **kwargs) -> str: # self, qkv, ... # self, q, kv, ....