diff --git a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py index 6f229d881d..60feeb1859 100644 --- a/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py +++ b/tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py @@ -320,13 +320,11 @@ def create_autodeploy_executor(ad_config: LlmArgs): max_draft_len = ( 0 if ad_config.speculative_config is None else ad_config.speculative_config.max_draft_len ) - max_total_draft_tokens = 0 - if ad_config.speculative_config is None: - max_total_draft_tokens = 0 - elif hasattr(ad_config.speculative_config, "max_total_draft_tokens"): - max_total_draft_tokens = ad_config.speculative_config.max_total_draft_tokens - else: - max_total_draft_tokens = max_draft_len + max_total_draft_tokens = ( + 0 + if ad_config.speculative_config is None + else ad_config.speculative_config.max_total_draft_tokens + ) # initialize model engine engine = ADEngine.build_from_config(ad_config=ad_config) @@ -417,6 +415,7 @@ def create_autodeploy_executor(ad_config: LlmArgs): max_input_len=ad_config.max_input_len, max_batch_size=ad_config.max_batch_size, max_draft_len=max_draft_len, + max_total_draft_tokens=max_total_draft_tokens, max_beam_width=ad_config.max_beam_width, ) return py_executor diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index e99eedd224..fcf66a3aae 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -510,7 +510,7 @@ class DeepseekV3Attention(MLA): aux_stream: Optional[torch.cuda.Stream] = None, ): config = model_config.pretrained_config - predicted_tokens_per_seq = model_config.spec_config.max_draft_len + 1 if model_config.spec_config is not None else 1 + predicted_tokens_per_seq = model_config.spec_config.max_total_draft_tokens + 1 if model_config.spec_config is not None else 1 super().__init__(hidden_size=config.hidden_size, num_attention_heads=config.num_attention_heads, num_key_value_heads=config.num_key_value_heads, diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 4099e0e104..e8a62aad5b 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -250,10 +250,10 @@ class KvCacheCreator: if not pytorch_backend_config.disable_overlap_scheduler: num_extra_tokens_per_seq = num_extra_tokens_per_seq + 1 if spec_cfg is not None: - num_extra_tokens_per_seq += spec_cfg.max_draft_len + num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens if spec_cfg is not None: - num_extra_tokens_per_seq += spec_cfg.max_draft_len + num_extra_tokens_per_seq += spec_cfg.max_total_draft_tokens num_extra_tokens_per_seq += get_num_extra_kv_tokens(spec_cfg) if self._dummy_reqs is None: @@ -808,6 +808,8 @@ def create_py_executor_instance( max_beam_width=max_beam_width, max_draft_len=spec_config.max_draft_len if spec_config is not None else 0, + max_total_draft_tokens=spec_config.max_total_draft_tokens + if spec_config is not None else 0, kv_cache_transceiver=kv_cache_transceiver, guided_decoder=guided_decoder, start_worker=start_worker, @@ -824,13 +826,8 @@ def create_torch_sampler_args(mapping: Mapping, *, max_seq_len: int, max_num_sequences = max_batch_size * mapping.pp_size max_draft_len = (0 if speculative_config is None else speculative_config.max_draft_len) - max_total_draft_tokens = 0 - if speculative_config is None: - max_total_draft_tokens = 0 - elif hasattr(speculative_config, 'max_total_draft_tokens'): - max_total_draft_tokens = speculative_config.max_total_draft_tokens - else: - max_total_draft_tokens = max_draft_len + max_total_draft_tokens = (0 if speculative_config is None else + speculative_config.max_total_draft_tokens) return TorchSampler.Args( max_seq_len=max_seq_len, diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index a08a45a8cb..411378deb8 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -93,7 +93,8 @@ class CUDAGraphRunner: @property def max_possible_draft_len(self): engine = self._get_engine() - return (engine.original_max_draft_len if self.enable_spec_decode else 0) + return (engine.original_max_total_draft_tokens + if self.enable_spec_decode else 0) def get_graph_key( self, @@ -102,10 +103,12 @@ class CUDAGraphRunner: engine = self._get_engine() if engine.is_draft_model and spec_resource_manager is not None and isinstance( spec_resource_manager, Eagle3ResourceManager): + # If 'is_first_draft' is True, even with tree decoding, the length of draft_len will only be 'max_draft_len', not 'max_total_draft_token'. + # Because we will pad the input to 'max_draft_len' length for the first draft layer. draft_len = engine.original_max_draft_len if spec_resource_manager.is_first_draft else 0 key = (batch_size, draft_len, spec_resource_manager.is_first_draft) else: - draft_len = self.spec_config.max_draft_len if self.enable_spec_decode else 0 + draft_len = self.spec_config.max_total_draft_tokens if self.enable_spec_decode else 0 key = (batch_size, draft_len, False) return key diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index f1d6146256..dba098684c 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -93,11 +93,11 @@ class ModelEngine(ABC): def _filter_cuda_graph_batch_sizes(cuda_graph_batch_sizes: list[int], max_batch_size: int, max_num_tokens: int, - max_draft_len: int, + max_total_draft_tokens: int, enable_padding: bool) -> list[int]: # This is the largest possible batch size for a pure decoding batch. max_cuda_graph_bs = min(max_batch_size, - int(max_num_tokens / (1 + max_draft_len))) + int(max_num_tokens / (1 + max_total_draft_tokens))) result = [] # This function assumes cuda_graph_batch_sizes is sorted @@ -162,11 +162,13 @@ class PyTorchModelEngine(ModelEngine): ExpertStatistic.create(self.dist.rank) self.pytorch_backend_config = pytorch_backend_config self.original_max_draft_len = spec_config.max_draft_len if spec_config is not None else 0 + self.original_max_total_draft_tokens = spec_config.max_total_draft_tokens if spec_config is not None else 0 # The draft model won't have any draft tokens attached to # generation requests when we invoke it autoregressively if spec_config is not None and is_draft_model: spec_config.max_draft_len = 0 + spec_config.max_total_draft_tokens = 0 self.spec_config = spec_config self.is_spec_decode = spec_config is not None self.sparse_attention_config = sparse_attention_config @@ -277,7 +279,7 @@ class PyTorchModelEngine(ModelEngine): self.spec_metadata = None update_spec_config_from_model_config(self.spec_config, self.model.config) - max_num_draft_tokens = self.original_max_draft_len * batch_size + max_num_draft_tokens = self.original_max_total_draft_tokens * batch_size self.draft_tokens_cuda = torch.empty((max_num_draft_tokens, ), dtype=torch.int, device='cuda') @@ -297,9 +299,11 @@ class PyTorchModelEngine(ModelEngine): self.without_logits = self.spec_config.spec_dec_mode.without_logits( ) or self.model_is_wrapped self.max_draft_len = spec_config.max_draft_len + self.max_total_draft_tokens = spec_config.max_total_draft_tokens else: self.without_logits = False self.max_draft_len = 0 + self.max_total_draft_tokens = 0 self.guided_decoder: Optional[CapturableGuidedDecoder] = None @@ -320,7 +324,7 @@ class PyTorchModelEngine(ModelEngine): self._cuda_graph_batch_sizes = _filter_cuda_graph_batch_sizes( pytorch_backend_config.cuda_graph_batch_sizes, self.batch_size, - self.max_num_tokens, self.original_max_draft_len, + self.max_num_tokens, self.original_max_total_draft_tokens, self._cuda_graph_padding_enabled ) if pytorch_backend_config.cuda_graph_batch_sizes else [] @@ -364,7 +368,7 @@ class PyTorchModelEngine(ModelEngine): @property def runtime_draft_len(self): - return self.max_draft_len if self.enable_spec_decode else 0 + return self.max_total_draft_tokens if self.enable_spec_decode else 0 def set_lora_model_config(self, lora_target_modules: list[str], @@ -585,20 +589,20 @@ class PyTorchModelEngine(ModelEngine): if self.model_is_wrapped and self.is_spec_decode and spec_resource_manager is not None and isinstance( spec_resource_manager, Eagle3ResourceManager): # The CDL path uses draft_len > 0 for the number of iterations in the drafting loop. - draft_lengths.append(self.original_max_draft_len) + draft_lengths.append(self.original_max_total_draft_tokens) else: - draft_lengths.append(self.max_draft_len) + draft_lengths.append(self.max_total_draft_tokens) else: # For non-draft model, we also capture the CUDA graph instance for draft length 0, # so that when we disable spec decode at runtime, we can still run the captured graph. # Note that for one engine mode, we are not able to turn off spec decode at runtime. - if (self.max_draft_len > 0 + if (self.max_total_draft_tokens > 0 and not self.spec_config.spec_dec_mode.use_one_engine() # Assume that speculation is always on if the user didn't give us a max_concurrency # value. This will save on memory. and self.spec_config.max_concurrency is not None): draft_lengths.append(0) - draft_lengths = [self.max_draft_len] + draft_lengths = [self.max_total_draft_tokens] for bs in cuda_graph_batch_sizes: if bs > self.batch_size: @@ -757,7 +761,7 @@ class PyTorchModelEngine(ModelEngine): num_ctx_requests + num_gen_tokens)), token_nums=[1] * num_gen_tokens, is_gen=True, - max_num_draft_tokens=self.max_draft_len, + max_num_draft_tokens=self.max_total_draft_tokens, use_mrope=self.use_mrope) if spec_resource_manager is not None: spec_resource_manager.add_dummy_requests(request_ids=list( @@ -830,7 +834,7 @@ class PyTorchModelEngine(ModelEngine): def _get_cuda_graph_draft_lengths( self, resource_manager: ResourceManager) -> List[int]: """Determines the draft lengths for which to capture CUDA graphs.""" - draft_lengths = [self.max_draft_len] + draft_lengths = [self.max_total_draft_tokens] spec_resource_manager = resource_manager.get_resource_manager( ResourceManagerType.SPEC_RESOURCE_MANAGER) @@ -1027,7 +1031,7 @@ class PyTorchModelEngine(ModelEngine): """ if self.enable_spec_decode and not self._disable_overlap_scheduler: # When enabling overlap scheduler, the kv cache for draft tokens will - # be prepared in advance by using the max_draft_len. But we need to use + # be prepared in advance by using the max_total_draft_tokens. But we need to use # new_tokens_lens_device to get the real past kv lengths and the # correct position ids. And to avoid blocking the async data transfer, # we need to preprocess the inputs in forward to update the position_ids and @@ -2252,7 +2256,7 @@ class PyTorchModelEngine(ModelEngine): # attn_metadata now depends on spec_metadata since it determines the shape/content of spec_dec parameter Tensors is_spec_dec_mode = spec_metadata.spec_dec_mode.attention_need_spec_dec_mode( spec_resource_manager, self.is_draft_model, self.attn_backend, - self.model_is_wrapped) + self.model_is_wrapped, spec_metadata.is_spec_dec_tree) attn_metadata.update_spec_dec_param( is_spec_dec_mode, spec_metadata.is_spec_dec_tree, spec_metadata.is_spec_dec_dynamic_tree, diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor.py b/tensorrt_llm/_torch/pyexecutor/py_executor.py index 02291cd4bb..e4d23937ba 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor.py @@ -160,6 +160,7 @@ class PyExecutor: max_batch_size: int = 8, max_beam_width: int = 1, max_draft_len: int = 0, + max_total_draft_tokens: int = 0, kv_cache_transceiver: Optional[KvCacheTransceiver] = None, guided_decoder: Optional[GuidedDecoder] = None, garbage_collection_gen0_threshold: Optional[int] = None, @@ -195,6 +196,7 @@ class PyExecutor: self.active = True self.max_beam_width = max_beam_width self.max_draft_len = max_draft_len + self.max_total_draft_tokens = max_total_draft_tokens self.max_num_tokens = model_engine.pytorch_backend_config.max_num_tokens self.print_log = model_engine.pytorch_backend_config.print_iter_log self.enable_iter_perf_stats = model_engine.pytorch_backend_config.enable_iter_perf_stats @@ -1040,7 +1042,7 @@ class PyExecutor: self.use_spec_decode = self.drafter.should_use_spec_decode( self.active_requests, self.max_batch_size, self.model_engine.max_num_tokens, - self.model_engine.spec_config.max_draft_len) + self.model_engine.spec_config.max_total_draft_tokens) logger.debug(f"Use spec decode: {self.use_spec_decode}") self.model_engine.enable_spec_decode = self.use_spec_decode @@ -1050,10 +1052,10 @@ class PyExecutor: LlmRequestState.GENERATION_IN_PROGRESS, LlmRequestState.DISAGG_GENERATION_INIT): continue - max_draft_len = self.model_engine.spec_config.max_draft_len + max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens request.draft_tokens = [ 0 - ] * max_draft_len if max_draft_len > 0 else [] + ] * max_total_draft_tokens if max_total_draft_tokens > 0 else [] # When overlap scheduler is enabled, and we already prepared the draft tokens in the previous batch, # we don't need to initialize py_draft_tokens at this stage because we haven't append the accepted tokens to the request yet. @@ -1223,11 +1225,11 @@ class PyExecutor: continue req.py_last_draft_tokens = req.py_draft_tokens - max_draft_len = self.model_engine.spec_config.max_draft_len + max_total_draft_tokens = self.model_engine.spec_config.max_total_draft_tokens - if max_draft_len > 0 and self.use_spec_decode: - req.py_draft_tokens = [0] * max_draft_len - req.py_draft_pages_allocated = max_draft_len + if max_total_draft_tokens > 0 and self.use_spec_decode: + req.py_draft_tokens = [0] * max_total_draft_tokens + req.py_draft_pages_allocated = max_total_draft_tokens else: req.py_draft_tokens = [] req.py_draft_pages_allocated = 0 @@ -1615,7 +1617,7 @@ class PyExecutor: request_ids=[0], is_gen=True, prepare_resource=True, - max_num_draft_tokens=self.max_draft_len, + max_num_draft_tokens=self.max_total_draft_tokens, )[0] llm_request.is_attention_dp_dummy = True spec_resource_manager = self.resource_manager.get_resource_manager( diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 3e2e8ef79f..214ac2014c 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -357,7 +357,9 @@ def create_py_executor( from tensorrt_llm._torch.speculative.drafting_loops import \ ChainDrafter - return ChainDrafter(spec_config.max_draft_len, model) + return ChainDrafter(spec_config.max_draft_len, + spec_config.max_total_draft_tokens, + model) else: drafting_loop_wrapper = None @@ -397,11 +399,11 @@ def create_py_executor( if not pytorch_backend_config.disable_overlap_scheduler: model_engine_max_seq_len = model_engine.max_seq_len + 1 if spec_config is not None: - model_engine_max_seq_len += spec_config.max_draft_len + model_engine_max_seq_len += spec_config.max_total_draft_tokens if spec_config is not None: model_engine_max_seq_len += get_num_extra_kv_tokens(spec_config) - model_engine_max_seq_len += spec_config.max_draft_len + model_engine_max_seq_len += spec_config.max_total_draft_tokens max_seq_len = model_engine_max_seq_len max_num_tokens = model_engine.max_num_tokens @@ -471,7 +473,8 @@ def create_py_executor( "vocab_size_padded": model_engine.model.vocab_size_padded } if spec_config is not None: - kwargs["max_num_draft_tokens"] = spec_config.max_draft_len + kwargs[ + "max_num_draft_tokens"] = spec_config.max_total_draft_tokens if spec_config is None or spec_config.spec_dec_mode.support_guided_decoder( ): diff --git a/tensorrt_llm/_torch/pyexecutor/sampler.py b/tensorrt_llm/_torch/pyexecutor/sampler.py index b52ca5f459..4b58882911 100644 --- a/tensorrt_llm/_torch/pyexecutor/sampler.py +++ b/tensorrt_llm/_torch/pyexecutor/sampler.py @@ -589,7 +589,7 @@ class TorchSampler(Sampler): def __init__(self, args: Args): self.max_seq_len = args.max_seq_len - self.max_tokens = args.max_draft_len + 1 + self.max_tokens = args.max_total_draft_tokens + 1 assert args.max_beam_width == self.MAX_BEAM_WIDTH, ( "TorchSampler only supports beam_width = 1" ) @@ -738,9 +738,9 @@ class TorchSampler(Sampler): we can find the longest match by comparing all the paths. Args: request: LlmRequest. The request with draft tokens. - new_tokens: torch.Tensor. [max_draft_len + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer. + new_tokens: torch.Tensor. [max_total_draft_tokens + 1, max_num_sequences, MAX_BEAM_WIDTH], host buffer. The tokens generated by the target model - The relationship between [max_draft_len + 1] and the draft token tree: + The relationship between [max_total_draft_tokens + 1] and the draft token tree: If the current node is accepted, what is the NEXT token_id that the target model will generate? For example, new_tokens[0, req_idx, 1] indicates the NEXT token_id sampled from the root node in the draft token tree if it is accepted. diff --git a/tensorrt_llm/_torch/speculative/drafter.py b/tensorrt_llm/_torch/speculative/drafter.py index 485934f7b5..3dd8683ba3 100644 --- a/tensorrt_llm/_torch/speculative/drafter.py +++ b/tensorrt_llm/_torch/speculative/drafter.py @@ -29,14 +29,14 @@ class Drafter(ABC): @final def should_use_spec_decode(self, requests: List[LlmRequest], max_batch_size: int, max_num_tokens: int, - max_draft_len: int) -> bool: + max_total_draft_tokens: int) -> bool: """ You probably don't want to override this. ModelEngine assumes that speculation is always on if max_concurrency is not specified by the user's spec config. """ - # Inputs typically validated upstream: max_batch_size>0, max_num_tokens>0, max_draft_len>=0 + # Inputs typically validated upstream: max_batch_size>0, max_num_tokens>0, max_total_draft_tokens>=0 if self.max_concurrency is None: return True @@ -45,7 +45,7 @@ class Drafter(ABC): if not requests or max_batch_size <= 0 or max_num_tokens <= 0: return False - tokens_per_request = 1 + max_draft_len + tokens_per_request = 1 + max_total_draft_tokens token_cap = max_num_tokens // tokens_per_request if token_cap <= 0: return False @@ -63,7 +63,7 @@ class Drafter(ABC): scheduled_requests: The scheduled requests to pad """ for req in scheduled_requests.generation_requests: - max_draft_tokens = self.max_draft_tokens + max_draft_tokens = self.max_draft_len num_draft_tokens = get_draft_token_length(req) req.py_draft_tokens.extend( 0 for _ in range(max_draft_tokens - num_draft_tokens)) diff --git a/tensorrt_llm/_torch/speculative/drafting_loops.py b/tensorrt_llm/_torch/speculative/drafting_loops.py index 20fdcf022e..886f0111ef 100644 --- a/tensorrt_llm/_torch/speculative/drafting_loops.py +++ b/tensorrt_llm/_torch/speculative/drafting_loops.py @@ -107,12 +107,14 @@ def prepare_for_generation(attn_metadata: AttentionMetadata, class ChainDrafter(torch.nn.Module): - def __init__(self, max_draft_len: int, draft_model: torch.nn.Module): + def __init__(self, max_draft_len: int, max_total_draft_tokens: int, + draft_model: torch.nn.Module): super().__init__() self.draft_model = draft_model self.config = self.draft_model.config self.model_config = self.draft_model.model_config self.max_draft_len = max_draft_len + self.max_total_draft_tokens = max_total_draft_tokens def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor, attn_metadata: AttentionMetadata, spec_metadata: SpecMetadata, diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 998f4b28cb..41be42a54d 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -131,6 +131,7 @@ class SpeculativeDecodingMode(IntEnum): is_draft_model: bool, attention_backend: Type[AttentionBackend], use_chain_drafter: bool, + is_spec_dec_tree: bool, ): """ If true, the attention backend kernel needs to run in spec-dec mode (multi-token query mode). @@ -154,8 +155,10 @@ class SpecMetadata: """ # The max number of requests in a single batch. max_num_requests: int - # The max number of draft tokens. + # The number of draft layers. (Also the number of draft tokens for the linear tree.) max_draft_len: int + # The max number of draft tokens for the static tree and dynamic tree . + max_total_draft_tokens: int # The number of gen-phase sequences in the batch. num_generations: int = 0 # Whether CUDA graph is enabled. @@ -191,9 +194,13 @@ class SpecMetadata: # The number of layers num_layers: int = 0 - # if spec-dec tree is a tree or a chain (linear tree) - is_spec_dec_tree: bool = False # if spec-dec tree wouldn't be changed at all, the mask won't be computed every step. + # NOTE: For the linear tree, though it can be treated as a special case of static tree. + # NOTE: But we do not set `is_spec_dec_tree` to True for this cases. + # NOTE: i.e., for the linear tree, is_spec_dec_tree == False and is_spec_dec_dynamic_tree == False. + # whether the spec-dec mode is a tree (can be static tree or dynamic tree). + is_spec_dec_tree: bool = False + # whether the spec-dec mode is a dynamic tree. is_spec_dec_dynamic_tree: bool = False def __post_init__(self): diff --git a/tensorrt_llm/_torch/speculative/model_drafter.py b/tensorrt_llm/_torch/speculative/model_drafter.py index 974e13130e..75bb2e6980 100644 --- a/tensorrt_llm/_torch/speculative/model_drafter.py +++ b/tensorrt_llm/_torch/speculative/model_drafter.py @@ -51,7 +51,8 @@ class ModelDrafter(Drafter): self, spec_config: "DecodingBaseConfig", draft_model_engine: "ModelEngine", - max_draft_tokens: int, + max_draft_len: int, + max_total_draft_tokens: int, draft_seq_slot_manager: SeqSlotManager, sampler: Sampler, spec_resource_manager: Optional[BaseResourceManager] = None, @@ -62,8 +63,11 @@ class ModelDrafter(Drafter): # Validate required parameters if draft_model_engine is None: raise ValueError("draft_model_engine cannot be None") - if max_draft_tokens < 0: - raise ValueError("max_draft_tokens must be >= 0") + if max_draft_len < 0: + raise ValueError("max_draft_len must be >= 0") + if max_total_draft_tokens < 0: + raise ValueError("max_total_draft_tokens must be >= 0") + assert max_draft_len <= max_total_draft_tokens # Model and resource management self.draft_model_engine = draft_model_engine @@ -72,7 +76,8 @@ class ModelDrafter(Drafter): # Configuration self.spec_config = spec_config - self.max_draft_tokens = max_draft_tokens + self.max_draft_len = max_draft_len + self.max_total_draft_tokens = max_total_draft_tokens # Sampling self.sampler = sampler self.guided_decoder = guided_decoder @@ -153,9 +158,11 @@ class ModelDrafter(Drafter): Create a chunked context request for accepted tokens. Only applicable if the draft model needs to recompute KV cache for accepted tokens (e.g. eagle 3) """ - # Pad input_tokens to max_draft_tokens + # Pad input_tokens to max_draft_len + # We use max_draft_len instead of max_total_draft_tokens here, + # because at most max_draft_len draft tokens are accepted. input_tokens.extend( - 0 for _ in range(self.max_draft_tokens - num_accepted_tokens)) + 0 for _ in range(self.max_draft_len - num_accepted_tokens)) new_request = self._create_draft_request(request, input_tokens) new_request.state = LlmRequestState.GENERATION_IN_PROGRESS new_request.py_num_accepted_draft_tokens = request.py_num_accepted_draft_tokens @@ -469,7 +476,7 @@ class ModelDrafter(Drafter): # We already updated the target state, so the new_tokens_lens should be all ones. new_tokens_lens = torch.ones(batch_size, device=device) next_draft_tokens = torch.zeros(batch_size, - self.max_draft_tokens, + self.max_draft_len, device=device) # Create a new SampleStateTensorsMTP object with the additional fields @@ -563,7 +570,7 @@ class ModelDrafter(Drafter): # Chunked prefill request in progress; no need to append draft tokens continue py_draft_logits = [] - for token_idx in range(self.max_draft_tokens): + for token_idx in range(self.max_draft_len): target_model_req.py_draft_tokens.append( draft_tokens_host[token_idx][req_idx]) py_draft_logits.append(draft_logits[token_idx][req_idx]) @@ -646,7 +653,7 @@ class ModelDrafter(Drafter): previous_draft_state = initial_draft_state # Generate remaining draft tokens iteratively - for i in range(self.max_draft_tokens - 1): + for i in range(self.max_draft_len - 1): if len(draft_batch.generation_requests) == 0: break @@ -722,7 +729,7 @@ class ModelDrafter(Drafter): target_inputs, outputs["new_draft_tokens"], draft_position=0, - draft_length=self.max_draft_tokens, + draft_length=self.max_draft_len, draft_batch=draft_batch, req_id_to_old_request=req_id_to_old_request) diff --git a/tensorrt_llm/_torch/speculative/ngram.py b/tensorrt_llm/_torch/speculative/ngram.py index dc23270945..74ec518d60 100644 --- a/tensorrt_llm/_torch/speculative/ngram.py +++ b/tensorrt_llm/_torch/speculative/ngram.py @@ -170,7 +170,7 @@ class NGramDrafter(Drafter): super().__init__(spec_config.max_concurrency) assert ngram_pool_manager is not None, "NGram needs a resource manager to maintain the pool." self.spec_config = spec_config - self.max_draft_tokens = spec_config.max_draft_len + self.max_draft_len = spec_config.max_draft_len self.spec_resource_manager = ngram_pool_manager def prepare_draft_tokens( diff --git a/tensorrt_llm/_torch/speculative/save_hidden_state.py b/tensorrt_llm/_torch/speculative/save_hidden_state.py index 202088784f..936ca5d2ca 100644 --- a/tensorrt_llm/_torch/speculative/save_hidden_state.py +++ b/tensorrt_llm/_torch/speculative/save_hidden_state.py @@ -21,6 +21,7 @@ class SaveHiddenStatesDrafter(Drafter): super().__init__(spec_config.max_concurrency) self.spec_config = spec_config self.max_draft_len = spec_config.max_draft_len + self.max_total_draft_tokens = spec_config.max_total_draft_tokens self._iter = 1 self._output_directory = spec_config.output_directory self._file_prefix = spec_config.file_prefix diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 8ff5ec8fc6..e0c2097c2b 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -23,6 +23,7 @@ def get_spec_metadata(spec_config, if spec_config.spec_dec_mode.is_mtp_one_model(): return MTPSpecMetadata( max_draft_len=spec_config.max_draft_len, + max_total_draft_tokens=spec_config.max_total_draft_tokens, spec_dec_mode=spec_config.spec_dec_mode, mtp_num_modules=spec_config.num_nextn_predict_layers, max_num_requests=max_num_requests, @@ -31,6 +32,7 @@ def get_spec_metadata(spec_config, if spec_config.spec_dec_mode.is_mtp_eagle(): return Eagle3SpecMetadata( max_draft_len=spec_config.max_draft_len, + max_total_draft_tokens=spec_config.max_total_draft_tokens, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, num_layers=model_config.num_hidden_layers, @@ -45,6 +47,7 @@ def get_spec_metadata(spec_config, if spec_config.spec_dec_mode.is_eagle3(): return Eagle3SpecMetadata( max_draft_len=spec_config.max_draft_len, + max_total_draft_tokens=spec_config.max_total_draft_tokens, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, num_layers=model_config.num_hidden_layers, @@ -55,14 +58,15 @@ def get_spec_metadata(spec_config, eagle3_resource_manager=spec_resource_manager, layers_to_capture=spec_config.eagle3_layers_to_capture, is_mtp_eagle=False, - max_total_draft_tokens=spec_config.max_total_draft_tokens, eagle_choices=spec_config.eagle_choices, - is_spec_dec_tree=spec_config.eagle_choices is not None, + is_spec_dec_tree=spec_config.eagle_choices is not None + or spec_config.use_dynamic_tree, is_spec_dec_dynamic_tree=spec_config.use_dynamic_tree, ) if spec_config.spec_dec_mode.is_eagle3_one_model(): return Eagle3OneModelSpecMetadata( max_draft_len=spec_config.max_draft_len, + max_total_draft_tokens=spec_config.max_total_draft_tokens, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, num_layers=model_config.num_hidden_layers, @@ -78,6 +82,7 @@ def get_spec_metadata(spec_config, } return Eagle3SpecMetadata( max_draft_len=spec_config.max_draft_len, + max_total_draft_tokens=1, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, num_layers=model_config.num_hidden_layers, @@ -87,13 +92,13 @@ def get_spec_metadata(spec_config, is_draft_model=is_draft_model, eagle3_resource_manager=spec_resource_manager, layers_to_capture=spec_config.eagle3_layers_to_capture, - max_total_draft_tokens=1, ) if spec_config.spec_dec_mode.is_draft_target() or \ spec_config.spec_dec_mode.is_ngram() or \ spec_config.spec_dec_mode.is_user_provided(): return SpecMetadata( max_draft_len=spec_config.max_draft_len, + max_total_draft_tokens=spec_config.max_total_draft_tokens, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, ) @@ -186,6 +191,7 @@ def get_spec_drafter(model_engine, return ModelDrafter(spec_config, draft_model_engine, spec_config.max_draft_len, + spec_config.max_total_draft_tokens, SeqSlotManager(max_num_requests), sampler, spec_resource_manager=spec_resource_manager, diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 35d02350e9..e4625d5ad7 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -413,7 +413,14 @@ class _ModelFormatKind(Enum): class DecodingBaseConfig(StrictBaseModel): + # The number of the drafter layers. max_draft_len: Optional[int] = None + # The number of draft tokens in the draft tokens tree. + # If it's a linear tree, each draft layer will only generate one draft token. + # In this case, max_draft_len == max_total_draft_tokens. + # If it's a static or dynamic tree, each draft layer may generate more than one draft token. + # In this case, max_total_draft_tokens >= max_draft_len. + max_total_draft_tokens: Optional[int] = None speculative_model_dir: Optional[Union[str, Path]] = None # PyTorch only. @@ -526,6 +533,10 @@ class MedusaDecodingConfig(DecodingBaseConfig): medusa_choices: Optional[List[List[int]]] = None num_medusa_heads: Optional[int] = None + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.max_total_draft_tokens = self.max_draft_len # Current Medusa only support linear tree + @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -544,8 +555,6 @@ class EagleDecodingConfig(DecodingBaseConfig): use_dynamic_tree: Optional[bool] = False # The topK value for each layer when enable dynamic tree. dynamic_tree_max_topK: Optional[int] = None - # The number of draft tokens in the draft tokens tree. - max_total_draft_tokens: Optional[int] = None # The number of eagle layer. will not be used in pytorch flow, just for compatibility with TRT flow num_eagle_layers: Optional[int] = None # The number of non-leaves in each layer. @@ -580,7 +589,7 @@ class EagleDecodingConfig(DecodingBaseConfig): # Checks whether the input eagle choices is valid # and reset the max_draft_len and num_eagle_layers if necessary if self.eagle_choices is not None: - # If eagle_choices is provided, use_dynamic_tree will not be used + # If eagle_choices is provided, use_dynamic_tree should not be used assert not self.use_dynamic_tree, "If eagle_choices is provided, use_dynamic_tree need to be False" # Get num_eagle_layers from eagle_choices @@ -651,6 +660,12 @@ class EagleDecodingConfig(DecodingBaseConfig): return len(self.eagle3_layers_to_capture) return 3 + @functools.cached_property + def is_linear_tree(self) -> bool: + if self.eagle_choices is None and self.use_dynamic_tree is False: + return True + return False + class SaveHiddenStatesDecodingConfig(DecodingBaseConfig): output_directory: str @@ -703,6 +718,10 @@ class UserProvidedDecodingConfig(DecodingBaseConfig): drafter: object # Type is Drafter resource_manager: object = None # Type is Optional[ResourceManager] + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.max_total_draft_tokens = self.max_draft_len # Current UserProvided only support linear tree + @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -735,6 +754,10 @@ class NGramDecodingConfig(DecodingBaseConfig): is_use_oldest: bool = True is_public_pool: bool = True + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.max_total_draft_tokens = self.max_draft_len # Current NGram only support linear tree + @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -747,6 +770,10 @@ class NGramDecodingConfig(DecodingBaseConfig): class DraftTargetDecodingConfig(DecodingBaseConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.max_total_draft_tokens = self.max_draft_len # Current DraftTarget only support linear tree + @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -776,10 +803,18 @@ class MTPDecodingConfig(DecodingBaseConfig): BEGIN_THINKING_PHASE_TOKEN: int = 128798 END_THINKING_PHASE_TOKEN: int = 128799 + def __init__(self, **kwargs): + super().__init__(**kwargs) + if 'num_nextn_predict_layers' in kwargs: + self.max_draft_len = kwargs['num_nextn_predict_layers'] + self.max_total_draft_tokens = kwargs[ + 'num_nextn_predict_layers'] # Current MTP only support linear tree + @classmethod def from_dict(cls, data: dict): out = cls(**data) out.max_draft_len = out.num_nextn_predict_layers + out.max_total_draft_tokens = out.num_nextn_predict_layers # Current MTP only support linear tree return out decoding_type: ClassVar[str] = "MTP" @@ -814,6 +849,10 @@ class AutoDecodingConfig(DecodingBaseConfig): Attributes that are inherited from the base class are ignored. """ + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.max_total_draft_tokens = self.max_draft_len # Current Auto only support linear tree + @classmethod def from_dict(cls, data: dict): return cls(**data) @@ -1156,6 +1195,7 @@ class LookaheadDecodingConfig(DecodingBaseConfig, PybindMirror): def __init__(self, **data): super().__init__(**data) + self.max_total_draft_tokens = self.max_draft_len # Current Lookahead only support linear tree self._check_fields() def calculate_speculative_resource(self): diff --git a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py index 3018c90425..01aa395f95 100644 --- a/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py +++ b/tests/unittest/_torch/speculative/test_dynamic_spec_decode.py @@ -73,7 +73,7 @@ def test_dynamic_spec_decode(enforce_single_worker, # Mock should_use_spec_decode to turn on/off spec decode dynamically. def mock_should_use_spec_decode(requests, max_batch_size, max_num_tokens, - max_draft_len): + max_total_draft_tokens): for req in requests: if req.state != LlmRequestState.GENERATION_IN_PROGRESS: continue @@ -198,40 +198,42 @@ def test_should_use_spec_decode(): assert drafter.should_use_spec_decode(active_requests, max_batch_size=8, max_num_tokens=4096 * 8, - max_draft_len=4) + max_total_draft_tokens=4) # Small batch size ON case: num_effective_requests = min(12, 5, very_large) = 5 <= 6 → True active_requests = [object()] * 12 assert drafter.should_use_spec_decode(active_requests, max_batch_size=5, max_num_tokens=4096 * 8, - max_draft_len=4) + max_total_draft_tokens=4) # Small token budget ON case: token_cap = 28 // (1+4) = 5 → min(12, 8, 5) = 5 <= 6 → True active_requests = [object()] * 12 assert drafter.should_use_spec_decode(active_requests, max_batch_size=8, max_num_tokens=28, - max_draft_len=4) + max_total_draft_tokens=4) # Generic OFF case: num_effective_requests = min(12, 8, very_large) = 8 > 6 → False active_requests = [object()] * 12 assert not drafter.should_use_spec_decode(active_requests, max_batch_size=8, max_num_tokens=4096 * 8, - max_draft_len=4) + max_total_draft_tokens=4) # Edge case - None active requests OFF case active_requests = [] assert not drafter.should_use_spec_decode(active_requests, max_batch_size=8, max_num_tokens=4096 * 8, - max_draft_len=4) + max_total_draft_tokens=4) # Edge case - Token cap equals 0 OFF case: token_cap = 4 // (1+4) = 0 → min(12, 8, 0) = 0 <= 6 → False active_requests = [object()] * 12 - assert not drafter.should_use_spec_decode( - active_requests, max_batch_size=8, max_num_tokens=4, max_draft_len=4) + assert not drafter.should_use_spec_decode(active_requests, + max_batch_size=8, + max_num_tokens=4, + max_total_draft_tokens=4) if __name__ == "__main__": diff --git a/tests/unittest/_torch/speculative/test_mtp.py b/tests/unittest/_torch/speculative/test_mtp.py index c4a9783e79..d3965f477c 100644 --- a/tests/unittest/_torch/speculative/test_mtp.py +++ b/tests/unittest/_torch/speculative/test_mtp.py @@ -309,6 +309,7 @@ class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase): spec_metadata = MTPSpecMetadata(max_num_requests=32, spec_dec_mode=spec_config.spec_dec_mode, max_draft_len=mtp_num_modules, + max_total_draft_tokens=mtp_num_modules, mtp_num_modules=mtp_num_modules) spec_metadata.draft_tokens = draft_tokens @@ -891,6 +892,7 @@ class TestMTPUpdateMTPHiddenStates(unittest.TestCase): max_num_requests=32, spec_dec_mode=spec_config.spec_dec_mode, max_draft_len=num_nextn_predict_layers, + max_total_draft_tokens=num_nextn_predict_layers, mtp_num_modules=num_nextn_predict_layers, mtp_hidden_states_manager=spec_manager) spec_metadata.request_ids = request_ids @@ -1386,6 +1388,7 @@ class TestMTPPrepareDrafterInputs(unittest.TestCase): max_num_requests=32, spec_dec_mode=spec_config.spec_dec_mode, max_draft_len=num_nextn_predict_layers, + max_total_draft_tokens=num_nextn_predict_layers, mtp_num_modules=num_nextn_predict_layers, mtp_hidden_states_manager=spec_manager) spec_metadata.request_ids = request_ids