mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[None][fix] Fix MTP 1-model sampler (#10369)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com> Signed-off-by: Wangshanshan <30051912+dominicshanshan@users.noreply.github.com>
This commit is contained in:
parent
d9aef94431
commit
13b0ab9c0e
@ -50,8 +50,7 @@ from ..speculative import (SpecMetadata, get_num_extra_kv_tokens,
|
||||
get_spec_metadata,
|
||||
update_spec_config_from_model_config)
|
||||
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
|
||||
from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata,
|
||||
Eagle3ResourceManager, Eagle3SpecMetadata)
|
||||
from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
|
||||
from ..speculative.mtp import SampleStateTensorsMTP
|
||||
from ..speculative.utils import SpecDecodingTensor
|
||||
from ..utils import (get_model_extra_attrs,
|
||||
@ -2784,9 +2783,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
num_accepted_draft_tokens)]
|
||||
if isinstance(spec_metadata, Eagle3SpecMetadata):
|
||||
spec_metadata.request_accepted_path = request_accepted_path
|
||||
if isinstance(spec_metadata, Eagle3OneModelSpecMetadata):
|
||||
spec_metadata.populate_sampling_params_for_one_model(
|
||||
scheduled_requests.all_requests())
|
||||
# No-op for non 1-model
|
||||
spec_metadata.populate_sampling_params_for_one_model(
|
||||
scheduled_requests.all_requests())
|
||||
spec_metadata.prepare()
|
||||
inputs['spec_metadata'] = spec_metadata
|
||||
|
||||
|
||||
@ -282,16 +282,12 @@ def create_py_executor(
|
||||
)
|
||||
llm_args.disable_overlap_scheduler = True
|
||||
|
||||
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine():
|
||||
if not spec_config.allow_advanced_sampling:
|
||||
logger.warning(
|
||||
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
|
||||
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
|
||||
)
|
||||
elif spec_config.spec_dec_mode.is_mtp_one_model():
|
||||
logger.warning(
|
||||
"Advanced sampling is not supported for MTP yet - this will be added soon."
|
||||
)
|
||||
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine(
|
||||
) and not spec_config.allow_advanced_sampling:
|
||||
logger.warning(
|
||||
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
|
||||
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
|
||||
)
|
||||
|
||||
if mm_encoder_only:
|
||||
llm_args.mm_encoder_only = True
|
||||
|
||||
@ -31,6 +31,7 @@ def get_spec_metadata(spec_config,
|
||||
mtp_num_modules=spec_config.num_nextn_predict_layers,
|
||||
max_num_requests=max_num_requests,
|
||||
mtp_hidden_states_manager=spec_resource_manager,
|
||||
allow_advanced_sampling=spec_config.allow_advanced_sampling,
|
||||
)
|
||||
if spec_config.spec_dec_mode.is_mtp_eagle():
|
||||
return Eagle3SpecMetadata(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user