Skip to content

Commit

Permalink
compiled RMSNorm
Browse files Browse the repository at this point in the history
ghstack-source-id: c4efb81ec6acc5442955908cc376df3e6d889af3
Pull Request resolved: #442
  • Loading branch information
tianyu-l committed Jul 10, 2024
1 parent bc3ec02 commit 7afe902
Show file tree
Hide file tree
Showing 10 changed files with 29 additions and 14 deletions.
6 changes: 5 additions & 1 deletion estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,12 @@ def estimate_memory(job_config: JobConfig):
)
job_config.model.norm_type = "rmsnorm"

if job_config.model.norm_type == "compiled_rmsnorm":
logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.")
job_config.model.norm_type = "rmsnorm"

if job_config.training.compile:
logger.info("Compile mode is not supported yet. " "Switching to Eager mode.")
logger.info("Compile mode is not supported yet. Switching to eager mode.")
job_config.training.compile = False

parallel_dims = ParallelDims(
Expand Down
2 changes: 1 addition & 1 deletion test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def build_test_list():
OverrideDefinitions(
[
[
"--memory_estimation.enabled",
"--memory_estimation.enabled --model.norm_type rmsnorm",
]
],
"FSDP2 Memory Tracking and Estimation",
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self):
"--model.norm_type",
type=str,
default="rmsnorm",
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]",
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, compiled_rmsnorm, fused_rmsnorm]",
)
self.parser.add_argument(
"--model.tokenizer_path",
Expand Down
21 changes: 16 additions & 5 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ def create_norm(norm_type: str, dim: int, eps: float = 1e-6):
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
elif norm_type == "rmsnorm":
return RMSNorm(dim, eps=eps)
elif norm_type == "compiled_rmsnorm":
return RMSNorm(dim, eps=eps, compile=True)
elif norm_type == "fused_rmsnorm":
return FusedRMSNorm(dim, eps=eps)
else:
Expand Down Expand Up @@ -87,17 +89,26 @@ class RMSNorm(nn.Module):
"""

def __init__(self, dim: int, eps: float = 1e-6):
def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.rmsnorm_fn = (
torch.compile(self.compute_rmsnorm, fullgraph=True)
if compile
else self.compute_rmsnorm
)

@staticmethod
def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float):
def _norm(x, eps):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)

def _norm(self, x: torch.Tensor):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
output = _norm(x.float(), eps).type_as(x)
return output * weight

def forward(self, x: torch.Tensor):
output = self._norm(x.float()).type_as(x)
return output * self.weight
return self.rmsnorm_fn(x, self.weight, self.eps)

def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "debugmodel"
norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
norm_type = "compiled_rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
# test tokenizer.model, for debug purpose only
tokenizer_path = "./test/assets/test_tiktoken.model"

Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_13b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ save_tb_folder = "tb"
[model]
name = "llama2"
flavor = "13B"
norm_type = "fused_rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model"

[optimizer]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ save_tb_folder = "tb"
[model]
name = "llama2"
flavor = "70B"
norm_type = "rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model"

[optimizer]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama2_7b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ save_tb_folder = "tb"
[model]
name = "llama2"
flavor = "7B"
norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model"

[optimizer]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_70b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "70B"
norm_type = "rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
Expand Down
2 changes: 1 addition & 1 deletion train_configs/llama3_8b.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "8B"
norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm]
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model"

[optimizer]
Expand Down

0 comments on commit 7afe902

Please sign in to comment.