Commit f12a89e
authored
[WebNN EP] Support GroupQueryAttention(GQA) (microsoft#23416)
### Description
<!-- Describe your changes. -->
Adds support for GroupQueryAttention via WebNN matmul, transpose,
reshape, and other operations that follow the logic in the GQA subgraph
below.
```
Abbreviations: B is batch_size, S is sequence_length, W is hidden_size, P is past_sequence_length
N is number of attention heads, H is head size, and W=N*H, h=Sqrt(H), G is group size.
GQA inputs: query, key value, past_key, past_value, seqlens_k, total_sequence_length
Notes: If the datatype of the inputs (qkv and past kv) is float16, we cast them to float32 to ensure data precision.
query key value
| | |
Reshape Reshape Reshape (B,S,H,N) seqlens_k
| | | / |
| | past_value | (scatter_indices*) |
q_Transpose | \ | / |
(0,2,1,3) | past_key ScatterND-----------------------|------> present_value
\ | / | |
present_key<--\----ScatterND Expand(G) (attention_bias, one/finfo_min mask*)
\ | | /
| Expand(G) | /
| | | /
| k_Transpose | /
| (0,1,3,2) | /
| | | /
+---------------------------------------+
| ScaledDotProductAttention |
+---------------------------------------+
|
output
```
The ScaledDotProductAttention logic is:
```
ScaledDotProductAttention Subgraph: The basis for MultiHeadAttention and GroupQueryAttention
inputs: query, key, value, scale, attention mask, and reshape_output_shape (for reshape)
Abbreviatios: B is batch_size, S is query sequence_length, kv_S is key/value sequence length,
N is number of attention heads, H is head size, W is hidden_size
query key
| |
+---matmul---+ scale
| |
+-----div-----+ attn_mask
| |
+-----add-----+ value
| |
+------matmul-----+
|
(0,2,1,3) transpose B,H,S,N -> B,S,H,N
|
Reshape B,S,H,N -> B,S,W
|
output
```
scatter_indices's calculation:
```
if_prefill (0/1 constant)
|
scatter_indices_left_constant scatter_indices_right_constant 0 ---> Where <--- Cast <---seqlens_k
| | |
| Add <--------------------------- scatter_pos*
| |
+--------------------+---------------------+
|
scatter_indices
```
attention_bias's calculation:
```
ones_array (shape=B,N,S,P) range_of_qkv_sequence_length_constant (0,1,2,...) (shape=S)
| |
CumSum (axis=3, exclusive=true, reversed=false) Add <--- scatter_pos
| |
| Expand (shape=P,S)
| |
+-------------------------------> Lesser <------------------------------Transpose (1,0)
|
1 ---> Where <--- finfo_min (minimum value of FP32)
|
attention_bias
```
*Notes: Now we only support `past_sequence_length ==
total_sequence_length` for GQA.*
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->1 parent c5d1416 commit f12a89e
File tree
20 files changed
+718
-37
lines changed- js/web/docs
- onnxruntime/core/providers/webnn/builders
- impl
20 files changed
+718
-37
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
49 | 49 | | |
50 | 50 | | |
51 | 51 | | |
| 52 | + | |
52 | 53 | | |
53 | 54 | | |
54 | 55 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
121 | 121 | | |
122 | 122 | | |
123 | 123 | | |
124 | | - | |
125 | | - | |
126 | | - | |
127 | | - | |
128 | | - | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
129 | 129 | | |
130 | | - | |
131 | | - | |
132 | | - | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
133 | 133 | | |
134 | 134 | | |
135 | 135 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
199 | 199 | | |
200 | 200 | | |
201 | 201 | | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
202 | 205 | | |
203 | 206 | | |
204 | 207 | | |
| |||
361 | 364 | | |
362 | 365 | | |
363 | 366 | | |
364 | | - | |
365 | | - | |
366 | | - | |
| 367 | + | |
| 368 | + | |
| 369 | + | |
367 | 370 | | |
368 | 371 | | |
369 | 372 | | |
| |||
Lines changed: 72 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
69 | 69 | | |
70 | 70 | | |
71 | 71 | | |
72 | | - | |
| 72 | + | |
73 | 73 | | |
74 | 74 | | |
75 | 75 | | |
| |||
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
70 | 70 | | |
71 | 71 | | |
72 | 72 | | |
73 | | - | |
| 73 | + | |
74 | 74 | | |
75 | 75 | | |
76 | 76 | | |
| |||
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
406 | 406 | | |
407 | 407 | | |
408 | 408 | | |
409 | | - | |
| 409 | + | |
410 | 410 | | |
411 | 411 | | |
412 | 412 | | |
| |||
Lines changed: 1 addition & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
237 | 237 | | |
238 | 238 | | |
239 | 239 | | |
240 | | - | |
| 240 | + | |
241 | 241 | | |
242 | 242 | | |
243 | 243 | | |
| |||
0 commit comments