Skip to content

Commit

Permalink
fix(910B): fix bugs in 910B for varlen and fixlen FA (#309)
Browse files Browse the repository at this point in the history
  • Loading branch information
li126com authored Aug 30, 2024
1 parent 39c23ff commit beb391a
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 30 deletions.
5 changes: 3 additions & 2 deletions internlm/core/scheduler/no_pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,15 @@ def _train_one_batch(
loss = self._call_engine_criterion(engine, output, label)
self._call_hooks("after_criterion", loss)
moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
sum(moe_losses) * gpc.config.loss.moe_loss_coeff # pylint: disable=E0606
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=get_current_device(), dtype=gpc.config.model.get("dtype"))
)
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled,
# so we need to do allreduce
if gpc.config.parallel.sequence_parallel:
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR))
moe_loss /= scale_loss
loss /= scale_loss
loss += moe_loss
Expand Down
9 changes: 5 additions & 4 deletions internlm/core/scheduler/pipeline_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def _get_data_label_for_current_step(self, stage_output, micro_batch_data):
label = micro_batch_data.pop("label", None)
data = {"stage_output": stage_output, **micro_batch_data}

return data, label
return data, label # pylint: disable=E0606

def _call_hooks(self, func_name: str, *args, **kwargs) -> None:
for hook in self._hooks:
Expand Down Expand Up @@ -309,13 +309,14 @@ def _forward_step(
output_obj = loss_reduced

moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
sum(moe_losses) * gpc.config.loss.moe_loss_coeff # pylint: disable=E0606
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=get_current_device(), dtype=gpc.config.model.get("dtype"))
)
# the moe_loss is computed among the "tensor" group if sequence parallel is enabled, so we need to do allreduce
if gpc.config.parallel.sequence_parallel:
dist.all_reduce(moe_loss, op=dist.ReduceOp.AVG, group=gpc.get_group(ParallelMode.TENSOR))
dist.all_reduce(moe_loss, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.TENSOR))
moe_loss.div_(gpc.get_world_size(ParallelMode.TENSOR))
moe_loss /= self.num_microbatches
accum_moe_loss.add_(moe_loss.detach())

Expand Down Expand Up @@ -866,7 +867,7 @@ def _forward_step(self, engine, chunk_id):
output_obj = loss_reduced

moe_loss = (
sum(moe_losses) * gpc.config.loss.moe_loss_coeff
sum(moe_losses) * gpc.config.loss.moe_loss_coeff # pylint: disable=E0606
if hasattr(gpc.config.model, "num_experts") and gpc.config.model.num_experts > 1
else torch.tensor(0.0, device=get_current_device(), dtype=gpc.config.model.get("dtype"))
)
Expand Down
4 changes: 3 additions & 1 deletion internlm/model/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ def _update_cos_sin_cache(
if max_seqlen is not None:
seqlen = max_seqlen
elif isinstance(indexes, int):
seqlen = indexes + x.shape[1] + 1
# logic changed temporaryly
# seqlen = indexes + x.shape[1] + 1
seqlen = gpc.config.data.seq_len
else:
# Note that this statement may cause synchronization between CPU and GPU,
# so it's best to precompute and pass in max_seqlen ahead of time
Expand Down
8 changes: 5 additions & 3 deletions internlm/model/ops/rotary_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
flash_rotary_impl = False

try:
from deeplink_ext.internlm_ops import ApplyRotaryEmb as DeeplinkApplyRotaryEmb
from deeplink_ext.internevo_ops import ApplyRotaryEmb as DeeplinkApplyRotaryEmb

deeplink_rotary_impl = True
except (ModuleNotFoundError, ImportError):
Expand Down Expand Up @@ -143,13 +143,14 @@ def rotary_emb_in_rotate_half_style(
cos (Tensor): cos, shape is [S, D//2].
sin (Tensor): sin, shape is [S, D//2].
"""
assert False, "This function has some bugs. You should not arrive here."
# reformat cos/sin shape.
cos = torch.cat((cos, cos), dim=-1)[None, :, None, :]
sin = torch.cat((sin, sin), dim=-1)[None, :, None, :]

if len(x.shape) == 5:
q, k, _ = x.unbind(dim=2)

q, k = q.squeeze(dim=2), k.squeeze(dim=2)
if interleaved:
q = torch.cat([q[..., ::2], q[..., 1::2]], dim=-1)
k = torch.cat([k[..., ::2], k[..., 1::2]], dim=-1)
Expand Down Expand Up @@ -300,6 +301,7 @@ def apply_rotary_emb(
# TODO: to support in_place argument
return DeeplinkApplyRotaryEmb.apply(x, cos, sin, interleaved, use_fused_rope)
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.NPU:
return rotary_emb_in_rotate_half_style(x, cos, sin, interleaved, use_fused_rope)
# return rotary_emb_in_rotate_half_style(x, cos, sin, interleaved, use_fused_rope)
return ApplyRotaryEmb.apply(x, cos, sin, interleaved, in_place)
else:
return ApplyRotaryEmb.apply(x, cos, sin, interleaved, in_place)
47 changes: 39 additions & 8 deletions internlm/solver/optimizer/hybrid_zero_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch.distributed as dist
from torch.optim import Optimizer

from internlm.accelerator import get_accelerator
from internlm.accelerator import AcceleratorType, get_accelerator
from internlm.core.context import Config, ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.context.parallel_context import (
Expand Down Expand Up @@ -117,6 +117,7 @@ def __init__(
growth_interval=growth_interval,
hysteresis=hysteresis,
max_scale=max_scale,
dtype=gpc.config.model.dtype,
)
self._found_overflow = torch.tensor([0], device=get_current_device(), dtype=torch.float32)

Expand Down Expand Up @@ -317,13 +318,25 @@ def _define_and_attach(param, reduce_rank=None):
)

def reduction_layernorm_func():
handle = reduce_tensor(
# BUG: 8.0.RC1.alpha003 hccl allreduce AVG op will not perform averaging operation.
# So we use sum + div here when training on Ascend machines.
op_type = (
torch.distributed.ReduceOp.SUM
if internlm_accelerator.get_accelerator_backend()
in [AcceleratorType.NPU, AcceleratorType.DIPU]
else torch.distributed.ReduceOp.AVG
)
parallel_mode = ParallelMode.WEIGHT if self.use_isp else ParallelMode.TENSOR
reduce_tensor(
param.grad,
dtype=None,
dst_rank=reduce_rank,
parallel_mode=ParallelMode.WEIGHT if self.use_isp else ParallelMode.TENSOR,
parallel_mode=parallel_mode,
op_type=op_type,
async_op=False,
)
handle.wait()
if op_type == torch.distributed.ReduceOp.SUM:
param.grad.div_(gpc.get_world_size(parallel_mode))

# define hook
# NOT IMPORTANT BUT GOOD TO KNOW:
Expand Down Expand Up @@ -500,37 +513,50 @@ def _reduce_grads_stored_in_bucket(self, current_bucket, reduce_rank=None):
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size, group_id, dp_parallel_mode):
grad_buckets_by_dtype = split_half_float_double(grads)
next_bucket_list = []

if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]:
op_type = torch.distributed.ReduceOp.SUM
avg_size = gpc.get_world_size(dp_parallel_mode)
else:
op_type = torch.distributed.ReduceOp.AVG
avg_size = -1

# add parameters into bucket for reduction
for tensor_list in grad_buckets_by_dtype:
param_bucket = TensorBucket(size=bucket_size)
for tensor in tensor_list:
param_bucket.add_to_bucket(tensor, allow_oversize=True)
if not param_bucket.is_empty():
self._reduce_and_copy(
bucket=param_bucket, reduce_rank=reduce_rank, group_id=group_id, dp_parallel_mode=dp_parallel_mode
bucket=param_bucket,
reduce_rank=reduce_rank,
group_id=group_id,
dp_parallel_mode=dp_parallel_mode,
op_type=op_type,
)
next_bucket_list.append(param_bucket)

# wait for the completion of previouce bucket list reduction, and do unflatten_and_copy()
# here we can also overlap the communication with some memcpy operation caused by bucket.flatten()
for bucket in self._bucket_in_progress:
bucket.commu_handle.wait()
bucket.unflatten_and_copy()
bucket.unflatten_and_copy(dp_group_size=avg_size)
bucket.empty()
self._bucket_in_progress = []
self._param_store.clear_grads_of_previous_reduced_params()

# after the completion of bucket list reduction, add new buckets into _bucket_in_progress
self._bucket_in_progress = next_bucket_list.copy()

def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, group_id, dp_parallel_mode):
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank, group_id, dp_parallel_mode, op_type):
# flatten the tensors and do allreduce
bucket.flatten()
bucket.commu_handle = reduce_tensor(
tensor=bucket.get_flat_tensor(),
dtype=None,
dst_rank=reduce_rank,
parallel_mode=dp_parallel_mode,
op_type=op_type,
)

# update the reduced tensor
Expand Down Expand Up @@ -674,10 +700,15 @@ def step(self, closure=None):
for group_id in range(self.num_param_groups):
self._reduce_grads_stored_in_bucket(self._bucket_store[group_id], reduce_rank=None)

if internlm_accelerator.get_accelerator_backend() in [AcceleratorType.NPU, AcceleratorType.DIPU]:
avg_size = gpc.get_world_size(ParallelMode.DATA)
else:
avg_size = -1

# wait grads reduced and clear reduced grads
for bucket in self._bucket_in_progress:
bucket.commu_handle.wait()
bucket.unflatten_and_copy()
bucket.unflatten_and_copy(dp_group_size=avg_size)
bucket.empty()
self._bucket_in_progress = []
self._param_store.clear_grads_of_previous_reduced_params()
Expand Down
6 changes: 4 additions & 2 deletions internlm/solver/optimizer/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,13 @@ def empty(self):
def flatten(self):
self._flat_tensor = _flatten_dense_tensors(self._bucket)

def unflatten_and_copy(self):
def unflatten_and_copy(self, dp_group_size: int = -1):
if self._unflatten_and_copy_flag:
unflattened_tensor_list = _unflatten_dense_tensors(self._flat_tensor, self._bucket)
for old, new in zip(self._bucket, unflattened_tensor_list):
old.copy_(new)
if dp_group_size != -1:
old.div_(dp_group_size)


class BucketStore_v2(BaseStore):
Expand Down Expand Up @@ -409,7 +411,7 @@ def build_grad_in_bucket(self, comm_stream):
grad = param.grad.clone().detach().flatten()
if padding_size > 0:
with torch.no_grad():
grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size])
grad = torch.nn.functional.pad(grad.view(-1), [0, padding_size]) # pylint: disable=E1102
grad_list = grad.split(grad.numel() // self.zero_world_size)
for rank in range(self.zero_world_size):
grad_current_rank = grad_list[rank].clone().detach()
Expand Down
51 changes: 41 additions & 10 deletions internlm/solver/optimizer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,14 @@ def split_half_float_double(tensor_list):
return buckets


def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.DATA):
def reduce_tensor(
tensor,
dtype=None,
dst_rank=None,
parallel_mode=ParallelMode.DATA,
op_type=torch.distributed.ReduceOp.AVG,
async_op=True,
):
"""
Reduce the tensor in the data parallel process group
Expand Down Expand Up @@ -114,13 +121,11 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
use_all_reduce = dst_rank is None

if use_all_reduce:
handle = dist.all_reduce(tensor=tensor, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True)
handle = dist.all_reduce(tensor=tensor, group=group, op=op_type, async_op=async_op)
else:
ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
global_rank = ranks_in_group[dst_rank]
handle = dist.reduce(
tensor=tensor, dst=global_rank, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True
)
handle = dist.reduce(tensor=tensor, dst=global_rank, group=group, op=op_type, async_op=async_op)

return handle

Expand Down Expand Up @@ -436,6 +441,7 @@ def __init__(
min_scale: Optional[float] = None,
max_scale: Optional[float] = None,
hysteresis: int = 2,
dtype=torch.bfloat16,
):
super().__init__(initial_scale)
if min_scale:
Expand All @@ -454,17 +460,42 @@ def __init__(
self._growth_step = 0
self._hysteresis = hysteresis
self._hysteresis_step = 0
self._dtype = dtype
self._sanity_checks()

def _sanity_checks(self) -> None:
"""Check if the arguments are correct."""

if self._min_scale:
assert self._min_scale > 0, "The minimum gradient scale cannot be zero or negative"
assert self._dtype in [torch.float16, torch.bfloat16, torch.float32]

if self._min_scale is not None:
min_scale = self._min_scale.item()
assert min_scale > 0, "The minimum gradient scale cannot be zero or negative"

if self._dtype != torch.float16 and min_scale != 1.0 and gpc.is_rank_for_log():
logger.warning(f"Detect you use {self._dtype}, but min_scale: {min_scale} != 1.0")

if self._max_scale:
assert self._min_scale > 0, "The maximum gradient scale cannot be zero or negative"
assert self._growth_factor > 1, "The growth factor cannot be equal or smaller than 1"
assert self._backoff_factor < 1 and self._backoff_factor > 0, "The backoff factor must be between 0 and 1"
max_scale = self._max_scale.item()
assert max_scale > 0, "The maximum gradient scale cannot be zero or negative"

if self._dtype != torch.float16 and max_scale != 1.0 and gpc.is_rank_for_log():
logger.warning(f"Detect you use {self._dtype}, but max_scale: {max_scale} != 1.0")

if self._dtype == torch.float16:
assert self._growth_factor > 1.0, "The growth factor cannot be equal or smaller than 1"
assert self._backoff_factor < 1.0 and self._backoff_factor > 0, "The backoff factor must be between 0 and 1"
else:
assert self._growth_factor >= 1.0, "The growth factor cannot be smaller than 1"
assert (
self._backoff_factor <= 1.0 and self._backoff_factor > 0
), "The backoff factor must be between 0 and 1"

if self._growth_factor != 1.0 and gpc.is_rank_for_log():
logger.warning(f"Detect you use {self._dtype}, but growth_factor: {self._growth_factor} != 1.0")
if self._backoff_factor != 1.0 and gpc.is_rank_for_log():
logger.warning(f"Detect you use {self._dtype}, but backoff_factor: {self._backoff_factor} != 1.0")

assert self._hysteresis >= 0, "The hysteresis cannot be negative"

def update(self, overflow: bool) -> None:
Expand Down

0 comments on commit beb391a

Please sign in to comment.