From 15281de799b95315e231dda348ae5f13ec605663 Mon Sep 17 00:00:00 2001 From: Yukun He <23156053+hyukn@users.noreply.github.com> Date: Wed, 14 Jan 2026 14:26:18 +0800 Subject: [PATCH] [None][fix] Reduce host overhead for unified nvfp4 gemm tuning path. (#10503) Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com> --- .../_torch/custom_ops/torch_custom_ops.py | 123 ++++++++---------- 1 file changed, 55 insertions(+), 68 deletions(-) diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index e3b725b393..06e93eb3e5 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -23,6 +23,10 @@ from ..utils import (ActivationType, fp4_scale_infer_shape, get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2) +if IS_CUTLASS_DSL_AVAILABLE: + from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \ + CuteDSLNVFP4BlackwellLinear + # Used to WAR an issue in torch.bmm that it would break the graph when the out is not contiguous. @torch.library.custom_op("trtllm::bmm_out", mutates_args=("out", )) @@ -700,8 +704,7 @@ class NVFP4GemmUnifiedRunner(TunableRunner): 0, 0, get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2), ), constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ), - # nested tuning should always be independent - distributed_tuning_strategy=DistributedTuningStrategy.INDEPENDENT, + distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL, ) def __init__(self, to_userbuffers: bool, output_dtype: torch.dtype, @@ -727,12 +730,12 @@ class NVFP4GemmUnifiedRunner(TunableRunner): def get_valid_tactics(self, inputs: List[torch.Tensor], profile: OptimizationProfile, - **kwargs) -> List[Tuple]: + **kwargs) -> List[Tuple[str, int]]: # return valid nvfp4 gemm implementations from allowed_backends tactics = [] act_fp4, weight, act_sf, weight_scale, alpha = inputs - # Add CUDA Core backend if available + # Add CUDA Core tactics if available if self._is_backend_allowed("cuda_core"): is_cuda_core_supported = False m = act_fp4.shape[0] @@ -748,7 +751,12 @@ class NVFP4GemmUnifiedRunner(TunableRunner): and m <= CudaCoreNVFP4Runner.MAX_M_DIMENSION) if is_cuda_core_supported: - tactics.append("cuda_core") + cuda_core_runner = CudaCoreNVFP4Runner(self.to_userbuffers, + self.output_dtype) + cuda_core_tactics = cuda_core_runner.get_valid_tactics( + inputs, profile) + tactics.extend([("cuda_core", tactic) + for tactic in cuda_core_tactics]) elif self._is_only_backend("cuda_core"): # Explicitly forced but conditions not met - raise error error_msg = f"CUDA Core backend requires SM >= {CudaCoreNVFP4Runner.MIN_SM_VERSION} and M <= {CudaCoreNVFP4Runner.MAX_M_DIMENSION}. " @@ -756,21 +764,30 @@ class NVFP4GemmUnifiedRunner(TunableRunner): error_msg += "Please add other backends to allowed_backends." raise ValueError(error_msg) - # Add CUTLASS runner (always available) + # Add CUTLASS tactics if available if self._is_backend_allowed("cutlass"): - tactics.append("cutlass") + cutlass_runner = FP4GemmRunner( + fp4_utils.FP4GemmType.W4A4_NVFP4_NVFP4, self.to_userbuffers, + self.output_dtype) + cutlass_tactics = cutlass_runner.get_valid_tactics(inputs, profile) + tactics.extend([("cutlass", tactic) for tactic in cutlass_tactics]) - # Add cuBLASLt runner if available + # Add cuBLASLt tactics if available if self._is_backend_allowed("cublaslt"): if IS_CUBLASLT_AVAILABLE: - tactics.append("cublaslt") + cublaslt_runner = CublasLtFP4GemmRunner(self.to_userbuffers, + self.output_dtype) + cublaslt_tactics = cublaslt_runner.get_valid_tactics( + inputs, profile) + tactics.extend([("cublaslt", tactic) + for tactic in cublaslt_tactics]) elif self._is_only_backend("cublaslt"): raise ValueError( "cuBLASLt backend is not available. " "Please check cuBLASLt installation or add other backends to allowed_backends." ) - # Add CuteDSL runner if available + # Add CuteDSL tactics if available if self._is_backend_allowed("cutedsl"): if IS_CUTLASS_DSL_AVAILABLE: # Check SM version first - CuteDSL NVFP4 only supports SM 100 (B200) @@ -784,8 +801,6 @@ class NVFP4GemmUnifiedRunner(TunableRunner): "Please add other backends to allowed_backends.") else: # SM version OK, check if CuteDSL supports the current shape - from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \ - CuteDSLNVFP4BlackwellLinear cutedsl_runner = CuteDSLNVFP4BlackwellLinear( self.output_dtype) cutedsl_tactics = cutedsl_runner.get_valid_tactics( @@ -793,7 +808,8 @@ class NVFP4GemmUnifiedRunner(TunableRunner): if cutedsl_tactics: # CuteDSL supports this shape - tactics.append("cutedsl") + tactics.extend([("cutedsl", tactic) + for tactic in cutedsl_tactics]) elif self._is_only_backend("cutedsl"): # Explicitly forced CuteDSL but it doesn't support this shape m, n, k = inputs[0].shape[0], inputs[1].shape[ @@ -817,65 +833,36 @@ class NVFP4GemmUnifiedRunner(TunableRunner): self, inputs: List[torch.Tensor], tactic: Union[ - str, int] = "cutlass", # str: backend name, or int: -1 for fallback + Tuple, + int] = -1, # tuple: (backend name, sub_tactic_id), or int: -1 for fallback **kwargs, ) -> torch.Tensor: - act_fp4, weight, act_sf, weight_scale, alpha = inputs - - # Handle fallback tactic (-1) on cache miss + # Handle fallback tactic on cache miss if tactic == -1: - # Get valid tactics and use first available - from tensorrt_llm._torch.autotuner import OptimizationProfile - valid_tactics = self.get_valid_tactics(inputs, - OptimizationProfile()) - if valid_tactics: - # Prefer cutlass as fallback if available, otherwise use first valid tactic - tactic = "cutlass" if "cutlass" in valid_tactics else valid_tactics[ - 0] - else: - m, n, k = inputs[0].shape[0], inputs[1].shape[ - 0], inputs[0].shape[1] * 2 - raise ValueError( - f"No valid backends available for the current shape:\n" - f" M={m}, N={n}, K={k}\n" - f" Allowed backends: {self.allowed_backends}") + # Prefer cutlass as fallback if available, otherwise use first valid backend + assert len( + self.allowed_backends) > 0, "No allowed backends available" + tactic = ("cutlass", + -1) if "cutlass" in self.allowed_backends else ( + self.allowed_backends[0], -1) - if tactic == "cuda_core": - # Unswizzle the activation scale factors - # act_sf is swizzled, need to reverse it for cuda_core_nvfp4_gemm - m = act_fp4.shape[0] - act_sf_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse( - act_sf.view((m + 128 - 1) // 128 * 128, -1)) - - # Call CUDA Core NVFP4 GEMM - return torch.ops.trtllm.cuda_core_nvfp4_gemm( - act_fp4, - weight, - act_sf_unswizzled, - weight_scale, - alpha, - bias=None, - out_dtype=self.output_dtype, - to_userbuffers=self.to_userbuffers) - elif tactic == "cutlass": - return torch.ops.trtllm.nvfp4_gemm_cutlass(act_fp4, weight, act_sf, - weight_scale, alpha, - self.output_dtype, - self.to_userbuffers) - elif tactic == "cublaslt": - return torch.ops.trtllm.nvfp4_gemm_cublaslt(act_fp4, weight, act_sf, - weight_scale, alpha, - self.output_dtype, - self.to_userbuffers) - elif tactic == "cutedsl": - return torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell( - act_fp4, weight, act_sf, weight_scale, alpha, self.output_dtype, - self.to_userbuffers) - elif tactic == -1: - return torch.ops.trtllm.nvfp4_gemm_cutlass(act_fp4, weight, act_sf, - weight_scale, alpha, - self.output_dtype, - self.to_userbuffers) + backend, sub_tactic = tactic + if backend == "cuda_core": + return CudaCoreNVFP4Runner(self.to_userbuffers, + self.output_dtype)(inputs, + tactic=sub_tactic) + elif backend == "cutlass": + return FP4GemmRunner(fp4_utils.FP4GemmType.W4A4_NVFP4_NVFP4, + self.to_userbuffers, + self.output_dtype)(inputs, tactic=sub_tactic) + elif backend == "cublaslt": + return CublasLtFP4GemmRunner(self.to_userbuffers, + self.output_dtype)(inputs, + tactic=sub_tactic) + elif backend == "cutedsl": + return CuteDSLNVFP4BlackwellLinear( + self.output_dtype, self.to_userbuffers)(inputs, + tactic=sub_tactic) else: raise ValueError(f"Invalid tactic: {tactic}")