-
|
Hello, I'm working on a function and one of the steps is to randomly select an element from an array that meets a certain condition and return the index of the element. If the function is not decorated with But if the function is decorated with Is there an alternative? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
|
You can use the import jax
import jax.numpy as jnp
def get_eligible_cells(arr, step):
key = jax.random.PRNGKey(step)
mask = (arr>2)
matched_inds = jnp.where(jnp.reshape(mask, (-1)), size=arr.size)[0]
idx = jax.random.randint(key, shape=(), minval=0, maxval=mask.sum())
return matched_inds[idx]
array_1d = jnp.array(([1, 2, 3], [4, 5, 6]))
ind = jax.jit(get_eligible_cells)(array_1d, 0)
print(ind)
# 4 |
Beta Was this translation helpful? Give feedback.
You can use the
sizeargument tojnp.whereto keep the size of the arrays static, and then userandintto index into the valid elements: