This commit is contained in:
Mike Iovine 2026-01-13 11:26:17 +08:00 committed by GitHub
commit 4406c1d76a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 11 additions and 15 deletions

View File

@ -48,8 +48,7 @@ from ..speculative import (SpecMetadata, get_num_extra_kv_tokens,
get_spec_metadata,
update_spec_config_from_model_config)
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata,
Eagle3ResourceManager, Eagle3SpecMetadata)
from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
from ..speculative.mtp import SampleStateTensorsMTP
from ..speculative.utils import SpecDecodingTensor
from ..utils import (get_model_extra_attrs,
@ -2684,9 +2683,9 @@ class PyTorchModelEngine(ModelEngine):
num_accepted_draft_tokens)]
if isinstance(spec_metadata, Eagle3SpecMetadata):
spec_metadata.request_accepted_path = request_accepted_path
if isinstance(spec_metadata, Eagle3OneModelSpecMetadata):
spec_metadata.populate_sampling_params_for_one_model(
scheduled_requests.all_requests())
# No-op for non 1-model
spec_metadata.populate_sampling_params_for_one_model(
scheduled_requests.all_requests())
spec_metadata.prepare()
inputs['spec_metadata'] = spec_metadata

View File

@ -281,16 +281,12 @@ def create_py_executor(
)
llm_args.disable_overlap_scheduler = True
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine():
if not spec_config.allow_advanced_sampling:
logger.warning(
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
)
elif spec_config.spec_dec_mode.is_mtp_one_model():
logger.warning(
"Advanced sampling is not supported for MTP yet - this will be added soon."
)
if spec_config is not None and spec_config.spec_dec_mode.use_one_engine(
) and not spec_config.allow_advanced_sampling:
logger.warning(
f"Falling back to greedy decoding for {spec_config.decoding_type}. If you "
"want to use non-greedy sampling, please set allow_advanced_sampling=True."
)
if mm_encoder_only:
llm_args.mm_encoder_only = True

View File

@ -31,6 +31,7 @@ def get_spec_metadata(spec_config,
mtp_num_modules=spec_config.num_nextn_predict_layers,
max_num_requests=max_num_requests,
mtp_hidden_states_manager=spec_resource_manager,
allow_advanced_sampling=spec_config.allow_advanced_sampling,
)
if spec_config.spec_dec_mode.is_mtp_eagle():
return Eagle3SpecMetadata(