#include "speculative.h" #include "common.h" #include "ggml.h" #include "llama.h" #include "log.h" #include "ngram-cache.h" #include "ngram-map.h" #include "ngram-mod.h" #include "sampling.h" #include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP) #include #include #include #include #include #include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 const std::map common_speculative_type_from_name_map = { {"none", COMMON_SPECULATIVE_TYPE_NONE}, {"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE}, {"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3}, {"draft-mtp", COMMON_SPECULATIVE_TYPE_DRAFT_MTP}, {"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, {"ngram-mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD}, {"ngram-cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} }; static std::string common_speculative_get_devices_str(const std::vector & devices) { std::string result; for (size_t i = 0; i < devices.size(); i++) { if (devices[i] == nullptr) { continue; } if (!result.empty()) result += ", "; result += ggml_backend_dev_name(devices[i]); } return result.empty() ? "default" : result; } struct common_speculative_config { common_speculative_type type; common_params_speculative params; common_speculative_config(common_speculative_type t, const common_params_speculative & p = common_params_speculative{}) : type(t), params(p) {} }; static bool common_speculative_are_compatible( const llama_model * model_tgt, const llama_model * model_dft) { const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); const auto vocab_type_tgt = llama_vocab_type(vocab_tgt); LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); const auto vocab_type_dft = llama_vocab_type(vocab_dft); LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft); if (vocab_type_tgt != vocab_type_dft) { LOG_WRN("%s: draft model vocab type must match target model to use speculation but " "vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt); return false; } if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) || (llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) { LOG_WRN("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n", __func__, llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft), llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft)); return false; } if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) || (llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) { LOG_WRN("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n", __func__, llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft), llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft)); return false; } { const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt); const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft); const int vocab_diff = n_vocab_tgt > n_vocab_dft ? n_vocab_tgt - n_vocab_dft : n_vocab_dft - n_vocab_tgt; if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) { LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__); LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n", n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE); return false; } for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); if (std::strcmp(token_text_tgt, token_text_dft) != 0) { LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__); LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i, common_token_to_piece(vocab_tgt, i).c_str(), common_token_to_piece(vocab_dft, i).c_str()); return false; } } } return true; } using common_speculative_draft_params_vec = std::vector; // state of an implementation of speculative decoding // // each implementation has a unique type and a state that is implementation-specific // in a subclass of common_speculative_impl struct common_speculative_impl { const common_speculative_type type; uint32_t n_seq; size_t n_call_begin = 0; // number of times this implementation was called for refresh. size_t n_call_draft = 0; // number of times this implementation was called for generation. size_t n_call_accept = 0; // number of times this implementation was called for accumulation. size_t n_gen_drafts = 0; // number of times a draft or part was generated by this implementation. size_t n_acc_drafts = 0; // number of times a draft or part was accepted by the target model. size_t n_gen_tokens = 0; // number of tokens generated by this implementation. size_t n_acc_tokens = 0; // number of tokens accepted by the target model. std::vector n_acc_tokens_per_pos; // number of tokens accepted per draft position. // TODO: track performance of most recent calls const bool gen_perf = true; // whether to generate performance stats. int64_t t_begin_us = 0; // total time spent in refresh of this implementation in microseconds. int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds. int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds. common_speculative_impl(common_speculative_type type, uint32_t n_seq) : type(type), n_seq(n_seq) {} virtual ~common_speculative_impl() = default; virtual void begin(llama_seq_id seq_id, const llama_tokens & prompt) = 0; virtual bool process(const llama_batch & batch) = 0; virtual void draft(common_speculative_draft_params_vec & dparams) = 0; virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0; // (optional) serialize/restore per-seq internal state (e.g. eagle3's deferred boundary). virtual bool get_state(llama_seq_id /*seq_id*/, std::vector & /*data*/) const { return false; } virtual void set_state(llama_seq_id /*seq_id*/, const std::vector & /*data*/) {} // true if this implementation requires the target context to extract post-norm embeddings virtual bool need_embd() const = 0; // true if this implementation requires the target context to extract pre-norm embeddings virtual bool need_embd_nextn() const { return false; } }; struct common_speculative_impl_draft_simple : public common_speculative_impl { common_params_speculative_draft params; llama_batch batch; std::vector smpls; common_speculative_impl_draft_simple(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, n_seq) , params(params.draft) { auto * ctx_dft = this->params.ctx_dft; auto * ctx_tgt = this->params.ctx_tgt; LOG_INF("%s: adding speculative implementation 'draft-simple'\n", __func__); LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min); LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__, this->params.n_gpu_layers, ggml_type_name(this->params.cache_type_k), ggml_type_name(this->params.cache_type_v), ctx_tgt ? "yes" : "no", ctx_dft ? "yes" : "no", common_speculative_get_devices_str(this->params.devices).c_str()); batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); // TODO: optimize or pass from outside? // { // common_params_sampling params; // params.no_perf = false; // // params.top_k = 40; // params.top_p = 0.9; // // params.samplers = { // COMMON_SAMPLER_TYPE_TOP_K, // COMMON_SAMPLER_TYPE_TOP_P, // COMMON_SAMPLER_TYPE_INFILL, // }; // // result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); // } smpls.resize(n_seq); for (auto & smpl : smpls) { common_params_sampling params; params.no_perf = false; params.top_k = 10; params.samplers = { COMMON_SAMPLER_TYPE_TOP_K, }; smpl.reset(common_sampler_init(llama_get_model(ctx_dft), params)); } const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt); if (!vocab_cmpt) { LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__); throw std::runtime_error("draft model vocab type must match target model to use speculation"); } if (n_seq != llama_n_seq_max(ctx_dft)) { LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft)); throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq"); } } ~common_speculative_impl_draft_simple() override { llama_batch_free(batch); } void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop } bool process(const llama_batch & batch) override { auto * ctx_dft = params.ctx_dft; const int ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret); return false; } return true; } void draft(common_speculative_draft_params_vec & dparams) override { auto & ctx_dft = params.ctx_dft; common_batch_clear(batch); // keep track of which sequences are still drafting int n_drafting = 0; std::vector drafting(n_seq); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } n_drafting++; drafting[seq_id] = true; common_sampler_reset(smpls[seq_id].get()); common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); } int ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); return; } int i = 0; while (n_drafting > 0) { int i_batch = 0; common_batch_clear(batch); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { if (!drafting[seq_id]) { continue; } auto * smpl = smpls[seq_id].get(); common_sampler_sample(smpl, ctx_dft, i_batch, true); ++i_batch; const auto * cur_p = common_sampler_get_candidates(smpl, true); for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); } // add drafted token for each sequence const llama_token id = cur_p->data[0].id; // only collect very high-confidence draft tokens if (cur_p->data[0].p < params.p_min) { drafting[seq_id] = false; n_drafting--; continue; } common_sampler_accept(smpl, id, true); auto & dp = dparams.at(seq_id); auto & result = *dp.result; result.push_back(id); if ((params.n_max <= (int) result.size()) || (dp.n_max > 0 && dp.n_max <= (int) result.size())) { drafting[seq_id] = false; n_drafting--; continue; } common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); } if (batch.n_tokens == 0) { break; } // evaluate the drafted tokens on the draft model ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); break; } ++i; } for (auto & dp : dparams) { if (!dp.drafting) { continue; } if (dp.result->size() < (size_t) params.n_min) { dp.result->clear(); } } } void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override { // noop } bool need_embd() const override { return false; } }; // EAGLE3 speculative decoding state // // Input of draft decoder: (This is different compared to MTP) // At "pos P", the decoder takes input pair (t_{P+1}, g_P), with RoPE at P. // - t_{P+1} = token at sequence pos P+1 (the *next* token after P) // - g_P = encoder output = projection of target's extracted hidden states at P // // Deferred boundary (MTP doesn't have this issue): // Within a single process() call with n_tokens, we can only write decoder KV for // training pos 0..n_tokens-2. The last training pos (n_tokens-1) needs t_{n_tokens} // which lies *outside* this batch — it is the token target will sample next or the first token from next ubatch. // So the last training pos of each process() call is *deferred* to whichever next call has // the missing token in hand: // - multi-ubatch prefill: the next process()'s first token completes the pair // (handled by the per-seq "cross-ubatch bridge") // - single-ubatch prefill / after verify: draft()'s seed step uses "dp.id_last" // (target's freshest sample) to complete the pair // // Per-seq carry-over state: // pending_g_last [n_embd_dec] ┐ the deferred boundary's (g, pos). Set by // pending_pos_last llama_pos ┘ process() at end of ubatch (= last row); // rebased by accept() to first-non-accepted pos. // verify_g [N × n_embd_dec] snapshot of process()'s encoder output; // verify_pos_first llama_pos consumed by accept() to recover the right // verify_g_rows int32_t pending_g_last row for any n_accepted value. // // Performance is overall good but there is waste in verify cycle: // process() runs encoder + decoder on the *full* verify batch including rows for // rejected drafts. The KV at those positions is then dropped. // // TODO: Not sure if we need optimization for this waste? // If so we may need hybrid stash: // in verify mode, have process() only stash features and let draft() seed run // encoder+decoder on n_accepted+1 rows). struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { common_params_speculative_draft params; llama_batch batch; std::vector smpls; // backend sampler chain per seq, attached to ctx_dft std::vector backend_chains; int32_t n_embd_dec = 0; // draft hidden size int32_t n_embd_enc = 0; // target_layer_ids_n * target_hidden_size int32_t n_embd_tgt = 0; // target model hidden size const int32_t * target_layer_ids = nullptr; // model_dft's extract layer indices uint32_t target_layer_ids_n = 0; // [per-seq] deferred boundary state std::vector> pending_g_last; std::vector pending_pos_last; // [per-seq] snapshot of the most recent process()'s encoder output std::vector> verify_g; // [n_seq][n_rows * n_embd_dec] std::vector verify_pos_first; // [n_seq] — pos of verify_g[seq][0] std::vector verify_g_rows; // [n_seq] — number of rows // scratch buffer for concatenated target features [n_tokens, n_embd_enc] std::vector features_buf; std::vector g_embd_buf; common_speculative_impl_draft_eagle3(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq) , params(params.draft) { LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__); LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling); auto * ctx_tgt = this->params.ctx_tgt; auto * ctx_dft = this->params.ctx_dft; GGML_ASSERT(ctx_tgt && ctx_dft && "EAGLE3 requires ctx_tgt and ctx_dft to be set"); const llama_model * model_dft = llama_get_model(ctx_dft); const llama_model * model_tgt = llama_get_model(ctx_tgt); target_layer_ids = llama_model_target_layer_ids (model_dft); target_layer_ids_n = llama_model_target_layer_ids_n(model_dft); if (target_layer_ids_n != 3) { throw std::runtime_error("draft model is not eagle3 (expected 3 extract layers, got " + std::to_string(target_layer_ids_n) + ")"); } n_embd_tgt = llama_model_n_embd(model_tgt); n_embd_dec = llama_model_n_embd(model_dft); n_embd_enc = (int32_t) target_layer_ids_n * n_embd_tgt; const int32_t n_b = (int32_t) llama_n_batch(ctx_dft); batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd_dec, /*n_seq_max=*/ 1); // llama_batch_init allocates only one of token/embd; eagle3 decoder needs both. // TODO: fix, how to call without malloc batch.token = (llama_token *) malloc(sizeof(llama_token) * n_b); smpls.resize(n_seq); for (auto & s : smpls) { common_params_sampling sparams; sparams.no_perf = false; sparams.top_k = 10; sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } // offload draft sampling to the backend backend_chains.assign(n_seq, nullptr); if (this->params.backend_sampling) { for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { llama_sampler * chain = llama_sampler_chain_init(llama_sampler_chain_default_params()); llama_sampler_chain_add(chain, llama_sampler_init_top_k(10)); if (!llama_set_sampler(ctx_dft, seq_id, chain)) { LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id); llama_sampler_free(chain); chain = nullptr; } backend_chains[seq_id] = chain; } } // turn on extraction of the target layers' input embeddings for (uint32_t k = 0; k < target_layer_ids_n; ++k) { llama_set_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k], true); } // turn on extraction of the draft model's pre-norm hidden state // (used both for the encoder output g_embd and the decoder pre-norm output). llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true); pending_g_last.assign(n_seq, std::vector(n_embd_dec, 0.0f)); pending_pos_last.assign(n_seq, -1); verify_g.assign(n_seq, std::vector()); verify_pos_first.assign(n_seq, -1); verify_g_rows.assign(n_seq, 0); } ~common_speculative_impl_draft_eagle3() override { auto * ctx_dft = this->params.ctx_dft; for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) backend_chains.size(); ++seq_id) { if (backend_chains[seq_id] == nullptr) { continue; } if (ctx_dft) { llama_set_sampler(ctx_dft, seq_id, nullptr); } llama_sampler_free(backend_chains[seq_id]); } backend_chains.clear(); if (batch.token != nullptr) { free(batch.token); batch.token = nullptr; } llama_batch_free(batch); } void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { const int32_t N = (int32_t) prompt.size(); if (N <= 0) { return; } // expected state after prefill: ctx_dft has pos 0..N-2 (last position is deferred to // draft()'s seed step). Warn only if more than one position is missing. auto * ctx_dft = this->params.ctx_dft; const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); if (pos_max < N - 2) { LOG_WRN("%s: ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. " "Drafts may degrade.\n", __func__, (int) pos_max, N - 2); } } bool process(const llama_batch & batch_in) override { if (batch_in.n_tokens <= 0) { return true; } if (batch_in.token == nullptr || batch_in.embd != nullptr) { return true; } const int32_t n_tokens = batch_in.n_tokens; // i_batch_beg[seq] / i_batch_end[seq]: inclusive batch indices of this seq's // first/last token in batch_in. Assumes per-seq tokens are contiguous within // the ubatch (server's default ordering). std::vector i_batch_beg(n_seq, -1); std::vector i_batch_end(n_seq, -1); for (int k = 0; k < n_tokens; ++k) { GGML_ASSERT(batch_in.n_seq_id[k] == 1); const llama_seq_id seq_id = batch_in.seq_id[k][0]; if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) { continue; } i_batch_end[seq_id] = k; if (i_batch_beg[seq_id] < 0) { i_batch_beg[seq_id] = k; } } auto * ctx_tgt = this->params.ctx_tgt; auto * ctx_dft = this->params.ctx_dft; // Interleave each extract_layer's hidden state into a contiguous buffer of // shape [n_tokens, target_layer_ids_n * n_embd_tgt]. Then run EAGLE3 encoder // to get one g_embd row per token. features_buf.resize((size_t) n_tokens * n_embd_enc, 0.0f); for (uint32_t k = 0; k < target_layer_ids_n; ++k) { const float * layer = llama_get_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k]); if (!layer) { GGML_ABORT("EAGLE3: target layer %d input not extracted.", target_layer_ids[k]); } for (int32_t i = 0; i < n_tokens; ++i) { float * dst = features_buf.data() + (size_t) i * n_embd_enc + k * (size_t) n_embd_tgt; const float * src = layer + (size_t) i * n_embd_tgt; std::memcpy(dst, src, (size_t) n_embd_tgt * sizeof(float)); } } g_embd_buf.resize((size_t) n_tokens * n_embd_dec); // llama_encode() requires the full encoder batch to fit in n_ubatch. // Allow batch > ubatch: eagle3's per-token encoder can be chunked safely. const int32_t n_ubatch_dft = (int32_t) llama_n_ubatch(ctx_dft); for (int32_t i = 0; i < n_tokens; i += n_ubatch_dft) { const int32_t n_chunk = std::min(n_ubatch_dft, n_tokens - i); llama_batch enc_batch = { /*.n_tokens =*/ n_chunk, /*.token =*/ nullptr, /*.embd =*/ features_buf.data() + (size_t) i * n_embd_enc, /*.pos =*/ nullptr, /*.n_seq_id =*/ nullptr, /*.seq_id =*/ nullptr, /*.logits =*/ nullptr, }; const int32_t rc = llama_encode(ctx_dft, enc_batch); if (rc != 0) { LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n", __func__, rc, (int) n_chunk, (int) i); return false; } // g_embd has shape [n_chunk, n_embd_dec] in ctx_dft's pre-norm embeddings buffer. const float * g_embd_chunk = llama_get_embeddings_nextn(ctx_dft); GGML_ASSERT(g_embd_chunk && "EAGLE3 encoder produced no output."); std::memcpy(g_embd_buf.data() + (size_t) i * n_embd_dec, g_embd_chunk, (size_t) n_chunk * n_embd_dec * sizeof(float)); } const float * g_embd = g_embd_buf.data(); const size_t row_bytes = (size_t) n_embd_dec * sizeof(float); // EAGLE3 decoder input convention: at memory pos P the input pair is // (token[P+1], g_embd[P]). This shifts the token index "left by one" relative to g_embd. // // Per seq, in order: // (a) cross-ubatch bridge — when applicable, write the previously-deferred // pos using this ubatch's first token + pending_g_last. // (b) main write loop — for k in [beg, end-1], write (token[k+1], g_embd[k]) // at pos[k]. The last training pos (k=end) is left unwritten = new // deferred boundary, completed by the next process() or draft() call. // (c) refresh deferred state — stash this ubatch's full g_embd into verify_g, // update pending_g_last / pending_pos_last to the last row. common_batch_clear(batch); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { const int32_t beg = i_batch_beg[seq_id]; const int32_t end = i_batch_end[seq_id]; if (beg < 0 || end < 0) { continue; } // cross-ubatch bridge — complete the prior ubatch's deferred boundary. // Fires iff all three preconditions hold: // 1) pending_pos_last >= 0 // 2) pending_pos_last + 1 == pos[beg] // 3) pending_pos_last > dft_pos_max // TODO: is this check needed? const llama_pos pending_pos = pending_pos_last[seq_id]; if (pending_pos >= 0 && pending_pos + 1 == batch_in.pos[beg]) { const llama_pos dft_pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); if (pending_pos > dft_pos_max) { common_batch_add(batch, batch_in.token[beg], pending_pos, { seq_id }, /*logits=*/ false); std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd_dec, pending_g_last[seq_id].data(), row_bytes); } } for (int32_t k = beg; k < end; ++k) { common_batch_add(batch, batch_in.token[k + 1], batch_in.pos[k], { seq_id }, /*logits=*/ false); std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd_dec, g_embd + (size_t) k * n_embd_dec, row_bytes); } // refresh deferred state const int32_t n_rows = end - beg + 1; verify_pos_first[seq_id] = batch_in.pos[beg]; pending_pos_last[seq_id] = batch_in.pos[end]; verify_g_rows[seq_id] = n_rows; verify_g[seq_id].resize((size_t) n_rows * n_embd_dec, 0.0f); std::memcpy(verify_g[seq_id].data(), g_embd + (size_t) beg * n_embd_dec, row_bytes * n_rows); std::memcpy(pending_g_last[seq_id].data(), g_embd + (size_t) end * n_embd_dec, row_bytes); } if (batch.n_tokens > 0) { const int32_t rc = llama_decode(ctx_dft, batch); if (rc != 0) { LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n", __func__, rc, (int) batch.n_tokens, (int) batch_in.pos[0]); return false; } } return true; } void draft(common_speculative_draft_params_vec & dparams) override { auto & ctx_dft = params.ctx_dft; common_batch_clear(batch); // keep track of which sequences are still drafting int n_drafting = 0; std::vector drafting(n_seq); const size_t row_bytes = (size_t) n_embd_dec * sizeof(float); // Complete the deferred boundary pair (dp.id_last, pending_g_last) at memory // pos pending_pos_last. dp.id_last is target's freshest sample (= corrected // token after verify, or first generated token after prefill), matching the // EAGLE3 input convention (token[P+1], g_embd[P]) at pos P. for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } if (pending_pos_last[seq_id] < 0) { continue; } n_drafting++; drafting[seq_id] = true; common_sampler_reset(smpls[seq_id].get()); llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, pending_pos_last[seq_id], -1); common_batch_add(batch, dp.id_last, pending_pos_last[seq_id], { seq_id }, true); std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd_dec, pending_g_last[seq_id].data(), row_bytes); } if (batch.n_tokens == 0) { return; } int ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); return; } int i = 0; while (n_drafting > 0) { int i_batch = 0; common_batch_clear(batch); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { if (!drafting[seq_id]) { continue; } auto * smpl = smpls[seq_id].get(); common_sampler_sample(smpl, ctx_dft, i_batch, true); // pre-norm hidden state of this position becomes g_embd for the next step const float * prenorm = llama_get_embeddings_nextn_ith(ctx_dft, i_batch); ++i_batch; const auto * cur_p = common_sampler_get_candidates(smpl, true); for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); } const llama_token id = cur_p->data[0].id; // only collect very high-confidence draft tokens // (configurable via --spec-draft-p-min, set to 0.0 to disable early-stop) if (cur_p->data[0].p < params.p_min) { drafting[seq_id] = false; n_drafting--; continue; } common_sampler_accept(smpl, id, true); auto & dp = dparams.at(seq_id); auto & result = *dp.result; result.push_back(id); if (params.n_max <= (int) result.size()) { drafting[seq_id] = false; n_drafting--; continue; } common_batch_add(batch, id, pending_pos_last[seq_id] + (i + 1), { seq_id }, true); std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd_dec, prenorm, row_bytes); } if (batch.n_tokens == 0) { break; } ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); break; } ++i; } for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } if (dp.result->size() < (size_t) params.n_min) { dp.result->clear(); } } } void accept(llama_seq_id seq_id, uint16_t n_accepted, bool /*is_other*/) override { if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) { return; } const int32_t n_rows = verify_g_rows[seq_id]; if (n_rows <= 0) { return; } const int32_t i_g = std::min(n_accepted, n_rows - 1); pending_pos_last[seq_id] = verify_pos_first[seq_id] + i_g; std::memcpy(pending_g_last[seq_id].data(), verify_g[seq_id].data() + (size_t) i_g * n_embd_dec, (size_t) n_embd_dec * sizeof(float)); } // we only need to stash the deferred boundary's g_embd row for recurrent/hybrid targets: // their single-position checkpoints drop it on restore bool need_boundary_stash() const { const llama_model * model_tgt = llama_get_model(params.ctx_tgt); return llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt); } bool get_state(llama_seq_id seq_id, std::vector & data) const override { if (!need_boundary_stash()) { return false; } if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || pending_pos_last[seq_id] < 0) { return false; } const llama_pos pos = pending_pos_last[seq_id]; const std::vector & g = pending_g_last[seq_id]; data.resize(sizeof(llama_pos) + g.size() * sizeof(float)); std::memcpy(data.data(), &pos, sizeof(llama_pos)); std::memcpy(data.data() + sizeof(llama_pos), g.data(), g.size() * sizeof(float)); return true; } void set_state(llama_seq_id seq_id, const std::vector & data) override { if (!need_boundary_stash()) { return; } if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) { return; } if (data.size() != sizeof(llama_pos) + (size_t) n_embd_dec * sizeof(float)) { return; } llama_pos pos = -1; std::memcpy(&pos, data.data(), sizeof(llama_pos)); pending_pos_last[seq_id] = pos; pending_g_last[seq_id].resize(n_embd_dec); std::memcpy(pending_g_last[seq_id].data(), data.data() + sizeof(llama_pos), (size_t) n_embd_dec * sizeof(float)); } bool need_embd() const override { return false; } }; struct common_speculative_impl_draft_mtp : public common_speculative_impl { common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft) llama_batch batch; std::vector smpls; // backend sampler chain per seq, attached to ctx_dft std::vector backend_chains; int32_t n_embd = 0; // One MTP draft driver, three modes (set once in the ctor): // is_mem_shared (gemma4): shares the target KV, runs all heads in one graph. // chain_heads (step35): n_mtp_layers trained heads, one per draft step. // neither (qwen35 / qwen35moe): a single trained MTP head. int32_t n_mtp_layers = 1; bool is_mem_shared = false; // gemma4 bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. // The last h-row of one process() call needs the first token of the NEXT // call to pair with, so it's stashed here until that next call fires. std::vector> pending_h; // [n_seq][n_embd] std::vector i_batch_beg; std::vector i_batch_end; // Hidden rows from the most recent target verification batch, grouped by seq. // Row 0 corresponds to the sampled token, row N to the Nth accepted draft token. std::vector> verify_h; std::vector verify_h_rows; std::vector i_last; std::vector> chain_h; common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq) , params(params.draft) { auto * ctx_tgt = this->params.ctx_tgt; auto * ctx_dft = this->params.ctx_dft; GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set"); n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft)); GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) && "MTP input row width must match the target h_nextn width"); n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft))); LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__); LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling); LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__, this->params.n_gpu_layers, ggml_type_name(this->params.cache_type_k), ggml_type_name(this->params.cache_type_v), ctx_tgt ? "yes" : "no", ctx_dft ? "yes" : "no", common_speculative_get_devices_str(this->params.devices).c_str()); const int32_t n_b = (int32_t) llama_n_batch(ctx_dft); batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1); // llama_batch_init allocates only one of token/embd; MTP needs both. // TODO: fix, how to call without malloc batch.token = (llama_token *) malloc(sizeof(llama_token) * n_b); smpls.resize(n_seq); for (auto & s : smpls) { common_params_sampling sparams; sparams.no_perf = false; sparams.top_k = 10; sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K }; s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams)); } // offload draft sampling to the backend backend_chains.assign(n_seq, nullptr); if (this->params.backend_sampling) { for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { llama_sampler * chain = llama_sampler_chain_init(llama_sampler_chain_default_params()); llama_sampler_chain_add(chain, llama_sampler_init_top_k(10)); if (!llama_set_sampler(ctx_dft, seq_id, chain)) { LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id); llama_sampler_free(chain); chain = nullptr; } backend_chains[seq_id] = chain; } } llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false); llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true); is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt; chain_heads = n_mtp_layers > 1 && !is_mem_shared; if (chain_heads) { this->params.n_max = std::min(this->params.n_max, n_mtp_layers); chain_h.assign(n_seq, {}); for (auto & c : chain_h) { c.reserve((size_t) (this->params.n_max + 1) * n_embd); } } pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); i_last.assign(n_seq, -1); i_batch_beg.assign(n_seq, -1); i_batch_end.assign(n_seq, -1); verify_h.assign(n_seq, {}); verify_h_rows.assign(n_seq, 0); } ~common_speculative_impl_draft_mtp() override { auto * ctx_dft = this->params.ctx_dft; for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) backend_chains.size(); ++seq_id) { if (backend_chains[seq_id] == nullptr) { continue; } if (ctx_dft) { llama_set_sampler(ctx_dft, seq_id, nullptr); } llama_sampler_free(backend_chains[seq_id]); } backend_chains.clear(); if (batch.token != nullptr) { free(batch.token); batch.token = nullptr; } llama_batch_free(batch); } void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { const int32_t N = (int32_t) prompt.size(); if (N <= 0) { return; } auto * ctx_dft = this->params.ctx_dft; const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id); if (pos_max < N - 1 && !is_mem_shared) { LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - " "process() hook may not have run on every prefill ubatch " "(need_embd / logits=1 on every prompt position?). " "Drafts may degrade.\n", __func__, (int) pos_max, N - 1); } } bool process(const llama_batch & batch_in) override { if (batch_in.n_tokens <= 0) { return true; } // TODO: how to make it work with vision tokens? if (batch_in.token == nullptr || batch_in.embd != nullptr) { return true; } const int32_t n_tokens = batch_in.n_tokens; // remember the frist and last batch index for each sequence std::fill(i_batch_beg.begin(), i_batch_beg.end(), -1); std::fill(i_batch_end.begin(), i_batch_end.end(), -1); for (int k = 0; k < n_tokens; ++k) { for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { GGML_ASSERT(batch_in.n_seq_id[k] == 1); if (batch_in.seq_id[k][0] == seq_id) { i_batch_end[seq_id] = k; if (i_batch_beg[seq_id] < 0) { i_batch_beg[seq_id] = k; } } } } auto * ctx_tgt = this->params.ctx_tgt; auto * ctx_dft = this->params.ctx_dft; const size_t row_bytes = (size_t) n_embd * sizeof(float); // if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode if (!is_mem_shared) { common_batch_clear(batch); for (int k = 0; k < n_tokens; ++k) { common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0); } // shift the tgt embeddings to the right by one position // assumes that the tokens in the batch are sequential for each sequence // i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1] // ^--- this is a problem // TODO:this is generally true, but would be nice to assert it { const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt); std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1)); } // fill the pending embeddings from a previous run auto set_h = [&](int idx, const float * h_row) { std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes); }; for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { if (i_batch_beg[seq_id] < 0) { continue; } set_h(i_batch_beg[seq_id], pending_h[seq_id].data()); } auto * mem_dft = llama_get_memory(ctx_dft); bool ok = true; for (int head = 0; head < n_mtp_layers; ++head) { if (chain_heads) { // ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544 for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { if (i_batch_beg[seq_id] < 0) { continue; } llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1); } llama_set_nextn_layer_offset(ctx_dft, head); } const int32_t rc = llama_decode(ctx_dft, batch); if (rc != 0) { LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n", __func__, head, (int) rc, (int) batch_in.pos[0]); ok = false; break; } } if (chain_heads) { llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes } if (!ok) { return false; } } for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { if (i_batch_end[seq_id] < 0) { continue; } const int32_t n_rows = i_batch_end[seq_id] - i_batch_beg[seq_id] + 1; verify_h_rows[seq_id] = n_rows; verify_h[seq_id].resize((size_t) n_rows * n_embd); for (int32_t i = 0; i < n_rows; ++i) { const float * h = llama_get_embeddings_nextn_ith(ctx_tgt, i_batch_beg[seq_id] + i); std::memcpy(verify_h[seq_id].data() + (size_t) i * n_embd, h, row_bytes); } std::memcpy(pending_h[seq_id].data(), verify_h[seq_id].data() + (size_t) (n_rows - 1) * n_embd, row_bytes); } return true; } void draft(common_speculative_draft_params_vec & dparams) override { auto & ctx_dft = params.ctx_dft; common_batch_clear(batch); // keep track of which sequences are still drafting int n_drafting = 0; std::vector drafting(n_seq); const size_t row_bytes = (size_t) n_embd * sizeof(float); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } n_drafting++; drafting[seq_id] = true; common_sampler_reset(smpls[seq_id].get()); common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes); i_last[seq_id] = batch.n_tokens - 1; if (chain_heads) { chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end()); } } int i = 0; while (n_drafting > 0) { // each step decodes under a different head, i.e. a different decoder layer, and // KV is per layer. process() filled this layer's KV only for positions < n_past // (prompt + accepted prefix) — nothing in the draft region yet. so reset the // draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact) // and select head i so it rebuilds its own layer's KV there; decoding just the // latest token would leave its attention reading cells only another head wrote. if (chain_heads) { auto * mem_dft = llama_get_memory(ctx_dft); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { if (drafting[seq_id]) { llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1); } } llama_set_nextn_layer_offset(ctx_dft, i); } int ret = llama_decode(ctx_dft, batch); if (ret != 0) { LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); break; } // rebuild the batch for the next step: the growing-KV paths re-add only the // new token (the KV already holds the prefix), while chained heads re-add the // whole prefix at the next head. dropped sequences are simply not re-added. common_batch_clear(batch); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { if (!drafting[seq_id]) { continue; } auto * smpl = smpls[seq_id].get(); common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true); const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]); const auto * cur_p = common_sampler_get_candidates(smpl, true); for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); } // add drafted token for each sequence const llama_token id = cur_p->data[0].id; // only collect very high-confidence draft tokens if (cur_p->data[0].p < params.p_min) { drafting[seq_id] = false; n_drafting--; continue; } common_sampler_accept(smpl, id, true); auto & dp = dparams.at(seq_id); auto & result = *dp.result; result.push_back(id); if (params.n_max <= (int) result.size()) { drafting[seq_id] = false; n_drafting--; continue; } if (chain_heads) { // ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546 chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd); const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far for (int t = 0; t < n_rows; ++t) { const llama_token tok = (t == 0) ? dp.id_last : result[t - 1]; common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1); std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes); } } else if (is_mem_shared) { // note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens // ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37 common_batch_add(batch, id, dp.n_past, { seq_id }, true); std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes); } else { common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes); } i_last[seq_id] = batch.n_tokens - 1; } if (batch.n_tokens == 0) { break; } ++i; } if (chain_heads) { llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes } for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } if (dp.result->size() < (size_t) params.n_min) { dp.result->clear(); } } } void accept(llama_seq_id seq_id, uint16_t n_accepted, bool /*is_other*/) override { if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) { return; } const int32_t n_rows = verify_h_rows[seq_id]; if (n_rows <= 0) { return; } const int32_t i_h = std::min(n_accepted, n_rows - 1); const size_t row_bytes = (size_t) n_embd * sizeof(float); std::memcpy(pending_h[seq_id].data(), verify_h[seq_id].data() + (size_t) i_h * n_embd, row_bytes); } bool need_embd() const override { return false; } bool need_embd_nextn() const override { return true; } }; // state of self-speculation (simple implementation, not ngram-map) struct common_speculative_impl_ngram_simple : public common_speculative_impl { common_params_speculative_ngram_map params; // shared across all sequences common_ngram_simple_config config; common_speculative_impl_ngram_simple( const common_params_speculative & params, uint32_t n_seq, common_ngram_simple_config config) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq) , params(params.ngram_simple) , config(config) { LOG_INF("%s: adding speculative implementation 'ngram-simple'\n", __func__); LOG_INF("%s: - size_n=%d, size_m=%d, min_hits=%d\n", __func__, this->params.size_n, this->params.size_m, this->params.min_hits); } void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop } bool process(const llama_batch & /*batch*/) override { // TODO: implement return true; } void draft(common_speculative_draft_params_vec & dparams) override { assert(dparams.size() == n_seq); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } *dp.result = common_ngram_simple_draft(config, *dp.prompt, dp.id_last); } } void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override { // noop } bool need_embd() const override { return false; } }; struct common_speculative_impl_ngram_map_k : public common_speculative_impl { // n_seq configs std::vector config; common_speculative_impl_ngram_map_k( const common_ngram_map & config, uint32_t n_seq) : common_speculative_impl(config.key_only ? COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K : COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, n_seq) { for (uint32_t i = 0; i < n_seq; i++) { this->config.push_back(config); } LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(this->type).c_str()); LOG_INF("%s: - size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n", __func__, config.size_key, config.size_value, config.key_only, config.min_hits); } void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { GGML_ASSERT(seq_id < (llama_seq_id) n_seq); common_ngram_map_begin(config[seq_id], prompt); } bool process(const llama_batch & /*batch*/) override { // TODO: implement return true; } void draft(common_speculative_draft_params_vec & dparams) override { assert(dparams.size() == n_seq); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } common_ngram_map_draft(config[seq_id], *dp.prompt, dp.id_last, *dp.result); } } void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) override { GGML_ASSERT((seq_id < (llama_seq_id) config.size())); if (is_other) { return; } common_ngram_map_accept(config[seq_id], n_accepted); } bool need_embd() const override { return false; } }; struct common_speculative_impl_ngram_mod : public common_speculative_impl { common_params_speculative_ngram_mod params; // shared across all sequences common_ngram_mod mod; // enable trace logging if LLAMA_TRACE is set const bool verbose; struct seq_info { // the last position in the prompt that was added to the ngram container size_t i_last = 0; // length of the last drafted n-gram (number of tokens returned by draft) size_t n_draft_last = 0; // consecutive accept rounds with low acceptance fraction (< 0.5) int n_low = 0; }; std::vector sinfos; common_speculative_impl_ngram_mod( const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, n_seq) , params(params.ngram_mod) , mod(params.ngram_mod.n_match, 4*1024*1024) , verbose(std::getenv("LLAMA_TRACE") != nullptr) { static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t)); LOG_INF("%s: adding speculative implementation 'ngram-mod'\n", __func__); LOG_INF("%s: - n_match=%d, n_max=%d, n_min=%d\n", __func__, this->params.n_match, this->params.n_max, this->params.n_min); LOG_INF("%s: - mod size=%zu (%.3f MB)\n", __func__, mod.size(), (float)(mod.size_bytes())/1024/1024); if (this->params.n_match < 16) { LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, " "see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, this->params.n_match); } sinfos.resize(n_seq); } void begin(llama_seq_id seq_id, const llama_tokens & prompt) override { auto & sinfo = sinfos[seq_id]; sinfo.i_last = 0; sinfo.n_draft_last = 0; const size_t n = mod.get_n(); if (prompt.size() < n) { return; } for (size_t i = 0; i < prompt.size() - n; ++i) { mod.add(prompt.data() + i); } sinfo.i_last = prompt.size() - n; const double f = (double)mod.get_used() / (double)mod.size(); LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f); constexpr double f_thold = 0.25; if (f > f_thold) { LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold); mod.reset(); } } void draft_one( llama_seq_id seq_id, common_speculative_draft_params & dparams) { auto & sinfo = sinfos[seq_id]; auto & result = *dparams.result; const auto & prompt = *dparams.prompt; sinfo.n_draft_last = 0; const size_t cur_len = prompt.size(); if (cur_len < mod.get_n()) { return; } const size_t n = mod.get_n(); // add new ngrams in chunks if (sinfo.i_last + 32 < cur_len) { for (size_t i = sinfo.i_last; i < cur_len - n; ++i) { mod.add(prompt.data() + i); } sinfo.i_last = cur_len - n; } result.resize(n + params.n_max); for (size_t i = 0; i < n - 1; ++i) { result[i] = prompt.at(cur_len - n + 1 + i); } result[n - 1] = dparams.id_last; for (int i = 0; i < params.n_max; ++i) { const llama_token token = mod.get(result.data() + i); if (token == common_ngram_mod::EMPTY) { if (i < params.n_min) { result.clear(); return; } result.resize(n + i); break; } result[n + i] = token; } // only return the m tokens that were drafted for (size_t i = 0; n + i < result.size(); ++i) { result[i] = result[n + i]; } result.resize(result.size() - n); // store length of drafted n-gram for later acceptance analysis sinfo.n_draft_last = result.size(); } bool process(const llama_batch & /*batch*/) override { // TODO: implement return true; } void draft(common_speculative_draft_params_vec & dparams) override { assert(dparams.size() == n_seq); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } draft_one(seq_id, dp); } } void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) override { if (is_other) { return; } auto & sinfo = sinfos[seq_id]; // compute acceptance fraction if we have a recorded draft length if (sinfo.n_draft_last > 0) { const double f_acc = (double)n_accepted / (double)sinfo.n_draft_last; if (f_acc < 0.25) { sinfo.n_low++; if (sinfo.n_low >= 5) { if (verbose) { LOG_WRN("%s: low acceptance streak (%d) - resetting ngram_mod\n", __func__, sinfo.n_low); } mod.reset(); sinfo.n_low = 0; sinfo.i_last = 0; } } else { sinfo.n_low = 0; } } } bool need_embd() const override { return false; } }; struct common_speculative_impl_ngram_cache : public common_speculative_impl { common_params_speculative_ngram_cache params; uint16_t n_draft; bool save_dynamic; bool save_static; struct seq_info { size_t cache_size = 0; // number of tokens in n-gram cache common_ngram_cache ngram_cache_context; common_ngram_cache ngram_cache_dynamic; common_ngram_cache ngram_cache_static; }; std::vector sinfos; common_speculative_impl_ngram_cache( const common_params_speculative & params, uint32_t n_seq, uint16_t n_draft, const std::string & path_static, const std::string & path_dynamic, bool save_dynamic, bool save_static) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, n_seq) , params(params.ngram_cache) , n_draft(n_draft) , save_dynamic(save_dynamic) , save_static(save_static) { LOG_INF("%s: adding speculative implementation 'ngram-cache'\n", __func__); LOG_INF("%s: - n_draft=%d, cache_static=%s, cache_dynamic=%s\n", __func__, n_draft, path_static.empty() ? "none" : path_static.c_str(), path_dynamic.empty() ? "none" : path_dynamic.c_str()); sinfos.resize(n_seq); if (!path_static.empty()) { try { auto ngram_cache_static = common_ngram_cache_load(path_static); for (auto & sinfo : sinfos) { sinfo.ngram_cache_static = ngram_cache_static; } } catch (...) { LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); GGML_ABORT("Couldn't read static lookup cache"); } } if (!path_dynamic.empty()) { try { auto ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); for (auto & sinfo : sinfos) { sinfo.ngram_cache_dynamic = ngram_cache_dynamic; } } catch (...) { LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); GGML_ABORT("Couldn't read dynamic lookup cache"); } } } void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override { // noop } void draft_one( llama_seq_id seq_id, common_speculative_draft_params & dparams) { auto & sinfo = sinfos[seq_id]; auto & result = *dparams.result; const auto & prompt = *dparams.prompt; if (sinfo.cache_size < prompt.size() + 1) { llama_tokens tokens_new; tokens_new.reserve(prompt.size() + 1 - sinfo.cache_size); for (size_t j = sinfo.cache_size; j < prompt.size(); ++j) { tokens_new.push_back(prompt[j]); } tokens_new.push_back(dparams.id_last); // add the last token // Update context ngram cache with new dparams.prompt: common_ngram_cache_update( sinfo.ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, tokens_new, tokens_new.size(), false); sinfo.cache_size = prompt.size() + 1; } llama_tokens inp; inp.reserve(prompt.size() + 1); for (size_t j = 0; j < prompt.size(); ++j) { inp.push_back(prompt[j]); } inp.push_back(dparams.id_last); result.push_back(dparams.id_last); common_ngram_cache_draft( inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, sinfo.ngram_cache_context, sinfo.ngram_cache_dynamic, sinfo.ngram_cache_static); if (result.size() > 0) { // delete first token in result (which is the id_last token) result.erase(result.begin()); } } bool process(const llama_batch & /*batch*/) override { // TODO: implement return true; } void draft(common_speculative_draft_params_vec & dparams) override { assert(dparams.size() == n_seq); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { continue; } draft_one(seq_id, dp); } } void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override { // noop } bool need_embd() const override { return false; } }; struct common_speculative { common_speculative_draft_params_vec dparams; // list of implementations to use and their states std::vector> impls; // which implementaion was used for a given seq_id std::vector impl_last; }; static common_ngram_map get_common_ngram_map( common_speculative_type type, const common_params_speculative_ngram_map & config) { uint16_t size_key = config.size_n; uint16_t size_value = config.size_m; bool key_only = type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K; uint16_t min_hits = config.min_hits; return common_ngram_map(size_key, size_value, key_only, min_hits); } static common_speculative_impl_ngram_cache create_state_ngram_cache( const common_speculative_config & config, uint32_t n_seq, const std::string & path_static, const std::string & path_dynamic) { uint16_t n_draft = 8; // TODO get from config? // TODO bool param in common/common.h to set save_static/save_dynamic? bool save_static = false; bool save_dynamic = false; common_speculative_impl_ngram_cache state(config.params, n_seq, n_draft, path_static, path_dynamic, save_static, save_dynamic); return state; } std::string common_speculative_type_name_str(const std::vector & types) { std::string result; for (size_t i = 0; i < types.size(); i++) { if (i > 0) { result += ","; } result += common_speculative_type_to_str(types[i]); } return result; } const char * common_speculative_all_types_str() { static std::string all_types_str = []() { std::vector types; types.reserve(COMMON_SPECULATIVE_TYPE_COUNT); for (int i = 0; i < COMMON_SPECULATIVE_TYPE_COUNT; i++) { types.push_back((common_speculative_type) i); } return common_speculative_type_name_str(types); }(); return all_types_str.c_str(); } std::string common_speculative_type_to_str(common_speculative_type type) { switch (type) { case COMMON_SPECULATIVE_TYPE_NONE: return "none"; case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple"; case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3"; case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: return "draft-mtp"; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v"; case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram-mod"; case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram-cache"; default: return "unknown"; } } std::vector common_speculative_types_from_names(const std::vector & names) { std::vector types; types.reserve(names.size()); for (const auto & name : names) { auto type = common_speculative_type_from_name_map.find(name); if (type != common_speculative_type_from_name_map.end()) { if (type->second == COMMON_SPECULATIVE_TYPE_NONE) { return std::vector { COMMON_SPECULATIVE_TYPE_NONE }; } types.push_back(type->second); continue; } throw std::invalid_argument("unknown speculative type: " + name); } return types; } common_speculative_type common_speculative_type_from_name(const std::string & name) { const auto it = common_speculative_type_from_name_map.find(name); if (it == common_speculative_type_from_name_map.end()) { return COMMON_SPECULATIVE_TYPE_COUNT; } return it->second; } static uint32_t common_get_enabled_speculative_configs(const std::vector & configs) { uint32_t result = 0; for (size_t i = 0; i < configs.size(); i++) { result |= (1u << configs[i]); } return result; } int32_t common_speculative_n_max(const common_params_speculative * spec) { int32_t n_max = 0; for (const auto type : spec->types) { switch (type) { case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: n_max = std::max(n_max, std::max(0, spec->draft.n_max)); break; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: n_max = std::max(n_max, (int32_t) spec->ngram_simple.size_m); break; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: n_max = std::max(n_max, (int32_t) spec->ngram_map_k.size_m); break; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: n_max = std::max(n_max, (int32_t) spec->ngram_map_k4v.size_m); break; case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: n_max = std::max(n_max, std::max(0, spec->ngram_mod.n_max)); break; case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: n_max = std::max(n_max, (int32_t) 8); break; case COMMON_SPECULATIVE_TYPE_NONE: case COMMON_SPECULATIVE_TYPE_COUNT: break; } } return n_max; } // initialization of the speculative decoding system // common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq) { // Compute the implementations to use based on the config and their order of preference std::vector configs = {}; // list of speculative configs to try { uint32_t enabled_configs = common_get_enabled_speculative_configs(params.types); bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE)); bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr; bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr; bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE)); bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE)); bool has_ngram_map_k = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K)); bool has_ngram_map_k4v = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V)); bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD)); // when adding a new type - update here the logic above static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 9); // this list here defines the priority of the speculators // the one with highest priority are listed first if (has_ngram_simple) { // This implementation can guess a lot of tokens without any draft model. configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params)); } if (has_ngram_map_k) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, params)); } if (has_ngram_map_k4v) { // This implementation can guess tokens with high acceptance rate but is more expensive. configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params)); } if (has_ngram_mod) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, params)); } if (has_ngram_cache) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); } if (has_draft_simple) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, params)); } if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params)); } if (has_draft_mtp) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params)); } } std::vector> impls = {}; for (const common_speculative_config & config : configs) { switch (config.type) { case COMMON_SPECULATIVE_TYPE_NONE: break; case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: { impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: { impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: { impls.push_back(std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple); uint16_t ngram_size_key = ngram_map.size_key; uint16_t mgram_size_value = ngram_map.size_value; auto config_simple = common_ngram_simple_config { /* .size_ngram = */ ngram_size_key, /* .size_mgram = */ mgram_size_value }; auto state = std::make_unique( /* .params = */ config.params, /* .n_seq = */ n_seq, /* .state = */ config_simple ); impls.push_back(std::move(state)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: { impls.push_back( std::make_unique( get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { impls.push_back( std::make_unique( get_common_ngram_map(config.type, config.params.ngram_map_k4v), n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: { impls.push_back( std::make_unique(config.params, n_seq)); break; } case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { auto state = create_state_ngram_cache( config, n_seq, params.ngram_cache.lookup_cache_static, params.ngram_cache.lookup_cache_dynamic); impls.push_back(std::make_unique(state)); break; } default: break; } } if (impls.empty()) { LOG_WRN("%s: no implementations specified for speculative decoding\n", __func__); return nullptr; } auto * result = new common_speculative { /* .dparams = */ common_speculative_draft_params_vec(n_seq), /* .impls = */ std::move(impls), /* .impl_last = */ std::vector(n_seq, nullptr) }; return result; } void common_speculative_free(common_speculative * spec) { if (spec == nullptr) { return; } delete spec; } common_speculative_draft_params & common_speculative_get_draft_params( common_speculative * spec, llama_seq_id seq_id) { GGML_ASSERT(spec); GGML_ASSERT(seq_id < (llama_seq_id) spec->dparams.size()); return spec->dparams[seq_id]; } void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt) { if (spec == nullptr) { return; } for (auto & impl : spec->impls) { common_time_meas tm(impl->t_begin_us, !impl->gen_perf); impl->begin(seq_id, prompt); impl->n_call_begin++; } } bool common_speculative_process(common_speculative * spec, const llama_batch & batch) { bool result = true; if (spec == nullptr) { return result; } for (auto & impl : spec->impls) { result = result && impl->process(batch); } return result; } bool common_speculative_need_embd(common_speculative * spec) { if (spec == nullptr) { return false; } for (auto & impl : spec->impls) { if (impl->need_embd()) { return true; } } return false; } bool common_speculative_need_embd_nextn(common_speculative * spec) { if (spec == nullptr) { return false; } for (auto & impl : spec->impls) { if (impl->need_embd_nextn()) { return true; } } return false; } void common_speculative_draft(common_speculative * spec) { if (spec == nullptr) { return; } auto & dparams = spec->dparams; { int n_drafting = 0; for (auto & dp : dparams) { GGML_ASSERT(!dp.drafting || dp.result->empty()); if (dp.drafting) { n_drafting++; } } if (n_drafting == 0) { return; } } for (auto & impl : spec->impls) { { common_time_meas tm(impl->t_draft_us, !impl->gen_perf); impl->draft(dparams); impl->n_call_draft++; } int n_drafting = 0; for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) { auto & dp = dparams[seq_id]; auto & result = *dp.result; // a new draft has been sampled if (dp.drafting && !result.empty()) { dp.drafting = false; if (dp.n_max > 0) { if (!result.empty() && (int) result.size() > dp.n_max) { LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.n_max); result.resize(dp.n_max); } } if (!result.empty()) { LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, common_speculative_type_to_str(impl.get()->type).c_str(), dp.prompt->size(), impl.get()->n_call_draft, result.size()); // remember which implementation was used spec->impl_last[seq_id] = impl.get(); impl->n_gen_drafts++; impl->n_gen_tokens += result.size(); } } if (dp.drafting) { n_drafting++; } } if (n_drafting == 0) { break; } } // these sequences failed to generate a draft for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) { auto & dp = dparams[seq_id]; if (dp.drafting) { dp.drafting = false; } } } void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) { common_speculative_impl * impl = spec->impl_last[seq_id]; GGML_ASSERT(impl); { common_time_meas tm(impl->t_accept_us, !impl->gen_perf); if (impl->n_acc_tokens_per_pos.size() < n_accepted) { impl->n_acc_tokens_per_pos.resize(n_accepted, 0); } for (size_t i = 0; i < n_accepted; ++i) { impl->n_acc_tokens_per_pos[i]++; } if (n_accepted > 0) { impl->n_acc_drafts++; impl->n_acc_tokens += n_accepted; } impl->accept(seq_id, n_accepted, false); impl->n_call_accept++; } // accept with the rest of the implementations, using is_other == true for (auto & impl_other : spec->impls) { if (impl_other.get() != impl) { impl_other->accept(seq_id, n_accepted, true); } } } // TODO: support the case of more than one speculative implementations having a state bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector & data) { if (spec == nullptr) { return false; } for (auto & impl : spec->impls) { if (impl->get_state(seq_id, data)) { return true; } } return false; } void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector & data) { if (spec == nullptr) { return; } for (auto & impl : spec->impls) { impl->set_state(seq_id, data); } } void common_speculative_print_stats(const common_speculative * spec) { if (spec == nullptr) { return; } for (const auto & impl : spec->impls) { std::string str_perf; if (impl->gen_perf) { std::ostringstream oss; oss << std::fixed << std::setprecision(3) << impl->t_begin_us / 1000.0 << ", "; oss << std::fixed << std::setprecision(3) << impl->t_draft_us / 1000.0 << ", "; oss << std::fixed << std::setprecision(3) << impl->t_accept_us / 1000.0; str_perf = ", dur(b,g,a) = " + oss.str() + " ms"; } else { str_perf = ""; } std::string str_stats; if (impl->n_call_accept > 0) { const double mean = 1.0 + (double) impl->n_acc_tokens / (double) impl->n_call_accept; std::ostringstream tmp; tmp << std::fixed << std::setprecision(3); for (size_t i = 0; i < impl->n_acc_tokens_per_pos.size(); ++i) { if (i > 0) { tmp << ", "; } tmp << (double) impl->n_acc_tokens_per_pos[i] / (double) impl->n_call_accept; } std::ostringstream oss; oss << std::fixed << std::setprecision(2) << mean; str_stats = ", #mean acc len = " + oss.str() + ", #acc rate/pos = (" + tmp.str() + ")"; } LOG_INF("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n", common_speculative_type_to_str(impl->type).c_str(), impl->n_call_begin, impl->n_call_draft, impl->n_call_accept, impl->n_gen_drafts, impl->n_acc_drafts, impl->n_gen_tokens, impl->n_acc_tokens, str_stats.c_str(), str_perf.c_str()); } }