Skip to content

Commit 8b510e5

Browse files
committed
Add matmul
1 parent fe6218b commit 8b510e5

File tree

6 files changed

+80
-0
lines changed

6 files changed

+80
-0
lines changed

third_party/blink/renderer/modules/ml/v2/BUILD.gn

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ blink_modules_sources("v2") {
2424
"ops/conv.h",
2525
"ops/input.cc",
2626
"ops/input.h",
27+
"ops/matmul.cc",
28+
"ops/matmul.h",
2729
"ops/output.cc",
2830
"ops/output.h",
2931
"ops/pooling.cc",

third_party/blink/renderer/modules/ml/v2/nn_context.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "third_party/blink/renderer/modules/ml/v2/ops/constant.h"
1919
#include "third_party/blink/renderer/modules/ml/v2/ops/conv.h"
2020
#include "third_party/blink/renderer/modules/ml/v2/ops/input.h"
21+
#include "third_party/blink/renderer/modules/ml/v2/ops/matmul.h"
2122
#include "third_party/blink/renderer/modules/ml/v2/ops/pooling.h"
2223
#include "third_party/blink/renderer/modules/ml/v2/ops/relu.h"
2324
#include "third_party/blink/renderer/modules/ml/v2/ops/reshape.h"
@@ -232,6 +233,10 @@ Operand* NNContext::relu(Operand* input) {
232233
return MakeGarbageCollected<Relu>(input);
233234
}
234235

236+
Operand* NNContext::matmul(Operand* a, Operand* b) {
237+
return MakeGarbageCollected<MatMul>(a, b);
238+
}
239+
235240
ScriptPromise NNContext::createModel(ScriptState* script_state,
236241
const NamedOperandVector& outputs) {
237242
auto* resolver = MakeGarbageCollected<ScriptPromiseResolver>(script_state);

third_party/blink/renderer/modules/ml/v2/nn_context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class NNContext final : public ScriptWrappable,
6969
Operand* reshape(Operand*, WTF::Vector<int32_t>);
7070
Operand* softmax(Operand*);
7171
Operand* relu(Operand*);
72+
Operand* matmul(Operand*, Operand*);
7273
ScriptPromise createModel(ScriptState*, const NamedOperandVector&);
7374

7475
// ExecutionContextLifecycleObserver overrides.

third_party/blink/renderer/modules/ml/v2/nn_context.idl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ interface NNContext {
2727
Operand reshape(Operand input, sequence<long> newShape);
2828
Operand softmax(Operand input);
2929
Operand relu(Operand input);
30+
Operand matmul(Operand a, Operand b);
3031

3132
// Create Model
3233
[CallWith=ScriptState] Promise<Model> createModel(sequence<NamedOperand> outputs);
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright 2020 The Chromium Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style license that can be
3+
// found in the LICENSE file.
4+
5+
#include "third_party/blink/renderer/modules/ml/v2/ops/matmul.h"
6+
7+
#include <memory>
8+
9+
#include "third_party/blink/renderer/modules/ml/neural_network_context.h"
10+
11+
namespace blink {
12+
13+
MatMul::MatMul(Operand* a, Operand* b) : Output({a, b}) {}
14+
15+
void MatMul::AddLayer(NNModel* model, uint32_t& index) {
16+
Vector<uint32_t> input_indexes;
17+
// Add input index to input_indexes.
18+
for (auto& input : Output::Inputs()) {
19+
input_indexes.push_back(input->Index());
20+
}
21+
22+
// We can't get the bias size.
23+
uint32_t bias_index = index++;
24+
model->AddUnspecifiedOperand();
25+
input_indexes.push_back(bias_index);
26+
27+
// Add fused code operand and set the value.
28+
uint32_t fuse_index = index++;
29+
model->AddScalarOperand(fuse_index, 0);
30+
input_indexes.push_back(fuse_index);
31+
32+
// There are no MatMul defined in Android NN API, We use kFullyConnected
33+
// instead of MatMul.
34+
uint32_t matmul_index = index++;
35+
model->AddScalarOperand(matmul_index, 0);
36+
input_indexes.push_back(matmul_index);
37+
38+
// Add MatMul output operand.
39+
uint32_t output_index = index++;
40+
Operand::SetIndex(output_index);
41+
model->AddUnspecifiedOperand();
42+
43+
model->AddOperation(NeuralNetworkContext::kFullyConnected, input_indexes,
44+
{output_index});
45+
}
46+
47+
} // namespace blink
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// Copyright 2020 The Chromium Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style license that can be
3+
// found in the LICENSE file.
4+
5+
#ifndef THIRD_PARTY_BLINK_RENDERER_MODULES_ML_OPS_MATMUL_H_
6+
#define THIRD_PARTY_BLINK_RENDERER_MODULES_ML_OPS_MATMUL_H_
7+
8+
#include "third_party/blink/renderer/modules/ml/v2/nn_model.h"
9+
#include "third_party/blink/renderer/modules/ml/v2/operand.h"
10+
#include "third_party/blink/renderer/modules/ml/v2/ops/output.h"
11+
12+
namespace blink {
13+
14+
class MatMul final : public Output {
15+
public:
16+
MatMul(Operand*, Operand*);
17+
~MatMul() override = default;
18+
19+
void AddLayer(NNModel* model, uint32_t& index) override;
20+
};
21+
22+
} // namespace blink
23+
24+
#endif // THIRD_PARTY_BLINK_RENDERER_MODULES_ML_OPS_MATMUL_H_

0 commit comments

Comments
 (0)