mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[None][feat] Clean up ngram auto mode, add max_concurrency to configs (#6676)
Signed-off-by: Mike Iovine <6158008+mikeiovine@users.noreply.github.com>
This commit is contained in:
parent
4055b764db
commit
e968f98b43
@ -871,6 +871,10 @@ class PyExecutor:
|
||||
self.use_spec_decode = self.drafter.should_use_spec_decode(
|
||||
self.active_requests)
|
||||
self.model_engine.enable_spec_decode = self.use_spec_decode
|
||||
# If speculation is off, this function sets py_draft_tokens to None
|
||||
# for all active requests. If it's on, we initialize py_draft_tokens
|
||||
# with dummy draft tokens to make the scheduler aware of the fact
|
||||
# that speculation is about to happen.
|
||||
self._prepare_draft_requests()
|
||||
|
||||
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
|
||||
|
||||
@ -170,6 +170,14 @@ def _mangle_executor_config(executor_config: ExecutorConfig):
|
||||
)
|
||||
executor_config.enable_chunked_context = False
|
||||
|
||||
spec_config = executor_config.speculative_config
|
||||
if not executor_config.pytorch_backend_config.disable_overlap_scheduler and spec_config is not None:
|
||||
if not spec_config.spec_dec_mode.support_overlap_scheduler():
|
||||
logger.warning(
|
||||
f"Disable overlap scheduler for speculation mode {spec_config.spec_dec_mode.name}"
|
||||
)
|
||||
executor_config.pytorch_backend_config.disable_overlap_scheduler = True
|
||||
|
||||
|
||||
def _get_mapping(executor_config: ExecutorConfig) -> Mapping:
|
||||
if executor_config.mapping is None:
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from .auto_heuristic import suggest_spec_config
|
||||
from .eagle3 import Eagle3SpecMetadata
|
||||
from .interface import SpecMetadata
|
||||
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
|
||||
@ -23,4 +24,5 @@ __all__ = [
|
||||
"get_spec_resource_manager",
|
||||
"get_spec_worker",
|
||||
"update_spec_config_from_model_config",
|
||||
"suggest_spec_config",
|
||||
]
|
||||
|
||||
17
tensorrt_llm/_torch/speculative/auto_heuristic.py
Normal file
17
tensorrt_llm/_torch/speculative/auto_heuristic.py
Normal file
@ -0,0 +1,17 @@
|
||||
def suggest_spec_config(max_batch_size: int) -> "DecodingBaseConfig":
|
||||
"""
|
||||
Suggests a reasonable draft model free speculation scheme.
|
||||
Used when the user specifies spec_mode == AUTO.
|
||||
|
||||
For now, we always use an ngram scheme that gets disabled at
|
||||
BS>=32.
|
||||
"""
|
||||
from tensorrt_llm.llmapi.llm_args import NGramDecodingConfig
|
||||
return NGramDecodingConfig(
|
||||
max_draft_len=5 if max_batch_size <= 4 else 3,
|
||||
max_matching_ngram_size=3 if max_batch_size <= 4 else 5,
|
||||
max_concurrency=32,
|
||||
is_keep_all=True,
|
||||
is_use_oldest=True,
|
||||
is_public_pool=True,
|
||||
)
|
||||
@ -9,6 +9,9 @@ from ..pyexecutor.scheduler import ScheduledRequests
|
||||
class Drafter(ABC):
|
||||
"""Abstract base class for all drafter implementations."""
|
||||
|
||||
def __init__(self, max_concurrency: Optional[int] = None) -> None:
|
||||
self.max_concurrency = max_concurrency
|
||||
|
||||
@abstractmethod
|
||||
def prepare_draft_tokens(
|
||||
self,
|
||||
@ -25,4 +28,6 @@ class Drafter(ABC):
|
||||
|
||||
def should_use_spec_decode(self, requests: List[LlmRequest]) -> bool:
|
||||
"""Check if spec decode should be used for the current iteration."""
|
||||
if self.max_concurrency is not None:
|
||||
return len(requests) <= self.max_concurrency
|
||||
return True
|
||||
|
||||
@ -48,6 +48,8 @@ class ModelDrafter(Drafter):
|
||||
spec_resource_manager: Optional[BaseResourceManager] = None,
|
||||
guided_decoder: Optional[GuidedDecoder] = None,
|
||||
):
|
||||
super().__init__(spec_config.max_concurrency)
|
||||
|
||||
# Validate required parameters
|
||||
if draft_model_engine is None:
|
||||
raise ValueError("draft_model_engine cannot be None")
|
||||
|
||||
@ -168,6 +168,7 @@ class NGramDrafter(Drafter):
|
||||
spec_config: NGramDecodingConfig,
|
||||
ngram_pool_manager: NGramPoolManager = None,
|
||||
):
|
||||
super().__init__(spec_config.max_concurrency)
|
||||
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
|
||||
self.spec_config = spec_config
|
||||
self.max_draft_len = spec_config.max_draft_len
|
||||
@ -178,11 +179,6 @@ class NGramDrafter(Drafter):
|
||||
scheduled_requests: ScheduledRequests,
|
||||
resource_manager: Optional[ResourceManager] = None,
|
||||
) -> None:
|
||||
# Disable NGram speculative decoding auto heuristic for batch size > 32.
|
||||
if self.spec_config.is_auto_heuristic and len(
|
||||
scheduled_requests.all_requests()) > 32:
|
||||
return
|
||||
|
||||
# Sort by request_id when py_batch_idx is None as a fallback.
|
||||
# This happens in the disagg case: for a set of new requests, we draft
|
||||
# before forward_step, so py_batch_idx is not assigned.
|
||||
|
||||
@ -32,8 +32,8 @@ from ..logger import logger
|
||||
from ..sampling_params import SamplingParams
|
||||
from ..scheduling_params import SchedulingParams
|
||||
from .llm_args import (TORCH_LLMARGS_EXPLICIT_DOCSTRING,
|
||||
TRT_LLMARGS_EXPLICIT_DOCSTRING, NGramDecodingConfig,
|
||||
PeftCacheConfig, PybindMirror, TorchLlmArgs, TrtLlmArgs)
|
||||
TRT_LLMARGS_EXPLICIT_DOCSTRING, PeftCacheConfig,
|
||||
PybindMirror, TorchLlmArgs, TrtLlmArgs)
|
||||
from .llm_utils import (CachedModelLoader, KvCacheRetentionConfig,
|
||||
LlmBuildStats, ModelLoader, _ModelRuntimeContext)
|
||||
from .mpi_session import MpiPoolSession, external_mpi_comm_available
|
||||
@ -1015,32 +1015,10 @@ class _TorchLLM(BaseLLM):
|
||||
|
||||
spec_config = self.args.speculative_config
|
||||
max_batch_size = self._executor_config.max_batch_size
|
||||
# Apply default heuristic to AutoDecodingConfig based on benchmark results
|
||||
# With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
|
||||
# With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
|
||||
# With concurrency > 32, speculative decoding is disabled.
|
||||
|
||||
if spec_config is not None and spec_config.decoding_type == "AUTO":
|
||||
if not self.args.disable_overlap_scheduler:
|
||||
logger.info(
|
||||
"Disable overlap scheduler to enable Auto speculative decoding with Ngram."
|
||||
)
|
||||
# From benchmark results, we found that NGram speculative decoding provides better performance than overlap scheduler with low concurrency <= 32.
|
||||
# Therefore, we disable overlap scheduler to enable NGram speculative decoding.
|
||||
self.args.disable_overlap_scheduler = True
|
||||
|
||||
spec_config = NGramDecodingConfig(
|
||||
max_draft_len=5 if max_batch_size <= 4 else 3,
|
||||
max_matching_ngram_size=3 if max_batch_size <= 4 else 5,
|
||||
is_keep_all=True,
|
||||
is_use_oldest=True,
|
||||
is_public_pool=True,
|
||||
# Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic.
|
||||
is_auto_heuristic=True,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Apply heuristic to incomplete NGramDecodingConfig: max_draft_len={spec_config.max_draft_len}, max_matching_ngram_size={spec_config.max_matching_ngram_size}"
|
||||
)
|
||||
from tensorrt_llm._torch.speculative import suggest_spec_config
|
||||
spec_config = suggest_spec_config(max_batch_size)
|
||||
|
||||
update_executor_config(
|
||||
self._executor_config,
|
||||
|
||||
@ -342,6 +342,11 @@ class DecodingBaseConfig(StrictBaseModel):
|
||||
max_draft_len: Optional[int] = None
|
||||
speculative_model_dir: Optional[Union[str, Path]] = None
|
||||
|
||||
# PyTorch only.
|
||||
# When specified, speculation will be disabled at batch sizes above
|
||||
# this value. Otherwise, speculation will always be on.
|
||||
max_concurrency: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
# dispatch to the correct decoding config
|
||||
@ -469,9 +474,6 @@ class NGramDecodingConfig(DecodingBaseConfig):
|
||||
is_keep_all: bool = True
|
||||
is_use_oldest: bool = True
|
||||
is_public_pool: bool = True
|
||||
# Flag to indicate the NGramDecodingConfig is instantiated by auto heuristic.
|
||||
# User should not set this flag. Use AutoDecodingConfig instead.
|
||||
is_auto_heuristic: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
@ -535,13 +537,10 @@ class AutoDecodingConfig(DecodingBaseConfig):
|
||||
"""
|
||||
Configuration for auto speculative decoding.
|
||||
|
||||
This config is used to automatically select the best speculative decoding algorithm.
|
||||
This config will automatically select a good, draft-model free
|
||||
speculation algorithm with some heuristic.
|
||||
|
||||
According to benchmark results, the best algorithm in general is NGRAM with low concurrency <= 32.
|
||||
Default heuristic:
|
||||
With concurrency <= 4, max_draft_len = 5, max_matching_ngram_size = 3
|
||||
With concurrency <= 32, max_draft_len = 3, max_matching_ngram_size = 5
|
||||
With concurrency > 32, speculative decoding is disabled.
|
||||
Attributes that are inherited from the base class are ignored.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
|
||||
Loading…
Reference in New Issue
Block a user