[TRTLLM-10321][feat] Support different KV cache layout for one-model spec dec (#10502)

Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
Ziyi Xiong 2026-02-10 05:16:02 +08:00 committed by GitHub
parent 092f4ce774
commit e76b634251
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 702 additions and 287 deletions

View File

@ -64,6 +64,8 @@ class AttentionMetadata:
max_num_sequences: Optional[int] = None
# The KV cache manager.
kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2]
# Draft KV cache manager for one-model speculative decoding with separate KV cache layouts
draft_kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2] = None
mapping: Optional[Mapping] = None
enable_flash_mla: bool = False

View File

@ -974,6 +974,7 @@ class RocketKVCacheManager(KVCacheManager):
use_mrope: bool = False,
max_beam_width: int = 1,
num_extra_decoding_steps: int = 0,
draft_kv_cache_manager=None,
):
requests = super().add_dummy_requests(
request_ids=request_ids,
@ -984,6 +985,7 @@ class RocketKVCacheManager(KVCacheManager):
use_mrope=use_mrope,
max_beam_width=max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps,
draft_kv_cache_manager=draft_kv_cache_manager,
)
if prepare_resource:
for req in requests:

View File

@ -679,6 +679,12 @@ class TrtllmAttentionMetadata(AttentionMetadata):
helix_is_inactive_rank: Optional[torch.Tensor] = None
helix_is_inactive_rank_cpu: Optional[torch.Tensor] = None
# Block offsets for the target and draft KV caches
kv_cache_block_offsets: Optional[torch.Tensor] = None
host_kv_cache_block_offsets: Optional[torch.Tensor] = None
draft_kv_cache_block_offsets: Optional[torch.Tensor] = None
draft_host_kv_cache_block_offsets: Optional[torch.Tensor] = None
@property
def max_seq_len(self) -> int:
"""
@ -786,6 +792,27 @@ class TrtllmAttentionMetadata(AttentionMetadata):
)
self.block_ids_per_seq = None
self.kv_block_ids_per_seq = None
# Allocate separate block offset tensors for draft KV cache manager
# Used in one-model speculative decoding with different KV cache layouts
if self.draft_kv_cache_manager is not None:
self.draft_kv_cache_block_offsets = self.get_empty(
buffers,
[
self.draft_kv_cache_manager.num_pools,
self.max_num_sequences, 2,
self.draft_kv_cache_manager.max_blocks_per_seq
],
cache_name="draft_kv_cache_block_offsets",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.draft_host_kv_cache_block_offsets = torch.empty_like(
self.draft_kv_cache_block_offsets,
device='cpu',
pin_memory=True,
)
if self.enable_flash_mla:
self.block_ids_per_seq = self.get_empty(
buffers,
@ -987,6 +1014,25 @@ class TrtllmAttentionMetadata(AttentionMetadata):
assert self.kv_lens[:self.num_seqs].max(
) <= self.kv_cache_manager.max_seq_len, error_message
# Also prepare draft KV cache block offsets if draft_kv_cache_manager exists
if self.draft_kv_cache_manager is not None:
# Copy blocks for all context requests
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
self.draft_host_kv_cache_block_offsets,
self.request_ids[:self.num_contexts], 1, 0)
# Copy blocks for all generation requests
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
self.draft_host_kv_cache_block_offsets,
self.request_ids[self.num_contexts:], self.beam_width,
self.num_contexts)
for pool_idx in range(
self.draft_host_kv_cache_block_offsets.shape[0]):
self.draft_kv_cache_block_offsets[
pool_idx, :self.num_seqs].copy_(
self.draft_host_kv_cache_block_offsets[
pool_idx, :self.num_seqs],
non_blocking=True)
self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
# Don't use self.kv_lens here because it includes extra tokens.
# Use actual KV length (without extra tokens) for kv_lens_runtime,

View File

@ -1840,16 +1840,18 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
input_ids: torch.IntTensor = None,
position_ids: Optional[torch.IntTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
return_context_logits: bool = False,
spec_metadata: Optional[SpecMetadata] = None,
resource_manager=None,
**kwargs,
) -> torch.Tensor:
return super().forward(attn_metadata=attn_metadata,
input_ids=input_ids,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
spec_metadata=spec_metadata,
return_context_logits=return_context_logits,
spec_metadata=spec_metadata,
resource_manager=resource_manager,
**kwargs)
def load_weights(self, weights: ConsumableWeightsDict):

View File

@ -1401,6 +1401,7 @@ class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model,
inputs_embeds: Optional[torch.FloatTensor] = None,
return_context_logits: bool = False,
spec_metadata: Optional[SpecMetadata] = None,
resource_manager=None,
**kwargs,
) -> torch.Tensor:
multimodal_params = kwargs.get("multimodal_params", [])
@ -1422,7 +1423,8 @@ class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model,
position_ids,
inputs_embeds,
spec_metadata=spec_metadata,
return_context_logits=return_context_logits)
return_context_logits=return_context_logits,
resource_manager=resource_manager)
def infer_max_seq_len(self):
if self.model_config.attn_backend.upper() != 'TRTLLM':

View File

@ -19,7 +19,8 @@ from ..modules.linear import (Linear, TensorParallelMode, WeightMode,
WeightsLoadingConfig)
from ..modules.rms_norm import RMSNorm
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
from ..speculative import SpecMetadata, get_spec_worker
from ..speculative import (SpecMetadata, get_spec_worker,
should_use_separate_draft_kv_cache)
from ..utils import AuxStreamType
from .checkpoints.base_weight_mapper import BaseWeightMapper
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel,
@ -931,6 +932,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
vocab_size=model_config.pretrained_config.vocab_size)
self.draft_model = None
self.draft_config = None
self.use_separate_draft_kv_cache = False
spec_config = getattr(model_config, 'spec_config', None)
if spec_config and spec_config.spec_dec_mode.use_one_engine():
if spec_config.spec_dec_mode.is_eagle3_one_model():
@ -964,11 +966,16 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
self.draft_config.quant_config.kv_cache_quant_algo = \
model_config.quant_config.kv_cache_quant_algo
self.use_separate_draft_kv_cache = should_use_separate_draft_kv_cache(
spec_config)
self.draft_model = get_draft_model(model_config, self.draft_config,
self.lm_head, self.model)
self.spec_worker = get_spec_worker(model_config.spec_config,
model_config,
model_config.mapping)
self.spec_worker = get_spec_worker(
model_config.spec_config,
model_config,
model_config.mapping,
use_separate_draft_kv_cache=self.use_separate_draft_kv_cache)
self.epilogue.append(self.draft_model)
self.epilogue.append(self.spec_worker)
@ -987,6 +994,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
inputs_embeds: Optional[torch.FloatTensor] = None,
return_context_logits: bool = False,
spec_metadata: Optional[SpecMetadata] = None,
resource_manager=None,
**kwargs,
) -> torch.Tensor:
hidden_states = self.model(
@ -1030,7 +1038,8 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
logits=logits,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
draft_model=self.draft_model)
draft_model=self.draft_model,
resource_manager=resource_manager)
else:
logits = self.logits_processor.forward(
hidden_states,

View File

@ -25,7 +25,8 @@ from tensorrt_llm.mapping import CpType, Mapping
from ..attention_backend import get_sparse_attn_kv_cache_manager
from ..model_config import ModelConfig
from ..speculative import get_num_extra_kv_tokens, get_spec_decoder
from ..speculative import (get_num_extra_kv_tokens, get_num_spec_layers,
get_spec_decoder, should_use_separate_draft_kv_cache)
from .config_utils import is_mla, is_nemotron_hybrid, is_qwen3_next
from .guided_decoder import GuidedDecoder
from .kv_cache_connector import KvCacheConnectorManager
@ -80,6 +81,7 @@ class KvCacheCreator:
sparse_attention_config: SparseAttentionConfig,
profiling_stage_data: Optional[dict],
execution_stream: Optional[torch.cuda.Stream] = None,
draft_config: Optional[ModelConfig] = None,
):
self._model_engine = model_engine
self._draft_model_engine = draft_model_engine
@ -103,6 +105,7 @@ class KvCacheCreator:
self._execution_stream = execution_stream
if self._kv_cache_manager_cls == KVCacheManager and kv_cache_config.use_kv_cache_manager_v2:
self._kv_cache_manager_cls = KVCacheManagerV2
self._draft_config = draft_config
def _get_kv_size_per_token(self):
model_config = self._model_engine.model.model_config
@ -115,6 +118,12 @@ class KvCacheCreator:
draft_model_config,
mapping,
tokens_per_block=self._tokens_per_block)
elif self._should_create_separate_draft_kv_cache():
# One-model draft with separate KV cache layout
kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token(
self._draft_config,
mapping,
tokens_per_block=self._tokens_per_block)
return kv_size_per_token
def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
@ -397,9 +406,12 @@ class KvCacheCreator:
# get kv cache stats for both model and draft model
kv_stats = py_executor.resource_manager.resource_managers.get(
ResourceManagerType.KV_CACHE_MANAGER).get_kv_cache_stats()
kv_stats_draft = py_executor.resource_manager.resource_managers.get(
ResourceManagerType.DRAFT_KV_CACHE_MANAGER).get_kv_cache_stats(
) if self._draft_model_engine is not None else None
# Get draft KV cache stats if present (either from two-model mode or one-model
# mode with separate draft KV cache)
draft_kv_cache_manager = py_executor.resource_manager.resource_managers.get(
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
kv_stats_draft = draft_kv_cache_manager.get_kv_cache_stats(
) if draft_kv_cache_manager is not None else None
# get total allocated bytes
allocated_bytes = kv_stats.allocated_bytes + (
@ -466,6 +478,15 @@ class KvCacheCreator:
mapping = self._mapping
assert model_engine.model.model_config.is_generation, "Only construct KV cache for generation models."
# When using separate draft KV cache in one-model speculative decoding,
# use layer_mask to include only target layers. The draft layers should
# only be in the separate draft KV cache manager.
# We still pass spec_config so that num_extra_kv_tokens is calculated.
spec_dec_layer_mask = None
if self._should_create_separate_draft_kv_cache():
num_target_layers = model_engine.model.model_config.pretrained_config.num_hidden_layers
spec_dec_layer_mask = [True] * num_target_layers
kv_cache_manager = _create_kv_cache_manager(
model_engine=model_engine,
kv_cache_manager_cls=self._kv_cache_manager_cls,
@ -481,6 +502,7 @@ class KvCacheCreator:
kv_connector_manager=self._kv_connector_manager,
estimating_kv_cache=estimating_kv_cache,
execution_stream=self._execution_stream,
layer_mask=spec_dec_layer_mask,
)
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to self
@ -503,6 +525,64 @@ class KvCacheCreator:
return kv_cache_manager
def _should_create_separate_draft_kv_cache(self) -> bool:
"""
Check if we need a separate draft KV cache manager for one-model mode.
Returns True if the speculative config has use_separate_draft_kv_cache=True.
"""
if self._draft_config is None:
return False
return should_use_separate_draft_kv_cache(self._speculative_config)
def _create_one_model_draft_kv_cache_manager(
self,
estimating_kv_cache: bool = False) -> Optional[KVCacheManager]:
"""
Create a KV cache manager for draft model layers in one-model mode
when target and draft have different KV cache layouts.
"""
# Get target model's num_hidden_layers to compute correct layer indices.
# Draft model layers in one-model mode start at target_num_layers.
target_pretrained_config = self._model_engine.model.model_config.pretrained_config
target_num_layers = target_pretrained_config.num_hidden_layers
# Use get_num_spec_layers to get the correct number of draft layers
# for the speculative decoding mode (e.g., num_eagle_layers for Eagle3)
num_draft_layers = get_num_spec_layers(self._speculative_config)
# Create layer_mask: False for target layers, True for draft layers.
# This ensures the draft KV cache manager uses the correct layer indices
# (e.g., layers 32, 33, ... instead of 0, 1, ...).
spec_dec_layer_mask = [False
] * target_num_layers + [True] * num_draft_layers
# Get the appropriate KV cache manager class for the draft model
draft_kv_cache_manager_cls = get_kv_cache_manager_cls(
self._draft_config)
return _create_kv_cache_manager(
model_engine=None,
kv_cache_manager_cls=draft_kv_cache_manager_cls,
mapping=self._mapping,
kv_cache_config=self._kv_cache_config,
tokens_per_block=self._tokens_per_block,
max_seq_len=self._max_seq_len,
max_batch_size=self._max_batch_size,
spec_config=self._speculative_config,
sparse_attn_config=None, # Not applicable for draft in one-model mode
max_num_tokens=self._max_num_tokens,
max_beam_width=self._max_beam_width,
kv_connector_manager=self._kv_connector_manager,
estimating_kv_cache=estimating_kv_cache,
execution_stream=self._execution_stream,
# One-model draft specific overrides
model_config=self._draft_config,
dtype=self._draft_config.pretrained_config.torch_dtype,
is_draft=True,
layer_mask=spec_dec_layer_mask,
num_layers=num_draft_layers,
)
def build_managers(self,
resources: Dict,
estimating_kv_cache: bool = False) -> None:
@ -514,9 +594,16 @@ class KvCacheCreator:
raise NotImplementedError(
"Connector manager is not supported for draft model.")
draft_kv_cache_manager = self._create_kv_cache_manager(
self._draft_model_engine, estimating_kv_cache
) if self._draft_model_engine is not None else None
draft_kv_cache_manager = None
# Two-model speculative decoding: draft model has separate engine
if self._draft_model_engine is not None:
draft_kv_cache_manager = self._create_kv_cache_manager(
self._draft_model_engine, estimating_kv_cache)
# One-model speculative decoding with different KV layouts
elif self._should_create_separate_draft_kv_cache():
draft_kv_cache_manager = self._create_one_model_draft_kv_cache_manager(
estimating_kv_cache)
resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager
resources[
@ -534,7 +621,7 @@ class KvCacheCreator:
def _create_kv_cache_manager(
model_engine: PyTorchModelEngine,
model_engine: Optional[PyTorchModelEngine],
kv_cache_manager_cls,
mapping: Mapping,
kv_cache_config: KvCacheConfig,
@ -547,13 +634,32 @@ def _create_kv_cache_manager(
max_beam_width: int,
kv_connector_manager: Optional[KvCacheConnectorManager],
estimating_kv_cache: bool,
execution_stream: Optional[torch.cuda.Stream] = None) -> KVCacheManager:
execution_stream: Optional[torch.cuda.Stream] = None,
# Optional overrides for one-model draft case (when model_engine is None)
model_config: Optional[ModelConfig] = None,
dtype: Optional[torch.dtype] = None,
is_draft: Optional[bool] = None,
layer_mask: Optional[List[bool]] = None,
num_layers: Optional[int] = None) -> KVCacheManager:
"""
Returns:
A KVCacheManager instance for the given model_engine
A KVCacheManager instance for the given model engine or model config
"""
config = model_engine.model.model_config.pretrained_config
quant_config = model_engine.model.model_config.quant_config
# Extract config from model_engine or use provided model_config
if model_config is not None:
config = model_config.pretrained_config
quant_config = model_config.quant_config
_model_config = model_config
else:
config = model_engine.model.model_config.pretrained_config
quant_config = model_engine.model.model_config.quant_config
_model_config = model_engine.model.model_config
if dtype is None:
dtype = model_engine.dtype
if is_draft is None:
is_draft = model_engine.is_draft_model
hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
@ -569,10 +675,10 @@ def _create_kv_cache_manager(
):
kv_cache_dtype = tensorrt_llm.bindings.DataType.NVFP4
else:
kv_cache_dtype = str_dtype_to_binding(
torch_dtype_to_str(model_engine.dtype))
kv_cache_dtype = str_dtype_to_binding(torch_dtype_to_str(dtype))
num_hidden_layers = config.num_hidden_layers
# Use provided num_layers if available, otherwise use config
num_hidden_layers = num_layers if num_layers is not None else config.num_hidden_layers
if is_mla(config):
kv_cache_manager = kv_cache_manager_cls(
@ -589,12 +695,13 @@ def _create_kv_cache_manager(
spec_config=spec_config,
vocab_size=config.vocab_size,
max_beam_width=max_beam_width,
is_draft=model_engine.is_draft_model,
is_draft=is_draft,
kv_connector_manager=kv_connector_manager
if not estimating_kv_cache else None,
sparse_attn_config=sparse_attn_config,
is_estimating_kv_cache=estimating_kv_cache,
execution_stream=execution_stream,
layer_mask=layer_mask,
)
elif is_nemotron_hybrid(config):
if max_beam_width > 1:
@ -606,9 +713,10 @@ def _create_kv_cache_manager(
"Connector manager is not supported for MambaHybridCacheManager."
)
config = model_engine.model.model_config.pretrained_config
num_layers = config.hybrid_override_pattern.count("*")
layer_mask = [char == "*" for char in config.hybrid_override_pattern]
hybrid_layer_mask = [
char == "*" for char in config.hybrid_override_pattern
]
mamba_num_layers = config.hybrid_override_pattern.count("M")
mamba_layer_mask = [
char == "M" for char in config.hybrid_override_pattern
@ -623,12 +731,13 @@ def _create_kv_cache_manager(
mamba_num_layers,
mamba_layer_mask,
config.torch_dtype,
model_engine.model.model_config.quant_config.mamba_ssm_cache_dtype,
quant_config.mamba_ssm_cache_dtype
if quant_config is not None else None,
# kv cache parameters
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
layer_mask=layer_mask,
layer_mask=hybrid_layer_mask,
num_kv_heads=num_key_value_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
@ -649,13 +758,12 @@ def _create_kv_cache_manager(
raise NotImplementedError(
"Connector manager is not supported for MambaHybridCacheManager."
)
config = model_engine.model.model_config.pretrained_config
mamba_layer_mask = [
True if i %
config.full_attention_interval != config.full_attention_interval -
1 else False for i in range(num_hidden_layers)
]
layer_mask = [
hybrid_layer_mask = [
False if i %
config.full_attention_interval != config.full_attention_interval -
1 else True for i in range(num_hidden_layers)
@ -673,12 +781,13 @@ def _create_kv_cache_manager(
num_mamba_layers,
mamba_layer_mask,
config.torch_dtype,
model_engine.model.model_config.quant_config.mamba_ssm_cache_dtype,
quant_config.mamba_ssm_cache_dtype
if quant_config is not None else None,
# kv cache parameters
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
layer_mask=layer_mask,
layer_mask=hybrid_layer_mask,
num_kv_heads=num_key_value_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
@ -694,7 +803,7 @@ def _create_kv_cache_manager(
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_for_vswa in KVCahceManager
is_vswa = kv_cache_config.max_attention_window is not None and len(
set(kv_cache_config.max_attention_window)) > 1
binding_model_config = model_engine.model.model_config.get_bindings_model_config(
binding_model_config = _model_config.get_bindings_model_config(
tokens_per_block=tokens_per_block) if is_vswa else None
kv_cache_manager = kv_cache_manager_cls(
@ -713,12 +822,13 @@ def _create_kv_cache_manager(
max_num_tokens=max_num_tokens,
model_config=binding_model_config,
max_beam_width=max_beam_width,
is_draft=model_engine.is_draft_model,
is_draft=is_draft,
kv_connector_manager=kv_connector_manager
if not estimating_kv_cache else None,
sparse_attn_config=sparse_attn_config,
is_estimating_kv_cache=estimating_kv_cache,
execution_stream=execution_stream,
layer_mask=layer_mask,
)
return kv_cache_manager

View File

@ -16,6 +16,7 @@ from ..memory_buffer_utils import get_memory_buffers
from ..modules.multi_stream_utils import with_multi_stream
from ..speculative.eagle3 import Eagle3ResourceManager
from ..speculative.mtp import SampleStateTensorsMTP
from ..speculative.utils import get_draft_kv_cache_manager
from ..utils import make_weak_ref, piecewise_cuda_graph
from .llm_request import get_draft_token_length
from .mamba_cache_manager import MambaCacheManager
@ -435,12 +436,20 @@ class CUDAGraphRunner:
# respect the requirement just in case that changes in the future.
if self.padding_dummy_request is None:
# Get draft KV cache manager only for one-model speculative decoding.
# In two-model mode, each model has its own KV cache manager, so
# draft_kv_cache_manager should be None.
draft_kv_cache_manager = get_draft_kv_cache_manager(
self.spec_config, resource_manager)
self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
[CUDA_GRAPH_DUMMY_REQUEST_ID],
is_gen=True,
max_num_draft_tokens=runtime_draft_len,
use_mrope=self.config.use_mrope,
max_beam_width=self.config.max_beam_width)
max_beam_width=self.config.max_beam_width,
draft_kv_cache_manager=draft_kv_cache_manager)
if self.padding_dummy_request is None:
return 0
else:

View File

@ -46,8 +46,8 @@ from ..models.modeling_utils import DecoderModelForCausalLM
from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer,
MoeLoadBalancerIterContext)
from ..peft.lora.cuda_graph_lora_manager import CudaGraphLoraManager
from ..speculative import (SpecMetadata, get_num_extra_kv_tokens,
get_spec_metadata,
from ..speculative import (SpecMetadata, get_draft_kv_cache_manager,
get_num_extra_kv_tokens, get_spec_metadata,
update_spec_config_from_model_config)
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
@ -567,6 +567,15 @@ class PyTorchModelEngine(ModelEngine):
def use_beam_search(self):
return self.max_beam_width > 1
def _get_draft_kv_cache_manager(
self, resource_manager: ResourceManager
) -> Optional[Union[KVCacheManager, KVCacheManagerV2]]:
"""
Returns the draft KV cache manager only in one-model speculative decoding
mode where the target model manages a separate draft KV cache.
"""
return get_draft_kv_cache_manager(self.spec_config, resource_manager)
@contextmanager
def set_warmup_flag(self):
prev_is_warmup = self.is_warmup
@ -909,6 +918,8 @@ class PyTorchModelEngine(ModelEngine):
"""A context manager to automatically free resources of a dummy batch."""
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
resource_manager)
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
try:
@ -917,6 +928,8 @@ class PyTorchModelEngine(ModelEngine):
if batch is not None and kv_cache_manager is not None:
for req in batch.all_requests():
kv_cache_manager.free_resources(req)
if draft_kv_cache_manager is not None:
draft_kv_cache_manager.free_resources(req)
if spec_resource_manager is not None:
spec_resource_manager.free_resources(req)
@ -939,6 +952,9 @@ class PyTorchModelEngine(ModelEngine):
"""Creates a generic dummy ScheduledRequests object for warmup."""
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
resource_manager)
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
@ -1011,7 +1027,8 @@ class PyTorchModelEngine(ModelEngine):
is_gen=False,
max_num_draft_tokens=self.runtime_draft_len,
use_mrope=self.use_mrope,
num_extra_decoding_steps=num_extra_decoding_steps)
num_extra_decoding_steps=num_extra_decoding_steps,
draft_kv_cache_manager=draft_kv_cache_manager)
if ctx_requests is None:
return None
@ -1030,7 +1047,8 @@ class PyTorchModelEngine(ModelEngine):
max_num_draft_tokens=self.max_total_draft_tokens,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps)
num_extra_decoding_steps=num_extra_decoding_steps,
draft_kv_cache_manager=draft_kv_cache_manager)
if gen_requests is None:
for r in ctx_requests:
@ -1058,6 +1076,8 @@ class PyTorchModelEngine(ModelEngine):
self.kv_cache_manager_key)
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
resource_manager)
available_blocks = kv_cache_manager.get_num_free_blocks(
) // self.max_beam_width
@ -1075,7 +1095,8 @@ class PyTorchModelEngine(ModelEngine):
max_num_draft_tokens=draft_len,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps)
num_extra_decoding_steps=num_extra_decoding_steps,
draft_kv_cache_manager=draft_kv_cache_manager)
if requests is None:
return None
@ -1083,6 +1104,12 @@ class PyTorchModelEngine(ModelEngine):
available_tokens = kv_cache_manager.get_num_available_tokens(
batch_size=batch_size, max_num_draft_tokens=draft_len)
# Also consider draft KV cache capacity when it exists
if draft_kv_cache_manager is not None:
draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens(
batch_size=batch_size, max_num_draft_tokens=draft_len)
available_tokens = min(available_tokens, draft_available_tokens)
# Add one dummy request with the maximum possible sequence length.
max_seq_len = min(
self.max_seq_len if max_seq_len is None else max_seq_len,
@ -1110,7 +1137,8 @@ class PyTorchModelEngine(ModelEngine):
max_num_draft_tokens=draft_len,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps)
num_extra_decoding_steps=num_extra_decoding_steps,
draft_kv_cache_manager=draft_kv_cache_manager)
if max_seq_len_request is None:
for r in requests:
@ -1141,8 +1169,11 @@ class PyTorchModelEngine(ModelEngine):
req.py_is_first_draft = True
req.py_draft_tokens = []
def _set_up_attn_metadata(self, kv_cache_manager: Union[KVCacheManager,
KVCacheManagerV2]):
def _set_up_attn_metadata(
self,
kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2],
draft_kv_cache_manager: Optional[Union[KVCacheManager,
KVCacheManagerV2]] = None):
enable_context_mla_with_cached_kv = is_mla(
self.model.model_config.pretrained_config) and (
self.attn_runtime_features.cache_reuse
@ -1191,6 +1222,7 @@ class PyTorchModelEngine(ModelEngine):
max_num_tokens=self.max_num_tokens,
max_num_sequences=self.batch_size * self.max_beam_width,
kv_cache_manager=kv_cache_manager,
draft_kv_cache_manager=draft_kv_cache_manager,
mapping=self.mapping,
runtime_features=self.attn_runtime_features,
enable_flash_mla=self.model.model_config.enable_flash_mla,
@ -1568,7 +1600,8 @@ class PyTorchModelEngine(ModelEngine):
else:
return self._apply_incremental_update_target(
scheduled_requests, kv_cache_manager, attn_metadata,
spec_metadata, new_tensors_device, num_accepted_tokens_device)
spec_metadata, new_tensors_device, num_accepted_tokens_device,
resource_manager)
@nvtx_range("_prepare_incremental_update_metadata")
def _prepare_incremental_update_metadata(
@ -1834,7 +1867,8 @@ class PyTorchModelEngine(ModelEngine):
attn_metadata: AttentionMetadata,
spec_metadata: Optional[SpecMetadata] = None,
new_tensors_device: Optional[SampleStateTensors] = None,
num_accepted_tokens_device: Optional[torch.Tensor] = None):
num_accepted_tokens_device: Optional[torch.Tensor] = None,
resource_manager: Optional[ResourceManager] = None):
# Extract tensors from new_tensors_device
new_tokens_device = new_tensors_device.new_tokens # [batch, 1 + draft_len]
new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch]
@ -1968,6 +2002,7 @@ class PyTorchModelEngine(ModelEngine):
'position_ids': final_position_ids,
'inputs_embeds': None,
"multimodal_params": [],
'resource_manager': resource_manager,
}
if bool(lora_params):
@ -2781,6 +2816,7 @@ class PyTorchModelEngine(ModelEngine):
'position_ids': final_position_ids,
'inputs_embeds': None,
"multimodal_params": multimodal_params_list,
'resource_manager': resource_manager,
}
if bool(lora_params):
@ -2846,7 +2882,8 @@ class PyTorchModelEngine(ModelEngine):
self,
scheduled_requests: ScheduledRequests,
attn_metadata: AttentionMetadata,
spec_metadata: Optional[SpecMetadata] = None):
spec_metadata: Optional[SpecMetadata] = None,
resource_manager: Optional[ResourceManager] = None):
"""
Prepare inputs for Pytorch Model.
"""
@ -2950,7 +2987,8 @@ class PyTorchModelEngine(ModelEngine):
'position_ids':
self.position_ids_cuda[:virtual_num_tokens].unsqueeze(0),
'inputs_embeds': None,
"multimodal_params": multimodal_params_list
"multimodal_params": multimodal_params_list,
'resource_manager': resource_manager,
}
if bool(lora_params):
@ -2993,10 +3031,12 @@ class PyTorchModelEngine(ModelEngine):
return inputs, None
def _prepare_star_attention_inputs(self,
scheduled_requests: ScheduledRequests,
kv_cache_manager,
attn_metadata: AttentionMetadata):
def _prepare_star_attention_inputs(
self,
scheduled_requests: ScheduledRequests,
kv_cache_manager,
attn_metadata: AttentionMetadata,
resource_manager: Optional[ResourceManager] = None):
"""
Prepare inputs for Pytorch Model.
"""
@ -3212,7 +3252,8 @@ class PyTorchModelEngine(ModelEngine):
'attn_metadata': attn_metadata,
'input_ids': self.input_ids_cuda[:num_tokens],
'position_ids': self.position_ids_cuda[:num_tokens].unsqueeze(0),
'inputs_embeds': None
'inputs_embeds': None,
'resource_manager': resource_manager,
}, gather_ids if is_spec_decode else None
def _get_lora_params_from_requests(
@ -3357,7 +3398,8 @@ class PyTorchModelEngine(ModelEngine):
cp_type = self.mapping.cp_config['cp_type']
if CpType.STAR == cp_type:
return self._prepare_star_attention_inputs(
scheduled_requests, kv_cache_manager, attn_metadata)
scheduled_requests, kv_cache_manager, attn_metadata,
resource_manager)
elif CpType.HELIX == cp_type:
# Take the usual route of _prepare_tp_inputs.
pass
@ -3384,8 +3426,11 @@ class PyTorchModelEngine(ModelEngine):
req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None):
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
resource_manager)
attn_metadata = self._set_up_attn_metadata(kv_cache_manager)
attn_metadata = self._set_up_attn_metadata(kv_cache_manager,
draft_kv_cache_manager)
if self.enable_spec_decode:
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
@ -3423,7 +3468,8 @@ class PyTorchModelEngine(ModelEngine):
if kv_cache_manager is None:
inputs, gather_ids = self._prepare_tp_inputs_no_cache(
scheduled_requests, attn_metadata, spec_metadata)
scheduled_requests, attn_metadata, spec_metadata,
resource_manager)
with MoeLoadBalancerIterContext(moe_load_balancer):
# Special handling for multimodal encoder only mode

View File

@ -313,6 +313,11 @@ def create_py_executor(
has_draft_model_engine = spec_config.spec_dec_mode.has_draft_model()
has_spec_drafter = spec_config.spec_dec_mode.has_spec_drafter()
# WAR for https://nvbugs/5807902
# Disable separate draft KV cache in disaggregated mode
if cache_transceiver_config is not None or kv_connector_config is not None:
spec_config._allow_separate_draft_kv_cache = False
# chunk_unit_size may be changed to 64 when using flash mla
attn_runtime_features = AttentionRuntimeFeatures(
chunked_prefill=enable_chunked_context,
@ -623,6 +628,10 @@ def create_py_executor(
if model_engine.model.model_config.is_generation:
#NOTE: non-generation models do not have kv cache
# Get draft config for one-engine speculative decoding if available
draft_config = getattr(model_engine.model, 'draft_config', None)
kv_cache_creator = KvCacheCreator(
model_engine=model_engine,
draft_model_engine=draft_model_engine,
@ -640,6 +649,7 @@ def create_py_executor(
profiling_stage_data=profiling_stage_data,
sparse_attention_config=sparse_attention_config,
execution_stream=execution_stream,
draft_config=draft_config,
)
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
with allocation_scope(

View File

@ -135,7 +135,10 @@ def get_pp_layers(
pp_layers = mapping.pp_layers(total_num_layers)
if layer_mask is not None:
pp_layers = [i for i in pp_layers if layer_mask[i]]
if spec_config is not None:
# Only add speculative layers when layer_mask is not provided.
# When layer_mask is provided, the caller explicitly controls which layers
# to include, so we should not add extra layers automatically.
if spec_config is not None and layer_mask is None:
num_spec_layers = get_num_spec_layers(spec_config)
total_num_layers += num_spec_layers
if mapping.is_last_pp_rank():
@ -586,6 +589,7 @@ class KVCacheManager(BaseResourceManager):
# we need to make the KV cache manager aware that multiple autoregressive steps will
# occur.
num_extra_decoding_steps: int = 0,
draft_kv_cache_manager: Optional[BaseResourceManager] = None,
):
available_blocks = self.get_num_free_blocks()
# No padding if not enough KV cache space
@ -625,6 +629,12 @@ class KVCacheManager(BaseResourceManager):
for _ in range(num_extra_decoding_steps):
self.impl.add_token(req_id)
if draft_kv_cache_manager is not None:
draft_kv_cache_manager.impl.add_sequence(
req_id, token_num, beam_width, req)
for _ in range(self.num_extra_kv_tokens):
draft_kv_cache_manager.impl.add_token(req_id)
if is_gen:
req.state = LlmRequestState.GENERATION_IN_PROGRESS
req.prompt_len = token_num - 1
@ -651,6 +661,10 @@ class KVCacheManager(BaseResourceManager):
for _ in range(max_num_draft_tokens):
self.impl.add_token(req_id)
if draft_kv_cache_manager is not None:
for _ in range(max_num_draft_tokens):
draft_kv_cache_manager.impl.add_token(req_id)
# TODO: Planning to get dummy_data from each model. Before that, we need to add dummy mrop_config to the request here.
if use_mrope:
dummy_mrope_position_ids = torch.arange(
@ -1886,7 +1900,8 @@ class KVCacheManagerV2(BaseResourceManager):
max_num_draft_tokens: int = 0,
use_mrope: bool = False,
max_beam_width: int = 1,
num_extra_decoding_steps: int = 0):
num_extra_decoding_steps: int = 0,
draft_kv_cache_manager: Optional['BaseResourceManager'] = None):
beam_width = max_beam_width
requests = []

View File

@ -1,14 +1,15 @@
from .auto_heuristic import suggest_spec_config
from .eagle3 import Eagle3SpecMetadata
from .interface import SpecMetadata, SpecWorkerBase
from .interface import (SpecMetadata, SpecWorkerBase,
should_use_separate_draft_kv_cache)
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
from .ngram import NGramDrafter, NGramPoolManager
from .save_hidden_state import SaveHiddenStatesDrafter
from .spec_tree_manager import SpecTreeManager
from .utils import (get_num_extra_kv_tokens, get_num_spec_layers,
get_spec_decoder, get_spec_drafter, get_spec_metadata,
get_spec_resource_manager, get_spec_worker,
update_spec_config_from_model_config)
from .utils import (get_draft_kv_cache_manager, get_num_extra_kv_tokens,
get_num_spec_layers, get_spec_decoder, get_spec_drafter,
get_spec_metadata, get_spec_resource_manager,
get_spec_worker, update_spec_config_from_model_config)
__all__ = [
"Eagle3SpecMetadata",
@ -20,6 +21,7 @@ __all__ = [
"SaveHiddenStatesDrafter",
"SpecMetadata",
"SpecWorkerBase",
"get_draft_kv_cache_manager",
"get_num_extra_kv_tokens",
"get_num_spec_layers",
"get_spec_decoder",
@ -27,6 +29,7 @@ __all__ = [
"get_spec_metadata",
"get_spec_resource_manager",
"get_spec_worker",
"should_use_separate_draft_kv_cache",
"update_spec_config_from_model_config",
"suggest_spec_config",
"SpecTreeManager",

View File

@ -358,8 +358,11 @@ class Eagle3OneModelSampler(MTPSampler):
class Eagle3OneModelWorker(SpecWorkerBase):
def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
super().__init__()
def __init__(self,
spec_config: "EagleDecodingConfig",
mapping: Mapping,
use_separate_draft_kv_cache: bool = False):
super().__init__(use_separate_draft_kv_cache)
self.spec_config = spec_config
self.mapping = mapping
@ -369,8 +372,15 @@ class Eagle3OneModelWorker(SpecWorkerBase):
# Skip torch.compile for now since current Torch is not compatible with Triton 3.4
# @torch.compile(options={"max-autotune": True})
def forward(self, input_ids, position_ids, hidden_states, logits,
attn_metadata, spec_metadata, draft_model):
def forward(self,
input_ids,
position_ids,
hidden_states,
logits,
attn_metadata,
spec_metadata,
draft_model,
resource_manager=None):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts
@ -400,81 +410,91 @@ class Eagle3OneModelWorker(SpecWorkerBase):
# Predict draft tokens
next_draft_tokens = []
original_all_rank_num_tokens = attn_metadata.all_rank_num_tokens
for i in range(self.max_draft_len):
if i == 0:
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
(self.max_draft_len + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
gather_ids = torch.concat(
[spec_metadata.gather_ids[:num_contexts], gather_ids_gen],
dim=0)
else:
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
if self.guided_decoder is not None:
new_tokens = inputs["input_ids"][gather_ids]
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
draft_step=i)
# Get the draft KV cache manager if using separate layouts
draft_kv_cache_manager = self.get_draft_kv_cache_manager(
resource_manager)
# Update attn_metadata.all_rank_num_tokens for attention DP
if original_all_rank_num_tokens is not None:
with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
for i in range(self.max_draft_len):
if i == 0:
attn_metadata.all_rank_num_tokens = original_all_rank_num_tokens
elif spec_metadata.all_rank_num_seqs is not None:
attn_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs
start_ids_gen = (
spec_metadata.batch_indices_cuda[:num_gens] *
(self.max_draft_len + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
gather_ids = torch.concat([
spec_metadata.gather_ids[:num_contexts], gather_ids_gen
],
dim=0)
else:
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
hidden_states, hidden_states_to_save = draft_model.model(**inputs)
# FIXME (jhaotingc): Currently we disable use_spec_decoding mode for Eagle engine nth steps except 1st step.
# Eagle engine takes in draft_len tokens from the previous step, run spec-dec mode with those tokens,
# then the following step can use regular decoding mode to generate 1 tokens per step.
# Currently the spec-dec mask for chained tree is not implemented yet.
# When token tree is supported, this can be removed and all steps may use spec-dec mode as well.
attn_metadata.use_spec_decoding = False
logits = draft_model.logits_processor(hidden_states[gather_ids],
draft_model.lm_head,
attn_metadata, True)
if self.guided_decoder is not None:
d2t = getattr(draft_model.model, "d2t", None)
self.guided_decoder.execute_draft_batch(logits,
d2t,
if self.guided_decoder is not None:
new_tokens = inputs["input_ids"][gather_ids]
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
draft_step=i)
new_draft_token = self.draft_decoder(logits, draft_model)
next_draft_tokens.append(new_draft_token)
# update inputs
hidden_states = hidden_states_to_save[gather_ids]
position_ids = inputs["position_ids"][gather_ids] + 1
# update attn_metadata
if i == 0:
attn_metadata._seq_lens[:batch_size].fill_(1)
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
attn_metadata.on_update()
# cannot run generation if their is no kv cache
if inputs["attn_metadata"].kv_cache_manager is not None:
attn_metadata.host_request_types[:attn_metadata.
num_contexts].fill_(1)
attn_metadata.num_contexts = 0
# update kv_lens_cuda
if hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
self.max_draft_len - num_accepted_tokens[num_contexts:])
attn_metadata.kv_lens_cuda[:num_contexts] += 1
elif hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[:batch_size] += 1
# support attention dp
inputs = {
"input_ids": new_draft_token,
"position_ids": position_ids,
"hidden_states": hidden_states,
"attn_metadata": attn_metadata,
"spec_metadata": spec_metadata,
}
# Update attn_metadata.all_rank_num_tokens for attention DP
if original_all_rank_num_tokens is not None:
if i == 0:
attn_metadata.all_rank_num_tokens = original_all_rank_num_tokens
elif spec_metadata.all_rank_num_seqs is not None:
attn_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs
hidden_states, hidden_states_to_save = draft_model.model(
**inputs)
# FIXME (jhaotingc): Currently we disable use_spec_decoding mode for Eagle engine nth steps except 1st step.
# Eagle engine takes in draft_len tokens from the previous step, run spec-dec mode with those tokens,
# then the following step can use regular decoding mode to generate 1 tokens per step.
# Currently the spec-dec mask for chained tree is not implemented yet.
# When token tree is supported, this can be removed and all steps may use spec-dec mode as well.
attn_metadata.use_spec_decoding = False
logits = draft_model.logits_processor(hidden_states[gather_ids],
draft_model.lm_head,
attn_metadata, True)
if self.guided_decoder is not None:
d2t = getattr(draft_model.model, "d2t", None)
self.guided_decoder.execute_draft_batch(logits,
d2t,
draft_step=i)
new_draft_token = self.draft_decoder(logits, draft_model)
next_draft_tokens.append(new_draft_token)
# update inputs
hidden_states = hidden_states_to_save[gather_ids]
position_ids = inputs["position_ids"][gather_ids] + 1
# update attn_metadata
if i == 0:
attn_metadata._seq_lens[:batch_size].fill_(1)
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
attn_metadata.on_update()
# cannot run generation if their is no kv cache
if inputs["attn_metadata"].kv_cache_manager is not None:
attn_metadata.host_request_types[:attn_metadata.
num_contexts].fill_(1)
attn_metadata.num_contexts = 0
# update kv_lens_cuda
if hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
self.max_draft_len -
num_accepted_tokens[num_contexts:])
attn_metadata.kv_lens_cuda[:num_contexts] += 1
elif hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[:batch_size] += 1
# support attention dp
inputs = {
"input_ids": new_draft_token,
"position_ids": position_ids,
"hidden_states": hidden_states,
"attn_metadata": attn_metadata,
"spec_metadata": spec_metadata,
}
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
# restore attn_metadata to support cuda graph

View File

@ -1,6 +1,7 @@
import copy
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional, Type
@ -11,10 +12,12 @@ from torch import nn
from tensorrt_llm.logger import logger
from ..._utils import get_sm_version
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
from ..attention_backend.trtllm import (AttentionBackend, TrtllmAttention,
TrtllmAttentionMetadata)
from ..cute_dsl_kernels.argmax import argmax as cute_argmax
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
from ..pyexecutor.resource_manager import BaseResourceManager
from ..pyexecutor.resource_manager import (BaseResourceManager,
ResourceManagerType)
if TYPE_CHECKING:
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
@ -26,6 +29,17 @@ if IS_FLASHINFER_AVAILABLE:
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
def should_use_separate_draft_kv_cache(spec_config) -> bool:
"""
Check if separate draft KV cache should be used for one-engine speculative decoding.
"""
if spec_config is None:
return False
if not spec_config.spec_dec_mode.use_one_engine():
return False
return spec_config._allow_separate_draft_kv_cache
def get_force_num_accepted_tokens() -> int:
"""
Read and parse the TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS environment variable.
@ -373,13 +387,14 @@ class SpecWorkerBase(nn.Module, ABC):
Provides common functionality for sampling and token handling.
"""
def __init__(self):
def __init__(self, use_separate_draft_kv_cache: bool = False):
super().__init__()
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
self.use_flashinfer = IS_FLASHINFER_AVAILABLE and flashinfer.__version__ >= "0.6.0"
self.seed = 0
self.offset = 0
self.use_separate_draft_kv_cache = use_separate_draft_kv_cache
@property
@abstractmethod
@ -600,6 +615,52 @@ class SpecWorkerBase(nn.Module, ABC):
else:
return torch.empty(0, dtype=torch.int32, device="cuda")
def get_draft_kv_cache_manager(self, resource_manager):
"""
Get the draft KV cache manager if using separate KV cache layouts.
"""
if self.use_separate_draft_kv_cache and resource_manager is not None:
return resource_manager.get_resource_manager(
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
return None
@contextmanager
def draft_kv_cache_context(self, attn_metadata, draft_kv_cache_manager):
"""
Context manager to temporarily switch to draft KV cache manager in one-engine speculative decoding.
This swaps both the kv_cache_manager reference AND the block offset tensors,
since the target and draft KV caches have different block layouts.
"""
# draft_kv_cache_manager is None if using two-engine speculative decoding or not enabling separate draft KV cache.
if draft_kv_cache_manager is None:
yield
return
# Only TrtllmAttentionMetadata supports separate draft KV cache layouts
if not isinstance(attn_metadata, TrtllmAttentionMetadata):
yield
return
# Save main KV cache manager and block offsets
target_kv_cache_manager = attn_metadata.kv_cache_manager
target_kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets
target_host_kv_cache_block_offsets = attn_metadata.host_kv_cache_block_offsets
# Switch to draft KV cache manager and its block offsets
attn_metadata.kv_cache_manager = draft_kv_cache_manager
attn_metadata.kv_cache_block_offsets = attn_metadata.draft_kv_cache_block_offsets
attn_metadata.host_kv_cache_block_offsets = attn_metadata.draft_host_kv_cache_block_offsets
try:
yield
finally:
# Restore main KV cache manager and block offsets
attn_metadata.kv_cache_manager = target_kv_cache_manager
attn_metadata.kv_cache_block_offsets = target_kv_cache_block_offsets
attn_metadata.host_kv_cache_block_offsets = target_host_kv_cache_block_offsets
def _sample_tokens_for_batch(
self,
logits: torch.Tensor,

View File

@ -367,8 +367,11 @@ class MTPSampler(Sampler[SampleStateMTP]):
class MTPWorker(SpecWorkerBase):
def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
super().__init__()
def __init__(self,
spec_config: "MTPDecodingConfig",
model_config=None,
use_separate_draft_kv_cache: bool = False):
super().__init__(use_separate_draft_kv_cache)
self.spec_config = spec_config
self.model_config = model_config
self.is_thop = False
@ -386,6 +389,7 @@ class MTPWorker(SpecWorkerBase):
attn_metadata,
spec_metadata,
draft_model,
resource_manager=None,
):
'''
Example:
@ -518,43 +522,50 @@ class MTPWorker(SpecWorkerBase):
# update attn metadata
if attn_metadata is not None:
self.change_attn_metadata(num_accepted_tokens, attn_metadata)
draft_inputs.update(attn_metadata=attn_metadata)
# Run MTP layers to predict draft tokens
next_draft_tokens = []
last_tokens_idx = torch.cumsum(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
for i, mtp_layer in enumerate(draft_model.mtp_layers):
if self.guided_decoder is not None:
new_tokens = draft_inputs['input_ids'][last_tokens_idx]
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
draft_step=i)
hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens,
**draft_inputs)
logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head,
attn_metadata).float()
if self.guided_decoder is not None:
self.guided_decoder.execute_draft_batch(logits, draft_step=i)
draft_kv_cache_manager = self.get_draft_kv_cache_manager(
resource_manager)
new_draft_token = self.draft_sampler(logits)
next_draft_tokens.append(new_draft_token)
# shift input_ids and hidden_states
input_ids = draft_inputs["input_ids"]
input_ids[:-1] = input_ids[1:].clone()
input_ids[last_tokens_idx] = new_draft_token
draft_hidden_states = draft_inputs["hidden_states"]
draft_hidden_states[:-1] = draft_hidden_states[1:].clone()
draft_hidden_states[last_tokens_idx] = hidden_states[
last_tokens_idx, :]
draft_inputs = {
"input_ids": input_ids,
"position_ids": draft_inputs["position_ids"],
"hidden_states": draft_hidden_states,
"attn_metadata": draft_inputs["attn_metadata"],
}
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
for i, mtp_layer in enumerate(draft_model.mtp_layers):
if self.guided_decoder is not None:
new_tokens = draft_inputs['input_ids'][last_tokens_idx]
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
draft_step=i)
hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens,
**draft_inputs)
logits = mtp_layer.shared_head(hidden_states,
draft_model.lm_head,
attn_metadata).float()
if self.guided_decoder is not None:
self.guided_decoder.execute_draft_batch(logits,
draft_step=i)
new_draft_token = self.draft_sampler(logits)
next_draft_tokens.append(new_draft_token)
# shift input_ids and hidden_states
input_ids = draft_inputs["input_ids"]
input_ids[:-1] = input_ids[1:].clone()
input_ids[last_tokens_idx] = new_draft_token
draft_hidden_states = draft_inputs["hidden_states"]
draft_hidden_states[:-1] = draft_hidden_states[1:].clone()
draft_hidden_states[last_tokens_idx] = hidden_states[
last_tokens_idx, :]
draft_inputs = {
"input_ids": input_ids,
"position_ids": draft_inputs["position_ids"],
"hidden_states": draft_hidden_states,
"attn_metadata": draft_inputs["attn_metadata"],
}
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
# restore attn metadata
if attn_metadata is not None:
@ -573,6 +584,39 @@ class MTPWorker(SpecWorkerBase):
'next_new_tokens': next_new_tokens
}
def skip_forward(
self,
input_ids,
position_ids,
hidden_states,
logits,
attn_metadata,
spec_metadata,
draft_model,
resource_manager=None,
):
batch_size = attn_metadata.num_seqs
mtp_num_modules = self.spec_config.num_nextn_predict_layers
accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
dtype=torch.int,
device=logits.device)
num_accepted_tokens = torch.ones(batch_size,
dtype=torch.int,
device=logits.device)
next_draft_tokens = torch.empty((batch_size, mtp_num_modules),
dtype=torch.int,
device=logits.device)
next_new_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
dtype=torch.int,
device=logits.device)
return {
'logits': logits,
'new_tokens': accepted_tokens,
'new_tokens_lens': num_accepted_tokens,
'next_draft_tokens': next_draft_tokens,
'next_new_tokens': next_new_tokens
}
def update_mtp_hidden_states(
self,
input_ids: torch.IntTensor,
@ -1146,8 +1190,9 @@ class MTPEagleWorker(MTPWorker):
def __init__(self,
spec_config: "MTPDecodingConfig",
model_config: Optional[ModelConfig] = None):
super().__init__(spec_config, model_config)
model_config: Optional[ModelConfig] = None,
use_separate_draft_kv_cache: bool = False):
super().__init__(spec_config, model_config, use_separate_draft_kv_cache)
self.model_config = model_config
self.mtp_num_modules = spec_config.num_nextn_predict_layers
self._is_mamba_hybrid_cache = None
@ -1170,6 +1215,7 @@ class MTPEagleWorker(MTPWorker):
attn_metadata,
spec_metadata,
draft_model,
resource_manager=None,
):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
@ -1212,123 +1258,134 @@ class MTPEagleWorker(MTPWorker):
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)
# Get the draft KV cache manager if using separate layouts
draft_kv_cache_manager = self.get_draft_kv_cache_manager(
resource_manager)
# Predict draft tokens
next_draft_tokens = []
for i in range(self.mtp_num_modules):
if i == 0:
hidden_states = draft_model.mtp_layers[0](
embed_tokens=draft_model.embed_tokens,
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
**inputs)
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
(self.mtp_num_modules + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
gather_ids = torch.concat(
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
else:
hidden_states = draft_model.mtp_layers[0](
embed_tokens=draft_model.embed_tokens,
all_rank_num_tokens=spec_metadata.
subseq_all_rank_num_tokens,
**inputs)
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
for i in range(self.mtp_num_modules):
if i == 0:
hidden_states = draft_model.mtp_layers[0](
embed_tokens=draft_model.embed_tokens,
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
**inputs)
if self.guided_decoder is not None:
new_tokens = inputs["input_ids"][gather_ids]
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
draft_step=i)
if self.model_config.mapping.enable_attention_dp and \
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
hidden_states_gathered = hidden_states[gather_ids]
token_count = hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1]).shape[0]
max_num_requests = spec_metadata.max_num_requests
pad_len = max_num_requests - token_count
if pad_len > 0:
padded_hidden_states = F.pad(hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1]),
(0, 0, 0, pad_len),
mode="constant",
value=0)
elif pad_len == 0:
padded_hidden_states = hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1])
start_ids_gen = (
spec_metadata.batch_indices_cuda[:num_gens] *
(self.mtp_num_modules + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
gather_ids = torch.concat(
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
else:
raise ValueError(
f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported"
)
logits = draft_model.mtp_layers[0].shared_head(
padded_hidden_states, draft_model.lm_head, attn_metadata,
True)
else:
logits = draft_model.mtp_layers[0].shared_head(
hidden_states[gather_ids], draft_model.lm_head,
attn_metadata, True)
if self.guided_decoder is not None:
self.guided_decoder.execute_draft_batch(logits, draft_step=i)
hidden_states = draft_model.mtp_layers[0](
embed_tokens=draft_model.embed_tokens,
all_rank_num_tokens=spec_metadata.
subseq_all_rank_num_tokens,
**inputs)
if self.model_config.mapping.enable_attention_dp and \
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
mapping_lm_head_tp = draft_model.mtp_layers[
0].shared_head.mapping_lm_head_tp
new_draft_token = self.draft_sampler(logits, mapping_lm_head_tp)
new_draft_token = new_draft_token[:token_count]
else:
new_draft_token = self.draft_sampler(logits)
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
hidden_states, position_ids = self.update_draft_tokens(
next_draft_tokens, new_draft_token, hidden_states, gather_ids,
inputs)
# update attn_metadata
if i == 0:
attn_metadata._seq_lens[:batch_size].fill_(1)
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
attn_metadata.on_update()
# cannot run generation if their is no kv cache
has_kv_cache = inputs[
"attn_metadata"].kv_cache_manager is not None
if has_kv_cache:
attn_metadata.host_request_types[:attn_metadata.
num_contexts].fill_(1)
attn_metadata.num_contexts = 0
# update kv_lens_cuda
if hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
self.mtp_num_modules -
num_accepted_tokens[num_contexts:])
attn_metadata.kv_lens_cuda[:num_contexts] += 1
# update metadata for flash mla
if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla:
reorder_block_ids_per_seq = torch.cat([
attn_metadata.
kv_block_ids_per_seq[num_contexts:batch_size],
attn_metadata.kv_block_ids_per_seq[:num_contexts]
])
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
reorder_block_ids_per_seq, non_blocking=True)
# update metadata
# some attention metadata needs to be updated when changing seq_lens/kv_lens
attn_metadata.update_for_spec_dec()
elif hasattr(attn_metadata, 'kv_lens_cuda'):
if self.guided_decoder is not None:
new_tokens = inputs["input_ids"][gather_ids]
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
draft_step=i)
if self.model_config.mapping.enable_attention_dp and \
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
hidden_states_gathered = hidden_states[gather_ids]
token_count = hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1]).shape[0]
max_num_requests = spec_metadata.max_num_requests
pad_len = max_num_requests - token_count
if pad_len > 0:
padded_hidden_states = F.pad(
hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1]),
(0, 0, 0, pad_len),
mode="constant",
value=0)
elif pad_len == 0:
padded_hidden_states = hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1])
else:
raise ValueError(
f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported"
)
logits = draft_model.mtp_layers[0].shared_head(
padded_hidden_states, draft_model.lm_head,
attn_metadata, True)
else:
logits = draft_model.mtp_layers[0].shared_head(
hidden_states[gather_ids], draft_model.lm_head,
attn_metadata, True)
if self.guided_decoder is not None:
self.guided_decoder.execute_draft_batch(logits,
draft_step=i)
@torch.compile(options={"max-autotune": True})
def update_kv_lens(kv_lens_cuda, batch_size):
kv_lens_cuda[:batch_size] += 1
if self.model_config.mapping.enable_attention_dp and \
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
mapping_lm_head_tp = draft_model.mtp_layers[
0].shared_head.mapping_lm_head_tp
new_draft_token = self.draft_sampler(
logits, mapping_lm_head_tp)
new_draft_token = new_draft_token[:token_count]
else:
new_draft_token = self.draft_sampler(logits)
update_kv_lens(attn_metadata.kv_lens_cuda, batch_size)
# update metadata
# some attention metadata needs to be updated when changing kv_lens
attn_metadata.update_for_spec_dec()
inputs = {
"input_ids": new_draft_token,
"position_ids": position_ids,
"hidden_states": hidden_states,
"attn_metadata": attn_metadata,
}
hidden_states, position_ids = self.update_draft_tokens(
next_draft_tokens, new_draft_token, hidden_states,
gather_ids, inputs)
# update attn_metadata
if i == 0:
attn_metadata._seq_lens[:batch_size].fill_(1)
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
attn_metadata.on_update()
# cannot run generation if their is no kv cache
has_kv_cache = inputs[
"attn_metadata"].kv_cache_manager is not None
if has_kv_cache:
attn_metadata.host_request_types[:attn_metadata.
num_contexts].fill_(1)
attn_metadata.num_contexts = 0
# update kv_lens_cuda
if hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
self.mtp_num_modules -
num_accepted_tokens[num_contexts:])
attn_metadata.kv_lens_cuda[:num_contexts] += 1
# update metadata for flash mla
if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla:
reorder_block_ids_per_seq = torch.cat([
attn_metadata.
kv_block_ids_per_seq[num_contexts:batch_size],
attn_metadata.kv_block_ids_per_seq[:num_contexts]
])
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
reorder_block_ids_per_seq, non_blocking=True)
# update metadata
# some attention metadata needs to be updated when changing seq_lens/kv_lens
attn_metadata.update_for_spec_dec()
elif hasattr(attn_metadata, 'kv_lens_cuda'):
@torch.compile(options={"max-autotune": True})
def update_kv_lens(kv_lens_cuda, batch_size):
kv_lens_cuda[:batch_size] += 1
update_kv_lens(attn_metadata.kv_lens_cuda, batch_size)
# update metadata
# some attention metadata needs to be updated when changing kv_lens
attn_metadata.update_for_spec_dec()
inputs = {
"input_ids": new_draft_token,
"position_ids": position_ids,
"hidden_states": hidden_states,
"attn_metadata": attn_metadata,
}
# restore attn_metadata to support cuda graph
self._restore_attn_metadata_from_spec_dec(attn_metadata)

View File

@ -220,14 +220,19 @@ def get_num_spec_layers(spec_config):
return 0
def get_spec_worker(spec_config, model_config, mapping):
def get_spec_worker(spec_config,
model_config,
mapping,
use_separate_draft_kv_cache: bool = False):
spec_dec_mode = spec_config.spec_dec_mode
if spec_dec_mode.is_mtp_vanilla():
return MTPWorker(spec_config, model_config)
return MTPWorker(spec_config, model_config, use_separate_draft_kv_cache)
if spec_dec_mode.is_mtp_eagle_one_model():
return MTPEagleWorker(spec_config, model_config)
return MTPEagleWorker(spec_config, model_config,
use_separate_draft_kv_cache)
if spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelWorker(spec_config, mapping)
return Eagle3OneModelWorker(spec_config, mapping,
use_separate_draft_kv_cache)
return None
@ -243,6 +248,21 @@ def get_num_extra_kv_tokens(spec_config):
return 0
def get_draft_kv_cache_manager(spec_config, resource_manager):
"""
Returns the draft KV cache manager only in one-model speculative decoding
mode where the target model manages a separate draft KV cache.
"""
from ..pyexecutor.resource_manager import ResourceManagerType
if spec_config is None:
return None
if not spec_config.spec_dec_mode.use_one_engine():
return None
return resource_manager.get_resource_manager(
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
def update_spec_config_from_model_config(spec_config, model_config):
if spec_config.spec_dec_mode.is_mtp_one_model():
# Use `max_draft_len` for several low-level APIs. TODO: Remove this after distinguishing them.

View File

@ -720,6 +720,8 @@ class DecodingBaseConfig(StrictBaseModel):
_allow_greedy_draft_tokens: bool = PrivateAttr(True)
# Internal: record decoding_type alias used during parsing (for warnings).
_decoding_type_alias: Optional[str] = PrivateAttr(default=None)
# If set, drafting will use separate KV cache in one-model speculative decoding.
_allow_separate_draft_kv_cache: bool = PrivateAttr(True)
@field_validator('draft_len_schedule')
@classmethod

View File

@ -155,8 +155,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
disable_overlap_scheduler: bool, enable_block_reuse: bool,
use_one_model: bool, enable_chunked_prefill: bool,
use_chain_drafter: bool, multi_batch: bool,
attention_dp: bool, use_hf_speculative_model: bool,
request):
attention_dp: bool, use_hf_speculative_model: bool):
# Eagle3 one model works with overlap scheduler and block reuse.
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35: