Skip to content

Conversation

@namgyu-youn
Copy link
Contributor

@namgyu-youn namgyu-youn commented Nov 22, 2025

Overview:
In _choose_scale_float8, the per-tensor quantization case (len(block_size) == 0) uses tensor.amax(keepdim=True) while _choose_qparams_affine uses torch.amax(..., keepdim=False) for the same purpose.

This PR aligns _choose_scale_float8 with _choose_qparams_affine by using tensor.amax(keepdim=False) for 1D scale factor.

Related Issue/PR: #3324

Test Plan: test/quantization/test_quant_primitives.py

@pytorch-bot
Copy link

pytorch-bot bot commented Nov 22, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3374

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 22, 2025
@namgyu-youn
Copy link
Contributor Author

@pytorchbot label "topic: not user facing"

@pytorch-bot pytorch-bot bot added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Nov 22, 2025
@namgyu-youn
Copy link
Contributor Author

Local test code:

import torch
from torchao.quantization.quant_primitives import _choose_scale_float8, _choose_scale_float8_old

a = torch.randn(4, 4)
scale = _choose_scale_float8(a, block_size=(4, 1))  # keepdim = False
print(scale)

scale_old = _choose_scale_float8_old(a, block_size=(4, 1))  # keepdim = True
print(scale_old)

And the result is (same):

tensor([[0.0023, 0.0044, 0.0040, 0.0027]])
tensor([[0.0023, 0.0044, 0.0040, 0.0027]])

@jerryzh168
Copy link
Contributor

jerryzh168 commented Nov 22, 2025

this is not enough I think there are

output_shape = [
input_size // block_size[i] for i, input_size in enumerate(tensor.shape)
]
scale = scale.reshape(output_shape)
that will expand the dimension of scale

I actually tried this locally, and doesn't seem to work very well. I think we can leave this for now, but one thing that is useful is to remove the need to do

def _maybe_expand_scale_to_tensor_shape(
, I don't know why we need this if float8 is already using a scale that's matching the dimension of input:
# Reshape scale back to match the expected output shape
# The scale tensor should have the same shape as the input divided by block_size
output_shape = [
input_size // block_size[i] for i, input_size in enumerate(tensor.shape)
]
scale = scale.reshape(output_shape)

maybe try removing calls to _maybe_expand_scale_to_tensor_shape in code and make sure all tests still pass is a good task to work on

also another thing we can do is to simplify the implementation of

def _slice_scale_for_dimension(
since for float8 scale is always matching the dim of input

@namgyu-youn namgyu-youn marked this pull request as draft November 27, 2025 05:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants