mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] Fix CI issue for dsl pkg install (#7784)
Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com> Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
parent
4f0e6b5f96
commit
14e455da3e
@ -74,4 +74,4 @@ triton==3.3.1; platform_machine == "x86_64"
|
||||
tiktoken
|
||||
blobfile
|
||||
openai-harmony==0.0.4
|
||||
nvidia-cutlass-dsl==4.1.0; python_version >= "3.12"
|
||||
# nvidia-cutlass-dsl==4.1.0; python_version >= "3.12"
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
|
||||
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
|
||||
from ..modules.attention import attn_custom_op_inplace, mla_custom_op_inplace
|
||||
from .cpp_custom_ops import _register_fake
|
||||
@ -15,6 +16,7 @@ __all__ = [
|
||||
'matmul_to_ub',
|
||||
'attn_custom_op_inplace',
|
||||
'mla_custom_op_inplace',
|
||||
'IS_CUTLASS_DSL_AVAILABLE',
|
||||
]
|
||||
|
||||
if IS_FLASHINFER_AVAILABLE:
|
||||
@ -28,3 +30,9 @@ if IS_FLASHINFER_AVAILABLE:
|
||||
'flashinfer_fused_add_rmsnorm',
|
||||
'flashinfer_apply_rope_with_cos_sin_cache_inplace',
|
||||
]
|
||||
|
||||
if IS_CUTLASS_DSL_AVAILABLE:
|
||||
from .cute_dsl_custom_ops import cute_dsl_nvfp4_gemm_blackwell
|
||||
__all__ += [
|
||||
'cute_dsl_nvfp4_gemm_blackwell',
|
||||
]
|
||||
|
||||
264
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Normal file
264
tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py
Normal file
@ -0,0 +1,264 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
import triton # type: ignore[import]
|
||||
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.math_utils import pad_up
|
||||
|
||||
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
|
||||
OptimizationProfile, TunableRunner, TuningConfig)
|
||||
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
|
||||
from ..utils import (fp4_scale_infer_shape,
|
||||
get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2)
|
||||
|
||||
if IS_CUTLASS_DSL_AVAILABLE:
|
||||
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
from tensorrt_llm._torch.cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import (
|
||||
Sm100BlockScaledPersistentDenseGemmKernel,
|
||||
Sm100BlockScaledPersistentDenseGemmKernelWrapper)
|
||||
from tensorrt_llm._torch.cute_dsl_kernels.blackwell.utils import make_ptr
|
||||
|
||||
try:
|
||||
from cuda.bindings import driver as cuda
|
||||
except ImportError:
|
||||
from cuda import cuda
|
||||
|
||||
class CuteDSLNVFP4BlackwellLinear(TunableRunner):
|
||||
kernel_dict = dict()
|
||||
|
||||
tuning_config = TuningConfig(
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2), ),
|
||||
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
|
||||
)
|
||||
|
||||
def __init__(self, alpha: float, output_dtype: torch.dtype):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.output_dtype = output_dtype
|
||||
assert output_dtype == torch.bfloat16
|
||||
|
||||
if get_sm_version() != 100:
|
||||
raise ValueError(
|
||||
f"SM version {get_sm_version()} is not supported for CuteDSLNVFP4BlackwellLinear, it only supports SM 100"
|
||||
)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
**kwargs,
|
||||
) -> List[Tuple[int, int]]:
|
||||
assert inputs[0].dim() == 2
|
||||
assert inputs[1].dim() == 2
|
||||
|
||||
m = inputs[0].shape[0]
|
||||
n = inputs[1].shape[0]
|
||||
k = inputs[0].shape[1]
|
||||
# Note: the input tensor use uint8 to store fp4, so the real_k is k * 2
|
||||
real_k = k * 2
|
||||
batch_size = 1
|
||||
# m,k
|
||||
a_major = "k"
|
||||
# n, k
|
||||
b_major = "k"
|
||||
# m, n
|
||||
c_major = "n"
|
||||
sf_vec_size = 16
|
||||
|
||||
# full shamoo
|
||||
mma_tiler_mn_candidates = [(256, 128), (128, 128), (128, 256),
|
||||
(256, 256), (256, 64), (128, 64)]
|
||||
cluster_shape_mn_candidates = [(1, 1), (1, 2), (1, 4), (2, 1),
|
||||
(2, 2), (2, 4), (4, 1), (4, 2),
|
||||
(4, 4)]
|
||||
return [
|
||||
(mma_tiler_mn, cluster_shape_mn)
|
||||
for mma_tiler_mn in mma_tiler_mn_candidates
|
||||
for cluster_shape_mn in cluster_shape_mn_candidates
|
||||
if Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
|
||||
cutlass.Float4E2M1FN, # ab_dtype,
|
||||
cutlass.Float8E4M3FN, # sf_dtype
|
||||
sf_vec_size, # sf_vec_size,
|
||||
cutlass.BFloat16, # c_dtype,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
m,
|
||||
n,
|
||||
real_k,
|
||||
batch_size,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
)
|
||||
]
|
||||
|
||||
def make_cute_dsl_global_pointer(self, tensor: torch.Tensor, dtype,
|
||||
assumed_align: int):
|
||||
return make_ptr(
|
||||
dtype,
|
||||
tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=assumed_align,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs fp8 blockwise gemm operation using CuTe DSL.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]):
|
||||
inputs[0]: Input tensor of shape (m, k), dtype: fp4.
|
||||
inputs[1]: Weight tensor of shape (n, k), dtype: fp4.
|
||||
inputs[2]: Input scale tensor of shape (k//16, m), dtype: fp8.
|
||||
inputs[3]: Weight scale tensor of shape (n, k//16), dtype: fp8.
|
||||
inputs[4]: Alpha scaling factor. dtype: float32.
|
||||
inputs[5]: Output dtype, expected to be torch.bfloat16.
|
||||
tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
|
||||
"""
|
||||
sf_vec_size = 16
|
||||
|
||||
if isinstance(tactic, tuple):
|
||||
mma_tiler_mn, cluster_shape_mn = tactic
|
||||
else:
|
||||
# fallback to default tactic
|
||||
mma_tiler_mn, cluster_shape_mn = [
|
||||
(128, 128),
|
||||
(1, 1),
|
||||
]
|
||||
|
||||
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs
|
||||
m, k, n = a_tensor.shape[0], a_tensor.shape[1], b_tensor.shape[0]
|
||||
c_tensor = torch.empty(*(m, n),
|
||||
dtype=self.output_dtype,
|
||||
device="cuda")
|
||||
|
||||
real_k = k * 2
|
||||
sf_m = pad_up(m, 128)
|
||||
sf_k = pad_up(real_k // sf_vec_size, 4)
|
||||
sf_n = pad_up(n, 128)
|
||||
|
||||
# the scaling tensor is 1D. we need to make sure it has been padded to the correct shape
|
||||
assert a_sf_tensor.shape == (sf_m * sf_k, )
|
||||
assert b_sf_tensor.shape == (sf_n * sf_k, )
|
||||
|
||||
a_ptr = self.make_cute_dsl_global_pointer(a_tensor,
|
||||
cutlass.Float4E2M1FN, 32)
|
||||
b_ptr = self.make_cute_dsl_global_pointer(b_tensor,
|
||||
cutlass.Float4E2M1FN, 32)
|
||||
a_sf_ptr = self.make_cute_dsl_global_pointer(
|
||||
a_sf_tensor, cutlass.Float8E4M3FN, 16)
|
||||
b_sf_ptr = self.make_cute_dsl_global_pointer(
|
||||
b_sf_tensor, cutlass.Float8E4M3FN, 16)
|
||||
c_ptr = self.make_cute_dsl_global_pointer(c_tensor,
|
||||
cutlass.BFloat16, 16)
|
||||
|
||||
# get stream
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
gemm_wrapper_func = Sm100BlockScaledPersistentDenseGemmKernelWrapper
|
||||
CACHE_KEY = (
|
||||
sf_vec_size,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
)
|
||||
if CACHE_KEY not in CuteDSLNVFP4BlackwellLinear.kernel_dict:
|
||||
gemm = gemm_wrapper_func(
|
||||
sf_vec_size,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
)
|
||||
# Compute max active clusters on current device
|
||||
hardware_info = cutlass.utils.HardwareInfo()
|
||||
max_active_clusters = hardware_info.get_max_active_clusters(
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1])
|
||||
|
||||
compiled_gemm = cute.compile(
|
||||
gemm,
|
||||
m,
|
||||
n,
|
||||
real_k,
|
||||
sf_m // 128,
|
||||
sf_n // 128,
|
||||
sf_k // 4,
|
||||
1,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
a_sf_ptr,
|
||||
b_sf_ptr,
|
||||
c_ptr,
|
||||
self.alpha,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
)
|
||||
|
||||
CuteDSLNVFP4BlackwellLinear.kernel_dict[
|
||||
CACHE_KEY] = compiled_gemm
|
||||
else:
|
||||
compiled_gemm = CuteDSLNVFP4BlackwellLinear.kernel_dict[
|
||||
CACHE_KEY]
|
||||
|
||||
# launch gemm kernel
|
||||
compiled_gemm(m, n, real_k, sf_m // 128, sf_n // 128, sf_k // 4,
|
||||
a_ptr, b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, self.alpha,
|
||||
stream)
|
||||
return c_tensor
|
||||
|
||||
# a/b: fp4, scale: fp8, output: bf16
|
||||
@torch.library.custom_op("trtllm::cute_dsl_nvfp4_gemm_blackwell",
|
||||
mutates_args=(),
|
||||
device_types="cuda")
|
||||
def cute_dsl_nvfp4_gemm_blackwell(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
alpha: float,
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
cute_dsl_nvfp4_gemm_blackwell_runner = CuteDSLNVFP4BlackwellLinear(
|
||||
alpha, output_dtype)
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::cute_dsl_nvfp4_gemm_blackwell",
|
||||
[cute_dsl_nvfp4_gemm_blackwell_runner],
|
||||
CuteDSLNVFP4BlackwellLinear.tuning_config,
|
||||
[input, weight, input_scale, weight_scale],
|
||||
)
|
||||
return cute_dsl_nvfp4_gemm_blackwell_runner(
|
||||
inputs=[input, weight, input_scale, weight_scale],
|
||||
tactic=best_tactic,
|
||||
)
|
||||
|
||||
@torch.library.register_fake("trtllm::cute_dsl_nvfp4_gemm_blackwell")
|
||||
def _(
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
alpha: float,
|
||||
output_dtype: torch.dtype,
|
||||
):
|
||||
# [m, k]
|
||||
shape = list(mat_a.shape)
|
||||
# [n, k]
|
||||
shape[-1] = mat_b.shape[-2]
|
||||
# output is fixed as bf16
|
||||
ret = mat_a.new_empty(shape, dtype=torch.bfloat16)
|
||||
return ret
|
||||
@ -1,4 +1,3 @@
|
||||
import sys
|
||||
from functools import lru_cache
|
||||
from typing import List, Mapping, Optional, Tuple
|
||||
|
||||
@ -9,7 +8,6 @@ import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
|
||||
from tensorrt_llm import deep_gemm
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.math_utils import pad_up
|
||||
|
||||
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
|
||||
OptimizationProfile, TunableRunner, TuningConfig)
|
||||
@ -19,27 +17,6 @@ from ..utils import (fp4_scale_infer_shape,
|
||||
get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2)
|
||||
|
||||
try:
|
||||
if sys.version_info >= (3, 12):
|
||||
HAS_CUTLASS_DSL = True
|
||||
import cutlass
|
||||
import cutlass.cute as cute
|
||||
|
||||
from tensorrt_llm._torch.cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import (
|
||||
Sm100BlockScaledPersistentDenseGemmKernel,
|
||||
Sm100BlockScaledPersistentDenseGemmKernelWrapper)
|
||||
from tensorrt_llm._torch.cute_dsl_kernels.blackwell.utils import \
|
||||
make_ptr
|
||||
else:
|
||||
HAS_CUTLASS_DSL = False
|
||||
except ImportError:
|
||||
HAS_CUTLASS_DSL = False
|
||||
|
||||
try:
|
||||
from cuda.bindings import driver as cuda
|
||||
except ImportError:
|
||||
from cuda import cuda
|
||||
|
||||
|
||||
# Used to WAR an issue in torch.bmm that it would break the graph when the out is not contiguous.
|
||||
@torch.library.custom_op("trtllm::bmm_out", mutates_args=("out", ))
|
||||
@ -1059,241 +1036,6 @@ def _(
|
||||
return x.new_empty((b, d), dtype=o_dtype)
|
||||
|
||||
|
||||
class CuteDSLNVFP4BlackwellLinear(TunableRunner):
|
||||
kernel_dict = dict()
|
||||
|
||||
tuning_config = TuningConfig(
|
||||
dynamic_tensor_specs=(DynamicTensorSpec(
|
||||
0, 0, get_last_power_of_2_num_tokens_buckets,
|
||||
last_positive_power_of_2), ),
|
||||
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
|
||||
)
|
||||
|
||||
def __init__(self, alpha: float, output_dtype: torch.dtype):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.output_dtype = output_dtype
|
||||
assert output_dtype == torch.bfloat16
|
||||
|
||||
if get_sm_version() != 100:
|
||||
raise ValueError(
|
||||
f"SM version {get_sm_version()} is not supported for CuteDSLNVFP4BlackwellLinear, it only supports SM 100"
|
||||
)
|
||||
|
||||
def get_valid_tactics(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
profile: OptimizationProfile,
|
||||
**kwargs,
|
||||
) -> List[Tuple[int, int]]:
|
||||
assert inputs[0].dim() == 2
|
||||
assert inputs[1].dim() == 2
|
||||
|
||||
m = inputs[0].shape[0]
|
||||
n = inputs[1].shape[0]
|
||||
k = inputs[0].shape[1]
|
||||
# Note: the input tensor use uint8 to store fp4, so the real_k is k * 2
|
||||
real_k = k * 2
|
||||
batch_size = 1
|
||||
# m,k
|
||||
a_major = "k"
|
||||
# n, k
|
||||
b_major = "k"
|
||||
# m, n
|
||||
c_major = "n"
|
||||
sf_vec_size = 16
|
||||
|
||||
# full shamoo
|
||||
mma_tiler_mn_candidates = [(256, 128), (128, 128), (128, 256),
|
||||
(256, 256), (256, 64), (128, 64)]
|
||||
cluster_shape_mn_candidates = [(1, 1), (1, 2), (1, 4), (2, 1), (2, 2),
|
||||
(2, 4), (4, 1), (4, 2), (4, 4)]
|
||||
return [
|
||||
(mma_tiler_mn, cluster_shape_mn)
|
||||
for mma_tiler_mn in mma_tiler_mn_candidates
|
||||
for cluster_shape_mn in cluster_shape_mn_candidates
|
||||
if Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
|
||||
cutlass.Float4E2M1FN, # ab_dtype,
|
||||
cutlass.Float8E4M3FN, # sf_dtype
|
||||
sf_vec_size, # sf_vec_size,
|
||||
cutlass.BFloat16, # c_dtype,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
m,
|
||||
n,
|
||||
real_k,
|
||||
batch_size,
|
||||
a_major,
|
||||
b_major,
|
||||
c_major,
|
||||
)
|
||||
]
|
||||
|
||||
def make_cute_dsl_global_pointer(self, tensor: torch.Tensor, dtype,
|
||||
assumed_align: int):
|
||||
return make_ptr(
|
||||
dtype,
|
||||
tensor.data_ptr(),
|
||||
cute.AddressSpace.gmem,
|
||||
assumed_align=assumed_align,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs: List[torch.Tensor],
|
||||
tactic,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Performs fp8 blockwise gemm operation using CuTe DSL.
|
||||
|
||||
Args:
|
||||
inputs (List[torch.Tensor]):
|
||||
inputs[0]: Input tensor of shape (m, k), dtype: fp4.
|
||||
inputs[1]: Weight tensor of shape (n, k), dtype: fp4.
|
||||
inputs[2]: Input scale tensor of shape (k//16, m), dtype: fp8.
|
||||
inputs[3]: Weight scale tensor of shape (n, k//16), dtype: fp8.
|
||||
inputs[4]: Alpha scaling factor. dtype: float32.
|
||||
inputs[5]: Output dtype, expected to be torch.bfloat16.
|
||||
tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
|
||||
"""
|
||||
sf_vec_size = 16
|
||||
|
||||
if isinstance(tactic, tuple):
|
||||
mma_tiler_mn, cluster_shape_mn = tactic
|
||||
else:
|
||||
# fallback to default tactic
|
||||
mma_tiler_mn, cluster_shape_mn = [
|
||||
(128, 128),
|
||||
(1, 1),
|
||||
]
|
||||
|
||||
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs
|
||||
m, k, n = a_tensor.shape[0], a_tensor.shape[1], b_tensor.shape[0]
|
||||
c_tensor = torch.empty(*(m, n), dtype=self.output_dtype, device="cuda")
|
||||
|
||||
real_k = k * 2
|
||||
sf_m = pad_up(m, 128)
|
||||
sf_k = pad_up(real_k // sf_vec_size, 4)
|
||||
sf_n = pad_up(n, 128)
|
||||
|
||||
# the scaling tensor is 1D. we need to make sure it has been padded to the correct shape
|
||||
assert a_sf_tensor.shape == (sf_m * sf_k, )
|
||||
assert b_sf_tensor.shape == (sf_n * sf_k, )
|
||||
|
||||
a_ptr = self.make_cute_dsl_global_pointer(a_tensor,
|
||||
cutlass.Float4E2M1FN, 32)
|
||||
b_ptr = self.make_cute_dsl_global_pointer(b_tensor,
|
||||
cutlass.Float4E2M1FN, 32)
|
||||
a_sf_ptr = self.make_cute_dsl_global_pointer(a_sf_tensor,
|
||||
cutlass.Float8E4M3FN, 16)
|
||||
b_sf_ptr = self.make_cute_dsl_global_pointer(b_sf_tensor,
|
||||
cutlass.Float8E4M3FN, 16)
|
||||
c_ptr = self.make_cute_dsl_global_pointer(c_tensor, cutlass.BFloat16,
|
||||
16)
|
||||
|
||||
# get stream
|
||||
torch_stream = torch.cuda.current_stream()
|
||||
stream = cuda.CUstream(torch_stream.cuda_stream)
|
||||
|
||||
gemm_wrapper_func = Sm100BlockScaledPersistentDenseGemmKernelWrapper
|
||||
CACHE_KEY = (
|
||||
sf_vec_size,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
)
|
||||
if CACHE_KEY not in CuteDSLNVFP4BlackwellLinear.kernel_dict:
|
||||
gemm = gemm_wrapper_func(
|
||||
sf_vec_size,
|
||||
mma_tiler_mn,
|
||||
cluster_shape_mn,
|
||||
)
|
||||
# Compute max active clusters on current device
|
||||
hardware_info = cutlass.utils.HardwareInfo()
|
||||
max_active_clusters = hardware_info.get_max_active_clusters(
|
||||
cluster_shape_mn[0] * cluster_shape_mn[1])
|
||||
|
||||
compiled_gemm = cute.compile(
|
||||
gemm,
|
||||
m,
|
||||
n,
|
||||
real_k,
|
||||
sf_m // 128,
|
||||
sf_n // 128,
|
||||
sf_k // 4,
|
||||
1,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
a_sf_ptr,
|
||||
b_sf_ptr,
|
||||
c_ptr,
|
||||
self.alpha,
|
||||
max_active_clusters,
|
||||
stream,
|
||||
)
|
||||
|
||||
CuteDSLNVFP4BlackwellLinear.kernel_dict[CACHE_KEY] = compiled_gemm
|
||||
else:
|
||||
compiled_gemm = CuteDSLNVFP4BlackwellLinear.kernel_dict[CACHE_KEY]
|
||||
|
||||
# launch gemm kernel
|
||||
compiled_gemm(m, n, real_k, sf_m // 128, sf_n // 128, sf_k // 4, a_ptr,
|
||||
b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, self.alpha, stream)
|
||||
return c_tensor
|
||||
|
||||
|
||||
# a/b: fp4, scale: fp8, output: bf16
|
||||
@torch.library.custom_op("trtllm::cute_dsl_nvfp4_gemm_blackwell",
|
||||
mutates_args=(),
|
||||
device_types="cuda")
|
||||
def cute_dsl_nvfp4_gemm_blackwell(
|
||||
input: torch.Tensor,
|
||||
weight: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
alpha: float,
|
||||
output_dtype: torch.dtype,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if not HAS_CUTLASS_DSL:
|
||||
raise RuntimeError("nvidia-cutlass-dsl 4.1.0 requires Python >=3.12")
|
||||
|
||||
tuner = AutoTuner.get()
|
||||
|
||||
cute_dsl_nvfp4_gemm_blackwell_runner = CuteDSLNVFP4BlackwellLinear(
|
||||
alpha, output_dtype)
|
||||
_, best_tactic = tuner.choose_one(
|
||||
"trtllm::cute_dsl_nvfp4_gemm_blackwell",
|
||||
[cute_dsl_nvfp4_gemm_blackwell_runner],
|
||||
CuteDSLNVFP4BlackwellLinear.tuning_config,
|
||||
[input, weight, input_scale, weight_scale],
|
||||
)
|
||||
return cute_dsl_nvfp4_gemm_blackwell_runner(
|
||||
inputs=[input, weight, input_scale, weight_scale],
|
||||
tactic=best_tactic,
|
||||
)
|
||||
|
||||
|
||||
@torch.library.register_fake("trtllm::cute_dsl_nvfp4_gemm_blackwell")
|
||||
def _(
|
||||
mat_a: torch.Tensor,
|
||||
mat_b: torch.Tensor,
|
||||
input_scale: torch.Tensor,
|
||||
weight_scale: torch.Tensor,
|
||||
alpha: float,
|
||||
output_dtype: torch.dtype,
|
||||
):
|
||||
# [m, k]
|
||||
shape = list(mat_a.shape)
|
||||
# [n, k]
|
||||
shape[-1] = mat_b.shape[-2]
|
||||
# output is fixed as bf16
|
||||
ret = mat_a.new_empty(shape, dtype=torch.bfloat16)
|
||||
return ret
|
||||
|
||||
|
||||
def get_event(event_idx: int):
|
||||
from ..utils import get_model_extra_attrs
|
||||
extra_attrs = get_model_extra_attrs()
|
||||
|
||||
18
tensorrt_llm/_torch/cute_dsl_utils.py
Normal file
18
tensorrt_llm/_torch/cute_dsl_utils.py
Normal file
@ -0,0 +1,18 @@
|
||||
import platform
|
||||
import traceback
|
||||
|
||||
from ..logger import logger
|
||||
|
||||
IS_CUTLASS_DSL_AVAILABLE = False
|
||||
|
||||
if platform.system() != "Windows":
|
||||
try:
|
||||
import cutlass # noqa
|
||||
import cutlass.cute as cute # noqa
|
||||
logger.info(f"cutlass dsl is available")
|
||||
IS_CUTLASS_DSL_AVAILABLE = True
|
||||
except ImportError:
|
||||
traceback.print_exc()
|
||||
print(
|
||||
"cutlass dsl is not installed properly, please try pip install nvidia-cutlass-dsl"
|
||||
)
|
||||
@ -23,6 +23,7 @@ from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
from ..._utils import is_sm_100f
|
||||
from ...models.modeling_utils import QuantConfig
|
||||
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
|
||||
from ..utils import Fp4QuantizedTensor
|
||||
|
||||
|
||||
@ -766,7 +767,7 @@ class NVFP4LinearMethod(LinearMethodBase):
|
||||
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
|
||||
input, module.input_scale, module.scaling_vector_size, False)
|
||||
|
||||
if module.use_cute_dsl_nvfp4_blockscaling_mm:
|
||||
if IS_CUTLASS_DSL_AVAILABLE and module.use_cute_dsl_nvfp4_blockscaling_mm:
|
||||
output = torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell(
|
||||
act_fp4, module.weight, act_sf, module.weight_scale,
|
||||
module.scalar_alpha, module.dtype)
|
||||
|
||||
@ -6,7 +6,9 @@ from utils.util import skip_pre_blackwell
|
||||
|
||||
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
|
||||
from tensorrt_llm._torch.autotuner import autotune
|
||||
from tensorrt_llm._torch.cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
|
||||
from tensorrt_llm._torch.modules.linear import Linear
|
||||
from tensorrt_llm._utils import get_sm_version
|
||||
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
|
||||
|
||||
scaling_vector_size = 16
|
||||
@ -86,7 +88,12 @@ def pad_up(x, pad_size):
|
||||
|
||||
@pytest.mark.skipif(sys.version_info < (3, 12),
|
||||
reason="cutlass-dsl 4.1.0 requires Python 3.12+")
|
||||
@skip_pre_blackwell
|
||||
@pytest.mark.skipif(
|
||||
get_sm_version() != 100,
|
||||
reason="This test is only supported in Blackwell architecture",
|
||||
)
|
||||
@pytest.mark.skipif(not IS_CUTLASS_DSL_AVAILABLE,
|
||||
reason="cutlass-dsl is not available")
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("mnk", [(128, 7168, 16384), (128, 24576, 1536),
|
||||
(128, 2112, 7168), (128, 4096, 7168),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user