From 5521c7b7e72eee8569356012833c2c6568ab4e5a Mon Sep 17 00:00:00 2001 From: yifeizhang-c <219273404+yifeizhang-c@users.noreply.github.com> Date: Fri, 6 Feb 2026 09:49:30 +0800 Subject: [PATCH] [TRTLLM-9457][feat] Add cute dsl fp8 gemm for Blackwell (#10130) Added FP8 cute dsl gemm and batch gemm. Signed-off-by: Yifei Zhang <219273404+yifeizhang-c@users.noreply.github.com> --- cpp/tensorrt_llm/thop/fp8Quantize.cpp | 12 +- tensorrt_llm/_torch/compilation/utils.py | 7 +- .../_torch/custom_ops/cute_dsl_custom_ops.py | 640 +++- .../_torch/custom_ops/torch_custom_ops.py | 6 +- .../blackwell/blockwise_gemm/__init__.py | 0 .../blockwise_gemm/blockwise_gemm.py | 2565 +++++++++++++++++ tensorrt_llm/_torch/model_config.py | 4 + .../_torch/models/modeling_deepseekv3.py | 34 +- tensorrt_llm/_torch/modules/attention.py | 51 +- tensorrt_llm/_torch/modules/linear.py | 3 +- .../_torch/pyexecutor/model_loader.py | 7 +- tensorrt_llm/_torch/utils.py | 10 + tensorrt_llm/llmapi/llm_args.py | 12 + .../defs/accuracy/test_llm_api_pytorch.py | 4 + .../parallel/test_fp8_block_scale_gemm.py | 113 +- .../api_stability/references/llm.yaml | 8 + 16 files changed, 3439 insertions(+), 37 deletions(-) create mode 100644 tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/__init__.py create mode 100644 tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/blockwise_gemm.py diff --git a/cpp/tensorrt_llm/thop/fp8Quantize.cpp b/cpp/tensorrt_llm/thop/fp8Quantize.cpp index 91746a321b..9e94e02950 100644 --- a/cpp/tensorrt_llm/thop/fp8Quantize.cpp +++ b/cpp/tensorrt_llm/thop/fp8Quantize.cpp @@ -121,9 +121,8 @@ std::tuple fp8_batched_quantize_1x128_permute102(at::Ten int64_t scaleSizeInBytes = mGemmRunner.getActScaleSize(m, b * n); int64_t elementSize = scaleSizeInBytes / torch::elementSize(FP8_BLOCK_SCALING_SF_DTYPE); - int m_4_align = (m + 3) / 4 * 4; - at::Tensor scaleFP8SF = at::detail::empty_cuda({b, m_4_align, elementSize / b / m_4_align}, - FP8_BLOCK_SCALING_SF_DTYPE, self.device(), /* stride */ std::nullopt); + at::Tensor scaleFP8SF = at::detail::empty_cuda( + {elementSize}, FP8_BLOCK_SCALING_SF_DTYPE, self.device(), /* stride */ std::nullopt); // 1D tensor __nv_fp8_e4m3* act_buffer = reinterpret_cast<__nv_fp8_e4m3*>(valueE4M3.data_ptr()); float* act_scale_buffer = reinterpret_cast(scaleFP8SF.data_ptr()); @@ -133,6 +132,13 @@ std::tuple fp8_batched_quantize_1x128_permute102(at::Ten auto* output_buffer = reinterpret_cast<__nv_bfloat16 const*>(self.data_ptr()); mGemmRunner.fp8CS1x128Reshape(act_buffer, act_scale_buffer, output_buffer, n, b, m, lda, stream); + // scaleFP8SF = scaleFP8SF[:, 0:num_n_blocks, 0:m_padded] + auto const num_n_blocks = (n + 127) / 128; + auto const act_scal_elesize = b * num_n_blocks * m_padded; + TORCH_CHECK(act_scal_elesize <= scaleFP8SF.numel(), "Scale tensor size mismatch. Expected at least ", + act_scal_elesize, " elements, got ", scaleFP8SF.numel()); + scaleFP8SF = scaleFP8SF.slice(0, 0, act_scal_elesize).view({b, num_n_blocks, m_padded}).contiguous(); + return {valueE4M3.slice(0, 0, b * m * n).view({b, m, n}), scaleFP8SF}; } } // namespace torch_ext diff --git a/tensorrt_llm/_torch/compilation/utils.py b/tensorrt_llm/_torch/compilation/utils.py index d5aa808f6d..2f8b7b19de 100644 --- a/tensorrt_llm/_torch/compilation/utils.py +++ b/tensorrt_llm/_torch/compilation/utils.py @@ -81,14 +81,14 @@ def inplace_info(): 1: "logits" }, torch.ops.trtllm.moe_unpermute_inplace.default: { - 2: "output" + 1: "output" }, torch.ops.trtllm.moe_output_memset_inplace.default: { 1: "input" }, torch.ops.trtllm.cute_dsl_nvfp4_grouped_gemm_finalize_inplace_blackwell.default: { - 6: "output" + 1: "output" }, torch.ops.trtllm.pp_recv_tensors.default: { 1: "tensors" @@ -96,6 +96,9 @@ def inplace_info(): torch.ops.trtllm.pp_send_tensors.default: { 1: "tensors" }, + torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell.default: { + 1: "output" + } } if IS_CUDA_TILE_AVAILABLE: # cuda.tile availability depends on GPU capability thus runtime check. diff --git a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py index a75b9aeddf..a7504a8b85 100644 --- a/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/cute_dsl_custom_ops.py @@ -6,13 +6,13 @@ import torch from tensorrt_llm.logger import logger -from ..._utils import get_sm_version +from ..._utils import get_sm_version, is_sm_100f from ...math_utils import ceil_div, pad_up from ..autotuner import (AutoTuner, ConstraintSpec, DistributedTuningStrategy, DynamicTensorSpec, OptimizationProfile, TunableRunner, TuningConfig) from ..cute_dsl_utils import IS_CUTLASS_DSL_AVAILABLE -from ..utils import (fp4_scale_infer_shape, +from ..utils import (fp4_scale_infer_shape, fp8_scale_infer_shape, get_last_power_of_2_num_tokens_buckets, last_positive_power_of_2) @@ -314,11 +314,13 @@ if IS_CUTLASS_DSL_AVAILABLE: Sm100BlockScaledContiguousGroupedGemmFinalizeFusionKernel from ..cute_dsl_kernels.blackwell.blockscaled_contiguous_grouped_gemm_swiglu_fusion import \ Sm100BlockScaledContiguousGroupedGemmSwigluFusionKernel + from ..cute_dsl_kernels.blackwell.blockwise_gemm.blockwise_gemm import \ + Sm100BlockwiseGemmKernel from ..cute_dsl_kernels.blackwell.dense_blockscaled_gemm_persistent import \ Sm100BlockScaledPersistentDenseGemmKernel from ..cute_dsl_kernels.blackwell.utils import make_ptr - class CuteDSLNVFP4BlackwellLinear(TunableRunner): + class CuteDSLNVFP4BlackwellRunner(TunableRunner): kernel_class = Sm100BlockScaledPersistentDenseGemmKernel kernel_cache = dict() tuning_config = TuningConfig( @@ -500,7 +502,7 @@ if IS_CUTLASS_DSL_AVAILABLE: **kwargs, ) -> torch.Tensor: """ - Performs fp8 blockwise gemm operation using CuTe DSL. + Performs fp4 blockwise gemm operation using CuTe DSL. Args: inputs (List[torch.Tensor]): @@ -590,7 +592,7 @@ if IS_CUTLASS_DSL_AVAILABLE: stream = cuda.CUstream(torch_stream.cuda_stream) cache_key = (sf_vec_size, mma_tiler_mn, cluster_shape_mn, swap_ab, - use_prefetch) + use_prefetch, self.use_tvm_ffi) if swap_ab: kernel_m = n kernel_n = m @@ -770,7 +772,7 @@ if IS_CUTLASS_DSL_AVAILABLE: tuner = AutoTuner.get() - runner = CuteDSLNVFP4BlackwellLinear(output_dtype, to_userbuffers, + runner = CuteDSLNVFP4BlackwellRunner(output_dtype, to_userbuffers, use_tvm_ffi) inputs = [input, weight, input_scale, weight_scale, alpha] _, best_tactic = tuner.choose_one( @@ -2161,3 +2163,629 @@ if IS_CUTLASS_DSL_AVAILABLE: dtype=input_scale.dtype, device=input_scale.device) return output, output_scale + + class CuteDSLFp8BlackwellRunner(TunableRunner): + kernel_class = Sm100BlockwiseGemmKernel + kernel_cache = dict() + + tuning_config = TuningConfig( + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 0, get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2), ), + constraint_specs=(ConstraintSpec(2, 1, fp8_scale_infer_shape), ), + ) + + def __init__(self, + output_dtype: torch.dtype = torch.bfloat16, + use_tvm_ffi: bool = True): + super().__init__() + if output_dtype != torch.bfloat16: + raise ValueError( + f"CuteDSL FP8 GEMM only supports bfloat16 output, got {output_dtype}" + ) + self.output_dtype = output_dtype + self.use_tvm_ffi = use_tvm_ffi + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + **kwargs, + ) -> List[int]: + if not is_sm_100f(): + logger.debug( + f"CuteDSL: SM version {get_sm_version()} is not supported. " + f"CuteDSL FP8 GEMM only supports SM 100 family. Skipping all tactics." + ) + return [] + + m = inputs[0].shape[0] + n = inputs[1].shape[0] + k = inputs[0].shape[1] + batch_size = 1 + # m,k + a_major = "k" + # n, k + b_major = "k" + # m, n + c_major = "n" + + use_2cta_instrs_candi = [False, True] + mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)] + cluster_shape_mn_candi = [ + (1, 1), + (1, 2), + (1, 4), + (2, 1), + (2, 2), + (2, 4), + (4, 1), + (4, 2), + (4, 4), + ] + return [ + (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn) + for use_2cta_instrs in use_2cta_instrs_candi + for mma_tiler_mn in mma_tiler_mn_candi + for cluster_shape_mn in cluster_shape_mn_candi + if self.__class__.kernel_class.can_implement( + cutlass.Float8E4M3FN, # ab_dtype, + cutlass.Float32, # acc_dtype, + cutlass.BFloat16, # c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + batch_size, + a_major, + b_major, + c_major, + ) + ] + + def forward( + self, + inputs: List[torch.Tensor], + tactic, + ) -> torch.Tensor: + """ + Performs fp8 blockwise (deepgemm like) operation using CuTe DSL. + + Args: + inputs (List[torch.Tensor]): + inputs[0]: Input tensor of shape (m, k), dtype: fp8. + inputs[1]: Weight tensor of shape (n, k), dtype: fp8. + inputs[2]: Input scale factor tensor of shape (k // 128, m), dtype: fp32. + inputs[3]: Weight scale factor tensor of shape (n // 128, k // 128), dtype: fp32. + tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn). + + Returns: + torch.Tensor: Output tensor of shape (m, n), dtype: bf16. + """ + if isinstance(tactic, tuple): + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic + else: + # fallback to default tactic + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [ + False, + (128, 128), + (1, 1), + ] + a_tensor, b_tensor, a_sf_tensor, b_sf_tensor = inputs + m, n, k = a_tensor.shape[0], b_tensor.shape[0], b_tensor.shape[1] + sf_m = m + sf_k = ceil_div(k, 128) + sf_n = ceil_div(n, 128) + c_tensor = torch.empty(*(m, n), + dtype=torch.bfloat16, + device=a_tensor.device) + c_tmp = c_tensor.view((1, m, n)) + c_tmp = c_tmp.permute(1, 2, 0) + + if not self.use_tvm_ffi: + a_ptr = make_ptr( + cutlass.Float8E4M3FN, + a_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_ptr = make_ptr( + cutlass.Float8E4M3FN, + b_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + a_sf_ptr = make_ptr( + cutlass.Float32, + a_sf_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_sf_ptr = make_ptr( + cutlass.Float32, + b_sf_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + c_cute_tensor = cute.runtime.from_dlpack( + c_tmp).mark_layout_dynamic(leading_dim=1) + + # get stream + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + cache_key = ( + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + self.use_tvm_ffi, + ) + if cache_key not in self.__class__.kernel_cache: + if self.use_tvm_ffi: + a_ptr = make_ptr( + cutlass.Float8E4M3FN, + a_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_ptr = make_ptr( + cutlass.Float8E4M3FN, + b_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + a_sf_ptr = make_ptr( + cutlass.Float32, + a_sf_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_sf_ptr = make_ptr( + cutlass.Float32, + b_sf_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + # Convert c_tensor to cute tensor for TVM FFI for env stream detection + c_cute_tensor = cute.runtime.from_dlpack( + c_tmp).mark_layout_dynamic(leading_dim=1) + stream = cute.runtime.make_fake_stream( + use_tvm_ffi_env_stream=True) + + gemm = self.__class__.kernel_class( + cutlass.Float32, # acc_dtype, + use_2cta_instrs=use_2cta_instrs, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + ) + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1]) + + compiled_gemm = cute.compile( + gemm.wrapper, + m, + n, + k, + sf_m, + sf_n, + sf_k, + 1, # batch + a_ptr, + b_ptr, + a_sf_ptr, + b_sf_ptr, + c_cute_tensor, + max_active_clusters=max_active_clusters, + stream=stream, + options=f"--opt-level 2 --enable-tvm-ffi" + if self.use_tvm_ffi else "--opt-level 2", + ) + self.__class__.kernel_cache[cache_key] = compiled_gemm + else: + compiled_gemm = self.__class__.kernel_cache[cache_key] + + # launch gemm kernel + if self.use_tvm_ffi: + # call with torch pointer types and no need to pass stream. + compiled_gemm( + m, + n, + k, + sf_m, + sf_n, + sf_k, + 1, # batch + a_tensor.data_ptr(), + b_tensor.data_ptr(), + a_sf_tensor.data_ptr(), + b_sf_tensor.data_ptr(), + c_tmp, + ) + else: + # call with cute types and need to pass torch stream. + compiled_gemm( + m, + n, + k, + sf_m, + sf_n, + sf_k, + 1, # batch + a_ptr, + b_ptr, + a_sf_ptr, + b_sf_ptr, + c_cute_tensor, + stream=stream, + ) + return c_tensor + + # a/b: fp8, scale: fp32, output: bf16 + @torch.library.custom_op("trtllm::cute_dsl_fp8_gemm_blackwell", + mutates_args=(), + device_types="cuda") + def cute_dsl_fp8_gemm_blackwell( + input: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + output_dtype: torch.dtype = torch.bfloat16, + use_tvm_ffi: bool = True, + ) -> torch.Tensor: + if output_dtype != torch.bfloat16: + raise ValueError( + f"CuteDSL FP8 GEMM only supports bfloat16 output, got {output_dtype}" + ) + if not is_sm_100f(): + raise ValueError( + f"CuteDSL: SM version {get_sm_version()} is not supported. " + f"CuteDSL FP8 GEMM only supports SM 100 family. Skipping all tactics." + ) + tuner = AutoTuner.get() + + runner = CuteDSLFp8BlackwellRunner(output_dtype=output_dtype, + use_tvm_ffi=use_tvm_ffi) + + inputs = [input, weight, input_scale, weight_scale] + _, best_tactic = tuner.choose_one( + "trtllm::cute_dsl_fp8_gemm_blackwell::gemm", + [runner], + runner.__class__.tuning_config, + inputs, + ) + return runner(inputs, tactic=best_tactic) + + @torch.library.register_fake("trtllm::cute_dsl_fp8_gemm_blackwell") + def _( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + output_dtype: torch.dtype = torch.bfloat16, + use_tvm_ffi: bool = True, + ): + # [m, k] + shape = list(mat_a.shape) + # [n, k] + shape[-1] = mat_b.shape[-2] + # output is fixed as bf16 + ret = mat_a.new_empty(shape, dtype=torch.bfloat16) + return ret + + class CuteDSLFp8BlackwellBmmRunner(TunableRunner): + kernel_class = Sm100BlockwiseGemmKernel + kernel_cache = dict() + + tuning_config = TuningConfig( + dynamic_tensor_specs=(DynamicTensorSpec( + 0, 1, get_last_power_of_2_num_tokens_buckets, + last_positive_power_of_2), ), + constraint_specs=(ConstraintSpec(2, 2, fp8_scale_infer_shape), ), + ) + + def __init__(self, + output_dtype: torch.dtype = torch.bfloat16, + use_tvm_ffi: bool = True): + super().__init__() + if output_dtype != torch.bfloat16: + raise ValueError( + f"CuteDSL FP8 BMM only supports bfloat16 output, got {output_dtype}" + ) + self.output_dtype = output_dtype + self.use_tvm_ffi = use_tvm_ffi + + def get_valid_tactics( + self, + inputs: List[torch.Tensor], + profile: OptimizationProfile, + **kwargs, + ) -> List[int]: + + if not is_sm_100f(): + logger.debug( + f"CuteDSL: SM version {get_sm_version()} is not supported. " + f"CuteDSL FP8 BMM only supports SM 100 family. Skipping all tactics." + ) + return [] + # [b, m, k] + batch_size, m, k = inputs[0].shape[0], inputs[0].shape[1], inputs[ + 0].shape[2] + # [b, n, k] + n = inputs[1].shape[1] + # m,k + a_major = "k" + # n, k + b_major = "k" + # m, n + c_major = "n" + + use_2cta_instrs_candi = [False, True] + mma_tiler_mn_candi = [(64, 128), (128, 128), (256, 128)] + cluster_shape_mn_candi = [ + (1, 1), + (1, 2), + (1, 4), + (2, 1), + (2, 2), + (2, 4), + (4, 1), + (4, 2), + (4, 4), + ] + return [ + (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn) + for use_2cta_instrs in use_2cta_instrs_candi + for mma_tiler_mn in mma_tiler_mn_candi + for cluster_shape_mn in cluster_shape_mn_candi + if self.__class__.kernel_class.can_implement( + cutlass.Float8E4M3FN, # ab_dtype, + cutlass.Float32, # acc_dtype, + cutlass.BFloat16, # c_dtype, + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + batch_size, + a_major, + b_major, + c_major, + ) + ] + + def forward( + self, + inputs: List[torch.Tensor], + tactic, + ) -> None: + """ + Performs fp8 blockwise (deepgemm like) batched gemm operation using CuTe DSL. + + Args: + inputs (List[torch.Tensor]): + inputs[0]: Input tensor of shape (batch_size, m, k), dtype: fp8. + inputs[1]: Weight tensor of shape (batch_size, n, k), dtype: fp8. + inputs[2]: Input scale tensor of shape (batch_size, k // 128, pad_up(m, 4)), dtype: fp32. + inputs[3]: Weight scale tensor of shape (batch_size, n // 128, k // 128), dtype: fp32. + tactic: Tiling and cluster strategy, typically a tuple (use_2cta_instrs, mma_tiler_mn, cluster_shape_mn). + + Returns: + torch.Tensor: Output tensor of shape (batch_size, m, n), dtype: bf16. + """ + if isinstance(tactic, tuple): + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = tactic + else: + # fallback to default tactic + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn = [ + False, + (128, 128), + (1, 1), + ] + + a_tensor, b_tensor, a_sf_tensor, b_sf_tensor, c_tensor = inputs + c_tmp = c_tensor.permute(1, 2, 0) + + batch_size = a_tensor.shape[0] + m = a_tensor.shape[1] + k = a_tensor.shape[2] + n = b_tensor.shape[1] + sf_m = pad_up(m, 4) + sf_k = ceil_div(k, 128) + sf_n = ceil_div(n, 128) + + if not self.use_tvm_ffi: + a_ptr = make_ptr( + cutlass.Float8E4M3FN, + a_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_ptr = make_ptr( + cutlass.Float8E4M3FN, + b_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + a_sf_ptr = make_ptr( + cutlass.Float32, + a_sf_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_sf_ptr = make_ptr( + cutlass.Float32, + b_sf_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + c_cute_tensor = cute.runtime.from_dlpack( + c_tmp).mark_layout_dynamic(leading_dim=1) + + # get stream + stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + cache_key = ( + use_2cta_instrs, + mma_tiler_mn, + cluster_shape_mn, + self.use_tvm_ffi, + ) + if cache_key not in self.__class__.kernel_cache: + if self.use_tvm_ffi: + a_ptr = make_ptr( + cutlass.Float8E4M3FN, + a_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_ptr = make_ptr( + cutlass.Float8E4M3FN, + b_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + a_sf_ptr = make_ptr( + cutlass.Float32, + a_sf_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + b_sf_ptr = make_ptr( + cutlass.Float32, + b_sf_tensor.data_ptr(), + cute.AddressSpace.gmem, + assumed_align=16, + ) + # Convert c_tensor to cute tensor for TVM FFI for env stream detection) + c_cute_tensor = cute.runtime.from_dlpack( + c_tmp).mark_layout_dynamic(leading_dim=1) + # make faked stream for TVM FFI + stream = cute.runtime.make_fake_stream( + use_tvm_ffi_env_stream=True) + + gemm = self.__class__.kernel_class( + cutlass.Float32, # acc_dtype, + use_2cta_instrs=use_2cta_instrs, + mma_tiler_mn=mma_tiler_mn, + cluster_shape_mn=cluster_shape_mn, + ) + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1]) + + compiled_gemm = cute.compile( + gemm.wrapper, + m, + n, + k, + sf_m, + sf_n, + sf_k, + batch_size, + a_ptr, + b_ptr, + a_sf_ptr, + b_sf_ptr, + c_cute_tensor, + max_active_clusters=max_active_clusters, + stream=stream, + options=f"--opt-level 2 --enable-tvm-ffi" + if self.use_tvm_ffi else "--opt-level 2", + ) + self.__class__.kernel_cache[cache_key] = compiled_gemm + else: + compiled_gemm = self.__class__.kernel_cache[cache_key] + + # launch gemm kernel + if self.use_tvm_ffi: + # call with torch pointer types and no need to pass stream. + compiled_gemm( + m, + n, + k, + sf_m, + sf_n, + sf_k, + batch_size, + a_tensor.data_ptr(), + b_tensor.data_ptr(), + a_sf_tensor.data_ptr(), + b_sf_tensor.data_ptr(), + c_tmp, + ) + else: + # call with cute types and need to pass torch stream. + compiled_gemm( + m, + n, + k, + sf_m, + sf_n, + sf_k, + batch_size, + a_ptr, + b_ptr, + a_sf_ptr, + b_sf_ptr, + c_cute_tensor, + stream=stream, + ) + + # a/b: fp8, scale: fp32, output: bf16 + @torch.library.custom_op("trtllm::cute_dsl_fp8_bmm_blackwell", + mutates_args=("output", ), + device_types="cuda") + def cute_dsl_fp8_bmm_blackwell( + input: torch.Tensor, + weight: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + output: torch.Tensor, + output_dtype: torch.dtype = torch.bfloat16, + use_tvm_ffi: bool = True, + ) -> None: + if output_dtype != torch.bfloat16: + raise ValueError( + f"CuteDSL FP8 BMM only supports bfloat16 output, got {output_dtype}" + ) + if not is_sm_100f(): + raise ValueError( + f"CuteDSL: SM version {get_sm_version()} is not supported. " + f"CuteDSL FP8 BMM only supports SM 100 family. Skipping all tactics." + ) + + tuner = AutoTuner.get() + + runner = CuteDSLFp8BlackwellBmmRunner(output_dtype=output_dtype, + use_tvm_ffi=use_tvm_ffi) + + inputs = [input, weight, input_scale, weight_scale, output] + + _, best_tactic = tuner.choose_one( + "trtllm::cute_dsl_fp8_bmm_blackwell::gemm", + [runner], + runner.__class__.tuning_config, + inputs, + ) + runner(inputs, tactic=best_tactic) + + @torch.library.register_fake("trtllm::cute_dsl_fp8_bmm_blackwell") + def _( + mat_a: torch.Tensor, + mat_b: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + output: torch.Tensor, + output_dtype: torch.dtype = torch.bfloat16, + use_tvm_ffi: bool = True, + ) -> None: + batch_size, m, k = mat_a.shape[0], mat_a.shape[1], mat_a.shape[2] + n = mat_b.shape[1] + assert output.dtype == torch.bfloat16, "CuTe DSL fp8 bmm output dtype must be bf16" + assert output.shape == (batch_size, m, + n), "CuTe DSL fp8 bmm output shape is incorrect" diff --git a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py index 4804e097dd..a8201cf714 100644 --- a/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py +++ b/tensorrt_llm/_torch/custom_ops/torch_custom_ops.py @@ -25,7 +25,7 @@ from ..utils import (ActivationType, fp4_scale_infer_shape, if IS_CUTLASS_DSL_AVAILABLE: from tensorrt_llm._torch.custom_ops.cute_dsl_custom_ops import \ - CuteDSLNVFP4BlackwellLinear + CuteDSLNVFP4BlackwellRunner # Used to WAR an issue in torch.bmm that it would break the graph when the out is not contiguous. @@ -819,7 +819,7 @@ class NVFP4GemmUnifiedRunner(TunableRunner): "Please add other backends to allowed_backends.") else: # SM version OK, check if CuteDSL supports the current shape - cutedsl_runner = CuteDSLNVFP4BlackwellLinear( + cutedsl_runner = CuteDSLNVFP4BlackwellRunner( self.output_dtype) cutedsl_tactics = cutedsl_runner.get_valid_tactics( inputs, profile) @@ -878,7 +878,7 @@ class NVFP4GemmUnifiedRunner(TunableRunner): self.output_dtype)(inputs, tactic=sub_tactic) elif backend == "cutedsl": - return CuteDSLNVFP4BlackwellLinear( + return CuteDSLNVFP4BlackwellRunner( self.output_dtype, self.to_userbuffers)(inputs, tactic=sub_tactic) else: diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/__init__.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/blockwise_gemm.py b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/blockwise_gemm.py new file mode 100644 index 0000000000..43b64cbeb2 --- /dev/null +++ b/tensorrt_llm/_torch/cute_dsl_kernels/blackwell/blockwise_gemm/blockwise_gemm.py @@ -0,0 +1,2565 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# This file is copied and modified from cutlass example https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/blockwise_gemm/blockwise_gemm.py + +import math +from typing import Tuple, Type, Union + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.nvgpu import cpasync, tcgen05 +from cutlass.pipeline import pipeline_init_arrive, pipeline_init_wait + +""" +High-performance persistent blockwise dense GEMM (C = (SFA * A) * (SFB * B)) example for the NVIDIA Blackwell +architecture using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") +- Matrix B is NxKxL, L is batch dimension, B can be column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") +- Each block will apply the scale factor A +- Each row will apply the scale factor B +- For each iteration, the kernel will compute C = A * B and then apply the scale factor C *= SFA * SFB + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. SCALE warp: Load scaleA and scaleB matrices from global memory (GMEM) to shared memory (SMEM) using non-TMA + operations. +2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Apply the scale factor and update the final accumulator Final = C * SFA * SFB + Final + - Type convert Final matrix to output type. + - Store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations. + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +.. code-block:: bash + + python examples/blackwell/blockwise_gemm/blockwise_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 4096,4096,4096,4 + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/blockwise_gemm/blockwise_gemm.py \ + --ab_dtype Float8E4M3FN --c_dtype BFloat16 --acc_dtype Float32 \ + --scale_dtype Float32 \ + --mma_tiler_mn 128,128 --cluster_shape_mn 1,2 \ + --mnkl 4096,4096,4096,4 + + +Constraints are same as dense_gemm.py: +* Supported input data types: fp8 (e4m3fn) + see detailed valid dtype combinations in below Sm100BlockwiseGemmKernel class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128/256 +* Mma tiler N must be 128, align with the scaleB requirement +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned +""" + + +class Sm100BlockwiseGemmKernel: + """This class implements batched matrix multiplication (C = (SFA * A) * (SFB * B)) with support for fp8 (e4m3fn, + e5m2) and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use CTA group 2 for advanced thread cooperation + :type use_2cta_instrs: bool + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: Supported A/B data types: + - Float8E4M3FN + + :note: Supported accumulator data types: + - Float32 + + :note: Supported C data types: + - Float16/BFloat16 + - Other data types are not supported for accuracy issues + + :note: Constraints: + - MMA tiler M must be 64/128/256 + - MMA tiler N must be 128 + - Cluster shape M must be multiple of 2 + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = Sm100BlockwiseGemmKernel( + ... acc_dtype=cutlass.Float32, + ... use_2cta_instrs=True, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2), + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, sfa_tensor, sfb_tensor, max_active_clusters, stream) + """ + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ): + """Initializes the configuration for a Blackwell blockwise dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param use_2cta_instrs: Boolean, True to use cta_group=2 MMA variant. + :type use_2cta_instrs: bool + :param cluster_shape_mn: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mn: Tuple[int, int] + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = use_2cta_instrs + self.cluster_shape_mn = cluster_shape_mn + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + + self.cta_group = tcgen05.CtaGroup.TWO if use_2cta_instrs else tcgen05.CtaGroup.ONE + + self.occupancy = 1 + # Set specialized warp ids + self.acc_update_warp_id = (0, 1, 2, 3) + self.epilog_warp_id = (4, 5, 6, 7) + self.mma_warp_id = 8 + self.tma_warp_id = 9 + self.scale_warp_id = 10 + self.sched_warp_id = 11 + self.threads_per_warp = 32 + self.threads_per_cta = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + self.sched_warp_id, + ) + ) + self.threads_wo_sched = self.threads_per_warp * len( + ( + *self.acc_update_warp_id, + *self.epilog_warp_id, + self.mma_warp_id, + self.tma_warp_id, + self.scale_warp_id, + ) + ) + self.num_regs_uniform_warps = 64 + self.num_regs_sched_warps = 64 + self.num_regs_epilogue_warps = 216 + self.num_regs_acc_update_warps = 216 + + # Set barrier for epilogue sync and tmem ptr sync + self.epilog_sync_barrier = pipeline.NamedBarrier( + barrier_id=1, + num_threads=32 * len(self.epilog_warp_id), + ) + self.tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=2, + num_threads=32 + * len((self.mma_warp_id, *self.epilog_warp_id, *self.acc_update_warp_id)), + ) + self.sched_sync_barrier = pipeline.NamedBarrier( + barrier_id=3, + num_threads=self.threads_per_warp, + ) + self.num_smem_capacity = utils.get_smem_capacity_in_bytes("sm_100") + # TMEM offset for final accumulator + self.tmem_final_offset = 384 + + def _setup_attributes(self): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Configure tiled mma + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_shape_k = cute.size(tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout((*self.cluster_shape_mn, 1)), + (tiled_mma.thr_id.shape,), + ) + + self.scale_granularity_m = 1 + self.scale_granularity_n = 128 + self.scale_granularity_k = 128 + self.scale_m_per_tile = self.cta_tile_shape_mnk[0] // self.scale_granularity_m + self.scale_n_per_tile = self.cta_tile_shape_mnk[1] // self.scale_granularity_n + self.scale_k_per_tile = self.cta_tile_shape_mnk[2] // self.scale_granularity_k + + if self.scale_k_per_tile != 1: + raise ValueError("scale_k_per_tile must be 1") + if self.scale_m_per_tile != self.cta_tile_shape_mnk[0]: + raise ValueError("scale_m_per_tile must be cta_tile_m") + if self.scale_n_per_tile != 1: + raise ValueError("scale_n_per_tile must be 1") + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.c_layout, + self.c_dtype, + ) + + # Setup A/B/C/Scale stage count in shared memory and ACC stage count in tensor memory + ( + self.num_acc_stage, + self.num_ab_stage, + self.num_c_stage, + self.num_scale_stage, + self.num_tile_stage, + ) = self._compute_stages( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.b_dtype, + self.epi_tile, + self.c_dtype, + self.c_layout, + self.sfa_dtype, + self.sfb_dtype, + self.scale_m_per_tile * self.scale_k_per_tile, + self.scale_n_per_tile * self.scale_k_per_tile, + self.num_smem_capacity, + self.occupancy, + ) + + # Compute A/B/C/Scale shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + tiled_mma, + self.mma_tiler, + self.a_dtype, + self.num_ab_stage, + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + tiled_mma, + self.mma_tiler, + self.b_dtype, + self.num_ab_stage, + ) + self.c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, + self.c_layout, + self.epi_tile, + self.num_c_stage, + ) + self.sfa_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_m_per_tile, + ), + ) + self.sfb_smem_layout_staged = cute.make_layout( + ( + (self.scale_granularity_n, self.scale_n_per_tile), + (self.scale_granularity_k, self.scale_k_per_tile), + self.num_scale_stage, + ), + stride=( + (0, self.scale_k_per_tile), + (0, 1), + self.scale_k_per_tile * self.scale_n_per_tile, + ), + ) + + # Compute the number of tensor memory allocation columns + self.num_tmem_alloc_cols = 512 + + @cute.jit + def __call__( + self, + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, + sfa: cute.Tensor, + sfb: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + epilogue_op: cutlass.Constexpr = lambda x: x, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param a: Input tensor A + :type a: cute.Tensor + :param b: Input tensor B + :type b: cute.Tensor + :param c: Output tensor C + :type c: cute.Tensor + :param sfa: Scale factor tensor A + :type sfa: cute.Tensor + :param sfb: Scale factor tensor B + :type sfb: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :param epilogue_op: Optional elementwise lambda function to apply to the output tensor + :type epilogue_op: cutlass.Constexpr + :raises TypeError: If input data types are incompatible with the MMA instruction. + """ + # Setup static attributes before smem/grid/tma computation + self.a_dtype: Type[cutlass.Numeric] = a.element_type + self.b_dtype: Type[cutlass.Numeric] = b.element_type + self.c_dtype: Type[cutlass.Numeric] = c.element_type + self.sfa_dtype: Type[cutlass.Numeric] = sfa.element_type + self.sfb_dtype: Type[cutlass.Numeric] = sfb.element_type + self.a_major_mode = utils.LayoutEnum.from_tensor(a).mma_major_mode() + self.b_major_mode = utils.LayoutEnum.from_tensor(b).mma_major_mode() + self.c_layout = utils.LayoutEnum.from_tensor(c) + + # Check if input data types are compatible with MMA instruction + if cutlass.const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + # Setup attributes that dependent on gemm inputs + self._setup_attributes() + + tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + atom_thr_size = cute.size(tiled_mma.thr_id.shape) + + # Setup TMA load for A + a_op = self._get_tma_atom_kind(atom_thr_size, self.is_a_mcast) + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + a, + a_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=(cutlass.TFloat32 if a.element_type is cutlass.Float32 else None), + ) + + # Setup TMA load for B + b_op = self._get_tma_atom_kind(atom_thr_size, self.is_b_mcast) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + b, + b_smem_layout, + self.mma_tiler, + tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=(cutlass.TFloat32 if b.element_type is cutlass.Float32 else None), + ) + + a_copy_size = cute.size_in_bytes(self.a_dtype, a_smem_layout) + b_copy_size = cute.size_in_bytes(self.b_dtype, b_smem_layout) + self.num_tma_load_bytes = (a_copy_size + b_copy_size) * atom_thr_size + + # Setup TMA store for C + tma_atom_c = None + tma_tensor_c = None + c_cta_v_layout = cute.composition(cute.make_identity_layout(c.shape), self.epi_tile) + epi_smem_layout = cute.slice_(self.c_smem_layout_staged, (None, None, 0)) + tma_atom_c, tma_tensor_c = cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileS2GOp(), + c, + epi_smem_layout, + c_cta_v_layout, + ) + + tensor_sfa = cute.make_tensor( + sfa.iterator, + cute.make_layout( + ( + (self.scale_granularity_m, sfa.shape[0]), + (self.scale_granularity_k, sfa.shape[1]), + sfa.shape[2], + ), + stride=( + (0, sfa.layout.stride[0]), + (0, sfa.layout.stride[1]), + sfa.layout.stride[2], + ), + ), + ) + tensor_sfb = cute.make_tensor( + sfb.iterator, + cute.make_layout( + ( + (self.scale_granularity_n, sfb.shape[0]), + (self.scale_granularity_k, sfb.shape[1]), + sfb.shape[2], + ), + stride=( + (0, sfb.layout.stride[0]), + (0, sfb.layout.stride[1]), + sfb.layout.stride[2], + ), + ), + ) + + # Compute grid size + self.tile_sched_params, grid = self._compute_grid( + c, self.cta_tile_shape_mnk, self.cluster_shape_mn, max_active_clusters + ) + + self.buffer_align_bytes = 1024 + + c_smem_size = cute.cosize(self.c_smem_layout_staged.outer) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + # (bidx, bidy, bidz, valid) + sInfo: cute.struct.Align[ + cute.struct.MemRange[cutlass.Int32, 4 * self.num_tile_stage], + # 1 byte alignment + 1, + ] + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_ab_stage * 2] + scale_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_scale_stage * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + tile_info_mbar_ptr: cute.struct.MemRange[cutlass.Int64, self.num_tile_stage * 2] + epi_mbar_ptr: cute.struct.MemRange[cutlass.Int64, 1 * 2] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: cutlass.Int32 + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype, + c_smem_size, + ], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA: cute.struct.Align[ + cute.struct.MemRange[self.sfa_dtype, cute.cosize(self.sfa_smem_layout_staged)], + self.buffer_align_bytes, + ] + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB: cute.struct.Align[ + cute.struct.MemRange[self.sfb_dtype, cute.cosize(self.sfb_smem_layout_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + tiled_mma, + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + tensor_sfa, + tensor_sfb, + self.cluster_layout_vmnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.epi_tile, + self.tile_sched_params, + epilogue_op, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=(*self.cluster_shape_mn, 1), + smem=self.shared_storage.size_in_bytes(), + stream=stream, + min_blocks_per_mp=1, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_c: cute.CopyAtom, + mC_mnl: cute.Tensor, + mSFA_mkl: cute.Tensor, + mSFB_nkl: cute.Tensor, + cluster_layout_vmnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + c_smem_layout_staged: Union[cute.Layout, cute.ComposedLayout, None], + sfa_smem_layout_staged: cute.Layout, + sfb_smem_layout_staged: cute.Layout, + epi_tile: cute.Tile, + tile_sched_params: utils.PersistentTileSchedulerParams, + epilogue_op: cutlass.Constexpr, + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + lane_idx = cute.arch.lane_idx() + + # + # Prefetch tma desc + # + if warp_idx == self.tma_warp_id: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + cpasync.prefetch_descriptor(tma_atom_c) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, bidy, bidz = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Initialize mainloop ab_pipeline (barrier) and states + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_tma_producer = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_tma_producer + ) + ab_pipeline = pipeline.PipelineTmaUmma.create( + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + num_stages=self.num_ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Initialize mainloop scale_pipeline (barrier) and states + scale_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + scale_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + scale_pipeline = pipeline.PipelineCpAsync.create( + barrier_storage=storage.scale_mbar_ptr.data_ptr(), + num_stages=self.num_scale_stage, + producer_group=scale_pipeline_producer_group, + consumer_group=scale_pipeline_consumer_group, + defer_sync=True, + ) + + # Initialize acc_pipeline (barrier) and states + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = len(self.epilog_warp_id) * (2 if use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + acc_pipeline = pipeline.PipelineUmmaAsync.create( + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + defer_sync=True, + ) + + # Initialize epilogue pipeline (barrier) and states + epi_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.acc_update_warp_id), + ) + epi_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * len(self.epilog_warp_id), + ) + epi_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.epi_mbar_ptr.data_ptr(), + num_stages=1, + producer_group=epi_pipeline_producer_group, + consumer_group=epi_pipeline_consumer_group, + defer_sync=True, + ) + + # Initialize tile info pipeline (barrier) and states + tile_info_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_per_warp * 1, + ) + tile_info_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + self.threads_wo_sched, + ) + tile_info_pipeline = pipeline.PipelineAsync.create( + barrier_storage=storage.tile_info_mbar_ptr.data_ptr(), + num_stages=self.num_tile_stage, + producer_group=tile_info_pipeline_producer_group, + consumer_group=tile_info_pipeline_consumer_group, + defer_sync=True, + ) + + # Tensor memory dealloc barrier init + tmem = utils.TmemAllocator( + tmem_holding_buf, + barrier_for_retrieve=self.tmem_alloc_barrier, + allocator_warp_id=self.epilog_warp_id[0], + is_two_cta=use_2cta_instrs, + two_cta_tmem_dealloc_mbar_ptr=tmem_dealloc_mbar_ptr, + ) + + # Cluster arrive after barrier init + pipeline_init_arrive(cluster_shape_mn=self.cluster_shape_mn, is_relaxed=True) + + # + # Setup smem tensor A/B/C/Scale + # + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sC = storage.sC.get_tensor(c_smem_layout_staged.outer, swizzle=c_smem_layout_staged.inner) + # (MMA, MMA_M, MMA_K, STAGE) + sA = storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + # (granularity_m, repeat_m), (granularity_k, repeat_k), num_scale_stage) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout_staged) + # (granularity_n, repeat_n), (granularity_k, repeat_k), num_scale_stage) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout_staged) + # (bidx, bidy, bidz, valid) + info_layout = cute.make_layout((4, self.num_tile_stage), stride=(1, 4)) + sInfo = storage.sInfo.get_tensor(info_layout) + + # + # Compute multicast mask for A/B buffer full + # + a_full_mcast_mask = None + b_full_mcast_mask = None + if cutlass.const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_full_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + + # + # Local_tile partition global tensors + # + # (bM, bK, loopM, loopK, loopL) + gA_mkl = cute.local_tile( + mA_mkl, cute.slice_(self.mma_tiler, (None, 0, None)), (None, None, None) + ) + # (bN, bK, loopN, loopK, loopL) + gB_nkl = cute.local_tile( + mB_nkl, cute.slice_(self.mma_tiler, (0, None, None)), (None, None, None) + ) + # (bM, bN, loopM, loopN, loopL) + gC_mnl = cute.local_tile( + mC_mnl, cute.slice_(self.mma_tiler, (None, None, 0)), (None, None, None) + ) + # (bM, bK, loopM, loopK, loopL) + gSFA_mkl = cute.local_tile( + mSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + gSFB_nkl = cute.local_tile( + mSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + # coordinate + cSFA_mkl = cute.make_identity_tensor(cute.shape(mSFA_mkl)) + cSFB_nkl = cute.make_identity_tensor(cute.shape(mSFB_nkl)) + # (bM, bK, loopM, loopK, loopL) + cSFA = cute.local_tile( + cSFA_mkl, + cute.slice_(self.cta_tile_shape_mnk, (None, 0, None)), + (None, None, None), + ) + # (bN, bK, loopN, loopK, loopL) + cSFB = cute.local_tile( + cSFB_nkl, + cute.slice_(self.cta_tile_shape_mnk, (0, None, None)), + (None, None, None), + ) + k_tile_cnt = cute.size(gA_mkl, mode=[3]) + + # + # Partition global tensor for TiledMMA_A/B/C + # + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + # (MMA, MMA_M, MMA_K, loopM, loopK, loopL) + tCgA = thr_mma.partition_A(gA_mkl) + # (MMA, MMA_N, MMA_K, loopN, loopK, loopL) + tCgB = thr_mma.partition_B(gB_nkl) + # (MMA, MMA_M, MMA_N, loopM, loopN, loopL) + tCgC = thr_mma.partition_C(gC_mnl) + + # scale viewed as C tensor + sSFA_view_as_C_layout = cute.make_layout( + ( + (self.scale_granularity_m, self.scale_m_per_tile), + self.cta_tile_shape_mnk[1], + self.num_scale_stage, + ), + stride=((0, 1), 0, self.scale_m_per_tile), + ) + sSFB_view_as_C_layout = cute.make_layout( + ( + self.cta_tile_shape_mnk[0], + (self.scale_granularity_n, self.scale_n_per_tile), + self.num_scale_stage, + ), + stride=(0, (0, 1), self.scale_n_per_tile), + ) + sSFA_view_as_C = cute.make_tensor(sSFA.iterator, sSFA_view_as_C_layout) + sSFB_view_as_C = cute.make_tensor(sSFB.iterator, sSFB_view_as_C_layout) + + # + # Partition global/shared tensor for TMA load A/B + # + # TMA load A partition_S/D + a_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAsA, tAgA = cpasync.tma_partition( + tma_atom_a, + block_in_cluster_coord_vmnk[2], + a_cta_layout, + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + # TMA load B partition_S/D + b_cta_layout = cute.make_layout(cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tBsB, tBgB = cpasync.tma_partition( + tma_atom_b, + block_in_cluster_coord_vmnk[1], + b_cta_layout, + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # + # Partition global/shared tensor for TMA load A/B + # + # load scaleA/scaleB + atom_copy = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(), + mSFA_mkl.element_type, + num_bits_per_copy=mSFA_mkl.element_type.width, + ) + tiled_copy_sfa = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + tiled_copy_sfb = cute.make_tiled_copy_tv( + atom_copy, cute.make_layout((32,)), cute.make_layout((1,)) + ) + thr_copy_sfa = tiled_copy_sfa.get_slice(lane_idx) + thr_copy_sfb = tiled_copy_sfb.get_slice(lane_idx) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopM, loopK, loopL) + tAgSFA_mkl = thr_copy_sfa.partition_S(gSFA_mkl) + tAsSFA = thr_copy_sfa.partition_D(sSFA) + tAcSFA = thr_copy_sfa.partition_S(cSFA) + # ((atom_v, rest_v), STAGE) + # ((atom_v, rest_v), loopN, loopK, loopL) + tBgSFB_nkl = thr_copy_sfb.partition_S(gSFB_nkl) + tBsSFB = thr_copy_sfb.partition_D(sSFB) + tBcSFB = thr_copy_sfb.partition_S(cSFB) + + # + # Partition shared/tensor memory tensor for TiledMMA_A/B/C + # + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + # + # Cluster wait before tensor memory alloc + # + pipeline_init_wait(cluster_shape_mn=self.cluster_shape_mn) + + # + # Specialized Schedule warp + # + if warp_idx == self.sched_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_sched_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + tile_info_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_tile_stage + ) + + while work_tile.is_valid_tile: + # query next tile + tile_sched.advance_to_next_work() + work_tile = tile_sched.get_current_work() + + # acquire tile info pipeline + tile_info_pipeline.producer_acquire(tile_info_producer_state) + + # store the tile info + cur_tile_coord = work_tile.tile_idx + with cute.arch.elect_one(): + sInfo[(0, tile_info_producer_state.index)] = cur_tile_coord[0] + sInfo[(1, tile_info_producer_state.index)] = cur_tile_coord[1] + sInfo[(2, tile_info_producer_state.index)] = cur_tile_coord[2] + sInfo[(3, tile_info_producer_state.index)] = cutlass.Int32( + work_tile.is_valid_tile + ) + + # fence view async shared + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.sched_sync_barrier.arrive_and_wait() + # commit tile info pipeline + tile_info_pipeline.producer_commit(tile_info_producer_state) + tile_info_producer_state.advance() + + tile_info_pipeline.producer_tail(tile_info_producer_state) + + # + # Specialized TMA load warp + # + if warp_idx == self.tma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_ab_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + # ((atom_v, rest_v), loopK) + tAgA_slice = tAgA[(None, mma_tile_coord_mnl[0], None, mma_tile_coord_mnl[2])] + # ((atom_v, rest_v), loopK) + tBgB_slice = tBgB[(None, mma_tile_coord_mnl[1], None, mma_tile_coord_mnl[2])] + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + ab_producer_state.reset_count() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + # + # Tma load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + tAgA_k = tAgA_slice[(None, ab_producer_state.count)] + tBgB_k = tBgB_slice[(None, ab_producer_state.count)] + tAsA_pipe = tAsA[(None, ab_producer_state.index)] + tBsB_pipe = tBsB[(None, ab_producer_state.index)] + + tma_bar = ab_pipeline.producer_get_barrier(ab_producer_state) + + # Conditionally wait for AB buffer empty + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + + # Peek (try_wait) AB buffer empty for k_tile = prefetch_k_tile_cnt + k_tile + 1 + ab_producer_state.advance() + peek_ab_empty_status = cutlass.Boolean(1) + if ab_producer_state.count < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + + # TMA load A/B + cute.copy( + tma_atom_a, + tAgA_k, + tAsA_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=a_full_mcast_mask, + ) + cute.copy( + tma_atom_b, + tBgB_k, + tBsB_pipe, + tma_bar_ptr=tma_bar, + mcast_mask=b_full_mcast_mask, + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait A/B buffer empty + # + ab_pipeline.producer_tail(ab_producer_state) + + # + # Specialized Scale load warp + # + if warp_idx == self.scale_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + # First tile + work_tile = tile_sched.initial_work_tile_info() + + scale_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_scale_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # + # Prepare the mask for scaleA/scaleB + # + tApSFA = cute.make_rmem_tensor( + cute.make_layout( + cute.filter_zeros(cute.slice_(tAsSFA, (None, None, None, 0))).shape + ), + cutlass.Boolean, + ) + tBpSFB = cute.make_rmem_tensor( + cute.make_layout( + cute.filter_zeros(cute.slice_(tBsSFB, (None, None, None, 0))).shape + ), + cutlass.Boolean, + ) + + # Peek (try_wait) SCALE buffer empty + scale_producer_state.reset_count() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # load loop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # + # Slice to per mma tile index + # + tAsSFA_pipe = cute.filter_zeros( + tAsSFA[(None, None, None, scale_producer_state.index)] + ) + tBsSFB_pipe = cute.filter_zeros( + tBsSFB[(None, None, None, scale_producer_state.index)] + ) + tAgSFA_k = cute.filter_zeros( + tAgSFA_mkl[ + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + tile_info[2], + ) + ] + ) + tBgSFB_k = cute.filter_zeros( + tBgSFB_nkl[ + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ) + ] + ) + + tAcSFA_compact = cute.filter_zeros( + cute.slice_( + tAcSFA, + ( + None, + None, + None, + tile_info[0], + scale_producer_state.count, + tile_info[2], + ), + ) + ) + tBcSFB_compact = cute.filter_zeros( + cute.slice_( + tBcSFB, + ( + None, + None, + None, + tile_info[1], + scale_producer_state.count, + tile_info[2], + ), + ) + ) + for i in cutlass.range_constexpr(cute.size(tApSFA, mode=[1])): + tApSFA[((0, 0), i, (0, 0))] = cute.elem_less( + tAcSFA_compact[(i)][0], mSFA_mkl.shape[0] + ) + for i in cutlass.range_constexpr(cute.size(tBpSFB, mode=[1])): + tBpSFB[((0, 0), i, (0, 0))] = cute.elem_less( + tBcSFB_compact[(i)][0], mSFB_nkl.shape[0] + ) + + # Conditionally wait for Scale buffer empty + scale_pipeline.producer_acquire(scale_producer_state, peek_scale_empty_status) + + # load scaleA/scaleB + cute.copy(tiled_copy_sfa, tAgSFA_k, tAsSFA_pipe, pred=tApSFA) + cute.copy(tiled_copy_sfb, tBgSFB_k, tBsSFB_pipe, pred=tBpSFB) + + scale_pipeline.producer_commit(scale_producer_state) + + # Peek (try_wait) Scale buffer empty + scale_producer_state.advance() + peek_scale_empty_status = cutlass.Boolean(1) + if scale_producer_state.count < k_tile_cnt: + peek_scale_empty_status = scale_pipeline.producer_try_acquire( + scale_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait Scale buffer empty + # + scale_pipeline.producer_tail(scale_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_uniform_warps) + # + # Bar sync for retrieve tensor memory ptr from shared mem + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # Peek (try_wait) AB buffer full for k_tile = 0 + ab_consumer_state.reset_count() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # Peek (try_wait) Acc buffer empty for k_tile = 0 + acc_producer_state.reset_count() + peek_acc_empty_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt and is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire(acc_producer_state) + + # + # Mma mainloop + # + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[(None, None, None, acc_producer_state.index)] + + # + # Wait for accumulator buffer empty + # + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state, peek_acc_empty_status) + + # + # Reset the ACCUMULATE field for each tile + # + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + + # tCtAcc += tCrA * tCrB + num_kblocks = cute.size(tCrA, mode=[2]) + for kblock_idx in cutlass.range(num_kblocks, unroll_full=True): + kblock_coord = ( + None, + None, + kblock_idx, + ab_consumer_state.index, + ) + + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[kblock_coord], + tCrB[kblock_coord], + tCtAcc, + ) + # Enable accumulate on tCtAcc after first kblock + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + ab_consumer_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if ab_consumer_state.count < k_tile_cnt: + if is_leader_cta: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + + # + # Async arrive accumulator buffer full(each kblock) + # + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + + # Peek (try_wait) Acc buffer empty for k_tile = k_tile + 1 + acc_producer_state.advance() + if acc_producer_state.count < k_tile_cnt: + if is_leader_cta: + peek_acc_empty_status = acc_pipeline.producer_try_acquire( + acc_producer_state + ) + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Wait for accumulator buffer empty + # + acc_pipeline.producer_tail(acc_producer_state) + + # + # Specialized acc update warps + # + if warp_idx <= self.acc_update_warp_id[-1]: + cute.arch.warpgroup_reg_alloc(self.num_regs_acc_update_warps) + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base.iterator + self.tmem_final_offset, tCtAcc_base.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc_base, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc, + tRT_tAcc_base, + ) = self.acc_update_tmem_copy_and_partition( + epi_tidx, + tCtAcc_base, + tCtAcc_final, + tCgC, + sSFA_view_as_C, + sSFB_view_as_C, + epi_tile, + ) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + + scale_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_scale_stage + ) + + epi_producer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + while is_valid_tile: + # initialize the final accumulator + tTR_rAcc_final.fill(0.0) + + tTR_rSFA = cute.make_rmem_tensor( + cute.slice_(tTR_sSFA, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + tTR_rSFB = cute.make_rmem_tensor( + cute.slice_(tTR_sSFB, (None, None, None, 0, None, 0)).shape, + self.acc_dtype, + ) + + scale_consumer_state.reset_count() + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait(scale_consumer_state) + + acc_consumer_state.reset_count() + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait(acc_consumer_state) + + for k_tile in cutlass.range(0, k_tile_cnt, 1, unroll=1): + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[ + (None, None, None, None, None, acc_consumer_state.index) + ] + + # + # Wait for scale buffer full + # + scale_pipeline.consumer_wait(scale_consumer_state, peek_scale_full_status) + + tTR_sSFA_slice = cute.slice_( + tTR_sSFA, + (None, None, None, 0, None, scale_consumer_state.index), + ) + tTR_sSFB_slice = cute.slice_( + tTR_sSFB, + (None, None, None, 0, None, scale_consumer_state.index), + ) + + scale_atom_copy = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + self.acc_dtype, + num_bits_per_copy=self.acc_dtype.width, + ) + + cute.copy(scale_atom_copy, tTR_sSFA_slice, tTR_rSFA) + cute.copy(scale_atom_copy, tTR_sSFB_slice, tTR_rSFB) + + # + # Wait for accumulator buffer full + # + + acc_pipeline.consumer_wait(acc_consumer_state, peek_acc_full_status) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + + # + # Update accumulator by scale factor in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Update accumulator by scale factor + # + tTR_rAcc_subtile = tTR_rAcc_final[(None, None, None, subtile_idx)] + tTR_rSFA_subtile = tTR_rSFA[(None, None, None, subtile_idx)] + tTR_rSFB_subtile = tTR_rSFB[(None, None, None, subtile_idx)] + + acc_vec = tTR_rAcc.load() + final_vec = tTR_rAcc_subtile.load() + scale_a = tTR_rSFA_subtile.load() + scale_b = tTR_rSFB_subtile.load() + scale = scale_a * scale_b + final_vec = acc_vec * scale + final_vec + tTR_rAcc_subtile.store(final_vec.to(self.acc_dtype)) + + # + # Async arrive accumulator buffer empty + # + scale_pipeline.consumer_release(scale_consumer_state) + scale_consumer_state.advance() + + peek_scale_full_status = cutlass.Boolean(1) + if scale_consumer_state.count < k_tile_cnt: + peek_scale_full_status = scale_pipeline.consumer_try_wait( + scale_consumer_state + ) + # + # Async arrive accumulator buffer empty + # + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + peek_acc_full_status = cutlass.Boolean(1) + if acc_consumer_state.count < k_tile_cnt: + peek_acc_full_status = acc_pipeline.consumer_try_wait(acc_consumer_state) + + tRT_tAcc = tRT_tAcc_base[(None, None, None, None, None, 0)] + tRT_tAcc = cute.group_modes(tRT_tAcc, 3, cute.rank(tRT_tAcc)) + + # + # Wait for epilogue buffer empty + # + epi_pipeline.producer_acquire(epi_producer_state) + + # copy the accumulator to tensor memory buffer + cute.copy(tiled_copy_r2t, tTR_rAcc_final, tRT_tAcc) + cute.arch.fence_view_async_tmem_store() + + # + # Async arrive epilogue buffer full + # + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Specialized epilogue warps + # + if warp_idx <= self.epilog_warp_id[-1] and warp_idx >= self.epilog_warp_id[0]: + cute.arch.warpgroup_reg_alloc(self.num_regs_epilogue_warps) + # + # Alloc tensor memory buffer + # + tmem.allocate(self.num_tmem_alloc_cols) + + # + # Bar sync for retrieve tensor memory ptr from shared memory + # + tmem.wait_for_alloc() + + # + # Retrieving tensor memory ptr and make accumulator tensor + # + tmem_ptr = tmem.retrieve_ptr(self.acc_dtype) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base_ = cute.make_tensor(tmem_ptr, tCtAcc_fake.layout) + tCtAcc_final = cute.make_tensor( + tCtAcc_base_.iterator + self.tmem_final_offset, tCtAcc_base_.layout + ) + + # + # Partition for epilogue + # + epi_tidx = tidx % 128 + ( + tiled_copy_t2r, + tTR_tAcc_base, + tTR_rAcc, + ) = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_final, tCgC, epi_tile, use_2cta_instrs + ) + + tTR_rC = None + tiled_copy_r2s = None + # simt_atom = None + tRS_rC = None + tRS_sC = None + bSG_sC = None + bSG_gC_partitioned = None + tTR_rC = cute.make_rmem_tensor(tTR_rAcc.shape, self.c_dtype) + tiled_copy_r2s, tRS_rC, tRS_sC = self.epilog_smem_copy_and_partition( + tiled_copy_t2r, tTR_rC, epi_tidx, sC + ) + ( + tma_atom_c, + bSG_sC, + bSG_gC_partitioned, + ) = self.epilog_gmem_copy_and_partition(epi_tidx, tma_atom_c, tCgC, epi_tile, sC) + + # + # Persistent tile scheduling loop + # + tile_sched = utils.StaticPersistentTileScheduler.create( + tile_sched_params, cute.arch.block_idx(), cute.arch.grid_dim() + ) + work_tile = tile_sched.initial_work_tile_info() + + epi_consumer_state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Consumer, 1) + + c_pipeline = None + # Threads/warps participating in tma store pipeline + c_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, + 32 * len(self.epilog_warp_id), + ) + c_pipeline = pipeline.PipelineTmaStore.create( + num_stages=self.num_c_stage, + producer_group=c_producer_group, + ) + + tile_info_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_tile_stage + ) + + # get the first tile info + tile_info = cute.make_rmem_tensor((4,), cutlass.Int32) + # Get tile coord from tile scheduler + cur_tile_coord = work_tile.tile_idx + # initialize the tile info + tile_info[0] = cur_tile_coord[0] + tile_info[1] = cur_tile_coord[1] + tile_info[2] = cur_tile_coord[2] + tile_info[3] = work_tile.is_valid_tile + + is_valid_tile = cutlass.Boolean(1) + is_valid_tile = tile_info[3] == 1 + + num_prev_subtiles = cutlass.Int32(0) + + while is_valid_tile: + mma_tile_coord_mnl = ( + tile_info[0] // cute.size(tiled_mma.thr_id.shape), + tile_info[1], + tile_info[2], + ) + # + # Slice to per mma tile index + # + bSG_gC = None + # ((ATOM_V, REST_V), EPI_M, EPI_N) + bSG_gC = bSG_gC_partitioned[ + ( + None, + None, + None, + mma_tile_coord_mnl[0], + mma_tile_coord_mnl[1], + mma_tile_coord_mnl[2], + ) + ] + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[(None, None, None, None, None, epi_consumer_state.index)] + + # + # Wait for accumulator buffer full + # + epi_pipeline.consumer_wait(epi_consumer_state) + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + bSG_gC = cute.group_modes(bSG_gC, 1, cute.rank(bSG_gC)) + + # + # Store accumulator to global memory in subtiles + # + subtile_cnt = cute.size(tTR_tAcc.shape, mode=[3]) + for subtile_idx in cutlass.range(subtile_cnt): + # + # Load accumulator from tensor memory buffer to register + # + tTR_tAcc_mn = tTR_tAcc[(None, None, None, subtile_idx)] + cute.copy(tiled_copy_t2r, tTR_tAcc_mn, tTR_rAcc) + + # + # Convert to C type + # + acc_vec = tiled_copy_r2s.retile(tTR_rAcc).load() + acc_vec = epilogue_op(acc_vec.to(self.c_dtype)) + tRS_rC.store(acc_vec) + + # + # Store C to shared memory + # + num_prev_subtiles = num_prev_subtiles + 1 + c_buffer = num_prev_subtiles % self.num_c_stage + cute.copy( + tiled_copy_r2s, + tRS_rC, + tRS_sC[(None, None, None, c_buffer)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + self.epilog_sync_barrier.arrive_and_wait() + + # + # TMA store C to global memory + # + if warp_idx == self.epilog_warp_id[0]: + cute.copy( + tma_atom_c, + bSG_sC[(None, c_buffer)], + bSG_gC[(None, subtile_idx)], + ) + # Fence and barrier to make sure shared memory store is visible to TMA store + c_pipeline.producer_commit() + c_pipeline.producer_acquire() + self.epilog_sync_barrier.arrive_and_wait() + + # + # Async arrive accumulator buffer empty + # + epi_pipeline.consumer_release(epi_consumer_state) + epi_consumer_state.advance() + + # + # Advance to next tile + # + tile_info_pipeline.consumer_wait(tile_info_consumer_state) + for idx in cutlass.range(4, unroll_full=True): + tile_info[idx] = sInfo[(idx, tile_info_consumer_state.index)] + is_valid_tile = tile_info[3] == 1 + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, + space=cute.arch.SharedSpace.shared_cta, + ) + tile_info_pipeline.consumer_release(tile_info_consumer_state) + tile_info_consumer_state.advance() + + # + # Dealloc the tensor memory buffer + # + tmem.relinquish_alloc_permit() + self.epilog_sync_barrier.arrive_and_wait() + tmem.free(tmem_ptr) + # + # Wait for C store complete + # + c_pipeline.producer_tail() + + def acc_update_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + tAcc_final: cute.Tensor, + gC_mnl: cute.Tensor, + sSFA: cute.Tensor, + sSFB: cute.Tensor, + epi_tile: cute.Tile, + ) -> Tuple[ + cute.TiledCopy, + cute.TiledCopy, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + cute.Tensor, + ]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and + register array (destination). + Make tiledCopy for tensor memory store, then use it to partition register array (source) and + tensor memory (destination). + Partition the scale factor tensor for related copy operations. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param tAcc_final: The final accumulator tensor to be copied and partitioned + :type tAcc_final: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param sSFA: The scale factor tensor for A + :type sSFA: cute.Tensor + :param sSFB: The scale factor tensor for B + :type sSFB: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tiled_copy_r2t: The tiled copy operation for register to tmem copy(r2t) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + - tTR_rAcc_final: The accumulated tensor in register used to hold all t2r results + - tTR_sSFA: The partitioned tensor SFA by tiled_copy_t2r + - tTR_sSFB: The partitioned tensor SFB by tiled_copy_t2r + - tRT_rAcc_final: The accumulated tensor in register used to hold all r2t results + - tRT_tAcc_final: The partitioned accumulator tensor by tiled_copy_r2t + :rtype: Tuple[cute.TiledCopy, cute.TiledCopy, cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor, + cute.Tensor, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + tmem_load_atom = None + tmem_store_atom = None + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + else: + tmem_load_atom = cute.make_copy_atom( + tcgen05.copy.Ld32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + if cutlass.const_expr(self.mma_tiler[0] == 64): + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St16x256bOp(tcgen05.copy.Repetition(8)), + self.acc_dtype, + ) + else: + tmem_store_atom = cute.make_copy_atom( + tcgen05.copy.St32x32bOp(tcgen05.copy.Repetition(32)), + self.acc_dtype, + ) + + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile) + tAcc_final_epi = cute.flat_divide(tAcc_final[((None, None), 0, 0, None)], epi_tile) + + tiled_copy_t2r = tcgen05.make_tmem_copy(tmem_load_atom, tAcc_epi[(None, None, 0, 0, 0)]) + tiled_copy_r2t = tcgen05.make_tmem_copy( + tmem_store_atom, tAcc_final_epi[(None, None, 0, 0, 0)] + ) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + thr_copy_r2t = tiled_copy_r2t.get_slice(tidx) + + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + sSFA_epi = cute.flat_divide(sSFA, epi_tile) + sSFB_epi = cute.flat_divide(sSFB, epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + tTR_sSFA = thr_copy_t2r.partition_D(sSFA_epi) + tTR_sSFB = thr_copy_t2r.partition_D(sSFB_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_rAcc_final_ = cute.make_rmem_tensor( + tTR_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + tTR_rAcc_final = cute.group_modes(tTR_rAcc_final_, 3, cute.rank(tTR_rAcc_final_)) + + tRT_gC = thr_copy_r2t.partition_S(gC_mnl_epi) + tRT_tAcc_final = thr_copy_r2t.partition_D(tAcc_final_epi) + # (R2T, R2T_M, R2T_N, EPI_M, EPI_N, loopM, loopN, loopL) + tRT_rAcc_final_ = cute.make_rmem_tensor( + tRT_gC[(None, None, None, None, None, 0, 0, 0)].shape, self.acc_dtype + ) + # (R2T, R2T_M, R2T_N, (EPI_M, EPI_N)) + tRT_rAcc_final = cute.group_modes(tRT_rAcc_final_, 3, cute.rank(tRT_rAcc_final_)) + + return ( + tiled_copy_t2r, + tiled_copy_r2t, + tTR_tAcc, + tTR_rAcc, + tTR_rAcc_final, + tTR_sSFA, + tTR_sSFB, + tRT_rAcc_final, + tRT_tAcc_final, + ) + + def epilog_tmem_copy_and_partition( + self, + tidx: cutlass.Int32, + tAcc: cute.Tensor, + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[cutlass.Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and + register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.c_layout, + self.c_dtype, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide( + tAcc[((None, None), 0, 0, None)], + epi_tile, + ) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_mnl_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N, loopM, loopN, loopL) + tTR_gC = thr_copy_t2r.partition_D(gC_mnl_epi) + + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_rmem_tensor( + tTR_gC[(None, None, None, 0, 0, 0, 0, 0)].shape, self.acc_dtype + ) + + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_copy_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + tTR_rC: cute.Tensor, + tidx: cutlass.Int32, + sC: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) + and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rC: The partitioned accumulator tensor + :type tTR_rC: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rC, tRS_sC) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rC: The partitioned tensor C (register source) + - tRS_sC: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + self.c_layout, self.c_dtype, self.acc_dtype, tiled_copy_t2r + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sC = thr_copy_r2s.partition_D(sC) + # (R2S, R2S_M, R2S_N) + tRS_rC = tiled_copy_r2s.retile(tTR_rC) + return tiled_copy_r2s, tRS_rC, tRS_sC + + def epilog_gmem_copy_and_partition( + self, + tidx: cutlass.Int32, + atom: Union[cute.CopyAtom, cute.TiledCopy], + gC_mnl: cute.Tensor, + epi_tile: cute.Tile, + sC: cute.Tensor, + ) -> Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor]: + """Make tiledCopy for global memory store, then use it to: + - partition register array (source) and global memory (destination) for none TMA store version; + - partition shared memory (source) and global memory (destination) for TMA store version. + + :param tidx: The thread index in epilogue warp groups + :type tidx: cutlass.Int32 + :param atom: The copy_atom_c to be used for TMA store version, or tiled_copy_t2r for none TMA store version + :type atom: cute.CopyAtom or cute.TiledCopy + :param gC_mnl: The global tensor C + :type gC_mnl: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param sC: The shared memory tensor to be copied and partitioned + :type sC: cute.Tensor + + :return: A tuple containing : + - For TMA store: (tma_atom_c, bSG_sC, bSG_gC) where: + - tma_atom_c: The TMA copy atom + - bSG_sC: The partitioned shared memory tensor C + - bSG_gC: The partitioned global tensor C + :rtype: Tuple[cute.CopyAtom, cute.Tensor, cute.Tensor] + """ + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, loopM, loopN, loopL) + gC_epi = cute.flat_divide(gC_mnl[((None, None), 0, 0, None, None, None)], epi_tile) + tma_atom_c = atom + sC_for_tma_partition = cute.group_modes(sC, 0, 2) + gC_for_tma_partition = cute.group_modes(gC_epi, 0, 2) + # ((ATOM_V, REST_V), EPI_M, EPI_N) + # ((ATOM_V, REST_V), EPI_M, EPI_N, loopM, loopN, loopL) + bSG_sC, bSG_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + sC_for_tma_partition, + gC_for_tma_partition, + ) + return tma_atom_c, bSG_sC, bSG_gC + + @staticmethod + def _compute_stages( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + epi_tile: cute.Tile, + c_dtype: Type[cutlass.Numeric], + c_layout: utils.LayoutEnum, + sfa_dtype: Type[cutlass.Numeric], + sfb_dtype: Type[cutlass.Numeric], + sfa_count: int, + sfb_count: int, + num_smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param c_dtype: Data type of operand C (output). + :type c_dtype: type[cutlass.Numeric] + :param c_layout: Layout of operand C. + :type c_layout: utils.LayoutEnum + :param num_smem_capacity: Total available shared memory capacity in bytes. + :type num_smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + # Default ACC stages + num_acc_stage = 3 if mma_tiler_mnk[0] / tiled_mma.thr_id.shape == 128 else 6 + + # Default C stages + num_c_stage = 2 + + # Default ScaleA/B stages + num_scale_stage = 10 + + # Default Tile info stages + num_tile_stage = 2 + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_stage_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + c_smem_layout_staged_one = sm100_utils.make_smem_layout_epi( + c_dtype, + c_layout, + epi_tile, + 1, + ) + + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_stage_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + # 1024B alignment + mbar_helpers_bytes = 1024 + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + c_bytes = c_bytes_per_stage * num_c_stage + sfa_bytes = sfa_count * (sfa_dtype.width // 8) * num_scale_stage + sfb_bytes = sfb_count * (sfb_dtype.width // 8) * num_scale_stage + scale_bytes = math.ceil((sfa_bytes + sfb_bytes) / 1024) * 1024 + + # Calculate A/B stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B stage + num_ab_stage = ( + num_smem_capacity // occupancy - (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + num_c_stage += ( + num_smem_capacity + - occupancy * ab_bytes_per_stage * num_ab_stage + - occupancy * (mbar_helpers_bytes + c_bytes + scale_bytes) + ) // (occupancy * c_bytes_per_stage) + return num_acc_stage, num_ab_stage, num_c_stage, num_scale_stage, num_tile_stage + + @staticmethod + def _compute_grid( + c: cute.Tensor, + cta_tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mn: Tuple[int, int], + max_active_clusters: cutlass.Constexpr, + ) -> Tuple[utils.PersistentTileSchedulerParams, Tuple[int, int, int]]: + """Use persistent tile scheduler to compute the grid size for the output tensor C. + + :param c: The output tensor C + :type c: cute.Tensor + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: tuple[int, int, int] + :param cluster_shape_mn: Shape of each cluster in M, N dimensions. + :type cluster_shape_mn: tuple[int, int] + :param max_active_clusters: Maximum number of active clusters. + :type max_active_clusters: cutlass.Constexpr + + :return: A tuple containing: + - tile_sched_params: Parameters for the persistent tile scheduler. + - grid: Grid shape for kernel launch. + :rtype: Tuple[utils.PersistentTileSchedulerParams, tuple[int, int, int]] + """ + c_shape = cute.slice_(cta_tile_shape_mnk, (None, None, 0)) + gc = cute.zipped_divide(c, tiler=c_shape) + num_ctas_mnl = gc[(0, (None, None, None))].shape + cluster_shape_mnl = (*cluster_shape_mn, 1) + + tile_sched_params = utils.PersistentTileSchedulerParams(num_ctas_mnl, cluster_shape_mnl) + grid = utils.StaticPersistentTileScheduler.get_grid_shape( + tile_sched_params, max_active_clusters + ) + + return tile_sched_params, grid + + @staticmethod + def _get_tma_atom_kind( + atom_sm_cnt: cutlass.Int32, mcast: cutlass.Boolean + ) -> Union[cpasync.CopyBulkTensorTileG2SMulticastOp, cpasync.CopyBulkTensorTileG2SOp]: + """ + Select the appropriate TMA copy atom based on the number of SMs and the multicast flag. + + :param atom_sm_cnt: The number of SMs + :type atom_sm_cnt: cutlass.Int32 + :param mcast: The multicast flag + :type mcast: cutlass.Boolean + + :return: The appropriate TMA copy atom kind + :rtype: cpasync.CopyBulkTensorTileG2SMulticastOp or cpasync.CopyBulkTensorTileG2SOp + + :raise ValueError: If the atom_sm_cnt is invalid + """ + if atom_sm_cnt == 2 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 2 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.TWO) + elif atom_sm_cnt == 1 and mcast: + return cpasync.CopyBulkTensorTileG2SMulticastOp(tcgen05.CtaGroup.ONE) + elif atom_sm_cnt == 1 and not mcast: + return cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + + raise ValueError(f"Invalid atom_sm_cnt: {atom_sm_cnt} and {mcast}") + + @staticmethod + def is_valid_dtypes( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if ab_dtype not in { + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if acc_dtype not in {cutlass.Float32}: + is_valid = False + if c_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.BFloat16}: + is_valid = False + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if not ( + (not use_2cta_instrs and mma_tiler_mn[0] in [64, 128]) + or (use_2cta_instrs and mma_tiler_mn[0] in [128, 256]) + ): + is_valid = False + # Skip invalid mma tile n + if mma_tiler_mn[1] not in (128,): + is_valid = False + # Skip illegal cluster shape + if cluster_shape_mn[0] % (2 if use_2cta_instrs else 1) != 0: + is_valid = False + + # Skip invalid cluster shape + def is_power_of_2(x): + return x > 0 and (x & (x - 1)) == 0 + + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + batch_size: int, + ab_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param batch_size: Batch dimension size + :type batch_size: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, batch_size)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, batch_size)) + or not check_contigous_16B_alignment(c_dtype, c_major == "m", (m, n, batch_size)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + c_dtype: Type[cutlass.Numeric], + use_2cta_instrs: bool, + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + batch_size: int, + a_major: str, + b_major: str, + c_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param c_dtype: The data type of the output tensor + :type c_dtype: Type[cutlass.Numeric] + :param use_2cta_instrs: Whether to use 2 CTA groups + :type use_2cta_instrs: bool + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param batch_size: Batch dimension size + :type batch_size: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param c_major: The major axis of the C tensor + :type c_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not Sm100BlockwiseGemmKernel.is_valid_dtypes(ab_dtype, acc_dtype, c_dtype): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not Sm100BlockwiseGemmKernel.is_valid_mma_tiler_and_cluster_shape( + use_2cta_instrs, mma_tiler_mn, cluster_shape_mn + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not Sm100BlockwiseGemmKernel.is_valid_tensor_alignment( + m, n, k, batch_size, ab_dtype, c_dtype, a_major, b_major, c_major + ): + can_implement = False + # Skip unsupported A/B layout + if not (a_major == "k" and b_major == "k"): + can_implement = False + return can_implement + + @cute.jit + def wrapper( + self, + m: cutlass.Int32, + n: cutlass.Int32, + k: cutlass.Int32, + sf_m: cutlass.Int32, + sf_n: cutlass.Int32, + sf_k: cutlass.Int32, + batch_size: cutlass.Int32, + a_ptr: cute.Pointer, + b_ptr: cute.Pointer, + a_sf_ptr: cute.Pointer, + b_sf_ptr: cute.Pointer, + c_tensor: cute.Tensor, + max_active_clusters: cutlass.Constexpr, + stream: cuda.CUstream, + ): + """Executes the wrapped GEMM kernel with dynamically shaped tensors. + + Args: + m (int): The M dimension of the GEMM problem. + n (int): The N dimension of the GEMM problem. + k (int): The K dimension of the GEMM problem. + sf_m (int): The M dimension of the scale factor tensor for A. + sf_n (int): The N dimension of the scale factor tensor for B. + sf_k (int): The K dimension of the scale factor tensor. + batch_size (int): The batch dimension of the GEMM problem. + a_ptr (cute.Pointer): Pointer to the A tensor. + b_ptr (cute.Pointer): Pointer to the B tensor. + a_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for A. + b_sf_ptr (cute.Pointer): Pointer to the scale factor tensor for B. + c_tensor (cute.Tensor): Specially set as cute.Tensor for TVM FFI stream detection. + max_active_clusters (cutlass.Constexpr): Maximum number of active + clusters. + stream (cuda.CUstream): CUDA stream for the operation. + """ + + # m, k, batch_size with inner most dimension as k + a_tensor = cute.make_tensor( + a_ptr, + layout=cute.make_ordered_layout((m, k, batch_size), order=(1, 0, 2)), + ) + # n, k, batch_size with inner most dimension as k + b_tensor = cute.make_tensor( + b_ptr, + layout=cute.make_ordered_layout( + (n, k, batch_size), + order=(1, 0, 2), + ), + ) + # sf_m, sf_k, batch_size + sfa_tensor = cute.make_tensor( + a_sf_ptr, + layout=cute.make_ordered_layout( + (sf_m, sf_k, batch_size), + order=(0, 1, 2), + ), + ) + # sf_n, sf_k, batch_size + sfb_tensor = cute.make_tensor( + b_sf_ptr, + layout=cute.make_ordered_layout( + (sf_n, sf_k, batch_size), + order=(1, 0, 2), + ), + ) + + self( + a_tensor, + b_tensor, + c_tensor, + sfa_tensor, + sfb_tensor, + max_active_clusters, + stream, + ) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 5da1629876..39a7289fee 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -122,6 +122,10 @@ class ModelConfig(Generic[TConfig]): extra_attrs: Dict = field(default_factory=dict, repr=False, init=False) + # cute dsl op configs + use_cute_dsl_blockscaling_mm: bool = False + use_cute_dsl_blockscaling_bmm: bool = False + _frozen: bool = field(default=False, init=False, repr=False) # If true, ONLY the vision encoder part of the full model is loaded/executed. diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 1171bb23f6..921eafe0a8 100755 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -672,6 +672,7 @@ class DeepseekV3Linear(Linear): reduce_output: bool = True, # ROW parallel only skip_create_weights_in_init: bool = False, use_custom_cublas_mm: bool = False, + use_cute_dsl_blockscaling_mm: bool = False, lora: Optional[LoraLayer] = None, ): super().__init__( @@ -688,6 +689,7 @@ class DeepseekV3Linear(Linear): skip_create_weights_in_init, use_custom_cublas_mm, lora, + use_cute_dsl_blockscaling_mm=use_cute_dsl_blockscaling_mm, ) def apply_linear(self, @@ -748,7 +750,10 @@ class DeepseekV3Attention(MLA): quant_config=model_config.get_quant_config(), skip_create_weights_in_init=model_config. skip_create_weights_in_init, - use_custom_cublas_mm=True) + use_custom_cublas_mm=True, + use_cute_dsl_blockscaling_mm=model_config. + use_cute_dsl_blockscaling_mm, + ) class DeepseekV32Attention(MLA): @@ -925,6 +930,7 @@ class Deepseekv3MoE(nn.Module): config = model_config.pretrained_config self.top_k = top_k self.use_dp = model_config.mapping.enable_attention_dp + self.use_cute_dsl_blockscaling_mm = model_config.use_cute_dsl_blockscaling_mm gate_cls = DeepseekV3Gate if hasattr(model_config.pretrained_config, "gate_cls"): gate_cls = model_config.pretrained_config.gate_cls @@ -977,7 +983,9 @@ class Deepseekv3MoE(nn.Module): dtype=dtype, config=model_config, overridden_tp_size=shared_tp_size, - reduce_output=False) + reduce_output=False, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm, + ) self.allreduce = None if not self.use_dp and self.mapping.tp_size > 1: @@ -1262,13 +1270,17 @@ class DeepseekV3DecoderLayer(DecoderLayer): self.fusion_config.PRE_MLP_FUSION = self.enable_fusion and has_mlp_tp and self.is_nvfp4 self.fusion_config.POST_MLP_FUSION = self.enable_fusion and has_mlp_tp - self.mlp = GatedMLP(hidden_size=config.hidden_size, - intermediate_size=config.intermediate_size, - bias=False, - dtype=config.torch_dtype, - config=model_config, - overridden_tp_size=self.mlp_tp_size, - reduce_output=has_mlp_tp) + self.mlp = GatedMLP( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + bias=False, + dtype=config.torch_dtype, + config=model_config, + overridden_tp_size=self.mlp_tp_size, + reduce_output=has_mlp_tp, + use_cute_dsl_blockscaling_mm=model_config. + use_cute_dsl_blockscaling_mm, + ) self.input_layernorm = RMSNorm(hidden_size=config.hidden_size, eps=config.rms_norm_eps, @@ -1564,6 +1576,8 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer): dtype=config.torch_dtype, skip_create_weights_in_init=model_config. skip_create_weights_in_init, + use_cute_dsl_blockscaling_mm=model_config. + use_cute_dsl_blockscaling_mm, ) else: self.eh_proj = Linear( @@ -1576,6 +1590,8 @@ class DeepseekV3MTP(DeepseekV3DecoderLayer): reduce_output=True, skip_create_weights_in_init=model_config. skip_create_weights_in_init, + use_cute_dsl_blockscaling_mm=model_config. + use_cute_dsl_blockscaling_mm, ) self.shared_head = DeepseekV3MTPHead(model_config) diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index 3f571d9d7e..c0739b7c69 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -255,6 +255,9 @@ class Attention(nn.Module): self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_key_value_heads * self.head_dim + self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm + self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm + qkv_shard_indices_mapping = { "q": (0, self.q_size * (2 if self.attn_output_gate else 1)), "k": @@ -280,7 +283,8 @@ class Attention(nn.Module): force_dynamic_quantization=config.force_dynamic_quantization, disable_deep_gemm=disable_deep_gemm, use_custom_cublas_mm=use_custom_cublas_mm, - fused_weight_shard_indices_mapping=qkv_shard_indices_mapping) + fused_weight_shard_indices_mapping=qkv_shard_indices_mapping, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE], [self.hidden_size]) @@ -299,7 +303,8 @@ class Attention(nn.Module): allreduce_strategy=config.allreduce_strategy, force_dynamic_quantization=config.force_dynamic_quantization, disable_deep_gemm=disable_deep_gemm, - use_custom_cublas_mm=use_custom_cublas_mm) + use_custom_cublas_mm=use_custom_cublas_mm, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) self.quant_config = config.get_quant_config() self.attn_backend = config.attn_backend @@ -686,6 +691,7 @@ def fp8_block_scaling_bmm_out( mat2_scale: torch.Tensor, out: torch.Tensor, mat2_dequant: Optional[torch.Tensor] = None, + use_cute_dsl_blockscaling_bmm: bool = False, ) -> torch.Tensor: sm_version = get_sm_version() if sm_version == 90 or sm_version == 89: @@ -706,7 +712,17 @@ def fp8_block_scaling_bmm_out( output) out.copy_(output) elif is_sm_100f(sm_version): - torch.bmm(mat1.transpose(0, 1), mat2_dequant.transpose(1, 2), out=out) + if use_cute_dsl_blockscaling_bmm: + mat1_fp8, mat1_scale = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102( + mat1) + torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(mat1_fp8, mat2_fp8, + mat1_scale, mat2_scale, + out) + mat1_scale = None + else: + torch.bmm(mat1.transpose(0, 1), + mat2_dequant.transpose(1, 2), + out=out) else: raise NotImplementedError(f"SM{sm_version} is not supported") @@ -851,6 +867,9 @@ class MLA(nn.Module): quant_config = config.get_quant_config() self.quant_config = quant_config + self.use_cute_dsl_blockscaling_mm = config.use_cute_dsl_blockscaling_mm + self.use_cute_dsl_blockscaling_bmm = config.use_cute_dsl_blockscaling_bmm + if not self.is_lite: self.kv_a_proj_with_mqa = Linear( hidden_size, @@ -860,7 +879,8 @@ class MLA(nn.Module): quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, use_custom_cublas_mm=True, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) self.q_a_layernorm = RMSNorm(hidden_size=self.q_lora_rank, eps=rms_norm_eps, @@ -876,7 +896,8 @@ class MLA(nn.Module): quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) else: self.kv_a_proj_with_mqa = Linear( hidden_size, @@ -886,7 +907,8 @@ class MLA(nn.Module): quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, use_custom_cublas_mm=True, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) self.q_proj = Linear( self.q_lora_rank, @@ -898,7 +920,8 @@ class MLA(nn.Module): quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) self.q_b_proj = self.q_proj self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank, @@ -915,7 +938,8 @@ class MLA(nn.Module): quant_config=quant_config, skip_create_weights_in_init=config.skip_create_weights_in_init, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) # This parameter will view into self.kv_b_proj.weight after loading weights. # For dummy weight initialization, this parameter is initialized with empty tensor. # Used in forward_absorption only @@ -947,7 +971,8 @@ class MLA(nn.Module): skip_create_weights_in_init=config.skip_create_weights_in_init, reduce_output=reduce_output, allreduce_strategy=config.allreduce_strategy, - force_dynamic_quantization=config.force_dynamic_quantization) + force_dynamic_quantization=config.force_dynamic_quantization, + use_cute_dsl_blockscaling_mm=self.use_cute_dsl_blockscaling_mm) def yarn_get_mscale(scale=1, mscale=1): if scale <= 1: @@ -1083,7 +1108,7 @@ class MLA(nn.Module): ), requires_grad=False, ) - if is_sm_100f(): + if is_sm_100f() and not self.use_cute_dsl_blockscaling_bmm: assert self.dtype == torch.bfloat16 self.k_b_proj_trans_dequant = nn.Parameter( torch.empty( @@ -1875,6 +1900,7 @@ class MLA(nn.Module): self.k_b_proj_trans_scale, q_nope_out, self.k_b_proj_trans_dequant, + self.use_cute_dsl_blockscaling_bmm, ), lambda: self.mqa.mla_rope_generation( fused_q, @@ -1952,6 +1978,7 @@ class MLA(nn.Module): self.v_b_proj_scale, attn_output.transpose(0, 1), self.v_b_proj_dequant, + self.use_cute_dsl_blockscaling_bmm, ) else: raise NotImplementedError( @@ -2007,6 +2034,7 @@ class MLA(nn.Module): self.k_b_proj_trans_scale, q_nope_out, self.k_b_proj_trans_dequant, + self.use_cute_dsl_blockscaling_bmm, ) else: raise NotImplementedError( @@ -2062,6 +2090,7 @@ class MLA(nn.Module): self.v_b_proj_scale, attn_output.transpose(0, 1), self.v_b_proj_dequant, + self.use_cute_dsl_blockscaling_bmm, ) else: raise NotImplementedError( @@ -2129,6 +2158,7 @@ class MLA(nn.Module): self.k_b_proj_trans_scale, q_nope_out, self.k_b_proj_trans_dequant, + self.use_cute_dsl_blockscaling_bmm, ) else: raise NotImplementedError( @@ -2205,6 +2235,7 @@ class MLA(nn.Module): self.v_b_proj_scale, attn_output.transpose(0, 1), self.v_b_proj_dequant, + self.use_cute_dsl_blockscaling_bmm, ) else: raise NotImplementedError( diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index f697ea5e15..61ce2de2b7 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -982,10 +982,9 @@ class FP8BlockScalesLinearMethod(UnquantizedLinearMethod): if is_sm_100f(): if module.use_cute_dsl_blockscaling_mm or module.disable_deep_gemm: - # TODO (@lmin): replace with cute_dsl gemm act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128( input) - output = torch.ops.trtllm.fp8_block_scaling_gemm( + output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell( act_input_fp8, module.weight, act_input_sf, module.weight_scale) else: diff --git a/tensorrt_llm/_torch/pyexecutor/model_loader.py b/tensorrt_llm/_torch/pyexecutor/model_loader.py index 8e90c8a278..d4f73e1c15 100644 --- a/tensorrt_llm/_torch/pyexecutor/model_loader.py +++ b/tensorrt_llm/_torch/pyexecutor/model_loader.py @@ -364,7 +364,12 @@ class ModelLoader: use_low_precision_moe_combine=self.llm_args.moe_config. use_low_precision_moe_combine, nvfp4_gemm_allowed_backends=self.llm_args.nvfp4_gemm_config. - allowed_backends) + allowed_backends, + use_cute_dsl_blockscaling_mm=self.llm_args. + use_cute_dsl_blockscaling_mm, + use_cute_dsl_blockscaling_bmm=self.llm_args. + use_cute_dsl_blockscaling_bmm, + ) # Only pass model_kwargs if it's explicitly set (not None) if self.llm_args.model_kwargs is not None: diff --git a/tensorrt_llm/_torch/utils.py b/tensorrt_llm/_torch/utils.py index 1f3b1cc9e9..d441cbd1da 100644 --- a/tensorrt_llm/_torch/utils.py +++ b/tensorrt_llm/_torch/utils.py @@ -301,6 +301,16 @@ def fp4_unswizzled_scale_infer_shape(input_shapes: List[List[int]]): return scale_shape * 2 +def fp8_scale_infer_shape(input_shapes: List[List[int]]): + """Calculate the dimensions of the fp8 scale tensor. + """ + input_shape = input_shapes[0] + assert len(input_shape) == 2 or len(input_shape) == 3 + has_batch = len(input_shape) == 3 + m = input_shape[-2] + return pad_up(m, 4) if has_batch else m + + _enable_piecewise_cuda_graph = True diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index 90010b6fa3..39f1cdc80a 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -3050,6 +3050,18 @@ class TorchLlmArgs(BaseLlmArgs): "Only enable it if you intend to use this feature.", status="prototype") + # fp8 cute dsl configs + use_cute_dsl_blockscaling_mm: bool = Field( + default=False, + description="If true, use CuTe DSL fp8 blockscaling mm implementation.", + status="prototype", + ) + use_cute_dsl_blockscaling_bmm: bool = Field( + default=False, + description="If true, use CuTe DSL fp8 blockscaling bmm implementation.", + status="prototype", + ) + # PrivateVars _quant_config: Optional[QuantConfig] = PrivateAttr(default=None) diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 672906b05d..84c21313ba 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -1533,6 +1533,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): cuda_graph_config=CudaGraphConfig() if cuda_graph else None, torch_compile_config=torch_compile_config, moe_config=MoeConfig(backend="CUTEDSL"), + use_cute_dsl_blockscaling_mm=True, + use_cute_dsl_blockscaling_bmm=True, ) if fp8kv: @@ -1695,6 +1697,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): cuda_graph_config=CudaGraphConfig() if cuda_graph else None, torch_compile_config=torch_compile_config, moe_config=MoeConfig(backend="CUTEDSL"), + use_cute_dsl_blockscaling_mm=True, + use_cute_dsl_blockscaling_bmm=True, ) if fp8kv: diff --git a/tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py b/tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py index 74401c3818..a6d9cfe0b6 100644 --- a/tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py +++ b/tests/unittest/_torch/thop/parallel/test_fp8_block_scale_gemm.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -109,6 +109,54 @@ def test_fp8_block_scale_gemm(dtype, m, k, n): torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3) +@pytest.mark.skipif( + not isSM100Family(), + reason="The test is for Blackwell. Current SM is %d." % getSMVersion(), +) +@pytest.mark.parametrize( + "k, n", + [(7168, 2112), (1536, 24576), (512, 32768), (16384, 7168), (7168, 4096), + (2048, 7168), (1024, 1024)], +) +@pytest.mark.parametrize( + "m", + [7, 64, 128, 4096], +) +@pytest.mark.parametrize( + "dtype", + [torch.bfloat16], +) +@pytest.mark.parametrize( + "use_tvm_ffi", + [True, False], +) +def test_cute_dsl_fp8_block_scale_gemm(dtype, m, k, n, use_tvm_ffi): + + torch.random.manual_seed(0) + a = torch.randn((m, k), device='cuda', dtype=dtype) / k + b = torch.randn((n, k), device='cuda', dtype=dtype) / k + + act_a_fp8, act_a_sf = torch.ops.trtllm.fp8_quantize_1x128(a) + act_b_fp8, act_b_sf = per_block_cast_to_fp8(b) + + output_expected = a @ b.t() + + with autotune(): + cute_dsl_output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell( + act_a_fp8, act_b_fp8, act_a_sf, act_b_sf, use_tvm_ffi=use_tvm_ffi) + + # test Cute DSL kernel + cute_dsl_output = torch.ops.trtllm.cute_dsl_fp8_gemm_blackwell( + act_a_fp8, act_b_fp8, act_a_sf, act_b_sf, use_tvm_ffi=use_tvm_ffi) + + diff = calc_diff(cute_dsl_output, output_expected) + assert diff < 1e-3 + torch.testing.assert_close(cute_dsl_output, + output_expected, + atol=1e-3, + rtol=1e-3) + + @pytest.mark.skipif( getSMVersion() != 90 and getSMVersion() != 89 and getSMVersion() != 120, reason="The test is for Hopper and Ada only. Current SM is %d." % @@ -171,6 +219,69 @@ def test_fp8_block_scale_bmm(dtype, m, k, n, num_groups): torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3) +@pytest.mark.skipif( + not isSM100Family(), + reason="The test is for Blackwell. Current SM is %d." % getSMVersion(), +) +@pytest.mark.parametrize( + "k, n", + [(7168, 2112), (512, 32768), (16384, 7168), (2048, 7168)], +) +@pytest.mark.parametrize( + "m", + [7, 64, 128], +) +@pytest.mark.parametrize( + "num_groups", + [4, 8, 16], +) +@pytest.mark.parametrize( + "dtype", + [torch.bfloat16], +) +@pytest.mark.parametrize( + "use_tvm_ffi", + [True, False], +) +def test_cute_dsl_fp8_block_scale_bmm(dtype, m, k, n, num_groups, use_tvm_ffi): + + torch.random.manual_seed(0) + a = torch.randn((m, num_groups, k), device='cuda', dtype=dtype) / k + a_fp8, a_scales = torch.ops.trtllm.fp8_batched_quantize_1x128_permute102(a) + + b = torch.randn((num_groups, n, k), device='cuda', dtype=dtype) / k + b_fp8 = torch.zeros_like(b, device='cuda', dtype=torch.float8_e4m3fn) + b_scales = torch.zeros((num_groups, (n + 127) // 128, (k + 127) // 128), + device='cuda', + dtype=torch.float) + + for i in range(num_groups): + b_fp8[i], b_scales[i] = per_block_cast_to_fp8(b[i]) + + output_expected = torch.einsum('mgk,gnk->gmn', a, b) + output = torch.empty((num_groups, m, n), + device='cuda', + dtype=torch.bfloat16) + # tune + with autotune(): + torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(a_fp8, + b_fp8, + a_scales, + b_scales, + output, + use_tvm_ffi=use_tvm_ffi) + # run the tuned kernel + torch.ops.trtllm.cute_dsl_fp8_bmm_blackwell(a_fp8, + b_fp8, + a_scales, + b_scales, + output, + use_tvm_ffi=use_tvm_ffi) + diff = calc_diff(output, output_expected) + assert diff < 1e-3 + torch.testing.assert_close(output, output_expected, atol=1e-3, rtol=1e-3) + + def deepSeekFp8ComputeGemmReference(mM, mN, mK, valsC, dqSfsC, valsA, dqSfsA, valsB, dqSfsB, quantizeOutput, tileSize): for mi in range(mM): diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 3f7deb867f..fc34086889 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -239,6 +239,14 @@ methods: annotation: Optional[Dict[str, Any]] default: null status: prototype + use_cute_dsl_blockscaling_mm: + annotation: bool + default: False + status: prototype + use_cute_dsl_blockscaling_bmm: + annotation: bool + default: False + status: prototype return_annotation: None generate: parameters: