Skip to content

Commit 35eb3cc

Browse files
authored
Metal: fused QKV projection (matmul+RoPE+KV cache init) kernel (#184)
1 parent bbc5c48 commit 35eb3cc

File tree

9 files changed

+313
-73
lines changed

9 files changed

+313
-73
lines changed

gpt_oss/metal/source/context.c

Lines changed: 58 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -253,10 +253,51 @@ static enum gptoss_status process_tokens(
253253
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_qkv kernel launch");
254254
return status;
255255
}
256+
257+
status = gptoss_metal_command_buffer_encode_launch_f32_rope(
258+
command_buffer,
259+
&model->f32_rope_fn,
260+
/*threadgroup_size=*/32,
261+
&context->qkv_activation_buffer,
262+
/*input_offset=*/0,
263+
&context->control_buffer,
264+
/*control_offset=*/0,
265+
model->rope_theta,
266+
model->interpolation_scale,
267+
model->yarn_offset,
268+
model->yarn_scale,
269+
model->yarn_multiplier,
270+
input_batch_size,
271+
model->num_heads,
272+
model->num_kv_heads,
273+
model->head_dim,
274+
/*token_offset=*/input_batch_start);
275+
if (status != gptoss_status_success) {
276+
GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch");
277+
return status;
278+
}
279+
280+
for (uint32_t t = 0; t < input_batch_size; t++) {
281+
for (uint32_t kv = 0; kv < 2; kv++) {
282+
for (uint32_t h = 0; h < model->num_kv_heads; h++) {
283+
status = gptoss_metal_command_buffer_encode_copy_buffer(
284+
command_buffer,
285+
&context->qkv_activation_buffer,
286+
/*input_offset=*/(t * attn_qkv_dim + (model->num_heads + kv * model->num_kv_heads + h) * model->head_dim) * sizeof(float),
287+
&context->kvcache_buffer,
288+
/*output_offset=*/(((n * model->num_kv_heads + h) * context->max_tokens + input_batch_start + t) * 2 + kv) * model->head_dim * sizeof(float),
289+
/*size=*/model->head_dim * sizeof(float));
290+
if (status != gptoss_status_success) {
291+
GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t);
292+
return status;
293+
}
294+
}
295+
}
296+
}
256297
} else {
257-
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
298+
status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
258299
command_buffer,
259-
&model->f32_bf16w_matmul_fn,
300+
&model->f32_bf16w_matmul_qkv_fn,
260301
model->attn_qkv_threadgroup_size,
261302
&context->rmsnorm_activation_buffer,
262303
/*input_offset=*/0,
@@ -266,49 +307,24 @@ static enum gptoss_status process_tokens(
266307
/*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
267308
&context->qkv_activation_buffer,
268309
/*output_offset=*/0,
310+
&context->kvcache_buffer,
311+
/*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
269312
&context->control_buffer,
270313
/*control_offset=*/0,
271314
/*num_tokens=*/input_batch_size,
272315
/*num_cols=*/model->embedding_dim,
273-
/*num_rows=*/attn_qkv_dim);
274-
if (status != gptoss_status_success) {
275-
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch");
276-
return status;
277-
}
278-
}
279-
status = gptoss_metal_command_buffer_encode_launch_f32_rope(
280-
command_buffer,
281-
&model->f32_rope_fn,
282-
/*threadgroup_size=*/32,
283-
&context->qkv_activation_buffer,
284-
/*input_offset=*/0,
285-
&context->control_buffer,
286-
/*control_offset=*/0,
287-
model->rope_theta,
288-
model->interpolation_scale,
289-
model->yarn_offset,
290-
model->yarn_scale,
291-
model->yarn_multiplier,
292-
input_batch_size,
293-
model->num_heads,
294-
model->num_kv_heads,
295-
model->head_dim,
296-
/*token_offset=*/input_batch_start);
297-
if (status != gptoss_status_success) {
298-
GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch");
299-
return status;
300-
}
301-
302-
for (uint32_t t = 0; t < input_batch_size; t++) {
303-
status = gptoss_metal_command_buffer_encode_copy_buffer(
304-
command_buffer,
305-
&context->qkv_activation_buffer,
306-
/*input_offset=*/(t * attn_qkv_dim + model->num_heads * model->head_dim) * sizeof(float),
307-
&context->kvcache_buffer,
308-
/*output_offset=*/(n * context->max_tokens + input_batch_start + t) * 2 * model->num_kv_heads * model->head_dim * sizeof(float),
309-
/*size=*/2 * model->num_kv_heads * model->head_dim * sizeof(float));
316+
/*num_q_heads=*/model->num_heads,
317+
/*num_kv_heads=*/model->num_kv_heads,
318+
/*attn_head_dim=*/model->head_dim,
319+
/*token_offset=*/input_batch_start,
320+
/*max_tokens=*/context->max_tokens,
321+
/*rope_base=*/model->rope_theta,
322+
/*interpolation_scale=*/model->interpolation_scale,
323+
/*yarn_offset=*/model->yarn_offset,
324+
/*yarn_scale=*/model->yarn_scale,
325+
/*yarn_multiplier=*/model->yarn_multiplier);
310326
if (status != gptoss_status_success) {
311-
GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t);
327+
GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch");
312328
return status;
313329
}
314330
}
@@ -320,16 +336,15 @@ static enum gptoss_status process_tokens(
320336
&context->qkv_activation_buffer,
321337
/*q_offset=*/attn_qkv_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
322338
&context->kvcache_buffer,
323-
/*k_offset=*/n * context->max_tokens * 2 * model->num_kv_heads * model->head_dim * sizeof(float),
324-
&context->kvcache_buffer,
325-
/*v_offset=*/(n * context->max_tokens * 2 + 1) * model->num_kv_heads * model->head_dim * sizeof(float),
339+
/*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
326340
&model->shared_weight_buffer,
327341
/*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n,
328342
&context->sdpa_activation_buffer,
329343
/*output_offset=*/0,
330344
&context->control_buffer,
331345
/*control_offset=*/0,
332346
/*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX,
347+
/*kv_stride=*/2 * context->max_tokens * model->head_dim,
333348
num_block_output_tokens,
334349
input_batch_start + input_batch_size - num_block_output_tokens,
335350
model->num_heads, model->num_kv_heads, model->head_dim);

gpt_oss/metal/source/include/internal/kernel-args.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ struct gptoss_topk_args {
3939
struct gptoss_sdpa_args {
4040
uint32_t qkv_dim;
4141
uint32_t num_kv_tokens;
42+
uint32_t kv_stride;
4243
uint32_t window;
4344
};
4445

@@ -126,6 +127,18 @@ struct gptoss_rope_args {
126127
float yarn_multiplier;
127128
};
128129

130+
struct gptoss_qkv_args {
131+
uint32_t num_column_vecs;
132+
uint32_t num_rows;
133+
uint32_t token_offset;
134+
float freq_scale;
135+
float interpolation_scale;
136+
float yarn_offset;
137+
float yarn_scale;
138+
float yarn_multiplier;
139+
uint32_t max_tokens;
140+
};
141+
129142
struct gptoss_softmax_args {
130143
uint32_t num_vecs;
131144
uint32_t num_vecs_per_threadgroup;

gpt_oss/metal/source/include/internal/metal-kernels.h

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,35 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
112112
uint32_t num_cols,
113113
uint32_t num_rows);
114114

115+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
116+
const struct gptoss_metal_command_buffer* command_buffer,
117+
const struct gptoss_metal_function* f32_bf16w_matmul_qkv_fn,
118+
size_t threadgroup_size,
119+
const struct gptoss_metal_buffer* input_buffer,
120+
size_t input_offset,
121+
const struct gptoss_metal_buffer* weight_buffer,
122+
size_t weight_offset,
123+
const struct gptoss_metal_buffer* bias_buffer,
124+
size_t bias_offset,
125+
const struct gptoss_metal_buffer* output_buffer,
126+
size_t output_offset,
127+
const struct gptoss_metal_buffer* kv_buffer,
128+
size_t kv_offset,
129+
const struct gptoss_metal_buffer* control_buffer,
130+
size_t control_offset,
131+
uint32_t num_tokens,
132+
uint32_t num_cols,
133+
uint32_t num_q_heads,
134+
uint32_t num_kv_heads,
135+
uint32_t attn_head_dim,
136+
uint32_t token_offset,
137+
uint32_t max_tokens,
138+
float rope_base,
139+
float interpolation_scale,
140+
float yarn_offset,
141+
float yarn_scale,
142+
float yarn_multiplier);
143+
115144
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
116145
const struct gptoss_metal_command_buffer* command_buffer,
117146
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
@@ -306,17 +335,16 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
306335
const struct gptoss_metal_function* f32_sdpa_fn,
307336
const struct gptoss_metal_buffer* q_buffer,
308337
size_t q_offset,
309-
const struct gptoss_metal_buffer* k_buffer,
310-
size_t k_offset,
311-
const struct gptoss_metal_buffer* v_buffer,
312-
size_t v_offset,
338+
const struct gptoss_metal_buffer* kv_buffer,
339+
size_t kv_offset,
313340
const struct gptoss_metal_buffer* s_buffer,
314341
size_t s_offset,
315342
const struct gptoss_metal_buffer* output_buffer,
316343
size_t output_offset,
317344
const struct gptoss_metal_buffer* control_buffer,
318345
size_t control_offset,
319346
uint32_t window,
347+
uint32_t kv_stride,
320348
uint32_t num_q_tokens,
321349
uint32_t num_kv_tokens,
322350
uint32_t num_q_heads,

gpt_oss/metal/source/include/internal/model.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ struct gptoss_model {
7878
struct gptoss_metal_function bf16_f32_embeddings_fn;
7979
struct gptoss_metal_function f32_bf16w_rmsnorm_fn;
8080
struct gptoss_metal_function f32_bf16w_matmul_fn;
81+
struct gptoss_metal_function f32_bf16w_matmul_qkv_fn;
8182
struct gptoss_metal_function f32_bf16w_dense_matmul_qkv_fn;
8283
struct gptoss_metal_function f32_bf16w_dense_matmul_attn_output_fn;
8384
struct gptoss_metal_function f32_bf16w_dense_matmul_mlp_gate_fn;

gpt_oss/metal/source/matmul.metal

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,94 @@ kernel void gptoss_f32_bf16w_matmul(
6767
}
6868
}
6969

70+
kernel void gptoss_f32_bf16w_matmul_qkv(
71+
constant gptoss_qkv_args& args [[ buffer(0) ]],
72+
const device float4* input [[ buffer(1) ]],
73+
const device bfloat4* weight [[ buffer(2) ]],
74+
const device bfloat* bias [[ buffer(3) ]],
75+
device float* q [[ buffer(4) ]],
76+
device float* kv [[ buffer(5) ]],
77+
const device gptoss_control* control [[ buffer(6) ]],
78+
threadgroup void* scratch [[ threadgroup(0) ]],
79+
uint2 gid [[threadgroup_position_in_grid]],
80+
uint simdgroup_tid [[thread_index_in_simdgroup]],
81+
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
82+
uint num_simdgroups [[simdgroups_per_threadgroup]])
83+
{
84+
const uint simdgroup_size = 32;
85+
const uint head_dim = 64;
86+
const uint num_q_heads = 64;
87+
const uint num_kv_heads = 8;
88+
if (control->abort != 0) {
89+
return;
90+
}
91+
92+
const uint num_column_vecs = args.num_column_vecs;
93+
const uint row = gid.x * num_simdgroups + simdgroup_idx;
94+
95+
input += gid.y * num_column_vecs + simdgroup_tid;
96+
weight += num_column_vecs * row + simdgroup_tid;
97+
bias += row;
98+
q += gid.y * args.num_rows;
99+
100+
uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
101+
102+
float4 sum4 = 0.0f;
103+
do {
104+
const bfloat4 w = *weight;
105+
const float4 i = *input;
106+
sum4 = metal::fma(static_cast<float4>(w), i, sum4);
107+
108+
weight += simdgroup_size;
109+
input += simdgroup_size;
110+
} while (--num_iter != 0);
111+
const float2 sum2 = sum4.xy + sum4.zw;
112+
float sum = sum2.x + sum2.y;
113+
sum = metal::simd_sum(sum);
114+
if (metal::simd_is_first()) {
115+
sum += static_cast<float>(*bias);
116+
static_cast<threadgroup float*>(scratch)[simdgroup_idx] = sum;
117+
}
118+
metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
119+
if (simdgroup_idx == 0) {
120+
const uint num_half_simdgroups = num_simdgroups / 2;
121+
if (simdgroup_tid < num_half_simdgroups) {
122+
float2 vals = static_cast<const threadgroup float2*>(scratch)[simdgroup_tid];
123+
const uint idx = gid.x * num_half_simdgroups + simdgroup_tid;
124+
const uint head_idx = idx / (head_dim / 2);
125+
const uint token_idx = args.token_offset + gid.y;
126+
const uint dim_idx = idx % (head_dim / 2);
127+
if (head_idx < num_q_heads + num_kv_heads) {
128+
const float dim_idx_val = static_cast<float>(dim_idx);
129+
const float inv_extrapolation_freq = metal::precise::exp(dim_idx_val * args.freq_scale);
130+
const float inv_interpolation_freq = inv_extrapolation_freq * args.interpolation_scale;
131+
const float alpha = metal::saturate(metal::fma(dim_idx_val, args.yarn_scale, args.yarn_offset));
132+
const float inv_freq = metal::mix(inv_extrapolation_freq, inv_interpolation_freq, alpha);
133+
134+
const float phi = static_cast<float>(token_idx) * inv_freq;
135+
const float yarn_multiplier = args.yarn_multiplier;
136+
float cosphi;
137+
const float sinphi = metal::precise::sincos(phi, cosphi) * yarn_multiplier;
138+
cosphi *= yarn_multiplier;
139+
140+
vals = (float2) {
141+
vals.x * cosphi - vals.y * sinphi,
142+
vals.x * sinphi + vals.y * cosphi,
143+
};
144+
}
145+
if (head_idx < num_q_heads) {
146+
reinterpret_cast<device float2*>(q)[idx] = vals;
147+
} else if (head_idx < num_q_heads + num_kv_heads) {
148+
const uint h = head_idx - num_q_heads;
149+
reinterpret_cast<device float2*>(kv + (h * args.max_tokens + token_idx) * 2 * head_dim)[dim_idx] = vals;
150+
} else {
151+
const uint h = head_idx - num_q_heads - num_kv_heads;
152+
reinterpret_cast<device float2*>(kv + (h * args.max_tokens + token_idx) * 2 * head_dim + head_dim)[dim_idx] = vals;
153+
}
154+
}
155+
}
156+
}
157+
70158
kernel void gptoss_f32_bf16w_unembedding(
71159
constant gptoss_unembedding_args& args [[ buffer(0) ]],
72160
const device float4* input [[ buffer(1) ]],

0 commit comments

Comments
 (0)