mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 15:55:08 +08:00
[TRTLLM-10321][feat] Support different KV cache layout for one-model spec dec (#10502)
Signed-off-by: ziyixiong-nv <219238287+ziyixiong-nv@users.noreply.github.com>
This commit is contained in:
parent
092f4ce774
commit
e76b634251
@ -64,6 +64,8 @@ class AttentionMetadata:
|
||||
max_num_sequences: Optional[int] = None
|
||||
# The KV cache manager.
|
||||
kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2]
|
||||
# Draft KV cache manager for one-model speculative decoding with separate KV cache layouts
|
||||
draft_kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2] = None
|
||||
mapping: Optional[Mapping] = None
|
||||
|
||||
enable_flash_mla: bool = False
|
||||
|
||||
@ -974,6 +974,7 @@ class RocketKVCacheManager(KVCacheManager):
|
||||
use_mrope: bool = False,
|
||||
max_beam_width: int = 1,
|
||||
num_extra_decoding_steps: int = 0,
|
||||
draft_kv_cache_manager=None,
|
||||
):
|
||||
requests = super().add_dummy_requests(
|
||||
request_ids=request_ids,
|
||||
@ -984,6 +985,7 @@ class RocketKVCacheManager(KVCacheManager):
|
||||
use_mrope=use_mrope,
|
||||
max_beam_width=max_beam_width,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager,
|
||||
)
|
||||
if prepare_resource:
|
||||
for req in requests:
|
||||
|
||||
@ -679,6 +679,12 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
helix_is_inactive_rank: Optional[torch.Tensor] = None
|
||||
helix_is_inactive_rank_cpu: Optional[torch.Tensor] = None
|
||||
|
||||
# Block offsets for the target and draft KV caches
|
||||
kv_cache_block_offsets: Optional[torch.Tensor] = None
|
||||
host_kv_cache_block_offsets: Optional[torch.Tensor] = None
|
||||
draft_kv_cache_block_offsets: Optional[torch.Tensor] = None
|
||||
draft_host_kv_cache_block_offsets: Optional[torch.Tensor] = None
|
||||
|
||||
@property
|
||||
def max_seq_len(self) -> int:
|
||||
"""
|
||||
@ -786,6 +792,27 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
)
|
||||
self.block_ids_per_seq = None
|
||||
self.kv_block_ids_per_seq = None
|
||||
|
||||
# Allocate separate block offset tensors for draft KV cache manager
|
||||
# Used in one-model speculative decoding with different KV cache layouts
|
||||
if self.draft_kv_cache_manager is not None:
|
||||
self.draft_kv_cache_block_offsets = self.get_empty(
|
||||
buffers,
|
||||
[
|
||||
self.draft_kv_cache_manager.num_pools,
|
||||
self.max_num_sequences, 2,
|
||||
self.draft_kv_cache_manager.max_blocks_per_seq
|
||||
],
|
||||
cache_name="draft_kv_cache_block_offsets",
|
||||
dtype=torch.int32,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.draft_host_kv_cache_block_offsets = torch.empty_like(
|
||||
self.draft_kv_cache_block_offsets,
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
if self.enable_flash_mla:
|
||||
self.block_ids_per_seq = self.get_empty(
|
||||
buffers,
|
||||
@ -987,6 +1014,25 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
assert self.kv_lens[:self.num_seqs].max(
|
||||
) <= self.kv_cache_manager.max_seq_len, error_message
|
||||
|
||||
# Also prepare draft KV cache block offsets if draft_kv_cache_manager exists
|
||||
if self.draft_kv_cache_manager is not None:
|
||||
# Copy blocks for all context requests
|
||||
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
|
||||
self.draft_host_kv_cache_block_offsets,
|
||||
self.request_ids[:self.num_contexts], 1, 0)
|
||||
# Copy blocks for all generation requests
|
||||
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
|
||||
self.draft_host_kv_cache_block_offsets,
|
||||
self.request_ids[self.num_contexts:], self.beam_width,
|
||||
self.num_contexts)
|
||||
for pool_idx in range(
|
||||
self.draft_host_kv_cache_block_offsets.shape[0]):
|
||||
self.draft_kv_cache_block_offsets[
|
||||
pool_idx, :self.num_seqs].copy_(
|
||||
self.draft_host_kv_cache_block_offsets[
|
||||
pool_idx, :self.num_seqs],
|
||||
non_blocking=True)
|
||||
|
||||
self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
|
||||
# Don't use self.kv_lens here because it includes extra tokens.
|
||||
# Use actual KV length (without extra tokens) for kv_lens_runtime,
|
||||
|
||||
@ -1840,16 +1840,18 @@ class DeepseekV3ForCausalLM(SpecDecOneEngineForCausalLM[DeepseekV3Model,
|
||||
input_ids: torch.IntTensor = None,
|
||||
position_ids: Optional[torch.IntTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
return_context_logits: bool = False,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
resource_manager=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return super().forward(attn_metadata=attn_metadata,
|
||||
input_ids=input_ids,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
spec_metadata=spec_metadata,
|
||||
return_context_logits=return_context_logits,
|
||||
spec_metadata=spec_metadata,
|
||||
resource_manager=resource_manager,
|
||||
**kwargs)
|
||||
|
||||
def load_weights(self, weights: ConsumableWeightsDict):
|
||||
|
||||
@ -1401,6 +1401,7 @@ class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
return_context_logits: bool = False,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
resource_manager=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
multimodal_params = kwargs.get("multimodal_params", [])
|
||||
@ -1422,7 +1423,8 @@ class Llama4ForConditionalGeneration(SpecDecOneEngineForCausalLM[Llama4Model,
|
||||
position_ids,
|
||||
inputs_embeds,
|
||||
spec_metadata=spec_metadata,
|
||||
return_context_logits=return_context_logits)
|
||||
return_context_logits=return_context_logits,
|
||||
resource_manager=resource_manager)
|
||||
|
||||
def infer_max_seq_len(self):
|
||||
if self.model_config.attn_backend.upper() != 'TRTLLM':
|
||||
|
||||
@ -19,7 +19,8 @@ from ..modules.linear import (Linear, TensorParallelMode, WeightMode,
|
||||
WeightsLoadingConfig)
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
|
||||
from ..speculative import SpecMetadata, get_spec_worker
|
||||
from ..speculative import (SpecMetadata, get_spec_worker,
|
||||
should_use_separate_draft_kv_cache)
|
||||
from ..utils import AuxStreamType
|
||||
from .checkpoints.base_weight_mapper import BaseWeightMapper
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel,
|
||||
@ -931,6 +932,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
vocab_size=model_config.pretrained_config.vocab_size)
|
||||
self.draft_model = None
|
||||
self.draft_config = None
|
||||
self.use_separate_draft_kv_cache = False
|
||||
spec_config = getattr(model_config, 'spec_config', None)
|
||||
if spec_config and spec_config.spec_dec_mode.use_one_engine():
|
||||
if spec_config.spec_dec_mode.is_eagle3_one_model():
|
||||
@ -964,11 +966,16 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
self.draft_config.quant_config.kv_cache_quant_algo = \
|
||||
model_config.quant_config.kv_cache_quant_algo
|
||||
|
||||
self.use_separate_draft_kv_cache = should_use_separate_draft_kv_cache(
|
||||
spec_config)
|
||||
|
||||
self.draft_model = get_draft_model(model_config, self.draft_config,
|
||||
self.lm_head, self.model)
|
||||
self.spec_worker = get_spec_worker(model_config.spec_config,
|
||||
model_config,
|
||||
model_config.mapping)
|
||||
self.spec_worker = get_spec_worker(
|
||||
model_config.spec_config,
|
||||
model_config,
|
||||
model_config.mapping,
|
||||
use_separate_draft_kv_cache=self.use_separate_draft_kv_cache)
|
||||
self.epilogue.append(self.draft_model)
|
||||
self.epilogue.append(self.spec_worker)
|
||||
|
||||
@ -987,6 +994,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
return_context_logits: bool = False,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
resource_manager=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
@ -1030,7 +1038,8 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
logits=logits,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata,
|
||||
draft_model=self.draft_model)
|
||||
draft_model=self.draft_model,
|
||||
resource_manager=resource_manager)
|
||||
else:
|
||||
logits = self.logits_processor.forward(
|
||||
hidden_states,
|
||||
|
||||
@ -25,7 +25,8 @@ from tensorrt_llm.mapping import CpType, Mapping
|
||||
|
||||
from ..attention_backend import get_sparse_attn_kv_cache_manager
|
||||
from ..model_config import ModelConfig
|
||||
from ..speculative import get_num_extra_kv_tokens, get_spec_decoder
|
||||
from ..speculative import (get_num_extra_kv_tokens, get_num_spec_layers,
|
||||
get_spec_decoder, should_use_separate_draft_kv_cache)
|
||||
from .config_utils import is_mla, is_nemotron_hybrid, is_qwen3_next
|
||||
from .guided_decoder import GuidedDecoder
|
||||
from .kv_cache_connector import KvCacheConnectorManager
|
||||
@ -80,6 +81,7 @@ class KvCacheCreator:
|
||||
sparse_attention_config: SparseAttentionConfig,
|
||||
profiling_stage_data: Optional[dict],
|
||||
execution_stream: Optional[torch.cuda.Stream] = None,
|
||||
draft_config: Optional[ModelConfig] = None,
|
||||
):
|
||||
self._model_engine = model_engine
|
||||
self._draft_model_engine = draft_model_engine
|
||||
@ -103,6 +105,7 @@ class KvCacheCreator:
|
||||
self._execution_stream = execution_stream
|
||||
if self._kv_cache_manager_cls == KVCacheManager and kv_cache_config.use_kv_cache_manager_v2:
|
||||
self._kv_cache_manager_cls = KVCacheManagerV2
|
||||
self._draft_config = draft_config
|
||||
|
||||
def _get_kv_size_per_token(self):
|
||||
model_config = self._model_engine.model.model_config
|
||||
@ -115,6 +118,12 @@ class KvCacheCreator:
|
||||
draft_model_config,
|
||||
mapping,
|
||||
tokens_per_block=self._tokens_per_block)
|
||||
elif self._should_create_separate_draft_kv_cache():
|
||||
# One-model draft with separate KV cache layout
|
||||
kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token(
|
||||
self._draft_config,
|
||||
mapping,
|
||||
tokens_per_block=self._tokens_per_block)
|
||||
return kv_size_per_token
|
||||
|
||||
def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
|
||||
@ -397,9 +406,12 @@ class KvCacheCreator:
|
||||
# get kv cache stats for both model and draft model
|
||||
kv_stats = py_executor.resource_manager.resource_managers.get(
|
||||
ResourceManagerType.KV_CACHE_MANAGER).get_kv_cache_stats()
|
||||
kv_stats_draft = py_executor.resource_manager.resource_managers.get(
|
||||
ResourceManagerType.DRAFT_KV_CACHE_MANAGER).get_kv_cache_stats(
|
||||
) if self._draft_model_engine is not None else None
|
||||
# Get draft KV cache stats if present (either from two-model mode or one-model
|
||||
# mode with separate draft KV cache)
|
||||
draft_kv_cache_manager = py_executor.resource_manager.resource_managers.get(
|
||||
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
|
||||
kv_stats_draft = draft_kv_cache_manager.get_kv_cache_stats(
|
||||
) if draft_kv_cache_manager is not None else None
|
||||
|
||||
# get total allocated bytes
|
||||
allocated_bytes = kv_stats.allocated_bytes + (
|
||||
@ -466,6 +478,15 @@ class KvCacheCreator:
|
||||
mapping = self._mapping
|
||||
assert model_engine.model.model_config.is_generation, "Only construct KV cache for generation models."
|
||||
|
||||
# When using separate draft KV cache in one-model speculative decoding,
|
||||
# use layer_mask to include only target layers. The draft layers should
|
||||
# only be in the separate draft KV cache manager.
|
||||
# We still pass spec_config so that num_extra_kv_tokens is calculated.
|
||||
spec_dec_layer_mask = None
|
||||
if self._should_create_separate_draft_kv_cache():
|
||||
num_target_layers = model_engine.model.model_config.pretrained_config.num_hidden_layers
|
||||
spec_dec_layer_mask = [True] * num_target_layers
|
||||
|
||||
kv_cache_manager = _create_kv_cache_manager(
|
||||
model_engine=model_engine,
|
||||
kv_cache_manager_cls=self._kv_cache_manager_cls,
|
||||
@ -481,6 +502,7 @@ class KvCacheCreator:
|
||||
kv_connector_manager=self._kv_connector_manager,
|
||||
estimating_kv_cache=estimating_kv_cache,
|
||||
execution_stream=self._execution_stream,
|
||||
layer_mask=spec_dec_layer_mask,
|
||||
)
|
||||
|
||||
# KVCacheManager (Non-draft) modifies the max_seq_len field, update it to self
|
||||
@ -503,6 +525,64 @@ class KvCacheCreator:
|
||||
|
||||
return kv_cache_manager
|
||||
|
||||
def _should_create_separate_draft_kv_cache(self) -> bool:
|
||||
"""
|
||||
Check if we need a separate draft KV cache manager for one-model mode.
|
||||
Returns True if the speculative config has use_separate_draft_kv_cache=True.
|
||||
"""
|
||||
if self._draft_config is None:
|
||||
return False
|
||||
|
||||
return should_use_separate_draft_kv_cache(self._speculative_config)
|
||||
|
||||
def _create_one_model_draft_kv_cache_manager(
|
||||
self,
|
||||
estimating_kv_cache: bool = False) -> Optional[KVCacheManager]:
|
||||
"""
|
||||
Create a KV cache manager for draft model layers in one-model mode
|
||||
when target and draft have different KV cache layouts.
|
||||
"""
|
||||
# Get target model's num_hidden_layers to compute correct layer indices.
|
||||
# Draft model layers in one-model mode start at target_num_layers.
|
||||
target_pretrained_config = self._model_engine.model.model_config.pretrained_config
|
||||
target_num_layers = target_pretrained_config.num_hidden_layers
|
||||
# Use get_num_spec_layers to get the correct number of draft layers
|
||||
# for the speculative decoding mode (e.g., num_eagle_layers for Eagle3)
|
||||
num_draft_layers = get_num_spec_layers(self._speculative_config)
|
||||
|
||||
# Create layer_mask: False for target layers, True for draft layers.
|
||||
# This ensures the draft KV cache manager uses the correct layer indices
|
||||
# (e.g., layers 32, 33, ... instead of 0, 1, ...).
|
||||
spec_dec_layer_mask = [False
|
||||
] * target_num_layers + [True] * num_draft_layers
|
||||
|
||||
# Get the appropriate KV cache manager class for the draft model
|
||||
draft_kv_cache_manager_cls = get_kv_cache_manager_cls(
|
||||
self._draft_config)
|
||||
|
||||
return _create_kv_cache_manager(
|
||||
model_engine=None,
|
||||
kv_cache_manager_cls=draft_kv_cache_manager_cls,
|
||||
mapping=self._mapping,
|
||||
kv_cache_config=self._kv_cache_config,
|
||||
tokens_per_block=self._tokens_per_block,
|
||||
max_seq_len=self._max_seq_len,
|
||||
max_batch_size=self._max_batch_size,
|
||||
spec_config=self._speculative_config,
|
||||
sparse_attn_config=None, # Not applicable for draft in one-model mode
|
||||
max_num_tokens=self._max_num_tokens,
|
||||
max_beam_width=self._max_beam_width,
|
||||
kv_connector_manager=self._kv_connector_manager,
|
||||
estimating_kv_cache=estimating_kv_cache,
|
||||
execution_stream=self._execution_stream,
|
||||
# One-model draft specific overrides
|
||||
model_config=self._draft_config,
|
||||
dtype=self._draft_config.pretrained_config.torch_dtype,
|
||||
is_draft=True,
|
||||
layer_mask=spec_dec_layer_mask,
|
||||
num_layers=num_draft_layers,
|
||||
)
|
||||
|
||||
def build_managers(self,
|
||||
resources: Dict,
|
||||
estimating_kv_cache: bool = False) -> None:
|
||||
@ -514,9 +594,16 @@ class KvCacheCreator:
|
||||
raise NotImplementedError(
|
||||
"Connector manager is not supported for draft model.")
|
||||
|
||||
draft_kv_cache_manager = self._create_kv_cache_manager(
|
||||
self._draft_model_engine, estimating_kv_cache
|
||||
) if self._draft_model_engine is not None else None
|
||||
draft_kv_cache_manager = None
|
||||
|
||||
# Two-model speculative decoding: draft model has separate engine
|
||||
if self._draft_model_engine is not None:
|
||||
draft_kv_cache_manager = self._create_kv_cache_manager(
|
||||
self._draft_model_engine, estimating_kv_cache)
|
||||
# One-model speculative decoding with different KV layouts
|
||||
elif self._should_create_separate_draft_kv_cache():
|
||||
draft_kv_cache_manager = self._create_one_model_draft_kv_cache_manager(
|
||||
estimating_kv_cache)
|
||||
|
||||
resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager
|
||||
resources[
|
||||
@ -534,7 +621,7 @@ class KvCacheCreator:
|
||||
|
||||
|
||||
def _create_kv_cache_manager(
|
||||
model_engine: PyTorchModelEngine,
|
||||
model_engine: Optional[PyTorchModelEngine],
|
||||
kv_cache_manager_cls,
|
||||
mapping: Mapping,
|
||||
kv_cache_config: KvCacheConfig,
|
||||
@ -547,13 +634,32 @@ def _create_kv_cache_manager(
|
||||
max_beam_width: int,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager],
|
||||
estimating_kv_cache: bool,
|
||||
execution_stream: Optional[torch.cuda.Stream] = None) -> KVCacheManager:
|
||||
execution_stream: Optional[torch.cuda.Stream] = None,
|
||||
# Optional overrides for one-model draft case (when model_engine is None)
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
is_draft: Optional[bool] = None,
|
||||
layer_mask: Optional[List[bool]] = None,
|
||||
num_layers: Optional[int] = None) -> KVCacheManager:
|
||||
"""
|
||||
Returns:
|
||||
A KVCacheManager instance for the given model_engine
|
||||
A KVCacheManager instance for the given model engine or model config
|
||||
"""
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
quant_config = model_engine.model.model_config.quant_config
|
||||
# Extract config from model_engine or use provided model_config
|
||||
if model_config is not None:
|
||||
config = model_config.pretrained_config
|
||||
quant_config = model_config.quant_config
|
||||
_model_config = model_config
|
||||
else:
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
quant_config = model_engine.model.model_config.quant_config
|
||||
_model_config = model_engine.model.model_config
|
||||
|
||||
if dtype is None:
|
||||
dtype = model_engine.dtype
|
||||
|
||||
if is_draft is None:
|
||||
is_draft = model_engine.is_draft_model
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
num_attention_heads = config.num_attention_heads
|
||||
@ -569,10 +675,10 @@ def _create_kv_cache_manager(
|
||||
):
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.NVFP4
|
||||
else:
|
||||
kv_cache_dtype = str_dtype_to_binding(
|
||||
torch_dtype_to_str(model_engine.dtype))
|
||||
kv_cache_dtype = str_dtype_to_binding(torch_dtype_to_str(dtype))
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
# Use provided num_layers if available, otherwise use config
|
||||
num_hidden_layers = num_layers if num_layers is not None else config.num_hidden_layers
|
||||
|
||||
if is_mla(config):
|
||||
kv_cache_manager = kv_cache_manager_cls(
|
||||
@ -589,12 +695,13 @@ def _create_kv_cache_manager(
|
||||
spec_config=spec_config,
|
||||
vocab_size=config.vocab_size,
|
||||
max_beam_width=max_beam_width,
|
||||
is_draft=model_engine.is_draft_model,
|
||||
is_draft=is_draft,
|
||||
kv_connector_manager=kv_connector_manager
|
||||
if not estimating_kv_cache else None,
|
||||
sparse_attn_config=sparse_attn_config,
|
||||
is_estimating_kv_cache=estimating_kv_cache,
|
||||
execution_stream=execution_stream,
|
||||
layer_mask=layer_mask,
|
||||
)
|
||||
elif is_nemotron_hybrid(config):
|
||||
if max_beam_width > 1:
|
||||
@ -606,9 +713,10 @@ def _create_kv_cache_manager(
|
||||
"Connector manager is not supported for MambaHybridCacheManager."
|
||||
)
|
||||
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
num_layers = config.hybrid_override_pattern.count("*")
|
||||
layer_mask = [char == "*" for char in config.hybrid_override_pattern]
|
||||
hybrid_layer_mask = [
|
||||
char == "*" for char in config.hybrid_override_pattern
|
||||
]
|
||||
mamba_num_layers = config.hybrid_override_pattern.count("M")
|
||||
mamba_layer_mask = [
|
||||
char == "M" for char in config.hybrid_override_pattern
|
||||
@ -623,12 +731,13 @@ def _create_kv_cache_manager(
|
||||
mamba_num_layers,
|
||||
mamba_layer_mask,
|
||||
config.torch_dtype,
|
||||
model_engine.model.model_config.quant_config.mamba_ssm_cache_dtype,
|
||||
quant_config.mamba_ssm_cache_dtype
|
||||
if quant_config is not None else None,
|
||||
# kv cache parameters
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
layer_mask=layer_mask,
|
||||
layer_mask=hybrid_layer_mask,
|
||||
num_kv_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
@ -649,13 +758,12 @@ def _create_kv_cache_manager(
|
||||
raise NotImplementedError(
|
||||
"Connector manager is not supported for MambaHybridCacheManager."
|
||||
)
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
mamba_layer_mask = [
|
||||
True if i %
|
||||
config.full_attention_interval != config.full_attention_interval -
|
||||
1 else False for i in range(num_hidden_layers)
|
||||
]
|
||||
layer_mask = [
|
||||
hybrid_layer_mask = [
|
||||
False if i %
|
||||
config.full_attention_interval != config.full_attention_interval -
|
||||
1 else True for i in range(num_hidden_layers)
|
||||
@ -673,12 +781,13 @@ def _create_kv_cache_manager(
|
||||
num_mamba_layers,
|
||||
mamba_layer_mask,
|
||||
config.torch_dtype,
|
||||
model_engine.model.model_config.quant_config.mamba_ssm_cache_dtype,
|
||||
quant_config.mamba_ssm_cache_dtype
|
||||
if quant_config is not None else None,
|
||||
# kv cache parameters
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
layer_mask=layer_mask,
|
||||
layer_mask=hybrid_layer_mask,
|
||||
num_kv_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
@ -694,7 +803,7 @@ def _create_kv_cache_manager(
|
||||
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_for_vswa in KVCahceManager
|
||||
is_vswa = kv_cache_config.max_attention_window is not None and len(
|
||||
set(kv_cache_config.max_attention_window)) > 1
|
||||
binding_model_config = model_engine.model.model_config.get_bindings_model_config(
|
||||
binding_model_config = _model_config.get_bindings_model_config(
|
||||
tokens_per_block=tokens_per_block) if is_vswa else None
|
||||
|
||||
kv_cache_manager = kv_cache_manager_cls(
|
||||
@ -713,12 +822,13 @@ def _create_kv_cache_manager(
|
||||
max_num_tokens=max_num_tokens,
|
||||
model_config=binding_model_config,
|
||||
max_beam_width=max_beam_width,
|
||||
is_draft=model_engine.is_draft_model,
|
||||
is_draft=is_draft,
|
||||
kv_connector_manager=kv_connector_manager
|
||||
if not estimating_kv_cache else None,
|
||||
sparse_attn_config=sparse_attn_config,
|
||||
is_estimating_kv_cache=estimating_kv_cache,
|
||||
execution_stream=execution_stream,
|
||||
layer_mask=layer_mask,
|
||||
)
|
||||
return kv_cache_manager
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ from ..memory_buffer_utils import get_memory_buffers
|
||||
from ..modules.multi_stream_utils import with_multi_stream
|
||||
from ..speculative.eagle3 import Eagle3ResourceManager
|
||||
from ..speculative.mtp import SampleStateTensorsMTP
|
||||
from ..speculative.utils import get_draft_kv_cache_manager
|
||||
from ..utils import make_weak_ref, piecewise_cuda_graph
|
||||
from .llm_request import get_draft_token_length
|
||||
from .mamba_cache_manager import MambaCacheManager
|
||||
@ -435,12 +436,20 @@ class CUDAGraphRunner:
|
||||
# respect the requirement just in case that changes in the future.
|
||||
if self.padding_dummy_request is None:
|
||||
|
||||
# Get draft KV cache manager only for one-model speculative decoding.
|
||||
# In two-model mode, each model has its own KV cache manager, so
|
||||
# draft_kv_cache_manager should be None.
|
||||
draft_kv_cache_manager = get_draft_kv_cache_manager(
|
||||
self.spec_config, resource_manager)
|
||||
|
||||
self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
|
||||
[CUDA_GRAPH_DUMMY_REQUEST_ID],
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=runtime_draft_len,
|
||||
use_mrope=self.config.use_mrope,
|
||||
max_beam_width=self.config.max_beam_width)
|
||||
max_beam_width=self.config.max_beam_width,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager)
|
||||
|
||||
if self.padding_dummy_request is None:
|
||||
return 0
|
||||
else:
|
||||
|
||||
@ -46,8 +46,8 @@ from ..models.modeling_utils import DecoderModelForCausalLM
|
||||
from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer,
|
||||
MoeLoadBalancerIterContext)
|
||||
from ..peft.lora.cuda_graph_lora_manager import CudaGraphLoraManager
|
||||
from ..speculative import (SpecMetadata, get_num_extra_kv_tokens,
|
||||
get_spec_metadata,
|
||||
from ..speculative import (SpecMetadata, get_draft_kv_cache_manager,
|
||||
get_num_extra_kv_tokens, get_spec_metadata,
|
||||
update_spec_config_from_model_config)
|
||||
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
|
||||
from ..speculative.eagle3 import Eagle3ResourceManager, Eagle3SpecMetadata
|
||||
@ -567,6 +567,15 @@ class PyTorchModelEngine(ModelEngine):
|
||||
def use_beam_search(self):
|
||||
return self.max_beam_width > 1
|
||||
|
||||
def _get_draft_kv_cache_manager(
|
||||
self, resource_manager: ResourceManager
|
||||
) -> Optional[Union[KVCacheManager, KVCacheManagerV2]]:
|
||||
"""
|
||||
Returns the draft KV cache manager only in one-model speculative decoding
|
||||
mode where the target model manages a separate draft KV cache.
|
||||
"""
|
||||
return get_draft_kv_cache_manager(self.spec_config, resource_manager)
|
||||
|
||||
@contextmanager
|
||||
def set_warmup_flag(self):
|
||||
prev_is_warmup = self.is_warmup
|
||||
@ -909,6 +918,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
"""A context manager to automatically free resources of a dummy batch."""
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
self.kv_cache_manager_key)
|
||||
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
spec_resource_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER)
|
||||
try:
|
||||
@ -917,6 +928,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if batch is not None and kv_cache_manager is not None:
|
||||
for req in batch.all_requests():
|
||||
kv_cache_manager.free_resources(req)
|
||||
if draft_kv_cache_manager is not None:
|
||||
draft_kv_cache_manager.free_resources(req)
|
||||
if spec_resource_manager is not None:
|
||||
spec_resource_manager.free_resources(req)
|
||||
|
||||
@ -939,6 +952,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
"""Creates a generic dummy ScheduledRequests object for warmup."""
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
self.kv_cache_manager_key)
|
||||
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
|
||||
spec_resource_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER)
|
||||
|
||||
@ -1011,7 +1027,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
is_gen=False,
|
||||
max_num_draft_tokens=self.runtime_draft_len,
|
||||
use_mrope=self.use_mrope,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps)
|
||||
num_extra_decoding_steps=num_extra_decoding_steps,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager)
|
||||
|
||||
if ctx_requests is None:
|
||||
return None
|
||||
@ -1030,7 +1047,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
max_num_draft_tokens=self.max_total_draft_tokens,
|
||||
use_mrope=self.use_mrope,
|
||||
max_beam_width=self.max_beam_width,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps)
|
||||
num_extra_decoding_steps=num_extra_decoding_steps,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager)
|
||||
|
||||
if gen_requests is None:
|
||||
for r in ctx_requests:
|
||||
@ -1058,6 +1076,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.kv_cache_manager_key)
|
||||
spec_resource_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER)
|
||||
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
|
||||
available_blocks = kv_cache_manager.get_num_free_blocks(
|
||||
) // self.max_beam_width
|
||||
@ -1075,7 +1095,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
max_num_draft_tokens=draft_len,
|
||||
use_mrope=self.use_mrope,
|
||||
max_beam_width=self.max_beam_width,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps)
|
||||
num_extra_decoding_steps=num_extra_decoding_steps,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager)
|
||||
|
||||
if requests is None:
|
||||
return None
|
||||
@ -1083,6 +1104,12 @@ class PyTorchModelEngine(ModelEngine):
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(
|
||||
batch_size=batch_size, max_num_draft_tokens=draft_len)
|
||||
|
||||
# Also consider draft KV cache capacity when it exists
|
||||
if draft_kv_cache_manager is not None:
|
||||
draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens(
|
||||
batch_size=batch_size, max_num_draft_tokens=draft_len)
|
||||
available_tokens = min(available_tokens, draft_available_tokens)
|
||||
|
||||
# Add one dummy request with the maximum possible sequence length.
|
||||
max_seq_len = min(
|
||||
self.max_seq_len if max_seq_len is None else max_seq_len,
|
||||
@ -1110,7 +1137,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
max_num_draft_tokens=draft_len,
|
||||
use_mrope=self.use_mrope,
|
||||
max_beam_width=self.max_beam_width,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps)
|
||||
num_extra_decoding_steps=num_extra_decoding_steps,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager)
|
||||
|
||||
if max_seq_len_request is None:
|
||||
for r in requests:
|
||||
@ -1141,8 +1169,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
req.py_is_first_draft = True
|
||||
req.py_draft_tokens = []
|
||||
|
||||
def _set_up_attn_metadata(self, kv_cache_manager: Union[KVCacheManager,
|
||||
KVCacheManagerV2]):
|
||||
def _set_up_attn_metadata(
|
||||
self,
|
||||
kv_cache_manager: Union[KVCacheManager, KVCacheManagerV2],
|
||||
draft_kv_cache_manager: Optional[Union[KVCacheManager,
|
||||
KVCacheManagerV2]] = None):
|
||||
enable_context_mla_with_cached_kv = is_mla(
|
||||
self.model.model_config.pretrained_config) and (
|
||||
self.attn_runtime_features.cache_reuse
|
||||
@ -1191,6 +1222,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_num_sequences=self.batch_size * self.max_beam_width,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager,
|
||||
mapping=self.mapping,
|
||||
runtime_features=self.attn_runtime_features,
|
||||
enable_flash_mla=self.model.model_config.enable_flash_mla,
|
||||
@ -1568,7 +1600,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
else:
|
||||
return self._apply_incremental_update_target(
|
||||
scheduled_requests, kv_cache_manager, attn_metadata,
|
||||
spec_metadata, new_tensors_device, num_accepted_tokens_device)
|
||||
spec_metadata, new_tensors_device, num_accepted_tokens_device,
|
||||
resource_manager)
|
||||
|
||||
@nvtx_range("_prepare_incremental_update_metadata")
|
||||
def _prepare_incremental_update_metadata(
|
||||
@ -1834,7 +1867,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
attn_metadata: AttentionMetadata,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
new_tensors_device: Optional[SampleStateTensors] = None,
|
||||
num_accepted_tokens_device: Optional[torch.Tensor] = None):
|
||||
num_accepted_tokens_device: Optional[torch.Tensor] = None,
|
||||
resource_manager: Optional[ResourceManager] = None):
|
||||
# Extract tensors from new_tensors_device
|
||||
new_tokens_device = new_tensors_device.new_tokens # [batch, 1 + draft_len]
|
||||
new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch]
|
||||
@ -1968,6 +2002,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
'position_ids': final_position_ids,
|
||||
'inputs_embeds': None,
|
||||
"multimodal_params": [],
|
||||
'resource_manager': resource_manager,
|
||||
}
|
||||
|
||||
if bool(lora_params):
|
||||
@ -2781,6 +2816,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
'position_ids': final_position_ids,
|
||||
'inputs_embeds': None,
|
||||
"multimodal_params": multimodal_params_list,
|
||||
'resource_manager': resource_manager,
|
||||
}
|
||||
|
||||
if bool(lora_params):
|
||||
@ -2846,7 +2882,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
attn_metadata: AttentionMetadata,
|
||||
spec_metadata: Optional[SpecMetadata] = None):
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
resource_manager: Optional[ResourceManager] = None):
|
||||
"""
|
||||
Prepare inputs for Pytorch Model.
|
||||
"""
|
||||
@ -2950,7 +2987,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
'position_ids':
|
||||
self.position_ids_cuda[:virtual_num_tokens].unsqueeze(0),
|
||||
'inputs_embeds': None,
|
||||
"multimodal_params": multimodal_params_list
|
||||
"multimodal_params": multimodal_params_list,
|
||||
'resource_manager': resource_manager,
|
||||
}
|
||||
|
||||
if bool(lora_params):
|
||||
@ -2993,10 +3031,12 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
return inputs, None
|
||||
|
||||
def _prepare_star_attention_inputs(self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
kv_cache_manager,
|
||||
attn_metadata: AttentionMetadata):
|
||||
def _prepare_star_attention_inputs(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
kv_cache_manager,
|
||||
attn_metadata: AttentionMetadata,
|
||||
resource_manager: Optional[ResourceManager] = None):
|
||||
"""
|
||||
Prepare inputs for Pytorch Model.
|
||||
"""
|
||||
@ -3212,7 +3252,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
'attn_metadata': attn_metadata,
|
||||
'input_ids': self.input_ids_cuda[:num_tokens],
|
||||
'position_ids': self.position_ids_cuda[:num_tokens].unsqueeze(0),
|
||||
'inputs_embeds': None
|
||||
'inputs_embeds': None,
|
||||
'resource_manager': resource_manager,
|
||||
}, gather_ids if is_spec_decode else None
|
||||
|
||||
def _get_lora_params_from_requests(
|
||||
@ -3357,7 +3398,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
cp_type = self.mapping.cp_config['cp_type']
|
||||
if CpType.STAR == cp_type:
|
||||
return self._prepare_star_attention_inputs(
|
||||
scheduled_requests, kv_cache_manager, attn_metadata)
|
||||
scheduled_requests, kv_cache_manager, attn_metadata,
|
||||
resource_manager)
|
||||
elif CpType.HELIX == cp_type:
|
||||
# Take the usual route of _prepare_tp_inputs.
|
||||
pass
|
||||
@ -3384,8 +3426,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None):
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
self.kv_cache_manager_key)
|
||||
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
|
||||
attn_metadata = self._set_up_attn_metadata(kv_cache_manager)
|
||||
attn_metadata = self._set_up_attn_metadata(kv_cache_manager,
|
||||
draft_kv_cache_manager)
|
||||
if self.enable_spec_decode:
|
||||
spec_resource_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER)
|
||||
@ -3423,7 +3468,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
if kv_cache_manager is None:
|
||||
inputs, gather_ids = self._prepare_tp_inputs_no_cache(
|
||||
scheduled_requests, attn_metadata, spec_metadata)
|
||||
scheduled_requests, attn_metadata, spec_metadata,
|
||||
resource_manager)
|
||||
|
||||
with MoeLoadBalancerIterContext(moe_load_balancer):
|
||||
# Special handling for multimodal encoder only mode
|
||||
|
||||
@ -313,6 +313,11 @@ def create_py_executor(
|
||||
has_draft_model_engine = spec_config.spec_dec_mode.has_draft_model()
|
||||
has_spec_drafter = spec_config.spec_dec_mode.has_spec_drafter()
|
||||
|
||||
# WAR for https://nvbugs/5807902
|
||||
# Disable separate draft KV cache in disaggregated mode
|
||||
if cache_transceiver_config is not None or kv_connector_config is not None:
|
||||
spec_config._allow_separate_draft_kv_cache = False
|
||||
|
||||
# chunk_unit_size may be changed to 64 when using flash mla
|
||||
attn_runtime_features = AttentionRuntimeFeatures(
|
||||
chunked_prefill=enable_chunked_context,
|
||||
@ -623,6 +628,10 @@ def create_py_executor(
|
||||
|
||||
if model_engine.model.model_config.is_generation:
|
||||
#NOTE: non-generation models do not have kv cache
|
||||
|
||||
# Get draft config for one-engine speculative decoding if available
|
||||
draft_config = getattr(model_engine.model, 'draft_config', None)
|
||||
|
||||
kv_cache_creator = KvCacheCreator(
|
||||
model_engine=model_engine,
|
||||
draft_model_engine=draft_model_engine,
|
||||
@ -640,6 +649,7 @@ def create_py_executor(
|
||||
profiling_stage_data=profiling_stage_data,
|
||||
sparse_attention_config=sparse_attention_config,
|
||||
execution_stream=execution_stream,
|
||||
draft_config=draft_config,
|
||||
)
|
||||
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
|
||||
with allocation_scope(
|
||||
|
||||
@ -135,7 +135,10 @@ def get_pp_layers(
|
||||
pp_layers = mapping.pp_layers(total_num_layers)
|
||||
if layer_mask is not None:
|
||||
pp_layers = [i for i in pp_layers if layer_mask[i]]
|
||||
if spec_config is not None:
|
||||
# Only add speculative layers when layer_mask is not provided.
|
||||
# When layer_mask is provided, the caller explicitly controls which layers
|
||||
# to include, so we should not add extra layers automatically.
|
||||
if spec_config is not None and layer_mask is None:
|
||||
num_spec_layers = get_num_spec_layers(spec_config)
|
||||
total_num_layers += num_spec_layers
|
||||
if mapping.is_last_pp_rank():
|
||||
@ -586,6 +589,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
# we need to make the KV cache manager aware that multiple autoregressive steps will
|
||||
# occur.
|
||||
num_extra_decoding_steps: int = 0,
|
||||
draft_kv_cache_manager: Optional[BaseResourceManager] = None,
|
||||
):
|
||||
available_blocks = self.get_num_free_blocks()
|
||||
# No padding if not enough KV cache space
|
||||
@ -625,6 +629,12 @@ class KVCacheManager(BaseResourceManager):
|
||||
for _ in range(num_extra_decoding_steps):
|
||||
self.impl.add_token(req_id)
|
||||
|
||||
if draft_kv_cache_manager is not None:
|
||||
draft_kv_cache_manager.impl.add_sequence(
|
||||
req_id, token_num, beam_width, req)
|
||||
for _ in range(self.num_extra_kv_tokens):
|
||||
draft_kv_cache_manager.impl.add_token(req_id)
|
||||
|
||||
if is_gen:
|
||||
req.state = LlmRequestState.GENERATION_IN_PROGRESS
|
||||
req.prompt_len = token_num - 1
|
||||
@ -651,6 +661,10 @@ class KVCacheManager(BaseResourceManager):
|
||||
for _ in range(max_num_draft_tokens):
|
||||
self.impl.add_token(req_id)
|
||||
|
||||
if draft_kv_cache_manager is not None:
|
||||
for _ in range(max_num_draft_tokens):
|
||||
draft_kv_cache_manager.impl.add_token(req_id)
|
||||
|
||||
# TODO: Planning to get dummy_data from each model. Before that, we need to add dummy mrop_config to the request here.
|
||||
if use_mrope:
|
||||
dummy_mrope_position_ids = torch.arange(
|
||||
@ -1886,7 +1900,8 @@ class KVCacheManagerV2(BaseResourceManager):
|
||||
max_num_draft_tokens: int = 0,
|
||||
use_mrope: bool = False,
|
||||
max_beam_width: int = 1,
|
||||
num_extra_decoding_steps: int = 0):
|
||||
num_extra_decoding_steps: int = 0,
|
||||
draft_kv_cache_manager: Optional['BaseResourceManager'] = None):
|
||||
|
||||
beam_width = max_beam_width
|
||||
requests = []
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
from .auto_heuristic import suggest_spec_config
|
||||
from .eagle3 import Eagle3SpecMetadata
|
||||
from .interface import SpecMetadata, SpecWorkerBase
|
||||
from .interface import (SpecMetadata, SpecWorkerBase,
|
||||
should_use_separate_draft_kv_cache)
|
||||
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
|
||||
from .ngram import NGramDrafter, NGramPoolManager
|
||||
from .save_hidden_state import SaveHiddenStatesDrafter
|
||||
from .spec_tree_manager import SpecTreeManager
|
||||
from .utils import (get_num_extra_kv_tokens, get_num_spec_layers,
|
||||
get_spec_decoder, get_spec_drafter, get_spec_metadata,
|
||||
get_spec_resource_manager, get_spec_worker,
|
||||
update_spec_config_from_model_config)
|
||||
from .utils import (get_draft_kv_cache_manager, get_num_extra_kv_tokens,
|
||||
get_num_spec_layers, get_spec_decoder, get_spec_drafter,
|
||||
get_spec_metadata, get_spec_resource_manager,
|
||||
get_spec_worker, update_spec_config_from_model_config)
|
||||
|
||||
__all__ = [
|
||||
"Eagle3SpecMetadata",
|
||||
@ -20,6 +21,7 @@ __all__ = [
|
||||
"SaveHiddenStatesDrafter",
|
||||
"SpecMetadata",
|
||||
"SpecWorkerBase",
|
||||
"get_draft_kv_cache_manager",
|
||||
"get_num_extra_kv_tokens",
|
||||
"get_num_spec_layers",
|
||||
"get_spec_decoder",
|
||||
@ -27,6 +29,7 @@ __all__ = [
|
||||
"get_spec_metadata",
|
||||
"get_spec_resource_manager",
|
||||
"get_spec_worker",
|
||||
"should_use_separate_draft_kv_cache",
|
||||
"update_spec_config_from_model_config",
|
||||
"suggest_spec_config",
|
||||
"SpecTreeManager",
|
||||
|
||||
@ -358,8 +358,11 @@ class Eagle3OneModelSampler(MTPSampler):
|
||||
|
||||
class Eagle3OneModelWorker(SpecWorkerBase):
|
||||
|
||||
def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
|
||||
super().__init__()
|
||||
def __init__(self,
|
||||
spec_config: "EagleDecodingConfig",
|
||||
mapping: Mapping,
|
||||
use_separate_draft_kv_cache: bool = False):
|
||||
super().__init__(use_separate_draft_kv_cache)
|
||||
self.spec_config = spec_config
|
||||
self.mapping = mapping
|
||||
|
||||
@ -369,8 +372,15 @@ class Eagle3OneModelWorker(SpecWorkerBase):
|
||||
|
||||
# Skip torch.compile for now since current Torch is not compatible with Triton 3.4
|
||||
# @torch.compile(options={"max-autotune": True})
|
||||
def forward(self, input_ids, position_ids, hidden_states, logits,
|
||||
attn_metadata, spec_metadata, draft_model):
|
||||
def forward(self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
hidden_states,
|
||||
logits,
|
||||
attn_metadata,
|
||||
spec_metadata,
|
||||
draft_model,
|
||||
resource_manager=None):
|
||||
batch_size = attn_metadata.num_seqs
|
||||
num_contexts = attn_metadata.num_contexts
|
||||
num_gens = batch_size - num_contexts
|
||||
@ -400,81 +410,91 @@ class Eagle3OneModelWorker(SpecWorkerBase):
|
||||
# Predict draft tokens
|
||||
next_draft_tokens = []
|
||||
original_all_rank_num_tokens = attn_metadata.all_rank_num_tokens
|
||||
for i in range(self.max_draft_len):
|
||||
if i == 0:
|
||||
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
|
||||
(self.max_draft_len + 1)).long()
|
||||
gather_ids_gen = (start_ids_gen +
|
||||
num_accepted_tokens[num_contexts:] - 1 +
|
||||
attn_metadata.num_ctx_tokens)
|
||||
gather_ids = torch.concat(
|
||||
[spec_metadata.gather_ids[:num_contexts], gather_ids_gen],
|
||||
dim=0)
|
||||
else:
|
||||
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
|
||||
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
new_tokens = inputs["input_ids"][gather_ids]
|
||||
self.guided_decoder.add_draft_batch(new_tokens,
|
||||
num_accepted_tokens,
|
||||
draft_step=i)
|
||||
# Get the draft KV cache manager if using separate layouts
|
||||
draft_kv_cache_manager = self.get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
|
||||
# Update attn_metadata.all_rank_num_tokens for attention DP
|
||||
if original_all_rank_num_tokens is not None:
|
||||
with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
|
||||
for i in range(self.max_draft_len):
|
||||
if i == 0:
|
||||
attn_metadata.all_rank_num_tokens = original_all_rank_num_tokens
|
||||
elif spec_metadata.all_rank_num_seqs is not None:
|
||||
attn_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs
|
||||
start_ids_gen = (
|
||||
spec_metadata.batch_indices_cuda[:num_gens] *
|
||||
(self.max_draft_len + 1)).long()
|
||||
gather_ids_gen = (start_ids_gen +
|
||||
num_accepted_tokens[num_contexts:] - 1 +
|
||||
attn_metadata.num_ctx_tokens)
|
||||
gather_ids = torch.concat([
|
||||
spec_metadata.gather_ids[:num_contexts], gather_ids_gen
|
||||
],
|
||||
dim=0)
|
||||
else:
|
||||
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
|
||||
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
|
||||
|
||||
hidden_states, hidden_states_to_save = draft_model.model(**inputs)
|
||||
|
||||
# FIXME (jhaotingc): Currently we disable use_spec_decoding mode for Eagle engine nth steps except 1st step.
|
||||
# Eagle engine takes in draft_len tokens from the previous step, run spec-dec mode with those tokens,
|
||||
# then the following step can use regular decoding mode to generate 1 tokens per step.
|
||||
# Currently the spec-dec mask for chained tree is not implemented yet.
|
||||
# When token tree is supported, this can be removed and all steps may use spec-dec mode as well.
|
||||
attn_metadata.use_spec_decoding = False
|
||||
|
||||
logits = draft_model.logits_processor(hidden_states[gather_ids],
|
||||
draft_model.lm_head,
|
||||
attn_metadata, True)
|
||||
if self.guided_decoder is not None:
|
||||
d2t = getattr(draft_model.model, "d2t", None)
|
||||
self.guided_decoder.execute_draft_batch(logits,
|
||||
d2t,
|
||||
if self.guided_decoder is not None:
|
||||
new_tokens = inputs["input_ids"][gather_ids]
|
||||
self.guided_decoder.add_draft_batch(new_tokens,
|
||||
num_accepted_tokens,
|
||||
draft_step=i)
|
||||
|
||||
new_draft_token = self.draft_decoder(logits, draft_model)
|
||||
next_draft_tokens.append(new_draft_token)
|
||||
# update inputs
|
||||
hidden_states = hidden_states_to_save[gather_ids]
|
||||
position_ids = inputs["position_ids"][gather_ids] + 1
|
||||
# update attn_metadata
|
||||
if i == 0:
|
||||
attn_metadata._seq_lens[:batch_size].fill_(1)
|
||||
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
|
||||
attn_metadata.on_update()
|
||||
# cannot run generation if their is no kv cache
|
||||
if inputs["attn_metadata"].kv_cache_manager is not None:
|
||||
attn_metadata.host_request_types[:attn_metadata.
|
||||
num_contexts].fill_(1)
|
||||
attn_metadata.num_contexts = 0
|
||||
# update kv_lens_cuda
|
||||
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
|
||||
self.max_draft_len - num_accepted_tokens[num_contexts:])
|
||||
attn_metadata.kv_lens_cuda[:num_contexts] += 1
|
||||
elif hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[:batch_size] += 1
|
||||
# support attention dp
|
||||
inputs = {
|
||||
"input_ids": new_draft_token,
|
||||
"position_ids": position_ids,
|
||||
"hidden_states": hidden_states,
|
||||
"attn_metadata": attn_metadata,
|
||||
"spec_metadata": spec_metadata,
|
||||
}
|
||||
# Update attn_metadata.all_rank_num_tokens for attention DP
|
||||
if original_all_rank_num_tokens is not None:
|
||||
if i == 0:
|
||||
attn_metadata.all_rank_num_tokens = original_all_rank_num_tokens
|
||||
elif spec_metadata.all_rank_num_seqs is not None:
|
||||
attn_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs
|
||||
|
||||
hidden_states, hidden_states_to_save = draft_model.model(
|
||||
**inputs)
|
||||
|
||||
# FIXME (jhaotingc): Currently we disable use_spec_decoding mode for Eagle engine nth steps except 1st step.
|
||||
# Eagle engine takes in draft_len tokens from the previous step, run spec-dec mode with those tokens,
|
||||
# then the following step can use regular decoding mode to generate 1 tokens per step.
|
||||
# Currently the spec-dec mask for chained tree is not implemented yet.
|
||||
# When token tree is supported, this can be removed and all steps may use spec-dec mode as well.
|
||||
attn_metadata.use_spec_decoding = False
|
||||
|
||||
logits = draft_model.logits_processor(hidden_states[gather_ids],
|
||||
draft_model.lm_head,
|
||||
attn_metadata, True)
|
||||
if self.guided_decoder is not None:
|
||||
d2t = getattr(draft_model.model, "d2t", None)
|
||||
self.guided_decoder.execute_draft_batch(logits,
|
||||
d2t,
|
||||
draft_step=i)
|
||||
|
||||
new_draft_token = self.draft_decoder(logits, draft_model)
|
||||
next_draft_tokens.append(new_draft_token)
|
||||
# update inputs
|
||||
hidden_states = hidden_states_to_save[gather_ids]
|
||||
position_ids = inputs["position_ids"][gather_ids] + 1
|
||||
# update attn_metadata
|
||||
if i == 0:
|
||||
attn_metadata._seq_lens[:batch_size].fill_(1)
|
||||
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
|
||||
attn_metadata.on_update()
|
||||
# cannot run generation if their is no kv cache
|
||||
if inputs["attn_metadata"].kv_cache_manager is not None:
|
||||
attn_metadata.host_request_types[:attn_metadata.
|
||||
num_contexts].fill_(1)
|
||||
attn_metadata.num_contexts = 0
|
||||
# update kv_lens_cuda
|
||||
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
|
||||
self.max_draft_len -
|
||||
num_accepted_tokens[num_contexts:])
|
||||
attn_metadata.kv_lens_cuda[:num_contexts] += 1
|
||||
elif hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[:batch_size] += 1
|
||||
# support attention dp
|
||||
inputs = {
|
||||
"input_ids": new_draft_token,
|
||||
"position_ids": position_ids,
|
||||
"hidden_states": hidden_states,
|
||||
"attn_metadata": attn_metadata,
|
||||
"spec_metadata": spec_metadata,
|
||||
}
|
||||
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
||||
|
||||
# restore attn_metadata to support cuda graph
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Optional, Type
|
||||
@ -11,10 +12,12 @@ from torch import nn
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..._utils import get_sm_version
|
||||
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
|
||||
from ..attention_backend.trtllm import (AttentionBackend, TrtllmAttention,
|
||||
TrtllmAttentionMetadata)
|
||||
from ..cute_dsl_kernels.argmax import argmax as cute_argmax
|
||||
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
|
||||
from ..pyexecutor.resource_manager import BaseResourceManager
|
||||
from ..pyexecutor.resource_manager import (BaseResourceManager,
|
||||
ResourceManagerType)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
|
||||
@ -26,6 +29,17 @@ if IS_FLASHINFER_AVAILABLE:
|
||||
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
|
||||
|
||||
|
||||
def should_use_separate_draft_kv_cache(spec_config) -> bool:
|
||||
"""
|
||||
Check if separate draft KV cache should be used for one-engine speculative decoding.
|
||||
"""
|
||||
if spec_config is None:
|
||||
return False
|
||||
if not spec_config.spec_dec_mode.use_one_engine():
|
||||
return False
|
||||
return spec_config._allow_separate_draft_kv_cache
|
||||
|
||||
|
||||
def get_force_num_accepted_tokens() -> int:
|
||||
"""
|
||||
Read and parse the TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS environment variable.
|
||||
@ -373,13 +387,14 @@ class SpecWorkerBase(nn.Module, ABC):
|
||||
Provides common functionality for sampling and token handling.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, use_separate_draft_kv_cache: bool = False):
|
||||
super().__init__()
|
||||
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
|
||||
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
|
||||
self.use_flashinfer = IS_FLASHINFER_AVAILABLE and flashinfer.__version__ >= "0.6.0"
|
||||
self.seed = 0
|
||||
self.offset = 0
|
||||
self.use_separate_draft_kv_cache = use_separate_draft_kv_cache
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -600,6 +615,52 @@ class SpecWorkerBase(nn.Module, ABC):
|
||||
else:
|
||||
return torch.empty(0, dtype=torch.int32, device="cuda")
|
||||
|
||||
def get_draft_kv_cache_manager(self, resource_manager):
|
||||
"""
|
||||
Get the draft KV cache manager if using separate KV cache layouts.
|
||||
"""
|
||||
if self.use_separate_draft_kv_cache and resource_manager is not None:
|
||||
return resource_manager.get_resource_manager(
|
||||
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def draft_kv_cache_context(self, attn_metadata, draft_kv_cache_manager):
|
||||
"""
|
||||
Context manager to temporarily switch to draft KV cache manager in one-engine speculative decoding.
|
||||
|
||||
This swaps both the kv_cache_manager reference AND the block offset tensors,
|
||||
since the target and draft KV caches have different block layouts.
|
||||
"""
|
||||
|
||||
# draft_kv_cache_manager is None if using two-engine speculative decoding or not enabling separate draft KV cache.
|
||||
if draft_kv_cache_manager is None:
|
||||
yield
|
||||
return
|
||||
|
||||
# Only TrtllmAttentionMetadata supports separate draft KV cache layouts
|
||||
if not isinstance(attn_metadata, TrtllmAttentionMetadata):
|
||||
yield
|
||||
return
|
||||
|
||||
# Save main KV cache manager and block offsets
|
||||
target_kv_cache_manager = attn_metadata.kv_cache_manager
|
||||
target_kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets
|
||||
target_host_kv_cache_block_offsets = attn_metadata.host_kv_cache_block_offsets
|
||||
|
||||
# Switch to draft KV cache manager and its block offsets
|
||||
attn_metadata.kv_cache_manager = draft_kv_cache_manager
|
||||
attn_metadata.kv_cache_block_offsets = attn_metadata.draft_kv_cache_block_offsets
|
||||
attn_metadata.host_kv_cache_block_offsets = attn_metadata.draft_host_kv_cache_block_offsets
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Restore main KV cache manager and block offsets
|
||||
attn_metadata.kv_cache_manager = target_kv_cache_manager
|
||||
attn_metadata.kv_cache_block_offsets = target_kv_cache_block_offsets
|
||||
attn_metadata.host_kv_cache_block_offsets = target_host_kv_cache_block_offsets
|
||||
|
||||
def _sample_tokens_for_batch(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
|
||||
@ -367,8 +367,11 @@ class MTPSampler(Sampler[SampleStateMTP]):
|
||||
|
||||
class MTPWorker(SpecWorkerBase):
|
||||
|
||||
def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
|
||||
super().__init__()
|
||||
def __init__(self,
|
||||
spec_config: "MTPDecodingConfig",
|
||||
model_config=None,
|
||||
use_separate_draft_kv_cache: bool = False):
|
||||
super().__init__(use_separate_draft_kv_cache)
|
||||
self.spec_config = spec_config
|
||||
self.model_config = model_config
|
||||
self.is_thop = False
|
||||
@ -386,6 +389,7 @@ class MTPWorker(SpecWorkerBase):
|
||||
attn_metadata,
|
||||
spec_metadata,
|
||||
draft_model,
|
||||
resource_manager=None,
|
||||
):
|
||||
'''
|
||||
Example:
|
||||
@ -518,43 +522,50 @@ class MTPWorker(SpecWorkerBase):
|
||||
# update attn metadata
|
||||
if attn_metadata is not None:
|
||||
self.change_attn_metadata(num_accepted_tokens, attn_metadata)
|
||||
draft_inputs.update(attn_metadata=attn_metadata)
|
||||
|
||||
# Run MTP layers to predict draft tokens
|
||||
next_draft_tokens = []
|
||||
last_tokens_idx = torch.cumsum(
|
||||
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
|
||||
for i, mtp_layer in enumerate(draft_model.mtp_layers):
|
||||
if self.guided_decoder is not None:
|
||||
new_tokens = draft_inputs['input_ids'][last_tokens_idx]
|
||||
self.guided_decoder.add_draft_batch(new_tokens,
|
||||
num_accepted_tokens,
|
||||
draft_step=i)
|
||||
|
||||
hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens,
|
||||
**draft_inputs)
|
||||
logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head,
|
||||
attn_metadata).float()
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.execute_draft_batch(logits, draft_step=i)
|
||||
draft_kv_cache_manager = self.get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
|
||||
new_draft_token = self.draft_sampler(logits)
|
||||
next_draft_tokens.append(new_draft_token)
|
||||
# shift input_ids and hidden_states
|
||||
input_ids = draft_inputs["input_ids"]
|
||||
input_ids[:-1] = input_ids[1:].clone()
|
||||
input_ids[last_tokens_idx] = new_draft_token
|
||||
draft_hidden_states = draft_inputs["hidden_states"]
|
||||
draft_hidden_states[:-1] = draft_hidden_states[1:].clone()
|
||||
draft_hidden_states[last_tokens_idx] = hidden_states[
|
||||
last_tokens_idx, :]
|
||||
draft_inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": draft_inputs["position_ids"],
|
||||
"hidden_states": draft_hidden_states,
|
||||
"attn_metadata": draft_inputs["attn_metadata"],
|
||||
}
|
||||
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
||||
with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
|
||||
for i, mtp_layer in enumerate(draft_model.mtp_layers):
|
||||
if self.guided_decoder is not None:
|
||||
new_tokens = draft_inputs['input_ids'][last_tokens_idx]
|
||||
self.guided_decoder.add_draft_batch(new_tokens,
|
||||
num_accepted_tokens,
|
||||
draft_step=i)
|
||||
|
||||
hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens,
|
||||
**draft_inputs)
|
||||
|
||||
logits = mtp_layer.shared_head(hidden_states,
|
||||
draft_model.lm_head,
|
||||
attn_metadata).float()
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.execute_draft_batch(logits,
|
||||
draft_step=i)
|
||||
|
||||
new_draft_token = self.draft_sampler(logits)
|
||||
next_draft_tokens.append(new_draft_token)
|
||||
# shift input_ids and hidden_states
|
||||
input_ids = draft_inputs["input_ids"]
|
||||
input_ids[:-1] = input_ids[1:].clone()
|
||||
input_ids[last_tokens_idx] = new_draft_token
|
||||
draft_hidden_states = draft_inputs["hidden_states"]
|
||||
draft_hidden_states[:-1] = draft_hidden_states[1:].clone()
|
||||
draft_hidden_states[last_tokens_idx] = hidden_states[
|
||||
last_tokens_idx, :]
|
||||
draft_inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": draft_inputs["position_ids"],
|
||||
"hidden_states": draft_hidden_states,
|
||||
"attn_metadata": draft_inputs["attn_metadata"],
|
||||
}
|
||||
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
||||
|
||||
# restore attn metadata
|
||||
if attn_metadata is not None:
|
||||
@ -573,6 +584,39 @@ class MTPWorker(SpecWorkerBase):
|
||||
'next_new_tokens': next_new_tokens
|
||||
}
|
||||
|
||||
def skip_forward(
|
||||
self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
hidden_states,
|
||||
logits,
|
||||
attn_metadata,
|
||||
spec_metadata,
|
||||
draft_model,
|
||||
resource_manager=None,
|
||||
):
|
||||
batch_size = attn_metadata.num_seqs
|
||||
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
||||
accepted_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
num_accepted_tokens = torch.ones(batch_size,
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
next_draft_tokens = torch.empty((batch_size, mtp_num_modules),
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
next_new_tokens = torch.empty((batch_size, (mtp_num_modules + 1)),
|
||||
dtype=torch.int,
|
||||
device=logits.device)
|
||||
return {
|
||||
'logits': logits,
|
||||
'new_tokens': accepted_tokens,
|
||||
'new_tokens_lens': num_accepted_tokens,
|
||||
'next_draft_tokens': next_draft_tokens,
|
||||
'next_new_tokens': next_new_tokens
|
||||
}
|
||||
|
||||
def update_mtp_hidden_states(
|
||||
self,
|
||||
input_ids: torch.IntTensor,
|
||||
@ -1146,8 +1190,9 @@ class MTPEagleWorker(MTPWorker):
|
||||
|
||||
def __init__(self,
|
||||
spec_config: "MTPDecodingConfig",
|
||||
model_config: Optional[ModelConfig] = None):
|
||||
super().__init__(spec_config, model_config)
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
use_separate_draft_kv_cache: bool = False):
|
||||
super().__init__(spec_config, model_config, use_separate_draft_kv_cache)
|
||||
self.model_config = model_config
|
||||
self.mtp_num_modules = spec_config.num_nextn_predict_layers
|
||||
self._is_mamba_hybrid_cache = None
|
||||
@ -1170,6 +1215,7 @@ class MTPEagleWorker(MTPWorker):
|
||||
attn_metadata,
|
||||
spec_metadata,
|
||||
draft_model,
|
||||
resource_manager=None,
|
||||
):
|
||||
batch_size = attn_metadata.num_seqs
|
||||
num_contexts = attn_metadata.num_contexts
|
||||
@ -1212,123 +1258,134 @@ class MTPEagleWorker(MTPWorker):
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata)
|
||||
|
||||
# Get the draft KV cache manager if using separate layouts
|
||||
draft_kv_cache_manager = self.get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
|
||||
# Predict draft tokens
|
||||
next_draft_tokens = []
|
||||
for i in range(self.mtp_num_modules):
|
||||
if i == 0:
|
||||
hidden_states = draft_model.mtp_layers[0](
|
||||
embed_tokens=draft_model.embed_tokens,
|
||||
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
|
||||
**inputs)
|
||||
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
|
||||
(self.mtp_num_modules + 1)).long()
|
||||
gather_ids_gen = (start_ids_gen +
|
||||
num_accepted_tokens[num_contexts:] - 1 +
|
||||
attn_metadata.num_ctx_tokens)
|
||||
gather_ids = torch.concat(
|
||||
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
|
||||
else:
|
||||
hidden_states = draft_model.mtp_layers[0](
|
||||
embed_tokens=draft_model.embed_tokens,
|
||||
all_rank_num_tokens=spec_metadata.
|
||||
subseq_all_rank_num_tokens,
|
||||
**inputs)
|
||||
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
|
||||
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
|
||||
with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
|
||||
for i in range(self.mtp_num_modules):
|
||||
if i == 0:
|
||||
hidden_states = draft_model.mtp_layers[0](
|
||||
embed_tokens=draft_model.embed_tokens,
|
||||
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
|
||||
**inputs)
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
new_tokens = inputs["input_ids"][gather_ids]
|
||||
self.guided_decoder.add_draft_batch(new_tokens,
|
||||
num_accepted_tokens,
|
||||
draft_step=i)
|
||||
if self.model_config.mapping.enable_attention_dp and \
|
||||
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
|
||||
hidden_states_gathered = hidden_states[gather_ids]
|
||||
token_count = hidden_states_gathered.view(
|
||||
-1, hidden_states_gathered.shape[-1]).shape[0]
|
||||
max_num_requests = spec_metadata.max_num_requests
|
||||
pad_len = max_num_requests - token_count
|
||||
if pad_len > 0:
|
||||
padded_hidden_states = F.pad(hidden_states_gathered.view(
|
||||
-1, hidden_states_gathered.shape[-1]),
|
||||
(0, 0, 0, pad_len),
|
||||
mode="constant",
|
||||
value=0)
|
||||
elif pad_len == 0:
|
||||
padded_hidden_states = hidden_states_gathered.view(
|
||||
-1, hidden_states_gathered.shape[-1])
|
||||
start_ids_gen = (
|
||||
spec_metadata.batch_indices_cuda[:num_gens] *
|
||||
(self.mtp_num_modules + 1)).long()
|
||||
gather_ids_gen = (start_ids_gen +
|
||||
num_accepted_tokens[num_contexts:] - 1 +
|
||||
attn_metadata.num_ctx_tokens)
|
||||
gather_ids = torch.concat(
|
||||
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported"
|
||||
)
|
||||
logits = draft_model.mtp_layers[0].shared_head(
|
||||
padded_hidden_states, draft_model.lm_head, attn_metadata,
|
||||
True)
|
||||
else:
|
||||
logits = draft_model.mtp_layers[0].shared_head(
|
||||
hidden_states[gather_ids], draft_model.lm_head,
|
||||
attn_metadata, True)
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.execute_draft_batch(logits, draft_step=i)
|
||||
hidden_states = draft_model.mtp_layers[0](
|
||||
embed_tokens=draft_model.embed_tokens,
|
||||
all_rank_num_tokens=spec_metadata.
|
||||
subseq_all_rank_num_tokens,
|
||||
**inputs)
|
||||
|
||||
if self.model_config.mapping.enable_attention_dp and \
|
||||
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
|
||||
mapping_lm_head_tp = draft_model.mtp_layers[
|
||||
0].shared_head.mapping_lm_head_tp
|
||||
new_draft_token = self.draft_sampler(logits, mapping_lm_head_tp)
|
||||
new_draft_token = new_draft_token[:token_count]
|
||||
else:
|
||||
new_draft_token = self.draft_sampler(logits)
|
||||
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
|
||||
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
|
||||
|
||||
hidden_states, position_ids = self.update_draft_tokens(
|
||||
next_draft_tokens, new_draft_token, hidden_states, gather_ids,
|
||||
inputs)
|
||||
# update attn_metadata
|
||||
if i == 0:
|
||||
attn_metadata._seq_lens[:batch_size].fill_(1)
|
||||
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
|
||||
attn_metadata.on_update()
|
||||
# cannot run generation if their is no kv cache
|
||||
has_kv_cache = inputs[
|
||||
"attn_metadata"].kv_cache_manager is not None
|
||||
if has_kv_cache:
|
||||
attn_metadata.host_request_types[:attn_metadata.
|
||||
num_contexts].fill_(1)
|
||||
attn_metadata.num_contexts = 0
|
||||
# update kv_lens_cuda
|
||||
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
|
||||
self.mtp_num_modules -
|
||||
num_accepted_tokens[num_contexts:])
|
||||
attn_metadata.kv_lens_cuda[:num_contexts] += 1
|
||||
# update metadata for flash mla
|
||||
if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla:
|
||||
reorder_block_ids_per_seq = torch.cat([
|
||||
attn_metadata.
|
||||
kv_block_ids_per_seq[num_contexts:batch_size],
|
||||
attn_metadata.kv_block_ids_per_seq[:num_contexts]
|
||||
])
|
||||
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
|
||||
reorder_block_ids_per_seq, non_blocking=True)
|
||||
# update metadata
|
||||
# some attention metadata needs to be updated when changing seq_lens/kv_lens
|
||||
attn_metadata.update_for_spec_dec()
|
||||
elif hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
if self.guided_decoder is not None:
|
||||
new_tokens = inputs["input_ids"][gather_ids]
|
||||
self.guided_decoder.add_draft_batch(new_tokens,
|
||||
num_accepted_tokens,
|
||||
draft_step=i)
|
||||
if self.model_config.mapping.enable_attention_dp and \
|
||||
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
|
||||
hidden_states_gathered = hidden_states[gather_ids]
|
||||
token_count = hidden_states_gathered.view(
|
||||
-1, hidden_states_gathered.shape[-1]).shape[0]
|
||||
max_num_requests = spec_metadata.max_num_requests
|
||||
pad_len = max_num_requests - token_count
|
||||
if pad_len > 0:
|
||||
padded_hidden_states = F.pad(
|
||||
hidden_states_gathered.view(
|
||||
-1, hidden_states_gathered.shape[-1]),
|
||||
(0, 0, 0, pad_len),
|
||||
mode="constant",
|
||||
value=0)
|
||||
elif pad_len == 0:
|
||||
padded_hidden_states = hidden_states_gathered.view(
|
||||
-1, hidden_states_gathered.shape[-1])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported"
|
||||
)
|
||||
logits = draft_model.mtp_layers[0].shared_head(
|
||||
padded_hidden_states, draft_model.lm_head,
|
||||
attn_metadata, True)
|
||||
else:
|
||||
logits = draft_model.mtp_layers[0].shared_head(
|
||||
hidden_states[gather_ids], draft_model.lm_head,
|
||||
attn_metadata, True)
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.execute_draft_batch(logits,
|
||||
draft_step=i)
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def update_kv_lens(kv_lens_cuda, batch_size):
|
||||
kv_lens_cuda[:batch_size] += 1
|
||||
if self.model_config.mapping.enable_attention_dp and \
|
||||
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
|
||||
mapping_lm_head_tp = draft_model.mtp_layers[
|
||||
0].shared_head.mapping_lm_head_tp
|
||||
new_draft_token = self.draft_sampler(
|
||||
logits, mapping_lm_head_tp)
|
||||
new_draft_token = new_draft_token[:token_count]
|
||||
else:
|
||||
new_draft_token = self.draft_sampler(logits)
|
||||
|
||||
update_kv_lens(attn_metadata.kv_lens_cuda, batch_size)
|
||||
# update metadata
|
||||
# some attention metadata needs to be updated when changing kv_lens
|
||||
attn_metadata.update_for_spec_dec()
|
||||
inputs = {
|
||||
"input_ids": new_draft_token,
|
||||
"position_ids": position_ids,
|
||||
"hidden_states": hidden_states,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
hidden_states, position_ids = self.update_draft_tokens(
|
||||
next_draft_tokens, new_draft_token, hidden_states,
|
||||
gather_ids, inputs)
|
||||
# update attn_metadata
|
||||
if i == 0:
|
||||
attn_metadata._seq_lens[:batch_size].fill_(1)
|
||||
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
|
||||
attn_metadata.on_update()
|
||||
# cannot run generation if their is no kv cache
|
||||
has_kv_cache = inputs[
|
||||
"attn_metadata"].kv_cache_manager is not None
|
||||
if has_kv_cache:
|
||||
attn_metadata.host_request_types[:attn_metadata.
|
||||
num_contexts].fill_(1)
|
||||
attn_metadata.num_contexts = 0
|
||||
# update kv_lens_cuda
|
||||
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
|
||||
self.mtp_num_modules -
|
||||
num_accepted_tokens[num_contexts:])
|
||||
attn_metadata.kv_lens_cuda[:num_contexts] += 1
|
||||
# update metadata for flash mla
|
||||
if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla:
|
||||
reorder_block_ids_per_seq = torch.cat([
|
||||
attn_metadata.
|
||||
kv_block_ids_per_seq[num_contexts:batch_size],
|
||||
attn_metadata.kv_block_ids_per_seq[:num_contexts]
|
||||
])
|
||||
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
|
||||
reorder_block_ids_per_seq, non_blocking=True)
|
||||
# update metadata
|
||||
# some attention metadata needs to be updated when changing seq_lens/kv_lens
|
||||
attn_metadata.update_for_spec_dec()
|
||||
elif hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def update_kv_lens(kv_lens_cuda, batch_size):
|
||||
kv_lens_cuda[:batch_size] += 1
|
||||
|
||||
update_kv_lens(attn_metadata.kv_lens_cuda, batch_size)
|
||||
# update metadata
|
||||
# some attention metadata needs to be updated when changing kv_lens
|
||||
attn_metadata.update_for_spec_dec()
|
||||
inputs = {
|
||||
"input_ids": new_draft_token,
|
||||
"position_ids": position_ids,
|
||||
"hidden_states": hidden_states,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
|
||||
# restore attn_metadata to support cuda graph
|
||||
self._restore_attn_metadata_from_spec_dec(attn_metadata)
|
||||
|
||||
@ -220,14 +220,19 @@ def get_num_spec_layers(spec_config):
|
||||
return 0
|
||||
|
||||
|
||||
def get_spec_worker(spec_config, model_config, mapping):
|
||||
def get_spec_worker(spec_config,
|
||||
model_config,
|
||||
mapping,
|
||||
use_separate_draft_kv_cache: bool = False):
|
||||
spec_dec_mode = spec_config.spec_dec_mode
|
||||
if spec_dec_mode.is_mtp_vanilla():
|
||||
return MTPWorker(spec_config, model_config)
|
||||
return MTPWorker(spec_config, model_config, use_separate_draft_kv_cache)
|
||||
if spec_dec_mode.is_mtp_eagle_one_model():
|
||||
return MTPEagleWorker(spec_config, model_config)
|
||||
return MTPEagleWorker(spec_config, model_config,
|
||||
use_separate_draft_kv_cache)
|
||||
if spec_dec_mode.is_eagle3_one_model():
|
||||
return Eagle3OneModelWorker(spec_config, mapping)
|
||||
return Eagle3OneModelWorker(spec_config, mapping,
|
||||
use_separate_draft_kv_cache)
|
||||
return None
|
||||
|
||||
|
||||
@ -243,6 +248,21 @@ def get_num_extra_kv_tokens(spec_config):
|
||||
return 0
|
||||
|
||||
|
||||
def get_draft_kv_cache_manager(spec_config, resource_manager):
|
||||
"""
|
||||
Returns the draft KV cache manager only in one-model speculative decoding
|
||||
mode where the target model manages a separate draft KV cache.
|
||||
"""
|
||||
from ..pyexecutor.resource_manager import ResourceManagerType
|
||||
|
||||
if spec_config is None:
|
||||
return None
|
||||
if not spec_config.spec_dec_mode.use_one_engine():
|
||||
return None
|
||||
return resource_manager.get_resource_manager(
|
||||
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
|
||||
|
||||
|
||||
def update_spec_config_from_model_config(spec_config, model_config):
|
||||
if spec_config.spec_dec_mode.is_mtp_one_model():
|
||||
# Use `max_draft_len` for several low-level APIs. TODO: Remove this after distinguishing them.
|
||||
|
||||
@ -720,6 +720,8 @@ class DecodingBaseConfig(StrictBaseModel):
|
||||
_allow_greedy_draft_tokens: bool = PrivateAttr(True)
|
||||
# Internal: record decoding_type alias used during parsing (for warnings).
|
||||
_decoding_type_alias: Optional[str] = PrivateAttr(default=None)
|
||||
# If set, drafting will use separate KV cache in one-model speculative decoding.
|
||||
_allow_separate_draft_kv_cache: bool = PrivateAttr(True)
|
||||
|
||||
@field_validator('draft_len_schedule')
|
||||
@classmethod
|
||||
|
||||
@ -155,8 +155,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
disable_overlap_scheduler: bool, enable_block_reuse: bool,
|
||||
use_one_model: bool, enable_chunked_prefill: bool,
|
||||
use_chain_drafter: bool, multi_batch: bool,
|
||||
attention_dp: bool, use_hf_speculative_model: bool,
|
||||
request):
|
||||
attention_dp: bool, use_hf_speculative_model: bool):
|
||||
# Eagle3 one model works with overlap scheduler and block reuse.
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
if total_mem_gb < 35:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user