Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

float8 delayed scaling: private API to fix user overriding buffers #1292

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions test/float8/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,29 @@ def test_inference_mode(self):
with torch.inference_mode(mode=True):
y = m(x)

@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
def test_to_empty_delayed_scaling_with_float8_all_gather(self):
with torch.device("meta"):
m_ref = nn.Sequential(nn.Linear(32, 32))
config = Float8LinearConfig(
cast_config_input=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_weight=CastConfig(scaling_type=ScalingType.DELAYED),
cast_config_grad_output=CastConfig(scaling_type=ScalingType.DELAYED),
enable_fsdp_float8_all_gather=True,
)
m_fp8 = convert_to_float8_training(m_ref, config=config)

assert m_fp8[0].fp8_amax_weight is m_fp8[0].weight._amax_buffer
assert m_fp8[0].fp8_amax_history_weight is m_fp8[0].weight._amax_history_buffer
assert m_fp8[0].fp8_scale_weight is m_fp8[0].weight._scale_buffer

m_fp8.to_empty(device="cuda")
m_fp8[0]._maybe_fixup_delayed_scaling_buffers()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we would need to call this inside torchtitan’s training loop?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, which is definitely not ideal


assert m_fp8[0].fp8_amax_weight is m_fp8[0].weight._amax_buffer
assert m_fp8[0].fp8_amax_history_weight is m_fp8[0].weight._amax_history_buffer
assert m_fp8[0].fp8_scale_weight is m_fp8[0].weight._scale_buffer


class TestScaledMM:
@unittest.skipIf(
Expand Down
13 changes: 13 additions & 0 deletions torchao/float8/float8_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,19 @@ def extra_repr(self):
s = f'{super().extra_repr()}, cast_configs={cast_config_str}"'
return s

def _maybe_fixup_delayed_scaling_buffers(self):
if (
self.config.enable_fsdp_float8_all_gather
and self.config.cast_config_weight.scaling_type is ScalingType.DELAYED
):
# in case the module weight-related buffers got overwritten by
# the user (such as when calling `model.to_empty`), we
# re-link the weight wrapper buffers to point to the correct
# location
self.weight._amax_buffer = self.fp8_amax_weight
self.weight._amax_history_buffer = self.fp8_amax_history_weight
self.weight._scale_buffer = self.fp8_scale_weight

@classmethod
def from_float(
cls,
Expand Down
Loading