Fix dense GEMM integration and add scale factor validation

- Fix c_sf shape calculation: use pad_up(m, 128) // 128 for non-128-aligned m
- Change c_sf dtype to uint8 to match fp4_utils.py SF_DTYPE
- Add scale factor shape and value validation in unit test
- Fix test to handle padded scale factors correctly

Signed-off-by: Zongfei Jing <20381269+zongfeijing@users.noreply.github.com>
This commit is contained in:
Zongfei Jing 2026-01-06 01:45:03 -08:00
parent 84dbc447f4
commit 49d887f521
2 changed files with 30 additions and 11 deletions

View File

@ -2366,12 +2366,10 @@ if IS_CUTLASS_DSL_AVAILABLE:
c = torch.empty((m, n_out), dtype=c_dtype, device=a.device)
# Allocate output scale factor (for FP4 output quantization)
# Shape: (32, 4, m // 128, 4, scale_n_out // 4, l)
# Shape: (32, 4, pad_up(m, 128) // 128, 4, scale_n_out // 4, l)
scale_n_out = n_out // self.scaling_vector_size
c_sf_shape = (32, 4, m // 128, 4, scale_n_out // 4, l)
c_sf = torch.empty(c_sf_shape,
dtype=torch.float8_e4m3fn,
device=a.device)
c_sf_shape = (32, 4, pad_up(m, 128) // 128, 4, scale_n_out // 4, l)
c_sf = torch.empty(c_sf_shape, dtype=torch.uint8, device=a.device)
# Get CUDA stream
torch_stream = torch.cuda.current_stream()
@ -2516,7 +2514,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
# Output scale factor shape
scale_n_out = n_out // scaling_vector_size
c_sf_shape = (32, 4, m // 128, 4, scale_n_out // 4, l)
output_sf = input.new_empty(c_sf_shape, dtype=torch.float8_e4m3fn)
c_sf_shape = (32, 4, pad_up(m, 128) // 128, 4, scale_n_out // 4, l)
output_sf = input.new_empty(c_sf_shape, dtype=torch.uint8)
return output, output_sf

View File

@ -10,6 +10,7 @@ import torch
from tensorrt_llm._torch.modules.fused_moe.quantization import interleave_linear_and_gate
from tensorrt_llm._torch.utils import swizzle_sf, unswizzle_sf
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.math_utils import pad_up
def swiglu_ref(x: torch.Tensor) -> torch.Tensor:
@ -79,7 +80,7 @@ def nvfp4_dense_gemm_ref(
)
@pytest.mark.parametrize("num_expert", [1, 4, 8])
@pytest.mark.parametrize("weight_per_expert", [256, 512])
@pytest.mark.parametrize("num_tokens", [128, 256])
@pytest.mark.parametrize("num_tokens", [127, 256])
@pytest.mark.parametrize("hidden_size", [256, 512])
def test_nvfp4_dense_gemm_swiglu_blackwell(
num_tokens: int, hidden_size: int, num_expert: int, weight_per_expert: int
@ -179,7 +180,27 @@ def test_nvfp4_dense_gemm_swiglu_blackwell(
match_ratio = (c.view(torch.uint8) == c_ref_quantized.view(torch.uint8)).sum().item() / c.view(
torch.uint8
).numel()
# Allow some tolerance due to different computation paths
print(f"Output match ratio: {match_ratio * 100:.2f}%")
assert match_ratio > 0.95, f"Only {match_ratio * 100:.2f}% elements match, expected >= 95%"
# Verify scale factor shape
scale_n_out = n_out // sf_vec_size
expected_c_sf_shape = (32, 4, pad_up(m, 128) // 128, 4, scale_n_out // 4, 1)
assert c_sf.shape == expected_c_sf_shape, (
f"Expected c_sf shape {expected_c_sf_shape}, got {c_sf.shape}"
)
assert c_sf.dtype == torch.uint8, f"Expected c_sf dtype uint8, got {c_sf.dtype}"
# Verify scale factor values
# Unswizzle both c_sf and c_sf_ref for comparison (both are padded to 128)
c_sf_unswizzled = unswizzle_sf(c_sf.view(-1), pad_up(m, 128), n_out)
c_sf_ref_unswizzled = unswizzle_sf(c_sf_ref.view(-1), pad_up(m, 128), n_out)
# Compare only the valid region (first m rows)
c_sf_valid = c_sf_unswizzled[:m, :]
c_sf_ref_valid = c_sf_ref_unswizzled[:m, :]
sf_match_ratio = (
c_sf_valid.view(torch.uint8) == c_sf_ref_valid.view(torch.uint8)
).sum().item() / c_sf_valid.view(torch.uint8).numel()
assert sf_match_ratio > 0.95, (
f"Scale factor: only {sf_match_ratio * 100:.2f}% match, expected >= 95%"
)