mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Signed-off-by: Chenfei Zhang <chenfeiz@nvidia.com> Signed-off-by: Yilin Fan <206948969+nv-yilinf@users.noreply.github.com> Co-authored-by: Chenfei Zhang <chenfeiz@nvidia.com>
84 lines
3.4 KiB
Python
84 lines
3.4 KiB
Python
from .eagle3 import (Eagle3OneModelDecoder, Eagle3OneModelSpecMetadata,
|
|
Eagle3OneModelWorker, Eagle3Sampler, Eagle3SpecMetadata)
|
|
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
|
|
MTPSpecMetadata, MTPWorker)
|
|
from .ngram import NGramPoolManager
|
|
|
|
|
|
def get_spec_metadata(spec_config,
|
|
max_num_requests,
|
|
max_num_tokens,
|
|
spec_resource_manager=None):
|
|
if spec_config.spec_dec_mode.is_mtp():
|
|
return MTPSpecMetadata(
|
|
max_draft_tokens=spec_config.max_draft_tokens,
|
|
spec_dec_mode=spec_config.spec_dec_mode,
|
|
mtp_num_modules=spec_config.num_nextn_predict_layers,
|
|
max_num_requests=max_num_requests,
|
|
mtp_hidden_states_manager=spec_resource_manager)
|
|
elif spec_config.spec_dec_mode.is_eagle3():
|
|
return Eagle3SpecMetadata(max_draft_tokens=spec_config.max_draft_tokens,
|
|
spec_dec_mode=spec_config.spec_dec_mode,
|
|
max_num_requests=max_num_requests,
|
|
num_layers=spec_config.num_layers,
|
|
hidden_size=spec_config.hidden_size)
|
|
elif spec_config.spec_dec_mode.is_eagle3_one_model():
|
|
return Eagle3OneModelSpecMetadata(
|
|
max_draft_tokens=spec_config.max_draft_tokens,
|
|
spec_dec_mode=spec_config.spec_dec_mode,
|
|
max_num_requests=max_num_requests,
|
|
num_layers=spec_config.num_layers,
|
|
hidden_size=spec_config.hidden_size,
|
|
max_num_tokens=max_num_tokens)
|
|
else:
|
|
return None
|
|
|
|
|
|
def get_spec_resource_manager(spec_config, model_config, max_num_requests):
|
|
if spec_config.spec_dec_mode.is_mtp_eagle():
|
|
if spec_config.use_relaxed_acceptance_for_thinking:
|
|
return MTPHiddenStatesManager(spec_config, model_config.torch_dtype,
|
|
model_config.hidden_size,
|
|
max_num_requests)
|
|
else:
|
|
return None
|
|
elif spec_config.spec_dec_mode.is_mtp():
|
|
return MTPHiddenStatesManager(spec_config, model_config.torch_dtype,
|
|
model_config.hidden_size,
|
|
max_num_requests)
|
|
elif spec_config.spec_dec_mode.is_ngram():
|
|
return NGramPoolManager(spec_config, max_num_requests)
|
|
else:
|
|
return None
|
|
|
|
|
|
def get_spec_decoder(max_seq_len, spec_config):
|
|
if spec_config.spec_dec_mode.is_mtp():
|
|
return MTPSampler(max_seq_len, spec_config)
|
|
elif spec_config.spec_dec_mode.is_eagle3():
|
|
return Eagle3Sampler(max_seq_len)
|
|
elif spec_config.spec_dec_mode.is_eagle3_one_model():
|
|
return Eagle3OneModelDecoder(max_seq_len, spec_config)
|
|
else:
|
|
return None
|
|
|
|
|
|
def get_num_spec_layers(spec_config):
|
|
if spec_config.spec_dec_mode.is_mtp():
|
|
return spec_config.num_nextn_predict_layers
|
|
elif spec_config.spec_dec_mode.is_eagle3_one_model():
|
|
return 1
|
|
else:
|
|
return 0
|
|
|
|
|
|
def get_spec_worker(spec_config, mapping):
|
|
if spec_config.spec_dec_mode.is_mtp():
|
|
return MTPWorker(spec_config)
|
|
elif spec_config.spec_dec_mode.is_mtp_eagle():
|
|
return MTPEagleWorker(spec_config)
|
|
elif spec_config.spec_dec_mode.is_eagle3_one_model():
|
|
return Eagle3OneModelWorker(spec_config, mapping)
|
|
else:
|
|
return None
|