diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 6cacd5ffac..f475280f85 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -1377,13 +1377,13 @@ class DeepseekV3DecoderLayer(DecoderLayer): hidden_states, residual = self.moe_allreduce( fc2_output, all_reduce_params=moe_all_reduce_params) else: - if self.next_layer_layernorm is not None: - hidden_states, residual = self.next_layer_layernorm( - hidden_states, residual) if spec_metadata is not None and spec_metadata.is_layer_capture( self.layer_idx): spec_metadata.maybe_capture_hidden_states( - self.layer_idx, hidden_states, None) + self.layer_idx, hidden_states, residual) + if self.next_layer_layernorm is not None: + hidden_states, residual = self.next_layer_layernorm( + hidden_states, residual) return hidden_states, residual diff --git a/tensorrt_llm/_torch/models/modeling_speculative.py b/tensorrt_llm/_torch/models/modeling_speculative.py index 582b3c40b3..dc4b3b1d54 100755 --- a/tensorrt_llm/_torch/models/modeling_speculative.py +++ b/tensorrt_llm/_torch/models/modeling_speculative.py @@ -8,7 +8,7 @@ from ...functional import PositionEmbeddingType from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams from ..model_config import ModelConfig, TConfig -from ..modules.attention import Attention +from ..modules.attention import MLA, Attention from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding from ..modules.fused_moe import moe_load_balancer_set_repeated_for_next_layer @@ -76,14 +76,81 @@ class Eagle3Attention(Attention): ) -class Eagle3DecoderLayer(DecoderLayer): +class Eagle3MLAttention(MLA): + """ + MLA (Multi-head Latent Attention) for Eagle3 draft model (e.g., DeepSeekV3). + The first layer takes concatenated [embeds, hidden_states] as input (2x hidden_size), + while subsequent layers take regular hidden_states (1x hidden_size). + """ def __init__( self, - model_config: LlamaConfig, + model_config: ModelConfig[PretrainedConfig], + layer_idx: Optional[int] = None, + aux_stream: Optional[torch.cuda.Stream] = None, + next_layer_regular: bool = False, + ): + config = model_config.pretrained_config + self._next_layer_regular = next_layer_regular + + predicted_tokens_per_seq = ( + model_config.spec_config.max_total_draft_tokens + + 1 if model_config.spec_config is not None else 1) + + super().__init__( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + num_key_value_heads=config.num_key_value_heads, + qk_rope_head_dim=config.qk_rope_head_dim, + qk_nope_head_dim=config.qk_nope_head_dim, + q_lora_rank=config.q_lora_rank, + kv_lora_rank=config.kv_lora_rank, + v_head_dim=config.v_head_dim, + predicted_tokens_per_seq=predicted_tokens_per_seq, + max_position_embeddings=config.max_position_embeddings, + bias=False, + pos_embd_params=PositionalEmbeddingParams( + type=PositionEmbeddingType.yarn, + rope=RopeParams.from_config(config), + is_neox=False, + ), + layer_idx=layer_idx, + dtype=config.torch_dtype, + config=model_config, + aux_stream=aux_stream, + ) + + # Override the kv_a_proj_with_mqa projection for first layer. + # The number of input features is twice as big for EAGLE3 draft models. + if not self._next_layer_regular: + quant_config = model_config.get_quant_config() + # For Eagle3, first layer takes [embeds, hidden_states] concatenated + self.kv_a_proj_with_mqa = Linear( + 2 * config.hidden_size, # Double input size + self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim, + bias=False, + dtype=config.torch_dtype, + quant_config=quant_config, + skip_create_weights_in_init=model_config. + skip_create_weights_in_init, + use_custom_cublas_mm=True, + ) + + +class Eagle3DecoderLayer(DecoderLayer): + """ + Unified decoder layer for Eagle3 speculative decoding. + Supports both standard attention (Llama-style) and MLA (DeepSeekV3-style). + """ + + def __init__( + self, + model_config: ModelConfig[PretrainedConfig], layer_idx: int = 0, is_first_layer: bool = True, - ) -> Tuple[torch.Tensor, torch.Tensor]: + use_mla: bool = False, + aux_stream: Optional[torch.cuda.Stream] = None, + ) -> None: super().__init__() config = model_config.pretrained_config eagle_config = config.eagle_config if hasattr(config, @@ -92,8 +159,18 @@ class Eagle3DecoderLayer(DecoderLayer): self._next_layer_regular = (eagle_config.get("next_layer_regular", True) and not is_first_layer) or eagle_config.get( "eh_proj_before_attn", False) - self.self_attn = Eagle3Attention(model_config, layer_idx, - self._next_layer_regular) + + # Select attention type based on config + if use_mla: + self.self_attn = Eagle3MLAttention( + model_config, + layer_idx, + aux_stream=aux_stream, + next_layer_regular=self._next_layer_regular, + ) + else: + self.self_attn = Eagle3Attention(model_config, layer_idx, + self._next_layer_regular) if config.model_type == "llama4_text": inter_size = config.intermediate_size_mlp @@ -109,18 +186,25 @@ class Eagle3DecoderLayer(DecoderLayer): overridden_tp_size=1 if model_config.mapping.enable_attention_dp else None, ) + if not self._next_layer_regular: - self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) + self.input_layernorm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) - self.hidden_norm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) + self.hidden_norm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) - self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) + self.post_attention_layernorm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) def forward( self, @@ -157,11 +241,16 @@ class Eagle3DecoderLayer(DecoderLayer): class Eagle3DraftModel(DecoderModel): + """ + Unified Eagle3 draft model supporting both standard attention (Llama-style) + and MLA attention (DeepSeekV3-style). + """ def __init__( self, - model_config: LlamaConfig, + model_config: ModelConfig[PretrainedConfig], start_layer_idx: int = 0, + use_mla: bool = False, ) -> None: super().__init__(model_config) @@ -175,6 +264,7 @@ class Eagle3DraftModel(DecoderModel): self.num_layers = model_config.pretrained_config.num_hidden_layers self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn", False) + self._use_mla = use_mla if hasattr(config, "target_hidden_size"): self.hidden_size_in = config.target_hidden_size @@ -184,40 +274,60 @@ class Eagle3DraftModel(DecoderModel): self._return_hidden_post_norm = eagle_config.get( "return_hidden_post_norm", False) + # Create auxiliary CUDA stream for MLA operations (only needed for MLA) + self.aux_stream = torch.cuda.Stream() if use_mla else None + if self.spec_config.num_capture_layers > 1: - self.fc = Linear(self.hidden_size_in * - self.spec_config.num_capture_layers, - config.hidden_size, - bias=getattr(config, "bias", False), - dtype=config.torch_dtype) + self.fc = Linear( + self.hidden_size_in * self.spec_config.num_capture_layers, + config.hidden_size, + bias=getattr(config, "bias", False), + dtype=config.torch_dtype, + ) if self.num_layers > 1: self.midlayer = nn.ModuleList([ - Eagle3DecoderLayer(model_config, - start_layer_idx + i, - is_first_layer=(i == 0)) - for i in range(self.num_layers) + Eagle3DecoderLayer( + model_config, + start_layer_idx + i, + is_first_layer=(i == 0), + use_mla=use_mla, + aux_stream=self.aux_stream, + ) for i in range(self.num_layers) ]) else: - self.midlayer = Eagle3DecoderLayer(model_config, start_layer_idx) + self.midlayer = Eagle3DecoderLayer( + model_config, + start_layer_idx, + use_mla=use_mla, + aux_stream=self.aux_stream, + ) - self.norm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) + self.norm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) + + if (config.draft_vocab_size is not None + and config.vocab_size != config.draft_vocab_size): + self.d2t = nn.Parameter( + torch.empty((config.draft_vocab_size, ), dtype=torch.int32), + requires_grad=False, + ) - if config.draft_vocab_size is not None and config.vocab_size != config.draft_vocab_size: - self.d2t = nn.Parameter(torch.empty((config.draft_vocab_size, ), - dtype=torch.int32), - requires_grad=False) if self._eh_proj_before_attn: - self.enorm = RMSNorm(hidden_size=config.hidden_size, - eps=config.rms_norm_eps, - dtype=config.torch_dtype) - self.eh_proj = nn.Linear(config.hidden_size * 2, - config.hidden_size, - bias=eagle_config.get( - "eh_proj_bias", False), - dtype=config.torch_dtype) + self.enorm = RMSNorm( + hidden_size=config.hidden_size, + eps=config.rms_norm_eps, + dtype=config.torch_dtype, + ) + self.eh_proj = nn.Linear( + config.hidden_size * 2, + config.hidden_size, + bias=eagle_config.get("eh_proj_bias", False), + dtype=config.torch_dtype, + ) if self.hidden_size_in != config.hidden_size: if model_config.mapping.enable_attention_dp: @@ -261,7 +371,7 @@ class Eagle3DraftModel(DecoderModel): assert hidden_states is not None # NOTE: If hidden states from the target model have to be concatenated, # ideally, we expect that to happen outside the model definition. This - # helps usavoid data-dependent control flow and gives us better CUDA + # helps us avoid data-dependent control flow and gives us better CUDA # graph coverage. if self._eh_proj_before_attn: input_embeds = self.enorm(inputs_embeds) @@ -273,17 +383,21 @@ class Eagle3DraftModel(DecoderModel): for layer in self.midlayer: if residual is not None: hidden_states = hidden_states + residual - hidden_states, residual = layer(position_ids=position_ids, - embeds=inputs_embeds, - hidden_states=hidden_states, - attn_metadata=attn_metadata, - spec_metadata=spec_metadata) + hidden_states, residual = layer( + position_ids=position_ids, + embeds=inputs_embeds, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + ) else: - hidden_states, residual = self.midlayer(position_ids=position_ids, - embeds=inputs_embeds, - hidden_states=hidden_states, - attn_metadata=attn_metadata, - spec_metadata=spec_metadata) + hidden_states, residual = self.midlayer( + position_ids=position_ids, + embeds=inputs_embeds, + hidden_states=hidden_states, + attn_metadata=attn_metadata, + spec_metadata=spec_metadata, + ) hidden_states, hidden_states_to_save = self.norm( hidden_states, residual) @@ -294,20 +408,36 @@ class Eagle3DraftModel(DecoderModel): # We use Llama3 as the base architecture for EAGLE3 draft layers @register_auto_model("EAGLE3LlamaForCausalLM") -class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, LlamaConfig]): +@register_auto_model("Eagle3DeepSeekV3ForCausalLM") +class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, + PretrainedConfig]): def __init__( self, - model_config: LlamaConfig, + model_config: ModelConfig[PretrainedConfig], start_layer_idx: int = 0, ): draft_vocab_size = model_config.pretrained_config.vocab_size if model_config.pretrained_config.draft_vocab_size is not None: draft_vocab_size = model_config.pretrained_config.draft_vocab_size - super().__init__(Eagle3DraftModel(model_config, start_layer_idx), - config=model_config, - hidden_size=model_config.pretrained_config.hidden_size, - vocab_size=draft_vocab_size) + + # Determine if we should use MLA attention based on config + # MLA is used for DeepSeekV3-style models that have kv_lora_rank + config = model_config.pretrained_config + self._use_mla = hasattr(config, 'kv_lora_rank') and config.kv_lora_rank + + draft_model = Eagle3DraftModel( + model_config, + start_layer_idx, + use_mla=self._use_mla, + ) + + super().__init__( + draft_model, + config=model_config, + hidden_size=model_config.pretrained_config.hidden_size, + vocab_size=draft_vocab_size, + ) self.load_lm_head_from_target = True def forward( @@ -339,6 +469,7 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, LlamaConfig]): ) def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper): + new_weights = {} for k, v in weights.items(): if 'lm_head' not in k: @@ -347,13 +478,24 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, LlamaConfig]): self.load_lm_head_from_target = False new_k = k new_weights[new_k] = v - if self.load_lm_head_from_target: - super().load_weights(weights=new_weights, - weight_mapper=weight_mapper, - skip_modules=['lm_head']) + + if self._use_mla: + # Use DeepseekV3WeightLoader for proper MLA weight handling + from .modeling_deepseekv3 import DeepseekV3WeightLoader + weight_loader = DeepseekV3WeightLoader(self, is_draft_model=False) + if self.load_lm_head_from_target: + weight_loader.load_weights(new_weights, + skip_modules=['lm_head']) + else: + weight_loader.load_weights(new_weights) else: - super().load_weights(weights=new_weights, - weight_mapper=weight_mapper) + if self.load_lm_head_from_target: + super().load_weights(weights=new_weights, + weight_mapper=weight_mapper, + skip_modules=['lm_head']) + else: + super().load_weights(weights=new_weights, + weight_mapper=weight_mapper) def load_weights_from_target_model(self, target_model: torch.nn.Module) -> None: @@ -704,6 +846,7 @@ def get_draft_model(model_config, draft_config, lm_head, model): spec_dec_mode = model_config.spec_config.spec_dec_mode if spec_dec_mode.is_eagle3_one_model(): if model_config.spec_config.eagle3_model_arch == "llama3": + # Eagle3ForCausalLM handles both Llama3 and DeepSeekV3 architectures return Eagle3ForCausalLM( draft_config, model_config.pretrained_config.num_hidden_layers) elif model_config.spec_config.eagle3_model_arch == "mistral_large3": @@ -714,6 +857,7 @@ def get_draft_model(model_config, draft_config, lm_head, model): raise ValueError( f"Unsupported eagle3 model architecture: {spec_dec_mode.eagle3_model_arch}" ) + elif spec_dec_mode.is_mtp_one_model(): return MTPForCausalLM(model_config, model_config.pretrained_config.num_hidden_layers, diff --git a/tensorrt_llm/_torch/pyexecutor/config_utils.py b/tensorrt_llm/_torch/pyexecutor/config_utils.py index e4fa9da6e6..01fda0c689 100644 --- a/tensorrt_llm/_torch/pyexecutor/config_utils.py +++ b/tensorrt_llm/_torch/pyexecutor/config_utils.py @@ -33,6 +33,7 @@ class LazyConfigDict(dict): _CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDict( deepseek_v32="DeepseekV3Config", + kimi_k2="DeepseekV3Config", ) # NOTE: HF config.json uses deepseek_v32 as model_type but with same DSV3 config class diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index c3da1c7017..6fa7fba858 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -133,7 +133,7 @@ class Eagle3SpecMetadata(SpecMetadata): self.layers_to_capture = (self.num_layers - 1, ) elif self.layers_to_capture is None: if self.num_layers == 1 or self.is_mtp_eagle: - self.layers_to_capture = (self.num_layers - 1, ) + self.layers_to_capture = (-1, ) else: if self.num_layers <= 5: raise ValueError( diff --git a/tests/unittest/_torch/speculative/test_eagle3.py b/tests/unittest/_torch/speculative/test_eagle3.py index d763bd50b4..e504c14e23 100644 --- a/tests/unittest/_torch/speculative/test_eagle3.py +++ b/tests/unittest/_torch/speculative/test_eagle3.py @@ -462,6 +462,114 @@ def test_deepseek_eagle3(): pass +def test_deepseek_mla_eagle3(): + use_cuda_graph = True + attn_backend = "TRTLLM" + disable_overlap_scheduler = False + enable_block_reuse = False + use_one_model = True + enable_chunked_prefill = False + + # Eagle3 one model works with overlap scheduler and block reuse. + total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 + if total_mem_gb < 150: + pytest.skip("Not enough memory to load target + draft model") + + models_path = llm_models_root() + eagle_config = { + "architectures": ["Eagle3DeepseekV3ForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "first_k_dense_replace": 1, + "hidden_act": "silu", + "hidden_size": 2560, + "intermediate_size": 8192, + "kv_lora_rank": 512, + "max_position_embeddings": 4096, + "model_type": "kimi_k2", + "num_attention_heads": 32, + "num_hidden_layers": 1, + "num_key_value_heads": 32, + "num_nextn_predict_layers": 0, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "rms_norm_eps": 1e-05, + "rope_scaling": { + "beta_fast": 1.0, + "beta_slow": 1.0, + "factor": 64.0, + "mscale": 1.0, + "mscale_all_dim": 1.0, + "original_max_position_embeddings": 4096, + "type": "yarn" + }, + "rope_theta": 50000.0, + "routed_scaling_factor": 2.827, + "scoring_func": "sigmoid", + "seq_aux": True, + "topk_group": 1, + "topk_method": "noaux_tc", + "torch_dtype": "bfloat16", + "torchscript": False, + "transformers_version": "4.51.3", + "use_bfloat16": False, + "use_cache": True, + "v_head_dim": 128, + "vocab_size": 129280, + "draft_vocab_size": 129280, + "eagle_config": { + "use_aux_hidden_state": True, + "use_input_layernorm_in_first_layer": True, + "use_last_layernorm": True, + "use_mtp_layernorm": False + } + } + with tempfile.TemporaryDirectory() as temp_dir: + eagle_model_dir = Path(temp_dir) + config_path = eagle_model_dir / "config.json" + with config_path.open("w") as f: + json.dump(eagle_config, f, indent=2) + target_model_dir = f"{models_path}/DeepSeek-V3-Lite/nvfp4_moe_only" + + # bs > 1 gives non-deterministic when doing IFB. There are slight chances + # that ref and spec does not match 100% + max_batch_size = 16 + max_draft_len = 3 + kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse, + max_tokens=8192) + cuda_graph_config = CudaGraphConfig( + batch_sizes=[1]) if use_cuda_graph else None + + llm_common_config = dict( + model=target_model_dir, + attn_backend=attn_backend, + disable_overlap_scheduler=disable_overlap_scheduler, + cuda_graph_config=cuda_graph_config, + max_batch_size=max_batch_size, + max_num_tokens=4096, + max_seq_len=4096, + kv_cache_config=kv_cache_config, + enable_chunked_prefill=enable_chunked_prefill, + load_format="dummy", + ) + + spec_config = EagleDecodingConfig(max_draft_len=max_draft_len, + speculative_model_dir=eagle_model_dir, + eagle3_one_model=use_one_model, + load_format="dummy") + + llm_spec = LLM(**llm_common_config, speculative_config=spec_config) + + tok_ids = llm_spec.tokenizer.encode("The future of AI is") + + sampling_params = SamplingParams(max_tokens=32, temperature=0) + for output in llm_spec.generate_async(tok_ids, + sampling_params, + streaming=True): + pass + + @pytest.mark.parametrize("use_one_model", [True, False]) def test_multi_eagle3(use_one_model: bool): use_cuda_graph = True