MTP: clean-up (#9)

* MTP: clean-up

* review: use llama_context_type instead of llama_graph_type

* review: remove llama_model_has_mtp

* review: fix convert issues

* convert: fix pycheck

* review: formatting

* use `mtp-` for identifying mtp models

* convert: fix mtp conversion
This commit is contained in:
Aman Gupta
2026-05-13 11:12:20 +08:00
parent 5c58cc5bdd
commit 3c3aebaaa0
20 changed files with 687 additions and 646 deletions
+12 -27
View File
@@ -756,6 +756,14 @@ private:
}
auto cparams = common_context_params_to_llama(params_dft);
const bool spec_mtp = std::find(params_base.speculative.types.begin(),
params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end();
if (spec_mtp) {
cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
}
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams));
ctx_dft_seq_rm_type = common_context_can_seq_rm(ctx_dft.get());
@@ -764,36 +772,13 @@ private:
params_base.speculative.draft.ctx_dft = ctx_dft.get();
} else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(),
COMMON_SPECULATIVE_TYPE_MTP) != params_base.speculative.types.end()) {
// MTP head lives in the *target* GGUF — load it as a sibling model
// with override_arch and feed it through the existing ctx_dft slot.
char trunk_arch[64] = {0};
llama_model_meta_val_str(model_tgt, "general.architecture", trunk_arch, sizeof(trunk_arch));
const char * mtp_arch = nullptr;
if (std::string(trunk_arch) == "qwen35") {
mtp_arch = "qwen35_mtp";
} else if (std::string(trunk_arch) == "qwen35moe") {
mtp_arch = "qwen35moe_mtp";
} else {
SRV_ERR("MTP not supported for trunk architecture '%s'\n", trunk_arch);
return false;
}
SRV_INF("loading MTP head from '%s' (override_arch=%s)\n",
params_base.model.path.c_str(), mtp_arch);
auto mparams_mtp = common_model_params_to_llama(params_base);
mparams_mtp.override_arch = mtp_arch;
model_dft.reset(llama_model_load_from_file(params_base.model.path.c_str(), mparams_mtp));
if (model_dft == nullptr) {
SRV_ERR("failed to load MTP head from '%s'\n", params_base.model.path.c_str());
return false;
}
SRV_INF("creating MTP draft context against the target model '%s'\n",
params_base.model.path.c_str());
auto cparams_mtp = common_context_params_to_llama(params_base);
cparams_mtp.ctx_type = LLAMA_CONTEXT_TYPE_MTP;
ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams_mtp));
ctx_dft.reset(llama_init_from_model(model_tgt, cparams_mtp));
if (ctx_dft == nullptr) {
SRV_ERR("%s", "failed to create MTP context\n");
return false;