diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index 6f4f7f85b8..1d759ce2bb 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -2366,12 +2366,10 @@ if IS_CUTLASS_DSL_AVAILABLE: c = torch.empty((m, n_out), dtype=c_dtype, device=a.device) # Allocate output scale factor (for FP4 output quantization) - # Shape: (32, 4, m // 128, 4, scale_n_out // 4, l) + # Shape: (32, 4, pad_up(m, 128) // 128, 4, scale_n_out // 4, l) scale_n_out = n_out // self.scaling_vector_size - c_sf_shape = (32, 4, m // 128, 4, scale_n_out // 4, l) - c_sf = torch.empty(c_sf_shape, - dtype=torch.float8_e4m3fn, - device=a.device) + c_sf_shape = (32, 4, pad_up(m, 128) // 128, 4, scale_n_out // 4, l) + c_sf = torch.empty(c_sf_shape, dtype=torch.uint8, device=a.device) # Get CUDA stream torch_stream = torch.cuda.current_stream() @@ -2516,7 +2514,7 @@ if IS_CUTLASS_DSL_AVAILABLE: # Output scale factor shape scale_n_out = n_out // scaling_vector_size - c_sf_shape = (32, 4, m // 128, 4, scale_n_out // 4, l) - output_sf = input.new_empty(c_sf_shape, dtype=torch.float8_e4m3fn) + c_sf_shape = (32, 4, pad_up(m, 128) // 128, 4, scale_n_out // 4, l) + output_sf = input.new_empty(c_sf_shape, dtype=torch.uint8) return output, output_sf diff --git a/tests/unittest/_torch/thop/parallel/test_cute_dsl_dense_gemm_swiglu.py b/tests/unittest/_torch/thop/parallel/test_cute_dsl_dense_gemm_swiglu.py index f6c3c286df..d9493563b9 100644 --- a/tests/unittest/_torch/thop/parallel/test_cute_dsl_dense_gemm_swiglu.py +++ b/tests/unittest/_torch/thop/parallel/test_cute_dsl_dense_gemm_swiglu.py @@ -10,6 +10,7 @@ import torch from tensorrt_llm._torch.modules.fused_moe.quantization import interleave_linear_and_gate from tensorrt_llm._torch.utils import swizzle_sf, unswizzle_sf from tensorrt_llm._utils import get_sm_version +from tensorrt_llm.math_utils import pad_up def swiglu_ref(x: torch.Tensor) -> torch.Tensor: @@ -79,7 +80,7 @@ def nvfp4_dense_gemm_ref( ) @pytest.mark.parametrize("num_expert", [1, 4, 8]) @pytest.mark.parametrize("weight_per_expert", [256, 512]) -@pytest.mark.parametrize("num_tokens", [128, 256]) +@pytest.mark.parametrize("num_tokens", [127, 256]) @pytest.mark.parametrize("hidden_size", [256, 512]) def test_nvfp4_dense_gemm_swiglu_blackwell( num_tokens: int, hidden_size: int, num_expert: int, weight_per_expert: int @@ -179,7 +180,27 @@ def test_nvfp4_dense_gemm_swiglu_blackwell( match_ratio = (c.view(torch.uint8) == c_ref_quantized.view(torch.uint8)).sum().item() / c.view( torch.uint8 ).numel() - - # Allow some tolerance due to different computation paths - print(f"Output match ratio: {match_ratio * 100:.2f}%") assert match_ratio > 0.95, f"Only {match_ratio * 100:.2f}% elements match, expected >= 95%" + + # Verify scale factor shape + scale_n_out = n_out // sf_vec_size + expected_c_sf_shape = (32, 4, pad_up(m, 128) // 128, 4, scale_n_out // 4, 1) + assert c_sf.shape == expected_c_sf_shape, ( + f"Expected c_sf shape {expected_c_sf_shape}, got {c_sf.shape}" + ) + assert c_sf.dtype == torch.uint8, f"Expected c_sf dtype uint8, got {c_sf.dtype}" + + # Verify scale factor values + # Unswizzle both c_sf and c_sf_ref for comparison (both are padded to 128) + c_sf_unswizzled = unswizzle_sf(c_sf.view(-1), pad_up(m, 128), n_out) + c_sf_ref_unswizzled = unswizzle_sf(c_sf_ref.view(-1), pad_up(m, 128), n_out) + + # Compare only the valid region (first m rows) + c_sf_valid = c_sf_unswizzled[:m, :] + c_sf_ref_valid = c_sf_ref_unswizzled[:m, :] + sf_match_ratio = ( + c_sf_valid.view(torch.uint8) == c_sf_ref_valid.view(torch.uint8) + ).sum().item() / c_sf_valid.view(torch.uint8).numel() + assert sf_match_ratio > 0.95, ( + f"Scale factor: only {sf_match_ratio * 100:.2f}% match, expected >= 95%" + )