Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
832a715
Initial commit fbrancasi/dev
federicobrancasi Apr 19, 2025
59c2066
Working Resnet18
federicobrancasi Apr 23, 2025
f9325e8
Codebase Refactor
federicobrancasi Apr 24, 2025
c0fdc13
Codebase Refactor
federicobrancasi Apr 24, 2025
be6ec27
Codebase Refactor
federicobrancasi Apr 24, 2025
9b17ee7
Codebase Refactor
federicobrancasi Apr 25, 2025
908c611
update Resnet18 test
federicobrancasi Apr 25, 2025
71977cb
Fix CI
federicobrancasi Apr 25, 2025
516605e
Minor Fixes
federicobrancasi Apr 30, 2025
b998c38
Rename for better understanding
federicobrancasi Apr 30, 2025
98c828c
Refactor codebase & fix ResNet-18 test (#1)
federicobrancasi May 2, 2025
e5edfc7
Add capability to deal with linear (Conv, Gemm...) without bias (#2)
FrancescoConti Jun 10, 2025
316bd96
Merge remote-tracking branch 'upstream/main' into fbrancasi/dev
federicobrancasi Jun 10, 2025
cfd21c3
Add CCT Model and Test
federicobrancasi Jun 10, 2025
e2bc1b4
Change Rounding Policy in Quant Module
federicobrancasi Jun 10, 2025
01a60c5
Refactor CCT Test
federicobrancasi Jun 10, 2025
2cbc76c
Handle Deterministic Session for ORT and Update Tests
federicobrancasi Jun 11, 2025
2277543
Add checkEquivalence Flag to brevitasToTrueQuant function
federicobrancasi Jun 11, 2025
9e489b1
Fix Problem of .view(-1) in TensorRecorder Util
federicobrancasi Jun 16, 2025
531fa57
Modify CCT Test
federicobrancasi Jun 23, 2025
dfe48d0
Modify CCT Test
federicobrancasi Jun 23, 2025
29d7c0c
Modify CCT Test Pretrained
federicobrancasi Jun 23, 2025
ac3cd41
Fix CCT Test Pretrained
federicobrancasi Jun 23, 2025
383e8bc
Fix CCT Test (use 8bit for weight_bit_width of Linear)
federicobrancasi Jun 23, 2025
67a4094
Add CIFAR10 for CI
federicobrancasi Jun 23, 2025
b64126b
Fix CI
federicobrancasi Jun 23, 2025
21b9c74
Fix CI
federicobrancasi Jun 23, 2025
18b2642
Update README.md
federicobrancasi Jun 23, 2025
2f00a67
Add New Models and Fix MHA Problems
federicobrancasi Jun 24, 2025
b2c6fe2
Update ORT opset version
federicobrancasi Jun 24, 2025
856f249
Add ViTB32 Test
federicobrancasi Jun 24, 2025
d580088
Merge branch 'fbrancasi/dev' of github.com:pulp-platform/DeepQuant in…
federicobrancasi Jun 24, 2025
2657957
Update TestCCT
federicobrancasi Jun 25, 2025
15250c8
Update TestViTB32
federicobrancasi Jun 25, 2025
0adc5a3
Use the right version of CCT
federicobrancasi Jul 2, 2025
3d3b90e
Update Tests to use right version of CCT
federicobrancasi Jul 9, 2025
3406450
Remove wrong version of CCT from Model Folder
federicobrancasi Jul 9, 2025
e27b638
Fix pytest configuration to exclude non-test files
federicobrancasi Jul 9, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ dist/

*.gz
*-ubyte
*.pth
*.pt
*.onnx
*.npz
onnx/*
Dataset/*
mnist_model.pth
45 changes: 5 additions & 40 deletions DeepQuant/CustomForwards/Activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,63 +4,28 @@
#
# Federico Brancasi <[email protected]>


import torch.nn as nn
from torch import Tensor
from brevitas.nn.quant_layer import QuantNonLinearActLayer
from torch import Tensor


class InnerForwardImplWrapperActivation(nn.Module):
"""
A small wrapper around the activation function of a Brevitas QuantActivation layer.

This wrapper exposes the original activation function as a standalone submodule
so that FX tracing can display it as a separate node.
"""
class WrapperActivation(nn.Module):
"""Expose inner activation so FX sees it as a leaf."""

def __init__(self, actImpl: nn.Module) -> None:
"""
Args:
act_impl: The original activation function module (e.g. an instance of nn.ReLU).
"""
super().__init__()
self.actImpl = actImpl

def forward(self, quantInput: Tensor) -> Tensor:
"""
Applies the wrapped activation function.

Args:
quant_input: Input tensor after input quantization.

Returns:
Output tensor after applying the activation.
"""
return self.actImpl(quantInput)


def quantActivationForward(self: QuantNonLinearActLayer, inp: Tensor) -> Tensor:
"""
Unrolled forward pass for a Brevitas QuantActivation layer.

Steps:
1) Apply self.input_quant to the input.
2) Apply the activation function via the wrapped activation implementation.
3) Apply self.act_quant to the activation output.

Args:
self: The QuantNonLinearActLayer instance.
inp: The input tensor.

Returns:
Output tensor after applying activation and output quantization.
"""
def activationForward(self: QuantNonLinearActLayer, inp: Tensor) -> Tensor:
"""Unroll input→act→output quant steps."""
quantInput = self.input_quant(inp) if self.input_quant is not None else inp
# Use the wrapped activation if available; otherwise pass through.
if hasattr(self, "wrappedActImpl"):
output = self.wrappedActImpl(quantInput)
else:
output = quantInput
import IPython; IPython.embed()
quantOutput = self.act_quant(output) if self.act_quant is not None else output
return quantOutput
75 changes: 0 additions & 75 deletions DeepQuant/CustomForwards/Linear.py

This file was deleted.

151 changes: 106 additions & 45 deletions DeepQuant/CustomForwards/MultiHeadAttention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,67 +4,76 @@
#
# Federico Brancasi <[email protected]>


import math
from typing import Optional, Tuple

import torch
import torch.nn.functional as F
from torch import Tensor
from brevitas.nn.quant_mha import QuantMultiheadAttention
from torch import Tensor


def unrolledQuantMhaForward(
self: QuantMultiheadAttention, query: Tensor, key: Tensor, value: Tensor
def _mhaForwardImpl(
self: QuantMultiheadAttention,
query: Tensor,
key: Tensor,
value: Tensor,
need_transpose_in: bool,
need_transpose_out: bool,
) -> Tensor:
"""
Export-friendly forward that explicitly unrolls the multi-head logic.

Steps:
1) Q, K, V projections
2) Reshapes & permutes for multi-head
3) Scales queries
4) Applies softmax and intermediate quantizations
5) Out projection

Args:
self: The QuantMultiheadAttention instance.
query: The query tensor of shape [sequence_len, batch_size, embed_dim].
key: The key tensor, same shape as query.
value: The value tensor, same shape as query.

Returns:
A torch.Tensor of shape [sequence_len, batch_size, embed_dim]
after the unrolled MHA steps.
"""
# 1) Q, K, V projections
qOut = self.q_proj(query)
kOut = self.k_proj(key)
vOut = self.v_proj(value)
"""Core MHA forward implementation."""
# FBRANCASI: Handle batch_first by transposing if needed
if need_transpose_in:
if key is value:
if query is key:
query = key = value = query.transpose(1, 0)
else:
query, key = [x.transpose(1, 0) for x in (query, key)]
value = key
else:
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]

if self.in_proj is not None:
# FBRANCASI: Handle packed projections (default case for models like ViT)
# Only support self-attention where query == key == value
if not (query is key and key is value):
raise RuntimeError(
"Packed in_proj is supported only for self-attention with k is v is q. Set packed_in_proj=False."
)
qkv = self.in_proj(query)
qkv_tensor = qkv.value if hasattr(qkv, "value") else qkv
qOut, kOut, vOut = qkv_tensor.chunk(3, dim=-1)
else:
q_result = self.q_proj(query)
k_result = self.k_proj(key)
v_result = self.v_proj(value)

qOut = q_result.value if hasattr(q_result, "value") else q_result
kOut = k_result.value if hasattr(k_result, "value") else k_result
vOut = v_result.value if hasattr(v_result, "value") else v_result

# 2) Multi-head reshape
seqLen, batchSize, embedDim = qOut.shape
headDim = embedDim // self.num_heads

qOut = (
qOut.view(seqLen, batchSize, self.num_heads, headDim)
.permute(1, 2, 0, 3)
.reshape(batchSize * self.num_heads, seqLen, headDim)
qOut.contiguous()
.view(seqLen, batchSize * self.num_heads, headDim)
.transpose(0, 1)
)
kOut = (
kOut.view(seqLen, batchSize, self.num_heads, headDim)
.permute(1, 2, 0, 3)
.reshape(batchSize * self.num_heads, seqLen, headDim)
kOut.contiguous()
.view(seqLen, batchSize * self.num_heads, headDim)
.transpose(0, 1)
)
vOut = (
vOut.view(seqLen, batchSize, self.num_heads, headDim)
.permute(1, 2, 0, 3)
.reshape(batchSize * self.num_heads, seqLen, headDim)
vOut.contiguous()
.view(seqLen, batchSize * self.num_heads, headDim)
.transpose(0, 1)
)

# 3) Scale queries, then quantize
qScaled = qOut / math.sqrt(headDim)
qScaled = self.q_scaled_quant(qScaled)

# 4) Transpose + quantize K, compute attention weights
k_t = kOut.transpose(-2, -1)
k_t = self.k_transposed_quant(k_t)

Expand All @@ -73,15 +82,67 @@ def unrolledQuantMhaForward(
attnWeights = F.softmax(attnWeights, dim=-1)
attnWeights = self.attn_output_weights_quant(attnWeights)

# 5) Quantize V, multiply, reshape back, and final out projection
vOut = self.v_quant(vOut)
attnOutput = torch.bmm(attnWeights, vOut)

attnOutput = (
attnOutput.view(batchSize, self.num_heads, seqLen, headDim)
.permute(2, 0, 1, 3)
.reshape(seqLen, batchSize, embedDim)
attnOutput.transpose(0, 1).contiguous().view(seqLen, batchSize, embedDim)
)

attnOutput = self.out_proj(attnOutput)
out_result = self.out_proj(attnOutput)
attnOutput = out_result.value if hasattr(out_result, "value") else out_result

if need_transpose_out:
attnOutput = attnOutput.transpose(1, 0)

return attnOutput


def mhaForwardBatchFirst(
self: QuantMultiheadAttention,
query: Tensor,
key: Tensor,
value: Tensor,
need_weights: bool = True,
**kwargs,
) -> Tuple[Tensor, Optional[Tensor]]:
"""MHA forward for batch_first=True."""
attn_output = _mhaForwardImpl(
self, query, key, value, need_transpose_in=True, need_transpose_out=True
)
return (attn_output, None)


def mhaForwardSeqFirst(
self: QuantMultiheadAttention,
query: Tensor,
key: Tensor,
value: Tensor,
need_weights: bool = True,
**kwargs,
) -> Tuple[Tensor, Optional[Tensor]]:
"""MHA forward for batch_first=False."""
attn_output = _mhaForwardImpl(
self, query, key, value, need_transpose_in=False, need_transpose_out=False
)
return (attn_output, None)


def mhaForward(
self: QuantMultiheadAttention,
query: Tensor,
key: Tensor,
value: Tensor,
need_weights: bool = True,
**kwargs,
) -> Tuple[Tensor, Optional[Tensor]]:
"""Explicit, export-friendly MHA forward.

This function will be replaced with the appropriate batch_first or seq_first version
during module transformation based on the module's batch_first attribute.
"""
# FBRANCASI: Appropriate version before tracing
if self.batch_first:
return mhaForwardBatchFirst(self, query, key, value, need_weights, **kwargs)
else:
return mhaForwardSeqFirst(self, query, key, value, need_weights, **kwargs)
36 changes: 36 additions & 0 deletions DeepQuant/CustomForwards/WBIOL.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2025 ETH Zurich and University of Bologna.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
#
# Federico Brancasi <[email protected]>

import torch.nn as nn
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer
from torch import Tensor


class WrapperWBIOL(nn.Module):
"""Expose `inner_forward_impl` as a standalone submodule."""

def __init__(self, innerForwardImpl: nn.Module) -> None:
super().__init__()
self.innerForwardImpl = innerForwardImpl

def forward(
self, quantInput: Tensor, quantWeight: Tensor, quantBias: Tensor
) -> Tensor:
return self.innerForwardImpl(quantInput, quantWeight, quantBias)


def WBIOLForward(self: QuantWeightBiasInputOutputLayer, inp: Tensor) -> Tensor:
"""Quant-in → quant-weight/bias → matmul → quant-out."""
quantInput = self.input_quant(inp)
quantWeight = self.weight_quant(self.weight)

quantBias = None
if self.bias is not None:
quantBias = self.bias_quant(self.bias, quantInput, quantWeight)

output = self.wrappedInnerForwardImpl(quantInput, quantWeight, quantBias)
quantOutput = self.output_quant(output)
return quantOutput
Loading