TensorRT-LLMs/tensorrt_llm/_torch/speculative/utils.py
Mike Iovine ca8291133a
[None][fix] Fix MTP 2-model (#8115)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
Signed-off-by: Mike Iovine <miovine@nvidia.com>
2025-10-03 10:13:50 -07:00

241 lines
9.7 KiB
Python

from typing import Optional
from ..pyexecutor.guided_decoder import GuidedDecoder
from ..pyexecutor.sampler import TorchSampler
from ..pyexecutor.seq_slot_manager import SeqSlotManager
from ..speculative.interface import SpecMetadata
from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata,
Eagle3OneModelWorker, Eagle3ResourceManager,
Eagle3SpecMetadata)
from .model_drafter import ModelDrafter
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
MTPSpecMetadata, MTPWorker)
from .ngram import NGramDrafter, NGramPoolManager
from .save_hidden_state import SaveHiddenStatesDrafter
def get_spec_metadata(spec_config,
model_config,
max_num_requests,
max_num_tokens,
spec_resource_manager=None,
is_draft_model=False):
if spec_config.spec_dec_mode.is_mtp_one_model():
return MTPSpecMetadata(
max_draft_len=spec_config.max_draft_len,
spec_dec_mode=spec_config.spec_dec_mode,
mtp_num_modules=spec_config.num_nextn_predict_layers,
max_num_requests=max_num_requests,
mtp_hidden_states_manager=spec_resource_manager,
)
if spec_config.spec_dec_mode.is_mtp_eagle():
return Eagle3SpecMetadata(
max_draft_len=spec_config.max_draft_len,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=model_config.num_hidden_layers,
hidden_size=model_config.hidden_size,
max_num_tokens=max_num_tokens,
dtype=model_config.torch_dtype,
is_draft_model=is_draft_model,
eagle3_resource_manager=spec_resource_manager,
layers_to_capture=None,
is_mtp_eagle=True,
)
if spec_config.spec_dec_mode.is_eagle3():
return Eagle3SpecMetadata(
max_draft_len=spec_config.max_draft_len,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=model_config.num_hidden_layers,
hidden_size=model_config.hidden_size,
max_num_tokens=max_num_tokens,
dtype=model_config.torch_dtype,
is_draft_model=is_draft_model,
eagle3_resource_manager=spec_resource_manager,
layers_to_capture=spec_config.eagle3_layers_to_capture,
is_mtp_eagle=False,
max_total_draft_tokens=spec_config.max_total_draft_tokens,
eagle_choices=spec_config.eagle_choices,
is_spec_dec_tree=spec_config.eagle_choices is not None,
is_spec_dec_dynamic_tree=spec_config.use_dynamic_tree,
)
if spec_config.spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelSpecMetadata(
max_draft_len=spec_config.max_draft_len,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=model_config.num_hidden_layers,
hidden_size=model_config.hidden_size,
max_num_tokens=max_num_tokens,
layers_to_capture=spec_config.eagle3_layers_to_capture,
)
if spec_config.spec_dec_mode.is_save_hidden_states():
if spec_config.eagle3_layers_to_capture is None:
spec_config.eagle3_layers_to_capture = {
1, model_config.num_hidden_layers // 2 - 1,
model_config.num_hidden_layers - 4, -1
}
return Eagle3SpecMetadata(
max_draft_len=spec_config.max_draft_len,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=model_config.num_hidden_layers,
hidden_size=model_config.hidden_size,
max_num_tokens=max_num_tokens,
dtype=model_config.torch_dtype,
is_draft_model=is_draft_model,
eagle3_resource_manager=spec_resource_manager,
layers_to_capture=spec_config.eagle3_layers_to_capture,
max_total_draft_tokens=1,
)
if spec_config.spec_dec_mode.is_draft_target() or \
spec_config.spec_dec_mode.is_ngram() or \
spec_config.spec_dec_mode.is_user_provided():
return SpecMetadata(
max_draft_len=spec_config.max_draft_len,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
)
return None
def get_spec_resource_manager(model_engine, draft_model_engine=None):
spec_config = model_engine.spec_config
if spec_config is None:
return None
model_config = model_engine.model.config
max_num_requests = model_engine.batch_size
max_seq_len = model_engine.max_seq_len
max_num_tokens = model_engine.max_num_tokens
spec_dec_mode = spec_config.spec_dec_mode
if spec_dec_mode.is_mtp_eagle_one_model():
if spec_config.use_relaxed_acceptance_for_thinking:
return MTPHiddenStatesManager(
spec_config,
model_config.torch_dtype,
model_config.hidden_size,
max_num_requests,
)
else:
return None
if spec_dec_mode.is_mtp_one_model():
return MTPHiddenStatesManager(
spec_config,
model_config.torch_dtype,
model_config.hidden_size,
max_num_requests,
)
if spec_dec_mode.is_eagle3() or spec_dec_mode.is_mtp_eagle():
assert draft_model_engine is not None, "Draft model engine is required for Eagle3 and MTP Eagle two model flow."
return Eagle3ResourceManager(
spec_config,
draft_model_engine.model.config.torch_dtype,
model_config.hidden_size,
max_num_requests,
max_seq_len,
max_num_tokens,
)
if spec_dec_mode.is_save_hidden_states():
return Eagle3ResourceManager(
spec_config,
model_engine.model.config.torch_dtype,
model_config.hidden_size,
max_num_requests,
max_seq_len,
max_num_tokens,
)
if spec_dec_mode.is_ngram():
return NGramPoolManager(spec_config, max_num_requests)
if spec_dec_mode.is_user_provided():
return spec_config.resource_manager
return None
def get_spec_decoder(sampler_args: TorchSampler.Args,
spec_config: "DecodingBaseConfig"):
if spec_config.spec_dec_mode.is_mtp_one_model():
return MTPSampler(sampler_args,
nextn=spec_config.num_nextn_predict_layers)
if spec_config.spec_dec_mode.is_eagle3(
) or spec_config.spec_dec_mode.is_mtp_eagle():
# TorchSampler handles Eagle3 gracefully, by integrating d2t into the sampling process
return TorchSampler(sampler_args)
if spec_config.spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelSampler(sampler_args)
raise ValueError(
f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}")
def get_spec_drafter(model_engine,
draft_model_engine,
sampler,
spec_resource_manager,
guided_decoder: Optional[GuidedDecoder] = None):
spec_config = model_engine.spec_config
if spec_config is None:
return None
if spec_config.spec_dec_mode.is_user_provided():
return spec_config.drafter
max_num_requests = model_engine.batch_size
if spec_config.spec_dec_mode.is_draft_target(
) or spec_config.spec_dec_mode.is_eagle3(
) or spec_config.spec_dec_mode.is_mtp_eagle():
return ModelDrafter(spec_config,
draft_model_engine,
spec_config.max_draft_len,
SeqSlotManager(max_num_requests),
sampler,
spec_resource_manager=spec_resource_manager,
guided_decoder=guided_decoder)
if spec_config.spec_dec_mode.is_ngram():
return NGramDrafter(spec_config, spec_resource_manager)
if spec_config.spec_dec_mode.is_save_hidden_states():
return SaveHiddenStatesDrafter(spec_config, spec_resource_manager)
return None
def get_num_spec_layers(spec_config):
if spec_config.spec_dec_mode.is_mtp_one_model():
return spec_config.num_nextn_predict_layers
if spec_config.spec_dec_mode.is_eagle3_one_model():
num_eagle_layers = spec_config.num_eagle_layers
return num_eagle_layers if num_eagle_layers is not None else 1
return 0
def get_spec_worker(spec_config, model_config, mapping):
spec_dec_mode = spec_config.spec_dec_mode
if spec_dec_mode.is_mtp_vanilla():
return MTPWorker(spec_config, model_config)
if spec_dec_mode.is_mtp_eagle_one_model():
return MTPEagleWorker(spec_config, model_config)
if spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelWorker(spec_config, mapping)
return None
def get_num_extra_kv_tokens(spec_config):
"""
Implementation detail for one model implementations of speculative decoding. Extra
KV cache tokens are required.
"""
if spec_config is None:
return 0
if spec_config.spec_dec_mode.is_eagle3_one_model():
return spec_config.max_draft_len - 1
return 0
def update_spec_config_from_model_config(spec_config, model_config):
if spec_config.spec_dec_mode.is_mtp_one_model():
# Use `max_draft_len` for several low-level APIs. TODO: Remove this after distinguishing them.
spec_config.max_draft_len = spec_config.num_nextn_predict_layers
# Use `num_nextn_predict_layers_from_model_config` to decide decoding mode MTP / MTP_EAGLE.
spec_config.num_nextn_predict_layers_from_model_config = model_config.num_nextn_predict_layers