diff --git a/.gitignore b/.gitignore index abd6ff3..aecc5d3 100644 --- a/.gitignore +++ b/.gitignore @@ -24,8 +24,9 @@ dist/ *.gz *-ubyte -*.pth +*.pt *.onnx *.npz onnx/* Dataset/* +mnist_model.pth \ No newline at end of file diff --git a/DeepQuant/CustomForwards/Activations.py b/DeepQuant/CustomForwards/Activations.py index 2a30848..d114513 100644 --- a/DeepQuant/CustomForwards/Activations.py +++ b/DeepQuant/CustomForwards/Activations.py @@ -4,63 +4,28 @@ # # Federico Brancasi - import torch.nn as nn -from torch import Tensor from brevitas.nn.quant_layer import QuantNonLinearActLayer +from torch import Tensor -class InnerForwardImplWrapperActivation(nn.Module): - """ - A small wrapper around the activation function of a Brevitas QuantActivation layer. - - This wrapper exposes the original activation function as a standalone submodule - so that FX tracing can display it as a separate node. - """ +class WrapperActivation(nn.Module): + """Expose inner activation so FX sees it as a leaf.""" def __init__(self, actImpl: nn.Module) -> None: - """ - Args: - act_impl: The original activation function module (e.g. an instance of nn.ReLU). - """ super().__init__() self.actImpl = actImpl def forward(self, quantInput: Tensor) -> Tensor: - """ - Applies the wrapped activation function. - - Args: - quant_input: Input tensor after input quantization. - - Returns: - Output tensor after applying the activation. - """ return self.actImpl(quantInput) -def quantActivationForward(self: QuantNonLinearActLayer, inp: Tensor) -> Tensor: - """ - Unrolled forward pass for a Brevitas QuantActivation layer. - - Steps: - 1) Apply self.input_quant to the input. - 2) Apply the activation function via the wrapped activation implementation. - 3) Apply self.act_quant to the activation output. - - Args: - self: The QuantNonLinearActLayer instance. - inp: The input tensor. - - Returns: - Output tensor after applying activation and output quantization. - """ +def activationForward(self: QuantNonLinearActLayer, inp: Tensor) -> Tensor: + """Unroll input→act→output quant steps.""" quantInput = self.input_quant(inp) if self.input_quant is not None else inp - # Use the wrapped activation if available; otherwise pass through. if hasattr(self, "wrappedActImpl"): output = self.wrappedActImpl(quantInput) else: output = quantInput - import IPython; IPython.embed() quantOutput = self.act_quant(output) if self.act_quant is not None else output return quantOutput diff --git a/DeepQuant/CustomForwards/Linear.py b/DeepQuant/CustomForwards/Linear.py deleted file mode 100644 index 9043677..0000000 --- a/DeepQuant/CustomForwards/Linear.py +++ /dev/null @@ -1,75 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - - -import torch.nn as nn -from torch import Tensor -from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer - - -class InnerForwardImplWrapperLinear(nn.Module): - """ - A small wrapper around the 'innerForwardImpl' of a Brevitas QuantLinear - (QuantWeightBiasInputOutputLayer). - - We want to expose the logic within 'innerForwardImpl' as a standalone - submodule, so that FX tracing can see it as a leaf. - """ - - def __init__(self, innerForwardImpl: nn.Module) -> None: - """ - Args: - innerForwardImpl: The original function that processes - (quant_input, quant_weight, quant_bias). - """ - super().__init__() - self.innerForwardImpl = innerForwardImpl - - def forward( - self, quantInput: Tensor, quantWeight: Tensor, quantBias: Tensor - ) -> Tensor: - """ - Applies the wrapped innerForwardImpl. - - Args: - quant_input: Input after input_quant. - quant_weight: Weight after weight_quant. - quant_bias: Bias after bias_quant (or None). - - Returns: - A torch.Tensor with the linear operation applied. - """ - return self.innerForwardImpl(quantInput, quantWeight, quantBias) - - -def quantWBIOLForward(self: QuantWeightBiasInputOutputLayer, inp: Tensor) -> Tensor: - """ - Unrolled forward pass for a Brevitas QuantLinear: - - Steps: - 1) self.input_quant - 2) self.weight_quant - 3) self.bias_quant (if bias is present) - 4) innerForwardImpl (wrapped) - 5) self.output_quant - - Args: - self: The QuantWeightBiasInputOutputLayer instance. - inp: The input Tensor to be processed. - - Returns: - Output Tensor after the unrolled quantized linear steps. - """ - quantInput = self.input_quant(inp) - quantWeight = self.weight_quant(self.weight) - - quantBias = None - if self.bias is not None: - quantBias = self.bias_quant(self.bias, quantInput, quantWeight) - - output = self.wrappedInnerForwardImpl(quantInput, quantWeight, quantBias) - quantOutput = self.output_quant(output) - return quantOutput diff --git a/DeepQuant/CustomForwards/MultiHeadAttention.py b/DeepQuant/CustomForwards/MultiHeadAttention.py index 76fe3ae..b76fc7e 100644 --- a/DeepQuant/CustomForwards/MultiHeadAttention.py +++ b/DeepQuant/CustomForwards/MultiHeadAttention.py @@ -4,67 +4,76 @@ # # Federico Brancasi - import math +from typing import Optional, Tuple + import torch import torch.nn.functional as F -from torch import Tensor from brevitas.nn.quant_mha import QuantMultiheadAttention +from torch import Tensor -def unrolledQuantMhaForward( - self: QuantMultiheadAttention, query: Tensor, key: Tensor, value: Tensor +def _mhaForwardImpl( + self: QuantMultiheadAttention, + query: Tensor, + key: Tensor, + value: Tensor, + need_transpose_in: bool, + need_transpose_out: bool, ) -> Tensor: - """ - Export-friendly forward that explicitly unrolls the multi-head logic. - - Steps: - 1) Q, K, V projections - 2) Reshapes & permutes for multi-head - 3) Scales queries - 4) Applies softmax and intermediate quantizations - 5) Out projection - - Args: - self: The QuantMultiheadAttention instance. - query: The query tensor of shape [sequence_len, batch_size, embed_dim]. - key: The key tensor, same shape as query. - value: The value tensor, same shape as query. - - Returns: - A torch.Tensor of shape [sequence_len, batch_size, embed_dim] - after the unrolled MHA steps. - """ - # 1) Q, K, V projections - qOut = self.q_proj(query) - kOut = self.k_proj(key) - vOut = self.v_proj(value) + """Core MHA forward implementation.""" + # FBRANCASI: Handle batch_first by transposing if needed + if need_transpose_in: + if key is value: + if query is key: + query = key = value = query.transpose(1, 0) + else: + query, key = [x.transpose(1, 0) for x in (query, key)] + value = key + else: + query, key, value = [x.transpose(1, 0) for x in (query, key, value)] + + if self.in_proj is not None: + # FBRANCASI: Handle packed projections (default case for models like ViT) + # Only support self-attention where query == key == value + if not (query is key and key is value): + raise RuntimeError( + "Packed in_proj is supported only for self-attention with k is v is q. Set packed_in_proj=False." + ) + qkv = self.in_proj(query) + qkv_tensor = qkv.value if hasattr(qkv, "value") else qkv + qOut, kOut, vOut = qkv_tensor.chunk(3, dim=-1) + else: + q_result = self.q_proj(query) + k_result = self.k_proj(key) + v_result = self.v_proj(value) + + qOut = q_result.value if hasattr(q_result, "value") else q_result + kOut = k_result.value if hasattr(k_result, "value") else k_result + vOut = v_result.value if hasattr(v_result, "value") else v_result - # 2) Multi-head reshape seqLen, batchSize, embedDim = qOut.shape headDim = embedDim // self.num_heads qOut = ( - qOut.view(seqLen, batchSize, self.num_heads, headDim) - .permute(1, 2, 0, 3) - .reshape(batchSize * self.num_heads, seqLen, headDim) + qOut.contiguous() + .view(seqLen, batchSize * self.num_heads, headDim) + .transpose(0, 1) ) kOut = ( - kOut.view(seqLen, batchSize, self.num_heads, headDim) - .permute(1, 2, 0, 3) - .reshape(batchSize * self.num_heads, seqLen, headDim) + kOut.contiguous() + .view(seqLen, batchSize * self.num_heads, headDim) + .transpose(0, 1) ) vOut = ( - vOut.view(seqLen, batchSize, self.num_heads, headDim) - .permute(1, 2, 0, 3) - .reshape(batchSize * self.num_heads, seqLen, headDim) + vOut.contiguous() + .view(seqLen, batchSize * self.num_heads, headDim) + .transpose(0, 1) ) - # 3) Scale queries, then quantize qScaled = qOut / math.sqrt(headDim) qScaled = self.q_scaled_quant(qScaled) - # 4) Transpose + quantize K, compute attention weights k_t = kOut.transpose(-2, -1) k_t = self.k_transposed_quant(k_t) @@ -73,15 +82,67 @@ def unrolledQuantMhaForward( attnWeights = F.softmax(attnWeights, dim=-1) attnWeights = self.attn_output_weights_quant(attnWeights) - # 5) Quantize V, multiply, reshape back, and final out projection vOut = self.v_quant(vOut) attnOutput = torch.bmm(attnWeights, vOut) attnOutput = ( - attnOutput.view(batchSize, self.num_heads, seqLen, headDim) - .permute(2, 0, 1, 3) - .reshape(seqLen, batchSize, embedDim) + attnOutput.transpose(0, 1).contiguous().view(seqLen, batchSize, embedDim) ) - attnOutput = self.out_proj(attnOutput) + out_result = self.out_proj(attnOutput) + attnOutput = out_result.value if hasattr(out_result, "value") else out_result + + if need_transpose_out: + attnOutput = attnOutput.transpose(1, 0) + return attnOutput + + +def mhaForwardBatchFirst( + self: QuantMultiheadAttention, + query: Tensor, + key: Tensor, + value: Tensor, + need_weights: bool = True, + **kwargs, +) -> Tuple[Tensor, Optional[Tensor]]: + """MHA forward for batch_first=True.""" + attn_output = _mhaForwardImpl( + self, query, key, value, need_transpose_in=True, need_transpose_out=True + ) + return (attn_output, None) + + +def mhaForwardSeqFirst( + self: QuantMultiheadAttention, + query: Tensor, + key: Tensor, + value: Tensor, + need_weights: bool = True, + **kwargs, +) -> Tuple[Tensor, Optional[Tensor]]: + """MHA forward for batch_first=False.""" + attn_output = _mhaForwardImpl( + self, query, key, value, need_transpose_in=False, need_transpose_out=False + ) + return (attn_output, None) + + +def mhaForward( + self: QuantMultiheadAttention, + query: Tensor, + key: Tensor, + value: Tensor, + need_weights: bool = True, + **kwargs, +) -> Tuple[Tensor, Optional[Tensor]]: + """Explicit, export-friendly MHA forward. + + This function will be replaced with the appropriate batch_first or seq_first version + during module transformation based on the module's batch_first attribute. + """ + # FBRANCASI: Appropriate version before tracing + if self.batch_first: + return mhaForwardBatchFirst(self, query, key, value, need_weights, **kwargs) + else: + return mhaForwardSeqFirst(self, query, key, value, need_weights, **kwargs) diff --git a/DeepQuant/CustomForwards/WBIOL.py b/DeepQuant/CustomForwards/WBIOL.py new file mode 100644 index 0000000..81e9e65 --- /dev/null +++ b/DeepQuant/CustomForwards/WBIOL.py @@ -0,0 +1,36 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +import torch.nn as nn +from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer +from torch import Tensor + + +class WrapperWBIOL(nn.Module): + """Expose `inner_forward_impl` as a standalone submodule.""" + + def __init__(self, innerForwardImpl: nn.Module) -> None: + super().__init__() + self.innerForwardImpl = innerForwardImpl + + def forward( + self, quantInput: Tensor, quantWeight: Tensor, quantBias: Tensor + ) -> Tensor: + return self.innerForwardImpl(quantInput, quantWeight, quantBias) + + +def WBIOLForward(self: QuantWeightBiasInputOutputLayer, inp: Tensor) -> Tensor: + """Quant-in → quant-weight/bias → matmul → quant-out.""" + quantInput = self.input_quant(inp) + quantWeight = self.weight_quant(self.weight) + + quantBias = None + if self.bias is not None: + quantBias = self.bias_quant(self.bias, quantInput, quantWeight) + + output = self.wrappedInnerForwardImpl(quantInput, quantWeight, quantBias) + quantOutput = self.output_quant(output) + return quantOutput diff --git a/DeepQuant/CustomTracer.py b/DeepQuant/CustomTracer.py deleted file mode 100644 index fab5dbe..0000000 --- a/DeepQuant/CustomTracer.py +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -Custom Brevitas tracer implementation for handling module transformation and tracing. -""" - -import torch.nn as nn -from brevitas.fx.brevitas_tracer import ( - _symbolic_trace, - _is_brevitas_leaf_module, - Tracer, -) -from torch.fx.graph_module import GraphModule -from typing import List, Type, Optional - - -class CustomBrevitasTracer(Tracer): - """ - A custom tracer that allows explicit control over leaf and non-leaf module designation. - - This tracer extends the Brevitas tracer to provide fine-grained control over which modules - should be treated as leaf modules (traced as a single unit) vs non-leaf modules - (traced into their constituent operations). - """ - - def __init__( - self, - leafClasses: Optional[List[Type[nn.Module]]] = None, - nonLeafClasses: Optional[List[Type[nn.Module]]] = None, - debug: bool = False, - ) -> None: - """ - Initialize the custom tracer with optional leaf and non-leaf module lists. - - Args: - leaf_classes: List of module classes to be treated as leaf modules. - non_leaf_classes: List of module classes to be treated as non-leaf modules. - debug: Whether to print debug information during tracing. - """ - super().__init__() - self.leafClasses = leafClasses if leafClasses is not None else [] - self.nonLeafClasses = nonLeafClasses if nonLeafClasses is not None else [] - self.debug = debug - - def registerLeafModule(self, moduleCls: Type[nn.Module]) -> None: - """ - Add a module class to the list of leaf modules. - - Args: - module_cls: The module class to register as a leaf module. - """ - if moduleCls not in self.leafClasses: - self.leafClasses.append(moduleCls) - - def registerNonLeafModule(self, moduleCls: Type[nn.Module]) -> None: - """ - Add a module class to the list of non-leaf modules. - - Args: - module_cls: The module class to register as a non-leaf module. - """ - if moduleCls not in self.nonLeafClasses: - self.nonLeafClasses.append(moduleCls) - - def is_leaf_module(self, m: nn.Module, moduleQualifiedName: str) -> bool: - """ - Determine whether a module should be treated as a leaf module. - - The decision follows this priority: - 1. If module is in leaf_classes, treat as leaf - 2. If module is in non_leaf_classes, treat as non-leaf - 3. Otherwise, fall back to default Brevitas behavior - - Args: - m: The module to check. - module_qualified_name: The fully qualified name of the module. - - Returns: - bool: True if the module should be treated as a leaf module, False otherwise. - """ - # First check explicitly registered classes - if any(isinstance(m, lc) for lc in self.leafClasses): - return True - if any(isinstance(m, nlc) for nlc in self.nonLeafClasses): - return False - # Fall back to default Brevitas behavior - return _is_brevitas_leaf_module(m, moduleQualifiedName) - - -def customBrevitasTrace( - root: nn.Module, concreteArgs=None, tracer: Optional[CustomBrevitasTracer] = None -) -> GraphModule: - """ - Create an FX GraphModule using the CustomBrevitasTracer. - - Args: - root: The root module to trace. - concrete_args: Concrete arguments to use for tracing. - tracer: Optional pre-configured CustomBrevitasTracer instance. - - Returns: - GraphModule: The traced module. - """ - if tracer is None: - tracer = CustomBrevitasTracer() - return _symbolic_trace(tracer, root, concreteArgs) diff --git a/DeepQuant/Export.py b/DeepQuant/Export.py new file mode 100644 index 0000000..239063e --- /dev/null +++ b/DeepQuant/Export.py @@ -0,0 +1,55 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from pathlib import Path +from typing import Optional, Union + +import torch +import torch.nn as nn + +from DeepQuant.Pipeline.DequantUnify import mergeDequants +from DeepQuant.Pipeline.Injection import injectCustomForwards +from DeepQuant.Pipeline.OnnxExport import exportToOnnx +from DeepQuant.Pipeline.OriginalTracing import traceOriginalModel +from DeepQuant.Pipeline.QuantSplit import splitQuantNodes + + +def brevitasToTrueQuant( + model: nn.Module, + exampleInput: torch.Tensor, + exportPath: Optional[Union[str, Path]] = Path.cwd() / "Tests" / "ONNX", + debug: bool = False, + checkEquivalence: bool = False, +) -> nn.Module: + """ + Export a Brevitas model to an FX GraphModule with unrolled quantization operations. + + This function applies a series of transformations to make the quantization steps + explicit in the model's computation graph, enabling efficient integer-only execution. + """ + + # Pipeline Step 1: Trace the original model + tracedModel, originalOutput = traceOriginalModel(model, exampleInput, debug) + + # Pipeline Step 2: Inject custom forward implementations + transformedModel, transformedOutput = injectCustomForwards( + tracedModel, exampleInput, originalOutput, debug, checkEquivalence + ) + + # Pipeline Step 3: Split quantization nodes + splitModel, splitOutput = splitQuantNodes( + transformedModel, exampleInput, transformedOutput, debug, checkEquivalence + ) + + # Pipeline Step 4: Unify dequant nodes + unifiedModel, _ = mergeDequants( + splitModel, exampleInput, splitOutput, debug, checkEquivalence + ) + + # Pipeline Step 5: Export to ONNX + onnxFile, _ = exportToOnnx(unifiedModel, exampleInput, exportPath, debug) + + return unifiedModel diff --git a/DeepQuant/ExportBrevitas.py b/DeepQuant/ExportBrevitas.py deleted file mode 100644 index 0ab87f0..0000000 --- a/DeepQuant/ExportBrevitas.py +++ /dev/null @@ -1,304 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -import torch -import torch.nn as nn -from pathlib import Path - -from DeepQuant.Injects.Transformations import ( - LinearTransformation, # Transformation for quantized linear layers (QuantLinear, QuantConv2d) - ActivationTransformation, # Transformation for quantized activation functions (QuantReLU, etc.) - MHATransformation, # Transformation for quantized multi-head attention modules -) -from DeepQuant.Injects.Executor import ( - TransformationExecutor, -) # Orchestrates sequential transformations -from .CustomTracer import ( - CustomBrevitasTracer, - customBrevitasTrace, -) # Custom FX tracer for Brevitas modules -from DeepQuant.QuantManipulation.ParameterExtractor import ( - extract_brevitas_proxy_params, # Extracts quantization parameters from Brevitas proxies - print_quant_params, # Displays quantization parameters in a readable format -) -from DeepQuant.QuantManipulation.QuantNodesDivider import ( - split_quant_nodes, -) # Splits quantization nodes into Quant/Dequant pairs -from brevitas.export.inference import ( - quant_inference_mode, -) # Inference mode for quantized models -from brevitas.export import ( - export_onnx_qcdq, -) # Native Brevitas ONNX export functions -from DeepQuant.QuantManipulation.DequantModifier import ( - unifyLinearDequants, -) # Unifies dequant nodes in linear layers -from brevitas.fx import brevitas_symbolic_trace # Brevitas-specific symbolic tracing -from DeepQuant.Utils.GraphPrinter import ( - GraphModulePrinter, -) # Custom Graph Printer -from DeepQuant.Utils.FxInterpreter import NodeTracer - - -# ANSI color codes for improved debug output readability -BLUE = "\033[94m" -RED = "\033[31m" -ENDC = "\033[0m" - - -def exportBrevitas( - model: nn.Module, exampleInput: torch.Tensor, debug: bool = False -) -> nn.Module: - """ - Export a Brevitas model to an FX GraphModule with unrolled quantization operations. - - This function applies a series of transformations to make the quantization steps - explicit in the model's computation graph, then traces the transformed model using - a custom FX tracer. - - Args: - model: The Brevitas-based model to export. - example_input: A representative input tensor for shape tracing. - debug: If True, prints transformation progress information. - - Returns: - nn.Module: An FX GraphModule with explicit quantization operations. - """ - - EXPORT_FOLDER = Path().cwd() - if Path().cwd().name == "DeepQuant": - EXPORT_FOLDER = EXPORT_FOLDER / "Tests/ONNX" - EXPORT_FOLDER.mkdir(parents=True, exist_ok=True) - - printer = GraphModulePrinter() - - ############################################################################### - # 1. Original Network - ############################################################################### - - model = brevitas_symbolic_trace( - model - ) # Symbolically trace the original model using Brevitas - if debug: - print("\n\n=== 1. Original Network ===\n") - printer.print_tabular(model) - print() - - with ( - torch.no_grad(), - quant_inference_mode(model), - ): # Disable gradients and use quantized inference mode - outputModel = model( - exampleInput - ) # Compute original model output on example input for validation - - # export_onnx_qcdq( # Export original model to ONNX format with QCDQ (Quant-Cast-DeQuant) nodes - # model, # Model to export - # args=exampleInput, # Example input for tracing - # export_path=EXPORT_FOLDER / "1_model_qcdq_original.onnx", - # opset_version=13, - # ) - - ############################################################################### - # 2. Injection of New Modules - ############################################################################### - - # Create transformation sequence in appropriate order - transformations = [ - MHATransformation(), # Multi-head attention transformation (applied first) - LinearTransformation(), # Quantized linear layers transformation - ActivationTransformation(), # Quantized activation functions transformation - ] - - # Initialize custom tracer for Brevitas - tracer = CustomBrevitasTracer(debug=debug) - - # Create and execute transformation sequence using the executor - executor = TransformationExecutor(transformations, debug=debug, tracer=tracer) - transformedModel = executor.execute( - model, exampleInput - ) # Apply all transformations to the model - - # Generate FX graph using the same tracer for consistency - fxModel = customBrevitasTrace( - root=transformedModel, # Transformed model to trace - concreteArgs=(exampleInput,), - tracer=tracer, # Use same tracer to maintain consistency with transformations - ) - fxModel.recompile() # Recompile the FX module to update its forward method - with torch.no_grad(): - outputFxModel = fxModel(exampleInput) # Compute transformed model output - - if isinstance(outputModel, tuple): - outputModel = outputModel[0] - - if torch.allclose( - outputFxModel, outputModel, atol=1e-5 - ): # Check numerical equivalence within tolerance - if debug: - print(f"{BLUE} ✓ Injection of New Modules: output is consistent{ENDC}") - else: - raise RuntimeError( # Raise error if outputs differ significantly - f"{RED} ✗ Injection of New Modules changed the output significantly{ENDC}" - ) - - if debug: - print(f"{BLUE} ✓ All transformations completed successfully!{ENDC}") - if debug: - print("\n=== 2. Network after the Injection of New Modules ===\n") - printer.print_tabular(fxModel) - - # export_onnx_qcdq( # Export transformed model to ONNX - # fxModel, # Transformed model - # args=exampleInput, - # export_path=EXPORT_FOLDER / "2_model_qcdq_transformed.onnx", - # opset_version=13, - # ) - - ############################################################################### - # 3. Extraction of Parameters & Split of Quant Nodes - ############################################################################### - - # Extract quantization parameters from the network's proxies - proxyParams = extract_brevitas_proxy_params( - fxModel - ) # Get scale, zero_point, bit_width for each quant node - - if debug: - print_quant_params( - proxyParams - ) # Display extracted parameters in a readable format - - # Split quantization nodes into separate Quant and Dequant nodes - splitFxModel = split_quant_nodes( - fxModel, proxyParams, debug - ) # Transform quant nodes into quant-dequant pairs - splitFxModel.recompile() # Recompile to update forward method with new nodes - - with torch.no_grad(): - outputFxModelSplitQuant = splitFxModel( - exampleInput - ) # Compute output after node splitting - - # print("Output Original: ", output_model) - # print("Output Split: ", output_fx_model_split_quant) - - if torch.allclose( - outputModel, outputFxModelSplitQuant, atol=1e-5 - ): # Verify numerical consistency - if debug: - print(f"{BLUE} ✓ Split of Quant Nodes: output is consistent{ENDC}") - else: - raise RuntimeError( # Raise error if inconsistent - f"{RED} ✗ Split of Quant Nodes changed the output significantly{ENDC}" - ) - - if debug: - print("\n=== 3. Network after the Split of Quant Nodes ===\n") - printer.print_tabular(splitFxModel) - print() - - torch.onnx.export( - splitFxModel, - args=exampleInput, - f=EXPORT_FOLDER / "3_model_splitted_quant.onnx", - opset_version=13, - keep_initializers_as_inputs=True, - do_constant_folding=False, - ) - - # return split_fx_model - - ############################################################################### - # 4. Modification of Dequant Nodes (shift them down) - ############################################################################### - - # Perform the unification of linear dequant nodes (move dequantization after computation) - fxModelUnified = unifyLinearDequants(splitFxModel, debug=debug) - fxModelUnified.recompile() # Recompile to update forward method with new node arrangement - - # Compute output after dequant node unification - with torch.no_grad(): - outputFxModelDequantModified = fxModelUnified( - exampleInput - ) # Output after dequant modification - - print("Output Original: ", outputModel) - print("Output Dequant Modified: ", outputFxModelDequantModified) - - if debug: - print("\n=== 4. Network after the Modification of Dequant Nodes ===\n") - printer.print_tabular(fxModelUnified) - print() - - # # Verify numerical consistency after dequant modification - # if torch.allclose( - # output_model, output_fx_model_dequant_modified, atol=1e-5 - # ): # Verify numerical consistency - # if debug: - # print(f"{BLUE} ✓ Modification of Dequant Nodes: output is consistent{ENDC}") - # else: - # raise RuntimeError( # Raise error if inconsistent - # f"{RED} ✗ Modification of Dequant Nodes changed the output significantly{ENDC}" - # ) - - # if debug: - # print("\n=== 4. Network after the Modification of Dequant Nodes ===\n") - # printer.print_tabular(fx_model_unified) - # print() - - onnxFile: str = EXPORT_FOLDER / "4_model_dequant_moved.onnx" - torch.onnx.export( - fxModelUnified, - args=exampleInput, - # f=EXPORT_FOLDER / "4_model_dequant_moved.onnx", - f=onnxFile, - opset_version=13, - keep_initializers_as_inputs=True, - do_constant_folding=False, - input_names=["input"], - output_names=["output"], - ) - - # Verify numerical consistency after dequant modification - if torch.allclose( - outputModel, outputFxModelDequantModified, atol=1e-5 - ): # Verify numerical consistency - if debug: - print(f"{BLUE} ✓ Modification of Dequant Nodes: output is consistent{ENDC}") - else: - raise RuntimeError( # Raise error if inconsistent - f"{RED} ✗ Modification of Dequant Nodes changed the output significantly{ENDC}" - ) - - import numpy as np - import onnxruntime as ort - import onnx - - # Step 2: Load the model and run shape inference - # (All tensors in ONNX graph should have explicit shape information) - onnxModel = onnx.load(onnxFile) - inferredModel = onnx.shape_inference.infer_shapes(onnxModel) - - # Step 3: Save the model with inferred shapes - onnx.save(inferredModel, onnxFile) - - inputFile: str = EXPORT_FOLDER / "inputs.npz" - np.savez(inputFile, input=exampleInput.cpu()) - print("Input npz: ", exampleInput) - print(f"Input data saved to {inputFile} ✓") - - # onnxruntime to run the exported model - ortSession: ort.InferenceSession = ort.InferenceSession(onnxFile) - ortInputs: dict = {"input": exampleInput.cpu().numpy()} - ortOutput: np.ndarray = ortSession.run(None, ortInputs)[0] - - outputFile: str = EXPORT_FOLDER / "outputs.npz" - np.savez(outputFile, output=ortOutput) - print("Output npz: ", ortOutput) - print(f"Output data saved to {outputFile} ✓") - - return fxModelUnified # Return the final optimized FX GraphModule diff --git a/DeepQuant/Injects/Base.py b/DeepQuant/Injects/Base.py deleted file mode 100644 index e9d72b9..0000000 --- a/DeepQuant/Injects/Base.py +++ /dev/null @@ -1,109 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -Base transformation infrastructure for the Brevitas export process. - -This module provides the foundational TransformationPass class that handles: -- Module type matching -- Forward method injection -- Output validation -- Recursive submodule transformation -""" - -import torch -import torch.nn as nn -from abc import ABC, abstractmethod -from typing import Any, Optional, Union, Tuple -from ..CustomTracer import CustomBrevitasTracer - - -class TransformationPass(ABC): - """ - Generic transformation pass for modifying Brevitas modules. - - A transformation pass targets specific module types and applies custom forward - implementations while ensuring output consistency. - """ - - def __init__( - self, - moduleCls: Union[type, Tuple[type, ...]], - validationTol: float = 1e-6, - ) -> None: - """ - Initialize a transformation pass. - - Args: - module_cls: Module class(es) this transformation targets. - injection_fn: Function that modifies the module's forward pass. - validation_tol: Tolerance for numerical comparison in validation. - """ - self.moduleCls = moduleCls - self.validationTol = validationTol - - def checkModuleType(self, module: nn.Module) -> bool: - """ - Check if a module is an instance of the target class(es). - - Args: - module: Module to check. - - Returns: - bool: True if module is an instance of self.module_cls. - """ - return isinstance(module, self.moduleCls) - - @abstractmethod - def injectForward( - self, module: nn.Module, tracer: Optional[CustomBrevitasTracer] = None - ) -> None: - """ - Inject the custom forward implementation into a module. - - Args: - module: Module whose forward method will be replaced. - tracer: Optional tracer for registering module classes. - """ - pass - - def validateTransformation( - self, outputBefore: Any, outputAfter: Any, atol: Optional[float] = None - ) -> bool: - """ - Validate transformation by comparing outputs. - - Args: - output_before: Model output before transformation. - output_after: Model output after transformation. - atol: Optional custom tolerance for comparison. - - Returns: - bool: True if outputs match within tolerance. - """ - if atol is None: - atol = self.validationTol - return torch.allclose(outputBefore, outputAfter, atol=atol) - - def transform( - self, model: nn.Module, tracer: Optional[CustomBrevitasTracer] = None - ) -> bool: - """ - Apply the transformation to all matching submodules. - - Args: - model: Model containing submodules to transform. - tracer: Optional tracer for registering transformed modules. - - Returns: - bool: True if any modules were transformed. - """ - transformDone = False - for _, submodule in model.named_modules(): - if self.checkModuleType(submodule): - self.injectForward(submodule, tracer) - transformDone = True - return transformDone diff --git a/DeepQuant/Injects/Executor.py b/DeepQuant/Injects/Executor.py deleted file mode 100644 index e41f3e9..0000000 --- a/DeepQuant/Injects/Executor.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -Executor module for handling transformation sequences in the Brevitas export process. -""" - -import torch -import torch.nn as nn -from typing import List, Optional -from .Base import TransformationPass -from ..CustomTracer import CustomBrevitasTracer - -# ANSI color codes -BLUE = "\033[94m" -RED = "\033[91m" -ENDC = "\033[0m" - - -class TransformationExecutor: - """ - Manages and executes a sequence of model transformations. - - The executor applies each transformation in sequence, validating that model outputs - remain consistent after each transformation step. - """ - - def __init__( - self, - transformations: List[TransformationPass], - debug: bool = False, - tracer: Optional[CustomBrevitasTracer] = None, - ) -> None: - """ - Initialize the transformation executor. - - Args: - transformations: List of transformation passes to apply. - debug: Whether to print debug information during execution. - tracer: Optional CustomBrevitasTracer instance for module registration. - """ - self.transformations = transformations - self.debug = debug - self.tracer = tracer - - def execute(self, model: nn.Module, exampleInput: torch.Tensor) -> nn.Module: - """ - Execute all transformations on the model in sequence. - - For each transformation: - 1. Apply the transformation - 2. Validate that model outputs remain consistent - 3. Update the reference output for the next transformation - - Args: - model: The PyTorch model to transform. - example_input: A representative input tensor for validation. - - Returns: - nn.Module: The transformed model. - - Raises: - RuntimeError: If any transformation results in output mismatch. - """ - model.eval() - with torch.no_grad(): - outputBefore = model(exampleInput) - if isinstance(outputBefore, tuple): - outputBefore = outputBefore[0] - - for transformation in self.transformations: - if transformation.transform(model, tracer=self.tracer): - outputAfter = model(exampleInput) - if isinstance(outputAfter, tuple): - outputAfter = outputAfter[0] - - if not transformation.validateTransformation( - outputBefore, outputAfter - ): - raise RuntimeError( - f"{RED} ✗ {transformation.__class__.__name__} failed - outputs mismatch{ENDC}" - ) - - if self.debug: - print( - f"{BLUE} ✓ {transformation.__class__.__name__} transformation successful\n{ENDC}" - f" leafClasses: {self.tracer.leafClasses}\n" - f" nonLeafClasses: {self.tracer.nonLeafClasses}\n" - ) - - outputBefore = outputAfter - - return model diff --git a/DeepQuant/Injects/Transformations.py b/DeepQuant/Injects/Transformations.py deleted file mode 100644 index 9a0e031..0000000 --- a/DeepQuant/Injects/Transformations.py +++ /dev/null @@ -1,143 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -Transformation classes for different types of Brevitas modules. - -This module provides specific transformation passes for each type of quantized module: -- Linear layers (QuantLinear, QuantConv2d) -- Activation functions (QuantReLU, QuantSigmoid) -- Multi-head attention (QuantMultiheadAttention) - -Each transformation class implements the abstract injectForward method from TransformationPass -to define its specific module transformation logic. -""" - -import torch.nn as nn -from typing import Optional -from brevitas.nn.quant_layer import ( - QuantWeightBiasInputOutputLayer, - QuantNonLinearActLayer, -) -from brevitas.nn.quant_mha import QuantMultiheadAttention - -from .Base import TransformationPass -from ..CustomForwards.Linear import InnerForwardImplWrapperLinear, quantWBIOLForward -from ..CustomForwards.MultiHeadAttention import unrolledQuantMhaForward -from ..CustomTracer import CustomBrevitasTracer -from ..CustomForwards.Activations import ( - InnerForwardImplWrapperActivation, - quantActivationForward, -) - - -class LinearTransformation(TransformationPass): - """ - Transformation pass for quantized linear layers (QuantLinear, QuantConv2d). - - Replaces the default forward with an unrolled implementation that exposes - all quantization steps in the computation graph. - """ - - def __init__(self) -> None: - """Initialize the linear transformation pass.""" - super().__init__( - moduleCls=QuantWeightBiasInputOutputLayer, - validationTol=1e-6, - ) - - def injectForward( - self, module: nn.Module, tracer: Optional[CustomBrevitasTracer] = None - ) -> None: - """ - Inject custom forward implementation for linear layers. - - Args: - module: The linear module to transform. - tracer: Optional tracer for registering transformed modules. - """ - module.wrappedInnerForwardImpl = InnerForwardImplWrapperLinear( - module.inner_forward_impl - ) - module.forward = quantWBIOLForward.__get__(module) - - if tracer: - tracer.registerLeafModule(InnerForwardImplWrapperLinear) - tracer.registerNonLeafModule(QuantWeightBiasInputOutputLayer) - - -class ActivationTransformation(TransformationPass): - """ - Transformation pass for quantized activation functions. - - Replaces the default forward with an unrolled implementation that exposes - the input quantization and activation quantization steps. - """ - - def __init__(self) -> None: - """Initialize the activation transformation pass.""" - super().__init__( - moduleCls=QuantNonLinearActLayer, - validationTol=1e-6, - ) - - def injectForward( - self, module: nn.Module, tracer: Optional[CustomBrevitasTracer] = None - ) -> None: - """ - Inject custom forward implementation for activation layers. - - This method instantiates the original activation function (if provided) and - wraps it using InnerForwardImplWrapperActivation, then overrides the forward method. - - Args: - module: The activation module to transform. - tracer: Optional tracer for registering transformed modules. - """ - # If the activation implementation was provided (e.g. nn.ReLU for QuantReLU), - # instantiate it. Otherwise, default to an identity. - if hasattr(module, "act_impl") and module.act_impl is not None: - actInstance = module.act_impl() # e.g. nn.ReLU() - else: - actInstance = nn.Identity() - - module.wrappedActImpl = InnerForwardImplWrapperActivation(actInstance) - module.forward = quantActivationForward.__get__(module) - - if tracer: - tracer.registerLeafModule(InnerForwardImplWrapperActivation) - tracer.registerNonLeafModule(QuantNonLinearActLayer) - - -class MHATransformation(TransformationPass): - """ - Transformation pass for quantized multi-head attention layers. - - Replaces the default forward with an unrolled implementation that exposes - all attention operations and their associated quantization steps. - """ - - def __init__(self) -> None: - """Initialize the MHA transformation pass.""" - super().__init__( - moduleCls=QuantMultiheadAttention, - validationTol=1e-5, - ) - - def injectForward( - self, module: nn.Module, tracer: Optional[CustomBrevitasTracer] = None - ) -> None: - """ - Inject custom forward implementation for MHA layers. - - Args: - module: The MHA module to transform. - tracer: Optional tracer for registering transformed modules. - """ - module.forward = unrolledQuantMhaForward.__get__(module) - - if tracer: - tracer.registerNonLeafModule(QuantMultiheadAttention) diff --git a/DeepQuant/Pipeline/DequantUnify.py b/DeepQuant/Pipeline/DequantUnify.py new file mode 100644 index 0000000..f125e31 --- /dev/null +++ b/DeepQuant/Pipeline/DequantUnify.py @@ -0,0 +1,122 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Tuple + +import torch +import torch.nn as nn + +from DeepQuant.QuantManipulation.DequantModifier import unifyLinearDequants +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc +from DeepQuant.Utils.GraphPrinter import GraphModulePrinter +from DeepQuant.Utils.TensorRecorder import TensorRecorder + + +def mergeDequants( + model: nn.Module, + exampleInput: torch.Tensor, + referenceOutput: torch.Tensor, + debug: bool = False, + checkEquivalence: bool = False, +) -> Tuple[nn.Module, torch.Tensor]: + """ + Unify dequantization nodes to enable integer-only computation. + + This step modifies the dequantization nodes in the graph to allow + operations to run in the integer domain, applying dequantization + only after the computations are complete (Requantization). + """ + printer = GraphModulePrinter() + tensorRecorder = TensorRecorder(debug=debug) + + if debug: + # FBRANCASI: Register hooks to record tensors from the split model (before dequant modification) + tensorRecorder.registerForwardHooks( + model, + nodeTypes=[ + "wrappedInnerForwardImpl", + "dequant", + "unified_dequant", + "linear", + "conv", + "quant", + "act", + "bias_quant", + "act_quant", + "relu", + ], + ) + + # FBRANCASI: Run the model to record tensors before modification + with torch.no_grad(): + _ = model(exampleInput) + + if debug: + # FBRANCASI: Save tensors as reference for comparison + tensorRecorder.setReferenceTensors() + + # FBRANCASI: Register mappings from wrappedInnerForwardImpl nodes to expected unified_dequant nodes + for node in model.graph.nodes: + if node.op == "call_module" and "wrappedInnerForwardImpl" in node.target: + baseName = node.target.replace(".wrappedInnerForwardImpl", "") + dequantName = f"{baseName}_unified_dequant" + dequantName = dequantName.replace(".", "_") + + tensorRecorder.recordNodeMapping(node.target, dequantName) + + unifiedModel = unifyLinearDequants(model, debug=debug) + unifiedModel.recompile() + + if debug: + print(cc.header("4. Network after Modification of Dequant Nodes")) + printer.printTabular(unifiedModel) + print() + + with torch.no_grad(): + output = unifiedModel(exampleInput) + + # FBRANCASI: Check output equivalence with a warning instead of error + if checkEquivalence: + # FBRANCASI: Handle case where output/referenceOutput might be tuples + refToCompare = referenceOutput[0] if isinstance(referenceOutput, tuple) else referenceOutput + outToCompare = output[0] if isinstance(output, tuple) else output + if not torch.allclose(refToCompare, outToCompare, atol=1e-5) and debug: + print( + cc.warning( + "Modification of Dequant Nodes may have changed the output slightly" + ) + ) + + if debug: + # FBRANCASI: Register hooks for the unified model and compare tensors + tensorRecorder.registerForwardHooks( + unifiedModel, + nodeTypes=[ + "wrappedInnerForwardImpl", + "dequant", + "unified_dequant", + "linear", + "conv", + "quant", + "act", + "bias_quant", + "act_quant", + "relu", + ], + ) + + # FBRANCASI: Run the model to record tensors after modification + with torch.no_grad(): + _ = unifiedModel(exampleInput) + + # FBRANCASI: Compare tensors before and after modification + print(cc.info("Tensor Comparison Before/After Dequant Unification:")) + results = tensorRecorder.compareTensors() + tensorRecorder.printComparisonResults(results) + + tensorRecorder.removeHooks() + + return unifiedModel, output diff --git a/DeepQuant/Pipeline/Injection.py b/DeepQuant/Pipeline/Injection.py new file mode 100644 index 0000000..75b168f --- /dev/null +++ b/DeepQuant/Pipeline/Injection.py @@ -0,0 +1,68 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Tuple + +import torch +import torch.nn as nn + +from DeepQuant.Transforms.Executor import TransformationExecutor +from DeepQuant.Transforms.Transformations import ( + ActivationTransformation, + LinearTransformation, + MHATransformation, +) +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc +from DeepQuant.Utils.CustomTracer import QuantTracer, customBrevitasTrace +from DeepQuant.Utils.GraphPrinter import GraphModulePrinter + + +def injectCustomForwards( + model: nn.Module, + exampleInput: torch.Tensor, + referenceOutput: torch.Tensor, + debug: bool = False, + checkEquivalence: bool = False, +) -> Tuple[nn.Module, torch.Tensor]: + """Inject custom forward implementations into the model.""" + printer = GraphModulePrinter() + + tracer = QuantTracer(debug=debug) + + transformations = [ + MHATransformation(), + LinearTransformation(), + ActivationTransformation(), + ] + + executor = TransformationExecutor(transformations, debug=debug, tracer=tracer) + transformedModel = executor.execute(model, exampleInput) + + fxModel = customBrevitasTrace( + root=transformedModel, + tracer=tracer, + ) + fxModel.recompile() + + with torch.no_grad(): + output = fxModel(exampleInput) + + if checkEquivalence: + outputToCompare = output[0] if isinstance(output, tuple) else output + if torch.allclose(referenceOutput, outputToCompare, atol=1e-5): + if debug: + print(cc.success("Injection of New Modules: output is consistent")) + else: + raise RuntimeError( + cc.error("Injection of New Modules changed the output significantly") + ) + + if debug: + print(cc.header("2. Network after Injection of New Modules")) + printer.printTabular(fxModel) + print() + + return fxModel, output diff --git a/DeepQuant/Pipeline/OnnxExport.py b/DeepQuant/Pipeline/OnnxExport.py new file mode 100644 index 0000000..de7bb09 --- /dev/null +++ b/DeepQuant/Pipeline/OnnxExport.py @@ -0,0 +1,89 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from pathlib import Path +from typing import Tuple, Union + +import numpy as np +import onnx +import onnxruntime as ort +import torch +import torch.nn as nn + +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc + + +def create_deterministic_session(): + """ + Create ONNX Runtime session with deterministic settings for exact reproducibility. + """ + options = ort.SessionOptions() + + options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + + options.use_deterministic_compute = True + options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + + options.intra_op_num_threads = 1 + options.inter_op_num_threads = 1 + + options.enable_cpu_mem_arena = False + options.enable_mem_pattern = False + options.enable_mem_reuse = False + + options.log_severity_level = 3 + options.enable_profiling = False + + return options + + +def exportToOnnx( + model: nn.Module, + exampleInput: torch.Tensor, + exportPath: Union[str, Path], + debug: bool = False, +) -> Tuple[Path, np.ndarray]: + """Export model to ONNX format and save input/output data.""" + exportPath = Path(exportPath) + exportPath.mkdir(parents=True, exist_ok=True) + + onnxFile = exportPath / "network.onnx" + inputFile = exportPath / "inputs.npz" + outputFile = exportPath / "outputs.npz" + + torch.onnx.export( + model, + args=exampleInput, + f=onnxFile, + opset_version=17, + keep_initializers_as_inputs=False, # FBRANCASI: Prevent warnings + do_constant_folding=True, + input_names=["input"], + output_names=["output"], + ) + + onnxModel = onnx.load(onnxFile) + inferredModel = onnx.shape_inference.infer_shapes(onnxModel) + onnx.save(inferredModel, onnxFile) + + np.savez(inputFile, input=exampleInput.cpu().numpy()) + if debug: + print() + print(cc.success(f"Input data saved to {inputFile}")) + + options = create_deterministic_session() + # ortSession = ort.InferenceSession(onnxFile) + ortSession = ort.InferenceSession( + onnxFile, sess_options=options, providers=["CPUExecutionProvider"] + ) + ortInputs = {"input": exampleInput.cpu().numpy()} + ortOutput = ortSession.run(None, ortInputs)[0] + + np.savez(outputFile, output=ortOutput) + if debug: + print(cc.success(f"Output data saved to {outputFile}\n")) + + return onnxFile, ortOutput diff --git a/DeepQuant/Pipeline/OriginalTracing.py b/DeepQuant/Pipeline/OriginalTracing.py new file mode 100644 index 0000000..d9e6bb9 --- /dev/null +++ b/DeepQuant/Pipeline/OriginalTracing.py @@ -0,0 +1,38 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Tuple + +import torch +import torch.nn as nn +from brevitas.export.inference import quant_inference_mode +from brevitas.fx import brevitas_symbolic_trace + +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc +from DeepQuant.Utils.GraphPrinter import GraphModulePrinter + + +def traceOriginalModel( + model: nn.Module, exampleInput: torch.Tensor, debug: bool = False +) -> Tuple[nn.Module, torch.Tensor]: + """Symbolically trace the original model using Brevitas.""" + printer = GraphModulePrinter() + + tracedModel = brevitas_symbolic_trace(model) + + if debug: + print(cc.header("1. Original Network")) + printer.printTabular(tracedModel) + print() + + with torch.no_grad(), quant_inference_mode(model): + output = model(exampleInput) + + # FBRANCASI: Handle case where output is a tuple (e.g., MHA) + if isinstance(output, tuple): + output = output[0] + + return tracedModel, output diff --git a/DeepQuant/Pipeline/QuantSplit.py b/DeepQuant/Pipeline/QuantSplit.py new file mode 100644 index 0000000..fafb80e --- /dev/null +++ b/DeepQuant/Pipeline/QuantSplit.py @@ -0,0 +1,65 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Tuple + +import torch +import torch.nn as nn + +from DeepQuant.QuantManipulation.QuantizationParameterExtractor import ( + extractBrevitasProxyParams, + printQuantParams, +) +from DeepQuant.QuantManipulation.QuantNodesDivider import convertQuantOperations +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc +from DeepQuant.Utils.GraphPrinter import GraphModulePrinter + + +def splitQuantNodes( + model: nn.Module, + exampleInput: torch.Tensor, + referenceOutput: torch.Tensor, + debug: bool = False, + checkEquivalence: bool = False, +) -> Tuple[nn.Module, torch.Tensor]: + """ + Split quantization nodes into separate Quant and Dequant nodes. + + This step transforms each quantization operation into explicit + Quant and Dequant node pairs, providing clear separation between + quantized and floating-point operations. + """ + printer = GraphModulePrinter() + + proxyParams = extractBrevitasProxyParams(model) + + if debug: + printQuantParams(proxyParams) + + splitModel = convertQuantOperations(model, proxyParams, debug) + splitModel.recompile() + + with torch.no_grad(): + output = splitModel(exampleInput) + + if checkEquivalence: + # FBRANCASI: Handle case where output/referenceOutput might be tuples + refToCompare = referenceOutput[0] if isinstance(referenceOutput, tuple) else referenceOutput + outToCompare = output[0] if isinstance(output, tuple) else output + if torch.allclose(refToCompare, outToCompare, atol=1e-5): + if debug: + print(cc.success("Split of Quant Nodes: output is consistent")) + else: + raise RuntimeError( + cc.error("Split of Quant Nodes changed the output significantly") + ) + + if debug: + print(cc.header("3. Network after Split of Quant Nodes")) + printer.printTabular(splitModel) + print() + + return splitModel, output diff --git a/DeepQuant/QuantManipulation/DequantModifier.py b/DeepQuant/QuantManipulation/DequantModifier.py index 8bd9ae5..d5342bd 100644 --- a/DeepQuant/QuantManipulation/DequantModifier.py +++ b/DeepQuant/QuantManipulation/DequantModifier.py @@ -4,66 +4,24 @@ # # Federico Brancasi -""" -This module provides a function to unify the linear dequant nodes (input, weight, bias) -into a single final dequant node after the linear wrappedInnerForwardImpl. - -Key steps: - 1) Rewire bias quant to reference the quant nodes of input/weight instead of their dequant. - 2) Rewire the linear's wrappedInnerForwardImpl so it references bias_quant instead of bias_dequant. - 3) Clone the bias dequant parameters (scale/zero_point/bit_width) to a new Dequant node - placed after the linear, removing the old bias_dequant node from the graph. - 4) Remove the input_dequant and weight_dequant nodes as well, once they have no more users. - 5) Recompile the FX GraphModule so that the generated forward code no longer references - the removed nodes. - -By the end, the linear operation is in the integer domain, and the final dequant occurs only once. -""" - import torch.fx as fx from DeepQuant.QuantManipulation.QuantDequantNodes import Dequant +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc -BLUE = "\033[94m" -ENDC = "\033[0m" -CHECK = " ✓" -ARROW = " ›" - - -def unifyLinearDequants( - fxModel: fx.GraphModule, debug: bool = False -) -> fx.GraphModule: - """ - Unify the linear dequant nodes (input, weight, bias) into a single final dequant node. - - This transformation: - * Redirects the linear's inputs to the quant nodes (removing input_dequant, weight_dequant). - * Updates bias_quant to reference those same quant nodes, removing references to dequant. - * Creates a new Dequant node after the linear operation, reusing the bias dequant parameters. - * Erases the old dequant nodes from the graph and submodules. - * Recompiles the graph so the final forward does not reference removed nodes. - - Args: - fxModel (fx.GraphModule): The input FX GraphModule to be modified. - debug (bool): If True, prints debug information. - - Returns: - fx.GraphModule: The modified FX GraphModule with a single dequant node after the linear. - """ +def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.GraphModule: + """Unify the linear dequant nodes (input, weight, bias) into a single final dequant node.""" graph = fxModel.graph allNodes = list(graph.nodes) if debug: - print(f"{BLUE}{ARROW} Starting Modification of Dequant Nodes...{ENDC}") + print(cc.info("Starting Modification of Dequant Nodes...")) for node in allNodes: - # Identify the "wrappedInnerForwardImpl" call for linear if node.op != "call_module" or "wrappedInnerForwardImpl" not in node.target: continue - # Typically the node args are: - # (linear1_input_dequant, linear1_weight_dequant, linear1_bias_dequant) oldArgs = list(node.args) biasDequantNode = None @@ -72,9 +30,11 @@ def unifyLinearDequants( newLinArgs = [] - # Collect and rewire the linear's arguments for arg in oldArgs: - if arg.op == "call_module" and "dequant" in arg.target.lower(): + # FCONTI: there is no Bias, propagate this to the newLinArgs + if arg is None: + newLinArgs.append(arg) + elif arg.op == "call_module" and "dequant" in arg.target.lower(): if "bias_dequant" in arg.target.lower(): biasDequantNode = arg elif "weight_dequant" in arg.target.lower(): @@ -82,7 +42,6 @@ def unifyLinearDequants( else: inputDequantNode = arg - # Replace the dequant input with the corresponding quant node quantNode = arg.args[0] newLinArgs.append(quantNode) else: @@ -91,46 +50,59 @@ def unifyLinearDequants( node.args = tuple(newLinArgs) if biasDequantNode is None: - # This would be unusual if a linear is missing bias or missing a bias_dequant + # FCONTI: this happens if a linear layer has no bias if debug: - print(f"Skipping {node.target}: no biasDequantNode found.") - continue - - # The bias_quant node that feeds biasDequantNode might reference input/weight dequant - # We rewrite it so that it references the input/weight quant nodes - biasQuantNode = biasDequantNode.args[0] - if ( - biasQuantNode.op == "call_module" - and "bias_quant" in biasQuantNode.target.lower() - ): - new_bq_args = list(biasQuantNode.args) - # Typically new_bq_args = [bias, input_dequant, weight_dequant] - for i, bq_arg in enumerate(new_bq_args): - if bq_arg.op == "call_module" and "dequant" in bq_arg.target.lower(): - new_bq_args[i] = bq_arg.args[0] # The corresponding quant node - biasQuantNode.args = tuple(new_bq_args) + print(f"Skipping bias for {node.target}: no biasDequantNode found.") + biasQuantNode = None else: - if debug: + biasQuantNode = biasDequantNode.args[0] + if ( + biasQuantNode.op == "call_module" + and "bias_quant" in biasQuantNode.target.lower() + ): + newBqArgs = list(biasQuantNode.args) + for i, bqArg in enumerate(newBqArgs): + if bqArg.op == "call_module" and "dequant" in bqArg.target.lower(): + newBqArgs[i] = bqArg.args[0] + biasQuantNode.args = tuple(newBqArgs) + else: + if debug: + print( + "Warning: Did not find a typical 'bias_quant' node shape in the graph." + ) + + # FCONTI: if there is a bias node, use it for scale/zeropoint/bitwidth. + # otherwise, rely on weight*input + if biasDequantNode is not None: + oldBiasDequantMod = fxModel.get_submodule(biasDequantNode.target) + dequantScale = oldBiasDequantMod.scale + dequantZeroPoint = oldBiasDequantMod.zeroPoint + dequantBitWidth = oldBiasDequantMod.bitWidth + oldDequantMod = oldBiasDequantMod + else: + oldInputDequantMod = fxModel.get_submodule(inputDequantNode.target) + oldWeightDequantMod = fxModel.get_submodule(weightDequantNode.target) + dequantScale = oldWeightDequantMod.scale * oldInputDequantMod.scale + # FCONTI: technically it should be: + # dZP = oWDM.zP * oIDM.zP - oWDM.scale * oIDM.zP * sum(weights) + # how to appropriately compute sum(weights)? + # for now we restrict ourselves to oIDM.zP = 0, so dZP = 0 + if debug and oldInputDequantMod.zeroPoint != 0.0: print( - "Warning: Did not find a typical 'bias_quant' node shape in the graph." + f"Warning: input Dequant node for {node.target} has non-zero zero-point (unsupported). Expect wrong results!" ) + dequantZeroPoint = 0.0 + dequantBitWidth = 32 # FCONTI: this is simply a reasonable assumption: is there a less arbitrary one? + oldDequantMod = oldWeightDequantMod - # Erase input_dequant/weight_dequant from the graph - # They should now have zero real users for dnode in (inputDequantNode, weightDequantNode): if dnode is not None: - # For safety, remove all references for usr in list(dnode.users.keys()): dnode.users[usr] = None if hasattr(fxModel, dnode.target): delattr(fxModel, dnode.target) graph.erase_node(dnode) - # Now we create the final single Dequant node after the linear - # by cloning the bias_dequant submodule's parameters - oldBiasDequantMod = fxModel.get_submodule(biasDequantNode.target) - - # Construct a new Dequant module from the old bias_dequant newDequantModName = ( node.target.replace(".wrappedInnerForwardImpl", "") + "_unified_dequant" ) @@ -138,21 +110,19 @@ def unifyLinearDequants( newDequantModName = newDequantModName.replace(".", "_") unifiedDequantMod = Dequant( - original_module=oldBiasDequantMod.original_module, - scale=oldBiasDequantMod.scale, - zero_point=oldBiasDequantMod.zero_point, - bit_width=oldBiasDequantMod.bit_width, + originalModule=oldDequantMod.originalModule, + scale=dequantScale, + zeroPoint=dequantZeroPoint, + bitWidth=dequantBitWidth, ) fxModel.add_module(newDequantModName, unifiedDequantMod) - # Insert the new dequant node after the linear's forward_impl with graph.inserting_after(node): newDequantNode = graph.call_module(newDequantModName, args=(node,)) - # Reroute all users of node to the new dequant node - old_users = list(node.users.keys()) - for usr in old_users: + oldUsers = list(node.users.keys()) + for usr in oldUsers: if usr is not newDequantNode: newArgs = list(usr.args) for i, a in enumerate(newArgs): @@ -160,29 +130,24 @@ def unifyLinearDequants( newArgs[i] = newDequantNode usr.args = tuple(newArgs) - # Remove the old bias_dequant node from the graph - for usr in list(biasDequantNode.users.keys()): - biasDequantNode.users[usr] = None - if hasattr(fxModel, biasDequantNode.target): - delattr(fxModel, biasDequantNode.target) - graph.erase_node(biasDequantNode) + if biasDequantNode is not None: + for usr in list(biasDequantNode.users.keys()): + biasDequantNode.users[usr] = None + if hasattr(fxModel, biasDequantNode.target): + delattr(fxModel, biasDequantNode.target) + graph.erase_node(biasDequantNode) if debug: - print(f" {CHECK} Modification done for {node.target}") + print(cc.success(f"Modification done for {node.target}")) - # Clean up any leftover references graph.lint() graph.eliminate_dead_code() - # Remove submodules that are now unused fxModel.delete_all_unused_submodules() - # Recompile so that the generated forward code no longer references removed nodes fxModel.recompile() if debug: - print( - f"{BLUE}{ARROW} Modification of Dequant Nodes completed successfully{ENDC}" - ) + print(cc.info("Modification of Dequant Nodes completed successfully")) return fxModel diff --git a/DeepQuant/QuantManipulation/ParameterExtractor.py b/DeepQuant/QuantManipulation/ParameterExtractor.py deleted file mode 100644 index b11d77b..0000000 --- a/DeepQuant/QuantManipulation/ParameterExtractor.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -This module extracts quantization proxy parameters from an exported FX model. -It retrieves scale, zero_point, bit_width and deduces the signedness of the quant -modules in the model by using type- and attribute-based checks rather than string -inspection. - -The safe_get_is_signed() function first looks for an explicit `is_signed` attribute, -then uses the module's min_val (if available) to infer signedness (a negative value -indicates signed quantization). If neither is available, it falls back to checking -the zero_point (a zero or near-zero value suggests unsigned quantization). - -The extracted parameters are printed using a color-coded format. -""" - -from typing import Any, Dict -import torch -import torch.nn as nn -from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector -from brevitas.proxy.parameter_quant import ( - WeightQuantProxyFromInjector, - BiasQuantProxyFromInjector, -) -from colorama import Fore, Style - - -def safe_get_scale(quant_obj: Any) -> Any: - """ - Safely retrieve the scale from a Brevitas quant proxy object. - - Args: - quant_obj: The quant proxy object. - - Returns: - The scale as a float if available, otherwise None. - """ - if quant_obj is None: - return None - maybe_scale = quant_obj.scale() if callable(quant_obj.scale) else quant_obj.scale - if maybe_scale is None: - return None - if isinstance(maybe_scale, torch.Tensor): - return maybe_scale.item() - elif isinstance(maybe_scale, float): - return maybe_scale - try: - return float(maybe_scale) - except Exception: - return None - - -def safe_get_zero_point(quant_obj: Any) -> Any: - """ - Safely retrieve the zero_point from a Brevitas quant proxy object. - - Args: - quant_obj: The quant proxy object. - - Returns: - The zero_point as a float if available, otherwise None. - """ - if quant_obj is None: - return None - maybe_zp = ( - quant_obj.zero_point() - if callable(quant_obj.zero_point) - else quant_obj.zero_point - ) - if maybe_zp is None: - return None - if isinstance(maybe_zp, torch.Tensor): - return maybe_zp.item() - elif isinstance(maybe_zp, float): - return maybe_zp - try: - return float(maybe_zp) - except Exception: - return None - - -def safe_get_is_signed(quant_obj: Any) -> bool: - """ - Determine whether a quant proxy/module is signed. - - The function first checks for an explicit `is_signed` attribute. - If not found, it checks for a `min_val` attribute: a negative min_val - indicates signed quantization. If that is unavailable, it examines the - zero_point (if nearly zero, it is assumed unsigned). Defaults to True. - - Args: - quant_obj: The quant proxy object. - - Returns: - True if the quantization is signed, False otherwise. - """ - if hasattr(quant_obj, "is_signed"): - return getattr(quant_obj, "is_signed") - if hasattr(quant_obj, "min_val"): - try: - return quant_obj.min_val < 0 - except Exception: - pass - zp = safe_get_zero_point(quant_obj) - if zp is not None: - # If zero_point is near zero, assume unsigned quantization. - return not (abs(zp) < 1e-5) - return True - - -def extract_brevitas_proxy_params(model: nn.Module) -> Dict[str, Dict[str, Any]]: - """ - Recursively scan the exported FX model to find quant proxy submodules of types: - ActQuantProxyFromInjector, WeightQuantProxyFromInjector, or BiasQuantProxyFromInjector. - For each matching module, extract the scale, zero_point, bit_width, and deduced signedness. - - Args: - model: The exported FX model. - - Returns: - A dictionary mapping module names to their quantization parameters: - { - 'module_name': { - 'scale': float or None, - 'zero_point': float or None, - 'bit_width': float or None, - 'is_signed': bool - }, - ... - } - """ - params_dict: Dict[str, Dict[str, Any]] = {} - - def recurse_modules(parent_mod: nn.Module, prefix: str = "") -> None: - for child_name, child_mod in parent_mod.named_children(): - full_name = f"{prefix}.{child_name}" if prefix else child_name - if isinstance( - child_mod, - ( - ActQuantProxyFromInjector, - WeightQuantProxyFromInjector, - BiasQuantProxyFromInjector, - ), - ): - scl = safe_get_scale(child_mod) - zp = safe_get_zero_point(child_mod) - bw = ( - child_mod.bit_width() - ) # Assumes bit_width() returns a numeric value. - is_signed = safe_get_is_signed(child_mod) - params_dict[full_name] = { - "scale": scl, - "zero_point": zp, - "bit_width": bw, - "is_signed": is_signed, - } - recurse_modules(child_mod, prefix=full_name) - - recurse_modules(model) - return params_dict - - -def print_quant_params(params_dict: Dict[str, Dict[str, Any]]) -> None: - """ - Print the extracted quantization parameters for each proxy module in a - color-coded format. - - Args: - params_dict: Dictionary containing quantization parameters. - """ - print(f"\n{Fore.BLUE}Extracted Parameters from the Network:{Style.RESET_ALL}") - for layer_name, quant_values in params_dict.items(): - print(f" {Fore.BLUE}{layer_name}:{Style.RESET_ALL}") - for param_key, param_val in quant_values.items(): - print(f" {param_key}: {param_val}") - print() diff --git a/DeepQuant/QuantManipulation/QuantDequantNodes.py b/DeepQuant/QuantManipulation/QuantDequantNodes.py index 7332833..64256da 100644 --- a/DeepQuant/QuantManipulation/QuantDequantNodes.py +++ b/DeepQuant/QuantManipulation/QuantDequantNodes.py @@ -4,130 +4,77 @@ # # Federico Brancasi -""" -Basic implementation of Quant and Dequant modules. -""" +from typing import Optional import torch import torch.nn as nn -from typing import Any, Optional, Union class Quant(nn.Module): - """ - Fake-quant module that applies a "saturating" approach using scale, zero_point, bit_width, - and signedness parameters extracted from a Brevitas parameter dictionary. - - This module simulates quantization effects on tensors by scaling, shifting, rounding, - and clamping their values. - """ + """Quantization module that applies scale, zero-point, and bit-width constraints.""" def __init__( self, - original_module: nn.Module, + originalModule: nn.Module, scale: float, - zero_point: float, - bit_width: float, + zeroPoint: float, + bitWidth: float, signed: Optional[bool] = True, ) -> None: - """ - Initialize the Quant module. - - Args: - original_module: The original Brevitas quant module (kept for reference). - scale: Scale factor used for quantization. - zero_point: Zero-point used for quantization. - bit_width: Bit width for the quantized representation (e.g., 8.0, 32.0). - signed: Boolean flag indicating if quantization is signed. - """ super().__init__() - self.original_module = original_module + self.originalModule = originalModule self.scale = scale - self.zero_point = zero_point - self.bit_width = bit_width + self.zeroPoint = zeroPoint + self.bitWidth = bitWidth self.signed = signed - if self.bit_width is not None: - bw_int = int(self.bit_width) + if self.bitWidth is not None: + bwInt = int(self.bitWidth) if self.signed: - self.min_val = -(2 ** (bw_int - 1)) - self.max_val = (2 ** (bw_int - 1)) - 1 + self.minVal = -(2 ** (bwInt - 1)) + self.maxVal = (2 ** (bwInt - 1)) - 1 else: - self.min_val = 0 - self.max_val = (2**bw_int) - 1 + self.minVal = 0 + self.maxVal = (2**bwInt) - 1 else: - self.min_val = None - self.max_val = None + self.minVal = None + self.maxVal = None def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Apply fake quantization to the input tensor. - - The quantization process is as follows: - 1) Scale the input tensor by 1/scale. - 2) Shift the scaled tensor by the zero_point. - 3) Round the shifted tensor to the nearest integer. - 4) Clamp the rounded tensor to the representable range based on bit_width - and signedness. - - Args: - x: Input tensor. - - Returns: - The fake quantized tensor. - """ - if self.scale is None or self.zero_point is None: + """Quantize the input tensor.""" + if self.scale is None or self.zeroPoint is None: return x - x_scaled = x / self.scale - x_shifted = x_scaled + self.zero_point - x_rounded = torch.round(x_shifted) - if self.bit_width is not None: - x_rounded = torch.clamp(x_rounded, self.min_val, self.max_val) - return x_rounded + xScaled = x / self.scale + xShifted = xScaled + self.zeroPoint + xRounded = torch.floor(xShifted + 0.5) + + if self.bitWidth is not None: + xRounded = torch.clamp(xRounded, self.minVal, self.maxVal) + return xRounded class Dequant(nn.Module): - """ - Dequant module that re-applies scale and zero_point to invert the quantization effect. - """ + """Dequantization module that applies inverse scale and zero-point transformations.""" def __init__( self, - original_module: nn.Module, + originalModule: nn.Module, scale: float, - zero_point: float, - bit_width: float, + zeroPoint: float, + bitWidth: float, signed: Optional[bool] = True, ) -> None: - """ - Initialize the Dequant module. - - Args: - original_module: The original Brevitas quant module. - scale: Scale factor from extracted parameters. - zero_point: Zero-point from extracted parameters. - bit_width: Bit width from extracted parameters. - signed: Boolean flag indicating if quantization is signed. - """ super().__init__() - self.original_module = original_module + self.originalModule = originalModule self.scale = scale - self.zero_point = zero_point - self.bit_width = bit_width + self.zeroPoint = zeroPoint + self.bitWidth = bitWidth self.signed = signed def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Undo the fake quantization by reversing the shift and scale. - - Args: - x: Input tensor. - - Returns: - The dequantized tensor. - """ - if self.scale is None or self.zero_point is None: + """Dequantize the input tensor.""" + if self.scale is None or self.zeroPoint is None: return x - x_dequant = (x - self.zero_point) * self.scale - return x_dequant + dequantizedX = (x - self.zeroPoint) * self.scale + return dequantizedX diff --git a/DeepQuant/QuantManipulation/QuantNodesDivider.py b/DeepQuant/QuantManipulation/QuantNodesDivider.py index 6b7ab10..b7e2026 100644 --- a/DeepQuant/QuantManipulation/QuantNodesDivider.py +++ b/DeepQuant/QuantManipulation/QuantNodesDivider.py @@ -4,145 +4,141 @@ # # Federico Brancasi -""" -Module for transforming FX graphs by splitting quantization nodes into Quant and Dequant, -while skipping activation quant nodes to preserve nonzero outputs. -""" +from typing import Any, Dict, List, Tuple import torch.fx as fx -from typing import Dict, Any, List, Tuple -from .QuantDequantNodes import Quant, Dequant import torch.nn as nn -BLUE = "\033[94m" -ENDC = "\033[0m" -ARROW = " ›" +from DeepQuant.QuantManipulation.QuantDequantNodes import Dequant, Quant +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc -def create_quant_dequant_nodes( +def insertQuantDequantPair( graph: fx.Graph, node: fx.Node, - fx_model: fx.GraphModule, - quant_name: str, - dequant_name: str, - original_module: nn.Module, - param_dict: Dict[str, Any], + fxModel: fx.GraphModule, + quantName: str, + dequantName: str, + originalModule: nn.Module, + paramDict: Dict[str, Any], ) -> Tuple[fx.Node, fx.Node]: - """ - Create separate Quant and Dequant nodes for a given FX node. - - This function replaces a single quantization node (e.g. weight_quant) - with two call_module nodes: one for Quant and one for Dequant. Because - the Quant module only accepts one Tensor argument, multiple arguments - (e.g. bias, input, weight) must be reduced to one. - - Args: - graph: The FX graph to insert new nodes into. - node: The original node referencing a quantization module. - fx_model: The GraphModule containing submodules. - quant_name: Name for the new Quant submodule. - dequant_name: Name for the new Dequant submodule. - original_module: The original Brevitas quant module. - param_dict: Dictionary with keys 'scale', 'zero_point', 'bit_width', - and 'is_signed'. - - Returns: - A tuple containing the newly created Quant and Dequant nodes. - """ + """Create separate Quant and Dequant nodes for a given FX node.""" if "bias_quant" in node.target.lower(): - main_arg = node.args[0] + mainArg = node.args[0] elif "weight_quant" in node.target.lower(): - main_arg = node.args[0] + mainArg = node.args[0] else: - main_arg = node.args[0] + mainArg = node.args[0] - scale_val = param_dict.get("scale", None) - zp_val = param_dict.get("zero_point", None) - bw_val = param_dict.get("bit_width", None) - signed_val = param_dict.get("is_signed", True) + scaleVal = paramDict.get("scale", None) + zpVal = paramDict.get("zero_point", None) + bwVal = paramDict.get("bit_width", None) + signedVal = paramDict.get("is_signed", True) - fx_model.add_module( - quant_name, Quant(original_module, scale_val, zp_val, bw_val, signed=signed_val) + fxModel.add_module( + quantName, Quant(originalModule, scaleVal, zpVal, bwVal, signed=signedVal) ) - fx_model.add_module( - dequant_name, - Dequant(original_module, scale_val, zp_val, bw_val, signed=signed_val), + fxModel.add_module( + dequantName, + Dequant(originalModule, scaleVal, zpVal, bwVal, signed=signedVal), ) with graph.inserting_after(node): - quant_node = graph.call_module(quant_name, args=(main_arg,)) + quantNode = graph.call_module(quantName, args=(mainArg,)) - with graph.inserting_after(quant_node): - dequant_node = graph.call_module(dequant_name, args=(quant_node,)) + with graph.inserting_after(quantNode): + dequantNode = graph.call_module(dequantName, args=(quantNode,)) - return quant_node, dequant_node + return quantNode, dequantNode -def split_quant_nodes( - fx_model: fx.GraphModule, full_params_dict: Dict[str, Dict[str, Any]], debug: bool +def convertQuantOperations( + fxModel: fx.GraphModule, fullParamsDict: Dict[str, Dict[str, Any]], debug: bool ) -> fx.GraphModule: - """ - Transform an FX graph by splitting each "call_module(...quant...)" node into - separate Quant -> Dequant nodes, skipping activation quant nodes to preserve - numeric accuracy. - - Args: - fx_model: The input FX GraphModule. - full_params_dict: A dictionary mapping module names to quantization parameters. - debug: Whether to print debug output. - - Returns: - The updated FX GraphModule with weight/bias quant calls split. - """ - graph = fx_model.graph - nodes_to_erase: List[fx.Node] = [] + """Split quantization nodes into separate Quant and Dequant nodes.""" + graph = fxModel.graph + nodesToRemove: List[fx.Node] = [] if debug: - print(f"{BLUE}{ARROW} Starting Quantization Node Splitting...{ENDC}") + print(cc.info("Starting Quantization Node Splitting...")) - all_nodes = list(graph.nodes) + allNodes = list(graph.nodes) - for node in all_nodes: + for node in allNodes: if ( node.op == "call_module" and "quant" in node.target.lower() and "act_impl" not in node.target.lower() ): - top_level = node.target.split(".")[0] - if top_level in ["sigmoid"]: - continue # Skip sigmoid + topLevel = node.target.split(".")[0] + if topLevel in ["sigmoid"]: + continue # FBRANCASI: Skip sigmoid - original_module = fx_model.get_submodule(node.target) - safe_target = node.target.replace(".", "_").replace("_quant", "") - quant_name = f"{safe_target}_quant_1" - dequant_name = f"{safe_target}_dequant" - param_info = full_params_dict.get(node.target, {}) + originalModule = fxModel.get_submodule(node.target) + safeTarget = node.target.replace(".", "_").replace("_quant", "") + quantName = f"{safeTarget}_quant_1" + dequantName = f"{safeTarget}_dequant" + paramInfo = fullParamsDict.get(node.target, {}) - quant_node, dequant_node = create_quant_dequant_nodes( + quantNode, dequantNode = insertQuantDequantPair( graph, node, - fx_model, - quant_name, - dequant_name, - original_module, - param_info, + fxModel, + quantName, + dequantName, + originalModule, + paramInfo, ) - # Re-route all users of the original node. - for user_node in list(node.users.keys()): - new_args = [] - for arg in user_node.args: - new_args.append(dequant_node if arg is node else arg) - user_node.args = tuple(new_args) - - nodes_to_erase.append(node) - - for erase_node in nodes_to_erase: - graph.erase_node(erase_node) + usersUpdated = False + for userNode in list(node.users.keys()): + if ( + userNode.op == "call_function" + and hasattr(userNode.target, "__name__") + and userNode.target.__name__ == "cat" + ): + # FBRANCASI: This is a concatenation operation - Special Handling + newCatArgs = list(userNode.args) + if len(newCatArgs) >= 1 and isinstance(newCatArgs[0], list): + tensorsList = newCatArgs[0] + updatedTensors = [] + for tensor in tensorsList: + if tensor is node: + updatedTensors.append(dequantNode) + else: + updatedTensors.append(tensor) + newCatArgs[0] = updatedTensors + userNode.args = tuple(newCatArgs) + usersUpdated = True + elif ( + userNode.op == "call_function" + and userNode.target == getattr + and len(userNode.args) >= 2 + and userNode.args[0] is node + and userNode.args[1] == "value" + ): + # FBRANCASI: Special handling for .value access on dequant output + # Replace getattr(dequant_node, 'value') with just dequant_node + userNode.replace_all_uses_with(dequantNode) + nodesToRemove.append(userNode) + usersUpdated = True + else: + # FBRANCASI: Standard node reference replacement + newArgs = [] + for arg in userNode.args: + newArgs.append(dequantNode if arg is node else arg) + userNode.args = tuple(newArgs) + usersUpdated = True + + if usersUpdated: + nodesToRemove.append(node) + + for eraseNode in nodesToRemove: + graph.erase_node(eraseNode) graph.lint() if debug: - print(f"{BLUE}{ARROW} Quantization Node Splitting completed Successfully{ENDC}") + print(cc.info("Quantization Node Splitting completed Successfully")) - return fx_model + return fxModel diff --git a/DeepQuant/QuantManipulation/QuantizationParameterExtractor.py b/DeepQuant/QuantManipulation/QuantizationParameterExtractor.py new file mode 100644 index 0000000..22c0629 --- /dev/null +++ b/DeepQuant/QuantManipulation/QuantizationParameterExtractor.py @@ -0,0 +1,110 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Any, Dict + +import torch +import torch.nn as nn +from brevitas.proxy.parameter_quant import ( + BiasQuantProxyFromInjector, + WeightQuantProxyFromInjector, +) +from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector + + +def getScale(quantObj: Any) -> Any: + """Extract scale parameter from quantization object.""" + if quantObj is None: + return None + maybeScale = quantObj.scale() if callable(quantObj.scale) else quantObj.scale + if maybeScale is None: + return None + if isinstance(maybeScale, torch.Tensor): + return maybeScale.item() + elif isinstance(maybeScale, float): + return maybeScale + try: + return float(maybeScale) + except Exception: + return None + + +def getZeroPoint(quantObj: Any) -> Any: + """Extract zero point parameter from quantization object.""" + if quantObj is None: + return None + maybeZp = ( + quantObj.zero_point() if callable(quantObj.zero_point) else quantObj.zero_point + ) + if maybeZp is None: + return None + if isinstance(maybeZp, torch.Tensor): + return maybeZp.item() + elif isinstance(maybeZp, float): + return maybeZp + try: + return float(maybeZp) + except Exception: + return None + + +def getIsSigned(quantObj: Any) -> bool: + """Determine if quantization is signed.""" + if hasattr(quantObj, "is_signed"): + return getattr(quantObj, "is_signed") + if hasattr(quantObj, "min_val"): + try: + return quantObj.min_val < 0 + except Exception: + pass + zp = getZeroPoint(quantObj) + if zp is not None: + # If zero_point is near zero, assume unsigned quantization. + return not (abs(zp) < 1e-5) + return True + + +def extractBrevitasProxyParams(model: nn.Module) -> Dict[str, Dict[str, Any]]: + """Extract quantization parameters from Brevitas proxy modules.""" + paramsDict: Dict[str, Dict[str, Any]] = {} + + def recurseModules(parentMod: nn.Module, prefix: str = "") -> None: + for childName, childMod in parentMod.named_children(): + fullName = f"{prefix}.{childName}" if prefix else childName + if isinstance( + childMod, + ( + ActQuantProxyFromInjector, + WeightQuantProxyFromInjector, + BiasQuantProxyFromInjector, + ), + ): + scl = getScale(childMod) + zp = getZeroPoint(childMod) + bw = childMod.bit_width() + isSigned = getIsSigned(childMod) + paramsDict[fullName] = { + "scale": scl, + "zero_point": zp, + "bit_width": bw, + "is_signed": isSigned, + } + recurseModules(childMod, prefix=fullName) + + recurseModules(model) + return paramsDict + + +def printQuantParams(paramsDict: Dict[str, Dict[str, Any]]) -> None: + """Print extracted quantization parameters in a readable format.""" + from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc + + print(f"{cc.wrap('Extracted Parameters from the Network:', cc.blue)}") + for layerName, quantValues in paramsDict.items(): + print(f" {cc.wrap(layerName + ':', cc.blue)}") + for paramKey, paramVal in quantValues.items(): + print(f" {paramKey}: {paramVal}") + print() diff --git a/DeepQuant/Transforms/Base.py b/DeepQuant/Transforms/Base.py new file mode 100644 index 0000000..392d6cf --- /dev/null +++ b/DeepQuant/Transforms/Base.py @@ -0,0 +1,53 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from abc import ABC, abstractmethod +from typing import Any, Optional, Tuple, Union + +import torch +import torch.nn as nn + +from DeepQuant.Utils.CustomTracer import QuantTracer + + +class TransformationPass(ABC): + """Base class for module transformation passes.""" + + def __init__( + self, + moduleCls: Union[type, Tuple[type, ...]], + validationTol: float = 1e-6, + ) -> None: + self.moduleCls = moduleCls + self.validationTol = validationTol + + def checkModuleType(self, module: nn.Module) -> bool: + """Check if a module is an instance of the target class(es).""" + return isinstance(module, self.moduleCls) + + @abstractmethod + def injectForward( + self, module: nn.Module, tracer: Optional[QuantTracer] = None + ) -> None: + """Inject the custom forward implementation into a module.""" + pass + + def validateTransformation( + self, outputBefore: Any, outputAfter: Any, atol: Optional[float] = None + ) -> bool: + """Validate transformation by comparing outputs.""" + if atol is None: + atol = self.validationTol + return torch.allclose(outputBefore, outputAfter, atol=atol) + + def transform(self, model: nn.Module, tracer: Optional[QuantTracer] = None) -> bool: + """Apply the transformation to all matching submodules.""" + transformDone = False + for _, submodule in model.named_modules(): + if self.checkModuleType(submodule): + self.injectForward(submodule, tracer) + transformDone = True + return transformDone diff --git a/DeepQuant/Transforms/Executor.py b/DeepQuant/Transforms/Executor.py new file mode 100644 index 0000000..09068f7 --- /dev/null +++ b/DeepQuant/Transforms/Executor.py @@ -0,0 +1,65 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import List, Optional + +import torch +import torch.nn as nn + +from DeepQuant.Transforms.Base import TransformationPass +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc +from DeepQuant.Utils.CustomTracer import QuantTracer + + +class TransformationExecutor: + """Runs a sequence of transformation passes.""" + + def __init__( + self, + transformations: List[TransformationPass], + debug: bool = False, + tracer: Optional[QuantTracer] = None, + ) -> None: + self.transformations = transformations + self.debug = debug + self.tracer = tracer + + def execute(self, model: nn.Module, exampleInput: torch.Tensor) -> nn.Module: + """Execute all transformations on the model.""" + model.eval() + with torch.no_grad(): + outputBefore = model(exampleInput) + if isinstance(outputBefore, tuple): + outputBefore = outputBefore[0] + + for transformation in self.transformations: + if transformation.transform(model, tracer=self.tracer): + outputAfter = model(exampleInput) + if isinstance(outputAfter, tuple): + outputAfter = outputAfter[0] + + if not transformation.validateTransformation( + outputBefore, outputAfter + ): + raise RuntimeError( + cc.error( + f"{transformation.__class__.__name__} failed - outputs mismatch" + ) + ) + + if self.debug: + print( + cc.success( + f"{transformation.__class__.__name__} transformation successful" + ) + ) + if self.tracer: + print(f" leafClasses: {self.tracer.leafClasses}") + print(f" nonLeafClasses: {self.tracer.nonLeafClasses}") + + outputBefore = outputAfter + + return model diff --git a/DeepQuant/Transforms/Transformations.py b/DeepQuant/Transforms/Transformations.py new file mode 100644 index 0000000..9464c03 --- /dev/null +++ b/DeepQuant/Transforms/Transformations.py @@ -0,0 +1,95 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Optional + +import torch.nn as nn +from brevitas.nn.quant_layer import ( + QuantNonLinearActLayer, + QuantWeightBiasInputOutputLayer, +) +from brevitas.nn.quant_mha import QuantMultiheadAttention + +from DeepQuant.CustomForwards.Activations import WrapperActivation, activationForward +from DeepQuant.CustomForwards.MultiHeadAttention import ( + mhaForwardBatchFirst, + mhaForwardSeqFirst, +) +from DeepQuant.CustomForwards.WBIOL import WBIOLForward, WrapperWBIOL +from DeepQuant.Transforms.Base import TransformationPass +from DeepQuant.Utils.CustomTracer import QuantTracer + + +class LinearTransformation(TransformationPass): + """Transforms quantized linear layers.""" + + def __init__(self) -> None: + super().__init__( + moduleCls=QuantWeightBiasInputOutputLayer, + validationTol=1e-6, + ) + + def injectForward( + self, module: nn.Module, tracer: Optional[QuantTracer] = None + ) -> None: + """Inject custom forward for linear layers.""" + module.wrappedInnerForwardImpl = WrapperWBIOL(module.inner_forward_impl) + module.forward = WBIOLForward.__get__(module) + + if tracer: + tracer.registerLeafModule(WrapperWBIOL) + tracer.registerNonLeafModule(QuantWeightBiasInputOutputLayer) + + +class ActivationTransformation(TransformationPass): + """Transforms quantized activation layers.""" + + def __init__(self) -> None: + super().__init__( + moduleCls=QuantNonLinearActLayer, + validationTol=1e-6, + ) + + def injectForward( + self, module: nn.Module, tracer: Optional[QuantTracer] = None + ) -> None: + """Inject custom forward for activation layers.""" + # FBRANCASI: If the activation implementation was provided (e.g. nn.ReLU + # for QuantReLU), instantiate it. Otherwise, default to an identity. + if hasattr(module, "act_impl") and module.act_impl is not None: + actInstance = module.act_impl() + else: + actInstance = nn.Identity() + + module.wrappedActImpl = WrapperActivation(actInstance) + module.forward = activationForward.__get__(module) + + if tracer: + tracer.registerLeafModule(WrapperActivation) + tracer.registerNonLeafModule(QuantNonLinearActLayer) + + +class MHATransformation(TransformationPass): + """Transforms quantized multi-head attention layers.""" + + def __init__(self) -> None: + super().__init__( + moduleCls=QuantMultiheadAttention, + validationTol=1e-5, + ) + + def injectForward( + self, module: nn.Module, tracer: Optional[QuantTracer] = None + ) -> None: + """Inject custom forward for multi-head attention layers.""" + # Select the appropriate forward function based on batch_first + if module.batch_first: + module.forward = mhaForwardBatchFirst.__get__(module) + else: + module.forward = mhaForwardSeqFirst.__get__(module) + + if tracer: + tracer.registerNonLeafModule(QuantMultiheadAttention) diff --git a/DeepQuant/Utils/ConsoleFormatter.py b/DeepQuant/Utils/ConsoleFormatter.py new file mode 100644 index 0000000..e5d03f8 --- /dev/null +++ b/DeepQuant/Utils/ConsoleFormatter.py @@ -0,0 +1,54 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi +class ConsoleColor: + """Console color utilities for formatted terminal output.""" + + # Color codes + blue = "\033[94m" + green = "\033[92m" + red = "\033[91m" + yellow = "\033[93m" + cyan = "\033[96m" + magenta = "\033[95m" + bold = "\033[1m" + reset = "\033[0m" + + # Symbols + checkmark = " ✓" + cross = " ✗" + arrow = " ›" + + @staticmethod + def wrap(text: str, color: str) -> str: + """Wrap text with color codes.""" + return f"{color}{text}{ConsoleColor.reset}" + + @staticmethod + def success(text: str) -> str: + """Format a success message.""" + return ConsoleColor.wrap(f"{ConsoleColor.checkmark} {text}", ConsoleColor.green) + + @staticmethod + def error(text: str) -> str: + """Format an error message.""" + return ConsoleColor.wrap(f"{ConsoleColor.cross} {text}", ConsoleColor.red) + + @staticmethod + def info(text: str) -> str: + """Format an informational message.""" + return ConsoleColor.wrap(f"{ConsoleColor.arrow} {text}", ConsoleColor.blue) + + @staticmethod + def warning(text: str) -> str: + """Format a warning message.""" + return ConsoleColor.wrap(text, ConsoleColor.yellow) + + @staticmethod + def header(text: str) -> str: + """Format a step header with separator lines.""" + separator = "=" * 50 + header_text = f"{separator}\n{text}\n{separator}" + return f"\n{ConsoleColor.wrap(header_text, ConsoleColor.magenta)}" diff --git a/DeepQuant/Utils/CustomTracer.py b/DeepQuant/Utils/CustomTracer.py new file mode 100644 index 0000000..4343496 --- /dev/null +++ b/DeepQuant/Utils/CustomTracer.py @@ -0,0 +1,57 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import List, Optional, Type + +import torch.nn as nn +from brevitas.fx.brevitas_tracer import ( + Tracer, + _is_brevitas_leaf_module, + _symbolic_trace, +) +from torch.fx.graph_module import GraphModule + + +class QuantTracer(Tracer): + """Enhanced tracer with fine-grained control over module tracing.""" + + def __init__( + self, + leafClasses: Optional[List[Type[nn.Module]]] = None, + nonLeafClasses: Optional[List[Type[nn.Module]]] = None, + debug: bool = False, + ) -> None: + super().__init__() + self.leafClasses = leafClasses if leafClasses is not None else [] + self.nonLeafClasses = nonLeafClasses if nonLeafClasses is not None else [] + self.debug = debug + + def registerLeafModule(self, moduleCls: Type[nn.Module]) -> None: + """Register a module class as a leaf module.""" + if moduleCls not in self.leafClasses: + self.leafClasses.append(moduleCls) + + def registerNonLeafModule(self, moduleCls: Type[nn.Module]) -> None: + """Register a module class as a non-leaf module.""" + if moduleCls not in self.nonLeafClasses: + self.nonLeafClasses.append(moduleCls) + + def is_leaf_module(self, m: nn.Module, moduleQualifiedName: str) -> bool: + """Determine if a module should be treated as a leaf module.""" + if any(isinstance(m, lc) for lc in self.leafClasses): + return True + if any(isinstance(m, nlc) for nlc in self.nonLeafClasses): + return False + return _is_brevitas_leaf_module(m, moduleQualifiedName) + + +def customBrevitasTrace( + root: nn.Module, concreteArgs=None, tracer: Optional[QuantTracer] = None +) -> GraphModule: + """Create an FX GraphModule using the QuantTracer (a custom Brevitas tracer).""" + if tracer is None: + tracer = QuantTracer() + return _symbolic_trace(tracer, root, concreteArgs) diff --git a/DeepQuant/Utils/FixCTT2Graph.py b/DeepQuant/Utils/FixCTT2Graph.py new file mode 100644 index 0000000..2a22907 --- /dev/null +++ b/DeepQuant/Utils/FixCTT2Graph.py @@ -0,0 +1,147 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +""" +Script to fix the CCTTrueQuantized ONNX model by duplicating shared constants. +This resolves the issue where a single Floor constant (onnx::Floor_772) is shared +across multiple bias quantization operations. +""" + +import argparse +import os + +import numpy as np +import onnx +from onnx import helper + + +def fix_shared_constants(model_path, output_path): + """Fix shared constants in ONNX model by creating unique copies.""" + print(f"Loading ONNX model from: {model_path}") + model = onnx.load(model_path) + + graph = model.graph + + shared_floor_tensor = None + for initializer in graph.initializer: + if initializer.name == "onnx::Floor_772": + shared_floor_tensor = initializer + break + + if shared_floor_tensor is None: + print("No shared Floor constant found. Model may already be fixed.") + return False + + print(f"Found shared Floor constant: {shared_floor_tensor.name}") + print(f"Tensor shape: {shared_floor_tensor.dims}") + + floor_nodes = [] + for node in graph.node: + if node.op_type == "Floor": + for input_name in node.input: + if input_name == shared_floor_tensor.name: + floor_nodes.append(node) + break + + print(f"Found {len(floor_nodes)} Floor nodes sharing the constant:") + for node in floor_nodes: + print(f" - {node.name}") + + new_initializers = [] + for i, node in enumerate(floor_nodes): + unique_name = f"Floor_772_unique_{i}_{node.name.replace('/', '_')}" + + new_tensor = helper.make_tensor( + name=unique_name, + data_type=shared_floor_tensor.data_type, + dims=shared_floor_tensor.dims, + vals=( + shared_floor_tensor.float_data + if shared_floor_tensor.float_data + else np.frombuffer( + shared_floor_tensor.raw_data, dtype=np.float32 + ).tolist() + ), + ) + + new_initializers.append(new_tensor) + + for j, input_name in enumerate(node.input): + if input_name == shared_floor_tensor.name: + node.input[j] = unique_name + break + + print(f" Created unique constant: {unique_name} for node: {node.name}") + + graph.initializer.remove(shared_floor_tensor) + + for new_tensor in new_initializers: + graph.initializer.append(new_tensor) + + inputs_to_remove = [] + for input_tensor in graph.input: + if input_tensor.name == shared_floor_tensor.name: + inputs_to_remove.append(input_tensor) + + for input_tensor in inputs_to_remove: + graph.input.remove(input_tensor) + + try: + onnx.checker.check_model(model) + print("Model validation passed!") + except Exception as e: + print(f"Model validation failed: {e}") + return False + + print(f"Saving fixed model to: {output_path}") + onnx.save(model, output_path) + + return True + + +def main(): + parser = argparse.ArgumentParser( + description="Fix shared constants in CCTTrueQuantized ONNX model" + ) + parser.add_argument( + "--input", + required=True, + help="Path to input ONNX model", + ) + parser.add_argument( + "--output", + required=True, + help="Path to output fixed ONNX model", + ) + + args = parser.parse_args() + + if not os.path.exists(args.input): + print(f"Error: Input file does not exist: {args.input}") + return 1 + + success = fix_shared_constants(args.input, args.output) + + if success: + print("Successfully fixed the ONNX model!") + print(f"Original model: {args.input}") + print(f"Fixed model: {args.output}") + + # FBRANCASI: Replace the original model with the fixed one + backup_path = args.input + ".backup" + print(f"Creating backup: {backup_path}") + os.rename(args.input, backup_path) + os.rename(args.output, args.input) + print("Replaced original model with fixed version") + + return 0 + else: + print("Failed to fix the ONNX model") + return 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/DeepQuant/Utils/FxInterpreter.py b/DeepQuant/Utils/FxInterpreter.py deleted file mode 100644 index 1ac434a..0000000 --- a/DeepQuant/Utils/FxInterpreter.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -FX Graph tracer that traces each node by wrapping submodules with proxy objects. -""" - -import torch -import torch.nn as nn -import torch.fx as fx -from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union, Callable -import functools -import inspect - - -class NodeTracer: - """ - Traces execution through an FX graph by wrapping each module with a - proxy that logs input and output values. - """ - - def __init__(self, debug: bool = True) -> None: - """ - Initialize the tracer. - - Args: - debug: Whether to print debug information. - """ - self.debug = debug - self.BLUE = "\033[94m" - self.GREEN = "\033[92m" - self.YELLOW = "\033[93m" - self.RED = "\033[91m" - self.RESET = "\033[0m" - self.traced_modules: Dict[str, nn.Module] = {} - self.call_count: Dict[str, int] = {} - - def trace( - self, model: fx.GraphModule, example_input: torch.Tensor - ) -> Optional[torch.Tensor]: - """ - Trace the execution of the model by wrapping modules with proxies. - - Args: - model: The FX GraphModule to trace. - example_input: The input tensor. - - Returns: - The model output, if successful. - """ - if self.debug: - print( - f"\n{self.BLUE}===== Starting FX Graph Execution Tracing ====={self.RESET}\n" - ) - print( - f"{self.BLUE}Input shape: {tuple(example_input.shape)}, dtype: {example_input.dtype}{self.RESET}\n" - ) - - # Wrap all submodules with our proxy - self._wrap_modules(model) - - # Create a copy of the original model to restore wrapped modules after tracing - original_modules = { - name: module - for name, module in model.named_modules() - if not isinstance(module, fx.GraphModule) - } - - try: - # Execute the model with the example input - with torch.no_grad(): - output = model(example_input) - - if self.debug: - print(f"\n{self.GREEN}Execution completed successfully!{self.RESET}") - if isinstance(output, torch.Tensor): - print( - f"{self.GREEN}Output shape: {tuple(output.shape)}, dtype: {output.dtype}{self.RESET}" - ) - else: - print(f"{self.GREEN}Output type: {type(output)}{self.RESET}") - - return output - - except Exception as e: - if self.debug: - print(f"\n{self.RED}Error during execution: {str(e)}{self.RESET}") - return None - - finally: - # Restore original modules - self._restore_modules(model, original_modules) - - def _wrap_modules(self, model: fx.GraphModule) -> None: - """ - Wrap all relevant modules with tracing proxies. - - Args: - model: The model containing modules to wrap. - """ - # Find relevant modules that match nodes in the graph - for name, module in list(model.named_modules()): - if not isinstance(module, fx.GraphModule): - if hasattr(module, "forward"): - original_forward = module.forward - self.traced_modules[name] = original_forward - - # Create wrapped forward method with tracing - @functools.wraps(original_forward) - def traced_forward(self, *args, **kwargs): - module_name = self._tracing_name - - # Increment call count - self._tracer.call_count.setdefault(module_name, 0) - self._tracer.call_count[module_name] += 1 - call_idx = self._tracer.call_count[module_name] - - # Print module info before call - if self._tracer.debug: - module_type = type(self).__name__ - print( - f"\n{self._tracer.YELLOW}[{module_name} ({module_type}) - Call #{call_idx}]{self._tracer.RESET}" - ) - - # Print input tensor info - for i, arg in enumerate(args): - if isinstance(arg, torch.Tensor): - print( - f" Input {i}: Tensor{tuple(arg.shape)} ({arg.dtype})" - ) - # Sample values for extra context - if arg.numel() > 0: - flat = arg.reshape(-1) - sample = flat[:3].tolist() - sample_str = ", ".join( - ( - f"{x:.6f}" - if isinstance(x, float) - else str(x) - ) - for x in sample - ) - print( - f" Values: [{sample_str}{'...' if flat.numel() > 3 else ''}]" - ) - elif ( - isinstance(arg, (list, tuple)) - and len(arg) > 0 - and isinstance(arg[0], torch.Tensor) - ): - print( - f" Input {i}: {type(arg).__name__} of {len(arg)} Tensors" - ) - else: - print(f" Input {i}: {type(arg).__name__}") - - # Call original forward method - result = self._original_forward(*args, **kwargs) - - # Print output info - if self._tracer.debug: - if isinstance(result, torch.Tensor): - print( - f" {self._tracer.GREEN}Output: Tensor{tuple(result.shape)} ({result.dtype}){self._tracer.RESET}" - ) - # Sample output values - if result.numel() > 0: - flat = result.reshape(-1) - sample = flat[:3].tolist() - sample_str = ", ".join( - f"{x:.6f}" if isinstance(x, float) else str(x) - for x in sample - ) - print( - f" Values: [{sample_str}{'...' if flat.numel() > 3 else ''}]" - ) - elif isinstance(result, (list, tuple)) and len(result) > 0: - print( - f" {self._tracer.GREEN}Output: {type(result).__name__} of length {len(result)}{self._tracer.RESET}" - ) - else: - print( - f" {self._tracer.GREEN}Output: {type(result).__name__}{self._tracer.RESET}" - ) - - return result - - # Attach tracer reference and original forward to the wrapped method - traced_forward.__self__ = module - traced_forward.__self__._tracer = self - traced_forward.__self__._original_forward = original_forward - traced_forward.__self__._tracing_name = name - - # Replace forward with wrapped version - module.forward = traced_forward.__get__(module) - - def _restore_modules( - self, model: fx.GraphModule, original_modules: Dict[str, nn.Module] - ) -> None: - """ - Restore original forward methods for all wrapped modules. - - Args: - model: The model containing wrapped modules. - original_modules: Dictionary of original modules. - """ - for name, original_forward in self.traced_modules.items(): - parts = name.split(".") - current = model - - # Navigate to the module - for part in parts: - if hasattr(current, part): - current = getattr(current, part) - else: - break - - # Restore original forward if found - if hasattr(current, "forward") and hasattr(current, "_original_forward"): - current.forward = original_forward diff --git a/DeepQuant/Utils/GraphPrinter.py b/DeepQuant/Utils/GraphPrinter.py index d3d6b9e..35dc97b 100644 --- a/DeepQuant/Utils/GraphPrinter.py +++ b/DeepQuant/Utils/GraphPrinter.py @@ -4,85 +4,23 @@ # # Federico Brancasi -""" -This module provides a specialized GraphModulePrinter class to display an FX GraphModule -in a tabular format, including optional metadata about quantization (like eps, n_levels, signed). +from typing import List, Literal -Usage: - from DeepQuant.graph_printer import GraphModulePrinter - - printer = GraphModulePrinter() - printer.print_tabular( - fx_model, - show_opcode=True, - show_class=True, - show_name=True, - show_target=True, - show_args=True, - show_kwargs=True, - show_eps=False, - show_nlevels=True, - show_signed=True, - unicode=False - ) - -Note: -- This example assumes that each node in the graph may have a `node.meta['quant']` dict - with fields like eps_in, eps_out, n_levels_in, n_levels_out, signed_in, and signed_out. -- If these fields are not present, the code will gracefully skip them or display placeholders. -- If you do not have such metadata in node.meta, you can adapt the logic to suit your needs. -""" - -import math -from typing import Any, List, Literal, Optional import torch.fx as fx - -try: - # Optional: colorama for colored output (requires `pip install colorama`) - from colorama import Fore, Back, Style - - COLORAMA_AVAILABLE = True -except ImportError: - COLORAMA_AVAILABLE = False - -try: - # Optional: tabulate for printing tables (requires `pip install tabulate`) - from tabulate import tabulate - - TABULATE_AVAILABLE = True -except ImportError: - TABULATE_AVAILABLE = False +from colorama import Back, Fore, Style +from tabulate import tabulate class GraphModulePrinter: - """ - Class for printing an FX GraphModule in a tabular format, optionally displaying - quantization metadata stored in node.meta['quant']. - - The code is based on an example snippet from a supervisor. The logic is adjusted - to fit our code style and to gracefully handle missing metadata. - """ + """Formatter and printer for FX graph modules.""" @staticmethod - def quant_info( + def quantInfo( node: fx.Node, prop: Literal["eps_in", "eps_out", "n_levels", "signed"] ) -> str: - """ - Retrieve a string representation of the quantization property for a given node. - - Args: - node: The FX node containing potential quantization metadata. - prop: The quantization property to display. One of 'eps_in', 'eps_out', - 'n_levels', or 'signed'. - - Returns: - A string representation of the requested property if it exists, or '{}' otherwise. - """ if "quant" not in node.meta: return "{}" - # At this point, we assume node.meta['quant'] is a dict-like object containing - # fields such as eps_in, eps_out, n_levels_in, n_levels_out, signed_in, signed_out, etc. qmeta = node.meta["quant"] if prop == "eps_in": @@ -90,12 +28,10 @@ def quant_info( elif prop == "eps_out": return str(qmeta.get("eps_out", "{}")) elif prop == "n_levels": - # This is just an example: we might have n_levels_in, n_levels_out, etc. n_in = qmeta.get("n_levels_in", "{}") n_out = qmeta.get("n_levels_out", "{}") return f"{n_in} -> {n_out}" elif prop == "signed": - # Example: 'signed_in' and 'signed_out' s_in = qmeta.get("signed_in", "{}") s_out = qmeta.get("signed_out", "{}") return f"{s_in} -> {s_out}" @@ -103,196 +39,126 @@ def quant_info( return "{}" @staticmethod - def class_info(node: fx.Node, gm: fx.GraphModule, unicode: bool = False) -> str: - """ - Retrieve class name for call_module nodes. For example, if node.target is - referencing a submodule of type nn.Conv2d, this returns 'Conv2d'. - - Args: - node: The FX node to analyze. - gm: The FX GraphModule containing the node. - unicode: If True, optionally highlight certain classes. - - Returns: - The class name as a string, or '' if not applicable. - """ + def classInfo(node: fx.Node, gm: fx.GraphModule, unicode: bool = False) -> str: if node.op == "call_module": submodule = gm.get_submodule(node.target) class_name = submodule.__class__.__name__ - if not COLORAMA_AVAILABLE or not unicode: + if not unicode: return class_name - # Optionally highlight if it's a special class, e.g. 'PACT' or so. if "PACT" in class_name: return Fore.GREEN + class_name + Style.RESET_ALL return class_name return "" @staticmethod - def node_info(node: fx.Node, attr: str, unicode: bool = False) -> str: - """ - Retrieve a specified attribute from the node (e.g. 'op', 'name', 'target', 'args'). - - Args: - node: The FX node. - attr: The name of the attribute to retrieve (e.g. 'op', 'name', 'target', 'args'). - unicode: If True, highlight certain functions in color. - - Returns: - A string representation of the requested attribute, or '' if not present. - """ + def nodeInfo(node: fx.Node, attr: str, unicode: bool = False) -> str: if not hasattr(node, attr): return "" value = getattr(node, attr) if attr == "op": - # Optionally highlight certain call_function ops - if node.op == "call_function" and COLORAMA_AVAILABLE and unicode: - # Example of a function whitelist + if node.op == "call_function" and unicode: whitelist_functions = ["getitem"] - if node.target.__name__ not in whitelist_functions: + if ( + hasattr(node.target, "__name__") + and node.target.__name__ not in whitelist_functions + ): return Back.YELLOW + str(value) + Style.RESET_ALL return str(value) @classmethod - def get_node_spec( + def getNodeSpec( cls, node: fx.Node, gm: fx.GraphModule, - show_opcode: bool = True, - show_class: bool = True, - show_name: bool = True, - show_target: bool = True, - show_args: bool = True, - show_kwargs: bool = True, - show_eps: bool = False, - show_nlevels: bool = True, - show_signed: bool = True, + showOpcode: bool = True, + showClass: bool = True, + showName: bool = True, + showTarget: bool = True, + showArgs: bool = True, + showKwargs: bool = True, + showEps: bool = False, + showNlevels: bool = True, + showSigned: bool = True, unicode: bool = False, ) -> List[str]: - """ - Collect string representations of the node's attributes/metadata for printing. - - Args: - node: The FX node to process. - gm: The FX GraphModule containing the node. - show_opcode: Whether to display the node's op code. - show_class: Whether to display the submodule class name (for call_module). - show_name: Whether to display the node's name. - show_target: Whether to display the node's target. - show_args: Whether to display the node's args. - show_kwargs: Whether to display the node's kwargs. - show_eps: Whether to display the quantization eps_in/eps_out (if available). - show_nlevels: Whether to display the n_levels_in -> n_levels_out. - show_signed: Whether to display the signed_in -> signed_out. - unicode: If True, apply color highlights for certain attributes. - - Returns: - A list of strings representing each requested attribute in order. - """ - node_specs: List[str] = [] - - if show_opcode: - node_specs.append(cls.node_info(node, "op", unicode)) - if show_class: - node_specs.append(cls.class_info(node, gm, unicode)) - if show_name: - node_specs.append(cls.node_info(node, "name", unicode)) - if show_target: - node_specs.append(cls.node_info(node, "target", unicode)) - if show_args: - node_specs.append(cls.node_info(node, "args", unicode)) - if show_kwargs: - node_specs.append(cls.node_info(node, "kwargs", unicode)) - - if show_nlevels: - node_specs.append(cls.quant_info(node, "n_levels")) - if show_signed: - node_specs.append(cls.quant_info(node, "signed")) - if show_eps: - node_specs.append(cls.quant_info(node, "eps_in")) - node_specs.append(cls.quant_info(node, "eps_out")) - - return node_specs + nodeSpecs: List[str] = [] + + if showOpcode: + nodeSpecs.append(cls.nodeInfo(node, "op", unicode)) + if showClass: + nodeSpecs.append(cls.classInfo(node, gm, unicode)) + if showName: + nodeSpecs.append(cls.nodeInfo(node, "name", unicode)) + if showTarget: + nodeSpecs.append(cls.nodeInfo(node, "target", unicode)) + if showArgs: + nodeSpecs.append(cls.nodeInfo(node, "args", unicode)) + if showKwargs: + nodeSpecs.append(cls.nodeInfo(node, "kwargs", unicode)) + + if showNlevels: + nodeSpecs.append(cls.quantInfo(node, "n_levels")) + if showSigned: + nodeSpecs.append(cls.quantInfo(node, "signed")) + if showEps: + nodeSpecs.append(cls.quantInfo(node, "eps_in")) + nodeSpecs.append(cls.quantInfo(node, "eps_out")) + + return nodeSpecs @classmethod - def print_tabular( + def printTabular( cls, gm: fx.GraphModule, - show_opcode: bool = True, - show_class: bool = True, - show_name: bool = True, - show_target: bool = True, - show_args: bool = False, - show_kwargs: bool = False, - show_eps: bool = False, - show_nlevels: bool = False, - show_signed: bool = False, + showOpcode: bool = True, + showClass: bool = True, + showName: bool = True, + showTarget: bool = True, + showArgs: bool = False, + showKwargs: bool = False, + showEps: bool = False, + showNlevels: bool = False, + showSigned: bool = False, unicode: bool = False, ) -> None: - """ - Print the graph in a tabular format with optional quantization metadata. - - Args: - gm: The FX GraphModule to display. - show_opcode: Whether to display the node's op code. - show_class: Whether to display the submodule class name (for call_module). - show_name: Whether to display the node's name. - show_target: Whether to display the node's target. - show_args: Whether to display the node's args. - show_kwargs: Whether to display the node's kwargs. - show_eps: Whether to display the quantization eps_in/eps_out (if available). - show_nlevels: Whether to display the n_levels_in -> n_levels_out. - show_signed: Whether to display the signed_in -> signed_out. - unicode: If True, apply color highlights for certain attributes. - - Returns: - None - """ - if not TABULATE_AVAILABLE: - print( - "Warning: 'tabulate' is not installed. Install via 'pip install tabulate' to use print_tabular." - ) - return - - node_list = list(gm.graph.nodes) - node_specs = [ - cls.get_node_spec( + nodeList = list(gm.graph.nodes) + nodeSpecs = [ + cls.getNodeSpec( node, gm, - show_opcode=show_opcode, - show_class=show_class, - show_name=show_name, - show_target=show_target, - show_args=show_args, - show_kwargs=show_kwargs, - show_eps=show_eps, - show_nlevels=show_nlevels, - show_signed=show_signed, + showOpcode=showOpcode, + showClass=showClass, + showName=showName, + showTarget=showTarget, + showArgs=showArgs, + showKwargs=showKwargs, + showEps=showEps, + showNlevels=showNlevels, + showSigned=showSigned, unicode=unicode, ) - for node in node_list + for node in nodeList ] headers = [] - if show_opcode: + if showOpcode: headers.append("opcode") - if show_class: + if showClass: headers.append("class") - if show_name: + if showName: headers.append("name") - if show_target: + if showTarget: headers.append("target") - if show_args: + if showArgs: headers.append("args") - if show_kwargs: + if showKwargs: headers.append("kwargs") - if show_nlevels: + if showNlevels: headers.append("n_levels") - if show_signed: + if showSigned: headers.append("signed") - if show_eps: + if showEps: headers.append("eps_in") headers.append("eps_out") - from tabulate import tabulate # safe import inside method - - print(tabulate(node_specs, headers=headers, tablefmt="mixed_grid")) + print(tabulate(nodeSpecs, headers=headers, tablefmt="mixed_grid")) diff --git a/DeepQuant/Utils/ONNXcutter.py b/DeepQuant/Utils/ONNXcutter.py new file mode 100644 index 0000000..44e7562 --- /dev/null +++ b/DeepQuant/Utils/ONNXcutter.py @@ -0,0 +1,275 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +""" +ONNX Model Cutter Utility + +This script allows to cut an ONNX model at any specified output tensor name +and optionally generate the output values at the cut point. + +Basic Usage: + python onnx_cutter.py + +Examples: + # Cut model at specific tensor + python onnx_cutter.py model.onnx /MatMul_4_output_0 cut_model.onnx + + # List all available tensor names in the model + python onnx_cutter.py model.onnx --list + + # Cut model and generate output.npz with random input + python onnx_cutter.py model.onnx /Conv_output_0 cut_model.onnx --generate-output + + # Cut model and generate output.npz using specific input data + python onnx_cutter.py model.onnx /Conv_output_0 cut_model.onnx --generate-output --input-npz inputs.npz + + # Cut model and test it after creation + python onnx_cutter.py model.onnx /Conv_output_0 cut_model.onnx --test + +Arguments: + input_model: Path to input ONNX model + cut_point: Tensor name where to cut (e.g., /MatMul_4_output_0) + output_model: Path to save the cut model + +Options: + --list, -l: List all available tensor names in the model + --test, -t: Test the cut model after creation + --generate-output, -g: Generate outputs.npz file containing the output at cut point + --input-npz: Path to input.npz file for generating output (optional) + +Output Files: + - : The cut ONNX model + - outputs.npz: (if --generate-output) Contains the output tensor values at cut point +""" + +import argparse +import sys +from pathlib import Path + +import numpy as np +import onnx +import onnxruntime as ort +from onnx import utils + + +def create_deterministic_session(): + options = ort.SessionOptions() + + options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL + + options.use_deterministic_compute = True + options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL + + options.intra_op_num_threads = 1 + options.inter_op_num_threads = 1 + + options.enable_cpu_mem_arena = False + options.enable_mem_pattern = False + options.enable_mem_reuse = False + + options.log_severity_level = 3 + options.enable_profiling = False + + return options + + +def list_available_outputs(model_path): + print(f"Loading model: {model_path}") + model = onnx.load(model_path) + + print("\nAvailable intermediate tensor names:") + print("=" * 50) + + all_outputs = [] + for node in model.graph.node: + for output in node.output: + all_outputs.append(output) + + # Sort and display + all_outputs.sort() + for i, output in enumerate(all_outputs, 1): + print(f"{i:3d}. {output}") + + print(f"\nTotal intermediate tensors: {len(all_outputs)}") + return all_outputs + + +def cut_onnx_model( + input_path, cut_point, output_path, generate_output=False, input_data=None +): + print(f"Loading model: {input_path}") + model = onnx.load(input_path) + + all_outputs = [] + for node in model.graph.node: + for output in node.output: + all_outputs.append(output) + + if cut_point not in all_outputs: + print(f"Error: Cut point '{cut_point}' not found in model!") + print(f"Available outputs: {len(all_outputs)}") + + similar = [out for out in all_outputs if cut_point.split("/")[-1] in out] + if similar: + print("Similar outputs found:") + for sim in similar[:10]: + print(f" - {sim}") + return False + + print(f"Cutting model at: {cut_point}") + + utils.extract_model( + input_path, + output_path, + input_names=[inp.name for inp in model.graph.input], + output_names=[cut_point], + ) + + print(f"Cut model saved to: {output_path}") + + try: + cut_model = onnx.load(output_path) + print(f"Verification: Cut model has {len(cut_model.graph.node)} nodes") + print(f"Input: {[inp.name for inp in cut_model.graph.input]}") + print(f"Output: {[out.name for out in cut_model.graph.output]}") + + try: + inferred_model = onnx.shape_inference.infer_shapes(cut_model) + onnx.save(inferred_model, output_path) + print("Shape inference applied successfully") + except Exception as e: + print(f"Warning: Could not apply shape inference: {e}") + + if generate_output: + output_dir = Path(output_path).parent + outputFile = output_dir / "outputs.npz" + + options = create_deterministic_session() + ortSession = ort.InferenceSession( + output_path, sess_options=options, providers=["CPUExecutionProvider"] + ) + input_info = ortSession.get_inputs()[0] + + if input_data is None: + shape = input_info.shape + shape = [ + 1 if isinstance(dim, str) or dim == "batch_size" else dim + for dim in shape + ] + input_data = np.random.randn(*shape).astype(np.float32) + print(f"Generated random input with shape: {shape}") + + ortInputs = {input_info.name: input_data} + ortOutput = ortSession.run(None, ortInputs)[0] + + np.savez(outputFile, output=ortOutput) + print(f"Output data saved to {outputFile}") + print(f"Output shape: {ortOutput.shape}") + + except Exception as e: + print(f"Error verifying cut model: {e}") + return False + + return True + + +def test_cut_model(model_path, input_shape=None): + try: + session = ort.InferenceSession(model_path) + input_info = session.get_inputs()[0] + + print("\nTesting cut model:") + print(f"Input name: {input_info.name}") + print(f"Input shape: {input_info.shape}") + print(f"Input type: {input_info.type}") + + if input_shape is None: + shape = input_info.shape + shape = [1 if isinstance(dim, str) else dim for dim in shape] + else: + shape = input_shape + + dummy_input = np.random.randn(*shape).astype(np.float32) + + outputs = session.run(None, {input_info.name: dummy_input}) + + print("Test successful!") + print(f"Output shape: {outputs[0].shape}") + print(f"Output type: {outputs[0].dtype}") + + except Exception as e: + print(f"Error testing cut model: {e}") + + +def main(): + parser = argparse.ArgumentParser(description="Cut ONNX model at specified tensor") + parser.add_argument("input_model", help="Input ONNX model path") + parser.add_argument( + "cut_point", nargs="?", help="Tensor name to cut at (e.g., /MatMul_4_output_0)" + ) + parser.add_argument("output_model", nargs="?", help="Output ONNX model path") + parser.add_argument( + "--list", "-l", action="store_true", help="List available tensor names" + ) + parser.add_argument( + "--test", "-t", action="store_true", help="Test the cut model after creation" + ) + parser.add_argument( + "--generate-output", "-g", action="store_true", help="Generate output.npz file" + ) + parser.add_argument( + "--input-npz", help="Path to input.npz file (for generating output)" + ) + + args = parser.parse_args() + + if not Path(args.input_model).exists(): + print(f"Error: Input model '{args.input_model}' not found!") + sys.exit(1) + + if args.list: + list_available_outputs(args.input_model) + return + + if not args.cut_point or not args.output_model: + print("Error: cut_point and output_model are required!") + print("Use --list to see available tensor names") + print( + "Example: python onnx_cutter.py CCTTQ.onnx /MatMul_4_output_0 cut_model.onnx" + ) + sys.exit(1) + + input_data = None + if args.input_npz: + if not Path(args.input_npz).exists(): + print(f"Error: Input npz file '{args.input_npz}' not found!") + sys.exit(1) + try: + data = np.load(args.input_npz) + input_data = ( + data["input"] if "input" in data else data[list(data.keys())[0]] + ) + print( + f"Loaded input data from {args.input_npz} with shape: {input_data.shape}" + ) + except Exception as e: + print(f"Error loading input npz file: {e}") + sys.exit(1) + + success = cut_onnx_model( + args.input_model, + args.cut_point, + args.output_model, + generate_output=args.generate_output, + input_data=input_data, + ) + + if success and args.test: + test_cut_model(args.output_model) + + +if __name__ == "__main__": + main() diff --git a/DeepQuant/Utils/TensorRecorder.py b/DeepQuant/Utils/TensorRecorder.py new file mode 100644 index 0000000..073a6b6 --- /dev/null +++ b/DeepQuant/Utils/TensorRecorder.py @@ -0,0 +1,177 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from collections import OrderedDict +from typing import Dict, List, Optional, Set + +import torch +import torch.fx as fx + +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc + + +class TensorRecorder: + """Records and compares tensor values during model execution.""" + + def __init__(self, debug: bool = False): + self.debug = debug + self._hooks: List[torch.utils.hooks.RemovableHandle] = [] + self._current: Dict[str, torch.Tensor] = {} + self._reference: Optional[Dict[str, torch.Tensor]] = None + self._executionOrder: List[str] = [] + self._nameMap: Dict[str, str] = {} + self._ignore: Set[str] = set() + + def clear(self) -> None: + self.removeHooks() + self._current.clear() + self._reference = None + self._executionOrder.clear() + self._nameMap.clear() + self._ignore.clear() + + def removeHooks(self) -> None: + for hook in self._hooks: + hook.remove() + self._hooks.clear() + + def registerForwardHooks( + self, model: fx.GraphModule, nodeTypes: Optional[List[str]] = None + ) -> None: + self.removeHooks() + wanted = [w.lower() for w in nodeTypes] if nodeTypes else [] + + def makeHook(name: str): + def hook(_, __, output): + if isinstance(output, torch.Tensor): + self._current[name] = output.detach().clone() + if name not in self._executionOrder: + self._executionOrder.append(name) + + return hook + + for name, module in model.named_modules(): + if name and any(w in name.lower() for w in wanted): + self._hooks.append(module.register_forward_hook(makeHook(name))) + + def recordNodeMapping(self, referenceName: str, currentName: str) -> None: + self._nameMap[referenceName] = currentName + if self.debug: + print(f"Registered mapping: {referenceName} → {currentName}") + + def setReferenceTensors(self) -> None: + self._reference = {k: v.clone() for k, v in self._current.items()} + self._referenceOrder = list(self._executionOrder) + + def compareTensors(self) -> Dict[str, Dict]: + if self._reference is None: + raise RuntimeError("setReferenceTensors has not been called") + + results: Dict[str, Dict] = OrderedDict() + for refName, refTensor in self._reference.items(): + if refName in self._ignore: + continue + + curName = self._nameMap.get(refName, refName) + if curName not in self._current: + results[refName] = {"match": False, "error": f"missing '{curName}'"} + continue + + curTensor = self._current[curName] + equal = torch.equal(refTensor, curTensor) + diffMask = refTensor != curTensor + + results[refName] = { + "match": equal, + "mapped": curName != refName, + "current_name": curName, + "shape": tuple(refTensor.shape), + "diff_count": diffMask.sum().item() if not equal else 0, + "diff_mask": diffMask, + "ref_tensor": refTensor, + "cur_tensor": curTensor, + } + return results + + def _topDifferences( + self, ref: torch.Tensor, cur: torch.Tensor, diffMask: torch.Tensor + ) -> List[str]: + maskFlat = diffMask.reshape(-1).bool() + if maskFlat.sum() == 0: + return [] + + absDiff = (ref - cur).abs().reshape(-1)[maskFlat] + unique, counts = torch.unique(absDiff, return_counts=True) + order = counts.argsort(descending=True) + + lines: List[str] = [] + for idx in order[:5]: + delta = unique[idx].item() + count = counts[idx].item() + sampleIndex = (absDiff == delta).nonzero(as_tuple=False)[0].item() + globalIndex = maskFlat.nonzero(as_tuple=False)[sampleIndex].item() + beforeValue = ref.reshape(-1)[globalIndex].item() + afterValue = cur.reshape(-1)[globalIndex].item() + + lines.append( + f" · Δ={delta:.6f} ({count} values) e.g. idx {globalIndex}: " + f"{beforeValue:.6f} → {afterValue:.6f}" + ) + return lines + + def printComparisonResults(self, results: Dict[str, Dict]) -> None: + if not results: + print("No comparison data available.") + return + + matches = sum(1 for r in results.values() if r["match"]) + total = len(results) + + print( + f"Compared {total}: " + f"{cc.wrap(str(matches) + ' equal', cc.green)}, " + f"{cc.wrap(str(total - matches) + ' different', cc.red)}\n" + ) + + orderedNames = getattr(self, "_referenceOrder", list(results.keys())) + for name in orderedNames: + if name not in results: + continue + + res = results[name] + statusColor = cc.green if res["match"] else cc.red + statusTag = cc.wrap("[OK]" if res["match"] else "[DIFF]", statusColor) + mappedNote = f" → {res['current_name']}" if res["mapped"] else "" + + print(f" {statusTag} {name}{mappedNote} | shape {res['shape']}") + if res["match"]: + continue + + if "error" in res: + print(cc.wrap(f" {res['error']}", cc.yellow)) + continue + + diffCount = res["diff_count"] + totalValues = torch.tensor(res["shape"]).prod().item() + percentage = diffCount / totalValues * 100 + absDiff = (res["ref_tensor"] - res["cur_tensor"]).abs() + nonZero = absDiff[absDiff > 0] + minDiff = nonZero.min().item() if nonZero.numel() else 0.0 + + print(f" Max diff: {absDiff.max().item():.8f}") + print(f" Min diff: {minDiff:.8f}") + print(f" Mean diff: {absDiff.mean().item():.8f}") + print( + f" Total differing values: {diffCount} of {totalValues} ({percentage:.4f}%)" + ) + + topLines = self._topDifferences( + res["ref_tensor"], res["cur_tensor"], res["diff_mask"] + ) + if topLines: + print(" Most common differences (up to 5):") + for line in topLines: + print(line) diff --git a/DeepQuant/__init__.py b/DeepQuant/__init__.py new file mode 100644 index 0000000..6ac381f --- /dev/null +++ b/DeepQuant/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +# FBRANCASI: Workaround for PyTorch/FX API change: ensure private alias exists +import torch.fx.node as _fx_node + +if not hasattr(_fx_node.Node, "_Node__update_args_kwargs"): + _fx_node.Node._Node__update_args_kwargs = _fx_node.Node._update_args_kwargs + +from DeepQuant.Export import brevitasToTrueQuant + +__all__ = ["brevitasToTrueQuant"] diff --git a/README.md b/README.md index c625c62..9c16e1f 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ A library for true-quantization and optimization of neural networks. -Deeploy is developed as part of the PULP project, a joint effort between ETH Zurich and the University of Bologna. +DeepQuant is developed as part of the PULP project, a joint effort between ETH Zurich and the University of Bologna. ## License @@ -20,4 +20,4 @@ pip install -e . We provide comprehensive tests with pytest, to execute all tests, simply run `pytest`. We mark our tests in two categories, `SingleLayerTests` and `ModelTests`, to execute the tests of the specific category, you can run `pytest -m `. For instance, to execute only the single layer tests, you can run `pytest -m SingleLayerTests`. ## ⚠️ Disclaimer ⚠️ -This library is currently in **beta stage** and under active development. Interfaces and features are subject to change, and stability is not yet guaranteed. Use at your own risk, and feel free to report any issues or contribute to its improvement. \ No newline at end of file +This library is currently in **beta stage** and under active development. Interfaces and features are subject to change, and stability is not yet guaranteed. Use at your own risk, and feel free to report any issues or contribute to its improvement. diff --git a/Tests/Data/CIFAR/cifar-10-batches-py/batches.meta b/Tests/Data/CIFAR/cifar-10-batches-py/batches.meta new file mode 100644 index 0000000..4467a6e Binary files /dev/null and b/Tests/Data/CIFAR/cifar-10-batches-py/batches.meta differ diff --git a/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_1 b/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_1 new file mode 100644 index 0000000..ab404a5 Binary files /dev/null and b/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_1 differ diff --git a/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_2 b/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_2 new file mode 100644 index 0000000..6bf1369 Binary files /dev/null and b/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_2 differ diff --git a/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_3 b/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_3 new file mode 100644 index 0000000..66a0d63 Binary files /dev/null and b/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_3 differ diff --git a/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_4 b/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_4 new file mode 100644 index 0000000..cf8d03d Binary files /dev/null and b/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_4 differ diff --git a/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_5 b/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_5 new file mode 100644 index 0000000..468b2aa Binary files /dev/null and b/Tests/Data/CIFAR/cifar-10-batches-py/data_batch_5 differ diff --git a/Tests/Data/CIFAR/cifar-10-batches-py/readme.html b/Tests/Data/CIFAR/cifar-10-batches-py/readme.html new file mode 100644 index 0000000..e377ade --- /dev/null +++ b/Tests/Data/CIFAR/cifar-10-batches-py/readme.html @@ -0,0 +1 @@ + diff --git a/Tests/Data/CIFAR/cifar-10-batches-py/test_batch b/Tests/Data/CIFAR/cifar-10-batches-py/test_batch new file mode 100644 index 0000000..3e03f1f Binary files /dev/null and b/Tests/Data/CIFAR/cifar-10-batches-py/test_batch differ diff --git a/Tests/Data/checkpoint_epoch_200_cct2_cifar10.pth b/Tests/Data/checkpoint_epoch_200_cct2_cifar10.pth new file mode 100644 index 0000000..dd6805e Binary files /dev/null and b/Tests/Data/checkpoint_epoch_200_cct2_cifar10.pth differ diff --git a/Tests/Models/CCT/CCT/cct.py b/Tests/Models/CCT/CCT/cct.py new file mode 100644 index 0000000..e013501 --- /dev/null +++ b/Tests/Models/CCT/CCT/cct.py @@ -0,0 +1,606 @@ +import torch.nn as nn +from torch.hub import load_state_dict_from_url + +from .utils.helpers import fc_check, pe_check +from .utils.tokenizer import Tokenizer +from .utils.transformers import TransformerClassifier +from .registry import register_model + +model_urls = { + 'cct_7_3x1_32': + 'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_cifar10_300epochs.pth', + 'cct_7_3x1_32_sine': + 'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_sine_cifar10_5000epochs.pth', + 'cct_7_3x1_32_c100': + 'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_cifar100_300epochs.pth', + 'cct_7_3x1_32_sine_c100': + 'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_3x1_32_sine_cifar100_5000epochs.pth', + 'cct_7_7x2_224_sine': + 'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_7_7x2_224_flowers102.pth', + 'cct_14_7x2_224': + 'https://shi-labs.com/projects/cct/checkpoints/pretrained/cct_14_7x2_224_imagenet.pth', + 'cct_14_7x2_384': + 'https://shi-labs.com/projects/cct/checkpoints/finetuned/cct_14_7x2_384_imagenet.pth', + 'cct_14_7x2_384_fl': + 'https://shi-labs.com/projects/cct/checkpoints/finetuned/cct_14_7x2_384_flowers102.pth', +} + + +class CCT(nn.Module): + + def __init__(self, + img_size = 224, + embedding_dim = 768, + n_input_channels = 3, + n_conv_layers = 1, + kernel_size = 7, + stride = 2, + padding = 3, + pooling_kernel_size = 3, + pooling_stride = 2, + pooling_padding = 1, + dropout = 0., + attention_dropout = 0.1, + stochastic_depth = 0.1, + num_layers = 14, + num_heads = 6, + mlp_ratio = 4.0, + num_classes = 1000, + positional_embedding = 'learnable', + *args, + **kwargs): + super(CCT, self).__init__() + + self.tokenizer = Tokenizer(n_input_channels = n_input_channels, + n_output_channels = embedding_dim, + kernel_size = kernel_size, + stride = stride, + padding = padding, + pooling_kernel_size = pooling_kernel_size, + pooling_stride = pooling_stride, + pooling_padding = pooling_padding, + max_pool = True, + activation = nn.ReLU, + n_conv_layers = n_conv_layers, + conv_bias = False) + + self.classifier = TransformerClassifier(sequence_length = self.tokenizer.sequence_length( + n_channels = n_input_channels, height = img_size, width = img_size), + embedding_dim = embedding_dim, + seq_pool = True, + dropout = dropout, + attention_dropout = attention_dropout, + stochastic_depth = stochastic_depth, + num_layers = num_layers, + num_heads = num_heads, + mlp_ratio = mlp_ratio, + num_classes = num_classes, + positional_embedding = positional_embedding) + + def forward(self, x): + x = self.tokenizer(x) + return self.classifier(x) + + +def _cct(arch, + pretrained, + progress, + num_layers, + num_heads, + mlp_ratio, + embedding_dim, + kernel_size = 3, + stride = None, + padding = None, + positional_embedding = 'learnable', + *args, + **kwargs): + stride = stride if stride is not None else max(1, (kernel_size // 2) - 1) + padding = padding if padding is not None else max(1, (kernel_size // 2)) + model = CCT(num_layers = num_layers, + num_heads = num_heads, + mlp_ratio = mlp_ratio, + embedding_dim = embedding_dim, + kernel_size = kernel_size, + stride = stride, + padding = padding, + *args, + **kwargs) + + if pretrained: + if arch in model_urls: + state_dict = load_state_dict_from_url(model_urls[arch], progress = progress) + if positional_embedding == 'learnable': + state_dict = pe_check(model, state_dict) + elif positional_embedding == 'sine': + state_dict['classifier.positional_emb'] = model.state_dict()['classifier.positional_emb'] + state_dict = fc_check(model, state_dict) + model.load_state_dict(state_dict) + else: + raise RuntimeError(f'Variant {arch} does not yet have pretrained weights.') + return model + + +@register_model +def cct_2(arch, pretrained, progress, *args, **kwargs): + return _cct(arch, + pretrained, + progress, + num_layers = 2, + num_heads = 2, + mlp_ratio = 1, + embedding_dim = 128, + *args, + **kwargs) + + +@register_model +def cct_4(arch, pretrained, progress, *args, **kwargs): + return _cct(arch, + pretrained, + progress, + num_layers = 4, + num_heads = 2, + mlp_ratio = 1, + embedding_dim = 128, + *args, + **kwargs) + + +@register_model +def cct_6(arch, pretrained, progress, *args, **kwargs): + return _cct(arch, + pretrained, + progress, + num_layers = 6, + num_heads = 4, + mlp_ratio = 2, + embedding_dim = 256, + *args, + **kwargs) + + +@register_model +def cct_7(arch, pretrained, progress, *args, **kwargs): + return _cct(arch, + pretrained, + progress, + num_layers = 7, + num_heads = 4, + mlp_ratio = 2, + embedding_dim = 256, + *args, + **kwargs) + + +@register_model +def cct_1(arch, pretrained, progress, embedding_dim = 128, num_heads = 2, num_layers = 2, *args, **kwargs): + return _cct(arch, + pretrained, + progress, + num_layers=num_layers, + num_heads=num_heads, + mlp_ratio=2, + embedding_dim=embedding_dim, + *args, + **kwargs) + + +@register_model +def cct_14(arch, pretrained, progress, *args, **kwargs): + return _cct(arch, + pretrained, + progress, + num_layers = 14, + num_heads = 6, + mlp_ratio = 3, + embedding_dim = 384, + *args, + **kwargs) + + +@register_model +def cct_2_3x2_32(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'learnable', + num_classes = 10, + *args, + **kwargs): + return cct_2('cct_2_3x2_32', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 2, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_2_3x2_32_sine(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'sine', + num_classes = 10, + *args, + **kwargs): + return cct_2('cct_2_3x2_32_sine', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 2, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_4_3x2_32(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'learnable', + num_classes = 10, + *args, + **kwargs): + return cct_4('cct_4_3x2_32', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 2, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_4_3x2_32_sine(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'sine', + num_classes = 10, + *args, + **kwargs): + return cct_4('cct_4_3x2_32_sine', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 2, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_6_3x1_32(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'learnable', + num_classes = 10, + *args, + **kwargs): + return cct_6('cct_6_3x1_32', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 1, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_6_3x1_32_sine(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'sine', + num_classes = 10, + *args, + **kwargs): + return cct_6('cct_6_3x1_32_sine', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 1, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_6_3x2_32(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'learnable', + num_classes = 10, + *args, + **kwargs): + return cct_6('cct_6_3x2_32', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 2, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_6_3x2_32_sine(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'sine', + num_classes = 10, + *args, + **kwargs): + return cct_6('cct_6_3x2_32_sine', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 2, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_test(pretrained=False, + progress=False, + img_size=8, + positional_embedding='learnable', + num_classes=10, + embedding_dim=128, + num_heads=2, + num_layers=2, + n_conv_layers=1, + *args, + **kwargs): + return cct_1('cct_1_3x1_32', + pretrained, + progress, + kernel_size=3, + n_conv_layers=n_conv_layers, + img_size=img_size, + positional_embedding=positional_embedding, + num_classes=num_classes, + embedding_dim=embedding_dim, + num_heads=num_heads, + num_layers=num_layers, + *args, + **kwargs) + + +@register_model +def cct_7_3x1_32(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'learnable', + num_classes = 10, + *args, + **kwargs): + return cct_7('cct_7_3x1_32', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 1, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_7_3x1_32_sine(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'sine', + num_classes = 10, + *args, + **kwargs): + return cct_7('cct_7_3x1_32_sine', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 1, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_7_3x1_32_c100(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'learnable', + num_classes = 100, + *args, + **kwargs): + return cct_7('cct_7_3x1_32_c100', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 1, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_7_3x1_32_sine_c100(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'sine', + num_classes = 100, + *args, + **kwargs): + return cct_7('cct_7_3x1_32_sine_c100', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 1, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_7_3x2_32(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'learnable', + num_classes = 10, + *args, + **kwargs): + return cct_7('cct_7_3x2_32', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 2, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_7_3x2_32_sine(pretrained = False, + progress = False, + img_size = 32, + positional_embedding = 'sine', + num_classes = 10, + *args, + **kwargs): + return cct_7('cct_7_3x2_32_sine', + pretrained, + progress, + kernel_size = 3, + n_conv_layers = 2, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_7_7x2_224(pretrained = False, + progress = False, + img_size = 224, + positional_embedding = 'learnable', + num_classes = 102, + *args, + **kwargs): + return cct_7('cct_7_7x2_224', + pretrained, + progress, + kernel_size = 7, + n_conv_layers = 2, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_7_7x2_224_sine(pretrained = False, + progress = False, + img_size = 224, + positional_embedding = 'sine', + num_classes = 102, + *args, + **kwargs): + return cct_7('cct_7_7x2_224_sine', + pretrained, + progress, + kernel_size = 7, + n_conv_layers = 2, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_14_7x2_224(pretrained = False, + progress = False, + img_size = 224, + positional_embedding = 'learnable', + num_classes = 1000, + *args, + **kwargs): + return cct_14('cct_14_7x2_224', + pretrained, + progress, + kernel_size = 7, + n_conv_layers = 2, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_14_7x2_384(pretrained = False, + progress = False, + img_size = 384, + positional_embedding = 'learnable', + num_classes = 1000, + *args, + **kwargs): + return cct_14('cct_14_7x2_384', + pretrained, + progress, + kernel_size = 7, + n_conv_layers = 2, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) + + +@register_model +def cct_14_7x2_384_fl(pretrained = False, + progress = False, + img_size = 384, + positional_embedding = 'learnable', + num_classes = 102, + *args, + **kwargs): + return cct_14('cct_14_7x2_384_fl', + pretrained, + progress, + kernel_size = 7, + n_conv_layers = 2, + img_size = img_size, + positional_embedding = positional_embedding, + num_classes = num_classes, + *args, + **kwargs) diff --git a/Tests/Models/CCT/CCT/registry.py b/Tests/Models/CCT/CCT/registry.py new file mode 100644 index 0000000..583fdbd --- /dev/null +++ b/Tests/Models/CCT/CCT/registry.py @@ -0,0 +1,5 @@ +def register_model(func): + """ + Fallback wrapper in case timm isn't installed + """ + return func diff --git a/Tests/Models/CCT/CCT/utils/__init__.py b/Tests/Models/CCT/CCT/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Tests/Models/CCT/CCT/utils/embedder.py b/Tests/Models/CCT/CCT/utils/embedder.py new file mode 100644 index 0000000..33ac302 --- /dev/null +++ b/Tests/Models/CCT/CCT/utils/embedder.py @@ -0,0 +1,39 @@ +import torch.nn as nn + + +class Embedder(nn.Module): + + def __init__(self, + word_embedding_dim = 300, + vocab_size = 100000, + padding_idx = 1, + pretrained_weight = None, + embed_freeze = False, + *args, + **kwargs): + super(Embedder, self).__init__() + self.embeddings = nn.Embedding.from_pretrained(pretrained_weight, freeze=embed_freeze) \ + if pretrained_weight is not None else \ + nn.Embedding(vocab_size, word_embedding_dim, padding_idx=padding_idx) + self.embeddings.weight.requires_grad = not embed_freeze + + def forward_mask(self, mask): + bsz, seq_len = mask.shape + new_mask = mask.view(bsz, seq_len, 1) + new_mask = new_mask.sum(-1) + new_mask = (new_mask > 0) + return new_mask + + def forward(self, x, mask = None): + embed = self.embeddings(x) + embed = embed if mask is None else embed * self.forward_mask(mask).unsqueeze(-1).float() + return embed, mask + + @staticmethod + def init_weight(m): + if isinstance(m, nn.Linear): + nn.init.trunc_normal_(m.weight, std = .02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + else: + nn.init.normal_(m.weight) diff --git a/Tests/Models/CCT/CCT/utils/helpers.py b/Tests/Models/CCT/CCT/utils/helpers.py new file mode 100644 index 0000000..6d18820 --- /dev/null +++ b/Tests/Models/CCT/CCT/utils/helpers.py @@ -0,0 +1,45 @@ +import logging +import math + +import torch +import torch.nn.functional as F + +_logger = logging.getLogger('train') + + +def resize_pos_embed(posemb, posemb_new, num_tokens = 1): + # Copied from `timm` by Ross Wightman: + # github.com/rwightman/pytorch-image-models + # Rescale the grid of position embeddings when loading from state_dict. Adapted from + # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 + ntok_new = posemb_new.shape[1] + if num_tokens: + posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] + ntok_new -= num_tokens + else: + posemb_tok, posemb_grid = posemb[:, :0], posemb[0] + gs_old = int(math.sqrt(len(posemb_grid))) + gs_new = int(math.sqrt(ntok_new)) + posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) + posemb_grid = F.interpolate(posemb_grid, size = (gs_new, gs_new), mode = 'bilinear') + posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) + posemb = torch.cat([posemb_tok, posemb_grid], dim = 1) + return posemb + + +def pe_check(model, state_dict, pe_key = 'classifier.positional_emb'): + if pe_key is not None and pe_key in state_dict.keys() and pe_key in model.state_dict().keys(): + if model.state_dict()[pe_key].shape != state_dict[pe_key].shape: + state_dict[pe_key] = resize_pos_embed(state_dict[pe_key], + model.state_dict()[pe_key], + num_tokens = model.classifier.num_tokens) + return state_dict + + +def fc_check(model, state_dict, fc_key = 'classifier.fc'): + for key in [f'{fc_key}.weight', f'{fc_key}.bias']: + if key is not None and key in state_dict.keys() and key in model.state_dict().keys(): + if model.state_dict()[key].shape != state_dict[key].shape: + _logger.warning(f'Removing {key}, number of classes has changed.') + state_dict[key] = model.state_dict()[key] + return state_dict diff --git a/Tests/Models/CCT/CCT/utils/stochastic_depth.py b/Tests/Models/CCT/CCT/utils/stochastic_depth.py new file mode 100644 index 0000000..1ce2c54 --- /dev/null +++ b/Tests/Models/CCT/CCT/utils/stochastic_depth.py @@ -0,0 +1,39 @@ +# Thanks to rwightman's timm package +# github.com:rwightman/pytorch-image-models + +import torch +import torch.nn as nn + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """ + Obtained from: github.com:rwightman/pytorch-image-models + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype = x.dtype, device = x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """ + Obtained from: github.com:rwightman/pytorch-image-models + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob = None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/Tests/Models/CCT/CCT/utils/tokenizer.py b/Tests/Models/CCT/CCT/utils/tokenizer.py new file mode 100644 index 0000000..26f02ed --- /dev/null +++ b/Tests/Models/CCT/CCT/utils/tokenizer.py @@ -0,0 +1,116 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class Tokenizer(nn.Module): + + def __init__(self, + kernel_size, + stride, + padding, + pooling_kernel_size = 3, + pooling_stride = 2, + pooling_padding = 1, + n_conv_layers = 1, + n_input_channels = 3, + n_output_channels = 64, + in_planes = 64, + activation = None, + max_pool = True, + conv_bias = False): + super(Tokenizer, self).__init__() + + n_filter_list = [n_input_channels] + \ + [in_planes for _ in range(n_conv_layers - 1)] + \ + [n_output_channels] + + self.conv_layers = nn.Sequential(*[ + nn.Sequential( + nn.Conv2d(n_filter_list[i], + n_filter_list[i + 1], + kernel_size = (kernel_size, kernel_size), + stride = (stride, stride), + padding = (padding, padding), + bias = conv_bias), + nn.Identity() if activation is None else activation(), + nn.MaxPool2d(kernel_size = pooling_kernel_size, stride = pooling_stride, padding = pooling_padding + ) if max_pool else nn.Identity()) for i in range(n_conv_layers) + ]) + + self.flattener = nn.Flatten(2, 3) + self.apply(self.init_weight) + + def sequence_length(self, n_channels = 3, height = 224, width = 224): + return self.forward(torch.zeros((1, n_channels, height, width))).shape[1] + + def forward(self, x): + return self.flattener(self.conv_layers(x)).transpose(-2, -1) + + @staticmethod + def init_weight(m): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) + + +class TextTokenizer(nn.Module): + + def __init__(self, + kernel_size, + stride, + padding, + pooling_kernel_size = 3, + pooling_stride = 2, + pooling_padding = 1, + embedding_dim = 300, + n_output_channels = 128, + activation = None, + max_pool = True, + *args, + **kwargs): + super(TextTokenizer, self).__init__() + + self.max_pool = max_pool + self.conv_layers = nn.Sequential( + nn.Conv2d(1, + n_output_channels, + kernel_size = (kernel_size, embedding_dim), + stride = (stride, 1), + padding = (padding, 0), + bias = False), + nn.Identity() if activation is None else activation(), + nn.MaxPool2d( + kernel_size = (pooling_kernel_size, 1), stride = (pooling_stride, + 1), padding = (pooling_padding, + 0)) if max_pool else nn.Identity()) + + self.apply(self.init_weight) + + def seq_len(self, seq_len = 32, embed_dim = 300): + return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1] + + def forward_mask(self, mask): + new_mask = mask.unsqueeze(1).float() + cnn_weight = torch.ones((1, 1, self.conv_layers[0].kernel_size[0]), device = mask.device, dtype = torch.float) + new_mask = F.conv1d(new_mask, cnn_weight, None, self.conv_layers[0].stride[0], self.conv_layers[0].padding[0], + 1, 1) + if self.max_pool: + new_mask = F.max_pool1d(new_mask, self.conv_layers[2].kernel_size[0], self.conv_layers[2].stride[0], + self.conv_layers[2].padding[0], 1, False, False) + new_mask = new_mask.squeeze(1) + new_mask = (new_mask > 0) + return new_mask + + def forward(self, x, mask = None): + x = x.unsqueeze(1) + x = self.conv_layers(x) + x = x.transpose(1, 3).squeeze(1) + if mask is not None: + mask = self.forward_mask(mask).unsqueeze(-1).float() + x = x * mask + return x, mask + + @staticmethod + def init_weight(m): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight) diff --git a/Tests/Models/CCT/CCT/utils/transformers.py b/Tests/Models/CCT/CCT/utils/transformers.py new file mode 100644 index 0000000..f14472c --- /dev/null +++ b/Tests/Models/CCT/CCT/utils/transformers.py @@ -0,0 +1,366 @@ +import torch +import torch.nn.functional as F +from torch.nn import Dropout, Identity, LayerNorm, Linear, Module, ModuleList, Parameter, init + +from .stochastic_depth import DropPath + + +class Attention(Module): + """ + Attention module with explicit Q, K, V branches for ONNX export. + """ + + def __init__(self, dim, num_heads=8, attention_dropout=0.1, projection_dropout=0.1): + super().__init__() + self.num_heads = num_heads + head_dim = dim // self.num_heads + self.scale = head_dim ** -0.5 + + + self.q_proj = Linear(dim, dim, bias=False) + self.k_proj = Linear(dim, dim, bias=False) + self.v_proj = Linear(dim, dim, bias=False) + + self.attn_drop = Dropout(attention_dropout) + self.proj = Linear(dim, dim) + self.proj_drop = Dropout(projection_dropout) + + def forward(self, x): + B, N, C = x.shape + q = self.q_proj(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + k = self.k_proj(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + v = self.v_proj(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3) + + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = F.softmax(attn, dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class MaskedAttention(Module): + + def __init__(self, dim, num_heads = 8, attention_dropout = 0.1, projection_dropout = 0.1): + super().__init__() + self.num_heads = num_heads + head_dim = dim // self.num_heads + self.scale = head_dim**-0.5 + + self.qkv = Linear(dim, dim * 3, bias = False) + self.attn_drop = Dropout(attention_dropout) + self.proj = Linear(dim, dim) + self.proj_drop = Dropout(projection_dropout) + + def forward(self, x, mask = None): + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] + + attn = (q @ k.transpose(-2, -1)) * self.scale + + if mask is not None: + mask_value = -torch.finfo(attn.dtype).max + assert mask.shape[-1] == attn.shape[-1], 'mask has incorrect dimensions' + mask = mask[:, None, :] * mask[:, :, None] + mask = mask.unsqueeze(1).repeat(1, self.num_heads, 1, 1) + attn.masked_fill_(~mask, mask_value) + + attn = attn.softmax(dim = -1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class TransformerEncoderLayer(Module): + """ + Inspired by torch.nn.TransformerEncoderLayer and timm. + """ + + def __init__(self, + d_model, + nhead, + dim_feedforward = 2048, + dropout = 0.1, + attention_dropout = 0.1, + drop_path_rate = 0.1): + super(TransformerEncoderLayer, self).__init__() + self.pre_norm = LayerNorm(d_model) + self.self_attn = Attention(dim = d_model, + num_heads = nhead, + attention_dropout = attention_dropout, + projection_dropout = dropout) + + self.linear1 = Linear(d_model, dim_feedforward) + self.dropout1 = Dropout(dropout) + self.norm1 = LayerNorm(d_model) + self.linear2 = Linear(dim_feedforward, d_model) + self.dropout2 = Dropout(dropout) + + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity() + + self.activation = F.gelu + + def forward(self, src: torch.Tensor, *args, **kwargs) -> torch.Tensor: + src = src + self.drop_path(self.self_attn(self.pre_norm(src))) + src = self.norm1(src) + src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) + src = src + self.drop_path(self.dropout2(src2)) + return src + + +class MaskedTransformerEncoderLayer(Module): + """ + Inspired by torch.nn.TransformerEncoderLayer and timm. + """ + + def __init__(self, + d_model, + nhead, + dim_feedforward = 2048, + dropout = 0.1, + attention_dropout = 0.1, + drop_path_rate = 0.1): + super(MaskedTransformerEncoderLayer, self).__init__() + self.pre_norm = LayerNorm(d_model) + self.self_attn = MaskedAttention(dim = d_model, + num_heads = nhead, + attention_dropout = attention_dropout, + projection_dropout = dropout) + + self.linear1 = Linear(d_model, dim_feedforward) + self.dropout1 = Dropout(dropout) + self.norm1 = LayerNorm(d_model) + self.linear2 = Linear(dim_feedforward, d_model) + self.dropout2 = Dropout(dropout) + + self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0 else Identity() + + self.activation = F.gelu + + def forward(self, src: torch.Tensor, mask = None, *args, **kwargs) -> torch.Tensor: + src = src + self.drop_path(self.self_attn(self.pre_norm(src), mask)) + src = self.norm1(src) + src2 = self.linear2(self.dropout1(self.activation(self.linear1(src)))) + src = src + self.drop_path(self.dropout2(src2)) + return src + + +class TransformerClassifier(Module): + + def __init__(self, + seq_pool = True, + embedding_dim = 768, + num_layers = 12, + num_heads = 12, + mlp_ratio = 4.0, + num_classes = 1000, + dropout = 0.1, + attention_dropout = 0.1, + stochastic_depth = 0.1, + positional_embedding = 'learnable', + sequence_length = None): + super().__init__() + positional_embedding = positional_embedding if \ + positional_embedding in ['sine', 'learnable', 'none'] else 'sine' + dim_feedforward = int(embedding_dim * mlp_ratio) + self.embedding_dim = embedding_dim + self.sequence_length = sequence_length + self.seq_pool = seq_pool + self.num_tokens = 0 + + assert sequence_length is not None or positional_embedding == 'none', \ + f"Positional embedding is set to {positional_embedding} and" \ + f" the sequence length was not specified." + + if not seq_pool: + sequence_length += 1 + self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad = True) + self.num_tokens = 1 + else: + self.attention_pool = Linear(self.embedding_dim, 1) + + if positional_embedding != 'none': + if positional_embedding == 'learnable': + self.positional_emb = Parameter(torch.zeros(1, sequence_length, embedding_dim), requires_grad = True) + init.trunc_normal_(self.positional_emb, std = 0.2) + else: + self.positional_emb = Parameter(self.sinusoidal_embedding(sequence_length, embedding_dim), + requires_grad = False) + else: + self.positional_emb = None + + self.dropout = Dropout(p = dropout) + dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)] + self.blocks = ModuleList([ + TransformerEncoderLayer(d_model = embedding_dim, + nhead = num_heads, + dim_feedforward = dim_feedforward, + dropout = dropout, + attention_dropout = attention_dropout, + drop_path_rate = dpr[i]) for i in range(num_layers) + ]) + self.norm = LayerNorm(embedding_dim) + + self.fc = Linear(embedding_dim, num_classes) + self.apply(self.init_weight) + + def forward(self, x): + if self.positional_emb is None and x.size(1) < self.sequence_length: + x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode = 'constant', value = 0) + + if not self.seq_pool: + cls_token = self.class_emb.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim = 1) + + if self.positional_emb is not None: + x += self.positional_emb + + x = self.dropout(x) + + for blk in self.blocks: + x = blk(x) + x = self.norm(x) + + if self.seq_pool: + x = torch.matmul(F.softmax(self.attention_pool(x), dim = 1).transpose(-1, -2), x).squeeze(-2) + else: + x = x[:, 0] + + x = self.fc(x) + return x + + @staticmethod + def init_weight(m): + if isinstance(m, Linear): + init.trunc_normal_(m.weight, std = .02) + if isinstance(m, Linear) and m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, LayerNorm): + init.constant_(m.bias, 0) + init.constant_(m.weight, 1.0) + + @staticmethod + def sinusoidal_embedding(n_channels, dim): + pe = torch.FloatTensor([[p / (10000**(2 * (i // 2) / dim)) for i in range(dim)] for p in range(n_channels)]) + pe[:, 0::2] = torch.sin(pe[:, 0::2]) + pe[:, 1::2] = torch.cos(pe[:, 1::2]) + return pe.unsqueeze(0) + + +class MaskedTransformerClassifier(Module): + + def __init__(self, + seq_pool = True, + embedding_dim = 768, + num_layers = 12, + num_heads = 12, + mlp_ratio = 4.0, + num_classes = 1000, + dropout = 0.1, + attention_dropout = 0.1, + stochastic_depth = 0.1, + positional_embedding = 'sine', + seq_len = None, + *args, + **kwargs): + super().__init__() + positional_embedding = positional_embedding if \ + positional_embedding in ['sine', 'learnable', 'none'] else 'sine' + dim_feedforward = int(embedding_dim * mlp_ratio) + self.embedding_dim = embedding_dim + self.seq_len = seq_len + self.seq_pool = seq_pool + self.num_tokens = 0 + + assert seq_len is not None or positional_embedding == 'none', \ + f"Positional embedding is set to {positional_embedding} and" \ + f" the sequence length was not specified." + + if not seq_pool: + seq_len += 1 + self.class_emb = Parameter(torch.zeros(1, 1, self.embedding_dim), requires_grad = True) + self.num_tokens = 1 + else: + self.attention_pool = Linear(self.embedding_dim, 1) + + if positional_embedding != 'none': + if positional_embedding == 'learnable': + seq_len += 1 # padding idx + self.positional_emb = Parameter(torch.zeros(1, seq_len, embedding_dim), requires_grad = True) + init.trunc_normal_(self.positional_emb, std = 0.2) + else: + self.positional_emb = Parameter(self.sinusoidal_embedding(seq_len, embedding_dim, padding_idx = True), + requires_grad = False) + else: + self.positional_emb = None + + self.dropout = Dropout(p = dropout) + dpr = [x.item() for x in torch.linspace(0, stochastic_depth, num_layers)] + self.blocks = ModuleList([ + MaskedTransformerEncoderLayer(d_model = embedding_dim, + nhead = num_heads, + dim_feedforward = dim_feedforward, + dropout = dropout, + attention_dropout = attention_dropout, + drop_path_rate = dpr[i]) for i in range(num_layers) + ]) + self.norm = LayerNorm(embedding_dim) + + self.fc = Linear(embedding_dim, num_classes) + self.apply(self.init_weight) + + def forward(self, x, mask = None): + if self.positional_emb is None and x.size(1) < self.seq_len: + x = F.pad(x, (0, 0, 0, self.n_channels - x.size(1)), mode = 'constant', value = 0) + + if not self.seq_pool: + cls_token = self.class_emb.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim = 1) + if mask is not None: + mask = torch.cat([torch.ones(size = (mask.shape[0], 1), device = mask.device), mask.float()], dim = 1) + mask = (mask > 0) + + if self.positional_emb is not None: + x += self.positional_emb + + x = self.dropout(x) + + for blk in self.blocks: + x = blk(x, mask = mask) + x = self.norm(x) + + if self.seq_pool: + x = torch.matmul(F.softmax(self.attention_pool(x), dim = 1).transpose(-1, -2), x).squeeze(-2) + else: + x = x[:, 0] + + x = self.fc(x) + return x + + @staticmethod + def init_weight(m): + if isinstance(m, Linear): + init.trunc_normal_(m.weight, std = .02) + if isinstance(m, Linear) and m.bias is not None: + init.constant_(m.bias, 0) + elif isinstance(m, LayerNorm): + init.constant_(m.bias, 0) + init.constant_(m.weight, 1.0) + + @staticmethod + def sinusoidal_embedding(n_channels, dim, padding_idx = False): + pe = torch.FloatTensor([[p / (10000**(2 * (i // 2) / dim)) for i in range(dim)] for p in range(n_channels)]) + pe[:, 0::2] = torch.sin(pe[:, 0::2]) + pe[:, 1::2] = torch.cos(pe[:, 1::2]) + pe = pe.unsqueeze(0) + if padding_idx: + return torch.cat([torch.zeros((1, 1, dim)), pe], dim = 1) + return pe diff --git a/Tests/Models/CCT/config.yaml b/Tests/Models/CCT/config.yaml new file mode 100644 index 0000000..b1ac70c --- /dev/null +++ b/Tests/Models/CCT/config.yaml @@ -0,0 +1,12 @@ +cct: + pretrained: False + img_size: 32 + num_classes: 10 + embedding_dim: 64 + num_heads: 1 + num_layers: 2 + batch_size: 1 + opset_version: 12 + +training: + learning_rate: 0.01 \ No newline at end of file diff --git a/Tests/Models/CCT/mnistCheckpoint.py b/Tests/Models/CCT/mnistCheckpoint.py new file mode 100644 index 0000000..b53c256 --- /dev/null +++ b/Tests/Models/CCT/mnistCheckpoint.py @@ -0,0 +1,214 @@ +import numpy as np +import onnxruntime as ort +import onnx +import os +import torch +import torchvision +from torchvision import transforms +from utils.utils import * + +def preprocess_mnist(batch_size, image_size): + """ + Preprocess MNIST dataset with configurable image size. + + Args: + batch_size: Number of images to process + image_size: Size to resize images to (will be used as both height and width) + + Returns: + Tuple of (images, labels) as numpy arrays + """ + transform = transforms.Compose([ + transforms.Resize((image_size, image_size)), + transforms.Grayscale(num_output_channels=3), + transforms.ToTensor() + ]) + + dataset = torchvision.datasets.MNIST(root="./data", train=False, transform=transform, download=True) + indices = np.random.choice(len(dataset), batch_size, replace=False) + images = torch.stack([dataset[i][0] for i in indices]) + labels = np.array([dataset[i][1] for i in indices], dtype=np.int64) + + return images.numpy(), labels + +def run_original_onnx_model(input_data, labels, model_path): + """ + Run inference on original ONNX model to get gradients. + + Args: + input_data: Input data for the model + labels: Labels for the model + model_path: Path to the original ONNX model (without SGD) + + Returns: + Dictionary of model outputs (gradients) + """ + ort_session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) + + output_names = [output.name for output in ort_session.get_outputs()] + print(f"Model has {len(output_names)} outputs: {output_names}") + + outputs = ort_session.run(None, {"input": input_data, "labels": labels}) + + output_dict = {} + for i, name in enumerate(output_names): + output_dict[name] = outputs[i] + + return output_dict + +def get_initializer_from_onnx(model_path, initializer_name): + """ + Extract initializer tensor from ONNX model. + + Args: + model_path: Path to the ONNX model + initializer_name: Name of the initializer to extract + + Returns: + Numpy array of the initializer tensor + """ + model = onnx.load(model_path) + for initializer in model.graph.initializer: + if initializer.name == initializer_name: + # Convert ONNX tensor to numpy array + from onnx import numpy_helper + return numpy_helper.to_array(initializer) + + raise ValueError(f"Initializer {initializer_name} not found in model") + +def apply_sgd_update(weight, gradient, learning_rate=0.01): + """ + Manually apply SGD update to weights. + + Args: + weight: Current weight tensor + gradient: Gradient tensor + learning_rate: Learning rate for SGD + + Returns: + Updated weight tensor + """ + return weight - learning_rate * gradient + +def create_test_input_output(): + """ + Create test input and output files with manual SGD implementation. + """ + # Load config + config = load_config() + if isinstance(config, tuple): + # Handle the case where load_config returns a tuple of values + pretrained, img_size, num_classes, embedding_dim, num_heads, num_layers, batch_size, opset_version = config + else: + # Handle the case where load_config returns a dictionary + img_size = config.get("img_size", 16) # Default to 16 if not specified + batch_size = config.get("batch_size", 8) # Default to 8 if not specified + embedding_dim = config.get("embedding_dim", 384) + num_heads = config.get("num_heads", 6) + num_layers = config.get("num_layers", 7) + + print(f"Using image size: {img_size}, batch size: {batch_size}") + + folder_name = f"CCT_train_{img_size}_{embedding_dim}_{num_heads}_{num_layers}" + + base_dir = os.path.dirname(os.path.abspath(__file__)) + folder_path = os.path.join(base_dir, "onnx", folder_name) + os.makedirs(folder_path, exist_ok=True) + + # Path to original training network + network_path = os.path.join(base_dir, "onnx", folder_name, "network_train.onnx") + input_path = os.path.join(base_dir, "onnx", folder_name, "inputs.npz") + output_path = os.path.join(base_dir, "onnx", folder_name, "outputs.npz") + + print(f"Original network path: {network_path}") + print(f"Input path: {input_path}") + print(f"Output path: {output_path}") + + if not os.path.exists(network_path): + raise FileNotFoundError(f"ONNX model file not found: {network_path}") + + # Create input data with the specified image size + input_data, labels = preprocess_mnist(batch_size, img_size) + np.savez(input_path, input=input_data, labels=labels) + print(f"✅ Input saved to inputs.npz (image size: {img_size}x{img_size}, batch size: {batch_size})") + + # Run the original model to get gradients + outputs_dict = run_original_onnx_model(input_data, labels, model_path=network_path) + + # Extract parameter gradients from outputs + weight_grad_name = None + bias_grad_name = None + + # Try to find the gradient outputs by name patterns + for name in outputs_dict.keys(): + if "classifier_fc_weight_grad" in name: + weight_grad_name = name + elif "classifier_fc_Gemm_Grad_dC_reduced" in name or "classifier_fc_bias_grad" in name: + bias_grad_name = name + + if not weight_grad_name: + print("❌ Could not find weight gradient in outputs") + print(f"Available outputs: {list(outputs_dict.keys())}") + return + + if not bias_grad_name: + print("❌ Could not find bias gradient in outputs") + print(f"Available outputs: {list(outputs_dict.keys())}") + return + + print(f"Found weight gradient: {weight_grad_name}") + print(f"Found bias gradient: {bias_grad_name}") + + # Extract original weights from model + try: + classifier_fc_weight = get_initializer_from_onnx(network_path, "classifier_fc_weight") + classifier_fc_bias = get_initializer_from_onnx(network_path, "classifier_fc_bias") + print(f"✅ Successfully extracted original parameters from model") + print(f" Weight shape: {classifier_fc_weight.shape}") + print(f" Bias shape: {classifier_fc_bias.shape}") + except Exception as e: + print(f"❌ Error extracting parameters: {e}") + import traceback + traceback.print_exc() + return + + # Apply SGD manually + learning_rate = load_train_config() + + weight_grad = outputs_dict[weight_grad_name] + bias_grad = outputs_dict[bias_grad_name] + + # Apply SGD update + weight_updated = apply_sgd_update(classifier_fc_weight, weight_grad, learning_rate) + bias_updated = apply_sgd_update(classifier_fc_bias, bias_grad, learning_rate) + + print(f"✅ Successfully applied SGD updates to parameters") + + # Create output dict with updated parameters + sgd_outputs = { + "classifier_fc_weight_updated": weight_updated, + "classifier_fc_bias_updated": bias_updated + } + + # Save updated parameters to output file + np.savez(output_path, **sgd_outputs) + print(f"✅ Updated parameters saved to {output_path}") + + # Print output shapes + print("Final output shapes:") + for name, arr in sgd_outputs.items(): + print(f" {name}: {arr.shape}") + +def main(): + """ + Main function to run the script with better error handling + """ + try: + create_test_input_output() + except Exception as e: + print(f"❌ Error: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/Tests/Models/CCT/testinfergenerate.py b/Tests/Models/CCT/testinfergenerate.py new file mode 100644 index 0000000..5f5028f --- /dev/null +++ b/Tests/Models/CCT/testinfergenerate.py @@ -0,0 +1,69 @@ +import onnx +import onnxruntime as ort +import numpy as np +import sys +import os +import torch +from CCT.cct import cct_test +from utils.utils import * + +def generate_cct_onnx_and_data(save_path=None): + """ Generate ONNX model for CCT based on config, with optional save path """ + + pretrained, img_size, num_classes, embedding_dim, num_heads, num_layers, batch_size, opset_version = load_config() + print(f"✅ Loaded config: img_size={img_size}, embedding_dim={embedding_dim}, num_heads={num_heads}, num_layers={num_layers}, opset_version={opset_version}") + + input_shape = (1, 3, img_size, img_size) + + folder_name = f"CCT_infer_{img_size}_{embedding_dim}_{num_heads}_{num_layers}" + + base_path = save_path if save_path else os.path.join(os.path.dirname(os.path.abspath(__file__)), "onnx", folder_name) + + onnx_file = os.path.join(base_path, "network.onnx") + input_file = os.path.join(base_path, "inputs.npz") + output_file = os.path.join(base_path, "outputs.npz") + + os.makedirs(base_path, exist_ok=True) + + model = cct_test( + pretrained=pretrained, + img_size=img_size, + num_classes=num_classes, + embedding_dim=embedding_dim, + num_heads=num_heads, + num_layers=num_layers, + n_conv_layers=2 + ) + model.eval() + model = randomize_layernorm_params(model) + + input_data = np.random.randn(*input_shape).astype(np.float32) + np.savez(input_file, input=input_data) + + input_tensor = torch.tensor(input_data) + + torch.onnx.export( + model, + input_tensor, + onnx_file, + opset_version=opset_version, + input_names=["input"], + output_names=["output"], + dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, + ) + onnx_model = onnx.load(onnx_file) + onnx_model = randomize_onnx_initializers(onnx_model) + print(f"✅ ONNX model saved to {onnx_file}") + rename_and_save_onnx(onnx_file, onnx_file) + + run_onnx_optimization_infer(onnx_file, embedding_dim, num_heads, input_shape) + rename_and_save_onnx(onnx_file, onnx_file) + ort_session = ort.InferenceSession(onnx_file) + output_data = ort_session.run(None, {"input": input_data})[0] + + np.savez(output_file, output=output_data) + print(f"✅ Output data saved to {output_file}") + +if __name__ == "__main__": + save_path = sys.argv[1] if len(sys.argv) > 1 else None + generate_cct_onnx_and_data(save_path) diff --git a/Tests/Models/CCT/testtraingenerate.py b/Tests/Models/CCT/testtraingenerate.py new file mode 100644 index 0000000..c738243 --- /dev/null +++ b/Tests/Models/CCT/testtraingenerate.py @@ -0,0 +1,160 @@ +import onnx +from onnx import helper +import torch +import os +import sys +import io +from CCT.cct import cct_test +from onnxruntime.training import artifacts +from utils.utils import * +from utils.fixshape import infer_shapes_with_custom_ops, print_onnx_shapes +from mnistCheckpoint import create_test_input_output +from utils.appendoptimizer import * + +def generate_cct_training_onnx(save_path=None): + """ Generate ONNX training model for CCT based on config, with optional save path """ + + pretrained, img_size, num_classes, embedding_dim, num_heads, num_layers, batch_size, opset_version = load_config() + + input_shape = (batch_size, 3, img_size, img_size) + + folder_name = f"CCT_train_{img_size}_{embedding_dim}_{num_heads}_{num_layers}" + + + base_path = save_path if save_path else os.path.join(os.path.dirname(os.path.abspath(__file__)), "onnx", folder_name) + os.makedirs(base_path, exist_ok=True) # Ensure directory exists + + onnx_infer_file = os.path.join(base_path, "network_infer.onnx") + onnx_train_file = os.path.join(base_path, "network_train.onnx") + onnx_output_file = os.path.join(base_path, "network.onnx") + onnx_train_optim = os.path.join(base_path, "network_train_optim.onnx") + + # Create CCT model and randomize layer norm parameters + model = cct_test( + pretrained=pretrained, + img_size=img_size, + num_classes=num_classes, + embedding_dim=embedding_dim, + num_heads=num_heads, + num_layers=num_layers, + n_conv_layers = 2 + ) + model.train() + model = randomize_layernorm_params(model) + + # Generate random input data for export + input_tensor = torch.randn(*input_shape, dtype=torch.float32) + + # Export model to ONNX in training mode + f = io.BytesIO() + torch.onnx.export( + model, + input_tensor, + f, + input_names=["input"], + output_names=["output"], + opset_version=opset_version, + do_constant_folding=False, # Ensure parameters are not folded into constants + # training=torch.onnx.TrainingMode.TRAINING, + # dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}}, + export_params=True, + keep_initializers_as_inputs=False, + ) + + # Load ONNX model from buffer and save it as network_infer.onnx + onnx_model = onnx.load_model_from_string(f.getvalue()) + # print("Randomizing initializers in inference model...") + onnx_model = randomize_onnx_initializers(onnx_model) + + onnx.save(onnx_model, onnx_infer_file) + print(f"✅ Inference ONNX model saved to {onnx_infer_file}") + + # Run optimization on the inference model + rename_and_save_onnx(onnx_infer_file, onnx_infer_file) + run_onnx_optimization(onnx_infer_file, embedding_dim, num_heads, input_shape) + print_onnx_shapes(onnx_infer_file) + onnx_model = onnx.load(onnx_infer_file) + + # Get all parameter names and require_grad names + all_param_names = [init.name for init in onnx_model.graph.initializer] + print(f" All Parameters: {all_param_names}") + + # requires_grad = [name for name in all_param_names if name in [ + # 'classifier_norm_bias', 'classifier_norm_weight', 'classifier_attention_pool_weight', 'classifier_attention_pool_bias', 'classifier_fc_weight', 'classifier_blocks_0_pre_norm_bias', 'classifier_fc_bias', 'node_0_classifier_attention_pool_Transpose__0' + # ]] + # requires_grad = [ name for name in all_param_names if "const" not in name] + + # requires_grad = [name for name in all_param_names if name in [ + # 'classifier_fc_weight', 'classifier_fc_bias', 'node_0_classifier_attention_pool_Transpose__0', 'classifier_norm_weight', 'classifier_norm_bias', 'classifier_attention_pool_bias' ]] + requires_grad = [name for name in all_param_names if name in [ + 'classifier_fc_weight', 'classifier_fc_bias' ]] + # requires_grad = [name for name in all_param_names if name in [ + # 'classifier_fc_weight', 'classifier_fc_bias']] + + # requires_grad = [name for name in all_param_names if name in [ + # 'node_0_classifier_blocks_0_linear1_Transpose__0', 'classifier_blocks_0_linear1_bias', 'node_0_classifier_blocks_0_linear2_Transpose__0' + # ]] + # requires_grad = [name for name in all_param_names if name in [ + # 'node_0_classifier_blocks_0_self_attn_q_proj_Transpose__0', 'node_0_classifier_blocks_0_self_attn_k_proj_Transpose__0', 'node_0_classifier_blocks_0_self_attn_v_proj_Transpose__0', + # 'node_0_classifier_blocks_0_self_attn_proj_Transpose__0', 'classifier_blocks_0_self_attn_proj_bias', 'classifier_blocks_0_pre_norm_weight', 'classifier_blocks_0_pre_norm_bias', 'classifier_positional_emb' + # ]] + + frozen_params = [name for name in all_param_names if name not in requires_grad] + + print(f"🔹 Training Only: {requires_grad}") + print(f"🔹 Frozen Parameters: {frozen_params}") + + + # Generate artifacts for training + artifacts.generate_artifacts( + onnx_model, + optimizer=artifacts.OptimType.SGD, + loss=artifacts.LossType.CrossEntropyLoss, + requires_grad=requires_grad, + frozen_params=frozen_params, + artifact_directory=base_path, + + ) + + training_model_path = os.path.join(base_path, "training_model.onnx") + if os.path.exists(training_model_path): + os.rename(training_model_path, onnx_train_file) + print(f"✅ Final Training ONNX model saved as {onnx_train_file}") + + # load the training model + onnx_model = onnx.load(onnx_train_file) + graph = onnx_model.graph + grad_tensor_names = [ name + '_grad' for name in requires_grad ] + + + for grad_name in grad_tensor_names: + if not any(output.name == grad_name for output in graph.output): + + grad_output = helper.make_tensor_value_info(grad_name, onnx.TensorProto.FLOAT, None) + graph.output.append(grad_output) + onnx.save(onnx_model, onnx_train_optim) + onnx.save(onnx_model, onnx_train_file) + + # train file for generating golden model debug + # train_optim file for further optimization + + # Run optimization on the training model + onnx_output_file = os.path.join(base_path, "network.onnx") + run_train_onnx_optimization(onnx_train_optim, onnx_output_file) + infer_shapes_with_custom_ops(onnx_output_file, onnx_output_file) + rename_nodes(onnx_output_file, onnx_output_file) + print_onnx_shapes(onnx_output_file) + + print(f"✅ Training ONNX model saved to {onnx_output_file}") + create_test_input_output() + print(f"✅ Created test input and output data") + + learning_rate = load_train_config() + add_sgd_nodes(onnx_output_file, onnx_output_file, learning_rate=learning_rate) + infer_shapes_with_custom_ops(onnx_output_file, onnx_output_file) + type_inference(onnx_output_file, onnx_output_file) + print(f"✅ Added SGD nodes to {onnx_output_file}") + +if __name__ == "__main__": + save_path = sys.argv[1] if len(sys.argv) > 1 else None + generate_cct_training_onnx(save_path) diff --git a/Tests/Models/CCT/utils/__init__.py b/Tests/Models/CCT/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/Tests/Models/CCT/utils/appendoptimizer.py b/Tests/Models/CCT/utils/appendoptimizer.py new file mode 100644 index 0000000..e9b53dc --- /dev/null +++ b/Tests/Models/CCT/utils/appendoptimizer.py @@ -0,0 +1,182 @@ +import onnx +from onnx import helper, shape_inference +from onnx import TensorProto +import re + +def add_sgd_nodes(model_path, output_path, learning_rate=0.01): + """ + Add SGD nodes to the ONNX model and remove original gradient outputs. + + Args: + model_path: Path to the original ONNX model + output_path: Path to save the modified ONNX model + learning_rate: Learning rate for SGD (default: 0.01) + """ + # Load the ONNX model + model = onnx.load(model_path) + graph = model.graph + + # Get reference to the output nodes + node_133_output = None + classifier_fc_weight_grad = None + + # Store all outputs for later filtering + original_outputs = [] + grad_outputs_to_remove = [] + + for output in graph.output: + original_outputs.append(output) + if re.match(r"node_\d+_classifier_fc_Gemm_Grad_dC_reduced", output.name): + node_classifier_output = output + grad_outputs_to_remove.append(output) + elif output.name == "classifier_fc_weight_grad": + classifier_fc_weight_grad = output + grad_outputs_to_remove.append(output) + + if not node_classifier_output or not classifier_fc_weight_grad: + raise ValueError("Could not find required output nodes") + + # Find the original classifier_fc_Gemm node to get initializers + classifier_fc_weight = None + classifier_fc_bias = None + + for initializer in graph.initializer: + if initializer.name == "classifier_fc_weight": + classifier_fc_weight = initializer + elif initializer.name == "classifier_fc_bias": + classifier_fc_bias = initializer + + if not classifier_fc_weight or not classifier_fc_bias: + raise ValueError("Could not find required initializers") + + # Check output shapes and ensure they match + from onnx import numpy_helper + weight_shape = list(numpy_helper.to_array(classifier_fc_weight).shape) # Convert to list + bias_shape = list(numpy_helper.to_array(classifier_fc_bias).shape) # Convert to list + + weight_grad_shape = get_output_shape(graph, "classifier_fc_weight_grad") + bias_grad_shape = get_output_shape(graph, "classifier_fc_weight_grad") + + print(f"Weight shape: {weight_shape}, Weight grad shape: {weight_grad_shape}") + print(f"Bias shape: {bias_shape}, Bias grad shape: {bias_grad_shape}") + + # Compare shapes after converting to the same type (list) + if weight_shape != weight_grad_shape: + print("Warning: Weight shape doesn't match gradient shape, but will continue") + print(f"Weight type: {type(weight_shape)}, Gradient type: {type(weight_grad_shape)}") + + # Also convert bias shapes to ensure proper comparison + if bias_shape != bias_grad_shape: + print("Warning: Bias shape doesn't match gradient shape, but will continue") + print(f"Bias type: {type(bias_shape)}, Gradient type: {type(bias_grad_shape)}") + + # Convert shapes back to tuples for creating tensor_value_info + weight_shape_tuple = tuple(weight_shape) + bias_shape_tuple = tuple(bias_shape) + + # Get the proper tensor types for output shape inference + weight_type = get_value_info_type(classifier_fc_weight) + bias_type = get_value_info_type(classifier_fc_bias) + + # Create SGD node for the weight gradient with learning_rate as attribute + sgd_weight_node = helper.make_node( + op_type="SGD", + inputs=[ + "classifier_fc_weight", # weights to update + "classifier_fc_weight_grad", # gradient + ], + outputs=["classifier_fc_weight_updated"], + name="classifier_fc_weight_sgd", + domain="", + lr=float(learning_rate) # Ensure learning_rate is a float and set as attribute + ) + + # Create SGD node for the bias gradient with learning_rate as attribute + sgd_bias_node = helper.make_node( + op_type="SGD", + inputs=[ + "classifier_fc_bias", # bias to update + "classifier_fc_weight_grad", # gradient + ], + outputs=["classifier_fc_bias_updated"], + name="classifier_fc_bias_sgd", + domain="", + lr=float(learning_rate) # Ensure learning_rate is a float and set as attribute + ) + + # Add the new SGD nodes to the graph + graph.node.extend([sgd_weight_node, sgd_bias_node]) + + # Create value info for the outputs with proper shapes + updated_weight_output = helper.make_tensor_value_info( + name="classifier_fc_weight_updated", + elem_type=TensorProto.FLOAT, + shape=weight_shape_tuple + ) + + updated_bias_output = helper.make_tensor_value_info( + name="classifier_fc_bias_updated", + elem_type=TensorProto.FLOAT, + shape=bias_shape_tuple + ) + + # Clear original outputs and add only what we want to keep + graph.ClearField("output") + + # Add only the SGD-updated outputs + graph.output.extend([updated_weight_output, updated_bias_output]) + + # Keep any other original outputs that weren't gradients + for output in original_outputs: + if output not in grad_outputs_to_remove: + graph.output.append(output) + + # Run shape inference to verify and update all shapes + try: + inferred_model = shape_inference.infer_shapes(model) + print("✅ Shape inference successful") + model = inferred_model + except Exception as e: + print(f"⚠️ Shape inference warning: {e}") + print("Continuing with explicit shapes...") + + # Save the modified model + onnx.save(model, output_path) + print(f"✅ Modified model saved to {output_path}") + print(f"Learning rate set to {learning_rate} as node attribute") + print("Original gradient outputs have been removed") + +def get_output_shape(graph, output_name): + """ + Get the shape of an output tensor by name. + + Args: + graph: ONNX graph + output_name: Name of the output tensor + + Returns: + List representing the tensor shape + """ + # First check among graph outputs + for output in graph.output: + if output.name == output_name: + return [dim.dim_value for dim in output.type.tensor_type.shape.dim] + + # If not found in outputs, look for value_info + for value_info in graph.value_info: + if value_info.name == output_name: + return [dim.dim_value for dim in value_info.type.tensor_type.shape.dim] + + raise ValueError(f"Cannot find shape for output {output_name}") + +def get_value_info_type(tensor): + """ + Get the data type of a tensor. + + Args: + tensor: ONNX tensor + + Returns: + ONNX data type + """ + return tensor.data_type \ No newline at end of file diff --git a/Tests/Models/CCT/utils/fixshape.py b/Tests/Models/CCT/utils/fixshape.py new file mode 100644 index 0000000..083044b --- /dev/null +++ b/Tests/Models/CCT/utils/fixshape.py @@ -0,0 +1,276 @@ +import onnx +import numpy as np +import argparse +from onnx import helper, numpy_helper, shape_inference, TensorProto +import logging +from collections import defaultdict +import sys +import copy +import onnx +import numpy as np +from onnx import shape_inference +from typing import List, Dict, Any, Optional + +def register_custom_shape_inference(): + + + from onnx.shape_inference import _bring_proto, _get_shape_calculator_dict + + def softmax_cross_entropy_grad_shape_inference(ctx): + + node = ctx.node + + + log_prob_type_proto = ctx.get_input_type(2) + if log_prob_type_proto is None: + return + + + ctx.set_output_type(0, log_prob_type_proto) + + print(f"SoftmaxCrossEntropyGrad shape inference: output shape set from log_prob input") + + + shape_calculator_dict = _get_shape_calculator_dict() + shape_calculator_dict["com.microsoft.SoftmaxCrossEntropyGrad"] = softmax_cross_entropy_grad_shape_inference + + +def infer_shapes_with_custom_ops(model_path: str, output_model_path: Optional[str] = None) -> onnx.ModelProto: + + model = onnx.load(model_path) + + op_types = set(node.op_type for node in model.graph.node) + microsoft_ops = [op for op in op_types if op.startswith("com.microsoft")] + print(f"Find Microsoft Custom Op: {microsoft_ops}") + + try: + + register_custom_shape_inference() + + + inferred_model = shape_inference.infer_shapes(model) + print("Shape Inference Succeeded") + except Exception as e: + print(f"Error: {str(e)}") + print("Try to infer ...") + + inferred_model = model + for i, node in enumerate(model.graph.node): + try: + + subgraph_model = extract_subgraph(model, [node]) + + inferred_subgraph = shape_inference.infer_shapes(subgraph_model) + + update_model_with_inferred_shapes(inferred_model, inferred_subgraph, i) + print(f"Node {i}: {node.op_type} Succeed") + except Exception as node_err: + print(f"Node {i}: {node.op_type} Shape inference fail: {str(node_err)}") + if node.op_type.startswith("com.microsoft"): + print(f"Try to use microsoft : {node.op_type}") + try: + apply_custom_inference(inferred_model.graph, node) + print(f"Node {i}: {node.op_type} infered by custom inference") + except Exception as custom_err: + print(f"Infered failed: {str(custom_err)}") + + if output_model_path: + onnx.save(inferred_model, output_model_path) + print(f"Onnx with shape saved: {output_model_path}") + + return inferred_model + + +def apply_custom_inference(graph: onnx.GraphProto, node: onnx.NodeProto) -> None: + + if node.op_type == "com.microsoft.SoftmaxCrossEntropyGrad": + + + if len(node.input) >= 3 and len(node.output) >= 1: + log_prob_shape = get_tensor_shape(graph, node.input[2]) + if log_prob_shape: + set_tensor_shape(graph, node.output[0], log_prob_shape) + print(f"SoftmaxCrossEntropyGrad: {log_prob_shape}") + + elif node.op_type.startswith("com.microsoft"): + + if "Grad" in node.op_type: + + if len(node.input) >= 2 and len(node.output) >= 1: + input_shape = get_tensor_shape(graph, node.input[1]) + if input_shape: + set_tensor_shape(graph, node.output[0], input_shape) + print(f"Set {node.op_type} Output with same size of input {node.input[1]} : {input_shape}") + + +def extract_subgraph(model: onnx.ModelProto, nodes: List[onnx.NodeProto]) -> onnx.ModelProto: + + subgraph = onnx.ModelProto() + subgraph.CopyFrom(model) + + + del subgraph.graph.node[:] + subgraph.graph.node.extend(nodes) + + return subgraph + + +def update_model_with_inferred_shapes(model: onnx.ModelProto, inferred_subgraph: onnx.ModelProto, node_index: int) -> None: + + if node_index < len(model.graph.node): + node = model.graph.node[node_index] + inferred_node = inferred_subgraph.graph.node[0] + + + for i, output_name in enumerate(node.output): + if i < len(inferred_node.output): + update_value_info_shape(model.graph, output_name, + get_value_info_by_name(inferred_subgraph.graph, inferred_node.output[i])) + + +def get_value_info_by_name(graph: onnx.GraphProto, name: str) -> Optional[onnx.ValueInfoProto]: + + for info in graph.output: + if info.name == name: + return info + + + for info in graph.value_info: + if info.name == name: + return info + + + for info in graph.input: + if info.name == name: + return info + + return None + + +def update_value_info_shape(graph: onnx.GraphProto, name: str, value_info: Optional[onnx.ValueInfoProto]) -> None: + + if not value_info or not value_info.type.tensor_type.shape: + return + + + existing_info = get_value_info_by_name(graph, name) + if existing_info: + existing_info.type.tensor_type.shape.CopyFrom(value_info.type.tensor_type.shape) + else: + + new_info = onnx.ValueInfoProto() + new_info.name = name + new_info.type.tensor_type.shape.CopyFrom(value_info.type.tensor_type.shape) + graph.value_info.append(new_info) + + +def set_tensor_shape(graph: onnx.GraphProto, tensor_name: str, shape: List[int]) -> None: + + value_info = get_value_info_by_name(graph, tensor_name) + if not value_info: + + value_info = onnx.ValueInfoProto() + value_info.name = tensor_name + graph.value_info.append(value_info) + + + value_info.type.tensor_type.shape.Clear() + + + for dim_value in shape: + dim = value_info.type.tensor_type.shape.dim.add() + if dim_value > 0: + dim.dim_value = dim_value + else: + + dim.dim_param = "?" + +def get_tensor_shape(model, tensor_name): + for initializer in model.graph.initializer: + if initializer.name == tensor_name: + return tuple(initializer.dims) + + for input_tensor in model.graph.input: + if input_tensor.name == tensor_name: + return tuple(dim.dim_value for dim in input_tensor.type.tensor_type.shape.dim) + + for output_tensor in model.graph.output: + if output_tensor.name == tensor_name: + return tuple(dim.dim_value for dim in output_tensor.type.tensor_type.shape.dim) + + return None + +def print_onnx_shapes(model_path): + model= onnx.load(model_path) + graph = model.graph + shape_info = {} + for input_info in graph.input: + shape = [] + for dim in input_info.type.tensor_type.shape.dim: + if dim.dim_param: + shape.append(dim.dim_param) + else: + shape.append(dim.dim_value) + shape_info[input_info.name] = shape + + + for output_info in graph.output: + shape = [] + for dim in output_info.type.tensor_type.shape.dim: + if dim.dim_param: + shape.append(dim.dim_param) + else: + shape.append(dim.dim_value) + shape_info[output_info.name] = shape + + for value_info in graph.value_info: + shape = [] + for dim in value_info.type.tensor_type.shape.dim: + if dim.dim_param: + shape.append(dim.dim_param) + else: + shape.append(dim.dim_value) + shape_info[value_info.name] = shape + + initializers = {init.name for init in graph.initializer} + + print("\nInput:") + for input_info in graph.input: + if input_info.name in shape_info: + print(f" {input_info.name}: {shape_info[input_info.name]}") + else: + print(f" {input_info.name}: Unknown") + + + print("\nOutput:") + for output_info in graph.output: + if output_info.name in shape_info: + print(f" {output_info.name}: {shape_info[output_info.name]}") + else: + print(f" {output_info.name}: Unknown") + + + print("\nNode info:") + for i, node in enumerate(graph.node): + print(f"\nNode {i+1}: {node.name} (type: {node.op_type})") + + print(" Input:") + for j, input_name in enumerate(node.input): + if input_name in initializers: + print(f" {j+1}. {input_name}: [Initializer]") + elif input_name in shape_info: + print(f" {j+1}. {input_name}: {shape_info[input_name]}") + else: + print(f" {j+1}. {input_name}: Unknown") + + print(" Output:") + for j, output_name in enumerate(node.output): + if output_name in shape_info: + print(f" {j+1}. {output_name}: {shape_info[output_name]}") + else: + print(f" {j+1}. {output_name}: Unknown") + + if node.attribute: + print(" Property:") + for attr in node.attribute: + print(f" {attr.name}") \ No newline at end of file diff --git a/Tests/Models/CCT/utils/trainoptimization.py b/Tests/Models/CCT/utils/trainoptimization.py new file mode 100644 index 0000000..8c50f68 --- /dev/null +++ b/Tests/Models/CCT/utils/trainoptimization.py @@ -0,0 +1,1591 @@ +import onnx +import os +import re +import subprocess +import yaml +from onnx import helper, numpy_helper, shape_inference +import numpy as np +import copy + +def add_c_to_gemm(input_model_path, output_model_path): + model = onnx.load(input_model_path) + graph = model.graph + + for node in graph.node: + if node.op_type == 'Gemm': + + if len(node.input) == 2: + print(f"Find Gemm without C: {node.name}") + + input_a_name = node.input[0] + input_b_name = node.input[1] + + + b_shape = None + for init in graph.initializer: + if init.name == input_b_name: + b_tensor = numpy_helper.to_array(init) + b_shape = b_tensor.shape + break + + + if b_shape is None: + for vi in graph.value_info: + if vi.name == input_b_name: + b_shape = [dim.dim_value for dim in vi.type.tensor_type.shape.dim] + break + + if b_shape is None: + output_name = node.output[0] + for vi in graph.value_info + [graph.output[i] for i in range(len(graph.output))]: + if vi.name == output_name: + output_shape = [dim.dim_value for dim in vi.type.tensor_type.shape.dim] + + transB = 0 + for attr in node.attribute: + if attr.name == 'transB' and attr.i == 1: + transB = 1 + + + c_length = output_shape[-1] + b_shape = [c_length, 0] if transB == 0 else [0, c_length] + break + + if b_shape is not None: + + transB = 0 + for attr in node.attribute: + if attr.name == 'transB' and attr.i == 1: + transB = 1 + + c_shape = [b_shape[1]] if not transB else [b_shape[0]] + + c_tensor = np.zeros(c_shape, dtype=np.float32) + c_name = f"{node.name}_c_bias" + + c_initializer = numpy_helper.from_array(c_tensor, name=c_name) + graph.initializer.append(c_initializer) + + + node.input.append(c_name) + print(f"Add C: {c_name}, Shape: {c_shape}") + else: + print(f"Warning: Cannot find {node.name} shape, pass this node.") + + + onnx.save(model, output_model_path) + print(f"Saved to: {output_model_path}") + +def replace_biasgelu_with_gelu_add(input_model_path, output_model_path): + + model = onnx.load(input_model_path) + + # Collect all value_info entries by name for easy lookup + value_info_map = {} + for vi in model.graph.value_info: + value_info_map[vi.name] = vi + + # Add input and output value_info to the map + for inp in model.graph.input: + value_info_map[inp.name] = inp + + for out in model.graph.output: + value_info_map[out.name] = out + + # Create new node list and value_info list + new_nodes = [] + new_value_info = [] + biasgelu_count = 0 + + # Counter for generating unique names + unique_id = 0 + def get_unique_name(prefix): + nonlocal unique_id + name = f"{prefix}_{unique_id}" + unique_id += 1 + return name + + # Process all nodes + for node in model.graph.node: + if node.op_type == 'BiasGelu': + biasgelu_count += 1 + + # Get BiasGelu inputs and outputs + input_name = node.input[0] # X + bias_name = node.input[1] # Bias + output_name = node.output[0] # Y + + # Generate unique name prefix + prefix = node.name if node.name else f"gelu_add" + + # Step 1: First apply Add operation to add bias + add_output = get_unique_name(f"{prefix}_add_out") + add_node = helper.make_node( + 'Add', + inputs=[input_name, bias_name], + outputs=[add_output], + name=f"{prefix}_add" + ) + new_nodes.append(add_node) + + # Create value_info for add_output with proper type and shape + # Use the same type and shape as the input tensor if available + if input_name in value_info_map: + input_value_info = value_info_map[input_name] + add_output_value_info = helper.make_tensor_value_info( + add_output, + input_value_info.type.tensor_type.elem_type, + [d.dim_value if d.dim_value else d.dim_param for d in input_value_info.type.tensor_type.shape.dim] + ) + new_value_info.append(add_output_value_info) + value_info_map[add_output] = add_output_value_info + + # Step 2: Then apply Gelu activation function + gelu_node = helper.make_node( + 'Gelu', + inputs=[add_output], + outputs=[output_name], + name=f"{prefix}_gelu" + ) + new_nodes.append(gelu_node) + + # If we have output value_info, make sure it's preserved + # Otherwise, create it with the same shape and type as the input to Gelu + if output_name not in value_info_map and add_output in value_info_map: + output_value_info = helper.make_tensor_value_info( + output_name, + value_info_map[add_output].type.tensor_type.elem_type, + [d.dim_value if d.dim_value else d.dim_param for d in value_info_map[add_output].type.tensor_type.shape.dim] + ) + new_value_info.append(output_value_info) + value_info_map[output_name] = output_value_info + else: + # Keep other nodes unchanged + new_nodes.append(node) + + print(f"Replaced {biasgelu_count} BiasGelu nodes with Gelu+Add combinations") + + # Create new graph with all collected value_info + new_graph = helper.make_graph( + nodes=new_nodes, + name=model.graph.name, + inputs=model.graph.input, + outputs=model.graph.output, + initializer=model.graph.initializer, + value_info=list(model.graph.value_info) + new_value_info + ) + + # Build new model, preserving original model metadata + new_model = helper.make_model( + new_graph, + producer_name=model.producer_name, + producer_version=model.producer_version, + domain=model.domain, + model_version=model.model_version, + doc_string=model.doc_string + ) + + # Copy opset imports + del new_model.opset_import[:] + new_model.opset_import.extend(model.opset_import) + + # Add Microsoft domain if not present (for Gelu) + has_ms_domain = any(opset.domain == "com.microsoft" for opset in new_model.opset_import) + if not has_ms_domain: + ms_opset = helper.make_opsetid("com.microsoft", 1) + new_model.opset_import.append(ms_opset) + + # Copy IR version + new_model.ir_version = model.ir_version + + # Run shape inference to ensure all shapes are properly defined + try: + new_model = shape_inference.infer_shapes(new_model) + print("Shape inference successful") + except Exception as e: + print(f"Warning: Shape inference failed: {e}") + + # Skip validation and directly save if needed + try: + onnx.checker.check_model(new_model) + print("Model validation successful!") + except Exception as e: + print(f"Warning: Model validation failed, but still saving: {e}") + + # Save the modified model + onnx.save(new_model, output_model_path) + print(f"Saved modified model to {output_model_path}") + + return new_model + +def fix_layernorm_output(input_model_path: str, output_model_path: str) -> bool: + """ + Fix output types and shapes for all LayerNorm operators in an ONNX model. + + Args: + input_model_path (str): Path to the input model file + output_model_path (str): Path to save the output model file + + Returns: + bool: True if the operation succeeded, False otherwise + """ + try: + # Load the model + model = onnx.load(input_model_path) + graph = model.graph + + # Find all LayerNorm nodes + layernorm_count = 0 + updated_count = 0 + tensor_info = {} + + # Collect tensor information + # Process input tensors + for input_tensor in graph.input: + name = input_tensor.name + shape = [dim.dim_value if dim.dim_value > 0 else None for dim in input_tensor.type.tensor_type.shape.dim] + elem_type = input_tensor.type.tensor_type.elem_type + tensor_info[name] = {"shape": shape, "elem_type": elem_type} + + # Process intermediate and output tensors + for value_info in list(graph.value_info) + list(graph.output): + name = value_info.name + shape = [dim.dim_value if dim.dim_value > 0 else None for dim in value_info.type.tensor_type.shape.dim] + elem_type = value_info.type.tensor_type.elem_type + tensor_info[name] = {"shape": shape, "elem_type": elem_type} + + # Fix each LayerNorm node + for node in graph.node: + if node.op_type == 'LayerNormalization': + layernorm_count += 1 + + if not node.input: + continue + + # Get input information + input_name = node.input[0] + if input_name not in tensor_info: + continue + + input_info = tensor_info[input_name] + input_shape = input_info["shape"] + input_elem_type = input_info["elem_type"] + + # Get axis attribute + axis = -1 + for attr in node.attribute: + if attr.name == "axis": + axis = attr.i + break + + # Process all outputs + for i, output_name in enumerate(node.output): + # Determine correct output shape and type + output_shape = None + output_elem_type = input_elem_type + + if i == 0: # Main output - same shape as input + output_shape = input_shape + else: # mean and std outputs - shape depends on normalization axis + # Handle negative axis index + if axis < 0 and input_shape and None not in input_shape: + axis = len(input_shape) + axis + + # Create shape for mean and std (remove normalization axis) + if input_shape and None not in input_shape and 0 <= axis < len(input_shape): + output_shape = list(input_shape) + output_shape.pop(axis) # Remove normalization axis + + # Find and remove existing value info + for value_info in list(graph.value_info): + if value_info.name == output_name: + graph.value_info.remove(value_info) + break + + # Create new value info + if output_shape and None not in output_shape: + new_value_info = onnx.helper.make_tensor_value_info( + output_name, + output_elem_type, + output_shape + ) + graph.value_info.append(new_value_info) + + # Update graph output if needed + for j, output in enumerate(list(graph.output)): + if output.name == output_name: + if output_shape and None not in output_shape: + new_output = onnx.helper.make_tensor_value_info( + output_name, + output_elem_type, + output_shape + ) + graph.output.remove(output) + graph.output.insert(j, new_output) + break + + # Update tensor info dictionary + tensor_info[output_name] = { + "shape": output_shape, + "elem_type": output_elem_type + } + + print(f" Output {i}: {output_name}, shape={output_shape}") # Debug info + + updated_count += 1 + + # Save the model + onnx.save(model, output_model_path) + print(f"Updated {updated_count}/{layernorm_count} LayerNorm nodes, model saved to {output_model_path}") + return True + + except Exception as e: + print(f"Error fixing LayerNorm outputs: {str(e)}") + return False + + +def modify_conflict_outputs(input_model_path, output_model_path): + model = onnx.load(input_model_path) + graph = model.graph + + select_nodes = [] + for node in graph.node: + if node.op_type == 'LayerNormalization' or node.op_type == 'MaxPool': + # if node.op_type == 'MaxPool': + select_nodes.append(node) + + print(f"Find {len(select_nodes)} Maxpool") + + outputs_to_remove = [] + + new_nodes = [] + + for node in graph.node: + if (node.op_type == 'LayerNormalization' or node.op_type == 'MaxPool') and len(node.output) > 1: + # if (node.op_type == 'MaxPool') and len(node.output) > 1: + outputs_to_remove.extend(node.output[1:]) + + new_node = onnx.NodeProto() + new_node.CopyFrom(node) + first_output = node.output[0] + + del new_node.output[:] + new_node.output.append(first_output) + + new_nodes.append(new_node) + else: + new_nodes.append(node) + + del graph.node[:] + graph.node.extend(new_nodes) + + new_outputs = [] + for output in graph.output: + if output.name not in outputs_to_remove: + new_outputs.append(output) + + del graph.output[:] + graph.output.extend(new_outputs) + + onnx.save(model, output_model_path) + print(f"Saved to: {output_model_path}") + +def convert_squeeze_unsqueeze_input_to_attr(input_model_path, output_model_path): + """ + Convert Squeeze and Unsqueeze nodes with axes as input to axes as attribute. + This is useful for compatibility with older ONNX versions where axes was only supported as an attribute. + + Args: + input_model_path: Path to the input ONNX model + output_model_path: Path to save the converted ONNX model + """ + model = onnx.load(input_model_path) + + modified_nodes = [] + modified_count = 0 + + initializers = {init.name: init for init in model.graph.initializer} + + for node in model.graph.node: + # Check if the node is Squeeze or Unsqueeze with more than one input + if (node.op_type in ['Squeeze', 'Unsqueeze']) and len(node.input) > 1: + modified_count += 1 + + data_input = node.input[0] + + axes_input_name = node.input[1] + + if axes_input_name in initializers: + # Get the axes values from the initializer + axes_initializer = initializers[axes_input_name] + axes_np = numpy_helper.to_array(axes_initializer) + axes_list = axes_np.tolist() + + # Make the axes a scalar if it's a single value + if isinstance(axes_list, list) and len(axes_list) == 1: + axes_list = axes_list[0] + + # Create a new node with axes as attribute instead of input + new_node = helper.make_node( + op_type=node.op_type, + inputs=[data_input], + outputs=list(node.output), + name=node.name, + axes=axes_list + ) + + # Copy other attributes if they exist + for attr in node.attribute: + if attr.name != 'axes': + new_node.attribute.append(attr) + + modified_nodes.append(new_node) + else: + # If we can't find the axes initializer, keep the original node + print(f"Warning: Cannot find '{node.name}' axes initializer. Keep the original node.") + modified_nodes.append(node) + else: + # Keep all other nodes as they are + modified_nodes.append(node) + + print(f"Modified {modified_count} Squeeze/Unsqueeze nodes") + + # Identify initializers that are no longer referenced + # This happens when we convert the axes from input to attribute + used_inputs = set() + for node in modified_nodes: + for input_name in node.input: + used_inputs.add(input_name) + + unused_initializers = set() + for init in model.graph.initializer: + if init.name not in used_inputs: + unused_initializers.add(init.name) + + # Create a new graph with the modified nodes and without unused initializers + new_graph = helper.make_graph( + nodes=modified_nodes, + name=model.graph.name, + inputs=model.graph.input, + outputs=model.graph.output, + initializer=[init for init in model.graph.initializer if init.name not in unused_initializers] + ) + + # Copy over value_info from the original model + for vi in model.graph.value_info: + new_graph.value_info.append(vi) + + # Create a new model with the updated graph + new_model = helper.make_model( + new_graph, + producer_name=model.producer_name, + producer_version=model.producer_version, + domain=model.domain, + model_version=model.model_version, + doc_string=model.doc_string + ) + + # Copy over IR version and opset imports + new_model.ir_version = model.ir_version + new_model.opset_import.extend(model.opset_import) + + # Save the model + onnx.save(new_model, output_model_path) + print(f"Saved to {output_model_path}") + + return new_model + +def run_optmization_remove_biasgelu(onnx_train_file, onnx_out_file): + """ + Replace BiasGelu operations with Add+Gelu while maintaining shape consistency. + + Args: + onnx_train_file: Path to input ONNX model file + onnx_out_file: Path to output ONNX model file + """ + # Load the model + model = onnx.load(onnx_train_file) + graph = model.graph + + # Create new nodes list to replace the old ones + new_nodes = [] + replaced_count = 0 + + # Process all nodes + for node in graph.node: + if node.op_type == "BiasGelu": + print(f"🔄 Replacing BiasGeluFusion: {node.name}") + replaced_count += 1 + + # Get input and output tensors + X, Bias = node.input + output = node.output[0] + + # Create intermediate tensor name + intermediate_output = f"{X}_add_bias" + + # Create Add node + add_node = helper.make_node( + "Add", + inputs=[X, Bias], + outputs=[intermediate_output], + name=f"{node.name}_Add", + ) + + # Create Gelu node + gelu_node = helper.make_node( + "Gelu", + inputs=[intermediate_output], + outputs=node.output, + name=f"{node.name}_Gelu", + ) + + # Add shape information for the intermediate tensor + # Try to find X's shape info + X_shape = None + X_type = 1 # Default to FLOAT + + # Look for X in inputs, outputs, value_info, or initializers + for info in graph.input: + if info.name == X: + X_shape = [d.dim_value if d.HasField("dim_value") else -1 for d in info.type.tensor_type.shape.dim] + X_type = info.type.tensor_type.elem_type + break + + if X_shape is None: + for info in graph.value_info: + if info.name == X: + X_shape = [d.dim_value if d.HasField("dim_value") else -1 for d in info.type.tensor_type.shape.dim] + X_type = info.type.tensor_type.elem_type + break + + # Add value_info for intermediate tensor + if X_shape: + value_info = helper.make_tensor_value_info( + intermediate_output, + X_type, + X_shape + ) + graph.value_info.append(value_info) + + # Add the new nodes to our list + new_nodes.extend([add_node, gelu_node]) + else: + # Keep other nodes unchanged + new_nodes.append(node) + + # Replace nodes in the graph + graph.ClearField("node") + graph.node.extend(new_nodes) + + # Create a copy of the model for selective shape inference + safe_model = copy.deepcopy(model) + + # Remove Microsoft custom operators that might cause shape inference to fail + ms_nodes = [] + for node in safe_model.graph.node: + if node.domain == "com.microsoft": + ms_nodes.append(node) + + if ms_nodes: + print(f"⚠️ Found {len(ms_nodes)} Microsoft custom operators that might affect shape inference") + + # Try to run shape inference on the modified model without MS operators + try: + # Create a temporary graph without Microsoft operators + temp_model = copy.deepcopy(safe_model) + temp_graph = temp_model.graph + + # Remove Microsoft custom operators + temp_nodes = [node for node in temp_graph.node if node.domain != "com.microsoft"] + temp_graph.ClearField("node") + temp_graph.node.extend(temp_nodes) + + # Run shape inference on this simplified model + inferred_model = shape_inference.infer_shapes(temp_model) + + # Collect inferred shapes for our new nodes + inferred_value_infos = {} + for value_info in inferred_model.graph.value_info: + inferred_value_infos[value_info.name] = value_info + + # Update the original model with any newly inferred shapes + for name, value_info in inferred_value_infos.items(): + # Skip if already exists + if any(info.name == name for info in model.graph.value_info): + continue + + model.graph.value_info.append(value_info) + + print("✅ Partial shape inference completed for non-Microsoft operators") + except Exception as e: + print(f"⚠️ Partial shape inference failed: {e}") + else: + # No Microsoft operators, try regular shape inference + try: + model = shape_inference.infer_shapes(model) + print("✅ Shape inference completed successfully") + except Exception as e: + print(f"⚠️ Shape inference failed: {e}") + + # Save the modified model + onnx.save(model, onnx_out_file) + + if replaced_count > 0: + print(f"✅ Successfully replaced {replaced_count} BiasGelu nodes with Add + GELU.") + else: + print("⚠️ No BiasGelu nodes were replaced.") + + return model + +def optimize_reshape_fusion(input_model_path: str, output_model_path: str) -> None: + """ + Optimize ONNX model by fusing consecutive Reshape operations. + + Args: + input_model_path: Path to the input ONNX model + output_model_path: Path where the optimized ONNX model will be saved + """ + print(f"Loading model: {input_model_path}") + model = onnx.load(input_model_path) + + # Create mapping from node name to node + node_map = {} + for node in model.graph.node: + node_map[node.name] = node + + # Create mapping from input name to producing node + input_to_node = {} + for node in model.graph.node: + for output in node.output: + input_to_node[output] = node + + # Create mapping from output name to consuming nodes + output_to_nodes = {} + for node in model.graph.node: + for input_name in node.input: + if input_name not in output_to_nodes: + output_to_nodes[input_name] = [] + output_to_nodes[input_name].append(node) + + # Find all Reshape nodes + reshape_nodes = [node for node in model.graph.node if node.op_type == "Reshape"] + + # Track nodes to be removed by index rather than node objects + # This avoids the "unhashable type: 'NodeProto'" error + nodes_to_remove_indices = [] + + # Track value info to keep + value_info_to_keep = set(vi.name for vi in model.graph.value_info) + + # For each Reshape node, check if its input is also from a Reshape node + for reshape_node in reshape_nodes: + # Get the input of the current Reshape node + input_name = reshape_node.input[0] + + # Check if the input comes from another Reshape operation + if input_name in input_to_node and input_to_node[input_name].op_type == "Reshape": + previous_reshape = input_to_node[input_name] + + # Check if the previous Reshape is only used by the current Reshape + if input_name in output_to_nodes and len(output_to_nodes[input_name]) == 1: + print(f"Found fusible Reshape pair: {previous_reshape.name} -> {reshape_node.name}") + + # Get the shape tensors for both Reshape nodes + prev_shape_tensor_name = previous_reshape.input[1] + current_shape_tensor_name = reshape_node.input[1] + + # Modify the current Reshape node to connect directly to the input of the previous Reshape + reshape_node.input[0] = previous_reshape.input[0] + + # Mark the previous Reshape node for removal by its index + for i, node in enumerate(model.graph.node): + if (node.name == previous_reshape.name and + node.op_type == previous_reshape.op_type and + node.input == previous_reshape.input and + node.output == previous_reshape.output): + nodes_to_remove_indices.append(i) + break + + # Intermediate value info doesn't need to be kept + if input_name in value_info_to_keep: + value_info_to_keep.remove(input_name) + + # Handle custom nodes from Microsoft + # Since Microsoft nodes might have a different structure or behavior + # We need to be careful when dealing with them + custom_nodes = [node for node in model.graph.node if node.domain.startswith('com.microsoft')] + print(f"Found {len(custom_nodes)} Microsoft custom nodes. These will be preserved.") + + # Create a new graph excluding the nodes to be removed + new_nodes = [] + for i, node in enumerate(model.graph.node): + if i not in nodes_to_remove_indices: + new_nodes.append(node) + + # Create a new value info list, keeping only the needed value info + new_value_info = [vi for vi in model.graph.value_info if vi.name in value_info_to_keep] + + # Create a new graph + new_graph = helper.make_graph( + nodes=new_nodes, + name=model.graph.name, + inputs=model.graph.input, + outputs=model.graph.output, + initializer=model.graph.initializer, + value_info=new_value_info + ) + + # Create a new model + new_model = helper.make_model( + new_graph, + producer_name="ONNX Reshape Fusion Optimizer", + ir_version=model.ir_version, + opset_imports=model.opset_import + ) + + # Preserve custom opsets from the original model + new_model.opset_import.extend([opset for opset in model.opset_import if opset.domain.startswith('com.microsoft')]) + + # Save the optimized model + onnx.save(new_model, output_model_path) + + # Print statistics + print(f"Original model node count: {len(model.graph.node)}") + print(f"Optimized model node count: {len(new_model.graph.node)}") + print(f"Removed Reshape nodes: {len(nodes_to_remove_indices)}") + print(f"Optimized model saved to: {output_model_path}") + + +def remove_identity_reducesum(input_model_path, output_model_path): + """ + Remove Identity and removable ReduceSum nodes from the model, + ensuring correct output naming + + Args: + input_model_path (str): Input ONNX model path + output_model_path (str): Output ONNX model path + + Returns: + onnx.ModelProto: Processed model + """ + import onnx + import numpy as np + from onnx import shape_inference, helper, TensorProto + + # Load the model and infer shapes + model = onnx.load(input_model_path) + try: + model = shape_inference.infer_shapes(model) + except Exception as e: + print(f"Warning: Shape inference failed: {e}. Continuing without shape information.") + + graph = model.graph + + # Build node mapping + node_map = {node.name: node for node in graph.node} + + # Store tensor shapes from value_info, inputs, and outputs + shape_info = {} + for info in list(graph.value_info) + list(graph.input) + list(graph.output): + if hasattr(info.type.tensor_type.shape, 'dim'): + dims = [] + for dim in info.type.tensor_type.shape.dim: + if dim.dim_value: + dims.append(dim.dim_value) + else: + dims.append(-1) + shape_info[info.name] = dims + + # Get initializer shapes + for initializer in graph.initializer: + shape_info[initializer.name] = list(initializer.dims) + + # Store nodes to remove and replacement mapping + nodes_to_remove = [] + replacement_map = {} + reshape_nodes_to_add = [] + + # Process Identity nodes + for node in graph.node: + if node.op_type == "Identity": + input_name = node.input[0] + output_name = node.output[0] + + replacement_map[output_name] = input_name + nodes_to_remove.append(node) + + # Process ReduceSum nodes + for node in graph.node: + if node.op_type == "ReduceSum": + input_name = node.input[0] + output_name = node.output[0] + + # Check for dimension 1 reduction with keepdims=0 + keepdims = 1 # Default value + for attr in node.attribute: + if attr.name == "keepdims": + keepdims = attr.i + break + + # Get reduction axes + axes = [] + for attr in node.attribute: + if attr.name == "axes": + axes = list(attr.ints) + break + + # If opset >= 13, axes might be an input + if len(node.input) > 1 and not axes: + axes_name = node.input[1] + for initializer in graph.initializer: + if initializer.name == axes_name: + axes = onnx.numpy_helper.to_array(initializer).tolist() + if not isinstance(axes, list): + axes = [axes] + break + + # Get input shape + if input_name in shape_info: + input_shape = shape_info[input_name] + + # Check if all reduction axes have dimension 1 + all_dim_one = True + for axis in axes: + # Handle negative axis + if axis < 0: + axis = len(input_shape) + axis + + if 0 <= axis < len(input_shape) and input_shape[axis] == 1: + continue + else: + all_dim_one = False + break + + if all_dim_one and axes: + if keepdims == 1: + # Simple replacement case + replacement_map[output_name] = input_name + nodes_to_remove.append(node) + elif keepdims == 0: + # Need to add a Reshape node + # Calculate output shape by removing dimensions with size 1 + output_shape = [] + for i, dim in enumerate(input_shape): + if i not in axes and (i + len(input_shape) not in axes): + output_shape.append(dim) + + # Create shape tensor for Reshape + shape_tensor_name = f"{node.name}_shape" + shape_tensor = helper.make_tensor( + name=shape_tensor_name, + data_type=TensorProto.INT64, + dims=[len(output_shape)], + vals=output_shape + ) + + # Create Reshape node + reshape_node = helper.make_node( + "Reshape", + inputs=[input_name, shape_tensor_name], + outputs=[output_name], + name=f"{node.name}_reshape" + ) + + # Store for later addition + reshape_nodes_to_add.append((reshape_node, shape_tensor)) + nodes_to_remove.append(node) + + # Update inputs of other nodes + for node in graph.node: + if node not in nodes_to_remove: + for i, input_name in enumerate(node.input): + if input_name in replacement_map: + node.input[i] = replacement_map[input_name] + + # Update graph outputs + for output in graph.output: + if output.name in replacement_map: + output.name = replacement_map[output.name] + + # Remove nodes and add new reshape nodes + new_nodes = [node for node in graph.node if node not in nodes_to_remove] + + # Add shape tensors to initializers + for _, shape_tensor in reshape_nodes_to_add: + graph.initializer.append(shape_tensor) + + # Add reshape nodes + for reshape_node, _ in reshape_nodes_to_add: + new_nodes.append(reshape_node) + + # Clear and re-add nodes + graph.ClearField("node") + graph.node.extend(new_nodes) + + # Save model + onnx.save(model, output_model_path) + + print(f"Saved to {output_model_path}") + print(f"Removed {len(nodes_to_remove)} nodes") + print(f"Added {len(reshape_nodes_to_add)} Reshape nodes") + + return model + + +def convert_reducesum_axes_to_attr(input_file: str, output_file: str): + model = onnx.load(input_file) + graph = model.graph + + new_nodes = [] + + initializers = {init.name: init for init in graph.initializer} + + for node in graph.node: + if node.op_type == "ReduceSum": + if len(node.input) >= 2: + data_input = node.input[0] + axes_input = node.input[1] + + if axes_input in initializers: + axes_tensor = initializers[axes_input] + axes_np = numpy_helper.to_array(axes_tensor) + axes_list = axes_np.tolist() + + new_node = helper.make_node( + op_type="ReduceSum", + inputs=[data_input], + outputs=node.output, + name=node.name, + axes=axes_list + ) + + for attr in node.attribute: + if attr.name != "axes": + new_node.attribute.append(attr) + + new_nodes.append(new_node) + else: + new_nodes.append(node) + else: + new_nodes.append(node) + else: + new_nodes.append(node) + + new_graph = helper.make_graph( + nodes=new_nodes, + name=graph.name, + inputs=graph.input, + outputs=graph.output, + initializer=graph.initializer, + value_info=graph.value_info + ) + + new_model = helper.make_model( + new_graph, + producer_name="ReduceSumAxesConverter", + ir_version=model.ir_version, + opset_imports=model.opset_import + ) + + new_model.metadata_props.extend(model.metadata_props) + + for domain in model.domain: + new_model.domain.append(domain) + + onnx.save(new_model, output_file) + print(f"Model converted and saved to: {output_file}") + + +def convert_fusedmatmul_to_gemm(input_model_path, output_model_path): + """ + Convert Microsoft's FusedMatMul nodes to standard Gemm nodes in an ONNX model. + This function handles custom ops and adds a zero tensor for the C input of Gemm when needed. + The three inputs to Gemm will be named A, B, and C in the function implementation. + + Args: + input_model_path: Path to the input ONNX model + output_model_path: Path to save the converted ONNX model + """ + # Load the model + model = onnx.load(input_model_path) + + # Track necessary changes + new_nodes = [] + new_initializers = [] + + # Process each node in the graph + for node in model.graph.node: + # Check if the node is a FusedMatMul from Microsoft domain + if node.op_type == "FusedMatMul" and node.domain == "com.microsoft": + # Extract attributes from FusedMatMul + alpha = 1.0 + transA = 0 + transB = 0 + + for attr in node.attribute: + if attr.name == "alpha": + alpha = attr.f + elif attr.name == "transA": + transA = attr.i + elif attr.name == "transB": + transB = attr.i + + # Get inputs and output of FusedMatMul + # In our implementation, we'll call these A and B + A = node.input[0] + B = node.input[1] + output = node.output[0] + + # Create a name for the zero tensor (C input for Gemm) + C = f"{output}_zero_bias" + + # To determine the shape of C, we need to find the output shape + # For this, we need to analyze the graph and infer shapes + + a_shape = None + b_shape = None + + # Try to find shapes from value_info or initializers + for vi in model.graph.value_info: + if vi.name == A: + a_shape = [dim.dim_value for dim in vi.type.tensor_type.shape.dim] + elif vi.name == B: + b_shape = [dim.dim_value for dim in vi.type.tensor_type.shape.dim] + + # Check initializers if shapes not found in value_info + if a_shape is None or b_shape is None: + for init in model.graph.initializer: + if init.name == A: + a_shape = list(init.dims) + elif init.name == B: + b_shape = list(init.dims) + + # If we couldn't determine exact shapes, use a placeholder approach + if a_shape and b_shape: + # Calculate output shape based on MatMul rules and transA/transB + if transA: + a_shape = a_shape[::-1] + if transB: + b_shape = b_shape[::-1] + + # For matmul: [M,K] * [K,N] = [M,N] + # The bias/C needs to be shape [N] + c_shape = [b_shape[-1]] + else: + # If we can't determine shapes, we'll add a placeholder initializer + c_shape = [1] # Placeholder + + # Create a zero tensor for C input + zero_tensor = numpy_helper.from_array( + np.zeros(c_shape, dtype=np.float32), + name=C + ) + new_initializers.append(zero_tensor) + + # Create the Gemm node + gemm_node = helper.make_node( + "Gemm", + inputs=[A, B, C], # Using A, B, C naming convention + outputs=[output], + name=f"{node.name}_gemm", + alpha=alpha, + beta=1.0, # Standard beta value + transA=transA, + transB=transB + ) + + new_nodes.append(gemm_node) + else: + # Keep other nodes as they are + new_nodes.append(node) + + # Create a new graph with updated nodes and initializers + new_graph = helper.make_graph( + nodes=new_nodes, + name=model.graph.name, + inputs=model.graph.input, + outputs=model.graph.output, + initializer=list(model.graph.initializer) + new_initializers, + value_info=model.graph.value_info + ) + + # Create a new model with the updated graph + # Preserve opset imports and other model metadata + new_model = helper.make_model( + new_graph, + producer_name="FusedMatMul2Gemm", + opset_imports=model.opset_import, + ir_version=model.ir_version + ) + + # Copy domain information for custom ops + for domain in model.domain: + new_model.domain.append(domain) + + # Copy model metadata + new_model.metadata_props.extend(model.metadata_props) + + # Save the new model + onnx.save(new_model, output_model_path) + print(f"Converted model saved to {output_model_path}") + + return new_model + +def convert_sum_to_add(input_model_path, output_model_path): + """ + Convert Sum operators to Add operators in an ONNX model. + Sum operator can take multiple inputs, while Add takes exactly two inputs. + This function breaks down Sum operators with >2 inputs into a series of Add operators. + + Args: + input_model_path: Path to the input ONNX model + output_model_path: Path to save the converted ONNX model + """ + # Load the model + model = onnx.load(input_model_path) + + # Track necessary changes + new_nodes = [] + processed_nodes = set() + + # Process each node in the graph + for i, node in enumerate(model.graph.node): + # Skip already processed nodes + if i in processed_nodes: + continue + + # Check if the node is a Sum operator + if node.op_type == "Sum": + input_count = len(node.input) + + if input_count == 1: + # Sum with one input is just an Identity + identity_node = helper.make_node( + "Identity", + inputs=[node.input[0]], + outputs=node.output, + name=f"{node.name}_identity" + ) + new_nodes.append(identity_node) + + elif input_count == 2: + # Sum with two inputs can be directly converted to Add + add_node = helper.make_node( + "Add", + inputs=[node.input[0], node.input[1]], + outputs=node.output, + name=f"{node.name}_add" + ) + new_nodes.append(add_node) + + else: + # Sum with more than two inputs needs to be broken down into a series of Add operations + # We'll create intermediate outputs for all but the last Add + intermediate_outputs = [] + + for j in range(input_count - 1): + if j == 0: + # First Add takes the first two inputs of Sum + input1 = node.input[0] + input2 = node.input[1] + else: + # Subsequent Adds take the output of the previous Add and the next input + input1 = intermediate_outputs[-1] + input2 = node.input[j + 1] + + # For the last Add, use the original output, otherwise create an intermediate output + if j == input_count - 2: + output = node.output[0] + else: + output = f"{node.name}_intermediate_{j}" + intermediate_outputs.append(output) + + # Create the Add node + add_node = helper.make_node( + "Add", + inputs=[input1, input2], + outputs=[output], + name=f"{node.name}_add_{j}" + ) + new_nodes.append(add_node) + + # Mark this node as processed + processed_nodes.add(i) + else: + # Keep other nodes as they are + new_nodes.append(node) + + # Create a new graph with updated nodes + new_graph = helper.make_graph( + nodes=new_nodes, + name=model.graph.name, + inputs=model.graph.input, + outputs=model.graph.output, + initializer=model.graph.initializer, + value_info=model.graph.value_info + ) + + # Create a new model with the updated graph + # Preserve opset imports and other model metadata + new_model = helper.make_model( + new_graph, + producer_name="SumToAddConverter", + opset_imports=model.opset_import, + ir_version=model.ir_version + ) + + # Copy domain information for custom ops + for domain in model.domain: + new_model.domain.append(domain) + + # Copy model metadata + new_model.metadata_props.extend(model.metadata_props) + + # Save the new model + onnx.save(new_model, output_model_path) + print(f"Converted model saved to {output_model_path}") + + return new_model + +def rename_softmaxgrad_op(input_model_path: str, output_model_path: str, + old_op_name: str = "SoftmaxGrad_13", + new_op_name: str = "SoftmaxGrad"): + """ + Rename Microsoft's custom operator SoftmaxGrad_13 to SoftmaxGrad. + + Args: + input_model_path: Path to the input ONNX model + output_model_path: Path to save the converted ONNX model + old_op_name: Original operator name (default: "SoftmaxGrad_13") + new_op_name: New operator name (default: "SoftmaxGrad") + """ + model = onnx.load(input_model_path) + + modified_nodes = [] + modified_count = 0 + + # Process each node in the graph + for node in model.graph.node: + # Check if the node is the target Microsoft domain operator + if node.op_type == old_op_name and node.domain == "com.microsoft": + modified_count += 1 + + # Create a new node with the updated op_type + new_node = helper.make_node( + op_type=new_op_name, + inputs=list(node.input), + outputs=list(node.output), + name=node.name, + domain=node.domain # Keep the original domain + ) + + # Copy all attributes from the original node + for attr in node.attribute: + new_node.attribute.append(attr) + + modified_nodes.append(new_node) + else: + # Keep all other nodes as they are + modified_nodes.append(node) + + print(f"Modified {modified_count} {old_op_name} nodes to {new_op_name}") + + # Create a new graph with the modified nodes + new_graph = helper.make_graph( + nodes=modified_nodes, + name=model.graph.name, + inputs=model.graph.input, + outputs=model.graph.output, + initializer=model.graph.initializer + ) + + # Copy over value_info from the original model + for vi in model.graph.value_info: + new_graph.value_info.append(vi) + + # Create a new model with the updated graph + new_model = helper.make_model( + new_graph, + producer_name=model.producer_name, + producer_version=model.producer_version, + domain=model.domain, + model_version=model.model_version, + doc_string=model.doc_string + ) + + # Copy over IR version and opset imports + new_model.ir_version = model.ir_version + new_model.opset_import.extend(model.opset_import) + + # Save the model + onnx.save(new_model, output_model_path) + print(f"Saved to {output_model_path}") + + return new_model + +def remove_softmax_loss_outputs(input_model_path, output_model_path): + """ + Remove loss outputs from SoftmaxCrossEntropyLoss nodes, keeping only the log probability output. + + Args: + input_model_path (str): Path to the input ONNX model + output_model_path (str): Path to save the modified ONNX model + """ + import onnx + + # Load the model + model = onnx.load(input_model_path) + graph = model.graph + + # Find SoftmaxCrossEntropyLoss nodes + target_nodes = [] + for node in graph.node: + if node.op_type == 'SoftmaxCrossEntropyLoss': + target_nodes.append(node) + + print(f"Found {len(target_nodes)} SoftmaxCrossEntropyLoss nodes") + + # Outputs to remove (first output - loss) + outputs_to_remove = [] + + # Create new nodes with modified outputs + new_nodes = [] + for node in graph.node: + if node.op_type == 'SoftmaxCrossEntropyLoss' and len(node.output) > 1: + # Keep only the second output (log probabilities) and remove the first (loss) + outputs_to_remove.append(node.output[0]) + + # Create a new node with only the second output + new_node = onnx.NodeProto() + new_node.CopyFrom(node) + log_prob_output = node.output[1] + + # Clear outputs and set only the log probability output + del new_node.output[:] + new_node.output.append(log_prob_output) + + new_nodes.append(new_node) + else: + # Keep other nodes unchanged + new_nodes.append(node) + + # Replace all nodes with the new set + del graph.node[:] + graph.node.extend(new_nodes) + + # Filter graph outputs to remove loss outputs + new_outputs = [] + for output in graph.output: + if output.name not in outputs_to_remove: + new_outputs.append(output) + + # Replace graph outputs with filtered list + del graph.output[:] + graph.output.extend(new_outputs) + + # Save the modified model + onnx.save(model, output_model_path) + print(f"Saved model with loss outputs removed to: {output_model_path}") + +def remove_softmax_grad_loss_inputs(input_model_path, output_model_path): + + # Load the model + model = onnx.load(input_model_path) + graph = model.graph + + # Find SoftmaxCrossEntropyLossGrad nodes + target_nodes = [] + for node in graph.node: + if node.op_type == 'SoftmaxCrossEntropyLossGrad': + target_nodes.append(node) + + print(f"Found {len(target_nodes)} SoftmaxCrossEntropyLossGrad nodes") + + # Inputs to remove (first input) + inputs_to_remove = [] + + # Create new nodes with modified inputs + new_nodes = [] + for node in graph.node: + if node.op_type == 'SoftmaxCrossEntropyLossGrad' and len(node.input) > 2: + # Remove the first input and keep the rest + first_input = node.input[0] + inputs_to_remove.append(first_input) + + # Create a new node without the first input + new_node = onnx.NodeProto() + new_node.CopyFrom(node) + + # Keep only the second and third inputs + remaining_inputs = list(node.input[1:]) + del new_node.input[:] + new_node.input.extend(remaining_inputs) + + new_nodes.append(new_node) + else: + # Keep other nodes unchanged + new_nodes.append(node) + + # Replace all nodes with the new set + del graph.node[:] + graph.node.extend(new_nodes) + + # Save the modified model + onnx.save(model, output_model_path) + print(f"Saved model with SoftmaxCrossEntropyLossGrad first input removed to: {output_model_path}") + +def optimize_softmax_axis(input_model_path, output_model_path): + + model = onnx.load(input_model_path) + + # Track if we made any changes + optimized = False + + # Create a map of value_info by name for easy access + value_info_map = {vi.name: vi for vi in model.graph.value_info} + value_info_map.update({vi.name: vi for vi in model.graph.input}) + value_info_map.update({vi.name: vi for vi in model.graph.output}) + + # Function to get shape from value_info + def get_shape(tensor_name): + if tensor_name in value_info_map: + shape = [] + for dim in value_info_map[tensor_name].type.tensor_type.shape.dim: + if dim.dim_param: + # Handle symbolic dimensions (set to -1 for dynamic dimension) + shape.append(-1) + else: + shape.append(dim.dim_value) + return shape + return None + + # Track the names of nodes to be removed + nodes_to_remove = [] + + # Track new nodes and value_infos to be added + new_nodes = [] + new_value_infos = [] + + # For each node in the graph + for i, node in enumerate(model.graph.node): + if node.op_type == "Softmax": + # Get the input and output names + input_name = node.input[0] + output_name = node.output[0] + + # Get the axis attribute + axis = None + for attr in node.attribute: + if attr.name == "axis": + axis = attr.i + break + + # If axis is not set, it defaults to 1 in ONNX + if axis is None: + axis = 1 + + # Get the input shape + input_shape = get_shape(input_name) + if input_shape is None: + print(f"Warning: Could not determine shape for {input_name}, skipping optimization") + continue + + # Check if all dimensions after axis are 1 + all_ones_after_axis = all(dim == 1 for dim in input_shape[axis+1:]) if axis+1 < len(input_shape) else True + + # Only optimize if the axis is not the last dimension and all subsequent dimensions are 1 + if axis != len(input_shape) - 1 and all_ones_after_axis and axis >= 0: + print(f"Optimizing Softmax node with input shape {input_shape} and axis={axis}") + + # Create unique names for intermediate tensors + reshape_before_output = f"{input_name}_reshaped_before_softmax" + softmax_output = f"{output_name}_after_softmax" + + # Calculate new shapes + # Move the axis dimension to the end and flatten all the 1s + new_shape_before = [] + for i in range(len(input_shape)): + if i < axis: + new_shape_before.append(input_shape[i]) + elif i == axis: + continue + elif i > axis: + continue + new_shape_before.append(input_shape[axis]) + + # Create reshape node before softmax + reshape_before_node = helper.make_node( + "Reshape", + inputs=[input_name, f"{input_name}_shape_before"], + outputs=[reshape_before_output], + name=f"Reshape_before_softmax_{output_name}" + ) + + # Create initializer for the shape tensor + shape_tensor_before = numpy_helper.from_array( + np.array(new_shape_before, dtype=np.int64), + name=f"{input_name}_shape_before" + ) + + # Create new softmax node with axis set to -1 (last dimension) + new_softmax_node = helper.make_node( + "Softmax", + inputs=[reshape_before_output], + outputs=[softmax_output], + name=f"Softmax_optimized_{output_name}", + axis=-1 # Use -1 to always target the last dimension + ) + + # Create reshape node after softmax to restore original shape + reshape_after_node = helper.make_node( + "Reshape", + inputs=[softmax_output, f"{output_name}_shape_after"], + outputs=[output_name], + name=f"Reshape_after_softmax_{output_name}" + ) + + # Create initializer for the shape tensor + shape_tensor_after = numpy_helper.from_array( + np.array(input_shape, dtype=np.int64), + name=f"{output_name}_shape_after" + ) + + # Create value info for reshape_before_output + reshape_before_vi = helper.make_tensor_value_info( + reshape_before_output, + value_info_map[input_name].type.tensor_type.elem_type, + new_shape_before + ) + + # Create value info for softmax_output + softmax_output_vi = helper.make_tensor_value_info( + softmax_output, + value_info_map[output_name].type.tensor_type.elem_type, + new_shape_before # Shape doesn't change after softmax + ) + + # Add all new nodes and value infos + new_nodes.extend([reshape_before_node, new_softmax_node, reshape_after_node]) + new_value_infos.extend([reshape_before_vi, softmax_output_vi]) + model.graph.initializer.extend([shape_tensor_before, shape_tensor_after]) + + # Mark the original node for removal + nodes_to_remove.append(node) + optimized = True + + # Remove the original nodes that were optimized + for node in nodes_to_remove: + model.graph.node.remove(node) + + # Add the new nodes and value infos + model.graph.node.extend(new_nodes) + model.graph.value_info.extend(new_value_infos) + + # Save the optimized model + print(f"Saving optimized model to {output_model_path}") + onnx.save(model, output_model_path) + + print(f"Optimization complete. Modified {len(nodes_to_remove)} Softmax nodes.") + return optimized + diff --git a/Tests/Models/CCT/utils/utils.py b/Tests/Models/CCT/utils/utils.py new file mode 100644 index 0000000..ac75d20 --- /dev/null +++ b/Tests/Models/CCT/utils/utils.py @@ -0,0 +1,450 @@ +import onnx +import os +import re +import subprocess +import yaml +import torch +from onnx import helper, numpy_helper, shape_inference +import numpy as np +import onnxruntime.tools +from onnxruntime.tools import symbolic_shape_infer +import copy +from .fixshape import print_onnx_shapes +from .trainoptimization import * +import random +from onnx import TensorProto + +def make_c_name(name, count=0): + if name.lower() in ["input", "output"]: + return name # Keep 'input' and 'output' as is + + name = re.sub(r'input|output', '', name, flags=re.IGNORECASE) # Remove 'input' and 'output' from other names + name = re.sub(r'[^a-zA-Z0-9_]', '_', name) + if name is None or name == "": + name = f'node_{count}' + if name[0].isdigit() or name[0] == '_': + name = f'node_{count}' + name + return name + +def rename_onnx_nodes(model): + i_node = 0 + for node in model.graph.node: + i_node += 1 + node.name = make_c_name(node.name, i_node) + for i, input_name in enumerate(node.input): + node.input[i] = make_c_name(input_name) + for i, output_name in enumerate(node.output): + node.output[i] = make_c_name(output_name) + + for input in model.graph.input: + input.name = make_c_name(input.name) + for output in model.graph.output: + output.name = make_c_name(output.name) + + for init in model.graph.initializer: + init.name = make_c_name(init.name) + + return model + +def rename_and_save_onnx(input_onnx, output_onnx): + model = onnx.load(input_onnx) + model = rename_onnx_nodes(model) + onnx.save(model, output_onnx) + print(f"✅ Renamed ONNX model saved to {output_onnx}") + +def run_onnx_optimization_infer(onnx_file, embedding_dim, num_heads, input_shape): + + batch_size, channels, height, width = input_shape # Extract input dimensions + try: + print("🔹 Fixing dynamic shape...") + subprocess.run([ + "python", "-m", "onnxruntime.tools.make_dynamic_shape_fixed", + "--input_name", "input", + "--input_shape", f"{batch_size},{channels},{height},{width}", + onnx_file, onnx_file + ], check=True) + + print("🔹 Running symbolic shape inference...") + subprocess.run([ + "python", "-m", "onnxruntime.tools.symbolic_shape_infer", + "--input", onnx_file, "--output", onnx_file, "--verbose", "3" + ], check=True) + + print("🔹 Optimizing ONNX model for ViT...") + subprocess.run([ + "python", "-m", "onnxruntime.transformers.optimizer", + "--input", onnx_file, "--output", onnx_file, + "--model_type", "vit", + "--num_heads", str(num_heads), # Controlled via config + "--hidden_size", str(embedding_dim), # Ensures hidden size = embedding_dim + "--use_multi_head_attention", + "--disable_bias_skip_layer_norm", + "--disable_skip_layer_norm", + "--disable_bias_gelu" + ], check=True) + + print("✅ ONNX model optimization complete!") + + except subprocess.CalledProcessError as e: + print(f"❌ Error during ONNX optimization: {e}") + +def run_onnx_optimization(onnx_file, embedding_dim, num_heads, input_shape): + """ Run ONNX Runtime tools to optimize the model """ + + batch_size, channels, height, width = input_shape # Extract input dimensions + + try: + print("🔹 Fixing dynamic shape...") + subprocess.run([ + "python", "-m", "onnxruntime.tools.make_dynamic_shape_fixed", + "--input_name", "input", + "--input_shape", f"{batch_size},{channels},{height},{width}", + onnx_file, onnx_file + ], check=True) + + print("🔹 Running symbolic shape inference...") + subprocess.run([ + "python", "-m", "onnxruntime.tools.symbolic_shape_infer", + "--input", onnx_file, "--output", onnx_file, "--verbose", "3" + ], check=True) + + print("🔹 Optimizing ONNX model for ViT...") + subprocess.run([ + "python", "-m", "onnxruntime.transformers.optimizer", + "--input", onnx_file, "--output", onnx_file, + "--model_type", "vit", + "--num_heads", str(num_heads), # Controlled via config + "--hidden_size", str(embedding_dim), # Ensures hidden size = embedding_dim + "--use_multi_head_attention", + "--disable_bias_skip_layer_norm", + "--disable_skip_layer_norm", + "--disable_bias_gelu", + "--disable_layer_norm", # compatible with opset 15 + ], check=True) + + print("✅ ONNX model optimization complete!") + + except subprocess.CalledProcessError as e: + print(f"❌ Error during ONNX optimization: {e}") + +def load_config(config_filename="../config.yaml"): + """Load and parse config.yaml, returning CCT-specific parameters in a single return statement.""" + # Resolve config.yaml relative to the script's location + script_dir = os.path.dirname(os.path.abspath(__file__)) + config_file = os.path.join(script_dir, config_filename) + + with open(config_file, "r") as f: + config = yaml.safe_load(f).get("cct", {}) + + return ( + config["pretrained"], + config["img_size"], + config["num_classes"], + config["embedding_dim"], + config["num_heads"], + config["num_layers"], + config["batch_size"], + config.get("opset_version", 12) # Default value for opset_version + ) + +def load_train_config(config_filename="../config.yaml"): + """Load and parse config.yaml, returning CCT-specific parameters in a single return statement.""" + # Resolve config.yaml relative to the script's location + script_dir = os.path.dirname(os.path.abspath(__file__)) + config_file = os.path.join(script_dir, config_filename) + + with open(config_file, "r") as f: + config = yaml.safe_load(f).get("training", {}) + + return config.get("learning_rate", 0.01) + + +def run_train_onnx_optimization(onnx_train_file, onnx_output_file): + # remove the second output of maxpool + print(f"🔹 Running optimization for {onnx_train_file}...") + + run_optmization_remove_biasgelu(onnx_train_file, onnx_train_file) + print(f"✅ Successfully removed BiasGeluFusion. Saved as {onnx_train_file}") + + fix_layernorm_output(onnx_train_file, onnx_train_file) + print( + f"✅ Successfully fixed LayerNormalization opset version. Saved as {onnx_train_file}" + ) + optimize_softmax_axis(onnx_train_file, onnx_train_file) + print( + f"✅ Successfully optimized Softmax axis. Saved as {onnx_train_file}") + + optimize_reshape_fusion(onnx_train_file, onnx_train_file) + print( + f"✅ Successfully optimized Reshape nodes. Saved as {onnx_output_file}") + + modify_conflict_outputs(onnx_train_file, onnx_train_file) + print( + f"✅ Successfully removed all second outputs from Maxpool nodes. Saved as {onnx_output_file}" + ) + + convert_squeeze_unsqueeze_input_to_attr(onnx_train_file, onnx_train_file) + print( + f"✅ Successfully converted Squeeze inputs to attributes. Saved as {onnx_output_file}" + ) + + add_c_to_gemm(onnx_train_file, onnx_output_file) + print(f"✅ Successfully added C to Gemm nodes. Saved as {onnx_output_file}") + + remove_identity_reducesum(onnx_output_file, onnx_output_file) + print( + f"✅ Successfully removed Identity and ReduceSum nodes. Saved as {onnx_output_file}" + ) + + convert_reducesum_axes_to_attr(onnx_output_file, onnx_output_file) + print( + f"✅ Successfully converted ReduceSum axes to attributes. Saved as {onnx_output_file}" + ) + + convert_fusedmatmul_to_gemm(onnx_output_file, onnx_output_file) + print( + f"✅ Successfully converted FusedMatMul to Gemm nodes. Saved as {onnx_output_file}" + ) + + convert_sum_to_add(onnx_output_file, onnx_output_file) + print( + f"✅ Successfully converted Sum to Add nodes. Saved as {onnx_output_file}" + ) + + rename_softmaxgrad_op(onnx_output_file, onnx_output_file) + print( + f"✅ Successfully renamed SoftmaxGrad nodes. Saved as {onnx_output_file}" + ) + + remove_softmax_loss_outputs(onnx_output_file, onnx_output_file) + print(f"✅ Successfully removed Softmax Loss outputs. Saved as {onnx_output_file}") + + remove_softmax_grad_loss_inputs(onnx_output_file, onnx_output_file) + print( + f"✅ Successfully removed Softmax Grad Loss inputs. Saved as {onnx_output_file}" + ) + + print_onnx_shapes(onnx_output_file) + +def rename_nodes(model_path, output_path): + """ + Rename nodes in an ONNX model by replacing all characters that are invalid + for C variable names with underscores. + + Args: + model_path: Path to the input ONNX model + output_path: Path to save the renamed model + """ + # Load the model + model = onnx.load(model_path) + + # Create a map to store original to new name mappings + name_map = {} + + # Helper function to replace invalid C variable name characters with underscores + def clean_name(name): + if name is None: + return None + # Replace any character that's not alphanumeric or underscore with underscore + # Ensure name starts with a letter or underscore (C variable rule) + cleaned = re.sub(r'[^a-zA-Z0-9_]', '_', name) + if cleaned and cleaned[0].isdigit(): + cleaned = '_' + cleaned + return cleaned + + # Process graph inputs + for input in model.graph.input: + if input.name: + new_name = clean_name(input.name) + name_map[input.name] = new_name + input.name = new_name + + # Process graph outputs + for output in model.graph.output: + if output.name: + new_name = clean_name(output.name) + name_map[output.name] = new_name + output.name = new_name + + # Process initializers + for initializer in model.graph.initializer: + if initializer.name: + new_name = clean_name(initializer.name) + name_map[initializer.name] = new_name + initializer.name = new_name + + # Process nodes + for node in model.graph.node: + # Rename node name if it exists + if node.name: + node.name = clean_name(node.name) + + # Rename node inputs + for i, input_name in enumerate(node.input): + if input_name in name_map: + node.input[i] = name_map[input_name] + else: + new_name = clean_name(input_name) + if new_name != input_name: + name_map[input_name] = new_name + node.input[i] = new_name + + # Rename node outputs + for i, output_name in enumerate(node.output): + if output_name in name_map: + node.output[i] = name_map[output_name] + else: + new_name = clean_name(output_name) + if new_name != output_name: + name_map[output_name] = new_name + node.output[i] = new_name + + # Rename attribute names if they contain node names + for attribute in node.attribute: + if attribute.type == onnx.AttributeProto.GRAPH: + # Handle subgraphs if present (recursive call would be needed) + pass + elif attribute.type == onnx.AttributeProto.STRINGS: + # Handle string attributes that might contain node names + for i, s in enumerate(attribute.strings): + s_str = s.decode('utf-8') if isinstance(s, bytes) else s + if s_str in name_map: + attribute.strings[i] = name_map[s_str].encode('utf-8') + + # Check for any value_info that might need renaming + for value_info in model.graph.value_info: + if value_info.name: + new_name = clean_name(value_info.name) + name_map[value_info.name] = new_name + value_info.name = new_name + + # Save the updated model + onnx.save(model, output_path) + + return name_map + +def randomize_layernorm_params(model): + for name, module in model.named_modules(): + if isinstance(module, torch.nn.LayerNorm): + with torch.no_grad(): + module.weight.data = module.weight.data + torch.randn_like(module.weight.data) * 1e-6 + module.bias.data = module.bias.data + torch.randn_like(module.bias.data) * 1e-6 + + return model + +def randomize_onnx_initializers(model, seed=None, exclude_patterns=None): + if seed is not None: + np.random.seed(seed) + random.seed(seed) + + if exclude_patterns is None: + exclude_patterns = ["const", "shape", "Constant"] + + graph = model.graph + + modified_count = 0 + zero_count = 0 + skipped_count = 0 + + for initializer in graph.initializer: + # Skip initializers with specific patterns in their names + if any(pattern in initializer.name for pattern in exclude_patterns): + skipped_count += 1 + continue + + # Convert initializer to numpy array + np_array = numpy_helper.to_array(initializer) + + # Check if array contains only zeros + if np.all(np_array == 0): + zero_count += 1 + + # Determine appropriate scale based on tensor dimension and type + if np_array.dtype == np.float32 or np_array.dtype == np.float64: + # Use Kaiming/He initialization for weights + if len(np_array.shape) > 1: + fan_in = np_array.shape[0] + scale = np.sqrt(2.0 / fan_in) + np_array = np.random.normal(0, scale, np_array.shape).astype(np_array.dtype) + else: + # For bias terms or 1D tensors + np_array = np.random.uniform(-0.1, 0.1, np_array.shape).astype(np_array.dtype) + elif np_array.dtype == np.int64 or np_array.dtype == np.int32: + # For integer tensors (e.g., indices) + max_val = min(100, 2**(np_array.itemsize*8 - 1) - 1) # Avoid overflow + np_array = np.random.randint(-max_val, max_val, np_array.shape).astype(np_array.dtype) + + # Create new tensor from modified numpy array + new_tensor = numpy_helper.from_array(np_array, initializer.name) + + # Replace original initializer with new tensor + initializer.CopyFrom(new_tensor) + modified_count += 1 + + + print(f"Randomization complete:") + print(f"- Total initializers: {len(graph.initializer)}") + print(f"- Zero initializers found and randomized: {zero_count}") + print(f"- Skipped initializers (based on patterns): {skipped_count}") + print(f"- Modified initializers: {modified_count}") + + return model + +def type_inference(input_model_path, output_model_path): + """ + Perform type inference on ONNX model, setting float32 type for variables without explicit types. + + Args: + input_model_path: Input ONNX model path + output_model_path: Output ONNX model path + """ + # Load the ONNX model + model = onnx.load(input_model_path) + graph = model.graph + + # Process input tensors + for input_tensor in graph.input: + if not input_tensor.type.tensor_type.elem_type: + print(f"Setting input variable {input_tensor.name} type to FLOAT") + input_tensor.type.tensor_type.elem_type = TensorProto.FLOAT + + # Process output tensors + for output_tensor in graph.output: + if not output_tensor.type.tensor_type.elem_type: + print(f"Setting output variable {output_tensor.name} type to FLOAT") + output_tensor.type.tensor_type.elem_type = TensorProto.FLOAT + + # Process intermediate variables + for value_info in graph.value_info: + if not value_info.type.tensor_type.elem_type: + print(f"Setting intermediate variable {value_info.name} type to FLOAT") + value_info.type.tensor_type.elem_type = TensorProto.FLOAT + + # Check for any tensors mentioned in nodes but missing type info + processed_tensors = set([tensor.name for tensor in graph.input] + + [tensor.name for tensor in graph.output] + + [tensor.name for tensor in graph.value_info]) + + missing_tensors = set() + for node in graph.node: + for input_name in node.input: + if input_name not in processed_tensors and input_name: + missing_tensors.add(input_name) + for output_name in node.output: + if output_name not in processed_tensors and output_name: + missing_tensors.add(output_name) + + # Add missing tensors to value_info with FLOAT type (keeping existing shapes) + for tensor_name in missing_tensors: + # Create a basic ValueInfo for the tensor (shape will be inferred by ONNX) + tensor_value_info = onnx.helper.make_tensor_value_info( + name=tensor_name, + elem_type=TensorProto.FLOAT, + shape=None # Let ONNX infer the shape + ) + graph.value_info.append(tensor_value_info) + print(f"Added missing tensor {tensor_name} with FLOAT type") + + # Save the modified model + onnx.save(model, output_model_path) + print(f"Successfully saved type-inferred model to {output_model_path}") \ No newline at end of file diff --git a/Tests/TestCCT.py b/Tests/TestCCT.py new file mode 100644 index 0000000..bc1d0e7 --- /dev/null +++ b/Tests/TestCCT.py @@ -0,0 +1,460 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Tuple + +import brevitas.nn as qnn +import pytest +import torch +import torch.nn as nn +from brevitas.fx.brevitas_tracer import symbolic_trace +from brevitas.graph.quantize import preprocess_for_quantize, quantize +from brevitas.graph.utils import replace_all_uses_except +from brevitas.quant import ( + Int8ActPerTensorFloat, + Int8WeightPerTensorFloat, + Int32Bias, + Uint8ActPerTensorFloat, +) + +from DeepQuant.Transforms.Executor import TransformationExecutor +from DeepQuant.Transforms.Transformations import LinearTransformation, MHATransformation +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc +from DeepQuant.Utils.CustomTracer import QuantTracer, customBrevitasTrace +from DeepQuant.Utils.GraphPrinter import GraphModulePrinter +from Tests.Models.CCT.CCT.cct import cct_2_3x2_32 + + +def injectCustomForwards( + model: nn.Module, + exampleInput: torch.Tensor, + referenceOutput: torch.Tensor, + debug: bool = False, + checkEquivalence: bool = False, +) -> Tuple[nn.Module, torch.Tensor]: + """Custom inject function for CCT that excludes ActivationTransformation.""" + printer = GraphModulePrinter() + + tracer = QuantTracer(debug=debug) + + transformations = [ + MHATransformation(), + LinearTransformation(), + # ActivationTransformation(), # FBRANCASI: Commented out for CCT compatibility + ] + + executor = TransformationExecutor(transformations, debug=debug, tracer=tracer) + transformedModel = executor.execute(model, exampleInput) + + fxModel = customBrevitasTrace( + root=transformedModel, + tracer=tracer, + ) + fxModel.recompile() + + with torch.no_grad(): + output = fxModel(exampleInput) + + if checkEquivalence: + if torch.allclose(referenceOutput, output, atol=1e-5): + if debug: + print(cc.success("Injection of New Modules: output is consistent")) + else: + raise RuntimeError( + cc.error("Injection of New Modules changed the output significantly") + ) + + if debug: + print(cc.header("2. Network after Injection of New Modules")) + printer.printTabular(fxModel) + print() + + return fxModel, output + + +def prepareCCT(model) -> nn.Module: + """ + Prepare a quantized CCT model for testing with export support. + """ + + if not hasattr(model, "graph"): + model = symbolic_trace(model) + + print("=== FIXING QUANTIZATION ISSUES ===") + + transpose_fixes = [] + qkv_fixes = [] + matmul_fixes = [] + + # FBRANCASI: Fix 1, Find transpose -> add patterns + for node in model.graph.nodes: + if node.op == "call_method" and node.target == "transpose": + for user in node.users: + if ( + "add" in user.name + or user.target in [torch.add] + or (user.op == "call_method" and user.target in ["add", "add_"]) + ): + transpose_fixes.append((node, user)) + break + + # FBRANCASI: Fix 2, Find QKV -> reshape patterns + for node in model.graph.nodes: + if node.op == "call_module" and "qkv" in node.target: + for user in node.users: + if user.op == "call_method" and user.target == "reshape": + qkv_fixes.append((node, user)) + break + + # FBRANCASI: Fix 3, Find matmul operations that need dequantization + for node in model.graph.nodes: + if node.op == "call_function" and node.target == torch.matmul: + matmul_fixes.append(node) + elif node.op == "call_method" and node.target == "matmul": + matmul_fixes.append(node) + elif ( + node.op == "call_function" + and hasattr(node.target, "__name__") + and node.target.__name__ == "matmul" + ): + matmul_fixes.append(node) + elif hasattr(node, "name") and "matmul" in node.name: + matmul_fixes.append(node) + elif ( + node.op == "call_function" + and hasattr(node.target, "__module__") + and node.target.__module__ == "operator" + and hasattr(node.target, "__name__") + and node.target.__name__ == "matmul" + ): + matmul_fixes.append(node) + + # FBRANCASI: Apply transpose fixes + print(f"\nApplying {len(transpose_fixes)} transpose fixes...") + for node, user in transpose_fixes: + print(f" Fixing: {node.name} -> {user.name}") + + quant_identity = qnn.QuantIdentity( + act_quant=Int8ActPerTensorFloat, return_quant_tensor=True + ) + + quant_name = f"{node.name}_quant_fix" + model.add_module(quant_name, quant_identity) + + with model.graph.inserting_after(node): + quant_node = model.graph.call_module(quant_name, args=(node,)) + + # Replace uses + replace_all_uses_except(node, quant_node, [quant_node]) + + # FBRANCASI: Apply QKV fixes + print(f"\nApplying {len(qkv_fixes)} QKV fixes...") + for node, reshape_user in qkv_fixes: + print(f" Fixing: {node.name} -> {reshape_user.name}") + + quant_identity = qnn.QuantIdentity( + act_quant=Int8ActPerTensorFloat, + return_quant_tensor=False, # FBRANCASI: return regular tensor for reshape + ) + + quant_name = f"{node.name}_reshape_fix" + model.add_module(quant_name, quant_identity) + + with model.graph.inserting_after(node): + quant_node = model.graph.call_module(quant_name, args=(node,)) + + reshape_user.update_arg(0, quant_node) + + # FBRANCASI: Apply matmul fixes + print(f"\nApplying {len(matmul_fixes)} matmul fixes...") + for node in matmul_fixes: + print( + f" Fixing matmul: {node.name}, args: {[arg.name if hasattr(arg, 'name') else str(arg) for arg in node.args]}" + ) + + # FBRANCASI: Add dequantization before both inputs of matmul + for i, arg in enumerate(node.args): + if isinstance(arg, torch.fx.Node): + print(f" Processing arg {i}: {arg.name}") + dequant_identity = qnn.QuantIdentity( + act_quant=Int8ActPerTensorFloat, + return_quant_tensor=False, # FBRANCASI: Return regular tensor for matmul + ) + + dequant_name = f"{arg.name}_matmul_dequant_{i}" + model.add_module(dequant_name, dequant_identity) + + with model.graph.inserting_before(node): + dequant_node = model.graph.call_module(dequant_name, args=(arg,)) + + # Update the matmul argument + node.update_arg(i, dequant_node) + print(f" Updated arg {i} to: {dequant_node.name}") + + model.recompile() + model.graph.lint() + + print("\n=== GRAPH MODIFICATION COMPLETE ===") + + # Debug: Print graph structure to understand the flow + print("\n=== DEBUG: Graph structure after fixes ===") + for node in model.graph.nodes: + if ( + "matmul" in node.name + or (node.op == "call_method" and node.target == "transpose") + or "permute" in node.name + ): + print( + f"Node: {node.name}, op: {node.op}, target: {node.target}, args: {[arg.name if hasattr(arg, 'name') else str(arg) for arg in node.args]}" + ) + # Print users of permute and transpose nodes + if "permute" in node.name or ( + node.op == "call_method" and node.target == "transpose" + ): + print(f" Users: {[user.name for user in node.users]}") + + # FBRANCASI: First pass - identify which Linear layers feed into matmul through permute/transpose + linear_to_matmul = set() + for node in model.graph.nodes: + if hasattr(node, "name") and "matmul" in node.name: + # Trace back through the args to find Linear layers + for arg in node.args: + if isinstance(arg, torch.fx.Node): + # Check if this path leads back to a linear layer + current = arg + visited = set() + while current and current not in visited: + visited.add(current) + if current.op == "call_module" and any( + proj in current.target + for proj in ["q_proj", "k_proj", "v_proj"] + ): + linear_to_matmul.add(current.target) + break + # Trace back through the first argument + if current.args and isinstance(current.args[0], torch.fx.Node): + current = current.args[0] + else: + break + + print(f"\nLinear layers that feed into matmul: {linear_to_matmul}") + + computeLayerMap = { + nn.Conv2d: ( + qnn.QuantConv2d, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": False, + "return_quant_tensor": True, + "output_bit_width": 8, + "weight_bit_width": 8, + }, + ), + nn.Linear: ( + qnn.QuantLinear, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "return_quant_tensor": True, # FBRANCASI: We'll handle this specially for q,k,v projections + "output_bit_width": 8, + "weight_bit_width": 8, + }, + ), + } + + quantActMap = {} + + quantIdentityMap = { + "signed": ( + qnn.QuantIdentity, + { + "act_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + "unsigned": ( + qnn.QuantIdentity, + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + } + + model = preprocess_for_quantize( + model, + equalize_iters=10, + equalize_scale_computation="range", + trace_model=False, # FBRANCASI: Already traced + ) + + quantizedModel = quantize( + graph_model=model, + compute_layer_map=computeLayerMap, + quant_act_map=quantActMap, + quant_identity_map=quantIdentityMap, + ) + + # FBRANCASI: Apply post-quantization fixes for matmul operations + print("\n=== POST-QUANTIZATION FIXES ===") + + nodes_needing_dequant = set() + + node_map = {node.name: node for node in quantizedModel.graph.nodes} + + import operator + + for node in quantizedModel.graph.nodes: + # FBRANCASI: Look for @ operator (represented as call_function with operator.matmul) + is_matmul = False + if node.op == "call_function": + if node.target == operator.matmul: + is_matmul = True + elif node.target == torch.matmul: + is_matmul = True + elif hasattr(node.target, "__name__") and node.target.__name__ == "matmul": + is_matmul = True + + if is_matmul: + print(f"\nFound matmul node: {node.name}") + print(f" Target: {node.target}") + print(f" Args: {node.args}") + print(f" Arg types: {[type(arg) for arg in node.args]}") + + # FBRANCASI: Mark both arguments as needing dequantization + for i, arg in enumerate(node.args): + print(f" Checking arg {i}: type={type(arg)}") + if hasattr(arg, "name") and hasattr(arg, "op"): + nodes_needing_dequant.add(arg) + print(f" Added node to dequant: {arg.name}") + else: + print(f" Skipped arg {i}: {arg}") + + print(f"\nNodes needing dequantization: {[n.name for n in nodes_needing_dequant]}") + + # FBRANCASI: Insert dequantization for each node that feeds into matmul + dequant_nodes = {} + for node in nodes_needing_dequant: + print(f"\nAdding dequantization after node: {node.name}") + + dequant_identity = qnn.QuantIdentity( + act_quant=Int8ActPerTensorFloat, + return_quant_tensor=False, # FBRANCASI: Return regular tensor for matmul + ) + + dequant_name = f"{node.name}_dequant_for_matmul" + quantizedModel.add_module(dequant_name, dequant_identity) + + with quantizedModel.graph.inserting_after(node): + dequant_node = quantizedModel.graph.call_module(dequant_name, args=(node,)) + + dequant_nodes[node] = dequant_node + + for user in list(node.users): + is_matmul_user = False + if user.op == "call_function": + if user.target == operator.matmul or user.target == torch.matmul: + is_matmul_user = True + elif ( + hasattr(user.target, "__name__") + and user.target.__name__ == "matmul" + ): + is_matmul_user = True + elif ( + hasattr(user.target, "__module__") + and user.target.__module__ == "operator" + and hasattr(user.target, "__name__") + and user.target.__name__ == "matmul" + ): + is_matmul_user = True + + if is_matmul_user: + print(f" Updating matmul {user.name} to use dequantized input") + new_args = [] + for i, arg in enumerate(user.args): + if arg == node: + new_args.append(dequant_node) + print( + f" Updated arg {i} from {node.name} to {dequant_node.name}" + ) + else: + new_args.append(arg) + user.args = tuple(new_args) + + quantizedModel.recompile() + quantizedModel.graph.lint() + + print("\n=== POST-QUANTIZATION FIXES COMPLETE ===") + + return quantizedModel + + +@pytest.mark.ModelTests +def deepQuantTestCCT(): + torch.manual_seed(42) + sampleInput = torch.randn(1, 3, 32, 32) + + model = cct_2_3x2_32() # FBRANCASI: 2 encoder layers, kernel dim 3, 2 convs, 32x32 + model.eval() + + print(model) + + quantizedModel = prepareCCT(model) + + print(f"\nTesting the Quantized Model with input shape: {sampleInput.shape}") + with torch.no_grad(): + output = quantizedModel(sampleInput) + print(f"Output shape: {output.shape}") + print(f"Output range: [{output.min().item():.3f}, {output.max().item():.3f}]") + + # FBRANCASI: Override the injectCustomForwards function in the module before DeepQuant.Export imports it + import DeepQuant.Pipeline.Injection as injection_module + + # FBRANCASI: Store original function + original_inject = injection_module.injectCustomForwards + + # FBRANCASI: Override with our custom function + injection_module.injectCustomForwards = injectCustomForwards + + # FBRANCASI: Force reload of Export module to pick up the override + import importlib + + import DeepQuant.Export + + importlib.reload(DeepQuant.Export) + + try: + from DeepQuant.Export import brevitasToTrueQuant + + quantizedModel.eval() + brevitasToTrueQuant(quantizedModel, sampleInput, debug=True) + finally: + # FBRANCASI: Restore original function and reload Export module again + injection_module.injectCustomForwards = original_inject + importlib.reload(DeepQuant.Export) + importlib.reload(DeepQuant.Export) + + # FBRANCASI: Important note + # Right now ONNX is not exporting the graph with GELUs folded and some nodes dont have shapes. + # + # If you need to use this ONNX in Deeploy (https://github.com/pulp-platform/Deeploy), please run + # these commands on the generated network.onnx to fix these problems that can arise in Deeploy: + # + # > python -m onnxruntime.transformers.optimizer --input Tests/ONNX/network.onnx --output network.onnx + # --model_type vit --num_heads 6 --hidden_size 384 --use_multi_head_attention --disable_bias_gelu + # --disable_bias_skip_layer_norm --disable_skip_layer_norm --use_multi_head_attention --opt_level 0 + # + # > python -m onnxruntime.tools.symbolic_shape_infer --input network.onnx --output network.onnx + # + # Also, if you have duplicated shared Floor constants in the graph (this will create problems in + # Deeploy), you can fix this using the script FixCTT2Graph.py under the Utils folder of DeepQuant diff --git a/Tests/TestCCTPretrained.py b/Tests/TestCCTPretrained.py new file mode 100644 index 0000000..7808938 --- /dev/null +++ b/Tests/TestCCTPretrained.py @@ -0,0 +1,517 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +from typing import Tuple + +import brevitas.nn as qnn +import pytest +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from brevitas.fx.brevitas_tracer import symbolic_trace +from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.quantize import preprocess_for_quantize, quantize +from brevitas.graph.utils import replace_all_uses_except +from brevitas.quant import ( + Int8ActPerTensorFloat, + Int8WeightPerTensorFloat, + Int32Bias, + Uint8ActPerTensorFloat, +) +from torch.utils.data import DataLoader, Subset +from tqdm import tqdm + +from DeepQuant.Transforms.Executor import TransformationExecutor +from DeepQuant.Transforms.Transformations import LinearTransformation, MHATransformation +from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc +from DeepQuant.Utils.CustomTracer import QuantTracer, customBrevitasTrace +from DeepQuant.Utils.GraphPrinter import GraphModulePrinter +from Tests.Models.CCT.CCT.cct import cct_2_3x2_32 + + +def injectCustomForwards( + model: nn.Module, + exampleInput: torch.Tensor, + referenceOutput: torch.Tensor, + debug: bool = False, + checkEquivalence: bool = False, +) -> Tuple[nn.Module, torch.Tensor]: + """Custom inject function for CCT that excludes ActivationTransformation.""" + printer = GraphModulePrinter() + + tracer = QuantTracer(debug=debug) + + transformations = [ + MHATransformation(), + LinearTransformation(), + # ActivationTransformation(), # FBRANCASI: Commented out for CCT compatibility + ] + + executor = TransformationExecutor(transformations, debug=debug, tracer=tracer) + transformedModel = executor.execute(model, exampleInput) + + fxModel = customBrevitasTrace( + root=transformedModel, + tracer=tracer, + ) + fxModel.recompile() + + with torch.no_grad(): + output = fxModel(exampleInput) + + if checkEquivalence: + if torch.allclose(referenceOutput, output, atol=1e-5): + if debug: + print(cc.success("Injection of New Modules: output is consistent")) + else: + raise RuntimeError( + cc.error("Injection of New Modules changed the output significantly") + ) + + if debug: + print(cc.header("2. Network after Injection of New Modules")) + printer.printTabular(fxModel) + print() + + return fxModel, output + + +def evaluateModel(model, dataLoader, evalDevice, name="Model"): + model.eval() + correctTop1 = 0 + correctTop5 = 0 + total = 0 + + with torch.no_grad(): + for inputs, targets in tqdm(dataLoader, desc=f"Evaluating {name}"): + isTQ = "TQ" in name + + if isTQ: + # FBRANCASI: Process different batches for the TQ model + for i in range(inputs.size(0)): + singleInput = inputs[i : i + 1].to(evalDevice) + singleOutput = model(singleInput) + + _, predicted = singleOutput.max(1) + if predicted.item() == targets[i].item(): + correctTop1 += 1 + + _, top5Pred = singleOutput.topk(5, dim=1, largest=True, sorted=True) + if targets[i].item() in top5Pred[0].cpu().numpy(): + correctTop5 += 1 + + total += 1 + else: + inputs = inputs.to(evalDevice) + targets = targets.to(evalDevice) + output = model(inputs) + + _, predicted = output.max(1) + correctTop1 += (predicted == targets).sum().item() + + _, top5Pred = output.topk(5, dim=1, largest=True, sorted=True) + for i in range(targets.size(0)): + if targets[i] in top5Pred[i]: + correctTop5 += 1 + + total += targets.size(0) + + top1Accuracy = 100.0 * correctTop1 / total + top5Accuracy = 100.0 * correctTop5 / total + + print( + f"{name} - Top-1 Accuracy: {top1Accuracy:.2f}% ({correctTop1}/{total}), " + f"Top-5 Accuracy: {top5Accuracy:.2f}%" + ) + + return top1Accuracy, top5Accuracy + + +def calibrateModel(model, calibLoader): + model.eval() + with torch.no_grad(), calibration_mode(model): + for inputs, _ in tqdm(calibLoader, desc="Calibrating model"): + inputs = inputs.to("cpu") + model(inputs) + print("Calibration completed.") + + +def prepareFQCCT(model) -> nn.Module: + """ + Prepare a quantized CCT model for testing with export support. + """ + + if not hasattr(model, "graph"): + model = symbolic_trace(model) + + print("=== FIXING QUANTIZATION ISSUES ===") + + transpose_fixes = [] + qkv_fixes = [] + matmul_fixes = [] + + # FBRANCASI: Fix 1, Find transpose -> add patterns + for node in model.graph.nodes: + if node.op == "call_method" and node.target == "transpose": + for user in node.users: + if ( + "add" in user.name + or user.target in [torch.add] + or (user.op == "call_method" and user.target in ["add", "add_"]) + ): + transpose_fixes.append((node, user)) + break + + # FBRANCASI: Fix 2, Find QKV -> reshape patterns + for node in model.graph.nodes: + if node.op == "call_module" and "qkv" in node.target: + for user in node.users: + if user.op == "call_method" and user.target == "reshape": + qkv_fixes.append((node, user)) + break + + # FBRANCASI: Fix 3, Find matmul operations that need dequantization (Run version) + for node in model.graph.nodes: + if node.op == "call_function" and node.target == torch.matmul: + matmul_fixes.append(node) + elif node.op == "call_method" and node.target == "__matmul__": + matmul_fixes.append(node) + elif hasattr(node, "target") and str(node.target) == "matmul": + matmul_fixes.append(node) + + # FBRANCASI: Apply transpose fixes + print(f"\nApplying {len(transpose_fixes)} transpose fixes...") + for node, user in transpose_fixes: + print(f" Fixing: {node.name} -> {user.name}") + + quant_identity = qnn.QuantIdentity( + act_quant=Int8ActPerTensorFloat, return_quant_tensor=True + ) + + quant_name = f"{node.name}_quant_fix" + model.add_module(quant_name, quant_identity) + + with model.graph.inserting_after(node): + quant_node = model.graph.call_module(quant_name, args=(node,)) + + # Replace uses + replace_all_uses_except(node, quant_node, [quant_node]) + + # FBRANCASI: Apply QKV fixes + print(f"\nApplying {len(qkv_fixes)} QKV fixes...") + for node, reshape_user in qkv_fixes: + print(f" Fixing: {node.name} -> {reshape_user.name}") + + quant_identity = qnn.QuantIdentity( + act_quant=Int8ActPerTensorFloat, + return_quant_tensor=False, # FBRANCASI: return regular tensor for reshape + ) + + quant_name = f"{node.name}_reshape_fix" + model.add_module(quant_name, quant_identity) + + with model.graph.inserting_after(node): + quant_node = model.graph.call_module(quant_name, args=(node,)) + + reshape_user.update_arg(0, quant_node) + + # FBRANCASI: Note matmul fixes found for later processing + print(f"\nFound {len(matmul_fixes)} matmul operations for post-quantization fixing") + + model.recompile() + model.graph.lint() + + # FBRANCASI: Print graph structure for debugging (Run version) + print("\n=== GRAPH STRUCTURE AFTER INITIAL FIXES ===") + for node in model.graph.nodes: + if node.op != "placeholder" and node.op != "output": + print(f" {node.name}: {node.op} - {node.target}") + + print("\n=== GRAPH MODIFICATION COMPLETE ===") + + computeLayerMap = { + nn.Conv2d: ( + qnn.QuantConv2d, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": False, + "return_quant_tensor": True, + "output_bit_width": 8, + "weight_bit_width": 8, + }, + ), + # FBRANCASI: Linear layers ENABLED in Run version + nn.Linear: ( + qnn.QuantLinear, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "return_quant_tensor": True, + "output_bit_width": 8, + "weight_bit_width": 8, + }, + ), + } + + quantActMap = {} + + quantIdentityMap = { + "signed": ( + qnn.QuantIdentity, + { + "act_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + "unsigned": ( + qnn.QuantIdentity, + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + } + + model = preprocess_for_quantize( + model, + equalize_iters=10, + equalize_scale_computation="range", + trace_model=False, # FBRANCASI: Already traced + ) + + quantizedModel = quantize( + graph_model=model, + compute_layer_map=computeLayerMap, + quant_act_map=quantActMap, + quant_identity_map=quantIdentityMap, + ) + + # FBRANCASI: Post-quantization matmul dequantization (Run version specific) + print("\n=== POST-QUANTIZATION MATMUL FIXES ===") + + # Find all matmul operations using @ operator + matmul_nodes = [] + for node in quantizedModel.graph.nodes: + if hasattr(node, "target") and ( + (hasattr(node.target, "__name__") and node.target.__name__ == "matmul") + or str(node.target) == "" + or (node.op == "call_function" and node.target == torch.matmul) + ): + matmul_nodes.append(node) + print(f"Found matmul node: {node.name}") + + print(f"\nTotal matmul nodes found: {len(matmul_nodes)}") + + # For each matmul, trace back to find linear layers and insert dequantization + for matmul_node in matmul_nodes: + print(f"\nProcessing matmul node: {matmul_node.name}") + + # Check both arguments of matmul + for arg_idx, arg in enumerate(matmul_node.args): + if hasattr(arg, "op"): + print(f" Checking arg {arg_idx}: {arg.name}") + + # Trace back to find if this comes from a linear layer + def find_linear_source(node, visited=None): + if visited is None: + visited = set() + if node in visited: + return None + visited.add(node) + + if node.op == "call_module" and isinstance( + quantizedModel.get_submodule(node.target), qnn.QuantLinear + ): + return node + + # Check node inputs + for inp in node.all_input_nodes: + result = find_linear_source(inp, visited) + if result: + return result + return None + + linear_source = find_linear_source(arg) + + if linear_source: + print(f" Found linear source: {linear_source.name}") + + # Insert dequantization after the argument node + dequant_identity = qnn.QuantIdentity( + act_quant=Int8ActPerTensorFloat, + return_quant_tensor=False, # Return regular tensor + ) + + dequant_name = f"{arg.name}_matmul_dequant" + quantizedModel.add_module(dequant_name, dequant_identity) + + with quantizedModel.graph.inserting_after(arg): + dequant_node = quantizedModel.graph.call_module( + dequant_name, args=(arg,) + ) + + # Update matmul to use dequantized input + matmul_node.update_arg(arg_idx, dequant_node) + print(f" Inserted dequantization: {dequant_name}") + + quantizedModel.recompile() + quantizedModel.graph.lint() + + print("\n=== FINAL QUANTIZATION COMPLETE ===") + + return quantizedModel + + +@pytest.mark.ModelTests +def deepQuantTestCCT(): + torch.manual_seed(42) + + # FBRANCASI: Setup CIFAR-10 dataset + transformsVal = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ] + ) + + dataset = torchvision.datasets.CIFAR10( + root="./Tests/Data/CIFAR", train=False, download=True, transform=transformsVal + ) + + DATASET_LIMIT = 256 + dataset = Subset(dataset, list(range(DATASET_LIMIT))) + print(f"Validation dataset size set to {len(dataset)} images.") + + calibLoader = DataLoader( + Subset(dataset, list(range(128))), batch_size=32, shuffle=False, pin_memory=True + ) + valLoader = DataLoader(dataset, batch_size=32, shuffle=False, pin_memory=True) + + # FBRANCASI: Device setup + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("mps" if torch.backends.mps.is_available() else device) + print(f"Using device: {device}") + + # FBRANCASI: Load original floating point model + originalModel = cct_2_3x2_32() + checkpointPath = "./Tests/Data/checkpoint_epoch_200_cct2_cifar10.pth" + checkpoint = torch.load(checkpointPath, map_location="cpu", weights_only=False) + + # FBRANCASI: Convert state dict from qkv to q_proj, k_proj, v_proj format + original_state_dict = checkpoint["model_state_dict"] + converted_state_dict = {} + + for key, value in original_state_dict.items(): + if "qkv.weight" in key: + # Split QKV weight into separate Q, K, V weights + dim = value.shape[0] // 3 + q_weight = value[:dim] + k_weight = value[dim : 2 * dim] + v_weight = value[2 * dim :] + + # Create new keys for separate projections + base_key = key.replace("qkv.weight", "") + converted_state_dict[base_key + "q_proj.weight"] = q_weight + converted_state_dict[base_key + "k_proj.weight"] = k_weight + converted_state_dict[base_key + "v_proj.weight"] = v_weight + else: + # Keep all other weights as is + converted_state_dict[key] = value + + originalModel.load_state_dict(converted_state_dict) + originalModel = originalModel.eval().to(device) + print("Original CCT-2 loaded from checkpoint with converted attention weights.") + + print("Evaluating original model...") + originalTop1, originalTop5 = evaluateModel( + originalModel, valLoader, device, "Original CCT-2" + ) + + print("Preparing and quantizing CCT-2...") + FQModel = prepareFQCCT(originalModel.to("cpu")) + + print("Calibrating FQ model...") + calibrateModel(FQModel, calibLoader) + + print("Evaluating FQ model...") + # FBRANCASI: Use CPU for brevitas models + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + FQTop1, FQTop5 = evaluateModel(FQModel, valLoader, device, "FQ CCT-2") + + sampleInput = torch.randn(1, 3, 32, 32).to("cpu") + + # FBRANCASI: Override the injectCustomForwards function in the module before DeepQuant.Export imports it + import DeepQuant.Pipeline.Injection as injection_module + + # FBRANCASI: Store original function + original_inject = injection_module.injectCustomForwards + + # FBRANCASI: Override with our custom function + injection_module.injectCustomForwards = injectCustomForwards + + # FBRANCASI: Force reload of Export module to pick up the override + import importlib + + import DeepQuant.Export + + importlib.reload(DeepQuant.Export) + + try: + from DeepQuant.Export import brevitasToTrueQuant + + TQModel = brevitasToTrueQuant(FQModel, sampleInput, debug=True) + finally: + # FBRANCASI: Restore original function and reload Export module again + injection_module.injectCustomForwards = original_inject + importlib.reload(DeepQuant.Export) + + numParameters = sum(p.numel() for p in TQModel.parameters()) + print(f"Number of parameters: {numParameters:,}") + + print("Evaluating TQ model...") + TQTop1, TQTop5 = evaluateModel(TQModel, valLoader, device, "TQ CCT-2") + + print("\nComparison Summary:") + print(f"{'Model':<25} {'Top-1 Accuracy':<25} {'Top-5 Accuracy':<25}") + print("-" * 75) + print(f"{'Original CCT-2':<25} {originalTop1:<24.2f} {originalTop5:<24.2f}") + print(f"{'FQ CCT-2':<25} {FQTop1:<24.2f} {FQTop5:<24.2f}") + print(f"{'TQ CCT-2':<25} {TQTop1:<24.2f} {TQTop5:<24.2f}") + print( + f"{'FQ Drop':<25} {originalTop1 - FQTop1:<24.2f} {originalTop5 - FQTop5:<24.2f}" + ) + print( + f"{'TQ Drop':<25} {originalTop1 - TQTop1:<24.2f} {originalTop5 - TQTop5:<24.2f}" + ) + + if abs(FQTop1 - TQTop1) > 5.0: + print( + f"Warning: Large accuracy drop between FQ and TQ models. " + f"Difference: {abs(FQTop1 - TQTop1):.2f}%" + ) + + # FBRANCASI: Important note + # Right now ONNX is not exporting the graph with GELUs folded and some nodes dont have shapes. + # + # If you need to use this ONNX in Deeploy (https://github.com/pulp-platform/Deeploy), please run + # these commands on the generated network.onnx to fix these problems that can arise in Deeploy: + # + # > python -m onnxruntime.transformers.optimizer --input Tests/ONNX/network.onnx --output network.onnx + # --model_type vit --num_heads 6 --hidden_size 384 --use_multi_head_attention --disable_bias_gelu + # --disable_bias_skip_layer_norm --disable_skip_layer_norm --use_multi_head_attention --opt_level 0 + # + # > python -m onnxruntime.tools.symbolic_shape_infer --input network.onnx --output network.onnx + # + # Also, if you have duplicated shared Floor constants in the graph (this will create problems in + # Deeploy), you can fix this using the script FixCTT2Graph.py under the Utils folder of DeepQuant diff --git a/Tests/TestConv.py b/Tests/TestConv.py index 011612c..b9b5e05 100644 --- a/Tests/TestConv.py +++ b/Tests/TestConv.py @@ -5,22 +5,23 @@ # Victor Jung # Federico Brancasi - +import brevitas.nn as qnn import pytest import torch import torch.nn as nn -import brevitas.nn as qnn from brevitas.quant.scaled_int import ( Int8ActPerTensorFloat, - Int32Bias, Int8WeightPerTensorFloat, + Int32Bias, ) -from DeepQuant.ExportBrevitas import exportBrevitas + +from DeepQuant import brevitasToTrueQuant class QuantConvNet(nn.Module): + """Simple quantized CNN with a single conv layer.""" - convAndLinQuantParams = { + convQuantParams = { "bias": True, "weight_bit_width": 4, "bias_quant": Int32Bias, @@ -30,31 +31,26 @@ class QuantConvNet(nn.Module): "return_quant_tensor": True, } - def __init__(self, in_channels: int = 1) -> None: + def __init__(self, inChannels: int = 1) -> None: super().__init__() self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True) - self.conv1 = qnn.QuantConv2d( - in_channels=in_channels, + in_channels=inChannels, out_channels=16, kernel_size=3, padding=1, - **QuantConvNet.convAndLinQuantParams + **QuantConvNet.convQuantParams, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.inputQuant(x) x = self.conv1(x) - return x @pytest.mark.SingleLayerTests def deepQuantTestConv() -> None: - torch.manual_seed(42) - model = QuantConvNet().eval() sampleInput = torch.randn(1, 1, 28, 28) - exportBrevitas(model, sampleInput, debug=True) + brevitasToTrueQuant(model, sampleInput, debug=True, checkEquivalence=True) diff --git a/Tests/TestLinear.py b/Tests/TestLinear.py index 675653f..39bdfee 100644 --- a/Tests/TestLinear.py +++ b/Tests/TestLinear.py @@ -4,33 +4,28 @@ # # Federico Brancasi - +import brevitas.nn as qnn import pytest - -### PyTorch Imports ### import torch import torch.nn as nn - -### Brevitas Import ### -import brevitas.nn as qnn from brevitas.quant.scaled_int import ( Int8ActPerTensorFloat, - Int32Bias, Int8WeightPerTensorFloat, + Int32Bias, ) -from DeepQuant.ExportBrevitas import exportBrevitas + +from DeepQuant import brevitasToTrueQuant class QuantLinearNet(nn.Module): + """Simple quantized network with a single linear layer.""" - def __init__(self, in_features: int = 16, hidden_features: int = 32) -> None: + def __init__(self, inFeatures: int = 16, hiddenFeatures: int = 32) -> None: super().__init__() - self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True) - self.linear1 = qnn.QuantLinear( - in_features=in_features, - out_features=hidden_features, + in_features=inFeatures, + out_features=hiddenFeatures, bias=True, weight_bit_width=4, bias_quant=Int32Bias, @@ -41,19 +36,14 @@ def __init__(self, in_features: int = 16, hidden_features: int = 32) -> None: ) def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.inputQuant(x) x = self.linear1(x) - return x @pytest.mark.SingleLayerTests def deepQuantTestLinear() -> None: - torch.manual_seed(42) - model = QuantLinearNet().eval() sampleInput = torch.randn(1, 4, 16) - - exportBrevitas(model, sampleInput, debug=True) + brevitasToTrueQuant(model, sampleInput, debug=True, checkEquivalence=True) diff --git a/Tests/TestMHSA.py b/Tests/TestMHSA.py index d5be3a9..96ae93c 100644 --- a/Tests/TestMHSA.py +++ b/Tests/TestMHSA.py @@ -4,39 +4,34 @@ # # Federico Brancasi - +import brevitas.nn as qnn import pytest import torch import torch.nn as nn -import brevitas.nn as qnn -from torch import Tensor -from DeepQuant.ExportBrevitas import exportBrevitas - from brevitas.quant.scaled_int import ( Int8ActPerTensorFloat, - Int32Bias, Int8WeightPerTensorFloat, + Int32Bias, Uint8ActPerTensorFloat, ) +from torch import Tensor + +from DeepQuant import brevitasToTrueQuant class QuantMHSANet(nn.Module): + """Simple quantized network with multi-head self-attention.""" - def __init__(self, embed_dim: int, num_heads: int) -> None: - """ - Args: - embed_dim: The dimension of each embedding vector. - num_heads: The number of attention heads. - """ + def __init__(self, embedDim: int, numHeads: int) -> None: super().__init__() self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True) self.mha = qnn.QuantMultiheadAttention( - embed_dim=embed_dim, - num_heads=num_heads, + embed_dim=embedDim, + num_heads=numHeads, dropout=0.0, bias=True, - packed_in_proj=False, # separate Q, K, V - batch_first=False, # expects (sequence, batch, embed_dim) + packed_in_proj=False, # FBRANCASI: separate Q, K, V + batch_first=False, # FBRANCASI: expects (sequence, batch, embed_dim) in_proj_input_quant=Int8ActPerTensorFloat, in_proj_weight_quant=Int8WeightPerTensorFloat, in_proj_bias_quant=Int32Bias, @@ -51,16 +46,6 @@ def __init__(self, embed_dim: int, num_heads: int) -> None: ) def forward(self, x: Tensor) -> Tensor: - """ - Forward pass that first quantizes the input, then applies multi-head attention. - - Args: - x: Input tensor of shape [sequence_len, batch_size, embed_dim]. - - Returns: - A tuple (output, None) as per the Brevitas MHA API, where output has shape - [sequence_len, batch_size, embed_dim]. - """ x = self.inputQuant(x) out = self.mha(x, x, x) return out @@ -68,10 +53,7 @@ def forward(self, x: Tensor) -> Tensor: @pytest.mark.SingleLayerTests def deepQuantTestMHSA() -> None: - torch.manual_seed(42) - - model = QuantMHSANet(embed_dim=16, num_heads=4).eval() + model = QuantMHSANet(embedDim=16, numHeads=4).eval() sampleInput = torch.randn(10, 2, 16) - - exportBrevitas(model, sampleInput, debug=True) + brevitasToTrueQuant(model, sampleInput, checkEquivalence=True) diff --git a/Tests/TestMobileNetV3Small.py b/Tests/TestMobileNetV3Small.py index 7a36392..308e3fc 100644 --- a/Tests/TestMobileNetV3Small.py +++ b/Tests/TestMobileNetV3Small.py @@ -2,39 +2,30 @@ # Licensed under the Apache License, Version 2.0, see LICENSE for details. # SPDX-License-Identifier: Apache-2.0 # -# Victor Juing +# Victor Jung +import brevitas.nn as qnn import pytest import torch import torch.nn as nn import torchvision.models as models -from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool -import brevitas.nn as qnn +from brevitas.graph.quantize import preprocess_for_quantize, quantize from brevitas.quant import ( Int8ActPerTensorFloat, Int8WeightPerTensorFloat, Int32Bias, Uint8ActPerTensorFloat, ) -from brevitas.graph.quantize import quantize -from DeepQuant.ExportBrevitas import exportBrevitas +from DeepQuant import brevitasToTrueQuant def prepareMBNetV3Model() -> nn.Module: - """ - Prepare a quantized MobileNetV3Small model for testing. - Steps: - 1) Load the torchvision MobileNetV3Small. - 2) Convert it to eval mode. - 3) Preprocess and adapt average pooling. - 4) Quantize it using Brevitas. - - Returns: - A quantized MobileNetV3Small model ready for export tests. - """ - baseModel = models.mobilenet_v3_small(weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1) + """Prepare a quantized MobileNetV3Small model for testing.""" + baseModel = models.mobilenet_v3_small( + weights=models.MobileNet_V3_Small_Weights.IMAGENET1K_V1 + ) baseModel = baseModel.eval() computeLayerMap = { @@ -99,9 +90,7 @@ def prepareMBNetV3Model() -> nn.Module: baseModel = preprocess_for_quantize( baseModel, equalize_iters=20, equalize_scale_computation="range" ) - baseModel = AdaptiveAvgPoolToAvgPool().apply( - baseModel, torch.ones(1, 3, 224, 224) - ) + baseModel = AdaptiveAvgPoolToAvgPool().apply(baseModel, torch.ones(1, 3, 224, 224)) quantizedModel = quantize( graph_model=baseModel, @@ -115,10 +104,7 @@ def prepareMBNetV3Model() -> nn.Module: @pytest.mark.ModelTests def deepQuantTestMobileNetV3Small() -> None: - torch.manual_seed(42) - - quantizedModel = prepareMBNetV3Model() + model = prepareMBNetV3Model() sampleInput = torch.randn(1, 3, 224, 224) - - exportBrevitas(quantizedModel, sampleInput, debug=True) + brevitasToTrueQuant(model, sampleInput, debug=True) diff --git a/Tests/TestResNet18.py b/Tests/TestResNet18.py index 3b62a06..186af74 100644 --- a/Tests/TestResNet18.py +++ b/Tests/TestResNet18.py @@ -5,37 +5,27 @@ # Federico Brancasi +import brevitas.nn as qnn import pytest import torch import torch.nn as nn import torchvision.models as models -from brevitas.graph.quantize import preprocess_for_quantize from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool -import brevitas.nn as qnn +from brevitas.graph.quantize import preprocess_for_quantize, quantize from brevitas.quant import ( Int8ActPerTensorFloat, Int8WeightPerTensorFloat, Int32Bias, Uint8ActPerTensorFloat, ) -from brevitas.graph.quantize import quantize -from DeepQuant.ExportBrevitas import exportBrevitas +from DeepQuant import brevitasToTrueQuant def prepareResnet18Model() -> nn.Module: - """ - Prepare a quantized ResNet18 model for testing. - Steps: - 1) Load the torchvision ResNet18. - 2) Convert it to eval mode. - 3) Preprocess and adapt average pooling. - 4) Quantize it using Brevitas. - - Returns: - A quantized ResNet18 model ready for export tests. - """ + """Prepare a fake-quantized (FQ) ResNet18 model.""" baseModel = models.resnet18(weights=models.ResNet18_Weights.DEFAULT) + baseModel = baseModel.eval() computeLayerMap = { @@ -67,16 +57,7 @@ def prepareResnet18Model() -> nn.Module: ), } - quantActMap = { - nn.ReLU: ( - qnn.QuantReLU, - { - "act_quant": Uint8ActPerTensorFloat, - "return_quant_tensor": True, - "bit_width": 8, - }, - ), - } + quantActMap = {} quantIdentityMap = { "signed": ( @@ -100,9 +81,7 @@ def prepareResnet18Model() -> nn.Module: baseModel = preprocess_for_quantize( baseModel, equalize_iters=20, equalize_scale_computation="range" ) - baseModel = AdaptiveAvgPoolToAvgPool().apply( - baseModel, torch.ones(1, 3, 224, 224) - ) + baseModel = AdaptiveAvgPoolToAvgPool().apply(baseModel, torch.ones(1, 3, 224, 224)) quantizedResnet = quantize( graph_model=baseModel, @@ -118,8 +97,6 @@ def prepareResnet18Model() -> nn.Module: def deepQuantTestResnet18() -> None: torch.manual_seed(42) - quantizedModel = prepareResnet18Model() sampleInput = torch.randn(1, 3, 224, 224) - - exportBrevitas(quantizedModel, sampleInput, debug=True) + brevitasToTrueQuant(quantizedModel, sampleInput, debug=True) diff --git a/Tests/TestResNet18Pretrained.py b/Tests/TestResNet18Pretrained.py new file mode 100644 index 0000000..d0f856c --- /dev/null +++ b/Tests/TestResNet18Pretrained.py @@ -0,0 +1,279 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +import tarfile +import urllib.request +from pathlib import Path + +import brevitas.nn as qnn +import pytest +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool +from brevitas.graph.quantize import preprocess_for_quantize, quantize +from brevitas.quant import ( + Int8ActPerTensorFloat, + Int8WeightPerTensorFloat, + Int32Bias, + Uint8ActPerTensorFloat, +) +from torch.utils.data import DataLoader, Subset +from torchvision.datasets import ImageFolder +from tqdm import tqdm + +from DeepQuant import brevitasToTrueQuant + + +def evaluateModel(model, dataLoader, evalDevice, name="Model"): + model.eval() + correctTop1 = 0 + correctTop5 = 0 + total = 0 + + with torch.no_grad(): + for inputs, targets in tqdm(dataLoader, desc=f"Evaluating {name}"): + isTQ = "TQ" in name + + if isTQ: + # FBRANCASI: Process different batches for the TQ model + for i in range(inputs.size(0)): + singleInput = inputs[i : i + 1].to(evalDevice) + singleOutput = model(singleInput) + + _, predicted = singleOutput.max(1) + if predicted.item() == targets[i].item(): + correctTop1 += 1 + + _, top5Pred = singleOutput.topk(5, dim=1, largest=True, sorted=True) + if targets[i].item() in top5Pred[0].cpu().numpy(): + correctTop5 += 1 + + total += 1 + else: + inputs = inputs.to(evalDevice) + targets = targets.to(evalDevice) + output = model(inputs) + + _, predicted = output.max(1) + correctTop1 += (predicted == targets).sum().item() + + _, top5Pred = output.topk(5, dim=1, largest=True, sorted=True) + for i in range(targets.size(0)): + if targets[i] in top5Pred[i]: + correctTop5 += 1 + + total += targets.size(0) + + top1Accuracy = 100.0 * correctTop1 / total + top5Accuracy = 100.0 * correctTop5 / total + + print( + f"{name} - Top-1 Accuracy: {top1Accuracy:.2f}% ({correctTop1}/{total}), " + f"Top-5 Accuracy: {top5Accuracy:.2f}%" + ) + + return top1Accuracy, top5Accuracy + + +def calibrateModel(model, calibLoader): + model.eval() + with torch.no_grad(), calibration_mode(model): + for inputs, _ in tqdm(calibLoader, desc="Calibrating model"): + inputs = inputs.to("cpu") + model(inputs) + print("Calibration completed.") + + +def prepareFQResNet18(): + """Prepare a fake-quantized (FQ) ResNet18 model.""" + baseModel = torchvision.models.resnet18( + weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1 + ) + baseModel = baseModel.eval().to("cpu") + + computeLayerMap = { + nn.Conv2d: ( + qnn.QuantConv2d, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": True, + "return_quant_tensor": True, + "output_bit_width": 8, + "weight_bit_width": 8, + }, + ), + nn.Linear: ( + qnn.QuantLinear, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": True, + "return_quant_tensor": True, + "output_bit_width": 8, + "weight_bit_width": 8, + }, + ), + } + + quantActMap = {} + + quantIdentityMap = { + "signed": ( + qnn.QuantIdentity, + { + "act_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + "unsigned": ( + qnn.QuantIdentity, + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + } + + dummyInput = torch.ones(1, 3, 224, 224).to("cpu") + + print("Preprocessing model for quantization...") + baseModel = preprocess_for_quantize( + baseModel, equalize_iters=20, equalize_scale_computation="range" + ) + + print("Converting AdaptiveAvgPool to AvgPool...") + baseModel = AdaptiveAvgPoolToAvgPool().apply(baseModel, dummyInput) + + print("Quantizing model...") + FQModel = quantize( + graph_model=baseModel, + compute_layer_map=computeLayerMap, + quant_act_map=quantActMap, + quant_identity_map=quantIdentityMap, + ) + + return FQModel + + +@pytest.mark.ModelTests +def deepQuantTestResnet18() -> None: + HOME = Path.home() + BASE = HOME / "Documents" / "ImagenetV2" + TAR_URL = ( + "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/" + "imagenetv2-matched-frequency.tar.gz" + ) + TAR_PATH = BASE / "imagenetv2-matched-frequency.tar.gz" + EXTRACT_DIR = BASE / "imagenetv2-matched-frequency-format-val" + + if not TAR_PATH.exists(): + BASE.mkdir(parents=True, exist_ok=True) + print(f"Downloading ImageNetV2 from {TAR_URL}...") + urllib.request.urlretrieve(TAR_URL, TAR_PATH) + + if not EXTRACT_DIR.exists(): + print(f"Extracting to {EXTRACT_DIR}...") + with tarfile.open(TAR_PATH, "r:*") as tar: + for member in tqdm(tar.getmembers(), desc="Extracting files"): + tar.extract(member, BASE) + print("Extraction completed.") + + transformsVal = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + dataset = ImageFolder(root=str(EXTRACT_DIR), transform=transformsVal) + dataset.classes = sorted(dataset.classes, key=lambda x: int(x)) + dataset.class_to_idx = {cls: i for i, cls in enumerate(dataset.classes)} + + newSamples = [] + for path, _ in dataset.samples: + clsName = Path(path).parent.name + newLabel = dataset.class_to_idx[clsName] + newSamples.append((path, newLabel)) + dataset.samples = newSamples + dataset.targets = [s[1] for s in newSamples] + + # FBRANCASI: Optional, reduce number of example for faster validation + DATASET_LIMIT = 256 + dataset = Subset(dataset, list(range(DATASET_LIMIT))) + print(f"Validation dataset size set to {len(dataset)} images.") + + calibLoader = DataLoader( + Subset(dataset, list(range(256))), batch_size=32, shuffle=False, pin_memory=True + ) + valLoader = DataLoader(dataset, batch_size=32, shuffle=False, pin_memory=True) + + # FBRANCASI: I'm on mac, so mps for me + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("mps" if torch.backends.mps.is_available() else device) + print(f"Using device: {device}") + + originalModel = torchvision.models.resnet18( + weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1 + ) + originalModel = originalModel.eval().to(device) + print("Original ResNet18 loaded.") + + print("Evaluating original model...") + originalTop1, originalTop5 = evaluateModel( + originalModel, valLoader, device, "Original ResNet18" + ) + + print("Preparing and quantizing ResNet18...") + FQModel = prepareFQResNet18() + + print("Calibrating FQ model...") + calibrateModel(FQModel, calibLoader) + + print("Evaluating FQ model...") + # FBRANCASI: I'm on mac, mps doesn't work with brevitas + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + FQTop1, FQTop5 = evaluateModel(FQModel, valLoader, device, "FQ ResNet18") + + sampleInputImg = torch.randn(1, 3, 224, 224).to("cpu") + TQModel = brevitasToTrueQuant(FQModel, sampleInputImg, debug=True) + + numParameters = sum(p.numel() for p in TQModel.parameters()) + print(f"Number of parameters: {numParameters:,}") + + print("Evaluating TQ model...") + TQTop1, TQTop5 = evaluateModel(TQModel, valLoader, device, "TQ ResNet18") + + print("\nComparison Summary:") + print(f"{'Model':<25} {'Top-1 Accuracy':<25} {'Top-5 Accuracy':<25}") + print("-" * 75) + print(f"{'Original ResNet18':<25} {originalTop1:<24.2f} {originalTop5:<24.2f}") + print(f"{'FQ ResNet18':<25} {FQTop1:<24.2f} {FQTop5:<24.2f}") + print(f"{'TQ ResNet18':<25} {TQTop1:<24.2f} {TQTop5:<24.2f}") + print( + f"{'FQ Drop':<25} {originalTop1 - FQTop1:<24.2f} {originalTop5 - FQTop5:<24.2f}" + ) + print( + f"{'TQ Drop':<25} {originalTop1 - TQTop1:<24.2f} {originalTop5 - TQTop5:<24.2f}" + ) + + if abs(FQTop1 - TQTop1) > 5.0 or abs(FQTop5 - TQTop5) > 5.0: + print( + f"Warning: Large accuracy drop between FQ and TQ models. " + f"Top-1 difference: {abs(FQTop1 - TQTop1):.2f}%, " + f"Top-5 difference: {abs(FQTop5 - TQTop5):.2f}%" + ) diff --git a/Tests/TestResNet50Pretrained.py b/Tests/TestResNet50Pretrained.py new file mode 100644 index 0000000..135469e --- /dev/null +++ b/Tests/TestResNet50Pretrained.py @@ -0,0 +1,279 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +import tarfile +import urllib.request +from pathlib import Path + +import brevitas.nn as qnn +import pytest +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool +from brevitas.graph.quantize import preprocess_for_quantize, quantize +from brevitas.quant import ( + Int8ActPerTensorFloat, + Int8WeightPerTensorFloat, + Int32Bias, + Uint8ActPerTensorFloat, +) +from torch.utils.data import DataLoader, Subset +from torchvision.datasets import ImageFolder +from tqdm import tqdm + +from DeepQuant import brevitasToTrueQuant + + +def evaluateModel(model, dataLoader, evalDevice, name="Model"): + model.eval() + correctTop1 = 0 + correctTop5 = 0 + total = 0 + + with torch.no_grad(): + for inputs, targets in tqdm(dataLoader, desc=f"Evaluating {name}"): + isTQ = "TQ" in name + + if isTQ: + # FBRANCASI: Process different batches for the TQ model + for i in range(inputs.size(0)): + singleInput = inputs[i : i + 1].to(evalDevice) + singleOutput = model(singleInput) + + _, predicted = singleOutput.max(1) + if predicted.item() == targets[i].item(): + correctTop1 += 1 + + _, top5Pred = singleOutput.topk(5, dim=1, largest=True, sorted=True) + if targets[i].item() in top5Pred[0].cpu().numpy(): + correctTop5 += 1 + + total += 1 + else: + inputs = inputs.to(evalDevice) + targets = targets.to(evalDevice) + output = model(inputs) + + _, predicted = output.max(1) + correctTop1 += (predicted == targets).sum().item() + + _, top5Pred = output.topk(5, dim=1, largest=True, sorted=True) + for i in range(targets.size(0)): + if targets[i] in top5Pred[i]: + correctTop5 += 1 + + total += targets.size(0) + + top1Accuracy = 100.0 * correctTop1 / total + top5Accuracy = 100.0 * correctTop5 / total + + print( + f"{name} - Top-1 Accuracy: {top1Accuracy:.2f}% ({correctTop1}/{total}), " + f"Top-5 Accuracy: {top5Accuracy:.2f}%" + ) + + return top1Accuracy, top5Accuracy + + +def calibrateModel(model, calibLoader): + model.eval() + with torch.no_grad(), calibration_mode(model): + for inputs, _ in tqdm(calibLoader, desc="Calibrating model"): + inputs = inputs.to("cpu") + model(inputs) + print("Calibration completed.") + + +def prepareFQResNet50(): + """Prepare a fake-quantized (FQ) ResNet50 model.""" + baseModel = torchvision.models.resnet50( + weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2 + ) + baseModel = baseModel.eval().to("cpu") + + computeLayerMap = { + nn.Conv2d: ( + qnn.QuantConv2d, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": True, + "return_quant_tensor": True, + "output_bit_width": 8, + "weight_bit_width": 8, + }, + ), + nn.Linear: ( + qnn.QuantLinear, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": True, + "return_quant_tensor": True, + "output_bit_width": 8, + "weight_bit_width": 8, + }, + ), + } + + quantActMap = {} + + quantIdentityMap = { + "signed": ( + qnn.QuantIdentity, + { + "act_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + "unsigned": ( + qnn.QuantIdentity, + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + } + + dummyInput = torch.ones(1, 3, 224, 224).to("cpu") + + print("Preprocessing model for quantization...") + baseModel = preprocess_for_quantize( + baseModel, equalize_iters=20, equalize_scale_computation="range" + ) + + print("Converting AdaptiveAvgPool to AvgPool...") + baseModel = AdaptiveAvgPoolToAvgPool().apply(baseModel, dummyInput) + + print("Quantizing model...") + FQModel = quantize( + graph_model=baseModel, + compute_layer_map=computeLayerMap, + quant_act_map=quantActMap, + quant_identity_map=quantIdentityMap, + ) + + return FQModel + + +@pytest.mark.ModelTests +def deepQuantTestResnet50() -> None: + HOME = Path.home() + BASE = HOME / "Documents" / "ImagenetV2" + TAR_URL = ( + "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/" + "imagenetv2-matched-frequency.tar.gz" + ) + TAR_PATH = BASE / "imagenetv2-matched-frequency.tar.gz" + EXTRACT_DIR = BASE / "imagenetv2-matched-frequency-format-val" + + if not TAR_PATH.exists(): + BASE.mkdir(parents=True, exist_ok=True) + print(f"Downloading ImageNetV2 from {TAR_URL}...") + urllib.request.urlretrieve(TAR_URL, TAR_PATH) + + if not EXTRACT_DIR.exists(): + print(f"Extracting to {EXTRACT_DIR}...") + with tarfile.open(TAR_PATH, "r:*") as tar: + for member in tqdm(tar.getmembers(), desc="Extracting files"): + tar.extract(member, BASE) + print("Extraction completed.") + + transformsVal = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + dataset = ImageFolder(root=str(EXTRACT_DIR), transform=transformsVal) + dataset.classes = sorted(dataset.classes, key=lambda x: int(x)) + dataset.class_to_idx = {cls: i for i, cls in enumerate(dataset.classes)} + + newSamples = [] + for path, _ in dataset.samples: + clsName = Path(path).parent.name + newLabel = dataset.class_to_idx[clsName] + newSamples.append((path, newLabel)) + dataset.samples = newSamples + dataset.targets = [s[1] for s in newSamples] + + # FBRANCASI: Optional, reduce number of example for faster validation + DATASET_LIMIT = 256 + dataset = Subset(dataset, list(range(DATASET_LIMIT))) + print(f"Validation dataset size set to {len(dataset)} images.") + + calibLoader = DataLoader( + Subset(dataset, list(range(256))), batch_size=32, shuffle=False, pin_memory=True + ) + valLoader = DataLoader(dataset, batch_size=32, shuffle=False, pin_memory=True) + + # FBRANCASI: I'm on mac, so mps for me + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("mps" if torch.backends.mps.is_available() else device) + print(f"Using device: {device}") + + originalModel = torchvision.models.resnet50( + weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2 + ) + originalModel = originalModel.eval().to(device) + print("Original ResNet50 loaded.") + + print("Evaluating original model...") + originalTop1, originalTop5 = evaluateModel( + originalModel, valLoader, device, "Original ResNet50" + ) + + print("Preparing and quantizing ResNet50...") + FQModel = prepareFQResNet50() + + print("Calibrating FQ model...") + calibrateModel(FQModel, calibLoader) + + print("Evaluating FQ model...") + # FBRANCASI: I'm on mac, mps doesn't work with brevitas + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + FQTop1, FQTop5 = evaluateModel(FQModel, valLoader, device, "FQ ResNet50") + + sampleInputImg = torch.randn(1, 3, 224, 224).to("cpu") + TQModel = brevitasToTrueQuant(FQModel, sampleInputImg, debug=True) + + numParameters = sum(p.numel() for p in TQModel.parameters()) + print(f"Number of parameters: {numParameters:,}") + + print("Evaluating TQ model...") + TQTop1, TQTop5 = evaluateModel(TQModel, valLoader, device, "TQ ResNet50") + + print("\nComparison Summary:") + print(f"{'Model':<25} {'Top-1 Accuracy':<25} {'Top-5 Accuracy':<25}") + print("-" * 75) + print(f"{'Original ResNet50':<25} {originalTop1:<24.2f} {originalTop5:<24.2f}") + print(f"{'FQ ResNet50':<25} {FQTop1:<24.2f} {FQTop5:<24.2f}") + print(f"{'TQ ResNet50':<25} {TQTop1:<24.2f} {TQTop5:<24.2f}") + print( + f"{'FQ Drop':<25} {originalTop1 - FQTop1:<24.2f} {originalTop5 - FQTop5:<24.2f}" + ) + print( + f"{'TQ Drop':<25} {originalTop1 - TQTop1:<24.2f} {originalTop5 - TQTop5:<24.2f}" + ) + + if abs(FQTop1 - TQTop1) > 5.0 or abs(FQTop5 - TQTop5) > 5.0: + print( + f"Warning: Large accuracy drop between FQ and TQ models. " + f"Top-1 difference: {abs(FQTop1 - TQTop1):.2f}%, " + f"Top-5 difference: {abs(FQTop5 - TQTop5):.2f}%" + ) \ No newline at end of file diff --git a/Tests/TestSimpleCNN.py b/Tests/TestSimpleCNN.py index bc755ec..e1afb48 100644 --- a/Tests/TestSimpleCNN.py +++ b/Tests/TestSimpleCNN.py @@ -4,29 +4,23 @@ # # Federico Brancasi - +import brevitas.nn as qnn import pytest import torch import torch.nn as nn -import brevitas.nn as qnn from brevitas.quant.scaled_int import ( Int8ActPerTensorFloat, - Int32Bias, Int8WeightPerTensorFloat, + Int32Bias, ) -from DeepQuant.ExportBrevitas import exportBrevitas + +from DeepQuant import brevitasToTrueQuant class SimpleQuantCNN(nn.Module): - """ - A simple quantized CNN that includes: - - Input quantization - - Two QuantConv2d layers with Quantized ReLU - - MaxPool2d - - A final QuantLinear layer - """ + """A simple quantized CNN with two conv layers and a linear layer.""" - convAndLinQuantParams = { + convQuantParams = { "bias": True, "weight_bit_width": 4, "bias_quant": Int32Bias, @@ -36,21 +30,16 @@ class SimpleQuantCNN(nn.Module): "return_quant_tensor": True, } - def __init__(self, in_channels: int = 1, num_classes: int = 10) -> None: - """ - Args: - in_channels: Number of input channels (e.g., 1 for grayscale). - num_classes: Number of output classes for the final linear layer. - """ + def __init__(self, inChannels: int = 1, numClasses: int = 10) -> None: super().__init__() self.inputQuant = qnn.QuantIdentity(return_quant_tensor=True) self.conv1 = qnn.QuantConv2d( - in_channels=in_channels, + in_channels=inChannels, out_channels=16, kernel_size=3, padding=1, - **SimpleQuantCNN.convAndLinQuantParams + **SimpleQuantCNN.convQuantParams, ) self.relu1 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True) self.pool1 = nn.MaxPool2d(kernel_size=2) @@ -60,28 +49,19 @@ def __init__(self, in_channels: int = 1, num_classes: int = 10) -> None: out_channels=32, kernel_size=3, padding=1, - **SimpleQuantCNN.convAndLinQuantParams + **SimpleQuantCNN.convQuantParams, ) self.relu2 = qnn.QuantReLU(bit_width=4, return_quant_tensor=True) self.pool2 = nn.MaxPool2d(kernel_size=2) self.flatten = nn.Flatten() self.fc = qnn.QuantLinear( - in_features=32 * 7 * 7, # If input is 28x28, shape after pooling is 7x7 - out_features=num_classes, - **SimpleQuantCNN.convAndLinQuantParams + in_features=32 * 7 * 7, + out_features=numClasses, + **SimpleQuantCNN.convQuantParams, ) def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass of the SimpleQuantCNN. - - Args: - x: Input tensor of shape [batch_size, in_channels, height, width]. - - Returns: - A quantized output tensor (batch_size, num_classes). - """ x = self.inputQuant(x) x = self.conv1(x) @@ -99,10 +79,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @pytest.mark.ModelTests def deepQuantTestSimpleCNN() -> None: - torch.manual_seed(42) - model = SimpleQuantCNN().eval() sampleInput = torch.randn(1, 1, 28, 28) - - exportBrevitas(model, sampleInput, debug=True) + brevitasToTrueQuant(model, sampleInput, debug=True, checkEquivalence=True) diff --git a/Tests/TestSimpleFCNN.py b/Tests/TestSimpleFCNN.py index 33b90f6..c3c7821 100644 --- a/Tests/TestSimpleFCNN.py +++ b/Tests/TestSimpleFCNN.py @@ -4,39 +4,25 @@ # # Federico Brancasi - -import warnings - -warnings.filterwarnings("ignore", category=UserWarning) -warnings.filterwarnings("ignore", category=UserWarning, message=".*has_cuda.*") -warnings.filterwarnings("ignore", category=UserWarning, message=".*has_cudnn.*") -warnings.filterwarnings("ignore", category=UserWarning, message=".*has_mps.*") -warnings.filterwarnings("ignore", category=UserWarning, message=".*has_mkldnn.*") -warnings.filterwarnings( - "ignore", category=UserWarning, message=".*experimental feature.*" -) -warnings.filterwarnings("ignore", category=UserWarning, message=".*deprecated.*") - from pathlib import Path -from tqdm import tqdm +import brevitas.nn as qnn import torch import torch.nn as nn import torch.optim as optim -from torch.utils.data import DataLoader -from torchvision import datasets, transforms - -import brevitas.nn as qnn -from brevitas.graph.quantize import preprocess_for_quantize, quantize from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.quantize import preprocess_for_quantize, quantize from brevitas.quant import ( Int8ActPerTensorFloat, Int8WeightPerTensorFloat, Int32Bias, Uint8ActPerTensorFloat, ) +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +from tqdm import tqdm -from DeepQuant.ExportBrevitas import exportBrevitas +from DeepQuant import brevitasToTrueQuant class SimpleFCNN(nn.Module): @@ -65,7 +51,6 @@ def trainModel( epochs: int = 10, learningRate: float = 0.001, ) -> nn.Module: - """Train the model if no saved weights exist.""" if savePath.exists(): print(f"Loading existing model from {savePath}") @@ -89,7 +74,6 @@ def trainModel( print(f"Epoch [{epoch+1}/{epochs}], Loss: {runningLoss/len(trainLoader):.4f}") - # Evaluate model.eval() correct = 0 total = 0 @@ -102,36 +86,35 @@ def trainModel( print(f"Accuracy on the test set: {100 * correct / total:.2f}%") - # Save model torch.save(model.state_dict(), savePath) print(f"Model saved to {savePath}") return model -def calibrate_model( - model: nn.Module, calib_loader: DataLoader, device: torch.device +def calibrateModel( + model: nn.Module, calibLoader: DataLoader, device: torch.device ) -> None: - """Calibrate the quantized model.""" model.eval() model.to(device) with ( torch.no_grad(), calibration_mode(model), - tqdm(calib_loader, desc="Calibrating") as pbar, + tqdm(calibLoader, desc="Calibrating") as pbar, ): for images, _ in pbar: images = images.to(device) images = images.to(torch.float) model(images) + DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") EXPORT_FOLDER = Path().cwd() / "Tests" MODEL_PATH = EXPORT_FOLDER / "Models" DATA_PATH = EXPORT_FOLDER / "Data" + def deepQuantTestSimpleFCNN() -> None: - EXPORT_FOLDER.mkdir(parents=True, exist_ok=True) MODEL_PATH.mkdir(parents=True, exist_ok=True) @@ -143,26 +126,21 @@ def deepQuantTestSimpleFCNN() -> None: ] ) - train_dataset = datasets.MNIST( + trainDataset = datasets.MNIST( root=DATA_PATH, train=True, download=True, transform=transform ) - test_dataset = datasets.MNIST( + testDataset = datasets.MNIST( root=DATA_PATH, train=False, download=True, transform=transform ) - trainLoader = DataLoader(train_dataset, batch_size=64, shuffle=True) - testLoader = DataLoader( - test_dataset, batch_size=64, shuffle=False, pin_memory=True - ) + trainLoader = DataLoader(trainDataset, batch_size=64, shuffle=True) + testLoader = DataLoader(testDataset, batch_size=64, shuffle=False, pin_memory=True) - # Train or load model - m = SimpleFCNN() - model = trainModel(m, trainLoader, testLoader, MODEL_PATH / "mnist_model.pth") + model = SimpleFCNN() + model = trainModel(model, trainLoader, testLoader, MODEL_PATH / "mnist_model.pth") - # Prepare for quantization model = preprocess_for_quantize(model) - # Quantization configurations computeLayerMap = { nn.Linear: ( qnn.QuantLinear, @@ -208,7 +186,6 @@ def deepQuantTestSimpleFCNN() -> None: ), } - # Quantize and calibrate modelQuant = quantize( model, compute_layer_map=computeLayerMap, @@ -216,11 +193,9 @@ def deepQuantTestSimpleFCNN() -> None: quant_identity_map=quantIdentityMap, ) - calibrate_model(modelQuant, testLoader, DEVICE) + calibrateModel(modelQuant, testLoader, DEVICE) - # Export and transform sampleInput, _ = next(iter(testLoader)) sampleInput = sampleInput[0:1] - print(f"Sample input shape: {sampleInput.shape}") - exportBrevitas(modelQuant, sampleInput.to(DEVICE), debug=True) + brevitasToTrueQuant(modelQuant, sampleInput.to(DEVICE), debug=True) diff --git a/Tests/TestVitB32.py b/Tests/TestVitB32.py new file mode 100644 index 0000000..b09b500 --- /dev/null +++ b/Tests/TestVitB32.py @@ -0,0 +1,139 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +import brevitas.nn as qnn +import pytest +import torch +import torch.nn as nn +import torchvision.models as models +from brevitas.graph.quantize import preprocess_for_quantize, quantize +from brevitas.quant import ( + Int8ActPerTensorFloat, + Int8WeightPerTensorFloat, + Int32Bias, + Uint8ActPerTensorFloat, +) + +from DeepQuant import brevitasToTrueQuant + + +def prepare_vit_b_32(model: nn.Module) -> nn.Module: + """ + Prepare a quantized ViT-B/32 model using Brevitas. + """ + + compute_layer_map = { + nn.Conv2d: ( + qnn.QuantConv2d, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": True, + "return_quant_tensor": True, + "weight_bit_width": 8, + }, + ), + nn.MultiheadAttention: ( + qnn.QuantMultiheadAttention, + { + "in_proj_input_quant": Int8ActPerTensorFloat, + "in_proj_weight_quant": Int8WeightPerTensorFloat, + "in_proj_bias_quant": Int32Bias, + "attn_output_weights_quant": Uint8ActPerTensorFloat, + "q_scaled_quant": Int8ActPerTensorFloat, + "k_transposed_quant": Int8ActPerTensorFloat, + "v_quant": Int8ActPerTensorFloat, + "out_proj_input_quant": Int8ActPerTensorFloat, + "out_proj_weight_quant": Int8WeightPerTensorFloat, + "out_proj_bias_quant": Int32Bias, + "out_proj_output_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, + }, + ), + nn.Linear: ( + qnn.QuantLinear, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": True, + "return_quant_tensor": True, + "weight_bit_width": 8, + }, + ), + } + + quant_act_map = { + nn.GELU: ( + qnn.QuantReLU, # FBRANCASI: Approximating GELU with QuantReLU + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + } + + quant_identity_map = { + "signed": ( + qnn.QuantIdentity, + { + "act_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + "unsigned": ( + qnn.QuantIdentity, + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + } + + print("\nPreprocessing model for quantization...") + model = preprocess_for_quantize( + model, + equalize_iters=10, + equalize_scale_computation="range", + ) + + print("\nQuantizing model...") + quantized_model = quantize( + graph_model=model, + compute_layer_map=compute_layer_map, + quant_act_map=quant_act_map, + quant_identity_map=quant_identity_map, + ) + + return quantized_model + + +@pytest.mark.ModelTests +def deepQuantTestViT(): + torch.manual_seed(42) + sampleInput = torch.randn(1, 3, 224, 224) + + vit_model = models.vit_b_32(weights=models.ViT_B_32_Weights.IMAGENET1K_V1) + vit_model.eval() + + print(f"\nTesting ViT-B/32 model with input shape: {sampleInput.shape}") + + quantized_vit = prepare_vit_b_32(vit_model) + + with torch.no_grad(): + output = quantized_vit(sampleInput) + if isinstance(output, tuple): + output = output[0] + print(f"Output shape: {output.shape}") + print(f"Output range: [{output.min().item():.3f}, {output.max().item():.3f}]") + + brevitasToTrueQuant(quantized_vit, sampleInput, debug=True, checkEquivalence=False) diff --git a/Tests/TestVitB32Pretrained.py b/Tests/TestVitB32Pretrained.py new file mode 100644 index 0000000..4cfafea --- /dev/null +++ b/Tests/TestVitB32Pretrained.py @@ -0,0 +1,302 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +import tarfile +import urllib.request +from pathlib import Path + +import brevitas.nn as qnn +import pytest +import torch +import torch.nn as nn +import torchvision +import torchvision.transforms as transforms +from brevitas.graph.calibrate import calibration_mode +from brevitas.graph.per_input import AdaptiveAvgPoolToAvgPool +from brevitas.graph.quantize import preprocess_for_quantize, quantize +from brevitas.quant import ( + Int8ActPerTensorFloat, + Int8WeightPerTensorFloat, + Int32Bias, + Uint8ActPerTensorFloat, +) +from torch.utils.data import DataLoader, Subset +from torchvision.datasets import ImageFolder +from tqdm import tqdm + +from DeepQuant import brevitasToTrueQuant + + +def evaluateModel(model, dataLoader, evalDevice, name="Model"): + model.eval() + correctTop1 = 0 + correctTop5 = 0 + total = 0 + + with torch.no_grad(): + for inputs, targets in tqdm(dataLoader, desc=f"Evaluating {name}"): + isTQ = "TQ" in name + + if isTQ: + # FBRANCASI: Process different batches for the TQ model + for i in range(inputs.size(0)): + singleInput = inputs[i : i + 1].to(evalDevice) + singleOutput = model(singleInput) + if isinstance(singleOutput, tuple): + singleOutput = singleOutput[0] + + _, predicted = singleOutput.max(1) + if predicted.item() == targets[i].item(): + correctTop1 += 1 + + _, top5Pred = singleOutput.topk(5, dim=1, largest=True, sorted=True) + if targets[i].item() in top5Pred[0].cpu().numpy(): + correctTop5 += 1 + + total += 1 + else: + inputs = inputs.to(evalDevice) + targets = targets.to(evalDevice) + output = model(inputs) + if isinstance(output, tuple): + output = output[0] + + _, predicted = output.max(1) + correctTop1 += (predicted == targets).sum().item() + + _, top5Pred = output.topk(5, dim=1, largest=True, sorted=True) + for i in range(targets.size(0)): + if targets[i] in top5Pred[i]: + correctTop5 += 1 + + total += targets.size(0) + + top1Accuracy = 100.0 * correctTop1 / total + top5Accuracy = 100.0 * correctTop5 / total + + print( + f"{name} - Top-1 Accuracy: {top1Accuracy:.2f}% ({correctTop1}/{total}), " + f"Top-5 Accuracy: {top5Accuracy:.2f}%" + ) + + return top1Accuracy, top5Accuracy + + +def calibrateModel(model, calibLoader): + model.eval() + with torch.no_grad(), calibration_mode(model): + for inputs, _ in tqdm(calibLoader, desc="Calibrating model"): + inputs = inputs.to("cpu") + output = model(inputs) + if isinstance(output, tuple): + output = output[0] + print("Calibration completed.") + + +def prepareFQVitB32(): + """Prepare a fake-quantized (FQ) ViT-B/32 model.""" + baseModel = torchvision.models.vit_b_32( + weights=torchvision.models.ViT_B_32_Weights.IMAGENET1K_V1 + ) + baseModel = baseModel.eval().to("cpu") + + computeLayerMap = { + nn.Conv2d: ( + qnn.QuantConv2d, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": True, + "return_quant_tensor": True, + "output_bit_width": 8, + "weight_bit_width": 8, + }, + ), + nn.MultiheadAttention: ( + qnn.QuantMultiheadAttention, + { + "in_proj_input_quant": Int8ActPerTensorFloat, + "in_proj_weight_quant": Int8WeightPerTensorFloat, + "in_proj_bias_quant": Int32Bias, + "attn_output_weights_quant": Uint8ActPerTensorFloat, + "q_scaled_quant": Int8ActPerTensorFloat, + "k_transposed_quant": Int8ActPerTensorFloat, + "v_quant": Int8ActPerTensorFloat, + "out_proj_input_quant": Int8ActPerTensorFloat, + "out_proj_weight_quant": Int8WeightPerTensorFloat, + "out_proj_bias_quant": Int32Bias, + "out_proj_output_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, + }, + ), + nn.Linear: ( + qnn.QuantLinear, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": True, + "return_quant_tensor": True, + "output_bit_width": 8, + "weight_bit_width": 8, + }, + ), + } + + quantActMap = {} + + quantIdentityMap = { + "signed": ( + qnn.QuantIdentity, + { + "act_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + "unsigned": ( + qnn.QuantIdentity, + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + } + + dummyInput = torch.ones(1, 3, 224, 224).to("cpu") + + print("Preprocessing model for quantization...") + baseModel = preprocess_for_quantize( + baseModel, equalize_iters=20, equalize_scale_computation="range" + ) + + print("Converting AdaptiveAvgPool to AvgPool...") + baseModel = AdaptiveAvgPoolToAvgPool().apply(baseModel, dummyInput) + + print("Quantizing model...") + FQModel = quantize( + graph_model=baseModel, + compute_layer_map=computeLayerMap, + quant_act_map=quantActMap, + quant_identity_map=quantIdentityMap, + ) + + return FQModel + + +@pytest.mark.ModelTests +def deepQuantTestVitB32Pretrained() -> None: + HOME = Path.home() + BASE = HOME / "Documents" / "ImagenetV2" + TAR_URL = ( + "https://huggingface.co/datasets/vaishaal/ImageNetV2/resolve/main/" + "imagenetv2-matched-frequency.tar.gz" + ) + TAR_PATH = BASE / "imagenetv2-matched-frequency.tar.gz" + EXTRACT_DIR = BASE / "imagenetv2-matched-frequency-format-val" + + if not TAR_PATH.exists(): + BASE.mkdir(parents=True, exist_ok=True) + print(f"Downloading ImageNetV2 from {TAR_URL}...") + urllib.request.urlretrieve(TAR_URL, TAR_PATH) + + if not EXTRACT_DIR.exists(): + print(f"Extracting to {EXTRACT_DIR}...") + with tarfile.open(TAR_PATH, "r:*") as tar: + for member in tqdm(tar.getmembers(), desc="Extracting files"): + tar.extract(member, BASE) + print("Extraction completed.") + + transformsVal = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + dataset = ImageFolder(root=str(EXTRACT_DIR), transform=transformsVal) + dataset.classes = sorted(dataset.classes, key=lambda x: int(x)) + dataset.class_to_idx = {cls: i for i, cls in enumerate(dataset.classes)} + + newSamples = [] + for path, _ in dataset.samples: + clsName = Path(path).parent.name + newLabel = dataset.class_to_idx[clsName] + newSamples.append((path, newLabel)) + dataset.samples = newSamples + dataset.targets = [s[1] for s in newSamples] + + # FBRANCASI: Optional, reduce number of example for faster validation + DATASET_LIMIT = 256 + dataset = Subset(dataset, list(range(DATASET_LIMIT))) + print(f"Validation dataset size set to {len(dataset)} images.") + + calibLoader = DataLoader( + Subset(dataset, list(range(256))), batch_size=32, shuffle=False, pin_memory=True + ) + valLoader = DataLoader(dataset, batch_size=32, shuffle=False, pin_memory=True) + + # FBRANCASI: I'm on mac, so mps for me + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("mps" if torch.backends.mps.is_available() else device) + print(f"Using device: {device}") + + originalModel = torchvision.models.vit_b_32( + weights=torchvision.models.ViT_B_32_Weights.IMAGENET1K_V1 + ) + originalModel = originalModel.eval().to(device) + print("Original ViT-B/32 loaded.") + + print("Evaluating original model...") + originalTop1, originalTop5 = evaluateModel( + originalModel, valLoader, device, "Original ViT-B/32" + ) + + print("Preparing and quantizing ViT-B/32...") + FQModel = prepareFQVitB32() + + print("Calibrating FQ model...") + calibrateModel(FQModel, calibLoader) + + print("Evaluating FQ model...") + # FBRANCASI: I'm on mac, mps doesn't work with brevitas + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + FQTop1, FQTop5 = evaluateModel(FQModel, valLoader, device, "FQ ViT-B/32") + + sampleInputImg = torch.randn(1, 3, 224, 224).to("cpu") + TQModel = brevitasToTrueQuant(FQModel, sampleInputImg, debug=True) + + numParameters = sum(p.numel() for p in TQModel.parameters()) + print(f"Number of parameters: {numParameters:,}") + + print("Evaluating TQ model...") + TQTop1, TQTop5 = evaluateModel(TQModel, valLoader, device, "TQ ViT-B/32") + + print("\nComparison Summary:") + print(f"{'Model':<25} {'Top-1 Accuracy':<25} {'Top-5 Accuracy':<25}") + print("-" * 75) + print(f"{'Original ViT-B/32':<25} {originalTop1:<24.2f} {originalTop5:<24.2f}") + print(f"{'FQ ViT-B/32':<25} {FQTop1:<24.2f} {FQTop5:<24.2f}") + print(f"{'TQ ViT-B/32':<25} {TQTop1:<24.2f} {TQTop5:<24.2f}") + print( + f"{'FQ Drop':<25} {originalTop1 - FQTop1:<24.2f} {originalTop5 - FQTop5:<24.2f}" + ) + print( + f"{'TQ Drop':<25} {originalTop1 - TQTop1:<24.2f} {originalTop5 - TQTop5:<24.2f}" + ) + + if abs(FQTop1 - TQTop1) > 5.0 or abs(FQTop5 - TQTop5) > 5.0: + print( + f"Warning: Large accuracy drop between FQ and TQ models. " + f"Top-1 difference: {abs(FQTop1 - TQTop1):.2f}%, " + f"Top-5 difference: {abs(FQTop5 - TQTop5):.2f}%" + ) diff --git a/Tests/TestYOLOv5.py b/Tests/TestYOLOv5.py new file mode 100644 index 0000000..7231492 --- /dev/null +++ b/Tests/TestYOLOv5.py @@ -0,0 +1,127 @@ +# Copyright 2025 ETH Zurich and University of Bologna. +# Licensed under the Apache License, Version 2.0, see LICENSE for details. +# SPDX-License-Identifier: Apache-2.0 +# +# Federico Brancasi + +import brevitas.nn as qnn +import pytest +import torch +import torch.nn as nn +from brevitas.graph.quantize import preprocess_for_quantize, quantize +from brevitas.quant import ( + Int8ActPerTensorFloat, + Int8WeightPerTensorFloat, + Int32Bias, + Uint8ActPerTensorFloat, +) + +from DeepQuant import brevitasToTrueQuant + + +def prepareYOLOv5Backbone() -> nn.Module: + """Prepare a quantized partial YOLOv5 model for testing.""" + from ultralytics import YOLO + + model = YOLO("Models/yolov5nu.pt") + pytorchModel = model.model + + # FBRANCASI: Just first few layers for simplicity + backbone = pytorchModel.model[0:4] + + computeLayerMap = { + nn.Conv2d: ( + qnn.QuantConv2d, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": True, + "return_quant_tensor": True, + "output_bit_width": 8, + "weight_bit_width": 4, + }, + ), + nn.Linear: ( + qnn.QuantLinear, + { + "input_quant": Int8ActPerTensorFloat, + "weight_quant": Int8WeightPerTensorFloat, + "output_quant": Int8ActPerTensorFloat, + "bias_quant": Int32Bias, + "bias": True, + "return_quant_tensor": True, + "output_bit_width": 8, + "weight_bit_width": 4, + }, + ), + } + + quantActMap = { + nn.SiLU: ( + qnn.QuantReLU, # FBRANCASI: As a substitute for now + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + nn.ReLU: ( + qnn.QuantReLU, + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + nn.LeakyReLU: ( + qnn.QuantReLU, # FBRANCASI: As a substitute for now + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + } + + quantIdentityMap = { + "signed": ( + qnn.QuantIdentity, + { + "act_quant": Int8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + "unsigned": ( + qnn.QuantIdentity, + { + "act_quant": Uint8ActPerTensorFloat, + "return_quant_tensor": True, + "bit_width": 8, + }, + ), + } + + backbone = preprocess_for_quantize( + backbone, equalize_iters=10, equalize_scale_computation="range" + ) + + quantizedModel = quantize( + graph_model=backbone, + compute_layer_map=computeLayerMap, + quant_act_map=quantActMap, + quant_identity_map=quantIdentityMap, + ) + + return quantizedModel + + +@pytest.mark.ModelTests +def deepQuantTestYOLOv5(): + torch.manual_seed(42) + quantizedModel = prepareYOLOv5Backbone() + sampleInput = torch.randn(1, 3, 128, 128) + quantizedModel.eval() + brevitasToTrueQuant(quantizedModel, sampleInput, debug=True) diff --git a/Tests/__init__.py b/Tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/conftest.py b/conftest.py deleted file mode 100644 index 950c05e..0000000 --- a/conftest.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright 2025 ETH Zurich and University of Bologna. -# Licensed under the Apache License, Version 2.0, see LICENSE for details. -# SPDX-License-Identifier: Apache-2.0 -# -# Federico Brancasi - -""" -Pytest configuration file that suppresses specific warnings, including those -related to torch.tensor constant registration in FX tracing. -""" - -import warnings - -warnings.filterwarnings("ignore", category=DeprecationWarning) -warnings.filterwarnings("ignore", category=UserWarning, message="Named tensors.*") -warnings.filterwarnings( - "ignore", category=UserWarning, message=".*__torch_function__.*" -) -warnings.filterwarnings( - "ignore", category=UserWarning, message="Was not able to add assertion.*" -) -warnings.filterwarnings( - "ignore", category=UserWarning, message="'has_cuda' is deprecated.*" -) -warnings.filterwarnings( - "ignore", category=UserWarning, message="'has_cudnn' is deprecated.*" -) -warnings.filterwarnings( - "ignore", category=UserWarning, message="'has_mps' is deprecated.*" -) -warnings.filterwarnings( - "ignore", category=UserWarning, message="'has_mkldnn' is deprecated.*" -) diff --git a/pyproject.toml b/pyproject.toml index 0534afe..34841ab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,13 +33,14 @@ dependencies = [ "onnx", "onnxoptimizer", "onnxruntime", + "ultralytics", ] [tool.setuptools] packages = ["DeepQuant"] [tool.pytest.ini_options] -python_files = ["Tests/*.py"] +python_files = ["Tests/Test*.py"] python_functions = ["deepQuantTest*"] markers = [ "SingleLayerTests: Tests for individual layers",