Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 52 additions & 20 deletions ds4.c
Original file line number Diff line number Diff line change
Expand Up @@ -10557,7 +10557,7 @@ static int metal_graph_decode_test(
g.after_ffn_hc = embedded_hc;
}
if (ok) ok = metal_graph_encode_output_head(&g, model, weights, vocab_dim);
if (ok) ok = ds4_gpu_end_commands() != 0;
if (ok) ok = ds4_gpu_end_commands_labeled("decode test") != 0;

if (ok) {
ok = ds4_gpu_tensor_read(g.after_ffn_hc, 0, gpu_hc, hc_dim * sizeof(float)) != 0 &&
Expand Down Expand Up @@ -10832,11 +10832,9 @@ static bool metal_graph_encode_token_raw_swa(
DS4_N_HC) != 0;

/*
* Start executing the prefix of the decode graph while the CPU is still
* encoding the rest. The split point is layer-based because this executor is
* a fixed DS4 tape, not a dynamic node graph; four layers is the measured
* point where the prefix is large enough to hide useful work without
* starving the second command buffer.
* Start executing decode prefixes while the CPU encodes later layers. The
* default preserves the measured single split after layer four; the
* flush-every knob is for isolating/mitigating Metal command-buffer hangs.
*/
uint32_t split_after_layers = 4;
const char *split_env = getenv("DS4_METAL_GRAPH_TOKEN_SPLIT_LAYERS");
Expand All @@ -10845,7 +10843,15 @@ static bool metal_graph_encode_token_raw_swa(
unsigned long v = strtoul(split_env, &end, 10);
if (end != split_env && v <= DS4_N_LAYER) split_after_layers = (uint32_t)v;
}
uint32_t flush_every_layers = 0;
const char *flush_every_env = getenv("DS4_METAL_GRAPH_TOKEN_FLUSH_EVERY");
if (flush_every_env && flush_every_env[0]) {
char *end = NULL;
unsigned long v = strtoul(flush_every_env, &end, 10);
if (end != flush_every_env && v > 0 && v <= DS4_N_LAYER) flush_every_layers = (uint32_t)v;
}

uint32_t segment_start = 0;
for (uint32_t il = 0; ok && il < DS4_N_LAYER; il++) {
ok = metal_graph_encode_decode_layer(g,
model,
Expand All @@ -10860,8 +10866,18 @@ static bool metal_graph_encode_token_raw_swa(
ds4_gpu_tensor *tmp = g->cur_hc;
g->cur_hc = g->after_ffn_hc;
g->after_ffn_hc = tmp;
if (ok && allow_split_flush && split_after_layers != 0 && il + 1u == split_after_layers) {
ok = ds4_gpu_flush_commands() != 0;
if (ok && allow_split_flush) {
const uint32_t layer_done = il + 1u;
const bool should_flush =
flush_every_layers != 0
? (layer_done < DS4_N_LAYER && layer_done % flush_every_layers == 0)
: (split_after_layers != 0 && layer_done == split_after_layers);
if (should_flush) {
char label[80];
snprintf(label, sizeof(label), "decode token layers %u-%u", segment_start, il);
ok = ds4_gpu_flush_commands_labeled(label) != 0;
segment_start = layer_done;
}
}
}

Expand Down Expand Up @@ -12786,7 +12802,11 @@ static bool metal_graph_eval_token_raw_swa(
bool ok = ds4_gpu_begin_commands() != 0;
if (ok) ok = metal_graph_encode_token_raw_swa(g, model, weights, token, pos, logits != NULL, true);
const double t_encoded = profile ? now_sec() : 0.0;
if (ok) ok = ds4_gpu_end_commands() != 0;
if (ok) {
char label[80];
snprintf(label, sizeof(label), "decode token tail pos %u", pos);
ok = ds4_gpu_end_commands_labeled(label) != 0;
}
const double t_done = profile ? now_sec() : 0.0;

if (ok && logits) {
Expand Down Expand Up @@ -12836,7 +12856,11 @@ static bool metal_graph_eval_token_raw_swa_top(
1,
1) != 0;
}
if (ok) ok = ds4_gpu_end_commands() != 0;
if (ok) {
char label[80];
snprintf(label, sizeof(label), "decode token top pos %u", pos);
ok = ds4_gpu_end_commands_labeled(label) != 0;
}
if (ok) ok = ds4_gpu_tensor_read(g->comp_selected, 0, top_id, sizeof(*top_id)) != 0;
if (ok && logits) {
ok = ds4_gpu_tensor_read(g->logits, 0, logits, (uint64_t)DS4_N_VOCAB * sizeof(float)) != 0;
Expand Down Expand Up @@ -13199,6 +13223,7 @@ static bool imatrix_collector_save(
}

static bool metal_graph_reset_prefill_state(ds4_gpu_graph *g) {
memset(g->layer_n_comp, 0, sizeof(g->layer_n_comp));
memset(g->layer_n_index_comp, 0, sizeof(g->layer_n_index_comp));
g->mtp_n_raw = 0;
for (uint32_t il = 0; il < DS4_N_LAYER; il++) {
Expand Down Expand Up @@ -15992,7 +16017,7 @@ int ds4_engine_mtp_draft_tokens(ds4_engine *e) {
}

const ds4_tokens *ds4_session_tokens(ds4_session *s) {
return s ? &s->checkpoint : NULL;
return s && s->checkpoint_valid ? &s->checkpoint : NULL;
}

#ifndef DS4_NO_GPU
Expand Down Expand Up @@ -17495,7 +17520,7 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t
NULL);
if (!ok) {
snprintf(err, errlen, "%s resumed prefill failed while extending checkpoint", backend_name);
s->checkpoint_valid = false;
ds4_session_invalidate(s);
return 1;
}
ds4_tokens_copy(&s->checkpoint, prompt);
Expand All @@ -17510,14 +17535,20 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t
s->logits))
{
snprintf(err, errlen, "%s decode failed while extending checkpoint", backend_name);
s->checkpoint_valid = false;
ds4_session_invalidate(s);
return 1;
}
token_vec_push(&s->checkpoint, prompt->v[i]);
}
return 0;
}

if (!metal_graph_reset_prefill_state(&s->graph)) {
snprintf(err, errlen, "%s prefill reset failed", backend_name);
ds4_session_invalidate(s);
return 1;
}

bool ok;
if (s->prefill_cap < (uint32_t)prompt->len) {
ds4_sync_progress progress = {
Expand All @@ -17537,7 +17568,7 @@ int ds4_session_sync(ds4_session *s, const ds4_tokens *prompt, char *err, size_t
}
if (!ok) {
snprintf(err, errlen, "%s prefill failed", backend_name);
s->checkpoint_valid = false;
ds4_session_invalidate(s);
return 1;
}
ds4_tokens_copy(&s->checkpoint, prompt);
Expand Down Expand Up @@ -17750,7 +17781,7 @@ static int ds4_session_eval_internal(ds4_session *s, int token, bool probe_mtp,
s->logits))
{
snprintf(err, errlen, "%s decode failed", ds4_backend_name(e->backend));
s->checkpoint_valid = false;
ds4_session_invalidate(s);
return 1;
}
token_vec_push(&s->checkpoint, token);
Expand Down Expand Up @@ -17927,7 +17958,7 @@ int ds4_session_eval_speculative_argmax(ds4_session *s, int first_token,
if (!ok) {
free(row_logits);
snprintf(err, errlen, "%s decode failed", ds4_backend_name(e->backend));
s->checkpoint_valid = false;
ds4_session_invalidate(s);
return -1;
}
memcpy(s->logits, row_logits, (size_t)DS4_N_VOCAB * sizeof(s->logits[0]));
Expand Down Expand Up @@ -18283,7 +18314,7 @@ int ds4_session_eval_speculative_argmax(ds4_session *s, int first_token,
* Fall through to the exact sequential verifier below. */
} else {
snprintf(err, errlen, "MTP verifier failed");
s->checkpoint_valid = false;
ds4_session_invalidate(s);
DS4_MTP_KEEP_ACCEPTED(0);
spec_frontier_free(&frontier);
free(row_logits);
Expand Down Expand Up @@ -18330,7 +18361,7 @@ int ds4_session_eval_speculative_argmax(ds4_session *s, int first_token,
NULL))
{
snprintf(err, errlen, "%s decode failed", ds4_backend_name(e->backend));
s->checkpoint_valid = false;
ds4_session_invalidate(s);
return -1;
}
token_vec_push(&s->checkpoint, drafts[i]);
Expand All @@ -18346,7 +18377,7 @@ int ds4_session_eval_speculative_argmax(ds4_session *s, int first_token,
(uint64_t)DS4_N_VOCAB * sizeof(s->logits[0])) == 0)
{
snprintf(err, errlen, "%s logits readback failed", ds4_backend_name(e->backend));
s->checkpoint_valid = false;
ds4_session_invalidate(s);
return -1;
}
logits_on_host = true;
Expand Down Expand Up @@ -18382,6 +18413,7 @@ int ds4_session_eval_speculative_argmax(ds4_session *s, int first_token,
}

void ds4_session_invalidate(ds4_session *s) {
if (!s) return;
s->checkpoint_valid = false;
s->checkpoint.len = 0;
s->mtp_draft_valid = false;
Expand All @@ -18395,7 +18427,7 @@ void ds4_session_rewind(ds4_session *s, int pos) {
}

int ds4_session_pos(ds4_session *s) {
return s->checkpoint.len;
return s && s->checkpoint_valid ? s->checkpoint.len : 0;
}

int ds4_session_ctx(ds4_session *s) {
Expand Down
1 change: 1 addition & 0 deletions ds4.h
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ int ds4_session_ctx(ds4_session *s);
int ds4_engine_routed_quant_bits(ds4_engine *e);
bool ds4_engine_has_mtp(ds4_engine *e);
int ds4_engine_mtp_draft_tokens(ds4_engine *e);
/* Returns NULL when the live checkpoint is invalidated after a backend error. */
const ds4_tokens *ds4_session_tokens(ds4_session *s);

/* Disk KV cache payload helpers. The server owns the outer file header and
Expand Down
14 changes: 12 additions & 2 deletions ds4_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1410,8 +1410,18 @@ extern "C" int ds4_gpu_tensor_copy(ds4_gpu_tensor *dst, uint64_t dst_offset,
}

extern "C" int ds4_gpu_begin_commands(void) { return 1; }
extern "C" int ds4_gpu_flush_commands(void) { return cuda_ok(cudaDeviceSynchronize(), "flush"); }
extern "C" int ds4_gpu_end_commands(void) { return cuda_ok(cudaDeviceSynchronize(), "end commands"); }
extern "C" int ds4_gpu_flush_commands_labeled(const char *label) {
return cuda_ok(cudaDeviceSynchronize(), label && label[0] ? label : "flush");
}
extern "C" int ds4_gpu_flush_commands(void) {
return ds4_gpu_flush_commands_labeled("flush");
}
extern "C" int ds4_gpu_end_commands_labeled(const char *label) {
return cuda_ok(cudaDeviceSynchronize(), label && label[0] ? label : "end commands");
}
extern "C" int ds4_gpu_end_commands(void) {
return ds4_gpu_end_commands_labeled("end commands");
}
extern "C" int ds4_gpu_synchronize(void) { return cuda_ok(cudaDeviceSynchronize(), "synchronize"); }

extern "C" int ds4_gpu_set_model_map(const void *model_map, uint64_t model_size) {
Expand Down
2 changes: 2 additions & 0 deletions ds4_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ int ds4_gpu_tensor_copy(ds4_gpu_tensor *dst, uint64_t dst_offset,

int ds4_gpu_begin_commands(void);
int ds4_gpu_flush_commands(void);
int ds4_gpu_flush_commands_labeled(const char *label);
int ds4_gpu_end_commands(void);
int ds4_gpu_end_commands_labeled(const char *label);
int ds4_gpu_synchronize(void);

int ds4_gpu_set_model_map(const void *model_map, uint64_t model_size);
Expand Down
26 changes: 22 additions & 4 deletions ds4_metal.m
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,19 @@ static void ds4_gpu_close_batch_encoder(void) {
g_batch_enc = nil;
}

static void ds4_gpu_label_command_buffer(id<MTLCommandBuffer> cb, const char *label) {
if (!cb || !label || !label[0]) return;
cb.label = [NSString stringWithUTF8String:label];
}

static int ds4_gpu_wait_command_buffer(id<MTLCommandBuffer> cb, const char *label) {
[cb waitUntilCompleted];
if (cb.status == MTLCommandBufferStatusError) {
NSString *cb_label = cb.label;
const char *use_label = cb_label ? [cb_label UTF8String] : label;
fprintf(stderr, "ds4: Metal %s failed: %s\n",
label, [[cb.error localizedDescription] UTF8String]);
use_label && use_label[0] ? use_label : "command batch",
[[cb.error localizedDescription] UTF8String]);
return 0;
}
return 1;
Expand All @@ -282,6 +290,7 @@ static int ds4_gpu_wait_pending_command_buffers(const char *label) {
static int ds4_gpu_finish_command_buffer(id<MTLCommandBuffer> cb, int owned, const char *label) {
if (!owned) return 1;

ds4_gpu_label_command_buffer(cb, label);
[cb commit];
int ok = ds4_gpu_wait_pending_command_buffers(label);
if (!ds4_gpu_wait_command_buffer(cb, label)) ok = 0;
Expand Down Expand Up @@ -3940,13 +3949,14 @@ int ds4_gpu_begin_commands(void) {
return g_batch_cb != nil;
}

int ds4_gpu_flush_commands(void) {
int ds4_gpu_flush_commands_labeled(const char *label) {
if (!g_initialized && !ds4_gpu_init()) return 0;
if (!g_batch_cb) return 0;

ds4_gpu_close_batch_encoder();
id<MTLCommandBuffer> cb = g_batch_cb;
g_batch_cb = nil;
ds4_gpu_label_command_buffer(cb, label);
[cb commit];
[g_pending_cbs addObject:cb];

Expand All @@ -3959,12 +3969,20 @@ int ds4_gpu_flush_commands(void) {
return 1;
}

int ds4_gpu_end_commands(void) {
int ds4_gpu_flush_commands(void) {
return ds4_gpu_flush_commands_labeled("command batch");
}

int ds4_gpu_end_commands_labeled(const char *label) {
if (!g_batch_cb) return 0;
ds4_gpu_close_batch_encoder();
id<MTLCommandBuffer> cb = g_batch_cb;
g_batch_cb = nil;
return ds4_gpu_finish_command_buffer(cb, 1, "command batch");
return ds4_gpu_finish_command_buffer(cb, 1, label && label[0] ? label : "command batch");
}

int ds4_gpu_end_commands(void) {
return ds4_gpu_end_commands_labeled("command batch");
}

int ds4_gpu_synchronize(void) {
Expand Down
4 changes: 4 additions & 0 deletions ds4_server.c
Original file line number Diff line number Diff line change
Expand Up @@ -8935,6 +8935,10 @@ static void kv_cache_close(kv_disk_cache *kc) {

static char *render_tokens_text(ds4_engine *engine, const ds4_tokens *tokens, size_t *out_len) {
buf b = {0};
if (!engine || !tokens) {
if (out_len) *out_len = 0;
return buf_take(&b);
}
for (int i = 0; i < tokens->len; i++) {
size_t len = 0;
char *piece = ds4_token_text(engine, tokens->v[i], &len);
Expand Down