diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index f186da6cd8..2033030ad0 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -48,8 +48,7 @@ from ..speculative import (SpecMetadata, get_num_extra_kv_tokens, get_spec_metadata, update_spec_config_from_model_config) from ..speculative.drafting_loops import BaseDraftingLoopWrapper -from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata, - Eagle3ResourceManager, Eagle3SpecMetadata) +from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata from ..speculative.mtp import SampleStateTensorsMTP from ..speculative.utils import SpecDecodingTensor from ..utils import (get_model_extra_attrs, @@ -2684,9 +2683,9 @@ class PyTorchModelEngine(ModelEngine): num_accepted_draft_tokens)] if isinstance(spec_metadata, Eagle3SpecMetadata): spec_metadata.request_accepted_path = request_accepted_path - if isinstance(spec_metadata, Eagle3OneModelSpecMetadata): - spec_metadata.populate_sampling_params_for_one_model( - scheduled_requests.all_requests()) + # No-op for non 1-model + spec_metadata.populate_sampling_params_for_one_model( + scheduled_requests.all_requests()) spec_metadata.prepare() inputs['spec_metadata'] = spec_metadata diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index bd1857dda2..22a5f1c8b6 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -281,16 +281,12 @@ def create_py_executor( ) llm_args.disable_overlap_scheduler = True - if spec_config is not None and spec_config.spec_dec_mode.use_one_engine(): - if not spec_config.allow_advanced_sampling: - logger.warning( - f"Falling back to greedy decoding for {spec_config.decoding_type}. If you " - "want to use non-greedy sampling, please set allow_advanced_sampling=True." - ) - elif spec_config.spec_dec_mode.is_mtp_one_model(): - logger.warning( - "Advanced sampling is not supported for MTP yet - this will be added soon." - ) + if spec_config is not None and spec_config.spec_dec_mode.use_one_engine( + ) and not spec_config.allow_advanced_sampling: + logger.warning( + f"Falling back to greedy decoding for {spec_config.decoding_type}. If you " + "want to use non-greedy sampling, please set allow_advanced_sampling=True." + ) if mm_encoder_only: llm_args.mm_encoder_only = True diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 6a22ad19bd..1eb036de99 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -31,6 +31,7 @@ def get_spec_metadata(spec_config, mtp_num_modules=spec_config.num_nextn_predict_layers, max_num_requests=max_num_requests, mtp_hidden_states_manager=spec_resource_manager, + allow_advanced_sampling=spec_config.allow_advanced_sampling, ) if spec_config.spec_dec_mode.is_mtp_eagle(): return Eagle3SpecMetadata(