[None][fix] Fix CI issue for dsl pkg install (#7784)

Signed-off-by: Mindy Li <11663212+limin2021@users.noreply.github.com>
Co-authored-by: QI JUN <22017000+QiJune@users.noreply.github.com>
This commit is contained in:
Li Min 2025-09-18 13:58:20 +08:00 committed by GitHub
parent 4f0e6b5f96
commit 14e455da3e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 301 additions and 261 deletions

View File

@ -74,4 +74,4 @@ triton==3.3.1; platform_machine == "x86_64"
tiktoken
blobfile
openai-harmony==0.0.4
nvidia-cutlass-dsl==4.1.0; python_version >= "3.12"
# nvidia-cutlass-dsl==4.1.0; python_version >= "3.12"

View File

@ -1,3 +1,4 @@
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
from ..flashinfer_utils import IS_FLASHINFER_AVAILABLE
from ..modules.attention import attn_custom_op_inplace, mla_custom_op_inplace
from .cpp_custom_ops import _register_fake
@ -15,6 +16,7 @@ __all__ = [
'matmul_to_ub',
'attn_custom_op_inplace',
'mla_custom_op_inplace',
'IS_CUTLASS_DSL_AVAILABLE',
]
if IS_FLASHINFER_AVAILABLE:
@ -28,3 +30,9 @@ if IS_FLASHINFER_AVAILABLE:
'flashinfer_fused_add_rmsnorm',
'flashinfer_apply_rope_with_cos_sin_cache_inplace',
]
if IS_CUTLASS_DSL_AVAILABLE:
from .cute_dsl_custom_ops import cute_dsl_nvfp4_gemm_blackwell
__all__ += [
'cute_dsl_nvfp4_gemm_blackwell',
]

View File

@ -0,0 +1,264 @@
from typing import List, Tuple
import torch
import triton # type: ignore[import]
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.math_utils import pad_up
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
OptimizationProfile, TunableRunner, TuningConfig)
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
from ..utils import (fp4_scale_infer_shape,
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2)
if IS_CUTLASS_DSL_AVAILABLE:
import cutlass
import cutlass.cute as cute
from tensorrt_llm._torch.cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import (
Sm100BlockScaledPersistentDenseGemmKernel,
Sm100BlockScaledPersistentDenseGemmKernelWrapper)
from tensorrt_llm._torch.cute_dsl_kernels.blackwell.utils import make_ptr
try:
from cuda.bindings import driver as cuda
except ImportError:
from cuda import cuda
class CuteDSLNVFP4BlackwellLinear(TunableRunner):
kernel_dict = dict()
tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
)
def __init__(self, alpha: float, output_dtype: torch.dtype):
super().__init__()
self.alpha = alpha
self.output_dtype = output_dtype
assert output_dtype == torch.bfloat16
if get_sm_version() != 100:
raise ValueError(
f"SM version {get_sm_version()} is not supported for CuteDSLNVFP4BlackwellLinear, it only supports SM 100"
)
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
**kwargs,
) -> List[Tuple[int, int]]:
assert inputs[0].dim() == 2
assert inputs[1].dim() == 2
m = inputs[0].shape[0]
n = inputs[1].shape[0]
k = inputs[0].shape[1]
# Note: the input tensor use uint8 to store fp4, so the real_k is k * 2
real_k = k * 2
batch_size = 1
# m,k
a_major = "k"
# n, k
b_major = "k"
# m, n
c_major = "n"
sf_vec_size = 16
# full shamoo
mma_tiler_mn_candidates = [(256, 128), (128, 128), (128, 256),
(256, 256), (256, 64), (128, 64)]
cluster_shape_mn_candidates = [(1, 1), (1, 2), (1, 4), (2, 1),
(2, 2), (2, 4), (4, 1), (4, 2),
(4, 4)]
return [
(mma_tiler_mn, cluster_shape_mn)
for mma_tiler_mn in mma_tiler_mn_candidates
for cluster_shape_mn in cluster_shape_mn_candidates
if Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
cutlass.Float4E2M1FN, # ab_dtype,
cutlass.Float8E4M3FN, # sf_dtype
sf_vec_size, # sf_vec_size,
cutlass.BFloat16, # c_dtype,
mma_tiler_mn,
cluster_shape_mn,
m,
n,
real_k,
batch_size,
a_major,
b_major,
c_major,
)
]
def make_cute_dsl_global_pointer(self, tensor: torch.Tensor, dtype,
assumed_align: int):
return make_ptr(
dtype,
tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=assumed_align,
)
def forward(
self,
inputs: List[torch.Tensor],
tactic,
) -> torch.Tensor:
"""
Performs fp8 blockwise gemm operation using CuTe DSL.
Args:
inputs (List[torch.Tensor]):
inputs[0]: Input tensor of shape (m, k), dtype: fp4.
inputs[1]: Weight tensor of shape (n, k), dtype: fp4.
inputs[2]: Input scale tensor of shape (k//16, m), dtype: fp8.
inputs[3]: Weight scale tensor of shape (n, k//16), dtype: fp8.
inputs[4]: Alpha scaling factor. dtype: float32.
inputs[5]: Output dtype, expected to be torch.bfloat16.
tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn).
Returns:
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
"""
sf_vec_size = 16
if isinstance(tactic, tuple):
mma_tiler_mn, cluster_shape_mn = tactic
else:
# fallback to default tactic
mma_tiler_mn, cluster_shape_mn = [
(128, 128),
(1, 1),
]
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs
m, k, n = a_tensor.shape[0], a_tensor.shape[1], b_tensor.shape[0]
c_tensor = torch.empty(*(m, n),
dtype=self.output_dtype,
device="cuda")
real_k = k * 2
sf_m = pad_up(m, 128)
sf_k = pad_up(real_k // sf_vec_size, 4)
sf_n = pad_up(n, 128)
# the scaling tensor is 1D. we need to make sure it has been padded to the correct shape
assert a_sf_tensor.shape == (sf_m * sf_k, )
assert b_sf_tensor.shape == (sf_n * sf_k, )
a_ptr = self.make_cute_dsl_global_pointer(a_tensor,
cutlass.Float4E2M1FN, 32)
b_ptr = self.make_cute_dsl_global_pointer(b_tensor,
cutlass.Float4E2M1FN, 32)
a_sf_ptr = self.make_cute_dsl_global_pointer(
a_sf_tensor, cutlass.Float8E4M3FN, 16)
b_sf_ptr = self.make_cute_dsl_global_pointer(
b_sf_tensor, cutlass.Float8E4M3FN, 16)
c_ptr = self.make_cute_dsl_global_pointer(c_tensor,
cutlass.BFloat16, 16)
# get stream
torch_stream = torch.cuda.current_stream()
stream = cuda.CUstream(torch_stream.cuda_stream)
gemm_wrapper_func = Sm100BlockScaledPersistentDenseGemmKernelWrapper
CACHE_KEY = (
sf_vec_size,
mma_tiler_mn,
cluster_shape_mn,
)
if CACHE_KEY not in CuteDSLNVFP4BlackwellLinear.kernel_dict:
gemm = gemm_wrapper_func(
sf_vec_size,
mma_tiler_mn,
cluster_shape_mn,
)
# Compute max active clusters on current device
hardware_info = cutlass.utils.HardwareInfo()
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1])
compiled_gemm = cute.compile(
gemm,
m,
n,
real_k,
sf_m // 128,
sf_n // 128,
sf_k // 4,
1,
a_ptr,
b_ptr,
a_sf_ptr,
b_sf_ptr,
c_ptr,
self.alpha,
max_active_clusters,
stream,
)
CuteDSLNVFP4BlackwellLinear.kernel_dict[
CACHE_KEY] = compiled_gemm
else:
compiled_gemm = CuteDSLNVFP4BlackwellLinear.kernel_dict[
CACHE_KEY]
# launch gemm kernel
compiled_gemm(m, n, real_k, sf_m // 128, sf_n // 128, sf_k // 4,
a_ptr, b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, self.alpha,
stream)
return c_tensor
# a/b: fp4, scale: fp8, output: bf16
@torch.library.custom_op("trtllm::cute_dsl_nvfp4_gemm_blackwell",
mutates_args=(),
device_types="cuda")
def cute_dsl_nvfp4_gemm_blackwell(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
alpha: float,
output_dtype: torch.dtype,
) -> torch.Tensor:
tuner = AutoTuner.get()
cute_dsl_nvfp4_gemm_blackwell_runner = CuteDSLNVFP4BlackwellLinear(
alpha, output_dtype)
_, best_tactic = tuner.choose_one(
"trtllm::cute_dsl_nvfp4_gemm_blackwell",
[cute_dsl_nvfp4_gemm_blackwell_runner],
CuteDSLNVFP4BlackwellLinear.tuning_config,
[input, weight, input_scale, weight_scale],
)
return cute_dsl_nvfp4_gemm_blackwell_runner(
inputs=[input, weight, input_scale, weight_scale],
tactic=best_tactic,
)
@torch.library.register_fake("trtllm::cute_dsl_nvfp4_gemm_blackwell")
def _(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
alpha: float,
output_dtype: torch.dtype,
):
# [m, k]
shape = list(mat_a.shape)
# [n, k]
shape[-1] = mat_b.shape[-2]
# output is fixed as bf16
ret = mat_a.new_empty(shape, dtype=torch.bfloat16)
return ret

View File

@ -1,4 +1,3 @@
import sys
from functools import lru_cache
from typing import List, Mapping, Optional, Tuple
@ -9,7 +8,6 @@ import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
from tensorrt_llm import deep_gemm
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.math_utils import pad_up
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
OptimizationProfile, TunableRunner, TuningConfig)
@ -19,27 +17,6 @@ from ..utils import (fp4_scale_infer_shape,
get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2)
try:
if sys.version_info >= (3, 12):
HAS_CUTLASS_DSL = True
import cutlass
import cutlass.cute as cute
from tensorrt_llm._torch.cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import (
Sm100BlockScaledPersistentDenseGemmKernel,
Sm100BlockScaledPersistentDenseGemmKernelWrapper)
from tensorrt_llm._torch.cute_dsl_kernels.blackwell.utils import \
make_ptr
else:
HAS_CUTLASS_DSL = False
except ImportError:
HAS_CUTLASS_DSL = False
try:
from cuda.bindings import driver as cuda
except ImportError:
from cuda import cuda
# Used to WAR an issue in torch.bmm that it would break the graph when the out is not contiguous.
@torch.library.custom_op("trtllm::bmm_out", mutates_args=("out", ))
@ -1059,241 +1036,6 @@ def _(
return x.new_empty((b, d), dtype=o_dtype)
class CuteDSLNVFP4BlackwellLinear(TunableRunner):
kernel_dict = dict()
tuning_config = TuningConfig(
dynamic_tensor_specs=(DynamicTensorSpec(
0, 0, get_last_power_of_2_num_tokens_buckets,
last_positive_power_of_2), ),
constraint_specs=(ConstraintSpec(2, 0, fp4_scale_infer_shape), ),
)
def __init__(self, alpha: float, output_dtype: torch.dtype):
super().__init__()
self.alpha = alpha
self.output_dtype = output_dtype
assert output_dtype == torch.bfloat16
if get_sm_version() != 100:
raise ValueError(
f"SM version {get_sm_version()} is not supported for CuteDSLNVFP4BlackwellLinear, it only supports SM 100"
)
def get_valid_tactics(
self,
inputs: List[torch.Tensor],
profile: OptimizationProfile,
**kwargs,
) -> List[Tuple[int, int]]:
assert inputs[0].dim() == 2
assert inputs[1].dim() == 2
m = inputs[0].shape[0]
n = inputs[1].shape[0]
k = inputs[0].shape[1]
# Note: the input tensor use uint8 to store fp4, so the real_k is k * 2
real_k = k * 2
batch_size = 1
# m,k
a_major = "k"
# n, k
b_major = "k"
# m, n
c_major = "n"
sf_vec_size = 16
# full shamoo
mma_tiler_mn_candidates = [(256, 128), (128, 128), (128, 256),
(256, 256), (256, 64), (128, 64)]
cluster_shape_mn_candidates = [(1, 1), (1, 2), (1, 4), (2, 1), (2, 2),
(2, 4), (4, 1), (4, 2), (4, 4)]
return [
(mma_tiler_mn, cluster_shape_mn)
for mma_tiler_mn in mma_tiler_mn_candidates
for cluster_shape_mn in cluster_shape_mn_candidates
if Sm100BlockScaledPersistentDenseGemmKernel.can_implement(
cutlass.Float4E2M1FN, # ab_dtype,
cutlass.Float8E4M3FN, # sf_dtype
sf_vec_size, # sf_vec_size,
cutlass.BFloat16, # c_dtype,
mma_tiler_mn,
cluster_shape_mn,
m,
n,
real_k,
batch_size,
a_major,
b_major,
c_major,
)
]
def make_cute_dsl_global_pointer(self, tensor: torch.Tensor, dtype,
assumed_align: int):
return make_ptr(
dtype,
tensor.data_ptr(),
cute.AddressSpace.gmem,
assumed_align=assumed_align,
)
def forward(
self,
inputs: List[torch.Tensor],
tactic,
) -> torch.Tensor:
"""
Performs fp8 blockwise gemm operation using CuTe DSL.
Args:
inputs (List[torch.Tensor]):
inputs[0]: Input tensor of shape (m, k), dtype: fp4.
inputs[1]: Weight tensor of shape (n, k), dtype: fp4.
inputs[2]: Input scale tensor of shape (k//16, m), dtype: fp8.
inputs[3]: Weight scale tensor of shape (n, k//16), dtype: fp8.
inputs[4]: Alpha scaling factor. dtype: float32.
inputs[5]: Output dtype, expected to be torch.bfloat16.
tactic: Tiling and cluster strategy, typically a tuple (mma_tiler_mn, cluster_shape_mn).
Returns:
torch.Tensor: Output tensor of shape (m, n), dtype: bf16.
"""
sf_vec_size = 16
if isinstance(tactic, tuple):
mma_tiler_mn, cluster_shape_mn = tactic
else:
# fallback to default tactic
mma_tiler_mn, cluster_shape_mn = [
(128, 128),
(1, 1),
]
a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs
m, k, n = a_tensor.shape[0], a_tensor.shape[1], b_tensor.shape[0]
c_tensor = torch.empty(*(m, n), dtype=self.output_dtype, device="cuda")
real_k = k * 2
sf_m = pad_up(m, 128)
sf_k = pad_up(real_k // sf_vec_size, 4)
sf_n = pad_up(n, 128)
# the scaling tensor is 1D. we need to make sure it has been padded to the correct shape
assert a_sf_tensor.shape == (sf_m * sf_k, )
assert b_sf_tensor.shape == (sf_n * sf_k, )
a_ptr = self.make_cute_dsl_global_pointer(a_tensor,
cutlass.Float4E2M1FN, 32)
b_ptr = self.make_cute_dsl_global_pointer(b_tensor,
cutlass.Float4E2M1FN, 32)
a_sf_ptr = self.make_cute_dsl_global_pointer(a_sf_tensor,
cutlass.Float8E4M3FN, 16)
b_sf_ptr = self.make_cute_dsl_global_pointer(b_sf_tensor,
cutlass.Float8E4M3FN, 16)
c_ptr = self.make_cute_dsl_global_pointer(c_tensor, cutlass.BFloat16,
16)
# get stream
torch_stream = torch.cuda.current_stream()
stream = cuda.CUstream(torch_stream.cuda_stream)
gemm_wrapper_func = Sm100BlockScaledPersistentDenseGemmKernelWrapper
CACHE_KEY = (
sf_vec_size,
mma_tiler_mn,
cluster_shape_mn,
)
if CACHE_KEY not in CuteDSLNVFP4BlackwellLinear.kernel_dict:
gemm = gemm_wrapper_func(
sf_vec_size,
mma_tiler_mn,
cluster_shape_mn,
)
# Compute max active clusters on current device
hardware_info = cutlass.utils.HardwareInfo()
max_active_clusters = hardware_info.get_max_active_clusters(
cluster_shape_mn[0] * cluster_shape_mn[1])
compiled_gemm = cute.compile(
gemm,
m,
n,
real_k,
sf_m // 128,
sf_n // 128,
sf_k // 4,
1,
a_ptr,
b_ptr,
a_sf_ptr,
b_sf_ptr,
c_ptr,
self.alpha,
max_active_clusters,
stream,
)
CuteDSLNVFP4BlackwellLinear.kernel_dict[CACHE_KEY] = compiled_gemm
else:
compiled_gemm = CuteDSLNVFP4BlackwellLinear.kernel_dict[CACHE_KEY]
# launch gemm kernel
compiled_gemm(m, n, real_k, sf_m // 128, sf_n // 128, sf_k // 4, a_ptr,
b_ptr, a_sf_ptr, b_sf_ptr, c_ptr, self.alpha, stream)
return c_tensor
# a/b: fp4, scale: fp8, output: bf16
@torch.library.custom_op("trtllm::cute_dsl_nvfp4_gemm_blackwell",
mutates_args=(),
device_types="cuda")
def cute_dsl_nvfp4_gemm_blackwell(
input: torch.Tensor,
weight: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
alpha: float,
output_dtype: torch.dtype,
) -> torch.Tensor:
if not HAS_CUTLASS_DSL:
raise RuntimeError("nvidia-cutlass-dsl 4.1.0 requires Python >=3.12")
tuner = AutoTuner.get()
cute_dsl_nvfp4_gemm_blackwell_runner = CuteDSLNVFP4BlackwellLinear(
alpha, output_dtype)
_, best_tactic = tuner.choose_one(
"trtllm::cute_dsl_nvfp4_gemm_blackwell",
[cute_dsl_nvfp4_gemm_blackwell_runner],
CuteDSLNVFP4BlackwellLinear.tuning_config,
[input, weight, input_scale, weight_scale],
)
return cute_dsl_nvfp4_gemm_blackwell_runner(
inputs=[input, weight, input_scale, weight_scale],
tactic=best_tactic,
)
@torch.library.register_fake("trtllm::cute_dsl_nvfp4_gemm_blackwell")
def _(
mat_a: torch.Tensor,
mat_b: torch.Tensor,
input_scale: torch.Tensor,
weight_scale: torch.Tensor,
alpha: float,
output_dtype: torch.dtype,
):
# [m, k]
shape = list(mat_a.shape)
# [n, k]
shape[-1] = mat_b.shape[-2]
# output is fixed as bf16
ret = mat_a.new_empty(shape, dtype=torch.bfloat16)
return ret
def get_event(event_idx: int):
from ..utils import get_model_extra_attrs
extra_attrs = get_model_extra_attrs()

View File

@ -0,0 +1,18 @@
import platform
import traceback
from ..logger import logger
IS_CUTLASS_DSL_AVAILABLE = False
if platform.system() != "Windows":
try:
import cutlass # noqa
import cutlass.cute as cute # noqa
logger.info(f"cutlass dsl is available")
IS_CUTLASS_DSL_AVAILABLE = True
except ImportError:
traceback.print_exc()
print(
"cutlass dsl is not installed properly, please try pip install nvidia-cutlass-dsl"
)

View File

@ -23,6 +23,7 @@ from tensorrt_llm.quantization.mode import QuantAlgo
from ..._utils import is_sm_100f
from ...models.modeling_utils import QuantConfig
from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
from ..utils import Fp4QuantizedTensor
@ -766,7 +767,7 @@ class NVFP4LinearMethod(LinearMethodBase):
act_fp4, act_sf = torch.ops.trtllm.fp4_quantize(
input, module.input_scale, module.scaling_vector_size, False)
if module.use_cute_dsl_nvfp4_blockscaling_mm:
if IS_CUTLASS_DSL_AVAILABLE and module.use_cute_dsl_nvfp4_blockscaling_mm:
output = torch.ops.trtllm.cute_dsl_nvfp4_gemm_blackwell(
act_fp4, module.weight, act_sf, module.weight_scale,
module.scalar_alpha, module.dtype)

View File

@ -6,7 +6,9 @@ from utils.util import skip_pre_blackwell
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
from tensorrt_llm._torch.autotuner import autotune
from tensorrt_llm._torch.cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE
from tensorrt_llm._torch.modules.linear import Linear
from tensorrt_llm._utils import get_sm_version
from tensorrt_llm.models.modeling_utils import QuantAlgo, QuantConfig
scaling_vector_size = 16
@ -86,7 +88,12 @@ def pad_up(x, pad_size):
@pytest.mark.skipif(sys.version_info < (3, 12),
reason="cutlass-dsl 4.1.0 requires Python 3.12+")
@skip_pre_blackwell
@pytest.mark.skipif(
get_sm_version() != 100,
reason="This test is only supported in Blackwell architecture",
)
@pytest.mark.skipif(not IS_CUTLASS_DSL_AVAILABLE,
reason="cutlass-dsl is not available")
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("mnk", [(128, 7168, 16384), (128, 24576, 1536),
(128, 2112, 7168), (128, 4096, 7168),