@@ -955,60 +955,5 @@ Status Transform(const GraphViewer& src_graph_viewer,
955955 return status;
956956}
957957} // namespace qdq_scales_fix
958-
959- namespace bfloat16_fix {
960- void replace_bf16_with_fp16 (qdq_scales_fix::CustomGraph& gen_graph) {
961- for (auto & const_node : gen_graph.original_graph .Nodes ()) {
962- auto node = const_cast <ONNX_NAMESPACE::Node*>(const_node);
963- if (node->OpType () == " Cast" ) {
964- for (auto & [name, const_attribute] : node->GetAttributes ()) {
965- auto & attribute = const_cast <ONNX_NAMESPACE::AttributeProto&>(const_attribute);
966- if (name == " to" && attribute.type () == ONNX_NAMESPACE::AttributeProto_AttributeType_INT)
967- if (attribute.i () == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
968- attribute.set_i (ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
969- }
970- }
971- for (auto & output : node->OutputDefs ()) {
972- auto & output_proto = const_cast <ONNX_NAMESPACE::TypeProto&>(output->ToProto ().type ());
973- if (output_proto.mutable_tensor_type ()->elem_type () == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16)
974- output_proto.mutable_tensor_type ()->set_elem_type (ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
975- }
976- }
977-
978- for (auto & node : gen_graph.original_graph .Nodes ()) {
979- for (auto & input_def : node->InputDefs ()) {
980- ORT_THROW_IF_ERROR (graph_utils::ConvertInMemoryDataToInline (gen_graph.original_graph , input_def->Name ()));
981- }
982- }
983-
984- const auto & init_set = gen_graph.original_graph .GetAllInitializedTensors ();
985- for (auto & [key, const_tensor_proto] : init_set) {
986- auto tensor_proto = const_cast <ONNX_NAMESPACE::TensorProto*>(const_tensor_proto);
987- auto dt = tensor_proto->data_type ();
988- if (dt == ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16) {
989- auto raw_data = tensor_proto->has_raw_data () ? reinterpret_cast <std::uint16_t *>(tensor_proto->mutable_raw_data ()->data ()) : nullptr ;
990- if (raw_data) {
991- tensor_proto->set_data_type (ONNX_NAMESPACE::TensorProto_DataType_FLOAT16);
992- std::int64_t size = 1 ;
993- for (int i = 0 ; i < tensor_proto->dims_size (); ++i)
994- size *= tensor_proto->dims ()[i];
995- for (std::int64_t i = 0 ; i < size; ++i) {
996- raw_data[i] = onnxruntime::MLFloat16 (onnxruntime::BFloat16::FromBits (raw_data[i])).val ;
997- }
998- }
999- }
1000- }
1001- }
1002-
1003- Status Transform (const GraphViewer& src_graph_viewer,
1004- const logging::Logger& logger,
1005- /* out*/ std::unique_ptr<onnxruntime::Model>& model) {
1006- auto status = qdq_scales_fix::copy_model (src_graph_viewer, logger, model);
1007- auto g = qdq_scales_fix::generate_graph_from_onnx (model->MainGraph ());
1008-
1009- replace_bf16_with_fp16 (g);
1010- return status;
1011- }
1012- } // namespace bfloat16_fix
1013958} // namespace openvino_ep
1014959} // namespace onnxruntime
0 commit comments