Skip to content

Add AMAX, AVG, NORM1, NORM2, MUL, MUL_NO_ZEROS reduction modes#325

Open
rsuderman wants to merge 3 commits intoiree-org:mainfrom
rsuderman:reduction_rest
Open

Add AMAX, AVG, NORM1, NORM2, MUL, MUL_NO_ZEROS reduction modes#325
rsuderman wants to merge 3 commits intoiree-org:mainfrom
rsuderman:reduction_rest

Conversation

@rsuderman
Copy link
Copy Markdown
Contributor

Enable the remaining cuDNN reduction modes in ReductionAttr and add the corresponding MLIR schemas to the asm emitter:

  • NORM1 lowers to abs + sum.dim_IntList.
  • AMAX lowers to abs + amax.
  • AVG lowers to mean.dim (float dtypes only — torch.aten.mean.dim is not defined on integer tensors, so the sample skips int32 for AVG).
  • NORM2 lowers to mul + sum.dim_IntList + sqrt.
  • MUL lowers directly to torch.prims.prod.
  • MUL_NO_ZEROS uses aten.ne.Scalar to build an i1 mask, then aten.where.ScalarOther to substitute 1 for zero entries before feeding the result to torch.prims.prod, so zero inputs are excluded from the product.

Extend samples/reduction/reduction_ops.cpp to exercise every new mode. Input data is built by a per-mode generateReductionInputData helper so MUL/MUL_NO_ZEROS get a non-trivial pattern (mostly 1s with a 2 and a 3, plus injected zeros for MUL_NO_ZEROS) that stays in range for fp16/int32, and the expected value is computed by the existing reference reduction loop rather than hardcoded.

Add lit tests for each new mode under tests/lit/ and register them in tests/CMakeLists.txt.

Enable the remaining cuDNN reduction modes in ReductionAttr and add
the corresponding MLIR schemas to the asm emitter:

- NORM1 lowers to abs + sum.dim_IntList.
- AMAX lowers to abs + amax.
- AVG lowers to mean.dim (float dtypes only — torch.aten.mean.dim is
  not defined on integer tensors, so the sample skips int32 for AVG).
- NORM2 lowers to mul + sum.dim_IntList + sqrt.
- MUL lowers directly to torch.prims.prod.
- MUL_NO_ZEROS uses aten.ne.Scalar to build an i1 mask, then
  aten.where.ScalarOther to substitute 1 for zero entries before
  feeding the result to torch.prims.prod, so zero inputs are
  excluded from the product.

Extend samples/reduction/reduction_ops.cpp to exercise every new
mode. Input data is built by a per-mode generateReductionInputData
helper so MUL/MUL_NO_ZEROS get a non-trivial pattern (mostly 1s with
a 2 and a 3, plus injected zeros for MUL_NO_ZEROS) that stays in
range for fp16/int32, and the expected value is computed by the
existing reference reduction loop rather than hardcoded.

Add lit tests for each new mode under tests/lit/ and register them
in tests/CMakeLists.txt.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>

# Conflicts:
#	include/fusilli/support/asm_emitter.h
#	samples/reduction/reduction_ops.cpp
Signed-off-by: Rob Suderman <rob.suderman@gmail.com>
Copy link
Copy Markdown
Member

@sjain-stanford sjain-stanford left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might need rebase / CI seems unclean.

#define FUSILLI_REDUCTION_MODES(OP) \
OP(NOT_SET) \
OP(SUM) \
/* OP(ADD) */ \
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is ADD dropped because it's equivalent to SUM or some other reason?

Comment on lines +2002 to +2010
permuteX, // {0}
dimListOss.str(), // {1}
suffix, // {2}
getResultNamesAsm(), // {3}
getOperandNamesAsm(), // {4}
getOperandTypesAsm(), // {5}
getResultTypesAsm(), // {6}
permuteY, // {7}
boolType // {8}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hyper nit: Use /* {0} */ style comments in format string placeholders for consistency with the rest.

{0}
{1}
%dtype_{2} = torch.constant.none
{3}_{2}_perm = torch.prims.prod {4}, %reduction_dims_{2}, %dtype_{2} : {5}, !torch.list<int>, !torch.none -> {6}
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude tells me this is not registered in the torch dialect. Is that true?

Do we need to use torch.aten.prod.dim_int (with chaining for multiple reduction dims)?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants