mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 42b52c38e8 into 6df2c8a074
This commit is contained in:
commit
227707a50d
@ -10,6 +10,7 @@ from flashinfer.jit.core import check_cuda_arch
|
||||
from typing_extensions import Self
|
||||
|
||||
from tensorrt_llm.functional import AttentionMaskType
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.models.modeling_utils import QuantConfig
|
||||
|
||||
from ..utils import get_global_attrs, get_model_extra_attrs
|
||||
@ -127,6 +128,11 @@ class FlashInferAttentionMetadata(AttentionMetadata):
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.draft_kv_cache_manager is not None:
|
||||
logger.warning(
|
||||
"draft_kv_cache_manager is not supported in FlashInfer backend. "
|
||||
"One-model speculative decoding with separate KV cache layouts "
|
||||
"may not work correctly.")
|
||||
self._post_init_with_buffers(self.cuda_graph_buffers)
|
||||
|
||||
def _post_init_with_buffers(self, buffers) -> None:
|
||||
|
||||
@ -56,6 +56,8 @@ class AttentionMetadata:
|
||||
max_num_sequences: Optional[int] = None
|
||||
# The KV cache manager.
|
||||
kv_cache_manager: KVCacheManager
|
||||
# Draft KV cache manager for one-model speculative decoding with separate KV cache layouts
|
||||
draft_kv_cache_manager: Optional[KVCacheManager] = None
|
||||
mapping: Optional[Mapping] = None
|
||||
|
||||
enable_flash_mla: bool = False
|
||||
|
||||
@ -974,6 +974,7 @@ class RocketKVCacheManager(KVCacheManager):
|
||||
use_mrope: bool = False,
|
||||
max_beam_width: int = 1,
|
||||
num_extra_decoding_steps: int = 0,
|
||||
draft_kv_cache_manager=None,
|
||||
):
|
||||
requests = super().add_dummy_requests(
|
||||
request_ids=request_ids,
|
||||
@ -984,6 +985,7 @@ class RocketKVCacheManager(KVCacheManager):
|
||||
use_mrope=use_mrope,
|
||||
max_beam_width=max_beam_width,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager,
|
||||
)
|
||||
if prepare_resource:
|
||||
for req in requests:
|
||||
|
||||
@ -796,6 +796,29 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
)
|
||||
self.block_ids_per_seq = None
|
||||
self.kv_block_ids_per_seq = None
|
||||
|
||||
# Allocate separate block offset tensors for draft KV cache manager
|
||||
# Used in one-model speculative decoding with different KV cache layouts
|
||||
if self.draft_kv_cache_manager is not None:
|
||||
self.draft_kv_cache_block_offsets = self.get_empty(
|
||||
buffers,
|
||||
[
|
||||
self.draft_kv_cache_manager.num_pools,
|
||||
self.max_num_sequences, 2,
|
||||
self.draft_kv_cache_manager.max_blocks_per_seq
|
||||
],
|
||||
cache_name="draft_kv_cache_block_offsets",
|
||||
dtype=torch.int32,
|
||||
capture_graph=capture_graph,
|
||||
)
|
||||
self.draft_host_kv_cache_block_offsets = torch.empty_like(
|
||||
self.draft_kv_cache_block_offsets,
|
||||
device='cpu',
|
||||
pin_memory=True,
|
||||
)
|
||||
else:
|
||||
self.draft_kv_cache_block_offsets = None
|
||||
self.draft_host_kv_cache_block_offsets = None
|
||||
if self.enable_flash_mla:
|
||||
self.block_ids_per_seq = self.get_empty(
|
||||
buffers,
|
||||
@ -1007,6 +1030,25 @@ class TrtllmAttentionMetadata(AttentionMetadata):
|
||||
assert self.kv_lens[:self.num_seqs].max(
|
||||
) <= self.kv_cache_manager.max_seq_len, error_message
|
||||
|
||||
# Also prepare draft KV cache block offsets if draft_kv_cache_manager exists
|
||||
if self.draft_kv_cache_manager is not None:
|
||||
# Copy blocks for all context requests
|
||||
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
|
||||
self.draft_host_kv_cache_block_offsets,
|
||||
self.request_ids[:self.num_contexts], 1, 0)
|
||||
# Copy blocks for all generation requests
|
||||
self.draft_kv_cache_manager.impl.copy_batch_block_offsets(
|
||||
self.draft_host_kv_cache_block_offsets,
|
||||
self.request_ids[self.num_contexts:], self.beam_width,
|
||||
self.num_contexts)
|
||||
for pool_idx in range(
|
||||
self.draft_host_kv_cache_block_offsets.shape[0]):
|
||||
self.draft_kv_cache_block_offsets[
|
||||
pool_idx, :self.num_seqs].copy_(
|
||||
self.draft_host_kv_cache_block_offsets[
|
||||
pool_idx, :self.num_seqs],
|
||||
non_blocking=True)
|
||||
|
||||
self.kv_lens_cuda_runtime = self.kv_lens_cuda[:self.num_seqs]
|
||||
# Don't use self.kv_lens here because it includes extra tokens.
|
||||
# Use actual KV length (without extra tokens) for kv_lens_runtime,
|
||||
|
||||
@ -17,7 +17,8 @@ from ..modules.linear import (Linear, TensorParallelMode, WeightMode,
|
||||
WeightsLoadingConfig)
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
|
||||
from ..speculative import SpecMetadata, get_spec_worker
|
||||
from ..speculative import (SpecMetadata, get_spec_worker,
|
||||
should_use_separate_draft_kv_cache)
|
||||
from ..utils import AuxStreamType
|
||||
from .checkpoints.base_weight_mapper import BaseWeightMapper
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM, TModel,
|
||||
@ -880,6 +881,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
vocab_size=model_config.pretrained_config.vocab_size)
|
||||
self.draft_model = None
|
||||
self.draft_config = None
|
||||
self.use_separate_draft_kv_cache = False
|
||||
spec_config = getattr(model_config, 'spec_config', None)
|
||||
if spec_config and spec_config.spec_dec_mode.use_one_engine():
|
||||
if spec_config.spec_dec_mode.is_eagle3_one_model():
|
||||
@ -913,11 +915,16 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
self.draft_config.quant_config.kv_cache_quant_algo = \
|
||||
model_config.quant_config.kv_cache_quant_algo
|
||||
|
||||
self.use_separate_draft_kv_cache = should_use_separate_draft_kv_cache(
|
||||
spec_config)
|
||||
|
||||
self.draft_model = get_draft_model(model_config, self.draft_config,
|
||||
self.lm_head, self.model)
|
||||
self.spec_worker = get_spec_worker(model_config.spec_config,
|
||||
model_config,
|
||||
model_config.mapping)
|
||||
self.spec_worker = get_spec_worker(
|
||||
model_config.spec_config,
|
||||
model_config,
|
||||
model_config.mapping,
|
||||
use_separate_draft_kv_cache=self.use_separate_draft_kv_cache)
|
||||
|
||||
if self.draft_config is not None and model_config.spec_config.eagle3_model_arch == "llama3":
|
||||
for key, value in self.draft_config.extra_attrs.items():
|
||||
@ -934,6 +941,7 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
return_context_logits: bool = False,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
resource_manager=None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(
|
||||
@ -978,7 +986,8 @@ class SpecDecOneEngineForCausalLM(DecoderModelForCausalLM[TModel, TConfig],
|
||||
logits=logits,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata,
|
||||
draft_model=self.draft_model)
|
||||
draft_model=self.draft_model,
|
||||
resource_manager=resource_manager)
|
||||
else:
|
||||
logits = self.logits_processor.forward(
|
||||
hidden_states,
|
||||
|
||||
@ -25,7 +25,8 @@ from tensorrt_llm.mapping import CpType, Mapping
|
||||
|
||||
from ..attention_backend import get_sparse_attn_kv_cache_manager
|
||||
from ..model_config import ModelConfig
|
||||
from ..speculative import get_num_extra_kv_tokens, get_spec_decoder
|
||||
from ..speculative import (get_num_extra_kv_tokens, get_spec_decoder,
|
||||
should_use_separate_draft_kv_cache)
|
||||
from .config_utils import is_mla, is_nemotron_hybrid, is_qwen3_next
|
||||
from .guided_decoder import GuidedDecoder
|
||||
from .kv_cache_connector import KvCacheConnectorManager
|
||||
@ -78,6 +79,7 @@ class KvCacheCreator:
|
||||
sparse_attention_config: SparseAttentionConfig,
|
||||
profiling_stage_data: Optional[dict],
|
||||
execution_stream: Optional[torch.cuda.Stream] = None,
|
||||
draft_config: Optional[ModelConfig] = None,
|
||||
):
|
||||
self._model_engine = model_engine
|
||||
self._draft_model_engine = draft_model_engine
|
||||
@ -99,6 +101,7 @@ class KvCacheCreator:
|
||||
self._kv_cache_manager_cls = get_kv_cache_manager_cls(
|
||||
model_engine.model.model_config)
|
||||
self._execution_stream = execution_stream
|
||||
self._draft_config = draft_config
|
||||
|
||||
def _get_kv_size_per_token(self):
|
||||
model_config = self._model_engine.model.model_config
|
||||
@ -111,6 +114,12 @@ class KvCacheCreator:
|
||||
draft_model_config,
|
||||
mapping,
|
||||
tokens_per_block=self._tokens_per_block)
|
||||
elif self._should_create_separate_draft_kv_cache():
|
||||
# One-model draft with separate KV cache layout
|
||||
kv_size_per_token += self._kv_cache_manager_cls.get_cache_size_per_token(
|
||||
self._draft_config,
|
||||
mapping,
|
||||
tokens_per_block=self._tokens_per_block)
|
||||
return kv_size_per_token
|
||||
|
||||
def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
|
||||
@ -393,9 +402,12 @@ class KvCacheCreator:
|
||||
# get kv cache stats for both model and draft model
|
||||
kv_stats = py_executor.resource_manager.resource_managers.get(
|
||||
ResourceManagerType.KV_CACHE_MANAGER).get_kv_cache_stats()
|
||||
kv_stats_draft = py_executor.resource_manager.resource_managers.get(
|
||||
ResourceManagerType.DRAFT_KV_CACHE_MANAGER).get_kv_cache_stats(
|
||||
) if self._draft_model_engine is not None else None
|
||||
# Get draft KV cache stats if present (either from two-model mode or one-model
|
||||
# mode with separate draft KV cache)
|
||||
draft_kv_cache_manager = py_executor.resource_manager.resource_managers.get(
|
||||
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
|
||||
kv_stats_draft = draft_kv_cache_manager.get_kv_cache_stats(
|
||||
) if draft_kv_cache_manager is not None else None
|
||||
|
||||
# get total allocated bytes
|
||||
allocated_bytes = kv_stats.allocated_bytes + (
|
||||
@ -499,6 +511,60 @@ class KvCacheCreator:
|
||||
|
||||
return kv_cache_manager
|
||||
|
||||
def _should_create_separate_draft_kv_cache(self) -> bool:
|
||||
"""
|
||||
Check if we need a separate draft KV cache manager for one-model mode.
|
||||
Returns True if the speculative config has use_separate_draft_kv_cache=True.
|
||||
"""
|
||||
if self._draft_config is None:
|
||||
return False
|
||||
|
||||
return should_use_separate_draft_kv_cache(self._speculative_config)
|
||||
|
||||
def _create_one_model_draft_kv_cache_manager(
|
||||
self,
|
||||
estimating_kv_cache: bool = False) -> Optional[KVCacheManager]:
|
||||
"""
|
||||
Create a KV cache manager for draft model layers in one-model mode
|
||||
when target and draft have different KV cache layouts.
|
||||
"""
|
||||
# Get target model's num_hidden_layers to compute correct layer indices.
|
||||
# Draft model layers in one-model mode start at target_num_layers.
|
||||
target_pretrained_config = self._model_engine.model.model_config.pretrained_config
|
||||
target_num_layers = target_pretrained_config.num_hidden_layers
|
||||
num_draft_layers = self._draft_config.pretrained_config.num_hidden_layers
|
||||
|
||||
# Create layer_mask: False for target layers, True for draft layers.
|
||||
# This ensures the draft KV cache manager uses the correct layer indices
|
||||
# (e.g., layers 32, 33, ... instead of 0, 1, ...).
|
||||
layer_mask = [False] * target_num_layers + [True] * num_draft_layers
|
||||
|
||||
# Get the appropriate KV cache manager class for the draft model
|
||||
draft_kv_cache_manager_cls = get_kv_cache_manager_cls(
|
||||
self._draft_config)
|
||||
|
||||
return _create_kv_cache_manager(
|
||||
model_engine=None,
|
||||
kv_cache_manager_cls=draft_kv_cache_manager_cls,
|
||||
mapping=self._mapping,
|
||||
kv_cache_config=self._kv_cache_config,
|
||||
tokens_per_block=self._tokens_per_block,
|
||||
max_seq_len=self._max_seq_len,
|
||||
max_batch_size=self._max_batch_size,
|
||||
spec_config=self._speculative_config,
|
||||
sparse_attn_config=None, # Not applicable for draft in one-model mode
|
||||
max_num_tokens=self._max_num_tokens,
|
||||
max_beam_width=self._max_beam_width,
|
||||
kv_connector_manager=None, # Not supported for draft models
|
||||
estimating_kv_cache=estimating_kv_cache,
|
||||
execution_stream=self._execution_stream,
|
||||
# One-model draft specific overrides
|
||||
model_config=self._draft_config,
|
||||
dtype=self._draft_config.pretrained_config.torch_dtype,
|
||||
is_draft=True,
|
||||
layer_mask=layer_mask,
|
||||
)
|
||||
|
||||
def build_managers(self,
|
||||
resources: Dict,
|
||||
estimating_kv_cache: bool = False) -> None:
|
||||
@ -510,9 +576,16 @@ class KvCacheCreator:
|
||||
raise NotImplementedError(
|
||||
"Connector manager is not supported for draft model.")
|
||||
|
||||
draft_kv_cache_manager = self._create_kv_cache_manager(
|
||||
self._draft_model_engine, estimating_kv_cache
|
||||
) if self._draft_model_engine is not None else None
|
||||
draft_kv_cache_manager = None
|
||||
|
||||
# Two-model speculative decoding: draft model has separate engine
|
||||
if self._draft_model_engine is not None:
|
||||
draft_kv_cache_manager = self._create_kv_cache_manager(
|
||||
self._draft_model_engine, estimating_kv_cache)
|
||||
# One-model speculative decoding with different KV layouts
|
||||
elif self._should_create_separate_draft_kv_cache():
|
||||
draft_kv_cache_manager = self._create_one_model_draft_kv_cache_manager(
|
||||
estimating_kv_cache)
|
||||
|
||||
resources[ResourceManagerType.KV_CACHE_MANAGER] = kv_cache_manager
|
||||
resources[
|
||||
@ -530,7 +603,7 @@ class KvCacheCreator:
|
||||
|
||||
|
||||
def _create_kv_cache_manager(
|
||||
model_engine: PyTorchModelEngine,
|
||||
model_engine: Optional[PyTorchModelEngine],
|
||||
kv_cache_manager_cls,
|
||||
mapping: Mapping,
|
||||
kv_cache_config: KvCacheConfig,
|
||||
@ -543,13 +616,31 @@ def _create_kv_cache_manager(
|
||||
max_beam_width: int,
|
||||
kv_connector_manager: Optional[KvCacheConnectorManager],
|
||||
estimating_kv_cache: bool,
|
||||
execution_stream: Optional[torch.cuda.Stream] = None) -> KVCacheManager:
|
||||
execution_stream: Optional[torch.cuda.Stream] = None,
|
||||
# Optional overrides for one-model draft case (when model_engine is None)
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
is_draft: Optional[bool] = None,
|
||||
layer_mask: Optional[List[bool]] = None) -> KVCacheManager:
|
||||
"""
|
||||
Returns:
|
||||
A KVCacheManager instance for the given model_engine
|
||||
A KVCacheManager instance for the given model engine or model config
|
||||
"""
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
quant_config = model_engine.model.model_config.quant_config
|
||||
# Extract config from model_engine or use provided model_config
|
||||
if model_config is not None:
|
||||
config = model_config.pretrained_config
|
||||
quant_config = model_config.quant_config
|
||||
_model_config = model_config
|
||||
else:
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
quant_config = model_engine.model.model_config.quant_config
|
||||
_model_config = model_engine.model.model_config
|
||||
|
||||
if dtype is None:
|
||||
dtype = model_engine.dtype
|
||||
|
||||
if is_draft is None:
|
||||
is_draft = model_engine.is_draft_model
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
num_attention_heads = config.num_attention_heads
|
||||
@ -565,8 +656,7 @@ def _create_kv_cache_manager(
|
||||
):
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.NVFP4
|
||||
else:
|
||||
kv_cache_dtype = str_dtype_to_binding(
|
||||
torch_dtype_to_str(model_engine.dtype))
|
||||
kv_cache_dtype = str_dtype_to_binding(torch_dtype_to_str(dtype))
|
||||
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
|
||||
@ -584,12 +674,13 @@ def _create_kv_cache_manager(
|
||||
dtype=kv_cache_dtype,
|
||||
spec_config=spec_config,
|
||||
max_beam_width=max_beam_width,
|
||||
is_draft=model_engine.is_draft_model,
|
||||
is_draft=is_draft,
|
||||
kv_connector_manager=kv_connector_manager
|
||||
if not estimating_kv_cache else None,
|
||||
sparse_attn_config=sparse_attn_config,
|
||||
is_estimating_kv_cache=estimating_kv_cache,
|
||||
execution_stream=execution_stream,
|
||||
layer_mask=layer_mask,
|
||||
)
|
||||
elif is_nemotron_hybrid(config):
|
||||
if max_beam_width > 1:
|
||||
@ -601,9 +692,10 @@ def _create_kv_cache_manager(
|
||||
"Connector manager is not supported for MambaHybridCacheManager."
|
||||
)
|
||||
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
num_layers = config.hybrid_override_pattern.count("*")
|
||||
layer_mask = [char == "*" for char in config.hybrid_override_pattern]
|
||||
hybrid_layer_mask = [
|
||||
char == "*" for char in config.hybrid_override_pattern
|
||||
]
|
||||
mamba_num_layers = config.hybrid_override_pattern.count("M")
|
||||
mamba_layer_mask = [
|
||||
char == "M" for char in config.hybrid_override_pattern
|
||||
@ -618,12 +710,13 @@ def _create_kv_cache_manager(
|
||||
mamba_num_layers,
|
||||
mamba_layer_mask,
|
||||
config.torch_dtype,
|
||||
model_engine.model.model_config.quant_config.mamba_ssm_cache_dtype,
|
||||
quant_config.mamba_ssm_cache_dtype
|
||||
if quant_config is not None else None,
|
||||
# kv cache parameters
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
layer_mask=layer_mask,
|
||||
layer_mask=hybrid_layer_mask,
|
||||
num_kv_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
@ -644,13 +737,12 @@ def _create_kv_cache_manager(
|
||||
raise NotImplementedError(
|
||||
"Connector manager is not supported for MambaHybridCacheManager."
|
||||
)
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
mamba_layer_mask = [
|
||||
True if i %
|
||||
config.full_attention_interval != config.full_attention_interval -
|
||||
1 else False for i in range(num_hidden_layers)
|
||||
]
|
||||
layer_mask = [
|
||||
hybrid_layer_mask = [
|
||||
False if i %
|
||||
config.full_attention_interval != config.full_attention_interval -
|
||||
1 else True for i in range(num_hidden_layers)
|
||||
@ -668,12 +760,13 @@ def _create_kv_cache_manager(
|
||||
num_mamba_layers,
|
||||
mamba_layer_mask,
|
||||
config.torch_dtype,
|
||||
model_engine.model.model_config.quant_config.mamba_ssm_cache_dtype,
|
||||
quant_config.mamba_ssm_cache_dtype
|
||||
if quant_config is not None else None,
|
||||
# kv cache parameters
|
||||
kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_layers,
|
||||
layer_mask=layer_mask,
|
||||
layer_mask=hybrid_layer_mask,
|
||||
num_kv_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=tokens_per_block,
|
||||
@ -689,7 +782,7 @@ def _create_kv_cache_manager(
|
||||
# NOTE: this is a workaround for VSWA to switch to calculate_max_num_blocks_from_cpp in KVCahceManager
|
||||
is_vswa = kv_cache_config.max_attention_window is not None and len(
|
||||
set(kv_cache_config.max_attention_window)) > 1
|
||||
binding_model_config = model_engine.model.model_config.get_bindings_model_config(
|
||||
binding_model_config = _model_config.get_bindings_model_config(
|
||||
tokens_per_block=tokens_per_block) if is_vswa else None
|
||||
|
||||
kv_cache_manager = kv_cache_manager_cls(
|
||||
@ -707,12 +800,13 @@ def _create_kv_cache_manager(
|
||||
max_num_tokens=max_num_tokens,
|
||||
model_config=binding_model_config,
|
||||
max_beam_width=max_beam_width,
|
||||
is_draft=model_engine.is_draft_model,
|
||||
is_draft=is_draft,
|
||||
kv_connector_manager=kv_connector_manager
|
||||
if not estimating_kv_cache else None,
|
||||
sparse_attn_config=sparse_attn_config,
|
||||
is_estimating_kv_cache=estimating_kv_cache,
|
||||
execution_stream=execution_stream,
|
||||
layer_mask=layer_mask,
|
||||
)
|
||||
return kv_cache_manager
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ from ..memory_buffer_utils import get_memory_buffers
|
||||
from ..modules.multi_stream_utils import with_multi_stream
|
||||
from ..speculative.eagle3 import Eagle3ResourceManager
|
||||
from ..speculative.mtp import SampleStateTensorsMTP
|
||||
from ..speculative.utils import get_draft_kv_cache_manager
|
||||
from ..utils import make_weak_ref, piecewise_cuda_graph
|
||||
from .llm_request import get_draft_token_length
|
||||
from .mamba_cache_manager import MambaCacheManager
|
||||
@ -439,12 +440,19 @@ class CUDAGraphRunner:
|
||||
if available_blocks < 1:
|
||||
return 0
|
||||
|
||||
# Get draft KV cache manager only for one-model speculative decoding.
|
||||
# In two-model mode, each model has its own KV cache manager, so
|
||||
# draft_kv_cache_manager should be None.
|
||||
draft_kv_cache_manager = get_draft_kv_cache_manager(
|
||||
self.spec_config, resource_manager)
|
||||
|
||||
self.padding_dummy_request = kv_cache_manager.add_dummy_requests(
|
||||
[CUDA_GRAPH_DUMMY_REQUEST_ID],
|
||||
is_gen=True,
|
||||
max_num_draft_tokens=runtime_draft_len,
|
||||
use_mrope=self.config.use_mrope,
|
||||
max_beam_width=self.config.max_beam_width)[0]
|
||||
max_beam_width=self.config.max_beam_width,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager)[0]
|
||||
self.padding_dummy_request.is_cuda_graph_dummy = True
|
||||
spec_res_mgr = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER)
|
||||
|
||||
@ -44,8 +44,8 @@ from ..models.modeling_multimodal_utils import filter_mm_token_from_input_ids
|
||||
from ..models.modeling_utils import DecoderModelForCausalLM
|
||||
from ..modules.fused_moe.moe_load_balancer import (MoeLoadBalancer,
|
||||
MoeLoadBalancerIterContext)
|
||||
from ..speculative import (SpecMetadata, get_num_extra_kv_tokens,
|
||||
get_spec_metadata,
|
||||
from ..speculative import (SpecMetadata, get_draft_kv_cache_manager,
|
||||
get_num_extra_kv_tokens, get_spec_metadata,
|
||||
update_spec_config_from_model_config)
|
||||
from ..speculative.drafting_loops import BaseDraftingLoopWrapper
|
||||
from ..speculative.eagle3 import (Eagle3OneModelSpecMetadata,
|
||||
@ -543,6 +543,15 @@ class PyTorchModelEngine(ModelEngine):
|
||||
def use_beam_search(self):
|
||||
return self.max_beam_width > 1
|
||||
|
||||
def _get_draft_kv_cache_manager(
|
||||
self,
|
||||
resource_manager: ResourceManager) -> Optional[KVCacheManager]:
|
||||
"""
|
||||
Returns the draft KV cache manager only in one-model speculative decoding
|
||||
mode where the target model manages a separate draft KV cache.
|
||||
"""
|
||||
return get_draft_kv_cache_manager(self.spec_config, resource_manager)
|
||||
|
||||
@contextmanager
|
||||
def set_warmup_flag(self):
|
||||
prev_is_warmup = self.is_warmup
|
||||
@ -852,6 +861,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
"""A context manager to automatically free resources of a dummy batch."""
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
self.kv_cache_manager_key)
|
||||
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
spec_resource_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER)
|
||||
try:
|
||||
@ -860,6 +871,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if batch is not None and kv_cache_manager is not None:
|
||||
for req in batch.all_requests():
|
||||
kv_cache_manager.free_resources(req)
|
||||
if draft_kv_cache_manager is not None:
|
||||
draft_kv_cache_manager.free_resources(req)
|
||||
if spec_resource_manager is not None:
|
||||
spec_resource_manager.free_resources(req)
|
||||
|
||||
@ -882,6 +895,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
"""Creates a generic dummy ScheduledRequests object for warmup."""
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
self.kv_cache_manager_key)
|
||||
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
|
||||
spec_resource_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER)
|
||||
|
||||
@ -953,7 +969,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
is_gen=False,
|
||||
max_num_draft_tokens=self.runtime_draft_len,
|
||||
use_mrope=self.use_mrope,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps)
|
||||
num_extra_decoding_steps=num_extra_decoding_steps,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager)
|
||||
|
||||
if spec_resource_manager is not None:
|
||||
spec_resource_manager.add_dummy_requests(
|
||||
@ -969,7 +986,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
max_num_draft_tokens=self.max_total_draft_tokens,
|
||||
use_mrope=self.use_mrope,
|
||||
max_beam_width=self.max_beam_width,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps)
|
||||
num_extra_decoding_steps=num_extra_decoding_steps,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager)
|
||||
if spec_resource_manager is not None:
|
||||
spec_resource_manager.add_dummy_requests(request_ids=list(
|
||||
range(num_ctx_requests, num_ctx_requests +
|
||||
@ -991,6 +1009,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.kv_cache_manager_key)
|
||||
spec_resource_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER)
|
||||
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
|
||||
available_blocks = kv_cache_manager.get_num_free_blocks(
|
||||
) // self.max_beam_width
|
||||
@ -1008,9 +1028,15 @@ class PyTorchModelEngine(ModelEngine):
|
||||
max_num_draft_tokens=draft_len,
|
||||
use_mrope=self.use_mrope,
|
||||
max_beam_width=self.max_beam_width,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps)
|
||||
num_extra_decoding_steps=num_extra_decoding_steps,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager)
|
||||
|
||||
available_tokens = kv_cache_manager.get_num_available_tokens(draft_len)
|
||||
# Also consider draft KV cache capacity when it exists
|
||||
if draft_kv_cache_manager is not None:
|
||||
draft_available_tokens = draft_kv_cache_manager.get_num_available_tokens(
|
||||
draft_len)
|
||||
available_tokens = min(available_tokens, draft_available_tokens)
|
||||
|
||||
# Add one dummy request with the maximum possible sequence length.
|
||||
max_seq_len = self.max_seq_len if max_seq_len is None else max_seq_len
|
||||
@ -1033,7 +1059,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
max_num_draft_tokens=draft_len,
|
||||
use_mrope=self.use_mrope,
|
||||
max_beam_width=self.max_beam_width,
|
||||
num_extra_decoding_steps=num_extra_decoding_steps)[0]
|
||||
num_extra_decoding_steps=num_extra_decoding_steps,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager)[0]
|
||||
|
||||
# Insert the longest request first to simulate padding for the CUDA graph.
|
||||
requests.insert(0, max_seq_len_request)
|
||||
@ -1057,7 +1084,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
req.py_is_first_draft = True
|
||||
req.py_draft_tokens = []
|
||||
|
||||
def _set_up_attn_metadata(self, kv_cache_manager: KVCacheManager):
|
||||
def _set_up_attn_metadata(
|
||||
self,
|
||||
kv_cache_manager: KVCacheManager,
|
||||
draft_kv_cache_manager: Optional[KVCacheManager] = None):
|
||||
enable_context_mla_with_cached_kv = is_mla(
|
||||
self.model.model_config.pretrained_config) and (
|
||||
self.attn_runtime_features.cache_reuse
|
||||
@ -1106,6 +1136,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
max_num_tokens=self.max_num_tokens,
|
||||
max_num_sequences=self.batch_size * self.max_beam_width,
|
||||
kv_cache_manager=kv_cache_manager,
|
||||
draft_kv_cache_manager=draft_kv_cache_manager,
|
||||
mapping=self.mapping,
|
||||
runtime_features=self.attn_runtime_features,
|
||||
enable_flash_mla=self.model.model_config.enable_flash_mla,
|
||||
@ -1473,7 +1504,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
else:
|
||||
return self._apply_incremental_update_target(
|
||||
scheduled_requests, kv_cache_manager, attn_metadata,
|
||||
spec_metadata, new_tensors_device, num_accepted_tokens_device)
|
||||
spec_metadata, new_tensors_device, num_accepted_tokens_device,
|
||||
resource_manager)
|
||||
|
||||
@nvtx_range("_prepare_incremental_update_metadata")
|
||||
def _prepare_incremental_update_metadata(
|
||||
@ -1739,7 +1771,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
attn_metadata: AttentionMetadata,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
new_tensors_device: Optional[SampleStateTensors] = None,
|
||||
num_accepted_tokens_device: Optional[torch.Tensor] = None):
|
||||
num_accepted_tokens_device: Optional[torch.Tensor] = None,
|
||||
resource_manager: Optional[ResourceManager] = None):
|
||||
# Extract tensors from new_tensors_device
|
||||
new_tokens_device = new_tensors_device.new_tokens # [batch, 1 + draft_len]
|
||||
new_tokens_lens_device = new_tensors_device.new_tokens_lens # [batch]
|
||||
@ -1873,6 +1906,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
'position_ids': final_position_ids,
|
||||
'inputs_embeds': None,
|
||||
"multimodal_params": [],
|
||||
'resource_manager': resource_manager,
|
||||
}
|
||||
|
||||
if bool(lora_params):
|
||||
@ -2665,6 +2699,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
'position_ids': final_position_ids,
|
||||
'inputs_embeds': None,
|
||||
"multimodal_params": multimodal_params_list,
|
||||
'resource_manager': resource_manager,
|
||||
}
|
||||
|
||||
if bool(lora_params):
|
||||
@ -2730,7 +2765,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
attn_metadata: AttentionMetadata,
|
||||
spec_metadata: Optional[SpecMetadata] = None):
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
resource_manager: Optional[ResourceManager] = None):
|
||||
"""
|
||||
Prepare inputs for Pytorch Model.
|
||||
"""
|
||||
@ -2834,7 +2870,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
'position_ids':
|
||||
self.position_ids_cuda[:virtual_num_tokens].unsqueeze(0),
|
||||
'inputs_embeds': None,
|
||||
"multimodal_params": multimodal_params_list
|
||||
"multimodal_params": multimodal_params_list,
|
||||
'resource_manager': resource_manager,
|
||||
}
|
||||
|
||||
if bool(lora_params):
|
||||
@ -2877,10 +2914,12 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
return inputs, None
|
||||
|
||||
def _prepare_star_attention_inputs(self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
kv_cache_manager,
|
||||
attn_metadata: AttentionMetadata):
|
||||
def _prepare_star_attention_inputs(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
kv_cache_manager,
|
||||
attn_metadata: AttentionMetadata,
|
||||
resource_manager: Optional[ResourceManager] = None):
|
||||
"""
|
||||
Prepare inputs for Pytorch Model.
|
||||
"""
|
||||
@ -3096,7 +3135,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
'attn_metadata': attn_metadata,
|
||||
'input_ids': self.input_ids_cuda[:num_tokens],
|
||||
'position_ids': self.position_ids_cuda[:num_tokens].unsqueeze(0),
|
||||
'inputs_embeds': None
|
||||
'inputs_embeds': None,
|
||||
'resource_manager': resource_manager,
|
||||
}, gather_ids if is_spec_decode else None
|
||||
|
||||
def _get_lora_params_from_requests(self,
|
||||
@ -3207,7 +3247,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
cp_type = self.mapping.cp_config['cp_type']
|
||||
if CpType.STAR == cp_type:
|
||||
return self._prepare_star_attention_inputs(
|
||||
scheduled_requests, kv_cache_manager, attn_metadata)
|
||||
scheduled_requests, kv_cache_manager, attn_metadata,
|
||||
resource_manager)
|
||||
elif CpType.HELIX == cp_type:
|
||||
# Take the usual route of _prepare_tp_inputs.
|
||||
pass
|
||||
@ -3235,8 +3276,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
req_id_to_old_request: Optional[Dict[int, LlmRequest]] = None):
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
self.kv_cache_manager_key)
|
||||
draft_kv_cache_manager = self._get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
|
||||
attn_metadata = self._set_up_attn_metadata(kv_cache_manager)
|
||||
attn_metadata = self._set_up_attn_metadata(kv_cache_manager,
|
||||
draft_kv_cache_manager)
|
||||
if self.enable_spec_decode:
|
||||
spec_resource_manager = resource_manager.get_resource_manager(
|
||||
ResourceManagerType.SPEC_RESOURCE_MANAGER)
|
||||
@ -3270,7 +3314,8 @@ class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
if kv_cache_manager is None:
|
||||
inputs, gather_ids = self._prepare_tp_inputs_no_cache(
|
||||
scheduled_requests, attn_metadata, spec_metadata)
|
||||
scheduled_requests, attn_metadata, spec_metadata,
|
||||
resource_manager)
|
||||
|
||||
with MoeLoadBalancerIterContext(moe_load_balancer):
|
||||
# Special handling for multimodal encoder only mode
|
||||
|
||||
@ -610,6 +610,10 @@ def create_py_executor(
|
||||
|
||||
if model_engine.model.model_config.is_generation:
|
||||
#NOTE: non-generation models do not have kv cache
|
||||
|
||||
# Get draft config for one-engine speculative decoding if available
|
||||
draft_config = getattr(model_engine.model, 'draft_config', None)
|
||||
|
||||
kv_cache_creator = KvCacheCreator(
|
||||
model_engine=model_engine,
|
||||
draft_model_engine=draft_model_engine,
|
||||
@ -627,6 +631,7 @@ def create_py_executor(
|
||||
profiling_stage_data=profiling_stage_data,
|
||||
sparse_attention_config=sparse_attention_config,
|
||||
execution_stream=execution_stream,
|
||||
draft_config=draft_config,
|
||||
)
|
||||
estimating_kv_cache = kv_cache_creator.try_prepare_estimation()
|
||||
with allocation_scope(
|
||||
|
||||
@ -519,6 +519,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
# we need to make the KV cache manager aware that multiple autoregressive steps will
|
||||
# occur.
|
||||
num_extra_decoding_steps: int = 0,
|
||||
draft_kv_cache_manager: Optional[BaseResourceManager] = None,
|
||||
):
|
||||
beam_width = max_beam_width
|
||||
requests = []
|
||||
@ -553,6 +554,12 @@ class KVCacheManager(BaseResourceManager):
|
||||
for _ in range(num_extra_decoding_steps):
|
||||
self.impl.add_token(req_id)
|
||||
|
||||
if draft_kv_cache_manager is not None:
|
||||
draft_kv_cache_manager.impl.add_sequence(
|
||||
req_id, token_num, beam_width, req)
|
||||
for _ in range(self.num_extra_kv_tokens):
|
||||
draft_kv_cache_manager.impl.add_token(req_id)
|
||||
|
||||
if is_gen:
|
||||
req.state = LlmRequestState.GENERATION_IN_PROGRESS
|
||||
req.prompt_len = token_num - 1
|
||||
@ -579,6 +586,10 @@ class KVCacheManager(BaseResourceManager):
|
||||
for _ in range(max_num_draft_tokens):
|
||||
self.impl.add_token(req_id)
|
||||
|
||||
if draft_kv_cache_manager is not None:
|
||||
for _ in range(max_num_draft_tokens):
|
||||
draft_kv_cache_manager.impl.add_token(req_id)
|
||||
|
||||
# TODO: Planning to get dummy_data from each model. Before that, we need to add dummy mrop_config to the request here.
|
||||
if use_mrope:
|
||||
dummy_mrope_position_ids = torch.arange(
|
||||
|
||||
@ -1,14 +1,15 @@
|
||||
from .auto_heuristic import suggest_spec_config
|
||||
from .eagle3 import Eagle3SpecMetadata
|
||||
from .interface import SpecMetadata, SpecWorkerBase
|
||||
from .interface import (SpecMetadata, SpecWorkerBase,
|
||||
should_use_separate_draft_kv_cache)
|
||||
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
|
||||
from .ngram import NGramDrafter, NGramPoolManager
|
||||
from .save_hidden_state import SaveHiddenStatesDrafter
|
||||
from .spec_tree_manager import SpecTreeManager
|
||||
from .utils import (get_num_extra_kv_tokens, get_num_spec_layers,
|
||||
get_spec_decoder, get_spec_drafter, get_spec_metadata,
|
||||
get_spec_resource_manager, get_spec_worker,
|
||||
update_spec_config_from_model_config)
|
||||
from .utils import (get_draft_kv_cache_manager, get_num_extra_kv_tokens,
|
||||
get_num_spec_layers, get_spec_decoder, get_spec_drafter,
|
||||
get_spec_metadata, get_spec_resource_manager,
|
||||
get_spec_worker, update_spec_config_from_model_config)
|
||||
|
||||
__all__ = [
|
||||
"Eagle3SpecMetadata",
|
||||
@ -20,6 +21,7 @@ __all__ = [
|
||||
"SaveHiddenStatesDrafter",
|
||||
"SpecMetadata",
|
||||
"SpecWorkerBase",
|
||||
"get_draft_kv_cache_manager",
|
||||
"get_num_extra_kv_tokens",
|
||||
"get_num_spec_layers",
|
||||
"get_spec_decoder",
|
||||
@ -27,6 +29,7 @@ __all__ = [
|
||||
"get_spec_metadata",
|
||||
"get_spec_resource_manager",
|
||||
"get_spec_worker",
|
||||
"should_use_separate_draft_kv_cache",
|
||||
"update_spec_config_from_model_config",
|
||||
"suggest_spec_config",
|
||||
"SpecTreeManager",
|
||||
|
||||
@ -358,8 +358,11 @@ class Eagle3OneModelSampler(MTPSampler):
|
||||
|
||||
class Eagle3OneModelWorker(SpecWorkerBase):
|
||||
|
||||
def __init__(self, spec_config: "EagleDecodingConfig", mapping: Mapping):
|
||||
super().__init__()
|
||||
def __init__(self,
|
||||
spec_config: "EagleDecodingConfig",
|
||||
mapping: Mapping,
|
||||
use_separate_draft_kv_cache: bool = False):
|
||||
super().__init__(use_separate_draft_kv_cache)
|
||||
self.spec_config = spec_config
|
||||
self.mapping = mapping
|
||||
|
||||
@ -369,8 +372,15 @@ class Eagle3OneModelWorker(SpecWorkerBase):
|
||||
|
||||
# Skip torch.compile for now since current Torch is not compatible with Triton 3.4
|
||||
# @torch.compile(options={"max-autotune": True})
|
||||
def forward(self, input_ids, position_ids, hidden_states, logits,
|
||||
attn_metadata, spec_metadata, draft_model):
|
||||
def forward(self,
|
||||
input_ids,
|
||||
position_ids,
|
||||
hidden_states,
|
||||
logits,
|
||||
attn_metadata,
|
||||
spec_metadata,
|
||||
draft_model,
|
||||
resource_manager=None):
|
||||
batch_size = attn_metadata.num_seqs
|
||||
num_contexts = attn_metadata.num_contexts
|
||||
num_gens = batch_size - num_contexts
|
||||
@ -401,81 +411,91 @@ class Eagle3OneModelWorker(SpecWorkerBase):
|
||||
# Predict draft tokens
|
||||
next_draft_tokens = []
|
||||
original_all_rank_num_tokens = attn_metadata.all_rank_num_tokens
|
||||
for i in range(self.max_draft_len):
|
||||
if i == 0:
|
||||
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
|
||||
(self.max_draft_len + 1)).long()
|
||||
gather_ids_gen = (start_ids_gen +
|
||||
num_accepted_tokens[num_contexts:] - 1 +
|
||||
attn_metadata.num_ctx_tokens)
|
||||
gather_ids = torch.concat(
|
||||
[spec_metadata.gather_ids[:num_contexts], gather_ids_gen],
|
||||
dim=0)
|
||||
else:
|
||||
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
|
||||
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
new_tokens = inputs["input_ids"][gather_ids]
|
||||
self.guided_decoder.add_draft_batch(new_tokens,
|
||||
num_accepted_tokens,
|
||||
draft_step=i)
|
||||
# Get the draft KV cache manager if using separate layouts
|
||||
draft_kv_cache_manager = self.get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
|
||||
# Update attn_metadata.all_rank_num_tokens for attention DP
|
||||
if original_all_rank_num_tokens is not None:
|
||||
with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
|
||||
for i in range(self.max_draft_len):
|
||||
if i == 0:
|
||||
attn_metadata.all_rank_num_tokens = original_all_rank_num_tokens
|
||||
elif spec_metadata.all_rank_num_seqs is not None:
|
||||
attn_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs
|
||||
start_ids_gen = (
|
||||
spec_metadata.batch_indices_cuda[:num_gens] *
|
||||
(self.max_draft_len + 1)).long()
|
||||
gather_ids_gen = (start_ids_gen +
|
||||
num_accepted_tokens[num_contexts:] - 1 +
|
||||
attn_metadata.num_ctx_tokens)
|
||||
gather_ids = torch.concat([
|
||||
spec_metadata.gather_ids[:num_contexts], gather_ids_gen
|
||||
],
|
||||
dim=0)
|
||||
else:
|
||||
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
|
||||
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
|
||||
|
||||
hidden_states, hidden_states_to_save = draft_model.model(**inputs)
|
||||
|
||||
# FIXME (jhaotingc): Currently we disable use_spec_decoding mode for Eagle engine nth steps except 1st step.
|
||||
# Eagle engine takes in draft_len tokens from the previous step, run spec-dec mode with those tokens,
|
||||
# then the following step can use regular decoding mode to generate 1 tokens per step.
|
||||
# Currently the spec-dec mask for chained tree is not implemented yet.
|
||||
# When token tree is supported, this can be removed and all steps may use spec-dec mode as well.
|
||||
attn_metadata.use_spec_decoding = False
|
||||
|
||||
logits = draft_model.logits_processor(hidden_states[gather_ids],
|
||||
draft_model.lm_head,
|
||||
attn_metadata, True)
|
||||
if self.guided_decoder is not None:
|
||||
d2t = getattr(draft_model.model, "d2t", None)
|
||||
self.guided_decoder.execute_draft_batch(logits,
|
||||
d2t,
|
||||
if self.guided_decoder is not None:
|
||||
new_tokens = inputs["input_ids"][gather_ids]
|
||||
self.guided_decoder.add_draft_batch(new_tokens,
|
||||
num_accepted_tokens,
|
||||
draft_step=i)
|
||||
|
||||
new_draft_token = self.draft_decoder(logits, draft_model)
|
||||
next_draft_tokens.append(new_draft_token)
|
||||
# update inputs
|
||||
hidden_states = hidden_states_to_save[gather_ids]
|
||||
position_ids = inputs["position_ids"][gather_ids] + 1
|
||||
# update attn_metadata
|
||||
if i == 0:
|
||||
attn_metadata._seq_lens[:batch_size].fill_(1)
|
||||
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
|
||||
attn_metadata.on_update()
|
||||
# cannot run generation if their is no kv cache
|
||||
if inputs["attn_metadata"].kv_cache_manager is not None:
|
||||
attn_metadata.host_request_types[:attn_metadata.
|
||||
num_contexts].fill_(1)
|
||||
attn_metadata.num_contexts = 0
|
||||
# update kv_lens_cuda
|
||||
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
|
||||
self.max_draft_len - num_accepted_tokens[num_contexts:])
|
||||
attn_metadata.kv_lens_cuda[:num_contexts] += 1
|
||||
elif hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[:batch_size] += 1
|
||||
# support attention dp
|
||||
inputs = {
|
||||
"input_ids": new_draft_token,
|
||||
"position_ids": position_ids,
|
||||
"hidden_states": hidden_states,
|
||||
"attn_metadata": attn_metadata,
|
||||
"spec_metadata": spec_metadata,
|
||||
}
|
||||
# Update attn_metadata.all_rank_num_tokens for attention DP
|
||||
if original_all_rank_num_tokens is not None:
|
||||
if i == 0:
|
||||
attn_metadata.all_rank_num_tokens = original_all_rank_num_tokens
|
||||
elif spec_metadata.all_rank_num_seqs is not None:
|
||||
attn_metadata.all_rank_num_tokens = spec_metadata.all_rank_num_seqs
|
||||
|
||||
hidden_states, hidden_states_to_save = draft_model.model(
|
||||
**inputs)
|
||||
|
||||
# FIXME (jhaotingc): Currently we disable use_spec_decoding mode for Eagle engine nth steps except 1st step.
|
||||
# Eagle engine takes in draft_len tokens from the previous step, run spec-dec mode with those tokens,
|
||||
# then the following step can use regular decoding mode to generate 1 tokens per step.
|
||||
# Currently the spec-dec mask for chained tree is not implemented yet.
|
||||
# When token tree is supported, this can be removed and all steps may use spec-dec mode as well.
|
||||
attn_metadata.use_spec_decoding = False
|
||||
|
||||
logits = draft_model.logits_processor(hidden_states[gather_ids],
|
||||
draft_model.lm_head,
|
||||
attn_metadata, True)
|
||||
if self.guided_decoder is not None:
|
||||
d2t = getattr(draft_model.model, "d2t", None)
|
||||
self.guided_decoder.execute_draft_batch(logits,
|
||||
d2t,
|
||||
draft_step=i)
|
||||
|
||||
new_draft_token = self.draft_decoder(logits, draft_model)
|
||||
next_draft_tokens.append(new_draft_token)
|
||||
# update inputs
|
||||
hidden_states = hidden_states_to_save[gather_ids]
|
||||
position_ids = inputs["position_ids"][gather_ids] + 1
|
||||
# update attn_metadata
|
||||
if i == 0:
|
||||
attn_metadata._seq_lens[:batch_size].fill_(1)
|
||||
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
|
||||
attn_metadata.on_update()
|
||||
# cannot run generation if their is no kv cache
|
||||
if inputs["attn_metadata"].kv_cache_manager is not None:
|
||||
attn_metadata.host_request_types[:attn_metadata.
|
||||
num_contexts].fill_(1)
|
||||
attn_metadata.num_contexts = 0
|
||||
# update kv_lens_cuda
|
||||
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
|
||||
self.max_draft_len -
|
||||
num_accepted_tokens[num_contexts:])
|
||||
attn_metadata.kv_lens_cuda[:num_contexts] += 1
|
||||
elif hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[:batch_size] += 1
|
||||
# support attention dp
|
||||
inputs = {
|
||||
"input_ids": new_draft_token,
|
||||
"position_ids": position_ids,
|
||||
"hidden_states": hidden_states,
|
||||
"attn_metadata": attn_metadata,
|
||||
"spec_metadata": spec_metadata,
|
||||
}
|
||||
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
||||
|
||||
# restore attn_metadata to support cuda graph
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import copy
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum, auto
|
||||
from typing import TYPE_CHECKING, List, Optional, Type
|
||||
@ -12,7 +13,8 @@ from tensorrt_llm.logger import logger
|
||||
|
||||
from ..._utils import get_sm_version
|
||||
from ..attention_backend.trtllm import AttentionBackend, TrtllmAttention
|
||||
from ..pyexecutor.resource_manager import BaseResourceManager
|
||||
from ..pyexecutor.resource_manager import (BaseResourceManager,
|
||||
ResourceManagerType)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..pyexecutor.guided_decoder import CapturableGuidedDecoder
|
||||
@ -21,6 +23,21 @@ if TYPE_CHECKING:
|
||||
FORCE_NUM_ACCEPTED_TOKENS_ENV_VAR = "TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS"
|
||||
|
||||
|
||||
def should_use_separate_draft_kv_cache(spec_config) -> bool:
|
||||
"""
|
||||
Check if separate draft KV cache should be used for one-engine speculative decoding.
|
||||
|
||||
Args:
|
||||
spec_config: The speculative decoding config (e.g., EagleDecodingConfig).
|
||||
|
||||
Returns:
|
||||
bool: True if separate draft KV cache should be used.
|
||||
"""
|
||||
if spec_config is None:
|
||||
return False
|
||||
return getattr(spec_config, 'use_separate_draft_kv_cache', False)
|
||||
|
||||
|
||||
def get_force_num_accepted_tokens() -> int:
|
||||
"""
|
||||
Read and parse the TLLM_SPEC_DECODE_FORCE_NUM_ACCEPTED_TOKENS environment variable.
|
||||
@ -367,10 +384,11 @@ class SpecWorkerBase(nn.Module, ABC):
|
||||
Provides common functionality for sampling and token handling.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, use_separate_draft_kv_cache: bool = False):
|
||||
super().__init__()
|
||||
self.guided_decoder: Optional["CapturableGuidedDecoder"] = None
|
||||
self.force_num_accepted_tokens = get_force_num_accepted_tokens()
|
||||
self.use_separate_draft_kv_cache = use_separate_draft_kv_cache
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
@ -385,6 +403,47 @@ class SpecWorkerBase(nn.Module, ABC):
|
||||
self.guided_decoder = guided_decoder
|
||||
return True
|
||||
|
||||
def get_draft_kv_cache_manager(self, resource_manager):
|
||||
"""
|
||||
Get the draft KV cache manager if using separate KV cache layouts.
|
||||
"""
|
||||
if self.use_separate_draft_kv_cache and resource_manager is not None:
|
||||
return resource_manager.get_resource_manager(
|
||||
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
|
||||
return None
|
||||
|
||||
@contextmanager
|
||||
def draft_kv_cache_context(self, attn_metadata, draft_kv_cache_manager):
|
||||
"""
|
||||
Context manager to temporarily switch to draft KV cache manager.
|
||||
|
||||
This swaps both the kv_cache_manager reference AND the block offset tensors,
|
||||
since the main and draft KV caches have different block layouts.
|
||||
"""
|
||||
if draft_kv_cache_manager is None:
|
||||
yield
|
||||
return
|
||||
|
||||
# Save main KV cache manager and block offsets
|
||||
target_kv_cache_manager = attn_metadata.kv_cache_manager
|
||||
target_kv_cache_block_offsets = attn_metadata.kv_cache_block_offsets
|
||||
target_host_kv_cache_block_offsets = attn_metadata.host_kv_cache_block_offsets
|
||||
|
||||
# Switch to draft KV cache manager and its block offsets
|
||||
attn_metadata.kv_cache_manager = draft_kv_cache_manager
|
||||
if hasattr(attn_metadata, 'draft_kv_cache_block_offsets'
|
||||
) and attn_metadata.draft_kv_cache_block_offsets is not None:
|
||||
attn_metadata.kv_cache_block_offsets = attn_metadata.draft_kv_cache_block_offsets
|
||||
attn_metadata.host_kv_cache_block_offsets = attn_metadata.draft_host_kv_cache_block_offsets
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Restore main KV cache manager and block offsets
|
||||
attn_metadata.kv_cache_manager = target_kv_cache_manager
|
||||
attn_metadata.kv_cache_block_offsets = target_kv_cache_block_offsets
|
||||
attn_metadata.host_kv_cache_block_offsets = target_host_kv_cache_block_offsets
|
||||
|
||||
def _sample_tokens_for_batch(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
|
||||
@ -349,8 +349,11 @@ class MTPSampler(TorchSampler):
|
||||
|
||||
class MTPWorker(SpecWorkerBase):
|
||||
|
||||
def __init__(self, spec_config: "MTPDecodingConfig", model_config=None):
|
||||
super().__init__()
|
||||
def __init__(self,
|
||||
spec_config: "MTPDecodingConfig",
|
||||
model_config=None,
|
||||
use_separate_draft_kv_cache: bool = False):
|
||||
super().__init__(use_separate_draft_kv_cache)
|
||||
self.spec_config = spec_config
|
||||
self.model_config = model_config
|
||||
self.is_thop = False
|
||||
@ -368,6 +371,7 @@ class MTPWorker(SpecWorkerBase):
|
||||
attn_metadata,
|
||||
spec_metadata,
|
||||
draft_model,
|
||||
resource_manager=None,
|
||||
):
|
||||
'''
|
||||
Example:
|
||||
@ -501,43 +505,50 @@ class MTPWorker(SpecWorkerBase):
|
||||
# update attn metadata
|
||||
if attn_metadata is not None:
|
||||
self.change_attn_metadata(num_accepted_tokens, attn_metadata)
|
||||
draft_inputs.update(attn_metadata=attn_metadata)
|
||||
|
||||
# Run MTP layers to predict draft tokens
|
||||
next_draft_tokens = []
|
||||
last_tokens_idx = torch.cumsum(
|
||||
attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1
|
||||
for i, mtp_layer in enumerate(draft_model.mtp_layers):
|
||||
if self.guided_decoder is not None:
|
||||
new_tokens = draft_inputs['input_ids'][last_tokens_idx]
|
||||
self.guided_decoder.add_draft_batch(new_tokens,
|
||||
num_accepted_tokens,
|
||||
draft_step=i)
|
||||
|
||||
hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens,
|
||||
**draft_inputs)
|
||||
logits = mtp_layer.shared_head(hidden_states, draft_model.lm_head,
|
||||
attn_metadata).float()
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.execute_draft_batch(logits, draft_step=i)
|
||||
draft_kv_cache_manager = self.get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
|
||||
new_draft_token = self.draft_sampler(logits)
|
||||
next_draft_tokens.append(new_draft_token)
|
||||
# shift input_ids and hidden_states
|
||||
input_ids = draft_inputs["input_ids"]
|
||||
input_ids[:-1] = input_ids[1:].clone()
|
||||
input_ids[last_tokens_idx] = new_draft_token
|
||||
draft_hidden_states = draft_inputs["hidden_states"]
|
||||
draft_hidden_states[:-1] = draft_hidden_states[1:].clone()
|
||||
draft_hidden_states[last_tokens_idx] = hidden_states[
|
||||
last_tokens_idx, :]
|
||||
draft_inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": draft_inputs["position_ids"],
|
||||
"hidden_states": draft_hidden_states,
|
||||
"attn_metadata": draft_inputs["attn_metadata"],
|
||||
}
|
||||
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
||||
with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
|
||||
for i, mtp_layer in enumerate(draft_model.mtp_layers):
|
||||
if self.guided_decoder is not None:
|
||||
new_tokens = draft_inputs['input_ids'][last_tokens_idx]
|
||||
self.guided_decoder.add_draft_batch(new_tokens,
|
||||
num_accepted_tokens,
|
||||
draft_step=i)
|
||||
|
||||
hidden_states = mtp_layer(embed_tokens=draft_model.embed_tokens,
|
||||
**draft_inputs)
|
||||
|
||||
logits = mtp_layer.shared_head(hidden_states,
|
||||
draft_model.lm_head,
|
||||
attn_metadata).float()
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.execute_draft_batch(logits,
|
||||
draft_step=i)
|
||||
|
||||
new_draft_token = self.draft_sampler(logits)
|
||||
next_draft_tokens.append(new_draft_token)
|
||||
# shift input_ids and hidden_states
|
||||
input_ids = draft_inputs["input_ids"]
|
||||
input_ids[:-1] = input_ids[1:].clone()
|
||||
input_ids[last_tokens_idx] = new_draft_token
|
||||
draft_hidden_states = draft_inputs["hidden_states"]
|
||||
draft_hidden_states[:-1] = draft_hidden_states[1:].clone()
|
||||
draft_hidden_states[last_tokens_idx] = hidden_states[
|
||||
last_tokens_idx, :]
|
||||
draft_inputs = {
|
||||
"input_ids": input_ids,
|
||||
"position_ids": draft_inputs["position_ids"],
|
||||
"hidden_states": draft_hidden_states,
|
||||
"attn_metadata": draft_inputs["attn_metadata"],
|
||||
}
|
||||
next_draft_tokens = torch.stack(next_draft_tokens, dim=1)
|
||||
|
||||
# restore attn metadata
|
||||
if attn_metadata is not None:
|
||||
@ -567,6 +578,7 @@ class MTPWorker(SpecWorkerBase):
|
||||
attn_metadata,
|
||||
spec_metadata,
|
||||
draft_model,
|
||||
resource_manager=None,
|
||||
):
|
||||
batch_size = attn_metadata.num_seqs
|
||||
mtp_num_modules = self.spec_config.num_nextn_predict_layers
|
||||
@ -1178,8 +1190,9 @@ class MTPEagleWorker(MTPWorker):
|
||||
|
||||
def __init__(self,
|
||||
spec_config: "MTPDecodingConfig",
|
||||
model_config: Optional[ModelConfig] = None):
|
||||
super().__init__(spec_config, model_config)
|
||||
model_config: Optional[ModelConfig] = None,
|
||||
use_separate_draft_kv_cache: bool = False):
|
||||
super().__init__(spec_config, model_config, use_separate_draft_kv_cache)
|
||||
self.model_config = model_config
|
||||
self.mtp_num_modules = spec_config.num_nextn_predict_layers
|
||||
|
||||
@ -1201,6 +1214,7 @@ class MTPEagleWorker(MTPWorker):
|
||||
attn_metadata,
|
||||
spec_metadata,
|
||||
draft_model,
|
||||
resource_manager=None,
|
||||
):
|
||||
batch_size = attn_metadata.num_seqs
|
||||
num_contexts = attn_metadata.num_contexts
|
||||
@ -1236,123 +1250,134 @@ class MTPEagleWorker(MTPWorker):
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata)
|
||||
|
||||
# Get the draft KV cache manager if using separate layouts
|
||||
draft_kv_cache_manager = self.get_draft_kv_cache_manager(
|
||||
resource_manager)
|
||||
|
||||
# Predict draft tokens
|
||||
next_draft_tokens = []
|
||||
for i in range(self.mtp_num_modules):
|
||||
if i == 0:
|
||||
hidden_states = draft_model.mtp_layers[0](
|
||||
embed_tokens=draft_model.embed_tokens,
|
||||
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
|
||||
**inputs)
|
||||
start_ids_gen = (spec_metadata.batch_indices_cuda[:num_gens] *
|
||||
(self.mtp_num_modules + 1)).long()
|
||||
gather_ids_gen = (start_ids_gen +
|
||||
num_accepted_tokens[num_contexts:] - 1 +
|
||||
attn_metadata.num_ctx_tokens)
|
||||
gather_ids = torch.concat(
|
||||
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
|
||||
else:
|
||||
hidden_states = draft_model.mtp_layers[0](
|
||||
embed_tokens=draft_model.embed_tokens,
|
||||
all_rank_num_tokens=spec_metadata.
|
||||
subseq_all_rank_num_tokens,
|
||||
**inputs)
|
||||
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
|
||||
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
|
||||
with self.draft_kv_cache_context(attn_metadata, draft_kv_cache_manager):
|
||||
for i in range(self.mtp_num_modules):
|
||||
if i == 0:
|
||||
hidden_states = draft_model.mtp_layers[0](
|
||||
embed_tokens=draft_model.embed_tokens,
|
||||
all_rank_num_tokens=spec_metadata.all_rank_num_tokens,
|
||||
**inputs)
|
||||
|
||||
if self.guided_decoder is not None:
|
||||
new_tokens = inputs["input_ids"][gather_ids]
|
||||
self.guided_decoder.add_draft_batch(new_tokens,
|
||||
num_accepted_tokens,
|
||||
draft_step=i)
|
||||
if self.model_config.mapping.enable_attention_dp and \
|
||||
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
|
||||
hidden_states_gathered = hidden_states[gather_ids]
|
||||
token_count = hidden_states_gathered.view(
|
||||
-1, hidden_states_gathered.shape[-1]).shape[0]
|
||||
max_num_requests = spec_metadata.max_num_requests
|
||||
pad_len = max_num_requests - token_count
|
||||
if pad_len > 0:
|
||||
padded_hidden_states = F.pad(hidden_states_gathered.view(
|
||||
-1, hidden_states_gathered.shape[-1]),
|
||||
(0, 0, 0, pad_len),
|
||||
mode="constant",
|
||||
value=0)
|
||||
elif pad_len == 0:
|
||||
padded_hidden_states = hidden_states_gathered.view(
|
||||
-1, hidden_states_gathered.shape[-1])
|
||||
start_ids_gen = (
|
||||
spec_metadata.batch_indices_cuda[:num_gens] *
|
||||
(self.mtp_num_modules + 1)).long()
|
||||
gather_ids_gen = (start_ids_gen +
|
||||
num_accepted_tokens[num_contexts:] - 1 +
|
||||
attn_metadata.num_ctx_tokens)
|
||||
gather_ids = torch.concat(
|
||||
[last_tokens_idx[:num_contexts], gather_ids_gen], dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported"
|
||||
)
|
||||
logits = draft_model.mtp_layers[0].shared_head(
|
||||
padded_hidden_states, draft_model.lm_head, attn_metadata,
|
||||
True)
|
||||
else:
|
||||
logits = draft_model.mtp_layers[0].shared_head(
|
||||
hidden_states[gather_ids], draft_model.lm_head,
|
||||
attn_metadata, True)
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.execute_draft_batch(logits, draft_step=i)
|
||||
hidden_states = draft_model.mtp_layers[0](
|
||||
embed_tokens=draft_model.embed_tokens,
|
||||
all_rank_num_tokens=spec_metadata.
|
||||
subseq_all_rank_num_tokens,
|
||||
**inputs)
|
||||
|
||||
if self.model_config.mapping.enable_attention_dp and \
|
||||
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
|
||||
mapping_lm_head_tp = draft_model.mtp_layers[
|
||||
0].shared_head.mapping_lm_head_tp
|
||||
new_draft_token = self.draft_sampler(logits, mapping_lm_head_tp)
|
||||
new_draft_token = new_draft_token[:token_count]
|
||||
else:
|
||||
new_draft_token = self.draft_sampler(logits)
|
||||
# All of the seq_len are 1, use batch_indices_cuda as gather_ids
|
||||
gather_ids = spec_metadata.batch_indices_cuda[:batch_size]
|
||||
|
||||
hidden_states, position_ids = self.update_draft_tokens(
|
||||
next_draft_tokens, new_draft_token, hidden_states, gather_ids,
|
||||
inputs)
|
||||
# update attn_metadata
|
||||
if i == 0:
|
||||
attn_metadata._seq_lens[:batch_size].fill_(1)
|
||||
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
|
||||
attn_metadata.on_update()
|
||||
# cannot run generation if their is no kv cache
|
||||
has_kv_cache = inputs[
|
||||
"attn_metadata"].kv_cache_manager is not None
|
||||
if has_kv_cache:
|
||||
attn_metadata.host_request_types[:attn_metadata.
|
||||
num_contexts].fill_(1)
|
||||
attn_metadata.num_contexts = 0
|
||||
# update kv_lens_cuda
|
||||
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
|
||||
self.mtp_num_modules -
|
||||
num_accepted_tokens[num_contexts:])
|
||||
attn_metadata.kv_lens_cuda[:num_contexts] += 1
|
||||
# update metadata for flash mla
|
||||
if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla:
|
||||
reorder_block_ids_per_seq = torch.cat([
|
||||
attn_metadata.
|
||||
kv_block_ids_per_seq[num_contexts:batch_size],
|
||||
attn_metadata.kv_block_ids_per_seq[:num_contexts]
|
||||
])
|
||||
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
|
||||
reorder_block_ids_per_seq, non_blocking=True)
|
||||
# update metadata
|
||||
# some attention metadata needs to be updated when changing seq_lens/kv_lens
|
||||
attn_metadata.update_for_spec_dec()
|
||||
elif hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
if self.guided_decoder is not None:
|
||||
new_tokens = inputs["input_ids"][gather_ids]
|
||||
self.guided_decoder.add_draft_batch(new_tokens,
|
||||
num_accepted_tokens,
|
||||
draft_step=i)
|
||||
if self.model_config.mapping.enable_attention_dp and \
|
||||
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
|
||||
hidden_states_gathered = hidden_states[gather_ids]
|
||||
token_count = hidden_states_gathered.view(
|
||||
-1, hidden_states_gathered.shape[-1]).shape[0]
|
||||
max_num_requests = spec_metadata.max_num_requests
|
||||
pad_len = max_num_requests - token_count
|
||||
if pad_len > 0:
|
||||
padded_hidden_states = F.pad(
|
||||
hidden_states_gathered.view(
|
||||
-1, hidden_states_gathered.shape[-1]),
|
||||
(0, 0, 0, pad_len),
|
||||
mode="constant",
|
||||
value=0)
|
||||
elif pad_len == 0:
|
||||
padded_hidden_states = hidden_states_gathered.view(
|
||||
-1, hidden_states_gathered.shape[-1])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"In MTPEagleWorker.forward(), token_count < max_num_requests, which is not supported"
|
||||
)
|
||||
logits = draft_model.mtp_layers[0].shared_head(
|
||||
padded_hidden_states, draft_model.lm_head,
|
||||
attn_metadata, True)
|
||||
else:
|
||||
logits = draft_model.mtp_layers[0].shared_head(
|
||||
hidden_states[gather_ids], draft_model.lm_head,
|
||||
attn_metadata, True)
|
||||
if self.guided_decoder is not None:
|
||||
self.guided_decoder.execute_draft_batch(logits,
|
||||
draft_step=i)
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def update_kv_lens(kv_lens_cuda, batch_size):
|
||||
kv_lens_cuda[:batch_size] += 1
|
||||
if self.model_config.mapping.enable_attention_dp and \
|
||||
getattr(self.model_config.mapping, 'enable_lm_head_tp_in_adp', False):
|
||||
mapping_lm_head_tp = draft_model.mtp_layers[
|
||||
0].shared_head.mapping_lm_head_tp
|
||||
new_draft_token = self.draft_sampler(
|
||||
logits, mapping_lm_head_tp)
|
||||
new_draft_token = new_draft_token[:token_count]
|
||||
else:
|
||||
new_draft_token = self.draft_sampler(logits)
|
||||
|
||||
update_kv_lens(attn_metadata.kv_lens_cuda, batch_size)
|
||||
# update metadata
|
||||
# some attention metadata needs to be updated when changing kv_lens
|
||||
attn_metadata.update_for_spec_dec()
|
||||
inputs = {
|
||||
"input_ids": new_draft_token,
|
||||
"position_ids": position_ids,
|
||||
"hidden_states": hidden_states,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
hidden_states, position_ids = self.update_draft_tokens(
|
||||
next_draft_tokens, new_draft_token, hidden_states,
|
||||
gather_ids, inputs)
|
||||
# update attn_metadata
|
||||
if i == 0:
|
||||
attn_metadata._seq_lens[:batch_size].fill_(1)
|
||||
attn_metadata._seq_lens_cuda[:batch_size].fill_(1)
|
||||
attn_metadata.on_update()
|
||||
# cannot run generation if their is no kv cache
|
||||
has_kv_cache = inputs[
|
||||
"attn_metadata"].kv_cache_manager is not None
|
||||
if has_kv_cache:
|
||||
attn_metadata.host_request_types[:attn_metadata.
|
||||
num_contexts].fill_(1)
|
||||
attn_metadata.num_contexts = 0
|
||||
# update kv_lens_cuda
|
||||
if hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
attn_metadata.kv_lens_cuda[num_contexts:batch_size] -= (
|
||||
self.mtp_num_modules -
|
||||
num_accepted_tokens[num_contexts:])
|
||||
attn_metadata.kv_lens_cuda[:num_contexts] += 1
|
||||
# update metadata for flash mla
|
||||
if has_kv_cache and num_contexts > 0 and attn_metadata.enable_flash_mla:
|
||||
reorder_block_ids_per_seq = torch.cat([
|
||||
attn_metadata.
|
||||
kv_block_ids_per_seq[num_contexts:batch_size],
|
||||
attn_metadata.kv_block_ids_per_seq[:num_contexts]
|
||||
])
|
||||
attn_metadata.block_ids_per_seq[:batch_size, :].copy_(
|
||||
reorder_block_ids_per_seq, non_blocking=True)
|
||||
# update metadata
|
||||
# some attention metadata needs to be updated when changing seq_lens/kv_lens
|
||||
attn_metadata.update_for_spec_dec()
|
||||
elif hasattr(attn_metadata, 'kv_lens_cuda'):
|
||||
|
||||
@torch.compile(options={"max-autotune": True})
|
||||
def update_kv_lens(kv_lens_cuda, batch_size):
|
||||
kv_lens_cuda[:batch_size] += 1
|
||||
|
||||
update_kv_lens(attn_metadata.kv_lens_cuda, batch_size)
|
||||
# update metadata
|
||||
# some attention metadata needs to be updated when changing kv_lens
|
||||
attn_metadata.update_for_spec_dec()
|
||||
inputs = {
|
||||
"input_ids": new_draft_token,
|
||||
"position_ids": position_ids,
|
||||
"hidden_states": hidden_states,
|
||||
"attn_metadata": attn_metadata,
|
||||
}
|
||||
|
||||
# restore attn_metadata to support cuda graph
|
||||
attn_metadata.restore_from_spec_dec()
|
||||
|
||||
@ -219,14 +219,19 @@ def get_num_spec_layers(spec_config):
|
||||
return 0
|
||||
|
||||
|
||||
def get_spec_worker(spec_config, model_config, mapping):
|
||||
def get_spec_worker(spec_config,
|
||||
model_config,
|
||||
mapping,
|
||||
use_separate_draft_kv_cache: bool = False):
|
||||
spec_dec_mode = spec_config.spec_dec_mode
|
||||
if spec_dec_mode.is_mtp_vanilla():
|
||||
return MTPWorker(spec_config, model_config)
|
||||
return MTPWorker(spec_config, model_config, use_separate_draft_kv_cache)
|
||||
if spec_dec_mode.is_mtp_eagle_one_model():
|
||||
return MTPEagleWorker(spec_config, model_config)
|
||||
return MTPEagleWorker(spec_config, model_config,
|
||||
use_separate_draft_kv_cache)
|
||||
if spec_dec_mode.is_eagle3_one_model():
|
||||
return Eagle3OneModelWorker(spec_config, mapping)
|
||||
return Eagle3OneModelWorker(spec_config, mapping,
|
||||
use_separate_draft_kv_cache)
|
||||
return None
|
||||
|
||||
|
||||
@ -242,6 +247,21 @@ def get_num_extra_kv_tokens(spec_config):
|
||||
return 0
|
||||
|
||||
|
||||
def get_draft_kv_cache_manager(spec_config, resource_manager):
|
||||
"""
|
||||
Returns the draft KV cache manager only in one-model speculative decoding
|
||||
mode where the target model manages a separate draft KV cache.
|
||||
"""
|
||||
from ..pyexecutor.resource_manager import ResourceManagerType
|
||||
|
||||
if spec_config is None:
|
||||
return None
|
||||
if not spec_config.spec_dec_mode.use_one_engine():
|
||||
return None
|
||||
return resource_manager.get_resource_manager(
|
||||
ResourceManagerType.DRAFT_KV_CACHE_MANAGER)
|
||||
|
||||
|
||||
def update_spec_config_from_model_config(spec_config, model_config):
|
||||
if spec_config.spec_dec_mode.is_mtp_one_model():
|
||||
# Use `max_draft_len` for several low-level APIs. TODO: Remove this after distinguishing them.
|
||||
|
||||
@ -857,6 +857,8 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
# The model architecture of the eagle3 model.
|
||||
# choices: llama3, mistral_large3
|
||||
eagle3_model_arch: str = "llama3"
|
||||
# Whether to use a separate KV cache manager for the draft model in one-model mode.
|
||||
use_separate_draft_kv_cache: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__()
|
||||
@ -1090,6 +1092,8 @@ class MTPDecodingConfig(DecodingBaseConfig):
|
||||
relaxed_delta: float = 0.
|
||||
use_mtp_vanilla: bool = False
|
||||
mtp_eagle_one_model: bool = True
|
||||
# Whether to use a separate KV cache manager for the draft model in one-model mode.
|
||||
use_separate_draft_kv_cache: bool = False
|
||||
|
||||
# TODO: remove this after distinguishing `max_draft_len` and `num_nextn_predict_layers`
|
||||
# Now we need a flag when MTPDecodingConfig is updated by PyTorchModelEngine.
|
||||
|
||||
@ -92,39 +92,71 @@ def test_kv_lens_runtime_with_eagle3_one_model():
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp",
|
||||
"use_cuda_graph,attn_backend,disable_overlap_scheduler,enable_block_reuse,use_one_model,enable_chunked_prefill,use_chain_drafter,multi_batch,attention_dp,use_separate_draft_kv_cache",
|
||||
[
|
||||
[True, "TRTLLM", True, False, False, False, True, False, False],
|
||||
[True, "TRTLLM", True, False, False, False, False, False, False],
|
||||
[False, "TRTLLM", True, False, False, False, True, False, False],
|
||||
[False, "TRTLLM", True, False, False, False, False, False, False],
|
||||
[True, "FLASHINFER", True, False, False, False, True, False, False],
|
||||
[False, "FLASHINFER", True, False, False, False, True, False, False],
|
||||
[False, "TRTLLM", False, True, True, False, True, False, False],
|
||||
[True, "TRTLLM", False, True, True, False, True, False, False],
|
||||
[True, "TRTLLM", True, False, True, True, True, False, False],
|
||||
[True, "TRTLLM", True, False, True, False, True, False, False],
|
||||
[True, "TRTLLM", True, False, False, True, True, False, False],
|
||||
[True, "TRTLLM", False, False, False, False, True, False, False],
|
||||
[False, "TRTLLM", False, False, False, False, True, False, False],
|
||||
[True, "TRTLLM", False, False, False, False, False, True, False],
|
||||
[True, "TRTLLM", False, False, False, False, False, True, True],
|
||||
[False, "TRTLLM", False, False, False, False, False, True, False],
|
||||
[True, "TRTLLM", False, False, False, False, True, True, False],
|
||||
[False, "TRTLLM", False, False, False, False, True, True, False],
|
||||
[True, "TRTLLM", False, False, False, False, False, False, False],
|
||||
[False, "TRTLLM", False, False, False, False, False, False, False],
|
||||
[True, "TRTLLM", False, False, False, True, True, False, False],
|
||||
[True, "TRTLLM", False, False, False, True, False, False, False],
|
||||
[True, "FLASHINFER", False, False, False, False, True, False, False],
|
||||
[False, "FLASHINFER", False, False, False, False, True, False, False],
|
||||
[True, "TRTLLM", True, False, False, False, True, False, False, False],
|
||||
[True, "TRTLLM", True, False, False, False, False, False, False, False],
|
||||
[False, "TRTLLM", True, False, False, False, True, False, False, False],
|
||||
[
|
||||
False, "TRTLLM", True, False, False, False, False, False, False,
|
||||
False
|
||||
],
|
||||
[
|
||||
True, "FLASHINFER", True, False, False, False, True, False, False,
|
||||
False
|
||||
],
|
||||
[
|
||||
False, "FLASHINFER", True, False, False, False, True, False, False,
|
||||
False
|
||||
],
|
||||
[False, "TRTLLM", False, True, True, False, True, False, False, False],
|
||||
[True, "TRTLLM", False, True, True, False, True, False, False, False],
|
||||
[True, "TRTLLM", True, False, True, True, True, False, False, False],
|
||||
[True, "TRTLLM", True, False, True, False, True, False, False, False],
|
||||
[True, "TRTLLM", True, False, False, True, True, False, False, False],
|
||||
[True, "TRTLLM", False, False, False, False, True, False, False, False],
|
||||
[
|
||||
False, "TRTLLM", False, False, False, False, True, False, False,
|
||||
False
|
||||
],
|
||||
[True, "TRTLLM", False, False, False, False, False, True, False, False],
|
||||
[True, "TRTLLM", False, False, False, False, False, True, True, False],
|
||||
[
|
||||
False, "TRTLLM", False, False, False, False, False, True, False,
|
||||
False
|
||||
],
|
||||
[True, "TRTLLM", False, False, False, False, True, True, False, False],
|
||||
[False, "TRTLLM", False, False, False, False, True, True, False, False],
|
||||
[
|
||||
True, "TRTLLM", False, False, False, False, False, False, False,
|
||||
False
|
||||
],
|
||||
[
|
||||
False, "TRTLLM", False, False, False, False, False, False, False,
|
||||
False
|
||||
],
|
||||
[True, "TRTLLM", False, False, False, True, True, False, False, False],
|
||||
[True, "TRTLLM", False, False, False, True, False, False, False, False],
|
||||
[
|
||||
True, "FLASHINFER", False, False, False, False, True, False, False,
|
||||
False
|
||||
],
|
||||
[
|
||||
False, "FLASHINFER", False, False, False, False, True, False, False,
|
||||
False
|
||||
],
|
||||
# Test use_separate_draft_kv_cache with one-model mode
|
||||
[True, "TRTLLM", False, True, True, False, False, True, False, True],
|
||||
[True, "TRTLLM", True, False, True, False, False, False, False, True],
|
||||
[False, "TRTLLM", False, True, True, True, False, True, False, True],
|
||||
[False, "TRTLLM", True, False, True, False, False, True, False, True],
|
||||
])
|
||||
@pytest.mark.high_cuda_memory
|
||||
def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
disable_overlap_scheduler: bool, enable_block_reuse: bool,
|
||||
use_one_model: bool, enable_chunked_prefill: bool,
|
||||
use_chain_drafter: bool, multi_batch: bool,
|
||||
attention_dp: bool, request):
|
||||
attention_dp: bool, use_separate_draft_kv_cache: bool):
|
||||
# Eagle3 one model works with overlap scheduler and block reuse.
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
if total_mem_gb < 35:
|
||||
@ -168,6 +200,7 @@ def test_llama_eagle3(use_cuda_graph: bool, attn_backend: str,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
# Llama 3 does not support one model eagle.
|
||||
eagle3_one_model=use_one_model,
|
||||
use_separate_draft_kv_cache=use_separate_draft_kv_cache,
|
||||
)
|
||||
spec_config._allow_chain_drafter = use_chain_drafter
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user