Skip to content

Commit dff3f4d

Browse files
committed
Add transpose
1 parent 8b510e5 commit dff3f4d

File tree

10 files changed

+96
-25
lines changed

10 files changed

+96
-25
lines changed

services/ml/compilation_impl_nn.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,11 @@ CompilationImplNN::CompilationImplNN(const ModelImplNN* model,
3333
}
3434

3535
CompilationImplNN::~CompilationImplNN() {
36-
// ANeuralNetworksCompilation_free(nn_compilation_);
37-
// The nn_compilation_ will be deleted in execution phase.
36+
#if defined(OS_ANDROID)
37+
ANeuralNetworksCompilation_free(nn_compilation_);
38+
#else
39+
IE(ie_compilation_free)(ie_compilation_);
40+
#endif
3841
}
3942

4043
void CompilationImplNN::Finish(int32_t preference, FinishCallback callback) {

services/ml/execution_impl_nn.cc

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -14,23 +14,17 @@
1414

1515
namespace ml {
1616

17-
// TODO:: CompilationImplNN* => std::unique<CompilationImplNN> so that
18-
// ie_compilation_free(ie_compilation_); can host in class CompilationImplNN.
1917
ExecutionImplNN::ExecutionImplNN(const CompilationImplNN* compilation,
2018
mojo::ScopedSharedBufferHandle memory)
2119
: operands_(compilation->operands_),
2220
operations_(compilation->operations_),
2321
inputs_(compilation->inputs_),
2422
outputs_(compilation->outputs_),
2523
memory_(std::move(memory)),
26-
#if defined(OS_ANDROID)
27-
nn_compilation_(compilation->nn_compilation_) {
28-
#else
29-
ie_compilation_(compilation->ie_compilation_) {
30-
#endif
24+
compilation_impl_(compilation) {
3125
#if defined(OS_LINUX) || defined(OS_WIN)
3226
// Create Execution
33-
IE(ie_execution_create)(ie_compilation_, &ie_execution_);
27+
IE(ie_execution_create)(compilation_impl_->ie_compilation_, &ie_execution_);
3428
#endif
3529
uint32_t total_length = 0;
3630
inputs_info_.reserve(inputs_.size());
@@ -54,9 +48,7 @@ ExecutionImplNN::ExecutionImplNN(const CompilationImplNN* compilation,
5448

5549
ExecutionImplNN::~ExecutionImplNN() {
5650
#if defined(OS_ANDROID)
57-
ANeuralNetworksCompilation_free(nn_compilation_);
5851
#else
59-
IE(ie_compilation_free)(ie_compilation_);
6052
IE(ie_execution_free)(ie_execution_);
6153
#endif
6254
DLOG(INFO) << "ANeuralNetworksCompilation_free";
@@ -91,8 +83,8 @@ void ExecutionImplNN::StartCompute(mojom::UserBufferPtr user_buffer,
9183
int32_t result = 0;
9284
#if defined(OS_ANDROID)
9385
ANeuralNetworksExecution* nn_execution;
94-
result =
95-
ANeuralNetworksExecution_create(nn_compilation_, &nn_execution);
86+
result = ANeuralNetworksExecution_create(compilation_impl_->nn_compilation_,
87+
&nn_execution);
9688
#endif
9789
for (size_t i = 0; i < inputs_info_.size(); ++i) {
9890
std::unique_ptr<OperandInfo>& info = inputs_info_[i];
@@ -101,8 +93,8 @@ void ExecutionImplNN::StartCompute(mojom::UserBufferPtr user_buffer,
10193
nn_execution, i, NULL, static_cast<void*>(info->mapping.get()),
10294
info->length);
10395
#else
104-
result = IE(ie_execution_set_input)(ie_execution_, i,
105-
info->mapping.get(), info->length);
96+
result = IE(ie_execution_set_input)(ie_execution_, i, info->mapping.get(),
97+
info->length);
10698
#endif
10799
}
108100

@@ -113,8 +105,8 @@ void ExecutionImplNN::StartCompute(mojom::UserBufferPtr user_buffer,
113105
nn_execution, i, NULL, static_cast<void*>(info->mapping.get()),
114106
info->length);
115107
#else
116-
result = IE(ie_execution_set_output)(
117-
ie_execution_, i, info->mapping.get(), info->length);
108+
result = IE(ie_execution_set_output)(ie_execution_, i, info->mapping.get(),
109+
info->length);
118110
#endif
119111
}
120112

services/ml/execution_impl_nn.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <vector>
99

1010
#include "base/macros.h"
11+
#include "base/memory/scoped_refptr.h"
1112
#include "services/ml/common.h"
1213
#include "services/ml/compilation_impl_nn.h"
1314
#include "services/ml/model_impl_nn.h"
@@ -28,8 +29,7 @@ namespace ml {
2829

2930
class ExecutionImplNN : public mojom::Execution {
3031
public:
31-
ExecutionImplNN(const CompilationImplNN*,
32-
mojo::ScopedSharedBufferHandle);
32+
ExecutionImplNN(const CompilationImplNN*, mojo::ScopedSharedBufferHandle);
3333
~ExecutionImplNN() override;
3434

3535
void StartCompute(mojom::UserBufferPtr user_buffer,
@@ -46,13 +46,11 @@ class ExecutionImplNN : public mojom::Execution {
4646
std::vector<std::unique_ptr<OperandInfo>> inputs_info_;
4747
std::vector<std::unique_ptr<OperandInfo>> outputs_info_;
4848
mojo::ScopedSharedBufferHandle memory_;
49-
49+
const CompilationImplNN* compilation_impl_;
5050
#if defined(OS_LINUX) || defined(OS_WIN)
51-
ie_compilation_t* ie_compilation_;
5251
ie_execution_t* ie_execution_;
53-
#else
54-
ANeuralNetworksCompilation* nn_compilation_;
5552
#endif
53+
5654
DISALLOW_COPY_AND_ASSIGN(ExecutionImplNN);
5755
};
5856

third_party/blink/renderer/modules/ml/v2/BUILD.gn

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ blink_modules_sources("v2") {
3636
"ops/reshape.h",
3737
"ops/softmax.cc",
3838
"ops/softmax.h",
39+
"ops/transpose.cc",
40+
"ops/transpose.h",
3941
]
4042

4143
public_deps = [

third_party/blink/renderer/modules/ml/v2/nn_compilation.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ void NNCompilation::OnCreateExecution(
7070

7171
if (result_code == ml::mojom::blink::NOT_ERROR) {
7272
resolver->Resolve(MakeGarbageCollected<NNExecution>(
73-
std::move(init_params), std::move(name_index_)));
73+
std::move(init_params), name_index_));
7474
} else {
7575
resolver->Reject(MakeGarbageCollected<DOMException>(
7676
DOMExceptionCode::kInvalidStateError,

third_party/blink/renderer/modules/ml/v2/nn_context.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include "third_party/blink/renderer/modules/ml/v2/ops/relu.h"
2424
#include "third_party/blink/renderer/modules/ml/v2/ops/reshape.h"
2525
#include "third_party/blink/renderer/modules/ml/v2/ops/softmax.h"
26+
#include "third_party/blink/renderer/modules/ml/v2/ops/transpose.h"
2627
#include "third_party/blink/renderer/platform/bindings/exception_code.h"
2728

2829
namespace blink {
@@ -237,6 +238,10 @@ Operand* NNContext::matmul(Operand* a, Operand* b) {
237238
return MakeGarbageCollected<MatMul>(a, b);
238239
}
239240

241+
Operand* NNContext::transpose(Operand* input, WTF::Vector<int32_t> new_shape) {
242+
return MakeGarbageCollected<Transpose>(input, std::move(new_shape));
243+
}
244+
240245
ScriptPromise NNContext::createModel(ScriptState* script_state,
241246
const NamedOperandVector& outputs) {
242247
auto* resolver = MakeGarbageCollected<ScriptPromiseResolver>(script_state);

third_party/blink/renderer/modules/ml/v2/nn_context.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class NNContext final : public ScriptWrappable,
7171
Operand* relu(Operand*);
7272
Operand* matmul(Operand*, Operand*);
7373
ScriptPromise createModel(ScriptState*, const NamedOperandVector&);
74+
Operand* transpose(Operand*, WTF::Vector<int32_t>);
7475

7576
// ExecutionContextLifecycleObserver overrides.
7677
void ContextDestroyed() override;

third_party/blink/renderer/modules/ml/v2/nn_context.idl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ interface NNContext {
2828
Operand softmax(Operand input);
2929
Operand relu(Operand input);
3030
Operand matmul(Operand a, Operand b);
31+
Operand transpose(Operand input, optional sequence<long> permutation=[]);
3132

3233
// Create Model
3334
[CallWith=ScriptState] Promise<Model> createModel(sequence<NamedOperand> outputs);
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
// Copyright 2020 The Chromium Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style license that can be
3+
// found in the LICENSE file.
4+
5+
#include "third_party/blink/renderer/modules/ml/v2/ops/transpose.h"
6+
7+
#include <memory>
8+
9+
#include "third_party/blink/renderer/modules/ml/neural_network_context.h"
10+
11+
namespace blink {
12+
13+
Transpose::Transpose(Operand* input, WTF::Vector<int32_t> permutation)
14+
: Output({input}), permutation_(permutation) {}
15+
16+
void Transpose::AddLayer(NNModel* model, uint32_t& index) {
17+
Vector<uint32_t> input_indexes;
18+
// Add input index to input_indexes.
19+
for (auto& input : Output::Inputs()) {
20+
input_indexes.push_back(input->Index());
21+
}
22+
23+
// Add permutation operand and set the value.
24+
if (!permutation_.IsEmpty()) {
25+
uint32_t permutation_index = index++;
26+
// The new shape is 1-D tensor.
27+
Vector<uint32_t> permutation_dims(1, permutation_.size());
28+
model->AddTensorOperand(permutation_index, permutation_dims, permutation_);
29+
input_indexes.push_back(permutation_index);
30+
}
31+
32+
// Add Reshape output operand.
33+
uint32_t output_index = index++;
34+
Operand::SetIndex(output_index);
35+
model->AddUnspecifiedOperand();
36+
37+
model->AddOperation(NeuralNetworkContext::kTranspose, input_indexes,
38+
{output_index});
39+
}
40+
41+
} // namespace blink
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
// Copyright 2020 The Chromium Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style license that can be
3+
// found in the LICENSE file.
4+
5+
#ifndef THIRD_PARTY_BLINK_RENDERER_MODULES_ML_OPS_TRANSPOSE_H_
6+
#define THIRD_PARTY_BLINK_RENDERER_MODULES_ML_OPS_TRANSPOSE_H_
7+
8+
#include "third_party/blink/renderer/modules/ml/v2/nn_model.h"
9+
#include "third_party/blink/renderer/modules/ml/v2/operand.h"
10+
#include "third_party/blink/renderer/modules/ml/v2/ops/output.h"
11+
#include "third_party/blink/renderer/platform/wtf/vector.h"
12+
13+
namespace blink {
14+
15+
class Transpose final : public Output {
16+
public:
17+
Transpose(Operand*, WTF::Vector<int32_t>);
18+
~Transpose() override = default;
19+
20+
void AddLayer(NNModel* model, uint32_t& index) override;
21+
22+
private:
23+
Vector<int32_t> permutation_;
24+
};
25+
26+
} // namespace blink
27+
28+
#endif // THIRD_PARTY_BLINK_RENDERER_MODULES_ML_OPS_TRANSPOSE_H_

0 commit comments

Comments
 (0)