[TRTLLM-9457][feat] Add cute dsl fp8 gemm for Blackwell (#10130)

Added FP8 cute dsl gemm and batch gemm.

Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
This commit is contained in:
yifeizhang-c 2026-02-06 09:49:30 +08:00 committed by GitHub
parent 712dcd31a9
commit 5521c7b7e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 3439 additions and 37 deletions

View File

@ -121,9 +121,8 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_quantize_1x128_permute102(at::Ten
int64_t scaleSizeInBytes = mGemmRunner.getActScaleSize(m, b * n);
int64_t elementSize = scaleSizeInBytes / torch::elementSize(FP8_BLOCK_SCALING_SF_DTYPE);
int m_4_align = (m + 3) / 4 * 4;
at::Tensor scaleFP8SF = at::detail::empty_cuda({b, m_4_align, elementSize / b / m_4_align},
FP8_BLOCK_SCALING_SF_DTYPE, self.device(), /* stride */ std::nullopt);
at::Tensor scaleFP8SF = at::detail::empty_cuda(
{elementSize}, FP8_BLOCK_SCALING_SF_DTYPE, self.device(), /* stride */ std::nullopt); // 1D tensor
__nv_fp8_e4m3* act_buffer = reinterpret_cast<__nv_fp8_e4m3*>(valueE4M3.data_ptr());
float* act_scale_buffer = reinterpret_cast<float*>(scaleFP8SF.data_ptr());
@ -133,6 +132,13 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_quantize_1x128_permute102(at::Ten
auto* output_buffer = reinterpret_cast<__nv_bfloat16 const*>(self.data_ptr());
mGemmRunner.fp8CS1x128Reshape(act_buffer, act_scale_buffer, output_buffer, n, b, m, lda, stream);
// scaleFP8SF = scaleFP8SF[:, 0:num_n_blocks, 0:m_padded]
auto const num_n_blocks = (n + 127) / 128;
auto const act_scal_elesize = b * num_n_blocks * m_padded;
TORCH_CHECK(act_scal_elesize <= scaleFP8SF.numel(), "Scale tensor size mismatch. Expected at least ",
act_scal_elesize, " elements, got ", scaleFP8SF.numel());
scaleFP8SF = scaleFP8SF.slice(0, 0, act_scal_elesize).view({b, num_n_blocks, m_padded}).contiguous();
return {valueE4M3.slice(0, 0, b * m * n).view({b, m, n}), scaleFP8SF};
}
} // namespace torch_ext

View File

@ -81,14 +81,14 @@ def inplace_info():
1: "logits"
},
torch.ops.trtllm.moe_unpermute_inplace.default: {
2: "output"
1: "output"
},
torch.ops.trtllm.moe_output_memset_inplace.default: {
1: "input"
},
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell.default:
{
6: "output"
1: "output"
},
torch.ops.trtllm.pp_recv_tensors.default: {
1: "tensors"
@ -96,6 +96,9 @@ def inplace_info():
torch.ops.trtllm.pp_send_tensors.default: {
1: "tensors"
},
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell.default: {
1: "output"
}
}
if IS_CUDA_TILE_AVAILABLE:
# cuda.tile availability depends on GPU capability thus runtime check.

View File

@ -6,13 +6,13 @@ import torch
from tensorrt_llm.logger import logger
from ..._utils import get_sm_version
from ..._utils import get_sm_version, is_sm_100f
from ...math_utils import ceil_div, pad_up
from ..autotuner import (AutoTuner, ConstraintSpec, DistributedTuningStrategy,
DynamicTensorSpec, OptimizationProfile, TunableRunner,
TuningConfig)
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
from ..utils import (fp4_scale_infer_shape,
from ..utils import (fp4_scale_infer_shape, fp8_scale_infer_shape,
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2)
@ -314,11 +314,13 @@ if IS_CUTLASS_DSL_AVAILABLE:
Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel
from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm_swiglu_fusion import \
Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
from ..cute_dsl_kernels.blackwell.blockwise_gemm.blockwise_gemm import \
Sm100BlockwiseGemmKernel
from ..cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import \
Sm100BlockScaledPersistentDenseGemmKernel
from ..cute_dsl_kernels.blackwell.utils import make_ptr
class CuteDSLNVFP4BlackwellLinear(TunableRunner):
class CuteDSLNVFP4BlackwellRunner(TunableRunner):
kernel_class = Sm100BlockScaledPersistentDenseGemmKernel
kernel_cache = dict()
tuning_config = TuningConfig(
@ -500,7 +502,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
**kwargs,
) -> torch.Tensor:
"""
Performs fp8 blockwise gemm operation using CuTe DSL.
Performs fp4 blockwise gemm operation using CuTe DSL.
Args:
inputs (List[torch.Tensor]):
@ -590,7 +592,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
stream = cuda.CUstream(torch_stream.cuda_stream)
cache_key = (sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab,
use_prefetch)
use_prefetch, self.use_tvm_ffi)
if swap_ab:
kernel_m = n
kernel_n = m
@ -770,7 +772,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
tuner = AutoTuner.get()
runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers,
runner = CuteDSLNVFP4BlackwellRunner(output_dtype, to_userbuffers,
use_tvm_ffi)
inputs = [input, weight, input_scale, weight_scale, alpha]
_, best_tactic = tuner.choose_one(
@ -2161,3 +2163,629 @@ if IS_CUTLASS_DSL_AVAILABLE:
dtype=input_scale.dtype,
device=input_scale.device)
return output, output_scale
class CuteDSLFp8BlackwellRunner(TunableRunner):
kernel_class = Sm100BlockwiseGemmKernel
kernel_cache = dict()
tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(2, 1, fp8_scale_infer_shape), ),
)
def __init__(self,
output_dtype: torch.dtype = torch.bfloat16,
use_tvm_ffi: bool = True):
super().__init__()
if output_dtype != torch.bfloat16:
raise ValueError(
f"CuteDSL FP8 GEMM only supports bfloat16 output, got {output_dtype}"
)
self.output_dtype = output_dtype
self.use_tvm_ffi = use_tvm_ffi
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
**kwargs,
) -> List[int]:
if not is_sm_100f():
logger.debug(
f"CuteDSL: SM version {get_sm_version()} is not supported. "
f"CuteDSL FP8 GEMM only supports SM 100 family. Skipping all tactics."
)
return []
m = inputs[0].shape[0]
n = inputs[1].shape[0]
k = inputs[0].shape[1]
batch_size = 1
# m,k
a_major = "k"
# n, k
b_major = "k"
# m, n
c_major = "n"
use_2cta_instrs_candi = [False, True]
mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)]
cluster_shape_mn_candi = [
(1, 1),
(1, 2),
(1, 4),
(2, 1),
(2, 2),
(2, 4),
(4, 1),
(4, 2),
(4, 4),
]
return [
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn)
for use_2cta_instrs in use_2cta_instrs_candi
for mma_tiler_mn in mma_tiler_mn_candi
for cluster_shape_mn in cluster_shape_mn_candi
if self.__class__.kernel_class.can_implement(
cutlass.Float8E4M3FN, # ab_dtype,
cutlass.Float32, # acc_dtype,
cutlass.BFloat16, # c_dtype,
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
m,
n,
k,
batch_size,
a_major,
b_major,
c_major,
)
]
def forward(
self,
inputs: List[torch.Tensor],
tactic,
) -> torch.Tensor:
"""
Performs fp8 blockwise (deepgemm like) operation using CuTe DSL.
Args:
inputs (List[torch.Tensor]):
inputs[0]: Input tensor of shape (m, k), dtype: fp8.
inputs[1]: Weight tensor of shape (n, k), dtype: fp8.
inputs[2]: Input scale factor tensor of shape (k // 128, m), dtype: fp32.
inputs[3]: Weight scale factor tensor of shape (n // 128, k // 128), dtype: fp32.
tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
Returns:
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
"""
if isinstance(tactic, tuple):
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic
else:
# fallback to default tactic
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [
False,
(128, 128),
(1, 1),
]
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs
m, n, k = a_tensor.shape[0], b_tensor.shape[0], b_tensor.shape[1]
sf_m = m
sf_k = ceil_div(k, 128)
sf_n = ceil_div(n, 128)
c_tensor = torch.empty(*(m, n),
dtype=torch.bfloat16,
device=a_tensor.device)
c_tmp = c_tensor.view((1, m, n))
c_tmp = c_tmp.permute(1, 2, 0)
if not self.use_tvm_ffi:
a_ptr = make_ptr(
cutlass.Float8E4M3FN,
a_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_ptr = make_ptr(
cutlass.Float8E4M3FN,
b_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
a_sf_ptr = make_ptr(
cutlass.Float32,
a_sf_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_sf_ptr = make_ptr(
cutlass.Float32,
b_sf_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
c_cute_tensor = cute.runtime.from_dlpack(
c_tmp).mark_layout_dynamic(leading_dim=1)
# get stream
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
cache_key = (
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
self.use_tvm_ffi,
)
if cache_key not in self.__class__.kernel_cache:
if self.use_tvm_ffi:
a_ptr = make_ptr(
cutlass.Float8E4M3FN,
a_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_ptr = make_ptr(
cutlass.Float8E4M3FN,
b_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
a_sf_ptr = make_ptr(
cutlass.Float32,
a_sf_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_sf_ptr = make_ptr(
cutlass.Float32,
b_sf_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
# Convert c_tensor to cute tensor for TVM FFI for env stream detection
c_cute_tensor = cute.runtime.from_dlpack(
c_tmp).mark_layout_dynamic(leading_dim=1)
stream = cute.runtime.make_fake_stream(
use_tvm_ffi_env_stream=True)
gemm = self.__class__.kernel_class(
cutlass.Float32, # acc_dtype,
use_2cta_instrs=use_2cta_instrs,
mma_tiler_mn=mma_tiler_mn,
cluster_shape_mn=cluster_shape_mn,
)
# Compute max active clusters on current device
hardware_info = cutlass.utils.HardwareInfo()
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1])
compiled_gemm = cute.compile(
gemm.wrapper,
m,
n,
k,
sf_m,
sf_n,
sf_k,
1, # batch
a_ptr,
b_ptr,
a_sf_ptr,
b_sf_ptr,
c_cute_tensor,
max_active_clusters=max_active_clusters,
stream=stream,
options=f"--opt-level 2 --enable-tvm-ffi"
if self.use_tvm_ffi else "--opt-level 2",
)
self.__class__.kernel_cache[cache_key] = compiled_gemm
else:
compiled_gemm = self.__class__.kernel_cache[cache_key]
# launch gemm kernel
if self.use_tvm_ffi:
# call with torch pointer types and no need to pass stream.
compiled_gemm(
m,
n,
k,
sf_m,
sf_n,
sf_k,
1, # batch
a_tensor.data_ptr(),
b_tensor.data_ptr(),
a_sf_tensor.data_ptr(),
b_sf_tensor.data_ptr(),
c_tmp,
)
else:
# call with cute types and need to pass torch stream.
compiled_gemm(
m,
n,
k,
sf_m,
sf_n,
sf_k,
1, # batch
a_ptr,
b_ptr,
a_sf_ptr,
b_sf_ptr,
c_cute_tensor,
stream=stream,
)
return c_tensor
# a/b: fp8, scale: fp32, output: bf16
@torch.library.custom_op("trtllm::cute_dsl_fp8_gemm_blackwell",
mutates_args=(),
device_types="cuda")
def cute_dsl_fp8_gemm_blackwell(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
output_dtype: torch.dtype = torch.bfloat16,
use_tvm_ffi: bool = True,
) -> torch.Tensor:
if output_dtype != torch.bfloat16:
raise ValueError(
f"CuteDSL FP8 GEMM only supports bfloat16 output, got {output_dtype}"
)
if not is_sm_100f():
raise ValueError(
f"CuteDSL: SM version {get_sm_version()} is not supported. "
f"CuteDSL FP8 GEMM only supports SM 100 family. Skipping all tactics."
)
tuner = AutoTuner.get()
runner = CuteDSLFp8BlackwellRunner(output_dtype=output_dtype,
use_tvm_ffi=use_tvm_ffi)
inputs = [input, weight, input_scale, weight_scale]
_, best_tactic = tuner.choose_one(
"trtllm::cute_dsl_fp8_gemm_blackwell::gemm",
[runner],
runner.__class__.tuning_config,
inputs,
)
return runner(inputs, tactic=best_tactic)
@torch.library.register_fake("trtllm::cute_dsl_fp8_gemm_blackwell")
def _(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
output_dtype: torch.dtype = torch.bfloat16,
use_tvm_ffi: bool = True,
):
# [m, k]
shape = list(mat_a.shape)
# [n, k]
shape[-1] = mat_b.shape[-2]
# output is fixed as bf16
ret = mat_a.new_empty(shape, dtype=torch.bfloat16)
return ret
class CuteDSLFp8BlackwellBmmRunner(TunableRunner):
kernel_class = Sm100BlockwiseGemmKernel
kernel_cache = dict()
tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 1, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(2, 2, fp8_scale_infer_shape), ),
)
def __init__(self,
output_dtype: torch.dtype = torch.bfloat16,
use_tvm_ffi: bool = True):
super().__init__()
if output_dtype != torch.bfloat16:
raise ValueError(
f"CuteDSL FP8 BMM only supports bfloat16 output, got {output_dtype}"
)
self.output_dtype = output_dtype
self.use_tvm_ffi = use_tvm_ffi
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
**kwargs,
) -> List[int]:
if not is_sm_100f():
logger.debug(
f"CuteDSL: SM version {get_sm_version()} is not supported. "
f"CuteDSL FP8 BMM only supports SM 100 family. Skipping all tactics."
)
return []
# [b, m, k]
batch_size, m, k = inputs[0].shape[0], inputs[0].shape[1], inputs[
0].shape[2]
# [b, n, k]
n = inputs[1].shape[1]
# m,k
a_major = "k"
# n, k
b_major = "k"
# m, n
c_major = "n"
use_2cta_instrs_candi = [False, True]
mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)]
cluster_shape_mn_candi = [
(1, 1),
(1, 2),
(1, 4),
(2, 1),
(2, 2),
(2, 4),
(4, 1),
(4, 2),
(4, 4),
]
return [
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn)
for use_2cta_instrs in use_2cta_instrs_candi
for mma_tiler_mn in mma_tiler_mn_candi
for cluster_shape_mn in cluster_shape_mn_candi
if self.__class__.kernel_class.can_implement(
cutlass.Float8E4M3FN, # ab_dtype,
cutlass.Float32, # acc_dtype,
cutlass.BFloat16, # c_dtype,
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
m,
n,
k,
batch_size,
a_major,
b_major,
c_major,
)
]
def forward(
self,
inputs: List[torch.Tensor],
tactic,
) -> None:
"""
Performs fp8 blockwise (deepgemm like) batched gemm operation using CuTe DSL.
Args:
inputs (List[torch.Tensor]):
inputs[0]: Input tensor of shape (batch_size, m, k), dtype: fp8.
inputs[1]: Weight tensor of shape (batch_size, n, k), dtype: fp8.
inputs[2]: Input scale tensor of shape (batch_size, k // 128, pad_up(m, 4)), dtype: fp32.
inputs[3]: Weight scale tensor of shape (batch_size, n // 128, k // 128), dtype: fp32.
tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
Returns:
torch.Tensor: Output tensor of shape (batch_size, m, n), dtype: bf16.
"""
if isinstance(tactic, tuple):
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic
else:
# fallback to default tactic
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [
False,
(128, 128),
(1, 1),
]
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor, c_tensor = inputs
c_tmp = c_tensor.permute(1, 2, 0)
batch_size = a_tensor.shape[0]
m = a_tensor.shape[1]
k = a_tensor.shape[2]
n = b_tensor.shape[1]
sf_m = pad_up(m, 4)
sf_k = ceil_div(k, 128)
sf_n = ceil_div(n, 128)
if not self.use_tvm_ffi:
a_ptr = make_ptr(
cutlass.Float8E4M3FN,
a_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_ptr = make_ptr(
cutlass.Float8E4M3FN,
b_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
a_sf_ptr = make_ptr(
cutlass.Float32,
a_sf_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_sf_ptr = make_ptr(
cutlass.Float32,
b_sf_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
c_cute_tensor = cute.runtime.from_dlpack(
c_tmp).mark_layout_dynamic(leading_dim=1)
# get stream
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
cache_key = (
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
self.use_tvm_ffi,
)
if cache_key not in self.__class__.kernel_cache:
if self.use_tvm_ffi:
a_ptr = make_ptr(
cutlass.Float8E4M3FN,
a_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_ptr = make_ptr(
cutlass.Float8E4M3FN,
b_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
a_sf_ptr = make_ptr(
cutlass.Float32,
a_sf_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_sf_ptr = make_ptr(
cutlass.Float32,
b_sf_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
# Convert c_tensor to cute tensor for TVM FFI for env stream detection)
c_cute_tensor = cute.runtime.from_dlpack(
c_tmp).mark_layout_dynamic(leading_dim=1)
# make faked stream for TVM FFI
stream = cute.runtime.make_fake_stream(
use_tvm_ffi_env_stream=True)
gemm = self.__class__.kernel_class(
cutlass.Float32, # acc_dtype,
use_2cta_instrs=use_2cta_instrs,
mma_tiler_mn=mma_tiler_mn,
cluster_shape_mn=cluster_shape_mn,
)
# Compute max active clusters on current device
hardware_info = cutlass.utils.HardwareInfo()
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1])
compiled_gemm = cute.compile(
gemm.wrapper,
m,
n,
k,
sf_m,
sf_n,
sf_k,
batch_size,
a_ptr,
b_ptr,
a_sf_ptr,
b_sf_ptr,
c_cute_tensor,
max_active_clusters=max_active_clusters,
stream=stream,
options=f"--opt-level 2 --enable-tvm-ffi"
if self.use_tvm_ffi else "--opt-level 2",
)
self.__class__.kernel_cache[cache_key] = compiled_gemm
else:
compiled_gemm = self.__class__.kernel_cache[cache_key]
# launch gemm kernel
if self.use_tvm_ffi:
# call with torch pointer types and no need to pass stream.
compiled_gemm(
m,
n,
k,
sf_m,
sf_n,
sf_k,
batch_size,
a_tensor.data_ptr(),
b_tensor.data_ptr(),
a_sf_tensor.data_ptr(),
b_sf_tensor.data_ptr(),
c_tmp,
)
else:
# call with cute types and need to pass torch stream.
compiled_gemm(
m,
n,
k,
sf_m,
sf_n,
sf_k,
batch_size,
a_ptr,
b_ptr,
a_sf_ptr,
b_sf_ptr,
c_cute_tensor,
stream=stream,
)
# a/b: fp8, scale: fp32, output: bf16
@torch.library.custom_op("trtllm::cute_dsl_fp8_bmm_blackwell",
mutates_args=("output", ),
device_types="cuda")
def cute_dsl_fp8_bmm_blackwell(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
output: torch.Tensor,
output_dtype: torch.dtype = torch.bfloat16,
use_tvm_ffi: bool = True,
) -> None:
if output_dtype != torch.bfloat16:
raise ValueError(
f"CuteDSL FP8 BMM only supports bfloat16 output, got {output_dtype}"
)
if not is_sm_100f():
raise ValueError(
f"CuteDSL: SM version {get_sm_version()} is not supported. "
f"CuteDSL FP8 BMM only supports SM 100 family. Skipping all tactics."
)
tuner = AutoTuner.get()
runner = CuteDSLFp8BlackwellBmmRunner(output_dtype=output_dtype,
use_tvm_ffi=use_tvm_ffi)
inputs = [input, weight, input_scale, weight_scale, output]
_, best_tactic = tuner.choose_one(
"trtllm::cute_dsl_fp8_bmm_blackwell::gemm",
[runner],
runner.__class__.tuning_config,
inputs,
)
runner(inputs, tactic=best_tactic)
@torch.library.register_fake("trtllm::cute_dsl_fp8_bmm_blackwell")
def _(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
output: torch.Tensor,
output_dtype: torch.dtype = torch.bfloat16,
use_tvm_ffi: bool = True,
) -> None:
batch_size, m, k = mat_a.shape[0], mat_a.shape[1], mat_a.shape[2]
n = mat_b.shape[1]
assert output.dtype == torch.bfloat16, "CuTe DSL fp8 bmm output dtype must be bf16"
assert output.shape == (batch_size, m,
n), "CuTe DSL fp8 bmm output shape is incorrect"

View File

@ -25,7 +25,7 @@ from ..utils import (ActivationType, fp4_scale_infer_shape,
if IS_CUTLASS_DSL_AVAILABLE:
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \
CuteDSLNVFP4BlackwellLinear
CuteDSLNVFP4BlackwellRunner
# Used to WAR an issue in torch.bmm that it would break the graph when the out is not contiguous.
@ -819,7 +819,7 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
"Please add other backends to allowed_backends.")
else:
# SM version OK, check if CuteDSL supports the current shape
cutedsl_runner = CuteDSLNVFP4BlackwellLinear(
cutedsl_runner = CuteDSLNVFP4BlackwellRunner(
self.output_dtype)
cutedsl_tactics = cutedsl_runner.get_valid_tactics(
inputs, profile)
@ -878,7 +878,7 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
self.output_dtype)(inputs,
tactic=sub_tactic)
elif backend == "cutedsl":
return CuteDSLNVFP4BlackwellLinear(
return CuteDSLNVFP4BlackwellRunner(
self.output_dtype, self.to_userbuffers)(inputs,
tactic=sub_tactic)
else:

File diff suppressed because it is too large Load Diff

View File

@ -122,6 +122,10 @@ class ModelConfig(Generic[TConfig]):
extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)
# cute dsl op configs
use_cute_dsl_blockscaling_mm: bool = False
use_cute_dsl_blockscaling_bmm: bool = False
_frozen: bool = field(default=False, init=False, repr=False)
# If true, ONLY the vision encoder part of the full model is loaded/executed.

View File

@ -672,6 +672,7 @@ class DeepseekV3Linear(Linear):
reduce_output: bool = True, # ROW parallel only
skip_create_weights_in_init: bool = False,
use_custom_cublas_mm: bool = False,
use_cute_dsl_blockscaling_mm: bool = False,
lora: Optional[LoraLayer] = None,
):
super().__init__(
@ -688,6 +689,7 @@ class DeepseekV3Linear(Linear):
skip_create_weights_in_init,
use_custom_cublas_mm,
lora,
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
)
def apply_linear(self,
@ -748,7 +750,10 @@ class DeepseekV3Attention(MLA):
quant_config=model_config.get_quant_config(),
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
use_custom_cublas_mm=True)
use_custom_cublas_mm=True,
use_cute_dsl_blockscaling_mm=model_config.
use_cute_dsl_blockscaling_mm,
)
class DeepseekV32Attention(MLA):
@ -925,6 +930,7 @@ class Deepseekv3MoE(nn.Module):
config = model_config.pretrained_config
self.top_k = top_k
self.use_dp = model_config.mapping.enable_attention_dp
self.use_cute_dsl_blockscaling_mm = model_config.use_cute_dsl_blockscaling_mm
gate_cls = DeepseekV3Gate
if hasattr(model_config.pretrained_config, "gate_cls"):
gate_cls = model_config.pretrained_config.gate_cls
@ -977,7 +983,9 @@ class Deepseekv3MoE(nn.Module):
dtype=dtype,
config=model_config,
overridden_tp_size=shared_tp_size,
reduce_output=False)
reduce_output=False,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm,
)
self.allreduce = None
if not self.use_dp and self.mapping.tp_size > 1:
@ -1262,13 +1270,17 @@ class DeepseekV3DecoderLayer(DecoderLayer):
self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_mlp_tp and self.is_nvfp4
self.fusion_config.POST_MLP_FUSION = self.enable_fusion and has_mlp_tp
self.mlp = GatedMLP(hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
bias=False,
dtype=config.torch_dtype,
config=model_config,
overridden_tp_size=self.mlp_tp_size,
reduce_output=has_mlp_tp)
self.mlp = GatedMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
bias=False,
dtype=config.torch_dtype,
config=model_config,
overridden_tp_size=self.mlp_tp_size,
reduce_output=has_mlp_tp,
use_cute_dsl_blockscaling_mm=model_config.
use_cute_dsl_blockscaling_mm,
)
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
eps=config.rms_norm_eps,
@ -1564,6 +1576,8 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
dtype=config.torch_dtype,
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
use_cute_dsl_blockscaling_mm=model_config.
use_cute_dsl_blockscaling_mm,
)
else:
self.eh_proj = Linear(
@ -1576,6 +1590,8 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
reduce_output=True,
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
use_cute_dsl_blockscaling_mm=model_config.
use_cute_dsl_blockscaling_mm,
)
self.shared_head = DeepseekV3MTPHead(model_config)

View File

@ -255,6 +255,9 @@ class Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm
self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm
qkv_shard_indices_mapping = {
"q": (0, self.q_size * (2 if self.attn_output_gate else 1)),
"k":
@ -280,7 +283,8 @@ class Attention(nn.Module):
force_dynamic_quantization=config.force_dynamic_quantization,
disable_deep_gemm=disable_deep_gemm,
use_custom_cublas_mm=use_custom_cublas_mm,
fused_weight_shard_indices_mapping=qkv_shard_indices_mapping)
fused_weight_shard_indices_mapping=qkv_shard_indices_mapping,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])
@ -299,7 +303,8 @@ class Attention(nn.Module):
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization,
disable_deep_gemm=disable_deep_gemm,
use_custom_cublas_mm=use_custom_cublas_mm)
use_custom_cublas_mm=use_custom_cublas_mm,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
self.quant_config = config.get_quant_config()
self.attn_backend = config.attn_backend
@ -686,6 +691,7 @@ def fp8_block_scaling_bmm_out(
mat2_scale: torch.Tensor,
out: torch.Tensor,
mat2_dequant: Optional[torch.Tensor] = None,
use_cute_dsl_blockscaling_bmm: bool = False,
) -> torch.Tensor:
sm_version = get_sm_version()
if sm_version == 90 or sm_version == 89:
@ -706,7 +712,17 @@ def fp8_block_scaling_bmm_out(
output)
out.copy_(output)
elif is_sm_100f(sm_version):
torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out)
if use_cute_dsl_blockscaling_bmm:
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
mat1)
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(mat1_fp8, mat2_fp8,
mat1_scale, mat2_scale,
out)
mat1_scale = None
else:
torch.bmm(mat1.transpose(0, 1),
mat2_dequant.transpose(1, 2),
out=out)
else:
raise NotImplementedError(f"SM{sm_version} is not supported")
@ -851,6 +867,9 @@ class MLA(nn.Module):
quant_config = config.get_quant_config()
self.quant_config = quant_config
self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm
self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm
if not self.is_lite:
self.kv_a_proj_with_mqa = Linear(
hidden_size,
@ -860,7 +879,8 @@ class MLA(nn.Module):
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
use_custom_cublas_mm=True,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
self.q_a_layernorm = RMSNorm(hidden_size=self.q_lora_rank,
eps=rms_norm_eps,
@ -876,7 +896,8 @@ class MLA(nn.Module):
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
else:
self.kv_a_proj_with_mqa = Linear(
hidden_size,
@ -886,7 +907,8 @@ class MLA(nn.Module):
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
use_custom_cublas_mm=True,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
self.q_proj = Linear(
self.q_lora_rank,
@ -898,7 +920,8 @@ class MLA(nn.Module):
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
self.q_b_proj = self.q_proj
self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank,
@ -915,7 +938,8 @@ class MLA(nn.Module):
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
# This parameter will view into self.kv_b_proj.weight after loading weights.
# For dummy weight initialization, this parameter is initialized with empty tensor.
# Used in forward_absorption only
@ -947,7 +971,8 @@ class MLA(nn.Module):
skip_create_weights_in_init=config.skip_create_weights_in_init,
reduce_output=reduce_output,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
@ -1083,7 +1108,7 @@ class MLA(nn.Module):
),
requires_grad=False,
)
if is_sm_100f():
if is_sm_100f() and not self.use_cute_dsl_blockscaling_bmm:
assert self.dtype == torch.bfloat16
self.k_b_proj_trans_dequant = nn.Parameter(
torch.empty(
@ -1875,6 +1900,7 @@ class MLA(nn.Module):
self.k_b_proj_trans_scale,
q_nope_out,
self.k_b_proj_trans_dequant,
self.use_cute_dsl_blockscaling_bmm,
),
lambda: self.mqa.mla_rope_generation(
fused_q,
@ -1952,6 +1978,7 @@ class MLA(nn.Module):
self.v_b_proj_scale,
attn_output.transpose(0, 1),
self.v_b_proj_dequant,
self.use_cute_dsl_blockscaling_bmm,
)
else:
raise NotImplementedError(
@ -2007,6 +2034,7 @@ class MLA(nn.Module):
self.k_b_proj_trans_scale,
q_nope_out,
self.k_b_proj_trans_dequant,
self.use_cute_dsl_blockscaling_bmm,
)
else:
raise NotImplementedError(
@ -2062,6 +2090,7 @@ class MLA(nn.Module):
self.v_b_proj_scale,
attn_output.transpose(0, 1),
self.v_b_proj_dequant,
self.use_cute_dsl_blockscaling_bmm,
)
else:
raise NotImplementedError(
@ -2129,6 +2158,7 @@ class MLA(nn.Module):
self.k_b_proj_trans_scale,
q_nope_out,
self.k_b_proj_trans_dequant,
self.use_cute_dsl_blockscaling_bmm,
)
else:
raise NotImplementedError(
@ -2205,6 +2235,7 @@ class MLA(nn.Module):
self.v_b_proj_scale,
attn_output.transpose(0, 1),
self.v_b_proj_dequant,
self.use_cute_dsl_blockscaling_bmm,
)
else:
raise NotImplementedError(

View File

@ -982,10 +982,9 @@ class FP8BlockScalesLinearMethod(UnquantizedLinearMethod):
if is_sm_100f():
if module.use_cute_dsl_blockscaling_mm or module.disable_deep_gemm:
# TODO (@lmin): replace with cute_dsl gemm
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
input)
output = torch.ops.trtllm.fp8_block_scaling_gemm(
output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell(
act_input_fp8, module.weight, act_input_sf,
module.weight_scale)
else:

View File

@ -364,7 +364,12 @@ class ModelLoader:
use_low_precision_moe_combine=self.llm_args.moe_config.
use_low_precision_moe_combine,
nvfp4_gemm_allowed_backends=self.llm_args.nvfp4_gemm_config.
allowed_backends)
allowed_backends,
use_cute_dsl_blockscaling_mm=self.llm_args.
use_cute_dsl_blockscaling_mm,
use_cute_dsl_blockscaling_bmm=self.llm_args.
use_cute_dsl_blockscaling_bmm,
)
# Only pass model_kwargs if it's explicitly set (not None)
if self.llm_args.model_kwargs is not None:

View File

@ -301,6 +301,16 @@ def fp4_unswizzled_scale_infer_shape(input_shapes: List[List[int]]):
return scale_shape * 2
def fp8_scale_infer_shape(input_shapes: List[List[int]]):
"""Calculate the dimensions of the fp8 scale tensor.
"""
input_shape = input_shapes[0]
assert len(input_shape) == 2 or len(input_shape) == 3
has_batch = len(input_shape) == 3
m = input_shape[-2]
return pad_up(m, 4) if has_batch else m
_enable_piecewise_cuda_graph = True

View File

@ -3050,6 +3050,18 @@ class TorchLlmArgs(BaseLlmArgs):
"Only enable it if you intend to use this feature.",
status="prototype")
# fp8 cute dsl configs
use_cute_dsl_blockscaling_mm: bool = Field(
default=False,
description="If true, use CuTe DSL fp8 blockscaling mm implementation.",
status="prototype",
)
use_cute_dsl_blockscaling_bmm: bool = Field(
default=False,
description="If true, use CuTe DSL fp8 blockscaling bmm implementation.",
status="prototype",
)
# PrivateVars
_quant_config: Optional[QuantConfig] = PrivateAttr(default=None)

View File

@ -1533,6 +1533,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
torch_compile_config=torch_compile_config,
moe_config=MoeConfig(backend="CUTEDSL"),
use_cute_dsl_blockscaling_mm=True,
use_cute_dsl_blockscaling_bmm=True,
)
if fp8kv:
@ -1695,6 +1697,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
torch_compile_config=torch_compile_config,
moe_config=MoeConfig(backend="CUTEDSL"),
use_cute_dsl_blockscaling_mm=True,
use_cute_dsl_blockscaling_bmm=True,
)
if fp8kv:

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -109,6 +109,54 @@ def test_fp8_block_scale_gemm(dtype, m, k, n):
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
@pytest.mark.skipif(
not isSM100Family(),
reason="The test is for Blackwell. Current SM is %d." % getSMVersion(),
)
@pytest.mark.parametrize(
"k, n",
[(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096),
(2048, 7168), (1024, 1024)],
)
@pytest.mark.parametrize(
"m",
[7, 64, 128, 4096],
)
@pytest.mark.parametrize(
"dtype",
[torch.bfloat16],
)
@pytest.mark.parametrize(
"use_tvm_ffi",
[True, False],
)
def test_cute_dsl_fp8_block_scale_gemm(dtype, m, k, n, use_tvm_ffi):
torch.random.manual_seed(0)
a = torch.randn((m, k), device='cuda', dtype=dtype) / k
b = torch.randn((n, k), device='cuda', dtype=dtype) / k
act_a_fp8, act_a_sf = torch.ops.trtllm.fp8_quantize_1x128(a)
act_b_fp8, act_b_sf = per_block_cast_to_fp8(b)
output_expected = a @ b.t()
with autotune():
cute_dsl_output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell(
act_a_fp8, act_b_fp8, act_a_sf, act_b_sf, use_tvm_ffi=use_tvm_ffi)
# test Cute DSL kernel
cute_dsl_output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell(
act_a_fp8, act_b_fp8, act_a_sf, act_b_sf, use_tvm_ffi=use_tvm_ffi)
diff = calc_diff(cute_dsl_output, output_expected)
assert diff < 1e-3
torch.testing.assert_close(cute_dsl_output,
output_expected,
atol=1e-3,
rtol=1e-3)
@pytest.mark.skipif(
getSMVersion() != 90 and getSMVersion() != 89 and getSMVersion() != 120,
reason="The test is for Hopper and Ada only. Current SM is %d." %
@ -171,6 +219,69 @@ def test_fp8_block_scale_bmm(dtype, m, k, n, num_groups):
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
@pytest.mark.skipif(
not isSM100Family(),
reason="The test is for Blackwell. Current SM is %d." % getSMVersion(),
)
@pytest.mark.parametrize(
"k, n",
[(7168, 2112), (512, 32768), (16384, 7168), (2048, 7168)],
)
@pytest.mark.parametrize(
"m",
[7, 64, 128],
)
@pytest.mark.parametrize(
"num_groups",
[4, 8, 16],
)
@pytest.mark.parametrize(
"dtype",
[torch.bfloat16],
)
@pytest.mark.parametrize(
"use_tvm_ffi",
[True, False],
)
def test_cute_dsl_fp8_block_scale_bmm(dtype, m, k, n, num_groups, use_tvm_ffi):
torch.random.manual_seed(0)
a = torch.randn((m, num_groups, k), device='cuda', dtype=dtype) / k
a_fp8, a_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(a)
b = torch.randn((num_groups, n, k), device='cuda', dtype=dtype) / k
b_fp8 = torch.zeros_like(b, device='cuda', dtype=torch.float8_e4m3fn)
b_scales = torch.zeros((num_groups, (n + 127) // 128, (k + 127) // 128),
device='cuda',
dtype=torch.float)
for i in range(num_groups):
b_fp8[i], b_scales[i] = per_block_cast_to_fp8(b[i])
output_expected = torch.einsum('mgk,gnk->gmn', a, b)
output = torch.empty((num_groups, m, n),
device='cuda',
dtype=torch.bfloat16)
# tune
with autotune():
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(a_fp8,
b_fp8,
a_scales,
b_scales,
output,
use_tvm_ffi=use_tvm_ffi)
# run the tuned kernel
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(a_fp8,
b_fp8,
a_scales,
b_scales,
output,
use_tvm_ffi=use_tvm_ffi)
diff = calc_diff(output, output_expected)
assert diff < 1e-3
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
def deepSeekFp8ComputeGemmReference(mM, mN, mK, valsC, dqSfsC, valsA, dqSfsA,
valsB, dqSfsB, quantizeOutput, tileSize):
for mi in range(mM):

View File

@ -239,6 +239,14 @@ methods:
annotation: Optional[Dict[str, Any]]
default: null
status: prototype
use_cute_dsl_blockscaling_mm:
annotation: bool
default: False
status: prototype
use_cute_dsl_blockscaling_bmm:
annotation: bool
default: False
status: prototype
return_annotation: None
generate:
parameters: