Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benchmarks/float8/float8_inference_roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@

import torchao
from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
Expand Down Expand Up @@ -433,13 +433,13 @@ def run(
kernel_preference=KernelPreference.TORCH,
)
elif recipe_name == "mxfp8_cublas":
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
kernel_preference=KernelPreference.AUTO,
)
elif recipe_name == "mxfp4_cutlass":
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float4_e2m1fn_x2,
weight_dtype=torch.float4_e2m1fn_x2,
kernel_preference=KernelPreference.AUTO,
Expand Down
4 changes: 2 additions & 2 deletions test/integration/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
from vllm import LLM, SamplingParams

from torchao.prototype.mx_formats import MXFPInferenceConfig
from torchao.prototype.mx_formats import MXDynamicActivationMXWeightConfig
from torchao.quantization.granularity import PerRow, PerTensor
from torchao.quantization.quant_api import (
CutlassInt4PackedLayout,
Expand Down Expand Up @@ -70,7 +70,7 @@ def get_tests() -> List[TorchAoConfig]:
Int8DynamicActivationInt4WeightConfig(layout=CutlassInt4PackedLayout())
)
]
SM100_TESTS = [TorchAoConfig(MXFPInferenceConfig())]
SM100_TESTS = [TorchAoConfig(MXDynamicActivationMXWeightConfig())]

# Check CUDA availability first
if not torch.cuda.is_available():
Expand Down
8 changes: 4 additions & 4 deletions test/prototype/mx_formats/test_inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.profiler import ProfilerActivity, profile

from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_inference_workflow_mx(
kernel_choice = KernelPreference.EMULATED
else:
kernel_choice = KernelPreference.AUTO
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=elem_dtype,
weight_dtype=elem_dtype,
kernel_preference=kernel_choice,
Expand Down Expand Up @@ -247,7 +247,7 @@ class VLLMIntegrationTestCase(TorchAOIntegrationTestCase):
reason="torch.compile requires PyTorch 2.8+",
)
def test_slice_and_copy_similar_to_vllm(self):
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
kernel_preference=KernelPreference.EMULATED,
Expand All @@ -260,7 +260,7 @@ def test_slice_and_copy_similar_to_vllm(self):
reason="torch.compile requires PyTorch 2.8+",
)
def test_narrow_similar_to_vllm(self):
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
kernel_preference=KernelPreference.EMULATED,
Expand Down
4 changes: 2 additions & 2 deletions test/prototype/mx_formats/test_mx_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch.nn as nn

from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
Expand Down Expand Up @@ -41,7 +41,7 @@ def test_serialization(recipe_name):
fname = None
with tempfile.NamedTemporaryFile(delete=False, mode="w") as f:
if recipe_name == "mxfp8":
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
kernel_preference=KernelPreference.EMULATED,
Expand Down
6 changes: 3 additions & 3 deletions torchao/prototype/mx_formats/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ import torch.nn as nn
from torchao.quantization import quantize_
import torchao.prototype.mx_formats
from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
Expand All @@ -120,7 +120,7 @@ x = torch.randn(128, 32, device="cuda", dtype=torch.bfloat16)
# mxfp8

m_mxfp8 = copy.deepcopy(m)
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float8_e4m3fn,
weight_dtype=torch.float8_e4m3fn,
kernel_preference=KernelPreference.AUTO,
Expand All @@ -132,7 +132,7 @@ y_mxfp8 = m_mxfp8(x)
# mxfp4

m_mxfp4 = copy.deepcopy(m)
config = MXFPInferenceConfig(
config = MXDynamicActivationMXWeightConfig(
activation_dtype=torch.float4_e2m1fn_x2,
weight_dtype=torch.float4_e2m1fn_x2,
kernel_preference=KernelPreference.AUTO,
Expand Down
4 changes: 2 additions & 2 deletions torchao/prototype/mx_formats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

# Note: Prototype and subject to change
from torchao.prototype.mx_formats.inference_workflow import (
MXFPInferenceConfig,
MXDynamicActivationMXWeightConfig,
NVFP4InferenceConfig,
NVFP4MMConfig,
)
Expand All @@ -17,7 +17,7 @@
__all__ = [
"MXLinearConfig",
"MXLinearRecipeName",
"MXFPInferenceConfig",
"MXDynamicActivationMXWeightConfig",
"NVFP4InferenceConfig",
"NVFP4MMConfig",
]
29 changes: 3 additions & 26 deletions torchao/prototype/mx_formats/inference_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,39 +36,16 @@
)


# TODO The naming for these configs is a little weird, rename before moving to public API
# Note: This API is extra prototype and will change in the future
@dataclass
class MXFPInferenceConfig(AOBaseConfig):
class MXDynamicActivationMXWeightConfig(AOBaseConfig):
"""
MX Format Inference Quantization

This module provides support for running inference with float8 quantization using MX formats.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: comment seem outdated, this supports both mxfp8 and mxfp4 right?

The quantization flow works as follows:

1. Weight Quantization:
- In _mx_inference_linear_transform(), the module's weight is converted to an MXTensor
- The weight is quantized to the specified dtype (float8_e4m3fn by default)
- This happens when quantize_() is called with an MXFPInferenceConfig

2. Activation Quantization:
- A callable (_input_activation_quant_func_mxfp) is defined that will quantize
activations during inference to the same dtype
- This function is passed to to_linear_activation_quantized() along with the
already-quantized weight

3. Runtime Flow:
- When the quantized module is called, the input goes through the LinearActivationQuantizedTensor
- The input (activation) is quantized just-in-time using the provided function
- The MX quantized activation and MX weight are used together in F.linear

Requirements:
- NVIDIA SM100+ hardware (Blackwell or newer) is required for execution
- PyTorch 2.5+ for proper serialization support

See also:
- LinearActivationQuantizedTensor in torchao.quantization.quant_api
- MXTensor in torchao.prototype.mx_formats.mx_tensor
"""

block_size: int = 32
Expand All @@ -95,9 +72,9 @@ def _linear_extra_repr(self):
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={_quantization_type(self.weight)}"


@register_quantize_module_handler(MXFPInferenceConfig)
@register_quantize_module_handler(MXDynamicActivationMXWeightConfig)
def _mx_inference_linear_transform(
module: torch.nn.Module, config: MXFPInferenceConfig
module: torch.nn.Module, config: MXDynamicActivationMXWeightConfig
):
weight = module.weight

Expand Down
Loading