-
Notifications
You must be signed in to change notification settings - Fork 382
[sparse] Migrate Float8SemiSparseTensor off of AQT #3361
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 6 commits
21d1200
ab923e0
d0ca9fc
6636766
47b5935
96994db
c978529
868cca7
9c4e421
fac2240
16ff85d
e9fae84
d2f51b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD 3-Clause license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
| import copy | ||
| import logging | ||
| import unittest | ||
|
|
||
| import torch | ||
| from torch import nn | ||
| from torch.testing._internal import common_utils | ||
|
|
||
| from torchao.quantization import ( | ||
| Float8DynamicActivationFloat8SemiSparseWeightConfig, | ||
| Float8DynamicActivationFloat8WeightConfig, | ||
| ) | ||
| from torchao.quantization.quant_api import ( | ||
| quantize_, | ||
| ) | ||
| from torchao.quantization.quantize_.workflows.float8.float8_packing_format import ( | ||
| Float8PackingFormat, | ||
| ) | ||
| from torchao.sparsity import apply_fake_sparsity | ||
| from torchao.utils import is_sm_at_least_90 | ||
|
|
||
| logging.basicConfig( | ||
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO | ||
| ) | ||
|
|
||
|
|
||
| class TestFloat8SemiSparseTensor(common_utils.TestCase): | ||
| @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") | ||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| @common_utils.parametrize("compile", [True, False]) | ||
| @common_utils.parametrize( | ||
| "packing_format", | ||
| [Float8PackingFormat.SPARSE_CUTLASS, Float8PackingFormat.SPARSE_CUSPARSELT], | ||
| ) | ||
| def test_fp8_cutlass_sparse(self, compile, packing_format): | ||
| with torch.inference_mode(): | ||
| input = torch.rand((256, 256), dtype=torch.bfloat16, device="cuda") | ||
| model = ( | ||
| nn.Sequential( | ||
| nn.Linear(256, 1024), | ||
| nn.Linear(1024, 256), | ||
| ) | ||
| .bfloat16() | ||
| .cuda() | ||
| .eval() | ||
| ) | ||
|
|
||
| apply_fake_sparsity(model) | ||
| model_copy = copy.deepcopy(model) | ||
|
|
||
| # Quantized | ||
| quantize_(model_copy, Float8DynamicActivationFloat8WeightConfig()) | ||
| dense_result = model_copy(input) | ||
|
|
||
| # Sparse + quantized | ||
| quantize_( | ||
| model, | ||
| Float8DynamicActivationFloat8SemiSparseWeightConfig( | ||
| float8_packing_format=packing_format | ||
| ), | ||
| ) | ||
| if compile: | ||
| model = torch.compile(model) | ||
| sparse_result = model(input) | ||
|
|
||
| torch.testing.assert_close( | ||
| dense_result.to(torch.float), | ||
| sparse_result.to(torch.float), | ||
| atol=3e-1, | ||
| rtol=3e-1, | ||
| ) | ||
|
|
||
| @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") | ||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| def test_fp8_cutlass_sparse_lowering_op_clone(self): | ||
| with torch.inference_mode(): | ||
| model = nn.Linear(256, 1024).half().cuda().eval() | ||
| apply_fake_sparsity(model) | ||
| quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) | ||
|
|
||
| original = model.weight.dequantize() | ||
| cloned = model.weight.clone().dequantize() | ||
|
|
||
| for o, c in zip(original, cloned): | ||
| torch.testing.assert_close(o, c, atol=0.0, rtol=0.0) | ||
|
|
||
| @unittest.skipIf(not is_sm_at_least_90(), "Need H100 to run") | ||
| @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||
| def test_fp8_cutlass_sparse_lowering_op_to(self): | ||
| # Need to run with inference mode to avoid dispatching to `aten.to_copy` | ||
| with torch.inference_mode(): | ||
| model = nn.Linear(256, 1024).half().cuda().eval() | ||
| apply_fake_sparsity(model) | ||
| model_copy = copy.deepcopy(model) | ||
| expected = model_copy.weight.to(dtype=torch.float) | ||
|
|
||
| quantize_(model, Float8DynamicActivationFloat8SemiSparseWeightConfig()) | ||
|
|
||
| original = torch.ops.aten.to.dtype_layout( | ||
| model.weight, | ||
| dtype=torch.float, | ||
| layout=torch.strided, | ||
| ) | ||
| torch.testing.assert_close(expected, original, atol=1e-1, rtol=1e-1) | ||
|
|
||
|
|
||
| common_utils.instantiate_parametrized_tests(TestFloat8SemiSparseTensor) | ||
|
|
||
| if __name__ == "__main__": | ||
| unittest.main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,5 +30,9 @@ class KernelPreference(str, Enum): | |
| """ | ||
| FBGEMM = "fbgemm" | ||
|
|
||
| """Use torchao cutlass kernel for fp8 + 2:4 sparse mm, requires building torchao with CUDA | ||
| """ | ||
| SPARSE_CUTLASS = "sparse_cutlass" | ||
|
||
|
|
||
|
|
||
| torch.serialization.add_safe_globals([KernelPreference]) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -29,3 +29,8 @@ class Float8PackingFormat(str, Enum): | |
| needed for the rest of the system to understand the specific format that's adopted. | ||
| """ | ||
| OPAQUE = "opaque" | ||
| """ | ||
| Sparse packing formats for 2:4 sparsity + FP8 quantization | ||
| """ | ||
| SPARSE_CUTLASS = "sparse_cutlass" | ||
|
||
| SPARSE_CUSPARSELT = "sparse_cusparselt" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think i hardcoded this to change to v2 by default, but we should probably split the bc breaking change to its own commit