@@ -46,7 +46,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_u32_fill_random(
4646 threadgroup_size , 1 , 1 ,
4747 num_threadgroups , 1 , 1 ,
4848 sizeof (args ), & args ,
49- 1 , & output_buffer , & output_offset );
49+ 1 , & output_buffer , & output_offset ,
50+ /*threadgroup_buffer_size=*/ 0 );
5051}
5152
5253enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random (
@@ -93,7 +94,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(
9394 threadgroup_size , 1 , 1 ,
9495 num_threadgroups , 1 , 1 ,
9596 sizeof (args ), & args ,
96- 1 , & output_buffer , & output_offset );
97+ 1 , & output_buffer , & output_offset ,
98+ /*threadgroup_buffer_size=*/ 0 );
9799}
98100
99101enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random (
@@ -140,7 +142,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
140142 threadgroup_size , 1 , 1 ,
141143 num_threadgroups , 1 , 1 ,
142144 sizeof (args ), & args ,
143- 1 , & output_buffer , & output_offset );
145+ 1 , & output_buffer , & output_offset ,
146+ /*threadgroup_buffer_size=*/ 0 );
144147}
145148
146149enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert (
@@ -180,7 +183,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
180183 threadgroup_size , 1 , 1 ,
181184 num_threadgroups , 1 , 1 ,
182185 sizeof (args ), & args ,
183- 3 , (const struct gptoss_metal_buffer * []) {block_buffer , scale_buffer , output_buffer }, NULL );
186+ 3 , (const struct gptoss_metal_buffer * []) {block_buffer , scale_buffer , output_buffer }, NULL ,
187+ /*threadgroup_buffer_size=*/ 0 );
184188}
185189
186190enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings (
@@ -222,7 +226,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings
222226 sizeof (args ), & args ,
223227 3 ,
224228 (const struct gptoss_metal_buffer * []) {token_buffer , weight_buffer , output_buffer },
225- (const size_t []) {token_offset , weight_offset , output_offset });
229+ (const size_t []) {token_offset , weight_offset , output_offset },
230+ /*threadgroup_buffer_size=*/ 0 );
226231}
227232
228233enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm (
@@ -268,7 +273,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
268273 sizeof (args ), & args ,
269274 3 ,
270275 (const struct gptoss_metal_buffer * []) {input_buffer , weight_buffer , output_buffer },
271- (const size_t []) {input_offset , weight_offset , output_offset });
276+ (const size_t []) {input_offset , weight_offset , output_offset },
277+ /*threadgroup_buffer_size=*/ 0 );
272278}
273279
274280enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul (
@@ -325,7 +331,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
325331 sizeof (args ), & args ,
326332 4 ,
327333 (const struct gptoss_metal_buffer * []) {input_buffer , weight_buffer , bias_buffer , output_buffer },
328- (const size_t []) {input_offset , weight_offset , bias_offset , output_offset });
334+ (const size_t []) {input_offset , weight_offset , bias_offset , output_offset },
335+ /*threadgroup_buffer_size=*/ 0 );
329336}
330337
331338enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add (
@@ -382,7 +389,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad
382389 sizeof (args ), & args ,
383390 4 ,
384391 (const struct gptoss_metal_buffer * []) {input_buffer , weight_buffer , bias_buffer , output_buffer },
385- (const size_t []) {input_offset , weight_offset , bias_offset , output_offset });
392+ (const size_t []) {input_offset , weight_offset , bias_offset , output_offset },
393+ /*threadgroup_buffer_size=*/ 0 );
386394}
387395
388396enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding (
@@ -437,7 +445,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembeddi
437445 sizeof (args ), & args ,
438446 4 ,
439447 (const struct gptoss_metal_buffer * []) {input_buffer , weight_buffer , output_buffer , argmax_buffer },
440- (const size_t []) {input_offset , weight_offset , output_offset , argmax_offset });
448+ (const size_t []) {input_offset , weight_offset , output_offset , argmax_offset },
449+ /*threadgroup_buffer_size=*/ 0 );
441450}
442451
443452enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu (
@@ -510,7 +519,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
510519 sizeof (args ), & args ,
511520 6 ,
512521 (const struct gptoss_metal_buffer * []) {input_buffer , expert_buffer , weight_block_buffer , weight_scale_buffer , bias_buffer , output_buffer },
513- (const size_t []) {input_offset , expert_offset , weight_block_offset , weight_scale_offset , bias_offset , output_offset });
522+ (const size_t []) {input_offset , expert_offset , weight_block_offset , weight_scale_offset , bias_offset , output_offset },
523+ /*threadgroup_buffer_size=*/ 0 );
514524}
515525
516526enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul (
@@ -581,7 +591,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
581591 sizeof (args ), & args ,
582592 6 ,
583593 (const struct gptoss_metal_buffer * []) {input_buffer , expert_buffer , weight_block_buffer , weight_scale_buffer , bias_buffer , output_buffer },
584- (const size_t []) {input_offset , expert_offset , weight_block_offset , weight_scale_offset , bias_offset , output_offset });
594+ (const size_t []) {input_offset , expert_offset , weight_block_offset , weight_scale_offset , bias_offset , output_offset },
595+ /*threadgroup_buffer_size=*/ 0 );
585596}
586597
587598enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope (
@@ -631,7 +642,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
631642 threadgroup_size , 1 , 1 ,
632643 num_qk_heads / num_simdgroups , num_tokens , 1 ,
633644 sizeof (args ), & args ,
634- 1 , (const struct gptoss_metal_buffer * []) {activations_buffer }, NULL );
645+ 1 , (const struct gptoss_metal_buffer * []) {activations_buffer }, NULL ,
646+ /*threadgroup_buffer_size=*/ 0 );
635647}
636648
637649enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate (
@@ -680,7 +692,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
680692 sizeof (args ), & args ,
681693 3 ,
682694 (const struct gptoss_metal_buffer * []) {input_buffer , expert_buffer , output_buffer },
683- (const size_t []) {input_offset , expert_offset , output_offset });
695+ (const size_t []) {input_offset , expert_offset , output_offset },
696+ /*threadgroup_buffer_size=*/ 0 );
684697}
685698
686699enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk (
@@ -715,7 +728,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
715728 sizeof (args ), & args ,
716729 2 ,
717730 (const struct gptoss_metal_buffer * []) {input_buffer , output_buffer },
718- (const size_t []) {input_offset , output_offset });
731+ (const size_t []) {input_offset , output_offset },
732+ /*threadgroup_buffer_size=*/ 0 );
719733}
720734
721735enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa (
@@ -753,6 +767,11 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
753767 return gptoss_status_invalid_argument ;
754768 }
755769
770+ const size_t max_context_tokens = math_min (num_q_tokens + num_kv_tokens + 1 , window );
771+ const size_t threadgroup_size = math_min (f32_sdpa_fn -> max_threadgroup_threads ,
772+ max_context_tokens * f32_sdpa_fn -> simdgroup_threads );
773+ const size_t half_threadgroup_size = math_round_down_po2 (threadgroup_size / 2 , f32_sdpa_fn -> simdgroup_threads );
774+
756775 const struct gptoss_sdpa_args args = {
757776 .qkv_dim = head_dim * (num_q_heads + 2 * num_kv_heads ),
758777 .num_kv_tokens = num_kv_tokens ,
@@ -761,12 +780,13 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
761780
762781 return gptoss_metal_command_buffer_encode_launch_kernel (
763782 command_buffer , f32_sdpa_fn ,
764- /* threadgroup_size=*/ 32 , 1 , 1 ,
783+ threadgroup_size , 1 , 1 ,
765784 num_q_tokens , num_kv_heads , 1 ,
766785 sizeof (args ), & args ,
767786 5 ,
768787 (const struct gptoss_metal_buffer * []) {q_buffer , k_buffer , v_buffer , s_buffer , output_buffer },
769- (const size_t []) {q_offset , k_offset , v_offset , s_offset , output_offset });
788+ (const size_t []) {q_offset , k_offset , v_offset , s_offset , output_offset },
789+ /*threadgroup_buffer_size=*/ half_threadgroup_size * 8 * 4 * sizeof (float ));
770790}
771791
772792enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax (
@@ -813,5 +833,6 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
813833 sizeof (args ), & args ,
814834 4 ,
815835 (const struct gptoss_metal_buffer * []) {score_buffer , argmax_buffer , prob_buffer , sum_buffer },
816- (const size_t []) {score_offset , argmax_offset , prob_offset , sum_offset });
836+ (const size_t []) {score_offset , argmax_offset , prob_offset , sum_offset },
837+ /*threadgroup_buffer_size=*/ 0 );
817838}
0 commit comments