Add initial EAGLE-3 implementation (#3035)

Signed-off-by: Mike Iovine <miovine@nvidia.com>
This commit is contained in:
Mike Iovine 2025-03-29 10:31:24 -04:00 committed by GitHub
parent 9c484b24e6
commit 5416966ddb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 972 additions and 128 deletions

View File

@ -197,6 +197,7 @@ class OPTModel(DecoderModel):
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(

View File

@ -4,7 +4,7 @@ from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.llmapi import MTPDecodingConfig
from tensorrt_llm.llmapi import EagleDecodingConfig, MTPDecodingConfig
example_prompts = [
"Hello, my name is",
@ -84,7 +84,10 @@ def add_llm_args(parser):
parser.add_argument('--load_format', type=str, default='auto')
# Speculative decoding
parser.add_argument('--mtp_nextn', type=int, default=0)
parser.add_argument('--spec_decode_algo', type=str, default=None)
parser.add_argument('--spec_decode_nextn', type=int, default=1)
parser.add_argument('--eagle_model_dir', type=str, default=None)
return parser
@ -111,8 +114,18 @@ def setup_llm(args):
free_gpu_memory_fraction=args.kv_cache_fraction,
)
mtp_config = MTPDecodingConfig(
num_nextn_predict_layers=args.mtp_nextn) if args.mtp_nextn > 0 else None
spec_decode_algo = args.spec_decode_algo.upper(
) if args.spec_decode_algo is not None else None
if spec_decode_algo == 'MTP':
spec_config = MTPDecodingConfig(
num_nextn_predict_layers=args.spec_decode_nextn)
elif spec_decode_algo == "EAGLE3":
spec_config = EagleDecodingConfig(
max_draft_len=args.spec_decode_nextn,
pytorch_eagle_weights_path=args.eagle_model_dir)
else:
spec_config = None
llm = LLM(model=args.model_dir,
max_seq_len=args.max_seq_len,
@ -126,7 +139,7 @@ def setup_llm(args):
moe_expert_parallel_size=args.moe_ep_size,
moe_tensor_parallel_size=args.moe_tp_size,
enable_chunked_prefill=args.enable_chunked_prefill,
speculative_config=mtp_config)
speculative_config=spec_config)
sampling_params = SamplingParams(
max_tokens=args.max_tokens,

View File

@ -21,6 +21,7 @@ from ..pyexecutor.resource_manager import KVCacheManager
class AttentionRuntimeFeatures:
chunked_prefill: bool = False
cache_reuse: bool = False
has_speculative_draft_tokens: bool = False
@dataclass(kw_only=True)

View File

@ -630,8 +630,10 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
)
assert not metadata.is_cross, "TRT-LLM Attention does not support cross attention yet."
use_paged_context_fmha = (metadata.runtime_features.chunked_prefill
or metadata.runtime_features.cache_reuse)
use_paged_context_fmha = (
metadata.runtime_features.chunked_prefill
or metadata.runtime_features.cache_reuse
or metadata.runtime_features.has_speculative_draft_tokens)
if use_paged_context_fmha and self.has_fp8_kv_cache:
# NOTE: W4A8_AWQ can be included too, exclude for now since

View File

@ -11,7 +11,14 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]):
def from_config(
config: ModelConfig[TConfig],
) -> DecoderModelForCausalLM[TModel, TConfig]:
cls = MODEL_CLASS_MAPPING.get(config.pretrained_config.architectures[0])
model_arch = config.pretrained_config.architectures[0]
# Hack to detect eagle3 checkpoints. TODO: should we provide
# our own checkpoints with the correct arch? It would let us
# avoid nasty stuff like this.
if hasattr(config.pretrained_config, "draft_vocab_size"):
model_arch = "EAGLE3" + model_arch
cls = MODEL_CLASS_MAPPING.get(model_arch)
if cls is None:
raise ValueError(
f"Unknown architecture for AutoModelForCausalLM: {config.pretrained_config.architectures[0]}"

View File

@ -1,4 +1,4 @@
from typing import Optional, Tuple
from typing import Dict, Optional, Tuple
import torch
from torch import nn
@ -14,8 +14,10 @@ from ..modules.attention import Attention
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, WeightMode, WeightsLoadingConfig
from ..modules.rms_norm import RMSNorm
from ..modules.rotary_embedding import RotaryEmbedding
from ..speculative import Eagle3SpecMetadata, SpecMetadata
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
register_auto_model, support_pp)
@ -74,6 +76,8 @@ class LlamaDecoderLayer(DecoderLayer):
) -> Tuple[torch.Tensor, torch.Tensor]:
super().__init__()
config = model_config.pretrained_config
self.layer_idx = layer_idx
self.self_attn = LlamaAttention(
model_config,
layer_idx=layer_idx,
@ -89,6 +93,7 @@ class LlamaDecoderLayer(DecoderLayer):
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
@ -99,6 +104,7 @@ class LlamaDecoderLayer(DecoderLayer):
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
residual: Optional[torch.Tensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if residual is None:
@ -117,6 +123,107 @@ class LlamaDecoderLayer(DecoderLayer):
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
if spec_metadata is not None:
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
hidden_states, residual)
return hidden_states, residual
class Eagle3LlamaAttention(LlamaAttention):
def __init__(
self,
model_config: ModelConfig[LlamaConfig],
layer_idx: Optional[int] = None,
):
super().__init__(model_config, layer_idx)
model_config = model_config or ModelConfig()
config = model_config.pretrained_config
tp_size = model_config.mapping.tp_size
tp_rank = model_config.mapping.tp_rank
gpus_per_node = model_config.mapping.gpus_per_node
# Override the QKV projection. The number of input features
# is twice as big for EAGLE3 draft models.
self.qkv_proj = Linear(
2 * self.hidden_size,
tp_size * self.q_size + 2 * tp_size * self.kv_size,
bias=config.attention_bias,
dtype=config.torch_dtype,
parallel_config=ParallelConfig(
tensor_parallel_size=tp_size,
tensor_parallel_rank=tp_rank,
tensor_parallel_mode=TensorParallelMode.COLUMN,
gpus_per_node=gpus_per_node,
pipeline_parallel_size=model_config.mapping.pp_size,
parallel_rank=model_config.mapping.rank),
weights_loading_config=WeightsLoadingConfig(
weight_mode=WeightMode.FUSED_QKV_LINEAR),
quant_config=model_config.get_quant_config(),
skip_create_weights=model_config.skip_create_weights,
)
class Eagle3LlamaDecoderLayer(DecoderLayer):
def __init__(
self,
model_config: ModelConfig[LlamaConfig],
layer_idx: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
super().__init__()
config = model_config.pretrained_config
self.layer_idx = layer_idx
self.self_attn = Eagle3LlamaAttention(
model_config,
layer_idx=layer_idx,
)
self.mlp = GatedMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
bias=config.mlp_bias,
dtype=config.torch_dtype,
config=model_config,
)
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.hidden_norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
def forward(
self,
position_ids: torch.LongTensor,
embeds: torch.Tensor,
hidden_states: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
residual = hidden_states
embeds = self.input_layernorm(embeds)
hidden_states = self.hidden_norm(hidden_states)
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
hidden_states = self.self_attn(
position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
)
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states = self.mlp(hidden_states)
@ -161,6 +268,7 @@ class LlamaModel(DecoderModel):
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
@ -177,7 +285,8 @@ class LlamaModel(DecoderModel):
hidden_states, residual = decoder_layer(position_ids=position_ids,
hidden_states=hidden_states,
attn_metadata=attn_metadata,
residual=residual)
residual=residual,
spec_metadata=spec_metadata)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
@ -207,3 +316,123 @@ class MistralForCausalLM(DecoderModelForCausalLM[LlamaModel, LlamaConfig]):
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.vocab_size)
class Eagle3LlamaDraftModel(DecoderModel):
def __init__(self, model_config: ModelConfig[LlamaConfig]) -> None:
super().__init__(model_config)
config = model_config.pretrained_config
self.dtype = config.torch_dtype
self.fc = Linear(config.hidden_size * 3,
config.hidden_size,
bias=False,
dtype=config.torch_dtype)
self.midlayer = Eagle3LlamaDecoderLayer(model_config, 0)
self.norm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.d2t = nn.Parameter(torch.empty((config.draft_vocab_size, ),
dtype=torch.int64),
requires_grad=False)
def forward(
self,
attn_metadata: AttentionMetadata,
embed_tokens: torch.nn.Module,
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
spec_metadata: Optional[SpecMetadata] = None,
hidden_states: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
)
if inputs_embeds is None:
inputs_embeds = embed_tokens(input_ids).to(self.dtype)
assert hidden_states is not None and len(hidden_states) > 0
if len(hidden_states) > 1:
hidden_states = torch.cat(hidden_states, dim=-1)
hidden_states = self.fc(hidden_states.to(self.dtype))
else:
hidden_states = hidden_states[0].to(self.dtype)
hidden_states, residual = self.midlayer(position_ids=position_ids,
embeds=inputs_embeds,
hidden_states=hidden_states,
attn_metadata=attn_metadata)
hidden_states, hidden_states_to_save = self.norm(
hidden_states, residual)
assert isinstance(spec_metadata, Eagle3SpecMetadata)
spec_metadata.hidden_states.append(hidden_states_to_save)
return hidden_states
@register_auto_model("EAGLE3LlamaForCausalLM")
class Eagle3LlamaForCausalLM(DecoderModelForCausalLM[Eagle3LlamaDraftModel,
LlamaConfig]):
def __init__(
self,
model_config: ModelConfig[LlamaConfig],
):
super().__init__(
Eagle3LlamaDraftModel(model_config),
config=model_config,
hidden_size=model_config.pretrained_config.hidden_size,
vocab_size=model_config.pretrained_config.draft_vocab_size)
def forward(
self,
attn_metadata: AttentionMetadata,
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
return_context_logits: bool = False,
spec_metadata: Optional[SpecMetadata] = None,
hidden_states: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
if "embed_tokens" not in kwargs:
raise ValueError(
"EAGLE3 checkpoints do not include embed_tokens. "
"The embed_tokens module from the target model therefore needs to "
"be passed explicitly via extra_model_inputs. NOTE: we can "
"get rid of this hack by providing our own custom checkpoint "
"format that includes embed_tokens.")
output = self.model(
input_ids=input_ids,
embed_tokens=kwargs['embed_tokens'],
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
spec_metadata=spec_metadata,
hidden_states=hidden_states,
)
return self.logits_processor.forward(
output,
self.lm_head,
attn_metadata,
return_context_logits,
)
def load_weights(self, weights: Dict):
new_weights = {}
for k, v in weights.items():
new_k = "model." + k if 'lm_head' not in k and "embed_tokens" not in k else k
new_weights[new_k] = v
super().load_weights(new_weights)

View File

@ -206,6 +206,7 @@ class MambaHybridModel(DecoderModel):
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(

View File

@ -211,6 +211,7 @@ class MixtralModel(DecoderModel):
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(

View File

@ -177,6 +177,7 @@ class NemotronModel(DecoderModel):
input_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(

View File

@ -172,6 +172,7 @@ class QwenModel(DecoderModel):
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(

View File

@ -22,6 +22,7 @@ from ..modules.linear import Linear, WeightMode
from ..modules.logits_procesor import LogitsProcessor
from ..modules.rms_norm import RMSNorm
from ..pipeline_interface import PipelineInterface
from ..speculative import SpecMetadata
@contextlib.contextmanager
@ -213,6 +214,7 @@ class DecoderModel(nn.Module, metaclass=PPInitCaller):
input_ids: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
**kwargs,
) -> torch.Tensor:
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError(
@ -469,9 +471,9 @@ class DecoderModelForCausalLM(nn.Module,
inputs_embeds: Optional[torch.FloatTensor] = None,
pipeline_interface: Optional[PipelineInterface] = None,
return_context_logits: bool = False,
spec_metadata: Optional[SpecMetadata] = None,
**kwargs,
) -> torch.Tensor:
if self._supports_pp and self.pp_size > 1:
output = self.model(
input_ids=input_ids,
@ -479,6 +481,7 @@ class DecoderModelForCausalLM(nn.Module,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
pipeline_interface=pipeline_interface,
spec_metadata=spec_metadata,
)
# No need to compute logits for non-last PP ranks
@ -490,6 +493,7 @@ class DecoderModelForCausalLM(nn.Module,
attn_metadata=attn_metadata,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
spec_metadata=spec_metadata,
)
return self.logits_processor.forward(

View File

@ -41,6 +41,7 @@ def check_flash_mla_config(config):
def cal_max_tokens(peak_memory, total_gpu_memory, fraction, model_config,
mapping: Mapping):
# TODO: take space occupied by draft KV cache manager into account.
mem_per_token = 2
quant_config = model_config.quant_config
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache():

View File

@ -276,6 +276,7 @@ class TorchDecoder(Decoder):
generation_requests.append(request)
for request in extend_requests:
num_accepted = 0
if request.state != LlmRequestState.GENERATION_COMPLETE:
new_token = new_tokens_list[idx]
num_tokens = request.add_new_token(new_token, beam_idx)
@ -290,12 +291,15 @@ class TorchDecoder(Decoder):
# Reject.
break
num_accepted += 1
new_token = new_tokens_list[idx + i + 1]
num_tokens = request.add_new_token(new_token, beam_idx)
if self._handle_stop_criteria(request, new_token,
num_tokens, beam_idx):
break
request.py_num_accepted_draft_tokens = num_accepted
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted
idx += len(request.py_draft_tokens) + 1
for request in generation_requests:

View File

@ -47,6 +47,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
self.py_batch_idx = None
self.py_rewind_len = 0
self.py_draft_tokens = self.draft_tokens
self.py_last_draft_tokens = None
def convert_wordlist(word_list) -> List[List[int]]:

View File

@ -50,8 +50,9 @@ class ModelEngine(ABC):
@abstractmethod
def forward(self, scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager,
new_tensors_device: Optional[Dict[str, torch.Tensor]],
resource_manager: ResourceManager):
extra_model_inputs: Optional[Dict[str, Any]]):
raise NotImplementedError
def warmup(self, resource_manager: ResourceManager) -> None:
@ -207,6 +208,10 @@ def initialize_dummy_weights(
param.uniform_(low, high, generator=generator)
KV_CACHE_MANAGER_KEY = 'kv_cache_manager'
DRAFT_KV_CACHE_MANAGER_KEY = 'draft_kv_cache_manager'
class PyTorchModelEngine(ModelEngine):
def __init__(
@ -233,6 +238,10 @@ class PyTorchModelEngine(ModelEngine):
self.dist = dist
self.pytorch_backend_config = pytorch_backend_config
self.spec_config = spec_config
# We keep a reference to the last used spec metadata to
# accommodate certain target/draft model use cases. See
# py_executor.py for how this is used.
self.last_spec_metadata = None
self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures(
)
@ -347,9 +356,14 @@ class PyTorchModelEngine(ModelEngine):
self.max_draft_len = 0
self.iter_counter = 0
# We look up this key in resource_manager during forward to find the
# kv cache manager. Can be changed to support multiple model engines
# with different KV cache managers.
self.kv_cache_manager_key = KV_CACHE_MANAGER_KEY
def warmup(self, resource_manager: ResourceManager) -> None:
kv_cache_manager = resource_manager.get_resource_manager(
'kv_cache_manager')
self.kv_cache_manager_key)
spec_resource_manager = resource_manager.get_resource_manager(
'spec_resource_manager')
if kv_cache_manager is None:
@ -464,6 +478,11 @@ class PyTorchModelEngine(ModelEngine):
if cp_type == 'star_attention':
return
# TODO: CUDA graph support with eagle.
if self.spec_config is not None and self.spec_config.spec_dec_mode.is_eagle3(
):
return
if self._torch_compile_enabled:
# Disable cuda graph capture here so that we can properly capture it later
with no_cuda_graph():
@ -817,6 +836,7 @@ class PyTorchModelEngine(ModelEngine):
inputs['attn_metadata'].kv_lens_cuda[
num_ctx_requests:num_seqs] += (
self.previous_kv_lens_offsets_cuda[:num_gen_requests])
return inputs
def _prepare_tp_inputs(
@ -918,6 +938,7 @@ class PyTorchModelEngine(ModelEngine):
past_seen_token_num = request.max_beam_num_tokens - 1
position_ids.append(past_seen_token_num)
draft_lens.append(num_draft_tokens)
prompt_lengths.append(num_draft_tokens + 1)
# draft tokens
input_ids.extend(request.py_draft_tokens)
gather_ids.extend(
@ -955,9 +976,9 @@ class PyTorchModelEngine(ModelEngine):
(1 + self.max_draft_len))
num_cached_tokens_per_seq.append(past_seen_token_num +
self.max_draft_len + 1)
prompt_lengths.append(request.py_prompt_len)
request_ids.append(request.py_request_id)
prompt_lengths.append(request.py_prompt_len)
sequence_lengths.extend([1] * len(generation_requests))
gather_ids.extend(
@ -1075,6 +1096,9 @@ class PyTorchModelEngine(ModelEngine):
attn_metadata.request_ids = request_ids
attn_metadata.prompt_lens = prompt_lengths
attn_metadata.num_contexts = len(scheduled_requests.context_requests)
if self.spec_config is not None and self.spec_config.spec_dec_mode.extend_ctx(
):
attn_metadata.num_contexts += len(extend_requests)
attn_metadata.kv_cache_params = KVCacheParams(
use_cache=True,
@ -1108,6 +1132,7 @@ class PyTorchModelEngine(ModelEngine):
spec_metadata.num_generations = len(
scheduled_requests.generation_requests)
spec_metadata.num_tokens = total_num_tokens
spec_metadata.seq_lens = sequence_lengths
spec_metadata.prepare()
inputs['spec_metadata'] = spec_metadata
@ -1223,6 +1248,7 @@ class PyTorchModelEngine(ModelEngine):
spec_metadata.num_generations = len(
scheduled_requests.generation_requests)
spec_metadata.num_tokens = num_tokens
spec_metadata.seq_lens = sequence_lengths
spec_metadata.prepare()
inputs['spec_metadata'] = spec_metadata
@ -1505,10 +1531,11 @@ class PyTorchModelEngine(ModelEngine):
def forward(self,
scheduled_requests: ScheduledRequests,
resource_manager: ResourceManager,
new_tensors_device: Optional[Dict[str, torch.Tensor]] = None):
new_tensors_device: Optional[Dict[str, torch.Tensor]] = None,
extra_model_inputs: Optional[Dict[str, Any]] = None):
kv_cache_manager = resource_manager.get_resource_manager(
'kv_cache_manager')
self.kv_cache_manager_key)
attn_metadata = self._set_up_attn_metadata(kv_cache_manager)
if self.spec_config is not None:
@ -1523,6 +1550,10 @@ class PyTorchModelEngine(ModelEngine):
if kv_cache_manager is None:
inputs, gather_ids = self._prepare_tp_inputs_no_cache(
scheduled_requests, attn_metadata, spec_metadata)
if extra_model_inputs is not None:
inputs.update(extra_model_inputs)
self.last_spec_metadata = spec_metadata
if self.mapping.has_pp() and not self.mapping.is_last_pp_rank():
pp_interface = self._forward_step_intermediate(inputs)
pp_interface.send()
@ -1549,6 +1580,10 @@ class PyTorchModelEngine(ModelEngine):
attn_metadata,
spec_metadata,
new_tensors_device)
if extra_model_inputs is not None:
inputs.update(extra_model_inputs)
self.last_spec_metadata = spec_metadata
self.iter_counter += 1
if maybe_graph is None:

View File

@ -11,7 +11,7 @@ import weakref
from collections import namedtuple
from contextlib import contextmanager
from itertools import chain
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Tuple, Union
import dill # nosec B403
import numpy as np
@ -139,7 +139,8 @@ class PyExecutor:
enable_overlap_scheduler: bool = False,
max_input_len: int = 2048,
max_batch_size: int = 8,
kv_cache_transceiver: KvCacheTransceiver = None):
kv_cache_transceiver: KvCacheTransceiver = None,
draft_model_engine: Optional[ModelEngine] = None):
super(PyExecutor, self).__init__()
self.device_id = torch.cuda.current_device()
self.global_rank = global_mpi_rank()
@ -158,6 +159,9 @@ class PyExecutor:
self.decoder = decoder
self.dist = dist
# Draft model for certain spec decode algorithms, e.g. EAGLE3
self.draft_model_engine = draft_model_engine
# enqueue and _fetch_new_requests used data
self.enqueue_lock = threading.Lock()
self.active = True
@ -178,6 +182,12 @@ class PyExecutor:
"kv_cache_manager")
self.enable_kv_cache_events = self.kv_cache_manager is not None and self.kv_cache_manager.event_buffer_max_size > 0
if self.draft_model_engine is not None and self.kv_cache_manager is not None:
if self.kv_cache_manager.enable_block_reuse:
raise NotImplementedError(
"Draft model engine + KV cache reuse is not supported yet. "
"This will be fixed in the near future!")
self.max_input_len = max_input_len
# _executor_loop private data
self.max_num_active_requests = model_engine.get_max_num_sequences()
@ -201,6 +211,8 @@ class PyExecutor:
self.canceled_req_ids = ReqIdsSet()
self.model_engine.warmup(self.resource_manager)
if self.draft_model_engine is not None:
self.draft_model_engine.warmup(self.resource_manager)
self.is_shutdown = False
@ -214,6 +226,12 @@ class PyExecutor:
event_loop = self._executor_loop_pp_overlap if enable_overlap_scheduler else self._executor_loop_pp
else:
event_loop = self._executor_loop_overlap if enable_overlap_scheduler else self._executor_loop
if self.draft_model_engine is not None and event_loop.__name__ != self._executor_loop.__name__:
raise NotImplementedError(
"Drafting is not supported for selected executor loop. "
"Please disable disagg/pipeline parallelism/overlap scheduler.")
self.worker_thread = threading.Thread(target=event_loop, daemon=True)
self.worker_thread.start()
@ -727,6 +745,10 @@ class PyExecutor:
num_dummy_request = self._get_num_dummy_request()
if num_dummy_request > 0:
self._merge_dummy_request(num_dummy_request)
if self.draft_model_engine is not None:
self._prepare_draft_requests()
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
)
@ -750,6 +772,8 @@ class PyExecutor:
if scheduled_batch.batch_size > 0:
self.resource_manager.prepare_resources(scheduled_batch)
if self.draft_model_engine is not None:
self._prepare_draft_tokens(scheduled_batch)
if self.kv_cache_transceiver:
# For generation requests which have completed KV cache transfer
@ -789,6 +813,35 @@ class PyExecutor:
self.response_cv.notify_all()
self.shutdown_event.set()
def _prepare_draft_requests(self):
try:
# Set draft tokens here to make the KV cache manager
# and scheduler aware of them.
for req in self.active_requests:
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
continue
req.py_last_draft_tokens = req.py_draft_tokens
max_draft_len = self.model_engine.spec_config.max_draft_tokens
max_seq_len = self.model_engine.max_seq_len
# Subtract 1 to account for the token we will add on this forward
# pass.
draft_len = min(max_seq_len - 1 - req.get_num_tokens(0),
max_draft_len)
if draft_len > 0:
req.py_draft_tokens = [0] * draft_len
req.py_draft_pages_allocated = draft_len
else:
req.py_draft_tokens = None
req.py_draft_pages_allocated = 0
except Exception as e:
traceback.print_exc()
error_msg = str(e)
logger.error(f"Encountered an error in decode: {error_msg}")
self._handle_errors(error_msg)
def _executor_loop_overlap(self):
torch.cuda.set_device(self.device_id)
got_finish_signal = False
@ -1531,6 +1584,189 @@ class PyExecutor:
logger.error(f"Encountered an error in decode: {error_msg}")
self._handle_errors(error_msg)
@nvtx_range("_prepare_draft_batch")
def _prepare_draft_batch(
self, scheduled_requests: ScheduledRequests
) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]:
"""
Prepares a batch for the draft model engine. Draft tokens are only produced
for generation requests.
The requests are prepared as follows:
1. The first time the draft engine sees a request, it's a context request.
2. Otherwise, if draft tokens were accepted on the last target model decoding
step, it's a chunked context request (we process all the accepted tokens together).
3. Otherwise, it's a generation request.
"""
try:
draft_batch = ScheduledRequests()
req_id_to_num_rejected_tokens = {}
for request in scheduled_requests.generation_requests:
if request.py_draft_pages_allocated == 0:
# No space for draft tokens.
continue
num_draft_tokens = len(
request.py_last_draft_tokens
) if request.py_last_draft_tokens is not None else 0
request.py_draft_tokens = []
num_accepted_tokens = getattr(request,
"py_num_accepted_draft_tokens", 0)
num_rejected_tokens = num_draft_tokens - num_accepted_tokens
assert num_rejected_tokens >= 0
req_id_to_num_rejected_tokens[
request.py_request_id] = num_rejected_tokens
spec_config = self.model_engine.spec_config
beam_idx = 0
input_tokens = spec_config.get_draft_model_prompt(
request.get_tokens()[beam_idx])
if request.max_beam_num_tokens - 1 == request.py_prompt_len:
# This is the first time the draft model is seeing this request.
# Prepare a context request. We discard the first token and take
# the newly decoded one - this is the convention for EAGLE 2 and 3.
assert num_draft_tokens == 0
new_request = LlmRequest(
request_id=request.py_request_id,
max_new_tokens=request.py_max_new_tokens,
input_tokens=input_tokens,
sampling_config=request.sampling_config,
is_streaming=False)
draft_batch.context_requests.append(new_request)
elif getattr(request, "py_num_accepted_draft_tokens", 0) == 0:
new_request = LlmRequest(
request_id=request.py_request_id,
max_new_tokens=request.py_max_new_tokens,
input_tokens=input_tokens[:-1],
sampling_config=request.sampling_config,
is_streaming=False)
# Explicitly add the last token so get_last_tokens() returns
# the right value
new_request.add_new_token(input_tokens[-1], beam_idx)
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
draft_batch.generation_requests.append(new_request)
else:
new_request = LlmRequest(
request_id=request.py_request_id,
max_new_tokens=request.py_max_new_tokens,
input_tokens=input_tokens,
sampling_config=request.sampling_config,
is_streaming=False)
new_request.context_chunk_size = num_accepted_tokens + 1
new_request.context_current_position = len(
input_tokens) - num_accepted_tokens - 1
draft_batch.context_requests.append(new_request)
new_request.py_stop_words_list = request.py_stop_words_list
new_request.is_dummy = False
return draft_batch, req_id_to_num_rejected_tokens
except Exception as e:
traceback.print_exc()
error_msg = str(e)
logger.error(f"Encountered an error in decode: {error_msg}")
self._handle_errors(error_msg)
@nvtx_range("_prepare_draft_tokens")
def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests):
try:
draft_batch, num_rejected_tokens = self._prepare_draft_batch(
scheduled_requests)
if draft_batch.batch_size == 0:
return
req_id_to_old_request = {
req.py_request_id: req
for req in chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests)
}
spec_metadata = self.model_engine.last_spec_metadata
hidden_states = spec_metadata.get_hidden_states(
draft_batch, num_rejected_tokens)
extra_model_inputs = {'hidden_states': hidden_states}
if spec_metadata.spec_dec_mode.is_eagle3():
# Another eagle3 hack. Eagle3 checkpoints don't have embed_tokens,
# so we need to provide them some other way. We can get rid of this
# hack if we provide our own preprocessed eagle3 checkpoints.
extra_model_inputs[
'embed_tokens'] = self.model_engine.model.model.embed_tokens
outputs = self.draft_model_engine.forward(
draft_batch,
self.resource_manager,
extra_model_inputs=extra_model_inputs)
if spec_metadata.spec_dec_mode.is_eagle3():
outputs['d2t'] = self.draft_model_engine.model.model.d2t.data
self._update_request_states(draft_batch)
self._decode(draft_batch, outputs)
def _process_decoded_tokens():
new_requests = []
for req in chain(draft_batch.context_requests,
draft_batch.generation_requests):
target_model_req = req_id_to_old_request[req.py_request_id]
target_model_req.py_draft_tokens.append(
req.get_last_tokens(0))
if req.state != LlmRequestState.GENERATION_COMPLETE and len(
target_model_req.py_draft_tokens
) < target_model_req.py_draft_pages_allocated:
new_requests.append(req)
return new_requests
new_requests = _process_decoded_tokens()
if not new_requests:
return
draft_batch.generation_requests = new_requests
draft_batch.context_requests = []
for _ in range(spec_metadata.max_draft_tokens - 1):
draft_spec_metadata = self.draft_model_engine.spec_metadata
hidden_states = draft_spec_metadata.get_hidden_states(
draft_batch)
extra_model_inputs = {'hidden_states': hidden_states}
if spec_metadata.spec_dec_mode.is_eagle3():
# See note above.
extra_model_inputs[
'embed_tokens'] = self.model_engine.model.model.embed_tokens
outputs = self.draft_model_engine.forward(
draft_batch,
self.resource_manager,
extra_model_inputs=extra_model_inputs)
if spec_metadata.spec_dec_mode.is_eagle3():
outputs[
'd2t'] = self.draft_model_engine.model.model.d2t.data
self._update_request_states(draft_batch)
self._decode(draft_batch, outputs)
new_requests = _process_decoded_tokens()
if not new_requests:
return
draft_batch.generation_requests = new_requests
except Exception as e:
traceback.print_exc()
error_msg = str(e)
logger.error(f"Encountered an error in decode: {error_msg}")
self._handle_errors(error_msg)
def _handle_errors(self, error_msg: Optional[str] = None):
error_responses = {}
error_msg = error_msg or "error"

View File

@ -9,7 +9,7 @@ from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from ..attention_backend.interface import AttentionRuntimeFeatures
from ..speculative import (get_num_spec_layers, get_spec_decoder,
from ..speculative import (Eagle3Config, get_num_spec_layers, get_spec_decoder,
get_spec_resource_manager)
from ._util import check_flash_mla_config, estimate_max_kv_cache_tokens, is_mla
from .config import PyTorchConfig
@ -18,13 +18,76 @@ from .decoder import (EarlyStopDecoder, TorchDecoder, TorchStarAttentionDecoder,
from .distributed import MPIDist
from .guided_decoder import GuidedDecoderResourceManager
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
from .model_engine import PyTorchModelEngine
from .model_engine import (DRAFT_KV_CACHE_MANAGER_KEY, KV_CACHE_MANAGER_KEY,
PyTorchModelEngine)
from .py_executor import PyExecutor
from .resource_manager import KVCacheManager, ResourceManager
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
SimpleScheduler)
def _create_kv_cache_manager(model_engine: PyTorchModelEngine, mapping: Mapping,
executor_config: ExecutorConfig) -> KVCacheManager:
config = model_engine.model.model_config.pretrained_config
quant_config = model_engine.model.model_config.quant_config
spec_config = executor_config.speculative_config
hidden_size = config.hidden_size
num_attention_heads = config.num_attention_heads
num_key_value_heads = getattr(config, 'num_key_value_heads',
num_attention_heads)
head_dim = hidden_size // num_attention_heads
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache():
kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8
else:
kv_cache_dtype = str_dtype_to_binding(
torch_dtype_to_str(model_engine.dtype))
num_hidden_layers = len(mapping.pp_layers_torch(config.num_hidden_layers))
# the number of layers using attention in Nemotron5 is lower than the number of hidden layers
if config.architectures[0] == "Nemotron5ForCausalLM":
# attention layers are derived from configuration (hybrid_override_pattern)
num_hidden_layers = config.hybrid_override_pattern.count("*")
if is_mla(config):
if spec_config is not None:
num_hidden_layers += get_num_spec_layers(spec_config)
return KVCacheManager(
executor_config.kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY,
num_layers=num_hidden_layers,
num_kv_heads=1,
head_dim=config.kv_lora_rank + config.qk_rope_head_dim,
tokens_per_block=executor_config.tokens_per_block,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
num_extra_kv_tokens=0
if spec_config is None else spec_config.num_extra_kv_tokens,
)
else:
if spec_config is not None:
num_hidden_layers += get_num_spec_layers(spec_config)
return KVCacheManager(
executor_config.kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_hidden_layers,
num_kv_heads=num_key_value_heads,
head_dim=head_dim,
tokens_per_block=executor_config.tokens_per_block,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
num_extra_kv_tokens=0
if spec_config is None else spec_config.num_extra_kv_tokens,
)
def create_py_executor(executor_config: ExecutorConfig,
checkpoint_dir: str = None,
engine_dir: str = None):
@ -32,7 +95,6 @@ def create_py_executor(executor_config: ExecutorConfig,
executor_config.pytorch_backend_config = PyTorchConfig()
pytorch_backend_config = executor_config.pytorch_backend_config
spec_config = executor_config.speculative_config
if executor_config.mapping is None:
mapping = Mapping(world_size=tensorrt_llm.mpi_world_size(),
@ -65,9 +127,13 @@ def create_py_executor(executor_config: ExecutorConfig,
executor_config.max_num_tokens = 8192
dist = MPIDist(mapping=mapping)
spec_config = executor_config.speculative_config
has_draft_model_engine = isinstance(spec_config, Eagle3Config)
attn_runtime_features = AttentionRuntimeFeatures(
chunked_prefill=executor_config.enable_chunked_context,
cache_reuse=executor_config.kv_cache_config.enable_block_reuse,
has_speculative_draft_tokens=has_draft_model_engine,
)
model_engine = PyTorchModelEngine(
@ -82,6 +148,23 @@ def create_py_executor(executor_config: ExecutorConfig,
spec_config=spec_config,
guided_decoding_config=executor_config.guided_decoding_config,
)
if has_draft_model_engine:
draft_model_engine = PyTorchModelEngine(
spec_config.eagle_weights_path,
pytorch_backend_config,
batch_size=executor_config.max_batch_size,
max_num_tokens=executor_config.max_num_tokens,
max_seq_len=executor_config.max_seq_len,
mapping=mapping,
attn_runtime_features=attn_runtime_features,
dist=dist,
spec_config=copy.copy(spec_config),
)
draft_model_engine.kv_cache_manager_key = DRAFT_KV_CACHE_MANAGER_KEY
else:
draft_model_engine = None
# PyTorchModelEngine modifies these fields, update them to executor_config
if pytorch_backend_config.enable_overlap_scheduler:
max_seq_len = model_engine.max_seq_len + 1
@ -98,19 +181,6 @@ def create_py_executor(executor_config: ExecutorConfig,
#NOTE: non-generation models do not have kv cache
executor_config.pytorch_backend_config.use_kv_cache = False
hidden_size = model_engine.model.config.hidden_size
num_attention_heads = model_engine.model.config.num_attention_heads
num_key_value_heads = getattr(model_engine.model.config,
'num_key_value_heads', num_attention_heads)
head_dim = hidden_size // num_attention_heads
quant_config = model_engine.model.model_config.quant_config
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache():
kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8
else:
kv_cache_dtype = str_dtype_to_binding(
torch_dtype_to_str(model_engine.dtype))
kv_cache_max_tokens = None
if model_engine.model.model_config.is_generation:
kv_cache_max_tokens = estimate_max_kv_cache_tokens(
@ -131,70 +201,37 @@ def create_py_executor(executor_config: ExecutorConfig,
ctx_chunk_config = None
config = model_engine.model.model_config.pretrained_config
# kv cache manager selection
if is_mla(config):
if check_flash_mla_config(config):
executor_config.tokens_per_block = 64
logger.info(
f"Change tokens_per_block to: {executor_config.tokens_per_block} for using FlashMLA"
)
executor_config.kv_cache_config.enable_block_reuse = False
executor_config.enable_chunked_context = False
if executor_config.pytorch_backend_config.use_kv_cache:
num_hidden_layers = len(
mapping.pp_layers_torch(
model_engine.model.config.num_hidden_layers))
# has kv cache
if is_mla(config):
if check_flash_mla_config(config):
executor_config.tokens_per_block = 64
logger.info(
f"Change tokens_per_block to: {executor_config.tokens_per_block} for using FlashMLA"
)
executor_config.kv_cache_config.enable_block_reuse = False
executor_config.enable_chunked_context = False
if spec_config is not None:
num_hidden_layers += get_num_spec_layers(spec_config)
kv_cache_manager = KVCacheManager(
executor_config.kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.
SELFKONLY,
num_layers=num_hidden_layers,
num_kv_heads=1,
head_dim=config.kv_lora_rank + config.qk_rope_head_dim,
tokens_per_block=executor_config.tokens_per_block,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
num_extra_kv_tokens=0
if spec_config is None else spec_config.num_extra_kv_tokens,
)
else:
# the number of layers using attention in Nemotron5 is lower from the number of hidden layers
if model_engine.model.config.architectures[
0] == "Nemotron5ForCausalLM":
# attention layers are derived from configuration (hybrid_override_pattern)
num_hidden_layers = model_engine.model.config.hybrid_override_pattern.count(
"*")
kv_cache_manager = KVCacheManager(
executor_config.kv_cache_config,
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
num_layers=num_hidden_layers,
num_kv_heads=num_key_value_heads,
head_dim=head_dim,
tokens_per_block=executor_config.tokens_per_block,
max_seq_len=executor_config.max_seq_len,
max_batch_size=executor_config.max_batch_size,
mapping=mapping,
dtype=kv_cache_dtype,
num_extra_kv_tokens=0
if spec_config is None else spec_config.num_extra_kv_tokens,
)
kv_cache_manager = _create_kv_cache_manager(model_engine, mapping,
executor_config)
draft_kv_cache_manager = _create_kv_cache_manager(
draft_model_engine, mapping,
executor_config) if draft_model_engine is not None else None
else:
# no kv cache
kv_cache_manager = None
draft_kv_cache_manager = None
# KVCacheManager modifies these fields, update them to executor_config
if kv_cache_manager is not None:
executor_config.max_seq_len = kv_cache_manager.max_seq_len
resources = {
"kv_cache_manager": kv_cache_manager
KV_CACHE_MANAGER_KEY: kv_cache_manager
} if kv_cache_manager is not None else {}
if draft_kv_cache_manager is not None:
resources[DRAFT_KV_CACHE_MANAGER_KEY] = draft_kv_cache_manager
if spec_config is not None:
spec_resource_manager = get_spec_resource_manager(
spec_config, model_engine.model.config, model_engine.batch_size * 2)
@ -225,10 +262,11 @@ def create_py_executor(executor_config: ExecutorConfig,
resources[key] = value
resource_manager = ResourceManager(resources)
# Make sure the kv cache manager is always invoked last as it could
# depend on the results of other resource managers.
if kv_cache_manager is not None:
resource_manager.resource_managers.move_to_end("kv_cache_manager",
resource_manager.resource_managers.move_to_end(KV_CACHE_MANAGER_KEY,
last=True)
num_micro_batches = 1
@ -276,5 +314,6 @@ def create_py_executor(executor_config: ExecutorConfig,
enable_overlap_scheduler,
max_batch_size=executor_config.max_batch_size,
max_input_len=executor_config.max_input_len,
kv_cache_transceiver=kv_cache_transceiver)
kv_cache_transceiver=kv_cache_transceiver,
draft_model_engine=draft_model_engine)
return py_executor

View File

@ -210,6 +210,7 @@ class KVCacheManager(BaseResourceManager):
self.kv_cache_pool_mapping = self.impl.get_layer_to_pool_mapping()
self.num_pools = self.impl.num_pools
self.max_blocks_per_seq = self.impl.max_blocks_per_seq
self.enable_block_reuse = kv_cache_config.enable_block_reuse
def shutdown(self):
self.impl.release_pools()

View File

@ -1,3 +1,4 @@
from .eagle3 import Eagle3Config, Eagle3SpecMetadata
from .interface import SpecConfig, SpecMetadata
from .mtp import MTPConfig, MTPEagleWorker, MTPSpecMetadata, MTPWorker
from .utils import (get_num_spec_layers, get_spec_decoder, get_spec_metadata,
@ -5,6 +6,7 @@ from .utils import (get_num_spec_layers, get_spec_decoder, get_spec_metadata,
__all__ = [
"SpecConfig", "SpecMetadata", "MTPConfig", "MTPWorker", "MTPEagleWorker",
"MTPSpecMetadata", "get_spec_metadata", "get_spec_resource_manager",
"get_spec_decoder", "get_num_spec_layers"
"Eagle3Config", "Eagle3SpecMetadata", "MTPSpecMetadata",
"get_spec_metadata", "get_spec_resource_manager", "get_spec_decoder",
"get_num_spec_layers"
]

View File

@ -0,0 +1,105 @@
from dataclasses import dataclass, field
from itertools import chain
from typing import Dict, List, Optional, Tuple
import torch
from ..pyexecutor.decoder import TorchDecoder
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
@dataclass
class Eagle3Config(SpecConfig):
spec_dec_name: str = "EAGLE3"
eagle_weights_path: Optional[str] = None
num_layers: int = 0
def __post_init__(self):
if self.eagle_weights_path is None:
raise ValueError("Path to EAGLE3 weights must be specified.")
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
self.num_extra_kv_tokens = 0
def update_from_model_config(self, model_config):
self.num_layers = model_config.num_hidden_layers
def get_draft_model_prompt(self,
input_tokens: torch.Tensor) -> torch.Tensor:
"""
Eagle3 always throws away the first token when processing draft inputs
"""
return input_tokens[1:]
@dataclass
class Eagle3SpecMetadata(SpecMetadata):
hidden_states: List[torch.Tensor] = field(default_factory=list)
num_layers: int = 0
layers_to_capture: Tuple[int, ...] = field(init=False)
target_model_embed_tokens: Optional[torch.nn.Module] = None
def __post_init__(self):
if self.num_layers == 1:
# For the draft model, we have to capture hiddens states
# manually outside of the decoder layer.
self.layers_to_capture = ()
else:
if self.num_layers <= 5:
raise ValueError("Not enough hidden layers for EAGLE")
self.layers_to_capture = (1, self.num_layers // 2 - 1,
self.num_layers - 3)
def prepare(self):
self.hidden_states = []
def maybe_capture_hidden_states(self, layer_id: int,
hidden_states: torch.Tensor,
residual: torch.Tensor) -> None:
if layer_id in self.layers_to_capture:
# TODO(miovine): write directly into a pre-allocated buffer for
# CUDA graph support.
self.hidden_states.append(hidden_states + residual)
def get_hidden_states(
self,
scheduled_requests,
num_rejected_tokens: Optional[Dict] = None) -> List[torch.Tensor]:
req_id_to_gather_ids = {}
seq_start = 0
for req_id, seqlen in zip(self.request_ids, self.seq_lens):
if num_rejected_tokens is not None:
if req_id in num_rejected_tokens:
req_id_to_gather_ids[req_id] = list(
range(seq_start,
seq_start + seqlen - num_rejected_tokens[req_id]))
else:
req_id_to_gather_ids[req_id] = [seq_start + seqlen - 1]
seq_start += seqlen
hidden_states_gather_ids = []
for req in chain(scheduled_requests.context_requests,
scheduled_requests.generation_requests):
hidden_states_gather_ids.extend(
req_id_to_gather_ids[req.py_request_id])
return [h[hidden_states_gather_ids] for h in self.hidden_states]
class Eagle3Decoder(TorchDecoder):
def _batch_decode(self, scheduled_requests, model_outputs):
logits = model_outputs["logits"]
new_tokens_device = torch.argmax(logits, dim=-1)
if "d2t" in model_outputs:
d2t = model_outputs["d2t"]
new_tokens_device = d2t[new_tokens_device] + new_tokens_device
new_tokens_host = new_tokens_device.to('cpu', non_blocking=True)
new_tensors_device = {"new_tokens_device": new_tokens_device}
new_tensors_host = {"new_tokens_host": new_tokens_host}
decoder_event = torch.cuda.Event()
decoder_event.record()
return new_tensors_device, new_tensors_host, decoder_event

View File

@ -1,35 +1,28 @@
import copy
from dataclasses import dataclass, field
from enum import IntEnum, auto
from typing import List, Optional
from typing import Dict, List, Optional
import torch
from ..model_config import TConfig
from ..pyexecutor.scheduler import ScheduledRequests
class SpeculativeDecodingMode(IntEnum):
MTP = auto()
MTP_EAGLE = auto()
MEDUSA = auto()
EAGLE = auto()
LOOKAHEAD = auto()
EAGLE3 = auto()
NONE = auto()
def is_mtp(self):
return self == SpeculativeDecodingMode.MTP or SpeculativeDecodingMode.MTP_EAGLE
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE
def is_mtp_eagle(self):
return self == SpeculativeDecodingMode.MTP_EAGLE
def is_medusa(self):
return self == SpeculativeDecodingMode.MEDUSA
def is_eagle(self):
return self == SpeculativeDecodingMode.EAGLE
def is_lookahead(self):
return self == SpeculativeDecodingMode.LOOKAHEAD
def is_eagle3(self):
return self == SpeculativeDecodingMode.EAGLE3
def is_none(self):
return self == SpeculativeDecodingMode.NONE
@ -38,22 +31,24 @@ class SpeculativeDecodingMode(IntEnum):
return self.is_mtp()
def needs_kv_cache_rewind(self):
return self.is_mtp() or self.is_eagle() or self.is_lookahead(
) or self.is_medusa()
return self.is_mtp()
def support_overlap_scheduler(self):
return self.is_mtp()
def extend_ctx(self):
"""
If true, treat generation requests with draft tokens as
chunked context requests at the kernel level. Required for
any spec dec mode that uses the SpecExecutor.
"""
return self.is_eagle3()
@staticmethod
def from_string(name: str):
name_map = {
"MTP": SpeculativeDecodingMode.MTP,
"MEDUSA": SpeculativeDecodingMode.MEDUSA,
"EAGLE": SpeculativeDecodingMode.EAGLE,
"LOOKAHEAD": SpeculativeDecodingMode.LOOKAHEAD,
None: SpeculativeDecodingMode.NONE,
}
return name_map[name]
def from_string(name: Optional[str]) -> "SpeculativeDecodingMode":
if name is None:
return SpeculativeDecodingMode.NONE
return SpeculativeDecodingMode[name.upper()]
@dataclass
@ -75,6 +70,14 @@ class SpecConfig:
def update_from_model_config(self, model_config: TConfig):
pass
def get_draft_model_prompt(self,
input_tokens: torch.Tensor) -> torch.Tensor:
"""
Override for spec dec modes that need to preprocess prompt
tokens before passing them to the draft model.
"""
return input_tokens
@dataclass
class SpecMetadata:
@ -98,6 +101,8 @@ class SpecMetadata:
# The request ID of each sequence in the batch.
# The shape is (batch_size).
request_ids: Optional[List[int]] = None
# Sequence length for each request.
seq_lens: Optional[List[int]] = None
# The gather ids for logits.
gather_ids: Optional[torch.Tensor] = None
# The number of tokens for speculative model/layer
@ -111,7 +116,8 @@ class SpecMetadata:
# draft/target layers. But KVCacheManager can only support kv caches with the
# same kv lengths for different layers. Add extra kv token in kv cache manager
# to haddle this issue.
num_extra_kv_tokens: Optional[int] = 0
num_extra_kv_tokens: Optional[int] = 0 # Number of layers in target model
num_layers: int = 0
def prepare():
"""
@ -130,3 +136,29 @@ class SpecMetadata:
cuda_graph_metadata.max_num_requests = max_batch_size
cuda_graph_metadata.__post_init__()
return cuda_graph_metadata
def maybe_capture_hidden_states(self, layer_id: int,
hidden_states: torch.Tensor,
residual: torch.Tensor) -> None:
"""
Some spec decode algorithms require hidden states from the target
model. Use this method to record them. By default, does nothing.
"""
def get_hidden_states(
self,
scheduled_requests: ScheduledRequests,
num_rejected_tokens: Optional[Dict] = None) -> List[torch.Tensor]:
"""
Return any captured hidden states. Should do any necessary
pre-processing.
num_rejected_tokens is a dictionary mapping request IDs to the
number of tokens rejected for that request. If a request ID isn't
in the dictionary, it means that the request is not needed for drafting.
If the dictionary is not given, this function assumes that the hidden
states are being prepared for running the draft model autoregressively,
and only the last hidden state vector for each sequence is returned.
"""
return []

View File

@ -1,3 +1,4 @@
from .eagle3 import Eagle3Decoder, Eagle3SpecMetadata
from .mtp import MTPDecoder, MTPHiddenStatesManager, MTPSpecMetadata
@ -11,6 +12,11 @@ def get_spec_metadata(spec_config,
mtp_num_modules=spec_config.num_nextn_predict_layers,
max_num_requests=max_num_requests,
mtp_hidden_states_manager=spec_resource_manager)
elif spec_config.spec_dec_mode.is_eagle3():
return Eagle3SpecMetadata(max_draft_tokens=spec_config.max_draft_tokens,
spec_dec_mode=spec_config.spec_dec_mode,
max_num_requests=max_num_requests,
num_layers=spec_config.num_layers)
else:
return None
@ -29,6 +35,8 @@ def get_spec_resource_manager(spec_config, model_config, max_num_requests):
def get_spec_decoder(max_seq_len, spec_config):
if spec_config.spec_dec_mode.is_mtp():
return MTPDecoder(max_seq_len, spec_config)
if spec_config.spec_dec_mode.is_eagle3():
return Eagle3Decoder(max_seq_len)
else:
return None

View File

@ -216,6 +216,7 @@ class EagleDecodingConfig(DecodingBaseConfig):
dynamic_tree_max_topK: Optional[int] = None
num_eagle_layers: Optional[int] = None
max_non_leaves_per_layer: Optional[int] = None
pytorch_eagle_weights_path: Optional[str] = None
@classmethod
def from_dict(cls, data: dict):
@ -1145,15 +1146,23 @@ class LlmArgs:
self.build_config.max_draft_len = self.speculative_config.max_draft_len
eagle_config = EagleConfig(
self.speculative_config.eagle_choices,
self.speculative_config.greedy_sampling,
self.speculative_config.posterior_threshold,
self.speculative_config.use_dynamic_tree,
self.speculative_config.dynamic_tree_max_topK)
self.decoding_config = DecodingConfig(
decoding_mode=DecodingMode.Eagle(),
eagle_config=eagle_config)
if self.backend != 'pytorch':
eagle_config = EagleConfig(
self.speculative_config.eagle_choices,
self.speculative_config.greedy_sampling,
self.speculative_config.posterior_threshold,
self.speculative_config.use_dynamic_tree,
self.speculative_config.dynamic_tree_max_topK)
self.decoding_config = DecodingConfig(
decoding_mode=DecodingMode.Eagle(),
eagle_config=eagle_config)
else:
from tensorrt_llm._torch.speculative import Eagle3Config
self.speculative_config = Eagle3Config(
max_draft_tokens=self.speculative_config.max_draft_len,
eagle_weights_path=self.speculative_config.
pytorch_eagle_weights_path)
elif isinstance(self.speculative_config, MTPDecodingConfig):
from tensorrt_llm._torch.speculative import MTPConfig
self.speculative_config = MTPConfig(

View File

@ -134,7 +134,8 @@ class OPTModel(Module):
attention_params=None,
prompt_embedding_table=None,
prompt_tasks=None,
prompt_vocab_size=None):
prompt_vocab_size=None,
**kwargs):
args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size
] if prompt_embedding_table is not None else []

View File

@ -1440,13 +1440,37 @@ def test_ptq_quickstart_advanced_mtp(llm_root, llm_venv, model_name,
str(example_root / "quickstart_advanced.py"),
"--enable_overlap_scheduler",
"--use_cuda_graph",
"--mtp_nextn",
"--spec_decode_nextn",
"1", # test 1 MTP module
"--spec_decode_algo",
"MTP",
"--model_dir",
f"{llm_models_root()}/{model_path}",
])
@pytest.mark.parametrize("model_name,model_path,eagle_model_path", [
("Llama-3.1-8b-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct",
"EAGLE3-LLaMA3.1-Instruct-8B"),
])
def test_ptp_quickstart_advanced_eagle3(llm_root, llm_venv, model_name,
model_path, eagle_model_path):
print(f"Testing {model_name}.")
example_root = Path(os.path.join(llm_root, "examples", "pytorch"))
llm_venv.run_cmd([
str(example_root / "quickstart_advanced.py"),
"--spec_decode_nextn",
"4",
"--spec_decode_algo",
"eagle3",
"--model_dir",
f"{llm_models_root()}/{model_path}",
"--eagle_model_dir",
f"{llm_models_root()}/{eagle_model_path}",
"--kv_cache_enable_block_reuse",
])
@pytest.mark.skip_less_device_memory(110000)
@pytest.mark.skip_less_device(8)
@pytest.mark.parametrize("model_name,model_path", [

View File

@ -14,7 +14,7 @@ l0_b200:
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]
- test_e2e.py::test_ptq_quickstart_advanced_mtp[DeepSeek-V3-Lite-BF16-DeepSeek-V3-Lite/bf16]
- test_e2e.py::test_ptp_quickstart_advanced_mixed_precision
- test_e2e.py::test_ptp_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
- examples/test_pytorch.py::test_llm_llama_1gpu[llama-3.1-8b-enable_fp4]
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
- unittest/_torch -k "not (modeling or multi_gpu or auto_deploy)"
@ -23,6 +23,7 @@ l0_b200:
- unittest/_torch/multi_gpu_modeling -k "deepseek and tp1 and nextn0"
- unittest/_torch/multi_gpu_modeling -k "deepseek and tp1 and not nextn0"
- unittest/_torch/auto_deploy/unit/singlegpu
- unittest/_torch/speculative/test_eagle3.py
- examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-disable_fp8-enable_fp4]
- examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-disable_fp8-disable_fp4]
# ------------- TRT tests ---------------

View File

@ -0,0 +1,84 @@
import os
import sys
import unittest
from tensorrt_llm import SamplingParams
from tensorrt_llm._torch import LLM
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
from tensorrt_llm.bindings.executor import KvCacheConfig
from tensorrt_llm.llmapi import EagleDecodingConfig
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
from utils.llm_data import llm_models_root
def test_llama_eagle3():
models_path = llm_models_root()
pytorch_config = PyTorchConfig(
enable_overlap_scheduler=False,
use_cuda_graph=False,
)
kv_cache_config = KvCacheConfig(enable_block_reuse=False, )
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
draft_len = 4
spec_config = EagleDecodingConfig(
max_draft_len=draft_len, pytorch_eagle_weights_path=eagle_model_dir)
llm_spec = LLM(model=target_model_dir,
pytorch_backend_config=pytorch_config,
kv_cache_config=kv_cache_config,
speculative_config=spec_config)
sampling_params = SamplingParams(
max_tokens=32,
temperature=0,
)
# First make sure the acceptance rate is reasonable.
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
num_tokens = 0
num_drafted = 0
num_accepted = 0
for output in llm_spec.generate_async(tok_ids,
SamplingParams(max_tokens=128,
temperature=0),
streaming=True):
beam = output.outputs[0]
new_tokens = beam.token_ids
num_drafted += draft_len
num_accepted += len(new_tokens) - num_tokens - 1
num_tokens = len(new_tokens)
accept_rate = num_accepted / num_drafted
assert accept_rate > 0.25
prompts = [
"The capital of France is", "The president of the United States is"
]
results_spec = llm_spec.generate(prompts, sampling_params)
generated_text_spec = [result.outputs[0].text for result in results_spec]
del llm_spec
llm_ref = LLM(model=target_model_dir,
pytorch_backend_config=pytorch_config,
kv_cache_config=kv_cache_config)
results_ref = llm_ref.generate(prompts, sampling_params)
generated_text_ref = [result.outputs[0].text for result in results_ref]
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
# The spec decode algorithm currently guarantees identical results
assert text_spec == text_ref
if __name__ == "__main__":
unittest.main()