mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-11 13:33:40 +08:00
[TRTLLM-6747][feat] Merge add sparse exp and shared exp into local reduction (#7369)
Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
This commit is contained in:
parent
ec595a8e29
commit
a7ed26dd8b
@ -599,6 +599,7 @@ class MnnvlMoe:
|
||||
top_k: int,
|
||||
token_count: int,
|
||||
use_low_precision_combine: bool = False,
|
||||
do_reduce: bool = True,
|
||||
):
|
||||
assert x.dim() == 2, "2D tensor supported, please reshape."
|
||||
output_tensors = torch.ops.trtllm.moe_comm(
|
||||
@ -614,7 +615,8 @@ class MnnvlMoe:
|
||||
[True],
|
||||
use_low_precision_combine,
|
||||
)
|
||||
output_tensor = output_tensors[0]
|
||||
return torch.sum(
|
||||
output_tensor.reshape(token_count, top_k, x.shape[1]), dim=1, keepdim=False
|
||||
)
|
||||
output_tensor = output_tensors[0].reshape(token_count, top_k, x.shape[1])
|
||||
if do_reduce:
|
||||
return torch.sum(output_tensor, dim=1, keepdim=False)
|
||||
else:
|
||||
return output_tensor
|
||||
|
||||
@ -128,6 +128,12 @@ def weight_dequant(x: torch.Tensor,
|
||||
return y
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def moe_reduce_add_shared_output(routed_output, shared_output):
|
||||
routed_output = torch.sum(routed_output, dim=1, keepdim=False)
|
||||
return shared_output + routed_output
|
||||
|
||||
|
||||
class DeepseekV3MTPHead(nn.Module):
|
||||
|
||||
def __init__(self, model_config: ModelConfig[PretrainedConfig]):
|
||||
@ -585,6 +591,8 @@ class Deepseekv3MoE(nn.Module):
|
||||
do_finalize)
|
||||
return routed_output
|
||||
|
||||
# NOTE: define compiled helpers at module scope to avoid defining decorators inside compiled frames
|
||||
|
||||
routed_output, shared_output = maybe_execute_in_parallel(
|
||||
_compute_routed_output, _compute_shared_output,
|
||||
self.event_dict[EventType.Main],
|
||||
@ -593,9 +601,17 @@ class Deepseekv3MoE(nn.Module):
|
||||
if not do_finalize:
|
||||
return [shared_output, *routed_output]
|
||||
else:
|
||||
assert shared_output.size() == routed_output.size(
|
||||
), f'unmatched tensor shape'
|
||||
final_hidden_states = shared_output + routed_output
|
||||
if routed_output.dim() == 3:
|
||||
assert shared_output.numel(
|
||||
) * self.top_k == routed_output.numel(
|
||||
), 'unmatched tensor shape'
|
||||
final_hidden_states = moe_reduce_add_shared_output(
|
||||
routed_output, shared_output)
|
||||
else:
|
||||
assert shared_output.size() == routed_output.size(
|
||||
), 'unmatched tensor shape'
|
||||
final_hidden_states = shared_output + routed_output
|
||||
|
||||
if not self.use_dp and self.mapping.tp_size > 1:
|
||||
final_hidden_states = self.allreduce(
|
||||
final_hidden_states,
|
||||
|
||||
@ -906,7 +906,8 @@ class WideEPMoE(MoE):
|
||||
ep_size=self.ep_size,
|
||||
top_k=top_k,
|
||||
token_count=token_count,
|
||||
use_low_precision_combine=self.use_low_precision_combine)
|
||||
use_low_precision_combine=self.use_low_precision_combine,
|
||||
do_reduce=False)
|
||||
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user