Skip to content
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
import unittest

import torch
from torch import nn
from torch.testing._internal import common_utils

from torchao.quantization import (
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Float8DynamicActivationFloat8WeightConfig,
)
from torchao.quantization.quant_api import (
quantize_,
)
from torchao.quantization.quantize_.workflows.float8.float8_packing_format import (
Float8PackingFormat,
)
from torchao.sparsity import apply_fake_sparsity
from torchao.utils import is_sm_at_least_90

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)


class TestFloat8SemiSparseTensor(common_utils.TestCase):
@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize(
"packing_format",
[Float8PackingFormat.SPARSE_CUTLASS, Float8PackingFormat.SPARSE_CUSPARSELT],
)
def test_fp8_cutlass_sparse(self, compile, packing_format):
with torch.inference_mode():
input = torch.rand((256, 256), dtype=torch.bfloat16, device="cuda")
model = (
nn.Sequential(
nn.Linear(256, 1024),
nn.Linear(1024, 256),
)
.bfloat16()
.cuda()
.eval()
)

apply_fake_sparsity(model)
model_copy = copy.deepcopy(model)

# Quantized
quantize_(model_copy, Float8DynamicActivationFloat8WeightConfig())
dense_result = model_copy(input)

# Sparse + quantized
quantize_(
model,
Float8DynamicActivationFloat8SemiSparseWeightConfig(
float8_packing_format=packing_format
),
)
if compile:
model = torch.compile(model)
sparse_result = model(input)

torch.testing.assert_close(
dense_result.to(torch.float),
sparse_result.to(torch.float),
atol=3e-1,
rtol=3e-1,
)

@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_fp8_cutlass_sparse_lowering_op_clone(self):
with torch.inference_mode():
model = nn.Linear(256, 1024).half().cuda().eval()
apply_fake_sparsity(model)
quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig())

original = model.weight.dequantize()
cloned = model.weight.clone().dequantize()

for o, c in zip(original, cloned):
torch.testing.assert_close(o, c, atol=0.0, rtol=0.0)

@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_fp8_cutlass_sparse_lowering_op_to(self):
# Need to run with inference mode to avoid dispatching to `aten.to_copy`
with torch.inference_mode():
model = nn.Linear(256, 1024).half().cuda().eval()
apply_fake_sparsity(model)
model_copy = copy.deepcopy(model)
expected = model_copy.weight.to(dtype=torch.float)

quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig())

original = torch.ops.aten.to.dtype_layout(
model.weight,
dtype=torch.float,
layout=torch.strided,
)
torch.testing.assert_close(expected, original, atol=1e-1, rtol=1e-1)


common_utils.instantiate_parametrized_tests(TestFloat8SemiSparseTensor)

if __name__ == "__main__":
unittest.main()
71 changes: 54 additions & 17 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
make_packed_linear_int8_dynamic_activation_intx_weight_tensor,
)
from torchao.dtypes.utils import Layout
from torchao.float8.config import e4m3_dtype, e5m2_dtype
from torchao.float8.config import e4m3_dtype
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.inference import (
Float8MMConfig,
Expand All @@ -76,6 +76,7 @@
from torchao.quantization.quantize_.workflows import (
Float8OpaqueTensor,
Float8PackingFormat,
Float8SemiSparseTensor,
Float8Tensor,
Int4ChooseQParamsAlgorithm,
Int4MarlinSparseTensor,
Expand Down Expand Up @@ -2000,44 +2001,80 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
`weight_dtype`: data type for quantized weight tensor.
"""

layout: Layout = CutlassSemiSparseLayout()
activation_dtype: torch.dtype = e5m2_dtype
activation_dtype: torch.dtype = e4m3_dtype
weight_dtype: torch.dtype = e4m3_dtype
granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = PerRow()
activation_value_lb: Optional[float] = None
activation_value_ub: Optional[float] = None
float8_packing_format: Float8PackingFormat = Float8PackingFormat.SPARSE_CUTLASS
version: int = 2

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Float8DynamicActivationFloat8SemiSparseWeightConfig"
)

assert self.float8_packing_format in {
Float8PackingFormat.SPARSE_CUTLASS,
Float8PackingFormat.SPARSE_CUSPARSELT,
}, f"{self.float8_packing_format} is not supported"


@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
module: torch.nn.Module, config: Float8DynamicActivationFloat8SemiSparseWeightConfig
module: torch.nn.Module,
config: Float8DynamicActivationFloat8SemiSparseWeightConfig,
*,
parameter_name: str = "weight",
):
assert is_sm_at_least_90(), "Float8 quantization is only supported on CUDA>=9.0"

if isinstance(module, Float8Linear):
module = _unwrap_float8_linear(module)

weight = module.weight
unquantized_param = getattr(module, parameter_name)
weight_dtype = config.weight_dtype
activation_dtype = config.activation_dtype
layout = config.layout

if not isinstance(layout, CutlassSemiSparseLayout):
version = config.version
activation_granularity, weight_granularity = _normalize_granularity(
config.granularity
)
activation_value_lb = config.activation_value_lb
activation_value_ub = config.activation_value_ub
act_quant_kwargs = QuantizeTensorToFloat8Kwargs(
activation_dtype,
activation_granularity,
hp_value_lb=activation_value_lb,
hp_value_ub=activation_value_ub,
)
packing_format = config.float8_packing_format

if version == 2:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think i hardcoded this to change to v2 by default, but we should probably split the bc breaking change to its own commit

quantized_param = Float8SemiSparseTensor.from_hp(
unquantized_param,
float8_dtype=weight_dtype,
granularity=weight_granularity,
packing_format=packing_format,
act_quant_kwargs=act_quant_kwargs,
)
else:
raise NotImplementedError(
f"Only CutlassSemiSparseLayout layout is supported. Received {layout}."
f"Only version 2 of Float8DynamicActivationFloat8SemiSparseWeightConfig is supported. Received {version}."
)

weight = _float8_cutlass_quant_sparse(weight, weight_dtype)
weight = to_linear_activation_quantized(
weight,
_float8_cutlass_quant,
quant_kwargs={"target_dtype": activation_dtype},
setattr(
module,
parameter_name,
torch.nn.Parameter(quantized_param, requires_grad=False),
)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)

module.weight = torch.nn.Parameter(weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


Expand Down
4 changes: 4 additions & 0 deletions torchao/quantization/quantize_/common/kernel_preference.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,9 @@ class KernelPreference(str, Enum):
"""
FBGEMM = "fbgemm"

"""Use torchao cutlass kernel for fp8 + 2:4 sparse mm, requires building torchao with CUDA
"""
SPARSE_CUTLASS = "sparse_cutlass"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my understanding is this is a new packing format, why is this a new kernel preference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sparse_cutlass vs sparse_cusparselt/hipsparselt is something we will need for AMD support coming up next half, which sounds like kernel preference to me (decide which op to use).

But if this is more a general thing and packing_format is the more specific way to decide op dispatch I am fine with using that as well.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jcaip , would be good to specify if the data format will be different and kernels different, or if data format is the same and kernels different.



torch.serialization.add_safe_globals([KernelPreference])
3 changes: 3 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
)
from .float8.float8_semi_sparse_tensor import (
Float8SemiSparseTensor,
)
from .int4.int4_choose_qparams_algorithm import Int4ChooseQParamsAlgorithm
from .int4.int4_marlin_sparse_tensor import (
Int4MarlinSparseTensor,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ class Float8PackingFormat(str, Enum):
needed for the rest of the system to understand the specific format that's adopted.
"""
OPAQUE = "opaque"
"""
Sparse packing formats for 2:4 sparsity + FP8 quantization
"""
SPARSE_CUTLASS = "sparse_cutlass"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intent is for the sparse tensor to use OPAQUE, and you can keep these formats internal to your workflow

SPARSE_CUSPARSELT = "sparse_cusparselt"
Loading