mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[TRTLLM-9457][feat] Add cute dsl fp8 gemm for Blackwell (#10130)
Added FP8 cute dsl gemm and batch gemm. Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com>
This commit is contained in:
parent
712dcd31a9
commit
5521c7b7e7
@ -121,9 +121,8 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_quantize_1x128_permute102(at::Ten
|
||||
|
||||
int64_t scaleSizeInBytes = mGemmRunner.getActScaleSize(m, b * n);
|
||||
int64_t elementSize = scaleSizeInBytes / torch::elementSize(FP8_BLOCK_SCALING_SF_DTYPE);
|
||||
int m_4_align = (m + 3) / 4 * 4;
|
||||
at::Tensor scaleFP8SF = at::detail::empty_cuda({b, m_4_align, elementSize / b / m_4_align},
|
||||
FP8_BLOCK_SCALING_SF_DTYPE, self.device(), /* stride */ std::nullopt);
|
||||
at::Tensor scaleFP8SF = at::detail::empty_cuda(
|
||||
{elementSize}, FP8_BLOCK_SCALING_SF_DTYPE, self.device(), /* stride */ std::nullopt); // 1D tensor
|
||||
|
||||
__nv_fp8_e4m3* act_buffer = reinterpret_cast<__nv_fp8_e4m3*>(valueE4M3.data_ptr());
|
||||
float* act_scale_buffer = reinterpret_cast<float*>(scaleFP8SF.data_ptr());
|
||||
@ -133,6 +132,13 @@ std::tuple<at::Tensor, at::Tensor> fp8_batched_quantize_1x128_permute102(at::Ten
|
||||
auto* output_buffer = reinterpret_cast<__nv_bfloat16 const*>(self.data_ptr());
|
||||
mGemmRunner.fp8CS1x128Reshape(act_buffer, act_scale_buffer, output_buffer, n, b, m, lda, stream);
|
||||
|
||||
// scaleFP8SF = scaleFP8SF[:, 0:num_n_blocks, 0:m_padded]
|
||||
auto const num_n_blocks = (n + 127) / 128;
|
||||
auto const act_scal_elesize = b * num_n_blocks * m_padded;
|
||||
TORCH_CHECK(act_scal_elesize <= scaleFP8SF.numel(), "Scale tensor size mismatch. Expected at least ",
|
||||
act_scal_elesize, " elements, got ", scaleFP8SF.numel());
|
||||
scaleFP8SF = scaleFP8SF.slice(0, 0, act_scal_elesize).view({b, num_n_blocks, m_padded}).contiguous();
|
||||
|
||||
return {valueE4M3.slice(0, 0, b * m * n).view({b, m, n}), scaleFP8SF};
|
||||
}
|
||||
} // namespace torch_ext
|
||||
|
||||
@ -81,14 +81,14 @@ def inplace_info():
|
||||
1: "logits"
|
||||
},
|
||||
torch.ops.trtllm.moe_unpermute_inplace.default: {
|
||||
2: "output"
|
||||
1: "output"
|
||||
},
|
||||
torch.ops.trtllm.moe_output_memset_inplace.default: {
|
||||
1: "input"
|
||||
},
|
||||
torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell.default:
|
||||
{
|
||||
6: "output"
|
||||
1: "output"
|
||||
},
|
||||
torch.ops.trtllm.pp_recv_tensors.default: {
|
||||
1: "tensors"
|
||||
@ -96,6 +96,9 @@ def inplace_info():
|
||||
torch.ops.trtllm.pp_send_tensors.default: {
|
||||
1: "tensors"
|
||||
},
|
||||
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell.default: {
|
||||
1: "output"
|
||||
}
|
||||
}
|
||||
if IS_CUDA_TILE_AVAILABLE:
|
||||
# cuda.tile availability depends on GPU capability thus runtime check.
|
||||
|
||||
@ -6,13 +6,13 @@ import torch
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
from ..._utils import get_sm_version
|
||||
from ..._utils import get_sm_version, is_sm_100f
|
||||
from ...math_utils import ceil_div, pad_up
|
||||
from ..autotuner import (AutoTuner, ConstraintSpec, DistributedTuningStrategy,
|
||||
DynamicTensorSpec, OptimizationProfile, TunableRunner,
|
||||
TuningConfig)
|
||||
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
|
||||
from ..utils import (fp4_scale_infer_shape,
|
||||
from ..utils import (fp4_scale_infer_shape, fp8_scale_infer_shape,
|
||||
get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2)
|
||||
|
||||
@ -314,11 +314,13 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel
|
||||
from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm_swiglu_fusion import \
|
||||
Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel
|
||||
from ..cute_dsl_kernels.blackwell.blockwise_gemm.blockwise_gemm import \
|
||||
Sm100BlockwiseGemmKernel
|
||||
from ..cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import \
|
||||
Sm100BlockScaledPersistentDenseGemmKernel
|
||||
from ..cute_dsl_kernels.blackwell.utils import make_ptr
|
||||
|
||||
class CuteDSLNVFP4BlackwellLinear(TunableRunner):
|
||||
class CuteDSLNVFP4BlackwellRunner(TunableRunner):
|
||||
kernel_class = Sm100BlockScaledPersistentDenseGemmKernel
|
||||
kernel_cache = dict()
|
||||
tuning_config = TuningConfig(
|
||||
@ -500,7 +502,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs fp8 blockwise gemm operation using CuTe DSL.
|
||||
Performs fp4 blockwise gemm operation using CuTe DSL.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]):
|
||||
@ -590,7 +592,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
cache_key = (sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab,
|
||||
use_prefetch)
|
||||
use_prefetch, self.use_tvm_ffi)
|
||||
if swap_ab:
|
||||
kernel_m = n
|
||||
kernel_n = m
|
||||
@ -770,7 +772,7 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers,
|
||||
runner = CuteDSLNVFP4BlackwellRunner(output_dtype, to_userbuffers,
|
||||
use_tvm_ffi)
|
||||
inputs = [input, weight, input_scale, weight_scale, alpha]
|
||||
_, best_tactic = tuner.choose_one(
|
||||
@ -2161,3 +2163,629 @@ if IS_CUTLASS_DSL_AVAILABLE:
|
||||
dtype=input_scale.dtype,
|
||||
device=input_scale.device)
|
||||
return output, output_scale
|
||||
|
||||
class CuteDSLFp8BlackwellRunner(TunableRunner):
|
||||
kernel_class = Sm100BlockwiseGemmKernel
|
||||
kernel_cache = dict()
|
||||
|
||||
tuning_config = TuningConfig(
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2), ),
|
||||
constraint_specs=(ConstraintSpec(2, 1, fp8_scale_infer_shape), ),
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
output_dtype: torch.dtype = torch.bfloat16,
|
||||
use_tvm_ffi: bool = True):
|
||||
super().__init__()
|
||||
if output_dtype != torch.bfloat16:
|
||||
raise ValueError(
|
||||
f"CuteDSL FP8 GEMM only supports bfloat16 output, got {output_dtype}"
|
||||
)
|
||||
self.output_dtype = output_dtype
|
||||
self.use_tvm_ffi = use_tvm_ffi
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
**kwargs,
|
||||
) -> List[int]:
|
||||
if not is_sm_100f():
|
||||
logger.debug(
|
||||
f"CuteDSL: SM version {get_sm_version()} is not supported. "
|
||||
f"CuteDSL FP8 GEMM only supports SM 100 family. Skipping all tactics."
|
||||
)
|
||||
return []
|
||||
|
||||
m = inputs[0].shape[0]
|
||||
n = inputs[1].shape[0]
|
||||
k = inputs[0].shape[1]
|
||||
batch_size = 1
|
||||
# m,k
|
||||
a_major = "k"
|
||||
# n, k
|
||||
b_major = "k"
|
||||
# m, n
|
||||
c_major = "n"
|
||||
|
||||
use_2cta_instrs_candi = [False, True]
|
||||
mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)]
|
||||
cluster_shape_mn_candi = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
]
|
||||
return [
|
||||
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn)
|
||||
for use_2cta_instrs in use_2cta_instrs_candi
|
||||
for mma_tiler_mn in mma_tiler_mn_candi
|
||||
for cluster_shape_mn in cluster_shape_mn_candi
|
||||
if self.__class__.kernel_class.can_implement(
|
||||
cutlass.Float8E4M3FN, # ab_dtype,
|
||||
cutlass.Float32, # acc_dtype,
|
||||
cutlass.BFloat16, # c_dtype,
|
||||
use_2cta_instrs,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
batch_size,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
)
|
||||
]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs fp8 blockwise (deepgemm like) operation using CuTe DSL.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]):
|
||||
inputs[0]: Input tensor of shape (m, k), dtype: fp8.
|
||||
inputs[1]: Weight tensor of shape (n, k), dtype: fp8.
|
||||
inputs[2]: Input scale factor tensor of shape (k // 128, m), dtype: fp32.
|
||||
inputs[3]: Weight scale factor tensor of shape (n // 128, k // 128), dtype: fp32.
|
||||
tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
|
||||
"""
|
||||
if isinstance(tactic, tuple):
|
||||
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic
|
||||
else:
|
||||
# fallback to default tactic
|
||||
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [
|
||||
False,
|
||||
(128, 128),
|
||||
(1, 1),
|
||||
]
|
||||
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs
|
||||
m, n, k = a_tensor.shape[0], b_tensor.shape[0], b_tensor.shape[1]
|
||||
sf_m = m
|
||||
sf_k = ceil_div(k, 128)
|
||||
sf_n = ceil_div(n, 128)
|
||||
c_tensor = torch.empty(*(m, n),
|
||||
dtype=torch.bfloat16,
|
||||
device=a_tensor.device)
|
||||
c_tmp = c_tensor.view((1, m, n))
|
||||
c_tmp = c_tmp.permute(1, 2, 0)
|
||||
|
||||
if not self.use_tvm_ffi:
|
||||
a_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
a_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
b_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
a_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
a_sf_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
b_sf_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
c_cute_tensor = cute.runtime.from_dlpack(
|
||||
c_tmp).mark_layout_dynamic(leading_dim=1)
|
||||
|
||||
# get stream
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
cache_key = (
|
||||
use_2cta_instrs,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
self.use_tvm_ffi,
|
||||
)
|
||||
if cache_key not in self.__class__.kernel_cache:
|
||||
if self.use_tvm_ffi:
|
||||
a_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
a_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
b_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
a_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
a_sf_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
b_sf_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
# Convert c_tensor to cute tensor for TVM FFI for env stream detection
|
||||
c_cute_tensor = cute.runtime.from_dlpack(
|
||||
c_tmp).mark_layout_dynamic(leading_dim=1)
|
||||
stream = cute.runtime.make_fake_stream(
|
||||
use_tvm_ffi_env_stream=True)
|
||||
|
||||
gemm = self.__class__.kernel_class(
|
||||
cutlass.Float32, # acc_dtype,
|
||||
use_2cta_instrs=use_2cta_instrs,
|
||||
mma_tiler_mn=mma_tiler_mn,
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
)
|
||||
# Compute max active clusters on current device
|
||||
hardware_info = cutlass.utils.HardwareInfo()
|
||||
max_active_clusters = hardware_info.get_max_active_clusters(
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1])
|
||||
|
||||
compiled_gemm = cute.compile(
|
||||
gemm.wrapper,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
sf_m,
|
||||
sf_n,
|
||||
sf_k,
|
||||
1, # batch
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
a_sf_ptr,
|
||||
b_sf_ptr,
|
||||
c_cute_tensor,
|
||||
max_active_clusters=max_active_clusters,
|
||||
stream=stream,
|
||||
options=f"--opt-level 2 --enable-tvm-ffi"
|
||||
if self.use_tvm_ffi else "--opt-level 2",
|
||||
)
|
||||
self.__class__.kernel_cache[cache_key] = compiled_gemm
|
||||
else:
|
||||
compiled_gemm = self.__class__.kernel_cache[cache_key]
|
||||
|
||||
# launch gemm kernel
|
||||
if self.use_tvm_ffi:
|
||||
# call with torch pointer types and no need to pass stream.
|
||||
compiled_gemm(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
sf_m,
|
||||
sf_n,
|
||||
sf_k,
|
||||
1, # batch
|
||||
a_tensor.data_ptr(),
|
||||
b_tensor.data_ptr(),
|
||||
a_sf_tensor.data_ptr(),
|
||||
b_sf_tensor.data_ptr(),
|
||||
c_tmp,
|
||||
)
|
||||
else:
|
||||
# call with cute types and need to pass torch stream.
|
||||
compiled_gemm(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
sf_m,
|
||||
sf_n,
|
||||
sf_k,
|
||||
1, # batch
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
a_sf_ptr,
|
||||
b_sf_ptr,
|
||||
c_cute_tensor,
|
||||
stream=stream,
|
||||
)
|
||||
return c_tensor
|
||||
|
||||
# a/b: fp8, scale: fp32, output: bf16
|
||||
@torch.library.custom_op("trtllm::cute_dsl_fp8_gemm_blackwell",
|
||||
mutates_args=(),
|
||||
device_types="cuda")
|
||||
def cute_dsl_fp8_gemm_blackwell(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
output_dtype: torch.dtype = torch.bfloat16,
|
||||
use_tvm_ffi: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if output_dtype != torch.bfloat16:
|
||||
raise ValueError(
|
||||
f"CuteDSL FP8 GEMM only supports bfloat16 output, got {output_dtype}"
|
||||
)
|
||||
if not is_sm_100f():
|
||||
raise ValueError(
|
||||
f"CuteDSL: SM version {get_sm_version()} is not supported. "
|
||||
f"CuteDSL FP8 GEMM only supports SM 100 family. Skipping all tactics."
|
||||
)
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
runner = CuteDSLFp8BlackwellRunner(output_dtype=output_dtype,
|
||||
use_tvm_ffi=use_tvm_ffi)
|
||||
|
||||
inputs = [input, weight, input_scale, weight_scale]
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::cute_dsl_fp8_gemm_blackwell::gemm",
|
||||
[runner],
|
||||
runner.__class__.tuning_config,
|
||||
inputs,
|
||||
)
|
||||
return runner(inputs, tactic=best_tactic)
|
||||
|
||||
@torch.library.register_fake("trtllm::cute_dsl_fp8_gemm_blackwell")
|
||||
def _(
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
output_dtype: torch.dtype = torch.bfloat16,
|
||||
use_tvm_ffi: bool = True,
|
||||
):
|
||||
# [m, k]
|
||||
shape = list(mat_a.shape)
|
||||
# [n, k]
|
||||
shape[-1] = mat_b.shape[-2]
|
||||
# output is fixed as bf16
|
||||
ret = mat_a.new_empty(shape, dtype=torch.bfloat16)
|
||||
return ret
|
||||
|
||||
class CuteDSLFp8BlackwellBmmRunner(TunableRunner):
|
||||
kernel_class = Sm100BlockwiseGemmKernel
|
||||
kernel_cache = dict()
|
||||
|
||||
tuning_config = TuningConfig(
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 1, get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2), ),
|
||||
constraint_specs=(ConstraintSpec(2, 2, fp8_scale_infer_shape), ),
|
||||
)
|
||||
|
||||
def __init__(self,
|
||||
output_dtype: torch.dtype = torch.bfloat16,
|
||||
use_tvm_ffi: bool = True):
|
||||
super().__init__()
|
||||
if output_dtype != torch.bfloat16:
|
||||
raise ValueError(
|
||||
f"CuteDSL FP8 BMM only supports bfloat16 output, got {output_dtype}"
|
||||
)
|
||||
self.output_dtype = output_dtype
|
||||
self.use_tvm_ffi = use_tvm_ffi
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
**kwargs,
|
||||
) -> List[int]:
|
||||
|
||||
if not is_sm_100f():
|
||||
logger.debug(
|
||||
f"CuteDSL: SM version {get_sm_version()} is not supported. "
|
||||
f"CuteDSL FP8 BMM only supports SM 100 family. Skipping all tactics."
|
||||
)
|
||||
return []
|
||||
# [b, m, k]
|
||||
batch_size, m, k = inputs[0].shape[0], inputs[0].shape[1], inputs[
|
||||
0].shape[2]
|
||||
# [b, n, k]
|
||||
n = inputs[1].shape[1]
|
||||
# m,k
|
||||
a_major = "k"
|
||||
# n, k
|
||||
b_major = "k"
|
||||
# m, n
|
||||
c_major = "n"
|
||||
|
||||
use_2cta_instrs_candi = [False, True]
|
||||
mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)]
|
||||
cluster_shape_mn_candi = [
|
||||
(1, 1),
|
||||
(1, 2),
|
||||
(1, 4),
|
||||
(2, 1),
|
||||
(2, 2),
|
||||
(2, 4),
|
||||
(4, 1),
|
||||
(4, 2),
|
||||
(4, 4),
|
||||
]
|
||||
return [
|
||||
(use_2cta_instrs, mma_tiler_mn, cluster_shape_mn)
|
||||
for use_2cta_instrs in use_2cta_instrs_candi
|
||||
for mma_tiler_mn in mma_tiler_mn_candi
|
||||
for cluster_shape_mn in cluster_shape_mn_candi
|
||||
if self.__class__.kernel_class.can_implement(
|
||||
cutlass.Float8E4M3FN, # ab_dtype,
|
||||
cutlass.Float32, # acc_dtype,
|
||||
cutlass.BFloat16, # c_dtype,
|
||||
use_2cta_instrs,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
batch_size,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
)
|
||||
]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic,
|
||||
) -> None:
|
||||
"""
|
||||
Performs fp8 blockwise (deepgemm like) batched gemm operation using CuTe DSL.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]):
|
||||
inputs[0]: Input tensor of shape (batch_size, m, k), dtype: fp8.
|
||||
inputs[1]: Weight tensor of shape (batch_size, n, k), dtype: fp8.
|
||||
inputs[2]: Input scale tensor of shape (batch_size, k // 128, pad_up(m, 4)), dtype: fp32.
|
||||
inputs[3]: Weight scale tensor of shape (batch_size, n // 128, k // 128), dtype: fp32.
|
||||
tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (batch_size, m, n), dtype: bf16.
|
||||
"""
|
||||
if isinstance(tactic, tuple):
|
||||
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic
|
||||
else:
|
||||
# fallback to default tactic
|
||||
use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [
|
||||
False,
|
||||
(128, 128),
|
||||
(1, 1),
|
||||
]
|
||||
|
||||
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor, c_tensor = inputs
|
||||
c_tmp = c_tensor.permute(1, 2, 0)
|
||||
|
||||
batch_size = a_tensor.shape[0]
|
||||
m = a_tensor.shape[1]
|
||||
k = a_tensor.shape[2]
|
||||
n = b_tensor.shape[1]
|
||||
sf_m = pad_up(m, 4)
|
||||
sf_k = ceil_div(k, 128)
|
||||
sf_n = ceil_div(n, 128)
|
||||
|
||||
if not self.use_tvm_ffi:
|
||||
a_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
a_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
b_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
a_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
a_sf_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
b_sf_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
c_cute_tensor = cute.runtime.from_dlpack(
|
||||
c_tmp).mark_layout_dynamic(leading_dim=1)
|
||||
|
||||
# get stream
|
||||
stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream)
|
||||
|
||||
cache_key = (
|
||||
use_2cta_instrs,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
self.use_tvm_ffi,
|
||||
)
|
||||
if cache_key not in self.__class__.kernel_cache:
|
||||
if self.use_tvm_ffi:
|
||||
a_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
a_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_ptr = make_ptr(
|
||||
cutlass.Float8E4M3FN,
|
||||
b_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
a_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
a_sf_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
b_sf_ptr = make_ptr(
|
||||
cutlass.Float32,
|
||||
b_sf_tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=16,
|
||||
)
|
||||
# Convert c_tensor to cute tensor for TVM FFI for env stream detection)
|
||||
c_cute_tensor = cute.runtime.from_dlpack(
|
||||
c_tmp).mark_layout_dynamic(leading_dim=1)
|
||||
# make faked stream for TVM FFI
|
||||
stream = cute.runtime.make_fake_stream(
|
||||
use_tvm_ffi_env_stream=True)
|
||||
|
||||
gemm = self.__class__.kernel_class(
|
||||
cutlass.Float32, # acc_dtype,
|
||||
use_2cta_instrs=use_2cta_instrs,
|
||||
mma_tiler_mn=mma_tiler_mn,
|
||||
cluster_shape_mn=cluster_shape_mn,
|
||||
)
|
||||
# Compute max active clusters on current device
|
||||
hardware_info = cutlass.utils.HardwareInfo()
|
||||
max_active_clusters = hardware_info.get_max_active_clusters(
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1])
|
||||
|
||||
compiled_gemm = cute.compile(
|
||||
gemm.wrapper,
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
sf_m,
|
||||
sf_n,
|
||||
sf_k,
|
||||
batch_size,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
a_sf_ptr,
|
||||
b_sf_ptr,
|
||||
c_cute_tensor,
|
||||
max_active_clusters=max_active_clusters,
|
||||
stream=stream,
|
||||
options=f"--opt-level 2 --enable-tvm-ffi"
|
||||
if self.use_tvm_ffi else "--opt-level 2",
|
||||
)
|
||||
self.__class__.kernel_cache[cache_key] = compiled_gemm
|
||||
else:
|
||||
compiled_gemm = self.__class__.kernel_cache[cache_key]
|
||||
|
||||
# launch gemm kernel
|
||||
if self.use_tvm_ffi:
|
||||
# call with torch pointer types and no need to pass stream.
|
||||
compiled_gemm(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
sf_m,
|
||||
sf_n,
|
||||
sf_k,
|
||||
batch_size,
|
||||
a_tensor.data_ptr(),
|
||||
b_tensor.data_ptr(),
|
||||
a_sf_tensor.data_ptr(),
|
||||
b_sf_tensor.data_ptr(),
|
||||
c_tmp,
|
||||
)
|
||||
else:
|
||||
# call with cute types and need to pass torch stream.
|
||||
compiled_gemm(
|
||||
m,
|
||||
n,
|
||||
k,
|
||||
sf_m,
|
||||
sf_n,
|
||||
sf_k,
|
||||
batch_size,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
a_sf_ptr,
|
||||
b_sf_ptr,
|
||||
c_cute_tensor,
|
||||
stream=stream,
|
||||
)
|
||||
|
||||
# a/b: fp8, scale: fp32, output: bf16
|
||||
@torch.library.custom_op("trtllm::cute_dsl_fp8_bmm_blackwell",
|
||||
mutates_args=("output", ),
|
||||
device_types="cuda")
|
||||
def cute_dsl_fp8_bmm_blackwell(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
output_dtype: torch.dtype = torch.bfloat16,
|
||||
use_tvm_ffi: bool = True,
|
||||
) -> None:
|
||||
if output_dtype != torch.bfloat16:
|
||||
raise ValueError(
|
||||
f"CuteDSL FP8 BMM only supports bfloat16 output, got {output_dtype}"
|
||||
)
|
||||
if not is_sm_100f():
|
||||
raise ValueError(
|
||||
f"CuteDSL: SM version {get_sm_version()} is not supported. "
|
||||
f"CuteDSL FP8 BMM only supports SM 100 family. Skipping all tactics."
|
||||
)
|
||||
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
runner = CuteDSLFp8BlackwellBmmRunner(output_dtype=output_dtype,
|
||||
use_tvm_ffi=use_tvm_ffi)
|
||||
|
||||
inputs = [input, weight, input_scale, weight_scale, output]
|
||||
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::cute_dsl_fp8_bmm_blackwell::gemm",
|
||||
[runner],
|
||||
runner.__class__.tuning_config,
|
||||
inputs,
|
||||
)
|
||||
runner(inputs, tactic=best_tactic)
|
||||
|
||||
@torch.library.register_fake("trtllm::cute_dsl_fp8_bmm_blackwell")
|
||||
def _(
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
output_dtype: torch.dtype = torch.bfloat16,
|
||||
use_tvm_ffi: bool = True,
|
||||
) -> None:
|
||||
batch_size, m, k = mat_a.shape[0], mat_a.shape[1], mat_a.shape[2]
|
||||
n = mat_b.shape[1]
|
||||
assert output.dtype == torch.bfloat16, "CuTe DSL fp8 bmm output dtype must be bf16"
|
||||
assert output.shape == (batch_size, m,
|
||||
n), "CuTe DSL fp8 bmm output shape is incorrect"
|
||||
|
||||
@ -25,7 +25,7 @@ from ..utils import (ActivationType, fp4_scale_infer_shape,
|
||||
|
||||
if IS_CUTLASS_DSL_AVAILABLE:
|
||||
from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \
|
||||
CuteDSLNVFP4BlackwellLinear
|
||||
CuteDSLNVFP4BlackwellRunner
|
||||
|
||||
|
||||
# Used to WAR an issue in torch.bmm that it would break the graph when the out is not contiguous.
|
||||
@ -819,7 +819,7 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
|
||||
"Please add other backends to allowed_backends.")
|
||||
else:
|
||||
# SM version OK, check if CuteDSL supports the current shape
|
||||
cutedsl_runner = CuteDSLNVFP4BlackwellLinear(
|
||||
cutedsl_runner = CuteDSLNVFP4BlackwellRunner(
|
||||
self.output_dtype)
|
||||
cutedsl_tactics = cutedsl_runner.get_valid_tactics(
|
||||
inputs, profile)
|
||||
@ -878,7 +878,7 @@ class NVFP4GemmUnifiedRunner(TunableRunner):
|
||||
self.output_dtype)(inputs,
|
||||
tactic=sub_tactic)
|
||||
elif backend == "cutedsl":
|
||||
return CuteDSLNVFP4BlackwellLinear(
|
||||
return CuteDSLNVFP4BlackwellRunner(
|
||||
self.output_dtype, self.to_userbuffers)(inputs,
|
||||
tactic=sub_tactic)
|
||||
else:
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -122,6 +122,10 @@ class ModelConfig(Generic[TConfig]):
|
||||
|
||||
extra_attrs: Dict = field(default_factory=dict, repr=False, init=False)
|
||||
|
||||
# cute dsl op configs
|
||||
use_cute_dsl_blockscaling_mm: bool = False
|
||||
use_cute_dsl_blockscaling_bmm: bool = False
|
||||
|
||||
_frozen: bool = field(default=False, init=False, repr=False)
|
||||
|
||||
# If true, ONLY the vision encoder part of the full model is loaded/executed.
|
||||
|
||||
@ -672,6 +672,7 @@ class DeepseekV3Linear(Linear):
|
||||
reduce_output: bool = True, # ROW parallel only
|
||||
skip_create_weights_in_init: bool = False,
|
||||
use_custom_cublas_mm: bool = False,
|
||||
use_cute_dsl_blockscaling_mm: bool = False,
|
||||
lora: Optional[LoraLayer] = None,
|
||||
):
|
||||
super().__init__(
|
||||
@ -688,6 +689,7 @@ class DeepseekV3Linear(Linear):
|
||||
skip_create_weights_in_init,
|
||||
use_custom_cublas_mm,
|
||||
lora,
|
||||
use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm,
|
||||
)
|
||||
|
||||
def apply_linear(self,
|
||||
@ -748,7 +750,10 @@ class DeepseekV3Attention(MLA):
|
||||
quant_config=model_config.get_quant_config(),
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
use_custom_cublas_mm=True)
|
||||
use_custom_cublas_mm=True,
|
||||
use_cute_dsl_blockscaling_mm=model_config.
|
||||
use_cute_dsl_blockscaling_mm,
|
||||
)
|
||||
|
||||
|
||||
class DeepseekV32Attention(MLA):
|
||||
@ -925,6 +930,7 @@ class Deepseekv3MoE(nn.Module):
|
||||
config = model_config.pretrained_config
|
||||
self.top_k = top_k
|
||||
self.use_dp = model_config.mapping.enable_attention_dp
|
||||
self.use_cute_dsl_blockscaling_mm = model_config.use_cute_dsl_blockscaling_mm
|
||||
gate_cls = DeepseekV3Gate
|
||||
if hasattr(model_config.pretrained_config, "gate_cls"):
|
||||
gate_cls = model_config.pretrained_config.gate_cls
|
||||
@ -977,7 +983,9 @@ class Deepseekv3MoE(nn.Module):
|
||||
dtype=dtype,
|
||||
config=model_config,
|
||||
overridden_tp_size=shared_tp_size,
|
||||
reduce_output=False)
|
||||
reduce_output=False,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm,
|
||||
)
|
||||
|
||||
self.allreduce = None
|
||||
if not self.use_dp and self.mapping.tp_size > 1:
|
||||
@ -1262,13 +1270,17 @@ class DeepseekV3DecoderLayer(DecoderLayer):
|
||||
self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_mlp_tp and self.is_nvfp4
|
||||
self.fusion_config.POST_MLP_FUSION = self.enable_fusion and has_mlp_tp
|
||||
|
||||
self.mlp = GatedMLP(hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
overridden_tp_size=self.mlp_tp_size,
|
||||
reduce_output=has_mlp_tp)
|
||||
self.mlp = GatedMLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.intermediate_size,
|
||||
bias=False,
|
||||
dtype=config.torch_dtype,
|
||||
config=model_config,
|
||||
overridden_tp_size=self.mlp_tp_size,
|
||||
reduce_output=has_mlp_tp,
|
||||
use_cute_dsl_blockscaling_mm=model_config.
|
||||
use_cute_dsl_blockscaling_mm,
|
||||
)
|
||||
|
||||
self.input_layernorm = RMSNorm(hidden_size=config.hidden_size,
|
||||
eps=config.rms_norm_eps,
|
||||
@ -1564,6 +1576,8 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
|
||||
dtype=config.torch_dtype,
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
use_cute_dsl_blockscaling_mm=model_config.
|
||||
use_cute_dsl_blockscaling_mm,
|
||||
)
|
||||
else:
|
||||
self.eh_proj = Linear(
|
||||
@ -1576,6 +1590,8 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer):
|
||||
reduce_output=True,
|
||||
skip_create_weights_in_init=model_config.
|
||||
skip_create_weights_in_init,
|
||||
use_cute_dsl_blockscaling_mm=model_config.
|
||||
use_cute_dsl_blockscaling_mm,
|
||||
)
|
||||
|
||||
self.shared_head = DeepseekV3MTPHead(model_config)
|
||||
|
||||
@ -255,6 +255,9 @@ class Attention(nn.Module):
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_key_value_heads * self.head_dim
|
||||
|
||||
self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm
|
||||
self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm
|
||||
|
||||
qkv_shard_indices_mapping = {
|
||||
"q": (0, self.q_size * (2 if self.attn_output_gate else 1)),
|
||||
"k":
|
||||
@ -280,7 +283,8 @@ class Attention(nn.Module):
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
disable_deep_gemm=disable_deep_gemm,
|
||||
use_custom_cublas_mm=use_custom_cublas_mm,
|
||||
fused_weight_shard_indices_mapping=qkv_shard_indices_mapping)
|
||||
fused_weight_shard_indices_mapping=qkv_shard_indices_mapping,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
|
||||
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
|
||||
[self.hidden_size])
|
||||
@ -299,7 +303,8 @@ class Attention(nn.Module):
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
disable_deep_gemm=disable_deep_gemm,
|
||||
use_custom_cublas_mm=use_custom_cublas_mm)
|
||||
use_custom_cublas_mm=use_custom_cublas_mm,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
|
||||
self.quant_config = config.get_quant_config()
|
||||
self.attn_backend = config.attn_backend
|
||||
@ -686,6 +691,7 @@ def fp8_block_scaling_bmm_out(
|
||||
mat2_scale: torch.Tensor,
|
||||
out: torch.Tensor,
|
||||
mat2_dequant: Optional[torch.Tensor] = None,
|
||||
use_cute_dsl_blockscaling_bmm: bool = False,
|
||||
) -> torch.Tensor:
|
||||
sm_version = get_sm_version()
|
||||
if sm_version == 90 or sm_version == 89:
|
||||
@ -706,7 +712,17 @@ def fp8_block_scaling_bmm_out(
|
||||
output)
|
||||
out.copy_(output)
|
||||
elif is_sm_100f(sm_version):
|
||||
torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out)
|
||||
if use_cute_dsl_blockscaling_bmm:
|
||||
mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(
|
||||
mat1)
|
||||
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(mat1_fp8, mat2_fp8,
|
||||
mat1_scale, mat2_scale,
|
||||
out)
|
||||
mat1_scale = None
|
||||
else:
|
||||
torch.bmm(mat1.transpose(0, 1),
|
||||
mat2_dequant.transpose(1, 2),
|
||||
out=out)
|
||||
else:
|
||||
raise NotImplementedError(f"SM{sm_version} is not supported")
|
||||
|
||||
@ -851,6 +867,9 @@ class MLA(nn.Module):
|
||||
quant_config = config.get_quant_config()
|
||||
self.quant_config = quant_config
|
||||
|
||||
self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm
|
||||
self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm
|
||||
|
||||
if not self.is_lite:
|
||||
self.kv_a_proj_with_mqa = Linear(
|
||||
hidden_size,
|
||||
@ -860,7 +879,8 @@ class MLA(nn.Module):
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
use_custom_cublas_mm=True,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
|
||||
self.q_a_layernorm = RMSNorm(hidden_size=self.q_lora_rank,
|
||||
eps=rms_norm_eps,
|
||||
@ -876,7 +896,8 @@ class MLA(nn.Module):
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
else:
|
||||
self.kv_a_proj_with_mqa = Linear(
|
||||
hidden_size,
|
||||
@ -886,7 +907,8 @@ class MLA(nn.Module):
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
use_custom_cublas_mm=True,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
|
||||
self.q_proj = Linear(
|
||||
self.q_lora_rank,
|
||||
@ -898,7 +920,8 @@ class MLA(nn.Module):
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
self.q_b_proj = self.q_proj
|
||||
|
||||
self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank,
|
||||
@ -915,7 +938,8 @@ class MLA(nn.Module):
|
||||
quant_config=quant_config,
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
# This parameter will view into self.kv_b_proj.weight after loading weights.
|
||||
# For dummy weight initialization, this parameter is initialized with empty tensor.
|
||||
# Used in forward_absorption only
|
||||
@ -947,7 +971,8 @@ class MLA(nn.Module):
|
||||
skip_create_weights_in_init=config.skip_create_weights_in_init,
|
||||
reduce_output=reduce_output,
|
||||
allreduce_strategy=config.allreduce_strategy,
|
||||
force_dynamic_quantization=config.force_dynamic_quantization)
|
||||
force_dynamic_quantization=config.force_dynamic_quantization,
|
||||
use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm)
|
||||
|
||||
def yarn_get_mscale(scale=1, mscale=1):
|
||||
if scale <= 1:
|
||||
@ -1083,7 +1108,7 @@ class MLA(nn.Module):
|
||||
),
|
||||
requires_grad=False,
|
||||
)
|
||||
if is_sm_100f():
|
||||
if is_sm_100f() and not self.use_cute_dsl_blockscaling_bmm:
|
||||
assert self.dtype == torch.bfloat16
|
||||
self.k_b_proj_trans_dequant = nn.Parameter(
|
||||
torch.empty(
|
||||
@ -1875,6 +1900,7 @@ class MLA(nn.Module):
|
||||
self.k_b_proj_trans_scale,
|
||||
q_nope_out,
|
||||
self.k_b_proj_trans_dequant,
|
||||
self.use_cute_dsl_blockscaling_bmm,
|
||||
),
|
||||
lambda: self.mqa.mla_rope_generation(
|
||||
fused_q,
|
||||
@ -1952,6 +1978,7 @@ class MLA(nn.Module):
|
||||
self.v_b_proj_scale,
|
||||
attn_output.transpose(0, 1),
|
||||
self.v_b_proj_dequant,
|
||||
self.use_cute_dsl_blockscaling_bmm,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -2007,6 +2034,7 @@ class MLA(nn.Module):
|
||||
self.k_b_proj_trans_scale,
|
||||
q_nope_out,
|
||||
self.k_b_proj_trans_dequant,
|
||||
self.use_cute_dsl_blockscaling_bmm,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -2062,6 +2090,7 @@ class MLA(nn.Module):
|
||||
self.v_b_proj_scale,
|
||||
attn_output.transpose(0, 1),
|
||||
self.v_b_proj_dequant,
|
||||
self.use_cute_dsl_blockscaling_bmm,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -2129,6 +2158,7 @@ class MLA(nn.Module):
|
||||
self.k_b_proj_trans_scale,
|
||||
q_nope_out,
|
||||
self.k_b_proj_trans_dequant,
|
||||
self.use_cute_dsl_blockscaling_bmm,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
@ -2205,6 +2235,7 @@ class MLA(nn.Module):
|
||||
self.v_b_proj_scale,
|
||||
attn_output.transpose(0, 1),
|
||||
self.v_b_proj_dequant,
|
||||
self.use_cute_dsl_blockscaling_bmm,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
|
||||
@ -982,10 +982,9 @@ class FP8BlockScalesLinearMethod(UnquantizedLinearMethod):
|
||||
|
||||
if is_sm_100f():
|
||||
if module.use_cute_dsl_blockscaling_mm or module.disable_deep_gemm:
|
||||
# TODO (@lmin): replace with cute_dsl gemm
|
||||
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
|
||||
input)
|
||||
output = torch.ops.trtllm.fp8_block_scaling_gemm(
|
||||
output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell(
|
||||
act_input_fp8, module.weight, act_input_sf,
|
||||
module.weight_scale)
|
||||
else:
|
||||
|
||||
@ -364,7 +364,12 @@ class ModelLoader:
|
||||
use_low_precision_moe_combine=self.llm_args.moe_config.
|
||||
use_low_precision_moe_combine,
|
||||
nvfp4_gemm_allowed_backends=self.llm_args.nvfp4_gemm_config.
|
||||
allowed_backends)
|
||||
allowed_backends,
|
||||
use_cute_dsl_blockscaling_mm=self.llm_args.
|
||||
use_cute_dsl_blockscaling_mm,
|
||||
use_cute_dsl_blockscaling_bmm=self.llm_args.
|
||||
use_cute_dsl_blockscaling_bmm,
|
||||
)
|
||||
|
||||
# Only pass model_kwargs if it's explicitly set (not None)
|
||||
if self.llm_args.model_kwargs is not None:
|
||||
|
||||
@ -301,6 +301,16 @@ def fp4_unswizzled_scale_infer_shape(input_shapes: List[List[int]]):
|
||||
return scale_shape * 2
|
||||
|
||||
|
||||
def fp8_scale_infer_shape(input_shapes: List[List[int]]):
|
||||
"""Calculate the dimensions of the fp8 scale tensor.
|
||||
"""
|
||||
input_shape = input_shapes[0]
|
||||
assert len(input_shape) == 2 or len(input_shape) == 3
|
||||
has_batch = len(input_shape) == 3
|
||||
m = input_shape[-2]
|
||||
return pad_up(m, 4) if has_batch else m
|
||||
|
||||
|
||||
_enable_piecewise_cuda_graph = True
|
||||
|
||||
|
||||
|
||||
@ -3050,6 +3050,18 @@ class TorchLlmArgs(BaseLlmArgs):
|
||||
"Only enable it if you intend to use this feature.",
|
||||
status="prototype")
|
||||
|
||||
# fp8 cute dsl configs
|
||||
use_cute_dsl_blockscaling_mm: bool = Field(
|
||||
default=False,
|
||||
description="If true, use CuTe DSL fp8 blockscaling mm implementation.",
|
||||
status="prototype",
|
||||
)
|
||||
use_cute_dsl_blockscaling_bmm: bool = Field(
|
||||
default=False,
|
||||
description="If true, use CuTe DSL fp8 blockscaling bmm implementation.",
|
||||
status="prototype",
|
||||
)
|
||||
|
||||
# PrivateVars
|
||||
_quant_config: Optional[QuantConfig] = PrivateAttr(default=None)
|
||||
|
||||
|
||||
@ -1533,6 +1533,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
torch_compile_config=torch_compile_config,
|
||||
moe_config=MoeConfig(backend="CUTEDSL"),
|
||||
use_cute_dsl_blockscaling_mm=True,
|
||||
use_cute_dsl_blockscaling_bmm=True,
|
||||
)
|
||||
|
||||
if fp8kv:
|
||||
@ -1695,6 +1697,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
cuda_graph_config=CudaGraphConfig() if cuda_graph else None,
|
||||
torch_compile_config=torch_compile_config,
|
||||
moe_config=MoeConfig(backend="CUTEDSL"),
|
||||
use_cute_dsl_blockscaling_mm=True,
|
||||
use_cute_dsl_blockscaling_bmm=True,
|
||||
)
|
||||
|
||||
if fp8kv:
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -109,6 +109,54 @@ def test_fp8_block_scale_gemm(dtype, m, k, n):
|
||||
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not isSM100Family(),
|
||||
reason="The test is for Blackwell. Current SM is %d." % getSMVersion(),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"k, n",
|
||||
[(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096),
|
||||
(2048, 7168), (1024, 1024)],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"m",
|
||||
[7, 64, 128, 4096],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
[torch.bfloat16],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"use_tvm_ffi",
|
||||
[True, False],
|
||||
)
|
||||
def test_cute_dsl_fp8_block_scale_gemm(dtype, m, k, n, use_tvm_ffi):
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
a = torch.randn((m, k), device='cuda', dtype=dtype) / k
|
||||
b = torch.randn((n, k), device='cuda', dtype=dtype) / k
|
||||
|
||||
act_a_fp8, act_a_sf = torch.ops.trtllm.fp8_quantize_1x128(a)
|
||||
act_b_fp8, act_b_sf = per_block_cast_to_fp8(b)
|
||||
|
||||
output_expected = a @ b.t()
|
||||
|
||||
with autotune():
|
||||
cute_dsl_output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell(
|
||||
act_a_fp8, act_b_fp8, act_a_sf, act_b_sf, use_tvm_ffi=use_tvm_ffi)
|
||||
|
||||
# test Cute DSL kernel
|
||||
cute_dsl_output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell(
|
||||
act_a_fp8, act_b_fp8, act_a_sf, act_b_sf, use_tvm_ffi=use_tvm_ffi)
|
||||
|
||||
diff = calc_diff(cute_dsl_output, output_expected)
|
||||
assert diff < 1e-3
|
||||
torch.testing.assert_close(cute_dsl_output,
|
||||
output_expected,
|
||||
atol=1e-3,
|
||||
rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
getSMVersion() != 90 and getSMVersion() != 89 and getSMVersion() != 120,
|
||||
reason="The test is for Hopper and Ada only. Current SM is %d." %
|
||||
@ -171,6 +219,69 @@ def test_fp8_block_scale_bmm(dtype, m, k, n, num_groups):
|
||||
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not isSM100Family(),
|
||||
reason="The test is for Blackwell. Current SM is %d." % getSMVersion(),
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"k, n",
|
||||
[(7168, 2112), (512, 32768), (16384, 7168), (2048, 7168)],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"m",
|
||||
[7, 64, 128],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"num_groups",
|
||||
[4, 8, 16],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"dtype",
|
||||
[torch.bfloat16],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"use_tvm_ffi",
|
||||
[True, False],
|
||||
)
|
||||
def test_cute_dsl_fp8_block_scale_bmm(dtype, m, k, n, num_groups, use_tvm_ffi):
|
||||
|
||||
torch.random.manual_seed(0)
|
||||
a = torch.randn((m, num_groups, k), device='cuda', dtype=dtype) / k
|
||||
a_fp8, a_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(a)
|
||||
|
||||
b = torch.randn((num_groups, n, k), device='cuda', dtype=dtype) / k
|
||||
b_fp8 = torch.zeros_like(b, device='cuda', dtype=torch.float8_e4m3fn)
|
||||
b_scales = torch.zeros((num_groups, (n + 127) // 128, (k + 127) // 128),
|
||||
device='cuda',
|
||||
dtype=torch.float)
|
||||
|
||||
for i in range(num_groups):
|
||||
b_fp8[i], b_scales[i] = per_block_cast_to_fp8(b[i])
|
||||
|
||||
output_expected = torch.einsum('mgk,gnk->gmn', a, b)
|
||||
output = torch.empty((num_groups, m, n),
|
||||
device='cuda',
|
||||
dtype=torch.bfloat16)
|
||||
# tune
|
||||
with autotune():
|
||||
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(a_fp8,
|
||||
b_fp8,
|
||||
a_scales,
|
||||
b_scales,
|
||||
output,
|
||||
use_tvm_ffi=use_tvm_ffi)
|
||||
# run the tuned kernel
|
||||
torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(a_fp8,
|
||||
b_fp8,
|
||||
a_scales,
|
||||
b_scales,
|
||||
output,
|
||||
use_tvm_ffi=use_tvm_ffi)
|
||||
diff = calc_diff(output, output_expected)
|
||||
assert diff < 1e-3
|
||||
torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def deepSeekFp8ComputeGemmReference(mM, mN, mK, valsC, dqSfsC, valsA, dqSfsA,
|
||||
valsB, dqSfsB, quantizeOutput, tileSize):
|
||||
for mi in range(mM):
|
||||
|
||||
@ -239,6 +239,14 @@ methods:
|
||||
annotation: Optional[Dict[str, Any]]
|
||||
default: null
|
||||
status: prototype
|
||||
use_cute_dsl_blockscaling_mm:
|
||||
annotation: bool
|
||||
default: False
|
||||
status: prototype
|
||||
use_cute_dsl_blockscaling_bmm:
|
||||
annotation: bool
|
||||
default: False
|
||||
status: prototype
|
||||
return_annotation: None
|
||||
generate:
|
||||
parameters:
|
||||
|
||||
Loading…
Reference in New Issue
Block a user