Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[float8] Allow specifying arbitrary dtype for each tensor #1326

Draft
wants to merge 7 commits into
base: gh/lw/2/base
Choose a base branch
from

Conversation

lw
Copy link
Contributor

@lw lw commented Nov 22, 2024

[ghstack-poisoned]
lw added a commit that referenced this pull request Nov 22, 2024
ghstack-source-id: 7dabc91df68ce20e15551c5488071579e49c263c
Pull Request resolved: #1326
Copy link

pytorch-bot bot commented Nov 22, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1326

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 97c9983 with merge base 1a0dbf1 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 22, 2024
[ghstack-poisoned]
lw added a commit that referenced this pull request Nov 22, 2024
ghstack-source-id: d9c0e023cb8667d6f13e2845e9b6845e1669f78a
Pull Request resolved: #1326
@lw lw added the topic: new feature Use this tag if this PR adds a new feature label Nov 22, 2024
[ghstack-poisoned]
lw added a commit that referenced this pull request Nov 22, 2024
ghstack-source-id: e86e1f6a42610d776bccd3f33044f23e311688eb
Pull Request resolved: #1326
[ghstack-poisoned]
lw added a commit that referenced this pull request Nov 22, 2024
ghstack-source-id: aa9f551c8d274f349c4298932fc95c88040abb09
Pull Request resolved: #1326
[ghstack-poisoned]
lw added a commit that referenced this pull request Nov 22, 2024
ghstack-source-id: c339ea060b7062871a5f57a939e8880e5f727de4
Pull Request resolved: #1326
[ghstack-poisoned]
lw added a commit that referenced this pull request Nov 22, 2024
ghstack-source-id: 4b3a2f0007d74e3453cefde1307f2a9c5271e83e
Pull Request resolved: #1326
@@ -62,6 +62,7 @@ class CastConfig:
scaling_type: ScalingType = ScalingType.DYNAMIC
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
static_scale: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

  1. can we add a comment on what this is used for, and that None means the default e4m3|e5m2 value will be used?
  2. optional - thoughts about naming this in a more specific way such as target_dtype, lowp_dtype, etc? dtype is a bit ambiguous across torchao unfortunately :(

@@ -343,12 +367,14 @@ def recipe_name_to_linear_config(
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)

# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can also add some context in the comments on L353:L363 that it also uses e4m3 for grads?

NoopFwToFloat8E5M2BwDelayed,
NoopFwToFloat8E5M2BwDynamic,
NoopFwToFloat8E5M2BwStatic,
NoopFwToFloat8BwDelayed,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for updating these!

@@ -303,13 +311,16 @@ def inner_func():

# Calculate the new scales from the updated history stacks
new_input_scales = amax_history_to_scale_stack(
fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
fp8_input_amax_history_stack, input_dtype, x_dtype, scale_fn_recipe
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will likely have to rebase on top of #1329 which changed this line

@@ -62,6 +62,7 @@ class CastConfig:
scaling_type: ScalingType = ScalingType.DYNAMIC
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
static_scale: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None

def short_str(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also add the dtype here, so it appears when we print an instance of Float8Linear? Float8Linear.__extra_repr__ calls this method.

@vkuzo
Copy link
Contributor

vkuzo commented Nov 26, 2024

This is great! LGTM, had some comments but all are pretty nitty. CI is green - ship it!

[ghstack-poisoned]
lw added a commit that referenced this pull request Dec 4, 2024
ghstack-source-id: d8300e2a07c087f3cd51b03e0e21125a83a29489
Pull Request resolved: #1326
@lw
Copy link
Contributor Author

lw commented Dec 4, 2024

Superseded by #1378

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants