Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Cannot create tensor from raw bytes + dtypes #1296

Open
Narsil opened this issue Jul 29, 2024 · 3 comments
Open

[Feature Request] Cannot create tensor from raw bytes + dtypes #1296

Narsil opened this issue Jul 29, 2024 · 3 comments
Labels
bug Something isn't working

Comments

@Narsil
Copy link

Narsil commented Jul 29, 2024

It is not possible at the moment create/manipulate tensors containing bfloat16 outside of MLX.

x = mx.array(jnp.ones((2, 2), dtype=jnp.bfloat16))

This is because it seems everything using memoryview and consorts fail on bf16.
This would be important in safetensors in order to implement in-memory loading (weights = safetensors.mlx.load(f.read()) for instance) and also for lazy loading certain tensors.

mx.load("file.safetensors") works great.

As far as I can tell memoryview object loose the actual dtype anyway.

Any API that would get a bytes + shape + dtype would work super generically I feel (with or without copying depending on constraints).
This would allow me to correctly implement all supported dtypes on MLX within safetensors itself.
(And others to do advance stuff like loading files from network sockets directly)

Thanks a lot for this work.

@awni awni added the bug Something isn't working label Jul 29, 2024
@Narsil
Copy link
Author

Narsil commented Jul 30, 2024

Quick comment re-reading this:

The feature is not to fix the memory view for JAX -> MLX, but really for a way to create tensors from raw bytes instead. (The memoryview just show cases the issue why it's necessary, but we cannot expect jax/tf to be existant for this to work).

@awni
Copy link
Member

awni commented Jul 30, 2024

Can you give an example of what you mean / how that would look? As far as I understand the Python buffer protocol does not support bfloat16.

E.g. memoryview(jnp.ones((2, 2), dtype=jnp.bfloat16)) raises an error.

@awni
Copy link
Member

awni commented Aug 1, 2024

@Narsil I'm still not fully understanding what API you are looking for / what's missing? Right now you can create an array from a Python memoryview object which should be pretty flexible:

a = np.array([1,2,3])

buffer = memoryview(a)
a_mx = mx.array(buffer)

Does that work for you or are you looking for something different? Maybe one thing that's missing is the lack of support for bfloat16?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants