diff --git a/DeepQuant/CustomForwards/MultiHeadAttention.py b/DeepQuant/CustomForwards/MultiHeadAttention.py index 76fe3ae..240fc21 100644 --- a/DeepQuant/CustomForwards/MultiHeadAttention.py +++ b/DeepQuant/CustomForwards/MultiHeadAttention.py @@ -13,7 +13,7 @@ def unrolledQuantMhaForward( - self: QuantMultiheadAttention, query: Tensor, key: Tensor, value: Tensor + self: QuantMultiheadAttention, query: Tensor, key: Tensor, value: Tensor, need_weights: bool = True ) -> Tensor: """ Export-friendly forward that explicitly unrolls the multi-head logic. @@ -84,4 +84,11 @@ def unrolledQuantMhaForward( ) attnOutput = self.out_proj(attnOutput) - return attnOutput + + if need_weights: + # return average attention weights over heads + attnWeights = attnWeights.view(batchSize, self.num_heads, seqLen, seqLen) + attnWeights = attnWeights.sum(dim=1) / self.num_heads + else: + attnWeights = None + return attnOutput, attnWeights diff --git a/DeepQuant/QuantManipulation/DequantModifier.py b/DeepQuant/QuantManipulation/DequantModifier.py index 8bd9ae5..930affa 100644 --- a/DeepQuant/QuantManipulation/DequantModifier.py +++ b/DeepQuant/QuantManipulation/DequantModifier.py @@ -74,6 +74,12 @@ def unifyLinearDequants( # Collect and rewire the linear's arguments for arg in oldArgs: + + # NEW: skip None and non-Node args, keep them as-is + if arg is None or not isinstance(arg, fx.Node): + newLinArgs.append(arg) + continue + if arg.op == "call_module" and "dequant" in arg.target.lower(): if "bias_dequant" in arg.target.lower(): biasDequantNode = arg diff --git a/DeepQuant/QuantManipulation/ParameterExtractor.py b/DeepQuant/QuantManipulation/ParameterExtractor.py index b11d77b..c3c6757 100644 --- a/DeepQuant/QuantManipulation/ParameterExtractor.py +++ b/DeepQuant/QuantManipulation/ParameterExtractor.py @@ -45,7 +45,10 @@ def safe_get_scale(quant_obj: Any) -> Any: if maybe_scale is None: return None if isinstance(maybe_scale, torch.Tensor): - return maybe_scale.item() + if maybe_scale.numel() == 1: + return maybe_scale.item() + else: + return maybe_scale elif isinstance(maybe_scale, float): return maybe_scale try: diff --git a/DeepQuant/QuantManipulation/QuantNodesDivider.py b/DeepQuant/QuantManipulation/QuantNodesDivider.py index 6b7ab10..154c4c2 100644 --- a/DeepQuant/QuantManipulation/QuantNodesDivider.py +++ b/DeepQuant/QuantManipulation/QuantNodesDivider.py @@ -132,7 +132,17 @@ def split_quant_nodes( for user_node in list(node.users.keys()): new_args = [] for arg in user_node.args: - new_args.append(dequant_node if arg is node else arg) + if arg is node: + new_args.append(dequant_node) + # if the argument is a tuple or list (e.g. concatenation, addition) + elif isinstance(arg, (tuple, list)): + seq_type = type(arg) + new_seq = [] + for a in arg: + new_seq.append(dequant_node if a is node else a) + new_args.append(seq_type(new_seq)) + else: + new_args.append(arg) user_node.args = tuple(new_args) nodes_to_erase.append(node) diff --git a/Tests/TestConvChannelWise.py b/Tests/TestConvChannelWise.py new file mode 100644 index 0000000..273905f --- /dev/null +++ b/Tests/TestConvChannelWise.py @@ -0,0 +1,59 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Victor Jung +# Federico Brancasi + +import pytest +import torch +import torch.nn as nn +import brevitas.nn as qnn +from brevitas.quant.scaled_int import ( + Int8ActPerTensorFloat, + Int32Bias, + Int8WeightPerChannelFloat +) +from DeepQuant.ExportBrevitas import exportBrevitas + + +class QuantConvNet(nn.Module): + + convAndLinQuantParams = { + "bias": True, + "weight_bit_width": 4, + "bias_quant": Int32Bias, + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerChannelFloat, + "output_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, + } + + def __init__(self, in_channels: int = 1) -> None: + super().__init__() + self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True) + + self.conv1 = qnn.QuantConv2d( + in_channels=in_channels, + out_channels=16, + kernel_size=3, + padding=1, + **QuantConvNet.convAndLinQuantParams + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + + x = self.inputQuant(x) + x = self.conv1(x) + + return x + + +@pytest.mark.SingleLayerTests +def deepQuantTestConv() -> None: + + torch.manual_seed(42) + + model = QuantConvNet().eval() + sampleInput = torch.randn(1, 1, 28, 28) + exportBrevitas(model, sampleInput, debug=True) diff --git a/Tests/TestSimpleCNNChannelWise.py b/Tests/TestSimpleCNNChannelWise.py new file mode 100644 index 0000000..e8b52d1 --- /dev/null +++ b/Tests/TestSimpleCNNChannelWise.py @@ -0,0 +1,108 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + + +import pytest +import torch +import torch.nn as nn +import brevitas.nn as qnn +from brevitas.quant.scaled_int import ( + Int8ActPerTensorFloat, + Int32Bias, + Int8WeightPerChannelFloat, +) +from DeepQuant.ExportBrevitas import exportBrevitas + + +class SimpleQuantCNN(nn.Module): + """ + A simple quantized CNN that includes: + - Input quantization + - Two QuantConv2d layers with Quantized ReLU + - MaxPool2d + - A final QuantLinear layer + """ + + convAndLinQuantParams = { + "bias": True, + "weight_bit_width": 4, + "bias_quant": Int32Bias, + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerChannelFloat, + "output_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, + } + + def __init__(self, in_channels: int = 1, num_classes: int = 10) -> None: + """ + Args: + in_channels: Number of input channels (e.g., 1 for grayscale). + num_classes: Number of output classes for the final linear layer. + """ + super().__init__() + self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True) + + self.conv1 = qnn.QuantConv2d( + in_channels=in_channels, + out_channels=16, + kernel_size=3, + padding=1, + **SimpleQuantCNN.convAndLinQuantParams + ) + self.relu1 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True) + self.pool1 = nn.MaxPool2d(kernel_size=2) + + self.conv2 = qnn.QuantConv2d( + in_channels=16, + out_channels=32, + kernel_size=3, + padding=1, + **SimpleQuantCNN.convAndLinQuantParams + ) + self.relu2 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True) + self.pool2 = nn.MaxPool2d(kernel_size=2) + + self.flatten = nn.Flatten() + self.fc = qnn.QuantLinear( + in_features=32 * 7 * 7, # If input is 28x28, shape after pooling is 7x7 + out_features=num_classes, + **SimpleQuantCNN.convAndLinQuantParams + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the SimpleQuantCNN. + + Args: + x: Input tensor of shape [batch_size, in_channels, height, width]. + + Returns: + A quantized output tensor (batch_size, num_classes). + """ + x = self.inputQuant(x) + + x = self.conv1(x) + x = self.relu1(x) + x = self.pool1(x) + + x = self.conv2(x) + x = self.relu2(x) + x = self.pool2(x) + + x = self.flatten(x) + x = self.fc(x) + return x + + +@pytest.mark.ModelTests +def deepQuantTestSimpleCNN() -> None: + + torch.manual_seed(42) + + model = SimpleQuantCNN().eval() + sampleInput = torch.randn(1, 1, 28, 28) + + exportBrevitas(model, sampleInput, debug=True)