@@ -1331,6 +1331,8 @@ void GraphBuilderOrt::AddResample2dOperation(
13311331 const std::string node_name = GetNodeName (resample2d.label );
13321332 const std::string input_name = GetOperandName (resample2d.input_operand_id );
13331333 const std::string output_name = GetOperandName (resample2d.output_operand_id );
1334+ const std::vector<uint32_t >& input_shape =
1335+ GetOperand (resample2d.input_operand_id ).descriptor .shape ();
13341336 const std::vector<uint32_t >& output_shape =
13351337 GetOperand (resample2d.output_operand_id ).descriptor .shape ();
13361338 std::vector<const char *> input_names = {input_name.c_str ()};
@@ -1342,34 +1344,40 @@ void GraphBuilderOrt::AddResample2dOperation(
13421344 const std::string roi_name = " " ;
13431345 input_names.push_back (roi_name.c_str ());
13441346
1347+ // When axes != [2, 3], webnn blink side will insert transpose before and
1348+ // after resample2d -
1349+ // https://source.chromium.org/chromium/chromium/src/+/main:third_party/blink/renderer/modules/ml/webnn/ml_graph_type_converter.cc;l=1438.
13451350 CHECK_EQ (resample2d.axes .size (), 2u );
1351+ CHECK_EQ (resample2d.axes [0 ], 2u );
1352+ CHECK_EQ (resample2d.axes [1 ], 3u );
1353+
1354+ CHECK_EQ (input_shape.size (), 4u );
13461355 std::string scales_name;
13471356 std::string sizes_name;
1357+ // Here we using default axes([0,..., R-1]) due to this issue-
1358+ // https://github.com/shiyi9801/chromium/issues/92.
13481359 if (resample2d.scales ) {
1349- // The number of elements of scales should be the same as the rank of axes
1350- // if provided .
1351- std::array<float , 2 > scales_data = {resample2d.scales ->at (0 ),
1360+ // The number of elements of scales should be the same as the rank of input
1361+ // or axes .
1362+ std::array<float , 4 > scales_data = {1 , 1 , resample2d.scales ->at (0 ),
13521363 resample2d.scales ->at (1 )};
1353- scales_name = CreateInitializer<float >({2 }, scales_data);
1364+ scales_name = CreateInitializer<float >({4 }, scales_data);
13541365 sizes_name = " " ;
13551366 } else {
1356- // The number of elements of sizes should be the same as the length of axes
1357- // if provided.
1358- std::array<int64_t , 2 > sizes_data = {
1359- base::checked_cast<int64_t >(output_shape[resample2d.axes [0 ]]),
1360- base::checked_cast<int64_t >(output_shape[resample2d.axes [1 ]])};
1361- sizes_name = CreateInitializer<int64_t >({2 }, sizes_data);
1367+ // The number of elements of sizes should be the same as the rank of input
1368+ // or axes.
1369+ CHECK_EQ (output_shape.size (), 4u );
1370+ std::array<int64_t , 4 > sizes_data = {
1371+ base::checked_cast<int64_t >(output_shape[0 ]),
1372+ base::checked_cast<int64_t >(output_shape[1 ]),
1373+ base::checked_cast<int64_t >(output_shape[2 ]),
1374+ base::checked_cast<int64_t >(output_shape[3 ])};
1375+ sizes_name = CreateInitializer<int64_t >({4 }, sizes_data);
13621376 scales_name = " " ;
13631377 }
13641378 input_names.push_back (scales_name.c_str ());
13651379 input_names.push_back (sizes_name.c_str ());
13661380
1367- std::array<int64_t , 2 > axes = {
1368- base::checked_cast<int64_t >(resample2d.axes [0 ]),
1369- base::checked_cast<int64_t >(resample2d.axes [1 ])};
1370- ScopedOrtOpAttrPtr attr_axes =
1371- model_builder_.CreateAttribute (/* name=*/ " axes" , axes);
1372-
13731381 std::string mode;
13741382 switch (resample2d.mode ) {
13751383 case mojom::Resample2d::InterpolationMode::kLinear :
@@ -1379,10 +1387,8 @@ void GraphBuilderOrt::AddResample2dOperation(
13791387 mode = " nearest" ;
13801388 break ;
13811389 }
1382- ScopedOrtOpAttrPtr attr_mode =
1383- model_builder_.CreateAttribute (/* name=*/ " mode" , mode);
1384- std::array<OrtOpAttr*, 2 > attributes = {attr_axes.Release (),
1385- attr_mode.Release ()};
1390+ std::array<OrtOpAttr*, 1 > attributes = {
1391+ model_builder_.CreateAttribute (/* name=*/ " mode" , mode).Release ()};
13861392
13871393 std::array<const char *, 1 > output_names = {output_name.c_str ()};
13881394
0 commit comments