Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support MLX in WhisperAX #200

Open
wants to merge 17 commits into
base: mlx-support
Choose a base branch
from
Open

Support MLX in WhisperAX #200

wants to merge 17 commits into from

Conversation

ZachNagengast
Copy link
Contributor

@ZachNagengast ZachNagengast commented Sep 7, 2024

Add support for the WhisperAX example app, as well as various refactors and cleanup.

image

Note: this is using an unreleased version of MLX, pending merge of ml-explore/mlx-swift#130

There are still some memory issues to address, will look more into this soon.

CoreML:
Screenshot 2024-09-06 at 11 28 23 PM
MLX:
image

Important note, first time running this will likely throw an error about PrepareMetalShaders, which requires Trust & Enable on this popup when selecting the error.
image

ZachNagengast and others added 15 commits July 15, 2024 08:24
…lity (#192)

* Make additional initializers, functions, members public, for WKPro

* Allows use of default internal functions & member accesses which have
  increased protections when imported

* Initializers were Xcode generated: right click class name -> refactor
  -> generate memberwise initializers
   * memberwise initializer defaults to internal, mark as public.

* Formatting

---------

Co-authored-by: ZachNagengast <[email protected]>
… models (#193)

* Add initial mlpackage loading (if .mlmodelc not present)

-- Does not modify model loading in OS WK.  This is a hook to modify
load path URLs.

* Always load audio encoder last

* Adjust timings to account for decoder<>encoder order swap

* Add helper for mlpackage detection

---------

Co-authored-by: ZachNagengast <[email protected]>
* Fix start time logic for file loading and resampling

* Add test file
As far as I can tell, these stored properties are not meant to be changed. Therefore, change them to be immutable. This change also makes these static properties concurrency-safe.
* Add VoiceActivityDetector base class

Add base class to allow different VAD implementations

* fix spaces
Copy link
Contributor

@jkrukowski jkrukowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, added some comments, let me know what you think

Comment on lines 35 to 36
=======
>>>>>>> main
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

git conflict markers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good catch 👍

Package.resolved Outdated
Comment on lines 6 to 8
"location" : "https://github.com/ml-explore/mlx-swift",
"location" : "https://github.com/davidkoski/mlx-swift.git",
"state" : {
"revision" : "597aaa5f465b4b9a17c8646b751053f84e37925b",
"version" : "0.16.0"
"revision" : "3314bc684f0ccab1793be54acddaea16c0501d3c"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, why this change?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok, now I can see why, nvm

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been merged into mlx-swift and tagged 0.16.2

@@ -171,6 +137,74 @@ public struct ModelComputeOptions {
}
}

public struct ModelInfo: Identifiable, Hashable {
public let id = UUID()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious, why this property public let id = UUID() is needed? can it be uniquely identified by name?

Copy link
Contributor Author

@ZachNagengast ZachNagengast Sep 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to make it identifiable, although I was using this for the picker at one point and may be vestigial, will check.

Comment on lines 144 to 152
let keyCache = try? MLX.stacked(keyCacheResult).asMLMultiArray()
let valueCache = try? MLX.stacked(valueCacheResult).asMLMultiArray()
let decodingCache = DecodingCache(
keyCache: keyCache,
valueCache: valueCache,
alignmentWeights: nil
)

let logits = try? result.logits?.asMLMultiArray()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like you've changed it to try ? in couple of places, is it intended?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, we want this to pass the throw up through the call stack. Reviewing build warnings as well.

@iandundas
Copy link
Contributor

Hey all, just a heads-up, the example project gets this SPM error:

CleanShot 2024-09-18 at 10 34 34@2x

@davidkoski
Copy link

Hey all, just a heads-up, the example project gets this SPM error:

CleanShot 2024-09-18 at 10 34 34@2x

that fork was merged and is now the 0.16.2 tag on https://github.com/ml-explore/mlx-swift

@ZachNagengast
Copy link
Contributor Author

Awesome, thanks for the update @davidkoski!

@latenitefilms
Copy link

@ZachNagengast - This looks awesome! Apologies - rookie question... It looks like currently the MLX repo only has the base and tiny models. How hard is it for mere mortals to "build" some of the larger models for testing this out with larger models?

@latenitefilms
Copy link

Or... can you use existing models with MLX?

@ZachNagengast
Copy link
Contributor Author

Sorry missed your original message! We will fill in the mlx repo with the remaining models as part of this release, we just made these copies for consistency with our swift package. Any MLX whisper model currently existing with the same naming scheme will work in theory 👍 @jkrukowski may be able to confirm or deny.

@latenitefilms
Copy link

Legend, thanks @ZachNagengast! So basically, we do need new models for MLX, we can use the existing WhisperKit models? They need to be optimised or something?

@ZachNagengast
Copy link
Contributor Author

Yep the existing WhisperKit models are optimized for CoreML, the ones in this repo we will fill out with the equivalent weights that are compatible with this MLX PR

@latenitefilms
Copy link

Sorry for all the rookie questions, but when you say "optimised for CoreML" - does this mean they ONLY work on CoreML, or can you use these CoreML models in MLX and they're just not as fast/accurate?

Apologies - this whole Whisper world is very new to me, so I very much appreciate all your wisdom and support!

@maxlund
Copy link

maxlund commented Oct 4, 2024

Yep the existing WhisperKit models are optimized for CoreML, the ones in this repo we will fill out with the equivalent weights that are compatible with this MLX PR

@ZachNagengast Is there any model conversion script we can run, or any other source we can use, in order to create/obtain more MLX compatible model versions?

@ZachNagengast
Copy link
Contributor Author

@latenitefilms Yes the .mlmodelc models only work with CoreML at the moment.
@maxlund There is a script made by @jkrukowski to do the conversion here #169, we'll integrate this into https://github.com/argmaxinc/whisperkittools in the future.

@latenitefilms
Copy link

Legend, thanks so much @ZachNagengast! Do you have a rough/ballpark ETA of when you're hoping to finish and merge in MLX support? No rush or pressure - just wondering if it's worth trying to convert our own models or not.

Let me know if there's anything I can do to help with MLX testing/release! Would love to see this in action ASAP!

Thanks for EVERYTHING you do! Appreciate it!

@ZachNagengast
Copy link
Contributor Author

There are just a few optimizations to fix up to make it ready for release, specifically memory usage. Current issues are:

  • MLX does not require a prewarm stage, so it should skip that. Currently its loading the model twice without freeing up the memory. Can also be solved by setting a cache limit or clearing the cache after load
  • KV cache should use this instead of mlmultiarrays
  • Sampling can be compiled for some easy speedups
  • Attention should use SDPA instead of current logic

These are paraphrased from @davidkoski and @awni

Will be revisiting this after the upcoming release but feel free to test with this current branch if you see any other potential speedup besides these, all the interfaces should be the same in its final form, just faster and more memory efficient with these changes.

@latenitefilms
Copy link

Amazing! Thanks so much! Will test out and let you know if I break anything.

@anishjain123
Copy link

Any update on this? Is it usable? @latenitefilms @ZachNagengast ? Can someone please point me in the direction of how i can get this set up?

@ZachNagengast
Copy link
Contributor Author

Hi @anishjain123 these are still pending issues #200 (comment), but this branch is technically usable. We'd like to resolve the perf and memory issues before merging, which is still a high priority for us! Working on a refactor right now to allow various different model input and output types, including MLXArray, which should help with the issues converting between MLMultiArray and MLXArray.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants