This commit is contained in:
Ziyi Xiong 2026-01-13 21:25:09 +08:00 committed by GitHub
commit 227707a50d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 691 additions and 303 deletions

View File

@ -10,6 +10,7 @@ from flashinfer.jit.core import check_cuda_arch
from typing_extensions import Self
from tensorrt_llm.functional import AttentionMaskType
from tensorrt_llm.logger import logger
from tensorrt_llm.models.modeling_utils import QuantConfig
from ..utils import get_global_attrs, get_model_extra_attrs
@ -127,6 +128,11 @@ class FlashInferAttentionMetadata(AttentionMetadata):
def __post_init__(self) -> None:
super().__post_init__()
if self.draft_kv_cache_manager is not None:
logger.warning(
"draft_kv_cache_manager is not supported in FlashInfer backend. "
"One-model speculative decoding with separate KV cache layouts "
"may not work correctly.")
self._post_init_with_buffers(self.cuda_graph_buffers)
def _post_init_with_buffers(self, buffers) -> None:

View File

@ -56,6 +56,8 @@ class AttentionMetadata:
max_num_sequences: Optional[int] = None
# The KV cache manager.
kv_cache_manager: KVCacheManager
# Draft KV cache manager for one-model speculative decoding with separate KV cache layouts
draft_kv_cache_manager: Optional[KVCacheManager] = None
mapping: Optional[Mapping] = None
enable_flash_mla: bool = False

View File

@ -974,6 +974,7 @@ class RocketKVCacheManager(KVCacheManager):
use_mrope: bool = False,
max_beam_width: int = 1,
num_extra_decoding_steps: int = 0,
draft_kv_cache_manager=None,
):
requests = super().add_dummy_requests(
request_ids=request_ids,
@ -984,6 +985,7 @@ class RocketKVCacheManager(KVCacheManager):
use_mrope=use_mrope,
max_beam_width=max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps,
draft_kv_cache_manager=draft_kv_cache_manager,
)
if prepare_resource:
for req in requests:

View File

@ -796,6 +796,29 @@ class TrtllmAttentionMetadata(AttentionMetadata):
)
self.block_ids_per_seq = None
self.kv_block_ids_per_seq = None
# Allocate separate block offset tensors for draft KV cache manager
# Used in one-model speculative decoding with different KV cache layouts
if self.draft_kv_cache_manager is not None:
self.draft_kv_cache_block_offsets = self.get_empty(
buffers,
[
self.draft_kv_cache_manager.num_pools,
self.max_num_sequences, 2,
self.draft_kv_cache_manager.max_blocks_per_seq
],
cache_name="draft_kv_cache_block_offsets",
dtype=torch.int32,
capture_graph=capture_graph,
)
self.draft_host_kv_cache_block_offsets = torch.empty_like(
self.draft_kv_cache_block_offsets,
device='cpu',
pin_memory=True,
)
else:
self.draft_kv_cache_block_offsets = None
self.draft_host_kv_cache_block_offsets = None
if self.enable_flash_mla:
self.block_ids_per_seq = self.get_empty(
buffers,
@ -1007,6 +1030,25 @@ class TrtllmAttentionMetadata(AttentionMetadata):
assert self.kv_lens[:self.num_seqs].max(
) <= self.kv_cache_manager.max_seq_len, error_message
# Also prepare draft KV cache block offsets if draft_kv_cache_manager exists
if self.draft_kv_cache_manager is not None:
# Copy blocks for all context requests
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
self.draft_host_kv_cache_block_offsets,
self.request_ids[:self.num_contexts], 1, 0)
# Copy blocks for all generation requests
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
self.draft_host_kv_cache_block_offsets,
self.request_ids[self.num_contexts:], self.beam_width,
self.num_contexts)
for pool_idx in range(
self.draft_host_kv_cache_block_offsets.shape[0]):
self.draft_kv_cache_block_offsets[
pool_idx, :self.num_seqs].copy_(
self.draft_host_kv_cache_block_offsets[
pool_idx, :self.num_seqs],
non_blocking=True)
self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
# Don't use self.kv_lens here because it includes extra tokens.
# Use actual KV length (without extra tokens) for kv_lens_runtime,

View File

@ -17,7 +17,8 @@ from ..modules.linear import (Linear, TensorParallelMode, WeightMode,
WeightsLoadingConfig)
from ..modules.rms_norm import RMSNorm
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
from ..speculative import SpecMetadata, get_spec_worker
from ..speculative import (SpecMetadata, get_spec_worker,
should_use_separate_draft_kv_cache)
from ..utils import AuxStreamType
from .checkpoints.base_weight_mapper import BaseWeightMapper
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel,
@ -880,6 +881,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
vocab_size=model_config.pretrained_config.vocab_size)
self.draft_model = None
self.draft_config = None
self.use_separate_draft_kv_cache = False
spec_config = getattr(model_config, 'spec_config', None)
if spec_config and spec_config.spec_dec_mode.use_one_engine():
if spec_config.spec_dec_mode.is_eagle3_one_model():
@ -913,11 +915,16 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
self.draft_config.quant_config.kv_cache_quant_algo = \
model_config.quant_config.kv_cache_quant_algo
self.use_separate_draft_kv_cache = should_use_separate_draft_kv_cache(
spec_config)
self.draft_model = get_draft_model(model_config, self.draft_config,
self.lm_head, self.model)
self.spec_worker = get_spec_worker(model_config.spec_config,
model_config,
model_config.mapping)
self.spec_worker = get_spec_worker(
model_config.spec_config,
model_config,
model_config.mapping,
use_separate_draft_kv_cache=self.use_separate_draft_kv_cache)
if self.draft_config is not None and model_config.spec_config.eagle3_model_arch == "llama3":
for key, value in self.draft_config.extra_attrs.items():
@ -934,6 +941,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
inputs_embeds: Optional[torch.FloatTensor] = None,
return_context_logits: bool = False,
spec_metadata: Optional[SpecMetadata] = None,
resource_manager=None,
**kwargs,
) -> torch.Tensor:
hidden_states = self.model(
@ -978,7 +986,8 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
logits=logits,
attn_metadata=attn_metadata,
spec_metadata=spec_metadata,
draft_model=self.draft_model)
draft_model=self.draft_model,
resource_manager=resource_manager)
else:
logits = self.logits_processor.forward(
hidden_states,

View File

@ -25,7 +25,8 @@ from tensorrt_llm.mapping import CpType, Mapping
from ..attention_backend import get_sparse_attn_kv_cache_manager
from ..model_config import ModelConfig
from ..speculative import get_num_extra_kv_tokens, get_spec_decoder
from ..speculative import (get_num_extra_kv_tokens, get_spec_decoder,
should_use_separate_draft_kv_cache)
from .config_utils import is_mla, is_nemotron_hybrid, is_qwen3_next
from .guided_decoder import GuidedDecoder
from .kv_cache_connector import KvCacheConnectorManager
@ -78,6 +79,7 @@ class KvCacheCreator:
sparse_attention_config: SparseAttentionConfig,
profiling_stage_data: Optional[dict],
execution_stream: Optional[torch.cuda.Stream] = None,
draft_config: Optional[ModelConfig] = None,
):
self._model_engine = model_engine
self._draft_model_engine = draft_model_engine
@ -99,6 +101,7 @@ class KvCacheCreator:
self._kv_cache_manager_cls = get_kv_cache_manager_cls(
model_engine.model.model_config)
self._execution_stream = execution_stream
self._draft_config = draft_config
def _get_kv_size_per_token(self):
model_config = self._model_engine.model.model_config
@ -111,6 +114,12 @@ class KvCacheCreator:
draft_model_config,
mapping,
tokens_per_block=self._tokens_per_block)
elif self._should_create_separate_draft_kv_cache():
# One-model draft with separate KV cache layout
kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token(
self._draft_config,
mapping,
tokens_per_block=self._tokens_per_block)
return kv_size_per_token
def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
@ -393,9 +402,12 @@ class KvCacheCreator:
# get kv cache stats for both model and draft model
kv_stats = py_executor.resource_manager.resource_managers.get(
ResourceManagerType.KV_CACHE_MANAGER).get_kv_cache_stats()
kv_stats_draft = py_executor.resource_manager.resource_managers.get(
ResourceManagerType.DRAFT_KV_CACHE_MANAGER).get_kv_cache_stats(
) if self._draft_model_engine is not None else None
# Get draft KV cache stats if present (either from two-model mode or one-model
# mode with separate draft KV cache)
draft_kv_cache_manager = py_executor.resource_manager.resource_managers.get(
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
kv_stats_draft = draft_kv_cache_manager.get_kv_cache_stats(
) if draft_kv_cache_manager is not None else None
# get total allocated bytes
allocated_bytes = kv_stats.allocated_bytes + (
@ -499,6 +511,60 @@ class KvCacheCreator:
return kv_cache_manager
def _should_create_separate_draft_kv_cache(self) -> bool:
"""
Check if we need a separate draft KV cache manager for one-model mode.
Returns True if the speculative config has use_separate_draft_kv_cache=True.
"""
if self._draft_config is None:
return False
return should_use_separate_draft_kv_cache(self._speculative_config)
def _create_one_model_draft_kv_cache_manager(
self,
estimating_kv_cache: bool = False) -> Optional[KVCacheManager]:
"""
Create a KV cache manager for draft model layers in one-model mode
when target and draft have different KV cache layouts.
"""
# Get target model's num_hidden_layers to compute correct layer indices.
# Draft model layers in one-model mode start at target_num_layers.
target_pretrained_config = self._model_engine.model.model_config.pretrained_config
target_num_layers = target_pretrained_config.num_hidden_layers
num_draft_layers = self._draft_config.pretrained_config.num_hidden_layers
# Create layer_mask: False for target layers, True for draft layers.
# This ensures the draft KV cache manager uses the correct layer indices
# (e.g., layers 32, 33, ... instead of 0, 1, ...).
layer_mask = [False] * target_num_layers + [True] * num_draft_layers
# Get the appropriate KV cache manager class for the draft model
draft_kv_cache_manager_cls = get_kv_cache_manager_cls(
self._draft_config)
return _create_kv_cache_manager(
model_engine=None,
kv_cache_manager_cls=draft_kv_cache_manager_cls,
mapping=self._mapping,
kv_cache_config=self._kv_cache_config,
tokens_per_block=self._tokens_per_block,
max_seq_len=self._max_seq_len,
max_batch_size=self._max_batch_size,
spec_config=self._speculative_config,
sparse_attn_config=None, # Not applicable for draft in one-model mode
max_num_tokens=self._max_num_tokens,
max_beam_width=self._max_beam_width,
kv_connector_manager=None, # Not supported for draft models
estimating_kv_cache=estimating_kv_cache,
execution_stream=self._execution_stream,
# One-model draft specific overrides
model_config=self._draft_config,
dtype=self._draft_config.pretrained_config.torch_dtype,
is_draft=True,
layer_mask=layer_mask,
)
def build_managers(self,
resources: Dict,
estimating_kv_cache: bool = False) -> None:
@ -510,9 +576,16 @@ class KvCacheCreator:
raise NotImplementedError(
"Connector manager is not supported for draft model.")
draft_kv_cache_manager = self._create_kv_cache_manager(
self._draft_model_engine, estimating_kv_cache
) if self._draft_model_engine is not None else None
draft_kv_cache_manager = None
# Two-model speculative decoding: draft model has separate engine
if self._draft_model_engine is not None:
draft_kv_cache_manager = self._create_kv_cache_manager(
self._draft_model_engine, estimating_kv_cache)
# One-model speculative decoding with different KV layouts
elif self._should_create_separate_draft_kv_cache():
draft_kv_cache_manager = self._create_one_model_draft_kv_cache_manager(
estimating_kv_cache)
resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager
resources[
@ -530,7 +603,7 @@ class KvCacheCreator:
def _create_kv_cache_manager(
model_engine: PyTorchModelEngine,
model_engine: Optional[PyTorchModelEngine],
kv_cache_manager_cls,
mapping: Mapping,
kv_cache_config: KvCacheConfig,
@ -543,13 +616,31 @@ def _create_kv_cache_manager(
max_beam_width: int,
kv_connector_manager: Optional[KvCacheConnectorManager],
estimating_kv_cache: bool,
execution_stream: Optional[torch.cuda.Stream] = None) -> KVCacheManager:
execution_stream: Optional[torch.cuda.Stream] = None,
# Optional overrides for one-model draft case (when model_engine is None)
model_config: Optional[ModelConfig] = None,
dtype: Optional[torch.dtype] = None,
is_draft: Optional[bool] = None,
layer_mask: Optional[List[bool]] = None) -> KVCacheManager:
"""
Returns:
A KVCacheManager instance for the given model_engine
A KVCacheManager instance for the given model engine or model config
"""
config = model_engine.model.model_config.pretrained_config
quant_config = model_engine.model.model_config.quant_config
# Extract config from model_engine or use provided model_config
if model_config is not None:
config = model_config.pretrained_config
quant_config = model_config.quant_config
_model_config = model_config
else:
config = model_engine.model.model_config.pretrained_config
quant_config = model_engine.model.model_config.quant_config
_model_config = model_engine.model.model_config
if dtype is None:
dtype = model_engine.dtype
if is_draft is None:
is_draft = model_engine.is_draft_model
hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
@ -565,8 +656,7 @@ def _create_kv_cache_manager(
):
kv_cache_dtype = tensorrt_llm.bindings.DataType.NVFP4
else:
kv_cache_dtype = str_dtype_to_binding(
torch_dtype_to_str(model_engine.dtype))
kv_cache_dtype = str_dtype_to_binding(torch_dtype_to_str(dtype))
num_hidden_layers = config.num_hidden_layers
@ -584,12 +674,13 @@ def _create_kv_cache_manager(
dtype=kv_cache_dtype,
spec_config=spec_config,
max_beam_width=max_beam_width,
is_draft=model_engine.is_draft_model,
is_draft=is_draft,
kv_connector_manager=kv_connector_manager
if not estimating_kv_cache else None,
sparse_attn_config=sparse_attn_config,
is_estimating_kv_cache=estimating_kv_cache,
execution_stream=execution_stream,
layer_mask=layer_mask,
)
elif is_nemotron_hybrid(config):
if max_beam_width > 1:
@ -601,9 +692,10 @@ def _create_kv_cache_manager(
"Connector manager is not supported for MambaHybridCacheManager."
)
config = model_engine.model.model_config.pretrained_config
num_layers = config.hybrid_override_pattern.count("*")
layer_mask = [char == "*" for char in config.hybrid_override_pattern]
hybrid_layer_mask = [
char == "*" for char in config.hybrid_override_pattern
]
mamba_num_layers = config.hybrid_override_pattern.count("M")
mamba_layer_mask = [
char == "M" for char in config.hybrid_override_pattern
@ -618,12 +710,13 @@ def _create_kv_cache_manager(
mamba_num_layers,
mamba_layer_mask,
config.torch_dtype,
model_engine.model.model_config.quant_config.mamba_ssm_cache_dtype,
quant_config.mamba_ssm_cache_dtype
if quant_config is not None else None,
# kv cache parameters
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
layer_mask=layer_mask,
layer_mask=hybrid_layer_mask,
num_kv_heads=num_key_value_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
@ -644,13 +737,12 @@ def _create_kv_cache_manager(
raise NotImplementedError(
"Connector manager is not supported for MambaHybridCacheManager."
)
config = model_engine.model.model_config.pretrained_config
mamba_layer_mask = [
True if i %
config.full_attention_interval != config.full_attention_interval -
1 else False for i in range(num_hidden_layers)
]
layer_mask = [
hybrid_layer_mask = [
False if i %
config.full_attention_interval != config.full_attention_interval -
1 else True for i in range(num_hidden_layers)
@ -668,12 +760,13 @@ def _create_kv_cache_manager(
num_mamba_layers,
mamba_layer_mask,
config.torch_dtype,
model_engine.model.model_config.quant_config.mamba_ssm_cache_dtype,
quant_config.mamba_ssm_cache_dtype
if quant_config is not None else None,
# kv cache parameters
kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_layers,
layer_mask=layer_mask,
layer_mask=hybrid_layer_mask,
num_kv_heads=num_key_value_heads,
head_dim=head_dim,
tokens_per_block=tokens_per_block,
@ -689,7 +782,7 @@ def _create_kv_cache_manager(
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager
is_vswa = kv_cache_config.max_attention_window is not None and len(
set(kv_cache_config.max_attention_window)) > 1
binding_model_config = model_engine.model.model_config.get_bindings_model_config(
binding_model_config = _model_config.get_bindings_model_config(
tokens_per_block=tokens_per_block) if is_vswa else None
kv_cache_manager = kv_cache_manager_cls(
@ -707,12 +800,13 @@ def _create_kv_cache_manager(
max_num_tokens=max_num_tokens,
model_config=binding_model_config,
max_beam_width=max_beam_width,
is_draft=model_engine.is_draft_model,
is_draft=is_draft,
kv_connector_manager=kv_connector_manager
if not estimating_kv_cache else None,
sparse_attn_config=sparse_attn_config,
is_estimating_kv_cache=estimating_kv_cache,
execution_stream=execution_stream,
layer_mask=layer_mask,
)
return kv_cache_manager

View File

@ -16,6 +16,7 @@ from ..memory_buffer_utils import get_memory_buffers
from ..modules.multi_stream_utils import with_multi_stream
from ..speculative.eagle3 import Eagle3ResourceManager
from ..speculative.mtp import SampleStateTensorsMTP
from ..speculative.utils import get_draft_kv_cache_manager
from ..utils import make_weak_ref, piecewise_cuda_graph
from .llm_request import get_draft_token_length
from .mamba_cache_manager import MambaCacheManager
@ -439,12 +440,19 @@ class CUDAGraphRunner:
if available_blocks < 1:
return 0
# Get draft KV cache manager only for one-model speculative decoding.
# In two-model mode, each model has its own KV cache manager, so
# draft_kv_cache_manager should be None.
draft_kv_cache_manager = get_draft_kv_cache_manager(
self.spec_config, resource_manager)
self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
[CUDA_GRAPH_DUMMY_REQUEST_ID],
is_gen=True,
max_num_draft_tokens=runtime_draft_len,
use_mrope=self.config.use_mrope,
max_beam_width=self.config.max_beam_width)[0]
max_beam_width=self.config.max_beam_width,
draft_kv_cache_manager=draft_kv_cache_manager)[0]
self.padding_dummy_request.is_cuda_graph_dummy = True
spec_res_mgr = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)

View File

@ -44,8 +44,8 @@ from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids
from ..models.modeling_utils import DecoderModelForCausalLM
from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer,
MoeLoadBalancerIterContext)
from ..speculative import (SpecMetadata, get_num_extra_kv_tokens,
get_spec_metadata,
from ..speculative import (SpecMetadata, get_draft_kv_cache_manager,
get_num_extra_kv_tokens, get_spec_metadata,
update_spec_config_from_model_config)
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata,
@ -543,6 +543,15 @@ class PyTorchModelEngine(ModelEngine):
def use_beam_search(self):
return self.max_beam_width > 1
def _get_draft_kv_cache_manager(
self,
resource_manager: ResourceManager) -> Optional[KVCacheManager]:
"""
Returns the draft KV cache manager only in one-model speculative decoding
mode where the target model manages a separate draft KV cache.
"""
return get_draft_kv_cache_manager(self.spec_config, resource_manager)
@contextmanager
def set_warmup_flag(self):
prev_is_warmup = self.is_warmup
@ -852,6 +861,8 @@ class PyTorchModelEngine(ModelEngine):
"""A context manager to automatically free resources of a dummy batch."""
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
resource_manager)
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
try:
@ -860,6 +871,8 @@ class PyTorchModelEngine(ModelEngine):
if batch is not None and kv_cache_manager is not None:
for req in batch.all_requests():
kv_cache_manager.free_resources(req)
if draft_kv_cache_manager is not None:
draft_kv_cache_manager.free_resources(req)
if spec_resource_manager is not None:
spec_resource_manager.free_resources(req)
@ -882,6 +895,9 @@ class PyTorchModelEngine(ModelEngine):
"""Creates a generic dummy ScheduledRequests object for warmup."""
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
resource_manager)
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
@ -953,7 +969,8 @@ class PyTorchModelEngine(ModelEngine):
is_gen=False,
max_num_draft_tokens=self.runtime_draft_len,
use_mrope=self.use_mrope,
num_extra_decoding_steps=num_extra_decoding_steps)
num_extra_decoding_steps=num_extra_decoding_steps,
draft_kv_cache_manager=draft_kv_cache_manager)
if spec_resource_manager is not None:
spec_resource_manager.add_dummy_requests(
@ -969,7 +986,8 @@ class PyTorchModelEngine(ModelEngine):
max_num_draft_tokens=self.max_total_draft_tokens,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps)
num_extra_decoding_steps=num_extra_decoding_steps,
draft_kv_cache_manager=draft_kv_cache_manager)
if spec_resource_manager is not None:
spec_resource_manager.add_dummy_requests(request_ids=list(
range(num_ctx_requests, num_ctx_requests +
@ -991,6 +1009,8 @@ class PyTorchModelEngine(ModelEngine):
self.kv_cache_manager_key)
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
resource_manager)
available_blocks = kv_cache_manager.get_num_free_blocks(
) // self.max_beam_width
@ -1008,9 +1028,15 @@ class PyTorchModelEngine(ModelEngine):
max_num_draft_tokens=draft_len,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps)
num_extra_decoding_steps=num_extra_decoding_steps,
draft_kv_cache_manager=draft_kv_cache_manager)
available_tokens = kv_cache_manager.get_num_available_tokens(draft_len)
# Also consider draft KV cache capacity when it exists
if draft_kv_cache_manager is not None:
draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens(
draft_len)
available_tokens = min(available_tokens, draft_available_tokens)
# Add one dummy request with the maximum possible sequence length.
max_seq_len = self.max_seq_len if max_seq_len is None else max_seq_len
@ -1033,7 +1059,8 @@ class PyTorchModelEngine(ModelEngine):
max_num_draft_tokens=draft_len,
use_mrope=self.use_mrope,
max_beam_width=self.max_beam_width,
num_extra_decoding_steps=num_extra_decoding_steps)[0]
num_extra_decoding_steps=num_extra_decoding_steps,
draft_kv_cache_manager=draft_kv_cache_manager)[0]
# Insert the longest request first to simulate padding for the CUDA graph.
requests.insert(0, max_seq_len_request)
@ -1057,7 +1084,10 @@ class PyTorchModelEngine(ModelEngine):
req.py_is_first_draft = True
req.py_draft_tokens = []
def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
def _set_up_attn_metadata(
self,
kv_cache_manager: KVCacheManager,
draft_kv_cache_manager: Optional[KVCacheManager] = None):
enable_context_mla_with_cached_kv = is_mla(
self.model.model_config.pretrained_config) and (
self.attn_runtime_features.cache_reuse
@ -1106,6 +1136,7 @@ class PyTorchModelEngine(ModelEngine):
max_num_tokens=self.max_num_tokens,
max_num_sequences=self.batch_size * self.max_beam_width,
kv_cache_manager=kv_cache_manager,
draft_kv_cache_manager=draft_kv_cache_manager,
mapping=self.mapping,
runtime_features=self.attn_runtime_features,
enable_flash_mla=self.model.model_config.enable_flash_mla,
@ -1473,7 +1504,8 @@ class PyTorchModelEngine(ModelEngine):
else:
return self._apply_incremental_update_target(
scheduled_requests, kv_cache_manager, attn_metadata,
spec_metadata, new_tensors_device, num_accepted_tokens_device)
spec_metadata, new_tensors_device, num_accepted_tokens_device,
resource_manager)
@nvtx_range("_prepare_incremental_update_metadata")
def _prepare_incremental_update_metadata(
@ -1739,7 +1771,8 @@ class PyTorchModelEngine(ModelEngine):
attn_metadata: AttentionMetadata,
spec_metadata: Optional[SpecMetadata] = None,
new_tensors_device: Optional[SampleStateTensors] = None,
num_accepted_tokens_device: Optional[torch.Tensor] = None):
num_accepted_tokens_device: Optional[torch.Tensor] = None,
resource_manager: Optional[ResourceManager] = None):
# Extract tensors from new_tensors_device
new_tokens_device = new_tensors_device.new_tokens # [batch, 1 + draft_len]
new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch]
@ -1873,6 +1906,7 @@ class PyTorchModelEngine(ModelEngine):
'position_ids': final_position_ids,
'inputs_embeds': None,
"multimodal_params": [],
'resource_manager': resource_manager,
}
if bool(lora_params):
@ -2665,6 +2699,7 @@ class PyTorchModelEngine(ModelEngine):
'position_ids': final_position_ids,
'inputs_embeds': None,
"multimodal_params": multimodal_params_list,
'resource_manager': resource_manager,
}
if bool(lora_params):
@ -2730,7 +2765,8 @@ class PyTorchModelEngine(ModelEngine):
self,
scheduled_requests: ScheduledRequests,
attn_metadata: AttentionMetadata,
spec_metadata: Optional[SpecMetadata] = None):
spec_metadata: Optional[SpecMetadata] = None,
resource_manager: Optional[ResourceManager] = None):
"""
Prepare inputs for Pytorch Model.
"""
@ -2834,7 +2870,8 @@ class PyTorchModelEngine(ModelEngine):
'position_ids':
self.position_ids_cuda[:virtual_num_tokens].unsqueeze(0),
'inputs_embeds': None,
"multimodal_params": multimodal_params_list
"multimodal_params": multimodal_params_list,
'resource_manager': resource_manager,
}
if bool(lora_params):
@ -2877,10 +2914,12 @@ class PyTorchModelEngine(ModelEngine):
return inputs, None
def _prepare_star_attention_inputs(self,
scheduled_requests: ScheduledRequests,
kv_cache_manager,
attn_metadata: AttentionMetadata):
def _prepare_star_attention_inputs(
self,
scheduled_requests: ScheduledRequests,
kv_cache_manager,
attn_metadata: AttentionMetadata,
resource_manager: Optional[ResourceManager] = None):
"""
Prepare inputs for Pytorch Model.
"""
@ -3096,7 +3135,8 @@ class PyTorchModelEngine(ModelEngine):
'attn_metadata': attn_metadata,
'input_ids': self.input_ids_cuda[:num_tokens],
'position_ids': self.position_ids_cuda[:num_tokens].unsqueeze(0),
'inputs_embeds': None
'inputs_embeds': None,
'resource_manager': resource_manager,
}, gather_ids if is_spec_decode else None
def _get_lora_params_from_requests(self,
@ -3207,7 +3247,8 @@ class PyTorchModelEngine(ModelEngine):
cp_type = self.mapping.cp_config['cp_type']
if CpType.STAR == cp_type:
return self._prepare_star_attention_inputs(
scheduled_requests, kv_cache_manager, attn_metadata)
scheduled_requests, kv_cache_manager, attn_metadata,
resource_manager)
elif CpType.HELIX == cp_type:
# Take the usual route of _prepare_tp_inputs.
pass
@ -3235,8 +3276,11 @@ class PyTorchModelEngine(ModelEngine):
req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None):
kv_cache_manager = resource_manager.get_resource_manager(
self.kv_cache_manager_key)
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
resource_manager)
attn_metadata = self._set_up_attn_metadata(kv_cache_manager)
attn_metadata = self._set_up_attn_metadata(kv_cache_manager,
draft_kv_cache_manager)
if self.enable_spec_decode:
spec_resource_manager = resource_manager.get_resource_manager(
ResourceManagerType.SPEC_RESOURCE_MANAGER)
@ -3270,7 +3314,8 @@ class PyTorchModelEngine(ModelEngine):
if kv_cache_manager is None:
inputs, gather_ids = self._prepare_tp_inputs_no_cache(
scheduled_requests, attn_metadata, spec_metadata)
scheduled_requests, attn_metadata, spec_metadata,
resource_manager)
with MoeLoadBalancerIterContext(moe_load_balancer):
# Special handling for multimodal encoder only mode

View File

@ -610,6 +610,10 @@ def create_py_executor(
if model_engine.model.model_config.is_generation:
#NOTE: non-generation models do not have kv cache
# Get draft config for one-engine speculative decoding if available
draft_config = getattr(model_engine.model, 'draft_config', None)
kv_cache_creator = KvCacheCreator(
model_engine=model_engine,
draft_model_engine=draft_model_engine,
@ -627,6 +631,7 @@ def create_py_executor(
profiling_stage_data=profiling_stage_data,
sparse_attention_config=sparse_attention_config,
execution_stream=execution_stream,
draft_config=draft_config,
)
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
with allocation_scope(

View File

@ -519,6 +519,7 @@ class KVCacheManager(BaseResourceManager):
# we need to make the KV cache manager aware that multiple autoregressive steps will
# occur.
num_extra_decoding_steps: int = 0,
draft_kv_cache_manager: Optional[BaseResourceManager] = None,
):
beam_width = max_beam_width
requests = []
@ -553,6 +554,12 @@ class KVCacheManager(BaseResourceManager):
for _ in range(num_extra_decoding_steps):
self.impl.add_token(req_id)
if draft_kv_cache_manager is not None:
draft_kv_cache_manager.impl.add_sequence(
req_id, token_num, beam_width, req)
for _ in range(self.num_extra_kv_tokens):
draft_kv_cache_manager.impl.add_token(req_id)
if is_gen:
req.state = LlmRequestState.GENERATION_IN_PROGRESS
req.prompt_len = token_num - 1
@ -579,6 +586,10 @@ class KVCacheManager(BaseResourceManager):
for _ in range(max_num_draft_tokens):
self.impl.add_token(req_id)
if draft_kv_cache_manager is not None:
for _ in range(max_num_draft_tokens):
draft_kv_cache_manager.impl.add_token(req_id)
# TODO: Planning to get dummy_data from each model. Before that, we need to add dummy mrop_config to the request here.
if use_mrope:
dummy_mrope_position_ids = torch.arange(

View File

@ -1,14 +1,15 @@
from .auto_heuristic import suggest_spec_config
from .eagle3 import Eagle3SpecMetadata
from .interface import SpecMetadata, SpecWorkerBase
from .interface import (SpecMetadata, SpecWorkerBase,
should_use_separate_draft_kv_cache)
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
from .ngram import NGramDrafter, NGramPoolManager
from .save_hidden_state import SaveHiddenStatesDrafter
from .spec_tree_manager import SpecTreeManager
from .utils import (get_num_extra_kv_tokens, get_num_spec_layers,
get_spec_decoder, get_spec_drafter, get_spec_metadata,
get_spec_resource_manager, get_spec_worker,
update_spec_config_from_model_config)
from .utils import (get_draft_kv_cache_manager, get_num_extra_kv_tokens,
get_num_spec_layers, get_spec_decoder, get_spec_drafter,
get_spec_metadata, get_spec_resource_manager,
get_spec_worker, update_spec_config_from_model_config)
__all__ = [
"Eagle3SpecMetadata",
@ -20,6 +21,7 @@ __all__ = [
"SaveHiddenStatesDrafter",
"SpecMetadata",
"SpecWorkerBase",
"get_draft_kv_cache_manager",
"get_num_extra_kv_tokens",
"get_num_spec_layers",
"get_spec_decoder",
@ -27,6 +29,7 @@ __all__ = [
"get_spec_metadata",
"get_spec_resource_manager",
"get_spec_worker",
"should_use_separate_draft_kv_cache",
"update_spec_config_from_model_config",
"suggest_spec_config",
"SpecTreeManager",

View File

@ -358,8 +358,11 @@ class Eagle3OneModelSampler(MTPSampler):
class Eagle3OneModelWorker(SpecWorkerBase):
def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
super().__init__()
def __init__(self,
spec_config: "EagleDecodingConfig",
mapping: Mapping,
use_separate_draft_kv_cache: bool = False):
super().__init__(use_separate_draft_kv_cache)
self.spec_config = spec_config
self.mapping = mapping
@ -369,8 +372,15 @@ class Eagle3OneModelWorker(SpecWorkerBase):
# Skip torch.compile for now since current Torch is not compatible with Triton 3.4
# @torch.compile(options={"max-autotune": True})
def forward(self, input_ids, position_ids, hidden_states, logits,
attn_metadata, spec_metadata, draft_model):
def forward(self,
input_ids,
position_ids,
hidden_states,
logits,
attn_metadata,
spec_metadata,
draft_model,
resource_manager=None):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
num_gens = batch_size - num_contexts
@ -401,81 +411,91 @@ class Eagle3OneModelWorker(SpecWorkerBase):
# Predict draft tokens
next_draft_tokens = []
original_all_rank_num_tokens = attn_metadata.all_rank_num_tokens
for i in range(self.max_draft_len):
if i == 0:
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
(self.max_draft_len + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
gather_ids = torch.concat(
[spec_metadata.gather_ids[:num_contexts], gather_ids_gen],
dim=0)
else:
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
if self.guided_decoder is not None:
new_tokens = inputs["input_ids"][gather_ids]
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
draft_step=i)
# Get the draft KV cache manager if using separate layouts
draft_kv_cache_manager = self.get_draft_kv_cache_manager(
resource_manager)
# Update attn_metadata.all_rank_num_tokens for attention DP
if original_all_rank_num_tokens is not None:
with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
for i in range(self.max_draft_len):
if i == 0:
attn_metadata.all_rank_num_tokens = original_all_rank_num_tokens
elif spec_metadata.all_rank_num_seqs is not None:
attn_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs
start_ids_gen = (
spec_metadata.batch_indices_cuda[:num_gens] *
(self.max_draft_len + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
gather_ids = torch.concat([
spec_metadata.gather_ids[:num_contexts], gather_ids_gen
],
dim=0)
else:
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
hidden_states, hidden_states_to_save = draft_model.model(**inputs)
# FIXME (jhaotingc): Currently we disable use_spec_decoding mode for Eagle engine nth steps except 1st step.
# Eagle engine takes in draft_len tokens from the previous step, run spec-dec mode with those tokens,
# then the following step can use regular decoding mode to generate 1 tokens per step.
# Currently the spec-dec mask for chained tree is not implemented yet.
# When token tree is supported, this can be removed and all steps may use spec-dec mode as well.
attn_metadata.use_spec_decoding = False
logits = draft_model.logits_processor(hidden_states[gather_ids],
draft_model.lm_head,
attn_metadata, True)
if self.guided_decoder is not None:
d2t = getattr(draft_model.model, "d2t", None)
self.guided_decoder.execute_draft_batch(logits,
d2t,
if self.guided_decoder is not None:
new_tokens = inputs["input_ids"][gather_ids]
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
draft_step=i)
new_draft_token = self.draft_decoder(logits, draft_model)
next_draft_tokens.append(new_draft_token)
# update inputs
hidden_states = hidden_states_to_save[gather_ids]
position_ids = inputs["position_ids"][gather_ids] + 1
# update attn_metadata
if i == 0:
attn_metadata._seq_lens[:batch_size].fill_(1)
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
attn_metadata.on_update()
# cannot run generation if their is no kv cache
if inputs["attn_metadata"].kv_cache_manager is not None:
attn_metadata.host_request_types[:attn_metadata.
num_contexts].fill_(1)
attn_metadata.num_contexts = 0
# update kv_lens_cuda
if hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
self.max_draft_len - num_accepted_tokens[num_contexts:])
attn_metadata.kv_lens_cuda[:num_contexts] += 1
elif hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[:batch_size] += 1
# support attention dp
inputs = {
"input_ids": new_draft_token,
"position_ids": position_ids,
"hidden_states": hidden_states,
"attn_metadata": attn_metadata,
"spec_metadata": spec_metadata,
}
# Update attn_metadata.all_rank_num_tokens for attention DP
if original_all_rank_num_tokens is not None:
if i == 0:
attn_metadata.all_rank_num_tokens = original_all_rank_num_tokens
elif spec_metadata.all_rank_num_seqs is not None:
attn_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs
hidden_states, hidden_states_to_save = draft_model.model(
**inputs)
# FIXME (jhaotingc): Currently we disable use_spec_decoding mode for Eagle engine nth steps except 1st step.
# Eagle engine takes in draft_len tokens from the previous step, run spec-dec mode with those tokens,
# then the following step can use regular decoding mode to generate 1 tokens per step.
# Currently the spec-dec mask for chained tree is not implemented yet.
# When token tree is supported, this can be removed and all steps may use spec-dec mode as well.
attn_metadata.use_spec_decoding = False
logits = draft_model.logits_processor(hidden_states[gather_ids],
draft_model.lm_head,
attn_metadata, True)
if self.guided_decoder is not None:
d2t = getattr(draft_model.model, "d2t", None)
self.guided_decoder.execute_draft_batch(logits,
d2t,
draft_step=i)
new_draft_token = self.draft_decoder(logits, draft_model)
next_draft_tokens.append(new_draft_token)
# update inputs
hidden_states = hidden_states_to_save[gather_ids]
position_ids = inputs["position_ids"][gather_ids] + 1
# update attn_metadata
if i == 0:
attn_metadata._seq_lens[:batch_size].fill_(1)
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
attn_metadata.on_update()
# cannot run generation if their is no kv cache
if inputs["attn_metadata"].kv_cache_manager is not None:
attn_metadata.host_request_types[:attn_metadata.
num_contexts].fill_(1)
attn_metadata.num_contexts = 0
# update kv_lens_cuda
if hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
self.max_draft_len -
num_accepted_tokens[num_contexts:])
attn_metadata.kv_lens_cuda[:num_contexts] += 1
elif hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[:batch_size] += 1
# support attention dp
inputs = {
"input_ids": new_draft_token,
"position_ids": position_ids,
"hidden_states": hidden_states,
"attn_metadata": attn_metadata,
"spec_metadata": spec_metadata,
}
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
# restore attn_metadata to support cuda graph

View File

@ -1,6 +1,7 @@
import copy
import os
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import IntEnum, auto
from typing import TYPE_CHECKING, List, Optional, Type
@ -12,7 +13,8 @@ from tensorrt_llm.logger import logger
from ..._utils import get_sm_version
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
from ..pyexecutor.resource_manager import BaseResourceManager
from ..pyexecutor.resource_manager import (BaseResourceManager,
ResourceManagerType)
if TYPE_CHECKING:
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
@ -21,6 +23,21 @@ if TYPE_CHECKING:
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
def should_use_separate_draft_kv_cache(spec_config) -> bool:
"""
Check if separate draft KV cache should be used for one-engine speculative decoding.
Args:
spec_config: The speculative decoding config (e.g., EagleDecodingConfig).
Returns:
bool: True if separate draft KV cache should be used.
"""
if spec_config is None:
return False
return getattr(spec_config, 'use_separate_draft_kv_cache', False)
def get_force_num_accepted_tokens() -> int:
"""
Read and parse the TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS environment variable.
@ -367,10 +384,11 @@ class SpecWorkerBase(nn.Module, ABC):
Provides common functionality for sampling and token handling.
"""
def __init__(self):
def __init__(self, use_separate_draft_kv_cache: bool = False):
super().__init__()
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
self.use_separate_draft_kv_cache = use_separate_draft_kv_cache
@property
@abstractmethod
@ -385,6 +403,47 @@ class SpecWorkerBase(nn.Module, ABC):
self.guided_decoder = guided_decoder
return True
def get_draft_kv_cache_manager(self, resource_manager):
"""
Get the draft KV cache manager if using separate KV cache layouts.
"""
if self.use_separate_draft_kv_cache and resource_manager is not None:
return resource_manager.get_resource_manager(
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
return None
@contextmanager
def draft_kv_cache_context(self, attn_metadata, draft_kv_cache_manager):
"""
Context manager to temporarily switch to draft KV cache manager.
This swaps both the kv_cache_manager reference AND the block offset tensors,
since the main and draft KV caches have different block layouts.
"""
if draft_kv_cache_manager is None:
yield
return
# Save main KV cache manager and block offsets
target_kv_cache_manager = attn_metadata.kv_cache_manager
target_kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets
target_host_kv_cache_block_offsets = attn_metadata.host_kv_cache_block_offsets
# Switch to draft KV cache manager and its block offsets
attn_metadata.kv_cache_manager = draft_kv_cache_manager
if hasattr(attn_metadata, 'draft_kv_cache_block_offsets'
) and attn_metadata.draft_kv_cache_block_offsets is not None:
attn_metadata.kv_cache_block_offsets = attn_metadata.draft_kv_cache_block_offsets
attn_metadata.host_kv_cache_block_offsets = attn_metadata.draft_host_kv_cache_block_offsets
try:
yield
finally:
# Restore main KV cache manager and block offsets
attn_metadata.kv_cache_manager = target_kv_cache_manager
attn_metadata.kv_cache_block_offsets = target_kv_cache_block_offsets
attn_metadata.host_kv_cache_block_offsets = target_host_kv_cache_block_offsets
def _sample_tokens_for_batch(
self,
logits: torch.Tensor,

View File

@ -349,8 +349,11 @@ class MTPSampler(TorchSampler):
class MTPWorker(SpecWorkerBase):
def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
super().__init__()
def __init__(self,
spec_config: "MTPDecodingConfig",
model_config=None,
use_separate_draft_kv_cache: bool = False):
super().__init__(use_separate_draft_kv_cache)
self.spec_config = spec_config
self.model_config = model_config
self.is_thop = False
@ -368,6 +371,7 @@ class MTPWorker(SpecWorkerBase):
attn_metadata,
spec_metadata,
draft_model,
resource_manager=None,
):
'''
Example:
@ -501,43 +505,50 @@ class MTPWorker(SpecWorkerBase):
# update attn metadata
if attn_metadata is not None:
self.change_attn_metadata(num_accepted_tokens, attn_metadata)
draft_inputs.update(attn_metadata=attn_metadata)
# Run MTP layers to predict draft tokens
next_draft_tokens = []
last_tokens_idx = torch.cumsum(
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
for i, mtp_layer in enumerate(draft_model.mtp_layers):
if self.guided_decoder is not None:
new_tokens = draft_inputs['input_ids'][last_tokens_idx]
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
draft_step=i)
hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens,
**draft_inputs)
logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head,
attn_metadata).float()
if self.guided_decoder is not None:
self.guided_decoder.execute_draft_batch(logits, draft_step=i)
draft_kv_cache_manager = self.get_draft_kv_cache_manager(
resource_manager)
new_draft_token = self.draft_sampler(logits)
next_draft_tokens.append(new_draft_token)
# shift input_ids and hidden_states
input_ids = draft_inputs["input_ids"]
input_ids[:-1] = input_ids[1:].clone()
input_ids[last_tokens_idx] = new_draft_token
draft_hidden_states = draft_inputs["hidden_states"]
draft_hidden_states[:-1] = draft_hidden_states[1:].clone()
draft_hidden_states[last_tokens_idx] = hidden_states[
last_tokens_idx, :]
draft_inputs = {
"input_ids": input_ids,
"position_ids": draft_inputs["position_ids"],
"hidden_states": draft_hidden_states,
"attn_metadata": draft_inputs["attn_metadata"],
}
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
for i, mtp_layer in enumerate(draft_model.mtp_layers):
if self.guided_decoder is not None:
new_tokens = draft_inputs['input_ids'][last_tokens_idx]
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
draft_step=i)
hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens,
**draft_inputs)
logits = mtp_layer.shared_head(hidden_states,
draft_model.lm_head,
attn_metadata).float()
if self.guided_decoder is not None:
self.guided_decoder.execute_draft_batch(logits,
draft_step=i)
new_draft_token = self.draft_sampler(logits)
next_draft_tokens.append(new_draft_token)
# shift input_ids and hidden_states
input_ids = draft_inputs["input_ids"]
input_ids[:-1] = input_ids[1:].clone()
input_ids[last_tokens_idx] = new_draft_token
draft_hidden_states = draft_inputs["hidden_states"]
draft_hidden_states[:-1] = draft_hidden_states[1:].clone()
draft_hidden_states[last_tokens_idx] = hidden_states[
last_tokens_idx, :]
draft_inputs = {
"input_ids": input_ids,
"position_ids": draft_inputs["position_ids"],
"hidden_states": draft_hidden_states,
"attn_metadata": draft_inputs["attn_metadata"],
}
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
# restore attn metadata
if attn_metadata is not None:
@ -567,6 +578,7 @@ class MTPWorker(SpecWorkerBase):
attn_metadata,
spec_metadata,
draft_model,
resource_manager=None,
):
batch_size = attn_metadata.num_seqs
mtp_num_modules = self.spec_config.num_nextn_predict_layers
@ -1178,8 +1190,9 @@ class MTPEagleWorker(MTPWorker):
def __init__(self,
spec_config: "MTPDecodingConfig",
model_config: Optional[ModelConfig] = None):
super().__init__(spec_config, model_config)
model_config: Optional[ModelConfig] = None,
use_separate_draft_kv_cache: bool = False):
super().__init__(spec_config, model_config, use_separate_draft_kv_cache)
self.model_config = model_config
self.mtp_num_modules = spec_config.num_nextn_predict_layers
@ -1201,6 +1214,7 @@ class MTPEagleWorker(MTPWorker):
attn_metadata,
spec_metadata,
draft_model,
resource_manager=None,
):
batch_size = attn_metadata.num_seqs
num_contexts = attn_metadata.num_contexts
@ -1236,123 +1250,134 @@ class MTPEagleWorker(MTPWorker):
attn_metadata=attn_metadata,
spec_metadata=spec_metadata)
# Get the draft KV cache manager if using separate layouts
draft_kv_cache_manager = self.get_draft_kv_cache_manager(
resource_manager)
# Predict draft tokens
next_draft_tokens = []
for i in range(self.mtp_num_modules):
if i == 0:
hidden_states = draft_model.mtp_layers[0](
embed_tokens=draft_model.embed_tokens,
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
**inputs)
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
(self.mtp_num_modules + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
gather_ids = torch.concat(
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
else:
hidden_states = draft_model.mtp_layers[0](
embed_tokens=draft_model.embed_tokens,
all_rank_num_tokens=spec_metadata.
subseq_all_rank_num_tokens,
**inputs)
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
for i in range(self.mtp_num_modules):
if i == 0:
hidden_states = draft_model.mtp_layers[0](
embed_tokens=draft_model.embed_tokens,
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
**inputs)
if self.guided_decoder is not None:
new_tokens = inputs["input_ids"][gather_ids]
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
draft_step=i)
if self.model_config.mapping.enable_attention_dp and \
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
hidden_states_gathered = hidden_states[gather_ids]
token_count = hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1]).shape[0]
max_num_requests = spec_metadata.max_num_requests
pad_len = max_num_requests - token_count
if pad_len > 0:
padded_hidden_states = F.pad(hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1]),
(0, 0, 0, pad_len),
mode="constant",
value=0)
elif pad_len == 0:
padded_hidden_states = hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1])
start_ids_gen = (
spec_metadata.batch_indices_cuda[:num_gens] *
(self.mtp_num_modules + 1)).long()
gather_ids_gen = (start_ids_gen +
num_accepted_tokens[num_contexts:] - 1 +
attn_metadata.num_ctx_tokens)
gather_ids = torch.concat(
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
else:
raise ValueError(
f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported"
)
logits = draft_model.mtp_layers[0].shared_head(
padded_hidden_states, draft_model.lm_head, attn_metadata,
True)
else:
logits = draft_model.mtp_layers[0].shared_head(
hidden_states[gather_ids], draft_model.lm_head,
attn_metadata, True)
if self.guided_decoder is not None:
self.guided_decoder.execute_draft_batch(logits, draft_step=i)
hidden_states = draft_model.mtp_layers[0](
embed_tokens=draft_model.embed_tokens,
all_rank_num_tokens=spec_metadata.
subseq_all_rank_num_tokens,
**inputs)
if self.model_config.mapping.enable_attention_dp and \
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
mapping_lm_head_tp = draft_model.mtp_layers[
0].shared_head.mapping_lm_head_tp
new_draft_token = self.draft_sampler(logits, mapping_lm_head_tp)
new_draft_token = new_draft_token[:token_count]
else:
new_draft_token = self.draft_sampler(logits)
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
hidden_states, position_ids = self.update_draft_tokens(
next_draft_tokens, new_draft_token, hidden_states, gather_ids,
inputs)
# update attn_metadata
if i == 0:
attn_metadata._seq_lens[:batch_size].fill_(1)
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
attn_metadata.on_update()
# cannot run generation if their is no kv cache
has_kv_cache = inputs[
"attn_metadata"].kv_cache_manager is not None
if has_kv_cache:
attn_metadata.host_request_types[:attn_metadata.
num_contexts].fill_(1)
attn_metadata.num_contexts = 0
# update kv_lens_cuda
if hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
self.mtp_num_modules -
num_accepted_tokens[num_contexts:])
attn_metadata.kv_lens_cuda[:num_contexts] += 1
# update metadata for flash mla
if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla:
reorder_block_ids_per_seq = torch.cat([
attn_metadata.
kv_block_ids_per_seq[num_contexts:batch_size],
attn_metadata.kv_block_ids_per_seq[:num_contexts]
])
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
reorder_block_ids_per_seq, non_blocking=True)
# update metadata
# some attention metadata needs to be updated when changing seq_lens/kv_lens
attn_metadata.update_for_spec_dec()
elif hasattr(attn_metadata, 'kv_lens_cuda'):
if self.guided_decoder is not None:
new_tokens = inputs["input_ids"][gather_ids]
self.guided_decoder.add_draft_batch(new_tokens,
num_accepted_tokens,
draft_step=i)
if self.model_config.mapping.enable_attention_dp and \
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
hidden_states_gathered = hidden_states[gather_ids]
token_count = hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1]).shape[0]
max_num_requests = spec_metadata.max_num_requests
pad_len = max_num_requests - token_count
if pad_len > 0:
padded_hidden_states = F.pad(
hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1]),
(0, 0, 0, pad_len),
mode="constant",
value=0)
elif pad_len == 0:
padded_hidden_states = hidden_states_gathered.view(
-1, hidden_states_gathered.shape[-1])
else:
raise ValueError(
f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported"
)
logits = draft_model.mtp_layers[0].shared_head(
padded_hidden_states, draft_model.lm_head,
attn_metadata, True)
else:
logits = draft_model.mtp_layers[0].shared_head(
hidden_states[gather_ids], draft_model.lm_head,
attn_metadata, True)
if self.guided_decoder is not None:
self.guided_decoder.execute_draft_batch(logits,
draft_step=i)
@torch.compile(options={"max-autotune": True})
def update_kv_lens(kv_lens_cuda, batch_size):
kv_lens_cuda[:batch_size] += 1
if self.model_config.mapping.enable_attention_dp and \
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
mapping_lm_head_tp = draft_model.mtp_layers[
0].shared_head.mapping_lm_head_tp
new_draft_token = self.draft_sampler(
logits, mapping_lm_head_tp)
new_draft_token = new_draft_token[:token_count]
else:
new_draft_token = self.draft_sampler(logits)
update_kv_lens(attn_metadata.kv_lens_cuda, batch_size)
# update metadata
# some attention metadata needs to be updated when changing kv_lens
attn_metadata.update_for_spec_dec()
inputs = {
"input_ids": new_draft_token,
"position_ids": position_ids,
"hidden_states": hidden_states,
"attn_metadata": attn_metadata,
}
hidden_states, position_ids = self.update_draft_tokens(
next_draft_tokens, new_draft_token, hidden_states,
gather_ids, inputs)
# update attn_metadata
if i == 0:
attn_metadata._seq_lens[:batch_size].fill_(1)
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
attn_metadata.on_update()
# cannot run generation if their is no kv cache
has_kv_cache = inputs[
"attn_metadata"].kv_cache_manager is not None
if has_kv_cache:
attn_metadata.host_request_types[:attn_metadata.
num_contexts].fill_(1)
attn_metadata.num_contexts = 0
# update kv_lens_cuda
if hasattr(attn_metadata, 'kv_lens_cuda'):
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
self.mtp_num_modules -
num_accepted_tokens[num_contexts:])
attn_metadata.kv_lens_cuda[:num_contexts] += 1
# update metadata for flash mla
if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla:
reorder_block_ids_per_seq = torch.cat([
attn_metadata.
kv_block_ids_per_seq[num_contexts:batch_size],
attn_metadata.kv_block_ids_per_seq[:num_contexts]
])
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
reorder_block_ids_per_seq, non_blocking=True)
# update metadata
# some attention metadata needs to be updated when changing seq_lens/kv_lens
attn_metadata.update_for_spec_dec()
elif hasattr(attn_metadata, 'kv_lens_cuda'):
@torch.compile(options={"max-autotune": True})
def update_kv_lens(kv_lens_cuda, batch_size):
kv_lens_cuda[:batch_size] += 1
update_kv_lens(attn_metadata.kv_lens_cuda, batch_size)
# update metadata
# some attention metadata needs to be updated when changing kv_lens
attn_metadata.update_for_spec_dec()
inputs = {
"input_ids": new_draft_token,
"position_ids": position_ids,
"hidden_states": hidden_states,
"attn_metadata": attn_metadata,
}
# restore attn_metadata to support cuda graph
attn_metadata.restore_from_spec_dec()

View File

@ -219,14 +219,19 @@ def get_num_spec_layers(spec_config):
return 0
def get_spec_worker(spec_config, model_config, mapping):
def get_spec_worker(spec_config,
model_config,
mapping,
use_separate_draft_kv_cache: bool = False):
spec_dec_mode = spec_config.spec_dec_mode
if spec_dec_mode.is_mtp_vanilla():
return MTPWorker(spec_config, model_config)
return MTPWorker(spec_config, model_config, use_separate_draft_kv_cache)
if spec_dec_mode.is_mtp_eagle_one_model():
return MTPEagleWorker(spec_config, model_config)
return MTPEagleWorker(spec_config, model_config,
use_separate_draft_kv_cache)
if spec_dec_mode.is_eagle3_one_model():
return Eagle3OneModelWorker(spec_config, mapping)
return Eagle3OneModelWorker(spec_config, mapping,
use_separate_draft_kv_cache)
return None
@ -242,6 +247,21 @@ def get_num_extra_kv_tokens(spec_config):
return 0
def get_draft_kv_cache_manager(spec_config, resource_manager):
"""
Returns the draft KV cache manager only in one-model speculative decoding
mode where the target model manages a separate draft KV cache.
"""
from ..pyexecutor.resource_manager import ResourceManagerType
if spec_config is None:
return None
if not spec_config.spec_dec_mode.use_one_engine():
return None
return resource_manager.get_resource_manager(
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
def update_spec_config_from_model_config(spec_config, model_config):
if spec_config.spec_dec_mode.is_mtp_one_model():
# Use `max_draft_len` for several low-level APIs. TODO: Remove this after distinguishing them.

View File

@ -857,6 +857,8 @@ class EagleDecodingConfig(DecodingBaseConfig):
# The model architecture of the eagle3 model.
# choices: llama3, mistral_large3
eagle3_model_arch: str = "llama3"
# Whether to use a separate KV cache manager for the draft model in one-model mode.
use_separate_draft_kv_cache: bool = False
def __init__(self, **kwargs):
super().__init__()
@ -1090,6 +1092,8 @@ class MTPDecodingConfig(DecodingBaseConfig):
relaxed_delta: float = 0.
use_mtp_vanilla: bool = False
mtp_eagle_one_model: bool = True
# Whether to use a separate KV cache manager for the draft model in one-model mode.
use_separate_draft_kv_cache: bool = False
# TODO: remove this after distinguishing `max_draft_len` and `num_nextn_predict_layers`
# Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine.

View File

@ -92,39 +92,71 @@ def test_kv_lens_runtime_with_eagle3_one_model():
@pytest.mark.parametrize(
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp",
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp,use_separate_draft_kv_cache",
[
[True, "TRTLLM", True, False, False, False, True, False, False],
[True, "TRTLLM", True, False, False, False, False, False, False],
[False, "TRTLLM", True, False, False, False, True, False, False],
[False, "TRTLLM", True, False, False, False, False, False, False],
[True, "FLASHINFER", True, False, False, False, True, False, False],
[False, "FLASHINFER", True, False, False, False, True, False, False],
[False, "TRTLLM", False, True, True, False, True, False, False],
[True, "TRTLLM", False, True, True, False, True, False, False],
[True, "TRTLLM", True, False, True, True, True, False, False],
[True, "TRTLLM", True, False, True, False, True, False, False],
[True, "TRTLLM", True, False, False, True, True, False, False],
[True, "TRTLLM", False, False, False, False, True, False, False],
[False, "TRTLLM", False, False, False, False, True, False, False],
[True, "TRTLLM", False, False, False, False, False, True, False],
[True, "TRTLLM", False, False, False, False, False, True, True],
[False, "TRTLLM", False, False, False, False, False, True, False],
[True, "TRTLLM", False, False, False, False, True, True, False],
[False, "TRTLLM", False, False, False, False, True, True, False],
[True, "TRTLLM", False, False, False, False, False, False, False],
[False, "TRTLLM", False, False, False, False, False, False, False],
[True, "TRTLLM", False, False, False, True, True, False, False],
[True, "TRTLLM", False, False, False, True, False, False, False],
[True, "FLASHINFER", False, False, False, False, True, False, False],
[False, "FLASHINFER", False, False, False, False, True, False, False],
[True, "TRTLLM", True, False, False, False, True, False, False, False],
[True, "TRTLLM", True, False, False, False, False, False, False, False],
[False, "TRTLLM", True, False, False, False, True, False, False, False],
[
False, "TRTLLM", True, False, False, False, False, False, False,
False
],
[
True, "FLASHINFER", True, False, False, False, True, False, False,
False
],
[
False, "FLASHINFER", True, False, False, False, True, False, False,
False
],
[False, "TRTLLM", False, True, True, False, True, False, False, False],
[True, "TRTLLM", False, True, True, False, True, False, False, False],
[True, "TRTLLM", True, False, True, True, True, False, False, False],
[True, "TRTLLM", True, False, True, False, True, False, False, False],
[True, "TRTLLM", True, False, False, True, True, False, False, False],
[True, "TRTLLM", False, False, False, False, True, False, False, False],
[
False, "TRTLLM", False, False, False, False, True, False, False,
False
],
[True, "TRTLLM", False, False, False, False, False, True, False, False],
[True, "TRTLLM", False, False, False, False, False, True, True, False],
[
False, "TRTLLM", False, False, False, False, False, True, False,
False
],
[True, "TRTLLM", False, False, False, False, True, True, False, False],
[False, "TRTLLM", False, False, False, False, True, True, False, False],
[
True, "TRTLLM", False, False, False, False, False, False, False,
False
],
[
False, "TRTLLM", False, False, False, False, False, False, False,
False
],
[True, "TRTLLM", False, False, False, True, True, False, False, False],
[True, "TRTLLM", False, False, False, True, False, False, False, False],
[
True, "FLASHINFER", False, False, False, False, True, False, False,
False
],
[
False, "FLASHINFER", False, False, False, False, True, False, False,
False
],
# Test use_separate_draft_kv_cache with one-model mode
[True, "TRTLLM", False, True, True, False, False, True, False, True],
[True, "TRTLLM", True, False, True, False, False, False, False, True],
[False, "TRTLLM", False, True, True, True, False, True, False, True],
[False, "TRTLLM", True, False, True, False, False, True, False, True],
])
@pytest.mark.high_cuda_memory
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
disable_overlap_scheduler: bool, enable_block_reuse: bool,
use_one_model: bool, enable_chunked_prefill: bool,
use_chain_drafter: bool, multi_batch: bool,
attention_dp: bool, request):
attention_dp: bool, use_separate_draft_kv_cache: bool):
# Eagle3 one model works with overlap scheduler and block reuse.
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
if total_mem_gb < 35:
@ -168,6 +200,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
speculative_model_dir=eagle_model_dir,
# Llama 3 does not support one model eagle.
eagle3_one_model=use_one_model,
use_separate_draft_kv_cache=use_separate_draft_kv_cache,
)
spec_config._allow_chain_drafter = use_chain_drafter