[None][chore] Enable tvm_ffi for cute dsl nvfp4_gemm to reduce host overhead. (#9690)

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
This commit is contained in:
Li Min 2025-12-08 13:28:11 +08:00 committed by GitHub
parent 2f526583fb
commit a422d70be6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 149 additions and 82 deletions

View File

@ -73,3 +73,5 @@ nvidia-cutlass-dsl==4.3.1; python_version >= "3.10"
plotly
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
partial_json_parser
apache-tvm-ffi==0.1.4 # used for reduce nvidia-cutlass-dsl host overhead
torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf

View File

@ -240,7 +240,8 @@ if IS_CUTLASS_DSL_AVAILABLE:
def __init__(self,
output_dtype: torch.dtype,
to_userbuffers: bool = False):
to_userbuffers: bool = False,
use_tvm_ffi: bool = True):
super().__init__()
if output_dtype != torch.bfloat16:
@ -249,17 +250,19 @@ if IS_CUTLASS_DSL_AVAILABLE:
)
self.output_dtype = output_dtype
self.to_userbuffers = to_userbuffers
self.use_tvm_ffi = use_tvm_ffi
def unique_id(self):
return (self.output_dtype, self.to_userbuffers)
return (self.output_dtype, self.to_userbuffers, self.use_tvm_ffi)
def __hash__(self):
return hash((self.output_dtype, self.to_userbuffers))
return hash(
(self.output_dtype, self.to_userbuffers, self.use_tvm_ffi))
def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
return self.output_dtype == other.output_dtype and self.to_userbuffers == other.to_userbuffers
return self.output_dtype == other.output_dtype and self.to_userbuffers == other.to_userbuffers and self.use_tvm_ffi == other.use_tvm_ffi
def get_valid_tactics(
self,
@ -464,51 +467,94 @@ if IS_CUTLASS_DSL_AVAILABLE:
f"CuteDSL: weight scale factor size mismatch. "
f"Expected {expected_b_sf_size} (sf_n={sf_n} * sf_k={sf_k}), "
f"got {b_sf_tensor.numel()} for shape N={n}, K={real_k}")
if alpha_tensor.numel() != 1:
raise ValueError(f"CuteDSL: alpha size mismatch. "
f"Expected 1, got {alpha_tensor.numel()}")
# Reshape to CuteDSL's expected format (just a view, no copy)
a_sf_tensor = a_sf_tensor.reshape(sf_m * sf_k)
b_sf_tensor = b_sf_tensor.reshape(sf_n * sf_k)
a_ptr = self.make_cute_dsl_global_pointer(a_tensor,
cutlass.Float4E2M1FN, 32)
b_ptr = self.make_cute_dsl_global_pointer(b_tensor,
cutlass.Float4E2M1FN, 32)
a_sf_ptr = self.make_cute_dsl_global_pointer(
a_sf_tensor, cutlass.Float8E4M3FN, 16)
b_sf_ptr = self.make_cute_dsl_global_pointer(
b_sf_tensor, cutlass.Float8E4M3FN, 16)
c_ptr = self.make_cute_dsl_global_pointer(c_tensor,
cutlass.BFloat16, 16)
# Create pointer to alpha on device
alpha_ptr = self.make_cute_dsl_global_pointer(
alpha_tensor, cutlass.Float32, 4)
if not self.use_tvm_ffi:
a_ptr = self.make_cute_dsl_global_pointer(
a_tensor, cutlass.Float4E2M1FN, 32)
b_ptr = self.make_cute_dsl_global_pointer(
b_tensor, cutlass.Float4E2M1FN, 32)
a_sf_ptr = self.make_cute_dsl_global_pointer(
a_sf_tensor, cutlass.Float8E4M3FN, 16)
b_sf_ptr = self.make_cute_dsl_global_pointer(
b_sf_tensor, cutlass.Float8E4M3FN, 16)
c_ptr = self.make_cute_dsl_global_pointer(
c_tensor, cutlass.BFloat16, 16)
alpha_cute_tensor = cute.runtime.from_dlpack(alpha_tensor)
# get stream
torch_stream = torch.cuda.current_stream()
stream = cuda.CUstream(torch_stream.cuda_stream)
# get stream
torch_stream = torch.cuda.current_stream()
stream = cuda.CUstream(torch_stream.cuda_stream)
cache_key = (sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab,
use_prefetch)
if swap_ab:
kernel_a_ptr = b_ptr
kernel_a_sf_ptr = b_sf_ptr
kernel_b_ptr = a_ptr
kernel_b_sf_ptr = a_sf_ptr
kernel_m = n
kernel_n = m
kernel_sf_m = sf_n
kernel_sf_n = sf_m
kernel_a_tensor = b_tensor
kernel_a_sf_tensor = b_sf_tensor
kernel_b_tensor = a_tensor
kernel_b_sf_tensor = a_sf_tensor
if not self.use_tvm_ffi:
kernel_a_ptr = b_ptr
kernel_a_sf_ptr = b_sf_ptr
kernel_b_ptr = a_ptr
kernel_b_sf_ptr = a_sf_ptr
else:
kernel_a_ptr = a_ptr
kernel_a_sf_ptr = a_sf_ptr
kernel_b_ptr = b_ptr
kernel_b_sf_ptr = b_sf_ptr
kernel_m = m
kernel_n = n
kernel_sf_m = sf_m
kernel_sf_n = sf_n
kernel_a_tensor = a_tensor
kernel_a_sf_tensor = a_sf_tensor
kernel_b_tensor = b_tensor
kernel_b_sf_tensor = b_sf_tensor
if not self.use_tvm_ffi:
kernel_a_ptr = a_ptr
kernel_a_sf_ptr = a_sf_ptr
kernel_b_ptr = b_ptr
kernel_b_sf_ptr = b_sf_ptr
if cache_key not in self.__class__.kernel_cache:
if self.use_tvm_ffi:
a_ptr = self.make_cute_dsl_global_pointer(
a_tensor, cutlass.Float4E2M1FN, 32)
b_ptr = self.make_cute_dsl_global_pointer(
b_tensor, cutlass.Float4E2M1FN, 32)
a_sf_ptr = self.make_cute_dsl_global_pointer(
a_sf_tensor, cutlass.Float8E4M3FN, 16)
b_sf_ptr = self.make_cute_dsl_global_pointer(
b_sf_tensor, cutlass.Float8E4M3FN, 16)
c_ptr = self.make_cute_dsl_global_pointer(
c_tensor, cutlass.BFloat16, 16)
alpha_cute_tensor = cute.runtime.from_dlpack(alpha_tensor)
# make faked stream
stream = cute.runtime.make_fake_stream(
use_tvm_ffi_env_stream=True)
if swap_ab:
kernel_a_ptr = b_ptr
kernel_a_sf_ptr = b_sf_ptr
kernel_b_ptr = a_ptr
kernel_b_sf_ptr = a_sf_ptr
else:
kernel_a_ptr = a_ptr
kernel_a_sf_ptr = a_sf_ptr
kernel_b_ptr = b_ptr
kernel_b_sf_ptr = b_sf_ptr
gemm = self.__class__.kernel_class(
sf_vec_size,
mma_tiler_mn,
@ -520,6 +566,8 @@ if IS_CUTLASS_DSL_AVAILABLE:
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1])
# Note: when tvm_ffi fake stream is used, at least one parameter shoube be tensor type,
# so we make alpha as the cute.Tensor type in the jit func.
compiled_gemm = cute.compile(
gemm.wrapper,
kernel_m,
@ -528,17 +576,18 @@ if IS_CUTLASS_DSL_AVAILABLE:
kernel_sf_m // 128,
kernel_sf_n // 128,
sf_k // 4,
1,
1, # batch
kernel_a_ptr,
kernel_b_ptr,
kernel_a_sf_ptr,
kernel_b_sf_ptr,
c_ptr,
alpha_ptr, # Pass alpha as device pointer
alpha_cute_tensor,
max_active_clusters,
stream,
swap_ab,
options=f"--opt-level 2",
options=f"--opt-level 2 --enable-tvm-ffi"
if self.use_tvm_ffi else "--opt-level 2",
)
self.__class__.kernel_cache[cache_key] = compiled_gemm
@ -546,21 +595,39 @@ if IS_CUTLASS_DSL_AVAILABLE:
compiled_gemm = self.__class__.kernel_cache[cache_key]
# launch gemm kernel
compiled_gemm(
kernel_m,
kernel_n,
real_k,
kernel_sf_m // 128,
kernel_sf_n // 128,
sf_k // 4,
kernel_a_ptr,
kernel_b_ptr,
kernel_a_sf_ptr,
kernel_b_sf_ptr,
c_ptr,
alpha_ptr, # Pass alpha as device pointer
stream,
)
if self.use_tvm_ffi:
# call with torch pointer types and no need to pass stream.
compiled_gemm(
kernel_m,
kernel_n,
real_k,
kernel_sf_m // 128,
kernel_sf_n // 128,
sf_k // 4,
kernel_a_tensor.data_ptr(),
kernel_b_tensor.data_ptr(),
kernel_a_sf_tensor.data_ptr(),
kernel_b_sf_tensor.data_ptr(),
c_tensor.data_ptr(),
alpha_tensor,
)
else:
# call with cute types and need to pass torch stream.
compiled_gemm(
kernel_m,
kernel_n,
real_k,
kernel_sf_m // 128,
kernel_sf_n // 128,
sf_k // 4,
kernel_a_ptr,
kernel_b_ptr,
kernel_a_sf_ptr,
kernel_b_sf_ptr,
c_ptr,
alpha_cute_tensor,
stream,
)
if swap_ab:
c_tensor = c_tensor.permute(1, 0)
@ -578,6 +645,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
alpha: torch.Tensor,
output_dtype: torch.dtype,
to_userbuffers: bool = False,
use_tvm_ffi: bool = True,
) -> torch.Tensor:
"""CuteDSL-based NVFP4 GEMM optimized for Blackwell.
@ -589,6 +657,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
alpha: Scaling factor
output_dtype: Output data type (must be bfloat16)
to_userbuffers: Whether to allocate output from UserBuffers pool
use_tvm_ffi: Whether to use TVM-FFI to call the kernel. Enable this option could help reduce the kernel host launch overhead.
Note:
This function is primarily used internally by nvfp4_gemm.
@ -604,7 +673,8 @@ if IS_CUTLASS_DSL_AVAILABLE:
tuner = AutoTuner.get()
runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers)
runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers,
use_tvm_ffi)
inputs = [input, weight, input_scale, weight_scale, alpha]
_, best_tactic = tuner.choose_one(
"trtllm::cute_dsl_nvfp4_gemm_blackwell",
@ -625,6 +695,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
alpha: torch.Tensor, # Match custom op signature
output_dtype: torch.dtype,
to_userbuffers: bool = False,
use_tvm_ffi: bool = True,
):
# [m, k]
shape = list(mat_a.shape)

View File

@ -2017,20 +2017,19 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
@cute.jit
def wrapper(
self,
m,
n,
k,
sf_m,
sf_n,
sf_k,
m: cutlass.Int32,
n: cutlass.Int32,
k: cutlass.Int32,
sf_m: cutlass.Int32,
sf_n: cutlass.Int32,
sf_k: cutlass.Int32,
l: cutlass.Constexpr,
a_ptr: cute.Pointer,
b_ptr: cute.Pointer,
a_sf_ptr: cute.Pointer,
b_sf_ptr: cute.Pointer,
c_ptr: cute.Pointer,
alpha: cute.
Pointer, # Device pointer to alpha, will be converted to Tensor
alpha_tensor: cute.Tensor,
max_active_clusters: cutlass.Constexpr,
current_stream: cuda.CUstream,
swap_ab: cutlass.Constexpr = False,
@ -2051,7 +2050,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
a_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for A.
b_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for B.
c_ptr (cute.Pointer): Pointer to the C tensor.
alpha (cute.Pointer): Device pointer to alpha scaling factor (converted to Tensor internally).
alpha_tensor (cute.Tensor): Device tensor to alpha scaling factor.
max_active_clusters (cutlass.Constexpr): Maximum number of active
clusters.
current_stream (cuda.CUstream): CUDA stream for the operation.
@ -2096,9 +2095,6 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
(32, 4, sf_n, 4, sf_k, l),
order=(2, 1, 4, 0, 3, 5),
))
alpha_tensor = cute.make_tensor(alpha,
layout=cute.make_ordered_layout(
(1, ), order=(0, )))
self(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, alpha_tensor,
max_active_clusters, current_stream, epilogue_op)

View File

@ -313,15 +313,17 @@ def nvfp4_gemm_perf_test(
x_sf_block_list = [x_sf_block]
w_sf_block_list = [w_sf_block]
alpha_tensor = torch.tensor([1.0]).cuda()
with torch.inference_mode(), autotune():
with nvtx.annotate(
f"cute_dsl tune, m={SEQ_LEN}, k={HIDDEN_SIZE}, n={OUTPUT_SIZE}",
color="orange",
):
output = torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell(
x_fp4, w_fp4, x_sf_block, w_sf_block, 1.0, dtype)
x_fp4, w_fp4, x_sf_block, w_sf_block, alpha_tensor, dtype)
from tensorrt_llm._torch.autotuner import AutoTuner
AutoTuner.get().print_statistics()
alpha_tensor = torch.tensor(1.0).cuda()
if test_ref:
with nvtx.annotate(
f"ref tune, m={SEQ_LEN}, k={HIDDEN_SIZE}, n={OUTPUT_SIZE}",
@ -342,7 +344,7 @@ def nvfp4_gemm_perf_test(
w_fp4_list[buffer_idx % workspace_count],
x_sf_block_list[buffer_idx % workspace_count],
w_sf_block_list[buffer_idx % workspace_count],
1.0,
alpha_tensor,
dtype,
)
buffer_idx = buffer_idx + 1
@ -356,7 +358,7 @@ def nvfp4_gemm_perf_test(
w_fp4_list[buffer_idx % workspace_count],
x_sf_block_list[buffer_idx % workspace_count],
w_sf_block_list[buffer_idx % workspace_count],
1.0,
alpha_tensor,
dtype,
)
buffer_idx = buffer_idx + 1
@ -457,7 +459,7 @@ def test_nvfp4_gemm_unified_all_tactics(dtype, mnk):
x_fp4, x_sf_block = torch.ops.trtllm.fp4_quantize(
x, x_sf_global, scaling_vector_size, False)
alpha_ref = 1.0 / (w_sf_global * x_sf_global)
alpha_tensor = torch.tensor(alpha_ref, dtype=torch.float32).cuda()
alpha_tensor = torch.tensor([alpha_ref], dtype=torch.float32).cuda()
# Reference: Use CUTLASS backend explicitly for reference output
with torch.inference_mode():
@ -749,23 +751,19 @@ def test_fp4_linear_cuda_core(dtype, mnk):
if __name__ == "__main__":
# m, n, k
fp4_linear_perf_test(torch.bfloat16, 128, 7168, 16384)
fp4_linear_perf_test(torch.bfloat16, 128, 24576, 1536)
fp4_linear_perf_test(torch.bfloat16, 128, 2112, 7168)
fp4_linear_perf_test(torch.bfloat16, 128, 4096, 7168)
fp4_linear_perf_test(torch.bfloat16, 128, 7168, 2048)
nvfp4_gemm_perf_test(torch.bfloat16, 128, 7168, 16384)
# group-1 test cases
for tokens in [128, 8192]:
nvfp4_gemm_perf_test(torch.bfloat16, tokens, 7168, 16384)
nvfp4_gemm_perf_test(torch.bfloat16, tokens, 24576, 1536)
nvfp4_gemm_perf_test(torch.bfloat16, tokens, 2112, 7168)
nvfp4_gemm_perf_test(torch.bfloat16, tokens, 4096, 7168)
nvfp4_gemm_perf_test(torch.bfloat16, tokens, 7168, 2048)
# # group-1 test cases
# for tokens in [128, 8192]:
# nvfp4_gemm_perf_test(torch.bfloat16, tokens, 7168, 16384)
# nvfp4_gemm_perf_test(torch.bfloat16, tokens, 24576, 1536)
# nvfp4_gemm_perf_test(torch.bfloat16, tokens, 2112, 7168)
# nvfp4_gemm_perf_test(torch.bfloat16, tokens, 4096, 7168)
# nvfp4_gemm_perf_test(torch.bfloat16, tokens, 7168, 2048)
# group-2 test cases
for m in [128, 256, 512]:
nvfp4_gemm_perf_test(torch.bfloat16, m, 131584, 7168)
nvfp4_gemm_perf_test(torch.bfloat16, m, 7168, 65792)
nvfp4_gemm_perf_test(torch.bfloat16, m, 227368, 2560, test_ref=False)
nvfp4_gemm_perf_test(torch.bfloat16, m, 2560, 113664)
# # group-2 test cases
# for m in [128, 256, 512]:
# nvfp4_gemm_perf_test(torch.bfloat16, m, 131584, 7168)
# nvfp4_gemm_perf_test(torch.bfloat16, m, 7168, 65792)
# nvfp4_gemm_perf_test(torch.bfloat16, m, 227368, 2560, test_ref=False)
# nvfp4_gemm_perf_test(torch.bfloat16, m, 2560, 113664)