@@ -743,10 +743,10 @@ def _add_quantized_conv_matmul_add_ops(
743743 weight_quantize_node : NodeProto ,
744744 input_quantize_params : QuantizationParams ,
745745 weight_quantize_params : QuantizationParams ,
746- bias_initializer : onnx .TensorProto ,
747- bias_add_name : str ,
748746 target_output : str ,
749747 transpose_weight : bool ,
748+ bias_add_name : str ,
749+ bias_initializer : Optional [onnx .TensorProto ] = None ,
750750 output_quantize_node : Optional [NodeProto ] = None ,
751751 output_dequantize_node : Optional [NodeProto ] = None ,
752752):
@@ -806,65 +806,62 @@ def _add_quantized_conv_matmul_add_ops(
806806 )
807807 model .graph .node .append (integer_op_node )
808808
809+ output_scale = input_quantize_params .scale * weight_quantize_params .scale
810+ output_scale_name = "{}_output.scale" .format (node .name )
811+ model .graph .initializer .append (
812+ numpy_helper .from_array (numpy .asarray (output_scale ), name = output_scale_name )
813+ )
814+
815+ last_output = integer_op_output
816+
809817 # Add bias + zero point correction
810818 # quantize bias
811- bias_initializer = numpy_helper .to_array (bias_initializer )
812- bias_scale = input_quantize_params .scale * weight_quantize_params .scale
813- bias_zero_point = 0
814- quantized_bias = _quantize_array (
815- bias_initializer , bias_scale , bias_zero_point , dtype = numpy .int32
816- )
817- if node .op_type == "Conv" and len (quantized_bias .shape ) == 1 :
818- # reshape for bias add broadcasting
819- quantized_bias = quantized_bias .reshape (1 , quantized_bias .shape [0 ], 1 , 1 )
819+ if bias_initializer is not None :
820+ bias_initializer = numpy_helper .to_array (bias_initializer )
820821
821- quantized_bias_name = "{}.bias_quantized" .format (bias_add_name )
822- quantized_bias_initializer = numpy_helper .from_array (
823- quantized_bias , name = quantized_bias_name
824- )
825- model .graph .initializer .append (quantized_bias_initializer )
826- quantized_bias_scale_name = "{}.scale" .format (quantized_bias_name )
827- model .graph .initializer .append (
828- numpy_helper .from_array (
829- numpy .asarray (bias_scale ), name = quantized_bias_scale_name
822+ bias_zero_point = 0
823+ quantized_bias = _quantize_array (
824+ bias_initializer , output_scale , bias_zero_point , dtype = numpy .int32
830825 )
831- )
832- quantized_bias_zero_point_name = "{}.zero_point" .format (quantized_bias_name )
833- model .graph .initializer .append (
834- numpy_helper .from_array (
835- numpy .asarray (bias_zero_point , dtype = numpy .uint8 ),
836- name = quantized_bias_zero_point_name ,
826+ if node .op_type == "Conv" and len (quantized_bias .shape ) == 1 :
827+ # reshape for bias add broadcasting
828+ quantized_bias = quantized_bias .reshape (1 , quantized_bias .shape [0 ], 1 , 1 )
829+
830+ quantized_bias_name = "{}.bias_quantized" .format (bias_add_name )
831+ quantized_bias_initializer = numpy_helper .from_array (
832+ quantized_bias , name = quantized_bias_name
837833 )
838- )
834+ model . graph . initializer . append ( quantized_bias_initializer )
839835
840- # get INT32 Add inputs and outputs
841- quant_add_inputs = [
842- integer_op_output , # MatMul/Conv integer outputs (INT32)
843- quantized_bias_name , # Quantized bias (INT32)
844- ]
836+ # get INT32 Add inputs and outputs
837+ quant_add_inputs = [
838+ last_output , # MatMul/Conv integer outputs (INT32)
839+ quantized_bias_name , # Quantized bias (INT32)
840+ ]
845841
846- quant_add_name = "{}_bias_add_quant" .format (node .name )
847- quant_add_output = (
848- output_quantize_node .output [0 ]
849- if output_quantize_node
850- else f"{ quant_add_name } _output"
851- )
842+ quant_add_name = "{}_bias_add_quant" .format (node .name )
843+ quant_add_output = (
844+ output_quantize_node .output [0 ]
845+ if output_quantize_node
846+ else f"{ quant_add_name } _output"
847+ )
852848
853- # create Add node and add it to graph
854- qadd_node = onnx .helper .make_node (
855- "Add" ,
856- quant_add_inputs ,
857- [quant_add_output ],
858- quant_add_name ,
859- )
860- model .graph .node .append (qadd_node )
849+ # create Add node and add it to graph
850+ qadd_node = onnx .helper .make_node (
851+ "Add" ,
852+ quant_add_inputs ,
853+ [quant_add_output ],
854+ quant_add_name ,
855+ )
856+ model .graph .node .append (qadd_node )
857+ last_output = quant_add_output
861858
862859 # create Cast node and add it to graph
863- cast_node_name = "{}_cast" .format (quant_add_name )
864- cast_node_output = "{}_cast " .format (quant_add_output )
860+ cast_node_name = "{}_cast" .format (node . name )
861+ cast_node_output = "{}_output " .format (cast_node_name )
865862 cast_node = onnx .helper .make_node (
866863 "Cast" ,
867- [quant_add_output ],
864+ [last_output ],
868865 [cast_node_output ],
869866 cast_node_name ,
870867 to = getattr (onnx .TensorProto , "FLOAT" ), # get Float32 enum id
@@ -874,9 +871,9 @@ def _add_quantized_conv_matmul_add_ops(
874871 # create Mul node for rescale
875872 mul_node_inputs = [
876873 cast_node_output , # a
877- quantized_bias_scale_name , # b -> rescale factor
874+ output_scale_name , # b -> rescale factor
878875 ]
879- mul_node_name = "{}_rescale_mul" .format (quant_add_name )
876+ mul_node_name = "{}_rescale_mul" .format (cast_node_name )
880877 mul_node = onnx .helper .make_node (
881878 "Mul" ,
882879 mul_node_inputs ,
@@ -979,10 +976,10 @@ def _convert_quantizable_gemm_no_activations(model: ModelProto):
979976 weight_quantize_node = weight_quantize_node ,
980977 input_quantize_params = input_quantize_params ,
981978 weight_quantize_params = weight_quantize_params ,
982- bias_initializer = bias_initializer ,
983- bias_add_name = "{}_bias_add" .format (gemm_node .name ),
984979 target_output = gemm_node .output [0 ],
985980 transpose_weight = transpose_weight ,
981+ bias_add_name = "{}_bias_add" .format (gemm_node .name ),
982+ bias_initializer = bias_initializer ,
986983 )
987984
988985 # Cleanup
@@ -1108,14 +1105,14 @@ def _convert_quantizable_matmul_and_add(model: ModelProto):
11081105 weight_quantize_node = weight_quantize_node ,
11091106 input_quantize_params = input_quantize_params ,
11101107 weight_quantize_params = weight_quantize_params ,
1111- bias_initializer = bias_initializer ,
1112- bias_add_name = bias_add_node .name ,
11131108 target_output = (
11141109 output_dequantize_node .output [0 ]
11151110 if output_dequantize_node
11161111 else bias_add_node .output [0 ]
11171112 ),
11181113 transpose_weight = True ,
1114+ bias_add_name = bias_add_node .name ,
1115+ bias_initializer = bias_initializer ,
11191116 output_quantize_node = output_quantize_node ,
11201117 output_dequantize_node = output_dequantize_node ,
11211118 )
@@ -1164,7 +1161,7 @@ def _convert_quantizable_conv_integer(model: ModelProto):
11641161 | | |
11651162 | DequantizeLinear |
11661163 | | |
1167- | Conv (with bias)
1164+ | Conv (with optional bias)
11681165 | |
11691166 | OUTPUT
11701167 | We end up converting to:
@@ -1174,7 +1171,7 @@ def _convert_quantizable_conv_integer(model: ModelProto):
11741171 | |
11751172 | ConvInteger (with constant uint8 kernel)
11761173 | |
1177- | Add (constant bias + zero point correction)
1174+ | Add (optional, constant bias + zero point correction)
11781175 | |
11791176 | Cast (INT32 -> FP32)
11801177 | |
@@ -1187,10 +1184,10 @@ def _convert_quantizable_conv_integer(model: ModelProto):
11871184 conv_nodes = [n for n in model .graph .node if n .op_type in ["Conv" ]]
11881185 orig_conv_weight_name_to_node_ids = defaultdict (list )
11891186 for conv_node in conv_nodes :
1190- if len (conv_node .input ) != 3 :
1191- # this function currently only converts Conv nodes with bias param
1192- # (i.e. from folded batch norm value)
1193- continue
1187+ # if len(conv_node.input) != 3:
1188+ # # this function currently only converts Conv nodes with bias param
1189+ # # (i.e. from folded batch norm value)
1190+ # continue
11941191
11951192 graph = ONNXGraph (model )
11961193
@@ -1226,12 +1223,15 @@ def _convert_quantizable_conv_integer(model: ModelProto):
12261223 if input_quantize_node .op_type != "DequantizeLinear" :
12271224 continue
12281225
1229- bias_initializer = graph . get_init_by_name (conv_node .input [ 2 ])
1230- if bias_initializer is None :
1231- _LOGGER . debug ( f"Unable to find bias initializer: { conv_node . input [ 2 ] } " )
1232- continue
1226+ if len (conv_node .input ) == 3 :
1227+ bias_initializer = graph . get_init_by_name ( conv_node . input [ 2 ])
1228+ else :
1229+ bias_initializer = None
12331230
1234- _LOGGER .debug (f"Matched quantizable Conv weight and bias: { conv_node .name } " )
1231+ if bias_initializer is None :
1232+ _LOGGER .debug (f"Matched quantizable Conv weight: { conv_node .name } " )
1233+ else :
1234+ _LOGGER .debug (f"Matched quantizable Conv weight and bias: { conv_node .name } " )
12351235
12361236 # Conversion
12371237 _add_quantized_conv_matmul_add_ops (
@@ -1241,10 +1241,10 @@ def _convert_quantizable_conv_integer(model: ModelProto):
12411241 weight_quantize_node = weight_quantize_node ,
12421242 input_quantize_params = input_quantize_params ,
12431243 weight_quantize_params = weight_quantize_params ,
1244- bias_initializer = bias_initializer ,
1245- bias_add_name = "{}_bias_add" .format (conv_node .name ),
12461244 target_output = conv_node .output [0 ],
12471245 transpose_weight = False ,
1246+ bias_add_name = "{}_bias_add" .format (conv_node .name ),
1247+ bias_initializer = bias_initializer ,
12481248 )
12491249 orig_conv_weight_name_to_node_ids [input_quantize_node .input [0 ]].append (
12501250 "{}_quant" .format (conv_node .output [0 ])
0 commit comments