[https://nvbugs/5753788][chore] fix empty tensor cutlass moe

Signed-off-by: leslie-fang25 <leslief@nvidia.com>
This commit is contained in:
leslie-fang25 2025-12-24 22:13:51 -08:00
parent 910a633066
commit 2ae080ca38
2 changed files with 29 additions and 5 deletions

View File

@ -146,7 +146,10 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
// The number of blocks for m. The m dimension will be padded to 128 for swizzled layout.
int numBlocksForM = layout == QuantizationSFLayout::SWIZZLED ? PadUpFn(m, 128) : m;
dim3 grid(std::min(numBlocksForM, multiProcessorCount * numBlocksPerSM));
int gridSize = std::min(numBlocksForM, multiProcessorCount * numBlocksPerSM);
// Ensure gridSize is not zero.
gridSize = std::max(1, gridSize);
dim3 grid(gridSize);
// Launch the cvt kernel.
auto* kernel_instance = useUE8M0
@ -165,7 +168,10 @@ void invokeFP4Quantization(int b, int m, int n, T const* input, float const* SFS
int const numBlocksPerSM = std::max(1u, 2048u / block.x);
// The number of blocks for m. The m dimension will be padded to 128 for swizzled layout.
int numBlocksForM = layout == QuantizationSFLayout::SWIZZLED ? PadUpFn(m, 128) : m;
dim3 grid(std::min(numBlocksForM, multiProcessorCount * numBlocksPerSM));
int gridSize = std::min(numBlocksForM, multiProcessorCount * numBlocksPerSM);
// Ensure gridSize is not zero.
gridSize = std::max(1, gridSize);
dim3 grid(gridSize);
// Launch the cvt kernel.
auto* kernel_instance = useUE8M0

View File

@ -319,11 +319,17 @@ class CutlassFusedMoE(MoE):
x_row = x.shape[0]
else:
x_row = x.shape[0]
hidden_size = x.shape[-1]
x, x_sf = torch.ops.trtllm.fp4_quantize(
x, self.fc31_input_scale, self.scaling_vector_size,
False, False)
if x_sf.numel() == 0 and x_sf.dim() == 1:
# View torch.Size[0] in to (0, -1) is not supported
x_sf = x_sf.view(
(0,
hidden_size // int(self.scaling_vector_size)))
# Reshape x_sf to 2D for post-quant communication
if x_sf is not None:
if x_sf is not None and x_sf.numel() != 0:
x_sf = x_sf.view((x_row, -1))
else:
if not isinstance(x, Fp4QuantizedTensor):
@ -494,8 +500,20 @@ class CutlassFusedMoE(MoE):
self._load_balancer_start_wait_gpu_stage(is_first_call)
# apply routing
token_selected_experts, token_final_scales = self.routing_method.apply(
router_logits)
if router_logits.numel() == 0:
# For dtype, refer to https://github.com/NVIDIA/TensorRT-LLM/blob/55f3cda66d05a2e5686c9c7512721beb522bc8b7/tensorrt_llm/_torch/modules/fused_moe/routing.py#L327
token_selected_experts = torch.empty(
(0, self.routing_method.experts_per_token),
dtype=torch.int32,
device=router_logits.device)
token_final_scales = torch.empty(
(0, self.routing_method.experts_per_token),
dtype=torch.float32,
device=router_logits.device)
else:
token_selected_experts, token_final_scales = self.routing_method.apply(
router_logits)
assert token_selected_experts.shape[
1] == self.routing_method.experts_per_token
assert token_selected_experts.shape == token_final_scales.shape