Skip to content

Commit f12a89e

Browse files
authored
[WebNN EP] Support GroupQueryAttention(GQA) (microsoft#23416)
### Description <!-- Describe your changes. --> Adds support for GroupQueryAttention via WebNN matmul, transpose, reshape, and other operations that follow the logic in the GQA subgraph below. ``` Abbreviations: B is batch_size, S is sequence_length, W is hidden_size, P is past_sequence_length N is number of attention heads, H is head size, and W=N*H, h=Sqrt(H), G is group size. GQA inputs: query, key value, past_key, past_value, seqlens_k, total_sequence_length Notes: If the datatype of the inputs (qkv and past kv) is float16, we cast them to float32 to ensure data precision. query key value | | | Reshape Reshape Reshape (B,S,H,N) seqlens_k | | | / | | | past_value | (scatter_indices*) | q_Transpose | \ | / | (0,2,1,3) | past_key ScatterND-----------------------|------> present_value \ | / | | present_key<--\----ScatterND Expand(G) (attention_bias, one/finfo_min mask*) \ | | / | Expand(G) | / | | | / | k_Transpose | / | (0,1,3,2) | / | | | / +---------------------------------------+ | ScaledDotProductAttention | +---------------------------------------+ | output ``` The ScaledDotProductAttention logic is: ``` ScaledDotProductAttention Subgraph: The basis for MultiHeadAttention and GroupQueryAttention inputs: query, key, value, scale, attention mask, and reshape_output_shape (for reshape) Abbreviatios: B is batch_size, S is query sequence_length, kv_S is key/value sequence length, N is number of attention heads, H is head size, W is hidden_size query key | | +---matmul---+ scale | | +-----div-----+ attn_mask | | +-----add-----+ value | | +------matmul-----+ | (0,2,1,3) transpose B,H,S,N -> B,S,H,N | Reshape B,S,H,N -> B,S,W | output ``` scatter_indices's calculation: ``` if_prefill (0/1 constant) | scatter_indices_left_constant scatter_indices_right_constant 0 ---> Where <--- Cast <---seqlens_k | | | | Add <--------------------------- scatter_pos* | | +--------------------+---------------------+ | scatter_indices ``` attention_bias's calculation: ``` ones_array (shape=B,N,S,P) range_of_qkv_sequence_length_constant (0,1,2,...) (shape=S) | | CumSum (axis=3, exclusive=true, reversed=false) Add <--- scatter_pos | | | Expand (shape=P,S) | | +-------------------------------> Lesser <------------------------------Transpose (1,0) | 1 ---> Where <--- finfo_min (minimum value of FP32) | attention_bias ``` *Notes: Now we only support `past_sequence_length == total_sequence_length` for GQA.* ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. -->
1 parent c5d1416 commit f12a89e

20 files changed

+718
-37
lines changed

js/web/docs/webnn-operators.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ platforms. Check the [WebNN status](https://webmachinelearning.github.io/webnn-s
4949
| GlobalLpPool| ai.onnx(7+) | l2Pool2d | Only supports 4-D input, 'p' value is 2 |
5050
| Greater | ai.onnx(7-8, 9-12, 13+) | greater | |
5151
| GreaterOrEqual | ai.onnx(12-15, 16+) | greaterOrEqual | |
52+
| GroupQueryAttention | com.microsoft(1+) | add, cast, concat, constant, cumulativeSum, div, expand, lesser, matmul, reshape, scatterND, softmax, transpose, where | Only supports input total_sequence_length is constant and past_sequence_length of past kv equals to present_sequence_length of present kv. Does not support cos_cache and sin_cache inputs |
5253
| GRU | ai.onnx(7-13, 14-21, 22+) | gru | Only supports 'layout' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' |
5354
| HardSigmoid | ai.onnx(7+) | hardSigmoid | |
5455
| HardSwish | ai.onnx(14+) | hardSwish | |

onnxruntime/core/providers/webnn/builders/helper.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,15 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
121121
return supported_nodes;
122122
}
123123

124-
bool AreInputDataTypesSame(const std::string_view op_type,
125-
gsl::span<const int32_t> input_types,
126-
const logging::Logger& logger) {
127-
for (size_t i = 1; i < input_types.size(); i++) {
128-
if (input_types[0] != input_types[i]) {
124+
bool AreDataTypesSame(const std::string_view op_type,
125+
gsl::span<const int32_t> data_types,
126+
const logging::Logger& logger) {
127+
for (size_t i = 1; i < data_types.size(); i++) {
128+
if (data_types[0] != data_types[i]) {
129129
LOGS(logger, VERBOSE) << "[" << op_type
130-
<< "] Input data types should be the same, but ["
131-
<< input_types[0] << "] does not match "
132-
<< input_types[i] << "].";
130+
<< "] data types should be the same, but ["
131+
<< data_types[0] << "] does not match "
132+
<< data_types[i] << "].";
133133
return false;
134134
}
135135
}

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,9 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
199199

200200
// Some ONNX ops are supported by decomposed WebNN ops.
201201
const std::map<std::string_view, std::vector<std::string_view>> decomposed_op_map = {
202+
{"GroupQueryAttention",
203+
{"add", "cast", "concat", "constant", "cumulativeSum", "div", "expand", "lesser", "matmul", "reshape", "scatterND",
204+
"softmax", "transpose", "where"}},
202205
{"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}},
203206
{"MatMulNBits", {"add", "dequantizeLinear", "matmul", "reshape", "transpose"}},
204207
{"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "split"}},
@@ -361,9 +364,9 @@ const std::map<ONNX_NAMESPACE::TensorProto_DataType, std::string_view> onnx_to_w
361364
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"},
362365
};
363366

364-
bool AreInputDataTypesSame(const std::string_view op_type,
365-
gsl::span<const int32_t> input_types,
366-
const logging::Logger& logger);
367+
bool AreDataTypesSame(const std::string_view op_type,
368+
gsl::span<const int32_t> input_types,
369+
const logging::Logger& logger);
367370
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types);
368371
bool IsDataTypeSupportedByOp(const std::string_view onnx_op_type,
369372
const int32_t onnx_data_type,
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Copyright (c) Intel Corporation. All rights reserved.
3+
// Licensed under the MIT License.
4+
5+
namespace onnxruntime {
6+
namespace webnn {
7+
/*
8+
ScaledDotProductAttention Subgraph: The basis for MultiHeadAttention and GroupQueryAttention
9+
inputs: query, key, value, scale, attention mask, and reshape_output_shape (for reshape)
10+
Abbreviatios: B is batch_size, S is query sequence_length, kv_S is key/value sequence length,
11+
N is number of attention heads, H is head size, W is hidden_size
12+
13+
query key
14+
| |
15+
+---matmul---+ scale
16+
| |
17+
+-----div-----+ attn_mask
18+
| |
19+
+-----add-----+ value
20+
| |
21+
+------matmul-----+
22+
|
23+
(0,2,1,3) transpose B,H,S,N -> B,S,H,N
24+
|
25+
reshape B,S,H,N -> B,S,W
26+
|
27+
output
28+
*/
29+
emscripten::val ScaledDotProductAttention(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger,
30+
emscripten::val query, emscripten::val key, emscripten::val value,
31+
emscripten::val scale, emscripten::val attn_mask,
32+
std::vector<uint32_t> reshape_output_shape) {
33+
emscripten::val common_options = emscripten::val::object();
34+
// B,H,S,N * B,H,kv_S,N = B,H,S,kv_S
35+
common_options.set("label", node.Name() + "_/Attention/qkv/matmul_1");
36+
emscripten::val matmul_output =
37+
model_builder.GetBuilder().call<emscripten::val>("matmul", query, key, common_options);
38+
39+
common_options.set("label", node.Name() + "_/Attention/qkv/div");
40+
emscripten::val div_output =
41+
model_builder.GetBuilder().call<emscripten::val>("mul", matmul_output, scale, common_options);
42+
43+
emscripten::val softmax_input = div_output;
44+
if (attn_mask != emscripten::val::undefined()) {
45+
common_options.set("label", node.Name() + "_/Attention/attn_mask/softmax_input");
46+
softmax_input = model_builder.GetBuilder().call<emscripten::val>("add", div_output, attn_mask, common_options);
47+
}
48+
49+
common_options.set("label", node.Name() + "_/Attention/attn_mask/softmax_input");
50+
int32_t softmax_axis = 3;
51+
emscripten::val softmax_output =
52+
model_builder.GetBuilder().call<emscripten::val>("softmax", softmax_input, softmax_axis, common_options);
53+
54+
// B,H,S,kv_S * B,H,kv_S,N = B,H,S,N
55+
common_options.set("label", node.Name() + "_/Attention/qkv/matmul_2");
56+
emscripten::val attn_output =
57+
model_builder.GetBuilder().call<emscripten::val>("matmul", softmax_output, value, common_options);
58+
59+
emscripten::val options = emscripten::val::object();
60+
options.set("permutation", emscripten::val::array(std::vector<uint32_t>({0, 2, 1, 3})));
61+
options.set("label", node.Name() + "_/Attention/qkv/transpose");
62+
attn_output = model_builder.GetBuilder().call<emscripten::val>("transpose", attn_output, options);
63+
64+
common_options.set("label", node.Name() + "_/Attention/qkv/reshape");
65+
attn_output = model_builder.GetBuilder().call<emscripten::val>(
66+
"reshape", attn_output, emscripten::val::array(reshape_output_shape), common_options);
67+
68+
return attn_output;
69+
}
70+
71+
} // namespace webnn
72+
} // namespace onnxruntime

onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod
6969
return false;
7070

7171
std::array<int32_t, 2> input_types{input0_type, input1_type};
72-
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
72+
if (!AreDataTypesSame(op_type, input_types, logger)) {
7373
return false;
7474
}
7575

onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ bool ConcatOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& nod
7070
}
7171

7272
std::array<int32_t, 2> input_types{input0_type, input_type};
73-
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
73+
if (!AreDataTypesSame(op_type, input_types, logger)) {
7474
return false;
7575
}
7676
}

onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node,
406406
if (has_input3) {
407407
input_types.push_back(input3_type);
408408
}
409-
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
409+
if (!AreDataTypesSame(op_type, input_types, logger)) {
410410
return false;
411411
}
412412

onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const GraphViewer&, const Node& node,
237237
if (has_input3) {
238238
input_types.push_back(input3_type);
239239
}
240-
if (!AreInputDataTypesSame(op_type, input_types, logger)) {
240+
if (!AreDataTypesSame(op_type, input_types, logger)) {
241241
return false;
242242
}
243243

0 commit comments

Comments
 (0)