-
Notifications
You must be signed in to change notification settings - Fork 120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MPS] torchao low-bit-precision optim does not expose 'backend' argument to torch.compile #955
Comments
This should be an easy PR to make, @bghira would you be interested in taking a stab at this? If you need any advice we hang out on #torchao on discord.gg/gpumode Also curious is this a production use case, we haven't taken Mac perf super seriously but hey if we have users maybe we should |
actually, using |
yeah simpletuner supports finetuning diffusion models via torch-mps w/ or w/o optimum-quanto up to the 12B parameter Flux model, which really takes advantage of quantisation, down from 30G at pure bf16 (stochastic rounding etc) training to 15GB with quantisation to int8 (mps does n't support fp8) |
Oh interesting you're also looking at diffusion models? we have a working group now dedicated towards that |
either way not seeing memory savings with the 8bit adamw as i need the gradients to be upcast to fp32. the 4bit optim uses some ops not implemented on MPS pytorch yet, and enabling CPU fallback results in:
glorious MPS-library-level error and full application fault |
Adding a compile backend flag makes sense, though I'm not sure which other backends are also useful for codegen optim. If I'm not wrong, aot_eager will run in eager mode, so there will be no memory saving benefits. The memory saving relies on the fact that we do dequant and re-quant inside the kernel, so the dequant tensors will never materialize in global memory. |
so you will require a custom mps extension for pytorch which accomplishes the same thing that you currently rely on cuda kernels for? eg. following apple's example: https://developer.apple.com/documentation/metal/metal_sample_code_library/customizing_a_pytorch_operation they provide a downloadable compileable sample: '''
Copyright © 2023 Apple Inc.
See LICENSE folder for this sample’s licensing information.
Abstract:
The code for compiling the custom pytorch extension.
'''
import torch.utils.cpp_extension
compiled_lib = torch.utils.cpp_extension.load(
name='CustomSoftshrink',
sources=['CustomSoftshrink.mm'],
extra_cflags=['-std=c++17'],
) and the relevant cpp example code that links into Metal directly: /*
See the LICENSE.txt file for this sample’s licensing information.
Abstract:
The code that registers a PyTorch custom operation.
*/
#include <torch/extension.h>
#include "CustomSoftshrink.h"
#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
// Helper function to retrieve the `MTLBuffer` from a `torch::Tensor`.
static inline id<MTLBuffer> getMTLBufferStorage(const torch::Tensor& tensor) {
return __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
}
torch::Tensor& dispatchSoftShrinkKernel(const torch::Tensor& input, torch::Tensor& output, float lambda) {
@autoreleasepool {
id<MTLDevice> device = MTLCreateSystemDefaultDevice();
... |
Yes, if you want 8-bit Adam for MPS, you probably need a custom kernel. (Unless torch.compile or triton support emitting MPS kernel in the future 👀) Just curious, if you want memory-saving techniques, would something like LoRA/QLoRA be more suitable? Low-bit optim only makes sense when you train the whole big model. And I don't think it's practical to do full fine-tune Flux on MPS (yes, you mentioned about developing on MPS first then move to CUDA, so there are still some valid points about having a working low bit optim on MPS, though you probably can do testing on CPU instead). |
we are already quantising the whole model though with a LoRA - actually, lycoris LoKr in this case. for my 128G unit, i can do a full finetune with 57G usage when ZeRO3 is working to use storage offload, and was hoping for more options tinygrad is doing autogen metal kernels 🤷 it's possible to achieve |
@msaroufim I don't know if there are already efforts on this, but a working group for inductor/triton+MPS might be interesting 👀 (totally out of scope of this issue though) |
on apple mps platforms, torchao training works great until we involve the AdamW8bit optimiser:
and conditionally setting:
to suppress this error and 'fallback to eager mode' does not work in this situation, it merely hides the notice that this parameter could be set.
using
aot_eager
is required instead ofinductor
for MPS, but additionally it kinda limits third-party backends.Can we expose the
backend
parameter, perhaps ascompile_backend
?The text was updated successfully, but these errors were encountered: