Skip to content

Commit

Permalink
Add warning to compile rmsnorm
Browse files Browse the repository at this point in the history
as titled, add warning to compile rmsnorm as it's not fully ready yet,
i.e. this issue #497

We can remove this warning once we fix the issue
  • Loading branch information
wanchaol committed Aug 6, 2024
1 parent a4d88d1 commit da18116
Show file tree
Hide file tree
Showing 8 changed files with 11 additions and 7 deletions.
4 changes: 4 additions & 0 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
elif norm_type == "rmsnorm":
return RMSNorm(dim, eps=eps)
elif norm_type == "compiled_rmsnorm":
import warnings
warnings.warn(
"compiled_rmsnorm is currently experimental and not ready to use yet."
)
return RMSNorm(dim, eps=eps, compile=True)
elif norm_type == "fused_rmsnorm":
return FusedRMSNorm(dim, eps=eps)
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "debugmodel"
norm_type = "compiled_rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
# test tokenizer.model, for debug purpose only
tokenizer_path = "./test/assets/test_tiktoken.model"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_13b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ save_tb_folder = "tb"
[model]
name = "llama2"
flavor = "13B"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model"

[optimizer]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ save_tb_folder = "tb"
[model]
name = "llama2"
flavor = "70B"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model"

[optimizer]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_7b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ save_tb_folder = "tb"
[model]
name = "llama2"
flavor = "7B"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model"

[optimizer]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_405b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "405B"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "70B"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
Expand Down

0 comments on commit da18116

Please sign in to comment.