Skip to content

Commit b2622ba

Browse files
lisa0314shiyi9801
authored andcommitted
[Resample2d] Using default axes as a workaround for OV (chromium#99)
Fix chromium#92
1 parent 4108f49 commit b2622ba

File tree

1 file changed

+26
-20
lines changed

1 file changed

+26
-20
lines changed

services/webnn/ort/graph_builder_ort.cc

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,6 +1331,8 @@ void GraphBuilderOrt::AddResample2dOperation(
13311331
const std::string node_name = GetNodeName(resample2d.label);
13321332
const std::string input_name = GetOperandName(resample2d.input_operand_id);
13331333
const std::string output_name = GetOperandName(resample2d.output_operand_id);
1334+
const std::vector<uint32_t>& input_shape =
1335+
GetOperand(resample2d.input_operand_id).descriptor.shape();
13341336
const std::vector<uint32_t>& output_shape =
13351337
GetOperand(resample2d.output_operand_id).descriptor.shape();
13361338
std::vector<const char*> input_names = {input_name.c_str()};
@@ -1342,34 +1344,40 @@ void GraphBuilderOrt::AddResample2dOperation(
13421344
const std::string roi_name = "";
13431345
input_names.push_back(roi_name.c_str());
13441346

1347+
// When axes != [2, 3], webnn blink side will insert transpose before and
1348+
// after resample2d -
1349+
// https://source.chromium.org/chromium/chromium/src/+/main:third_party/blink/renderer/modules/ml/webnn/ml_graph_type_converter.cc;l=1438.
13451350
CHECK_EQ(resample2d.axes.size(), 2u);
1351+
CHECK_EQ(resample2d.axes[0], 2u);
1352+
CHECK_EQ(resample2d.axes[1], 3u);
1353+
1354+
CHECK_EQ(input_shape.size(), 4u);
13461355
std::string scales_name;
13471356
std::string sizes_name;
1357+
// Here we using default axes([0,..., R-1]) due to this issue-
1358+
// https://github.com/shiyi9801/chromium/issues/92.
13481359
if (resample2d.scales) {
1349-
// The number of elements of scales should be the same as the rank of axes
1350-
// if provided.
1351-
std::array<float, 2> scales_data = {resample2d.scales->at(0),
1360+
// The number of elements of scales should be the same as the rank of input
1361+
// or axes.
1362+
std::array<float, 4> scales_data = {1, 1, resample2d.scales->at(0),
13521363
resample2d.scales->at(1)};
1353-
scales_name = CreateInitializer<float>({2}, scales_data);
1364+
scales_name = CreateInitializer<float>({4}, scales_data);
13541365
sizes_name = "";
13551366
} else {
1356-
// The number of elements of sizes should be the same as the length of axes
1357-
// if provided.
1358-
std::array<int64_t, 2> sizes_data = {
1359-
base::checked_cast<int64_t>(output_shape[resample2d.axes[0]]),
1360-
base::checked_cast<int64_t>(output_shape[resample2d.axes[1]])};
1361-
sizes_name = CreateInitializer<int64_t>({2}, sizes_data);
1367+
// The number of elements of sizes should be the same as the rank of input
1368+
// or axes.
1369+
CHECK_EQ(output_shape.size(), 4u);
1370+
std::array<int64_t, 4> sizes_data = {
1371+
base::checked_cast<int64_t>(output_shape[0]),
1372+
base::checked_cast<int64_t>(output_shape[1]),
1373+
base::checked_cast<int64_t>(output_shape[2]),
1374+
base::checked_cast<int64_t>(output_shape[3])};
1375+
sizes_name = CreateInitializer<int64_t>({4}, sizes_data);
13621376
scales_name = "";
13631377
}
13641378
input_names.push_back(scales_name.c_str());
13651379
input_names.push_back(sizes_name.c_str());
13661380

1367-
std::array<int64_t, 2> axes = {
1368-
base::checked_cast<int64_t>(resample2d.axes[0]),
1369-
base::checked_cast<int64_t>(resample2d.axes[1])};
1370-
ScopedOrtOpAttrPtr attr_axes =
1371-
model_builder_.CreateAttribute(/*name=*/"axes", axes);
1372-
13731381
std::string mode;
13741382
switch (resample2d.mode) {
13751383
case mojom::Resample2d::InterpolationMode::kLinear:
@@ -1379,10 +1387,8 @@ void GraphBuilderOrt::AddResample2dOperation(
13791387
mode = "nearest";
13801388
break;
13811389
}
1382-
ScopedOrtOpAttrPtr attr_mode =
1383-
model_builder_.CreateAttribute(/*name=*/"mode", mode);
1384-
std::array<OrtOpAttr*, 2> attributes = {attr_axes.Release(),
1385-
attr_mode.Release()};
1390+
std::array<OrtOpAttr*, 1> attributes = {
1391+
model_builder_.CreateAttribute(/*name=*/"mode", mode).Release()};
13861392

13871393
std::array<const char*, 1> output_names = {output_name.c_str()};
13881394

0 commit comments

Comments
 (0)