This commit is contained in:
yifeizhang-c 2026-01-13 21:01:06 +08:00 committed by GitHub
commit 3d19df2f87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 6357 additions and 34 deletions

View File

@ -121,9 +121,8 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_quantize_1x128_permute102(at::Ten
int64_t scaleSizeInBytes = mGemmRunner.getActScaleSize(m, b * n);
int64_t elementSize = scaleSizeInBytes / torch::elementSize(FP8_BLOCK_SCALING_SF_DTYPE);
int m_4_align = (m + 3) / 4 * 4;
at::Tensor scaleFP8SF = at::detail::empty_cuda({b, m_4_align, elementSize / b / m_4_align},
FP8_BLOCK_SCALING_SF_DTYPE, self.device(), /* stride */ std::nullopt);
at::Tensor scaleFP8SF = at::detail::empty_cuda(
{elementSize}, FP8_BLOCK_SCALING_SF_DTYPE, self.device(), /* stride */ std::nullopt); // 1D tensor
__nv_fp8_e4m3* act_buffer = reinterpret_cast<__nv_fp8_e4m3*>(valueE4M3.data_ptr());
float* act_scale_buffer = reinterpret_cast<float*>(scaleFP8SF.data_ptr());
@ -133,6 +132,13 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_quantize_1x128_permute102(at::Ten
auto* output_buffer = reinterpret_cast<__nv_bfloat16 const*>(self.data_ptr());
mGemmRunner.fp8CS1x128Reshape(act_buffer, act_scale_buffer, output_buffer, n, b, m, lda, stream);
// scaleFP8SF = scaleFP8SF[:, 0:num_n_blocks, 0:m_padded]
auto const num_n_blocks = (n + 127) / 128;
auto const act_scal_elesize = b * num_n_blocks * m_padded;
TORCH_CHECK(act_scal_elesize <= scaleFP8SF.numel(), "Scale tensor size mismatch. Expected at least ",
act_scal_elesize, " elements, got ", scaleFP8SF.numel());
scaleFP8SF = scaleFP8SF.slice(0, 0, act_scal_elesize).view({b, num_n_blocks, m_padded}).contiguous();
return {valueE4M3.slice(0, 0, b * m * n).view({b, m, n}), scaleFP8SF};
}
} // namespace torch_ext

View File

@ -130,6 +130,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
int const num_experts_per_node = num_experts_on_rank;
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
int64_t num_moe_inputs = static_cast<int64_t>(experts_per_token * num_rows);
TORCH_CHECK(num_moe_inputs <= std::numeric_limits<int32_t>::max(),
"num_moe_inputs exceeds int32 range (because we use int32 for expert_first_token_offset_tensor). "
"num_moe_inputs = ",
num_moe_inputs);
auto permuted_row_to_unpermuted_row_tensor
= torch::empty({num_moe_inputs}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
@ -226,6 +230,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
"Invalid dtype, only supports input tensor with float32, float16 and bfloat16 dtype");
break;
}
expert_first_token_offset_tensor = expert_first_token_offset_tensor.to(torch::kInt32);
return std::make_tuple(permuted_row_to_unpermuted_row_tensor, permuted_token_selected_experts_tensor,
permuted_data_tensor, expert_first_token_offset_tensor, permuted_token_final_scales_tensor,
unpermuted_row_to_permuted_row_tensor);

View File

@ -5,13 +5,13 @@ import torch
from tensorrt_llm.logger import logger
from ..._utils import get_sm_version
from ..._utils import get_sm_version, is_sm_100f
from ...math_utils import ceil_div, pad_up
from ..autotuner import (AutoTuner, ConstraintSpec, DistributedTuningStrategy,
DynamicTensorSpec, OptimizationProfile, TunableRunner,
TuningConfig)
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
from ..utils import (fp4_scale_infer_shape,
from ..utils import (fp4_scale_infer_shape, fp8_scale_infer_shape,
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2)
@ -337,6 +337,10 @@ if IS_CUTLASS_DSL_AVAILABLE:
Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel
from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm_swiglu_fusion import \
Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
from ..cute_dsl_kernels.blackwell.blockwise_gemm.blockwise_gemm import \
BlockwiseGemmKernel
from ..cute_dsl_kernels.blackwell.blockwise_gemm.contiguous_offset_grouped_gemm import \
BlockwiseContiguousGroupedGemmKernel
from ..cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import \
Sm100BlockScaledPersistentDenseGemmKernel
from ..cute_dsl_kernels.blackwell.utils import make_ptr
@ -523,7 +527,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
**kwargs,
) -> torch.Tensor:
"""
Performs fp8 blockwise gemm operation using CuTe DSL.
Performs fp4 blockwise gemm operation using CuTe DSL.
Args:
inputs (List[torch.Tensor]):
@ -2179,3 +2183,733 @@ if IS_CUTLASS_DSL_AVAILABLE:
dtype=input_scale.dtype,
device=input_scale.device)
return output, output_scale
class CuteDSLFp8BlackwellLinear(TunableRunner):
kernel_class = BlockwiseGemmKernel
kernel_cache = dict()
tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(2, 1, fp8_scale_infer_shape), ),
)
def __init__(self):
super().__init__()
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
**kwargs,
) -> List[int]:
if not is_sm_100f():
logger.debug(
f"CuteDSL: SM version {get_sm_version()} is not supported. "
f"CuteDSL FP8 GEMM only supports SM 100 family. Skipping all tactics."
)
return []
m = inputs[0].shape[0]
n = inputs[1].shape[0]
k = inputs[0].shape[1]
batch_size = 1
# m,k
a_major = "k"
# n, k
b_major = "k"
# m, n
c_major = "n"
use_2cta_instrs_candi = [False, True]
mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)]
cluster_shape_mn_candi = [
(1, 1),
(1, 2),
(1, 4),
(2, 1),
(2, 2),
(2, 4),
(4, 1),
(4, 2),
(4, 4),
]
return [
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn)
for use_2cta_instrs in use_2cta_instrs_candi
for mma_tiler_mn in mma_tiler_mn_candi
for cluster_shape_mn in cluster_shape_mn_candi
if self.__class__.kernel_class.can_implement(
cutlass.Float8E4M3FN, # ab_dtype,
cutlass.Float32, # acc_dtype,
cutlass.BFloat16, # c_dtype,
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
m,
n,
k,
batch_size,
a_major,
b_major,
c_major,
)
]
def forward(
self,
inputs: List[torch.Tensor],
tactic,
) -> torch.Tensor:
"""
Performs fp8 blockwise (deepgemm like) operation using CuTe DSL.
Args:
inputs (List[torch.Tensor]):
inputs[0]: Input tensor of shape (m, k), dtype: fp8.
inputs[1]: Weight tensor of shape (n, k), dtype: fp8.
inputs[2]: Input scale factor tensor of shape (k // 128, m), dtype: fp32.
inputs[3]: Weight scale factor tensor of shape (n // 128, k // 128), dtype: fp32.
tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
Returns:
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
"""
if isinstance(tactic, tuple):
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic
else:
# fallback to default tactic
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [
False,
(128, 128),
(1, 1),
]
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs
m, n, k = a_tensor.shape[0], b_tensor.shape[0], b_tensor.shape[1]
sf_m = m
sf_k = ceil_div(k, 128)
sf_n = ceil_div(n, 128)
c_tensor = torch.empty(*(m, n),
dtype=torch.bfloat16,
device=a_tensor.device)
a_ptr = make_ptr(
cutlass.Float8E4M3FN,
a_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_ptr = make_ptr(
cutlass.Float8E4M3FN,
b_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
a_sf_ptr = make_ptr(
cutlass.Float32,
a_sf_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_sf_ptr = make_ptr(
cutlass.Float32,
b_sf_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
c_ptr = make_ptr(
cutlass.BFloat16,
c_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
# get stream
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
cache_key = (
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
)
if cache_key not in self.__class__.kernel_cache:
gemm = self.__class__.kernel_class(
cutlass.Float32, # acc_dtype,
use_2cta_instrs=use_2cta_instrs,
mma_tiler_mn=mma_tiler_mn,
cluster_shape_mn=cluster_shape_mn,
)
# Compute max active clusters on current device
hardware_info = cutlass.utils.HardwareInfo()
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1])
compiled_gemm = cute.compile(
gemm.wrapper,
m,
n,
k,
sf_m,
sf_n,
sf_k,
1, # batch
a_ptr,
b_ptr,
a_sf_ptr,
b_sf_ptr,
c_ptr,
max_active_clusters=max_active_clusters,
stream=stream,
)
self.__class__.kernel_cache[cache_key] = compiled_gemm
else:
compiled_gemm = self.__class__.kernel_cache[cache_key]
# launch gemm kernel
compiled_gemm(
m,
n,
k,
sf_m,
sf_n,
sf_k,
1, # batch
a_ptr,
b_ptr,
a_sf_ptr,
b_sf_ptr,
c_ptr,
stream=stream,
)
return c_tensor
# a/b: fp8, scale: fp32, output: bf16
@torch.library.custom_op("trtllm::cute_dsl_fp8_gemm_blackwell",
mutates_args=(),
device_types="cuda")
def cute_dsl_fp8_gemm_blackwell(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
) -> torch.Tensor:
tuner = AutoTuner.get()
runner = CuteDSLFp8BlackwellLinear()
inputs = [input, weight, input_scale, weight_scale]
_, best_tactic = tuner.choose_one(
"trtllm::cute_dsl_fp8_gemm_blackwell::gemm",
[runner],
runner.__class__.tuning_config,
inputs,
)
return runner(inputs, tactic=best_tactic)
@torch.library.register_fake("trtllm::cute_dsl_fp8_gemm_blackwell")
def _(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
):
# [m, k]
shape = [i for i in mat_a.shape]
# [n, k]
shape[-1] = mat_b.shape[-2]
# output is fixed as bf16
ret = mat_a.new_empty(shape, dtype=torch.bfloat16)
return ret
class CuteDSLFp8BlackwellBmm(TunableRunner):
kernel_class = BlockwiseGemmKernel
kernel_cache = dict()
tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 1, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(2, 1, fp8_scale_infer_shape), ),
)
def __init__(self):
super().__init__()
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
**kwargs,
) -> List[int]:
if not is_sm_100f():
logger.debug(
f"CuteDSL: SM version {get_sm_version()} is not supported. "
f"CuteDSL FP8 BMM only supports SM 100 family. Skipping all tactics."
)
return []
# [b, m, k]
batch_size, m, k = inputs[0].shape[0], inputs[0].shape[1], inputs[
0].shape[2]
# [b, n, k]
n = inputs[1].shape[1]
# m,k
a_major = "k"
# n, k
b_major = "k"
# m, n
c_major = "n"
use_2cta_instrs_candi = [False, True]
mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)]
cluster_shape_mn_candi = [
(1, 1),
(1, 2),
(1, 4),
(2, 1),
(2, 2),
(2, 4),
(4, 1),
(4, 2),
(4, 4),
]
return [
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn)
for use_2cta_instrs in use_2cta_instrs_candi
for mma_tiler_mn in mma_tiler_mn_candi
for cluster_shape_mn in cluster_shape_mn_candi
if self.__class__.kernel_class.can_implement(
cutlass.Float8E4M3FN, # ab_dtype,
cutlass.Float32, # acc_dtype,
cutlass.BFloat16, # c_dtype,
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
m,
n,
k,
batch_size,
a_major,
b_major,
c_major,
)
]
def forward(
self,
inputs: List[torch.Tensor],
tactic,
) -> None:
"""
Performs fp8 blockwise (deepgemm like) batched gemm operation using CuTe DSL.
Args:
inputs (List[torch.Tensor]):
inputs[0]: Input tensor of shape (batch_size, m, k), dtype: fp8.
inputs[1]: Weight tensor of shape (batch_size, n, k), dtype: fp8.
inputs[2]: Input scale tensor of shape (batch_size, k // 128, pad_up(m, 4)), dtype: fp32.
inputs[3]: Weight scale tensor of shape (batch_size, n // 128, k // 128), dtype: fp32.
tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
Returns:
torch.Tensor: Output tensor of shape (batch_size, m, n), dtype: bf16.
"""
if isinstance(tactic, tuple):
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic
else:
# fallback to default tactic
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [
False,
(128, 128),
(1, 1),
]
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor, c_tensor = inputs
batch_size = a_tensor.shape[0]
m = a_tensor.shape[1]
k = a_tensor.shape[2]
n = b_tensor.shape[1]
sf_m = pad_up(m, 4)
sf_k = ceil_div(k, 128)
sf_n = ceil_div(n, 128)
a_ptr = make_ptr(
cutlass.Float8E4M3FN,
a_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_ptr = make_ptr(
cutlass.Float8E4M3FN,
b_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
a_sf_ptr = make_ptr(
cutlass.Float32,
a_sf_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_sf_ptr = make_ptr(
cutlass.Float32,
b_sf_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
c_ptr = make_ptr(
cutlass.BFloat16,
c_tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
# get stream
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
cache_key = (
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
)
if cache_key not in self.__class__.kernel_cache:
gemm = self.__class__.kernel_class(
cutlass.Float32, # acc_dtype,
use_2cta_instrs=use_2cta_instrs,
mma_tiler_mn=mma_tiler_mn,
cluster_shape_mn=cluster_shape_mn,
)
# Compute max active clusters on current device
hardware_info = cutlass.utils.HardwareInfo()
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1])
compiled_gemm = cute.compile(
gemm.wrapper,
m,
n,
k,
sf_m,
sf_n,
sf_k,
batch_size,
a_ptr,
b_ptr,
a_sf_ptr,
b_sf_ptr,
c_ptr,
max_active_clusters=max_active_clusters,
stream=stream,
)
self.__class__.kernel_cache[cache_key] = compiled_gemm
else:
compiled_gemm = self.__class__.kernel_cache[cache_key]
# launch gemm kernel
compiled_gemm(
m,
n,
k,
sf_m,
sf_n,
sf_k,
batch_size,
a_ptr,
b_ptr,
a_sf_ptr,
b_sf_ptr,
c_ptr,
stream=stream,
)
# a/b: fp8, scale: fp32, out: bf16
@torch.library.custom_op("trtllm::cute_dsl_fp8_bmm_blackwell",
mutates_args=(),
device_types="cuda")
def cute_dsl_fp8_bmm_blackwell(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
tuner = AutoTuner.get()
runner = CuteDSLFp8BlackwellBmm()
inputs = [input, weight, input_scale, weight_scale, out]
_, best_tactic = tuner.choose_one(
"trtllm::cute_dsl_fp8_bmm_blackwell::gemm",
[runner],
runner.__class__.tuning_config,
inputs,
)
runner(inputs, tactic=best_tactic)
@torch.library.register_fake("trtllm::cute_dsl_fp8_bmm_blackwell")
def _(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
out: torch.Tensor,
) -> None:
batch_size, m, k = mat_a.shape[0], mat_a.shape[1], mat_a.shape[2]
n = mat_b.shape[1]
assert out.dtype == torch.bfloat16, "CuTe DSL fp8 bmm output dtype must be bf16"
assert out.shape == (batch_size, m,
n), "CuTe DSL fp8 bmm output shape is incorrect"
class CuteDSLFp8BlackwellGroupGemm(TunableRunner):
kernel_class = BlockwiseContiguousGroupedGemmKernel
kernel_cache = dict()
tuning_config = TuningConfig()
def __init__(self):
super().__init__()
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
**kwargs,
) -> List[int]:
if not is_sm_100f():
logger.debug(
f"CuteDSL: SM version {get_sm_version()} is not supported. "
f"CuteDSL FP8 Group Gemm only supports SM 100 family. Skipping all tactics."
)
return []
# [m, k]
m, k = inputs[0].shape[0], inputs[0].shape[1]
# [group_size, n, k]
group_size, n, k = inputs[1].shape[0], inputs[1].shape[1], inputs[
1].shape[2]
batch_size = 1
# m,k
a_major = "k"
# n, k
b_major = "k"
# m, n
c_major = "n"
use_2cta_instrs_candi = [False, True]
mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)]
cluster_shape_mn_candi = [
(1, 1),
(1, 2),
(1, 4),
(2, 1),
(2, 2),
(2, 4),
(4, 1),
(4, 2),
(4, 4),
]
return [
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn)
for use_2cta_instrs in use_2cta_instrs_candi
for mma_tiler_mn in mma_tiler_mn_candi
for cluster_shape_mn in cluster_shape_mn_candi
if self.__class__.kernel_class.can_implement(
cutlass.Float8E4M3FN, # ab_dtype,
cutlass.Float32, # acc_dtype,
cutlass.BFloat16, # c_dtype,
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
m,
n,
k,
batch_size,
a_major,
b_major,
c_major,
)
]
def forward(
self,
inputs: List[torch.Tensor],
tactic,
) -> torch.Tensor:
"""
Performs fp8 grouped gemm operation using CuTe DSL.
Args:
inputs (List[torch.Tensor]):
inputs[0]: Input tensor of shape (m, k), dtype: fp8.
inputs[1]: Weight tensor of shape (group_size, n, k), dtype: fp8.
inputs[2]: Input scale tensor of shape (k // 128, m), dtype: fp32.
inputs[3]: Weight scale tensor of shape (group_size, n // 128, k // 128), dtype: fp32.
inputs[4]: Group offset tensor of shape (group_size + 1), dtype: int32.
tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
Returns:
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
"""
if isinstance(tactic, tuple):
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic
else:
# fallback to default tactic
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [
False,
(128, 128),
(1, 1),
]
a, b, a_sf, b_sf, group_offset = inputs
m, k, n = a.shape[0], a.shape[1], b.shape[1]
group_size = b.shape[0]
sf_m = m
sf_n = ceil_div(n, 128)
sf_k = ceil_div(k, 128)
c = torch.empty(*(m, n), dtype=torch.bfloat16, device=a.device)
a_ptr = make_ptr(
cutlass.Float8E4M3FN,
a.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_ptr = make_ptr(
cutlass.Float8E4M3FN,
b.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
c_ptr = make_ptr(
cutlass.BFloat16,
c.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
a_sf_ptr = make_ptr(
cutlass.Float32,
a_sf.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
b_sf_ptr = make_ptr(
cutlass.Float32,
b_sf.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=16,
)
group_offset_ptr = make_ptr(
cutlass.Int32,
group_offset.data_ptr(),
cute.AddressSpace.gmem,
)
a_sf_tmp = a_sf.reshape((1, sf_k, sf_m))
a_sf_tmp = a_sf_tmp.permute(2, 1, 0)
mSFA = cute.runtime.from_dlpack(
a_sf_tmp, assumed_align=16).mark_layout_dynamic(leading_dim=0)
# get stream
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
cache_key = (
use_2cta_instrs,
mma_tiler_mn,
cluster_shape_mn,
)
if cache_key not in self.__class__.kernel_cache:
gemm = self.__class__.kernel_class(
cutlass.Float32, # acc_dtype,
use_2cta_instrs=use_2cta_instrs,
mma_tiler_mn=mma_tiler_mn,
cluster_shape_mn=cluster_shape_mn,
)
# Compute max active clusters on current device
hardware_info = cutlass.utils.HardwareInfo()
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1])
compiled_gemm = cute.compile(
gemm.wrapper,
m,
n,
k,
sf_m,
sf_n,
sf_k,
1, # batch
group_size,
a_ptr,
b_ptr,
mSFA,
b_sf_ptr,
c_ptr,
group_offset_ptr,
max_active_clusters,
stream,
)
self.__class__.kernel_cache[cache_key] = compiled_gemm
else:
compiled_gemm = self.__class__.kernel_cache[cache_key]
# launch gemm kernel
compiled_gemm(
m,
n,
k,
sf_m,
sf_n,
sf_k,
1, # batch
group_size,
a_ptr,
b_ptr,
mSFA,
b_sf_ptr,
c_ptr,
group_offset_ptr,
stream=stream,
)
return c
# a/b: fp8, scale: fp32, out: bf16
@torch.library.custom_op(
"trtllm::cute_dsl_fp8_group_blockwise_gemm_blackwell",
mutates_args=(),
device_types="cuda")
def cute_dsl_fp8_group_blockwise_gemm_blackwell(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_offset: torch.Tensor,
) -> torch.Tensor:
tuner = AutoTuner.get()
runner = CuteDSLFp8BlackwellGroupGemm()
inputs = [input, weight, input_scale, weight_scale, group_offset]
_, best_tactic = tuner.choose_one(
"trtllm::cute_dsl_fp8_group_blockwise_gemm_blackwell::gemm",
[runner],
runner.__class__.tuning_config,
inputs,
)
return runner(inputs, tactic=best_tactic)
@torch.library.register_fake(
"trtllm::cute_dsl_fp8_group_blockwise_gemm_blackwell")
def _(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
group_offset: torch.Tensor,
) -> torch.Tensor:
m, k = mat_a.shape[0], mat_a.shape[1]
num_group, n, k = mat_b.shape[0], mat_b.shape[1], mat_b.shape[2]
return mat_a.new_empty((m, n), dtype=torch.bfloat16)

File diff suppressed because it is too large Load Diff

View File

@ -122,6 +122,10 @@ class ModelConfig(Generic[TConfig]):
extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)
# cute dsl op configs
use_cute_dsl_blockscaling_mm: bool = False
use_cute_dsl_blockscaling_bmm: bool = False
_frozen: bool = field(default=False, init=False, repr=False)
# If true, ONLY the vision encoder part of the full model is loaded/executed.

View File

@ -255,6 +255,9 @@ class Attention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_key_value_heads * self.head_dim
self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm
self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm
qkv_shard_indices_mapping = {
"q": (0, self.q_size * (2 if self.attn_output_gate else 1)),
"k":
@ -280,7 +283,8 @@ class Attention(nn.Module):
force_dynamic_quantization=config.force_dynamic_quantization,
disable_deep_gemm=disable_deep_gemm,
use_custom_cublas_mm=use_custom_cublas_mm,
fused_weight_shard_indices_mapping=qkv_shard_indices_mapping)
fused_weight_shard_indices_mapping=qkv_shard_indices_mapping,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])
@ -299,7 +303,8 @@ class Attention(nn.Module):
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization,
disable_deep_gemm=disable_deep_gemm,
use_custom_cublas_mm=use_custom_cublas_mm)
use_custom_cublas_mm=use_custom_cublas_mm,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
self.quant_config = config.get_quant_config()
self.attn_backend = config.attn_backend
@ -684,6 +689,7 @@ def fp8_block_scaling_bmm_out(
mat2_scale: torch.Tensor,
out: torch.Tensor,
mat2_dequant: Optional[torch.Tensor] = None,
use_cute_dsl_blockscaling_bmm: bool = False,
) -> torch.Tensor:
sm_version = get_sm_version()
if sm_version == 90 or sm_version == 89:
@ -704,7 +710,17 @@ def fp8_block_scaling_bmm_out(
output)
out.copy_(output)
elif is_sm_100f(sm_version):
torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out)
if use_cute_dsl_blockscaling_bmm:
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
mat1)
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(mat1_fp8, mat2_fp8,
mat1_scale, mat2_scale,
out)
mat1_scale = None
else:
torch.bmm(mat1.transpose(0, 1),
mat2_dequant.transpose(1, 2),
out=out)
else:
raise NotImplementedError(f"SM{sm_version} is not supported")
@ -847,6 +863,9 @@ class MLA(nn.Module):
quant_config = config.get_quant_config()
self.quant_config = quant_config
self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm
self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm
if not self.is_lite:
self.kv_a_proj_with_mqa = Linear(
hidden_size,
@ -856,7 +875,8 @@ class MLA(nn.Module):
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
use_custom_cublas_mm=True,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
self.q_a_layernorm = RMSNorm(hidden_size=self.q_lora_rank,
eps=rms_norm_eps,
@ -872,7 +892,8 @@ class MLA(nn.Module):
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
else:
self.kv_a_proj_with_mqa = Linear(
hidden_size,
@ -882,7 +903,8 @@ class MLA(nn.Module):
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
use_custom_cublas_mm=True,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
self.q_proj = Linear(
self.q_lora_rank,
@ -894,7 +916,8 @@ class MLA(nn.Module):
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
self.q_b_proj = self.q_proj
self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank,
@ -911,7 +934,8 @@ class MLA(nn.Module):
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
# This parameter will view into self.kv_b_proj.weight after loading weights.
# For dummy weight initialization, this parameter is initialized with empty tensor.
# Used in forward_absorption only
@ -943,7 +967,8 @@ class MLA(nn.Module):
skip_create_weights_in_init=config.skip_create_weights_in_init,
reduce_output=reduce_output,
allreduce_strategy=config.allreduce_strategy,
force_dynamic_quantization=config.force_dynamic_quantization)
force_dynamic_quantization=config.force_dynamic_quantization,
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
@ -1079,7 +1104,7 @@ class MLA(nn.Module):
),
requires_grad=False,
)
if is_sm_100f():
if is_sm_100f() and not self.use_cute_dsl_blockscaling_bmm:
assert self.dtype == torch.bfloat16
self.k_b_proj_trans_dequant = nn.Parameter(
torch.empty(
@ -1871,6 +1896,7 @@ class MLA(nn.Module):
self.k_b_proj_trans_scale,
q_nope_out,
self.k_b_proj_trans_dequant,
self.use_cute_dsl_blockscaling_bmm,
),
lambda: self.mqa.mla_rope_generation(
fused_q,
@ -1948,6 +1974,7 @@ class MLA(nn.Module):
self.v_b_proj_scale,
attn_output.transpose(0, 1),
self.v_b_proj_dequant,
self.use_cute_dsl_blockscaling_bmm,
)
else:
raise NotImplementedError(
@ -2003,6 +2030,7 @@ class MLA(nn.Module):
self.k_b_proj_trans_scale,
q_nope_out,
self.k_b_proj_trans_dequant,
self.use_cute_dsl_blockscaling_bmm,
)
else:
raise NotImplementedError(
@ -2058,6 +2086,7 @@ class MLA(nn.Module):
self.v_b_proj_scale,
attn_output.transpose(0, 1),
self.v_b_proj_dequant,
self.use_cute_dsl_blockscaling_bmm,
)
else:
raise NotImplementedError(
@ -2125,6 +2154,7 @@ class MLA(nn.Module):
self.k_b_proj_trans_scale,
q_nope_out,
self.k_b_proj_trans_dequant,
self.use_cute_dsl_blockscaling_bmm,
)
else:
raise NotImplementedError(
@ -2201,6 +2231,7 @@ class MLA(nn.Module):
self.v_b_proj_scale,
attn_output.transpose(0, 1),
self.v_b_proj_dequant,
self.use_cute_dsl_blockscaling_bmm,
)
else:
raise NotImplementedError(

View File

@ -308,6 +308,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
dtype (Optional[torch.dtype]): Data type for the weights.
reduce_results (bool): Whether to reduce the results across devices.
model_config (ModelConfig): Configuration object for the model.
use_cute_dsl_fp8 (bool): Whether to use CuteDSL FP8 blockwise gemm.
"""
def __init__(
@ -328,6 +329,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
layer_idx: Optional[int] = None,
init_load_balancer: bool = True,
without_comm: bool = False,
use_cute_dsl_fp8: bool = False,
):
super().__init__(
routing_method=routing_method,
@ -354,6 +356,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
for key in [EventType.Main, EventType.MoeOutputMemset]:
if key not in self.event_dict:
self.event_dict[key] = torch.cuda.Event()
self.use_cute_dsl_fp8 = use_cute_dsl_fp8
def select_alltoall_method_type(self) -> AlltoallMethodType:
return AlltoallMethodType.NotEnabled
@ -600,22 +603,40 @@ class CuteDslFusedMoE(CutlassFusedMoE):
use_fp8_block_scaling=True,
)
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)
x = cute_dsl_fp8_group_blockwise_gemm_ref(
a=x,
b=self.w3_w1_weight.view(weight_dtype),
a_sf=x_sf,
b_sf=self.quant_scales[0],
offset_array=expert_first_token_offset,
)
if is_sm_100f() and self.use_cute_dsl_fp8:
x = torch.ops.trtllm.cute_dsl_fp8_group_blockwise_gemm_blackwell(
input=x,
weight=self.w3_w1_weight.view(weight_dtype),
input_scale=x_sf,
weight_scale=self.quant_scales[0],
group_offset=expert_first_token_offset,
)
else:
x = cute_dsl_fp8_group_blockwise_gemm_ref(
a=x,
b=self.w3_w1_weight.view(weight_dtype),
a_sf=x_sf,
b_sf=self.quant_scales[0],
offset_array=expert_first_token_offset,
)
x = swiglu_fused_moe(x)
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)
x = cute_dsl_fp8_group_blockwise_gemm_ref(
a=x,
b=self.w2_weight.view(weight_dtype),
a_sf=x_sf,
b_sf=self.quant_scales[1],
offset_array=expert_first_token_offset,
)
if is_sm_100f() and self.use_cute_dsl_fp8:
x = torch.ops.trtllm.cute_dsl_fp8_group_blockwise_gemm_blackwell(
input=x,
weight=self.w2_weight.view(weight_dtype),
input_scale=x_sf,
weight_scale=self.quant_scales[1],
group_offset=expert_first_token_offset,
)
else:
x = cute_dsl_fp8_group_blockwise_gemm_ref(
a=x,
b=self.w2_weight.view(weight_dtype),
a_sf=x_sf,
b_sf=self.quant_scales[1],
offset_array=expert_first_token_offset,
)
x = torch.ops.trtllm.moe_finalize_scale_op(
x,
None, # biases

View File

@ -740,10 +740,9 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
if is_sm_100f():
if module.use_cute_dsl_blockscaling_mm or module.disable_deep_gemm:
# TODO (@lmin): replace with cute_dsl gemm
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
input)
output = torch.ops.trtllm.fp8_block_scaling_gemm(
output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell(
act_input_fp8, module.weight, act_input_sf,
module.weight_scale)
else:

View File

@ -301,6 +301,16 @@ def fp4_unswizzled_scale_infer_shape(input_shapes: List[List[int]]):
return scale_shape * 2
def fp8_scale_infer_shape(input_shapes: List[List[int]]):
"""Calculate the dimensions of the fp8 scale tensor.
"""
input_shape = input_shapes[0]
assert len(input_shape) == 2 or len(input_shape) == 3
has_batch = len(input_shape) == 3
m = input_shape[-2]
return pad_up(m, 4) if has_batch else m
_enable_piecewise_cuda_graph = True

View File

@ -1048,7 +1048,7 @@ def test_fused_moe_fp8_blockwise_deepgemm(dtype,
@skip_pre_blackwell
@pytest.mark.parametrize(
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode",
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode, use_cute_dsl_fp8",
product(
[torch.bfloat16],
[72],
@ -1056,6 +1056,7 @@ def test_fused_moe_fp8_blockwise_deepgemm(dtype,
[2560],
[DefaultMoeRoutingMethod],
[MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ],
[False, True],
),
)
def test_fused_moe_fp8_blockwise_cute_dsl(dtype,
@ -1064,6 +1065,7 @@ def test_fused_moe_fp8_blockwise_cute_dsl(dtype,
hidden_size,
RoutingMethodCls,
WeightLoadingMode,
use_cute_dsl_fp8,
mapping=None):
SEQ_LEN = seq_len
HIDDEN_SIZE = hidden_size
@ -1153,6 +1155,7 @@ def test_fused_moe_fp8_blockwise_cute_dsl(dtype,
reduce_results=True,
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
weight_loading_mode=WeightLoadingMode,
use_cute_dsl_fp8=use_cute_dsl_fp8,
)
fused_moe.cuda()
fused_moe.load_weights([weights])

View File

@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -109,6 +109,50 @@ def test_fp8_block_scale_gemm(dtype, m, k, n):
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
@pytest.mark.skipif(
not isSM100Family(),
reason="The test is for Blackwell. Current SM is %d." % getSMVersion(),
)
@pytest.mark.parametrize(
"k, n",
[(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096),
(2048, 7168), (1024, 1024)],
)
@pytest.mark.parametrize(
"m",
[7, 64, 128, 4096],
)
@pytest.mark.parametrize(
"dtype",
[torch.bfloat16],
)
def test_cute_dsl_fp8_block_scale_gemm(dtype, m, k, n):
torch.random.manual_seed(0)
a = torch.randn((m, k), device='cuda', dtype=dtype) / k
b = torch.randn((n, k), device='cuda', dtype=dtype) / k
act_a_fp8, act_a_sf = torch.ops.trtllm.fp8_quantize_1x128(a)
act_b_fp8, act_b_sf = per_block_cast_to_fp8(b)
output_expected = a @ b.t()
with autotune():
cute_dsl_output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell(
act_a_fp8, act_b_fp8, act_a_sf, act_b_sf)
# test Cute DSL kernel
cute_dsl_output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell(
act_a_fp8, act_b_fp8, act_a_sf, act_b_sf)
diff = calc_diff(cute_dsl_output, output_expected)
assert diff < 1e-3
torch.testing.assert_close(cute_dsl_output,
output_expected,
atol=1e-3,
rtol=1e-3)
@pytest.mark.skipif(
getSMVersion() != 90 and getSMVersion() != 89 and getSMVersion() != 120,
reason="The test is for Hopper and Ada only. Current SM is %d." %
@ -171,6 +215,57 @@ def test_fp8_block_scale_bmm(dtype, m, k, n, num_groups):
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
@pytest.mark.skipif(
not isSM100Family(),
reason="The test is for Blackwell. Current SM is %d." % getSMVersion(),
)
@pytest.mark.parametrize(
"k, n",
[(7168, 2112), (512, 32768), (16384, 7168), (2048, 7168)],
)
@pytest.mark.parametrize(
"m",
[7, 64, 128],
)
@pytest.mark.parametrize(
"num_groups",
[4, 8, 16],
)
@pytest.mark.parametrize(
"dtype",
[torch.bfloat16],
)
def test_cute_dsl_fp8_block_scale_bmm(dtype, m, k, n, num_groups):
torch.random.manual_seed(0)
a = torch.randn((m, num_groups, k), device='cuda', dtype=dtype) / k
a_fp8, a_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(a)
b = torch.randn((num_groups, n, k), device='cuda', dtype=dtype) / k
b_fp8 = torch.zeros_like(b, device='cuda', dtype=torch.float8_e4m3fn)
b_scales = torch.zeros((num_groups, (n + 127) // 128, (k + 127) // 128),
device='cuda',
dtype=torch.float)
for i in range(num_groups):
b_fp8[i], b_scales[i] = per_block_cast_to_fp8(b[i])
output_expected = torch.einsum('mgk,gnk->gmn', a, b)
output = torch.empty((num_groups, m, n),
device='cuda',
dtype=torch.bfloat16)
# tune
with autotune():
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(a_fp8, b_fp8, a_scales,
b_scales, output)
# run the tuned kernel
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(a_fp8, b_fp8, a_scales,
b_scales, output)
diff = calc_diff(output, output_expected)
assert diff < 1e-3
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
def deepSeekFp8ComputeGemmReference(mM, mN, mK, valsC, dqSfsC, valsA, dqSfsA,
valsB, dqSfsB, quantizeOutput, tileSize):
for mi in range(mM):

View File

@ -0,0 +1,87 @@
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import random
import pytest
import torch
from _torch.helpers import calc_diff, per_block_cast_to_fp8
from utils.util import getSMVersion, isSM100Family
from tensorrt_llm._torch.autotuner import autotune
@pytest.mark.skipif(
not isSM100Family(),
reason="The test is for Blackwell only. Current SM is %d." % getSMVersion(),
)
@pytest.mark.parametrize("num_experts", [72])
@pytest.mark.parametrize("k", [1536])
@pytest.mark.parametrize("n", [2560])
@pytest.mark.parametrize("max_tokens_per_group", [10, 50, 100, 128, 256, 512, 1000, 1024])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
def test_cute_dsl_fp8_block_scale_group_gemm(dtype, num_experts, k, n, max_tokens_per_group):
random.seed(0)
torch.random.manual_seed(0)
group_m = []
for i in range(num_experts):
group_m.append(random.randint(0, max_tokens_per_group))
group_m = torch.tensor(group_m, dtype=torch.int32, device="cuda")
group_m_cum = torch.cumsum(group_m, dim=0)
group_offset = torch.cat([torch.zeros(1, dtype=torch.int32, device="cuda"), group_m_cum], dim=0)
group_offset = group_offset.to(torch.int32)
m = sum(group_m)
a = torch.randn((m, k), device="cuda", dtype=dtype) / k
b = torch.randn((num_experts, n, k), device="cuda", dtype=dtype) / k
output_expected = torch.zeros((m, n), device="cuda", dtype=dtype)
for i in range(num_experts):
start = group_offset[i]
end = group_offset[i + 1]
output_expected[start:end, :] = torch.einsum("mk,nk->mn", a[start:end, :], b[i, :, :])
a_fp8, a_scale = torch.ops.trtllm.fp8_quantize_1x128(a)
b_fp8 = torch.empty(num_experts, n, k, dtype=torch.float8_e4m3fn, device="cuda")
b_scale = torch.empty(
num_experts, math.ceil(n / 128), math.ceil(k / 128), dtype=torch.float32, device="cuda"
)
for i in range(num_experts):
cur_b, cur_b_scale = per_block_cast_to_fp8(b[i, :, :])
b_fp8[i, :, :] = cur_b
b_scale[i, :, :] = cur_b_scale
with autotune():
output = torch.ops.trtllm.cute_dsl_fp8_group_blockwise_gemm_blackwell(
input=a_fp8,
weight=b_fp8,
input_scale=a_scale,
weight_scale=b_scale,
group_offset=group_offset,
)
output = torch.ops.trtllm.cute_dsl_fp8_group_blockwise_gemm_blackwell(
input=a_fp8,
weight=b_fp8,
input_scale=a_scale,
weight_scale=b_scale,
group_offset=group_offset,
)
diff = calc_diff(output, output_expected)
assert diff < 1e-3
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)