mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 53c6f63fa8 into 38296a472b
This commit is contained in:
commit
84bf06a416
@ -34,7 +34,9 @@ if IS_FLASHINFER_AVAILABLE:
|
||||
]
|
||||
|
||||
if IS_CUTLASS_DSL_AVAILABLE:
|
||||
from .cute_dsl_custom_ops import cute_dsl_nvfp4_gemm_blackwell
|
||||
from .cute_dsl_custom_ops import (cute_dsl_bf16_gemm_blackwell,
|
||||
cute_dsl_nvfp4_gemm_blackwell)
|
||||
__all__ += [
|
||||
'cute_dsl_nvfp4_gemm_blackwell',
|
||||
'cute_dsl_bf16_gemm_blackwell',
|
||||
]
|
||||
|
||||
@ -339,6 +339,8 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
|
||||
from ..cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import \
|
||||
Sm100BlockScaledPersistentDenseGemmKernel
|
||||
from ..cute_dsl_kernels.blackwell.dense_gemm_persistent import (
|
||||
Sm100PersistentDenseGemmKernel, Sm100PersistentDenseGemmKernelWrapper)
|
||||
from ..cute_dsl_kernels.blackwell.utils import make_ptr
|
||||
|
||||
class CuteDSLNVFP4BlackwellLinear(TunableRunner):
|
||||
@ -825,6 +827,295 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
ret = mat_a.new_empty(shape, dtype=torch.bfloat16)
|
||||
return ret
|
||||
|
||||
class CuteDSLBF16BlackwellGemm(TunableRunner):
|
||||
"""TunableRunner for BF16 dense GEMM on Blackwell using cuteDSL."""
|
||||
kernel_class = Sm100PersistentDenseGemmKernel
|
||||
kernel_cache = dict()
|
||||
|
||||
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2), ), )
|
||||
|
||||
def __init__(self, output_dtype: torch.dtype):
|
||||
super().__init__()
|
||||
self.output_dtype = output_dtype
|
||||
|
||||
if (sm_version := get_sm_version()) not in (100, 103):
|
||||
raise ValueError(
|
||||
f"SM version {sm_version} is not supported for CuteDSLBF16BlackwellGemm, it only supports SM 100/103 (Blackwell)"
|
||||
)
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.output_dtype, ))
|
||||
|
||||
def __eq__(self, other):
|
||||
if not isinstance(other, CuteDSLBF16BlackwellGemm):
|
||||
return False
|
||||
return self.output_dtype == other.output_dtype
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
**kwargs,
|
||||
) -> List[Tuple[int, int]]:
|
||||
assert inputs[0].dim() == 2
|
||||
assert inputs[1].dim() == 2
|
||||
|
||||
m = inputs[0].shape[0]
|
||||
n = inputs[1].shape[0]
|
||||
k = inputs[0].shape[1]
|
||||
batch_size = 1
|
||||
a_major = "k"
|
||||
b_major = "k"
|
||||
|
||||
# Determine cutlass dtype from torch dtype
|
||||
if inputs[0].dtype == torch.bfloat16:
|
||||
ab_dtype = cutlass.BFloat16
|
||||
elif inputs[0].dtype == torch.float16:
|
||||
ab_dtype = cutlass.Float16
|
||||
else:
|
||||
ab_dtype = cutlass.BFloat16
|
||||
|
||||
if self.output_dtype == torch.bfloat16:
|
||||
c_dtype = cutlass.BFloat16
|
||||
elif self.output_dtype == torch.float16:
|
||||
c_dtype = cutlass.Float16
|
||||
elif self.output_dtype == torch.float32:
|
||||
c_dtype = cutlass.Float32
|
||||
else:
|
||||
c_dtype = cutlass.BFloat16
|
||||
|
||||
mma_tiler_mn_candidates = [
|
||||
(256, 128),
|
||||
(128, 128),
|
||||
(128, 256),
|
||||
(256, 256),
|
||||
(256, 64),
|
||||
(128, 64),
|
||||
]
|
||||
cluster_shape_mn_candidates = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
]
|
||||
swap_ab_candidates = [True, False]
|
||||
|
||||
valid_tactics = []
|
||||
for swap_ab in swap_ab_candidates:
|
||||
for mma_tiler_mn in mma_tiler_mn_candidates:
|
||||
for cluster_shape_mn in cluster_shape_mn_candidates:
|
||||
if swap_ab:
|
||||
c_major = "m"
|
||||
kernel_m = n
|
||||
kernel_n = m
|
||||
else:
|
||||
c_major = "n"
|
||||
kernel_m = m
|
||||
kernel_n = n
|
||||
|
||||
use_2cta_instrs = mma_tiler_mn[0] == 256
|
||||
kernel = Sm100PersistentDenseGemmKernel(
|
||||
cutlass.Float32,
|
||||
use_2cta_instrs,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
)
|
||||
if kernel.can_implement(
|
||||
(kernel_m, kernel_n, k, batch_size),
|
||||
ab_dtype,
|
||||
c_dtype,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
):
|
||||
valid_tactics.append(
|
||||
(mma_tiler_mn, cluster_shape_mn, swap_ab))
|
||||
|
||||
return valid_tactics
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs BF16 dense GEMM operation using CuTe DSL.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]):
|
||||
inputs[0]: Input tensor of shape (m, k), dtype: bf16/fp16.
|
||||
inputs[1]: Weight tensor of shape (n, k), dtype: bf16/fp16.
|
||||
tactic: Tiling and cluster strategy, tuple (mma_tiler_mn, cluster_shape_mn, swap_ab).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (m, n).
|
||||
"""
|
||||
if isinstance(tactic, tuple):
|
||||
mma_tiler_mn, cluster_shape_mn, swap_ab = tactic
|
||||
else:
|
||||
mma_tiler_mn, cluster_shape_mn, swap_ab = [
|
||||
(128, 128),
|
||||
(1, 1),
|
||||
False,
|
||||
]
|
||||
|
||||
a_tensor, b_tensor = inputs
|
||||
m, k, n = a_tensor.shape[0], a_tensor.shape[1], b_tensor.shape[0]
|
||||
c_tensor = torch.empty(*(m, n),
|
||||
dtype=self.output_dtype,
|
||||
device="cuda")
|
||||
|
||||
if swap_ab:
|
||||
c_tensor = c_tensor.permute(1, 0)
|
||||
|
||||
# Determine cutlass dtypes
|
||||
if a_tensor.dtype == torch.bfloat16:
|
||||
ab_cutlass_dtype = cutlass.BFloat16
|
||||
elif a_tensor.dtype == torch.float16:
|
||||
ab_cutlass_dtype = cutlass.Float16
|
||||
else:
|
||||
ab_cutlass_dtype = cutlass.BFloat16
|
||||
|
||||
if self.output_dtype == torch.bfloat16:
|
||||
c_cutlass_dtype = cutlass.BFloat16
|
||||
elif self.output_dtype == torch.float16:
|
||||
c_cutlass_dtype = cutlass.Float16
|
||||
elif self.output_dtype == torch.float32:
|
||||
c_cutlass_dtype = cutlass.Float32
|
||||
else:
|
||||
c_cutlass_dtype = cutlass.BFloat16
|
||||
|
||||
a_ptr = make_ptr(ab_cutlass_dtype,
|
||||
a_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16)
|
||||
b_ptr = make_ptr(ab_cutlass_dtype,
|
||||
b_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16)
|
||||
c_ptr = make_ptr(c_cutlass_dtype,
|
||||
c_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16)
|
||||
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
use_2cta_instrs = mma_tiler_mn[0] == 256
|
||||
CACHE_KEY = (
|
||||
ab_cutlass_dtype,
|
||||
c_cutlass_dtype,
|
||||
use_2cta_instrs,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
swap_ab,
|
||||
)
|
||||
|
||||
if swap_ab:
|
||||
kernel_a_ptr = b_ptr
|
||||
kernel_b_ptr = a_ptr
|
||||
kernel_m = n
|
||||
kernel_n = m
|
||||
else:
|
||||
kernel_a_ptr = a_ptr
|
||||
kernel_b_ptr = b_ptr
|
||||
kernel_m = m
|
||||
kernel_n = n
|
||||
|
||||
if CACHE_KEY not in CuteDSLBF16BlackwellGemm.kernel_cache:
|
||||
gemm = Sm100PersistentDenseGemmKernelWrapper(
|
||||
cutlass.Float32,
|
||||
use_2cta_instrs,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
)
|
||||
hardware_info = cutlass.utils.HardwareInfo()
|
||||
max_active_clusters = hardware_info.get_max_active_clusters(
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1])
|
||||
|
||||
compiled_gemm = cute.compile(
|
||||
gemm,
|
||||
kernel_m,
|
||||
kernel_n,
|
||||
k,
|
||||
1, # batch_size
|
||||
kernel_a_ptr,
|
||||
kernel_b_ptr,
|
||||
c_ptr,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
swap_ab,
|
||||
)
|
||||
|
||||
CuteDSLBF16BlackwellGemm.kernel_cache[CACHE_KEY] = compiled_gemm
|
||||
else:
|
||||
compiled_gemm = CuteDSLBF16BlackwellGemm.kernel_cache[CACHE_KEY]
|
||||
|
||||
compiled_gemm(
|
||||
kernel_m,
|
||||
kernel_n,
|
||||
k,
|
||||
kernel_a_ptr,
|
||||
kernel_b_ptr,
|
||||
c_ptr,
|
||||
stream,
|
||||
)
|
||||
|
||||
if swap_ab:
|
||||
c_tensor = c_tensor.permute(1, 0)
|
||||
return c_tensor
|
||||
|
||||
# input/weight: bf16/fp16, output: bf16/fp16/fp32
|
||||
@torch.library.custom_op("trtllm::cute_dsl_bf16_gemm_blackwell",
|
||||
mutates_args=(),
|
||||
device_types="cuda")
|
||||
def cute_dsl_bf16_gemm_blackwell(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Dense GEMM for BF16/FP16 inputs on Blackwell.
|
||||
|
||||
Args:
|
||||
input: Input tensor of shape (m, k), dtype bf16 or fp16
|
||||
weight: Weight tensor of shape (n, k), dtype bf16 or fp16
|
||||
output_dtype: Output tensor dtype (bf16, fp16, or fp32)
|
||||
|
||||
Returns:
|
||||
Output tensor of shape (m, n)
|
||||
"""
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
runner = CuteDSLBF16BlackwellGemm(output_dtype)
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::cute_dsl_bf16_gemm_blackwell",
|
||||
[runner],
|
||||
CuteDSLBF16BlackwellGemm.tuning_config,
|
||||
[input, weight],
|
||||
)
|
||||
return runner(
|
||||
inputs=[input, weight],
|
||||
tactic=best_tactic,
|
||||
)
|
||||
|
||||
@torch.library.register_fake("trtllm::cute_dsl_bf16_gemm_blackwell")
|
||||
def _(
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
output_dtype: torch.dtype,
|
||||
):
|
||||
# mat_a: [m, k], mat_b: [n, k] -> output: [m, n]
|
||||
shape = [mat_a.shape[0], mat_b.shape[0]]
|
||||
return mat_a.new_empty(shape, dtype=output_dtype)
|
||||
|
||||
class Sm100BlockScaledContiguousGroupedGemmRunner(TunableRunner):
|
||||
kernel_class = Sm100BlockScaledContiguousGroupedGemmKernel
|
||||
kernel_cache = dict()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user