@@ -253,10 +253,51 @@ static enum gptoss_status process_tokens(
253253 GPTOSS_LOG_ERROR ("failed to encode f32_bf16w_dense_matmul_qkv kernel launch" );
254254 return status ;
255255 }
256+
257+ status = gptoss_metal_command_buffer_encode_launch_f32_rope (
258+ command_buffer ,
259+ & model -> f32_rope_fn ,
260+ /*threadgroup_size=*/ 32 ,
261+ & context -> qkv_activation_buffer ,
262+ /*input_offset=*/ 0 ,
263+ & context -> control_buffer ,
264+ /*control_offset=*/ 0 ,
265+ model -> rope_theta ,
266+ model -> interpolation_scale ,
267+ model -> yarn_offset ,
268+ model -> yarn_scale ,
269+ model -> yarn_multiplier ,
270+ input_batch_size ,
271+ model -> num_heads ,
272+ model -> num_kv_heads ,
273+ model -> head_dim ,
274+ /*token_offset=*/ input_batch_start );
275+ if (status != gptoss_status_success ) {
276+ GPTOSS_LOG_ERROR ("failed to encode f32_rope kernel launch" );
277+ return status ;
278+ }
279+
280+ for (uint32_t t = 0 ; t < input_batch_size ; t ++ ) {
281+ for (uint32_t kv = 0 ; kv < 2 ; kv ++ ) {
282+ for (uint32_t h = 0 ; h < model -> num_kv_heads ; h ++ ) {
283+ status = gptoss_metal_command_buffer_encode_copy_buffer (
284+ command_buffer ,
285+ & context -> qkv_activation_buffer ,
286+ /*input_offset=*/ (t * attn_qkv_dim + (model -> num_heads + kv * model -> num_kv_heads + h ) * model -> head_dim ) * sizeof (float ),
287+ & context -> kvcache_buffer ,
288+ /*output_offset=*/ (((n * model -> num_kv_heads + h ) * context -> max_tokens + input_batch_start + t ) * 2 + kv ) * model -> head_dim * sizeof (float ),
289+ /*size=*/ model -> head_dim * sizeof (float ));
290+ if (status != gptoss_status_success ) {
291+ GPTOSS_LOG_ERROR ("failed to encode copy of token %" PRIu32 " to KV cache" , t );
292+ return status ;
293+ }
294+ }
295+ }
296+ }
256297 } else {
257- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul (
298+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv (
258299 command_buffer ,
259- & model -> f32_bf16w_matmul_fn ,
300+ & model -> f32_bf16w_matmul_qkv_fn ,
260301 model -> attn_qkv_threadgroup_size ,
261302 & context -> rmsnorm_activation_buffer ,
262303 /*input_offset=*/ 0 ,
@@ -266,49 +307,24 @@ static enum gptoss_status process_tokens(
266307 /*bias_offset=*/ model -> attn_qkv_bias_offset + model -> per_block_shared_weights_size * n ,
267308 & context -> qkv_activation_buffer ,
268309 /*output_offset=*/ 0 ,
310+ & context -> kvcache_buffer ,
311+ /*kv_offset=*/ n * model -> num_kv_heads * context -> max_tokens * 2 * model -> head_dim * sizeof (float ),
269312 & context -> control_buffer ,
270313 /*control_offset=*/ 0 ,
271314 /*num_tokens=*/ input_batch_size ,
272315 /*num_cols=*/ model -> embedding_dim ,
273- /*num_rows=*/ attn_qkv_dim );
274- if (status != gptoss_status_success ) {
275- GPTOSS_LOG_ERROR ("failed to encode f32_bf16w_matmul kernel launch" );
276- return status ;
277- }
278- }
279- status = gptoss_metal_command_buffer_encode_launch_f32_rope (
280- command_buffer ,
281- & model -> f32_rope_fn ,
282- /*threadgroup_size=*/ 32 ,
283- & context -> qkv_activation_buffer ,
284- /*input_offset=*/ 0 ,
285- & context -> control_buffer ,
286- /*control_offset=*/ 0 ,
287- model -> rope_theta ,
288- model -> interpolation_scale ,
289- model -> yarn_offset ,
290- model -> yarn_scale ,
291- model -> yarn_multiplier ,
292- input_batch_size ,
293- model -> num_heads ,
294- model -> num_kv_heads ,
295- model -> head_dim ,
296- /*token_offset=*/ input_batch_start );
297- if (status != gptoss_status_success ) {
298- GPTOSS_LOG_ERROR ("failed to encode f32_rope kernel launch" );
299- return status ;
300- }
301-
302- for (uint32_t t = 0 ; t < input_batch_size ; t ++ ) {
303- status = gptoss_metal_command_buffer_encode_copy_buffer (
304- command_buffer ,
305- & context -> qkv_activation_buffer ,
306- /*input_offset=*/ (t * attn_qkv_dim + model -> num_heads * model -> head_dim ) * sizeof (float ),
307- & context -> kvcache_buffer ,
308- /*output_offset=*/ (n * context -> max_tokens + input_batch_start + t ) * 2 * model -> num_kv_heads * model -> head_dim * sizeof (float ),
309- /*size=*/ 2 * model -> num_kv_heads * model -> head_dim * sizeof (float ));
316+ /*num_q_heads=*/ model -> num_heads ,
317+ /*num_kv_heads=*/ model -> num_kv_heads ,
318+ /*attn_head_dim=*/ model -> head_dim ,
319+ /*token_offset=*/ input_batch_start ,
320+ /*max_tokens=*/ context -> max_tokens ,
321+ /*rope_base=*/ model -> rope_theta ,
322+ /*interpolation_scale=*/ model -> interpolation_scale ,
323+ /*yarn_offset=*/ model -> yarn_offset ,
324+ /*yarn_scale=*/ model -> yarn_scale ,
325+ /*yarn_multiplier=*/ model -> yarn_multiplier );
310326 if (status != gptoss_status_success ) {
311- GPTOSS_LOG_ERROR ("failed to encode copy of token %" PRIu32 " to KV cache" , t );
327+ GPTOSS_LOG_ERROR ("failed to encode f32_bf16w_matmul_qkv kernel launch" );
312328 return status ;
313329 }
314330 }
@@ -320,16 +336,15 @@ static enum gptoss_status process_tokens(
320336 & context -> qkv_activation_buffer ,
321337 /*q_offset=*/ attn_qkv_dim * (input_batch_size - num_block_output_tokens ) * sizeof (float ),
322338 & context -> kvcache_buffer ,
323- /*k_offset=*/ n * context -> max_tokens * 2 * model -> num_kv_heads * model -> head_dim * sizeof (float ),
324- & context -> kvcache_buffer ,
325- /*v_offset=*/ (n * context -> max_tokens * 2 + 1 ) * model -> num_kv_heads * model -> head_dim * sizeof (float ),
339+ /*kv_offset=*/ n * model -> num_kv_heads * context -> max_tokens * 2 * model -> head_dim * sizeof (float ),
326340 & model -> shared_weight_buffer ,
327341 /*s_offset=*/ model -> attn_sdpa_sink_offset + model -> per_block_shared_weights_size * n ,
328342 & context -> sdpa_activation_buffer ,
329343 /*output_offset=*/ 0 ,
330344 & context -> control_buffer ,
331345 /*control_offset=*/ 0 ,
332346 /*window=*/ n % 2 == 0 ? model -> attention_window : UINT32_MAX ,
347+ /*kv_stride=*/ 2 * context -> max_tokens * model -> head_dim ,
333348 num_block_output_tokens ,
334349 input_batch_start + input_batch_size - num_block_output_tokens ,
335350 model -> num_heads , model -> num_kv_heads , model -> head_dim );
0 commit comments