mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-04 18:21:52 +08:00
[None][feat] Perfect routing for Deepseek models (#11127)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
parent
baf9f7b4dc
commit
531f85dc9b
@ -175,3 +175,4 @@ The perfect router logits are specifically designed for `RenormalizeMoeRoutingMe
|
||||
|
||||
Currently supported:
|
||||
- GPT-OSS (uses `RenormalizeMoeRoutingMethod`)
|
||||
- DeepSeek-V3 / DeepSeek-R1 (uses `DeepSeekV3MoeRoutingMethod`)
|
||||
|
||||
@ -56,7 +56,13 @@ from ..modules.embedding import Embedding
|
||||
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, MoE,
|
||||
MoEWeightLoadingMode, create_moe)
|
||||
from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE
|
||||
from ..modules.fused_moe.routing import Deepseekv3RoutingImpl
|
||||
|
||||
# isort: off
|
||||
from ..modules.fused_moe.routing import (Deepseekv3RoutingImpl,
|
||||
get_cached_perfect_router_logits,
|
||||
precompute_common_perfect_router_logits
|
||||
)
|
||||
# isort: on
|
||||
from ..modules.gated_mlp import GatedMLP
|
||||
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
|
||||
from ..modules.multi_stream_utils import maybe_execute_in_parallel
|
||||
@ -952,6 +958,18 @@ class Deepseekv3MoE(nn.Module):
|
||||
for key in [EventType.Main, EventType.MoeShared]
|
||||
}
|
||||
|
||||
# Store config values for perfect routing.
|
||||
self.model_config = model_config
|
||||
self.dtype = dtype
|
||||
|
||||
# Perfect router caching - precompute common logits if enabled.
|
||||
if os.environ.get('ENABLE_PERFECT_ROUTER', '0') == '1':
|
||||
precompute_common_perfect_router_logits(
|
||||
num_experts=num_experts,
|
||||
experts_per_token=top_k,
|
||||
moe_ep_size=model_config.mapping.moe_ep_size,
|
||||
dtype=dtype)
|
||||
|
||||
def _compute_shared_expert_tp_size(
|
||||
self, intermediate_size: int,
|
||||
block_size: int) -> tuple[int, float | None]:
|
||||
@ -998,6 +1016,22 @@ class Deepseekv3MoE(nn.Module):
|
||||
return model_config.quant_config_dict.get(
|
||||
f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config)
|
||||
|
||||
def _create_ideal_expert_load_balanced_logits(
|
||||
self, num_tokens: int, num_experts: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Create ideal logits that produce GPU-aware load balanced expert assignment.
|
||||
This method uses the global cache to access precomputed logits to optimize performance.
|
||||
"""
|
||||
# Use global cached logits.
|
||||
return get_cached_perfect_router_logits(
|
||||
num_tokens=num_tokens,
|
||||
num_experts=num_experts,
|
||||
experts_per_token=self.top_k,
|
||||
moe_ep_size=self.model_config.mapping.moe_ep_size,
|
||||
device=device,
|
||||
dtype=self.dtype)
|
||||
|
||||
def compute_routed_output(self, hidden_states, hidden_states_fp4,
|
||||
all_rank_num_tokens, do_finalize):
|
||||
# max-throughput
|
||||
@ -1012,6 +1046,17 @@ class Deepseekv3MoE(nn.Module):
|
||||
|
||||
router_logits = self.gate(hidden_states)
|
||||
|
||||
# Use ideal load balanced logits if enabled, otherwise use gate output.
|
||||
if os.environ.get('ENABLE_PERFECT_ROUTER', '0') == '1':
|
||||
# WARNING: This discards the learned gate output and uses ideal logits for perfect load balancing.
|
||||
# Only use this for testing load balancing strategies, not for actual inference.
|
||||
# The gate is still computed to maintain realistic performance measurement.
|
||||
num_tokens, num_experts = router_logits.shape
|
||||
router_logits = self._create_ideal_expert_load_balanced_logits(
|
||||
num_tokens=num_tokens,
|
||||
num_experts=num_experts,
|
||||
device=hidden_states.device)
|
||||
|
||||
routed_output = self.experts(
|
||||
hidden_states_fp4
|
||||
if hidden_states_fp4 is not None else hidden_states,
|
||||
|
||||
Loading…
Reference in New Issue
Block a user