@@ -202,6 +202,7 @@ def __init__(
202202 op_types_to_quantize : tuple [str , ...] | None = None ,
203203 quant_axes : tuple [tuple [str , int ], ...] | None = None ,
204204 bits : int = 4 ,
205+ channel_wised_quantize : bool = False ,
205206 ):
206207 """
207208 This is a class for weight only affine quantization configuration.
@@ -236,6 +237,9 @@ def __init__(
236237 self .is_symmetric = is_symmetric
237238 self .bits = bits
238239 self .accuracy_level = accuracy_level
240+ self .channel_wised_quantize = channel_wised_quantize
241+ if channel_wised_quantize and quant_format == QuantFormat .QOperator :
242+ raise NotImplementedError ("QuantFormat.QOperator is not supported channel_wised_quantize yet" )
239243
240244
241245class NVAWQWeightOnlyQuantConfig (WeightOnlyQuantConfig ):
@@ -734,6 +738,26 @@ def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, Gr
734738 return None , None
735739
736740
741+ # transpose int4 matrix (packed as uint8)
742+ def transpose_packed_int4_matrix (packed , rows , cols ):
743+ # unpack to int4 matrix
744+ total = rows * cols
745+ high = (packed >> 4 ) & 0x0F
746+ low = packed & 0x0F
747+ int4_vals = np .empty (total , dtype = np .uint8 )
748+ int4_vals [0 ::2 ] = low
749+ int4_vals [1 ::2 ] = high
750+ int4_matrix = int4_vals .reshape ((rows , cols ))
751+
752+ # transpose int4 matrix
753+ int4_matrix_transposed = int4_matrix .T
754+
755+ # pack to uint8
756+ flat = int4_matrix_transposed .reshape (- 1 )
757+ packed = ((flat [1 ::2 ] << 4 ) & 0xF0 ) | (flat [0 ::2 ] & 0x0F )
758+ return packed .astype (np .uint8 )
759+
760+
737761class DefaultWeightOnlyQuantizer :
738762 def __init__ (self , config : DefaultWeightOnlyQuantConfig ):
739763 self .config = config
@@ -770,6 +794,10 @@ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.n
770794 packed , fp32weight , scales , zero_point , block_size , cols , rows , self .config .is_symmetric
771795 )
772796 else :
797+ # block size equal to rows (K) if channel wised quantize enabled
798+ block_size = rows if self .config .channel_wised_quantize else self .config .block_size
799+ k_blocks = (rows + block_size - 1 ) // block_size
800+
773801 assert qbits == 4 , "QDQ format only support 4 bits quantization"
774802 packed = np .zeros ((rows * cols + 1 ) // 2 , dtype = "uint8" )
775803 zero_point = np .zeros ((cols * k_blocks + 1 ) // 2 , dtype = "uint8" )
@@ -812,6 +840,16 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
812840 )
813841 scales_tensor = onnx .numpy_helper .from_array (scales , b_tensor .name + "_DQ_scales" )
814842
843+ # if QDQ, CW and SYM enabled, optimize for Intel NPU, tranpose the weight to NHWC format will increase performance
844+ qdq_opt_for_intel_npu_enabled = self .config .quant_format == QuantFormat .QDQ \
845+ and self .config .channel_wised_quantize and self .config .is_symmetric
846+ if qdq_opt_for_intel_npu_enabled :
847+ rows , cols = b_ndarray .shape
848+ packed = transpose_packed_int4_matrix (packed , rows , cols )
849+ scales = scales .reshape ((cols , 1 )) # (cols, 1)
850+ b_quant = onnx .helper .make_tensor (b_tensor .name + f"_DQ_Q{ bits } " , qtype , [cols , rows ], packed .tobytes (), True )
851+ scales_tensor = onnx .numpy_helper .from_array (scales , b_tensor .name + "_DQ_scales" )
852+
815853 for input in b_graph .input :
816854 if input .name == input_b :
817855 b_graph .input .remove (input )
@@ -849,15 +887,21 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
849887 else :
850888 dq_input_names = [b_quant .name , scales_tensor .name ]
851889 dq_output_names = [b_quant .name + "_output" ]
852- matmul_input_names = [node .input [0 ], dq_output_names [0 ]]
890+ tp_input_names = [dq_output_names [0 ]]
891+ tp_output_names = [dq_output_names [0 ] + "_transposed" ]
892+ matmul_input_names = [node .input [0 ], tp_output_names [0 ] if qdq_opt_for_intel_npu_enabled else dq_output_names [0 ]]
853893 matmul_output_names = [node .output [0 ]]
854894 if not self .config .is_symmetric :
855895 zp_tensor = onnx .helper .make_tensor (
856896 b_tensor .name + "_DQ_zero_points" , qtype , scales .shape , zero_points .tobytes (), True
857897 )
858898 dq_input_names .append (zp_tensor .name )
859899 b_graph .initializer .extend ([zp_tensor ])
860- dq_kwargs = {"axis" : 0 , "block_size" : self .config .block_size }
900+ rows , cols = b_ndarray .shape
901+ dq_kwargs = {
902+ "axis" : 1 if qdq_opt_for_intel_npu_enabled else 0 ,
903+ "block_size" : rows if self .config .channel_wised_quantize else self .config .block_size
904+ }
861905 dq_node = onnx .helper .make_node (
862906 "DequantizeLinear" ,
863907 inputs = dq_input_names ,
@@ -871,7 +915,16 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
871915 outputs = matmul_output_names ,
872916 name = node .name + f"_matmul_Q{ bits } " if node .name else "" ,
873917 )
874- output_nodes .extend ([dq_node , matmul_node ])
918+ if qdq_opt_for_intel_npu_enabled :
919+ tp_node = onnx .helper .make_node (
920+ "Transpose" ,
921+ inputs = tp_input_names ,
922+ outputs = tp_output_names ,
923+ perm = [1 ,0 ],
924+ )
925+ output_nodes .extend ([dq_node , tp_node , matmul_node ])
926+ else :
927+ output_nodes .extend ([dq_node , matmul_node ])
875928
876929 return output_nodes
877930
0 commit comments