diff --git a/tensorrt_llm/_torch/attention_backend/interface.py b/tensorrt_llm/_torch/attention_backend/interface.py index cfa33def9d..0f4c5bac42 100644 --- a/tensorrt_llm/_torch/attention_backend/interface.py +++ b/tensorrt_llm/_torch/attention_backend/interface.py @@ -64,6 +64,8 @@ class AttentionMetadata: max_num_sequences: Optional[int] = None # The KV cache manager. kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2] + # Draft KV cache manager for one-model speculative decoding with separate KV cache layouts + draft_kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2] = None mapping: Optional[Mapping] = None enable_flash_mla: bool = False diff --git a/tensorrt_llm/_torch/attention_backend/sparse/rocket.py b/tensorrt_llm/_torch/attention_backend/sparse/rocket.py index 9bacfed2b0..4e75c0904d 100644 --- a/tensorrt_llm/_torch/attention_backend/sparse/rocket.py +++ b/tensorrt_llm/_torch/attention_backend/sparse/rocket.py @@ -974,6 +974,7 @@ class RocketKVCacheManager(KVCacheManager): use_mrope: bool = False, max_beam_width: int = 1, num_extra_decoding_steps: int = 0, + draft_kv_cache_manager=None, ): requests = super().add_dummy_requests( request_ids=request_ids, @@ -984,6 +985,7 @@ class RocketKVCacheManager(KVCacheManager): use_mrope=use_mrope, max_beam_width=max_beam_width, num_extra_decoding_steps=num_extra_decoding_steps, + draft_kv_cache_manager=draft_kv_cache_manager, ) if prepare_resource: for req in requests: diff --git a/tensorrt_llm/_torch/attention_backend/trtllm.py b/tensorrt_llm/_torch/attention_backend/trtllm.py index 47f433e1fc..9620746137 100644 --- a/tensorrt_llm/_torch/attention_backend/trtllm.py +++ b/tensorrt_llm/_torch/attention_backend/trtllm.py @@ -679,6 +679,12 @@ class TrtllmAttentionMetadata(AttentionMetadata): helix_is_inactive_rank: Optional[torch.Tensor] = None helix_is_inactive_rank_cpu: Optional[torch.Tensor] = None + # Block offsets for the target and draft KV caches + kv_cache_block_offsets: Optional[torch.Tensor] = None + host_kv_cache_block_offsets: Optional[torch.Tensor] = None + draft_kv_cache_block_offsets: Optional[torch.Tensor] = None + draft_host_kv_cache_block_offsets: Optional[torch.Tensor] = None + @property def max_seq_len(self) -> int: """ @@ -786,6 +792,27 @@ class TrtllmAttentionMetadata(AttentionMetadata): ) self.block_ids_per_seq = None self.kv_block_ids_per_seq = None + + # Allocate separate block offset tensors for draft KV cache manager + # Used in one-model speculative decoding with different KV cache layouts + if self.draft_kv_cache_manager is not None: + self.draft_kv_cache_block_offsets = self.get_empty( + buffers, + [ + self.draft_kv_cache_manager.num_pools, + self.max_num_sequences, 2, + self.draft_kv_cache_manager.max_blocks_per_seq + ], + cache_name="draft_kv_cache_block_offsets", + dtype=torch.int32, + capture_graph=capture_graph, + ) + self.draft_host_kv_cache_block_offsets = torch.empty_like( + self.draft_kv_cache_block_offsets, + device='cpu', + pin_memory=True, + ) + if self.enable_flash_mla: self.block_ids_per_seq = self.get_empty( buffers, @@ -987,6 +1014,25 @@ class TrtllmAttentionMetadata(AttentionMetadata): assert self.kv_lens[:self.num_seqs].max( ) <= self.kv_cache_manager.max_seq_len, error_message + # Also prepare draft KV cache block offsets if draft_kv_cache_manager exists + if self.draft_kv_cache_manager is not None: + # Copy blocks for all context requests + self.draft_kv_cache_manager.impl.copy_batch_block_offsets( + self.draft_host_kv_cache_block_offsets, + self.request_ids[:self.num_contexts], 1, 0) + # Copy blocks for all generation requests + self.draft_kv_cache_manager.impl.copy_batch_block_offsets( + self.draft_host_kv_cache_block_offsets, + self.request_ids[self.num_contexts:], self.beam_width, + self.num_contexts) + for pool_idx in range( + self.draft_host_kv_cache_block_offsets.shape[0]): + self.draft_kv_cache_block_offsets[ + pool_idx, :self.num_seqs].copy_( + self.draft_host_kv_cache_block_offsets[ + pool_idx, :self.num_seqs], + non_blocking=True) + self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs] # Don't use self.kv_lens here because it includes extra tokens. # Use actual KV length (without extra tokens) for kv_lens_runtime, diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 921eafe0a8..d8d339921a 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -1840,16 +1840,18 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model, input_ids: torch.IntTensor = None, position_ids: Optional[torch.IntTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, - spec_metadata: Optional[SpecMetadata] = None, return_context_logits: bool = False, + spec_metadata: Optional[SpecMetadata] = None, + resource_manager=None, **kwargs, ) -> torch.Tensor: return super().forward(attn_metadata=attn_metadata, input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds, - spec_metadata=spec_metadata, return_context_logits=return_context_logits, + spec_metadata=spec_metadata, + resource_manager=resource_manager, **kwargs) def load_weights(self, weights: ConsumableWeightsDict): diff --git a/tensorrt_llm/_torch/models/modeling_llama.py b/tensorrt_llm/_torch/models/modeling_llama.py index a6c32766af..54193a32c0 100644 --- a/tensorrt_llm/_torch/models/modeling_llama.py +++ b/tensorrt_llm/_torch/models/modeling_llama.py @@ -1401,6 +1401,7 @@ class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model, inputs_embeds: Optional[torch.FloatTensor] = None, return_context_logits: bool = False, spec_metadata: Optional[SpecMetadata] = None, + resource_manager=None, **kwargs, ) -> torch.Tensor: multimodal_params = kwargs.get("multimodal_params", []) @@ -1422,7 +1423,8 @@ class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model, position_ids, inputs_embeds, spec_metadata=spec_metadata, - return_context_logits=return_context_logits) + return_context_logits=return_context_logits, + resource_manager=resource_manager) def infer_max_seq_len(self): if self.model_config.attn_backend.upper() != 'TRTLLM': diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index a22e196a93..1d062e8413 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -19,7 +19,8 @@ from ..modules.linear import (Linear, TensorParallelMode, WeightMode, WeightsLoadingConfig) from ..modules.rms_norm import RMSNorm from ..pyexecutor.guided_decoder import CapturableGuidedDecoder -from ..speculative import SpecMetadata, get_spec_worker +from ..speculative import (SpecMetadata, get_spec_worker, + should_use_separate_draft_kv_cache) from ..utils import AuxStreamType from .checkpoints.base_weight_mapper import BaseWeightMapper from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel, @@ -931,6 +932,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig], vocab_size=model_config.pretrained_config.vocab_size) self.draft_model = None self.draft_config = None + self.use_separate_draft_kv_cache = False spec_config = getattr(model_config, 'spec_config', None) if spec_config and spec_config.spec_dec_mode.use_one_engine(): if spec_config.spec_dec_mode.is_eagle3_one_model(): @@ -964,11 +966,16 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig], self.draft_config.quant_config.kv_cache_quant_algo = \ model_config.quant_config.kv_cache_quant_algo + self.use_separate_draft_kv_cache = should_use_separate_draft_kv_cache( + spec_config) + self.draft_model = get_draft_model(model_config, self.draft_config, self.lm_head, self.model) - self.spec_worker = get_spec_worker(model_config.spec_config, - model_config, - model_config.mapping) + self.spec_worker = get_spec_worker( + model_config.spec_config, + model_config, + model_config.mapping, + use_separate_draft_kv_cache=self.use_separate_draft_kv_cache) self.epilogue.append(self.draft_model) self.epilogue.append(self.spec_worker) @@ -987,6 +994,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig], inputs_embeds: Optional[torch.FloatTensor] = None, return_context_logits: bool = False, spec_metadata: Optional[SpecMetadata] = None, + resource_manager=None, **kwargs, ) -> torch.Tensor: hidden_states = self.model( @@ -1030,7 +1038,8 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig], logits=logits, attn_metadata=attn_metadata, spec_metadata=spec_metadata, - draft_model=self.draft_model) + draft_model=self.draft_model, + resource_manager=resource_manager) else: logits = self.logits_processor.forward( hidden_states, diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index 497e92d861..0343f746ad 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -25,7 +25,8 @@ from tensorrt_llm.mapping import CpType, Mapping from ..attention_backend import get_sparse_attn_kv_cache_manager from ..model_config import ModelConfig -from ..speculative import get_num_extra_kv_tokens, get_spec_decoder +from ..speculative import (get_num_extra_kv_tokens, get_num_spec_layers, + get_spec_decoder, should_use_separate_draft_kv_cache) from .config_utils import is_mla, is_nemotron_hybrid, is_qwen3_next from .guided_decoder import GuidedDecoder from .kv_cache_connector import KvCacheConnectorManager @@ -80,6 +81,7 @@ class KvCacheCreator: sparse_attention_config: SparseAttentionConfig, profiling_stage_data: Optional[dict], execution_stream: Optional[torch.cuda.Stream] = None, + draft_config: Optional[ModelConfig] = None, ): self._model_engine = model_engine self._draft_model_engine = draft_model_engine @@ -103,6 +105,7 @@ class KvCacheCreator: self._execution_stream = execution_stream if self._kv_cache_manager_cls == KVCacheManager and kv_cache_config.use_kv_cache_manager_v2: self._kv_cache_manager_cls = KVCacheManagerV2 + self._draft_config = draft_config def _get_kv_size_per_token(self): model_config = self._model_engine.model.model_config @@ -115,6 +118,12 @@ class KvCacheCreator: draft_model_config, mapping, tokens_per_block=self._tokens_per_block) + elif self._should_create_separate_draft_kv_cache(): + # One-model draft with separate KV cache layout + kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token( + self._draft_config, + mapping, + tokens_per_block=self._tokens_per_block) return kv_size_per_token def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction, @@ -397,9 +406,12 @@ class KvCacheCreator: # get kv cache stats for both model and draft model kv_stats = py_executor.resource_manager.resource_managers.get( ResourceManagerType.KV_CACHE_MANAGER).get_kv_cache_stats() - kv_stats_draft = py_executor.resource_manager.resource_managers.get( - ResourceManagerType.DRAFT_KV_CACHE_MANAGER).get_kv_cache_stats( - ) if self._draft_model_engine is not None else None + # Get draft KV cache stats if present (either from two-model mode or one-model + # mode with separate draft KV cache) + draft_kv_cache_manager = py_executor.resource_manager.resource_managers.get( + ResourceManagerType.DRAFT_KV_CACHE_MANAGER) + kv_stats_draft = draft_kv_cache_manager.get_kv_cache_stats( + ) if draft_kv_cache_manager is not None else None # get total allocated bytes allocated_bytes = kv_stats.allocated_bytes + ( @@ -466,6 +478,15 @@ class KvCacheCreator: mapping = self._mapping assert model_engine.model.model_config.is_generation, "Only construct KV cache for generation models." + # When using separate draft KV cache in one-model speculative decoding, + # use layer_mask to include only target layers. The draft layers should + # only be in the separate draft KV cache manager. + # We still pass spec_config so that num_extra_kv_tokens is calculated. + spec_dec_layer_mask = None + if self._should_create_separate_draft_kv_cache(): + num_target_layers = model_engine.model.model_config.pretrained_config.num_hidden_layers + spec_dec_layer_mask = [True] * num_target_layers + kv_cache_manager = _create_kv_cache_manager( model_engine=model_engine, kv_cache_manager_cls=self._kv_cache_manager_cls, @@ -481,6 +502,7 @@ class KvCacheCreator: kv_connector_manager=self._kv_connector_manager, estimating_kv_cache=estimating_kv_cache, execution_stream=self._execution_stream, + layer_mask=spec_dec_layer_mask, ) # KVCacheManager (Non-draft) modifies the max_seq_len field, update it to self @@ -503,6 +525,64 @@ class KvCacheCreator: return kv_cache_manager + def _should_create_separate_draft_kv_cache(self) -> bool: + """ + Check if we need a separate draft KV cache manager for one-model mode. + Returns True if the speculative config has use_separate_draft_kv_cache=True. + """ + if self._draft_config is None: + return False + + return should_use_separate_draft_kv_cache(self._speculative_config) + + def _create_one_model_draft_kv_cache_manager( + self, + estimating_kv_cache: bool = False) -> Optional[KVCacheManager]: + """ + Create a KV cache manager for draft model layers in one-model mode + when target and draft have different KV cache layouts. + """ + # Get target model's num_hidden_layers to compute correct layer indices. + # Draft model layers in one-model mode start at target_num_layers. + target_pretrained_config = self._model_engine.model.model_config.pretrained_config + target_num_layers = target_pretrained_config.num_hidden_layers + # Use get_num_spec_layers to get the correct number of draft layers + # for the speculative decoding mode (e.g., num_eagle_layers for Eagle3) + num_draft_layers = get_num_spec_layers(self._speculative_config) + + # Create layer_mask: False for target layers, True for draft layers. + # This ensures the draft KV cache manager uses the correct layer indices + # (e.g., layers 32, 33, ... instead of 0, 1, ...). + spec_dec_layer_mask = [False + ] * target_num_layers + [True] * num_draft_layers + + # Get the appropriate KV cache manager class for the draft model + draft_kv_cache_manager_cls = get_kv_cache_manager_cls( + self._draft_config) + + return _create_kv_cache_manager( + model_engine=None, + kv_cache_manager_cls=draft_kv_cache_manager_cls, + mapping=self._mapping, + kv_cache_config=self._kv_cache_config, + tokens_per_block=self._tokens_per_block, + max_seq_len=self._max_seq_len, + max_batch_size=self._max_batch_size, + spec_config=self._speculative_config, + sparse_attn_config=None, # Not applicable for draft in one-model mode + max_num_tokens=self._max_num_tokens, + max_beam_width=self._max_beam_width, + kv_connector_manager=self._kv_connector_manager, + estimating_kv_cache=estimating_kv_cache, + execution_stream=self._execution_stream, + # One-model draft specific overrides + model_config=self._draft_config, + dtype=self._draft_config.pretrained_config.torch_dtype, + is_draft=True, + layer_mask=spec_dec_layer_mask, + num_layers=num_draft_layers, + ) + def build_managers(self, resources: Dict, estimating_kv_cache: bool = False) -> None: @@ -514,9 +594,16 @@ class KvCacheCreator: raise NotImplementedError( "Connector manager is not supported for draft model.") - draft_kv_cache_manager = self._create_kv_cache_manager( - self._draft_model_engine, estimating_kv_cache - ) if self._draft_model_engine is not None else None + draft_kv_cache_manager = None + + # Two-model speculative decoding: draft model has separate engine + if self._draft_model_engine is not None: + draft_kv_cache_manager = self._create_kv_cache_manager( + self._draft_model_engine, estimating_kv_cache) + # One-model speculative decoding with different KV layouts + elif self._should_create_separate_draft_kv_cache(): + draft_kv_cache_manager = self._create_one_model_draft_kv_cache_manager( + estimating_kv_cache) resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager resources[ @@ -534,7 +621,7 @@ class KvCacheCreator: def _create_kv_cache_manager( - model_engine: PyTorchModelEngine, + model_engine: Optional[PyTorchModelEngine], kv_cache_manager_cls, mapping: Mapping, kv_cache_config: KvCacheConfig, @@ -547,13 +634,32 @@ def _create_kv_cache_manager( max_beam_width: int, kv_connector_manager: Optional[KvCacheConnectorManager], estimating_kv_cache: bool, - execution_stream: Optional[torch.cuda.Stream] = None) -> KVCacheManager: + execution_stream: Optional[torch.cuda.Stream] = None, + # Optional overrides for one-model draft case (when model_engine is None) + model_config: Optional[ModelConfig] = None, + dtype: Optional[torch.dtype] = None, + is_draft: Optional[bool] = None, + layer_mask: Optional[List[bool]] = None, + num_layers: Optional[int] = None) -> KVCacheManager: """ Returns: - A KVCacheManager instance for the given model_engine + A KVCacheManager instance for the given model engine or model config """ - config = model_engine.model.model_config.pretrained_config - quant_config = model_engine.model.model_config.quant_config + # Extract config from model_engine or use provided model_config + if model_config is not None: + config = model_config.pretrained_config + quant_config = model_config.quant_config + _model_config = model_config + else: + config = model_engine.model.model_config.pretrained_config + quant_config = model_engine.model.model_config.quant_config + _model_config = model_engine.model.model_config + + if dtype is None: + dtype = model_engine.dtype + + if is_draft is None: + is_draft = model_engine.is_draft_model hidden_size = config.hidden_size num_attention_heads = config.num_attention_heads @@ -569,10 +675,10 @@ def _create_kv_cache_manager( ): kv_cache_dtype = tensorrt_llm.bindings.DataType.NVFP4 else: - kv_cache_dtype = str_dtype_to_binding( - torch_dtype_to_str(model_engine.dtype)) + kv_cache_dtype = str_dtype_to_binding(torch_dtype_to_str(dtype)) - num_hidden_layers = config.num_hidden_layers + # Use provided num_layers if available, otherwise use config + num_hidden_layers = num_layers if num_layers is not None else config.num_hidden_layers if is_mla(config): kv_cache_manager = kv_cache_manager_cls( @@ -589,12 +695,13 @@ def _create_kv_cache_manager( spec_config=spec_config, vocab_size=config.vocab_size, max_beam_width=max_beam_width, - is_draft=model_engine.is_draft_model, + is_draft=is_draft, kv_connector_manager=kv_connector_manager if not estimating_kv_cache else None, sparse_attn_config=sparse_attn_config, is_estimating_kv_cache=estimating_kv_cache, execution_stream=execution_stream, + layer_mask=layer_mask, ) elif is_nemotron_hybrid(config): if max_beam_width > 1: @@ -606,9 +713,10 @@ def _create_kv_cache_manager( "Connector manager is not supported for MambaHybridCacheManager." ) - config = model_engine.model.model_config.pretrained_config num_layers = config.hybrid_override_pattern.count("*") - layer_mask = [char == "*" for char in config.hybrid_override_pattern] + hybrid_layer_mask = [ + char == "*" for char in config.hybrid_override_pattern + ] mamba_num_layers = config.hybrid_override_pattern.count("M") mamba_layer_mask = [ char == "M" for char in config.hybrid_override_pattern @@ -623,12 +731,13 @@ def _create_kv_cache_manager( mamba_num_layers, mamba_layer_mask, config.torch_dtype, - model_engine.model.model_config.quant_config.mamba_ssm_cache_dtype, + quant_config.mamba_ssm_cache_dtype + if quant_config is not None else None, # kv cache parameters kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, num_layers=num_layers, - layer_mask=layer_mask, + layer_mask=hybrid_layer_mask, num_kv_heads=num_key_value_heads, head_dim=head_dim, tokens_per_block=tokens_per_block, @@ -649,13 +758,12 @@ def _create_kv_cache_manager( raise NotImplementedError( "Connector manager is not supported for MambaHybridCacheManager." ) - config = model_engine.model.model_config.pretrained_config mamba_layer_mask = [ True if i % config.full_attention_interval != config.full_attention_interval - 1 else False for i in range(num_hidden_layers) ] - layer_mask = [ + hybrid_layer_mask = [ False if i % config.full_attention_interval != config.full_attention_interval - 1 else True for i in range(num_hidden_layers) @@ -673,12 +781,13 @@ def _create_kv_cache_manager( num_mamba_layers, mamba_layer_mask, config.torch_dtype, - model_engine.model.model_config.quant_config.mamba_ssm_cache_dtype, + quant_config.mamba_ssm_cache_dtype + if quant_config is not None else None, # kv cache parameters kv_cache_config, tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF, num_layers=num_layers, - layer_mask=layer_mask, + layer_mask=hybrid_layer_mask, num_kv_heads=num_key_value_heads, head_dim=head_dim, tokens_per_block=tokens_per_block, @@ -694,7 +803,7 @@ def _create_kv_cache_manager( # NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_for_vswa in KVCahceManager is_vswa = kv_cache_config.max_attention_window is not None and len( set(kv_cache_config.max_attention_window)) > 1 - binding_model_config = model_engine.model.model_config.get_bindings_model_config( + binding_model_config = _model_config.get_bindings_model_config( tokens_per_block=tokens_per_block) if is_vswa else None kv_cache_manager = kv_cache_manager_cls( @@ -713,12 +822,13 @@ def _create_kv_cache_manager( max_num_tokens=max_num_tokens, model_config=binding_model_config, max_beam_width=max_beam_width, - is_draft=model_engine.is_draft_model, + is_draft=is_draft, kv_connector_manager=kv_connector_manager if not estimating_kv_cache else None, sparse_attn_config=sparse_attn_config, is_estimating_kv_cache=estimating_kv_cache, execution_stream=execution_stream, + layer_mask=layer_mask, ) return kv_cache_manager diff --git a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py index 9b0b51fa31..7878215855 100644 --- a/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py +++ b/tensorrt_llm/_torch/pyexecutor/cuda_graph_runner.py @@ -16,6 +16,7 @@ from ..memory_buffer_utils import get_memory_buffers from ..modules.multi_stream_utils import with_multi_stream from ..speculative.eagle3 import Eagle3ResourceManager from ..speculative.mtp import SampleStateTensorsMTP +from ..speculative.utils import get_draft_kv_cache_manager from ..utils import make_weak_ref, piecewise_cuda_graph from .llm_request import get_draft_token_length from .mamba_cache_manager import MambaCacheManager @@ -435,12 +436,20 @@ class CUDAGraphRunner: # respect the requirement just in case that changes in the future. if self.padding_dummy_request is None: + # Get draft KV cache manager only for one-model speculative decoding. + # In two-model mode, each model has its own KV cache manager, so + # draft_kv_cache_manager should be None. + draft_kv_cache_manager = get_draft_kv_cache_manager( + self.spec_config, resource_manager) + self.padding_dummy_request = kv_cache_manager.add_dummy_requests( [CUDA_GRAPH_DUMMY_REQUEST_ID], is_gen=True, max_num_draft_tokens=runtime_draft_len, use_mrope=self.config.use_mrope, - max_beam_width=self.config.max_beam_width) + max_beam_width=self.config.max_beam_width, + draft_kv_cache_manager=draft_kv_cache_manager) + if self.padding_dummy_request is None: return 0 else: diff --git a/tensorrt_llm/_torch/pyexecutor/model_engine.py b/tensorrt_llm/_torch/pyexecutor/model_engine.py index e6ff77c993..de14006bfd 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_engine.py +++ b/tensorrt_llm/_torch/pyexecutor/model_engine.py @@ -46,8 +46,8 @@ from ..models.modeling_utils import DecoderModelForCausalLM from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer, MoeLoadBalancerIterContext) from ..peft.lora.cuda_graph_lora_manager import CudaGraphLoraManager -from ..speculative import (SpecMetadata, get_num_extra_kv_tokens, - get_spec_metadata, +from ..speculative import (SpecMetadata, get_draft_kv_cache_manager, + get_num_extra_kv_tokens, get_spec_metadata, update_spec_config_from_model_config) from ..speculative.drafting_loops import BaseDraftingLoopWrapper from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata @@ -567,6 +567,15 @@ class PyTorchModelEngine(ModelEngine): def use_beam_search(self): return self.max_beam_width > 1 + def _get_draft_kv_cache_manager( + self, resource_manager: ResourceManager + ) -> Optional[Union[KVCacheManager, KVCacheManagerV2]]: + """ + Returns the draft KV cache manager only in one-model speculative decoding + mode where the target model manages a separate draft KV cache. + """ + return get_draft_kv_cache_manager(self.spec_config, resource_manager) + @contextmanager def set_warmup_flag(self): prev_is_warmup = self.is_warmup @@ -909,6 +918,8 @@ class PyTorchModelEngine(ModelEngine): """A context manager to automatically free resources of a dummy batch.""" kv_cache_manager = resource_manager.get_resource_manager( self.kv_cache_manager_key) + draft_kv_cache_manager = self._get_draft_kv_cache_manager( + resource_manager) spec_resource_manager = resource_manager.get_resource_manager( ResourceManagerType.SPEC_RESOURCE_MANAGER) try: @@ -917,6 +928,8 @@ class PyTorchModelEngine(ModelEngine): if batch is not None and kv_cache_manager is not None: for req in batch.all_requests(): kv_cache_manager.free_resources(req) + if draft_kv_cache_manager is not None: + draft_kv_cache_manager.free_resources(req) if spec_resource_manager is not None: spec_resource_manager.free_resources(req) @@ -939,6 +952,9 @@ class PyTorchModelEngine(ModelEngine): """Creates a generic dummy ScheduledRequests object for warmup.""" kv_cache_manager = resource_manager.get_resource_manager( self.kv_cache_manager_key) + draft_kv_cache_manager = self._get_draft_kv_cache_manager( + resource_manager) + spec_resource_manager = resource_manager.get_resource_manager( ResourceManagerType.SPEC_RESOURCE_MANAGER) @@ -1011,7 +1027,8 @@ class PyTorchModelEngine(ModelEngine): is_gen=False, max_num_draft_tokens=self.runtime_draft_len, use_mrope=self.use_mrope, - num_extra_decoding_steps=num_extra_decoding_steps) + num_extra_decoding_steps=num_extra_decoding_steps, + draft_kv_cache_manager=draft_kv_cache_manager) if ctx_requests is None: return None @@ -1030,7 +1047,8 @@ class PyTorchModelEngine(ModelEngine): max_num_draft_tokens=self.max_total_draft_tokens, use_mrope=self.use_mrope, max_beam_width=self.max_beam_width, - num_extra_decoding_steps=num_extra_decoding_steps) + num_extra_decoding_steps=num_extra_decoding_steps, + draft_kv_cache_manager=draft_kv_cache_manager) if gen_requests is None: for r in ctx_requests: @@ -1058,6 +1076,8 @@ class PyTorchModelEngine(ModelEngine): self.kv_cache_manager_key) spec_resource_manager = resource_manager.get_resource_manager( ResourceManagerType.SPEC_RESOURCE_MANAGER) + draft_kv_cache_manager = self._get_draft_kv_cache_manager( + resource_manager) available_blocks = kv_cache_manager.get_num_free_blocks( ) // self.max_beam_width @@ -1075,7 +1095,8 @@ class PyTorchModelEngine(ModelEngine): max_num_draft_tokens=draft_len, use_mrope=self.use_mrope, max_beam_width=self.max_beam_width, - num_extra_decoding_steps=num_extra_decoding_steps) + num_extra_decoding_steps=num_extra_decoding_steps, + draft_kv_cache_manager=draft_kv_cache_manager) if requests is None: return None @@ -1083,6 +1104,12 @@ class PyTorchModelEngine(ModelEngine): available_tokens = kv_cache_manager.get_num_available_tokens( batch_size=batch_size, max_num_draft_tokens=draft_len) + # Also consider draft KV cache capacity when it exists + if draft_kv_cache_manager is not None: + draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens( + batch_size=batch_size, max_num_draft_tokens=draft_len) + available_tokens = min(available_tokens, draft_available_tokens) + # Add one dummy request with the maximum possible sequence length. max_seq_len = min( self.max_seq_len if max_seq_len is None else max_seq_len, @@ -1110,7 +1137,8 @@ class PyTorchModelEngine(ModelEngine): max_num_draft_tokens=draft_len, use_mrope=self.use_mrope, max_beam_width=self.max_beam_width, - num_extra_decoding_steps=num_extra_decoding_steps) + num_extra_decoding_steps=num_extra_decoding_steps, + draft_kv_cache_manager=draft_kv_cache_manager) if max_seq_len_request is None: for r in requests: @@ -1141,8 +1169,11 @@ class PyTorchModelEngine(ModelEngine): req.py_is_first_draft = True req.py_draft_tokens = [] - def _set_up_attn_metadata(self, kv_cache_manager: Union[KVCacheManager, - KVCacheManagerV2]): + def _set_up_attn_metadata( + self, + kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2], + draft_kv_cache_manager: Optional[Union[KVCacheManager, + KVCacheManagerV2]] = None): enable_context_mla_with_cached_kv = is_mla( self.model.model_config.pretrained_config) and ( self.attn_runtime_features.cache_reuse @@ -1191,6 +1222,7 @@ class PyTorchModelEngine(ModelEngine): max_num_tokens=self.max_num_tokens, max_num_sequences=self.batch_size * self.max_beam_width, kv_cache_manager=kv_cache_manager, + draft_kv_cache_manager=draft_kv_cache_manager, mapping=self.mapping, runtime_features=self.attn_runtime_features, enable_flash_mla=self.model.model_config.enable_flash_mla, @@ -1568,7 +1600,8 @@ class PyTorchModelEngine(ModelEngine): else: return self._apply_incremental_update_target( scheduled_requests, kv_cache_manager, attn_metadata, - spec_metadata, new_tensors_device, num_accepted_tokens_device) + spec_metadata, new_tensors_device, num_accepted_tokens_device, + resource_manager) @nvtx_range("_prepare_incremental_update_metadata") def _prepare_incremental_update_metadata( @@ -1834,7 +1867,8 @@ class PyTorchModelEngine(ModelEngine): attn_metadata: AttentionMetadata, spec_metadata: Optional[SpecMetadata] = None, new_tensors_device: Optional[SampleStateTensors] = None, - num_accepted_tokens_device: Optional[torch.Tensor] = None): + num_accepted_tokens_device: Optional[torch.Tensor] = None, + resource_manager: Optional[ResourceManager] = None): # Extract tensors from new_tensors_device new_tokens_device = new_tensors_device.new_tokens # [batch, 1 + draft_len] new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch] @@ -1968,6 +2002,7 @@ class PyTorchModelEngine(ModelEngine): 'position_ids': final_position_ids, 'inputs_embeds': None, "multimodal_params": [], + 'resource_manager': resource_manager, } if bool(lora_params): @@ -2781,6 +2816,7 @@ class PyTorchModelEngine(ModelEngine): 'position_ids': final_position_ids, 'inputs_embeds': None, "multimodal_params": multimodal_params_list, + 'resource_manager': resource_manager, } if bool(lora_params): @@ -2846,7 +2882,8 @@ class PyTorchModelEngine(ModelEngine): self, scheduled_requests: ScheduledRequests, attn_metadata: AttentionMetadata, - spec_metadata: Optional[SpecMetadata] = None): + spec_metadata: Optional[SpecMetadata] = None, + resource_manager: Optional[ResourceManager] = None): """ Prepare inputs for Pytorch Model. """ @@ -2950,7 +2987,8 @@ class PyTorchModelEngine(ModelEngine): 'position_ids': self.position_ids_cuda[:virtual_num_tokens].unsqueeze(0), 'inputs_embeds': None, - "multimodal_params": multimodal_params_list + "multimodal_params": multimodal_params_list, + 'resource_manager': resource_manager, } if bool(lora_params): @@ -2993,10 +3031,12 @@ class PyTorchModelEngine(ModelEngine): return inputs, None - def _prepare_star_attention_inputs(self, - scheduled_requests: ScheduledRequests, - kv_cache_manager, - attn_metadata: AttentionMetadata): + def _prepare_star_attention_inputs( + self, + scheduled_requests: ScheduledRequests, + kv_cache_manager, + attn_metadata: AttentionMetadata, + resource_manager: Optional[ResourceManager] = None): """ Prepare inputs for Pytorch Model. """ @@ -3212,7 +3252,8 @@ class PyTorchModelEngine(ModelEngine): 'attn_metadata': attn_metadata, 'input_ids': self.input_ids_cuda[:num_tokens], 'position_ids': self.position_ids_cuda[:num_tokens].unsqueeze(0), - 'inputs_embeds': None + 'inputs_embeds': None, + 'resource_manager': resource_manager, }, gather_ids if is_spec_decode else None def _get_lora_params_from_requests( @@ -3357,7 +3398,8 @@ class PyTorchModelEngine(ModelEngine): cp_type = self.mapping.cp_config['cp_type'] if CpType.STAR == cp_type: return self._prepare_star_attention_inputs( - scheduled_requests, kv_cache_manager, attn_metadata) + scheduled_requests, kv_cache_manager, attn_metadata, + resource_manager) elif CpType.HELIX == cp_type: # Take the usual route of _prepare_tp_inputs. pass @@ -3384,8 +3426,11 @@ class PyTorchModelEngine(ModelEngine): req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None): kv_cache_manager = resource_manager.get_resource_manager( self.kv_cache_manager_key) + draft_kv_cache_manager = self._get_draft_kv_cache_manager( + resource_manager) - attn_metadata = self._set_up_attn_metadata(kv_cache_manager) + attn_metadata = self._set_up_attn_metadata(kv_cache_manager, + draft_kv_cache_manager) if self.enable_spec_decode: spec_resource_manager = resource_manager.get_resource_manager( ResourceManagerType.SPEC_RESOURCE_MANAGER) @@ -3423,7 +3468,8 @@ class PyTorchModelEngine(ModelEngine): if kv_cache_manager is None: inputs, gather_ids = self._prepare_tp_inputs_no_cache( - scheduled_requests, attn_metadata, spec_metadata) + scheduled_requests, attn_metadata, spec_metadata, + resource_manager) with MoeLoadBalancerIterContext(moe_load_balancer): # Special handling for multimodal encoder only mode diff --git a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py index 6dafff2978..944192eeac 100644 --- a/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py +++ b/tensorrt_llm/_torch/pyexecutor/py_executor_creator.py @@ -313,6 +313,11 @@ def create_py_executor( has_draft_model_engine = spec_config.spec_dec_mode.has_draft_model() has_spec_drafter = spec_config.spec_dec_mode.has_spec_drafter() + # WAR for https://nvbugs/5807902 + # Disable separate draft KV cache in disaggregated mode + if cache_transceiver_config is not None or kv_connector_config is not None: + spec_config._allow_separate_draft_kv_cache = False + # chunk_unit_size may be changed to 64 when using flash mla attn_runtime_features = AttentionRuntimeFeatures( chunked_prefill=enable_chunked_context, @@ -623,6 +628,10 @@ def create_py_executor( if model_engine.model.model_config.is_generation: #NOTE: non-generation models do not have kv cache + + # Get draft config for one-engine speculative decoding if available + draft_config = getattr(model_engine.model, 'draft_config', None) + kv_cache_creator = KvCacheCreator( model_engine=model_engine, draft_model_engine=draft_model_engine, @@ -640,6 +649,7 @@ def create_py_executor( profiling_stage_data=profiling_stage_data, sparse_attention_config=sparse_attention_config, execution_stream=execution_stream, + draft_config=draft_config, ) estimating_kv_cache = kv_cache_creator.try_prepare_estimation() with allocation_scope( diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 2b973af2e6..df80d786de 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -135,7 +135,10 @@ def get_pp_layers( pp_layers = mapping.pp_layers(total_num_layers) if layer_mask is not None: pp_layers = [i for i in pp_layers if layer_mask[i]] - if spec_config is not None: + # Only add speculative layers when layer_mask is not provided. + # When layer_mask is provided, the caller explicitly controls which layers + # to include, so we should not add extra layers automatically. + if spec_config is not None and layer_mask is None: num_spec_layers = get_num_spec_layers(spec_config) total_num_layers += num_spec_layers if mapping.is_last_pp_rank(): @@ -586,6 +589,7 @@ class KVCacheManager(BaseResourceManager): # we need to make the KV cache manager aware that multiple autoregressive steps will # occur. num_extra_decoding_steps: int = 0, + draft_kv_cache_manager: Optional[BaseResourceManager] = None, ): available_blocks = self.get_num_free_blocks() # No padding if not enough KV cache space @@ -625,6 +629,12 @@ class KVCacheManager(BaseResourceManager): for _ in range(num_extra_decoding_steps): self.impl.add_token(req_id) + if draft_kv_cache_manager is not None: + draft_kv_cache_manager.impl.add_sequence( + req_id, token_num, beam_width, req) + for _ in range(self.num_extra_kv_tokens): + draft_kv_cache_manager.impl.add_token(req_id) + if is_gen: req.state = LlmRequestState.GENERATION_IN_PROGRESS req.prompt_len = token_num - 1 @@ -651,6 +661,10 @@ class KVCacheManager(BaseResourceManager): for _ in range(max_num_draft_tokens): self.impl.add_token(req_id) + if draft_kv_cache_manager is not None: + for _ in range(max_num_draft_tokens): + draft_kv_cache_manager.impl.add_token(req_id) + # TODO: Planning to get dummy_data from each model. Before that, we need to add dummy mrop_config to the request here. if use_mrope: dummy_mrope_position_ids = torch.arange( @@ -1886,7 +1900,8 @@ class KVCacheManagerV2(BaseResourceManager): max_num_draft_tokens: int = 0, use_mrope: bool = False, max_beam_width: int = 1, - num_extra_decoding_steps: int = 0): + num_extra_decoding_steps: int = 0, + draft_kv_cache_manager: Optional['BaseResourceManager'] = None): beam_width = max_beam_width requests = [] diff --git a/tensorrt_llm/_torch/speculative/__init__.py b/tensorrt_llm/_torch/speculative/__init__.py index a99a88b514..80307164be 100644 --- a/tensorrt_llm/_torch/speculative/__init__.py +++ b/tensorrt_llm/_torch/speculative/__init__.py @@ -1,14 +1,15 @@ from .auto_heuristic import suggest_spec_config from .eagle3 import Eagle3SpecMetadata -from .interface import SpecMetadata, SpecWorkerBase +from .interface import (SpecMetadata, SpecWorkerBase, + should_use_separate_draft_kv_cache) from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker from .ngram import NGramDrafter, NGramPoolManager from .save_hidden_state import SaveHiddenStatesDrafter from .spec_tree_manager import SpecTreeManager -from .utils import (get_num_extra_kv_tokens, get_num_spec_layers, - get_spec_decoder, get_spec_drafter, get_spec_metadata, - get_spec_resource_manager, get_spec_worker, - update_spec_config_from_model_config) +from .utils import (get_draft_kv_cache_manager, get_num_extra_kv_tokens, + get_num_spec_layers, get_spec_decoder, get_spec_drafter, + get_spec_metadata, get_spec_resource_manager, + get_spec_worker, update_spec_config_from_model_config) __all__ = [ "Eagle3SpecMetadata", @@ -20,6 +21,7 @@ __all__ = [ "SaveHiddenStatesDrafter", "SpecMetadata", "SpecWorkerBase", + "get_draft_kv_cache_manager", "get_num_extra_kv_tokens", "get_num_spec_layers", "get_spec_decoder", @@ -27,6 +29,7 @@ __all__ = [ "get_spec_metadata", "get_spec_resource_manager", "get_spec_worker", + "should_use_separate_draft_kv_cache", "update_spec_config_from_model_config", "suggest_spec_config", "SpecTreeManager", diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index fab4fb7b65..ebc621e649 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -358,8 +358,11 @@ class Eagle3OneModelSampler(MTPSampler): class Eagle3OneModelWorker(SpecWorkerBase): - def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping): - super().__init__() + def __init__(self, + spec_config: "EagleDecodingConfig", + mapping: Mapping, + use_separate_draft_kv_cache: bool = False): + super().__init__(use_separate_draft_kv_cache) self.spec_config = spec_config self.mapping = mapping @@ -369,8 +372,15 @@ class Eagle3OneModelWorker(SpecWorkerBase): # Skip torch.compile for now since current Torch is not compatible with Triton 3.4 # @torch.compile(options={"max-autotune": True}) - def forward(self, input_ids, position_ids, hidden_states, logits, - attn_metadata, spec_metadata, draft_model): + def forward(self, + input_ids, + position_ids, + hidden_states, + logits, + attn_metadata, + spec_metadata, + draft_model, + resource_manager=None): batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts num_gens = batch_size - num_contexts @@ -400,81 +410,91 @@ class Eagle3OneModelWorker(SpecWorkerBase): # Predict draft tokens next_draft_tokens = [] original_all_rank_num_tokens = attn_metadata.all_rank_num_tokens - for i in range(self.max_draft_len): - if i == 0: - start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] * - (self.max_draft_len + 1)).long() - gather_ids_gen = (start_ids_gen + - num_accepted_tokens[num_contexts:] - 1 + - attn_metadata.num_ctx_tokens) - gather_ids = torch.concat( - [spec_metadata.gather_ids[:num_contexts], gather_ids_gen], - dim=0) - else: - # All of the seq_len are 1, use batch_indices_cuda as gather_ids - gather_ids = spec_metadata.batch_indices_cuda[:batch_size] - if self.guided_decoder is not None: - new_tokens = inputs["input_ids"][gather_ids] - self.guided_decoder.add_draft_batch(new_tokens, - num_accepted_tokens, - draft_step=i) + # Get the draft KV cache manager if using separate layouts + draft_kv_cache_manager = self.get_draft_kv_cache_manager( + resource_manager) - # Update attn_metadata.all_rank_num_tokens for attention DP - if original_all_rank_num_tokens is not None: + with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager): + for i in range(self.max_draft_len): if i == 0: - attn_metadata.all_rank_num_tokens = original_all_rank_num_tokens - elif spec_metadata.all_rank_num_seqs is not None: - attn_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs + start_ids_gen = ( + spec_metadata.batch_indices_cuda[:num_gens] * + (self.max_draft_len + 1)).long() + gather_ids_gen = (start_ids_gen + + num_accepted_tokens[num_contexts:] - 1 + + attn_metadata.num_ctx_tokens) + gather_ids = torch.concat([ + spec_metadata.gather_ids[:num_contexts], gather_ids_gen + ], + dim=0) + else: + # All of the seq_len are 1, use batch_indices_cuda as gather_ids + gather_ids = spec_metadata.batch_indices_cuda[:batch_size] - hidden_states, hidden_states_to_save = draft_model.model(**inputs) - - # FIXME (jhaotingc): Currently we disable use_spec_decoding mode for Eagle engine nth steps except 1st step. - # Eagle engine takes in draft_len tokens from the previous step, run spec-dec mode with those tokens, - # then the following step can use regular decoding mode to generate 1 tokens per step. - # Currently the spec-dec mask for chained tree is not implemented yet. - # When token tree is supported, this can be removed and all steps may use spec-dec mode as well. - attn_metadata.use_spec_decoding = False - - logits = draft_model.logits_processor(hidden_states[gather_ids], - draft_model.lm_head, - attn_metadata, True) - if self.guided_decoder is not None: - d2t = getattr(draft_model.model, "d2t", None) - self.guided_decoder.execute_draft_batch(logits, - d2t, + if self.guided_decoder is not None: + new_tokens = inputs["input_ids"][gather_ids] + self.guided_decoder.add_draft_batch(new_tokens, + num_accepted_tokens, draft_step=i) - new_draft_token = self.draft_decoder(logits, draft_model) - next_draft_tokens.append(new_draft_token) - # update inputs - hidden_states = hidden_states_to_save[gather_ids] - position_ids = inputs["position_ids"][gather_ids] + 1 - # update attn_metadata - if i == 0: - attn_metadata._seq_lens[:batch_size].fill_(1) - attn_metadata._seq_lens_cuda[:batch_size].fill_(1) - attn_metadata.on_update() - # cannot run generation if their is no kv cache - if inputs["attn_metadata"].kv_cache_manager is not None: - attn_metadata.host_request_types[:attn_metadata. - num_contexts].fill_(1) - attn_metadata.num_contexts = 0 - # update kv_lens_cuda - if hasattr(attn_metadata, 'kv_lens_cuda'): - attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( - self.max_draft_len - num_accepted_tokens[num_contexts:]) - attn_metadata.kv_lens_cuda[:num_contexts] += 1 - elif hasattr(attn_metadata, 'kv_lens_cuda'): - attn_metadata.kv_lens_cuda[:batch_size] += 1 - # support attention dp - inputs = { - "input_ids": new_draft_token, - "position_ids": position_ids, - "hidden_states": hidden_states, - "attn_metadata": attn_metadata, - "spec_metadata": spec_metadata, - } + # Update attn_metadata.all_rank_num_tokens for attention DP + if original_all_rank_num_tokens is not None: + if i == 0: + attn_metadata.all_rank_num_tokens = original_all_rank_num_tokens + elif spec_metadata.all_rank_num_seqs is not None: + attn_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs + + hidden_states, hidden_states_to_save = draft_model.model( + **inputs) + + # FIXME (jhaotingc): Currently we disable use_spec_decoding mode for Eagle engine nth steps except 1st step. + # Eagle engine takes in draft_len tokens from the previous step, run spec-dec mode with those tokens, + # then the following step can use regular decoding mode to generate 1 tokens per step. + # Currently the spec-dec mask for chained tree is not implemented yet. + # When token tree is supported, this can be removed and all steps may use spec-dec mode as well. + attn_metadata.use_spec_decoding = False + + logits = draft_model.logits_processor(hidden_states[gather_ids], + draft_model.lm_head, + attn_metadata, True) + if self.guided_decoder is not None: + d2t = getattr(draft_model.model, "d2t", None) + self.guided_decoder.execute_draft_batch(logits, + d2t, + draft_step=i) + + new_draft_token = self.draft_decoder(logits, draft_model) + next_draft_tokens.append(new_draft_token) + # update inputs + hidden_states = hidden_states_to_save[gather_ids] + position_ids = inputs["position_ids"][gather_ids] + 1 + # update attn_metadata + if i == 0: + attn_metadata._seq_lens[:batch_size].fill_(1) + attn_metadata._seq_lens_cuda[:batch_size].fill_(1) + attn_metadata.on_update() + # cannot run generation if their is no kv cache + if inputs["attn_metadata"].kv_cache_manager is not None: + attn_metadata.host_request_types[:attn_metadata. + num_contexts].fill_(1) + attn_metadata.num_contexts = 0 + # update kv_lens_cuda + if hasattr(attn_metadata, 'kv_lens_cuda'): + attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( + self.max_draft_len - + num_accepted_tokens[num_contexts:]) + attn_metadata.kv_lens_cuda[:num_contexts] += 1 + elif hasattr(attn_metadata, 'kv_lens_cuda'): + attn_metadata.kv_lens_cuda[:batch_size] += 1 + # support attention dp + inputs = { + "input_ids": new_draft_token, + "position_ids": position_ids, + "hidden_states": hidden_states, + "attn_metadata": attn_metadata, + "spec_metadata": spec_metadata, + } next_draft_tokens = torch.stack(next_draft_tokens, dim=1) # restore attn_metadata to support cuda graph diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index eeee063a7f..9ace3c3e98 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -1,6 +1,7 @@ import copy import os from abc import ABC, abstractmethod +from contextlib import contextmanager from dataclasses import dataclass, field from enum import IntEnum, auto from typing import TYPE_CHECKING, List, Optional, Type @@ -11,10 +12,12 @@ from torch import nn from tensorrt_llm.logger import logger from ..._utils import get_sm_version -from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention +from ..attention_backend.trtllm import (AttentionBackend, TrtllmAttention, + TrtllmAttentionMetadata) from ..cute_dsl_kernels.argmax import argmax as cute_argmax from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE -from ..pyexecutor.resource_manager import BaseResourceManager +from ..pyexecutor.resource_manager import (BaseResourceManager, + ResourceManagerType) if TYPE_CHECKING: from ..pyexecutor.guided_decoder import CapturableGuidedDecoder @@ -26,6 +29,17 @@ if IS_FLASHINFER_AVAILABLE: FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS" +def should_use_separate_draft_kv_cache(spec_config) -> bool: + """ + Check if separate draft KV cache should be used for one-engine speculative decoding. + """ + if spec_config is None: + return False + if not spec_config.spec_dec_mode.use_one_engine(): + return False + return spec_config._allow_separate_draft_kv_cache + + def get_force_num_accepted_tokens() -> int: """ Read and parse the TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS environment variable. @@ -373,13 +387,14 @@ class SpecWorkerBase(nn.Module, ABC): Provides common functionality for sampling and token handling. """ - def __init__(self): + def __init__(self, use_separate_draft_kv_cache: bool = False): super().__init__() self.guided_decoder: Optional["CapturableGuidedDecoder"] = None self.force_num_accepted_tokens = get_force_num_accepted_tokens() self.use_flashinfer = IS_FLASHINFER_AVAILABLE and flashinfer.__version__ >= "0.6.0" self.seed = 0 self.offset = 0 + self.use_separate_draft_kv_cache = use_separate_draft_kv_cache @property @abstractmethod @@ -600,6 +615,52 @@ class SpecWorkerBase(nn.Module, ABC): else: return torch.empty(0, dtype=torch.int32, device="cuda") + def get_draft_kv_cache_manager(self, resource_manager): + """ + Get the draft KV cache manager if using separate KV cache layouts. + """ + if self.use_separate_draft_kv_cache and resource_manager is not None: + return resource_manager.get_resource_manager( + ResourceManagerType.DRAFT_KV_CACHE_MANAGER) + return None + + @contextmanager + def draft_kv_cache_context(self, attn_metadata, draft_kv_cache_manager): + """ + Context manager to temporarily switch to draft KV cache manager in one-engine speculative decoding. + + This swaps both the kv_cache_manager reference AND the block offset tensors, + since the target and draft KV caches have different block layouts. + """ + + # draft_kv_cache_manager is None if using two-engine speculative decoding or not enabling separate draft KV cache. + if draft_kv_cache_manager is None: + yield + return + + # Only TrtllmAttentionMetadata supports separate draft KV cache layouts + if not isinstance(attn_metadata, TrtllmAttentionMetadata): + yield + return + + # Save main KV cache manager and block offsets + target_kv_cache_manager = attn_metadata.kv_cache_manager + target_kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets + target_host_kv_cache_block_offsets = attn_metadata.host_kv_cache_block_offsets + + # Switch to draft KV cache manager and its block offsets + attn_metadata.kv_cache_manager = draft_kv_cache_manager + attn_metadata.kv_cache_block_offsets = attn_metadata.draft_kv_cache_block_offsets + attn_metadata.host_kv_cache_block_offsets = attn_metadata.draft_host_kv_cache_block_offsets + + try: + yield + finally: + # Restore main KV cache manager and block offsets + attn_metadata.kv_cache_manager = target_kv_cache_manager + attn_metadata.kv_cache_block_offsets = target_kv_cache_block_offsets + attn_metadata.host_kv_cache_block_offsets = target_host_kv_cache_block_offsets + def _sample_tokens_for_batch( self, logits: torch.Tensor, diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index a49c7f9b5e..9eb15e5780 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -367,8 +367,11 @@ class MTPSampler(Sampler[SampleStateMTP]): class MTPWorker(SpecWorkerBase): - def __init__(self, spec_config: "MTPDecodingConfig", model_config=None): - super().__init__() + def __init__(self, + spec_config: "MTPDecodingConfig", + model_config=None, + use_separate_draft_kv_cache: bool = False): + super().__init__(use_separate_draft_kv_cache) self.spec_config = spec_config self.model_config = model_config self.is_thop = False @@ -386,6 +389,7 @@ class MTPWorker(SpecWorkerBase): attn_metadata, spec_metadata, draft_model, + resource_manager=None, ): ''' Example: @@ -518,43 +522,50 @@ class MTPWorker(SpecWorkerBase): # update attn metadata if attn_metadata is not None: self.change_attn_metadata(num_accepted_tokens, attn_metadata) - draft_inputs.update(attn_metadata=attn_metadata) # Run MTP layers to predict draft tokens next_draft_tokens = [] last_tokens_idx = torch.cumsum( attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 - for i, mtp_layer in enumerate(draft_model.mtp_layers): - if self.guided_decoder is not None: - new_tokens = draft_inputs['input_ids'][last_tokens_idx] - self.guided_decoder.add_draft_batch(new_tokens, - num_accepted_tokens, - draft_step=i) - hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens, - **draft_inputs) - logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head, - attn_metadata).float() - if self.guided_decoder is not None: - self.guided_decoder.execute_draft_batch(logits, draft_step=i) + draft_kv_cache_manager = self.get_draft_kv_cache_manager( + resource_manager) - new_draft_token = self.draft_sampler(logits) - next_draft_tokens.append(new_draft_token) - # shift input_ids and hidden_states - input_ids = draft_inputs["input_ids"] - input_ids[:-1] = input_ids[1:].clone() - input_ids[last_tokens_idx] = new_draft_token - draft_hidden_states = draft_inputs["hidden_states"] - draft_hidden_states[:-1] = draft_hidden_states[1:].clone() - draft_hidden_states[last_tokens_idx] = hidden_states[ - last_tokens_idx, :] - draft_inputs = { - "input_ids": input_ids, - "position_ids": draft_inputs["position_ids"], - "hidden_states": draft_hidden_states, - "attn_metadata": draft_inputs["attn_metadata"], - } - next_draft_tokens = torch.stack(next_draft_tokens, dim=1) + with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager): + for i, mtp_layer in enumerate(draft_model.mtp_layers): + if self.guided_decoder is not None: + new_tokens = draft_inputs['input_ids'][last_tokens_idx] + self.guided_decoder.add_draft_batch(new_tokens, + num_accepted_tokens, + draft_step=i) + + hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens, + **draft_inputs) + + logits = mtp_layer.shared_head(hidden_states, + draft_model.lm_head, + attn_metadata).float() + if self.guided_decoder is not None: + self.guided_decoder.execute_draft_batch(logits, + draft_step=i) + + new_draft_token = self.draft_sampler(logits) + next_draft_tokens.append(new_draft_token) + # shift input_ids and hidden_states + input_ids = draft_inputs["input_ids"] + input_ids[:-1] = input_ids[1:].clone() + input_ids[last_tokens_idx] = new_draft_token + draft_hidden_states = draft_inputs["hidden_states"] + draft_hidden_states[:-1] = draft_hidden_states[1:].clone() + draft_hidden_states[last_tokens_idx] = hidden_states[ + last_tokens_idx, :] + draft_inputs = { + "input_ids": input_ids, + "position_ids": draft_inputs["position_ids"], + "hidden_states": draft_hidden_states, + "attn_metadata": draft_inputs["attn_metadata"], + } + next_draft_tokens = torch.stack(next_draft_tokens, dim=1) # restore attn metadata if attn_metadata is not None: @@ -573,6 +584,39 @@ class MTPWorker(SpecWorkerBase): 'next_new_tokens': next_new_tokens } + def skip_forward( + self, + input_ids, + position_ids, + hidden_states, + logits, + attn_metadata, + spec_metadata, + draft_model, + resource_manager=None, + ): + batch_size = attn_metadata.num_seqs + mtp_num_modules = self.spec_config.num_nextn_predict_layers + accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), + dtype=torch.int, + device=logits.device) + num_accepted_tokens = torch.ones(batch_size, + dtype=torch.int, + device=logits.device) + next_draft_tokens = torch.empty((batch_size, mtp_num_modules), + dtype=torch.int, + device=logits.device) + next_new_tokens = torch.empty((batch_size, (mtp_num_modules + 1)), + dtype=torch.int, + device=logits.device) + return { + 'logits': logits, + 'new_tokens': accepted_tokens, + 'new_tokens_lens': num_accepted_tokens, + 'next_draft_tokens': next_draft_tokens, + 'next_new_tokens': next_new_tokens + } + def update_mtp_hidden_states( self, input_ids: torch.IntTensor, @@ -1146,8 +1190,9 @@ class MTPEagleWorker(MTPWorker): def __init__(self, spec_config: "MTPDecodingConfig", - model_config: Optional[ModelConfig] = None): - super().__init__(spec_config, model_config) + model_config: Optional[ModelConfig] = None, + use_separate_draft_kv_cache: bool = False): + super().__init__(spec_config, model_config, use_separate_draft_kv_cache) self.model_config = model_config self.mtp_num_modules = spec_config.num_nextn_predict_layers self._is_mamba_hybrid_cache = None @@ -1170,6 +1215,7 @@ class MTPEagleWorker(MTPWorker): attn_metadata, spec_metadata, draft_model, + resource_manager=None, ): batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts @@ -1212,123 +1258,134 @@ class MTPEagleWorker(MTPWorker): attn_metadata=attn_metadata, spec_metadata=spec_metadata) + # Get the draft KV cache manager if using separate layouts + draft_kv_cache_manager = self.get_draft_kv_cache_manager( + resource_manager) + # Predict draft tokens next_draft_tokens = [] - for i in range(self.mtp_num_modules): - if i == 0: - hidden_states = draft_model.mtp_layers[0]( - embed_tokens=draft_model.embed_tokens, - all_rank_num_tokens=spec_metadata.all_rank_num_tokens, - **inputs) - start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] * - (self.mtp_num_modules + 1)).long() - gather_ids_gen = (start_ids_gen + - num_accepted_tokens[num_contexts:] - 1 + - attn_metadata.num_ctx_tokens) - gather_ids = torch.concat( - [last_tokens_idx[:num_contexts], gather_ids_gen], dim=0) - else: - hidden_states = draft_model.mtp_layers[0]( - embed_tokens=draft_model.embed_tokens, - all_rank_num_tokens=spec_metadata. - subseq_all_rank_num_tokens, - **inputs) - # All of the seq_len are 1, use batch_indices_cuda as gather_ids - gather_ids = spec_metadata.batch_indices_cuda[:batch_size] + with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager): + for i in range(self.mtp_num_modules): + if i == 0: + hidden_states = draft_model.mtp_layers[0]( + embed_tokens=draft_model.embed_tokens, + all_rank_num_tokens=spec_metadata.all_rank_num_tokens, + **inputs) - if self.guided_decoder is not None: - new_tokens = inputs["input_ids"][gather_ids] - self.guided_decoder.add_draft_batch(new_tokens, - num_accepted_tokens, - draft_step=i) - if self.model_config.mapping.enable_attention_dp and \ - getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False): - hidden_states_gathered = hidden_states[gather_ids] - token_count = hidden_states_gathered.view( - -1, hidden_states_gathered.shape[-1]).shape[0] - max_num_requests = spec_metadata.max_num_requests - pad_len = max_num_requests - token_count - if pad_len > 0: - padded_hidden_states = F.pad(hidden_states_gathered.view( - -1, hidden_states_gathered.shape[-1]), - (0, 0, 0, pad_len), - mode="constant", - value=0) - elif pad_len == 0: - padded_hidden_states = hidden_states_gathered.view( - -1, hidden_states_gathered.shape[-1]) + start_ids_gen = ( + spec_metadata.batch_indices_cuda[:num_gens] * + (self.mtp_num_modules + 1)).long() + gather_ids_gen = (start_ids_gen + + num_accepted_tokens[num_contexts:] - 1 + + attn_metadata.num_ctx_tokens) + gather_ids = torch.concat( + [last_tokens_idx[:num_contexts], gather_ids_gen], dim=0) else: - raise ValueError( - f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported" - ) - logits = draft_model.mtp_layers[0].shared_head( - padded_hidden_states, draft_model.lm_head, attn_metadata, - True) - else: - logits = draft_model.mtp_layers[0].shared_head( - hidden_states[gather_ids], draft_model.lm_head, - attn_metadata, True) - if self.guided_decoder is not None: - self.guided_decoder.execute_draft_batch(logits, draft_step=i) + hidden_states = draft_model.mtp_layers[0]( + embed_tokens=draft_model.embed_tokens, + all_rank_num_tokens=spec_metadata. + subseq_all_rank_num_tokens, + **inputs) - if self.model_config.mapping.enable_attention_dp and \ - getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False): - mapping_lm_head_tp = draft_model.mtp_layers[ - 0].shared_head.mapping_lm_head_tp - new_draft_token = self.draft_sampler(logits, mapping_lm_head_tp) - new_draft_token = new_draft_token[:token_count] - else: - new_draft_token = self.draft_sampler(logits) + # All of the seq_len are 1, use batch_indices_cuda as gather_ids + gather_ids = spec_metadata.batch_indices_cuda[:batch_size] - hidden_states, position_ids = self.update_draft_tokens( - next_draft_tokens, new_draft_token, hidden_states, gather_ids, - inputs) - # update attn_metadata - if i == 0: - attn_metadata._seq_lens[:batch_size].fill_(1) - attn_metadata._seq_lens_cuda[:batch_size].fill_(1) - attn_metadata.on_update() - # cannot run generation if their is no kv cache - has_kv_cache = inputs[ - "attn_metadata"].kv_cache_manager is not None - if has_kv_cache: - attn_metadata.host_request_types[:attn_metadata. - num_contexts].fill_(1) - attn_metadata.num_contexts = 0 - # update kv_lens_cuda - if hasattr(attn_metadata, 'kv_lens_cuda'): - attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( - self.mtp_num_modules - - num_accepted_tokens[num_contexts:]) - attn_metadata.kv_lens_cuda[:num_contexts] += 1 - # update metadata for flash mla - if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla: - reorder_block_ids_per_seq = torch.cat([ - attn_metadata. - kv_block_ids_per_seq[num_contexts:batch_size], - attn_metadata.kv_block_ids_per_seq[:num_contexts] - ]) - attn_metadata.block_ids_per_seq[:batch_size, :].copy_( - reorder_block_ids_per_seq, non_blocking=True) - # update metadata - # some attention metadata needs to be updated when changing seq_lens/kv_lens - attn_metadata.update_for_spec_dec() - elif hasattr(attn_metadata, 'kv_lens_cuda'): + if self.guided_decoder is not None: + new_tokens = inputs["input_ids"][gather_ids] + self.guided_decoder.add_draft_batch(new_tokens, + num_accepted_tokens, + draft_step=i) + if self.model_config.mapping.enable_attention_dp and \ + getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False): + hidden_states_gathered = hidden_states[gather_ids] + token_count = hidden_states_gathered.view( + -1, hidden_states_gathered.shape[-1]).shape[0] + max_num_requests = spec_metadata.max_num_requests + pad_len = max_num_requests - token_count + if pad_len > 0: + padded_hidden_states = F.pad( + hidden_states_gathered.view( + -1, hidden_states_gathered.shape[-1]), + (0, 0, 0, pad_len), + mode="constant", + value=0) + elif pad_len == 0: + padded_hidden_states = hidden_states_gathered.view( + -1, hidden_states_gathered.shape[-1]) + else: + raise ValueError( + f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported" + ) + logits = draft_model.mtp_layers[0].shared_head( + padded_hidden_states, draft_model.lm_head, + attn_metadata, True) + else: + logits = draft_model.mtp_layers[0].shared_head( + hidden_states[gather_ids], draft_model.lm_head, + attn_metadata, True) + if self.guided_decoder is not None: + self.guided_decoder.execute_draft_batch(logits, + draft_step=i) - @torch.compile(options={"max-autotune": True}) - def update_kv_lens(kv_lens_cuda, batch_size): - kv_lens_cuda[:batch_size] += 1 + if self.model_config.mapping.enable_attention_dp and \ + getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False): + mapping_lm_head_tp = draft_model.mtp_layers[ + 0].shared_head.mapping_lm_head_tp + new_draft_token = self.draft_sampler( + logits, mapping_lm_head_tp) + new_draft_token = new_draft_token[:token_count] + else: + new_draft_token = self.draft_sampler(logits) - update_kv_lens(attn_metadata.kv_lens_cuda, batch_size) - # update metadata - # some attention metadata needs to be updated when changing kv_lens - attn_metadata.update_for_spec_dec() - inputs = { - "input_ids": new_draft_token, - "position_ids": position_ids, - "hidden_states": hidden_states, - "attn_metadata": attn_metadata, - } + hidden_states, position_ids = self.update_draft_tokens( + next_draft_tokens, new_draft_token, hidden_states, + gather_ids, inputs) + # update attn_metadata + if i == 0: + attn_metadata._seq_lens[:batch_size].fill_(1) + attn_metadata._seq_lens_cuda[:batch_size].fill_(1) + attn_metadata.on_update() + # cannot run generation if their is no kv cache + has_kv_cache = inputs[ + "attn_metadata"].kv_cache_manager is not None + if has_kv_cache: + attn_metadata.host_request_types[:attn_metadata. + num_contexts].fill_(1) + attn_metadata.num_contexts = 0 + # update kv_lens_cuda + if hasattr(attn_metadata, 'kv_lens_cuda'): + attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= ( + self.mtp_num_modules - + num_accepted_tokens[num_contexts:]) + attn_metadata.kv_lens_cuda[:num_contexts] += 1 + # update metadata for flash mla + if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla: + reorder_block_ids_per_seq = torch.cat([ + attn_metadata. + kv_block_ids_per_seq[num_contexts:batch_size], + attn_metadata.kv_block_ids_per_seq[:num_contexts] + ]) + attn_metadata.block_ids_per_seq[:batch_size, :].copy_( + reorder_block_ids_per_seq, non_blocking=True) + # update metadata + # some attention metadata needs to be updated when changing seq_lens/kv_lens + attn_metadata.update_for_spec_dec() + elif hasattr(attn_metadata, 'kv_lens_cuda'): + + @torch.compile(options={"max-autotune": True}) + def update_kv_lens(kv_lens_cuda, batch_size): + kv_lens_cuda[:batch_size] += 1 + + update_kv_lens(attn_metadata.kv_lens_cuda, batch_size) + # update metadata + # some attention metadata needs to be updated when changing kv_lens + attn_metadata.update_for_spec_dec() + inputs = { + "input_ids": new_draft_token, + "position_ids": position_ids, + "hidden_states": hidden_states, + "attn_metadata": attn_metadata, + } # restore attn_metadata to support cuda graph self._restore_attn_metadata_from_spec_dec(attn_metadata) diff --git a/tensorrt_llm/_torch/speculative/utils.py b/tensorrt_llm/_torch/speculative/utils.py index 1eb036de99..cd35fa8609 100644 --- a/tensorrt_llm/_torch/speculative/utils.py +++ b/tensorrt_llm/_torch/speculative/utils.py @@ -220,14 +220,19 @@ def get_num_spec_layers(spec_config): return 0 -def get_spec_worker(spec_config, model_config, mapping): +def get_spec_worker(spec_config, + model_config, + mapping, + use_separate_draft_kv_cache: bool = False): spec_dec_mode = spec_config.spec_dec_mode if spec_dec_mode.is_mtp_vanilla(): - return MTPWorker(spec_config, model_config) + return MTPWorker(spec_config, model_config, use_separate_draft_kv_cache) if spec_dec_mode.is_mtp_eagle_one_model(): - return MTPEagleWorker(spec_config, model_config) + return MTPEagleWorker(spec_config, model_config, + use_separate_draft_kv_cache) if spec_dec_mode.is_eagle3_one_model(): - return Eagle3OneModelWorker(spec_config, mapping) + return Eagle3OneModelWorker(spec_config, mapping, + use_separate_draft_kv_cache) return None @@ -243,6 +248,21 @@ def get_num_extra_kv_tokens(spec_config): return 0 +def get_draft_kv_cache_manager(spec_config, resource_manager): + """ + Returns the draft KV cache manager only in one-model speculative decoding + mode where the target model manages a separate draft KV cache. + """ + from ..pyexecutor.resource_manager import ResourceManagerType + + if spec_config is None: + return None + if not spec_config.spec_dec_mode.use_one_engine(): + return None + return resource_manager.get_resource_manager( + ResourceManagerType.DRAFT_KV_CACHE_MANAGER) + + def update_spec_config_from_model_config(spec_config, model_config): if spec_config.spec_dec_mode.is_mtp_one_model(): # Use `max_draft_len` for several low-level APIs. TODO: Remove this after distinguishing them. diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 39f1cdc80a..7fa2d335dc 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -720,6 +720,8 @@ class DecodingBaseConfig(StrictBaseModel): _allow_greedy_draft_tokens: bool = PrivateAttr(True) # Internal: record decoding_type alias used during parsing (for warnings). _decoding_type_alias: Optional[str] = PrivateAttr(default=None) + # If set, drafting will use separate KV cache in one-model speculative decoding. + _allow_separate_draft_kv_cache: bool = PrivateAttr(True) @field_validator('draft_len_schedule') @classmethod diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index c8ede3ed8c..c8d8880c31 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -155,8 +155,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str, disable_overlap_scheduler: bool, enable_block_reuse: bool, use_one_model: bool, enable_chunked_prefill: bool, use_chain_drafter: bool, multi_batch: bool, - attention_dp: bool, use_hf_speculative_model: bool, - request): + attention_dp: bool, use_hf_speculative_model: bool): # Eagle3 one model works with overlap scheduler and block reuse. total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 if total_mem_gb < 35: