We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a22417c commit e4a25e4Copy full SHA for e4a25e4
jax/_src/pallas/fuser/block_spec.py
@@ -2225,6 +2225,7 @@ def register_eltwise_rule(prim: core.Primitive):
2225
register_eltwise_rule(lax.cos_p)
2226
register_eltwise_rule(lax.sqrt_p)
2227
register_eltwise_rule(lax.rsqrt_p)
2228
+register_eltwise_rule(lax.square_p)
2229
register_eltwise_rule(lax.log_p)
2230
register_eltwise_rule(lax.integer_pow_p)
2231
0 commit comments