From 2989bf5b39674365fe2c4dcd489bf5f673fe94b8 Mon Sep 17 00:00:00 2001 From: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com> Date: Fri, 13 Feb 2026 17:39:24 -0800 Subject: [PATCH] [None][feat] Add new helix kernels for MNNVL-based codepath (#11433) Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com> --- cpp/tensorrt_llm/kernels/helixKernels.cu | 213 +++++++++++++- cpp/tensorrt_llm/kernels/helixKernels.h | 11 +- cpp/tensorrt_llm/thop/helixPostProcessOp.cpp | 132 +++++++-- tensorrt_llm/_torch/modules/attention.py | 72 +++-- .../accuracy/test_disaggregated_serving.py | 18 +- .../test_lists/qa/llm_function_core.txt | 12 +- .../test_lists/test-db/l0_dgx_b200.yml | 12 +- .../unittest/_torch/modules/test_mla_helix.py | 35 ++- .../thop/parallel/test_helix_postprocess.py | 269 ++++++++++++------ 9 files changed, 616 insertions(+), 158 deletions(-) diff --git a/cpp/tensorrt_llm/kernels/helixKernels.cu b/cpp/tensorrt_llm/kernels/helixKernels.cu index ed4e80a808..4612620bfd 100644 --- a/cpp/tensorrt_llm/kernels/helixKernels.cu +++ b/cpp/tensorrt_llm/kernels/helixKernels.cu @@ -80,8 +80,8 @@ static constexpr int NUM_PRE_LOAD = 8; // output: [num_tokens, num_heads * kv_lora_rank] (half) // gathered_o: [cp_size, num_tokens, num_heads * kv_lora_rank] (half) // gathered_stats: [cp_size, num_tokens, num_heads, 2] (fp32) -// note: we explicitly avoid using restrict here, to avoid getting ld.global.nc -// which may have longer latency +// Note: we explicitly avoid using restrict here, to avoid getting ld.global.nc, +// which may have longer latency. template __global__ void helix_postprocess_kernel( T* output, T const* gathered_o, float2 const* gathered_stats, int cp_size, int kv_lora_rank) @@ -217,10 +217,10 @@ static constexpr int MAX_KV_LORA_BYTES = (MAX_THREADS - WARP_SIZE) * BYTES_O_PER // output: [num_tokens, num_heads * kv_lora_rank] (half) // gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank] (half) // gathered_stats: [num_tokens, num_heads, cp_size, 2] (fp32) -// note: we explicitly avoid using restrict here, to avoid getting ld.global.nc -// which may have longer latency +// Note: we explicitly avoid using restrict here, to avoid getting ld.global.nc, +// which may have longer latency. template -__global__ void __launch_bounds__(MAX_THREADS) helix_postprocess_kernel_native( +__global__ void __launch_bounds__(MAX_THREADS) helix_postprocess_kernel_native_v1( T* output, T const* gathered_o, float2 const* gathered_stats, int cp_size, int kv_lora_rank) { // Each block processes one (token, head) @@ -358,6 +358,153 @@ __global__ void __launch_bounds__(MAX_THREADS) helix_postprocess_kernel_native( *reinterpret_cast(output_off) = *reinterpret_cast(output_typed); } +// Kernel: fused helix post-processing for cp_dim=1 layout (Version 2) +// output: [num_tokens, num_heads * kv_lora_rank] (half) +// gathered_o: [num_tokens, cp_size, num_heads, kv_lora_rank] (half) +// gathered_stats: [num_tokens, cp_size, num_heads, 2] (fp32) +// Note: we explicitly avoid using restrict here, to avoid getting ld.global.nc, +// which may have longer latency. +template +__global__ void __launch_bounds__(MAX_THREADS) helix_postprocess_kernel_native_v2( + T* output, T const* gathered_o, float2 const* gathered_stats, int cp_size, int kv_lora_rank) +{ + // Each block processes one (token, head) + // gridDim.x: num_tokens, gridDim.y: num_heads + // there are two separate types of warps: + // warp 0 calculates the correction values (one per cp_size) + // all other warps pre-load the gathered_o elements for the current token/head + // and once warp 0 is done, all other warps can start accumulating the output + static constexpr int NUM_O_PER_THREAD = BYTES_O_PER_THREAD / sizeof(T); + + int tok_idx = blockIdx.x; + int head_idx = blockIdx.y; + int num_tokens = gridDim.x; + int num_heads = gridDim.y; + + int const cp_size_aligned = ((cp_size + NUM_PRE_LOAD - 1) / NUM_PRE_LOAD) * NUM_PRE_LOAD; + __shared__ float smem_correction[MAX_CP]; + + int lane_idx = threadIdx.x % WARP_SIZE; + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0); + + // For cp_dim=1 layout: [num_tokens, cp_size, num_heads, kv_lora_rank] + // Pointer offsets differ from cp_dim=2 + T const* gathered_o_off; + gathered_o_off = gathered_o + tok_idx * cp_size * num_heads * kv_lora_rank + head_idx * kv_lora_rank; + // we subtract WARP_SIZE because first warp is not participating in pre-load + gathered_o_off += (threadIdx.x - WARP_SIZE) * NUM_O_PER_THREAD; + float4 const* gathered_o_16b = reinterpret_cast(gathered_o_off); + // For cp_dim=1: stride between cp entries is num_heads * kv_lora_rank + int gathered_16b_stride = (num_heads * kv_lora_rank) / NUM_O_PER_THREAD; + // For cp_dim=1: stats layout is [num_tokens, cp_size, num_heads, 2] + int stats_offset = tok_idx * cp_size * num_heads + head_idx; + int stats_stride = num_heads; + + // here we have to wait for memory operations of the previous kernel to + // complete +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaGridDependencySynchronize(); +#endif + + float max_values[MAX_CP_VAL_PER_THREAD]; + float sum_values[MAX_CP_VAL_PER_THREAD]; + T vals[NUM_PRE_LOAD][NUM_O_PER_THREAD]; + float final_sum[NUM_O_PER_THREAD]; + float corr_vals[NUM_PRE_LOAD]; + T output_typed[NUM_O_PER_THREAD]; + + if (warp_idx == 0) + { + // the warp collectively calculates the correction values +#pragma unroll + for (int cp_val_idx = 0; cp_val_idx < MAX_CP_VAL_PER_THREAD; ++cp_val_idx) + { + auto cp_idx = cp_val_idx * WARP_SIZE + lane_idx; + auto stats_idx = stats_offset + cp_idx * stats_stride; + float2 stats = cp_idx < cp_size ? gathered_stats[stats_idx] : make_float2(-INFINITY, 0.F); + max_values[cp_val_idx] = stats.x; + sum_values[cp_val_idx] = stats.y; + } + float corrected_values[MAX_CP_VAL_PER_THREAD]; + warpReduceCorrectedSum(corrected_values, max_values, sum_values); +#pragma unroll + for (int cp_val_idx = 0; cp_val_idx < MAX_CP_VAL_PER_THREAD; ++cp_val_idx) + { + auto cp_idx = cp_val_idx * WARP_SIZE + lane_idx; + smem_correction[cp_idx] = corrected_values[cp_val_idx]; + } + } + else + { + // all other warps pre-load the gathered_o elements +#pragma unroll + for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx) + { + auto val = gathered_o_16b[cp_idx * gathered_16b_stride]; + *reinterpret_cast(vals[cp_idx]) = val; + } +#pragma unroll + for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx) + { + final_sum[o_idx] = 0.F; + } + } + __syncthreads(); + + // warp 0 exits early + if (warp_idx == 0) + return; + + // here we can trigger the dependent kernels to start +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + +#pragma unroll + for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx) + { + corr_vals[cp_idx] = smem_correction[cp_idx]; + } + + for (int cp_idx_base = NUM_PRE_LOAD; cp_idx_base < cp_size_aligned; cp_idx_base += NUM_PRE_LOAD) + { +#pragma unroll + for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD; ++cp_idx) + { +#pragma unroll + for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx) + { + final_sum[o_idx] += static_cast(vals[cp_idx][o_idx]) * corr_vals[cp_idx]; + } + } +#pragma unroll + for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD; ++cp_idx) + { + *reinterpret_cast(vals[cp_idx]) = cp_idx_base + cp_idx < cp_size + ? gathered_o_16b[(cp_idx_base + cp_idx) * gathered_16b_stride] + : make_float4(0.F, 0.F, 0.F, 0.F); + corr_vals[cp_idx] = cp_idx_base + cp_idx < cp_size ? smem_correction[cp_idx_base + cp_idx] : 0.F; + } + } +#pragma unroll + for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx) + { +#pragma unroll + for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx) + { + final_sum[o_idx] += static_cast(vals[cp_idx][o_idx]) * corr_vals[cp_idx]; + } + } +#pragma unroll + for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx) + { + output_typed[o_idx] = static_cast(final_sum[o_idx]); + } + auto* output_off = output + tok_idx * num_heads * kv_lora_rank + head_idx * kv_lora_rank; + output_off += (threadIdx.x - WARP_SIZE) * NUM_O_PER_THREAD; + *reinterpret_cast(output_off) = *reinterpret_cast(output_typed); +} + } // anonymous namespace template @@ -394,21 +541,21 @@ INSTANTIATE_POST_PROC(__half); INSTANTIATE_POST_PROC(__nv_bfloat16); template -void helixPostProcessNative(HelixPostProcParams const& params, cudaStream_t stream) +void helixPostProcessNativeV1(HelixPostProcParams const& params, cudaStream_t stream) { - // Check that gathered_o is 16-byte aligned + // Check that gathered_o is 16-byte aligned. TLLM_CHECK_WITH_INFO(reinterpret_cast(params.gathered_o) % 16 == 0, "gathered_o must be 16-byte aligned for async memcpy"); - // TODO: Figure why this constraint is specific to this implementation and not legacy one. + // TODO: Figure out why this constraint is specific to this implementation and not the legacy one. TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) <= MAX_KV_LORA_BYTES, "kv_lora_rank * sizeof(T) must be <= %zu bytes", MAX_KV_LORA_BYTES); - // Check that kv_lora_rank * sizeof(T) is a multiple of 16 + // Check that kv_lora_rank * sizeof(T) is a multiple of 16. TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) % 16 == 0, "kv_lora_rank * sizeof(T) must be a multiple of 16 for async memcpy"); - // Check that cp_size is not larger than the max fallback CP size + // Check that cp_size is not larger than the max fallback CP size. TLLM_CHECK_WITH_INFO(params.cp_size <= MAX_CP, "cp_size > fallback max CP size"); - auto kernel_instance = helix_postprocess_kernel_native; + auto kernel_instance = helix_postprocess_kernel_native_v1; cudaLaunchConfig_t config; config.gridDim = dim3(params.num_tokens, params.num_heads); config.blockDim = WARP_SIZE + params.kv_lora_rank * sizeof(T) / 16; @@ -423,11 +570,47 @@ void helixPostProcessNative(HelixPostProcParams const& params, cudaStream_t s params.gathered_stats, params.cp_size, params.kv_lora_rank)); } -#define INSTANTIATE_POST_PROC_NATIVE(T) \ - template void helixPostProcessNative(HelixPostProcParams const& params, cudaStream_t stream); +#define INSTANTIATE_POST_PROC_NATIVE_V1(T) \ + template void helixPostProcessNativeV1(HelixPostProcParams const& params, cudaStream_t stream); -INSTANTIATE_POST_PROC_NATIVE(__half); -INSTANTIATE_POST_PROC_NATIVE(__nv_bfloat16); +INSTANTIATE_POST_PROC_NATIVE_V1(__half); +INSTANTIATE_POST_PROC_NATIVE_V1(__nv_bfloat16); + +template +void helixPostProcessNativeV2(HelixPostProcParams const& params, cudaStream_t stream) +{ + // Check that gathered_o is 16-byte aligned. + TLLM_CHECK_WITH_INFO(reinterpret_cast(params.gathered_o) % 16 == 0, + "gathered_o must be 16-byte aligned for async memcpy"); + // TODO: Figure out why this constraint is specific to this implementation and not the legacy one. + TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) <= MAX_KV_LORA_BYTES, + "kv_lora_rank * sizeof(T) must be <= %zu bytes", MAX_KV_LORA_BYTES); + // Check that kv_lora_rank * sizeof(T) is a multiple of 16. + TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) % 16 == 0, + "kv_lora_rank * sizeof(T) must be a multiple of 16 for async memcpy"); + // Check that cp_size is not larger than the max fallback CP size. + TLLM_CHECK_WITH_INFO(params.cp_size <= MAX_CP, "cp_size > fallback max CP size"); + + auto kernel_instance = helix_postprocess_kernel_native_v2; + cudaLaunchConfig_t config; + config.gridDim = dim3(params.num_tokens, params.num_heads); + config.blockDim = WARP_SIZE + params.kv_lora_rank * sizeof(T) / 16; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = common::getEnvEnablePDL(); + config.numAttrs = 1; + config.attrs = attrs; + TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_instance, params.output, params.gathered_o, + params.gathered_stats, params.cp_size, params.kv_lora_rank)); +} + +#define INSTANTIATE_POST_PROC_NATIVE_V2(T) \ + template void helixPostProcessNativeV2(HelixPostProcParams const& params, cudaStream_t stream); + +INSTANTIATE_POST_PROC_NATIVE_V2(__half); +INSTANTIATE_POST_PROC_NATIVE_V2(__nv_bfloat16); } // namespace kernels diff --git a/cpp/tensorrt_llm/kernels/helixKernels.h b/cpp/tensorrt_llm/kernels/helixKernels.h index 12036438b7..8b2f5f6bde 100644 --- a/cpp/tensorrt_llm/kernels/helixKernels.h +++ b/cpp/tensorrt_llm/kernels/helixKernels.h @@ -43,8 +43,17 @@ struct HelixPostProcParams template void helixPostProcess(HelixPostProcParams const& params, cudaStream_t stream); +// Version 1: cp_dim=2 layout. +// gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank]. +// gathered_stats: [num_tokens, num_heads, cp_size, 2]. template -void helixPostProcessNative(HelixPostProcParams const& params, cudaStream_t stream); +void helixPostProcessNativeV1(HelixPostProcParams const& params, cudaStream_t stream); + +// Version 2: cp_dim=1 layout. +// gathered_o: [num_tokens, cp_size, num_heads, kv_lora_rank]. +// gathered_stats: [num_tokens, cp_size, num_heads, 2]. +template +void helixPostProcessNativeV2(HelixPostProcParams const& params, cudaStream_t stream); } // namespace kernels diff --git a/cpp/tensorrt_llm/thop/helixPostProcessOp.cpp b/cpp/tensorrt_llm/thop/helixPostProcessOp.cpp index b0d25e38c9..77b97904b3 100644 --- a/cpp/tensorrt_llm/thop/helixPostProcessOp.cpp +++ b/cpp/tensorrt_llm/thop/helixPostProcessOp.cpp @@ -100,20 +100,14 @@ torch::Tensor helix_post_process(torch::Tensor const& gathered_o, torch::Tensor } template -inline torch::Tensor helix_post_process_native_impl( - torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, int cp_dim, Fn fn) +inline torch::Tensor helix_post_process_native_impl_v1( + torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, Fn fn) { CHECK_TH_CUDA(gathered_o); CHECK_CONTIGUOUS(gathered_o); CHECK_TH_CUDA(gathered_stats); CHECK_CONTIGUOUS(gathered_stats); - // Only cp_dim=2 is supported - TORCH_CHECK(cp_dim == 2, - "cp_dim must be 2. Expects tensor shapes to be: \n" - "gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank], \n" - "gathered_stats: [num_tokens, num_heads, cp_size, 2]"); - // For cp_dim=2: tokens_dim=0, heads_dim=1 auto tokens_dim = 0; auto heads_dim = 1; @@ -126,14 +120,14 @@ inline torch::Tensor helix_post_process_native_impl( auto const cp_size = gathered_stats.sizes()[2]; auto const kv_lora_rank = gathered_o.sizes()[3]; - // check remaining input tensor dimensions + // Check remaining input tensor dimensions. TORCH_CHECK(gathered_o.sizes()[2] == cp_size, "gathered_o cp_size dim must match"); TORCH_CHECK(gathered_o.sizes()[tokens_dim] == num_tokens, "gathered_o tokens_dim must match num_tokens"); TORCH_CHECK(gathered_o.sizes()[heads_dim] == num_heads, "gathered_o heads_dim must match num_heads"); TORCH_CHECK(gathered_stats.sizes()[3] == 2, "gathered_stats last dimension must be 2"); - // Check data types + // Check data types. TORCH_CHECK( gathered_o.scalar_type() == at::ScalarType::Half || gathered_o.scalar_type() == at::ScalarType::BFloat16, "gathered_o must be half or bfloat16"); @@ -143,15 +137,75 @@ inline torch::Tensor helix_post_process_native_impl( // memcpy) TORCH_CHECK(reinterpret_cast(gathered_o.data_ptr()) % 16 == 0, "gathered_o must be 16-byte aligned"); - // Check that kv_lora_rank * sizeof(data_type) is a multiple of 16 + // Check that kv_lora_rank * sizeof(data_type) is a multiple of 16. size_t data_type_size = torch::elementSize(gathered_o.scalar_type()); TORCH_CHECK((kv_lora_rank * data_type_size) % 16 == 0, "kv_lora_rank * sizeof(data_type) must be a multiple of 16"); - // Create output tensor + // Create output tensor. std::vector output_shape = {num_tokens, num_heads * kv_lora_rank}; torch::Tensor output = torch::empty(output_shape, gathered_o.options()); - // Get CUDA stream + // Get CUDA stream. + auto stream = at::cuda::getCurrentCUDAStream(gathered_o.get_device()); + + tensorrt_llm::kernels::HelixPostProcParams params{reinterpret_cast(output.mutable_data_ptr()), + reinterpret_cast(gathered_o.data_ptr()), reinterpret_cast(gathered_stats.data_ptr()), + static_cast(cp_size), static_cast(num_tokens), static_cast(num_heads), + static_cast(kv_lora_rank)}; + fn(params, stream); + + if (scale != 1.0) + { + output *= scale; + } + + return output; +} + +template +inline torch::Tensor helix_post_process_native_impl_v2( + torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, Fn fn) +{ + CHECK_TH_CUDA(gathered_o); + CHECK_CONTIGUOUS(gathered_o); + CHECK_TH_CUDA(gathered_stats); + CHECK_CONTIGUOUS(gathered_stats); + + // For cp_dim=1: gathered_o: [num_tokens, cp_size, num_heads, kv_lora_rank] + // gathered_stats: [num_tokens, cp_size, num_heads, 2] + TORCH_CHECK(gathered_o.dim() == 4, "gathered_o must be 4D tensor [num_tokens, cp_size, num_heads, kv_lora_rank]"); + TORCH_CHECK(gathered_stats.dim() == 4, "gathered_stats must be 4D tensor [num_tokens, cp_size, num_heads, 2]"); + + auto const num_tokens = gathered_o.sizes()[0]; + auto const cp_size = gathered_o.sizes()[1]; + auto const num_heads = gathered_o.sizes()[2]; + auto const kv_lora_rank = gathered_o.sizes()[3]; + + // Check remaining input tensor dimensions. + TORCH_CHECK(gathered_stats.sizes()[0] == num_tokens, "gathered_stats num_tokens dim must match"); + TORCH_CHECK(gathered_stats.sizes()[1] == cp_size, "gathered_stats cp_size dim must match"); + TORCH_CHECK(gathered_stats.sizes()[2] == num_heads, "gathered_stats num_heads dim must match"); + TORCH_CHECK(gathered_stats.sizes()[3] == 2, "gathered_stats last dimension must be 2"); + + // Check data types. + TORCH_CHECK( + gathered_o.scalar_type() == at::ScalarType::Half || gathered_o.scalar_type() == at::ScalarType::BFloat16, + "gathered_o must be half or bfloat16"); + TORCH_CHECK(gathered_stats.scalar_type() == at::ScalarType::Float, "gathered_stats must be float32"); + + // Check alignment requirements for gathered_o (16-byte aligned for async + // memcpy). + TORCH_CHECK(reinterpret_cast(gathered_o.data_ptr()) % 16 == 0, "gathered_o must be 16-byte aligned"); + + // Check that kv_lora_rank * sizeof(data_type) is a multiple of 16. + size_t data_type_size = torch::elementSize(gathered_o.scalar_type()); + TORCH_CHECK((kv_lora_rank * data_type_size) % 16 == 0, "kv_lora_rank * sizeof(data_type) must be a multiple of 16"); + + // Create output tensor. + std::vector output_shape = {num_tokens, num_heads * kv_lora_rank}; + torch::Tensor output = torch::empty(output_shape, gathered_o.options()); + + // Get CUDA stream. auto stream = at::cuda::getCurrentCUDAStream(gathered_o.get_device()); tensorrt_llm::kernels::HelixPostProcParams params{reinterpret_cast(output.mutable_data_ptr()), @@ -171,24 +225,54 @@ inline torch::Tensor helix_post_process_native_impl( inline torch::Tensor helix_post_process_native( torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, int64_t cp_dim) { - TORCH_CHECK(cp_dim == 2, "cp_dim must be 2. Only cp_dim=2 layout is supported."); - if (gathered_o.scalar_type() == at::ScalarType::Half) - { - return helix_post_process_native_impl<__half>( - gathered_o, gathered_stats, scale, int(cp_dim), tensorrt_llm::kernels::helixPostProcessNative<__half>); - } - else if (gathered_o.scalar_type() == at::ScalarType::BFloat16) + TORCH_CHECK(cp_dim == 1 || cp_dim == 2, + "cp_dim must be 1 or 2. \n" + "cp_dim=1: gathered_o: [num_tokens, cp_size, num_heads, kv_lora_rank] \n" + "cp_dim=2: gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank]"); + + if (cp_dim == 1) { + // Version 2 layout: [num_tokens, cp_size, num_heads, kv_lora_rank]. + if (gathered_o.scalar_type() == at::ScalarType::Half) + { + return helix_post_process_native_impl_v2<__half>( + gathered_o, gathered_stats, scale, tensorrt_llm::kernels::helixPostProcessNativeV2<__half>); + } + else if (gathered_o.scalar_type() == at::ScalarType::BFloat16) + { #ifdef ENABLE_BF16 - return helix_post_process_native_impl<__nv_bfloat16>(gathered_o, gathered_stats, scale, int(cp_dim), - tensorrt_llm::kernels::helixPostProcessNative<__nv_bfloat16>); + return helix_post_process_native_impl_v2<__nv_bfloat16>( + gathered_o, gathered_stats, scale, tensorrt_llm::kernels::helixPostProcessNativeV2<__nv_bfloat16>); #else - TLLM_THROW("BFloat16 must be enabled to use helix_post_process_native with bf16 tensors."); + TLLM_THROW("BFloat16 must be enabled to use helix_post_process_native with bf16 tensors."); #endif + } + else + { + TLLM_THROW("helix_post_process_native only supports half and bfloat16 tensors."); + } } else { - TLLM_THROW("helix_post_process_native only supports half and bfloat16 tensors."); + // Version 1 layout: [num_tokens, num_heads, cp_size, kv_lora_rank]. + if (gathered_o.scalar_type() == at::ScalarType::Half) + { + return helix_post_process_native_impl_v1<__half>( + gathered_o, gathered_stats, scale, tensorrt_llm::kernels::helixPostProcessNativeV1<__half>); + } + else if (gathered_o.scalar_type() == at::ScalarType::BFloat16) + { +#ifdef ENABLE_BF16 + return helix_post_process_native_impl_v1<__nv_bfloat16>( + gathered_o, gathered_stats, scale, tensorrt_llm::kernels::helixPostProcessNativeV1<__nv_bfloat16>); +#else + TLLM_THROW("BFloat16 must be enabled to use helix_post_process_native with bf16 tensors."); +#endif + } + else + { + TLLM_THROW("helix_post_process_native only supports half and bfloat16 tensors."); + } } } diff --git a/tensorrt_llm/_torch/modules/attention.py b/tensorrt_llm/_torch/modules/attention.py index f944f50126..59f7f1493c 100644 --- a/tensorrt_llm/_torch/modules/attention.py +++ b/tensorrt_llm/_torch/modules/attention.py @@ -1195,29 +1195,63 @@ class MLA(nn.Module): num_tokens = partial_o.shape[0] cp_size = self.mapping.cp_size - # Reshape for FIFO-based all-to-all. - # partial_o: [num_tokens, num_heads * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp, kv_lora_rank] - # softmax_stats: [num_tokens, num_heads, 2] -> [num_tokens, cp_size, num_heads_tp_cp, 2] + # Check which FIFO version to use (default: version 2 for better performance). + fifo_version = self.mapping.cp_config.get("fifo_version", 2) - partial_o = partial_o.view( - num_tokens, cp_size, self.num_heads_tp_cp, - kv_lora_rank).transpose(1, 2).contiguous() - softmax_stats = softmax_stats.view(num_tokens, cp_size, - self.num_heads_tp_cp, - 2).transpose(1, - 2).contiguous() + if fifo_version == 1: + # Version 1: Uses transpose+contiguous before alltoall, cp_dim=2. + # Reshape for FIFO-based all-to-all. Overlap the two .contiguous() calls on separate streams. + # partial_o: [num_tokens, num_heads * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp, kv_lora_rank] + # softmax_stats: [num_tokens, num_heads, 2] -> [num_tokens, cp_size, num_heads_tp_cp, 2] + partial_o, softmax_stats = maybe_execute_in_parallel( + lambda: partial_o.view(num_tokens, cp_size, self. + num_heads_tp_cp, kv_lora_rank). + transpose(1, 2).contiguous(), + lambda: softmax_stats.view(num_tokens, cp_size, self. + num_heads_tp_cp, 2). + transpose(1, 2).contiguous(), + self.ln_events[0], + self.ln_events[1], + self.aux_stream, + ) - # Call FIFO-based helixAllToAll. - partial_o_out, softmax_stats_out = helix.alltoall_native( - partial_o, softmax_stats) + # Call FIFO-based helixAllToAll. + partial_o_out, softmax_stats_out = helix.alltoall_native( + partial_o, softmax_stats) - # partial_o_out: [num_tokens, num_heads_tp_cp, cp_size, kv_lora_rank] - # softmax_stats_out: [num_tokens, num_heads_tp_cp, cp_size, 2] - # cp_dim = 2 (the dimension where cp_size is located) + # partial_o_out: [num_tokens, num_heads_tp_cp, cp_size, kv_lora_rank] + # softmax_stats_out: [num_tokens, num_heads_tp_cp, cp_size, 2] + # cp_dim = 2 (the dimension where cp_size is located) - # Call helix_post_process_native with cp_dim=2. - return torch.ops.trtllm.helix_post_process_native( - partial_o_out, softmax_stats_out, 1.0, 2) + # Call helix_post_process_native V1 (cp_dim=2). + return torch.ops.trtllm.helix_post_process_native( + partial_o_out, softmax_stats_out, 1.0, 2) + else: + # Version 2: Uses simple view (no transpose+contiguous) for better performance. + # partial_o: [num_tokens, num_heads * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp * kv_lora_rank] + # softmax_stats: [num_tokens, num_heads, 2] -> [num_tokens, cp_size, num_heads_tp_cp * 2] + partial_o = partial_o.view( + num_tokens, cp_size, + self.num_heads_tp_cp * kv_lora_rank) + softmax_stats = softmax_stats.view(num_tokens, cp_size, + self.num_heads_tp_cp * 2) + + # Call FIFO-based helixAllToAll. + partial_o_out, softmax_stats_out = helix.alltoall_native( + partial_o, softmax_stats) + + # Reshape after alltoall for post-processing. + # partial_o_out: [num_tokens, cp_size, num_heads_tp_cp * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp, kv_lora_rank] + # softmax_stats_out: [num_tokens, cp_size, num_heads_tp_cp * 2] -> [num_tokens, cp_size, num_heads_tp_cp, 2] + gathered_o = partial_o_out.view(num_tokens, cp_size, + self.num_heads_tp_cp, + kv_lora_rank) + gathered_stats = softmax_stats_out.view( + num_tokens, cp_size, self.num_heads_tp_cp, 2) + + # Call helix_post_process_native V2 (cp_dim=1). + return torch.ops.trtllm.helix_post_process_native( + gathered_o, gathered_stats, 1.0, 1) else: attn_output = attn_backend.forward(q, k, v, attn_metadata, **kwargs) return attn_output diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index c5e5d584d7..9c31d8bddc 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -982,10 +982,21 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): "cudagraph:none", "cudagraph:without_padding", "cudagraph:with_padding" ]) - @pytest.mark.parametrize("comms_medium", ["fifo", "nccl"]) + @pytest.mark.parametrize("comms_medium", ["fifo_v1", "fifo_v2", "nccl"]) def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config, gen_pp, gen_tp, gen_cp, enable_attention_dp): - use_nccl_for_alltoall = comms_medium == "nccl" + # Parse comms_medium to get use_nccl_for_alltoall and fifo_version. + if comms_medium == "nccl": + use_nccl_for_alltoall = True + fifo_version = 2 # Not used when NCCL is enabled. + elif comms_medium == "fifo_v1": + use_nccl_for_alltoall = False + fifo_version = 1 + elif comms_medium == "fifo_v2": + use_nccl_for_alltoall = False + fifo_version = 2 + else: + raise ValueError(f"Unknown comms_medium: {comms_medium}") gen_ep = gen_tp * gen_cp kv_cache_config = { "free_gpu_memory_fraction": 0.5, @@ -1014,7 +1025,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness): "cp_config": { "cp_type": "HELIX", "tokens_per_block": 32, - "use_nccl_for_alltoall": use_nccl_for_alltoall + "use_nccl_for_alltoall": use_nccl_for_alltoall, + "fifo_version": fifo_version, }, "disable_overlap_scheduler": True, "kv_cache_config": kv_cache_config, diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index de47aab6ec..99368ccc95 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -293,13 +293,17 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp1tp2cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1tp2cp2] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp1tp1cp4] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1tp1cp4] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp2tp1cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp2tp1cp2] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1dp2cp2] -accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp1dp2cp2] +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1dp2cp2] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b200.yml b/tests/integration/test_lists/test-db/l0_dgx_b200.yml index 1592d1247f..fc4bbd824a 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b200.yml @@ -75,10 +75,10 @@ l0_dgx_b200: backend: pytorch orchestrator: mpi tests: - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60) - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1dp2cp2] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60) @@ -108,8 +108,8 @@ l0_dgx_b200: backend: pytorch orchestrator: mpi tests: - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) - - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60) + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60) - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1dp2cp2] TIMEOUT (60) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60) diff --git a/tests/unittest/_torch/modules/test_mla_helix.py b/tests/unittest/_torch/modules/test_mla_helix.py index a6a0d5202e..5a386b6f7f 100644 --- a/tests/unittest/_torch/modules/test_mla_helix.py +++ b/tests/unittest/_torch/modules/test_mla_helix.py @@ -649,7 +649,8 @@ def _full_test_multi_gpu( world_size: int, scenario: Scenario, gen_steps: int, - comms_medium: str = False, + use_nccl_for_alltoall: bool = False, + fifo_version: int = 2, ): if scenario.rope_scaling: rope_scaling = { @@ -825,7 +826,8 @@ def _full_test_multi_gpu( cp_size=world_size, cp_config={ "cp_type": CpType.HELIX, - "use_nccl_for_alltoall": comms_medium == "nccl", + "use_nccl_for_alltoall": use_nccl_for_alltoall, + "fifo_version": fifo_version, }, ) # we use cp_allgather here because there is no broadcast op across CP group @@ -861,7 +863,7 @@ def _run_single_rank(func, *args, **kwargs): @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test") @pytest.mark.parametrize("scenario", test_scenarios, ids=lambda x: f"scenario: {x}") -@pytest.mark.parametrize("comms_medium", ["nccl", "fifo"]) +@pytest.mark.parametrize("comms_medium", ["nccl", "fifo_v1", "fifo_v2"]) def test_mla_helix_distributed( scenario: Scenario, comms_medium: str, @@ -872,11 +874,34 @@ def test_mla_helix_distributed( world_size = 2 print(f"Testing with comms_medium={comms_medium}.") gen_steps = scenario.ref_steps if gen_steps is None else gen_steps + + # Parse comms_medium to get use_nccl_for_alltoall and fifo_version. + if comms_medium == "nccl": + use_nccl_for_alltoall = True + fifo_version = 2 # Not used when NCCL is enabled. + elif comms_medium == "fifo_v1": + use_nccl_for_alltoall = False + fifo_version = 1 + elif comms_medium == "fifo_v2": + use_nccl_for_alltoall = False + fifo_version = 2 + else: + raise ValueError(f"Unknown comms_medium: {comms_medium}") + with MPIPoolExecutor(max_workers=world_size) as executor: results = executor.map( _run_single_rank, *zip( - *[(_full_test_multi_gpu, world_size, scenario, gen_steps, comms_medium == "nccl")] + *[ + ( + _full_test_multi_gpu, + world_size, + scenario, + gen_steps, + use_nccl_for_alltoall, + fifo_version, + ) + ] * world_size ), ) @@ -888,7 +913,7 @@ def test_mla_helix_distributed( if __name__ == "__main__": - for comms_medium in ["fifo", "nccl"]: + for comms_medium in ["fifo_v1", "fifo_v2", "nccl"]: print(f"\n{'=' * 60}") print(f"Testing with comms_medium={comms_medium}") print(f"{'=' * 60}\n") diff --git a/tests/unittest/_torch/thop/parallel/test_helix_postprocess.py b/tests/unittest/_torch/thop/parallel/test_helix_postprocess.py index 879ddb2b5b..9723bc547b 100644 --- a/tests/unittest/_torch/thop/parallel/test_helix_postprocess.py +++ b/tests/unittest/_torch/thop/parallel/test_helix_postprocess.py @@ -22,22 +22,40 @@ from parameterized import parameterized import tensorrt_llm -def baseline(gathered_o, gathered_stats, kv_lora_rank, scale, native=False): +def baseline(gathered_o, gathered_stats, kv_lora_rank, scale, native_v1=False, native_v2=False): """Reference implementation (libtorch) Args: gathered_o: Input tensor - - native=False: [cp_size, num_tokens, num_heads * kv_lora_rank] - - native=True: [num_tokens, num_heads, cp_size, kv_lora_rank] + - native_v1=False, native_v2=False: [cp_size, num_tokens, num_heads * kv_lora_rank] + - native_v1=True (cp_dim=2): [num_tokens, num_heads, cp_size, kv_lora_rank] + - native_v2=True (cp_dim=1): [num_tokens, cp_size, num_heads, kv_lora_rank] gathered_stats: Stats tensor - - native=False: [cp_size, num_tokens, num_heads, 2] - - native=True: [num_tokens, num_heads, cp_size, 2] + - native_v1=False, native_v2=False: [cp_size, num_tokens, num_heads, 2] + - native_v1=True (cp_dim=2): [num_tokens, num_heads, cp_size, 2] + - native_v2=True (cp_dim=1): [num_tokens, cp_size, num_heads, 2] kv_lora_rank: KV LoRA rank scale: Scale factor - native: Whether to use native layout (cp_dim=2) + native_v1: Whether to use native V1 layout (cp_dim=2). + native_v2: Whether to use native V2 layout (cp_dim=1). """ - if native: - # Native layout: cp_dim=2 + if native_v2: + # Native V2 layout: cp_dim=1. + # [num_tokens, cp_size, num_heads, 2] -> reduce over cp_size (dim=1). + global_max = gathered_stats[..., 0].max(dim=1, keepdim=True)[0] + corrected_max = gathered_stats[..., 0] - global_max + corrected_max_exp = torch.exp(corrected_max) + corrected_sum = gathered_stats[..., 1] * corrected_max_exp + global_sum = corrected_sum.sum(dim=1, keepdim=True) + correction = (gathered_stats[..., 1] * corrected_max_exp / global_sum).unsqueeze(-1) + gathered_o_fp32 = gathered_o.to(torch.float32) + corrected_o = gathered_o_fp32 * correction + # Sum over cp_size dimension (dim=1), result: [num_tokens, num_heads, kv_lora_rank] + corrected_o = corrected_o.sum(dim=1) + # Reshape to [num_tokens, num_heads * kv_lora_rank] + corrected_o = corrected_o.view(corrected_o.shape[0], -1) + elif native_v1: + # Native V1 layout: cp_dim=2. # [num_tokens, num_heads, cp_size] global_max = gathered_stats[..., 0].max(dim=-1, keepdim=True)[0] corrected_max = gathered_stats[..., 0] - global_max @@ -75,27 +93,54 @@ class TestHelixPostProcess(unittest.TestCase): torch.cuda.manual_seed(42) def _test_helix_postprocess( - self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native=False + self, + cp_size, + num_tokens, + num_heads, + kv_lora_rank, + scale, + dtype, + native_v1=False, + native_v2=False, ): - """Test helix postprocessing with given parameters + """Test helix postprocessing with given parameters. Args: - cp_size: Context parallelism size - num_tokens: Number of tokens - num_heads: Number of attention heads - kv_lora_rank: KV LoRA rank - scale: Scale factor - dtype: Data type (float16 or bfloat16) - native: Whether to use native layout (cp_dim=2) + cp_size: Context parallelism size. + num_tokens: Number of tokens. + num_heads: Number of attention heads. + kv_lora_rank: KV LoRA rank. + scale: Scale factor. + dtype: Data type (float16 or bfloat16). + native_v1: Whether to use native V1 layout (cp_dim=2). + native_v2: Whether to use native V2 layout (cp_dim=1). """ device = torch.device("cuda") - if native: - # Native layout: [num_tokens, num_heads, cp_size, kv_lora_rank] + if native_v2: + # Native V2 layout: [num_tokens, cp_size, num_heads, kv_lora_rank]. + gathered_o = torch.empty( + num_tokens, cp_size, num_heads, kv_lora_rank, dtype=dtype, device=device + ).uniform_(-1, 1) + # gathered_stats: [num_tokens, cp_size, num_heads, 2]. + gathered_stats = torch.empty( + num_tokens, cp_size, num_heads, 2, dtype=torch.float32, device=device + ) + gathered_o_max = torch.max(gathered_o, dim=-1, keepdim=True)[0] + gathered_stats[..., 0] = gathered_o_max[..., 0] + gathered_o_sum = torch.sum(torch.exp(gathered_o - gathered_o_max), dim=-1) + gathered_stats[..., 1] = gathered_o_sum + + # Call the custom operator with cp_dim=1 (V2). + output = torch.ops.trtllm.helix_post_process_native( + gathered_o, gathered_stats, scale, 1 + ) + elif native_v1: + # Native V1 layout: [num_tokens, num_heads, cp_size, kv_lora_rank]. gathered_o = torch.empty( num_tokens, num_heads, cp_size, kv_lora_rank, dtype=dtype, device=device ).uniform_(-1, 1) - # gathered_stats: [num_tokens, num_heads, cp_size, 2] + # gathered_stats: [num_tokens, num_heads, cp_size, 2]. gathered_stats = torch.empty( num_tokens, num_heads, cp_size, 2, dtype=torch.float32, device=device ) @@ -104,16 +149,16 @@ class TestHelixPostProcess(unittest.TestCase): gathered_o_sum = torch.sum(torch.exp(gathered_o - gathered_o_max), dim=-1) gathered_stats[..., 1] = gathered_o_sum - # Call the custom operator with cp_dim=2 + # Call the custom operator with cp_dim=2 (V1). output = torch.ops.trtllm.helix_post_process_native( gathered_o, gathered_stats, scale, 2 ) else: - # Original layout: [cp_size, num_tokens, num_heads, kv_lora_rank] + # Original layout: [cp_size, num_tokens, num_heads, kv_lora_rank]. gathered_o_init = torch.empty( cp_size, num_tokens, num_heads, kv_lora_rank, dtype=dtype, device=device ).uniform_(-1, 1) - # gathered_stats: [cp_size, num_tokens, num_heads, 2] + # gathered_stats: [cp_size, num_tokens, num_heads, 2]. gathered_stats = torch.empty( cp_size, num_tokens, num_heads, 2, dtype=torch.float32, device=device ) @@ -124,80 +169,124 @@ class TestHelixPostProcess(unittest.TestCase): gathered_o = gathered_o_init.view(cp_size, num_tokens, num_heads * kv_lora_rank) - # Call the custom operator + # Call the custom operator. output = torch.ops.trtllm.helix_post_process(gathered_o, gathered_stats, scale) - # Compute baseline - expected_output = baseline(gathered_o, gathered_stats, kv_lora_rank, scale, native=native) + # Compute baseline. + expected_output = baseline( + gathered_o, + gathered_stats, + kv_lora_rank, + scale, + native_v1=native_v1, + native_v2=native_v2, + ) - # Compare results + # Compare results. torch.testing.assert_close(output, expected_output, atol=1e-3, rtol=1e-2) @parameterized.expand( [ - # (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native) - (4, 8, 2, 64, 1.0, torch.float16, False), - (8, 16, 4, 128, 0.5, torch.float16, False), - (16, 32, 8, 256, 2.0, torch.float16, False), - (4, 8, 2, 64, 1.0, torch.bfloat16, False), - (8, 16, 4, 128, 0.5, torch.bfloat16, False), - (16, 32, 8, 256, 2.0, torch.bfloat16, False), - (4, 8, 2, 64, 1.0, torch.float16, True), - (8, 16, 4, 128, 0.5, torch.float16, True), - (16, 32, 8, 256, 2.0, torch.float16, True), - (4, 8, 2, 64, 1.0, torch.bfloat16, True), - (8, 16, 4, 128, 0.5, torch.bfloat16, True), - (16, 32, 8, 256, 2.0, torch.bfloat16, True), + # (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2) + # Original layout. + (4, 8, 2, 64, 1.0, torch.float16, False, False), + (8, 16, 4, 128, 0.5, torch.float16, False, False), + (16, 32, 8, 256, 2.0, torch.float16, False, False), + (4, 8, 2, 64, 1.0, torch.bfloat16, False, False), + (8, 16, 4, 128, 0.5, torch.bfloat16, False, False), + (16, 32, 8, 256, 2.0, torch.bfloat16, False, False), + # Native V1 layout (cp_dim=2). + (4, 8, 2, 64, 1.0, torch.float16, True, False), + (8, 16, 4, 128, 0.5, torch.float16, True, False), + (16, 32, 8, 256, 2.0, torch.float16, True, False), + (4, 8, 2, 64, 1.0, torch.bfloat16, True, False), + (8, 16, 4, 128, 0.5, torch.bfloat16, True, False), + (16, 32, 8, 256, 2.0, torch.bfloat16, True, False), + # Native V2 layout (cp_dim=1). + (4, 8, 2, 64, 1.0, torch.float16, False, True), + (8, 16, 4, 128, 0.5, torch.float16, False, True), + (16, 32, 8, 256, 2.0, torch.float16, False, True), + (4, 8, 2, 64, 1.0, torch.bfloat16, False, True), + (8, 16, 4, 128, 0.5, torch.bfloat16, False, True), + (16, 32, 8, 256, 2.0, torch.bfloat16, False, True), ] ) def test_helix_postprocess_basic( - self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native + self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2 ): - """Test basic helix postprocessing functionality""" + """Test basic helix postprocessing functionality.""" self._test_helix_postprocess( - cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native + cp_size, + num_tokens, + num_heads, + kv_lora_rank, + scale, + dtype, + native_v1=native_v1, + native_v2=native_v2, ) @parameterized.expand( [ - # (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native) - # Edge cases for non-native layout - (1, 1, 1, 16, 1.0, torch.float16, False), # Minimal sizes - (256, 1, 1, 16, 1.0, torch.float16, False), # Max cp_size - (128, 1, 1, 16, 1.0, torch.float16, False), # Single token - (4, 8, 1, 16, 1.0, torch.float16, False), # Single head - (4, 8, 2, 2048, 1.0, torch.float16, False), # Large kv_lora_rank - # Edge cases for native layout - (1, 1, 1, 16, 1.0, torch.float16, True), # Minimal sizes - (256, 1, 1, 16, 1.0, torch.float16, True), # Max cp_size - (128, 1, 1, 16, 1.0, torch.float16, True), # Single token - (4, 8, 1, 16, 1.0, torch.float16, True), # Single head - # Note: Large kv_lora_rank (2048) exceeds MAX_KV_LORA_BYTES for native kernel + # (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2) + # Edge cases for original layout. + (1, 1, 1, 16, 1.0, torch.float16, False, False), # Minimal sizes. + (256, 1, 1, 16, 1.0, torch.float16, False, False), # Max cp_size. + (128, 1, 1, 16, 1.0, torch.float16, False, False), # Single token. + (4, 8, 1, 16, 1.0, torch.float16, False, False), # Single head. + (4, 8, 2, 2048, 1.0, torch.float16, False, False), # Large kv_lora_rank. + # Edge cases for native V1 layout. + (1, 1, 1, 16, 1.0, torch.float16, True, False), # Minimal sizes. + (256, 1, 1, 16, 1.0, torch.float16, True, False), # Max cp_size. + (128, 1, 1, 16, 1.0, torch.float16, True, False), # Single token. + (4, 8, 1, 16, 1.0, torch.float16, True, False), # Single head. + # Note: Large kv_lora_rank (2048) exceeds MAX_KV_LORA_BYTES for native kernel. + # Edge cases for native V2 layout. + (1, 1, 1, 16, 1.0, torch.float16, False, True), # Minimal sizes. + (256, 1, 1, 16, 1.0, torch.float16, False, True), # Max cp_size. + (128, 1, 1, 16, 1.0, torch.float16, False, True), # Single token. + (4, 8, 1, 16, 1.0, torch.float16, False, True), # Single head. ] ) def test_helix_postprocess_edge_cases( - self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native + self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2 ): - """Test edge cases with minimal dimensions""" + """Test edge cases with minimal dimensions.""" self._test_helix_postprocess( - cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native + cp_size, + num_tokens, + num_heads, + kv_lora_rank, + scale, + dtype, + native_v1=native_v1, + native_v2=native_v2, ) @parameterized.expand( [ - # (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native) - (16, 16, 64, 512, 1.0, torch.float16, False), - (16, 16, 64, 512, 1.0, torch.bfloat16, False), - (16, 16, 64, 512, 1.0, torch.float16, True), - (16, 16, 64, 512, 1.0, torch.bfloat16, True), + # (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2) + (16, 16, 64, 512, 1.0, torch.float16, False, False), + (16, 16, 64, 512, 1.0, torch.bfloat16, False, False), + (16, 16, 64, 512, 1.0, torch.float16, True, False), + (16, 16, 64, 512, 1.0, torch.bfloat16, True, False), + (16, 16, 64, 512, 1.0, torch.float16, False, True), + (16, 16, 64, 512, 1.0, torch.bfloat16, False, True), ] ) def test_helix_postprocess_large_inputs( - self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native + self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2 ): - """Test with larger inputs to ensure performance and correctness""" + """Test with larger inputs to ensure performance and correctness.""" self._test_helix_postprocess( - cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native + cp_size, + num_tokens, + num_heads, + kv_lora_rank, + scale, + dtype, + native_v1=native_v1, + native_v2=native_v2, ) def test_helix_postprocess_invalid_inputs(self): @@ -232,14 +321,14 @@ class TestHelixPostProcess(unittest.TestCase): """Test error handling for invalid inputs (native layout)""" device = torch.device("cuda") - # Test with wrong cp_dim (only cp_dim=2 is supported) + # Test with wrong cp_dim (only cp_dim=1 and cp_dim=2 are supported). gathered_o = torch.randn(8, 2, 4, 64, dtype=torch.float16, device=device) gathered_stats = torch.randn(8, 2, 4, 2, dtype=torch.float32, device=device) with pytest.raises(RuntimeError): torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 0) with pytest.raises(RuntimeError): - torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 1) + torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 3) # Test with wrong tensor dimensions (3D instead of 4D) gathered_o = torch.randn(8, 2, 256, dtype=torch.float16, device=device) @@ -264,19 +353,20 @@ class TestHelixPostProcess(unittest.TestCase): @parameterized.expand( [ - # (native,) - (False,), - (True,), + # (layout,) — "nccl", "fifo_v1", "fifo_v2". + ("nccl",), + ("fifo_v1",), + ("fifo_v2",), ] ) - def test_helix_postprocess_alignment_requirements(self, native): - """Test alignment requirements""" + def test_helix_postprocess_alignment_requirements(self, layout): + """Test alignment requirements for all layouts.""" device = torch.device("cuda") - # For float16 (2 bytes), kv_lora_rank must be multiple of 8 for 16-byte alignment - - if native: - # This should work (kv_lora_rank = 64 is multiple of 8) + # For float16 (2 bytes), kv_lora_rank must be multiple of 8 for 16-byte alignment. + if layout == "fifo_v1": + # V1 layout: [num_tokens, num_heads, cp_size, kv_lora_rank]. + # This should work (kv_lora_rank = 64 is multiple of 8). gathered_o = torch.randn(8, 2, 4, 64, dtype=torch.float16, device=device) gathered_stats = torch.randn(8, 2, 4, 2, dtype=torch.float32, device=device) @@ -285,13 +375,30 @@ class TestHelixPostProcess(unittest.TestCase): except RuntimeError as e: pytest.fail(f"Should not raise error for valid alignment: {e}") - # Test with kv_lora_rank that doesn't satisfy alignment requirements + # Test with kv_lora_rank that doesn't satisfy alignment requirements. gathered_o = torch.randn(8, 1, 4, 4, dtype=torch.float16, device=device) gathered_stats = torch.randn(8, 1, 4, 2, dtype=torch.float32, device=device) with pytest.raises(RuntimeError): torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 2) + elif layout == "fifo_v2": + # V2 layout: [num_tokens, cp_size, num_heads, kv_lora_rank]. + # This should work (kv_lora_rank = 64 is multiple of 8). + gathered_o = torch.randn(8, 4, 2, 64, dtype=torch.float16, device=device) + gathered_stats = torch.randn(8, 4, 2, 2, dtype=torch.float32, device=device) + + try: + torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 1) + except RuntimeError as e: + pytest.fail(f"Should not raise error for valid alignment: {e}") + + # Test with kv_lora_rank that doesn't satisfy alignment requirements. + gathered_o = torch.randn(8, 4, 1, 4, dtype=torch.float16, device=device) + gathered_stats = torch.randn(8, 4, 1, 2, dtype=torch.float32, device=device) + with pytest.raises(RuntimeError): + torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 1) else: - # This should work (kv_lora_rank = 64 is multiple of 8) + # NCCL layout. + # This should work (kv_lora_rank = 64 is multiple of 8). gathered_o = torch.randn(4, 8, 2 * 64, dtype=torch.float16, device=device) gathered_stats = torch.randn(4, 8, 2, 2, dtype=torch.float32, device=device) @@ -300,7 +407,7 @@ class TestHelixPostProcess(unittest.TestCase): except RuntimeError as e: pytest.fail(f"Should not raise error for valid alignment: {e}") - # Test with kv_lora_rank that doesn't satisfy alignment requirements + # Test with kv_lora_rank that doesn't satisfy alignment requirements. gathered_o = torch.randn(4, 8, 4, dtype=torch.float16, device=device) gathered_stats = torch.randn(4, 8, 1, 2, dtype=torch.float32, device=device) with pytest.raises(RuntimeError):