[None][fix] Reduce host overhead for unified nvfp4 gemm tuning path.

Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
Yukun He 2026-01-07 12:40:20 +00:00
parent 3bd319dc8e
commit 9550c969ee

View File

@ -23,6 +23,10 @@ from ..utils import (ActivationType, fp4_scale_infer_shape,
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2)
if IS_CUTLASS_DSL_AVAILABLE:
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \
CuteDSLNVFP4BlackwellLinear
# Used to WAR an issue in torch.bmm that it would break the graph when the out is not contiguous.
@torch.library.custom_op("trtllm::bmm_out", mutates_args=("out", ))
@ -700,8 +704,7 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
0, 0, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
# nested tuning should always be independent
distributed_tuning_strategy=DistributedTuningStrategy.INDEPENDENT,
distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL,
)
def __init__(self, to_userbuffers: bool, output_dtype: torch.dtype,
@ -727,12 +730,12 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
def get_valid_tactics(self, inputs: List[torch.Tensor],
profile: OptimizationProfile,
**kwargs) -> List[Tuple]:
**kwargs) -> List[Tuple[str, int]]:
# return valid nvfp4 gemm implementations from allowed_backends
tactics = []
act_fp4, weight, act_sf, weight_scale, alpha = inputs
# Add CUDA Core backend if available
# Add CUDA Core tactics if available
if self._is_backend_allowed("cuda_core"):
is_cuda_core_supported = False
m = act_fp4.shape[0]
@ -748,7 +751,12 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
and m <= CudaCoreNVFP4Runner.MAX_M_DIMENSION)
if is_cuda_core_supported:
tactics.append("cuda_core")
cuda_core_runner = CudaCoreNVFP4Runner(self.to_userbuffers,
self.output_dtype)
cuda_core_tactics = cuda_core_runner.get_valid_tactics(
inputs, profile)
tactics.extend([("cuda_core", tactic)
for tactic in cuda_core_tactics])
elif self._is_only_backend("cuda_core"):
# Explicitly forced but conditions not met - raise error
error_msg = f"CUDA Core backend requires SM >= {CudaCoreNVFP4Runner.MIN_SM_VERSION} and M <= {CudaCoreNVFP4Runner.MAX_M_DIMENSION}. "
@ -756,21 +764,30 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
error_msg += "Please add other backends to allowed_backends."
raise ValueError(error_msg)
# Add CUTLASS runner (always available)
# Add CUTLASS tactics if available
if self._is_backend_allowed("cutlass"):
tactics.append("cutlass")
cutlass_runner = FP4GemmRunner(
fp4_utils.FP4GemmType.W4A4_NVFP4_NVFP4, self.to_userbuffers,
self.output_dtype)
cutlass_tactics = cutlass_runner.get_valid_tactics(inputs, profile)
tactics.extend([("cutlass", tactic) for tactic in cutlass_tactics])
# Add cuBLASLt runner if available
# Add cuBLASLt tactics if available
if self._is_backend_allowed("cublaslt"):
if IS_CUBLASLT_AVAILABLE:
tactics.append("cublaslt")
cublaslt_runner = CublasLtFP4GemmRunner(self.to_userbuffers,
self.output_dtype)
cublaslt_tactics = cublaslt_runner.get_valid_tactics(
inputs, profile)
tactics.extend([("cublaslt", tactic)
for tactic in cublaslt_tactics])
elif self._is_only_backend("cublaslt"):
raise ValueError(
"cuBLASLt backend is not available. "
"Please check cuBLASLt installation or add other backends to allowed_backends."
)
# Add CuteDSL runner if available
# Add CuteDSL tactics if available
if self._is_backend_allowed("cutedsl"):
if IS_CUTLASS_DSL_AVAILABLE:
# Check SM version first - CuteDSL NVFP4 only supports SM 100 (B200)
@ -784,8 +801,6 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
"Please add other backends to allowed_backends.")
else:
# SM version OK, check if CuteDSL supports the current shape
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \
CuteDSLNVFP4BlackwellLinear
cutedsl_runner = CuteDSLNVFP4BlackwellLinear(
self.output_dtype)
cutedsl_tactics = cutedsl_runner.get_valid_tactics(
@ -793,7 +808,8 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
if cutedsl_tactics:
# CuteDSL supports this shape
tactics.append("cutedsl")
tactics.extend([("cutedsl", tactic)
for tactic in cutedsl_tactics])
elif self._is_only_backend("cutedsl"):
# Explicitly forced CuteDSL but it doesn't support this shape
m, n, k = inputs[0].shape[0], inputs[1].shape[
@ -817,65 +833,36 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
self,
inputs: List[torch.Tensor],
tactic: Union[
str, int] = "cutlass", # str: backend name, or int: -1 for fallback
Tuple,
int] = -1, # tuple: (backend name, sub_tactic_id), or int: -1 for fallback
**kwargs,
) -> torch.Tensor:
act_fp4, weight, act_sf, weight_scale, alpha = inputs
# Handle fallback tactic (-1) on cache miss
# Handle fallback tactic on cache miss
if tactic == -1:
# Get valid tactics and use first available
from tensorrt_llm._torch.autotuner import OptimizationProfile
valid_tactics = self.get_valid_tactics(inputs,
OptimizationProfile())
if valid_tactics:
# Prefer cutlass as fallback if available, otherwise use first valid tactic
tactic = "cutlass" if "cutlass" in valid_tactics else valid_tactics[
0]
else:
m, n, k = inputs[0].shape[0], inputs[1].shape[
0], inputs[0].shape[1] * 2
raise ValueError(
f"No valid backends available for the current shape:\n"
f" M={m}, N={n}, K={k}\n"
f" Allowed backends: {self.allowed_backends}")
# Prefer cutlass as fallback if available, otherwise use first valid backend
assert len(
self.allowed_backends) > 0, "No allowed backends available"
tactic = ("cutlass",
-1) if "cutlass" in self.allowed_backends else (
self.allowed_backends[0], -1)
if tactic == "cuda_core":
# Unswizzle the activation scale factors
# act_sf is swizzled, need to reverse it for cuda_core_nvfp4_gemm
m = act_fp4.shape[0]
act_sf_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
act_sf.view((m + 128 - 1) // 128 * 128, -1))
# Call CUDA Core NVFP4 GEMM
return torch.ops.trtllm.cuda_core_nvfp4_gemm(
act_fp4,
weight,
act_sf_unswizzled,
weight_scale,
alpha,
bias=None,
out_dtype=self.output_dtype,
to_userbuffers=self.to_userbuffers)
elif tactic == "cutlass":
return torch.ops.trtllm.nvfp4_gemm_cutlass(act_fp4, weight, act_sf,
weight_scale, alpha,
self.output_dtype,
self.to_userbuffers)
elif tactic == "cublaslt":
return torch.ops.trtllm.nvfp4_gemm_cublaslt(act_fp4, weight, act_sf,
weight_scale, alpha,
self.output_dtype,
self.to_userbuffers)
elif tactic == "cutedsl":
return torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell(
act_fp4, weight, act_sf, weight_scale, alpha, self.output_dtype,
self.to_userbuffers)
elif tactic == -1:
return torch.ops.trtllm.nvfp4_gemm_cutlass(act_fp4, weight, act_sf,
weight_scale, alpha,
self.output_dtype,
self.to_userbuffers)
backend, sub_tactic = tactic
if backend == "cuda_core":
return CudaCoreNVFP4Runner(self.to_userbuffers,
self.output_dtype)(inputs,
tactic=sub_tactic)
elif backend == "cutlass":
return FP4GemmRunner(fp4_utils.FP4GemmType.W4A4_NVFP4_NVFP4,
self.to_userbuffers,
self.output_dtype)(inputs, tactic=sub_tactic)
elif backend == "cublaslt":
return CublasLtFP4GemmRunner(self.to_userbuffers,
self.output_dtype)(inputs,
tactic=sub_tactic)
elif backend == "cutedsl":
return CuteDSLNVFP4BlackwellLinear(
self.output_dtype, self.to_userbuffers)(inputs,
tactic=sub_tactic)
else:
raise ValueError(f"Invalid tactic: {tactic}")