[TRTLLM-6747][feat] Merge add sparse exp and shared exp into local reduction (#7369)

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
This commit is contained in:
Zongfei Jing 2025-09-01 09:20:00 +08:00 committed by GitHub
parent ec595a8e29
commit a7ed26dd8b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 27 additions and 8 deletions

View File

@ -599,6 +599,7 @@ class MnnvlMoe:
top_k: int,
token_count: int,
use_low_precision_combine: bool = False,
do_reduce: bool = True,
):
assert x.dim() == 2, "2D tensor supported, please reshape."
output_tensors = torch.ops.trtllm.moe_comm(
@ -614,7 +615,8 @@ class MnnvlMoe:
[True],
use_low_precision_combine,
)
output_tensor = output_tensors[0]
return torch.sum(
output_tensor.reshape(token_count, top_k, x.shape[1]), dim=1, keepdim=False
)
output_tensor = output_tensors[0].reshape(token_count, top_k, x.shape[1])
if do_reduce:
return torch.sum(output_tensor, dim=1, keepdim=False)
else:
return output_tensor

View File

@ -128,6 +128,12 @@ def weight_dequant(x: torch.Tensor,
return y
@torch.compile(dynamic=True)
def moe_reduce_add_shared_output(routed_output, shared_output):
routed_output = torch.sum(routed_output, dim=1, keepdim=False)
return shared_output + routed_output
class DeepseekV3MTPHead(nn.Module):
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
@ -585,6 +591,8 @@ class Deepseekv3MoE(nn.Module):
do_finalize)
return routed_output
# NOTE: define compiled helpers at module scope to avoid defining decorators inside compiled frames
routed_output, shared_output = maybe_execute_in_parallel(
_compute_routed_output, _compute_shared_output,
self.event_dict[EventType.Main],
@ -593,9 +601,17 @@ class Deepseekv3MoE(nn.Module):
if not do_finalize:
return [shared_output, *routed_output]
else:
assert shared_output.size() == routed_output.size(
), f'unmatched tensor shape'
final_hidden_states = shared_output + routed_output
if routed_output.dim() == 3:
assert shared_output.numel(
) * self.top_k == routed_output.numel(
), 'unmatched tensor shape'
final_hidden_states = moe_reduce_add_shared_output(
routed_output, shared_output)
else:
assert shared_output.size() == routed_output.size(
), 'unmatched tensor shape'
final_hidden_states = shared_output + routed_output
if not self.use_dp and self.mapping.tp_size > 1:
final_hidden_states = self.allreduce(
final_hidden_states,

View File

@ -906,7 +906,8 @@ class WideEPMoE(MoE):
ep_size=self.ep_size,
top_k=top_k,
token_count=token_count,
use_low_precision_combine=self.use_low_precision_combine)
use_low_precision_combine=self.use_low_precision_combine,
do_reduce=False)
return final_hidden_states