Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MPS] torchao low-bit-precision optim does not expose 'backend' argument to torch.compile #955

Open
bghira opened this issue Sep 26, 2024 · 10 comments
Labels
good first issue Good for newcomers

Comments

@bghira
Copy link

bghira commented Sep 26, 2024

on apple mps platforms, torchao training works great until we involve the AdamW8bit optimiser:

    assert wrapper_code_gen_cls is not None, f"Device {device_type} not supported"
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: Device mps not supported

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

and conditionally setting:

    if torch.backends.mps.is_available():
        import torch._dynamo
        torch._dynamo.config.suppress_errors = True

to suppress this error and 'fallback to eager mode' does not work in this situation, it merely hides the notice that this parameter could be set.

using aot_eager is required instead of inductor for MPS, but additionally it kinda limits third-party backends.

Can we expose the backend parameter, perhaps as compile_backend ?

@msaroufim msaroufim added the good first issue Good for newcomers label Sep 26, 2024
@msaroufim
Copy link
Member

msaroufim commented Sep 26, 2024

This should be an easy PR to make, @bghira would you be interested in taking a stab at this? If you need any advice we hang out on #torchao on discord.gg/gpumode

Also curious is this a production use case, we haven't taken Mac perf super seriously but hey if we have users maybe we should

@bghira
Copy link
Author

bghira commented Sep 26, 2024

actually, using aot_eager gets autograd involved and then dtype complaints happen. the gradients need to be in fp32 precision ... for a low bit optim? 🤔

@bghira
Copy link
Author

bghira commented Sep 26, 2024

yeah simpletuner supports finetuning diffusion models via torch-mps w/ or w/o optimum-quanto up to the 12B parameter Flux model, which really takes advantage of quantisation, down from 30G at pure bf16 (stochastic rounding etc) training to 15GB with quantisation to int8 (mps does n't support fp8)

@msaroufim
Copy link
Member

Oh interesting you're also looking at diffusion models? we have a working group now dedicated towards that

@bghira
Copy link
Author

bghira commented Sep 26, 2024

either way not seeing memory savings with the 8bit adamw as i need the gradients to be upcast to fp32. the 4bit optim uses some ops not implemented on MPS pytorch yet, and enabling CPU fallback results in:

(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %4 = "mps.multiply"(%2, %arg2) : (tensor<16x3072xf32>, tensor<1xbf16>) -> tensor<*xf32>
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: error: 'mps.multiply' op requires the same element type for all operands and results
(mpsFileLoc): /AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphUtilities.mm:233:0: note: see current operation: %4 = "mps.multiply"(%2, %arg2) : (tensor<16x3072xf32>, tensor<1xbf16>) -> tensor<*xf32>
/AppleInternal/Library/BuildRoots/e0873e53-5185-11ef-9a51-9ab6d782fe32/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShadersGraph/mpsgraph/MetalPerformanceShadersGraph/Core/Files/MPSGraphExecutable.mm:953: failed assertion `original module failed verification'
Traceback (most recent call last)

glorious MPS-library-level error and full application fault

@gau-nernst
Copy link
Collaborator

Adding a compile backend flag makes sense, though I'm not sure which other backends are also useful for codegen optim.

If I'm not wrong, aot_eager will run in eager mode, so there will be no memory saving benefits. The memory saving relies on the fact that we do dequant and re-quant inside the kernel, so the dequant tensors will never materialize in global memory.

@bghira
Copy link
Author

bghira commented Sep 26, 2024

so you will require a custom mps extension for pytorch which accomplishes the same thing that you currently rely on cuda kernels for?

eg. following apple's example: https://developer.apple.com/documentation/metal/metal_sample_code_library/customizing_a_pytorch_operation

they provide a downloadable compileable sample:

'''
Copyright © 2023 Apple Inc.

See LICENSE folder for this sample’s licensing information.

Abstract:
The code for compiling the custom pytorch extension.
'''

import torch.utils.cpp_extension

compiled_lib = torch.utils.cpp_extension.load(
    name='CustomSoftshrink',
    sources=['CustomSoftshrink.mm'],
    extra_cflags=['-std=c++17'],
   )

and the relevant cpp example code that links into Metal directly:

/*
See the LICENSE.txt file for this sample’s licensing information.

Abstract:
The code that registers a PyTorch custom operation.
*/


#include <torch/extension.h>
#include "CustomSoftshrink.h"

#import <Foundation/Foundation.h>
#import <Metal/Metal.h>

// Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`.
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
  return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}

torch::Tensor& dispatchSoftShrinkKernel(const torch::Tensor& input, torch::Tensor& output, float lambda) {
    @autoreleasepool {
        id<MTLDevice> device = MTLCreateSystemDefaultDevice();
...

@gau-nernst
Copy link
Collaborator

Yes, if you want 8-bit Adam for MPS, you probably need a custom kernel. (Unless torch.compile or triton support emitting MPS kernel in the future 👀)

Just curious, if you want memory-saving techniques, would something like LoRA/QLoRA be more suitable? Low-bit optim only makes sense when you train the whole big model. And I don't think it's practical to do full fine-tune Flux on MPS (yes, you mentioned about developing on MPS first then move to CUDA, so there are still some valid points about having a working low bit optim on MPS, though you probably can do testing on CPU instead).

@bghira
Copy link
Author

bghira commented Sep 27, 2024

we are already quantising the whole model though with a LoRA - actually, lycoris LoKr in this case.

for my 128G unit, i can do a full finetune with 57G usage when ZeRO3 is working to use storage offload, and was hoping for more options

tinygrad is doing autogen metal kernels 🤷 it's possible to achieve

@gau-nernst
Copy link
Collaborator

gau-nernst commented Sep 27, 2024

@msaroufim I don't know if there are already efforts on this, but a working group for inductor/triton+MPS might be interesting 👀 (totally out of scope of this issue though)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

3 participants