mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] Reduce host overhead for unified nvfp4 gemm tuning path.
Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
This commit is contained in:
parent
3bd319dc8e
commit
9550c969ee
@ -23,6 +23,10 @@ from ..utils import (ActivationType, fp4_scale_infer_shape,
|
||||
get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2)
|
||||
|
||||
if IS_CUTLASS_DSL_AVAILABLE:
|
||||
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \
|
||||
CuteDSLNVFP4BlackwellLinear
|
||||
|
||||
|
||||
# Used to WAR an issue in torch.bmm that it would break the graph when the out is not contiguous.
|
||||
@torch.library.custom_op("trtllm::bmm_out", mutates_args=("out", ))
|
||||
@ -700,8 +704,7 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
|
||||
0, 0, get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2), ),
|
||||
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
|
||||
# nested tuning should always be independent
|
||||
distributed_tuning_strategy=DistributedTuningStrategy.INDEPENDENT,
|
||||
distributed_tuning_strategy=DistributedTuningStrategy.PARALLEL,
|
||||
)
|
||||
|
||||
def __init__(self, to_userbuffers: bool, output_dtype: torch.dtype,
|
||||
@ -727,12 +730,12 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
|
||||
|
||||
def get_valid_tactics(self, inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
**kwargs) -> List[Tuple]:
|
||||
**kwargs) -> List[Tuple[str, int]]:
|
||||
# return valid nvfp4 gemm implementations from allowed_backends
|
||||
tactics = []
|
||||
act_fp4, weight, act_sf, weight_scale, alpha = inputs
|
||||
|
||||
# Add CUDA Core backend if available
|
||||
# Add CUDA Core tactics if available
|
||||
if self._is_backend_allowed("cuda_core"):
|
||||
is_cuda_core_supported = False
|
||||
m = act_fp4.shape[0]
|
||||
@ -748,7 +751,12 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
|
||||
and m <= CudaCoreNVFP4Runner.MAX_M_DIMENSION)
|
||||
|
||||
if is_cuda_core_supported:
|
||||
tactics.append("cuda_core")
|
||||
cuda_core_runner = CudaCoreNVFP4Runner(self.to_userbuffers,
|
||||
self.output_dtype)
|
||||
cuda_core_tactics = cuda_core_runner.get_valid_tactics(
|
||||
inputs, profile)
|
||||
tactics.extend([("cuda_core", tactic)
|
||||
for tactic in cuda_core_tactics])
|
||||
elif self._is_only_backend("cuda_core"):
|
||||
# Explicitly forced but conditions not met - raise error
|
||||
error_msg = f"CUDA Core backend requires SM >= {CudaCoreNVFP4Runner.MIN_SM_VERSION} and M <= {CudaCoreNVFP4Runner.MAX_M_DIMENSION}. "
|
||||
@ -756,21 +764,30 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
|
||||
error_msg += "Please add other backends to allowed_backends."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Add CUTLASS runner (always available)
|
||||
# Add CUTLASS tactics if available
|
||||
if self._is_backend_allowed("cutlass"):
|
||||
tactics.append("cutlass")
|
||||
cutlass_runner = FP4GemmRunner(
|
||||
fp4_utils.FP4GemmType.W4A4_NVFP4_NVFP4, self.to_userbuffers,
|
||||
self.output_dtype)
|
||||
cutlass_tactics = cutlass_runner.get_valid_tactics(inputs, profile)
|
||||
tactics.extend([("cutlass", tactic) for tactic in cutlass_tactics])
|
||||
|
||||
# Add cuBLASLt runner if available
|
||||
# Add cuBLASLt tactics if available
|
||||
if self._is_backend_allowed("cublaslt"):
|
||||
if IS_CUBLASLT_AVAILABLE:
|
||||
tactics.append("cublaslt")
|
||||
cublaslt_runner = CublasLtFP4GemmRunner(self.to_userbuffers,
|
||||
self.output_dtype)
|
||||
cublaslt_tactics = cublaslt_runner.get_valid_tactics(
|
||||
inputs, profile)
|
||||
tactics.extend([("cublaslt", tactic)
|
||||
for tactic in cublaslt_tactics])
|
||||
elif self._is_only_backend("cublaslt"):
|
||||
raise ValueError(
|
||||
"cuBLASLt backend is not available. "
|
||||
"Please check cuBLASLt installation or add other backends to allowed_backends."
|
||||
)
|
||||
|
||||
# Add CuteDSL runner if available
|
||||
# Add CuteDSL tactics if available
|
||||
if self._is_backend_allowed("cutedsl"):
|
||||
if IS_CUTLASS_DSL_AVAILABLE:
|
||||
# Check SM version first - CuteDSL NVFP4 only supports SM 100 (B200)
|
||||
@ -784,8 +801,6 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
|
||||
"Please add other backends to allowed_backends.")
|
||||
else:
|
||||
# SM version OK, check if CuteDSL supports the current shape
|
||||
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \
|
||||
CuteDSLNVFP4BlackwellLinear
|
||||
cutedsl_runner = CuteDSLNVFP4BlackwellLinear(
|
||||
self.output_dtype)
|
||||
cutedsl_tactics = cutedsl_runner.get_valid_tactics(
|
||||
@ -793,7 +808,8 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
|
||||
|
||||
if cutedsl_tactics:
|
||||
# CuteDSL supports this shape
|
||||
tactics.append("cutedsl")
|
||||
tactics.extend([("cutedsl", tactic)
|
||||
for tactic in cutedsl_tactics])
|
||||
elif self._is_only_backend("cutedsl"):
|
||||
# Explicitly forced CuteDSL but it doesn't support this shape
|
||||
m, n, k = inputs[0].shape[0], inputs[1].shape[
|
||||
@ -817,65 +833,36 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic: Union[
|
||||
str, int] = "cutlass", # str: backend name, or int: -1 for fallback
|
||||
Tuple,
|
||||
int] = -1, # tuple: (backend name, sub_tactic_id), or int: -1 for fallback
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
act_fp4, weight, act_sf, weight_scale, alpha = inputs
|
||||
|
||||
# Handle fallback tactic (-1) on cache miss
|
||||
# Handle fallback tactic on cache miss
|
||||
if tactic == -1:
|
||||
# Get valid tactics and use first available
|
||||
from tensorrt_llm._torch.autotuner import OptimizationProfile
|
||||
valid_tactics = self.get_valid_tactics(inputs,
|
||||
OptimizationProfile())
|
||||
if valid_tactics:
|
||||
# Prefer cutlass as fallback if available, otherwise use first valid tactic
|
||||
tactic = "cutlass" if "cutlass" in valid_tactics else valid_tactics[
|
||||
0]
|
||||
else:
|
||||
m, n, k = inputs[0].shape[0], inputs[1].shape[
|
||||
0], inputs[0].shape[1] * 2
|
||||
raise ValueError(
|
||||
f"No valid backends available for the current shape:\n"
|
||||
f" M={m}, N={n}, K={k}\n"
|
||||
f" Allowed backends: {self.allowed_backends}")
|
||||
# Prefer cutlass as fallback if available, otherwise use first valid backend
|
||||
assert len(
|
||||
self.allowed_backends) > 0, "No allowed backends available"
|
||||
tactic = ("cutlass",
|
||||
-1) if "cutlass" in self.allowed_backends else (
|
||||
self.allowed_backends[0], -1)
|
||||
|
||||
if tactic == "cuda_core":
|
||||
# Unswizzle the activation scale factors
|
||||
# act_sf is swizzled, need to reverse it for cuda_core_nvfp4_gemm
|
||||
m = act_fp4.shape[0]
|
||||
act_sf_unswizzled = torch.ops.trtllm.block_scale_interleave_reverse(
|
||||
act_sf.view((m + 128 - 1) // 128 * 128, -1))
|
||||
|
||||
# Call CUDA Core NVFP4 GEMM
|
||||
return torch.ops.trtllm.cuda_core_nvfp4_gemm(
|
||||
act_fp4,
|
||||
weight,
|
||||
act_sf_unswizzled,
|
||||
weight_scale,
|
||||
alpha,
|
||||
bias=None,
|
||||
out_dtype=self.output_dtype,
|
||||
to_userbuffers=self.to_userbuffers)
|
||||
elif tactic == "cutlass":
|
||||
return torch.ops.trtllm.nvfp4_gemm_cutlass(act_fp4, weight, act_sf,
|
||||
weight_scale, alpha,
|
||||
self.output_dtype,
|
||||
self.to_userbuffers)
|
||||
elif tactic == "cublaslt":
|
||||
return torch.ops.trtllm.nvfp4_gemm_cublaslt(act_fp4, weight, act_sf,
|
||||
weight_scale, alpha,
|
||||
self.output_dtype,
|
||||
self.to_userbuffers)
|
||||
elif tactic == "cutedsl":
|
||||
return torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell(
|
||||
act_fp4, weight, act_sf, weight_scale, alpha, self.output_dtype,
|
||||
self.to_userbuffers)
|
||||
elif tactic == -1:
|
||||
return torch.ops.trtllm.nvfp4_gemm_cutlass(act_fp4, weight, act_sf,
|
||||
weight_scale, alpha,
|
||||
self.output_dtype,
|
||||
self.to_userbuffers)
|
||||
backend, sub_tactic = tactic
|
||||
if backend == "cuda_core":
|
||||
return CudaCoreNVFP4Runner(self.to_userbuffers,
|
||||
self.output_dtype)(inputs,
|
||||
tactic=sub_tactic)
|
||||
elif backend == "cutlass":
|
||||
return FP4GemmRunner(fp4_utils.FP4GemmType.W4A4_NVFP4_NVFP4,
|
||||
self.to_userbuffers,
|
||||
self.output_dtype)(inputs, tactic=sub_tactic)
|
||||
elif backend == "cublaslt":
|
||||
return CublasLtFP4GemmRunner(self.to_userbuffers,
|
||||
self.output_dtype)(inputs,
|
||||
tactic=sub_tactic)
|
||||
elif backend == "cutedsl":
|
||||
return CuteDSLNVFP4BlackwellLinear(
|
||||
self.output_dtype, self.to_userbuffers)(inputs,
|
||||
tactic=sub_tactic)
|
||||
else:
|
||||
raise ValueError(f"Invalid tactic: {tactic}")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user