Skip to content

Commit

Permalink
remove compiled_rmsnorm
Browse files Browse the repository at this point in the history
ghstack-source-id: ceb4fa54121be241633daf06a0ca2eb407667274
Pull Request resolved: #535
  • Loading branch information
tianyu-l committed Aug 20, 2024
1 parent b76d755 commit 40210ea
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 128 deletions.
200 changes: 96 additions & 104 deletions test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,103 @@ def build_test_list():
"""
integration_tests_flavors = defaultdict(list)
integration_tests_flavors["debug_model.toml"] = [
OverrideDefinitions(
[
[],
],
"default",
"default",
),
OverrideDefinitions(
[
[
"--training.compile",
],
],
"1D compile",
"1d_compile",
),
OverrideDefinitions(
[
[
"--training.compile",
"--activation_checkpoint.mode selective",
"--activation_checkpoint.selective_ac_option op",
],
],
"1D compile with selective op AC",
"1d_compile_sac_op",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2",
],
],
"2D eager",
"2d_eager",
),
OverrideDefinitions(
[
[
"--training.compile",
"--training.tensor_parallel_degree 2",
],
],
"2D compile",
"2d_compile",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2",
"--model.norm_type=fused_rmsnorm",
],
],
"2D eager with fused_rmsnorm",
"2d_eager_fused_rmsnorm",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
],
[
"--checkpoint.enable_checkpoint",
"--training.steps 20",
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.model_weights_only",
],
],
"Checkpoint Integration Test - Save Model Weights Only fp32",
"model_weights_only_fp32",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.model_weights_only",
"--checkpoint.export_dtype bfloat16",
],
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
"model_weights_only_bf16",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule flexible_interleaved_1f1b",
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
],
],
"PP looped flexible 1f1b test",
Expand All @@ -69,7 +158,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
],
],
"PP 1D test 1f1b",
Expand All @@ -85,7 +173,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 1",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
],
],
"PP 1D test gpipe",
Expand All @@ -101,7 +188,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule 1f1b",
"--training.data_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
],
],
"PP+DP 1f1b 2D test",
Expand All @@ -116,7 +202,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_schedule gpipe",
"--training.data_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm crashes with PP
],
],
"PP+DP gpipe 2D test",
Expand All @@ -130,7 +215,6 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
],
],
"PP+TP 2D test",
Expand All @@ -144,102 +228,13 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 2",
"--experimental.pipeline_parallel_split_points layers.4",
"--experimental.pipeline_parallel_split_mode tracer",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with tracer
],
],
"PP tracer frontend test",
"pp_tracer",
requires_seed_checkpoint=True,
ngpu=2,
),
OverrideDefinitions(
[
[],
],
"default",
"default",
),
OverrideDefinitions(
[
[
"--training.compile --model.norm_type=rmsnorm",
],
],
"1D compile",
"1d_compile",
),
OverrideDefinitions(
[
[
"--training.compile",
"--activation_checkpoint.mode selective",
"--activation_checkpoint.selective_ac_option op",
],
],
"1D compile with selective op AC",
"1d_compile_sac_op",
),
OverrideDefinitions(
[
[
"--training.compile --training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
],
],
"2D compile",
"2d_compile",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2 --model.norm_type=rmsnorm",
],
],
"Eager mode 2DParallel with rmsnorm",
"eager_2d_rmsnorm",
),
OverrideDefinitions(
[
[
"--training.tensor_parallel_degree 2 --model.norm_type=fused_rmsnorm",
],
],
"Eager mode 2DParallel with fused_rmsnorm",
"eager_2d_fused_rmsnorm",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
],
[
"--checkpoint.enable_checkpoint",
"--training.steps 20",
],
],
"Checkpoint Integration Test - Save Load Full Checkpoint",
"full_checkpoint",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.model_weights_only",
],
],
"Checkpoint Integration Test - Save Model Weights Only fp32",
"model_weights_only_fp32",
),
OverrideDefinitions(
[
[
"--checkpoint.enable_checkpoint",
"--checkpoint.model_weights_only",
"--checkpoint.export_dtype bfloat16",
],
],
"Checkpoint Integration Test - Save Model Weights Only bf16",
"model_weights_only_bf16",
),
OverrideDefinitions(
[
[
Expand All @@ -248,7 +243,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
],
[
"--training.steps 20",
Expand All @@ -257,7 +251,6 @@ def build_test_list():
"--experimental.pipeline_parallel_split_points layers.4",
"--training.data_parallel_degree 2",
"--training.tensor_parallel_degree 2",
"--model.norm_type rmsnorm", # fused_rmsnorm not yet compatible with TP
],
],
"PP+DP+TP 3D test with save/load resume ckpt",
Expand All @@ -272,7 +265,6 @@ def build_test_list():
"--experimental.pipeline_parallel_degree 4",
"--experimental.pipeline_parallel_split_points layers.1,layers.2,layers.3,layers.4,layers.5,layers.6,layers.7",
"--experimental.pipeline_parallel_schedule interleaved_1f1b",
"--model.norm_type rmsnorm", # fused_rmsnorm throws cuda context error with pp
],
],
"PP looped 1f1b test",
Expand All @@ -292,21 +284,21 @@ def build_test_list():
OverrideDefinitions(
[
[
"--memory_estimation.enabled --model.norm_type rmsnorm",
"--training.data_parallel_type ddp",
]
],
"FSDP2 Memory Tracking and Estimation",
"fsdp2_mem_tracker",
"DDP",
"ddp",
ngpu=4,
),
OverrideDefinitions(
[
[
"--training.data_parallel_type ddp",
"--memory_estimation.enabled",
]
],
"DDP",
"ddp",
"FSDP2 Memory Tracking and Estimation",
"fsdp2_mem_tracker",
ngpu=4,
),
]
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self):
"--model.norm_type",
type=str,
default="rmsnorm",
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, compiled_rmsnorm, fused_rmsnorm]",
help="Type of layer normalization to use [layernorm, np_layernorm, rmsnorm, fused_rmsnorm]",
)
self.parser.add_argument(
"--model.tokenizer_path",
Expand Down
28 changes: 6 additions & 22 deletions torchtitan/models/norms.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
Args:
norm_type (str): The type of normalization layer to build.
Supported types: 1. rmsnorm 2. fused_rmsnorm 3. layernorm 4. np_layernorm
Supported types: layernorm, np_layernorm, rmsnorm, fused_rmsnorm
dim (int): The dimension of the normalization layer.
eps (float, optional): The epsilon value for numerical stability. Defaults to 1e-6.
Expand All @@ -42,13 +42,6 @@ def build_norm(norm_type: str, dim: int, eps: float = 1e-6):
return nn.LayerNorm(dim, eps=eps, elementwise_affine=False, bias=False)
elif norm_type == "rmsnorm":
return RMSNorm(dim, eps=eps)
elif norm_type == "compiled_rmsnorm":
import warnings

warnings.warn(
"compiled_rmsnorm is currently experimental and not ready to use yet."
)
return RMSNorm(dim, eps=eps, compile=True)
elif norm_type == "fused_rmsnorm":
return FusedRMSNorm(dim, eps=eps)
else:
Expand Down Expand Up @@ -94,26 +87,17 @@ class RMSNorm(nn.Module):
"""

def __init__(self, dim: int, eps: float = 1e-6, compile: bool = False):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))
self.rmsnorm_fn = (
torch.compile(self.compute_rmsnorm, fullgraph=True)
if compile
else self.compute_rmsnorm
)

@staticmethod
def compute_rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float):
def _norm(x, eps):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps)

output = _norm(x.float(), eps).type_as(x)
return output * weight
def _norm(self, x: torch.Tensor):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x: torch.Tensor):
return self.rmsnorm_fn(x, self.weight, self.eps)
output = self._norm(x.float()).type_as(x)
return output * self.weight

def reset_parameters(self):
torch.nn.init.ones_(self.weight) # type: ignore
Expand Down
2 changes: 1 addition & 1 deletion train_configs/debug_model.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ save_tb_folder = "tb"
[model]
name = "llama3"
flavor = "debugmodel"
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / compiled_rmsnorm / fused_rmsnorm
norm_type = "rmsnorm" # layernorm / np_layernorm / rmsnorm / fused_rmsnorm
# test tokenizer.model, for debug purpose only
tokenizer_path = "./test/assets/test_tiktoken.model"

Expand Down

0 comments on commit 40210ea

Please sign in to comment.