Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions DeepQuant/CustomForwards/MultiHeadAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


def unrolledQuantMhaForward(
self: QuantMultiheadAttention, query: Tensor, key: Tensor, value: Tensor
self: QuantMultiheadAttention, query: Tensor, key: Tensor, value: Tensor, need_weights: bool = True
) -> Tensor:
"""
Export-friendly forward that explicitly unrolls the multi-head logic.
Expand Down Expand Up @@ -84,4 +84,11 @@ def unrolledQuantMhaForward(
)

attnOutput = self.out_proj(attnOutput)
return attnOutput

if need_weights:
# return average attention weights over heads
attnWeights = attnWeights.view(batchSize, self.num_heads, seqLen, seqLen)
attnWeights = attnWeights.sum(dim=1) / self.num_heads
else:
attnWeights = None
return attnOutput, attnWeights
6 changes: 6 additions & 0 deletions DeepQuant/QuantManipulation/DequantModifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,12 @@ def unifyLinearDequants(

# Collect and rewire the linear's arguments
for arg in oldArgs:

# NEW: skip None and non-Node args, keep them as-is
if arg is None or not isinstance(arg, fx.Node):
newLinArgs.append(arg)
continue

if arg.op == "call_module" and "dequant" in arg.target.lower():
if "bias_dequant" in arg.target.lower():
biasDequantNode = arg
Expand Down
5 changes: 4 additions & 1 deletion DeepQuant/QuantManipulation/ParameterExtractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,10 @@ def safe_get_scale(quant_obj: Any) -> Any:
if maybe_scale is None:
return None
if isinstance(maybe_scale, torch.Tensor):
return maybe_scale.item()
if maybe_scale.numel() == 1:
return maybe_scale.item()
else:
return maybe_scale
elif isinstance(maybe_scale, float):
return maybe_scale
try:
Expand Down
12 changes: 11 additions & 1 deletion DeepQuant/QuantManipulation/QuantNodesDivider.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,17 @@ def split_quant_nodes(
for user_node in list(node.users.keys()):
new_args = []
for arg in user_node.args:
new_args.append(dequant_node if arg is node else arg)
if arg is node:
new_args.append(dequant_node)
# if the argument is a tuple or list (e.g. concatenation, addition)
elif isinstance(arg, (tuple, list)):
seq_type = type(arg)
new_seq = []
for a in arg:
new_seq.append(dequant_node if a is node else a)
new_args.append(seq_type(new_seq))
else:
new_args.append(arg)
user_node.args = tuple(new_args)

nodes_to_erase.append(node)
Expand Down
59 changes: 59 additions & 0 deletions Tests/TestConvChannelWise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2025 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
#
# Victor Jung <[email protected]>
# Federico Brancasi <[email protected]>

import pytest
import torch
import torch.nn as nn
import brevitas.nn as qnn
from brevitas.quant.scaled_int import (
Int8ActPerTensorFloat,
Int32Bias,
Int8WeightPerChannelFloat
)
from DeepQuant.ExportBrevitas import exportBrevitas


class QuantConvNet(nn.Module):

convAndLinQuantParams = {
"bias": True,
"weight_bit_width": 4,
"bias_quant": Int32Bias,
"input_quant": Int8ActPerTensorFloat,
"weight_quant": Int8WeightPerChannelFloat,
"output_quant": Int8ActPerTensorFloat,
"return_quant_tensor": True,
}

def __init__(self, in_channels: int = 1) -> None:
super().__init__()
self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True)

self.conv1 = qnn.QuantConv2d(
in_channels=in_channels,
out_channels=16,
kernel_size=3,
padding=1,
**QuantConvNet.convAndLinQuantParams
)

def forward(self, x: torch.Tensor) -> torch.Tensor:

x = self.inputQuant(x)
x = self.conv1(x)

return x


@pytest.mark.SingleLayerTests
def deepQuantTestConv() -> None:

torch.manual_seed(42)

model = QuantConvNet().eval()
sampleInput = torch.randn(1, 1, 28, 28)
exportBrevitas(model, sampleInput, debug=True)
108 changes: 108 additions & 0 deletions Tests/TestSimpleCNNChannelWise.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2025 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
#
# Federico Brancasi <[email protected]>


import pytest
import torch
import torch.nn as nn
import brevitas.nn as qnn
from brevitas.quant.scaled_int import (
Int8ActPerTensorFloat,
Int32Bias,
Int8WeightPerChannelFloat,
)
from DeepQuant.ExportBrevitas import exportBrevitas


class SimpleQuantCNN(nn.Module):
"""
A simple quantized CNN that includes:
- Input quantization
- Two QuantConv2d layers with Quantized ReLU
- MaxPool2d
- A final QuantLinear layer
"""

convAndLinQuantParams = {
"bias": True,
"weight_bit_width": 4,
"bias_quant": Int32Bias,
"input_quant": Int8ActPerTensorFloat,
"weight_quant": Int8WeightPerChannelFloat,
"output_quant": Int8ActPerTensorFloat,
"return_quant_tensor": True,
}

def __init__(self, in_channels: int = 1, num_classes: int = 10) -> None:
"""
Args:
in_channels: Number of input channels (e.g., 1 for grayscale).
num_classes: Number of output classes for the final linear layer.
"""
super().__init__()
self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True)

self.conv1 = qnn.QuantConv2d(
in_channels=in_channels,
out_channels=16,
kernel_size=3,
padding=1,
**SimpleQuantCNN.convAndLinQuantParams
)
self.relu1 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
self.pool1 = nn.MaxPool2d(kernel_size=2)

self.conv2 = qnn.QuantConv2d(
in_channels=16,
out_channels=32,
kernel_size=3,
padding=1,
**SimpleQuantCNN.convAndLinQuantParams
)
self.relu2 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True)
self.pool2 = nn.MaxPool2d(kernel_size=2)

self.flatten = nn.Flatten()
self.fc = qnn.QuantLinear(
in_features=32 * 7 * 7, # If input is 28x28, shape after pooling is 7x7
out_features=num_classes,
**SimpleQuantCNN.convAndLinQuantParams
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass of the SimpleQuantCNN.

Args:
x: Input tensor of shape [batch_size, in_channels, height, width].

Returns:
A quantized output tensor (batch_size, num_classes).
"""
x = self.inputQuant(x)

x = self.conv1(x)
x = self.relu1(x)
x = self.pool1(x)

x = self.conv2(x)
x = self.relu2(x)
x = self.pool2(x)

x = self.flatten(x)
x = self.fc(x)
return x


@pytest.mark.ModelTests
def deepQuantTestSimpleCNN() -> None:

torch.manual_seed(42)

model = SimpleQuantCNN().eval()
sampleInput = torch.randn(1, 1, 28, 28)

exportBrevitas(model, sampleInput, debug=True)