mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Merge 0ed300bb2c into c1b0b7350f
This commit is contained in:
commit
3d19df2f87
@ -121,9 +121,8 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_quantize_1x128_permute102(at::Ten
|
||||
|
||||
int64_t scaleSizeInBytes = mGemmRunner.getActScaleSize(m, b * n);
|
||||
int64_t elementSize = scaleSizeInBytes / torch::elementSize(FP8_BLOCK_SCALING_SF_DTYPE);
|
||||
int m_4_align = (m + 3) / 4 * 4;
|
||||
at::Tensor scaleFP8SF = at::detail::empty_cuda({b, m_4_align, elementSize / b / m_4_align},
|
||||
FP8_BLOCK_SCALING_SF_DTYPE, self.device(), /* stride */ std::nullopt);
|
||||
at::Tensor scaleFP8SF = at::detail::empty_cuda(
|
||||
{elementSize}, FP8_BLOCK_SCALING_SF_DTYPE, self.device(), /* stride */ std::nullopt); // 1D tensor
|
||||
|
||||
__nv_fp8_e4m3* act_buffer = reinterpret_cast<__nv_fp8_e4m3*>(valueE4M3.data_ptr());
|
||||
float* act_scale_buffer = reinterpret_cast<float*>(scaleFP8SF.data_ptr());
|
||||
@ -133,6 +132,13 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_quantize_1x128_permute102(at::Ten
|
||||
auto* output_buffer = reinterpret_cast<__nv_bfloat16 const*>(self.data_ptr());
|
||||
mGemmRunner.fp8CS1x128Reshape(act_buffer, act_scale_buffer, output_buffer, n, b, m, lda, stream);
|
||||
|
||||
// scaleFP8SF = scaleFP8SF[:, 0:num_n_blocks, 0:m_padded]
|
||||
auto const num_n_blocks = (n + 127) / 128;
|
||||
auto const act_scal_elesize = b * num_n_blocks * m_padded;
|
||||
TORCH_CHECK(act_scal_elesize <= scaleFP8SF.numel(), "Scale tensor size mismatch. Expected at least ",
|
||||
act_scal_elesize, " elements, got ", scaleFP8SF.numel());
|
||||
scaleFP8SF = scaleFP8SF.slice(0, 0, act_scal_elesize).view({b, num_n_blocks, m_padded}).contiguous();
|
||||
|
||||
return {valueE4M3.slice(0, 0, b * m * n).view({b, m, n}), scaleFP8SF};
|
||||
}
|
||||
} // namespace torch_ext
|
||||
|
||||
@ -130,6 +130,10 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
|
||||
int const num_experts_per_node = num_experts_on_rank;
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
int64_t num_moe_inputs = static_cast<int64_t>(experts_per_token * num_rows);
|
||||
TORCH_CHECK(num_moe_inputs <= std::numeric_limits<int32_t>::max(),
|
||||
"num_moe_inputs exceeds int32 range (because we use int32 for expert_first_token_offset_tensor). "
|
||||
"num_moe_inputs = ",
|
||||
num_moe_inputs);
|
||||
|
||||
auto permuted_row_to_unpermuted_row_tensor
|
||||
= torch::empty({num_moe_inputs}, torch::dtype(torch::kInt32).device(torch::kCUDA).requires_grad(false));
|
||||
@ -226,6 +230,7 @@ std::tuple<torch::Tensor, torch::Tensor, torch::Tensor, torch::Tensor, torch::Te
|
||||
"Invalid dtype, only supports input tensor with float32, float16 and bfloat16 dtype");
|
||||
break;
|
||||
}
|
||||
expert_first_token_offset_tensor = expert_first_token_offset_tensor.to(torch::kInt32);
|
||||
return std::make_tuple(permuted_row_to_unpermuted_row_tensor, permuted_token_selected_experts_tensor,
|
||||
permuted_data_tensor, expert_first_token_offset_tensor, permuted_token_final_scales_tensor,
|
||||
unpermuted_row_to_permuted_row_tensor);
|
||||
|
||||
@ -5,13 +5,13 @@ import torch
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..._utils import get_sm_version
|
||||
from ..._utils import get_sm_version, is_sm_100f
|
||||
from ...math_utils import ceil_div, pad_up
|
||||
from ..autotuner import (AutoTuner, ConstraintSpec, DistributedTuningStrategy,
|
||||
DynamicTensorSpec, OptimizationProfile, TunableRunner,
|
||||
TuningConfig)
|
||||
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
|
||||
from ..utils import (fp4_scale_infer_shape,
|
||||
from ..utils import (fp4_scale_infer_shape, fp8_scale_infer_shape,
|
||||
get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2)
|
||||
|
||||
@ -337,6 +337,10 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel
|
||||
from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm_swiglu_fusion import \
|
||||
Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
|
||||
from ..cute_dsl_kernels.blackwell.blockwise_gemm.blockwise_gemm import \
|
||||
BlockwiseGemmKernel
|
||||
from ..cute_dsl_kernels.blackwell.blockwise_gemm.contiguous_offset_grouped_gemm import \
|
||||
BlockwiseContiguousGroupedGemmKernel
|
||||
from ..cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import \
|
||||
Sm100BlockScaledPersistentDenseGemmKernel
|
||||
from ..cute_dsl_kernels.blackwell.utils import make_ptr
|
||||
@ -523,7 +527,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs fp8 blockwise gemm operation using CuTe DSL.
|
||||
Performs fp4 blockwise gemm operation using CuTe DSL.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]):
|
||||
@ -2179,3 +2183,733 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
dtype=input_scale.dtype,
|
||||
device=input_scale.device)
|
||||
return output, output_scale
|
||||
|
||||
class CuteDSLFp8BlackwellLinear(TunableRunner):
|
||||
kernel_class = BlockwiseGemmKernel
|
||||
kernel_cache = dict()
|
||||
|
||||
tuning_config = TuningConfig(
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2), ),
|
||||
constraint_specs=(ConstraintSpec(2, 1, fp8_scale_infer_shape), ),
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
**kwargs,
|
||||
) -> List[int]:
|
||||
if not is_sm_100f():
|
||||
logger.debug(
|
||||
f"CuteDSL: SM version {get_sm_version()} is not supported. "
|
||||
f"CuteDSL FP8 GEMM only supports SM 100 family. Skipping all tactics."
|
||||
)
|
||||
return []
|
||||
|
||||
m = inputs[0].shape[0]
|
||||
n = inputs[1].shape[0]
|
||||
k = inputs[0].shape[1]
|
||||
batch_size = 1
|
||||
# m,k
|
||||
a_major = "k"
|
||||
# n, k
|
||||
b_major = "k"
|
||||
# m, n
|
||||
c_major = "n"
|
||||
|
||||
use_2cta_instrs_candi = [False, True]
|
||||
mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)]
|
||||
cluster_shape_mn_candi = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
]
|
||||
return [
|
||||
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn)
|
||||
for use_2cta_instrs in use_2cta_instrs_candi
|
||||
for mma_tiler_mn in mma_tiler_mn_candi
|
||||
for cluster_shape_mn in cluster_shape_mn_candi
|
||||
if self.__class__.kernel_class.can_implement(
|
||||
cutlass.Float8E4M3FN, # ab_dtype,
|
||||
cutlass.Float32, # acc_dtype,
|
||||
cutlass.BFloat16, # c_dtype,
|
||||
use_2cta_instrs,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
batch_size,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
)
|
||||
]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs fp8 blockwise (deepgemm like) operation using CuTe DSL.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]):
|
||||
inputs[0]: Input tensor of shape (m, k), dtype: fp8.
|
||||
inputs[1]: Weight tensor of shape (n, k), dtype: fp8.
|
||||
inputs[2]: Input scale factor tensor of shape (k // 128, m), dtype: fp32.
|
||||
inputs[3]: Weight scale factor tensor of shape (n // 128, k // 128), dtype: fp32.
|
||||
tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
|
||||
"""
|
||||
if isinstance(tactic, tuple):
|
||||
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic
|
||||
else:
|
||||
# fallback to default tactic
|
||||
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [
|
||||
False,
|
||||
(128, 128),
|
||||
(1, 1),
|
||||
]
|
||||
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs
|
||||
m, n, k = a_tensor.shape[0], b_tensor.shape[0], b_tensor.shape[1]
|
||||
sf_m = m
|
||||
sf_k = ceil_div(k, 128)
|
||||
sf_n = ceil_div(n, 128)
|
||||
c_tensor = torch.empty(*(m, n),
|
||||
dtype=torch.bfloat16,
|
||||
device=a_tensor.device)
|
||||
|
||||
a_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
a_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
b_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
a_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
a_sf_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
b_sf_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
c_ptr = make_ptr(
|
||||
cutlass.BFloat16,
|
||||
c_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
|
||||
# get stream
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
cache_key = (
|
||||
use_2cta_instrs,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
)
|
||||
if cache_key not in self.__class__.kernel_cache:
|
||||
gemm = self.__class__.kernel_class(
|
||||
cutlass.Float32, # acc_dtype,
|
||||
use_2cta_instrs=use_2cta_instrs,
|
||||
mma_tiler_mn=mma_tiler_mn,
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
)
|
||||
# Compute max active clusters on current device
|
||||
hardware_info = cutlass.utils.HardwareInfo()
|
||||
max_active_clusters = hardware_info.get_max_active_clusters(
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1])
|
||||
|
||||
compiled_gemm = cute.compile(
|
||||
gemm.wrapper,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
sf_m,
|
||||
sf_n,
|
||||
sf_k,
|
||||
1, # batch
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
a_sf_ptr,
|
||||
b_sf_ptr,
|
||||
c_ptr,
|
||||
max_active_clusters=max_active_clusters,
|
||||
stream=stream,
|
||||
)
|
||||
self.__class__.kernel_cache[cache_key] = compiled_gemm
|
||||
else:
|
||||
compiled_gemm = self.__class__.kernel_cache[cache_key]
|
||||
|
||||
# launch gemm kernel
|
||||
compiled_gemm(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
sf_m,
|
||||
sf_n,
|
||||
sf_k,
|
||||
1, # batch
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
a_sf_ptr,
|
||||
b_sf_ptr,
|
||||
c_ptr,
|
||||
stream=stream,
|
||||
)
|
||||
return c_tensor
|
||||
|
||||
# a/b: fp8, scale: fp32, output: bf16
|
||||
@torch.library.custom_op("trtllm::cute_dsl_fp8_gemm_blackwell",
|
||||
mutates_args=(),
|
||||
device_types="cuda")
|
||||
def cute_dsl_fp8_gemm_blackwell(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
runner = CuteDSLFp8BlackwellLinear()
|
||||
|
||||
inputs = [input, weight, input_scale, weight_scale]
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::cute_dsl_fp8_gemm_blackwell::gemm",
|
||||
[runner],
|
||||
runner.__class__.tuning_config,
|
||||
inputs,
|
||||
)
|
||||
return runner(inputs, tactic=best_tactic)
|
||||
|
||||
@torch.library.register_fake("trtllm::cute_dsl_fp8_gemm_blackwell")
|
||||
def _(
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
):
|
||||
# [m, k]
|
||||
shape = [i for i in mat_a.shape]
|
||||
# [n, k]
|
||||
shape[-1] = mat_b.shape[-2]
|
||||
# output is fixed as bf16
|
||||
ret = mat_a.new_empty(shape, dtype=torch.bfloat16)
|
||||
return ret
|
||||
|
||||
class CuteDSLFp8BlackwellBmm(TunableRunner):
|
||||
kernel_class = BlockwiseGemmKernel
|
||||
kernel_cache = dict()
|
||||
|
||||
tuning_config = TuningConfig(
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 1, get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2), ),
|
||||
constraint_specs=(ConstraintSpec(2, 1, fp8_scale_infer_shape), ),
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
**kwargs,
|
||||
) -> List[int]:
|
||||
|
||||
if not is_sm_100f():
|
||||
logger.debug(
|
||||
f"CuteDSL: SM version {get_sm_version()} is not supported. "
|
||||
f"CuteDSL FP8 BMM only supports SM 100 family. Skipping all tactics."
|
||||
)
|
||||
return []
|
||||
# [b, m, k]
|
||||
batch_size, m, k = inputs[0].shape[0], inputs[0].shape[1], inputs[
|
||||
0].shape[2]
|
||||
# [b, n, k]
|
||||
n = inputs[1].shape[1]
|
||||
# m,k
|
||||
a_major = "k"
|
||||
# n, k
|
||||
b_major = "k"
|
||||
# m, n
|
||||
c_major = "n"
|
||||
|
||||
use_2cta_instrs_candi = [False, True]
|
||||
mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)]
|
||||
cluster_shape_mn_candi = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
]
|
||||
return [
|
||||
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn)
|
||||
for use_2cta_instrs in use_2cta_instrs_candi
|
||||
for mma_tiler_mn in mma_tiler_mn_candi
|
||||
for cluster_shape_mn in cluster_shape_mn_candi
|
||||
if self.__class__.kernel_class.can_implement(
|
||||
cutlass.Float8E4M3FN, # ab_dtype,
|
||||
cutlass.Float32, # acc_dtype,
|
||||
cutlass.BFloat16, # c_dtype,
|
||||
use_2cta_instrs,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
batch_size,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
)
|
||||
]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic,
|
||||
) -> None:
|
||||
"""
|
||||
Performs fp8 blockwise (deepgemm like) batched gemm operation using CuTe DSL.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]):
|
||||
inputs[0]: Input tensor of shape (batch_size, m, k), dtype: fp8.
|
||||
inputs[1]: Weight tensor of shape (batch_size, n, k), dtype: fp8.
|
||||
inputs[2]: Input scale tensor of shape (batch_size, k // 128, pad_up(m, 4)), dtype: fp32.
|
||||
inputs[3]: Weight scale tensor of shape (batch_size, n // 128, k // 128), dtype: fp32.
|
||||
tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (batch_size, m, n), dtype: bf16.
|
||||
"""
|
||||
if isinstance(tactic, tuple):
|
||||
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic
|
||||
else:
|
||||
# fallback to default tactic
|
||||
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [
|
||||
False,
|
||||
(128, 128),
|
||||
(1, 1),
|
||||
]
|
||||
|
||||
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor, c_tensor = inputs
|
||||
|
||||
batch_size = a_tensor.shape[0]
|
||||
m = a_tensor.shape[1]
|
||||
k = a_tensor.shape[2]
|
||||
n = b_tensor.shape[1]
|
||||
sf_m = pad_up(m, 4)
|
||||
sf_k = ceil_div(k, 128)
|
||||
sf_n = ceil_div(n, 128)
|
||||
|
||||
a_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
a_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
b_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
a_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
a_sf_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
b_sf_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
c_ptr = make_ptr(
|
||||
cutlass.BFloat16,
|
||||
c_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
|
||||
# get stream
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
cache_key = (
|
||||
use_2cta_instrs,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
)
|
||||
if cache_key not in self.__class__.kernel_cache:
|
||||
gemm = self.__class__.kernel_class(
|
||||
cutlass.Float32, # acc_dtype,
|
||||
use_2cta_instrs=use_2cta_instrs,
|
||||
mma_tiler_mn=mma_tiler_mn,
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
)
|
||||
# Compute max active clusters on current device
|
||||
hardware_info = cutlass.utils.HardwareInfo()
|
||||
max_active_clusters = hardware_info.get_max_active_clusters(
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1])
|
||||
|
||||
compiled_gemm = cute.compile(
|
||||
gemm.wrapper,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
sf_m,
|
||||
sf_n,
|
||||
sf_k,
|
||||
batch_size,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
a_sf_ptr,
|
||||
b_sf_ptr,
|
||||
c_ptr,
|
||||
max_active_clusters=max_active_clusters,
|
||||
stream=stream,
|
||||
)
|
||||
self.__class__.kernel_cache[cache_key] = compiled_gemm
|
||||
else:
|
||||
compiled_gemm = self.__class__.kernel_cache[cache_key]
|
||||
|
||||
# launch gemm kernel
|
||||
compiled_gemm(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
sf_m,
|
||||
sf_n,
|
||||
sf_k,
|
||||
batch_size,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
a_sf_ptr,
|
||||
b_sf_ptr,
|
||||
c_ptr,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# a/b: fp8, scale: fp32, out: bf16
|
||||
@torch.library.custom_op("trtllm::cute_dsl_fp8_bmm_blackwell",
|
||||
mutates_args=(),
|
||||
device_types="cuda")
|
||||
def cute_dsl_fp8_bmm_blackwell(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
runner = CuteDSLFp8BlackwellBmm()
|
||||
|
||||
inputs = [input, weight, input_scale, weight_scale, out]
|
||||
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::cute_dsl_fp8_bmm_blackwell::gemm",
|
||||
[runner],
|
||||
runner.__class__.tuning_config,
|
||||
inputs,
|
||||
)
|
||||
runner(inputs, tactic=best_tactic)
|
||||
|
||||
@torch.library.register_fake("trtllm::cute_dsl_fp8_bmm_blackwell")
|
||||
def _(
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
) -> None:
|
||||
batch_size, m, k = mat_a.shape[0], mat_a.shape[1], mat_a.shape[2]
|
||||
n = mat_b.shape[1]
|
||||
assert out.dtype == torch.bfloat16, "CuTe DSL fp8 bmm output dtype must be bf16"
|
||||
assert out.shape == (batch_size, m,
|
||||
n), "CuTe DSL fp8 bmm output shape is incorrect"
|
||||
|
||||
class CuteDSLFp8BlackwellGroupGemm(TunableRunner):
|
||||
kernel_class = BlockwiseContiguousGroupedGemmKernel
|
||||
kernel_cache = dict()
|
||||
|
||||
tuning_config = TuningConfig()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
**kwargs,
|
||||
) -> List[int]:
|
||||
if not is_sm_100f():
|
||||
logger.debug(
|
||||
f"CuteDSL: SM version {get_sm_version()} is not supported. "
|
||||
f"CuteDSL FP8 Group Gemm only supports SM 100 family. Skipping all tactics."
|
||||
)
|
||||
return []
|
||||
# [m, k]
|
||||
m, k = inputs[0].shape[0], inputs[0].shape[1]
|
||||
# [group_size, n, k]
|
||||
group_size, n, k = inputs[1].shape[0], inputs[1].shape[1], inputs[
|
||||
1].shape[2]
|
||||
batch_size = 1
|
||||
# m,k
|
||||
a_major = "k"
|
||||
# n, k
|
||||
b_major = "k"
|
||||
# m, n
|
||||
c_major = "n"
|
||||
|
||||
use_2cta_instrs_candi = [False, True]
|
||||
mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)]
|
||||
cluster_shape_mn_candi = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
]
|
||||
return [
|
||||
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn)
|
||||
for use_2cta_instrs in use_2cta_instrs_candi
|
||||
for mma_tiler_mn in mma_tiler_mn_candi
|
||||
for cluster_shape_mn in cluster_shape_mn_candi
|
||||
if self.__class__.kernel_class.can_implement(
|
||||
cutlass.Float8E4M3FN, # ab_dtype,
|
||||
cutlass.Float32, # acc_dtype,
|
||||
cutlass.BFloat16, # c_dtype,
|
||||
use_2cta_instrs,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
batch_size,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
)
|
||||
]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs fp8 grouped gemm operation using CuTe DSL.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]):
|
||||
inputs[0]: Input tensor of shape (m, k), dtype: fp8.
|
||||
inputs[1]: Weight tensor of shape (group_size, n, k), dtype: fp8.
|
||||
inputs[2]: Input scale tensor of shape (k // 128, m), dtype: fp32.
|
||||
inputs[3]: Weight scale tensor of shape (group_size, n // 128, k // 128), dtype: fp32.
|
||||
inputs[4]: Group offset tensor of shape (group_size + 1), dtype: int32.
|
||||
tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
|
||||
"""
|
||||
|
||||
if isinstance(tactic, tuple):
|
||||
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic
|
||||
else:
|
||||
# fallback to default tactic
|
||||
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [
|
||||
False,
|
||||
(128, 128),
|
||||
(1, 1),
|
||||
]
|
||||
|
||||
a, b, a_sf, b_sf, group_offset = inputs
|
||||
m, k, n = a.shape[0], a.shape[1], b.shape[1]
|
||||
group_size = b.shape[0]
|
||||
sf_m = m
|
||||
sf_n = ceil_div(n, 128)
|
||||
sf_k = ceil_div(k, 128)
|
||||
c = torch.empty(*(m, n), dtype=torch.bfloat16, device=a.device)
|
||||
|
||||
a_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
a.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
b.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
c_ptr = make_ptr(
|
||||
cutlass.BFloat16,
|
||||
c.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
a_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
a_sf.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
b_sf.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
group_offset_ptr = make_ptr(
|
||||
cutlass.Int32,
|
||||
group_offset.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
)
|
||||
|
||||
a_sf_tmp = a_sf.reshape((1, sf_k, sf_m))
|
||||
a_sf_tmp = a_sf_tmp.permute(2, 1, 0)
|
||||
|
||||
mSFA = cute.runtime.from_dlpack(
|
||||
a_sf_tmp, assumed_align=16).mark_layout_dynamic(leading_dim=0)
|
||||
|
||||
# get stream
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
cache_key = (
|
||||
use_2cta_instrs,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
)
|
||||
if cache_key not in self.__class__.kernel_cache:
|
||||
gemm = self.__class__.kernel_class(
|
||||
cutlass.Float32, # acc_dtype,
|
||||
use_2cta_instrs=use_2cta_instrs,
|
||||
mma_tiler_mn=mma_tiler_mn,
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
)
|
||||
# Compute max active clusters on current device
|
||||
hardware_info = cutlass.utils.HardwareInfo()
|
||||
max_active_clusters = hardware_info.get_max_active_clusters(
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1])
|
||||
|
||||
compiled_gemm = cute.compile(
|
||||
gemm.wrapper,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
sf_m,
|
||||
sf_n,
|
||||
sf_k,
|
||||
1, # batch
|
||||
group_size,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
mSFA,
|
||||
b_sf_ptr,
|
||||
c_ptr,
|
||||
group_offset_ptr,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
)
|
||||
self.__class__.kernel_cache[cache_key] = compiled_gemm
|
||||
else:
|
||||
compiled_gemm = self.__class__.kernel_cache[cache_key]
|
||||
|
||||
# launch gemm kernel
|
||||
compiled_gemm(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
sf_m,
|
||||
sf_n,
|
||||
sf_k,
|
||||
1, # batch
|
||||
group_size,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
mSFA,
|
||||
b_sf_ptr,
|
||||
c_ptr,
|
||||
group_offset_ptr,
|
||||
stream=stream,
|
||||
)
|
||||
return c
|
||||
|
||||
# a/b: fp8, scale: fp32, out: bf16
|
||||
@torch.library.custom_op(
|
||||
"trtllm::cute_dsl_fp8_group_blockwise_gemm_blackwell",
|
||||
mutates_args=(),
|
||||
device_types="cuda")
|
||||
def cute_dsl_fp8_group_blockwise_gemm_blackwell(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_offset: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
|
||||
tuner = AutoTuner.get()
|
||||
runner = CuteDSLFp8BlackwellGroupGemm()
|
||||
|
||||
inputs = [input, weight, input_scale, weight_scale, group_offset]
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::cute_dsl_fp8_group_blockwise_gemm_blackwell::gemm",
|
||||
[runner],
|
||||
runner.__class__.tuning_config,
|
||||
inputs,
|
||||
)
|
||||
return runner(inputs, tactic=best_tactic)
|
||||
|
||||
@torch.library.register_fake(
|
||||
"trtllm::cute_dsl_fp8_group_blockwise_gemm_blackwell")
|
||||
def _(
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
group_offset: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
m, k = mat_a.shape[0], mat_a.shape[1]
|
||||
num_group, n, k = mat_b.shape[0], mat_b.shape[1], mat_b.shape[2]
|
||||
return mat_a.new_empty((m, n), dtype=torch.bfloat16)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -122,6 +122,10 @@ class ModelConfig(Generic[TConfig]):
|
||||
|
||||
extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)
|
||||
|
||||
# cute dsl op configs
|
||||
use_cute_dsl_blockscaling_mm: bool = False
|
||||
use_cute_dsl_blockscaling_bmm: bool = False
|
||||
|
||||
_frozen: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
# If true, ONLY the vision encoder part of the full model is loaded/executed.
|
||||
|
||||
@ -255,6 +255,9 @@ class Attention(nn.Module):
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_key_value_heads * self.head_dim
|
||||
|
||||
self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm
|
||||
self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm
|
||||
|
||||
qkv_shard_indices_mapping = {
|
||||
"q": (0, self.q_size * (2 if self.attn_output_gate else 1)),
|
||||
"k":
|
||||
@ -280,7 +283,8 @@ class Attention(nn.Module):
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
disable_deep_gemm=disable_deep_gemm,
|
||||
use_custom_cublas_mm=use_custom_cublas_mm,
|
||||
fused_weight_shard_indices_mapping=qkv_shard_indices_mapping)
|
||||
fused_weight_shard_indices_mapping=qkv_shard_indices_mapping,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
|
||||
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
|
||||
[self.hidden_size])
|
||||
@ -299,7 +303,8 @@ class Attention(nn.Module):
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
disable_deep_gemm=disable_deep_gemm,
|
||||
use_custom_cublas_mm=use_custom_cublas_mm)
|
||||
use_custom_cublas_mm=use_custom_cublas_mm,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
|
||||
self.quant_config = config.get_quant_config()
|
||||
self.attn_backend = config.attn_backend
|
||||
@ -684,6 +689,7 @@ def fp8_block_scaling_bmm_out(
|
||||
mat2_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
mat2_dequant: Optional[torch.Tensor] = None,
|
||||
use_cute_dsl_blockscaling_bmm: bool = False,
|
||||
) -> torch.Tensor:
|
||||
sm_version = get_sm_version()
|
||||
if sm_version == 90 or sm_version == 89:
|
||||
@ -704,7 +710,17 @@ def fp8_block_scaling_bmm_out(
|
||||
output)
|
||||
out.copy_(output)
|
||||
elif is_sm_100f(sm_version):
|
||||
torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out)
|
||||
if use_cute_dsl_blockscaling_bmm:
|
||||
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
|
||||
mat1)
|
||||
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(mat1_fp8, mat2_fp8,
|
||||
mat1_scale, mat2_scale,
|
||||
out)
|
||||
mat1_scale = None
|
||||
else:
|
||||
torch.bmm(mat1.transpose(0, 1),
|
||||
mat2_dequant.transpose(1, 2),
|
||||
out=out)
|
||||
else:
|
||||
raise NotImplementedError(f"SM{sm_version} is not supported")
|
||||
|
||||
@ -847,6 +863,9 @@ class MLA(nn.Module):
|
||||
quant_config = config.get_quant_config()
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm
|
||||
self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm
|
||||
|
||||
if not self.is_lite:
|
||||
self.kv_a_proj_with_mqa = Linear(
|
||||
hidden_size,
|
||||
@ -856,7 +875,8 @@ class MLA(nn.Module):
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
use_custom_cublas_mm=True,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
|
||||
self.q_a_layernorm = RMSNorm(hidden_size=self.q_lora_rank,
|
||||
eps=rms_norm_eps,
|
||||
@ -872,7 +892,8 @@ class MLA(nn.Module):
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
else:
|
||||
self.kv_a_proj_with_mqa = Linear(
|
||||
hidden_size,
|
||||
@ -882,7 +903,8 @@ class MLA(nn.Module):
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
use_custom_cublas_mm=True,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
|
||||
self.q_proj = Linear(
|
||||
self.q_lora_rank,
|
||||
@ -894,7 +916,8 @@ class MLA(nn.Module):
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
self.q_b_proj = self.q_proj
|
||||
|
||||
self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank,
|
||||
@ -911,7 +934,8 @@ class MLA(nn.Module):
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
# This parameter will view into self.kv_b_proj.weight after loading weights.
|
||||
# For dummy weight initialization, this parameter is initialized with empty tensor.
|
||||
# Used in forward_absorption only
|
||||
@ -943,7 +967,8 @@ class MLA(nn.Module):
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
reduce_output=reduce_output,
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
@ -1079,7 +1104,7 @@ class MLA(nn.Module):
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
if is_sm_100f():
|
||||
if is_sm_100f() and not self.use_cute_dsl_blockscaling_bmm:
|
||||
assert self.dtype == torch.bfloat16
|
||||
self.k_b_proj_trans_dequant = nn.Parameter(
|
||||
torch.empty(
|
||||
@ -1871,6 +1896,7 @@ class MLA(nn.Module):
|
||||
self.k_b_proj_trans_scale,
|
||||
q_nope_out,
|
||||
self.k_b_proj_trans_dequant,
|
||||
self.use_cute_dsl_blockscaling_bmm,
|
||||
),
|
||||
lambda: self.mqa.mla_rope_generation(
|
||||
fused_q,
|
||||
@ -1948,6 +1974,7 @@ class MLA(nn.Module):
|
||||
self.v_b_proj_scale,
|
||||
attn_output.transpose(0, 1),
|
||||
self.v_b_proj_dequant,
|
||||
self.use_cute_dsl_blockscaling_bmm,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -2003,6 +2030,7 @@ class MLA(nn.Module):
|
||||
self.k_b_proj_trans_scale,
|
||||
q_nope_out,
|
||||
self.k_b_proj_trans_dequant,
|
||||
self.use_cute_dsl_blockscaling_bmm,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -2058,6 +2086,7 @@ class MLA(nn.Module):
|
||||
self.v_b_proj_scale,
|
||||
attn_output.transpose(0, 1),
|
||||
self.v_b_proj_dequant,
|
||||
self.use_cute_dsl_blockscaling_bmm,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -2125,6 +2154,7 @@ class MLA(nn.Module):
|
||||
self.k_b_proj_trans_scale,
|
||||
q_nope_out,
|
||||
self.k_b_proj_trans_dequant,
|
||||
self.use_cute_dsl_blockscaling_bmm,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -2201,6 +2231,7 @@ class MLA(nn.Module):
|
||||
self.v_b_proj_scale,
|
||||
attn_output.transpose(0, 1),
|
||||
self.v_b_proj_dequant,
|
||||
self.use_cute_dsl_blockscaling_bmm,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
||||
@ -308,6 +308,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
dtype (Optional[torch.dtype]): Data type for the weights.
|
||||
reduce_results (bool): Whether to reduce the results across devices.
|
||||
model_config (ModelConfig): Configuration object for the model.
|
||||
use_cute_dsl_fp8 (bool): Whether to use CuteDSL FP8 blockwise gemm.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -328,6 +329,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
layer_idx: Optional[int] = None,
|
||||
init_load_balancer: bool = True,
|
||||
without_comm: bool = False,
|
||||
use_cute_dsl_fp8: bool = False,
|
||||
):
|
||||
super().__init__(
|
||||
routing_method=routing_method,
|
||||
@ -354,6 +356,7 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
for key in [EventType.Main, EventType.MoeOutputMemset]:
|
||||
if key not in self.event_dict:
|
||||
self.event_dict[key] = torch.cuda.Event()
|
||||
self.use_cute_dsl_fp8 = use_cute_dsl_fp8
|
||||
|
||||
def select_alltoall_method_type(self) -> AlltoallMethodType:
|
||||
return AlltoallMethodType.NotEnabled
|
||||
@ -600,22 +603,40 @@ class CuteDslFusedMoE(CutlassFusedMoE):
|
||||
use_fp8_block_scaling=True,
|
||||
)
|
||||
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)
|
||||
x = cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||
a=x,
|
||||
b=self.w3_w1_weight.view(weight_dtype),
|
||||
a_sf=x_sf,
|
||||
b_sf=self.quant_scales[0],
|
||||
offset_array=expert_first_token_offset,
|
||||
)
|
||||
if is_sm_100f() and self.use_cute_dsl_fp8:
|
||||
x = torch.ops.trtllm.cute_dsl_fp8_group_blockwise_gemm_blackwell(
|
||||
input=x,
|
||||
weight=self.w3_w1_weight.view(weight_dtype),
|
||||
input_scale=x_sf,
|
||||
weight_scale=self.quant_scales[0],
|
||||
group_offset=expert_first_token_offset,
|
||||
)
|
||||
else:
|
||||
x = cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||
a=x,
|
||||
b=self.w3_w1_weight.view(weight_dtype),
|
||||
a_sf=x_sf,
|
||||
b_sf=self.quant_scales[0],
|
||||
offset_array=expert_first_token_offset,
|
||||
)
|
||||
x = swiglu_fused_moe(x)
|
||||
x, x_sf = torch.ops.trtllm.fp8_quantize_1x128(x)
|
||||
x = cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||
a=x,
|
||||
b=self.w2_weight.view(weight_dtype),
|
||||
a_sf=x_sf,
|
||||
b_sf=self.quant_scales[1],
|
||||
offset_array=expert_first_token_offset,
|
||||
)
|
||||
if is_sm_100f() and self.use_cute_dsl_fp8:
|
||||
x = torch.ops.trtllm.cute_dsl_fp8_group_blockwise_gemm_blackwell(
|
||||
input=x,
|
||||
weight=self.w2_weight.view(weight_dtype),
|
||||
input_scale=x_sf,
|
||||
weight_scale=self.quant_scales[1],
|
||||
group_offset=expert_first_token_offset,
|
||||
)
|
||||
else:
|
||||
x = cute_dsl_fp8_group_blockwise_gemm_ref(
|
||||
a=x,
|
||||
b=self.w2_weight.view(weight_dtype),
|
||||
a_sf=x_sf,
|
||||
b_sf=self.quant_scales[1],
|
||||
offset_array=expert_first_token_offset,
|
||||
)
|
||||
x = torch.ops.trtllm.moe_finalize_scale_op(
|
||||
x,
|
||||
None, # biases
|
||||
|
||||
@ -740,10 +740,9 @@ class FP8BlockScalesLinearMethod(LinearMethodBase):
|
||||
|
||||
if is_sm_100f():
|
||||
if module.use_cute_dsl_blockscaling_mm or module.disable_deep_gemm:
|
||||
# TODO (@lmin): replace with cute_dsl gemm
|
||||
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
|
||||
input)
|
||||
output = torch.ops.trtllm.fp8_block_scaling_gemm(
|
||||
output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell(
|
||||
act_input_fp8, module.weight, act_input_sf,
|
||||
module.weight_scale)
|
||||
else:
|
||||
|
||||
@ -301,6 +301,16 @@ def fp4_unswizzled_scale_infer_shape(input_shapes: List[List[int]]):
|
||||
return scale_shape * 2
|
||||
|
||||
|
||||
def fp8_scale_infer_shape(input_shapes: List[List[int]]):
|
||||
"""Calculate the dimensions of the fp8 scale tensor.
|
||||
"""
|
||||
input_shape = input_shapes[0]
|
||||
assert len(input_shape) == 2 or len(input_shape) == 3
|
||||
has_batch = len(input_shape) == 3
|
||||
m = input_shape[-2]
|
||||
return pad_up(m, 4) if has_batch else m
|
||||
|
||||
|
||||
_enable_piecewise_cuda_graph = True
|
||||
|
||||
|
||||
|
||||
@ -1048,7 +1048,7 @@ def test_fused_moe_fp8_blockwise_deepgemm(dtype,
|
||||
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.parametrize(
|
||||
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode",
|
||||
"dtype, num_experts, seq_len, hidden_size, RoutingMethodCls, WeightLoadingMode, use_cute_dsl_fp8",
|
||||
product(
|
||||
[torch.bfloat16],
|
||||
[72],
|
||||
@ -1056,6 +1056,7 @@ def test_fused_moe_fp8_blockwise_deepgemm(dtype,
|
||||
[2560],
|
||||
[DefaultMoeRoutingMethod],
|
||||
[MoEWeightLoadingMode.VANILLA, MoEWeightLoadingMode.FUSED_GATE_UP_PROJ],
|
||||
[False, True],
|
||||
),
|
||||
)
|
||||
def test_fused_moe_fp8_blockwise_cute_dsl(dtype,
|
||||
@ -1064,6 +1065,7 @@ def test_fused_moe_fp8_blockwise_cute_dsl(dtype,
|
||||
hidden_size,
|
||||
RoutingMethodCls,
|
||||
WeightLoadingMode,
|
||||
use_cute_dsl_fp8,
|
||||
mapping=None):
|
||||
SEQ_LEN = seq_len
|
||||
HIDDEN_SIZE = hidden_size
|
||||
@ -1153,6 +1155,7 @@ def test_fused_moe_fp8_blockwise_cute_dsl(dtype,
|
||||
reduce_results=True,
|
||||
model_config=ModelConfig(quant_config=quant_config, mapping=mapping),
|
||||
weight_loading_mode=WeightLoadingMode,
|
||||
use_cute_dsl_fp8=use_cute_dsl_fp8,
|
||||
)
|
||||
fused_moe.cuda()
|
||||
fused_moe.load_weights([weights])
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -109,6 +109,50 @@ def test_fp8_block_scale_gemm(dtype, m, k, n):
|
||||
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not isSM100Family(),
|
||||
reason="The test is for Blackwell. Current SM is %d." % getSMVersion(),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"k, n",
|
||||
[(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096),
|
||||
(2048, 7168), (1024, 1024)],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"m",
|
||||
[7, 64, 128, 4096],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
[torch.bfloat16],
|
||||
)
|
||||
def test_cute_dsl_fp8_block_scale_gemm(dtype, m, k, n):
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
a = torch.randn((m, k), device='cuda', dtype=dtype) / k
|
||||
b = torch.randn((n, k), device='cuda', dtype=dtype) / k
|
||||
|
||||
act_a_fp8, act_a_sf = torch.ops.trtllm.fp8_quantize_1x128(a)
|
||||
act_b_fp8, act_b_sf = per_block_cast_to_fp8(b)
|
||||
|
||||
output_expected = a @ b.t()
|
||||
|
||||
with autotune():
|
||||
cute_dsl_output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell(
|
||||
act_a_fp8, act_b_fp8, act_a_sf, act_b_sf)
|
||||
|
||||
# test Cute DSL kernel
|
||||
cute_dsl_output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell(
|
||||
act_a_fp8, act_b_fp8, act_a_sf, act_b_sf)
|
||||
|
||||
diff = calc_diff(cute_dsl_output, output_expected)
|
||||
assert diff < 1e-3
|
||||
torch.testing.assert_close(cute_dsl_output,
|
||||
output_expected,
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
getSMVersion() != 90 and getSMVersion() != 89 and getSMVersion() != 120,
|
||||
reason="The test is for Hopper and Ada only. Current SM is %d." %
|
||||
@ -171,6 +215,57 @@ def test_fp8_block_scale_bmm(dtype, m, k, n, num_groups):
|
||||
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not isSM100Family(),
|
||||
reason="The test is for Blackwell. Current SM is %d." % getSMVersion(),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"k, n",
|
||||
[(7168, 2112), (512, 32768), (16384, 7168), (2048, 7168)],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"m",
|
||||
[7, 64, 128],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_groups",
|
||||
[4, 8, 16],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
[torch.bfloat16],
|
||||
)
|
||||
def test_cute_dsl_fp8_block_scale_bmm(dtype, m, k, n, num_groups):
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
a = torch.randn((m, num_groups, k), device='cuda', dtype=dtype) / k
|
||||
a_fp8, a_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(a)
|
||||
|
||||
b = torch.randn((num_groups, n, k), device='cuda', dtype=dtype) / k
|
||||
b_fp8 = torch.zeros_like(b, device='cuda', dtype=torch.float8_e4m3fn)
|
||||
b_scales = torch.zeros((num_groups, (n + 127) // 128, (k + 127) // 128),
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
|
||||
for i in range(num_groups):
|
||||
b_fp8[i], b_scales[i] = per_block_cast_to_fp8(b[i])
|
||||
|
||||
output_expected = torch.einsum('mgk,gnk->gmn', a, b)
|
||||
output = torch.empty((num_groups, m, n),
|
||||
device='cuda',
|
||||
dtype=torch.bfloat16)
|
||||
# tune
|
||||
with autotune():
|
||||
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(a_fp8, b_fp8, a_scales,
|
||||
b_scales, output)
|
||||
# run the tuned kernel
|
||||
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(a_fp8, b_fp8, a_scales,
|
||||
b_scales, output)
|
||||
diff = calc_diff(output, output_expected)
|
||||
assert diff < 1e-3
|
||||
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def deepSeekFp8ComputeGemmReference(mM, mN, mK, valsC, dqSfsC, valsA, dqSfsA,
|
||||
valsB, dqSfsB, quantizeOutput, tileSize):
|
||||
for mi in range(mM):
|
||||
|
||||
@ -0,0 +1,87 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from _torch.helpers import calc_diff, per_block_cast_to_fp8
|
||||
from utils.util import getSMVersion, isSM100Family
|
||||
|
||||
from tensorrt_llm._torch.autotuner import autotune
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not isSM100Family(),
|
||||
reason="The test is for Blackwell only. Current SM is %d." % getSMVersion(),
|
||||
)
|
||||
@pytest.mark.parametrize("num_experts", [72])
|
||||
@pytest.mark.parametrize("k", [1536])
|
||||
@pytest.mark.parametrize("n", [2560])
|
||||
@pytest.mark.parametrize("max_tokens_per_group", [10, 50, 100, 128, 256, 512, 1000, 1024])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
def test_cute_dsl_fp8_block_scale_group_gemm(dtype, num_experts, k, n, max_tokens_per_group):
|
||||
random.seed(0)
|
||||
torch.random.manual_seed(0)
|
||||
|
||||
group_m = []
|
||||
for i in range(num_experts):
|
||||
group_m.append(random.randint(0, max_tokens_per_group))
|
||||
group_m = torch.tensor(group_m, dtype=torch.int32, device="cuda")
|
||||
group_m_cum = torch.cumsum(group_m, dim=0)
|
||||
group_offset = torch.cat([torch.zeros(1, dtype=torch.int32, device="cuda"), group_m_cum], dim=0)
|
||||
group_offset = group_offset.to(torch.int32)
|
||||
|
||||
m = sum(group_m)
|
||||
a = torch.randn((m, k), device="cuda", dtype=dtype) / k
|
||||
b = torch.randn((num_experts, n, k), device="cuda", dtype=dtype) / k
|
||||
output_expected = torch.zeros((m, n), device="cuda", dtype=dtype)
|
||||
|
||||
for i in range(num_experts):
|
||||
start = group_offset[i]
|
||||
end = group_offset[i + 1]
|
||||
output_expected[start:end, :] = torch.einsum("mk,nk->mn", a[start:end, :], b[i, :, :])
|
||||
|
||||
a_fp8, a_scale = torch.ops.trtllm.fp8_quantize_1x128(a)
|
||||
b_fp8 = torch.empty(num_experts, n, k, dtype=torch.float8_e4m3fn, device="cuda")
|
||||
b_scale = torch.empty(
|
||||
num_experts, math.ceil(n / 128), math.ceil(k / 128), dtype=torch.float32, device="cuda"
|
||||
)
|
||||
for i in range(num_experts):
|
||||
cur_b, cur_b_scale = per_block_cast_to_fp8(b[i, :, :])
|
||||
b_fp8[i, :, :] = cur_b
|
||||
b_scale[i, :, :] = cur_b_scale
|
||||
|
||||
with autotune():
|
||||
output = torch.ops.trtllm.cute_dsl_fp8_group_blockwise_gemm_blackwell(
|
||||
input=a_fp8,
|
||||
weight=b_fp8,
|
||||
input_scale=a_scale,
|
||||
weight_scale=b_scale,
|
||||
group_offset=group_offset,
|
||||
)
|
||||
output = torch.ops.trtllm.cute_dsl_fp8_group_blockwise_gemm_blackwell(
|
||||
input=a_fp8,
|
||||
weight=b_fp8,
|
||||
input_scale=a_scale,
|
||||
weight_scale=b_scale,
|
||||
group_offset=group_offset,
|
||||
)
|
||||
|
||||
diff = calc_diff(output, output_expected)
|
||||
assert diff < 1e-3
|
||||
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
|
||||
Loading…
Reference in New Issue
Block a user