[None][feat] Perfect routing for Deepseek models (#11127)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
Balaram Buddharaju 2026-01-30 20:46:35 -08:00 committed by GitHub
parent baf9f7b4dc
commit 531f85dc9b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 47 additions and 1 deletions

View File

@ -175,3 +175,4 @@ The perfect router logits are specifically designed for `RenormalizeMoeRoutingMe
Currently supported:
- GPT-OSS (uses `RenormalizeMoeRoutingMethod`)
- DeepSeek-V3 / DeepSeek-R1 (uses `DeepSeekV3MoeRoutingMethod`)

View File

@ -56,7 +56,13 @@ from ..modules.embedding import Embedding
from ..modules.fused_moe import (DeepSeekV3MoeRoutingMethod, MoE,
MoEWeightLoadingMode, create_moe)
from ..modules.fused_moe.fused_moe_wide_ep import WideEPMoE
from ..modules.fused_moe.routing import Deepseekv3RoutingImpl
# isort: off
from ..modules.fused_moe.routing import (Deepseekv3RoutingImpl,
get_cached_perfect_router_logits,
precompute_common_perfect_router_logits
)
# isort: on
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@ -952,6 +958,18 @@ class Deepseekv3MoE(nn.Module):
for key in [EventType.Main, EventType.MoeShared]
}
# Store config values for perfect routing.
self.model_config = model_config
self.dtype = dtype
# Perfect router caching - precompute common logits if enabled.
if os.environ.get('ENABLE_PERFECT_ROUTER', '0') == '1':
precompute_common_perfect_router_logits(
num_experts=num_experts,
experts_per_token=top_k,
moe_ep_size=model_config.mapping.moe_ep_size,
dtype=dtype)
def _compute_shared_expert_tp_size(
self, intermediate_size: int,
block_size: int) -> tuple[int, float | None]:
@ -998,6 +1016,22 @@ class Deepseekv3MoE(nn.Module):
return model_config.quant_config_dict.get(
f"model.layers.{layer_idx}.mlp.experts", model_config.quant_config)
def _create_ideal_expert_load_balanced_logits(
self, num_tokens: int, num_experts: int,
device: torch.device) -> torch.Tensor:
"""
Create ideal logits that produce GPU-aware load balanced expert assignment.
This method uses the global cache to access precomputed logits to optimize performance.
"""
# Use global cached logits.
return get_cached_perfect_router_logits(
num_tokens=num_tokens,
num_experts=num_experts,
experts_per_token=self.top_k,
moe_ep_size=self.model_config.mapping.moe_ep_size,
device=device,
dtype=self.dtype)
def compute_routed_output(self, hidden_states, hidden_states_fp4,
all_rank_num_tokens, do_finalize):
# max-throughput
@ -1012,6 +1046,17 @@ class Deepseekv3MoE(nn.Module):
router_logits = self.gate(hidden_states)
# Use ideal load balanced logits if enabled, otherwise use gate output.
if os.environ.get('ENABLE_PERFECT_ROUTER', '0') == '1':
# WARNING: This discards the learned gate output and uses ideal logits for perfect load balancing.
# Only use this for testing load balancing strategies, not for actual inference.
# The gate is still computed to maintain realistic performance measurement.
num_tokens, num_experts = router_logits.shape
router_logits = self._create_ideal_expert_load_balanced_logits(
num_tokens=num_tokens,
num_experts=num_experts,
device=hidden_states.device)
routed_output = self.experts(
hidden_states_fp4
if hidden_states_fp4 is not None else hidden_states,