Replies: 1 comment
-
|
You could pass the constants by closure instead; modifying your example, it might look something like this: def main_func(a_const, b_const, c_const, d_const, e_const, variable, array:jnp.ndarray)
def compute_variables(variable, x):
# add a dozen lines of logic here
return (x+a_const-b_const*variable, variable)
scan(compute_variable, init=variable), xs=array)The general approach is: you can reference |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Sup,
I recently had to write a module where we had 25
scans in a single file... Well, it was not so simple because we had 5 variables in carry - and most of them were constants. Minimal reproducible example here:I believe in ML there are also hyperparameters which need to be passed and needelessly optimized for by
jitin every line. If there is a constants, XLA may not bother optimizing functions for them in the same way asstatic_argnumsmakes functions faster..Would
scanmaybe get an optionalconstsparameter in which the constant parameters would be passed?It's a QOL feature which also may speed up
scanand other loops computation a little bit.Is it big enough to be implemeneted?
(same goes for
fori_loopandwhile_loops)Beta Was this translation helpful? Give feedback.
All reactions