mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][feat] Eagle: MLA Based Eagle (#9677)
Signed-off-by: Izzy Putterman <iputterman@nvidia.com>
This commit is contained in:
parent
f3dd6da080
commit
bdf6953ddc
@ -1377,13 +1377,13 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
hidden_states, residual = self.moe_allreduce(
|
||||
fc2_output, all_reduce_params=moe_all_reduce_params)
|
||||
else:
|
||||
if self.next_layer_layernorm is not None:
|
||||
hidden_states, residual = self.next_layer_layernorm(
|
||||
hidden_states, residual)
|
||||
if spec_metadata is not None and spec_metadata.is_layer_capture(
|
||||
self.layer_idx):
|
||||
spec_metadata.maybe_capture_hidden_states(
|
||||
self.layer_idx, hidden_states, None)
|
||||
self.layer_idx, hidden_states, residual)
|
||||
if self.next_layer_layernorm is not None:
|
||||
hidden_states, residual = self.next_layer_layernorm(
|
||||
hidden_states, residual)
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@ -8,7 +8,7 @@ from ...functional import PositionEmbeddingType
|
||||
from ..attention_backend import AttentionMetadata
|
||||
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
|
||||
from ..model_config import ModelConfig, TConfig
|
||||
from ..modules.attention import Attention
|
||||
from ..modules.attention import MLA, Attention
|
||||
from ..modules.decoder_layer import DecoderLayer
|
||||
from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import moe_load_balancer_set_repeated_for_next_layer
|
||||
@ -76,14 +76,81 @@ class Eagle3Attention(Attention):
|
||||
)
|
||||
|
||||
|
||||
class Eagle3DecoderLayer(DecoderLayer):
|
||||
class Eagle3MLAttention(MLA):
|
||||
"""
|
||||
MLA (Multi-head Latent Attention) for Eagle3 draft model (e.g., DeepSeekV3).
|
||||
The first layer takes concatenated [embeds, hidden_states] as input (2x hidden_size),
|
||||
while subsequent layers take regular hidden_states (1x hidden_size).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: LlamaConfig,
|
||||
model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: Optional[int] = None,
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
next_layer_regular: bool = False,
|
||||
):
|
||||
config = model_config.pretrained_config
|
||||
self._next_layer_regular = next_layer_regular
|
||||
|
||||
predicted_tokens_per_seq = (
|
||||
model_config.spec_config.max_total_draft_tokens +
|
||||
1 if model_config.spec_config is not None else 1)
|
||||
|
||||
super().__init__(
|
||||
hidden_size=config.hidden_size,
|
||||
num_attention_heads=config.num_attention_heads,
|
||||
num_key_value_heads=config.num_key_value_heads,
|
||||
qk_rope_head_dim=config.qk_rope_head_dim,
|
||||
qk_nope_head_dim=config.qk_nope_head_dim,
|
||||
q_lora_rank=config.q_lora_rank,
|
||||
kv_lora_rank=config.kv_lora_rank,
|
||||
v_head_dim=config.v_head_dim,
|
||||
predicted_tokens_per_seq=predicted_tokens_per_seq,
|
||||
max_position_embeddings=config.max_position_embeddings,
|
||||
bias=False,
|
||||
pos_embd_params=PositionalEmbeddingParams(
|
||||
type=PositionEmbeddingType.yarn,
|
||||
rope=RopeParams.from_config(config),
|
||||
is_neox=False,
|
||||
),
|
||||
layer_idx=layer_idx,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
aux_stream=aux_stream,
|
||||
)
|
||||
|
||||
# Override the kv_a_proj_with_mqa projection for first layer.
|
||||
# The number of input features is twice as big for EAGLE3 draft models.
|
||||
if not self._next_layer_regular:
|
||||
quant_config = model_config.get_quant_config()
|
||||
# For Eagle3, first layer takes [embeds, hidden_states] concatenated
|
||||
self.kv_a_proj_with_mqa = Linear(
|
||||
2 * config.hidden_size, # Double input size
|
||||
self.q_lora_rank + self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
use_custom_cublas_mm=True,
|
||||
)
|
||||
|
||||
|
||||
class Eagle3DecoderLayer(DecoderLayer):
|
||||
"""
|
||||
Unified decoder layer for Eagle3 speculative decoding.
|
||||
Supports both standard attention (Llama-style) and MLA (DeepSeekV3-style).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: ModelConfig[PretrainedConfig],
|
||||
layer_idx: int = 0,
|
||||
is_first_layer: bool = True,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
use_mla: bool = False,
|
||||
aux_stream: Optional[torch.cuda.Stream] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
config = model_config.pretrained_config
|
||||
eagle_config = config.eagle_config if hasattr(config,
|
||||
@ -92,8 +159,18 @@ class Eagle3DecoderLayer(DecoderLayer):
|
||||
self._next_layer_regular = (eagle_config.get("next_layer_regular", True)
|
||||
and not is_first_layer) or eagle_config.get(
|
||||
"eh_proj_before_attn", False)
|
||||
self.self_attn = Eagle3Attention(model_config, layer_idx,
|
||||
self._next_layer_regular)
|
||||
|
||||
# Select attention type based on config
|
||||
if use_mla:
|
||||
self.self_attn = Eagle3MLAttention(
|
||||
model_config,
|
||||
layer_idx,
|
||||
aux_stream=aux_stream,
|
||||
next_layer_regular=self._next_layer_regular,
|
||||
)
|
||||
else:
|
||||
self.self_attn = Eagle3Attention(model_config, layer_idx,
|
||||
self._next_layer_regular)
|
||||
|
||||
if config.model_type == "llama4_text":
|
||||
inter_size = config.intermediate_size_mlp
|
||||
@ -109,18 +186,25 @@ class Eagle3DecoderLayer(DecoderLayer):
|
||||
overridden_tp_size=1
|
||||
if model_config.mapping.enable_attention_dp else None,
|
||||
)
|
||||
|
||||
if not self._next_layer_regular:
|
||||
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
self.input_layernorm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
self.hidden_norm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
self.hidden_norm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
self.post_attention_layernorm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -157,11 +241,16 @@ class Eagle3DecoderLayer(DecoderLayer):
|
||||
|
||||
|
||||
class Eagle3DraftModel(DecoderModel):
|
||||
"""
|
||||
Unified Eagle3 draft model supporting both standard attention (Llama-style)
|
||||
and MLA attention (DeepSeekV3-style).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: LlamaConfig,
|
||||
model_config: ModelConfig[PretrainedConfig],
|
||||
start_layer_idx: int = 0,
|
||||
use_mla: bool = False,
|
||||
) -> None:
|
||||
super().__init__(model_config)
|
||||
|
||||
@ -175,6 +264,7 @@ class Eagle3DraftModel(DecoderModel):
|
||||
self.num_layers = model_config.pretrained_config.num_hidden_layers
|
||||
self._eh_proj_before_attn = eagle_config.get("eh_proj_before_attn",
|
||||
False)
|
||||
self._use_mla = use_mla
|
||||
|
||||
if hasattr(config, "target_hidden_size"):
|
||||
self.hidden_size_in = config.target_hidden_size
|
||||
@ -184,40 +274,60 @@ class Eagle3DraftModel(DecoderModel):
|
||||
self._return_hidden_post_norm = eagle_config.get(
|
||||
"return_hidden_post_norm", False)
|
||||
|
||||
# Create auxiliary CUDA stream for MLA operations (only needed for MLA)
|
||||
self.aux_stream = torch.cuda.Stream() if use_mla else None
|
||||
|
||||
if self.spec_config.num_capture_layers > 1:
|
||||
self.fc = Linear(self.hidden_size_in *
|
||||
self.spec_config.num_capture_layers,
|
||||
config.hidden_size,
|
||||
bias=getattr(config, "bias", False),
|
||||
dtype=config.torch_dtype)
|
||||
self.fc = Linear(
|
||||
self.hidden_size_in * self.spec_config.num_capture_layers,
|
||||
config.hidden_size,
|
||||
bias=getattr(config, "bias", False),
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
if self.num_layers > 1:
|
||||
self.midlayer = nn.ModuleList([
|
||||
Eagle3DecoderLayer(model_config,
|
||||
start_layer_idx + i,
|
||||
is_first_layer=(i == 0))
|
||||
for i in range(self.num_layers)
|
||||
Eagle3DecoderLayer(
|
||||
model_config,
|
||||
start_layer_idx + i,
|
||||
is_first_layer=(i == 0),
|
||||
use_mla=use_mla,
|
||||
aux_stream=self.aux_stream,
|
||||
) for i in range(self.num_layers)
|
||||
])
|
||||
else:
|
||||
self.midlayer = Eagle3DecoderLayer(model_config, start_layer_idx)
|
||||
self.midlayer = Eagle3DecoderLayer(
|
||||
model_config,
|
||||
start_layer_idx,
|
||||
use_mla=use_mla,
|
||||
aux_stream=self.aux_stream,
|
||||
)
|
||||
|
||||
self.norm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
self.norm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
if (config.draft_vocab_size is not None
|
||||
and config.vocab_size != config.draft_vocab_size):
|
||||
self.d2t = nn.Parameter(
|
||||
torch.empty((config.draft_vocab_size, ), dtype=torch.int32),
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
if config.draft_vocab_size is not None and config.vocab_size != config.draft_vocab_size:
|
||||
self.d2t = nn.Parameter(torch.empty((config.draft_vocab_size, ),
|
||||
dtype=torch.int32),
|
||||
requires_grad=False)
|
||||
if self._eh_proj_before_attn:
|
||||
self.enorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype)
|
||||
self.eh_proj = nn.Linear(config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=eagle_config.get(
|
||||
"eh_proj_bias", False),
|
||||
dtype=config.torch_dtype)
|
||||
self.enorm = RMSNorm(
|
||||
hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
self.eh_proj = nn.Linear(
|
||||
config.hidden_size * 2,
|
||||
config.hidden_size,
|
||||
bias=eagle_config.get("eh_proj_bias", False),
|
||||
dtype=config.torch_dtype,
|
||||
)
|
||||
|
||||
if self.hidden_size_in != config.hidden_size:
|
||||
if model_config.mapping.enable_attention_dp:
|
||||
@ -261,7 +371,7 @@ class Eagle3DraftModel(DecoderModel):
|
||||
assert hidden_states is not None
|
||||
# NOTE: If hidden states from the target model have to be concatenated,
|
||||
# ideally, we expect that to happen outside the model definition. This
|
||||
# helps usavoid data-dependent control flow and gives us better CUDA
|
||||
# helps us avoid data-dependent control flow and gives us better CUDA
|
||||
# graph coverage.
|
||||
if self._eh_proj_before_attn:
|
||||
input_embeds = self.enorm(inputs_embeds)
|
||||
@ -273,17 +383,21 @@ class Eagle3DraftModel(DecoderModel):
|
||||
for layer in self.midlayer:
|
||||
if residual is not None:
|
||||
hidden_states = hidden_states + residual
|
||||
hidden_states, residual = layer(position_ids=position_ids,
|
||||
embeds=inputs_embeds,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata)
|
||||
hidden_states, residual = layer(
|
||||
position_ids=position_ids,
|
||||
embeds=inputs_embeds,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata,
|
||||
)
|
||||
else:
|
||||
hidden_states, residual = self.midlayer(position_ids=position_ids,
|
||||
embeds=inputs_embeds,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata)
|
||||
hidden_states, residual = self.midlayer(
|
||||
position_ids=position_ids,
|
||||
embeds=inputs_embeds,
|
||||
hidden_states=hidden_states,
|
||||
attn_metadata=attn_metadata,
|
||||
spec_metadata=spec_metadata,
|
||||
)
|
||||
|
||||
hidden_states, hidden_states_to_save = self.norm(
|
||||
hidden_states, residual)
|
||||
@ -294,20 +408,36 @@ class Eagle3DraftModel(DecoderModel):
|
||||
|
||||
# We use Llama3 as the base architecture for EAGLE3 draft layers
|
||||
@register_auto_model("EAGLE3LlamaForCausalLM")
|
||||
class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, LlamaConfig]):
|
||||
@register_auto_model("Eagle3DeepSeekV3ForCausalLM")
|
||||
class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel,
|
||||
PretrainedConfig]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_config: LlamaConfig,
|
||||
model_config: ModelConfig[PretrainedConfig],
|
||||
start_layer_idx: int = 0,
|
||||
):
|
||||
draft_vocab_size = model_config.pretrained_config.vocab_size
|
||||
if model_config.pretrained_config.draft_vocab_size is not None:
|
||||
draft_vocab_size = model_config.pretrained_config.draft_vocab_size
|
||||
super().__init__(Eagle3DraftModel(model_config, start_layer_idx),
|
||||
config=model_config,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=draft_vocab_size)
|
||||
|
||||
# Determine if we should use MLA attention based on config
|
||||
# MLA is used for DeepSeekV3-style models that have kv_lora_rank
|
||||
config = model_config.pretrained_config
|
||||
self._use_mla = hasattr(config, 'kv_lora_rank') and config.kv_lora_rank
|
||||
|
||||
draft_model = Eagle3DraftModel(
|
||||
model_config,
|
||||
start_layer_idx,
|
||||
use_mla=self._use_mla,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
draft_model,
|
||||
config=model_config,
|
||||
hidden_size=model_config.pretrained_config.hidden_size,
|
||||
vocab_size=draft_vocab_size,
|
||||
)
|
||||
self.load_lm_head_from_target = True
|
||||
|
||||
def forward(
|
||||
@ -339,6 +469,7 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, LlamaConfig]):
|
||||
)
|
||||
|
||||
def load_weights(self, weights: Dict, weight_mapper: BaseWeightMapper):
|
||||
|
||||
new_weights = {}
|
||||
for k, v in weights.items():
|
||||
if 'lm_head' not in k:
|
||||
@ -347,13 +478,24 @@ class Eagle3ForCausalLM(DecoderModelForCausalLM[Eagle3DraftModel, LlamaConfig]):
|
||||
self.load_lm_head_from_target = False
|
||||
new_k = k
|
||||
new_weights[new_k] = v
|
||||
if self.load_lm_head_from_target:
|
||||
super().load_weights(weights=new_weights,
|
||||
weight_mapper=weight_mapper,
|
||||
skip_modules=['lm_head'])
|
||||
|
||||
if self._use_mla:
|
||||
# Use DeepseekV3WeightLoader for proper MLA weight handling
|
||||
from .modeling_deepseekv3 import DeepseekV3WeightLoader
|
||||
weight_loader = DeepseekV3WeightLoader(self, is_draft_model=False)
|
||||
if self.load_lm_head_from_target:
|
||||
weight_loader.load_weights(new_weights,
|
||||
skip_modules=['lm_head'])
|
||||
else:
|
||||
weight_loader.load_weights(new_weights)
|
||||
else:
|
||||
super().load_weights(weights=new_weights,
|
||||
weight_mapper=weight_mapper)
|
||||
if self.load_lm_head_from_target:
|
||||
super().load_weights(weights=new_weights,
|
||||
weight_mapper=weight_mapper,
|
||||
skip_modules=['lm_head'])
|
||||
else:
|
||||
super().load_weights(weights=new_weights,
|
||||
weight_mapper=weight_mapper)
|
||||
|
||||
def load_weights_from_target_model(self,
|
||||
target_model: torch.nn.Module) -> None:
|
||||
@ -704,6 +846,7 @@ def get_draft_model(model_config, draft_config, lm_head, model):
|
||||
spec_dec_mode = model_config.spec_config.spec_dec_mode
|
||||
if spec_dec_mode.is_eagle3_one_model():
|
||||
if model_config.spec_config.eagle3_model_arch == "llama3":
|
||||
# Eagle3ForCausalLM handles both Llama3 and DeepSeekV3 architectures
|
||||
return Eagle3ForCausalLM(
|
||||
draft_config, model_config.pretrained_config.num_hidden_layers)
|
||||
elif model_config.spec_config.eagle3_model_arch == "mistral_large3":
|
||||
@ -714,6 +857,7 @@ def get_draft_model(model_config, draft_config, lm_head, model):
|
||||
raise ValueError(
|
||||
f"Unsupported eagle3 model architecture: {spec_dec_mode.eagle3_model_arch}"
|
||||
)
|
||||
|
||||
elif spec_dec_mode.is_mtp_one_model():
|
||||
return MTPForCausalLM(model_config,
|
||||
model_config.pretrained_config.num_hidden_layers,
|
||||
|
||||
@ -33,6 +33,7 @@ class LazyConfigDict(dict):
|
||||
|
||||
_CONFIG_REGISTRY: dict[str, type[transformers.PretrainedConfig]] = LazyConfigDict(
|
||||
deepseek_v32="DeepseekV3Config",
|
||||
kimi_k2="DeepseekV3Config",
|
||||
) # NOTE: HF config.json uses deepseek_v32 as model_type but with same DSV3 config class
|
||||
|
||||
|
||||
|
||||
@ -133,7 +133,7 @@ class Eagle3SpecMetadata(SpecMetadata):
|
||||
self.layers_to_capture = (self.num_layers - 1, )
|
||||
elif self.layers_to_capture is None:
|
||||
if self.num_layers == 1 or self.is_mtp_eagle:
|
||||
self.layers_to_capture = (self.num_layers - 1, )
|
||||
self.layers_to_capture = (-1, )
|
||||
else:
|
||||
if self.num_layers <= 5:
|
||||
raise ValueError(
|
||||
|
||||
@ -462,6 +462,114 @@ def test_deepseek_eagle3():
|
||||
pass
|
||||
|
||||
|
||||
def test_deepseek_mla_eagle3():
|
||||
use_cuda_graph = True
|
||||
attn_backend = "TRTLLM"
|
||||
disable_overlap_scheduler = False
|
||||
enable_block_reuse = False
|
||||
use_one_model = True
|
||||
enable_chunked_prefill = False
|
||||
|
||||
# Eagle3 one model works with overlap scheduler and block reuse.
|
||||
total_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
||||
if total_mem_gb < 150:
|
||||
pytest.skip("Not enough memory to load target + draft model")
|
||||
|
||||
models_path = llm_models_root()
|
||||
eagle_config = {
|
||||
"architectures": ["Eagle3DeepseekV3ForCausalLM"],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"first_k_dense_replace": 1,
|
||||
"hidden_act": "silu",
|
||||
"hidden_size": 2560,
|
||||
"intermediate_size": 8192,
|
||||
"kv_lora_rank": 512,
|
||||
"max_position_embeddings": 4096,
|
||||
"model_type": "kimi_k2",
|
||||
"num_attention_heads": 32,
|
||||
"num_hidden_layers": 1,
|
||||
"num_key_value_heads": 32,
|
||||
"num_nextn_predict_layers": 0,
|
||||
"q_lora_rank": 1536,
|
||||
"qk_nope_head_dim": 128,
|
||||
"qk_rope_head_dim": 64,
|
||||
"rms_norm_eps": 1e-05,
|
||||
"rope_scaling": {
|
||||
"beta_fast": 1.0,
|
||||
"beta_slow": 1.0,
|
||||
"factor": 64.0,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"type": "yarn"
|
||||
},
|
||||
"rope_theta": 50000.0,
|
||||
"routed_scaling_factor": 2.827,
|
||||
"scoring_func": "sigmoid",
|
||||
"seq_aux": True,
|
||||
"topk_group": 1,
|
||||
"topk_method": "noaux_tc",
|
||||
"torch_dtype": "bfloat16",
|
||||
"torchscript": False,
|
||||
"transformers_version": "4.51.3",
|
||||
"use_bfloat16": False,
|
||||
"use_cache": True,
|
||||
"v_head_dim": 128,
|
||||
"vocab_size": 129280,
|
||||
"draft_vocab_size": 129280,
|
||||
"eagle_config": {
|
||||
"use_aux_hidden_state": True,
|
||||
"use_input_layernorm_in_first_layer": True,
|
||||
"use_last_layernorm": True,
|
||||
"use_mtp_layernorm": False
|
||||
}
|
||||
}
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
eagle_model_dir = Path(temp_dir)
|
||||
config_path = eagle_model_dir / "config.json"
|
||||
with config_path.open("w") as f:
|
||||
json.dump(eagle_config, f, indent=2)
|
||||
target_model_dir = f"{models_path}/DeepSeek-V3-Lite/nvfp4_moe_only"
|
||||
|
||||
# bs > 1 gives non-deterministic when doing IFB. There are slight chances
|
||||
# that ref and spec does not match 100%
|
||||
max_batch_size = 16
|
||||
max_draft_len = 3
|
||||
kv_cache_config = KvCacheConfig(enable_block_reuse=enable_block_reuse,
|
||||
max_tokens=8192)
|
||||
cuda_graph_config = CudaGraphConfig(
|
||||
batch_sizes=[1]) if use_cuda_graph else None
|
||||
|
||||
llm_common_config = dict(
|
||||
model=target_model_dir,
|
||||
attn_backend=attn_backend,
|
||||
disable_overlap_scheduler=disable_overlap_scheduler,
|
||||
cuda_graph_config=cuda_graph_config,
|
||||
max_batch_size=max_batch_size,
|
||||
max_num_tokens=4096,
|
||||
max_seq_len=4096,
|
||||
kv_cache_config=kv_cache_config,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
load_format="dummy",
|
||||
)
|
||||
|
||||
spec_config = EagleDecodingConfig(max_draft_len=max_draft_len,
|
||||
speculative_model_dir=eagle_model_dir,
|
||||
eagle3_one_model=use_one_model,
|
||||
load_format="dummy")
|
||||
|
||||
llm_spec = LLM(**llm_common_config, speculative_config=spec_config)
|
||||
|
||||
tok_ids = llm_spec.tokenizer.encode("The future of AI is")
|
||||
|
||||
sampling_params = SamplingParams(max_tokens=32, temperature=0)
|
||||
for output in llm_spec.generate_async(tok_ids,
|
||||
sampling_params,
|
||||
streaming=True):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_one_model", [True, False])
|
||||
def test_multi_eagle3(use_one_model: bool):
|
||||
use_cuda_graph = True
|
||||
|
||||
Loading…
Reference in New Issue
Block a user