Commit 9379e55
Adds
This change introduces a new `jax.lax.scaled_dot` function, which computes a dot product where the inputs can be float8 types. It produces the Composite op that could be lowered to the triton, cuBLAS, cuDNN, or rewritten as the regular dot.
The fallback: If the scaled-dot is not enabled the composite call gets inlined as a sequence of the ops: the float8 inputs and scales are converted to bfloat16, the scales broadcasted and then multiplied with the corresponding operands elementwise, then passed to `jax.lax.dot_general`.
The function includes input validation. Tests are added to cover various scenarios, including error conditions and jit compilation.
PiperOrigin-RevId: 821573195jax.lax.scaled_dot for scaled dot products.1 parent dd62dd7 commit 9379e55
File tree
6 files changed
+1063
-0
lines changed- docs
- jax
- _src/lax
- lax
- tests
6 files changed
+1063
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
152 | 152 | | |
153 | 153 | | |
154 | 154 | | |
| 155 | + | |
155 | 156 | | |
156 | 157 | | |
157 | 158 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
314 | 314 | | |
315 | 315 | | |
316 | 316 | | |
| 317 | + | |
0 commit comments