Skip to content

Conversation

@copybara-service
Copy link

@copybara-service copybara-service bot commented Oct 28, 2025

Adds jax.lax.scaled_dot for scaled dot products.

This change introduces a new jax.lax.scaled_dot function, which computes a dot product where the inputs can be float8 types. It produces the Composite op that could be lowered to the triton, cuBLAS, cuDNN, or rewritten as the regular dot.

The fallback: If the scaled-dot is not enabled the composite call gets inlined as a sequence of the ops: the float8 inputs and scales are converted to bfloat16, the scales broadcasted and then multiplied with the corresponding operands elementwise, then passed to jax.lax.dot_general.

The function includes input validation. Tests are added to cover various scenarios, including error conditions and jit compilation.

@copybara-service copybara-service bot force-pushed the test_821573195 branch 11 times, most recently from efb8541 to 14b7b51 Compare November 3, 2025 15:34
@copybara-service copybara-service bot force-pushed the test_821573195 branch 12 times, most recently from 982bf06 to 684cb72 Compare November 24, 2025 08:43
@copybara-service copybara-service bot force-pushed the test_821573195 branch 6 times, most recently from 37c8f17 to baa6bbf Compare December 13, 2025 16:02
@copybara-service copybara-service bot force-pushed the test_821573195 branch 12 times, most recently from 9379e55 to c3b4bb5 Compare December 16, 2025 11:16
This change introduces a new `jax.lax.scaled_dot` function, which computes a dot product where the inputs can be float8 types. It produces the Composite op that could be lowered to the triton, cuBLAS, cuDNN, or rewritten as the regular dot.

The fallback: If the scaled-dot is not enabled the composite call gets inlined as a sequence of the ops: the float8 inputs and scales are converted to bfloat16, the scales broadcasted and then multiplied with the corresponding operands elementwise, then passed to `jax.lax.dot_general`.

The function includes input validation. Tests are added to cover various scenarios, including error conditions and jit compilation.

PiperOrigin-RevId: 821573195
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant