Skip to content

Commit 5a0692d

Browse files
committed
fix quantizedLinear layer feeds into grapg output
1 parent ea13a05 commit 5a0692d

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

onnxruntime/core/providers/openvino/qdq_transformations/qdq_stripping.cc

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,7 @@ static bool CheckQFeedsIntoQuantizedOutput(const NodeUnit& node_unit,
346346
auto op_of_quantized_layer = node_unit.Outputs();
347347
for (auto& itr : op_of_quantized_layer) {
348348
auto it = graph_op_data_type.find(itr.node_arg.Name());
349-
if (it != graph_op_data_type.end() && it->second == "tensor(uint8)") {
349+
if (it != graph_op_data_type.end() && (it->second == "tensor(uint8)" || it->second == "tensor(uint16)")) {
350350
return true;
351351
}
352352
}
@@ -369,6 +369,11 @@ static bool CheckQRuleSet(const NodeUnit& node_unit,
369369
graph_op_data_type[src_graph.GetNodeArg(ops->Name())->Name()] = ops->Type()->data();
370370
}
371371

372+
// check If any quantized node feeds into the src graph output
373+
if (CheckQFeedsIntoQuantizedOutput(node_unit, std::move(graph_op_data_type))) {
374+
return true;
375+
}
376+
372377
// If UInt16 Q, don't keep it
373378
if (GetQDQDataType(q_node) == DT_UINT16 || GetQDQDataType(q_node) == DT_INT16) {
374379
reason = SkipReason::Int16QDQ;
@@ -381,9 +386,7 @@ static bool CheckQRuleSet(const NodeUnit& node_unit,
381386
} else if (op_type == "Add") {
382387
// Add keeps all Qs
383388
return true;
384-
} else if (CheckQFeedsIntoQuantizedOutput(node_unit, std::move(graph_op_data_type))) {
385-
return true;
386-
} else {
389+
} else {
387390
// Keep Q of an unsupported Op only if the target that succeeds it is a supported Op in this list
388391
return IsNextTargetNodeOfQValid(q_node, &target_node, src_graph, {"Conv", "Add", "MatMul"}, false);
389392
}

0 commit comments

Comments
 (0)