diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 0b66a8f45e..de658b1363 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -2311,7 +2311,7 @@ def _choose_scale_float8( block_size, tensor.shape ) tensor_reshaped = tensor.view(shape_for_reduction) - max_abs = tensor_reshaped.abs().amax(dim=reduction_dims, keepdim=True) + max_abs = tensor_reshaped.abs().amax(dim=reduction_dims, keepdim=False) if hp_value_lb is not None or hp_value_ub is not None: max_abs = torch.clamp(max_abs, min=hp_value_lb, max=hp_value_ub) scale = max_abs / quant_max