TensorRT-LLMs/tensorrt_llm/_torch/speculative/utils.py
Thor Johnsen 5d438be59a
[TRTLLM-5000][feat] Pytorch implementation of ngram drafter (#3936)
* v1.5

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

v1.5.4 Add back draft_overhead to spec dec stats

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* v1.5.5: fix CI error

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* v1.6: fix CI error 8196 > 8192

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* Address reviewer concerns

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* Address reviewer concerns

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* precommit run

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

* v2.0: Address reviewer concerns

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* v2.1: add fix from wili

Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>

* Revert changes that require use of TypeAlias because that requires python version >= 3.10

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>

---------

Signed-off-by: Thor Johnsen <41591019+thorjohnsen@users.noreply.github.com>
Signed-off-by: wili-65535 <wili-65535@users.noreply.github.com>
Co-authored-by: wili-65535 <wili-65535@users.noreply.github.com>
2025-05-21 10:40:00 +08:00

58 lines
2.3 KiB
Python

from .eagle3 import Eagle3Sampler, Eagle3SpecMetadata
from .mtp import MTPHiddenStatesManager, MTPSampler, MTPSpecMetadata
from .ngram import NGramPoolManager
def get_spec_metadata(spec_config,
max_num_requests,
spec_resource_manager=None):
if spec_config.spec_dec_mode.is_mtp():
return MTPSpecMetadata(
max_draft_tokens=spec_config.max_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
mtp_num_modules=spec_config.num_nextn_predict_layers,
max_num_requests=max_num_requests,
mtp_hidden_states_manager=spec_resource_manager)
elif spec_config.spec_dec_mode.is_eagle3():
return Eagle3SpecMetadata(max_draft_tokens=spec_config.max_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=spec_config.num_layers,
hidden_size=spec_config.hidden_size)
else:
return None
def get_spec_resource_manager(spec_config, model_config, max_num_requests):
if spec_config.spec_dec_mode.is_mtp_eagle():
if spec_config.use_relaxed_acceptance_for_thinking:
return MTPHiddenStatesManager(spec_config, model_config.torch_dtype,
model_config.hidden_size,
max_num_requests)
else:
return None
elif spec_config.spec_dec_mode.is_mtp():
return MTPHiddenStatesManager(spec_config, model_config.torch_dtype,
model_config.hidden_size,
max_num_requests)
elif spec_config.spec_dec_mode.is_ngram():
return NGramPoolManager(spec_config, max_num_requests)
else:
return None
def get_spec_decoder(max_seq_len, spec_config):
if spec_config.spec_dec_mode.is_mtp():
return MTPSampler(max_seq_len, spec_config)
if spec_config.spec_dec_mode.is_eagle3():
return Eagle3Sampler(max_seq_len)
else:
return None
def get_num_spec_layers(spec_config):
if spec_config.spec_dec_mode.is_mtp():
return spec_config.num_nextn_predict_layers
else:
return 0