[TRTLLM-8160][feat] Add max_total_draft_tokens (#8366)

Signed-off-by: Yue Weng <25103990+yweng0828@users.noreply.github.com>
This commit is contained in:
YueWeng 2025-10-21 23:11:04 +08:00 committed by GitHub
parent a0024f4d34
commit 8dc4aac5b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 156 additions and 80 deletions

View File

@ -320,13 +320,11 @@ def create_autodeploy_executor(ad_config: LlmArgs):
max_draft_len = (
0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_len
)
max_total_draft_tokens = 0
if ad_config.speculative_config is None:
max_total_draft_tokens = 0
elif hasattr(ad_config.speculative_config, "max_total_draft_tokens"):
max_total_draft_tokens = ad_config.speculative_config.max_total_draft_tokens
else:
max_total_draft_tokens = max_draft_len
max_total_draft_tokens = (
0
if ad_config.speculative_config is None
else ad_config.speculative_config.max_total_draft_tokens
)
# initialize model engine
engine = ADEngine.build_from_config(ad_config=ad_config)
@ -417,6 +415,7 @@ def create_autodeploy_executor(ad_config: LlmArgs):
max_input_len=ad_config.max_input_len,
max_batch_size=ad_config.max_batch_size,
max_draft_len=max_draft_len,
max_total_draft_tokens=max_total_draft_tokens,
max_beam_width=ad_config.max_beam_width,
)
return py_executor

View File

@ -510,7 +510,7 @@ class DeepseekV3Attention(MLA):
aux_stream: Optional[torch.cuda.Stream] = None,
):
config = model_config.pretrained_config
predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1
predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1
super().__init__(hidden_size=config.hidden_size,
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,

View File

@ -250,10 +250,10 @@ class KvCacheCreator:
if not pytorch_backend_config.disable_overlap_scheduler:
num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1
if spec_cfg is not None:
num_extra_tokens_per_seq += spec_cfg.max_draft_len
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
if spec_cfg is not None:
num_extra_tokens_per_seq += spec_cfg.max_draft_len
num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens
num_extra_tokens_per_seq += get_num_extra_kv_tokens(spec_cfg)
if self._dummy_reqs is None:
@ -808,6 +808,8 @@ def create_py_executor_instance(
max_beam_width=max_beam_width,
max_draft_len=spec_config.max_draft_len
if spec_config is not None else 0,
max_total_draft_tokens=spec_config.max_total_draft_tokens
if spec_config is not None else 0,
kv_cache_transceiver=kv_cache_transceiver,
guided_decoder=guided_decoder,
start_worker=start_worker,
@ -824,13 +826,8 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int,
max_num_sequences = max_batch_size * mapping.pp_size
max_draft_len = (0 if speculative_config is None else
speculative_config.max_draft_len)
max_total_draft_tokens = 0
if speculative_config is None:
max_total_draft_tokens = 0
elif hasattr(speculative_config, 'max_total_draft_tokens'):
max_total_draft_tokens = speculative_config.max_total_draft_tokens
else:
max_total_draft_tokens = max_draft_len
max_total_draft_tokens = (0 if speculative_config is None else
speculative_config.max_total_draft_tokens)
return TorchSampler.Args(
max_seq_len=max_seq_len,

View File

@ -93,7 +93,8 @@ class CUDAGraphRunner:
@property
def max_possible_draft_len(self):
engine = self._get_engine()
return (engine.original_max_draft_len if self.enable_spec_decode else 0)
return (engine.original_max_total_draft_tokens
if self.enable_spec_decode else 0)
def get_graph_key(
self,
@ -102,10 +103,12 @@ class CUDAGraphRunner:
engine = self._get_engine()
if engine.is_draft_model and spec_resource_manager is not None and isinstance(
spec_resource_manager, Eagle3ResourceManager):
# If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'.
# Because we will pad the input to 'max_draft_len' length for the first draft layer.
draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0
key = (batch_size, draft_len, spec_resource_manager.is_first_draft)
else:
draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0
draft_len = self.spec_config.max_total_draft_tokens if self.enable_spec_decode else 0
key = (batch_size, draft_len, False)
return key

View File

@ -93,11 +93,11 @@ class ModelEngine(ABC):
def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int],
max_batch_size: int, max_num_tokens: int,
max_draft_len: int,
max_total_draft_tokens: int,
enable_padding: bool) -> list[int]:
# This is the largest possible batch size for a pure decoding batch.
max_cuda_graph_bs = min(max_batch_size,
int(max_num_tokens / (1 + max_draft_len)))
int(max_num_tokens / (1 + max_total_draft_tokens)))
result = []
# This function assumes cuda_graph_batch_sizes is sorted
@ -162,11 +162,13 @@ class PyTorchModelEngine(ModelEngine):
ExpertStatistic.create(self.dist.rank)
self.pytorch_backend_config = pytorch_backend_config
self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0
self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0
# The draft model won't have any draft tokens attached to
# generation requests when we invoke it autoregressively
if spec_config is not None and is_draft_model:
spec_config.max_draft_len = 0
spec_config.max_total_draft_tokens = 0
self.spec_config = spec_config
self.is_spec_decode = spec_config is not None
self.sparse_attention_config = sparse_attention_config
@ -277,7 +279,7 @@ class PyTorchModelEngine(ModelEngine):
self.spec_metadata = None
update_spec_config_from_model_config(self.spec_config,
self.model.config)
max_num_draft_tokens = self.original_max_draft_len * batch_size
max_num_draft_tokens = self.original_max_total_draft_tokens * batch_size
self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ),
dtype=torch.int,
device='cuda')
@ -297,9 +299,11 @@ class PyTorchModelEngine(ModelEngine):
self.without_logits = self.spec_config.spec_dec_mode.without_logits(
) or self.model_is_wrapped
self.max_draft_len = spec_config.max_draft_len
self.max_total_draft_tokens = spec_config.max_total_draft_tokens
else:
self.without_logits = False
self.max_draft_len = 0
self.max_total_draft_tokens = 0
self.guided_decoder: Optional[CapturableGuidedDecoder] = None
@ -320,7 +324,7 @@ class PyTorchModelEngine(ModelEngine):
self._cuda_graph_batch_sizes = _filter_cuda_graph_batch_sizes(
pytorch_backend_config.cuda_graph_batch_sizes, self.batch_size,
self.max_num_tokens, self.original_max_draft_len,
self.max_num_tokens, self.original_max_total_draft_tokens,
self._cuda_graph_padding_enabled
) if pytorch_backend_config.cuda_graph_batch_sizes else []
@ -364,7 +368,7 @@ class PyTorchModelEngine(ModelEngine):
@property
def runtime_draft_len(self):
return self.max_draft_len if self.enable_spec_decode else 0
return self.max_total_draft_tokens if self.enable_spec_decode else 0
def set_lora_model_config(self,
lora_target_modules: list[str],
@ -585,20 +589,20 @@ class PyTorchModelEngine(ModelEngine):
if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance(
spec_resource_manager, Eagle3ResourceManager):
# The CDL path uses draft_len > 0 for the number of iterations in the drafting loop.
draft_lengths.append(self.original_max_draft_len)
draft_lengths.append(self.original_max_total_draft_tokens)
else:
draft_lengths.append(self.max_draft_len)
draft_lengths.append(self.max_total_draft_tokens)
else:
# For non-draft model, we also capture the CUDA graph instance for draft length 0,
# so that when we disable spec decode at runtime, we can still run the captured graph.
# Note that for one engine mode, we are not able to turn off spec decode at runtime.
if (self.max_draft_len > 0
if (self.max_total_draft_tokens > 0
and not self.spec_config.spec_dec_mode.use_one_engine()
# Assume that speculation is always on if the user didn't give us a max_concurrency
# value. This will save on memory.
and self.spec_config.max_concurrency is not None):
draft_lengths.append(0)
draft_lengths = [self.max_draft_len]
draft_lengths = [self.max_total_draft_tokens]
for bs in cuda_graph_batch_sizes:
if bs > self.batch_size:
@ -757,7 +761,7 @@ class PyTorchModelEngine(ModelEngine):
num_ctx_requests + num_gen_tokens)),
token_nums=[1] * num_gen_tokens,
is_gen=True,
max_num_draft_tokens=self.max_draft_len,
max_num_draft_tokens=self.max_total_draft_tokens,
use_mrope=self.use_mrope)
if spec_resource_manager is not None:
spec_resource_manager.add_dummy_requests(request_ids=list(
@ -830,7 +834,7 @@ class PyTorchModelEngine(ModelEngine):
def _get_cuda_graph_draft_lengths(
self, resource_manager: ResourceManager) -> List[int]:
"""Determines the draft lengths for which to capture CUDA graphs."""
draft_lengths = [self.max_draft_len]
draft_lengths = [self.max_total_draft_tokens]
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
@ -1027,7 +1031,7 @@ class PyTorchModelEngine(ModelEngine):
"""
if self.enable_spec_decode and not self._disable_overlap_scheduler:
# When enabling overlap scheduler, the kv cache for draft tokens will
# be prepared in advance by using the max_draft_len. But we need to use
# be prepared in advance by using the max_total_draft_tokens. But we need to use
# new_tokens_lens_device to get the real past kv lengths and the
# correct position ids. And to avoid blocking the async data transfer,
# we need to preprocess the inputs in forward to update the position_ids and
@ -2252,7 +2256,7 @@ class PyTorchModelEngine(ModelEngine):
# attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors
is_spec_dec_mode = spec_metadata.spec_dec_mode.attention_need_spec_dec_mode(
spec_resource_manager, self.is_draft_model, self.attn_backend,
self.model_is_wrapped)
self.model_is_wrapped, spec_metadata.is_spec_dec_tree)
attn_metadata.update_spec_dec_param(
is_spec_dec_mode, spec_metadata.is_spec_dec_tree,
spec_metadata.is_spec_dec_dynamic_tree,

View File

@ -160,6 +160,7 @@ class PyExecutor:
max_batch_size: int = 8,
max_beam_width: int = 1,
max_draft_len: int = 0,
max_total_draft_tokens: int = 0,
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
guided_decoder: Optional[GuidedDecoder] = None,
garbage_collection_gen0_threshold: Optional[int] = None,
@ -195,6 +196,7 @@ class PyExecutor:
self.active = True
self.max_beam_width = max_beam_width
self.max_draft_len = max_draft_len
self.max_total_draft_tokens = max_total_draft_tokens
self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens
self.print_log = model_engine.pytorch_backend_config.print_iter_log
self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats
@ -1040,7 +1042,7 @@ class PyExecutor:
self.use_spec_decode = self.drafter.should_use_spec_decode(
self.active_requests, self.max_batch_size,
self.model_engine.max_num_tokens,
self.model_engine.spec_config.max_draft_len)
self.model_engine.spec_config.max_total_draft_tokens)
logger.debug(f"Use spec decode: {self.use_spec_decode}")
self.model_engine.enable_spec_decode = self.use_spec_decode
@ -1050,10 +1052,10 @@ class PyExecutor:
LlmRequestState.GENERATION_IN_PROGRESS,
LlmRequestState.DISAGG_GENERATION_INIT):
continue
max_draft_len = self.model_engine.spec_config.max_draft_len
max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens
request.draft_tokens = [
0
] * max_draft_len if max_draft_len > 0 else []
] * max_total_draft_tokens if max_total_draft_tokens > 0 else []
# When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch,
# we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet.
@ -1223,11 +1225,11 @@ class PyExecutor:
continue
req.py_last_draft_tokens = req.py_draft_tokens
max_draft_len = self.model_engine.spec_config.max_draft_len
max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens
if max_draft_len > 0 and self.use_spec_decode:
req.py_draft_tokens = [0] * max_draft_len
req.py_draft_pages_allocated = max_draft_len
if max_total_draft_tokens > 0 and self.use_spec_decode:
req.py_draft_tokens = [0] * max_total_draft_tokens
req.py_draft_pages_allocated = max_total_draft_tokens
else:
req.py_draft_tokens = []
req.py_draft_pages_allocated = 0
@ -1615,7 +1617,7 @@ class PyExecutor:
request_ids=[0],
is_gen=True,
prepare_resource=True,
max_num_draft_tokens=self.max_draft_len,
max_num_draft_tokens=self.max_total_draft_tokens,
)[0]
llm_request.is_attention_dp_dummy = True
spec_resource_manager = self.resource_manager.get_resource_manager(

View File

@ -357,7 +357,9 @@ def create_py_executor(
from tensorrt_llm._torch.speculative.drafting_loops import \
ChainDrafter
return ChainDrafter(spec_config.max_draft_len, model)
return ChainDrafter(spec_config.max_draft_len,
spec_config.max_total_draft_tokens,
model)
else:
drafting_loop_wrapper = None
@ -397,11 +399,11 @@ def create_py_executor(
if not pytorch_backend_config.disable_overlap_scheduler:
model_engine_max_seq_len = model_engine.max_seq_len + 1
if spec_config is not None:
model_engine_max_seq_len += spec_config.max_draft_len
model_engine_max_seq_len += spec_config.max_total_draft_tokens
if spec_config is not None:
model_engine_max_seq_len += get_num_extra_kv_tokens(spec_config)
model_engine_max_seq_len += spec_config.max_draft_len
model_engine_max_seq_len += spec_config.max_total_draft_tokens
max_seq_len = model_engine_max_seq_len
max_num_tokens = model_engine.max_num_tokens
@ -471,7 +473,8 @@ def create_py_executor(
"vocab_size_padded": model_engine.model.vocab_size_padded
}
if spec_config is not None:
kwargs["max_num_draft_tokens"] = spec_config.max_draft_len
kwargs[
"max_num_draft_tokens"] = spec_config.max_total_draft_tokens
if spec_config is None or spec_config.spec_dec_mode.support_guided_decoder(
):

View File

@ -589,7 +589,7 @@ class TorchSampler(Sampler):
def __init__(self, args: Args):
self.max_seq_len = args.max_seq_len
self.max_tokens = args.max_draft_len + 1
self.max_tokens = args.max_total_draft_tokens + 1
assert args.max_beam_width == self.MAX_BEAM_WIDTH, (
"TorchSampler only supports beam_width = 1"
)
@ -738,9 +738,9 @@ class TorchSampler(Sampler):
we can find the longest match by comparing all the paths.
Args:
request: LlmRequest. The request with draft tokens.
new_tokens: torch.Tensor. [max_draft_len + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer.
new_tokens: torch.Tensor. [max_total_draft_tokens + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer.
The tokens generated by the target model
The relationship between [max_draft_len + 1] and the draft token tree:
The relationship between [max_total_draft_tokens + 1] and the draft token tree:
If the current node is accepted, what is the NEXT token_id that the target model will generate?
For example, new_tokens[0, req_idx, 1] indicates the NEXT token_id sampled from the root
node in the draft token tree if it is accepted.

View File

@ -29,14 +29,14 @@ class Drafter(ABC):
@final
def should_use_spec_decode(self, requests: List[LlmRequest],
max_batch_size: int, max_num_tokens: int,
max_draft_len: int) -> bool:
max_total_draft_tokens: int) -> bool:
"""
You probably don't want to override this. ModelEngine
assumes that speculation is always on if max_concurrency
is not specified by the user's spec config.
"""
# Inputs typically validated upstream: max_batch_size>0, max_num_tokens>0, max_draft_len>=0
# Inputs typically validated upstream: max_batch_size>0, max_num_tokens>0, max_total_draft_tokens>=0
if self.max_concurrency is None:
return True
@ -45,7 +45,7 @@ class Drafter(ABC):
if not requests or max_batch_size <= 0 or max_num_tokens <= 0:
return False
tokens_per_request = 1 + max_draft_len
tokens_per_request = 1 + max_total_draft_tokens
token_cap = max_num_tokens // tokens_per_request
if token_cap <= 0:
return False
@ -63,7 +63,7 @@ class Drafter(ABC):
scheduled_requests: The scheduled requests to pad
"""
for req in scheduled_requests.generation_requests:
max_draft_tokens = self.max_draft_tokens
max_draft_tokens = self.max_draft_len
num_draft_tokens = get_draft_token_length(req)
req.py_draft_tokens.extend(
0 for _ in range(max_draft_tokens - num_draft_tokens))

View File

@ -107,12 +107,14 @@ def prepare_for_generation(attn_metadata: AttentionMetadata,
class ChainDrafter(torch.nn.Module):
def __init__(self, max_draft_len: int, draft_model: torch.nn.Module):
def __init__(self, max_draft_len: int, max_total_draft_tokens: int,
draft_model: torch.nn.Module):
super().__init__()
self.draft_model = draft_model
self.config = self.draft_model.config
self.model_config = self.draft_model.model_config
self.max_draft_len = max_draft_len
self.max_total_draft_tokens = max_total_draft_tokens
def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor,
attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata,

View File

@ -131,6 +131,7 @@ class SpeculativeDecodingMode(IntEnum):
is_draft_model: bool,
attention_backend: Type[AttentionBackend],
use_chain_drafter: bool,
is_spec_dec_tree: bool,
):
"""
If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode).
@ -154,8 +155,10 @@ class SpecMetadata:
"""
# The max number of requests in a single batch.
max_num_requests: int
# The max number of draft tokens.
# The number of draft layers. (Also the number of draft tokens for the linear tree.)
max_draft_len: int
# The max number of draft tokens for the static tree and dynamic tree .
max_total_draft_tokens: int
# The number of gen-phase sequences in the batch.
num_generations: int = 0
# Whether CUDA graph is enabled.
@ -191,9 +194,13 @@ class SpecMetadata:
# The number of layers
num_layers: int = 0
# if spec-dec tree is a tree or a chain (linear tree)
is_spec_dec_tree: bool = False
# if spec-dec tree wouldn't be changed at all, the mask won't be computed every step.
# NOTE: For the linear tree, though it can be treated as a special case of static tree.
# NOTE: But we do not set `is_spec_dec_tree` to True for this cases.
# NOTE: i.e., for the linear tree, is_spec_dec_tree == False and is_spec_dec_dynamic_tree == False.
# whether the spec-dec mode is a tree (can be static tree or dynamic tree).
is_spec_dec_tree: bool = False
# whether the spec-dec mode is a dynamic tree.
is_spec_dec_dynamic_tree: bool = False
def __post_init__(self):

View File

@ -51,7 +51,8 @@ class ModelDrafter(Drafter):
self,
spec_config: "DecodingBaseConfig",
draft_model_engine: "ModelEngine",
max_draft_tokens: int,
max_draft_len: int,
max_total_draft_tokens: int,
draft_seq_slot_manager: SeqSlotManager,
sampler: Sampler,
spec_resource_manager: Optional[BaseResourceManager] = None,
@ -62,8 +63,11 @@ class ModelDrafter(Drafter):
# Validate required parameters
if draft_model_engine is None:
raise ValueError("draft_model_engine cannot be None")
if max_draft_tokens < 0:
raise ValueError("max_draft_tokens must be >= 0")
if max_draft_len < 0:
raise ValueError("max_draft_len must be >= 0")
if max_total_draft_tokens < 0:
raise ValueError("max_total_draft_tokens must be >= 0")
assert max_draft_len <= max_total_draft_tokens
# Model and resource management
self.draft_model_engine = draft_model_engine
@ -72,7 +76,8 @@ class ModelDrafter(Drafter):
# Configuration
self.spec_config = spec_config
self.max_draft_tokens = max_draft_tokens
self.max_draft_len = max_draft_len
self.max_total_draft_tokens = max_total_draft_tokens
# Sampling
self.sampler = sampler
self.guided_decoder = guided_decoder
@ -153,9 +158,11 @@ class ModelDrafter(Drafter):
Create a chunked context request for accepted tokens.
Only applicable if the draft model needs to recompute KV cache for accepted tokens (e.g. eagle 3)
"""
# Pad input_tokens to max_draft_tokens
# Pad input_tokens to max_draft_len
# We use max_draft_len instead of max_total_draft_tokens here,
# because at most max_draft_len draft tokens are accepted.
input_tokens.extend(
0 for _ in range(self.max_draft_tokens - num_accepted_tokens))
0 for _ in range(self.max_draft_len - num_accepted_tokens))
new_request = self._create_draft_request(request, input_tokens)
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
new_request.py_num_accepted_draft_tokens = request.py_num_accepted_draft_tokens
@ -469,7 +476,7 @@ class ModelDrafter(Drafter):
# We already updated the target state, so the new_tokens_lens should be all ones.
new_tokens_lens = torch.ones(batch_size, device=device)
next_draft_tokens = torch.zeros(batch_size,
self.max_draft_tokens,
self.max_draft_len,
device=device)
# Create a new SampleStateTensorsMTP object with the additional fields
@ -563,7 +570,7 @@ class ModelDrafter(Drafter):
# Chunked prefill request in progress; no need to append draft tokens
continue
py_draft_logits = []
for token_idx in range(self.max_draft_tokens):
for token_idx in range(self.max_draft_len):
target_model_req.py_draft_tokens.append(
draft_tokens_host[token_idx][req_idx])
py_draft_logits.append(draft_logits[token_idx][req_idx])
@ -646,7 +653,7 @@ class ModelDrafter(Drafter):
previous_draft_state = initial_draft_state
# Generate remaining draft tokens iteratively
for i in range(self.max_draft_tokens - 1):
for i in range(self.max_draft_len - 1):
if len(draft_batch.generation_requests) == 0:
break
@ -722,7 +729,7 @@ class ModelDrafter(Drafter):
target_inputs,
outputs["new_draft_tokens"],
draft_position=0,
draft_length=self.max_draft_tokens,
draft_length=self.max_draft_len,
draft_batch=draft_batch,
req_id_to_old_request=req_id_to_old_request)

View File

@ -170,7 +170,7 @@ class NGramDrafter(Drafter):
super().__init__(spec_config.max_concurrency)
assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool."
self.spec_config = spec_config
self.max_draft_tokens = spec_config.max_draft_len
self.max_draft_len = spec_config.max_draft_len
self.spec_resource_manager = ngram_pool_manager
def prepare_draft_tokens(

View File

@ -21,6 +21,7 @@ class SaveHiddenStatesDrafter(Drafter):
super().__init__(spec_config.max_concurrency)
self.spec_config = spec_config
self.max_draft_len = spec_config.max_draft_len
self.max_total_draft_tokens = spec_config.max_total_draft_tokens
self._iter = 1
self._output_directory = spec_config.output_directory
self._file_prefix = spec_config.file_prefix

View File

@ -23,6 +23,7 @@ def get_spec_metadata(spec_config,
if spec_config.spec_dec_mode.is_mtp_one_model():
return MTPSpecMetadata(
max_draft_len=spec_config.max_draft_len,
max_total_draft_tokens=spec_config.max_total_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
mtp_num_modules=spec_config.num_nextn_predict_layers,
max_num_requests=max_num_requests,
@ -31,6 +32,7 @@ def get_spec_metadata(spec_config,
if spec_config.spec_dec_mode.is_mtp_eagle():
return Eagle3SpecMetadata(
max_draft_len=spec_config.max_draft_len,
max_total_draft_tokens=spec_config.max_total_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=model_config.num_hidden_layers,
@ -45,6 +47,7 @@ def get_spec_metadata(spec_config,
if spec_config.spec_dec_mode.is_eagle3():
return Eagle3SpecMetadata(
max_draft_len=spec_config.max_draft_len,
max_total_draft_tokens=spec_config.max_total_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=model_config.num_hidden_layers,
@ -55,14 +58,15 @@ def get_spec_metadata(spec_config,
eagle3_resource_manager=spec_resource_manager,
layers_to_capture=spec_config.eagle3_layers_to_capture,
is_mtp_eagle=False,
max_total_draft_tokens=spec_config.max_total_draft_tokens,
eagle_choices=spec_config.eagle_choices,
is_spec_dec_tree=spec_config.eagle_choices is not None,
is_spec_dec_tree=spec_config.eagle_choices is not None
or spec_config.use_dynamic_tree,
is_spec_dec_dynamic_tree=spec_config.use_dynamic_tree,
)
if spec_config.spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelSpecMetadata(
max_draft_len=spec_config.max_draft_len,
max_total_draft_tokens=spec_config.max_total_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=model_config.num_hidden_layers,
@ -78,6 +82,7 @@ def get_spec_metadata(spec_config,
}
return Eagle3SpecMetadata(
max_draft_len=spec_config.max_draft_len,
max_total_draft_tokens=1,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=model_config.num_hidden_layers,
@ -87,13 +92,13 @@ def get_spec_metadata(spec_config,
is_draft_model=is_draft_model,
eagle3_resource_manager=spec_resource_manager,
layers_to_capture=spec_config.eagle3_layers_to_capture,
max_total_draft_tokens=1,
)
if spec_config.spec_dec_mode.is_draft_target() or \
spec_config.spec_dec_mode.is_ngram() or \
spec_config.spec_dec_mode.is_user_provided():
return SpecMetadata(
max_draft_len=spec_config.max_draft_len,
max_total_draft_tokens=spec_config.max_total_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
)
@ -186,6 +191,7 @@ def get_spec_drafter(model_engine,
return ModelDrafter(spec_config,
draft_model_engine,
spec_config.max_draft_len,
spec_config.max_total_draft_tokens,
SeqSlotManager(max_num_requests),
sampler,
spec_resource_manager=spec_resource_manager,

View File

@ -413,7 +413,14 @@ class _ModelFormatKind(Enum):
class DecodingBaseConfig(StrictBaseModel):
# The number of the drafter layers.
max_draft_len: Optional[int] = None
# The number of draft tokens in the draft tokens tree.
# If it's a linear tree, each draft layer will only generate one draft token.
# In this case, max_draft_len == max_total_draft_tokens.
# If it's a static or dynamic tree, each draft layer may generate more than one draft token.
# In this case, max_total_draft_tokens >= max_draft_len.
max_total_draft_tokens: Optional[int] = None
speculative_model_dir: Optional[Union[str, Path]] = None
# PyTorch only.
@ -526,6 +533,10 @@ class MedusaDecodingConfig(DecodingBaseConfig):
medusa_choices: Optional[List[List[int]]] = None
num_medusa_heads: Optional[int] = None
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.max_total_draft_tokens = self.max_draft_len # Current Medusa only support linear tree
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
@ -544,8 +555,6 @@ class EagleDecodingConfig(DecodingBaseConfig):
use_dynamic_tree: Optional[bool] = False
# The topK value for each layer when enable dynamic tree.
dynamic_tree_max_topK: Optional[int] = None
# The number of draft tokens in the draft tokens tree.
max_total_draft_tokens: Optional[int] = None
# The number of eagle layer. will not be used in pytorch flow, just for compatibility with TRT flow
num_eagle_layers: Optional[int] = None
# The number of non-leaves in each layer.
@ -580,7 +589,7 @@ class EagleDecodingConfig(DecodingBaseConfig):
# Checks whether the input eagle choices is valid
# and reset the max_draft_len and num_eagle_layers if necessary
if self.eagle_choices is not None:
# If eagle_choices is provided, use_dynamic_tree will not be used
# If eagle_choices is provided, use_dynamic_tree should not be used
assert not self.use_dynamic_tree, "If eagle_choices is provided, use_dynamic_tree need to be False"
# Get num_eagle_layers from eagle_choices
@ -651,6 +660,12 @@ class EagleDecodingConfig(DecodingBaseConfig):
return len(self.eagle3_layers_to_capture)
return 3
@functools.cached_property
def is_linear_tree(self) -> bool:
if self.eagle_choices is None and self.use_dynamic_tree is False:
return True
return False
class SaveHiddenStatesDecodingConfig(DecodingBaseConfig):
output_directory: str
@ -703,6 +718,10 @@ class UserProvidedDecodingConfig(DecodingBaseConfig):
drafter: object # Type is Drafter
resource_manager: object = None # Type is Optional[ResourceManager]
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.max_total_draft_tokens = self.max_draft_len # Current UserProvided only support linear tree
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
@ -735,6 +754,10 @@ class NGramDecodingConfig(DecodingBaseConfig):
is_use_oldest: bool = True
is_public_pool: bool = True
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.max_total_draft_tokens = self.max_draft_len # Current NGram only support linear tree
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
@ -747,6 +770,10 @@ class NGramDecodingConfig(DecodingBaseConfig):
class DraftTargetDecodingConfig(DecodingBaseConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.max_total_draft_tokens = self.max_draft_len # Current DraftTarget only support linear tree
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
@ -776,10 +803,18 @@ class MTPDecodingConfig(DecodingBaseConfig):
BEGIN_THINKING_PHASE_TOKEN: int = 128798
END_THINKING_PHASE_TOKEN: int = 128799
def __init__(self, **kwargs):
super().__init__(**kwargs)
if 'num_nextn_predict_layers' in kwargs:
self.max_draft_len = kwargs['num_nextn_predict_layers']
self.max_total_draft_tokens = kwargs[
'num_nextn_predict_layers'] # Current MTP only support linear tree
@classmethod
def from_dict(cls, data: dict):
out = cls(**data)
out.max_draft_len = out.num_nextn_predict_layers
out.max_total_draft_tokens = out.num_nextn_predict_layers # Current MTP only support linear tree
return out
decoding_type: ClassVar[str] = "MTP"
@ -814,6 +849,10 @@ class AutoDecodingConfig(DecodingBaseConfig):
Attributes that are inherited from the base class are ignored.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.max_total_draft_tokens = self.max_draft_len # Current Auto only support linear tree
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
@ -1156,6 +1195,7 @@ class LookaheadDecodingConfig(DecodingBaseConfig, PybindMirror):
def __init__(self, **data):
super().__init__(**data)
self.max_total_draft_tokens = self.max_draft_len # Current Lookahead only support linear tree
self._check_fields()
def calculate_speculative_resource(self):

View File

@ -73,7 +73,7 @@ def test_dynamic_spec_decode(enforce_single_worker,
# Mock should_use_spec_decode to turn on/off spec decode dynamically.
def mock_should_use_spec_decode(requests, max_batch_size, max_num_tokens,
max_draft_len):
max_total_draft_tokens):
for req in requests:
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
continue
@ -198,40 +198,42 @@ def test_should_use_spec_decode():
assert drafter.should_use_spec_decode(active_requests,
max_batch_size=8,
max_num_tokens=4096 * 8,
max_draft_len=4)
max_total_draft_tokens=4)
# Small batch size ON case: num_effective_requests = min(12, 5, very_large) = 5 <= 6 → True
active_requests = [object()] * 12
assert drafter.should_use_spec_decode(active_requests,
max_batch_size=5,
max_num_tokens=4096 * 8,
max_draft_len=4)
max_total_draft_tokens=4)
# Small token budget ON case: token_cap = 28 // (1+4) = 5 → min(12, 8, 5) = 5 <= 6 → True
active_requests = [object()] * 12
assert drafter.should_use_spec_decode(active_requests,
max_batch_size=8,
max_num_tokens=28,
max_draft_len=4)
max_total_draft_tokens=4)
# Generic OFF case: num_effective_requests = min(12, 8, very_large) = 8 > 6 → False
active_requests = [object()] * 12
assert not drafter.should_use_spec_decode(active_requests,
max_batch_size=8,
max_num_tokens=4096 * 8,
max_draft_len=4)
max_total_draft_tokens=4)
# Edge case - None active requests OFF case
active_requests = []
assert not drafter.should_use_spec_decode(active_requests,
max_batch_size=8,
max_num_tokens=4096 * 8,
max_draft_len=4)
max_total_draft_tokens=4)
# Edge case - Token cap equals 0 OFF case: token_cap = 4 // (1+4) = 0 → min(12, 8, 0) = 0 <= 6 → False
active_requests = [object()] * 12
assert not drafter.should_use_spec_decode(
active_requests, max_batch_size=8, max_num_tokens=4, max_draft_len=4)
assert not drafter.should_use_spec_decode(active_requests,
max_batch_size=8,
max_num_tokens=4,
max_total_draft_tokens=4)
if __name__ == "__main__":

View File

@ -309,6 +309,7 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase):
spec_metadata = MTPSpecMetadata(max_num_requests=32,
spec_dec_mode=spec_config.spec_dec_mode,
max_draft_len=mtp_num_modules,
max_total_draft_tokens=mtp_num_modules,
mtp_num_modules=mtp_num_modules)
spec_metadata.draft_tokens = draft_tokens
@ -891,6 +892,7 @@ class TestMTPUpdateMTPHiddenStates(unittest.TestCase):
max_num_requests=32,
spec_dec_mode=spec_config.spec_dec_mode,
max_draft_len=num_nextn_predict_layers,
max_total_draft_tokens=num_nextn_predict_layers,
mtp_num_modules=num_nextn_predict_layers,
mtp_hidden_states_manager=spec_manager)
spec_metadata.request_ids = request_ids
@ -1386,6 +1388,7 @@ class TestMTPPrepareDrafterInputs(unittest.TestCase):
max_num_requests=32,
spec_dec_mode=spec_config.spec_dec_mode,
max_draft_len=num_nextn_predict_layers,
max_total_draft_tokens=num_nextn_predict_layers,
mtp_num_modules=num_nextn_predict_layers,
mtp_hidden_states_manager=spec_manager)
spec_metadata.request_ids = request_ids