cuteDSL dense gemm bf16

Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
Faraz Khoubsirat 2025-12-16 20:40:23 +00:00
parent bd13957e70
commit 53c6f63fa8
No known key found for this signature in database
GPG Key ID: 15733A5323348457
3 changed files with 1435 additions and 1 deletions

View File

@ -34,7 +34,9 @@ if IS_FLASHINFER_AVAILABLE:
]
if IS_CUTLASS_DSL_AVAILABLE:
from .cute_dsl_custom_ops import cute_dsl_nvfp4_gemm_blackwell
from .cute_dsl_custom_ops import (cute_dsl_bf16_gemm_blackwell,
cute_dsl_nvfp4_gemm_blackwell)
__all__ += [
'cute_dsl_nvfp4_gemm_blackwell',
'cute_dsl_bf16_gemm_blackwell',
]

View File

@ -354,6 +354,8 @@ if IS_CUTLASS_DSL_AVAILABLE:
Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
from ..cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import \
Sm100BlockScaledPersistentDenseGemmKernel
from ..cute_dsl_kernels.blackwell.dense_gemm_persistent import (
Sm100PersistentDenseGemmKernel, Sm100PersistentDenseGemmKernelWrapper)
from ..cute_dsl_kernels.blackwell.utils import make_ptr
class CuteDSLNVFP4BlackwellLinear(TunableRunner):
@ -835,6 +837,295 @@ if IS_CUTLASS_DSL_AVAILABLE:
ret = mat_a.new_empty(shape, dtype=torch.bfloat16)
return ret
class CuteDSLBF16BlackwellGemm(TunableRunner):
"""TunableRunner for BF16 dense GEMM on Blackwell using cuteDSL."""
kernel_class = Sm100PersistentDenseGemmKernel
kernel_cache = dict()
tuning_config = TuningConfig(dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ), )
def __init__(self, output_dtype: torch.dtype):
super().__init__()
self.output_dtype = output_dtype
if (sm_version := get_sm_version()) not in (100, 103):
raise ValueError(
f"SM version {sm_version} is not supported for CuteDSLBF16BlackwellGemm, it only supports SM 100/103 (Blackwell)"
)
def __hash__(self):
return hash((self.output_dtype, ))
def __eq__(self, other):
if not isinstance(other, CuteDSLBF16BlackwellGemm):
return False
return self.output_dtype == other.output_dtype
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
**kwargs,
) -> List[Tuple[int, int]]:
assert inputs[0].dim() == 2
assert inputs[1].dim() == 2
m = inputs[0].shape[0]
n = inputs[1].shape[0]
k = inputs[0].shape[1]
batch_size = 1
a_major = "k"
b_major = "k"
# Determine cutlass dtype from torch dtype
if inputs[0].dtype == torch.bfloat16:
ab_dtype = cutlass.BFloat16
elif inputs[0].dtype == torch.float16:
ab_dtype = cutlass.Float16
else:
ab_dtype = cutlass.BFloat16
if self.output_dtype == torch.bfloat16:
c_dtype = cutlass.BFloat16
elif self.output_dtype == torch.float16:
c_dtype = cutlass.Float16
elif self.output_dtype == torch.float32:
c_dtype = cutlass.Float32
else:
c_dtype = cutlass.BFloat16
mma_tiler_mn_candidates = [
(256, 128),
(128, 128),
(128, 256),
(256, 256),
(256, 64),
(128, 64),
]
cluster_shape_mn_candidates = [
(1, 1),
(1, 2),
(1, 4),
(2, 1),
(2, 2),
(2, 4),
(4, 1),
(4, 2),
(4, 4),
]
swap_ab_candidates = [True, False]
valid_tactics = []
for swap_ab in swap_ab_candidates:
for mma_tiler_mn in mma_tiler_mn_candidates:
for cluster_shape_mn in cluster_shape_mn_candidates:
if swap_ab:
c_major = "m"
kernel_m = n
kernel_n = m
else:
c_major = "n"
kernel_m = m
kernel_n = n
use_2cta_instrs = mma_tiler_mn[0] == 256
kernel = Sm100PersistentDenseGemmKernel(
cutlass.Float32,
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
)
if kernel.can_implement(
(kernel_m, kernel_n, k, batch_size),
ab_dtype,
c_dtype,
a_major,
b_major,
c_major,
):
valid_tactics.append(
(mma_tiler_mn, cluster_shape_mn, swap_ab))
return valid_tactics
def forward(
self,
inputs: List[torch.Tensor],
tactic,
) -> torch.Tensor:
"""
Performs BF16 dense GEMM operation using CuTe DSL.
Args:
inputs (List[torch.Tensor]):
inputs[0]: Input tensor of shape (m, k), dtype: bf16/fp16.
inputs[1]: Weight tensor of shape (n, k), dtype: bf16/fp16.
tactic: Tiling and cluster strategy, tuple (mma_tiler_mn, cluster_shape_mn, swap_ab).
Returns:
torch.Tensor: Output tensor of shape (m, n).
"""
if isinstance(tactic, tuple):
mma_tiler_mn, cluster_shape_mn, swap_ab = tactic
else:
mma_tiler_mn, cluster_shape_mn, swap_ab = [
(128, 128),
(1, 1),
False,
]
a_tensor, b_tensor = inputs
m, k, n = a_tensor.shape[0], a_tensor.shape[1], b_tensor.shape[0]
c_tensor = torch.empty(*(m, n),
dtype=self.output_dtype,
device="cuda")
if swap_ab:
c_tensor = c_tensor.permute(1, 0)
# Determine cutlass dtypes
if a_tensor.dtype == torch.bfloat16:
ab_cutlass_dtype = cutlass.BFloat16
elif a_tensor.dtype == torch.float16:
ab_cutlass_dtype = cutlass.Float16
else:
ab_cutlass_dtype = cutlass.BFloat16
if self.output_dtype == torch.bfloat16:
c_cutlass_dtype = cutlass.BFloat16
elif self.output_dtype == torch.float16:
c_cutlass_dtype = cutlass.Float16
elif self.output_dtype == torch.float32:
c_cutlass_dtype = cutlass.Float32
else:
c_cutlass_dtype = cutlass.BFloat16
a_ptr = make_ptr(ab_cutlass_dtype,
a_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16)
b_ptr = make_ptr(ab_cutlass_dtype,
b_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16)
c_ptr = make_ptr(c_cutlass_dtype,
c_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16)
torch_stream = torch.cuda.current_stream()
stream = cuda.CUstream(torch_stream.cuda_stream)
use_2cta_instrs = mma_tiler_mn[0] == 256
CACHE_KEY = (
ab_cutlass_dtype,
c_cutlass_dtype,
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
swap_ab,
)
if swap_ab:
kernel_a_ptr = b_ptr
kernel_b_ptr = a_ptr
kernel_m = n
kernel_n = m
else:
kernel_a_ptr = a_ptr
kernel_b_ptr = b_ptr
kernel_m = m
kernel_n = n
if CACHE_KEY not in CuteDSLBF16BlackwellGemm.kernel_cache:
gemm = Sm100PersistentDenseGemmKernelWrapper(
cutlass.Float32,
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
)
hardware_info = cutlass.utils.HardwareInfo()
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1])
compiled_gemm = cute.compile(
gemm,
kernel_m,
kernel_n,
k,
1, # batch_size
kernel_a_ptr,
kernel_b_ptr,
c_ptr,
max_active_clusters,
stream,
swap_ab,
)
CuteDSLBF16BlackwellGemm.kernel_cache[CACHE_KEY] = compiled_gemm
else:
compiled_gemm = CuteDSLBF16BlackwellGemm.kernel_cache[CACHE_KEY]
compiled_gemm(
kernel_m,
kernel_n,
k,
kernel_a_ptr,
kernel_b_ptr,
c_ptr,
stream,
)
if swap_ab:
c_tensor = c_tensor.permute(1, 0)
return c_tensor
# input/weight: bf16/fp16, output: bf16/fp16/fp32
@torch.library.custom_op("trtllm::cute_dsl_bf16_gemm_blackwell",
mutates_args=(),
device_types="cuda")
def cute_dsl_bf16_gemm_blackwell(
input: torch.Tensor,
weight: torch.Tensor,
output_dtype: torch.dtype,
) -> torch.Tensor:
"""
Dense GEMM for BF16/FP16 inputs on Blackwell.
Args:
input: Input tensor of shape (m, k), dtype bf16 or fp16
weight: Weight tensor of shape (n, k), dtype bf16 or fp16
output_dtype: Output tensor dtype (bf16, fp16, or fp32)
Returns:
Output tensor of shape (m, n)
"""
tuner = AutoTuner.get()
runner = CuteDSLBF16BlackwellGemm(output_dtype)
_, best_tactic = tuner.choose_one(
"trtllm::cute_dsl_bf16_gemm_blackwell",
[runner],
CuteDSLBF16BlackwellGemm.tuning_config,
[input, weight],
)
return runner(
inputs=[input, weight],
tactic=best_tactic,
)
@torch.library.register_fake("trtllm::cute_dsl_bf16_gemm_blackwell")
def _(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
output_dtype: torch.dtype,
):
# mat_a: [m, k], mat_b: [n, k] -> output: [m, n]
shape = [mat_a.shape[0], mat_b.shape[0]]
return mat_a.new_empty(shape, dtype=output_dtype)
class Sm100BlockScaledContiguousGroupedGemmRunner(TunableRunner):
kernel_class = Sm100BlockScaledContiguousGroupedGemmKernel
kernel_cache = dict()

File diff suppressed because it is too large Load Diff