Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FLoat8 Autocast Issue #1418

Open
asahni04 opened this issue Dec 16, 2024 · 1 comment
Open

FLoat8 Autocast Issue #1418

asahni04 opened this issue Dec 16, 2024 · 1 comment
Labels

Comments

@asahni04
Copy link

asahni04 commented Dec 16, 2024

facing error when x.float(). on a Float8linear casted input by i presume the previous layers. @vkuzo

class Fp32LayerNorm(nn.LayerNorm):
    """
    Wrapper around :class:`~torch.nn.LayerNorm` to support mixed-precision training.
    """

    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__(*args, **kwargs)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Args:
            x (torch.Tensor): Input tensor.
        Returns:
            torch.Tensor: The normalized output tensor having the same shape as ``x``.
        """
        output = nn.functional.layer_norm(
            x.float(),
            self.normalized_shape,
            self.weight.float() if self.weight is not None else None,
            self.bias.float() if self.bias is not None else None,
            self.eps,
        )
        return output.type_as(x)

error:



[rank2]:   File "some_file", line 134, in modulate
[rank2]:     x = norm(x)
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1841, in _call_impl
[rank2]:     return inner()
[rank2]:   File "/opt/conda/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1790, in inner
[rank2]:     result = forward_call(*args, **kwargs)
[rank2]:   File "some_file", line 96, in forward
[rank2]:     x.float(),


al_frame.py", line 632, in _fn
[rank1]:     return fn(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_api.py", line 340, in __torch_dispatch__
[rank1]:     return DTensor._op_dispatcher.dispatch(
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/distributed/tensor/_dispatch.py", line 215, in dispatch
[rank1]:     local_results = op_call(*local_tensor_args, **op_info.local_kwargs)
[rank1]:                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torch/_ops.py", line 716, in __call__
[rank1]:     return self._op(*args, **kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torchao/float8/float8_tensor.py", line 383, in __torch_dispatch__
[rank1]:     return FLOAT8_OPS_TABLE[func](func, args, kwargs)
[rank1]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank1]:   File "/opt/conda/lib/python3.11/site-packages/torchao/float8/float8_ops.py", line 360, in autocast_to_copy
[rank1]:     assert kwargs["dtype"] in {
[rank1]:            ^^^^^^^^^^^^^^^^^^^^
[rank1]: torch._dynamo.exc.TorchRuntimeError: Failed running call_method float(*(DTensor(local_tensor=Float8Tensor(dtype=torch.float8_e4m3fn, scale=FakeTensor(..., device='cuda:1', size=()), linear_mm_config=LinearMMConfig(output=ScaledMMConfig(emulate=False, use_fast_accum=True, fp8_output=False, pad_inner_dim=False), grad_input=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False), grad_weight=ScaledMMConfig(emulate=False, use_fast_accum=False, fp8_output=False, pad_inner_dim=False)), axiswise_dim=None
[rank1]: gemm_input_role=GemmInputRole.INPUT
[rank1]: as_orig_prec=FakeTensor(..., device='cuda:1', size=(64, 72, 1024), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0, 1, 2, 3, 4, 5, 6, 7], mesh_dim_names=('tensor_parallel',)), placements=(Shard(dim=1),)),), **{}):
[rank1]: Only support floating point conversion for autocast w/ Float8Tensor


@vkuzo
Copy link
Contributor

vkuzo commented Dec 16, 2024

Could you share a standalone repro? I'm surprised to see this as Float8Tensor is not a user facing object and it's not supposed to leak outside of Float8Linear.

@vkuzo vkuzo added the float8 label Dec 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants