Skip to content

Commit e4a25e4

Browse files
Marcello MaggioniGoogle-ML-Automation
authored andcommitted
[JAX] Adding square as a fuser target
PiperOrigin-RevId: 831459678
1 parent a22417c commit e4a25e4

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

jax/_src/pallas/fuser/block_spec.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2225,6 +2225,7 @@ def register_eltwise_rule(prim: core.Primitive):
22252225
register_eltwise_rule(lax.cos_p)
22262226
register_eltwise_rule(lax.sqrt_p)
22272227
register_eltwise_rule(lax.rsqrt_p)
2228+
register_eltwise_rule(lax.square_p)
22282229
register_eltwise_rule(lax.log_p)
22292230
register_eltwise_rule(lax.integer_pow_p)
22302231

0 commit comments

Comments
 (0)