TensorRT-LLMs/tensorrt_llm/_torch/speculative/utils.py
Robin Kobus 30a19fcf7c
[TRTLLM-6291] feat: Add user-provided speculative decoding support (#5204)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
2025-07-07 16:30:43 +02:00

151 lines
5.8 KiB
Python

from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler
from tensorrt_llm._torch.speculative.interface import SpecConfig, SpecMetadata
from .draft_target import DraftTargetSpecMetadata
from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata,
Eagle3OneModelWorker, Eagle3ResourceManager,
Eagle3SpecMetadata)
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
MTPSpecMetadata, MTPWorker)
from .ngram import NGramDrafter, NGramPoolManager
def get_spec_metadata(spec_config,
max_num_requests,
max_num_tokens,
spec_resource_manager=None,
is_draft_model=False):
if spec_config.spec_dec_mode.is_mtp():
return MTPSpecMetadata(
max_draft_tokens=spec_config.max_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
mtp_num_modules=spec_config.num_nextn_predict_layers,
max_num_requests=max_num_requests,
mtp_hidden_states_manager=spec_resource_manager,
)
if spec_config.spec_dec_mode.is_eagle3():
return Eagle3SpecMetadata(
max_draft_tokens=spec_config.max_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=spec_config.num_layers,
hidden_size=spec_config.hidden_size,
max_num_tokens=max_num_tokens,
dtype=spec_config.dtype,
is_draft_model=is_draft_model,
eagle3_resource_manager=spec_resource_manager,
)
if spec_config.spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelSpecMetadata(
max_draft_tokens=spec_config.max_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=spec_config.num_layers,
hidden_size=spec_config.hidden_size,
max_num_tokens=max_num_tokens,
)
if spec_config.spec_dec_mode.is_draft_target():
return DraftTargetSpecMetadata(
max_draft_tokens=spec_config.max_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
)
if spec_config.spec_dec_mode.is_ngram(
) or spec_config.spec_dec_mode.is_user_provided():
return SpecMetadata(
max_draft_tokens=spec_config.max_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
)
return None
def get_spec_resource_manager(model_engine,
draft_model_engine=None,
drafter=None):
spec_config = model_engine.spec_config
if spec_config is None:
return None
model_config = model_engine.model.config
max_num_requests = model_engine.batch_size
max_seq_len = model_engine.max_seq_len
max_num_tokens = model_engine.max_num_tokens
spec_dec_mode = spec_config.spec_dec_mode
if spec_dec_mode.is_mtp_eagle():
if spec_config.use_relaxed_acceptance_for_thinking:
return MTPHiddenStatesManager(
spec_config,
model_config.torch_dtype,
model_config.hidden_size,
max_num_requests,
)
else:
return None
if spec_dec_mode.is_mtp():
return MTPHiddenStatesManager(
spec_config,
model_config.torch_dtype,
model_config.hidden_size,
max_num_requests,
)
if spec_dec_mode.is_eagle3():
assert draft_model_engine is not None, "Draft model engine is required for Eagle3 two model flow."
return Eagle3ResourceManager(
spec_config,
draft_model_engine.model.config.torch_dtype,
model_config.hidden_size,
max_num_requests,
max_seq_len,
max_num_tokens,
)
if spec_dec_mode.is_ngram() or spec_dec_mode.is_user_provided():
assert drafter is not None, "Drafter is required for ngram or user provided speculative decoding."
return drafter.spec_resource_manager
return None
def get_spec_decoder(sampler_args: TorchSampler.Args, spec_config: SpecConfig):
if spec_config.spec_dec_mode.is_mtp():
return MTPSampler(sampler_args,
nextn=spec_config.num_nextn_predict_layers)
if spec_config.spec_dec_mode.is_eagle3():
# TorchSampler handles Eagle3 gracefully, by integrating d2t into the sampling process
return TorchSampler(sampler_args)
if spec_config.spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelSampler(sampler_args)
raise ValueError(
f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}")
def get_spec_drafter(model_engine):
spec_config = model_engine.spec_config
max_num_requests = model_engine.batch_size
if spec_config is None:
return None
if spec_config.spec_dec_mode.is_ngram():
return NGramDrafter(spec_config,
NGramPoolManager(spec_config, max_num_requests))
if spec_config.spec_dec_mode.is_user_provided():
return spec_config.drafter
return None
def get_num_spec_layers(spec_config):
if spec_config.spec_dec_mode.is_mtp():
return spec_config.num_nextn_predict_layers
elif spec_config.spec_dec_mode.is_eagle3_one_model():
return 1
else:
return 0
def get_spec_worker(spec_config, mapping):
if spec_config.spec_dec_mode.is_mtp():
return MTPWorker(spec_config)
elif spec_config.spec_dec_mode.is_mtp_eagle():
return MTPEagleWorker(spec_config)
elif spec_config.spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelWorker(spec_config, mapping)
else:
return None