Adds jax.lax.scaled_dot for scaled dot products.
#32918
Open
+1,063
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Adds
jax.lax.scaled_dotfor scaled dot products.This change introduces a new
jax.lax.scaled_dotfunction, which computes a dot product where the inputs can be float8 types. It produces the Composite op that could be lowered to the triton, cuBLAS, cuDNN, or rewritten as the regular dot.The fallback: If the scaled-dot is not enabled the composite call gets inlined as a sequence of the ops: the float8 inputs and scales are converted to bfloat16, the scales broadcasted and then multiplied with the corresponding operands elementwise, then passed to
jax.lax.dot_general.The function includes input validation. Tests are added to cover various scenarios, including error conditions and jit compilation.