mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
[#8694][fix] fix AutoDeploy cuda memory access failure in nvidia/NVIDIA-Nemotron-Nano-31B-A3-v3 (#8696)
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
This commit is contained in:
parent
b37a8a9a74
commit
e051a05e6c
@ -517,7 +517,16 @@ def triton_fused_moe(
|
||||
"""Triton unquantized MoE with 2-layer MLP and ReLU^2 activation."""
|
||||
x_shape = x.shape
|
||||
x2d = x.view(-1, x_shape[-1])
|
||||
topk_ids = selected_experts.to(torch.int32).contiguous()
|
||||
|
||||
# Get number of local experts from weight shape
|
||||
num_experts = w1_stacked_weight.shape[0]
|
||||
|
||||
# Clamp expert IDs to valid range to handle EP sharding
|
||||
# After EP sharding, some expert IDs may be negative (for experts on other ranks)
|
||||
# Clamp them to 0 (first expert) - these will be masked by routing_weights=0 anyway
|
||||
selected_experts_clamped = torch.clamp(selected_experts, min=0, max=num_experts - 1)
|
||||
|
||||
topk_ids = selected_experts_clamped.to(torch.int32).contiguous()
|
||||
topk_weights = routing_weights.to(torch.float32).contiguous()
|
||||
|
||||
out2d = _fused_moe_mlp_relu2(x2d, w1_stacked_weight, w2_stacked_weight, topk_ids, topk_weights)
|
||||
@ -565,7 +574,16 @@ def triton_quant_fp8_moe(
|
||||
|
||||
x_shape = x.shape
|
||||
x2d = x.view(-1, x_shape[-1])
|
||||
topk_ids = selected_experts.to(torch.int32).contiguous()
|
||||
|
||||
# Get number of local experts from weight shape
|
||||
num_experts = w1_weight.shape[0]
|
||||
|
||||
# Clamp expert IDs to valid range to handle EP sharding
|
||||
# After EP sharding, some expert IDs may be negative (for experts on other ranks)
|
||||
# Clamp them to 0 (first expert) - these will be masked by routing_weights=0 anyway
|
||||
selected_experts_clamped = torch.clamp(selected_experts, min=0, max=num_experts - 1)
|
||||
|
||||
topk_ids = selected_experts_clamped.to(torch.int32).contiguous()
|
||||
topk_weights = routing_weights.to(torch.float32).contiguous()
|
||||
|
||||
# Weights are already stacked [E, ...] - just ensure contiguous and extract scales
|
||||
|
||||
Loading…
Reference in New Issue
Block a user