Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] searchsorted #1255

Open
weirdykid opened this issue Jul 7, 2024 · 10 comments
Open

[Feature] searchsorted #1255

weirdykid opened this issue Jul 7, 2024 · 10 comments
Labels
enhancement New feature or request

Comments

@weirdykid
Copy link

Is there an equivalent to np.searchsorted or a way that I could reasonably implement something similar with the existing ops?

@awni
Copy link
Member

awni commented Jul 7, 2024

We don't have it, but you can certainly do a binary search with existing ops. Here's something for 1D arrays that works, it should be fairly straight-forward to support an axis parameter if you need it:

def searchsorted(a, b):
    axis = 0
    size = a.shape[axis]
    steps = math.ceil(math.log2(size))
    upper = size
    lower = 0
    indices = mx.full(b.shape, vals=size//2, dtype=mx.uint32)
    for _ in range(steps):
        lt = b < a[indices]
        new_indices = mx.where(lt, (lower + indices) // 2, (indices + upper) // 2)
        lower = mx.where(lt, lower, indices)
        upper = mx.where(lt, indices, upper)
        indices = new_indices
    return indices

Also it will be a lot faster if you mx.compile it. Particularly if you are using it multiple times with the same shapes.

@awni
Copy link
Member

awni commented Jul 7, 2024

I'm open to adding a little binary search implementation like that into MLX to support searchsorted. We could use the above as a starting point.

@awni awni added the enhancement New feature or request label Jul 7, 2024
@awni
Copy link
Member

awni commented Jul 8, 2024

Another option is to do something like the following. It's linear in a but probably quite a bit faster, especially for small arrays, since it's far few operations:

def searchsorted(a, b):
    return (a[None, :] < b[:, None]).sum(axis=1)

@weirdykid
Copy link
Author

Ah okay I think I can make these workarounds for time time being. Thanks!!

@Saanidhyavats
Copy link
Contributor

@awni can I start working on this issue since it was last active on july 7

@awni
Copy link
Member

awni commented Sep 12, 2024

By all means

@Saanidhyavats
Copy link
Contributor

Another option is to do something like the following. It's linear in a but probably quite a bit faster, especially for small arrays, since it's far few operations:

def searchsorted(a, b):
    return (a[None, :] < b[:, None]).sum(axis=1)

Do we have to go with this approach or with the one in numpy array ?

@awni
Copy link
Member

awni commented Sep 19, 2024

I would compare

def searchsorted(a, b):
    axis = 0
    size = a.shape[axis]
    steps = math.ceil(math.log2(size))
    upper = size
    lower = 0
    indices = mx.full(b.shape, vals=size//2, dtype=mx.uint32)
    for _ in range(steps):
        lt = b < a[indices]
        new_indices = mx.where(lt, (lower + indices) // 2, (indices + upper) // 2)
        lower = mx.where(lt, lower, indices)
        upper = mx.where(lt, indices, upper)
        indices = new_indices
    return indices

and

def searchsorted(a, b):
    return (a[None, :] < b[:, None]).sum(axis=1)

And see which is faster. Presumably there will be a size at which the first is faster but it will start out slower. We could try to dispatch based on that. Or just use the more scalable version.

@Saanidhyavats
Copy link
Contributor

Saanidhyavats commented Sep 19, 2024

I would compare

def searchsorted(a, b):
    axis = 0
    size = a.shape[axis]
    steps = math.ceil(math.log2(size))
    upper = size
    lower = 0
    indices = mx.full(b.shape, vals=size//2, dtype=mx.uint32)
    for _ in range(steps):
        lt = b < a[indices]
        new_indices = mx.where(lt, (lower + indices) // 2, (indices + upper) // 2)
        lower = mx.where(lt, lower, indices)
        upper = mx.where(lt, indices, upper)
        indices = new_indices
    return indices

and

def searchsorted(a, b):
    return (a[None, :] < b[:, None]).sum(axis=1)

And see which is faster. Presumably there will be a size at which the first is faster but it will start out slower. We could try to dispatch based on that. Or just use the more scalable version.

Assuming B and A are representing length of array b and a.
The time complexity and space complexity of the first case is: O(b* log(a)), O(b)
for 2nd case: O(ab), O(ab)

From scalability point of view (if we compare space and time complexity), I think 1st case looks more appropriate right?

@angeloskath
Copy link
Member

The constant factors of the logarithmic approach are quite larger so it is not as simple as that. The following is on my laptop. Also note that mx.compile helps the binary search quit a bit. These are all on the GPU as well, the CPU could be faster for some searches.

Sorted size | Search size | Binary search | Binary search compiled | Linear search
------------+-------------+---------------+------------------------+----------------
     1024   |        1    |     0.65 ms   |             0.38 ms    |      0.14 ms
     1024   |        4    |     0.58 ms   |             0.37 ms    |      0.14 ms
     1024   |       16    |     0.56 ms   |             0.36 ms    |      0.14 ms
     1024   |       64    |     0.56 ms   |             0.36 ms    |      0.14 ms
     1024   |      256    |     0.56 ms   |             0.35 ms    |      0.17 ms
     1024   |     1024    |     0.57 ms   |             0.35 ms    |      0.22 ms
    16384   |        1    |     0.76 ms   |             0.41 ms    |      0.21 ms
    16384   |        4    |     0.75 ms   |             0.42 ms    |      0.14 ms
    16384   |       16    |     0.74 ms   |             0.45 ms    |      0.16 ms
    16384   |       64    |     0.75 ms   |             0.43 ms    |      0.22 ms
    16384   |      256    |     0.74 ms   |             0.42 ms    |      0.41 ms
    16384   |     1024    |     0.80 ms   |             0.44 ms    |      1.18 ms
  2097152   |        1    |     1.02 ms   |             0.53 ms    |      0.61 ms
  2097152   |        4    |     0.98 ms   |             0.55 ms    |      0.91 ms
  2097152   |       16    |     1.00 ms   |             0.55 ms    |      2.11 ms
  2097152   |       64    |     1.01 ms   |             0.55 ms    |      7.77 ms
  2097152   |      256    |     1.02 ms   |             0.56 ms    |     34.49 ms
  2097152   |     1024    |     1.03 ms   |             0.57 ms    |    132.58 ms

The TL;DR is that if you want to search in less than 16k elements or if you only searching 1-2 elements it doesn't make much sense in using the binary search. If otoh you are searching for a lot of elements in a large sorted array (in the millions of elements), then you can expect 100x improvement using binary search :-) .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

4 participants