Skip to content

Conversation

@lhutton1
Copy link
Contributor

This commit fixes support for the argmax operation by allowing fp8/bf16 input operands with an int64 output type in the profile compilance such that it aligns with the spec.

This commit fixes support for the argmax operation by allowing
fp8/bf16 input operands with an int64 output type in the
profile compilance such that it aligns with the spec.

Change-Id: I129383923c6aac639907d0fa6e83ee4bc97c774d
@llvmbot
Copy link
Member

llvmbot commented Nov 10, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Luke Hutton (lhutton1)

Changes

This commit fixes support for the argmax operation by allowing fp8/bf16 input operands with an int64 output type in the profile compilance such that it aligns with the spec.


Full diff: https://github.com/llvm/llvm-project/pull/167378.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+10-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index c774d870a8c45..0005402cd1f44 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -476,7 +476,16 @@ extensionComplianceMap = {
         {{fp32T, i64T}, SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::fp8e4m3}, {{{fp8e4m3T, i32T}, SpecificationVersion::V_1_0}}},
       {{Extension::fp8e5m2}, {{{fp8e5m2T, i32T}, SpecificationVersion::V_1_0}}},
-      {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}}},
+      {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}},
+      {{Extension::fp8e4m3, Extension::int64},
+       {{{fp8e4m3T, i64T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::fp8e5m2, Extension::int64},
+       {{{fp8e5m2T, i64T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::bf16, Extension::int64},
+       {{{bf16T, i64T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf}}},
     {"tosa.avg_pool2d",
      {{{Extension::int16},
        {{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 9bd7aa8f0783e..acbff73b8b948 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -130,3 +130,19 @@ func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
 }
+
+// -----
+
+// CHECK-LABEL: test_argmax_fp8_i64
+func.func @test_argmax_fp8_i64(%arg0: tensor<12x8x16xf8E5M2>) -> tensor<12x16xi64> {
+  %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E5M2>) -> tensor<12x16xi64>
+  return %0 : tensor<12x16xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_argmax_bf16_i64
+func.func @test_argmax_bf16_i64(%arg0: tensor<12x8x16xbf16>) -> tensor<12x16xi64> {
+  %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xbf16>) -> tensor<12x16xi64>
+  return %0 : tensor<12x16xi64>
+}

@llvmbot
Copy link
Member

llvmbot commented Nov 10, 2025

@llvm/pr-subscribers-mlir

Author: Luke Hutton (lhutton1)

Changes

This commit fixes support for the argmax operation by allowing fp8/bf16 input operands with an int64 output type in the profile compilance such that it aligns with the spec.


Full diff: https://github.com/llvm/llvm-project/pull/167378.diff

2 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc (+10-1)
  • (modified) mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir (+16)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
index c774d870a8c45..0005402cd1f44 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaComplianceData.h.inc
@@ -476,7 +476,16 @@ extensionComplianceMap = {
         {{fp32T, i64T}, SpecificationVersion::V_1_1_DRAFT}}},
       {{Extension::fp8e4m3}, {{{fp8e4m3T, i32T}, SpecificationVersion::V_1_0}}},
       {{Extension::fp8e5m2}, {{{fp8e5m2T, i32T}, SpecificationVersion::V_1_0}}},
-      {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}}}},
+      {{Extension::bf16}, {{{bf16T, i32T}, SpecificationVersion::V_1_0}}},
+      {{Extension::fp8e4m3, Extension::int64},
+       {{{fp8e4m3T, i64T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::fp8e5m2, Extension::int64},
+       {{{fp8e5m2T, i64T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf},
+      {{Extension::bf16, Extension::int64},
+       {{{bf16T, i64T}, SpecificationVersion::V_1_1_DRAFT}},
+       allOf}}},
     {"tosa.avg_pool2d",
      {{{Extension::int16},
        {{{i16T, i16T, i16T, i32T, i16T}, SpecificationVersion::V_1_0}}},
diff --git a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
index 9bd7aa8f0783e..acbff73b8b948 100644
--- a/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-validation-version-1p1-valid.mlir
@@ -130,3 +130,19 @@ func.func @test_cast_to_block_scaled_mxint8(%arg0: tensor<4x32xf32>) -> (tensor<
   %0:2 = tosa.cast_to_block_scaled %arg0 {block_size = #tosa.block_size<BLOCK_SIZE_32> : i32, stochastic_round = false} : (tensor<4x32xf32>) -> (tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>)
   return %0#0, %0#1 : tensor<4x32x!tosa.mxint8>, tensor<4x1xf8E8M0FNU>
 }
+
+// -----
+
+// CHECK-LABEL: test_argmax_fp8_i64
+func.func @test_argmax_fp8_i64(%arg0: tensor<12x8x16xf8E5M2>) -> tensor<12x16xi64> {
+  %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xf8E5M2>) -> tensor<12x16xi64>
+  return %0 : tensor<12x16xi64>
+}
+
+// -----
+
+// CHECK-LABEL: test_argmax_bf16_i64
+func.func @test_argmax_bf16_i64(%arg0: tensor<12x8x16xbf16>) -> tensor<12x16xi64> {
+  %0 = tosa.argmax %arg0 { axis = 1 : i32 } : (tensor<12x8x16xbf16>) -> tensor<12x16xi64>
+  return %0 : tensor<12x16xi64>
+}

@IanTaylerLessa-arm
Copy link
Contributor

LGTM. Thanks!

@lhutton1 lhutton1 merged commit d5388c3 into llvm:main Nov 12, 2025
13 checks passed
@lhutton1 lhutton1 deleted the fix-argmax-int64-support branch November 12, 2025 10:42
git-crd pushed a commit to git-crd/crd-llvm-project that referenced this pull request Nov 13, 2025
…m#167378)

This commit fixes support for the argmax operation by allowing fp8/bf16
input operands with an int64 output type in the profile compilance such
that it aligns with the spec.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants