-
Notifications
You must be signed in to change notification settings - Fork 187
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[float8] Allow specifying arbitrary dtype for each tensor #1326
base: gh/lw/2/base
Are you sure you want to change the base?
Conversation
ghstack-source-id: 7dabc91df68ce20e15551c5488071579e49c263c Pull Request resolved: #1326
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1326
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 97c9983 with merge base 1a0dbf1 (): NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: d9c0e023cb8667d6f13e2845e9b6845e1669f78a Pull Request resolved: #1326
ghstack-source-id: e86e1f6a42610d776bccd3f33044f23e311688eb Pull Request resolved: #1326
ghstack-source-id: aa9f551c8d274f349c4298932fc95c88040abb09 Pull Request resolved: #1326
ghstack-source-id: c339ea060b7062871a5f57a939e8880e5f727de4 Pull Request resolved: #1326
ghstack-source-id: 4b3a2f0007d74e3453cefde1307f2a9c5271e83e Pull Request resolved: #1326
@@ -62,6 +62,7 @@ class CastConfig: | |||
scaling_type: ScalingType = ScalingType.DYNAMIC | |||
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE | |||
static_scale: Optional[torch.Tensor] = None | |||
dtype: Optional[torch.dtype] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
- can we add a comment on what this is used for, and that
None
means the default e4m3|e5m2 value will be used? - optional - thoughts about naming this in a more specific way such as
target_dtype
,lowp_dtype
, etc?dtype
is a bit ambiguous across torchao unfortunately :(
@@ -343,12 +367,14 @@ def recipe_name_to_linear_config( | |||
cc_w = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) | |||
|
|||
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise | |||
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) | |||
cc_go = CastConfig( | |||
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: maybe we can also add some context in the comments on L353:L363 that it also uses e4m3 for grads?
NoopFwToFloat8E5M2BwDelayed, | ||
NoopFwToFloat8E5M2BwDynamic, | ||
NoopFwToFloat8E5M2BwStatic, | ||
NoopFwToFloat8BwDelayed, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for updating these!
@@ -303,13 +311,16 @@ def inner_func(): | |||
|
|||
# Calculate the new scales from the updated history stacks | |||
new_input_scales = amax_history_to_scale_stack( | |||
fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe | |||
fp8_input_amax_history_stack, input_dtype, x_dtype, scale_fn_recipe |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
will likely have to rebase on top of #1329 which changed this line
@@ -62,6 +62,7 @@ class CastConfig: | |||
scaling_type: ScalingType = ScalingType.DYNAMIC | |||
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE | |||
static_scale: Optional[torch.Tensor] = None | |||
dtype: Optional[torch.dtype] = None | |||
|
|||
def short_str(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we also add the dtype here, so it appears when we print an instance of Float8Linear
? Float8Linear.__extra_repr__
calls this method.
This is great! LGTM, had some comments but all are pretty nitty. CI is green - ship it! |
ghstack-source-id: d8300e2a07c087f3cd51b03e0e21125a83a29489 Pull Request resolved: #1326
Superseded by #1378 |
Stack from ghstack (oldest at bottom):