mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Merge 678948a8bc into ba1cb6831d
This commit is contained in:
commit
4406c1d76a
@ -48,8 +48,7 @@ from ..speculative import (SpecMetadata, get_num_extra_kv_tokens,
|
||||
get_spec_metadata,
|
||||
update_spec_config_from_model_config)
|
||||
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
|
||||
from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata,
|
||||
Eagle3ResourceManager, Eagle3SpecMetadata)
|
||||
from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
|
||||
from ..speculative.mtp import SampleStateTensorsMTP
|
||||
from ..speculative.utils import SpecDecodingTensor
|
||||
from ..utils import (get_model_extra_attrs,
|
||||
@ -2684,9 +2683,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
num_accepted_draft_tokens)]
|
||||
if isinstance(spec_metadata, Eagle3SpecMetadata):
|
||||
spec_metadata.request_accepted_path = request_accepted_path
|
||||
if isinstance(spec_metadata, Eagle3OneModelSpecMetadata):
|
||||
spec_metadata.populate_sampling_params_for_one_model(
|
||||
scheduled_requests.all_requests())
|
||||
# No-op for non 1-model
|
||||
spec_metadata.populate_sampling_params_for_one_model(
|
||||
scheduled_requests.all_requests())
|
||||
spec_metadata.prepare()
|
||||
inputs['spec_metadata'] = spec_metadata
|
||||
|
||||
|
||||
@ -281,16 +281,12 @@ def create_py_executor(
|
||||
)
|
||||
llm_args.disable_overlap_scheduler = True
|
||||
|
||||
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine():
|
||||
if not spec_config.allow_advanced_sampling:
|
||||
logger.warning(
|
||||
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
|
||||
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
|
||||
)
|
||||
elif spec_config.spec_dec_mode.is_mtp_one_model():
|
||||
logger.warning(
|
||||
"Advanced sampling is not supported for MTP yet - this will be added soon."
|
||||
)
|
||||
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine(
|
||||
) and not spec_config.allow_advanced_sampling:
|
||||
logger.warning(
|
||||
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
|
||||
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
|
||||
)
|
||||
|
||||
if mm_encoder_only:
|
||||
llm_args.mm_encoder_only = True
|
||||
|
||||
@ -31,6 +31,7 @@ def get_spec_metadata(spec_config,
|
||||
mtp_num_modules=spec_config.num_nextn_predict_layers,
|
||||
max_num_requests=max_num_requests,
|
||||
mtp_hidden_states_manager=spec_resource_manager,
|
||||
allow_advanced_sampling=spec_config.allow_advanced_sampling,
|
||||
)
|
||||
if spec_config.spec_dec_mode.is_mtp_eagle():
|
||||
return Eagle3SpecMetadata(
|
||||
|
||||
Loading…
Reference in New Issue
Block a user