diff --git a/configs/7B_gemma.py b/configs/7B_gemma.py new file mode 100644 index 00000000..84ec119f --- /dev/null +++ b/configs/7B_gemma.py @@ -0,0 +1,226 @@ +JOB_NAME = "7b_gemma_train" +model_type = "GEMMA" +DO_ALERT = False + +VOCAB_SIZE = 256000 +SEQ_LEN = 2048 +HIDDEN_SIZE = 3072 +NUM_ATTENTION_HEAD = 16 +NUM_KV_ATTENTION_HEAD = 16 +HEAD_DIM = 256 +MLP_RATIO = 8 +NUM_LAYER = 28 + + +MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" +# Ckpt folder format: +# fs: 'local:/mnt/nfs/XXX' +SAVE_CKPT_FOLDER = "local:llm_ckpts" +LOAD_CKPT_FOLDER = "local:llm_ckpts/49" + +# boto3 Ckpt folder format: +# import os +# BOTO3_IP = os.environ["BOTO3_IP"] # boto3 bucket endpoint +# SAVE_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm" +# LOAD_CKPT_FOLDER = f"boto3:s3://model_weights.{BOTO3_IP}/internlm/snapshot/1/" +CHECKPOINT_EVERY = 50 +ckpt = dict( + enable_save_ckpt=False, # enable ckpt save. + save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt. + # 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering + # training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm) + # with an automatic restart mechanism upon training reboot. + # Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint + # path specified in `load_ckpt_info` by default. + # If you want to initialize your model weights from another model, you must set `auto_resume` to False. + # If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None. + auto_resume=False, + checkpoint_every=CHECKPOINT_EVERY, + async_upload=True, # async ckpt upload. (only work for boto3 ckpt) + async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload. + oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency. +) + +TRAIN_FOLDER = None +VALID_FOLDER = None # "/path/to/dataset" +data = dict( + seq_len=SEQ_LEN, + # micro_num means the number of micro_batch contained in one gradient update + micro_num=4, + # packed_length = micro_bsz * SEQ_LEN + micro_bsz=1, + # defaults to the value of micro_num + valid_micro_num=4, + # defaults to 0, means disable evaluate + valid_every=0, + pack_sample_into_one=False, + total_steps=20, + skip_batches="", + # rampup_batch_size (str): A string with three space-separated integers representing the + # starting batch size, the increment, and the number of steps between + # each increment. For example, "192 24 8" means that the batch size (micro_num) + # starts at 192 and increases by 24 every 8 steps. Defaults to None. + # (IMPORTANT): The interval step size is 'micro_bsz'. + rampup_batch_size="", + # Datasets with less than 50 rows will be discarded + min_length=50, + train_folder=TRAIN_FOLDER, + valid_folder=VALID_FOLDER, + empty_cache_and_diag_interval=200, + diag_outlier_ratio=1.1, +) + +grad_scaler = dict( + fp16=dict( + # the initial loss scale, defaults to 2**16 + initial_scale=2**16, + # the minimum loss scale, defaults to None + min_scale=1, + # the number of steps to increase loss scale when no overflow occurs + growth_interval=1000, + ), + # the multiplication factor for increasing loss scale, defaults to 2 + growth_factor=2, + # the multiplication factor for decreasing loss scale, defaults to 0.5 + backoff_factor=0.5, + # the maximum loss scale, defaults to None + max_scale=2**24, + # the number of overflows before decreasing loss scale, defaults to 2 + hysteresis=2, +) + +hybrid_zero_optimizer = dict( + # Enable low_level_optimzer overlap_communication + overlap_sync_grad=True, + overlap_sync_param=False, + # bucket size for nccl communication params + reduce_bucket_size=512 * 1024 * 1024, + # grad clipping + clip_grad_norm=1.0, +) + +loss = dict( + label_smoothing=0, +) + +adam = dict( + lr=1e-4, + adam_beta1=0.9, + adam_beta2=0.95, + adam_beta2_c=0, + adam_eps=1e-8, + weight_decay=0.01, +) + +lr_scheduler = dict( + total_steps=data["total_steps"], + init_steps=0, # optimizer_warmup_step + warmup_ratio=0.01, + eta_min=1e-5, + last_epoch=-1, +) + +beta2_scheduler = dict( + init_beta2=adam["adam_beta2"], + c=adam["adam_beta2_c"], + cur_iter=-1, +) + +use_fp32_norm = False +model = dict( + checkpoint=False, + num_chunks=1, + num_attention_heads=NUM_ATTENTION_HEAD, + num_kv_attention_heads=NUM_KV_ATTENTION_HEAD, + embed_split_hidden=True, + vocab_size=VOCAB_SIZE, + embed_grad_scale=1, + parallel_output=True, + hidden_size=HIDDEN_SIZE, + num_layers=NUM_LAYER, + no_bias=True, + mlp_ratio=MLP_RATIO, + apply_post_layer_norm=False, + dtype="torch.bfloat16", + add_unit_offset=True, + norm_type="rmsnorm", + layer_norm_epsilon=1e-6, + head_dim=HEAD_DIM, + use_flash_attn=True, + # Whether the odd and even columns of the query and key in the model are normally interleaved. + # If it's True, the model's odd and even columns are normally ordered; if it's False, + # it means that the model has prematurely concatenated all odd columns and even columns in front + # and back, in order to improve the RoPE's computational efficiency. + # Example: + # qk_interleaved = True: q[-1] = [q1,q2,q3,q4,q5,q6,...], k[-1] = [k1,k2,k3,k4,k5,k6,...] + # qk_interleaved = False: q[-1] = [q1,q3,q5,...,q2,q4,q6,...], k[-1] = [k1,k3,k5,...,k2,k4,k6,...] + qk_interleaved=False, + use_swiglu=False, +) + +""" +zero1 parallel (dict): + 1. size: int + * if size <= 0, the size of the zero process group is equal to the size of the dp process group, + so parameters will be divided within the range of dp. + * if size == 1, zero is not used, and all dp groups retain the full amount of model parameters. + * if size > 1 and size <= dp world size, the world size of zero is a subset of dp world size. + For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. + 2. fsdp: bool, enable/disable torch's fully sharded data parallel, defaults to False. +tensor parallel (dict): + 1. size: int, the size of tensor parallel. + 2. mode: str, the tensor parallel mode, should be in ['mtp', 'msp', 'fsp', 'isp'], + defaults to 'mtp', means the pure megatron tensor parallel without sequence parallel. + msp: megatron tensor parallel with sequence parallel, sequence parallel size = tensor parallel size. + fsp: tensor parallel by flash-attn with sequence parallel, sequence parallel size = tensor parallel size. + isp: customed intern sequence parallel without tensor parallel, can be used with weight parallel. +pipeline parallel (dict): + 1. size: int, the size of pipeline parallel. + 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler, + defaults to False. +weight parallel (dict): + 1. size: int, the size of weight parallel. + 2. overlap: bool, enable/disable all_gather/reduce_scatter communication overlap, defaults to False. + 3. memory_pool: bool, enable/disable memory pool, defaults to False. +""" +parallel = dict( + zero1=dict(size=-1), + tensor=dict(size=1, mode="mtp"), + pipeline=dict(size=1, interleaved_overlap=True), + weight=dict(size=1, overlap=True, memory_pool=True), +) + +cudnn_deterministic = False +cudnn_benchmark = False + +monitor = dict( + # feishu alert configs + alert=dict( + enable_feishu_alert=DO_ALERT, + feishu_alert_address=None, # feishu webhook to send alert message + light_monitor_address=None, # light_monitor address to send heartbeat + alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", + ), + tensorboard=dict( + queue_max_length=10, + ), +) + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" + +generation = dict( + ckpt_folder="/path/to/saved/ckpt", + output_folder="/path/to/save/generation", + batch_size=1, + eos_id=[2, 0], + bos_id=1, + max_length=100, + do_sample=True, + temperature=1.0, + top_k=50, + top_p=1.0, + repetition_penalty=1, + length_penalty=1.0, +) diff --git a/internlm/model/modeling_baichuan2.py b/internlm/model/modeling_baichuan2.py index 349eb0c3..b7daf55b 100644 --- a/internlm/model/modeling_baichuan2.py +++ b/internlm/model/modeling_baichuan2.py @@ -1,10 +1,14 @@ # Copyright (c) InternLM. All rights reserved. import math +import os from typing import Optional import torch +from einops import rearrange from torch import nn +from tqdm import tqdm +from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.initialize.initialize_tensor import ( @@ -13,6 +17,7 @@ scaled_init_method_uniform, uniform_, ) +from internlm.model.base_model import BaseModel from internlm.model.modules.embedding import Embedding1D from internlm.model.modules.linear import new_linear from internlm.model.modules.mha import MHA @@ -24,7 +29,14 @@ ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger +from internlm.utils.storage_manager import get_fns, llm_load, llm_save +from transformers.modeling_utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + shard_checkpoint, +) +internlm_accelerator = get_accelerator() logger = get_logger(__file__) @@ -136,7 +148,7 @@ def __init__( mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, # TODO: to support more activation functions - activation_type="swiglu" if use_swiglu else "swiglu", + activation_type="swiglu" if use_swiglu else "gelu", ) self.use_swiglu = use_swiglu @@ -189,7 +201,7 @@ def forward(self, hidden_states, residual=None, **kwargs): else: return self._forward(hidden_states, residual, **kwargs) - def _forward(self, hidden_states=None, residual=None, *args, **kwargs): # pylint: disable=W1113 + def _forward(self, hidden_states, residual, *args, **kwargs): r"""Pass the input through the encoder layer. Args: @@ -258,7 +270,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states -class Baichuan2(nn.Module): +class Baichuan2(BaseModel): """ 1D Packed Flash Llama. @@ -429,3 +441,11 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs): hidden_states = self.output(hidden_states) return hidden_states + + @staticmethod + def load_hf_weights(folder: str, model: nn.Module) -> None: + raise NotImplementedError + + @staticmethod + def convert_internevo2hf_weights(src: str, tgt: str) -> None: + raise NotImplementedError \ No newline at end of file diff --git a/internlm/model/modeling_gemma.py b/internlm/model/modeling_gemma.py new file mode 100644 index 00000000..10eb27cd --- /dev/null +++ b/internlm/model/modeling_gemma.py @@ -0,0 +1,531 @@ +# Copyright (c) InternLM. All rights reserved. +import math +from typing import Optional + +import torch +from torch import nn + +from internlm.accelerator import get_accelerator +from internlm.core.context import ParallelMode +from internlm.core.context.parallel_context import global_context as gpc +from internlm.initialize.initialize_tensor import ( + normal_, + scaled_init_method_normal, + scaled_init_method_uniform, + uniform_, +) +from internlm.model.base_model import BaseModel +from internlm.model.modules.embedding import Embedding1D +from internlm.model.modules.linear import new_linear +from internlm.model.modules.mha import GQA +from internlm.model.modules.mlp import new_feed_forward +from internlm.model.modules.norm import new_layer_norm +from internlm.model.utils import ( + convert_attn_args_to_kwargs, + convert_attn_kwargs_to_args, +) +from internlm.solver.activation_checkpoint import activation_checkpoint +from internlm.utils.logger import get_logger + +try: + from flash_attn.modules.mlp import ParallelFusedMLP +except ImportError: + pass + +MODEL_TYPE = "GEMMA" + +logger = get_logger(__file__) +internlm_accelerator = get_accelerator() + + +class GemmaDecoder(nn.Module): + """ + 1D Packed Flash Llama Layer. + + Args: + hidden_size (int): The hidden size of model. 768 by default. + num_attention_heads (int): The number of attention heads. 12 by default. + head_dim (int): The dimention of attention head dimention. hidden_size divided by num_heads by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0 by default. + drop_rate (float): The dropout rate of the input hidden state. 0.0 by default. + dtype (torch.dtype): Type of data. torch.float by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-5 by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + layer_idx (int): The index of current layer. 0 by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + device (Optional[Union[str, torch.device]]): The device will be used. + add_unit_offset(bool): Add one to RMSNorm weight multiply by normed input. False by default. + use_glu (bool): Whether to use glu. True by default. + use_swiglu (bool): Whether to use swiglu. True by default. + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. + tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], + "mtp" by default. + """ + + def __init__( + self, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_kv_attention_heads: int = 12, + head_dim: int = None, + mlp_ratio: int = 4, + attn_drop_rate: float = 0, + drop_rate: float = 0.0, + max_position_embeddings: int = 2048, + dtype: torch.dtype = torch.float, + layer_norm_epsilon: float = 1e-6, + checkpoint: bool = False, + layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, + residual_in_fp32: bool = False, + device: Optional[torch.device] = None, + apply_post_layer_norm: bool = False, + fused_dropout_add_ln: bool = True, + no_bias: bool = False, + norm_type: str = "rmsnorm", + qk_interleaved: bool = False, + add_unit_offset: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_glu: bool = True, + use_swiglu: bool = True, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + init_type: str = "normal", + rope_base: int = 10000, + mlp_layer_fusion: bool = False, + multiple_of: int = 256, + tp_mode: str = "mtp", + ): + super().__init__() + self.checkpoint = checkpoint + # dropout selective checkpoint can only be enabled when checkpoint is disabled. + self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False + self.layer_idx = layer_idx + self.prenorm = not apply_post_layer_norm + assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" + self.fused_dropout_add_ln = fused_dropout_add_ln + self.attn_wqkv_init_std = attn_wqkv_init_std + self.attn_other_init_std = attn_other_init_std + self.ffn_uplayer_init_std = ffn_uplayer_init_std + self.ffn_other_init_std = ffn_other_init_std + + if not head_dim: + head_dim = hidden_size // num_attention_heads + + self.attention = GQA( + embed_dim=hidden_size, + num_heads=num_attention_heads, + num_kv_heads=num_kv_attention_heads, + head_dim=head_dim, + dropout=attn_drop_rate, + max_position_embeddings=max_position_embeddings, + softmax_scale=1 / math.sqrt(head_dim), + causal=True, + layer_idx=layer_idx, + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + rotary_emb_dim=head_dim, + rotary_emb_scale_base=0, + device=device, + dtype=dtype, + qk_interleaved=qk_interleaved, + bias=not no_bias, + rope_base=rope_base, + enable_qkv_fusion=False, + ) + + self.dropout1 = nn.Dropout(drop_rate) + self.dropout2 = nn.Dropout(drop_rate) + self.attention_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset) + self.ffn_norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset) + + sequence_parallel = gpc.config.parallel.get("sequence_parallel", False) + parallel_mode = ParallelMode.WEIGHT if tp_mode == "isp" else ParallelMode.TENSOR + + if use_glu: + self.feed_forward = new_feed_forward( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + bias=False, + device=device, + dtype=dtype, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + activation_type="swiglu" if use_swiglu else "gelu", + ) + else: + self.feed_forward = ParallelFusedMLP( + hidden_size, + int(hidden_size * mlp_ratio), + out_features=hidden_size, + activation="gelu_approx", + process_group=gpc.get_group(parallel_mode), + bias1=False, + bias2=False, + sequence_parallel=sequence_parallel, + checkpoint_lvl=0, + heuristic="auto", + device=device, + dtype=dtype, + ) + + + self.use_glu = use_glu + self.use_swiglu = use_swiglu + self.use_scaled_init = use_scaled_init + self.residual_in_fp32 = residual_in_fp32 # only make sense when using prenorm + self.return_residual = False + + if init_type == "normal": + self.init_func = normal_ + self.scaled_init_func = scaled_init_method_normal + else: + self.init_func = uniform_ + self.scaled_init_func = scaled_init_method_uniform + + self.reset_parameters() + + def reset_parameters(self): + with torch.no_grad(): + for name, param in self.attention.named_parameters(): + if param.ndim == 1: + param.data.zero_() + elif "wq" in name or "wk" in name or "wv" in name: + self.init_func(std=self.attn_wqkv_init_std)(param.data) + elif self.use_scaled_init: # wo + self.scaled_init_func(sigma=self.attn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.attn_other_init_std)(param.data) + + for name, param in self.feed_forward.named_parameters(): + if self.use_glu: + if self.use_scaled_init and "w2" in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func( + std=self.ffn_uplayer_init_std if "w1" in name or "w3" in name else self.ffn_other_init_std + )(param.data) + else: + if self.use_scaled_init and "fc1" not in name: + self.scaled_init_func(sigma=self.ffn_other_init_std, num_layers=self.layer_idx + 1)(param.data) + else: + self.init_func(std=self.ffn_uplayer_init_std if "fc1" in name else self.ffn_other_init_std)( + param.data + ) + + def forward( + self, hidden_states, residual=None, **kwargs): + if self.checkpoint and self.training: + args = convert_attn_kwargs_to_args(kwargs) + return activation_checkpoint(self._forward, False, hidden_states, residual, *args) + else: + return self._forward(hidden_states, residual, **kwargs) + + def _forward(self, hidden_states, residual, *args, **kwargs): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: hidden_states = Attn/MLP(LN(residual)) + cu_seqlens: 1d LongTensor, len(cu_seqlens) = hidden_states + 1 + indexes: the length of index is same as hidden states, which stand for the current position + """ + if self.prenorm: + + def _dropout_and_norm_attn(_residual, _hidden_states): + _dropped = self.dropout1(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.attention_norm(_residual.to(dtype=self.attention_norm.weight.dtype)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint(_dropout_and_norm_attn, False, residual, hidden_states) + else: + residual, hidden_states = _dropout_and_norm_attn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs) + hidden_states = self.attention(hidden_states, **mixer_kwargs) + + if not isinstance(self.feed_forward, nn.Identity): + if not self.fused_dropout_add_ln: + + def _dropout_and_norm_ffn(_residual, _hidden_states): + _dropped = self.dropout2(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.ffn_norm(_residual.to(self.ffn_norm.weight.dtype)) + + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint( + _dropout_and_norm_ffn, False, residual, hidden_states + ) + else: + residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + hidden_states = self.feed_forward(hidden_states) + + return hidden_states + residual + else: + assert residual is None + + mixer_out = self.attention(hidden_states, **kwargs) + if self.return_residual: # mixer out is actually a pair here + mixer_out, hidden_states = mixer_out + hidden_states = self.attention_norm(self.dropout1(mixer_out) + hidden_states).to( + dtype=self.attention_norm.weight.dtype + ) + if not isinstance(self.feed_forward, nn.Identity): + mlp_out = self.feed_forward(hidden_states) + if self.return_residual: # mlp out is actually a pair here + mlp_out, hidden_states = mlp_out + hidden_states = self.ffn_norm((self.dropout2(mlp_out)) + hidden_states).to( + dtype=self.ffn_norm.weight.dtype + ) + return hidden_states + + +class Gemma(BaseModel): + """ + 1D Packed Flash Llama. + + Args: + num_layers (int): The number of layer. 12 by default. + hidden_size (int): The size of hidden state. 768 by default. + num_attention_heads (int): The number of attention head. 12 by default. + head_dim (int): The dimention of attention head dimention. hidden_size divided by num_heads by default. + vocab_size (int): The size of vocabulary. 50304 by default. + mlp_ratio (int): The ratio of MLP layers. 4 by default. + attn_drop_rate (float): The dropout rate of attention module. 0.0 by default. + drop_rate (float): The dropout rate of input hidden state. 0.0 by default. + dtype (torch.dtype): The type of data. torch.float by default. + checkpoint (bool): Whether to use checkpointing to save VRAM. True by default. + checkpoint_fraction (float): The proportion of layers that need to be checkpointed compared to the total number + of layers. 1.0 by default. + layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. + first (bool): Whether input embedding layer or not. False by default. + last (bool): Whether output embedding layer or not. False by default. + embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. + parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. + start_layer_idx (int): The index of start layer in the pipeline. 0 by default. + device (Optional[Union[str, torch.device]]): The device will be used. None by default. + residual_in_fp32 (bool): Whether to use residual in fp32. False by default. + add_unit_offset(bool): Add one to RMSNorm weight multiply by normed input. False by default. + use_glu (bool): Whether to use glu. True by default. + use_swiglu (bool): Whether to use swiglu. True by default. + embedding_init_std (float): std used to init embedding weight. 0.02 by default, + attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, + attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, + ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu + otherwise init fc1 weight in ffn. 0.02 by default, + ffn_other_init_std (float): std used to init ffn_other weight. 0.02 by default, + out_head_init_std (float): std used to init output lmhead weight. 0.02 by default, + init_type (str): Initialization type. Use uniform or normal. "normal" by default, + extra_pred_tokens (int): The number of extra output head for multi-token-prediction. 0 by default. + rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. + multiple_of (int): The value to make SwiGLU hidden layer size multiple of large power of 2. + """ + + def __init__( + self, + num_layers: int = 12, + hidden_size: int = 768, + num_attention_heads: int = 12, + num_kv_attention_heads: int = 12, + head_dim: int = None, + vocab_size: int = 50304, + mlp_ratio: int = 4, + attn_drop_rate: float = 0.0, + drop_rate: float = 0.0, + max_position_embeddings: int = 2048, + dtype: torch.dtype = torch.float, + checkpoint: float = 1.0, + layer_norm_epsilon: float = 1e-5, + first: bool = False, + last: bool = False, + embed_grad_scale: float = 0.1, + parallel_output: bool = True, + start_layer_idx: int = 0, + use_dynamic_ntk_rope: bool = False, + device: Optional[torch.device] = None, + apply_post_layer_norm=False, + no_bias=False, + residual_in_fp32: bool = False, + norm_type: str = "rmsnorm", + qk_interleaved: bool = False, + add_unit_offset: bool = False, + is_reward: bool = False, + dropout_selective_checkpoint: bool = True, + use_scaled_init: bool = True, + use_glu: bool = True, + use_swiglu: bool = False, + embedding_init_std: float = 0.02, + attn_wqkv_init_std: float = 0.02, + attn_other_init_std: float = 0.02, + ffn_uplayer_init_std: float = 0.02, + ffn_other_init_std: float = 0.02, + out_head_init_std: float = 0.02, + init_type: str = "normal", + extra_pred_tokens: int = 0, + rope_base: int = 10000, + norm_head: bool = False, + mlp_layer_fusion: bool = False, + multiple_of: int = 256, + ): + super().__init__() + + checkpoint_layer_num = int(num_layers * checkpoint) + self.hidden_size = hidden_size + self.embed_grad_scale = embed_grad_scale + self.parallel_output = parallel_output + self.tp_mode = "mtp" + if isinstance(gpc.config.parallel["tensor"], dict): + self.tp_mode = gpc.config.parallel["tensor"].get("mode", "mtp") + + if first: + self.tok_embeddings = Embedding1D(num_embeddings=vocab_size, embedding_dim=hidden_size) + for _, param in self.tok_embeddings.named_parameters(): + if init_type == "normal": + normal_(std=embedding_init_std)(param) + else: + uniform_(std=embedding_init_std)(param) + + self.layers = nn.ModuleList( + [ + GemmaDecoder( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_kv_attention_heads=num_kv_attention_heads, + head_dim=head_dim, + mlp_ratio=mlp_ratio, + attn_drop_rate=attn_drop_rate, + drop_rate=drop_rate, + max_position_embeddings=max_position_embeddings, + dtype=dtype, + layer_norm_epsilon=layer_norm_epsilon, + checkpoint=lid < checkpoint_layer_num, + layer_idx=lid + start_layer_idx, # This parameter is used for caching during generation + use_dynamic_ntk_rope=use_dynamic_ntk_rope, + residual_in_fp32=residual_in_fp32, + device=device, + apply_post_layer_norm=apply_post_layer_norm, + fused_dropout_add_ln=False, + no_bias=no_bias, + norm_type=norm_type, + add_unit_offset=add_unit_offset, + dropout_selective_checkpoint=dropout_selective_checkpoint, + use_scaled_init=use_scaled_init, + use_glu=use_glu, + use_swiglu=use_swiglu, + qk_interleaved=qk_interleaved, + attn_wqkv_init_std=attn_wqkv_init_std, + attn_other_init_std=attn_other_init_std, + ffn_uplayer_init_std=ffn_uplayer_init_std, + ffn_other_init_std=ffn_other_init_std, + init_type=init_type, + rope_base=rope_base, + mlp_layer_fusion=mlp_layer_fusion, + multiple_of=multiple_of, + tp_mode=self.tp_mode, + ) + for lid in range(num_layers) + ] + ) + + if last: + if not apply_post_layer_norm: + self.norm = new_layer_norm(norm_type, hidden_size, eps=layer_norm_epsilon, add_unit_offset=add_unit_offset) + + self.output = new_linear( + name="output", + in_features=hidden_size, + out_features=gpc.get_world_size(ParallelMode.TENSOR) if is_reward else vocab_size, + bias=False, + device=device, + is_reward=is_reward, + dtype=dtype, + weight_scale=embed_grad_scale, + norm_head=norm_head, + ) + for _, param in self.output.named_parameters(): + if init_type == "normal": + normal_(std=out_head_init_std)(param) + else: + uniform_(std=out_head_init_std)(param) + + if extra_pred_tokens > 0: + self.extra_pred_tokens = extra_pred_tokens + assert not is_reward, "extra_pred_tokens > 0 means using multi token prediction, not implement for RLHF" + self.extra_outputs = nn.ModuleList( + [ + new_linear( + name="output", + in_features=hidden_size, + out_features=vocab_size, + bias=False, + device=device, + is_reward=is_reward, + dtype=dtype, + weight_scale=embed_grad_scale, + norm_head=norm_head, + ) + for _ in range(self.extra_pred_tokens) + ] + ) + for _, param in self.extra_outputs.named_parameters(): + if init_type == "normal": + normal_(std=out_head_init_std)(param) + else: + uniform_(std=out_head_init_std)(param) + + def forward(self, hidden_states=None, input_ids=None, **kwargs): + # attention_mask: compute attention on the places where the value is 1 + if hasattr(self, "tok_embeddings"): + hidden_states = self.tok_embeddings(input_ids) + if self.embed_grad_scale != 1: + hidden_states = ( + self.embed_grad_scale * hidden_states + (1 - self.embed_grad_scale) * hidden_states.detach() + ) + hidden_states = hidden_states * (self.hidden_size**0.5) + + for _, block in enumerate(self.layers): + hidden_states = block(hidden_states, residual=None, **kwargs) + + if hasattr(self, "norm"): + hidden_states = self.norm(hidden_states.to(self.norm.weight.dtype)) + if hasattr(self, "extra_pred_tokens") and self.extra_pred_tokens > 0: + extra_hidden_states_list = [self.extra_outputs[i](hidden_states) for i in range(self.extra_pred_tokens)] + else: + extra_hidden_states_list = None + if hasattr(self, "output"): + hidden_states = self.output(hidden_states) + + if extra_hidden_states_list is not None: + return (hidden_states, extra_hidden_states_list) + + return hidden_states + + @staticmethod + def load_hf_weights(folder: str, model: nn.Module) -> None: + raise NotImplementedError + + @staticmethod + def convert_internevo2hf_weights(src: str, tgt: str) -> None: + raise NotImplementedError \ No newline at end of file diff --git a/internlm/model/modeling_qwen2.py b/internlm/model/modeling_qwen2.py index 363962f2..6ac8127a 100644 --- a/internlm/model/modeling_qwen2.py +++ b/internlm/model/modeling_qwen2.py @@ -1,10 +1,14 @@ # Copyright (c) InternLM. All rights reserved. import math +import os from typing import Optional import torch +from einops import rearrange from torch import nn +from tqdm import tqdm +from internlm.accelerator import get_accelerator from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import global_context as gpc from internlm.initialize.initialize_tensor import ( @@ -13,6 +17,7 @@ scaled_init_method_uniform, uniform_, ) +from internlm.model.base_model import BaseModel from internlm.model.modules.embedding import Embedding1D from internlm.model.modules.linear import new_linear from internlm.model.modules.mha import SWA @@ -24,7 +29,14 @@ ) from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger +from internlm.utils.storage_manager import get_fns, llm_load, llm_save +from transformers.modeling_utils import ( + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + shard_checkpoint, +) +internlm_accelerator = get_accelerator() logger = get_logger(__file__) @@ -45,7 +57,6 @@ class Qwen2Decoder(nn.Module): residual_in_fp32 (bool): Whether to use residual in fp32. False by default. device (Optional[Union[str, torch.device]]): The device will be used. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. - use_flash_attn (bool): Whether use flash-attn. True by default. attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, ffn_uplayer_init_std (float): std used to init w1, w2 weight in ffn when using glu @@ -79,11 +90,9 @@ def __init__( mlp_bias=False, norm_type: str = "rmsnorm", qk_interleaved: bool = False, - adapt_hf: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, ffn_uplayer_init_std: float = 0.02, @@ -104,7 +113,6 @@ def __init__( # dropout selective checkpoint can only be enabled when checkpoint is disabled. self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False self.layer_idx = layer_idx - self.use_flash_attn = use_flash_attn self.prenorm = not apply_post_layer_norm assert not fused_dropout_add_ln, "dropout_add_layer_norm can not be used here" self.fused_dropout_add_ln = fused_dropout_add_ln @@ -131,11 +139,9 @@ def __init__( use_dynamic_ntk_rope=use_dynamic_ntk_rope, rotary_emb_dim=head_dim, rotary_emb_scale_base=0, - use_flash_attn=use_flash_attn, device=device, dtype=dtype, qk_interleaved=qk_interleaved, - rot_embed_HF_impl=adapt_hf, qkv_bias=qkv_bias, o_bias=o_bias, rope_type=rope_type, @@ -160,6 +166,7 @@ def __init__( dtype=dtype, mlp_layer_fusion=mlp_layer_fusion, multiple_of=multiple_of, + activation_type="swiglu" if use_swiglu else "gelu", ) self.use_swiglu = use_swiglu @@ -212,7 +219,7 @@ def forward(self, hidden_states, residual=None, **kwargs): else: return self._forward(hidden_states, residual, **kwargs) - def _forward(self, hidden_states=None, residual=None, *args, **kwargs): # pylint: disable=W1113 + def _forward(self, hidden_states, residual, *args, **kwargs): r"""Pass the input through the encoder layer. Args: @@ -282,7 +289,7 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): return hidden_states -class Qwen2(nn.Module): +class Qwen2(BaseModel): """ 1D Packed Flash Qwen. @@ -299,15 +306,12 @@ class Qwen2(nn.Module): layer_norm_epsilon (float): A value added to the denominator for numerical stability. 1e-6 by default. first (bool): Whether input embedding layer or not. False by default. last (bool): Whether output embedding layer or not. False by default. - embed_split_hidden (bool): Split the embedding layer in the hidden state dimention or vocabulary dimention. - True by default. embed_grad_scale (float): Refer to GLM-130B, for training stability. 0.1 by default. parallel_output (bool): If it is necessary to collect the output of parallel computing. True by default. start_layer_idx (int): The index of start layer in the pipeline. 0 by default. device (Optional[Union[str, torch.device]]): The device will be used. None by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. - use_flash_attn (bool): Whether to use flash-attn. True by default. embedding_init_std (float): std used to init embedding weight. 0.02 by default, attn_wqkv_init_std (float): std used to init attn_wqkv weight. 0.02 by default, attn_other_init_std (float): std used to init attn_other weight. 0.02 by default, @@ -348,13 +352,11 @@ def __init__( mlp_bias=False, residual_in_fp32: bool = False, norm_type: str = "rmsnorm", - adapt_hf: bool = False, qk_interleaved: bool = False, is_reward: bool = False, dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, - use_flash_attn: bool = True, embedding_init_std: float = 0.02, attn_wqkv_init_std: float = 0.02, attn_other_init_std: float = 0.02, @@ -376,7 +378,6 @@ def __init__( ): super().__init__() - self.use_flash_attn = use_flash_attn self.embed_grad_scale = embed_grad_scale checkpoint_layer_num = int(num_layers * checkpoint) @@ -414,9 +415,7 @@ def __init__( dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, - use_flash_attn=use_flash_attn, qk_interleaved=qk_interleaved, - adapt_hf=adapt_hf, attn_wqkv_init_std=attn_wqkv_init_std, attn_other_init_std=attn_other_init_std, ffn_uplayer_init_std=ffn_uplayer_init_std, @@ -513,3 +512,11 @@ def forward(self, hidden_states=None, input_ids=None, **kwargs): return (hidden_states, extra_hidden_states_list) return hidden_states + + @staticmethod + def load_hf_weights(folder: str, model: nn.Module) -> None: + raise NotImplementedError + + @staticmethod + def convert_internevo2hf_weights(src: str, tgt: str) -> None: + raise NotImplementedError \ No newline at end of file diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 9115bcc8..9b601a3e 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -355,6 +355,7 @@ def __init__( num_heads: int, num_kv_heads: int, max_position_embeddings: int = 2048, + head_dim: int = None, bias: bool = False, dropout: float = 0.0, softmax_scale: float = None, @@ -375,9 +376,15 @@ def __init__( self.embed_dim = embed_dim self.num_heads = num_heads + + if head_dim: + self.head_dim = head_dim + q_dim = head_dim * num_heads + else: + self.head_dim = self.embed_dim // num_heads + q_dim = embed_dim self.num_kv_heads = num_kv_heads self.q_per_kv = num_heads // num_kv_heads - self.head_dim = self.embed_dim // num_heads self.kv_dim = self.head_dim * num_kv_heads self.enable_qkv_fusion = enable_qkv_fusion @@ -405,7 +412,7 @@ def __init__( if enable_qkv_fusion: self.wqkv = new_linear("wqkv", embed_dim, embed_dim + 2 * self.kv_dim, bias, **factory_kwargs) else: - self.wq = new_linear("wq", embed_dim, embed_dim, bias, **factory_kwargs) + self.wq = new_linear("wq", embed_dim, q_dim, bias, **factory_kwargs) self.wk = new_linear("wk", embed_dim, self.kv_dim, bias, **factory_kwargs) self.wv = new_linear("wv", embed_dim, self.kv_dim, bias, **factory_kwargs) @@ -416,7 +423,7 @@ def __init__( causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout, layer_idx=layer_idx ) - self.wo = new_linear("wo", embed_dim, embed_dim, bias, **factory_kwargs) + self.wo = new_linear("wo", q_dim, embed_dim, bias, **factory_kwargs) def register_checkpoint_compatibility_hooks( self, pre_load_hook: Optional[Callable] = None, pre_save_hook: Optional[Callable] = None @@ -649,11 +656,8 @@ class SWA(nn.Module): rotary_emb_dim (int): The dimention of Rotary Embedding. 0 by default. rotary_emb_scale_base (int): The scaling factor of Rotary Embedding. If scale_base > 0, this implements XPos(Sun et al., https://arxiv.org/abs/2212.10554). 0 by default. - use_flash_attn (boolean): Whether to use flash attention or not.If False, vanilla attention module will be used. - False by default. device (Optional[Union[str, torch.device]]): The device will be used. dtype (Optional[torch.dtype]): The type of data. - use_flash_attn (bool): Whether to use flash-attn. True by default. rope_base (int): The value of `base` for rotary position embeddings. 10000 by default. tp_mode (str): The string value of tensor parallel mode, should be in ["mtp", "msp", "fsp", "isp"], "mtp" by default. @@ -678,10 +682,8 @@ def __init__( rope_scaling_factor: float = 1.0, rotary_emb_dim: int = 0, rotary_emb_scale_base: int = 0, - use_flash_attn: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, - rot_embed_HF_impl: Optional[bool] = False, use_sliding_window: bool = False, sliding_window: int = None, tp_mode: str = "mtp", @@ -704,7 +706,6 @@ def __init__( self.layer_idx = layer_idx self.use_dynamic_ntk_rope = use_dynamic_ntk_rope self.rotary_emb_dim = rotary_emb_dim - self.use_flash_attn = use_flash_attn self.use_sliding_window = use_sliding_window self.sliding_window = sliding_window self.dtype = dtype @@ -712,7 +713,6 @@ def __init__( self.rope_type = rope_type self.use_logn_attn = use_logn_attn self.interleaved = qk_interleaved - self.rot_embed_HF_impl = rot_embed_HF_impl factory_kwargs = {"device": device, "dtype": dtype} @@ -780,13 +780,6 @@ def _training(self, x, **kwargs): k = rearrange(k, "b t (h d) -> b t h d", d=self.head_dim) v = rearrange(v, "b t (h d) -> b t h d", d=self.head_dim) - # qkv shift - # the rotary embedding in flash attention module in performed by separating the front and back parts, while - # most of others are done by odd-even methods. - if not self.rot_embed_HF_impl: - q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) - k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) - kv_seq_len = k.size(0) use_window_circumstance = ( _flash_supports_window_size @@ -863,10 +856,6 @@ def _inference(self, x, inference_params=None, **kwargs): # pylint: disable=W06 k = rearrange(k, "b s (h d) -> b s h d", d=self.head_dim) v = rearrange(v, "b s (h d) -> b s h d", d=self.head_dim) - if not self.rot_embed_HF_impl: - q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1) - k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1) - kv_seq_len = k.size(0) use_window_circumstance = ( _flash_supports_window_size diff --git a/internlm/model/modules/mlp.py b/internlm/model/modules/mlp.py index 897e1363..dfdbc5bf 100644 --- a/internlm/model/modules/mlp.py +++ b/internlm/model/modules/mlp.py @@ -3,11 +3,12 @@ from typing import Dict, Optional +from internlm.utils.utils import ActivationType import torch from torch import nn from internlm.model.modules.linear import new_linear -from internlm.model.modules.utils import Silu +from internlm.model.modules.utils import Gelu, Silu from internlm.utils.logger import get_logger logger = get_logger(__file__) @@ -71,10 +72,10 @@ def __init__( ): super().__init__() - # TODO: support gelu... - assert activation_type in ("swiglu"), f"Unsupported activation type: {activation_type}" + assert activation_type in (ActivationType.swiglu.name, ActivationType.gelu.name), f"Unsupported activation type: {activation_type}" self.mlp_layer_fusion = mlp_layer_fusion + self.activation_type = activation_type hidden_features = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) @@ -98,7 +99,12 @@ def forward(self, x): else: fussed_out = self.fused_w1_w3(x) w1_o, w3_o = torch.split(fussed_out, fussed_out.shape[-1] // 2, dim=-1) - out = self.w2(Silu(w1_o, w3_o)) + + if self.activation_type is ActivationType.swiglu.name: + out = self.w2(Silu(w1_o, w3_o)) + else: + out = self.w2(Gelu(w1_o, w3_o)) + return out diff --git a/internlm/model/modules/norm.py b/internlm/model/modules/norm.py index b94cdd43..70717ff9 100644 --- a/internlm/model/modules/norm.py +++ b/internlm/model/modules/norm.py @@ -2,6 +2,7 @@ layer norm modules """ +import inspect from typing import List, Union import torch @@ -12,8 +13,12 @@ Shape = Union[int, List[int], torch.Size] -def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5): +def new_layer_norm(norm_type: str, normalized_shape: Shape, eps: float = 1e-5, add_unit_offset=False): if norm_type == "rmsnorm": - return RMSNorm(normalized_shape, eps) + rmsnorm_params = inspect.signature(RMSNorm).parameters + if 'add_unit_offset' in rmsnorm_params: + return RMSNorm(normalized_shape, eps, add_unit_offset) + else: + return RMSNorm(normalized_shape, eps) else: # default: layernorm return nn.LayerNorm(normalized_shape, eps) diff --git a/internlm/model/modules/utils.py b/internlm/model/modules/utils.py index dd86cb1c..bf1ae048 100644 --- a/internlm/model/modules/utils.py +++ b/internlm/model/modules/utils.py @@ -20,7 +20,12 @@ def Silu(w1_o, w2_o): return F.silu(w1_o) * w2_o +def Gelu(w1_o, w2_o): + return F.gelu(w1_o) * w2_o + + Silu = torch.jit.script(Silu) +Gelu = torch.jit.script(Gelu) def update_kv_cache(kv, inference_params, layer_idx): diff --git a/internlm/model/ops/norm.py b/internlm/model/ops/norm.py index 3cd43dab..34e7c007 100644 --- a/internlm/model/ops/norm.py +++ b/internlm/model/ops/norm.py @@ -35,7 +35,7 @@ torchnpu_rmsnorm_impl = False -def manual_rms_norm(my_input, weight, normalized_shape, eps): +def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=False): # layer norm should always be calculated in float32 dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1)) variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True) @@ -48,13 +48,16 @@ def manual_rms_norm(my_input, weight, normalized_shape, eps): if weight.dtype in [torch.float16, torch.bfloat16]: my_input = my_input.to(weight.dtype) - return weight * my_input + if add_unit_offset: + return (1 + weight) * my_input + else: + return weight * my_input class _RMSNorm(torch.nn.Module): """A generic module for RMS normalization.""" - def __init__(self, normalized_shape, eps=1e-5): + def __init__(self, normalized_shape, eps=1e-5, add_unit_offset=False): super().__init__() if isinstance(normalized_shape, numbers.Integral): @@ -62,18 +65,22 @@ def __init__(self, normalized_shape, eps=1e-5): self.normalized_shape = torch.Size(normalized_shape) self.eps = eps self.weight = Parameter(torch.empty(*normalized_shape)) + self.add_unit_offset = add_unit_offset self.reset_parameters() def forward(self, _input: torch.Tensor): if apex_rmsnorm_impl: _norm_func = mixed_dtype_fused_rms_norm_affine + return _norm_func(_input, self.weight, self.normalized_shape, self.eps) else: _norm_func = manual_rms_norm - - return _norm_func(_input, self.weight, self.normalized_shape, self.eps) + return _norm_func(_input, self.weight, self.normalized_shape, self.eps, self.add_unit_offset) def reset_parameters(self): - init.ones_(self.weight) + if self.add_unit_offset: + init.zeros_(self.weight) + else: + init.ones_(self.weight) def extra_repr(self): return f"{self.normalized_shape}, eps={self.eps}, " diff --git a/internlm/model/registry.py b/internlm/model/registry.py index 4cecfb5a..c923ec20 100644 --- a/internlm/model/registry.py +++ b/internlm/model/registry.py @@ -4,6 +4,7 @@ from typing import Callable from internlm.model.modeling_baichuan2 import Baichuan2 +from internlm.model.modeling_gemma import Gemma from internlm.model.modeling_internlm import InternLM1 from internlm.model.modeling_internlm2 import InternLM2 from internlm.model.modeling_llama import Llama2 @@ -87,6 +88,7 @@ def register_model_initializer() -> None: model_initializer.register_module(ModelType.LLAVA.name, Llava) model_initializer.register_module(ModelType.QWEN2.name, Qwen2) model_initializer.register_module(ModelType.BAICHUAN2.name, Baichuan2) + model_initializer.register_module(ModelType.GEMMA.name, Gemma) register_model_initializer() diff --git a/internlm/utils/utils.py b/internlm/utils/utils.py index 4150ec60..ac2b62ae 100644 --- a/internlm/utils/utils.py +++ b/internlm/utils/utils.py @@ -49,6 +49,7 @@ class ModelType(Enum): LLAVA = 5 QWEN2 = 6 BAICHUAN2 = 7 + GEMMA = 8 class DataType(Enum): @@ -63,6 +64,10 @@ class TensorParallelMode(Enum): isp = 4 +class ActivationType(Enum): + swiglu = 1 + gelu = 2 + def check_attention_argument(*args, **kwargs) -> str: # self, qkv, ... # self, q, kv, ....