mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Add initial EAGLE-3 implementation (#3035)
Signed-off-by: Mike Iovine <miovine@nvidia.com>
This commit is contained in:
parent
9c484b24e6
commit
5416966ddb
@ -197,6 +197,7 @@ class OPTModel(DecoderModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
|
||||
@ -4,7 +4,7 @@ from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.llmapi import MTPDecodingConfig
|
||||
from tensorrt_llm.llmapi import EagleDecodingConfig, MTPDecodingConfig
|
||||
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
@ -84,7 +84,10 @@ def add_llm_args(parser):
|
||||
parser.add_argument('--load_format', type=str, default='auto')
|
||||
|
||||
# Speculative decoding
|
||||
parser.add_argument('--mtp_nextn', type=int, default=0)
|
||||
parser.add_argument('--spec_decode_algo', type=str, default=None)
|
||||
parser.add_argument('--spec_decode_nextn', type=int, default=1)
|
||||
parser.add_argument('--eagle_model_dir', type=str, default=None)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
@ -111,8 +114,18 @@ def setup_llm(args):
|
||||
free_gpu_memory_fraction=args.kv_cache_fraction,
|
||||
)
|
||||
|
||||
mtp_config = MTPDecodingConfig(
|
||||
num_nextn_predict_layers=args.mtp_nextn) if args.mtp_nextn > 0 else None
|
||||
spec_decode_algo = args.spec_decode_algo.upper(
|
||||
) if args.spec_decode_algo is not None else None
|
||||
|
||||
if spec_decode_algo == 'MTP':
|
||||
spec_config = MTPDecodingConfig(
|
||||
num_nextn_predict_layers=args.spec_decode_nextn)
|
||||
elif spec_decode_algo == "EAGLE3":
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=args.spec_decode_nextn,
|
||||
pytorch_eagle_weights_path=args.eagle_model_dir)
|
||||
else:
|
||||
spec_config = None
|
||||
|
||||
llm = LLM(model=args.model_dir,
|
||||
max_seq_len=args.max_seq_len,
|
||||
@ -126,7 +139,7 @@ def setup_llm(args):
|
||||
moe_expert_parallel_size=args.moe_ep_size,
|
||||
moe_tensor_parallel_size=args.moe_tp_size,
|
||||
enable_chunked_prefill=args.enable_chunked_prefill,
|
||||
speculative_config=mtp_config)
|
||||
speculative_config=spec_config)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=args.max_tokens,
|
||||
|
||||
@ -21,6 +21,7 @@ from ..pyexecutor.resource_manager import KVCacheManager
|
||||
class AttentionRuntimeFeatures:
|
||||
chunked_prefill: bool = False
|
||||
cache_reuse: bool = False
|
||||
has_speculative_draft_tokens: bool = False
|
||||
|
||||
|
||||
@dataclass(kw_only=True)
|
||||
|
||||
@ -630,8 +630,10 @@ class TrtllmAttention(AttentionBackend[TrtllmAttentionMetadata]):
|
||||
)
|
||||
assert not metadata.is_cross, "TRT-LLM Attention does not support cross attention yet."
|
||||
|
||||
use_paged_context_fmha = (metadata.runtime_features.chunked_prefill
|
||||
or metadata.runtime_features.cache_reuse)
|
||||
use_paged_context_fmha = (
|
||||
metadata.runtime_features.chunked_prefill
|
||||
or metadata.runtime_features.cache_reuse
|
||||
or metadata.runtime_features.has_speculative_draft_tokens)
|
||||
|
||||
if use_paged_context_fmha and self.has_fp8_kv_cache:
|
||||
# NOTE: W4A8_AWQ can be included too, exclude for now since
|
||||
|
||||
@ -11,7 +11,14 @@ class AutoModelForCausalLM(Generic[TModel, TConfig]):
|
||||
def from_config(
|
||||
config: ModelConfig[TConfig],
|
||||
) -> DecoderModelForCausalLM[TModel, TConfig]:
|
||||
cls = MODEL_CLASS_MAPPING.get(config.pretrained_config.architectures[0])
|
||||
model_arch = config.pretrained_config.architectures[0]
|
||||
# Hack to detect eagle3 checkpoints. TODO: should we provide
|
||||
# our own checkpoints with the correct arch? It would let us
|
||||
# avoid nasty stuff like this.
|
||||
if hasattr(config.pretrained_config, "draft_vocab_size"):
|
||||
model_arch = "EAGLE3" + model_arch
|
||||
|
||||
cls = MODEL_CLASS_MAPPING.get(model_arch)
|
||||
if cls is None:
|
||||
raise ValueError(
|
||||
f"Unknown architecture for AutoModelForCausalLM: {config.pretrained_config.architectures[0]}"
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Optional, Tuple
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -14,8 +14,10 @@ from ..modules.attention import Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import Linear, WeightMode, WeightsLoadingConfig
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..modules.rotary_embedding import RotaryEmbedding
|
||||
from ..speculative import Eagle3SpecMetadata, SpecMetadata
|
||||
from .modeling_utils import (DecoderModel, DecoderModelForCausalLM,
|
||||
register_auto_model, support_pp)
|
||||
|
||||
@ -74,6 +76,8 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.self_attn = LlamaAttention(
|
||||
model_config,
|
||||
layer_idx=layer_idx,
|
||||
@ -89,6 +93,7 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
@ -99,6 +104,7 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if residual is None:
|
||||
@ -117,6 +123,107 @@ class LlamaDecoderLayer(DecoderLayer):
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if spec_metadata is not None:
|
||||
spec_metadata.maybe_capture_hidden_states(self.layer_idx,
|
||||
hidden_states, residual)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class Eagle3LlamaAttention(LlamaAttention):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[LlamaConfig],
|
||||
layer_idx: Optional[int] = None,
|
||||
):
|
||||
super().__init__(model_config, layer_idx)
|
||||
|
||||
model_config = model_config or ModelConfig()
|
||||
config = model_config.pretrained_config
|
||||
|
||||
tp_size = model_config.mapping.tp_size
|
||||
tp_rank = model_config.mapping.tp_rank
|
||||
gpus_per_node = model_config.mapping.gpus_per_node
|
||||
|
||||
# Override the QKV projection. The number of input features
|
||||
# is twice as big for EAGLE3 draft models.
|
||||
self.qkv_proj = Linear(
|
||||
2 * self.hidden_size,
|
||||
tp_size * self.q_size + 2 * tp_size * self.kv_size,
|
||||
bias=config.attention_bias,
|
||||
dtype=config.torch_dtype,
|
||||
parallel_config=ParallelConfig(
|
||||
tensor_parallel_size=tp_size,
|
||||
tensor_parallel_rank=tp_rank,
|
||||
tensor_parallel_mode=TensorParallelMode.COLUMN,
|
||||
gpus_per_node=gpus_per_node,
|
||||
pipeline_parallel_size=model_config.mapping.pp_size,
|
||||
parallel_rank=model_config.mapping.rank),
|
||||
weights_loading_config=WeightsLoadingConfig(
|
||||
weight_mode=WeightMode.FUSED_QKV_LINEAR),
|
||||
quant_config=model_config.get_quant_config(),
|
||||
skip_create_weights=model_config.skip_create_weights,
|
||||
)
|
||||
|
||||
|
||||
class Eagle3LlamaDecoderLayer(DecoderLayer):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[LlamaConfig],
|
||||
layer_idx: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.self_attn = Eagle3LlamaAttention(
|
||||
model_config,
|
||||
layer_idx=layer_idx,
|
||||
)
|
||||
|
||||
self.mlp = GatedMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
bias=config.mlp_bias,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
)
|
||||
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
self.hidden_norm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
position_ids: torch.LongTensor,
|
||||
embeds: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
attn_metadata: AttentionMetadata,
|
||||
) -> torch.Tensor:
|
||||
residual = hidden_states
|
||||
|
||||
embeds = self.input_layernorm(embeds)
|
||||
hidden_states = self.hidden_norm(hidden_states)
|
||||
|
||||
hidden_states = torch.cat([embeds, hidden_states], dim=-1)
|
||||
|
||||
hidden_states = self.self_attn(
|
||||
position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
)
|
||||
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
@ -161,6 +268,7 @@ class LlamaModel(DecoderModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
@ -177,7 +285,8 @@ class LlamaModel(DecoderModel):
|
||||
hidden_states, residual = decoder_layer(position_ids=position_ids,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
residual=residual)
|
||||
residual=residual,
|
||||
spec_metadata=spec_metadata)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
@ -207,3 +316,123 @@ class MistralForCausalLM(DecoderModelForCausalLM[LlamaModel, LlamaConfig]):
|
||||
config=model_config,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=model_config.pretrained_config.vocab_size)
|
||||
|
||||
|
||||
class Eagle3LlamaDraftModel(DecoderModel):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[LlamaConfig]) -> None:
|
||||
super().__init__(model_config)
|
||||
|
||||
config = model_config.pretrained_config
|
||||
self.dtype = config.torch_dtype
|
||||
|
||||
self.fc = Linear(config.hidden_size * 3,
|
||||
config.hidden_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
self.midlayer = Eagle3LlamaDecoderLayer(model_config, 0)
|
||||
|
||||
self.norm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
|
||||
self.d2t = nn.Parameter(torch.empty((config.draft_vocab_size, ),
|
||||
dtype=torch.int64),
|
||||
requires_grad=False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
embed_tokens: torch.nn.Module,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
|
||||
)
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = embed_tokens(input_ids).to(self.dtype)
|
||||
|
||||
assert hidden_states is not None and len(hidden_states) > 0
|
||||
|
||||
if len(hidden_states) > 1:
|
||||
hidden_states = torch.cat(hidden_states, dim=-1)
|
||||
hidden_states = self.fc(hidden_states.to(self.dtype))
|
||||
else:
|
||||
hidden_states = hidden_states[0].to(self.dtype)
|
||||
|
||||
hidden_states, residual = self.midlayer(position_ids=position_ids,
|
||||
embeds=inputs_embeds,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata)
|
||||
|
||||
hidden_states, hidden_states_to_save = self.norm(
|
||||
hidden_states, residual)
|
||||
assert isinstance(spec_metadata, Eagle3SpecMetadata)
|
||||
spec_metadata.hidden_states.append(hidden_states_to_save)
|
||||
return hidden_states
|
||||
|
||||
|
||||
@register_auto_model("EAGLE3LlamaForCausalLM")
|
||||
class Eagle3LlamaForCausalLM(DecoderModelForCausalLM[Eagle3LlamaDraftModel,
|
||||
LlamaConfig]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[LlamaConfig],
|
||||
):
|
||||
super().__init__(
|
||||
Eagle3LlamaDraftModel(model_config),
|
||||
config=model_config,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=model_config.pretrained_config.draft_vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
attn_metadata: AttentionMetadata,
|
||||
input_ids: torch.LongTensor = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
return_context_logits: bool = False,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
hidden_states: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if "embed_tokens" not in kwargs:
|
||||
raise ValueError(
|
||||
"EAGLE3 checkpoints do not include embed_tokens. "
|
||||
"The embed_tokens module from the target model therefore needs to "
|
||||
"be passed explicitly via extra_model_inputs. NOTE: we can "
|
||||
"get rid of this hack by providing our own custom checkpoint "
|
||||
"format that includes embed_tokens.")
|
||||
|
||||
output = self.model(
|
||||
input_ids=input_ids,
|
||||
embed_tokens=kwargs['embed_tokens'],
|
||||
attn_metadata=attn_metadata,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
spec_metadata=spec_metadata,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
return self.logits_processor.forward(
|
||||
output,
|
||||
self.lm_head,
|
||||
attn_metadata,
|
||||
return_context_logits,
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Dict):
|
||||
new_weights = {}
|
||||
for k, v in weights.items():
|
||||
new_k = "model." + k if 'lm_head' not in k and "embed_tokens" not in k else k
|
||||
new_weights[new_k] = v
|
||||
|
||||
super().load_weights(new_weights)
|
||||
|
||||
@ -206,6 +206,7 @@ class MambaHybridModel(DecoderModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
|
||||
@ -211,6 +211,7 @@ class MixtralModel(DecoderModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
|
||||
@ -177,6 +177,7 @@ class NemotronModel(DecoderModel):
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
|
||||
@ -172,6 +172,7 @@ class QwenModel(DecoderModel):
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
mrope_config: Optional[Tuple[torch.Tensor, int]] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
|
||||
@ -22,6 +22,7 @@ from ..modules.linear import Linear, WeightMode
|
||||
from ..modules.logits_procesor import LogitsProcessor
|
||||
from ..modules.rms_norm import RMSNorm
|
||||
from ..pipeline_interface import PipelineInterface
|
||||
from ..speculative import SpecMetadata
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -213,6 +214,7 @@ class DecoderModel(nn.Module, metaclass=PPInitCaller):
|
||||
input_ids: torch.LongTensor = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
if (input_ids is None) ^ (inputs_embeds is not None):
|
||||
raise ValueError(
|
||||
@ -469,9 +471,9 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
pipeline_interface: Optional[PipelineInterface] = None,
|
||||
return_context_logits: bool = False,
|
||||
spec_metadata: Optional[SpecMetadata] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self._supports_pp and self.pp_size > 1:
|
||||
output = self.model(
|
||||
input_ids=input_ids,
|
||||
@ -479,6 +481,7 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
pipeline_interface=pipeline_interface,
|
||||
spec_metadata=spec_metadata,
|
||||
)
|
||||
|
||||
# No need to compute logits for non-last PP ranks
|
||||
@ -490,6 +493,7 @@ class DecoderModelForCausalLM(nn.Module,
|
||||
attn_metadata=attn_metadata,
|
||||
position_ids=position_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
spec_metadata=spec_metadata,
|
||||
)
|
||||
|
||||
return self.logits_processor.forward(
|
||||
|
||||
@ -41,6 +41,7 @@ def check_flash_mla_config(config):
|
||||
|
||||
def cal_max_tokens(peak_memory, total_gpu_memory, fraction, model_config,
|
||||
mapping: Mapping):
|
||||
# TODO: take space occupied by draft KV cache manager into account.
|
||||
mem_per_token = 2
|
||||
quant_config = model_config.quant_config
|
||||
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache():
|
||||
|
||||
@ -276,6 +276,7 @@ class TorchDecoder(Decoder):
|
||||
generation_requests.append(request)
|
||||
|
||||
for request in extend_requests:
|
||||
num_accepted = 0
|
||||
if request.state != LlmRequestState.GENERATION_COMPLETE:
|
||||
new_token = new_tokens_list[idx]
|
||||
num_tokens = request.add_new_token(new_token, beam_idx)
|
||||
@ -290,12 +291,15 @@ class TorchDecoder(Decoder):
|
||||
# Reject.
|
||||
break
|
||||
|
||||
num_accepted += 1
|
||||
new_token = new_tokens_list[idx + i + 1]
|
||||
num_tokens = request.add_new_token(new_token, beam_idx)
|
||||
|
||||
if self._handle_stop_criteria(request, new_token,
|
||||
num_tokens, beam_idx):
|
||||
break
|
||||
request.py_num_accepted_draft_tokens = num_accepted
|
||||
request.py_rewind_len = request.py_draft_pages_allocated - num_accepted
|
||||
idx += len(request.py_draft_tokens) + 1
|
||||
|
||||
for request in generation_requests:
|
||||
|
||||
@ -47,6 +47,7 @@ class LlmRequest(tensorrt_llm.bindings.internal.batch_manager.LlmRequest):
|
||||
self.py_batch_idx = None
|
||||
self.py_rewind_len = 0
|
||||
self.py_draft_tokens = self.draft_tokens
|
||||
self.py_last_draft_tokens = None
|
||||
|
||||
|
||||
def convert_wordlist(word_list) -> List[List[int]]:
|
||||
|
||||
@ -50,8 +50,9 @@ class ModelEngine(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, scheduled_requests: ScheduledRequests,
|
||||
resource_manager: ResourceManager,
|
||||
new_tensors_device: Optional[Dict[str, torch.Tensor]],
|
||||
resource_manager: ResourceManager):
|
||||
extra_model_inputs: Optional[Dict[str, Any]]):
|
||||
raise NotImplementedError
|
||||
|
||||
def warmup(self, resource_manager: ResourceManager) -> None:
|
||||
@ -207,6 +208,10 @@ def initialize_dummy_weights(
|
||||
param.uniform_(low, high, generator=generator)
|
||||
|
||||
|
||||
KV_CACHE_MANAGER_KEY = 'kv_cache_manager'
|
||||
DRAFT_KV_CACHE_MANAGER_KEY = 'draft_kv_cache_manager'
|
||||
|
||||
|
||||
class PyTorchModelEngine(ModelEngine):
|
||||
|
||||
def __init__(
|
||||
@ -233,6 +238,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.dist = dist
|
||||
self.pytorch_backend_config = pytorch_backend_config
|
||||
self.spec_config = spec_config
|
||||
# We keep a reference to the last used spec metadata to
|
||||
# accommodate certain target/draft model use cases. See
|
||||
# py_executor.py for how this is used.
|
||||
self.last_spec_metadata = None
|
||||
|
||||
self.attn_runtime_features = attn_runtime_features or AttentionRuntimeFeatures(
|
||||
)
|
||||
@ -347,9 +356,14 @@ class PyTorchModelEngine(ModelEngine):
|
||||
self.max_draft_len = 0
|
||||
self.iter_counter = 0
|
||||
|
||||
# We look up this key in resource_manager during forward to find the
|
||||
# kv cache manager. Can be changed to support multiple model engines
|
||||
# with different KV cache managers.
|
||||
self.kv_cache_manager_key = KV_CACHE_MANAGER_KEY
|
||||
|
||||
def warmup(self, resource_manager: ResourceManager) -> None:
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
'kv_cache_manager')
|
||||
self.kv_cache_manager_key)
|
||||
spec_resource_manager = resource_manager.get_resource_manager(
|
||||
'spec_resource_manager')
|
||||
if kv_cache_manager is None:
|
||||
@ -464,6 +478,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if cp_type == 'star_attention':
|
||||
return
|
||||
|
||||
# TODO: CUDA graph support with eagle.
|
||||
if self.spec_config is not None and self.spec_config.spec_dec_mode.is_eagle3(
|
||||
):
|
||||
return
|
||||
|
||||
if self._torch_compile_enabled:
|
||||
# Disable cuda graph capture here so that we can properly capture it later
|
||||
with no_cuda_graph():
|
||||
@ -817,6 +836,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
inputs['attn_metadata'].kv_lens_cuda[
|
||||
num_ctx_requests:num_seqs] += (
|
||||
self.previous_kv_lens_offsets_cuda[:num_gen_requests])
|
||||
|
||||
return inputs
|
||||
|
||||
def _prepare_tp_inputs(
|
||||
@ -918,6 +938,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
past_seen_token_num = request.max_beam_num_tokens - 1
|
||||
position_ids.append(past_seen_token_num)
|
||||
draft_lens.append(num_draft_tokens)
|
||||
prompt_lengths.append(num_draft_tokens + 1)
|
||||
# draft tokens
|
||||
input_ids.extend(request.py_draft_tokens)
|
||||
gather_ids.extend(
|
||||
@ -955,9 +976,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
(1 + self.max_draft_len))
|
||||
num_cached_tokens_per_seq.append(past_seen_token_num +
|
||||
self.max_draft_len + 1)
|
||||
prompt_lengths.append(request.py_prompt_len)
|
||||
|
||||
request_ids.append(request.py_request_id)
|
||||
prompt_lengths.append(request.py_prompt_len)
|
||||
|
||||
sequence_lengths.extend([1] * len(generation_requests))
|
||||
gather_ids.extend(
|
||||
@ -1075,6 +1096,9 @@ class PyTorchModelEngine(ModelEngine):
|
||||
attn_metadata.request_ids = request_ids
|
||||
attn_metadata.prompt_lens = prompt_lengths
|
||||
attn_metadata.num_contexts = len(scheduled_requests.context_requests)
|
||||
if self.spec_config is not None and self.spec_config.spec_dec_mode.extend_ctx(
|
||||
):
|
||||
attn_metadata.num_contexts += len(extend_requests)
|
||||
|
||||
attn_metadata.kv_cache_params = KVCacheParams(
|
||||
use_cache=True,
|
||||
@ -1108,6 +1132,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_metadata.num_generations = len(
|
||||
scheduled_requests.generation_requests)
|
||||
spec_metadata.num_tokens = total_num_tokens
|
||||
spec_metadata.seq_lens = sequence_lengths
|
||||
spec_metadata.prepare()
|
||||
inputs['spec_metadata'] = spec_metadata
|
||||
|
||||
@ -1223,6 +1248,7 @@ class PyTorchModelEngine(ModelEngine):
|
||||
spec_metadata.num_generations = len(
|
||||
scheduled_requests.generation_requests)
|
||||
spec_metadata.num_tokens = num_tokens
|
||||
spec_metadata.seq_lens = sequence_lengths
|
||||
spec_metadata.prepare()
|
||||
inputs['spec_metadata'] = spec_metadata
|
||||
|
||||
@ -1505,10 +1531,11 @@ class PyTorchModelEngine(ModelEngine):
|
||||
def forward(self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
resource_manager: ResourceManager,
|
||||
new_tensors_device: Optional[Dict[str, torch.Tensor]] = None):
|
||||
new_tensors_device: Optional[Dict[str, torch.Tensor]] = None,
|
||||
extra_model_inputs: Optional[Dict[str, Any]] = None):
|
||||
|
||||
kv_cache_manager = resource_manager.get_resource_manager(
|
||||
'kv_cache_manager')
|
||||
self.kv_cache_manager_key)
|
||||
|
||||
attn_metadata = self._set_up_attn_metadata(kv_cache_manager)
|
||||
if self.spec_config is not None:
|
||||
@ -1523,6 +1550,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
if kv_cache_manager is None:
|
||||
inputs, gather_ids = self._prepare_tp_inputs_no_cache(
|
||||
scheduled_requests, attn_metadata, spec_metadata)
|
||||
if extra_model_inputs is not None:
|
||||
inputs.update(extra_model_inputs)
|
||||
self.last_spec_metadata = spec_metadata
|
||||
|
||||
if self.mapping.has_pp() and not self.mapping.is_last_pp_rank():
|
||||
pp_interface = self._forward_step_intermediate(inputs)
|
||||
pp_interface.send()
|
||||
@ -1549,6 +1580,10 @@ class PyTorchModelEngine(ModelEngine):
|
||||
attn_metadata,
|
||||
spec_metadata,
|
||||
new_tensors_device)
|
||||
if extra_model_inputs is not None:
|
||||
inputs.update(extra_model_inputs)
|
||||
self.last_spec_metadata = spec_metadata
|
||||
|
||||
self.iter_counter += 1
|
||||
|
||||
if maybe_graph is None:
|
||||
|
||||
@ -11,7 +11,7 @@ import weakref
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import dill # nosec B403
|
||||
import numpy as np
|
||||
@ -139,7 +139,8 @@ class PyExecutor:
|
||||
enable_overlap_scheduler: bool = False,
|
||||
max_input_len: int = 2048,
|
||||
max_batch_size: int = 8,
|
||||
kv_cache_transceiver: KvCacheTransceiver = None):
|
||||
kv_cache_transceiver: KvCacheTransceiver = None,
|
||||
draft_model_engine: Optional[ModelEngine] = None):
|
||||
super(PyExecutor, self).__init__()
|
||||
self.device_id = torch.cuda.current_device()
|
||||
self.global_rank = global_mpi_rank()
|
||||
@ -158,6 +159,9 @@ class PyExecutor:
|
||||
self.decoder = decoder
|
||||
self.dist = dist
|
||||
|
||||
# Draft model for certain spec decode algorithms, e.g. EAGLE3
|
||||
self.draft_model_engine = draft_model_engine
|
||||
|
||||
# enqueue and _fetch_new_requests used data
|
||||
self.enqueue_lock = threading.Lock()
|
||||
self.active = True
|
||||
@ -178,6 +182,12 @@ class PyExecutor:
|
||||
"kv_cache_manager")
|
||||
self.enable_kv_cache_events = self.kv_cache_manager is not None and self.kv_cache_manager.event_buffer_max_size > 0
|
||||
|
||||
if self.draft_model_engine is not None and self.kv_cache_manager is not None:
|
||||
if self.kv_cache_manager.enable_block_reuse:
|
||||
raise NotImplementedError(
|
||||
"Draft model engine + KV cache reuse is not supported yet. "
|
||||
"This will be fixed in the near future!")
|
||||
|
||||
self.max_input_len = max_input_len
|
||||
# _executor_loop private data
|
||||
self.max_num_active_requests = model_engine.get_max_num_sequences()
|
||||
@ -201,6 +211,8 @@ class PyExecutor:
|
||||
self.canceled_req_ids = ReqIdsSet()
|
||||
|
||||
self.model_engine.warmup(self.resource_manager)
|
||||
if self.draft_model_engine is not None:
|
||||
self.draft_model_engine.warmup(self.resource_manager)
|
||||
|
||||
self.is_shutdown = False
|
||||
|
||||
@ -214,6 +226,12 @@ class PyExecutor:
|
||||
event_loop = self._executor_loop_pp_overlap if enable_overlap_scheduler else self._executor_loop_pp
|
||||
else:
|
||||
event_loop = self._executor_loop_overlap if enable_overlap_scheduler else self._executor_loop
|
||||
|
||||
if self.draft_model_engine is not None and event_loop.__name__ != self._executor_loop.__name__:
|
||||
raise NotImplementedError(
|
||||
"Drafting is not supported for selected executor loop. "
|
||||
"Please disable disagg/pipeline parallelism/overlap scheduler.")
|
||||
|
||||
self.worker_thread = threading.Thread(target=event_loop, daemon=True)
|
||||
self.worker_thread.start()
|
||||
|
||||
@ -727,6 +745,10 @@ class PyExecutor:
|
||||
num_dummy_request = self._get_num_dummy_request()
|
||||
if num_dummy_request > 0:
|
||||
self._merge_dummy_request(num_dummy_request)
|
||||
|
||||
if self.draft_model_engine is not None:
|
||||
self._prepare_draft_requests()
|
||||
|
||||
scheduled_batch, fitting_disagg_gen_init_requests, num_fitting_reqs = self._schedule(
|
||||
)
|
||||
|
||||
@ -750,6 +772,8 @@ class PyExecutor:
|
||||
|
||||
if scheduled_batch.batch_size > 0:
|
||||
self.resource_manager.prepare_resources(scheduled_batch)
|
||||
if self.draft_model_engine is not None:
|
||||
self._prepare_draft_tokens(scheduled_batch)
|
||||
|
||||
if self.kv_cache_transceiver:
|
||||
# For generation requests which have completed KV cache transfer
|
||||
@ -789,6 +813,35 @@ class PyExecutor:
|
||||
self.response_cv.notify_all()
|
||||
self.shutdown_event.set()
|
||||
|
||||
def _prepare_draft_requests(self):
|
||||
try:
|
||||
# Set draft tokens here to make the KV cache manager
|
||||
# and scheduler aware of them.
|
||||
for req in self.active_requests:
|
||||
if req.state != LlmRequestState.GENERATION_IN_PROGRESS:
|
||||
continue
|
||||
req.py_last_draft_tokens = req.py_draft_tokens
|
||||
max_draft_len = self.model_engine.spec_config.max_draft_tokens
|
||||
max_seq_len = self.model_engine.max_seq_len
|
||||
|
||||
# Subtract 1 to account for the token we will add on this forward
|
||||
# pass.
|
||||
draft_len = min(max_seq_len - 1 - req.get_num_tokens(0),
|
||||
max_draft_len)
|
||||
|
||||
if draft_len > 0:
|
||||
req.py_draft_tokens = [0] * draft_len
|
||||
req.py_draft_pages_allocated = draft_len
|
||||
else:
|
||||
req.py_draft_tokens = None
|
||||
req.py_draft_pages_allocated = 0
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
error_msg = str(e)
|
||||
logger.error(f"Encountered an error in decode: {error_msg}")
|
||||
self._handle_errors(error_msg)
|
||||
|
||||
def _executor_loop_overlap(self):
|
||||
torch.cuda.set_device(self.device_id)
|
||||
got_finish_signal = False
|
||||
@ -1531,6 +1584,189 @@ class PyExecutor:
|
||||
logger.error(f"Encountered an error in decode: {error_msg}")
|
||||
self._handle_errors(error_msg)
|
||||
|
||||
@nvtx_range("_prepare_draft_batch")
|
||||
def _prepare_draft_batch(
|
||||
self, scheduled_requests: ScheduledRequests
|
||||
) -> Tuple[ScheduledRequests, Dict[int, LlmRequest]]:
|
||||
"""
|
||||
Prepares a batch for the draft model engine. Draft tokens are only produced
|
||||
for generation requests.
|
||||
|
||||
The requests are prepared as follows:
|
||||
1. The first time the draft engine sees a request, it's a context request.
|
||||
2. Otherwise, if draft tokens were accepted on the last target model decoding
|
||||
step, it's a chunked context request (we process all the accepted tokens together).
|
||||
3. Otherwise, it's a generation request.
|
||||
"""
|
||||
try:
|
||||
draft_batch = ScheduledRequests()
|
||||
req_id_to_num_rejected_tokens = {}
|
||||
|
||||
for request in scheduled_requests.generation_requests:
|
||||
if request.py_draft_pages_allocated == 0:
|
||||
# No space for draft tokens.
|
||||
continue
|
||||
|
||||
num_draft_tokens = len(
|
||||
request.py_last_draft_tokens
|
||||
) if request.py_last_draft_tokens is not None else 0
|
||||
request.py_draft_tokens = []
|
||||
|
||||
num_accepted_tokens = getattr(request,
|
||||
"py_num_accepted_draft_tokens", 0)
|
||||
num_rejected_tokens = num_draft_tokens - num_accepted_tokens
|
||||
assert num_rejected_tokens >= 0
|
||||
req_id_to_num_rejected_tokens[
|
||||
request.py_request_id] = num_rejected_tokens
|
||||
|
||||
spec_config = self.model_engine.spec_config
|
||||
beam_idx = 0
|
||||
input_tokens = spec_config.get_draft_model_prompt(
|
||||
request.get_tokens()[beam_idx])
|
||||
|
||||
if request.max_beam_num_tokens - 1 == request.py_prompt_len:
|
||||
# This is the first time the draft model is seeing this request.
|
||||
# Prepare a context request. We discard the first token and take
|
||||
# the newly decoded one - this is the convention for EAGLE 2 and 3.
|
||||
assert num_draft_tokens == 0
|
||||
new_request = LlmRequest(
|
||||
request_id=request.py_request_id,
|
||||
max_new_tokens=request.py_max_new_tokens,
|
||||
input_tokens=input_tokens,
|
||||
sampling_config=request.sampling_config,
|
||||
is_streaming=False)
|
||||
|
||||
draft_batch.context_requests.append(new_request)
|
||||
elif getattr(request, "py_num_accepted_draft_tokens", 0) == 0:
|
||||
new_request = LlmRequest(
|
||||
request_id=request.py_request_id,
|
||||
max_new_tokens=request.py_max_new_tokens,
|
||||
input_tokens=input_tokens[:-1],
|
||||
sampling_config=request.sampling_config,
|
||||
is_streaming=False)
|
||||
# Explicitly add the last token so get_last_tokens() returns
|
||||
# the right value
|
||||
new_request.add_new_token(input_tokens[-1], beam_idx)
|
||||
new_request.state = LlmRequestState.GENERATION_IN_PROGRESS
|
||||
draft_batch.generation_requests.append(new_request)
|
||||
else:
|
||||
new_request = LlmRequest(
|
||||
request_id=request.py_request_id,
|
||||
max_new_tokens=request.py_max_new_tokens,
|
||||
input_tokens=input_tokens,
|
||||
sampling_config=request.sampling_config,
|
||||
is_streaming=False)
|
||||
new_request.context_chunk_size = num_accepted_tokens + 1
|
||||
new_request.context_current_position = len(
|
||||
input_tokens) - num_accepted_tokens - 1
|
||||
|
||||
draft_batch.context_requests.append(new_request)
|
||||
|
||||
new_request.py_stop_words_list = request.py_stop_words_list
|
||||
new_request.is_dummy = False
|
||||
|
||||
return draft_batch, req_id_to_num_rejected_tokens
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
error_msg = str(e)
|
||||
logger.error(f"Encountered an error in decode: {error_msg}")
|
||||
self._handle_errors(error_msg)
|
||||
|
||||
@nvtx_range("_prepare_draft_tokens")
|
||||
def _prepare_draft_tokens(self, scheduled_requests: ScheduledRequests):
|
||||
try:
|
||||
draft_batch, num_rejected_tokens = self._prepare_draft_batch(
|
||||
scheduled_requests)
|
||||
|
||||
if draft_batch.batch_size == 0:
|
||||
return
|
||||
|
||||
req_id_to_old_request = {
|
||||
req.py_request_id: req
|
||||
for req in chain(scheduled_requests.context_requests,
|
||||
scheduled_requests.generation_requests)
|
||||
}
|
||||
|
||||
spec_metadata = self.model_engine.last_spec_metadata
|
||||
|
||||
hidden_states = spec_metadata.get_hidden_states(
|
||||
draft_batch, num_rejected_tokens)
|
||||
|
||||
extra_model_inputs = {'hidden_states': hidden_states}
|
||||
|
||||
if spec_metadata.spec_dec_mode.is_eagle3():
|
||||
# Another eagle3 hack. Eagle3 checkpoints don't have embed_tokens,
|
||||
# so we need to provide them some other way. We can get rid of this
|
||||
# hack if we provide our own preprocessed eagle3 checkpoints.
|
||||
extra_model_inputs[
|
||||
'embed_tokens'] = self.model_engine.model.model.embed_tokens
|
||||
|
||||
outputs = self.draft_model_engine.forward(
|
||||
draft_batch,
|
||||
self.resource_manager,
|
||||
extra_model_inputs=extra_model_inputs)
|
||||
|
||||
if spec_metadata.spec_dec_mode.is_eagle3():
|
||||
outputs['d2t'] = self.draft_model_engine.model.model.d2t.data
|
||||
|
||||
self._update_request_states(draft_batch)
|
||||
|
||||
self._decode(draft_batch, outputs)
|
||||
|
||||
def _process_decoded_tokens():
|
||||
new_requests = []
|
||||
for req in chain(draft_batch.context_requests,
|
||||
draft_batch.generation_requests):
|
||||
target_model_req = req_id_to_old_request[req.py_request_id]
|
||||
target_model_req.py_draft_tokens.append(
|
||||
req.get_last_tokens(0))
|
||||
if req.state != LlmRequestState.GENERATION_COMPLETE and len(
|
||||
target_model_req.py_draft_tokens
|
||||
) < target_model_req.py_draft_pages_allocated:
|
||||
new_requests.append(req)
|
||||
|
||||
return new_requests
|
||||
|
||||
new_requests = _process_decoded_tokens()
|
||||
if not new_requests:
|
||||
return
|
||||
|
||||
draft_batch.generation_requests = new_requests
|
||||
draft_batch.context_requests = []
|
||||
|
||||
for _ in range(spec_metadata.max_draft_tokens - 1):
|
||||
draft_spec_metadata = self.draft_model_engine.spec_metadata
|
||||
hidden_states = draft_spec_metadata.get_hidden_states(
|
||||
draft_batch)
|
||||
extra_model_inputs = {'hidden_states': hidden_states}
|
||||
if spec_metadata.spec_dec_mode.is_eagle3():
|
||||
# See note above.
|
||||
extra_model_inputs[
|
||||
'embed_tokens'] = self.model_engine.model.model.embed_tokens
|
||||
|
||||
outputs = self.draft_model_engine.forward(
|
||||
draft_batch,
|
||||
self.resource_manager,
|
||||
extra_model_inputs=extra_model_inputs)
|
||||
|
||||
if spec_metadata.spec_dec_mode.is_eagle3():
|
||||
outputs[
|
||||
'd2t'] = self.draft_model_engine.model.model.d2t.data
|
||||
self._update_request_states(draft_batch)
|
||||
self._decode(draft_batch, outputs)
|
||||
|
||||
new_requests = _process_decoded_tokens()
|
||||
if not new_requests:
|
||||
return
|
||||
draft_batch.generation_requests = new_requests
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
error_msg = str(e)
|
||||
logger.error(f"Encountered an error in decode: {error_msg}")
|
||||
self._handle_errors(error_msg)
|
||||
|
||||
def _handle_errors(self, error_msg: Optional[str] = None):
|
||||
error_responses = {}
|
||||
error_msg = error_msg or "error"
|
||||
|
||||
@ -9,7 +9,7 @@ from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.mapping import Mapping
|
||||
|
||||
from ..attention_backend.interface import AttentionRuntimeFeatures
|
||||
from ..speculative import (get_num_spec_layers, get_spec_decoder,
|
||||
from ..speculative import (Eagle3Config, get_num_spec_layers, get_spec_decoder,
|
||||
get_spec_resource_manager)
|
||||
from ._util import check_flash_mla_config, estimate_max_kv_cache_tokens, is_mla
|
||||
from .config import PyTorchConfig
|
||||
@ -18,13 +18,76 @@ from .decoder import (EarlyStopDecoder, TorchDecoder, TorchStarAttentionDecoder,
|
||||
from .distributed import MPIDist
|
||||
from .guided_decoder import GuidedDecoderResourceManager
|
||||
from .kv_cache_transceiver import AttentionTypeCpp, create_kv_cache_transceiver
|
||||
from .model_engine import PyTorchModelEngine
|
||||
from .model_engine import (DRAFT_KV_CACHE_MANAGER_KEY, KV_CACHE_MANAGER_KEY,
|
||||
PyTorchModelEngine)
|
||||
from .py_executor import PyExecutor
|
||||
from .resource_manager import KVCacheManager, ResourceManager
|
||||
from .scheduler import (BindCapacityScheduler, BindMicroBatchScheduler,
|
||||
SimpleScheduler)
|
||||
|
||||
|
||||
def _create_kv_cache_manager(model_engine: PyTorchModelEngine, mapping: Mapping,
|
||||
executor_config: ExecutorConfig) -> KVCacheManager:
|
||||
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
quant_config = model_engine.model.model_config.quant_config
|
||||
spec_config = executor_config.speculative_config
|
||||
|
||||
hidden_size = config.hidden_size
|
||||
num_attention_heads = config.num_attention_heads
|
||||
num_key_value_heads = getattr(config, 'num_key_value_heads',
|
||||
num_attention_heads)
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
|
||||
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache():
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8
|
||||
else:
|
||||
kv_cache_dtype = str_dtype_to_binding(
|
||||
torch_dtype_to_str(model_engine.dtype))
|
||||
|
||||
num_hidden_layers = len(mapping.pp_layers_torch(config.num_hidden_layers))
|
||||
# the number of layers using attention in Nemotron5 is lower than the number of hidden layers
|
||||
if config.architectures[0] == "Nemotron5ForCausalLM":
|
||||
# attention layers are derived from configuration (hybrid_override_pattern)
|
||||
num_hidden_layers = config.hybrid_override_pattern.count("*")
|
||||
|
||||
if is_mla(config):
|
||||
if spec_config is not None:
|
||||
num_hidden_layers += get_num_spec_layers(spec_config)
|
||||
|
||||
return KVCacheManager(
|
||||
executor_config.kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELFKONLY,
|
||||
num_layers=num_hidden_layers,
|
||||
num_kv_heads=1,
|
||||
head_dim=config.kv_lora_rank + config.qk_rope_head_dim,
|
||||
tokens_per_block=executor_config.tokens_per_block,
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
num_extra_kv_tokens=0
|
||||
if spec_config is None else spec_config.num_extra_kv_tokens,
|
||||
)
|
||||
else:
|
||||
if spec_config is not None:
|
||||
num_hidden_layers += get_num_spec_layers(spec_config)
|
||||
return KVCacheManager(
|
||||
executor_config.kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_hidden_layers,
|
||||
num_kv_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=executor_config.tokens_per_block,
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
num_extra_kv_tokens=0
|
||||
if spec_config is None else spec_config.num_extra_kv_tokens,
|
||||
)
|
||||
|
||||
|
||||
def create_py_executor(executor_config: ExecutorConfig,
|
||||
checkpoint_dir: str = None,
|
||||
engine_dir: str = None):
|
||||
@ -32,7 +95,6 @@ def create_py_executor(executor_config: ExecutorConfig,
|
||||
executor_config.pytorch_backend_config = PyTorchConfig()
|
||||
|
||||
pytorch_backend_config = executor_config.pytorch_backend_config
|
||||
spec_config = executor_config.speculative_config
|
||||
|
||||
if executor_config.mapping is None:
|
||||
mapping = Mapping(world_size=tensorrt_llm.mpi_world_size(),
|
||||
@ -65,9 +127,13 @@ def create_py_executor(executor_config: ExecutorConfig,
|
||||
executor_config.max_num_tokens = 8192
|
||||
dist = MPIDist(mapping=mapping)
|
||||
|
||||
spec_config = executor_config.speculative_config
|
||||
has_draft_model_engine = isinstance(spec_config, Eagle3Config)
|
||||
|
||||
attn_runtime_features = AttentionRuntimeFeatures(
|
||||
chunked_prefill=executor_config.enable_chunked_context,
|
||||
cache_reuse=executor_config.kv_cache_config.enable_block_reuse,
|
||||
has_speculative_draft_tokens=has_draft_model_engine,
|
||||
)
|
||||
|
||||
model_engine = PyTorchModelEngine(
|
||||
@ -82,6 +148,23 @@ def create_py_executor(executor_config: ExecutorConfig,
|
||||
spec_config=spec_config,
|
||||
guided_decoding_config=executor_config.guided_decoding_config,
|
||||
)
|
||||
|
||||
if has_draft_model_engine:
|
||||
draft_model_engine = PyTorchModelEngine(
|
||||
spec_config.eagle_weights_path,
|
||||
pytorch_backend_config,
|
||||
batch_size=executor_config.max_batch_size,
|
||||
max_num_tokens=executor_config.max_num_tokens,
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
mapping=mapping,
|
||||
attn_runtime_features=attn_runtime_features,
|
||||
dist=dist,
|
||||
spec_config=copy.copy(spec_config),
|
||||
)
|
||||
draft_model_engine.kv_cache_manager_key = DRAFT_KV_CACHE_MANAGER_KEY
|
||||
else:
|
||||
draft_model_engine = None
|
||||
|
||||
# PyTorchModelEngine modifies these fields, update them to executor_config
|
||||
if pytorch_backend_config.enable_overlap_scheduler:
|
||||
max_seq_len = model_engine.max_seq_len + 1
|
||||
@ -98,19 +181,6 @@ def create_py_executor(executor_config: ExecutorConfig,
|
||||
#NOTE: non-generation models do not have kv cache
|
||||
executor_config.pytorch_backend_config.use_kv_cache = False
|
||||
|
||||
hidden_size = model_engine.model.config.hidden_size
|
||||
num_attention_heads = model_engine.model.config.num_attention_heads
|
||||
num_key_value_heads = getattr(model_engine.model.config,
|
||||
'num_key_value_heads', num_attention_heads)
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
|
||||
quant_config = model_engine.model.model_config.quant_config
|
||||
if quant_config is not None and quant_config.quant_mode.has_fp8_kv_cache():
|
||||
kv_cache_dtype = tensorrt_llm.bindings.DataType.FP8
|
||||
else:
|
||||
kv_cache_dtype = str_dtype_to_binding(
|
||||
torch_dtype_to_str(model_engine.dtype))
|
||||
|
||||
kv_cache_max_tokens = None
|
||||
if model_engine.model.model_config.is_generation:
|
||||
kv_cache_max_tokens = estimate_max_kv_cache_tokens(
|
||||
@ -131,70 +201,37 @@ def create_py_executor(executor_config: ExecutorConfig,
|
||||
ctx_chunk_config = None
|
||||
|
||||
config = model_engine.model.model_config.pretrained_config
|
||||
# kv cache manager selection
|
||||
if is_mla(config):
|
||||
if check_flash_mla_config(config):
|
||||
executor_config.tokens_per_block = 64
|
||||
logger.info(
|
||||
f"Change tokens_per_block to: {executor_config.tokens_per_block} for using FlashMLA"
|
||||
)
|
||||
executor_config.kv_cache_config.enable_block_reuse = False
|
||||
executor_config.enable_chunked_context = False
|
||||
|
||||
if executor_config.pytorch_backend_config.use_kv_cache:
|
||||
num_hidden_layers = len(
|
||||
mapping.pp_layers_torch(
|
||||
model_engine.model.config.num_hidden_layers))
|
||||
# has kv cache
|
||||
if is_mla(config):
|
||||
if check_flash_mla_config(config):
|
||||
executor_config.tokens_per_block = 64
|
||||
logger.info(
|
||||
f"Change tokens_per_block to: {executor_config.tokens_per_block} for using FlashMLA"
|
||||
)
|
||||
executor_config.kv_cache_config.enable_block_reuse = False
|
||||
executor_config.enable_chunked_context = False
|
||||
if spec_config is not None:
|
||||
num_hidden_layers += get_num_spec_layers(spec_config)
|
||||
kv_cache_manager = KVCacheManager(
|
||||
executor_config.kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.
|
||||
SELFKONLY,
|
||||
num_layers=num_hidden_layers,
|
||||
num_kv_heads=1,
|
||||
head_dim=config.kv_lora_rank + config.qk_rope_head_dim,
|
||||
tokens_per_block=executor_config.tokens_per_block,
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
num_extra_kv_tokens=0
|
||||
if spec_config is None else spec_config.num_extra_kv_tokens,
|
||||
)
|
||||
else:
|
||||
# the number of layers using attention in Nemotron5 is lower from the number of hidden layers
|
||||
if model_engine.model.config.architectures[
|
||||
0] == "Nemotron5ForCausalLM":
|
||||
# attention layers are derived from configuration (hybrid_override_pattern)
|
||||
num_hidden_layers = model_engine.model.config.hybrid_override_pattern.count(
|
||||
"*")
|
||||
kv_cache_manager = KVCacheManager(
|
||||
executor_config.kv_cache_config,
|
||||
tensorrt_llm.bindings.internal.batch_manager.CacheType.SELF,
|
||||
num_layers=num_hidden_layers,
|
||||
num_kv_heads=num_key_value_heads,
|
||||
head_dim=head_dim,
|
||||
tokens_per_block=executor_config.tokens_per_block,
|
||||
max_seq_len=executor_config.max_seq_len,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
mapping=mapping,
|
||||
dtype=kv_cache_dtype,
|
||||
num_extra_kv_tokens=0
|
||||
if spec_config is None else spec_config.num_extra_kv_tokens,
|
||||
)
|
||||
kv_cache_manager = _create_kv_cache_manager(model_engine, mapping,
|
||||
executor_config)
|
||||
|
||||
draft_kv_cache_manager = _create_kv_cache_manager(
|
||||
draft_model_engine, mapping,
|
||||
executor_config) if draft_model_engine is not None else None
|
||||
else:
|
||||
# no kv cache
|
||||
kv_cache_manager = None
|
||||
draft_kv_cache_manager = None
|
||||
|
||||
# KVCacheManager modifies these fields, update them to executor_config
|
||||
if kv_cache_manager is not None:
|
||||
executor_config.max_seq_len = kv_cache_manager.max_seq_len
|
||||
|
||||
resources = {
|
||||
"kv_cache_manager": kv_cache_manager
|
||||
KV_CACHE_MANAGER_KEY: kv_cache_manager
|
||||
} if kv_cache_manager is not None else {}
|
||||
|
||||
if draft_kv_cache_manager is not None:
|
||||
resources[DRAFT_KV_CACHE_MANAGER_KEY] = draft_kv_cache_manager
|
||||
|
||||
if spec_config is not None:
|
||||
spec_resource_manager = get_spec_resource_manager(
|
||||
spec_config, model_engine.model.config, model_engine.batch_size * 2)
|
||||
@ -225,10 +262,11 @@ def create_py_executor(executor_config: ExecutorConfig,
|
||||
resources[key] = value
|
||||
|
||||
resource_manager = ResourceManager(resources)
|
||||
|
||||
# Make sure the kv cache manager is always invoked last as it could
|
||||
# depend on the results of other resource managers.
|
||||
if kv_cache_manager is not None:
|
||||
resource_manager.resource_managers.move_to_end("kv_cache_manager",
|
||||
resource_manager.resource_managers.move_to_end(KV_CACHE_MANAGER_KEY,
|
||||
last=True)
|
||||
|
||||
num_micro_batches = 1
|
||||
@ -276,5 +314,6 @@ def create_py_executor(executor_config: ExecutorConfig,
|
||||
enable_overlap_scheduler,
|
||||
max_batch_size=executor_config.max_batch_size,
|
||||
max_input_len=executor_config.max_input_len,
|
||||
kv_cache_transceiver=kv_cache_transceiver)
|
||||
kv_cache_transceiver=kv_cache_transceiver,
|
||||
draft_model_engine=draft_model_engine)
|
||||
return py_executor
|
||||
|
||||
@ -210,6 +210,7 @@ class KVCacheManager(BaseResourceManager):
|
||||
self.kv_cache_pool_mapping = self.impl.get_layer_to_pool_mapping()
|
||||
self.num_pools = self.impl.num_pools
|
||||
self.max_blocks_per_seq = self.impl.max_blocks_per_seq
|
||||
self.enable_block_reuse = kv_cache_config.enable_block_reuse
|
||||
|
||||
def shutdown(self):
|
||||
self.impl.release_pools()
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from .eagle3 import Eagle3Config, Eagle3SpecMetadata
|
||||
from .interface import SpecConfig, SpecMetadata
|
||||
from .mtp import MTPConfig, MTPEagleWorker, MTPSpecMetadata, MTPWorker
|
||||
from .utils import (get_num_spec_layers, get_spec_decoder, get_spec_metadata,
|
||||
@ -5,6 +6,7 @@ from .utils import (get_num_spec_layers, get_spec_decoder, get_spec_metadata,
|
||||
|
||||
__all__ = [
|
||||
"SpecConfig", "SpecMetadata", "MTPConfig", "MTPWorker", "MTPEagleWorker",
|
||||
"MTPSpecMetadata", "get_spec_metadata", "get_spec_resource_manager",
|
||||
"get_spec_decoder", "get_num_spec_layers"
|
||||
"Eagle3Config", "Eagle3SpecMetadata", "MTPSpecMetadata",
|
||||
"get_spec_metadata", "get_spec_resource_manager", "get_spec_decoder",
|
||||
"get_num_spec_layers"
|
||||
]
|
||||
|
||||
105
tensorrt_llm/_torch/speculative/eagle3.py
Normal file
105
tensorrt_llm/_torch/speculative/eagle3.py
Normal file
@ -0,0 +1,105 @@
|
||||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from ..pyexecutor.decoder import TorchDecoder
|
||||
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class Eagle3Config(SpecConfig):
|
||||
spec_dec_name: str = "EAGLE3"
|
||||
eagle_weights_path: Optional[str] = None
|
||||
num_layers: int = 0
|
||||
|
||||
def __post_init__(self):
|
||||
if self.eagle_weights_path is None:
|
||||
raise ValueError("Path to EAGLE3 weights must be specified.")
|
||||
|
||||
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
|
||||
self.spec_dec_name)
|
||||
self.num_extra_kv_tokens = 0
|
||||
|
||||
def update_from_model_config(self, model_config):
|
||||
self.num_layers = model_config.num_hidden_layers
|
||||
|
||||
def get_draft_model_prompt(self,
|
||||
input_tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Eagle3 always throws away the first token when processing draft inputs
|
||||
"""
|
||||
return input_tokens[1:]
|
||||
|
||||
|
||||
@dataclass
|
||||
class Eagle3SpecMetadata(SpecMetadata):
|
||||
hidden_states: List[torch.Tensor] = field(default_factory=list)
|
||||
num_layers: int = 0
|
||||
layers_to_capture: Tuple[int, ...] = field(init=False)
|
||||
target_model_embed_tokens: Optional[torch.nn.Module] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.num_layers == 1:
|
||||
# For the draft model, we have to capture hiddens states
|
||||
# manually outside of the decoder layer.
|
||||
self.layers_to_capture = ()
|
||||
else:
|
||||
if self.num_layers <= 5:
|
||||
raise ValueError("Not enough hidden layers for EAGLE")
|
||||
|
||||
self.layers_to_capture = (1, self.num_layers // 2 - 1,
|
||||
self.num_layers - 3)
|
||||
|
||||
def prepare(self):
|
||||
self.hidden_states = []
|
||||
|
||||
def maybe_capture_hidden_states(self, layer_id: int,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor) -> None:
|
||||
if layer_id in self.layers_to_capture:
|
||||
# TODO(miovine): write directly into a pre-allocated buffer for
|
||||
# CUDA graph support.
|
||||
self.hidden_states.append(hidden_states + residual)
|
||||
|
||||
def get_hidden_states(
|
||||
self,
|
||||
scheduled_requests,
|
||||
num_rejected_tokens: Optional[Dict] = None) -> List[torch.Tensor]:
|
||||
req_id_to_gather_ids = {}
|
||||
seq_start = 0
|
||||
for req_id, seqlen in zip(self.request_ids, self.seq_lens):
|
||||
if num_rejected_tokens is not None:
|
||||
if req_id in num_rejected_tokens:
|
||||
req_id_to_gather_ids[req_id] = list(
|
||||
range(seq_start,
|
||||
seq_start + seqlen - num_rejected_tokens[req_id]))
|
||||
else:
|
||||
req_id_to_gather_ids[req_id] = [seq_start + seqlen - 1]
|
||||
|
||||
seq_start += seqlen
|
||||
|
||||
hidden_states_gather_ids = []
|
||||
for req in chain(scheduled_requests.context_requests,
|
||||
scheduled_requests.generation_requests):
|
||||
hidden_states_gather_ids.extend(
|
||||
req_id_to_gather_ids[req.py_request_id])
|
||||
|
||||
return [h[hidden_states_gather_ids] for h in self.hidden_states]
|
||||
|
||||
|
||||
class Eagle3Decoder(TorchDecoder):
|
||||
|
||||
def _batch_decode(self, scheduled_requests, model_outputs):
|
||||
logits = model_outputs["logits"]
|
||||
new_tokens_device = torch.argmax(logits, dim=-1)
|
||||
if "d2t" in model_outputs:
|
||||
d2t = model_outputs["d2t"]
|
||||
new_tokens_device = d2t[new_tokens_device] + new_tokens_device
|
||||
new_tokens_host = new_tokens_device.to('cpu', non_blocking=True)
|
||||
new_tensors_device = {"new_tokens_device": new_tokens_device}
|
||||
new_tensors_host = {"new_tokens_host": new_tokens_host}
|
||||
decoder_event = torch.cuda.Event()
|
||||
decoder_event.record()
|
||||
return new_tensors_device, new_tensors_host, decoder_event
|
||||
@ -1,35 +1,28 @@
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum, auto
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from ..model_config import TConfig
|
||||
from ..pyexecutor.scheduler import ScheduledRequests
|
||||
|
||||
|
||||
class SpeculativeDecodingMode(IntEnum):
|
||||
MTP = auto()
|
||||
MTP_EAGLE = auto()
|
||||
MEDUSA = auto()
|
||||
EAGLE = auto()
|
||||
LOOKAHEAD = auto()
|
||||
EAGLE3 = auto()
|
||||
NONE = auto()
|
||||
|
||||
def is_mtp(self):
|
||||
return self == SpeculativeDecodingMode.MTP or SpeculativeDecodingMode.MTP_EAGLE
|
||||
return self == SpeculativeDecodingMode.MTP or self == SpeculativeDecodingMode.MTP_EAGLE
|
||||
|
||||
def is_mtp_eagle(self):
|
||||
return self == SpeculativeDecodingMode.MTP_EAGLE
|
||||
|
||||
def is_medusa(self):
|
||||
return self == SpeculativeDecodingMode.MEDUSA
|
||||
|
||||
def is_eagle(self):
|
||||
return self == SpeculativeDecodingMode.EAGLE
|
||||
|
||||
def is_lookahead(self):
|
||||
return self == SpeculativeDecodingMode.LOOKAHEAD
|
||||
def is_eagle3(self):
|
||||
return self == SpeculativeDecodingMode.EAGLE3
|
||||
|
||||
def is_none(self):
|
||||
return self == SpeculativeDecodingMode.NONE
|
||||
@ -38,22 +31,24 @@ class SpeculativeDecodingMode(IntEnum):
|
||||
return self.is_mtp()
|
||||
|
||||
def needs_kv_cache_rewind(self):
|
||||
return self.is_mtp() or self.is_eagle() or self.is_lookahead(
|
||||
) or self.is_medusa()
|
||||
return self.is_mtp()
|
||||
|
||||
def support_overlap_scheduler(self):
|
||||
return self.is_mtp()
|
||||
|
||||
def extend_ctx(self):
|
||||
"""
|
||||
If true, treat generation requests with draft tokens as
|
||||
chunked context requests at the kernel level. Required for
|
||||
any spec dec mode that uses the SpecExecutor.
|
||||
"""
|
||||
return self.is_eagle3()
|
||||
|
||||
@staticmethod
|
||||
def from_string(name: str):
|
||||
name_map = {
|
||||
"MTP": SpeculativeDecodingMode.MTP,
|
||||
"MEDUSA": SpeculativeDecodingMode.MEDUSA,
|
||||
"EAGLE": SpeculativeDecodingMode.EAGLE,
|
||||
"LOOKAHEAD": SpeculativeDecodingMode.LOOKAHEAD,
|
||||
None: SpeculativeDecodingMode.NONE,
|
||||
}
|
||||
return name_map[name]
|
||||
def from_string(name: Optional[str]) -> "SpeculativeDecodingMode":
|
||||
if name is None:
|
||||
return SpeculativeDecodingMode.NONE
|
||||
return SpeculativeDecodingMode[name.upper()]
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -75,6 +70,14 @@ class SpecConfig:
|
||||
def update_from_model_config(self, model_config: TConfig):
|
||||
pass
|
||||
|
||||
def get_draft_model_prompt(self,
|
||||
input_tokens: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Override for spec dec modes that need to preprocess prompt
|
||||
tokens before passing them to the draft model.
|
||||
"""
|
||||
return input_tokens
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpecMetadata:
|
||||
@ -98,6 +101,8 @@ class SpecMetadata:
|
||||
# The request ID of each sequence in the batch.
|
||||
# The shape is (batch_size).
|
||||
request_ids: Optional[List[int]] = None
|
||||
# Sequence length for each request.
|
||||
seq_lens: Optional[List[int]] = None
|
||||
# The gather ids for logits.
|
||||
gather_ids: Optional[torch.Tensor] = None
|
||||
# The number of tokens for speculative model/layer
|
||||
@ -111,7 +116,8 @@ class SpecMetadata:
|
||||
# draft/target layers. But KVCacheManager can only support kv caches with the
|
||||
# same kv lengths for different layers. Add extra kv token in kv cache manager
|
||||
# to haddle this issue.
|
||||
num_extra_kv_tokens: Optional[int] = 0
|
||||
num_extra_kv_tokens: Optional[int] = 0 # Number of layers in target model
|
||||
num_layers: int = 0
|
||||
|
||||
def prepare():
|
||||
"""
|
||||
@ -130,3 +136,29 @@ class SpecMetadata:
|
||||
cuda_graph_metadata.max_num_requests = max_batch_size
|
||||
cuda_graph_metadata.__post_init__()
|
||||
return cuda_graph_metadata
|
||||
|
||||
def maybe_capture_hidden_states(self, layer_id: int,
|
||||
hidden_states: torch.Tensor,
|
||||
residual: torch.Tensor) -> None:
|
||||
"""
|
||||
Some spec decode algorithms require hidden states from the target
|
||||
model. Use this method to record them. By default, does nothing.
|
||||
"""
|
||||
|
||||
def get_hidden_states(
|
||||
self,
|
||||
scheduled_requests: ScheduledRequests,
|
||||
num_rejected_tokens: Optional[Dict] = None) -> List[torch.Tensor]:
|
||||
"""
|
||||
Return any captured hidden states. Should do any necessary
|
||||
pre-processing.
|
||||
|
||||
num_rejected_tokens is a dictionary mapping request IDs to the
|
||||
number of tokens rejected for that request. If a request ID isn't
|
||||
in the dictionary, it means that the request is not needed for drafting.
|
||||
|
||||
If the dictionary is not given, this function assumes that the hidden
|
||||
states are being prepared for running the draft model autoregressively,
|
||||
and only the last hidden state vector for each sequence is returned.
|
||||
"""
|
||||
return []
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from .eagle3 import Eagle3Decoder, Eagle3SpecMetadata
|
||||
from .mtp import MTPDecoder, MTPHiddenStatesManager, MTPSpecMetadata
|
||||
|
||||
|
||||
@ -11,6 +12,11 @@ def get_spec_metadata(spec_config,
|
||||
mtp_num_modules=spec_config.num_nextn_predict_layers,
|
||||
max_num_requests=max_num_requests,
|
||||
mtp_hidden_states_manager=spec_resource_manager)
|
||||
elif spec_config.spec_dec_mode.is_eagle3():
|
||||
return Eagle3SpecMetadata(max_draft_tokens=spec_config.max_draft_tokens,
|
||||
spec_dec_mode=spec_config.spec_dec_mode,
|
||||
max_num_requests=max_num_requests,
|
||||
num_layers=spec_config.num_layers)
|
||||
else:
|
||||
return None
|
||||
|
||||
@ -29,6 +35,8 @@ def get_spec_resource_manager(spec_config, model_config, max_num_requests):
|
||||
def get_spec_decoder(max_seq_len, spec_config):
|
||||
if spec_config.spec_dec_mode.is_mtp():
|
||||
return MTPDecoder(max_seq_len, spec_config)
|
||||
if spec_config.spec_dec_mode.is_eagle3():
|
||||
return Eagle3Decoder(max_seq_len)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@ -216,6 +216,7 @@ class EagleDecodingConfig(DecodingBaseConfig):
|
||||
dynamic_tree_max_topK: Optional[int] = None
|
||||
num_eagle_layers: Optional[int] = None
|
||||
max_non_leaves_per_layer: Optional[int] = None
|
||||
pytorch_eagle_weights_path: Optional[str] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict):
|
||||
@ -1145,15 +1146,23 @@ class LlmArgs:
|
||||
|
||||
self.build_config.max_draft_len = self.speculative_config.max_draft_len
|
||||
|
||||
eagle_config = EagleConfig(
|
||||
self.speculative_config.eagle_choices,
|
||||
self.speculative_config.greedy_sampling,
|
||||
self.speculative_config.posterior_threshold,
|
||||
self.speculative_config.use_dynamic_tree,
|
||||
self.speculative_config.dynamic_tree_max_topK)
|
||||
self.decoding_config = DecodingConfig(
|
||||
decoding_mode=DecodingMode.Eagle(),
|
||||
eagle_config=eagle_config)
|
||||
if self.backend != 'pytorch':
|
||||
eagle_config = EagleConfig(
|
||||
self.speculative_config.eagle_choices,
|
||||
self.speculative_config.greedy_sampling,
|
||||
self.speculative_config.posterior_threshold,
|
||||
self.speculative_config.use_dynamic_tree,
|
||||
self.speculative_config.dynamic_tree_max_topK)
|
||||
self.decoding_config = DecodingConfig(
|
||||
decoding_mode=DecodingMode.Eagle(),
|
||||
eagle_config=eagle_config)
|
||||
else:
|
||||
from tensorrt_llm._torch.speculative import Eagle3Config
|
||||
self.speculative_config = Eagle3Config(
|
||||
max_draft_tokens=self.speculative_config.max_draft_len,
|
||||
eagle_weights_path=self.speculative_config.
|
||||
pytorch_eagle_weights_path)
|
||||
|
||||
elif isinstance(self.speculative_config, MTPDecodingConfig):
|
||||
from tensorrt_llm._torch.speculative import MTPConfig
|
||||
self.speculative_config = MTPConfig(
|
||||
|
||||
@ -134,7 +134,8 @@ class OPTModel(Module):
|
||||
attention_params=None,
|
||||
prompt_embedding_table=None,
|
||||
prompt_tasks=None,
|
||||
prompt_vocab_size=None):
|
||||
prompt_vocab_size=None,
|
||||
**kwargs):
|
||||
|
||||
args = [prompt_embedding_table, prompt_tasks, prompt_vocab_size
|
||||
] if prompt_embedding_table is not None else []
|
||||
|
||||
@ -1440,13 +1440,37 @@ def test_ptq_quickstart_advanced_mtp(llm_root, llm_venv, model_name,
|
||||
str(example_root / "quickstart_advanced.py"),
|
||||
"--enable_overlap_scheduler",
|
||||
"--use_cuda_graph",
|
||||
"--mtp_nextn",
|
||||
"--spec_decode_nextn",
|
||||
"1", # test 1 MTP module
|
||||
"--spec_decode_algo",
|
||||
"MTP",
|
||||
"--model_dir",
|
||||
f"{llm_models_root()}/{model_path}",
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_name,model_path,eagle_model_path", [
|
||||
("Llama-3.1-8b-Instruct", "llama-3.1-model/Llama-3.1-8B-Instruct",
|
||||
"EAGLE3-LLaMA3.1-Instruct-8B"),
|
||||
])
|
||||
def test_ptp_quickstart_advanced_eagle3(llm_root, llm_venv, model_name,
|
||||
model_path, eagle_model_path):
|
||||
print(f"Testing {model_name}.")
|
||||
example_root = Path(os.path.join(llm_root, "examples", "pytorch"))
|
||||
llm_venv.run_cmd([
|
||||
str(example_root / "quickstart_advanced.py"),
|
||||
"--spec_decode_nextn",
|
||||
"4",
|
||||
"--spec_decode_algo",
|
||||
"eagle3",
|
||||
"--model_dir",
|
||||
f"{llm_models_root()}/{model_path}",
|
||||
"--eagle_model_dir",
|
||||
f"{llm_models_root()}/{eagle_model_path}",
|
||||
"--kv_cache_enable_block_reuse",
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.skip_less_device_memory(110000)
|
||||
@pytest.mark.skip_less_device(8)
|
||||
@pytest.mark.parametrize("model_name,model_path", [
|
||||
|
||||
@ -14,7 +14,7 @@ l0_b200:
|
||||
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-NVFP4-nvfp4-quantized/Meta-Llama-3.1-8B]
|
||||
- test_e2e.py::test_ptp_quickstart_advanced[Llama3.1-8B-FP8-llama-3.1-model/Llama-3.1-8B-Instruct-FP8]
|
||||
- test_e2e.py::test_ptq_quickstart_advanced_mtp[DeepSeek-V3-Lite-BF16-DeepSeek-V3-Lite/bf16]
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_mixed_precision
|
||||
- test_e2e.py::test_ptp_quickstart_advanced_eagle3[Llama-3.1-8b-Instruct-llama-3.1-model/Llama-3.1-8B-Instruct-EAGLE3-LLaMA3.1-Instruct-8B]
|
||||
- examples/test_pytorch.py::test_llm_llama_1gpu[llama-3.1-8b-enable_fp4]
|
||||
- test_e2e.py::test_trtllm_bench_pytorch_backend_sanity[meta-llama/Llama-3.1-8B-llama-3.1-8b-False-False]
|
||||
- unittest/_torch -k "not (modeling or multi_gpu or auto_deploy)"
|
||||
@ -23,6 +23,7 @@ l0_b200:
|
||||
- unittest/_torch/multi_gpu_modeling -k "deepseek and tp1 and nextn0"
|
||||
- unittest/_torch/multi_gpu_modeling -k "deepseek and tp1 and not nextn0"
|
||||
- unittest/_torch/auto_deploy/unit/singlegpu
|
||||
- unittest/_torch/speculative/test_eagle3.py
|
||||
- examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-disable_fp8-enable_fp4]
|
||||
- examples/test_pytorch.py::test_llm_deepseek_1gpu[deepseek-v3-lite-disable_fp8-disable_fp4]
|
||||
# ------------- TRT tests ---------------
|
||||
|
||||
84
tests/unittest/_torch/speculative/test_eagle3.py
Normal file
84
tests/unittest/_torch/speculative/test_eagle3.py
Normal file
@ -0,0 +1,84 @@
|
||||
import os
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from tensorrt_llm import SamplingParams
|
||||
from tensorrt_llm._torch import LLM
|
||||
from tensorrt_llm._torch.pyexecutor.config import PyTorchConfig
|
||||
from tensorrt_llm.bindings.executor import KvCacheConfig
|
||||
from tensorrt_llm.llmapi import EagleDecodingConfig
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
||||
from utils.llm_data import llm_models_root
|
||||
|
||||
|
||||
def test_llama_eagle3():
|
||||
models_path = llm_models_root()
|
||||
|
||||
pytorch_config = PyTorchConfig(
|
||||
enable_overlap_scheduler=False,
|
||||
use_cuda_graph=False,
|
||||
)
|
||||
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=False, )
|
||||
|
||||
eagle_model_dir = f"{models_path}/EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
target_model_dir = f"{models_path}/llama-3.1-model/Llama-3.1-8B-Instruct"
|
||||
|
||||
draft_len = 4
|
||||
spec_config = EagleDecodingConfig(
|
||||
max_draft_len=draft_len, pytorch_eagle_weights_path=eagle_model_dir)
|
||||
|
||||
llm_spec = LLM(model=target_model_dir,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
kv_cache_config=kv_cache_config,
|
||||
speculative_config=spec_config)
|
||||
|
||||
sampling_params = SamplingParams(
|
||||
max_tokens=32,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
# First make sure the acceptance rate is reasonable.
|
||||
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
|
||||
num_tokens = 0
|
||||
|
||||
num_drafted = 0
|
||||
num_accepted = 0
|
||||
|
||||
for output in llm_spec.generate_async(tok_ids,
|
||||
SamplingParams(max_tokens=128,
|
||||
temperature=0),
|
||||
streaming=True):
|
||||
beam = output.outputs[0]
|
||||
new_tokens = beam.token_ids
|
||||
|
||||
num_drafted += draft_len
|
||||
num_accepted += len(new_tokens) - num_tokens - 1
|
||||
|
||||
num_tokens = len(new_tokens)
|
||||
|
||||
accept_rate = num_accepted / num_drafted
|
||||
assert accept_rate > 0.25
|
||||
|
||||
prompts = [
|
||||
"The capital of France is", "The president of the United States is"
|
||||
]
|
||||
results_spec = llm_spec.generate(prompts, sampling_params)
|
||||
generated_text_spec = [result.outputs[0].text for result in results_spec]
|
||||
|
||||
del llm_spec
|
||||
llm_ref = LLM(model=target_model_dir,
|
||||
pytorch_backend_config=pytorch_config,
|
||||
kv_cache_config=kv_cache_config)
|
||||
|
||||
results_ref = llm_ref.generate(prompts, sampling_params)
|
||||
generated_text_ref = [result.outputs[0].text for result in results_ref]
|
||||
|
||||
for text_spec, text_ref in zip(generated_text_spec, generated_text_ref):
|
||||
# The spec decode algorithm currently guarantees identical results
|
||||
assert text_spec == text_ref
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user