Skip to content

Commit

Permalink
add compatability for non-transformer based huggingface models
Browse files Browse the repository at this point in the history
  • Loading branch information
zigzagcai committed Sep 29, 2024
1 parent d2f5a61 commit 88595a8
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions internlm/train/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,18 +891,21 @@ def traverse(module):


def inject_config(model: nn.Module) -> None:
# Compatibility for Vision-Language Model
if hasattr(model.config, "text_config"):
model_config = model.config.text_config
llm_cfg = model.config.text_config
else:
model_config = model.config
gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = model_config.vocab_size
gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = model_config.hidden_size
gpc.config.model.num_layers = gpc.config.NUM_LAYER = model_config.num_hidden_layers
gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = model_config.num_attention_heads
gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = model_config.intermediate_size / model_config.hidden_size
llm_cfg = model.config
gpc.config.model.vocab_size = gpc.config.VOCAB_SIZE = llm_cfg.vocab_size
gpc.config.model.hidden_size = gpc.config.HIDDEN_SIZE = llm_cfg.hidden_size
gpc.config.model.num_layers = gpc.config.NUM_LAYER = llm_cfg.num_hidden_layers
# Compatibility for Mamba
if hasattr(llm_cfg, "num_attention_heads"):
gpc.config.model.num_attention_heads = gpc.config.NUM_ATTENTION_HEAD = llm_cfg.num_attention_heads
gpc.config.model.mlp_ratio = gpc.config.MLP_RATIO = llm_cfg.intermediate_size / llm_cfg.hidden_size
# For models that use GQA
if hasattr(model_config, "num_key_value_heads"):
gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = model_config.num_key_value_heads
if hasattr(llm_cfg, "num_key_value_heads"):
gpc.config.model.num_kv_attention_heads = gpc.config.NUM_KV_ATTENTION_HEAD = llm_cfg.num_key_value_heads


def inject_model_helper(model: Union[nn.Module, nn.ModuleList], inject_info: Optional[Dict] = None) -> None:
Expand Down

0 comments on commit 88595a8

Please sign in to comment.