-
Notifications
You must be signed in to change notification settings - Fork 942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] GPU scatter support for complex64 #1214
Comments
Do you know which op within the log mel spectrogram is causing the error? My guess would be it's in padding the spectrogram? Probably the quickest fix is to run the scatter op on the cpu with: with mx.stream(cpu):
# scatter op Or you could re-implement in terms of We should definitely support the complex64 scatter op directly too, but hopefully the above unblocks you. |
Thanks for the quick reply and the help! I swapped out my
I tried paring down the issue further to specific lines in Thanks again for your help! I'm just sharing this result in case others need a quick workaround. |
No worries! Glad you have a workaround. Strange that it's still failing, I don't seem to be able to reproduce it locally. Are you on the latest MLX version? We added full GPU FFT support recently and that error would make sense on an older version. |
My version is 0.15.0, I upgraded to 0.15.1 and tried but got the same error. I have a 2020 MacBook Air with an M1 and 16gb memory. stack trace:
|
I'm trying to implement a VQ-VAQ with a loss function which uses the log Mel spectrogram. I can train my network using MSE and L1 loss with float32 values, but when I apply the Mel transform to these arrays I get the following error
Is this on the short-term roadmap? If not, I might try to implement the feature or maybe a short-term workaround.
System information:
The text was updated successfully, but these errors were encountered: