mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][chore] Enable tvm_ffi for cute dsl nvfp4_gemm to reduce host overhead. (#9690)
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
This commit is contained in:
parent
2f526583fb
commit
a422d70be6
@ -73,3 +73,5 @@ nvidia-cutlass-dsl==4.3.1; python_version >= "3.10"
|
||||
plotly
|
||||
numexpr<2.14.0 # WAR for attempted use of nonexistent numpy.typing
|
||||
partial_json_parser
|
||||
apache-tvm-ffi==0.1.4 # used for reduce nvidia-cutlass-dsl host overhead
|
||||
torch-c-dlpack-ext==0.1.3 # used for reduce nvidia-cutlass-dsl host overhead, optional package for improved torch tensor calling perf
|
||||
|
||||
@ -240,7 +240,8 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
|
||||
def __init__(self,
|
||||
output_dtype: torch.dtype,
|
||||
to_userbuffers: bool = False):
|
||||
to_userbuffers: bool = False,
|
||||
use_tvm_ffi: bool = True):
|
||||
super().__init__()
|
||||
|
||||
if output_dtype != torch.bfloat16:
|
||||
@ -249,17 +250,19 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
)
|
||||
self.output_dtype = output_dtype
|
||||
self.to_userbuffers = to_userbuffers
|
||||
self.use_tvm_ffi = use_tvm_ffi
|
||||
|
||||
def unique_id(self):
|
||||
return (self.output_dtype, self.to_userbuffers)
|
||||
return (self.output_dtype, self.to_userbuffers, self.use_tvm_ffi)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.output_dtype, self.to_userbuffers))
|
||||
return hash(
|
||||
(self.output_dtype, self.to_userbuffers, self.use_tvm_ffi))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, self.__class__):
|
||||
return False
|
||||
return self.output_dtype == other.output_dtype and self.to_userbuffers == other.to_userbuffers
|
||||
return self.output_dtype == other.output_dtype and self.to_userbuffers == other.to_userbuffers and self.use_tvm_ffi == other.use_tvm_ffi
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
@ -464,51 +467,94 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
f"CuteDSL: weight scale factor size mismatch. "
|
||||
f"Expected {expected_b_sf_size} (sf_n={sf_n} * sf_k={sf_k}), "
|
||||
f"got {b_sf_tensor.numel()} for shape N={n}, K={real_k}")
|
||||
if alpha_tensor.numel() != 1:
|
||||
raise ValueError(f"CuteDSL: alpha size mismatch. "
|
||||
f"Expected 1, got {alpha_tensor.numel()}")
|
||||
|
||||
# Reshape to CuteDSL's expected format (just a view, no copy)
|
||||
a_sf_tensor = a_sf_tensor.reshape(sf_m * sf_k)
|
||||
b_sf_tensor = b_sf_tensor.reshape(sf_n * sf_k)
|
||||
|
||||
a_ptr = self.make_cute_dsl_global_pointer(a_tensor,
|
||||
cutlass.Float4E2M1FN, 32)
|
||||
b_ptr = self.make_cute_dsl_global_pointer(b_tensor,
|
||||
cutlass.Float4E2M1FN, 32)
|
||||
a_sf_ptr = self.make_cute_dsl_global_pointer(
|
||||
a_sf_tensor, cutlass.Float8E4M3FN, 16)
|
||||
b_sf_ptr = self.make_cute_dsl_global_pointer(
|
||||
b_sf_tensor, cutlass.Float8E4M3FN, 16)
|
||||
c_ptr = self.make_cute_dsl_global_pointer(c_tensor,
|
||||
cutlass.BFloat16, 16)
|
||||
# Create pointer to alpha on device
|
||||
alpha_ptr = self.make_cute_dsl_global_pointer(
|
||||
alpha_tensor, cutlass.Float32, 4)
|
||||
if not self.use_tvm_ffi:
|
||||
a_ptr = self.make_cute_dsl_global_pointer(
|
||||
a_tensor, cutlass.Float4E2M1FN, 32)
|
||||
b_ptr = self.make_cute_dsl_global_pointer(
|
||||
b_tensor, cutlass.Float4E2M1FN, 32)
|
||||
a_sf_ptr = self.make_cute_dsl_global_pointer(
|
||||
a_sf_tensor, cutlass.Float8E4M3FN, 16)
|
||||
b_sf_ptr = self.make_cute_dsl_global_pointer(
|
||||
b_sf_tensor, cutlass.Float8E4M3FN, 16)
|
||||
c_ptr = self.make_cute_dsl_global_pointer(
|
||||
c_tensor, cutlass.BFloat16, 16)
|
||||
alpha_cute_tensor = cute.runtime.from_dlpack(alpha_tensor)
|
||||
|
||||
# get stream
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
# get stream
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
cache_key = (sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab,
|
||||
use_prefetch)
|
||||
if swap_ab:
|
||||
kernel_a_ptr = b_ptr
|
||||
kernel_a_sf_ptr = b_sf_ptr
|
||||
kernel_b_ptr = a_ptr
|
||||
kernel_b_sf_ptr = a_sf_ptr
|
||||
kernel_m = n
|
||||
kernel_n = m
|
||||
kernel_sf_m = sf_n
|
||||
kernel_sf_n = sf_m
|
||||
|
||||
kernel_a_tensor = b_tensor
|
||||
kernel_a_sf_tensor = b_sf_tensor
|
||||
kernel_b_tensor = a_tensor
|
||||
kernel_b_sf_tensor = a_sf_tensor
|
||||
|
||||
if not self.use_tvm_ffi:
|
||||
kernel_a_ptr = b_ptr
|
||||
kernel_a_sf_ptr = b_sf_ptr
|
||||
kernel_b_ptr = a_ptr
|
||||
kernel_b_sf_ptr = a_sf_ptr
|
||||
else:
|
||||
kernel_a_ptr = a_ptr
|
||||
kernel_a_sf_ptr = a_sf_ptr
|
||||
kernel_b_ptr = b_ptr
|
||||
kernel_b_sf_ptr = b_sf_ptr
|
||||
kernel_m = m
|
||||
kernel_n = n
|
||||
kernel_sf_m = sf_m
|
||||
kernel_sf_n = sf_n
|
||||
|
||||
kernel_a_tensor = a_tensor
|
||||
kernel_a_sf_tensor = a_sf_tensor
|
||||
kernel_b_tensor = b_tensor
|
||||
kernel_b_sf_tensor = b_sf_tensor
|
||||
|
||||
if not self.use_tvm_ffi:
|
||||
kernel_a_ptr = a_ptr
|
||||
kernel_a_sf_ptr = a_sf_ptr
|
||||
kernel_b_ptr = b_ptr
|
||||
kernel_b_sf_ptr = b_sf_ptr
|
||||
|
||||
if cache_key not in self.__class__.kernel_cache:
|
||||
if self.use_tvm_ffi:
|
||||
a_ptr = self.make_cute_dsl_global_pointer(
|
||||
a_tensor, cutlass.Float4E2M1FN, 32)
|
||||
b_ptr = self.make_cute_dsl_global_pointer(
|
||||
b_tensor, cutlass.Float4E2M1FN, 32)
|
||||
a_sf_ptr = self.make_cute_dsl_global_pointer(
|
||||
a_sf_tensor, cutlass.Float8E4M3FN, 16)
|
||||
b_sf_ptr = self.make_cute_dsl_global_pointer(
|
||||
b_sf_tensor, cutlass.Float8E4M3FN, 16)
|
||||
c_ptr = self.make_cute_dsl_global_pointer(
|
||||
c_tensor, cutlass.BFloat16, 16)
|
||||
alpha_cute_tensor = cute.runtime.from_dlpack(alpha_tensor)
|
||||
# make faked stream
|
||||
stream = cute.runtime.make_fake_stream(
|
||||
use_tvm_ffi_env_stream=True)
|
||||
|
||||
if swap_ab:
|
||||
kernel_a_ptr = b_ptr
|
||||
kernel_a_sf_ptr = b_sf_ptr
|
||||
kernel_b_ptr = a_ptr
|
||||
kernel_b_sf_ptr = a_sf_ptr
|
||||
else:
|
||||
kernel_a_ptr = a_ptr
|
||||
kernel_a_sf_ptr = a_sf_ptr
|
||||
kernel_b_ptr = b_ptr
|
||||
kernel_b_sf_ptr = b_sf_ptr
|
||||
|
||||
gemm = self.__class__.kernel_class(
|
||||
sf_vec_size,
|
||||
mma_tiler_mn,
|
||||
@ -520,6 +566,8 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
max_active_clusters = hardware_info.get_max_active_clusters(
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1])
|
||||
|
||||
# Note: when tvm_ffi fake stream is used, at least one parameter shoube be tensor type,
|
||||
# so we make alpha as the cute.Tensor type in the jit func.
|
||||
compiled_gemm = cute.compile(
|
||||
gemm.wrapper,
|
||||
kernel_m,
|
||||
@ -528,17 +576,18 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
kernel_sf_m // 128,
|
||||
kernel_sf_n // 128,
|
||||
sf_k // 4,
|
||||
1,
|
||||
1, # batch
|
||||
kernel_a_ptr,
|
||||
kernel_b_ptr,
|
||||
kernel_a_sf_ptr,
|
||||
kernel_b_sf_ptr,
|
||||
c_ptr,
|
||||
alpha_ptr, # Pass alpha as device pointer
|
||||
alpha_cute_tensor,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
swap_ab,
|
||||
options=f"--opt-level 2",
|
||||
options=f"--opt-level 2 --enable-tvm-ffi"
|
||||
if self.use_tvm_ffi else "--opt-level 2",
|
||||
)
|
||||
|
||||
self.__class__.kernel_cache[cache_key] = compiled_gemm
|
||||
@ -546,21 +595,39 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
compiled_gemm = self.__class__.kernel_cache[cache_key]
|
||||
|
||||
# launch gemm kernel
|
||||
compiled_gemm(
|
||||
kernel_m,
|
||||
kernel_n,
|
||||
real_k,
|
||||
kernel_sf_m // 128,
|
||||
kernel_sf_n // 128,
|
||||
sf_k // 4,
|
||||
kernel_a_ptr,
|
||||
kernel_b_ptr,
|
||||
kernel_a_sf_ptr,
|
||||
kernel_b_sf_ptr,
|
||||
c_ptr,
|
||||
alpha_ptr, # Pass alpha as device pointer
|
||||
stream,
|
||||
)
|
||||
if self.use_tvm_ffi:
|
||||
# call with torch pointer types and no need to pass stream.
|
||||
compiled_gemm(
|
||||
kernel_m,
|
||||
kernel_n,
|
||||
real_k,
|
||||
kernel_sf_m // 128,
|
||||
kernel_sf_n // 128,
|
||||
sf_k // 4,
|
||||
kernel_a_tensor.data_ptr(),
|
||||
kernel_b_tensor.data_ptr(),
|
||||
kernel_a_sf_tensor.data_ptr(),
|
||||
kernel_b_sf_tensor.data_ptr(),
|
||||
c_tensor.data_ptr(),
|
||||
alpha_tensor,
|
||||
)
|
||||
else:
|
||||
# call with cute types and need to pass torch stream.
|
||||
compiled_gemm(
|
||||
kernel_m,
|
||||
kernel_n,
|
||||
real_k,
|
||||
kernel_sf_m // 128,
|
||||
kernel_sf_n // 128,
|
||||
sf_k // 4,
|
||||
kernel_a_ptr,
|
||||
kernel_b_ptr,
|
||||
kernel_a_sf_ptr,
|
||||
kernel_b_sf_ptr,
|
||||
c_ptr,
|
||||
alpha_cute_tensor,
|
||||
stream,
|
||||
)
|
||||
|
||||
if swap_ab:
|
||||
c_tensor = c_tensor.permute(1, 0)
|
||||
@ -578,6 +645,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
alpha: torch.Tensor,
|
||||
output_dtype: torch.dtype,
|
||||
to_userbuffers: bool = False,
|
||||
use_tvm_ffi: bool = True,
|
||||
) -> torch.Tensor:
|
||||
"""CuteDSL-based NVFP4 GEMM optimized for Blackwell.
|
||||
|
||||
@ -589,6 +657,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
alpha: Scaling factor
|
||||
output_dtype: Output data type (must be bfloat16)
|
||||
to_userbuffers: Whether to allocate output from UserBuffers pool
|
||||
use_tvm_ffi: Whether to use TVM-FFI to call the kernel. Enable this option could help reduce the kernel host launch overhead.
|
||||
|
||||
Note:
|
||||
This function is primarily used internally by nvfp4_gemm.
|
||||
@ -604,7 +673,8 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers)
|
||||
runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers,
|
||||
use_tvm_ffi)
|
||||
inputs = [input, weight, input_scale, weight_scale, alpha]
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::cute_dsl_nvfp4_gemm_blackwell",
|
||||
@ -625,6 +695,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
alpha: torch.Tensor, # Match custom op signature
|
||||
output_dtype: torch.dtype,
|
||||
to_userbuffers: bool = False,
|
||||
use_tvm_ffi: bool = True,
|
||||
):
|
||||
# [m, k]
|
||||
shape = list(mat_a.shape)
|
||||
|
||||
@ -2017,20 +2017,19 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
@cute.jit
|
||||
def wrapper(
|
||||
self,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
sf_m,
|
||||
sf_n,
|
||||
sf_k,
|
||||
m: cutlass.Int32,
|
||||
n: cutlass.Int32,
|
||||
k: cutlass.Int32,
|
||||
sf_m: cutlass.Int32,
|
||||
sf_n: cutlass.Int32,
|
||||
sf_k: cutlass.Int32,
|
||||
l: cutlass.Constexpr,
|
||||
a_ptr: cute.Pointer,
|
||||
b_ptr: cute.Pointer,
|
||||
a_sf_ptr: cute.Pointer,
|
||||
b_sf_ptr: cute.Pointer,
|
||||
c_ptr: cute.Pointer,
|
||||
alpha: cute.
|
||||
Pointer, # Device pointer to alpha, will be converted to Tensor
|
||||
alpha_tensor: cute.Tensor,
|
||||
max_active_clusters: cutlass.Constexpr,
|
||||
current_stream: cuda.CUstream,
|
||||
swap_ab: cutlass.Constexpr = False,
|
||||
@ -2051,7 +2050,7 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
a_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for A.
|
||||
b_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for B.
|
||||
c_ptr (cute.Pointer): Pointer to the C tensor.
|
||||
alpha (cute.Pointer): Device pointer to alpha scaling factor (converted to Tensor internally).
|
||||
alpha_tensor (cute.Tensor): Device tensor to alpha scaling factor.
|
||||
max_active_clusters (cutlass.Constexpr): Maximum number of active
|
||||
clusters.
|
||||
current_stream (cuda.CUstream): CUDA stream for the operation.
|
||||
@ -2096,9 +2095,6 @@ class Sm100BlockScaledPersistentDenseGemmKernel:
|
||||
(32, 4, sf_n, 4, sf_k, l),
|
||||
order=(2, 1, 4, 0, 3, 5),
|
||||
))
|
||||
alpha_tensor = cute.make_tensor(alpha,
|
||||
layout=cute.make_ordered_layout(
|
||||
(1, ), order=(0, )))
|
||||
|
||||
self(a_tensor, b_tensor, sfa_tensor, sfb_tensor, c_tensor, alpha_tensor,
|
||||
max_active_clusters, current_stream, epilogue_op)
|
||||
|
||||
@ -313,15 +313,17 @@ def nvfp4_gemm_perf_test(
|
||||
x_sf_block_list = [x_sf_block]
|
||||
w_sf_block_list = [w_sf_block]
|
||||
|
||||
alpha_tensor = torch.tensor([1.0]).cuda()
|
||||
with torch.inference_mode(), autotune():
|
||||
with nvtx.annotate(
|
||||
f"cute_dsl tune, m={SEQ_LEN}, k={HIDDEN_SIZE}, n={OUTPUT_SIZE}",
|
||||
color="orange",
|
||||
):
|
||||
output = torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell(
|
||||
x_fp4, w_fp4, x_sf_block, w_sf_block, 1.0, dtype)
|
||||
x_fp4, w_fp4, x_sf_block, w_sf_block, alpha_tensor, dtype)
|
||||
from tensorrt_llm._torch.autotuner import AutoTuner
|
||||
AutoTuner.get().print_statistics()
|
||||
|
||||
alpha_tensor = torch.tensor(1.0).cuda()
|
||||
if test_ref:
|
||||
with nvtx.annotate(
|
||||
f"ref tune, m={SEQ_LEN}, k={HIDDEN_SIZE}, n={OUTPUT_SIZE}",
|
||||
@ -342,7 +344,7 @@ def nvfp4_gemm_perf_test(
|
||||
w_fp4_list[buffer_idx % workspace_count],
|
||||
x_sf_block_list[buffer_idx % workspace_count],
|
||||
w_sf_block_list[buffer_idx % workspace_count],
|
||||
1.0,
|
||||
alpha_tensor,
|
||||
dtype,
|
||||
)
|
||||
buffer_idx = buffer_idx + 1
|
||||
@ -356,7 +358,7 @@ def nvfp4_gemm_perf_test(
|
||||
w_fp4_list[buffer_idx % workspace_count],
|
||||
x_sf_block_list[buffer_idx % workspace_count],
|
||||
w_sf_block_list[buffer_idx % workspace_count],
|
||||
1.0,
|
||||
alpha_tensor,
|
||||
dtype,
|
||||
)
|
||||
buffer_idx = buffer_idx + 1
|
||||
@ -457,7 +459,7 @@ def test_nvfp4_gemm_unified_all_tactics(dtype, mnk):
|
||||
x_fp4, x_sf_block = torch.ops.trtllm.fp4_quantize(
|
||||
x, x_sf_global, scaling_vector_size, False)
|
||||
alpha_ref = 1.0 / (w_sf_global * x_sf_global)
|
||||
alpha_tensor = torch.tensor(alpha_ref, dtype=torch.float32).cuda()
|
||||
alpha_tensor = torch.tensor([alpha_ref], dtype=torch.float32).cuda()
|
||||
|
||||
# Reference: Use CUTLASS backend explicitly for reference output
|
||||
with torch.inference_mode():
|
||||
@ -749,23 +751,19 @@ def test_fp4_linear_cuda_core(dtype, mnk):
|
||||
|
||||
if __name__ == "__main__":
|
||||
# m, n, k
|
||||
fp4_linear_perf_test(torch.bfloat16, 128, 7168, 16384)
|
||||
fp4_linear_perf_test(torch.bfloat16, 128, 24576, 1536)
|
||||
fp4_linear_perf_test(torch.bfloat16, 128, 2112, 7168)
|
||||
fp4_linear_perf_test(torch.bfloat16, 128, 4096, 7168)
|
||||
fp4_linear_perf_test(torch.bfloat16, 128, 7168, 2048)
|
||||
nvfp4_gemm_perf_test(torch.bfloat16, 128, 7168, 16384)
|
||||
|
||||
# group-1 test cases
|
||||
for tokens in [128, 8192]:
|
||||
nvfp4_gemm_perf_test(torch.bfloat16, tokens, 7168, 16384)
|
||||
nvfp4_gemm_perf_test(torch.bfloat16, tokens, 24576, 1536)
|
||||
nvfp4_gemm_perf_test(torch.bfloat16, tokens, 2112, 7168)
|
||||
nvfp4_gemm_perf_test(torch.bfloat16, tokens, 4096, 7168)
|
||||
nvfp4_gemm_perf_test(torch.bfloat16, tokens, 7168, 2048)
|
||||
# # group-1 test cases
|
||||
# for tokens in [128, 8192]:
|
||||
# nvfp4_gemm_perf_test(torch.bfloat16, tokens, 7168, 16384)
|
||||
# nvfp4_gemm_perf_test(torch.bfloat16, tokens, 24576, 1536)
|
||||
# nvfp4_gemm_perf_test(torch.bfloat16, tokens, 2112, 7168)
|
||||
# nvfp4_gemm_perf_test(torch.bfloat16, tokens, 4096, 7168)
|
||||
# nvfp4_gemm_perf_test(torch.bfloat16, tokens, 7168, 2048)
|
||||
|
||||
# group-2 test cases
|
||||
for m in [128, 256, 512]:
|
||||
nvfp4_gemm_perf_test(torch.bfloat16, m, 131584, 7168)
|
||||
nvfp4_gemm_perf_test(torch.bfloat16, m, 7168, 65792)
|
||||
nvfp4_gemm_perf_test(torch.bfloat16, m, 227368, 2560, test_ref=False)
|
||||
nvfp4_gemm_perf_test(torch.bfloat16, m, 2560, 113664)
|
||||
# # group-2 test cases
|
||||
# for m in [128, 256, 512]:
|
||||
# nvfp4_gemm_perf_test(torch.bfloat16, m, 131584, 7168)
|
||||
# nvfp4_gemm_perf_test(torch.bfloat16, m, 7168, 65792)
|
||||
# nvfp4_gemm_perf_test(torch.bfloat16, m, 227368, 2560, test_ref=False)
|
||||
# nvfp4_gemm_perf_test(torch.bfloat16, m, 2560, 113664)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user