Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
import unittest

import torch
from torch import nn
from torch.testing._internal import common_utils

from torchao.quantization import (
Float8DynamicActivationFloat8SemiSparseWeightConfig,
Float8DynamicActivationFloat8WeightConfig,
)
from torchao.quantization.quant_api import (
quantize_,
)
from torchao.quantization.quantize_.workflows import (
Float8SemiSparseTensorPackingFormat,
)
from torchao.sparsity import apply_fake_sparsity
from torchao.utils import is_sm_at_least_90

logging.basicConfig(
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO
)


class TestFloat8SemiSparseTensor(common_utils.TestCase):
@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
@common_utils.parametrize("compile", [True, False])
@common_utils.parametrize(
"packing_format",
[
Float8SemiSparseTensorPackingFormat.SPARSE_CUTLASS,
Float8SemiSparseTensorPackingFormat.SPARSE_CUSPARSELT,
],
)
def test_fp8_cutlass_sparse(self, compile, packing_format):
with torch.inference_mode():
input = torch.rand((256, 256), dtype=torch.bfloat16, device="cuda")
model = (
nn.Sequential(
nn.Linear(256, 1024),
nn.Linear(1024, 256),
)
.bfloat16()
.cuda()
.eval()
)

apply_fake_sparsity(model)
model_copy = copy.deepcopy(model)

# Quantized
quantize_(model_copy, Float8DynamicActivationFloat8WeightConfig())
dense_result = model_copy(input)

# Sparse + quantized
quantize_(
model,
Float8DynamicActivationFloat8SemiSparseWeightConfig(
float8_packing_format=packing_format
),
)
if compile:
model = torch.compile(model)
sparse_result = model(input)

torch.testing.assert_close(
dense_result.to(torch.float),
sparse_result.to(torch.float),
atol=3e-1,
rtol=3e-1,
)

@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_fp8_cutlass_sparse_lowering_op_clone(self):
with torch.inference_mode():
model = nn.Linear(256, 1024).half().cuda().eval()
apply_fake_sparsity(model)
quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig())

original = model.weight.dequantize()
cloned = model.weight.clone().dequantize()

for o, c in zip(original, cloned):
torch.testing.assert_close(o, c, atol=0.0, rtol=0.0)

@unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_fp8_cutlass_sparse_lowering_op_to(self):
# Need to run with inference mode to avoid dispatching to `aten.to_copy`
with torch.inference_mode():
model = nn.Linear(256, 1024).half().cuda().eval()
apply_fake_sparsity(model)
model_copy = copy.deepcopy(model)
expected = model_copy.weight.to(dtype=torch.float)

quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig())

original = torch.ops.aten.to.dtype_layout(
model.weight,
dtype=torch.float,
layout=torch.strided,
)
torch.testing.assert_close(expected, original, atol=1e-1, rtol=1e-1)


common_utils.instantiate_parametrized_tests(TestFloat8SemiSparseTensor)

if __name__ == "__main__":
unittest.main()
72 changes: 56 additions & 16 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
make_packed_linear_int8_dynamic_activation_intx_weight_tensor,
)
from torchao.dtypes.utils import Layout
from torchao.float8.config import e4m3_dtype, e5m2_dtype
from torchao.float8.config import e4m3_dtype
from torchao.float8.float8_linear import Float8Linear
from torchao.float8.inference import (
Float8MMConfig,
Expand All @@ -74,6 +74,8 @@
KernelPreference,
)
from torchao.quantization.quantize_.workflows import (
Float8SemiSparseTensor,
Float8SemiSparseTensorPackingFormat,
Float8Tensor,
Int4ChooseQParamsAlgorithm,
Int4MarlinSparseTensor,
Expand Down Expand Up @@ -1971,44 +1973,82 @@ class Float8DynamicActivationFloat8SemiSparseWeightConfig(AOBaseConfig):
`weight_dtype`: data type for quantized weight tensor.
"""

layout: Layout = CutlassSemiSparseLayout()
activation_dtype: torch.dtype = e5m2_dtype
activation_dtype: torch.dtype = e4m3_dtype
weight_dtype: torch.dtype = e4m3_dtype
granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = PerRow()
activation_value_lb: Optional[float] = None
activation_value_ub: Optional[float] = None
float8_packing_format: Float8SemiSparseTensorPackingFormat = (
Float8SemiSparseTensorPackingFormat.SPARSE_CUTLASS
)
version: int = 2

def __post_init__(self):
torch._C._log_api_usage_once(
"torchao.quantization.Float8DynamicActivationFloat8SemiSparseWeightConfig"
)

assert self.float8_packing_format in {
Float8SemiSparseTensorPackingFormat.SPARSE_CUTLASS,
Float8SemiSparseTensorPackingFormat.SPARSE_CUSPARSELT,
}, f"{self.float8_packing_format} is not supported"


@register_quantize_module_handler(Float8DynamicActivationFloat8SemiSparseWeightConfig)
def _float8_dynamic_activation_float8_semi_sparse_weight_transform(
module: torch.nn.Module, config: Float8DynamicActivationFloat8SemiSparseWeightConfig
module: torch.nn.Module,
config: Float8DynamicActivationFloat8SemiSparseWeightConfig,
*,
parameter_name: str = "weight",
):
assert is_sm_at_least_90(), "Float8 quantization is only supported on CUDA>=9.0"

if isinstance(module, Float8Linear):
module = _unwrap_float8_linear(module)

weight = module.weight
unquantized_param = getattr(module, parameter_name)
weight_dtype = config.weight_dtype
activation_dtype = config.activation_dtype
layout = config.layout
version = config.version
activation_granularity, weight_granularity = _normalize_granularity(
config.granularity
)
activation_value_lb = config.activation_value_lb
activation_value_ub = config.activation_value_ub
act_quant_kwargs = QuantizeTensorToFloat8Kwargs(
activation_dtype,
activation_granularity,
hp_value_lb=activation_value_lb,
hp_value_ub=activation_value_ub,
)
packing_format = config.float8_packing_format

if not isinstance(layout, CutlassSemiSparseLayout):
if version == 2:
quantized_param = Float8SemiSparseTensor.from_hp(
unquantized_param,
float8_dtype=weight_dtype,
granularity=weight_granularity,
packing_format=packing_format,
act_quant_kwargs=act_quant_kwargs,
)
else:
raise NotImplementedError(
f"Only CutlassSemiSparseLayout layout is supported. Received {layout}."
f"Only version 2 of Float8DynamicActivationFloat8SemiSparseWeightConfig is supported. Received {version}."
)

weight = _float8_cutlass_quant_sparse(weight, weight_dtype)
weight = to_linear_activation_quantized(
weight,
_float8_cutlass_quant,
quant_kwargs={"target_dtype": activation_dtype},
setattr(
module,
parameter_name,
torch.nn.Parameter(quantized_param, requires_grad=False),
)
module.extra_repr = types.MethodType(
partial(
_module_extra_repr,
original_extra_repr=module.extra_repr,
parameter_name=parameter_name,
),
module,
)

module.weight = torch.nn.Parameter(weight, requires_grad=False)
module.extra_repr = types.MethodType(_linear_extra_repr, module)
return module


Expand Down
6 changes: 6 additions & 0 deletions torchao/quantization/quantize_/workflows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from .float8.float8_semi_sparse_tensor import (
Float8SemiSparseTensor,
Float8SemiSparseTensorPackingFormat,
)
from .float8.float8_tensor import (
Float8Tensor,
QuantizeTensorToFloat8Kwargs,
Expand Down Expand Up @@ -38,6 +42,8 @@
"Int4PlainInt32Tensor",
"Int4TilePackedTo4dTensor",
"Float8Tensor",
"Float8SemiSparseTensor",
"Float8SemiSparseTensorPackingFormat",
"QuantizeTensorToFloat8Kwargs",
"Int4OpaqueTensor",
"Int4ChooseQParamsAlgorithm",
Expand Down
Loading
Loading