Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ def compression_param_info(
}
return output

def compress_scale(
self,
scale: Tensor,
quantization_args: QuantizationArgs,
) -> Dict[str, torch.Tensor]:
assert quantization_args.scale_dtype is not None
return scale.to(quantization_args.scale_dtype)

def compress_weight(
self,
weight: Tensor,
Expand All @@ -103,7 +111,9 @@ def compress_weight(
if device is not None:
weight_packed = weight_packed.to(device)
compressed_dict["weight_packed"] = weight_packed
compressed_dict["weight_scale"] = scale.to(quantization_args.scale_dtype)
compressed_dict["weight_scale"] = self.compress_scale(
scale=scale, quantization_args=quantization_args
)
return compressed_dict

def decompress_weight(
Expand All @@ -130,7 +140,21 @@ class MXFP4PackedCompressor(NVFP4PackedCompressor):
Alias for mxfp4 quantized models
"""

pass
def compress_scale(
self,
scale: Tensor,
quantization_args: QuantizationArgs,
) -> Dict[str, torch.Tensor]:
assert quantization_args.scale_dtype is not None
scale_exp = 127 + torch.floor(torch.log2(scale)).to(torch.int32) - 2
return scale_exp.to(quantization_args.scale_dtype)

def decompress_weight(
self,
compressed_data: Dict[str, Tensor],
quantization_args: Optional[QuantizationArgs] = None,
) -> torch.Tensor:
raise NotImplementedError("MXFP4 Decompression is currently not supported")


@torch.compile(fullgraph=True, dynamic=True)
Expand Down
2 changes: 2 additions & 0 deletions src/compressed_tensors/config/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def _get_quant_compression_format(
is_weight_only = weight_args is not None and input_args is None

if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value:
if weight_args.group_size == 32:
return CompressionFormat.mxfp4_pack_quantized
return CompressionFormat.nvfp4_pack_quantized

if is_weight_only: # w4a16 and w8a16
Expand Down
41 changes: 40 additions & 1 deletion src/compressed_tensors/quantization/quant_scheme.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from copy import deepcopy
from typing import List, Optional

import torch
from compressed_tensors.config import CompressionFormat
from compressed_tensors.quantization.quant_args import (
FP8_E4M3_DATA,
Expand Down Expand Up @@ -192,6 +192,43 @@ def is_preset_scheme(name: str) -> bool:
),
)

MXFP4A16 = dict(
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.GROUP,
symmetric=True,
dynamic=False,
group_size=32,
scale_dtype=torch.uint8,
zp_dtype=torch.uint8,
)
)

MXFP4 = dict(
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.GROUP,
symmetric=True,
dynamic=False,
group_size=32,
scale_dtype=torch.uint8,
zp_dtype=torch.uint8,
),
input_activations=QuantizationArgs(
num_bits=4,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.GROUP,
dynamic=True,
symmetric=True,
group_size=32,
scale_dtype=torch.uint8,
zp_dtype=torch.uint8,
),
)


# 8 bit integer weights and 8 bit activations quantization
INT8_W8A8 = dict(
weights=QuantizationArgs(
Expand Down Expand Up @@ -343,4 +380,6 @@ def is_preset_scheme(name: str) -> bool:
"FP8_BLOCK": FP8_BLOCK,
"NVFP4A16": NVFP4A16,
"NVFP4": NVFP4,
"MXFP4A16": MXFP4A16,
"MXFP4": MXFP4,
}
17 changes: 14 additions & 3 deletions src/compressed_tensors/quantization/utils/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@
round_to_quantized_type_dtype,
)
from compressed_tensors.quantization.quant_scheme import QuantizationScheme
from compressed_tensors.quantization.utils.mxfp4_utils import (
generate_mxfp4_scales,
maybe_convert_from_mxfp4_exp,
should_generatre_mxfp4_scales,
)
from compressed_tensors.utils import deprecated
from loguru import logger
from torch import FloatTensor, IntTensor, Tensor
Expand Down Expand Up @@ -88,7 +93,10 @@ def calculate_qparams(
# 1. Generate scale and zero-point
if quantization_args.symmetric:
max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals))
scales = max_val_pos / (float(bit_range) / 2)
if should_generatre_mxfp4_scales(args=quantization_args):
scales = generate_mxfp4_scales(x=max_val_pos)
else:
scales = max_val_pos / (float(bit_range) / 2)
zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype)
else:
if (
Expand All @@ -112,7 +120,10 @@ def calculate_qparams(
scales, dtype=quantization_args.scale_dtype
)

# 4. Update any 0s with small values to
# 4. Optionally remove exponent
scales = maybe_convert_from_mxfp4_exp(quantization_args, scales)

# 5. Update any 0s with small values to
# prevent div by 0
eps = _get_dtype_eps(
dtype=quantization_args.scale_dtype
Expand All @@ -125,7 +136,7 @@ def calculate_qparams(
scales,
)

# 5. Round the zp to zp_dtype
# 6. Round the zp to zp_dtype
zero_points = round_to_quantized_type_dtype(
zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False
)
Expand Down
44 changes: 25 additions & 19 deletions src/compressed_tensors/quantization/utils/mxfp4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,29 @@
# limitations under the License.

import torch
from compressed_tensors.quantization.quant_args import BFLOAT16_DATA, FP4_E2M1_DATA
from compressed_tensors.quantization.quant_args import (
BFLOAT16_DATA,
FP4_E2M1_DATA,
QuantizationArgs,
)


__all__ = ["convert_mxfp4_exp_scale", "generate_mxfp4_scales", "round_to_power_2"]
__all__ = [
"maybe_convert_from_mxfp4_exp",
"generate_mxfp4_scales",
"round_to_power_2",
"should_generatre_mxfp4_scales",
]

# Reference: https://github.com/vllm-project/vllm/blob/main/tests/quantization/reference_mxfp4.py # noqa: E501


def convert_mxfp4_exp_scale(
scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16
def should_generatre_mxfp4_scales(args: QuantizationArgs):
return args.num_bits == 4 and args.type == "float" and args.group_size == 32


def maybe_convert_from_mxfp4_exp(
args: QuantizationArgs, scale: torch.Tensor
) -> torch.Tensor:
"""
Converts mxfp4 scales. Scales are powers of 2, with the
Expand All @@ -32,10 +45,12 @@ def convert_mxfp4_exp_scale(
:param scale: uint8 exponent scale
:param dtype: dense dtype
"""
assert scale.dtype == torch.uint8
scale_exp = scale.to(torch.int32) - 127
scale = 2.00 ** (scale_exp.to(torch.float))
return scale.to(dtype)
original_dtype = scale.dtype
if should_generatre_mxfp4_scales(args):
scale_exp = scale.to(torch.int32) - 127
scale = 2.00 ** (scale_exp.to(torch.float))
return scale.to(original_dtype)
return scale


def round_to_power_2(x: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -77,21 +92,12 @@ def generate_mxfp4_scales(x: torch.Tensor) -> torch.Tensor:
Generate mxfp4 scales. The scales require the following steps
1. Round to the closest power of 2
2. Convert to exponent
3. Store in uint8

Called when calculating qparams using observers.

:param x: tensor to round to closest power of 2
:returns uint8 scales as exponents
:returns scales as exponents
"""
# Round to closest power of 2
scale_power_2 = round_to_power_2(x)
# Convert to exponent
scale_exp = 127 + torch.floor(torch.log2(scale_power_2)).to(torch.int32) - 2
# Clamp and store in uint8, as expected by mxfp4
scale_exp = torch.clamp(
scale_exp,
max=torch.iinfo(torch.uint8).max,
min=torch.iinfo(torch.uint8).min,
)
return scale_exp.to(torch.uint8)
return 127 + torch.floor(torch.log2(scale_power_2)) - 2
24 changes: 21 additions & 3 deletions tests/test_quantization/test_utils/test_mxfp4_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.

import torch
from compressed_tensors.quantization import round_to_quantized_type_dtype
from compressed_tensors.quantization.utils import (
convert_mxfp4_exp_scale,
generate_mxfp4_scales,
maybe_convert_from_mxfp4_exp,
round_to_power_2,
)

Expand Down Expand Up @@ -61,6 +62,12 @@ def test_round_power_2():


def test_mxfp4_scales_e2e():
from compressed_tensors.quantization.quant_args import (
QuantizationArgs,
QuantizationStrategy,
QuantizationType,
)

mock_weight = torch.normal(mean=0.0002, std=0.0576, size=(2880, 2880))

x = mock_weight.reshape(*mock_weight.shape[:-1], -1, 32).to(torch.bfloat16)
Expand All @@ -71,8 +78,19 @@ def test_mxfp4_scales_e2e():
max_vals = torch.max(max_vals, torch.zeros_like(max_vals))
block_max = torch.max(torch.abs(min_vals), torch.abs(max_vals))

scales_generated = generate_mxfp4_scales(block_max)
converted_ct = convert_mxfp4_exp_scale(scales_generated)
args = QuantizationArgs(
num_bits=4,
type=QuantizationType.FLOAT,
strategy=QuantizationStrategy.GROUP,
group_size=32,
scale_dtype=torch.uint8,
zp_dtype=torch.uint8,
)

scales = generate_mxfp4_scales(block_max)
scales = round_to_quantized_type_dtype(scales, dtype=args.scale_dtype)

converted_ct = maybe_convert_from_mxfp4_exp(args=args, scale=scales)

scales_exp = torch.log2(converted_ct)
block_max_exp = torch.floor(torch.log2(round_to_power_2(block_max))) - 2
Expand Down