Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] GPU scatter support for complex64 #1214

Open
credwood opened this issue Jun 17, 2024 · 4 comments
Open

[Feature] GPU scatter support for complex64 #1214

credwood opened this issue Jun 17, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@credwood
Copy link

credwood commented Jun 17, 2024

I'm trying to implement a VQ-VAQ with a loss function which uses the log Mel spectrogram. I can train my network using MSE and L1 loss with float32 values, but when I apply the Mel transform to these arrays I get the following error

ValueError: [scatter] GPU scatter does not yet support complex64 for the input or updates.

Is this on the short-term roadmap? If not, I might try to implement the feature or maybe a short-term workaround.

System information:

  • OS Version: MacOS Sonoma 14.5
  • Version 0.15.0
@barronalex
Copy link
Collaborator

Do you know which op within the log mel spectrogram is causing the error? My guess would be it's in padding the spectrogram?

Probably the quickest fix is to run the scatter op on the cpu with:

with mx.stream(cpu):
  # scatter op

Or you could re-implement in terms of mx.pad which is how the Whisper MLX example does it.

We should definitely support the complex64 scatter op directly too, but hopefully the above unblocks you.

@credwood
Copy link
Author

Do you know which op within the log mel spectrogram is causing the error? My guess would be it's in padding the spectrogram?

Probably the quickest fix is to run the scatter op on the cpu with:

with mx.stream(cpu):
  # scatter op

Or you could re-implement in terms of mx.pad which is how the Whisper MLX example does it.

We should definitely support the complex64 scatter op directly too, but hopefully the above unblocks you.

Thanks for the quick reply and the help! I swapped out my stft function for the one in the Whisper example linked above, and that triggered the same problem (I was already using mx.pad in my padding function). I adapted the log_mel_spectrogram function from that same Whisper example code and was able to get my model training by running the following lines in that function on CPU:

with mx.stream(mx.cpu):
        freqs = stft(audio, window, nperseg=n_fft, noverlap=hop_length)
        magnitudes = freqs[:, :-1, :].abs().square()
        filters = mel_filters(n_mels, n_fft)
        mel_spec = magnitudes @ filters.T

I tried paring down the issue further to specific lines in stft along with the other lines in log_mel_spectrogram but this is currently the easiest workaround for me.

Thanks again for your help! I'm just sharing this result in case others need a quick workaround.

@barronalex
Copy link
Collaborator

No worries! Glad you have a workaround.

Strange that it's still failing, I don't seem to be able to reproduce it locally.

Are you on the latest MLX version? We added full GPU FFT support recently and that error would make sense on an older version.

@credwood
Copy link
Author

credwood commented Jun 17, 2024

No worries! Glad you have a workaround.

Strange that it's still failing, I don't seem to be able to reproduce it locally.

Are you on the latest MLX version? We added full GPU FFT support recently and that error would make sense on an older version.

My version is 0.15.0, I upgraded to 0.15.1 and tried but got the same error. I have a 2020 MacBook Air with an M1 and 16gb memory.

stack trace:

  File "/Users/red/projects/audio_app/train.py", line 109, in main
    loss = step(batch)
           ^^^^^^^^^^^
  File "/Users/red/projects/audio_app/train.py", line 91, in step
    loss, grads = loss_and_grad_fn(model, X)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/red/opt/anaconda3/envs/native/lib/python3.11/site-packages/mlx/nn/utils.py", line 34, in wrapped_value_grad_fn
    value, grad = value_grad_fn(model.trainable_parameters(), *args, **kwargs)

ValueError: [scatter] GPU scatter does not yet support complex64 for the input or updates.

@awni awni added the enhancement New feature or request label Jul 15, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants