diff --git a/estimation.py b/estimation.py index e82a7b71..ddf24d8a 100644 --- a/estimation.py +++ b/estimation.py @@ -57,8 +57,12 @@ def estimate_memory(job_config: JobConfig): ) job_config.model.norm_type = "rmsnorm" + if job_config.model.norm_type == "compiled_rmsnorm": + logger.info("Compiled RMSNorm is not supported yet. Switching to RMSNorm.") + job_config.model.norm_type = "rmsnorm" + if job_config.training.compile: - logger.info("Compile mode is not supported yet. " "Switching to Eager mode.") + logger.info("Compile mode is not supported yet. Switching to eager mode.") job_config.training.compile = False parallel_dims = ParallelDims( diff --git a/test_runner.py b/test_runner.py index cba63544..319f99d7 100755 --- a/test_runner.py +++ b/test_runner.py @@ -266,7 +266,7 @@ def build_test_list(): OverrideDefinitions( [ [ - "--memory_estimation.enabled", + "--memory_estimation.enabled --model.norm_type rmsnorm", ] ], "FSDP2 Memory Tracking and Estimation", diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index bfea8f83..3ade1b9d 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -165,7 +165,7 @@ def __init__(self): "--model.norm_type", type=str, default="rmsnorm", - help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]", + help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, compiled_rmsnorm, fused_rmsnorm]", ) self.parser.add_argument( "--model.tokenizer_path", diff --git a/torchtitan/models/norms.py b/torchtitan/models/norms.py index 4245fe41..10a6b853 100644 --- a/torchtitan/models/norms.py +++ b/torchtitan/models/norms.py @@ -42,6 +42,8 @@ def create_norm(norm_type: str, dim: int, eps: float = 1e-6): return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False) elif norm_type == "rmsnorm": return RMSNorm(dim, eps=eps) + elif norm_type == "compiled_rmsnorm": + return RMSNorm(dim, eps=eps, compile=True) elif norm_type == "fused_rmsnorm": return FusedRMSNorm(dim, eps=eps) else: @@ -87,17 +89,26 @@ class RMSNorm(nn.Module): """ - def __init__(self, dim: int, eps: float = 1e-6): + def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) + self.rmsnorm_fn = ( + torch.compile(self.compute_rmsnorm, fullgraph=True) + if compile + else self.compute_rmsnorm + ) + + @staticmethod + def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float): + def _norm(x, eps): + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) - def _norm(self, x: torch.Tensor): - return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + output = _norm(x.float(), eps).type_as(x) + return output * weight def forward(self, x: torch.Tensor): - output = self._norm(x.float()).type_as(x) - return output * self.weight + return self.rmsnorm_fn(x, self.weight, self.eps) def reset_parameters(self): torch.nn.init.ones_(self.weight) # type: ignore diff --git a/train_configs/debug_model.toml b/train_configs/debug_model.toml index 1b4b3539..cb2fb215 100644 --- a/train_configs/debug_model.toml +++ b/train_configs/debug_model.toml @@ -21,7 +21,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "debugmodel" -norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm +norm_type = "compiled_rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm # test tokenizer.model, for debug purpose only tokenizer_path = "./test/assets/test_tiktoken.model" diff --git a/train_configs/llama2_13b.toml b/train_configs/llama2_13b.toml index 719fc445..05e3c27b 100644 --- a/train_configs/llama2_13b.toml +++ b/train_configs/llama2_13b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama2" flavor = "13B" -norm_type = "fused_rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm] +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model" [optimizer] diff --git a/train_configs/llama2_70b.toml b/train_configs/llama2_70b.toml index c8ec9595..5b2dd493 100644 --- a/train_configs/llama2_70b.toml +++ b/train_configs/llama2_70b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama2" flavor = "70B" -norm_type = "rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm] +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model" [optimizer] diff --git a/train_configs/llama2_7b.toml b/train_configs/llama2_7b.toml index 7e2196fb..9b72246a 100644 --- a/train_configs/llama2_7b.toml +++ b/train_configs/llama2_7b.toml @@ -17,7 +17,7 @@ save_tb_folder = "tb" [model] name = "llama2" flavor = "7B" -norm_type = "fused_rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm tokenizer_path = "./torchtitan/datasets/tokenizer/tokenizer.model" [optimizer] diff --git a/train_configs/llama3_70b.toml b/train_configs/llama3_70b.toml index 218f3783..93b529f6 100644 --- a/train_configs/llama3_70b.toml +++ b/train_configs/llama3_70b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "70B" -norm_type = "rmsnorm" # [layernorm, np_layernorm, rmsnorm, fused_rmsnorm] +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer] diff --git a/train_configs/llama3_8b.toml b/train_configs/llama3_8b.toml index 2fb89004..95a53d56 100644 --- a/train_configs/llama3_8b.toml +++ b/train_configs/llama3_8b.toml @@ -18,7 +18,7 @@ save_tb_folder = "tb" [model] name = "llama3" flavor = "8B" -norm_type = "rmsnorm" # [layernorm / np_layernorm / rmsnorm / fused_rmsnorm] +norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm tokenizer_path = "./torchtitan/datasets/tokenizer/original/tokenizer.model" [optimizer]