Skip to content

Commit 37c8f17

Browse files
loisloGoogle-ML-Automation
authored andcommitted
Adds jax.lax.scaled_dot for scaled dot products.
This change introduces a new `jax.lax.scaled_dot` function, which computes a dot product where the inputs can be float8 types. It produces the Composite op that could be lowered to the triton, cuBLAS, cuDNN, or rewritten as the regular dot. The fallback: If the scaled-dot is not enabled the composite call gets inlined as a sequence of the ops: the float8 inputs and scales are converted to bfloat16, the scales broadcasted and then multiplied with the corresponding operands elementwise, then passed to `jax.lax.dot_general`. The function includes input validation for shapes and dtypes, and supports both 2D and 3D (batched) dot products. Tests are added to cover various scenarios, including error conditions and jit compilation. PiperOrigin-RevId: 821573195
1 parent cd50850 commit 37c8f17

File tree

6 files changed

+1069
-0
lines changed

6 files changed

+1069
-0
lines changed

docs/jax.lax.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@ Operators
152152
rng_uniform
153153
round
154154
rsqrt
155+
scaled_dot
155156
scatter
156157
scatter_add
157158
scatter_apply

jax/_src/lax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,4 @@
314314
from jax._src.lax.ann import (
315315
approx_top_k_p as approx_top_k_p
316316
)
317+
from jax._src.lax.scaled_dot import scaled_dot as scaled_dot

0 commit comments

Comments
 (0)