Skip to content

Commit 995e148

Browse files
authored
feat(metail): Parallelize SDPA across multiple simdgroups (#144)
1 parent 69a0b1c commit 995e148

File tree

6 files changed

+238
-69
lines changed

6 files changed

+238
-69
lines changed

gpt_oss/metal/source/include/internal/math.h

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <assert.h>
34
#include <stddef.h>
45
#include <stdint.h>
56

@@ -19,11 +20,21 @@ inline static size_t math_sub_sat(size_t a, size_t b) {
1920
return a > b ? a - b : 0;
2021
}
2122

22-
static size_t math_round_up_po2(size_t bytes, size_t multiple) {
23+
static size_t math_round_down_po2(size_t number, size_t multiple) {
24+
assert(multiple != 0);
25+
assert((multiple & (multiple - 1)) == 0);
26+
27+
return number & -multiple;
28+
}
29+
30+
static size_t math_round_up_po2(size_t number, size_t multiple) {
31+
assert(multiple != 0);
32+
assert((multiple & (multiple - 1)) == 0);
33+
2334
const size_t multiple_mask = multiple - 1;
24-
if ((bytes & multiple_mask) != 0) {
25-
bytes |= multiple_mask;
26-
bytes += 1;
35+
if ((number & multiple_mask) != 0) {
36+
number |= multiple_mask;
37+
number += 1;
2738
}
28-
return bytes;
39+
return number;
2940
}

gpt_oss/metal/source/include/internal/metal.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(
118118
size_t num_threadgroups_z,
119119
size_t params_size,
120120
const void* params,
121-
size_t num_buffers,
122-
const struct gptoss_metal_buffer** buffers,
123-
const size_t* buffer_offsets);
121+
size_t num_device_buffers,
122+
const struct gptoss_metal_buffer** device_buffers,
123+
const size_t* device_buffer_offsets,
124+
size_t threadgroup_buffer_size);
124125

125126
enum gptoss_status gptoss_metal_command_buffer_commit(
126127
const struct gptoss_metal_command_buffer* command_buffer);

gpt_oss/metal/source/include/internal/metal.hpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,10 +246,11 @@ class CommandBuffer {
246246
const std::array<size_t, 3>& threadgroup_size,
247247
const std::array<size_t, 3>& num_threadgroups,
248248
size_t params_size, const void* params,
249-
std::initializer_list<const Buffer*> buffers = {})
249+
std::initializer_list<const Buffer*> device_buffers = {},
250+
size_t threadgroup_buffer_size = 0)
250251
{
251-
std::vector<const gptoss_metal_buffer*> buffer_handles(buffers.size());
252-
std::transform(buffers.begin(), buffers.end(), buffer_handles.begin(),
252+
std::vector<const gptoss_metal_buffer*> buffer_handles(device_buffers.size());
253+
std::transform(device_buffers.begin(), device_buffers.end(), buffer_handles.begin(),
253254
[](const Buffer* buffer) -> const gptoss_metal_buffer* { return buffer->handle(); });
254255
Check(gptoss_metal_command_buffer_encode_launch_kernel(
255256
&command_buffer_, function.handle(),
@@ -258,7 +259,8 @@ class CommandBuffer {
258259
params_size, params,
259260
buffer_handles.size(),
260261
buffer_handles.data(),
261-
/*buffer_offsets=*/nullptr),
262+
/*buffer_offsets=*/nullptr,
263+
threadgroup_buffer_size),
262264
"gptoss_metal_command_buffer_encode_launch_kernel");
263265
}
264266

gpt_oss/metal/source/metal-kernels.c

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_u32_fill_random(
4646
threadgroup_size, 1, 1,
4747
num_threadgroups, 1, 1,
4848
sizeof(args), &args,
49-
1, &output_buffer, &output_offset);
49+
1, &output_buffer, &output_offset,
50+
/*threadgroup_buffer_size=*/0);
5051
}
5152

5253
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(
@@ -93,7 +94,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(
9394
threadgroup_size, 1, 1,
9495
num_threadgroups, 1, 1,
9596
sizeof(args), &args,
96-
1, &output_buffer, &output_offset);
97+
1, &output_buffer, &output_offset,
98+
/*threadgroup_buffer_size=*/0);
9799
}
98100

99101
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
@@ -140,7 +142,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
140142
threadgroup_size, 1, 1,
141143
num_threadgroups, 1, 1,
142144
sizeof(args), &args,
143-
1, &output_buffer, &output_offset);
145+
1, &output_buffer, &output_offset,
146+
/*threadgroup_buffer_size=*/0);
144147
}
145148

146149
enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
@@ -180,7 +183,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
180183
threadgroup_size, 1, 1,
181184
num_threadgroups, 1, 1,
182185
sizeof(args), &args,
183-
3, (const struct gptoss_metal_buffer *[]) {block_buffer, scale_buffer, output_buffer}, NULL);
186+
3, (const struct gptoss_metal_buffer *[]) {block_buffer, scale_buffer, output_buffer}, NULL,
187+
/*threadgroup_buffer_size=*/0);
184188
}
185189

186190
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
@@ -222,7 +226,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings
222226
sizeof(args), &args,
223227
3,
224228
(const struct gptoss_metal_buffer *[]) {token_buffer, weight_buffer, output_buffer},
225-
(const size_t[]) {token_offset, weight_offset, output_offset});
229+
(const size_t[]) {token_offset, weight_offset, output_offset},
230+
/*threadgroup_buffer_size=*/0);
226231
}
227232

228233
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
@@ -268,7 +273,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
268273
sizeof(args), &args,
269274
3,
270275
(const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer},
271-
(const size_t[]) {input_offset, weight_offset, output_offset});
276+
(const size_t[]) {input_offset, weight_offset, output_offset},
277+
/*threadgroup_buffer_size=*/0);
272278
}
273279

274280
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
@@ -325,7 +331,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
325331
sizeof(args), &args,
326332
4,
327333
(const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer},
328-
(const size_t[]) {input_offset, weight_offset, bias_offset, output_offset});
334+
(const size_t[]) {input_offset, weight_offset, bias_offset, output_offset},
335+
/*threadgroup_buffer_size=*/0);
329336
}
330337

331338
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
@@ -382,7 +389,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad
382389
sizeof(args), &args,
383390
4,
384391
(const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer},
385-
(const size_t[]) {input_offset, weight_offset, bias_offset, output_offset});
392+
(const size_t[]) {input_offset, weight_offset, bias_offset, output_offset},
393+
/*threadgroup_buffer_size=*/0);
386394
}
387395

388396
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
@@ -437,7 +445,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembeddi
437445
sizeof(args), &args,
438446
4,
439447
(const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, argmax_buffer},
440-
(const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset});
448+
(const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset},
449+
/*threadgroup_buffer_size=*/0);
441450
}
442451

443452
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
@@ -510,7 +519,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
510519
sizeof(args), &args,
511520
6,
512521
(const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},
513-
(const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset});
522+
(const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset},
523+
/*threadgroup_buffer_size=*/0);
514524
}
515525

516526
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
@@ -581,7 +591,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
581591
sizeof(args), &args,
582592
6,
583593
(const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},
584-
(const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset});
594+
(const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset},
595+
/*threadgroup_buffer_size=*/0);
585596
}
586597

587598
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
@@ -631,7 +642,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
631642
threadgroup_size, 1, 1,
632643
num_qk_heads / num_simdgroups, num_tokens, 1,
633644
sizeof(args), &args,
634-
1, (const struct gptoss_metal_buffer *[]) {activations_buffer}, NULL);
645+
1, (const struct gptoss_metal_buffer *[]) {activations_buffer}, NULL,
646+
/*threadgroup_buffer_size=*/0);
635647
}
636648

637649
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
@@ -680,7 +692,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
680692
sizeof(args), &args,
681693
3,
682694
(const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, output_buffer},
683-
(const size_t[]) {input_offset, expert_offset, output_offset});
695+
(const size_t[]) {input_offset, expert_offset, output_offset},
696+
/*threadgroup_buffer_size=*/0);
684697
}
685698

686699
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
@@ -715,7 +728,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
715728
sizeof(args), &args,
716729
2,
717730
(const struct gptoss_metal_buffer *[]) {input_buffer, output_buffer},
718-
(const size_t[]) {input_offset, output_offset});
731+
(const size_t[]) {input_offset, output_offset},
732+
/*threadgroup_buffer_size=*/0);
719733
}
720734

721735
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
@@ -753,6 +767,11 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
753767
return gptoss_status_invalid_argument;
754768
}
755769

770+
const size_t max_context_tokens = math_min(num_q_tokens + num_kv_tokens + 1, window);
771+
const size_t threadgroup_size = math_min(f32_sdpa_fn->max_threadgroup_threads,
772+
max_context_tokens * f32_sdpa_fn->simdgroup_threads);
773+
const size_t half_threadgroup_size = math_round_down_po2(threadgroup_size / 2, f32_sdpa_fn->simdgroup_threads);
774+
756775
const struct gptoss_sdpa_args args = {
757776
.qkv_dim = head_dim * (num_q_heads + 2 * num_kv_heads),
758777
.num_kv_tokens = num_kv_tokens,
@@ -761,12 +780,13 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
761780

762781
return gptoss_metal_command_buffer_encode_launch_kernel(
763782
command_buffer, f32_sdpa_fn,
764-
/*threadgroup_size=*/32, 1, 1,
783+
threadgroup_size, 1, 1,
765784
num_q_tokens, num_kv_heads, 1,
766785
sizeof(args), &args,
767786
5,
768787
(const struct gptoss_metal_buffer *[]) {q_buffer, k_buffer, v_buffer, s_buffer, output_buffer},
769-
(const size_t[]) {q_offset, k_offset, v_offset, s_offset, output_offset});
788+
(const size_t[]) {q_offset, k_offset, v_offset, s_offset, output_offset},
789+
/*threadgroup_buffer_size=*/half_threadgroup_size * 8 * 4 * sizeof(float));
770790
}
771791

772792
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
@@ -813,5 +833,6 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
813833
sizeof(args), &args,
814834
4,
815835
(const struct gptoss_metal_buffer *[]) {score_buffer, argmax_buffer, prob_buffer, sum_buffer},
816-
(const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset});
836+
(const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset},
837+
/*threadgroup_buffer_size=*/0);
817838
}

gpt_oss/metal/source/metal.m

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(
380380
size_t num_threadgroups_z,
381381
size_t params_size,
382382
const void* params,
383-
size_t num_buffers,
384-
const struct gptoss_metal_buffer** buffers,
385-
const size_t* buffer_offsets)
383+
size_t num_device_buffers,
384+
const struct gptoss_metal_buffer** device_buffers,
385+
const size_t* device_buffer_offsets,
386+
size_t threadgroup_buffer_size)
386387
{
387388
if (command_buffer->object == NULL || function->pipeline_state_object == NULL) {
388389
return gptoss_status_invalid_state;
@@ -396,11 +397,14 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(
396397
// Set kernel arguments
397398
[command_encoder_obj setComputePipelineState:pipeline_state_obj];
398399
[command_encoder_obj setBytes:params length:params_size atIndex:0];
399-
for (size_t i = 0; i < num_buffers; ++i) {
400-
id<MTLBuffer> buffer_obj = (id<MTLBuffer>) buffers[i]->object;
401-
const NSUInteger offset = buffer_offsets == NULL ? 0 : (NSUInteger) buffer_offsets[i];
400+
for (size_t i = 0; i < num_device_buffers; ++i) {
401+
id<MTLBuffer> buffer_obj = (id<MTLBuffer>) device_buffers[i]->object;
402+
const NSUInteger offset = device_buffer_offsets == NULL ? 0 : (NSUInteger) device_buffer_offsets[i];
402403
[command_encoder_obj setBuffer:buffer_obj offset:offset atIndex:i + 1];
403404
}
405+
if (threadgroup_buffer_size != 0) {
406+
[command_encoder_obj setThreadgroupMemoryLength:threadgroup_buffer_size atIndex:0];
407+
}
404408

405409
// Dispatch kernel
406410
const MTLSize threadgroup_size = MTLSizeMake(threadgroup_size_x, threadgroup_size_y, threadgroup_size_z);

0 commit comments

Comments
 (0)