Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input tensor is not an XLA tensor on AWS Trainium instance #8510

Open
JmeanJmy opened this issue Dec 20, 2024 · 1 comment
Open

Input tensor is not an XLA tensor on AWS Trainium instance #8510

JmeanJmy opened this issue Dec 20, 2024 · 1 comment

Comments

@JmeanJmy
Copy link

JmeanJmy commented Dec 20, 2024

Hi team, I'm currently testing my training job on AWS Trainium instance. I encountered error Input tensor is not an XLA tensor: torch.FloatTensor when using pytorch Conv1d/Linear module. I’ve confirmed that the input tensor has been moved to xla as I explicitly called .to(xm.xla_device()) when passing the input tensor to the module forward method. However, I found out the error was actually caused by the weight and bias generated within those pytorch module, eg here: https://github.com/pytorch/pytorch/blob/main/torch/nn/modules/conv.py#L375, I printed the device location for self.weght and self.bias and they are on cpu. I have to modify the source Conv1d code to resolve the issue, eg:

def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):
    input = input.to(self.device)
    weight = weight.to(self.device)
    if bias is not None:
        bias = bias.to(self.device)

    if self.padding_mode != 'zeros':
        return F.conv1d(
            F.pad(input, self._reversed_padding_repeated_twice, mode=self.padding_mode),
            weight, bias, self.stride, _single(0), self.dilation, self.groups
        )
    return F.conv1d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)

Does anyone know how to make sure those are on the xla device?

@radna0
Copy link

radna0 commented Dec 25, 2024

Speaking from experience @JmeanJmy , Input tensor is not an XLA tensor is somewhat misleading as it can either mean the model or tensor are not on the xla device. Have you tried just moving the Conv.weight and Conv.bias (or just moving the whole Conv) to device if it's not on the device outside of the source code? I believe something this trivial should not need source code modification.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants