diff --git a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py index b9bd8edef..dd3c2a463 100644 --- a/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py +++ b/src/compressed_tensors/compressors/quantized_compressors/fp4_quantized.py @@ -81,6 +81,14 @@ def compression_param_info( } return output + def compress_scale( + self, + scale: Tensor, + quantization_args: QuantizationArgs, + ) -> Dict[str, torch.Tensor]: + assert quantization_args.scale_dtype is not None + return scale.to(quantization_args.scale_dtype) + def compress_weight( self, weight: Tensor, @@ -103,7 +111,9 @@ def compress_weight( if device is not None: weight_packed = weight_packed.to(device) compressed_dict["weight_packed"] = weight_packed - compressed_dict["weight_scale"] = scale.to(quantization_args.scale_dtype) + compressed_dict["weight_scale"] = self.compress_scale( + scale=scale, quantization_args=quantization_args + ) return compressed_dict def decompress_weight( @@ -130,7 +140,21 @@ class MXFP4PackedCompressor(NVFP4PackedCompressor): Alias for mxfp4 quantized models """ - pass + def compress_scale( + self, + scale: Tensor, + quantization_args: QuantizationArgs, + ) -> Dict[str, torch.Tensor]: + assert quantization_args.scale_dtype is not None + scale_exp = 127 + torch.floor(torch.log2(scale)).to(torch.int32) - 2 + return scale_exp.to(quantization_args.scale_dtype) + + def decompress_weight( + self, + compressed_data: Dict[str, Tensor], + quantization_args: Optional[QuantizationArgs] = None, + ) -> torch.Tensor: + raise NotImplementedError("MXFP4 Decompression is currently not supported") @torch.compile(fullgraph=True, dynamic=True) diff --git a/src/compressed_tensors/config/format.py b/src/compressed_tensors/config/format.py index 4f6610de3..5d0c11436 100644 --- a/src/compressed_tensors/config/format.py +++ b/src/compressed_tensors/config/format.py @@ -50,6 +50,8 @@ def _get_quant_compression_format( is_weight_only = weight_args is not None and input_args is None if weight_args.num_bits == 4 and weight_args.type == QuantizationType.FLOAT.value: + if weight_args.group_size == 32: + return CompressionFormat.mxfp4_pack_quantized return CompressionFormat.nvfp4_pack_quantized if is_weight_only: # w4a16 and w8a16 diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index d31e133e8..6e4e103c4 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -11,11 +11,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import warnings from copy import deepcopy from typing import List, Optional +import torch from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import ( FP8_E4M3_DATA, @@ -192,6 +192,43 @@ def is_preset_scheme(name: str) -> bool: ), ) +MXFP4A16 = dict( + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + symmetric=True, + dynamic=False, + group_size=32, + scale_dtype=torch.uint8, + zp_dtype=torch.uint8, + ) +) + +MXFP4 = dict( + weights=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + symmetric=True, + dynamic=False, + group_size=32, + scale_dtype=torch.uint8, + zp_dtype=torch.uint8, + ), + input_activations=QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + dynamic=True, + symmetric=True, + group_size=32, + scale_dtype=torch.uint8, + zp_dtype=torch.uint8, + ), +) + + # 8 bit integer weights and 8 bit activations quantization INT8_W8A8 = dict( weights=QuantizationArgs( @@ -343,4 +380,6 @@ def is_preset_scheme(name: str) -> bool: "FP8_BLOCK": FP8_BLOCK, "NVFP4A16": NVFP4A16, "NVFP4": NVFP4, + "MXFP4A16": MXFP4A16, + "MXFP4": MXFP4, } diff --git a/src/compressed_tensors/quantization/utils/helpers.py b/src/compressed_tensors/quantization/utils/helpers.py index 45a4ef83c..59c5f245a 100644 --- a/src/compressed_tensors/quantization/utils/helpers.py +++ b/src/compressed_tensors/quantization/utils/helpers.py @@ -27,6 +27,11 @@ round_to_quantized_type_dtype, ) from compressed_tensors.quantization.quant_scheme import QuantizationScheme +from compressed_tensors.quantization.utils.mxfp4_utils import ( + generate_mxfp4_scales, + maybe_convert_from_mxfp4_exp, + should_generatre_mxfp4_scales, +) from compressed_tensors.utils import deprecated from loguru import logger from torch import FloatTensor, IntTensor, Tensor @@ -88,7 +93,10 @@ def calculate_qparams( # 1. Generate scale and zero-point if quantization_args.symmetric: max_val_pos = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - scales = max_val_pos / (float(bit_range) / 2) + if should_generatre_mxfp4_scales(args=quantization_args): + scales = generate_mxfp4_scales(x=max_val_pos) + else: + scales = max_val_pos / (float(bit_range) / 2) zero_points = torch.zeros(scales.shape, device=device, dtype=min_vals.dtype) else: if ( @@ -112,7 +120,10 @@ def calculate_qparams( scales, dtype=quantization_args.scale_dtype ) - # 4. Update any 0s with small values to + # 4. Optionally remove exponent + scales = maybe_convert_from_mxfp4_exp(quantization_args, scales) + + # 5. Update any 0s with small values to # prevent div by 0 eps = _get_dtype_eps( dtype=quantization_args.scale_dtype @@ -125,7 +136,7 @@ def calculate_qparams( scales, ) - # 5. Round the zp to zp_dtype + # 6. Round the zp to zp_dtype zero_points = round_to_quantized_type_dtype( zero_points, dtype=quantization_args.zp_dtype, cast_to_original_dtype=False ) diff --git a/src/compressed_tensors/quantization/utils/mxfp4_utils.py b/src/compressed_tensors/quantization/utils/mxfp4_utils.py index 17821ae72..21dd841fb 100644 --- a/src/compressed_tensors/quantization/utils/mxfp4_utils.py +++ b/src/compressed_tensors/quantization/utils/mxfp4_utils.py @@ -13,16 +13,29 @@ # limitations under the License. import torch -from compressed_tensors.quantization.quant_args import BFLOAT16_DATA, FP4_E2M1_DATA +from compressed_tensors.quantization.quant_args import ( + BFLOAT16_DATA, + FP4_E2M1_DATA, + QuantizationArgs, +) -__all__ = ["convert_mxfp4_exp_scale", "generate_mxfp4_scales", "round_to_power_2"] +__all__ = [ + "maybe_convert_from_mxfp4_exp", + "generate_mxfp4_scales", + "round_to_power_2", + "should_generatre_mxfp4_scales", +] # Reference: https://github.com/vllm-project/vllm/blob/main/tests/quantization/reference_mxfp4.py # noqa: E501 -def convert_mxfp4_exp_scale( - scale: torch.Tensor, dtype: torch.dtype = torch.bfloat16 +def should_generatre_mxfp4_scales(args: QuantizationArgs): + return args.num_bits == 4 and args.type == "float" and args.group_size == 32 + + +def maybe_convert_from_mxfp4_exp( + args: QuantizationArgs, scale: torch.Tensor ) -> torch.Tensor: """ Converts mxfp4 scales. Scales are powers of 2, with the @@ -32,10 +45,12 @@ def convert_mxfp4_exp_scale( :param scale: uint8 exponent scale :param dtype: dense dtype """ - assert scale.dtype == torch.uint8 - scale_exp = scale.to(torch.int32) - 127 - scale = 2.00 ** (scale_exp.to(torch.float)) - return scale.to(dtype) + original_dtype = scale.dtype + if should_generatre_mxfp4_scales(args): + scale_exp = scale.to(torch.int32) - 127 + scale = 2.00 ** (scale_exp.to(torch.float)) + return scale.to(original_dtype) + return scale def round_to_power_2(x: torch.Tensor) -> torch.Tensor: @@ -77,21 +92,12 @@ def generate_mxfp4_scales(x: torch.Tensor) -> torch.Tensor: Generate mxfp4 scales. The scales require the following steps 1. Round to the closest power of 2 2. Convert to exponent - 3. Store in uint8 Called when calculating qparams using observers. :param x: tensor to round to closest power of 2 - :returns uint8 scales as exponents + :returns scales as exponents """ # Round to closest power of 2 scale_power_2 = round_to_power_2(x) - # Convert to exponent - scale_exp = 127 + torch.floor(torch.log2(scale_power_2)).to(torch.int32) - 2 - # Clamp and store in uint8, as expected by mxfp4 - scale_exp = torch.clamp( - scale_exp, - max=torch.iinfo(torch.uint8).max, - min=torch.iinfo(torch.uint8).min, - ) - return scale_exp.to(torch.uint8) + return 127 + torch.floor(torch.log2(scale_power_2)) - 2 diff --git a/tests/test_quantization/test_utils/test_mxfp4_utils.py b/tests/test_quantization/test_utils/test_mxfp4_utils.py index 723228bec..15ac84801 100644 --- a/tests/test_quantization/test_utils/test_mxfp4_utils.py +++ b/tests/test_quantization/test_utils/test_mxfp4_utils.py @@ -13,9 +13,10 @@ # limitations under the License. import torch +from compressed_tensors.quantization import round_to_quantized_type_dtype from compressed_tensors.quantization.utils import ( - convert_mxfp4_exp_scale, generate_mxfp4_scales, + maybe_convert_from_mxfp4_exp, round_to_power_2, ) @@ -61,6 +62,12 @@ def test_round_power_2(): def test_mxfp4_scales_e2e(): + from compressed_tensors.quantization.quant_args import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, + ) + mock_weight = torch.normal(mean=0.0002, std=0.0576, size=(2880, 2880)) x = mock_weight.reshape(*mock_weight.shape[:-1], -1, 32).to(torch.bfloat16) @@ -71,8 +78,19 @@ def test_mxfp4_scales_e2e(): max_vals = torch.max(max_vals, torch.zeros_like(max_vals)) block_max = torch.max(torch.abs(min_vals), torch.abs(max_vals)) - scales_generated = generate_mxfp4_scales(block_max) - converted_ct = convert_mxfp4_exp_scale(scales_generated) + args = QuantizationArgs( + num_bits=4, + type=QuantizationType.FLOAT, + strategy=QuantizationStrategy.GROUP, + group_size=32, + scale_dtype=torch.uint8, + zp_dtype=torch.uint8, + ) + + scales = generate_mxfp4_scales(block_max) + scales = round_to_quantized_type_dtype(scales, dtype=args.scale_dtype) + + converted_ct = maybe_convert_from_mxfp4_exp(args=args, scale=scales) scales_exp = torch.log2(converted_ct) block_max_exp = torch.floor(torch.log2(round_to_power_2(block_max))) - 2