[None][feat] Clean up ngram auto mode, add max_concurrency to configs (#6676)

Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
Mike Iovine 2025-08-07 12:51:47 -04:00 committed by GitHub
parent 4055b764db
commit e968f98b43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 52 additions and 41 deletions

View File

@ -871,6 +871,10 @@ class PyExecutor:
self.use_spec_decode = self.drafter.should_use_spec_decode(
self.active_requests)
self.model_engine.enable_spec_decode = self.use_spec_decode
# If speculation is off, this function sets py_draft_tokens to None
# for all active requests. If it's on, we initialize py_draft_tokens
# with dummy draft tokens to make the scheduler aware of the fact
# that speculation is about to happen.
self._prepare_draft_requests()
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(

View File

@ -170,6 +170,14 @@ def _mangle_executor_config(executor_config: ExecutorConfig):
)
executor_config.enable_chunked_context = False
spec_config = executor_config.speculative_config
if not executor_config.pytorch_backend_config.disable_overlap_scheduler and spec_config is not None:
if not spec_config.spec_dec_mode.support_overlap_scheduler():
logger.warning(
f"Disable overlap scheduler for speculation mode {spec_config.spec_dec_mode.name}"
)
executor_config.pytorch_backend_config.disable_overlap_scheduler = True
def _get_mapping(executor_config: ExecutorConfig) -> Mapping:
if executor_config.mapping is None:

View File

@ -1,3 +1,4 @@
from .auto_heuristic import suggest_spec_config
from .eagle3 import Eagle3SpecMetadata
from .interface import SpecMetadata
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
@ -23,4 +24,5 @@ __all__ = [
"get_spec_resource_manager",
"get_spec_worker",
"update_spec_config_from_model_config",
"suggest_spec_config",
]

View File

@ -0,0 +1,17 @@
def suggest_spec_config(max_batch_size: int) -> "DecodingBaseConfig":
"""
Suggests a reasonable draft model free speculation scheme.
Used when the user specifies spec_mode == AUTO.
For now, we always use an ngram scheme that gets disabled at
BS>=32.
"""
from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig
return NGramDecodingConfig(
max_draft_len=5 if max_batch_size <= 4 else 3,
max_matching_ngram_size=3 if max_batch_size <= 4 else 5,
max_concurrency=32,
is_keep_all=True,
is_use_oldest=True,
is_public_pool=True,
)

View File

@ -9,6 +9,9 @@ from ..pyexecutor.scheduler import ScheduledRequests
class Drafter(ABC):
"""Abstract base class for all drafter implementations."""
def __init__(self, max_concurrency: Optional[int] = None) -> None:
self.max_concurrency = max_concurrency
@abstractmethod
def prepare_draft_tokens(
self,
@ -25,4 +28,6 @@ class Drafter(ABC):
def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
"""Check if spec decode should be used for the current iteration."""
if self.max_concurrency is not None:
return len(requests) <= self.max_concurrency
return True

View File

@ -48,6 +48,8 @@ class ModelDrafter(Drafter):
spec_resource_manager: Optional[BaseResourceManager] = None,
guided_decoder: Optional[GuidedDecoder] = None,
):
super().__init__(spec_config.max_concurrency)
# Validate required parameters
if draft_model_engine is None:
raise ValueError("draft_model_engine cannot be None")

View File

@ -168,6 +168,7 @@ class NGramDrafter(Drafter):
spec_config: NGramDecodingConfig,
ngram_pool_manager: NGramPoolManager = None,
):
super().__init__(spec_config.max_concurrency)
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
self.spec_config = spec_config
self.max_draft_len = spec_config.max_draft_len
@ -178,11 +179,6 @@ class NGramDrafter(Drafter):
scheduled_requests: ScheduledRequests,
resource_manager: Optional[ResourceManager] = None,
) -> None:
# Disable NGram speculative decoding auto heuristic for batch size > 32.
if self.spec_config.is_auto_heuristic and len(
scheduled_requests.all_requests()) > 32:
return
# Sort by request_id when py_batch_idx is None as a fallback.
# This happens in the disagg case: for a set of new requests, we draft
# before forward_step, so py_batch_idx is not assigned.

View File

@ -32,8 +32,8 @@ from ..logger import logger
from ..sampling_params import SamplingParams
from ..scheduling_params import SchedulingParams
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig,
PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs)
TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig,
PybindMirror, TorchLlmArgs, TrtLlmArgs)
from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig,
LlmBuildStats, ModelLoader, _ModelRuntimeContext)
from .mpi_session import MpiPoolSession, external_mpi_comm_available
@ -1015,32 +1015,10 @@ class _TorchLLM(BaseLLM):
spec_config = self.args.speculative_config
max_batch_size = self._executor_config.max_batch_size
# Apply default heuristic to AutoDecodingConfig based on benchmark results
# With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
# With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
# With concurrency > 32, speculative decoding is disabled.
if spec_config is not None and spec_config.decoding_type == "AUTO":
if not self.args.disable_overlap_scheduler:
logger.info(
"Disable overlap scheduler to enable Auto speculative decoding with Ngram."
)
# From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32.
# Therefore, we disable overlap scheduler to enable NGram speculative decoding.
self.args.disable_overlap_scheduler = True
spec_config = NGramDecodingConfig(
max_draft_len=5 if max_batch_size <= 4 else 3,
max_matching_ngram_size=3 if max_batch_size <= 4 else 5,
is_keep_all=True,
is_use_oldest=True,
is_public_pool=True,
# Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic.
is_auto_heuristic=True,
)
logger.info(
f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}"
)
from tensorrt_llm._torch.speculative import suggest_spec_config
spec_config = suggest_spec_config(max_batch_size)
update_executor_config(
self._executor_config,

View File

@ -342,6 +342,11 @@ class DecodingBaseConfig(StrictBaseModel):
max_draft_len: Optional[int] = None
speculative_model_dir: Optional[Union[str, Path]] = None
# PyTorch only.
# When specified, speculation will be disabled at batch sizes above
# this value. Otherwise, speculation will always be on.
max_concurrency: Optional[int] = None
@classmethod
def from_dict(cls, data: dict):
# dispatch to the correct decoding config
@ -469,9 +474,6 @@ class NGramDecodingConfig(DecodingBaseConfig):
is_keep_all: bool = True
is_use_oldest: bool = True
is_public_pool: bool = True
# Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic.
# User should not set this flag. Use AutoDecodingConfig instead.
is_auto_heuristic: bool = False
@classmethod
def from_dict(cls, data: dict):
@ -535,13 +537,10 @@ class AutoDecodingConfig(DecodingBaseConfig):
"""
Configuration for auto speculative decoding.
This config is used to automatically select the best speculative decoding algorithm.
This config will automatically select a good, draft-model free
speculation algorithm with some heuristic.
According to benchmark results, the best algorithm in general is NGRAM with low concurrency <= 32.
Default heuristic:
With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
With concurrency > 32, speculative decoding is disabled.
Attributes that are inherited from the base class are ignored.
"""
@classmethod