Files
llama.cpp/common/speculative.cpp
2026-06-21 11:37:12 +03:00

2305 lines
86 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include "speculative.h"
#include "common.h"
#include "ggml.h"
#include "llama.h"
#include "log.h"
#include "ngram-cache.h"
#include "ngram-map.h"
#include "ngram-mod.h"
#include "sampling.h"
#include "../src/llama-ext.h" // staging API: llama_set_embeddings_nextn / llama_get_embeddings_nextn_ith (used by MTP)
#include <algorithm>
#include <cassert>
#include <cstring>
#include <iomanip>
#include <map>
#include <cinttypes>
#define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128
#define SPEC_VOCAB_CHECK_START_TOKEN_ID 5
const std::map<std::string, common_speculative_type> common_speculative_type_from_name_map = {
{"none", COMMON_SPECULATIVE_TYPE_NONE},
{"draft-simple", COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE},
{"draft-eagle3", COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3},
{"draft-mtp", COMMON_SPECULATIVE_TYPE_DRAFT_MTP},
{"ngram-simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram-map-k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram-map-k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
{"ngram-mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD},
{"ngram-cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE}
};
static std::string common_speculative_get_devices_str(const std::vector<ggml_backend_dev_t> & devices) {
std::string result;
for (size_t i = 0; i < devices.size(); i++) {
if (devices[i] == nullptr) {
continue;
}
if (!result.empty()) result += ", ";
result += ggml_backend_dev_name(devices[i]);
}
return result.empty() ? "default" : result;
}
struct common_speculative_config {
common_speculative_type type;
common_params_speculative params;
common_speculative_config(common_speculative_type t,
const common_params_speculative & p = common_params_speculative{}) : type(t), params(p) {}
};
static bool common_speculative_are_compatible(
const llama_model * model_tgt,
const llama_model * model_dft) {
const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt);
const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft);
const auto vocab_type_tgt = llama_vocab_type(vocab_tgt);
LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt);
const auto vocab_type_dft = llama_vocab_type(vocab_dft);
LOG_DBG("%s: vocab_type dft: %d\n", __func__, vocab_type_dft);
if (vocab_type_tgt != vocab_type_dft) {
LOG_WRN("%s: draft model vocab type must match target model to use speculation but "
"vocab_type_dft = %d while vocab_type_tgt = %d\n", __func__, vocab_type_dft, vocab_type_tgt);
return false;
}
if (llama_vocab_get_add_bos(vocab_tgt) != llama_vocab_get_add_bos(vocab_dft) ||
(llama_vocab_get_add_bos(vocab_tgt) && llama_vocab_bos(vocab_tgt) != llama_vocab_bos(vocab_dft))) {
LOG_WRN("%s: draft model bos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
__func__,
llama_vocab_get_add_bos(vocab_tgt), llama_vocab_get_add_bos(vocab_dft),
llama_vocab_bos(vocab_tgt), llama_vocab_bos(vocab_dft));
return false;
}
if (llama_vocab_get_add_eos(vocab_tgt) != llama_vocab_get_add_eos(vocab_dft) ||
(llama_vocab_get_add_eos(vocab_tgt) && llama_vocab_eos(vocab_tgt) != llama_vocab_eos(vocab_dft))) {
LOG_WRN("%s: draft model eos tokens must match target model to use speculation. add: %d - %d, id: %d - %d)\n",
__func__,
llama_vocab_get_add_eos(vocab_tgt), llama_vocab_get_add_eos(vocab_dft),
llama_vocab_eos(vocab_tgt), llama_vocab_eos(vocab_dft));
return false;
}
{
const int n_vocab_tgt = llama_vocab_n_tokens(vocab_tgt);
const int n_vocab_dft = llama_vocab_n_tokens(vocab_dft);
const int vocab_diff = n_vocab_tgt > n_vocab_dft
? n_vocab_tgt - n_vocab_dft
: n_vocab_dft - n_vocab_tgt;
if (vocab_diff > SPEC_VOCAB_MAX_SIZE_DIFFERENCE) {
LOG_DBG("%s: draft model vocab must closely match target model to use speculation but ", __func__);
LOG_DBG("target vocab size %d does not match draft vocab size %d - difference %d, max allowed %d\n",
n_vocab_tgt, llama_vocab_n_tokens(vocab_dft), vocab_diff, SPEC_VOCAB_MAX_SIZE_DIFFERENCE);
return false;
}
for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) {
const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i);
const char * token_text_dft = llama_vocab_get_text(vocab_dft, i);
if (std::strcmp(token_text_tgt, token_text_dft) != 0) {
LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__);
LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i,
common_token_to_piece(vocab_tgt, i).c_str(),
common_token_to_piece(vocab_dft, i).c_str());
return false;
}
}
}
return true;
}
using common_speculative_draft_params_vec = std::vector<common_speculative_draft_params>;
// state of an implementation of speculative decoding
//
// each implementation has a unique type and a state that is implementation-specific
// in a subclass of common_speculative_impl
struct common_speculative_impl {
const common_speculative_type type;
uint32_t n_seq;
size_t n_call_begin = 0; // number of times this implementation was called for refresh.
size_t n_call_draft = 0; // number of times this implementation was called for generation.
size_t n_call_accept = 0; // number of times this implementation was called for accumulation.
size_t n_gen_drafts = 0; // number of times a draft or part was generated by this implementation.
size_t n_acc_drafts = 0; // number of times a draft or part was accepted by the target model.
size_t n_gen_tokens = 0; // number of tokens generated by this implementation.
size_t n_acc_tokens = 0; // number of tokens accepted by the target model.
std::vector<size_t> n_acc_tokens_per_pos; // number of tokens accepted per draft position.
// TODO: track performance of most recent calls
const bool gen_perf = true; // whether to generate performance stats.
int64_t t_begin_us = 0; // total time spent in refresh of this implementation in microseconds.
int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds.
int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds.
common_speculative_impl(common_speculative_type type, uint32_t n_seq) : type(type), n_seq(n_seq) {}
virtual ~common_speculative_impl() = default;
virtual void begin(llama_seq_id seq_id, const llama_tokens & prompt) = 0;
virtual bool process(const llama_batch & batch) = 0;
virtual void draft(common_speculative_draft_params_vec & dparams) = 0;
virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0;
// (optional) serialize/restore per-seq internal state (e.g. eagle3's deferred boundary).
virtual bool get_state(llama_seq_id /*seq_id*/, std::vector<uint8_t> & /*data*/) const { return false; }
virtual void set_state(llama_seq_id /*seq_id*/, const std::vector<uint8_t> & /*data*/) {}
// true if this implementation requires the target context to extract post-norm embeddings
virtual bool need_embd() const = 0;
// true if this implementation requires the target context to extract pre-norm embeddings
virtual bool need_embd_nextn() const { return false; }
};
struct common_speculative_impl_draft_simple : public common_speculative_impl {
common_params_speculative_draft params;
llama_batch batch;
std::vector<common_sampler_ptr> smpls;
common_speculative_impl_draft_simple(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, n_seq)
, params(params.draft)
{
auto * ctx_dft = this->params.ctx_dft;
auto * ctx_tgt = this->params.ctx_tgt;
LOG_INF("%s: adding speculative implementation 'draft-simple'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min);
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
this->params.n_gpu_layers,
ggml_type_name(this->params.cache_type_k),
ggml_type_name(this->params.cache_type_v),
ctx_tgt ? "yes" : "no",
ctx_dft ? "yes" : "no",
common_speculative_get_devices_str(this->params.devices).c_str());
batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1);
// TODO: optimize or pass from outside?
// {
// common_params_sampling params;
// params.no_perf = false;
//
// params.top_k = 40;
// params.top_p = 0.9;
//
// params.samplers = {
// COMMON_SAMPLER_TYPE_TOP_K,
// COMMON_SAMPLER_TYPE_TOP_P,
// COMMON_SAMPLER_TYPE_INFILL,
// };
//
// result->smpl = common_sampler_init(llama_get_model(ctx_dft), params);
// }
smpls.resize(n_seq);
for (auto & smpl : smpls) {
common_params_sampling params;
params.no_perf = false;
params.top_k = 10;
params.samplers = {
COMMON_SAMPLER_TYPE_TOP_K,
};
smpl.reset(common_sampler_init(llama_get_model(ctx_dft), params));
}
const bool vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft));
LOG_DBG("%s: vocab_cmpt = %d\n", __func__, vocab_cmpt);
if (!vocab_cmpt) {
LOG_ERR("%s: the target and draft vocabs are not compatible\n", __func__);
throw std::runtime_error("draft model vocab type must match target model to use speculation");
}
if (n_seq != llama_n_seq_max(ctx_dft)) {
LOG_ERR("%s: n_seq mismatch: %d != %d\n", __func__, n_seq, llama_n_seq_max(ctx_dft));
throw std::runtime_error("the draft model number of sequences is incompatible with the speculative n_seq");
}
}
~common_speculative_impl_draft_simple() override {
llama_batch_free(batch);
}
void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override {
// noop
}
bool process(const llama_batch & batch) override {
auto * ctx_dft = params.ctx_dft;
const int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_ERR("%s: failed to decode draft batch, ret = %d\n", __func__, ret);
return false;
}
return true;
}
void draft(common_speculative_draft_params_vec & dparams) override {
auto & ctx_dft = params.ctx_dft;
common_batch_clear(batch);
// keep track of which sequences are still drafting
int n_drafting = 0;
std::vector<bool> drafting(n_seq);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
n_drafting++;
drafting[seq_id] = true;
common_sampler_reset(smpls[seq_id].get());
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
}
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
return;
}
int i = 0;
while (n_drafting > 0) {
int i_batch = 0;
common_batch_clear(batch);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (!drafting[seq_id]) {
continue;
}
auto * smpl = smpls[seq_id].get();
common_sampler_sample(smpl, ctx_dft, i_batch, true);
++i_batch;
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
drafting[seq_id] = false;
n_drafting--;
continue;
}
common_sampler_accept(smpl, id, true);
auto & dp = dparams.at(seq_id);
auto & result = *dp.result;
result.push_back(id);
if ((params.n_max <= (int) result.size()) ||
(dp.n_max > 0 && dp.n_max <= (int) result.size())) {
drafting[seq_id] = false;
n_drafting--;
continue;
}
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
}
if (batch.n_tokens == 0) {
break;
}
// evaluate the drafted tokens on the draft model
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
++i;
}
for (auto & dp : dparams) {
if (!dp.drafting) {
continue;
}
if (dp.result->size() < (size_t) params.n_min) {
dp.result->clear();
}
}
}
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
// noop
}
bool need_embd() const override {
return false;
}
};
// EAGLE3 speculative decoding state
//
// Input of draft decoder: (This is different compared to MTP)
// At "pos P", the decoder takes input pair (t_{P+1}, g_P), with RoPE at P.
// - t_{P+1} = token at sequence pos P+1 (the *next* token after P)
// - g_P = encoder output = projection of target's extracted hidden states at P
//
// Deferred boundary (MTP doesn't have this issue):
// Within a single process() call with n_tokens, we can only write decoder KV for
// training pos 0..n_tokens-2. The last training pos (n_tokens-1) needs t_{n_tokens}
// which lies *outside* this batch — it is the token target will sample next or the first token from next ubatch.
// So the last training pos of each process() call is *deferred* to whichever next call has
// the missing token in hand:
// - multi-ubatch prefill: the next process()'s first token completes the pair
// (handled by the per-seq "cross-ubatch bridge")
// - single-ubatch prefill / after verify: draft()'s seed step uses "dp.id_last"
// (target's freshest sample) to complete the pair
//
// Per-seq carry-over state:
// pending_g_last [n_embd_dec] ┐ the deferred boundary's (g, pos). Set by
// pending_pos_last llama_pos ┘ process() at end of ubatch (= last row);
// rebased by accept() to first-non-accepted pos.
// verify_g [N × n_embd_dec] snapshot of process()'s encoder output;
// verify_pos_first llama_pos consumed by accept() to recover the right
// verify_g_rows int32_t pending_g_last row for any n_accepted value.
//
// Performance is overall good but there is waste in verify cycle:
// process() runs encoder + decoder on the *full* verify batch including rows for
// rejected drafts. The KV at those positions is then dropped.
//
// TODO: Not sure if we need optimization for this waste?
// If so we may need hybrid stash:
// in verify mode, have process() only stash features and let draft() seed run
// encoder+decoder on n_accepted+1 rows).
struct common_speculative_impl_draft_eagle3 : public common_speculative_impl {
common_params_speculative_draft params;
llama_batch batch;
std::vector<common_sampler_ptr> smpls;
// backend sampler chain per seq, attached to ctx_dft
std::vector<llama_sampler *> backend_chains;
int32_t n_embd_dec = 0; // draft hidden size
int32_t n_embd_enc = 0; // target_layer_ids_n * target_hidden_size
int32_t n_embd_tgt = 0; // target model hidden size
const int32_t * target_layer_ids = nullptr; // model_dft's extract layer indices
uint32_t target_layer_ids_n = 0;
// [per-seq] deferred boundary state
std::vector<std::vector<float>> pending_g_last;
std::vector<llama_pos> pending_pos_last;
// [per-seq] snapshot of the most recent process()'s encoder output
std::vector<std::vector<float>> verify_g; // [n_seq][n_rows * n_embd_dec]
std::vector<llama_pos> verify_pos_first; // [n_seq] — pos of verify_g[seq][0]
std::vector<int32_t> verify_g_rows; // [n_seq] — number of rows
// scratch buffer for concatenated target features [n_tokens, n_embd_enc]
std::vector<float> features_buf;
std::vector<float> g_embd_buf;
common_speculative_impl_draft_eagle3(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, n_seq)
, params(params.draft)
{
LOG_INF("%s: adding speculative implementation 'draft-eagle3'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%f, backend_sampling=%d\n", __func__, params.draft.n_max, params.draft.n_min, params.draft.p_min, (int) params.draft.backend_sampling);
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "EAGLE3 requires ctx_tgt and ctx_dft to be set");
const llama_model * model_dft = llama_get_model(ctx_dft);
const llama_model * model_tgt = llama_get_model(ctx_tgt);
target_layer_ids = llama_model_target_layer_ids (model_dft);
target_layer_ids_n = llama_model_target_layer_ids_n(model_dft);
if (target_layer_ids_n != 3) {
throw std::runtime_error("draft model is not eagle3 (expected 3 extract layers, got " +
std::to_string(target_layer_ids_n) + ")");
}
n_embd_tgt = llama_model_n_embd(model_tgt);
n_embd_dec = llama_model_n_embd(model_dft);
n_embd_enc = (int32_t) target_layer_ids_n * n_embd_tgt;
const int32_t n_b = (int32_t) llama_n_batch(ctx_dft);
batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd_dec, /*n_seq_max=*/ 1);
// llama_batch_init allocates only one of token/embd; eagle3 decoder needs both.
// TODO: fix, how to call without malloc
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_b);
smpls.resize(n_seq);
for (auto & s : smpls) {
common_params_sampling sparams;
sparams.no_perf = false;
sparams.top_k = 10;
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
}
// offload draft sampling to the backend
backend_chains.assign(n_seq, nullptr);
if (this->params.backend_sampling) {
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
llama_sampler * chain = llama_sampler_chain_init(llama_sampler_chain_default_params());
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
llama_sampler_free(chain);
chain = nullptr;
}
backend_chains[seq_id] = chain;
}
}
// turn on extraction of the target layers' input embeddings
for (uint32_t k = 0; k < target_layer_ids_n; ++k) {
llama_set_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k], true);
}
// turn on extraction of the draft model's pre-norm hidden state
// (used both for the encoder output g_embd and the decoder pre-norm output).
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
pending_g_last.assign(n_seq, std::vector<float>(n_embd_dec, 0.0f));
pending_pos_last.assign(n_seq, -1);
verify_g.assign(n_seq, std::vector<float>());
verify_pos_first.assign(n_seq, -1);
verify_g_rows.assign(n_seq, 0);
}
~common_speculative_impl_draft_eagle3() override {
auto * ctx_dft = this->params.ctx_dft;
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) backend_chains.size(); ++seq_id) {
if (backend_chains[seq_id] == nullptr) {
continue;
}
if (ctx_dft) {
llama_set_sampler(ctx_dft, seq_id, nullptr);
}
llama_sampler_free(backend_chains[seq_id]);
}
backend_chains.clear();
if (batch.token != nullptr) {
free(batch.token);
batch.token = nullptr;
}
llama_batch_free(batch);
}
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
const int32_t N = (int32_t) prompt.size();
if (N <= 0) {
return;
}
// expected state after prefill: ctx_dft has pos 0..N-2 (last position is deferred to
// draft()'s seed step). Warn only if more than one position is missing.
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 2) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-2=%d — process() did not run on every prefill ubatch. "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 2);
}
}
bool process(const llama_batch & batch_in) override {
if (batch_in.n_tokens <= 0) {
return true;
}
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
return true;
}
const int32_t n_tokens = batch_in.n_tokens;
// i_batch_beg[seq] / i_batch_end[seq]: inclusive batch indices of this seq's
// first/last token in batch_in. Assumes per-seq tokens are contiguous within
// the ubatch (server's default ordering).
std::vector<int32_t> i_batch_beg(n_seq, -1);
std::vector<int32_t> i_batch_end(n_seq, -1);
for (int k = 0; k < n_tokens; ++k) {
GGML_ASSERT(batch_in.n_seq_id[k] == 1);
const llama_seq_id seq_id = batch_in.seq_id[k][0];
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
continue;
}
i_batch_end[seq_id] = k;
if (i_batch_beg[seq_id] < 0) {
i_batch_beg[seq_id] = k;
}
}
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
// Interleave each extract_layer's hidden state into a contiguous buffer of
// shape [n_tokens, target_layer_ids_n * n_embd_tgt]. Then run EAGLE3 encoder
// to get one g_embd row per token.
features_buf.resize((size_t) n_tokens * n_embd_enc, 0.0f);
for (uint32_t k = 0; k < target_layer_ids_n; ++k) {
const float * layer = llama_get_embeddings_layer_inp(ctx_tgt, (uint32_t) target_layer_ids[k]);
if (!layer) {
GGML_ABORT("EAGLE3: target layer %d input not extracted.", target_layer_ids[k]);
}
for (int32_t i = 0; i < n_tokens; ++i) {
float * dst = features_buf.data() + (size_t) i * n_embd_enc + k * (size_t) n_embd_tgt;
const float * src = layer + (size_t) i * n_embd_tgt;
std::memcpy(dst, src, (size_t) n_embd_tgt * sizeof(float));
}
}
g_embd_buf.resize((size_t) n_tokens * n_embd_dec);
// llama_encode() requires the full encoder batch to fit in n_ubatch.
// Allow batch > ubatch: eagle3's per-token encoder can be chunked safely.
const int32_t n_ubatch_dft = (int32_t) llama_n_ubatch(ctx_dft);
for (int32_t i = 0; i < n_tokens; i += n_ubatch_dft) {
const int32_t n_chunk = std::min(n_ubatch_dft, n_tokens - i);
llama_batch enc_batch = {
/*.n_tokens =*/ n_chunk,
/*.token =*/ nullptr,
/*.embd =*/ features_buf.data() + (size_t) i * n_embd_enc,
/*.pos =*/ nullptr,
/*.n_seq_id =*/ nullptr,
/*.seq_id =*/ nullptr,
/*.logits =*/ nullptr,
};
const int32_t rc = llama_encode(ctx_dft, enc_batch);
if (rc != 0) {
LOG_ERR("%s: llama_encode(ctx_dft) failed rc=%d (n_tokens=%d, offset=%d)\n",
__func__, rc, (int) n_chunk, (int) i);
return false;
}
// g_embd has shape [n_chunk, n_embd_dec] in ctx_dft's pre-norm embeddings buffer.
const float * g_embd_chunk = llama_get_embeddings_nextn(ctx_dft);
GGML_ASSERT(g_embd_chunk && "EAGLE3 encoder produced no output.");
std::memcpy(g_embd_buf.data() + (size_t) i * n_embd_dec,
g_embd_chunk,
(size_t) n_chunk * n_embd_dec * sizeof(float));
}
const float * g_embd = g_embd_buf.data();
const size_t row_bytes = (size_t) n_embd_dec * sizeof(float);
// EAGLE3 decoder input convention: at memory pos P the input pair is
// (token[P+1], g_embd[P]). This shifts the token index "left by one" relative to g_embd.
//
// Per seq, in order:
// (a) cross-ubatch bridge — when applicable, write the previously-deferred
// pos using this ubatch's first token + pending_g_last.
// (b) main write loop — for k in [beg, end-1], write (token[k+1], g_embd[k])
// at pos[k]. The last training pos (k=end) is left unwritten = new
// deferred boundary, completed by the next process() or draft() call.
// (c) refresh deferred state — stash this ubatch's full g_embd into verify_g,
// update pending_g_last / pending_pos_last to the last row.
common_batch_clear(batch);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
const int32_t beg = i_batch_beg[seq_id];
const int32_t end = i_batch_end[seq_id];
if (beg < 0 || end < 0) {
continue;
}
// cross-ubatch bridge — complete the prior ubatch's deferred boundary.
// Fires iff all three preconditions hold:
// 1) pending_pos_last >= 0
// 2) pending_pos_last + 1 == pos[beg]
// 3) pending_pos_last > dft_pos_max // TODO: is this check needed?
const llama_pos pending_pos = pending_pos_last[seq_id];
if (pending_pos >= 0 && pending_pos + 1 == batch_in.pos[beg]) {
const llama_pos dft_pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pending_pos > dft_pos_max) {
common_batch_add(batch, batch_in.token[beg], pending_pos, { seq_id }, /*logits=*/ false);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd_dec,
pending_g_last[seq_id].data(), row_bytes);
}
}
for (int32_t k = beg; k < end; ++k) {
common_batch_add(batch, batch_in.token[k + 1], batch_in.pos[k], { seq_id }, /*logits=*/ false);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd_dec,
g_embd + (size_t) k * n_embd_dec, row_bytes);
}
// refresh deferred state
const int32_t n_rows = end - beg + 1;
verify_pos_first[seq_id] = batch_in.pos[beg];
pending_pos_last[seq_id] = batch_in.pos[end];
verify_g_rows[seq_id] = n_rows;
verify_g[seq_id].resize((size_t) n_rows * n_embd_dec, 0.0f);
std::memcpy(verify_g[seq_id].data(), g_embd + (size_t) beg * n_embd_dec, row_bytes * n_rows);
std::memcpy(pending_g_last[seq_id].data(), g_embd + (size_t) end * n_embd_dec, row_bytes);
}
if (batch.n_tokens > 0) {
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (n_tokens=%d, ubatch_pos[0]=%d)\n",
__func__, rc, (int) batch.n_tokens, (int) batch_in.pos[0]);
return false;
}
}
return true;
}
void draft(common_speculative_draft_params_vec & dparams) override {
auto & ctx_dft = params.ctx_dft;
common_batch_clear(batch);
// keep track of which sequences are still drafting
int n_drafting = 0;
std::vector<bool> drafting(n_seq);
const size_t row_bytes = (size_t) n_embd_dec * sizeof(float);
// Complete the deferred boundary pair (dp.id_last, pending_g_last) at memory
// pos pending_pos_last. dp.id_last is target's freshest sample (= corrected
// token after verify, or first generated token after prefill), matching the
// EAGLE3 input convention (token[P+1], g_embd[P]) at pos P.
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
if (pending_pos_last[seq_id] < 0) {
continue;
}
n_drafting++;
drafting[seq_id] = true;
common_sampler_reset(smpls[seq_id].get());
llama_memory_seq_rm(llama_get_memory(ctx_dft), seq_id, pending_pos_last[seq_id], -1);
common_batch_add(batch, dp.id_last, pending_pos_last[seq_id], { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd_dec,
pending_g_last[seq_id].data(),
row_bytes);
}
if (batch.n_tokens == 0) {
return;
}
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode returned %d\n", __func__, ret);
return;
}
int i = 0;
while (n_drafting > 0) {
int i_batch = 0;
common_batch_clear(batch);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (!drafting[seq_id]) {
continue;
}
auto * smpl = smpls[seq_id].get();
common_sampler_sample(smpl, ctx_dft, i_batch, true);
// pre-norm hidden state of this position becomes g_embd for the next step
const float * prenorm = llama_get_embeddings_nextn_ith(ctx_dft, i_batch);
++i_batch;
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens
// (configurable via --spec-draft-p-min, set to 0.0 to disable early-stop)
if (cur_p->data[0].p < params.p_min) {
drafting[seq_id] = false;
n_drafting--;
continue;
}
common_sampler_accept(smpl, id, true);
auto & dp = dparams.at(seq_id);
auto & result = *dp.result;
result.push_back(id);
if (params.n_max <= (int) result.size()) {
drafting[seq_id] = false;
n_drafting--;
continue;
}
common_batch_add(batch, id, pending_pos_last[seq_id] + (i + 1), { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd_dec, prenorm, row_bytes);
}
if (batch.n_tokens == 0) {
break;
}
ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
++i;
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
if (dp.result->size() < (size_t) params.n_min) {
dp.result->clear();
}
}
}
void accept(llama_seq_id seq_id, uint16_t n_accepted, bool /*is_other*/) override {
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
return;
}
const int32_t n_rows = verify_g_rows[seq_id];
if (n_rows <= 0) {
return;
}
const int32_t i_g = std::min<int32_t>(n_accepted, n_rows - 1);
pending_pos_last[seq_id] = verify_pos_first[seq_id] + i_g;
std::memcpy(pending_g_last[seq_id].data(),
verify_g[seq_id].data() + (size_t) i_g * n_embd_dec,
(size_t) n_embd_dec * sizeof(float));
}
// we only need to stash the deferred boundary's g_embd row for recurrent/hybrid targets:
// their single-position checkpoints drop it on restore
bool need_boundary_stash() const {
const llama_model * model_tgt = llama_get_model(params.ctx_tgt);
return llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt);
}
bool get_state(llama_seq_id seq_id, std::vector<uint8_t> & data) const override {
if (!need_boundary_stash()) {
return false;
}
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || pending_pos_last[seq_id] < 0) {
return false;
}
const llama_pos pos = pending_pos_last[seq_id];
const std::vector<float> & g = pending_g_last[seq_id];
data.resize(sizeof(llama_pos) + g.size() * sizeof(float));
std::memcpy(data.data(), &pos, sizeof(llama_pos));
std::memcpy(data.data() + sizeof(llama_pos), g.data(), g.size() * sizeof(float));
return true;
}
void set_state(llama_seq_id seq_id, const std::vector<uint8_t> & data) override {
if (!need_boundary_stash()) {
return;
}
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
return;
}
if (data.size() != sizeof(llama_pos) + (size_t) n_embd_dec * sizeof(float)) {
return;
}
llama_pos pos = -1;
std::memcpy(&pos, data.data(), sizeof(llama_pos));
pending_pos_last[seq_id] = pos;
pending_g_last[seq_id].resize(n_embd_dec);
std::memcpy(pending_g_last[seq_id].data(), data.data() + sizeof(llama_pos), (size_t) n_embd_dec * sizeof(float));
}
bool need_embd() const override {
return false;
}
};
struct common_speculative_impl_draft_mtp : public common_speculative_impl {
common_params_speculative_draft params; // reuses the draft-model params slot (ctx_tgt/ctx_dft)
llama_batch batch;
std::vector<common_sampler_ptr> smpls;
// backend sampler chain per seq, attached to ctx_dft
std::vector<llama_sampler *> backend_chains;
int32_t n_embd = 0;
// One MTP draft driver, three modes (set once in the ctor):
// is_mem_shared (gemma4): shares the target KV, runs all heads in one graph.
// chain_heads (step35): n_mtp_layers trained heads, one per draft step.
// neither (qwen35 / qwen35moe): a single trained MTP head.
int32_t n_mtp_layers = 1;
bool is_mem_shared = false; // gemma4
bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared
// Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1.
// The last h-row of one process() call needs the first token of the NEXT
// call to pair with, so it's stashed here until that next call fires.
std::vector<std::vector<float>> pending_h; // [n_seq][n_embd]
std::vector<int32_t> i_batch_beg;
std::vector<int32_t> i_batch_end;
// Hidden rows from the most recent target verification batch, grouped by seq.
// Row 0 corresponds to the sampled token, row N to the Nth accepted draft token.
std::vector<std::vector<float>> verify_h;
std::vector<int32_t> verify_h_rows;
std::vector<int> i_last;
std::vector<std::vector<float>> chain_h;
common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq)
, params(params.draft)
{
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
GGML_ASSERT(ctx_tgt && ctx_dft && "MTP requires ctx_tgt and ctx_dft to be set");
n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft));
GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) &&
"MTP input row width must match the target h_nextn width");
n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft)));
LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__);
LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling);
LOG_INF("%s: - gpu_layers=%d, cache_k=%s, cache_v=%s, ctx_tgt=%s, ctx_dft=%s, devices=[%s]\n", __func__,
this->params.n_gpu_layers,
ggml_type_name(this->params.cache_type_k),
ggml_type_name(this->params.cache_type_v),
ctx_tgt ? "yes" : "no",
ctx_dft ? "yes" : "no",
common_speculative_get_devices_str(this->params.devices).c_str());
const int32_t n_b = (int32_t) llama_n_batch(ctx_dft);
batch = llama_batch_init(/*n_tokens=*/ n_b, /*embd=*/ n_embd, /*n_seq_max=*/ 1);
// llama_batch_init allocates only one of token/embd; MTP needs both.
// TODO: fix, how to call without malloc
batch.token = (llama_token *) malloc(sizeof(llama_token) * n_b);
smpls.resize(n_seq);
for (auto & s : smpls) {
common_params_sampling sparams;
sparams.no_perf = false;
sparams.top_k = 10;
sparams.samplers = { COMMON_SAMPLER_TYPE_TOP_K };
s.reset(common_sampler_init(llama_get_model(ctx_dft), sparams));
}
// offload draft sampling to the backend
backend_chains.assign(n_seq, nullptr);
if (this->params.backend_sampling) {
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
llama_sampler * chain = llama_sampler_chain_init(llama_sampler_chain_default_params());
llama_sampler_chain_add(chain, llama_sampler_init_top_k(10));
if (!llama_set_sampler(ctx_dft, seq_id, chain)) {
LOG_WRN("%s: backend offload failed for seq_id=%d; using CPU sampler\n", __func__, (int) seq_id);
llama_sampler_free(chain);
chain = nullptr;
}
backend_chains[seq_id] = chain;
}
}
llama_set_embeddings_nextn(ctx_tgt, true, /*masked*/ false);
llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true);
is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt;
chain_heads = n_mtp_layers > 1 && !is_mem_shared;
if (chain_heads) {
this->params.n_max = std::min(this->params.n_max, n_mtp_layers);
chain_h.assign(n_seq, {});
for (auto & c : chain_h) {
c.reserve((size_t) (this->params.n_max + 1) * n_embd);
}
}
pending_h.assign(n_seq, std::vector<float>(n_embd, 0.0f));
i_last.assign(n_seq, -1);
i_batch_beg.assign(n_seq, -1);
i_batch_end.assign(n_seq, -1);
verify_h.assign(n_seq, {});
verify_h_rows.assign(n_seq, 0);
}
~common_speculative_impl_draft_mtp() override {
auto * ctx_dft = this->params.ctx_dft;
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) backend_chains.size(); ++seq_id) {
if (backend_chains[seq_id] == nullptr) {
continue;
}
if (ctx_dft) {
llama_set_sampler(ctx_dft, seq_id, nullptr);
}
llama_sampler_free(backend_chains[seq_id]);
}
backend_chains.clear();
if (batch.token != nullptr) {
free(batch.token);
batch.token = nullptr;
}
llama_batch_free(batch);
}
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
const int32_t N = (int32_t) prompt.size();
if (N <= 0) {
return;
}
auto * ctx_dft = this->params.ctx_dft;
const llama_pos pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_dft), seq_id);
if (pos_max < N - 1 && !is_mem_shared) {
LOG_WRN("%s: ctx_dft pos_max=%d < N-1=%d - "
"process() hook may not have run on every prefill ubatch "
"(need_embd / logits=1 on every prompt position?). "
"Drafts may degrade.\n",
__func__, (int) pos_max, N - 1);
}
}
bool process(const llama_batch & batch_in) override {
if (batch_in.n_tokens <= 0) {
return true;
}
// TODO: how to make it work with vision tokens?
if (batch_in.token == nullptr || batch_in.embd != nullptr) {
return true;
}
const int32_t n_tokens = batch_in.n_tokens;
// remember the frist and last batch index for each sequence
std::fill(i_batch_beg.begin(), i_batch_beg.end(), -1);
std::fill(i_batch_end.begin(), i_batch_end.end(), -1);
for (int k = 0; k < n_tokens; ++k) {
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
GGML_ASSERT(batch_in.n_seq_id[k] == 1);
if (batch_in.seq_id[k][0] == seq_id) {
i_batch_end[seq_id] = k;
if (i_batch_beg[seq_id] < 0) {
i_batch_beg[seq_id] = k;
}
}
}
}
auto * ctx_tgt = this->params.ctx_tgt;
auto * ctx_dft = this->params.ctx_dft;
const size_t row_bytes = (size_t) n_embd * sizeof(float);
// if kv is shared with target (e.g Gemma4), then we can skip this catch-up decode
if (!is_mem_shared) {
common_batch_clear(batch);
for (int k = 0; k < n_tokens; ++k) {
common_batch_add(batch, batch_in.token[k], batch_in.pos[k], { batch_in.seq_id[k][0] }, 0);
}
// shift the tgt embeddings to the right by one position
// assumes that the tokens in the batch are sequential for each sequence
// i.e. we cannot have seq_id like this: [0, 0, 0, 1, 1, 0, 1, 1]
// ^--- this is a problem
// TODO:this is generally true, but would be nice to assert it
{
const float * h_tgt = llama_get_embeddings_nextn(ctx_tgt);
std::memcpy(batch.embd + (size_t) 1 * n_embd, h_tgt, row_bytes * (n_tokens-1));
}
// fill the pending embeddings from a previous run
auto set_h = [&](int idx, const float * h_row) {
std::memcpy(batch.embd + (size_t) idx * n_embd, h_row, row_bytes);
};
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
set_h(i_batch_beg[seq_id], pending_h[seq_id].data());
}
auto * mem_dft = llama_get_memory(ctx_dft);
bool ok = true;
for (int head = 0; head < n_mtp_layers; ++head) {
if (chain_heads) {
// ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_beg[seq_id] < 0) {
continue;
}
llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1);
}
llama_set_nextn_layer_offset(ctx_dft, head);
}
const int32_t rc = llama_decode(ctx_dft, batch);
if (rc != 0) {
LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n",
__func__, head, (int) rc, (int) batch_in.pos[0]);
ok = false;
break;
}
}
if (chain_heads) {
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
}
if (!ok) {
return false;
}
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (i_batch_end[seq_id] < 0) {
continue;
}
const int32_t n_rows = i_batch_end[seq_id] - i_batch_beg[seq_id] + 1;
verify_h_rows[seq_id] = n_rows;
verify_h[seq_id].resize((size_t) n_rows * n_embd);
for (int32_t i = 0; i < n_rows; ++i) {
const float * h = llama_get_embeddings_nextn_ith(ctx_tgt, i_batch_beg[seq_id] + i);
std::memcpy(verify_h[seq_id].data() + (size_t) i * n_embd, h, row_bytes);
}
std::memcpy(pending_h[seq_id].data(),
verify_h[seq_id].data() + (size_t) (n_rows - 1) * n_embd, row_bytes);
}
return true;
}
void draft(common_speculative_draft_params_vec & dparams) override {
auto & ctx_dft = params.ctx_dft;
common_batch_clear(batch);
// keep track of which sequences are still drafting
int n_drafting = 0;
std::vector<bool> drafting(n_seq);
const size_t row_bytes = (size_t) n_embd * sizeof(float);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
n_drafting++;
drafting[seq_id] = true;
common_sampler_reset(smpls[seq_id].get());
common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes);
i_last[seq_id] = batch.n_tokens - 1;
if (chain_heads) {
chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end());
}
}
int i = 0;
while (n_drafting > 0) {
// each step decodes under a different head, i.e. a different decoder layer, and
// KV is per layer. process() filled this layer's KV only for positions < n_past
// (prompt + accepted prefix) — nothing in the draft region yet. so reset the
// draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact)
// and select head i so it rebuilds its own layer's KV there; decoding just the
// latest token would leave its attention reading cells only another head wrote.
if (chain_heads) {
auto * mem_dft = llama_get_memory(ctx_dft);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (drafting[seq_id]) {
llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1);
}
}
llama_set_nextn_layer_offset(ctx_dft, i);
}
int ret = llama_decode(ctx_dft, batch);
if (ret != 0) {
LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret);
break;
}
// rebuild the batch for the next step: the growing-KV paths re-add only the
// new token (the KV already holds the prefix), while chained heads re-add the
// whole prefix at the next head. dropped sequences are simply not re-added.
common_batch_clear(batch);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
if (!drafting[seq_id]) {
continue;
}
auto * smpl = smpls[seq_id].get();
common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true);
const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]);
const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - seq_id %d, draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
seq_id, k, i, cur_p->data[k].id, cur_p->data[k].p,
common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
}
// add drafted token for each sequence
const llama_token id = cur_p->data[0].id;
// only collect very high-confidence draft tokens
if (cur_p->data[0].p < params.p_min) {
drafting[seq_id] = false;
n_drafting--;
continue;
}
common_sampler_accept(smpl, id, true);
auto & dp = dparams.at(seq_id);
auto & result = *dp.result;
result.push_back(id);
if (params.n_max <= (int) result.size()) {
drafting[seq_id] = false;
n_drafting--;
continue;
}
if (chain_heads) {
// ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546
chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd);
const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far
for (int t = 0; t < n_rows; ++t) {
const llama_token tok = (t == 0) ? dp.id_last : result[t - 1];
common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd,
chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes);
}
} else if (is_mem_shared) {
// note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens
// ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37
common_batch_add(batch, id, dp.n_past, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
} else {
common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true);
std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes);
}
i_last[seq_id] = batch.n_tokens - 1;
}
if (batch.n_tokens == 0) {
break;
}
++i;
}
if (chain_heads) {
llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes
}
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
if (dp.result->size() < (size_t) params.n_min) {
dp.result->clear();
}
}
}
void accept(llama_seq_id seq_id, uint16_t n_accepted, bool /*is_other*/) override {
if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) {
return;
}
const int32_t n_rows = verify_h_rows[seq_id];
if (n_rows <= 0) {
return;
}
const int32_t i_h = std::min<int32_t>(n_accepted, n_rows - 1);
const size_t row_bytes = (size_t) n_embd * sizeof(float);
std::memcpy(pending_h[seq_id].data(), verify_h[seq_id].data() + (size_t) i_h * n_embd, row_bytes);
}
bool need_embd() const override {
return false;
}
bool need_embd_nextn() const override {
return true;
}
};
// state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_impl_ngram_simple : public common_speculative_impl {
common_params_speculative_ngram_map params;
// shared across all sequences
common_ngram_simple_config config;
common_speculative_impl_ngram_simple(
const common_params_speculative & params, uint32_t n_seq,
common_ngram_simple_config config)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, n_seq)
, params(params.ngram_simple)
, config(config)
{
LOG_INF("%s: adding speculative implementation 'ngram-simple'\n", __func__);
LOG_INF("%s: - size_n=%d, size_m=%d, min_hits=%d\n", __func__,
this->params.size_n, this->params.size_m, this->params.min_hits);
}
void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override {
// noop
}
bool process(const llama_batch & /*batch*/) override {
// TODO: implement
return true;
}
void draft(common_speculative_draft_params_vec & dparams) override {
assert(dparams.size() == n_seq);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
*dp.result = common_ngram_simple_draft(config, *dp.prompt, dp.id_last);
}
}
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
// noop
}
bool need_embd() const override {
return false;
}
};
struct common_speculative_impl_ngram_map_k : public common_speculative_impl {
// n_seq configs
std::vector<common_ngram_map> config;
common_speculative_impl_ngram_map_k(
const common_ngram_map & config,
uint32_t n_seq)
: common_speculative_impl(config.key_only ? COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K
: COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, n_seq)
{
for (uint32_t i = 0; i < n_seq; i++) {
this->config.push_back(config);
}
LOG_INF("%s: adding speculative implementation '%s'\n", __func__, common_speculative_type_to_str(this->type).c_str());
LOG_INF("%s: - size_key=%d, size_value=%d, key_only=%d, min_hits=%d\n", __func__,
config.size_key, config.size_value, config.key_only, config.min_hits);
}
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
GGML_ASSERT(seq_id < (llama_seq_id) n_seq);
common_ngram_map_begin(config[seq_id], prompt);
}
bool process(const llama_batch & /*batch*/) override {
// TODO: implement
return true;
}
void draft(common_speculative_draft_params_vec & dparams) override {
assert(dparams.size() == n_seq);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
common_ngram_map_draft(config[seq_id], *dp.prompt, dp.id_last, *dp.result);
}
}
void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) override {
GGML_ASSERT((seq_id < (llama_seq_id) config.size()));
if (is_other) {
return;
}
common_ngram_map_accept(config[seq_id], n_accepted);
}
bool need_embd() const override {
return false;
}
};
struct common_speculative_impl_ngram_mod : public common_speculative_impl {
common_params_speculative_ngram_mod params;
// shared across all sequences
common_ngram_mod mod;
// enable trace logging if LLAMA_TRACE is set
const bool verbose;
struct seq_info {
// the last position in the prompt that was added to the ngram container
size_t i_last = 0;
// length of the last drafted n-gram (number of tokens returned by draft)
size_t n_draft_last = 0;
// consecutive accept rounds with low acceptance fraction (< 0.5)
int n_low = 0;
};
std::vector<seq_info> sinfos;
common_speculative_impl_ngram_mod(
const common_params_speculative & params,
uint32_t n_seq)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, n_seq)
, params(params.ngram_mod)
, mod(params.ngram_mod.n_match, 4*1024*1024)
, verbose(std::getenv("LLAMA_TRACE") != nullptr) {
static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t));
LOG_INF("%s: adding speculative implementation 'ngram-mod'\n", __func__);
LOG_INF("%s: - n_match=%d, n_max=%d, n_min=%d\n", __func__,
this->params.n_match, this->params.n_max, this->params.n_min);
LOG_INF("%s: - mod size=%zu (%.3f MB)\n", __func__,
mod.size(), (float)(mod.size_bytes())/1024/1024);
if (this->params.n_match < 16) {
LOG_WRN("%s: ngram_mod n_match=%d is too small - poor quality is possible, "
"see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, this->params.n_match);
}
sinfos.resize(n_seq);
}
void begin(llama_seq_id seq_id, const llama_tokens & prompt) override {
auto & sinfo = sinfos[seq_id];
sinfo.i_last = 0;
sinfo.n_draft_last = 0;
const size_t n = mod.get_n();
if (prompt.size() < n) {
return;
}
for (size_t i = 0; i < prompt.size() - n; ++i) {
mod.add(prompt.data() + i);
}
sinfo.i_last = prompt.size() - n;
const double f = (double)mod.get_used() / (double)mod.size();
LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f);
constexpr double f_thold = 0.25;
if (f > f_thold) {
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
mod.reset();
}
}
void draft_one(
llama_seq_id seq_id,
common_speculative_draft_params & dparams) {
auto & sinfo = sinfos[seq_id];
auto & result = *dparams.result;
const auto & prompt = *dparams.prompt;
sinfo.n_draft_last = 0;
const size_t cur_len = prompt.size();
if (cur_len < mod.get_n()) {
return;
}
const size_t n = mod.get_n();
// add new ngrams in chunks
if (sinfo.i_last + 32 < cur_len) {
for (size_t i = sinfo.i_last; i < cur_len - n; ++i) {
mod.add(prompt.data() + i);
}
sinfo.i_last = cur_len - n;
}
result.resize(n + params.n_max);
for (size_t i = 0; i < n - 1; ++i) {
result[i] = prompt.at(cur_len - n + 1 + i);
}
result[n - 1] = dparams.id_last;
for (int i = 0; i < params.n_max; ++i) {
const llama_token token = mod.get(result.data() + i);
if (token == common_ngram_mod::EMPTY) {
if (i < params.n_min) {
result.clear();
return;
}
result.resize(n + i);
break;
}
result[n + i] = token;
}
// only return the m tokens that were drafted
for (size_t i = 0; n + i < result.size(); ++i) {
result[i] = result[n + i];
}
result.resize(result.size() - n);
// store length of drafted n-gram for later acceptance analysis
sinfo.n_draft_last = result.size();
}
bool process(const llama_batch & /*batch*/) override {
// TODO: implement
return true;
}
void draft(common_speculative_draft_params_vec & dparams) override {
assert(dparams.size() == n_seq);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
draft_one(seq_id, dp);
}
}
void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) override {
if (is_other) {
return;
}
auto & sinfo = sinfos[seq_id];
// compute acceptance fraction if we have a recorded draft length
if (sinfo.n_draft_last > 0) {
const double f_acc = (double)n_accepted / (double)sinfo.n_draft_last;
if (f_acc < 0.25) {
sinfo.n_low++;
if (sinfo.n_low >= 5) {
if (verbose) {
LOG_WRN("%s: low acceptance streak (%d) - resetting ngram_mod\n", __func__, sinfo.n_low);
}
mod.reset();
sinfo.n_low = 0;
sinfo.i_last = 0;
}
} else {
sinfo.n_low = 0;
}
}
}
bool need_embd() const override {
return false;
}
};
struct common_speculative_impl_ngram_cache : public common_speculative_impl {
common_params_speculative_ngram_cache params;
uint16_t n_draft;
bool save_dynamic;
bool save_static;
struct seq_info {
size_t cache_size = 0; // number of tokens in n-gram cache
common_ngram_cache ngram_cache_context;
common_ngram_cache ngram_cache_dynamic;
common_ngram_cache ngram_cache_static;
};
std::vector<seq_info> sinfos;
common_speculative_impl_ngram_cache(
const common_params_speculative & params,
uint32_t n_seq,
uint16_t n_draft,
const std::string & path_static,
const std::string & path_dynamic,
bool save_dynamic,
bool save_static)
: common_speculative_impl(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, n_seq)
, params(params.ngram_cache)
, n_draft(n_draft)
, save_dynamic(save_dynamic)
, save_static(save_static)
{
LOG_INF("%s: adding speculative implementation 'ngram-cache'\n", __func__);
LOG_INF("%s: - n_draft=%d, cache_static=%s, cache_dynamic=%s\n", __func__,
n_draft,
path_static.empty() ? "none" : path_static.c_str(),
path_dynamic.empty() ? "none" : path_dynamic.c_str());
sinfos.resize(n_seq);
if (!path_static.empty()) {
try {
auto ngram_cache_static = common_ngram_cache_load(path_static);
for (auto & sinfo : sinfos) {
sinfo.ngram_cache_static = ngram_cache_static;
}
} catch (...) {
LOG_ERR("failed to open static lookup cache: %s", path_static.c_str());
GGML_ABORT("Couldn't read static lookup cache");
}
}
if (!path_dynamic.empty()) {
try {
auto ngram_cache_dynamic = common_ngram_cache_load(path_dynamic);
for (auto & sinfo : sinfos) {
sinfo.ngram_cache_dynamic = ngram_cache_dynamic;
}
} catch (...) {
LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str());
GGML_ABORT("Couldn't read dynamic lookup cache");
}
}
}
void begin(llama_seq_id /*seq_id*/, const llama_tokens & /*prompt*/) override {
// noop
}
void draft_one(
llama_seq_id seq_id,
common_speculative_draft_params & dparams) {
auto & sinfo = sinfos[seq_id];
auto & result = *dparams.result;
const auto & prompt = *dparams.prompt;
if (sinfo.cache_size < prompt.size() + 1) {
llama_tokens tokens_new;
tokens_new.reserve(prompt.size() + 1 - sinfo.cache_size);
for (size_t j = sinfo.cache_size; j < prompt.size(); ++j) {
tokens_new.push_back(prompt[j]);
}
tokens_new.push_back(dparams.id_last); // add the last token
// Update context ngram cache with new dparams.prompt:
common_ngram_cache_update(
sinfo.ngram_cache_context,
LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX,
tokens_new, tokens_new.size(), false);
sinfo.cache_size = prompt.size() + 1;
}
llama_tokens inp;
inp.reserve(prompt.size() + 1);
for (size_t j = 0; j < prompt.size(); ++j) {
inp.push_back(prompt[j]);
}
inp.push_back(dparams.id_last);
result.push_back(dparams.id_last);
common_ngram_cache_draft(
inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX,
sinfo.ngram_cache_context,
sinfo.ngram_cache_dynamic,
sinfo.ngram_cache_static);
if (result.size() > 0) {
// delete first token in result (which is the id_last token)
result.erase(result.begin());
}
}
bool process(const llama_batch & /*batch*/) override {
// TODO: implement
return true;
}
void draft(common_speculative_draft_params_vec & dparams) override {
assert(dparams.size() == n_seq);
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) {
auto & dp = dparams[seq_id];
if (!dp.drafting) {
continue;
}
draft_one(seq_id, dp);
}
}
void accept(llama_seq_id /*seq_id*/, uint16_t /*n_accepted*/, bool /*is_other*/) override {
// noop
}
bool need_embd() const override {
return false;
}
};
struct common_speculative {
common_speculative_draft_params_vec dparams;
// list of implementations to use and their states
std::vector<std::unique_ptr<common_speculative_impl>> impls;
// which implementaion was used for a given seq_id
std::vector<common_speculative_impl *> impl_last;
};
static common_ngram_map get_common_ngram_map(
common_speculative_type type,
const common_params_speculative_ngram_map & config) {
uint16_t size_key = config.size_n;
uint16_t size_value = config.size_m;
bool key_only = type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
uint16_t min_hits = config.min_hits;
return common_ngram_map(size_key, size_value, key_only, min_hits);
}
static common_speculative_impl_ngram_cache create_state_ngram_cache(
const common_speculative_config & config,
uint32_t n_seq,
const std::string & path_static,
const std::string & path_dynamic) {
uint16_t n_draft = 8; // TODO get from config?
// TODO bool param in common/common.h to set save_static/save_dynamic?
bool save_static = false;
bool save_dynamic = false;
common_speculative_impl_ngram_cache state(config.params, n_seq, n_draft, path_static, path_dynamic, save_static, save_dynamic);
return state;
}
std::string common_speculative_type_name_str(const std::vector<common_speculative_type> & types) {
std::string result;
for (size_t i = 0; i < types.size(); i++) {
if (i > 0) {
result += ",";
}
result += common_speculative_type_to_str(types[i]);
}
return result;
}
const char * common_speculative_all_types_str() {
static std::string all_types_str = []() {
std::vector<common_speculative_type> types;
types.reserve(COMMON_SPECULATIVE_TYPE_COUNT);
for (int i = 0; i < COMMON_SPECULATIVE_TYPE_COUNT; i++) {
types.push_back((common_speculative_type) i);
}
return common_speculative_type_name_str(types);
}();
return all_types_str.c_str();
}
std::string common_speculative_type_to_str(common_speculative_type type) {
switch (type) {
case COMMON_SPECULATIVE_TYPE_NONE: return "none";
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: return "draft-simple";
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: return "draft-eagle3";
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: return "draft-mtp";
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram-simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram-map-k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram-map-k4v";
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram-mod";
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram-cache";
default: return "unknown";
}
}
std::vector<common_speculative_type> common_speculative_types_from_names(const std::vector<std::string> & names) {
std::vector<common_speculative_type> types;
types.reserve(names.size());
for (const auto & name : names) {
auto type = common_speculative_type_from_name_map.find(name);
if (type != common_speculative_type_from_name_map.end()) {
if (type->second == COMMON_SPECULATIVE_TYPE_NONE) {
return std::vector<common_speculative_type> { COMMON_SPECULATIVE_TYPE_NONE };
}
types.push_back(type->second);
continue;
}
throw std::invalid_argument("unknown speculative type: " + name);
}
return types;
}
common_speculative_type common_speculative_type_from_name(const std::string & name) {
const auto it = common_speculative_type_from_name_map.find(name);
if (it == common_speculative_type_from_name_map.end()) {
return COMMON_SPECULATIVE_TYPE_COUNT;
}
return it->second;
}
static uint32_t common_get_enabled_speculative_configs(const std::vector<common_speculative_type> & configs) {
uint32_t result = 0;
for (size_t i = 0; i < configs.size(); i++) {
result |= (1u << configs[i]);
}
return result;
}
int32_t common_speculative_n_max(const common_params_speculative * spec) {
int32_t n_max = 0;
for (const auto type : spec->types) {
switch (type) {
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE:
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3:
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP:
n_max = std::max(n_max, std::max(0, spec->draft.n_max));
break;
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE:
n_max = std::max(n_max, (int32_t) spec->ngram_simple.size_m);
break;
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K:
n_max = std::max(n_max, (int32_t) spec->ngram_map_k.size_m);
break;
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V:
n_max = std::max(n_max, (int32_t) spec->ngram_map_k4v.size_m);
break;
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD:
n_max = std::max(n_max, std::max(0, spec->ngram_mod.n_max));
break;
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE:
n_max = std::max(n_max, (int32_t) 8);
break;
case COMMON_SPECULATIVE_TYPE_NONE:
case COMMON_SPECULATIVE_TYPE_COUNT:
break;
}
}
return n_max;
}
// initialization of the speculative decoding system
//
common_speculative * common_speculative_init(common_params_speculative & params, uint32_t n_seq) {
// Compute the implementations to use based on the config and their order of preference
std::vector<common_speculative_config> configs = {}; // list of speculative configs to try
{
uint32_t enabled_configs = common_get_enabled_speculative_configs(params.types);
bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE));
bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr;
bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr;
bool has_ngram_cache = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_CACHE));
bool has_ngram_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE));
bool has_ngram_map_k = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K));
bool has_ngram_map_k4v = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V));
bool has_ngram_mod = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_NGRAM_MOD));
// when adding a new type - update here the logic above
static_assert(COMMON_SPECULATIVE_TYPE_COUNT == 9);
// this list here defines the priority of the speculators
// the one with highest priority are listed first
if (has_ngram_simple) {
// This implementation can guess a lot of tokens without any draft model.
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params));
}
if (has_ngram_map_k) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, params));
}
if (has_ngram_map_k4v) {
// This implementation can guess tokens with high acceptance rate but is more expensive.
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params));
}
if (has_ngram_mod) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, params));
}
if (has_ngram_cache) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
}
if (has_draft_simple) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE, params));
}
if (has_draft_eagle3) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params));
}
if (has_draft_mtp) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params));
}
}
std::vector<std::unique_ptr<common_speculative_impl>> impls = {};
for (const common_speculative_config & config : configs) {
switch (config.type) {
case COMMON_SPECULATIVE_TYPE_NONE:
break;
case COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE: {
impls.push_back(std::make_unique<common_speculative_impl_draft_simple>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3: {
impls.push_back(std::make_unique<common_speculative_impl_draft_eagle3>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_DRAFT_MTP: {
impls.push_back(std::make_unique<common_speculative_impl_draft_mtp>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: {
common_ngram_map ngram_map = get_common_ngram_map(config.type, config.params.ngram_simple);
uint16_t ngram_size_key = ngram_map.size_key;
uint16_t mgram_size_value = ngram_map.size_value;
auto config_simple = common_ngram_simple_config {
/* .size_ngram = */ ngram_size_key,
/* .size_mgram = */ mgram_size_value
};
auto state = std::make_unique<common_speculative_impl_ngram_simple>(
/* .params = */ config.params,
/* .n_seq = */ n_seq,
/* .state = */ config_simple
);
impls.push_back(std::move(state));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: {
impls.push_back(
std::make_unique<common_speculative_impl_ngram_map_k>(
get_common_ngram_map(config.type, config.params.ngram_map_k), n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: {
impls.push_back(
std::make_unique<common_speculative_impl_ngram_map_k>(
get_common_ngram_map(config.type, config.params.ngram_map_k4v), n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: {
impls.push_back(
std::make_unique<common_speculative_impl_ngram_mod>(config.params, n_seq));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
auto state = create_state_ngram_cache(
config, n_seq,
params.ngram_cache.lookup_cache_static,
params.ngram_cache.lookup_cache_dynamic);
impls.push_back(std::make_unique<common_speculative_impl_ngram_cache>(state));
break;
}
default:
break;
}
}
if (impls.empty()) {
LOG_WRN("%s: no implementations specified for speculative decoding\n", __func__);
return nullptr;
}
auto * result = new common_speculative {
/* .dparams = */ common_speculative_draft_params_vec(n_seq),
/* .impls = */ std::move(impls),
/* .impl_last = */ std::vector<common_speculative_impl *>(n_seq, nullptr)
};
return result;
}
void common_speculative_free(common_speculative * spec) {
if (spec == nullptr) {
return;
}
delete spec;
}
common_speculative_draft_params & common_speculative_get_draft_params(
common_speculative * spec,
llama_seq_id seq_id) {
GGML_ASSERT(spec);
GGML_ASSERT(seq_id < (llama_seq_id) spec->dparams.size());
return spec->dparams[seq_id];
}
void common_speculative_begin(common_speculative * spec, llama_seq_id seq_id, const llama_tokens & prompt) {
if (spec == nullptr) {
return;
}
for (auto & impl : spec->impls) {
common_time_meas tm(impl->t_begin_us, !impl->gen_perf);
impl->begin(seq_id, prompt);
impl->n_call_begin++;
}
}
bool common_speculative_process(common_speculative * spec, const llama_batch & batch) {
bool result = true;
if (spec == nullptr) {
return result;
}
for (auto & impl : spec->impls) {
result = result && impl->process(batch);
}
return result;
}
bool common_speculative_need_embd(common_speculative * spec) {
if (spec == nullptr) {
return false;
}
for (auto & impl : spec->impls) {
if (impl->need_embd()) {
return true;
}
}
return false;
}
bool common_speculative_need_embd_nextn(common_speculative * spec) {
if (spec == nullptr) {
return false;
}
for (auto & impl : spec->impls) {
if (impl->need_embd_nextn()) {
return true;
}
}
return false;
}
void common_speculative_draft(common_speculative * spec) {
if (spec == nullptr) {
return;
}
auto & dparams = spec->dparams;
{
int n_drafting = 0;
for (auto & dp : dparams) {
GGML_ASSERT(!dp.drafting || dp.result->empty());
if (dp.drafting) {
n_drafting++;
}
}
if (n_drafting == 0) {
return;
}
}
for (auto & impl : spec->impls) {
{
common_time_meas tm(impl->t_draft_us, !impl->gen_perf);
impl->draft(dparams);
impl->n_call_draft++;
}
int n_drafting = 0;
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) {
auto & dp = dparams[seq_id];
auto & result = *dp.result;
// a new draft has been sampled
if (dp.drafting && !result.empty()) {
dp.drafting = false;
if (dp.n_max > 0) {
if (!result.empty() && (int) result.size() > dp.n_max) {
LOG_DBG("%s: truncating draft to %d tokens\n", __func__, dp.n_max);
result.resize(dp.n_max);
}
}
if (!result.empty()) {
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
common_speculative_type_to_str(impl.get()->type).c_str(), dp.prompt->size(),
impl.get()->n_call_draft, result.size());
// remember which implementation was used
spec->impl_last[seq_id] = impl.get();
impl->n_gen_drafts++;
impl->n_gen_tokens += result.size();
}
}
if (dp.drafting) {
n_drafting++;
}
}
if (n_drafting == 0) {
break;
}
}
// these sequences failed to generate a draft
for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) dparams.size(); ++seq_id) {
auto & dp = dparams[seq_id];
if (dp.drafting) {
dp.drafting = false;
}
}
}
void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, uint16_t n_accepted) {
common_speculative_impl * impl = spec->impl_last[seq_id];
GGML_ASSERT(impl);
{
common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
if (impl->n_acc_tokens_per_pos.size() < n_accepted) {
impl->n_acc_tokens_per_pos.resize(n_accepted, 0);
}
for (size_t i = 0; i < n_accepted; ++i) {
impl->n_acc_tokens_per_pos[i]++;
}
if (n_accepted > 0) {
impl->n_acc_drafts++;
impl->n_acc_tokens += n_accepted;
}
impl->accept(seq_id, n_accepted, false);
impl->n_call_accept++;
}
// accept with the rest of the implementations, using is_other == true
for (auto & impl_other : spec->impls) {
if (impl_other.get() != impl) {
impl_other->accept(seq_id, n_accepted, true);
}
}
}
// TODO: support the case of more than one speculative implementations having a state
bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector<uint8_t> & data) {
if (spec == nullptr) {
return false;
}
for (auto & impl : spec->impls) {
if (impl->get_state(seq_id, data)) {
return true;
}
}
return false;
}
void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector<uint8_t> & data) {
if (spec == nullptr) {
return;
}
for (auto & impl : spec->impls) {
impl->set_state(seq_id, data);
}
}
void common_speculative_print_stats(const common_speculative * spec) {
if (spec == nullptr) {
return;
}
for (const auto & impl : spec->impls) {
std::string str_perf;
if (impl->gen_perf) {
std::ostringstream oss;
oss << std::fixed << std::setprecision(3) << impl->t_begin_us / 1000.0 << ", ";
oss << std::fixed << std::setprecision(3) << impl->t_draft_us / 1000.0 << ", ";
oss << std::fixed << std::setprecision(3) << impl->t_accept_us / 1000.0;
str_perf = ", dur(b,g,a) = " + oss.str() + " ms";
} else {
str_perf = "";
}
std::string str_stats;
if (impl->n_call_accept > 0) {
const double mean =
1.0 + (double) impl->n_acc_tokens / (double) impl->n_call_accept;
std::ostringstream tmp;
tmp << std::fixed << std::setprecision(3);
for (size_t i = 0; i < impl->n_acc_tokens_per_pos.size(); ++i) {
if (i > 0) {
tmp << ", ";
}
tmp << (double) impl->n_acc_tokens_per_pos[i] / (double) impl->n_call_accept;
}
std::ostringstream oss;
oss << std::fixed << std::setprecision(2) << mean;
str_stats = ", #mean acc len = " + oss.str() + ", #acc rate/pos = (" + tmp.str() + ")";
}
LOG_INF("statistics %16s: #calls(b,g,a) = %4zu %6zu %6zu, #gen drafts = %6zu, #acc drafts = %5zu, #gen tokens = %6zu, #acc tokens = %5zu%s%s\n",
common_speculative_type_to_str(impl->type).c_str(),
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
impl->n_gen_drafts,
impl->n_acc_drafts,
impl->n_gen_tokens,
impl->n_acc_tokens,
str_stats.c_str(),
str_perf.c_str());
}
}