[None][feat] Add new helix kernels for MNNVL-based codepath (#11433)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
Balaram Buddharaju 2026-02-13 17:39:24 -08:00 committed by GitHub
parent 4debf153d8
commit 2989bf5b39
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 616 additions and 158 deletions

View File

@ -80,8 +80,8 @@ static constexpr int NUM_PRE_LOAD = 8;
// output: [num_tokens, num_heads * kv_lora_rank] (half)
// gathered_o: [cp_size, num_tokens, num_heads * kv_lora_rank] (half)
// gathered_stats: [cp_size, num_tokens, num_heads, 2] (fp32)
// note: we explicitly avoid using restrict here, to avoid getting ld.global.nc
// which may have longer latency
// Note: we explicitly avoid using restrict here, to avoid getting ld.global.nc,
// which may have longer latency.
template <typename T>
__global__ void helix_postprocess_kernel(
T* output, T const* gathered_o, float2 const* gathered_stats, int cp_size, int kv_lora_rank)
@ -217,10 +217,10 @@ static constexpr int MAX_KV_LORA_BYTES = (MAX_THREADS - WARP_SIZE) * BYTES_O_PER
// output: [num_tokens, num_heads * kv_lora_rank] (half)
// gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank] (half)
// gathered_stats: [num_tokens, num_heads, cp_size, 2] (fp32)
// note: we explicitly avoid using restrict here, to avoid getting ld.global.nc
// which may have longer latency
// Note: we explicitly avoid using restrict here, to avoid getting ld.global.nc,
// which may have longer latency.
template <typename T>
__global__ void __launch_bounds__(MAX_THREADS) helix_postprocess_kernel_native(
__global__ void __launch_bounds__(MAX_THREADS) helix_postprocess_kernel_native_v1(
T* output, T const* gathered_o, float2 const* gathered_stats, int cp_size, int kv_lora_rank)
{
// Each block processes one (token, head)
@ -358,6 +358,153 @@ __global__ void __launch_bounds__(MAX_THREADS) helix_postprocess_kernel_native(
*reinterpret_cast<float4*>(output_off) = *reinterpret_cast<float4*>(output_typed);
}
// Kernel: fused helix post-processing for cp_dim=1 layout (Version 2)
// output: [num_tokens, num_heads * kv_lora_rank] (half)
// gathered_o: [num_tokens, cp_size, num_heads, kv_lora_rank] (half)
// gathered_stats: [num_tokens, cp_size, num_heads, 2] (fp32)
// Note: we explicitly avoid using restrict here, to avoid getting ld.global.nc,
// which may have longer latency.
template <typename T>
__global__ void __launch_bounds__(MAX_THREADS) helix_postprocess_kernel_native_v2(
T* output, T const* gathered_o, float2 const* gathered_stats, int cp_size, int kv_lora_rank)
{
// Each block processes one (token, head)
// gridDim.x: num_tokens, gridDim.y: num_heads
// there are two separate types of warps:
// warp 0 calculates the correction values (one per cp_size)
// all other warps pre-load the gathered_o elements for the current token/head
// and once warp 0 is done, all other warps can start accumulating the output
static constexpr int NUM_O_PER_THREAD = BYTES_O_PER_THREAD / sizeof(T);
int tok_idx = blockIdx.x;
int head_idx = blockIdx.y;
int num_tokens = gridDim.x;
int num_heads = gridDim.y;
int const cp_size_aligned = ((cp_size + NUM_PRE_LOAD - 1) / NUM_PRE_LOAD) * NUM_PRE_LOAD;
__shared__ float smem_correction[MAX_CP];
int lane_idx = threadIdx.x % WARP_SIZE;
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0);
// For cp_dim=1 layout: [num_tokens, cp_size, num_heads, kv_lora_rank]
// Pointer offsets differ from cp_dim=2
T const* gathered_o_off;
gathered_o_off = gathered_o + tok_idx * cp_size * num_heads * kv_lora_rank + head_idx * kv_lora_rank;
// we subtract WARP_SIZE because first warp is not participating in pre-load
gathered_o_off += (threadIdx.x - WARP_SIZE) * NUM_O_PER_THREAD;
float4 const* gathered_o_16b = reinterpret_cast<float4 const*>(gathered_o_off);
// For cp_dim=1: stride between cp entries is num_heads * kv_lora_rank
int gathered_16b_stride = (num_heads * kv_lora_rank) / NUM_O_PER_THREAD;
// For cp_dim=1: stats layout is [num_tokens, cp_size, num_heads, 2]
int stats_offset = tok_idx * cp_size * num_heads + head_idx;
int stats_stride = num_heads;
// here we have to wait for memory operations of the previous kernel to
// complete
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
float max_values[MAX_CP_VAL_PER_THREAD];
float sum_values[MAX_CP_VAL_PER_THREAD];
T vals[NUM_PRE_LOAD][NUM_O_PER_THREAD];
float final_sum[NUM_O_PER_THREAD];
float corr_vals[NUM_PRE_LOAD];
T output_typed[NUM_O_PER_THREAD];
if (warp_idx == 0)
{
// the warp collectively calculates the correction values
#pragma unroll
for (int cp_val_idx = 0; cp_val_idx < MAX_CP_VAL_PER_THREAD; ++cp_val_idx)
{
auto cp_idx = cp_val_idx * WARP_SIZE + lane_idx;
auto stats_idx = stats_offset + cp_idx * stats_stride;
float2 stats = cp_idx < cp_size ? gathered_stats[stats_idx] : make_float2(-INFINITY, 0.F);
max_values[cp_val_idx] = stats.x;
sum_values[cp_val_idx] = stats.y;
}
float corrected_values[MAX_CP_VAL_PER_THREAD];
warpReduceCorrectedSum(corrected_values, max_values, sum_values);
#pragma unroll
for (int cp_val_idx = 0; cp_val_idx < MAX_CP_VAL_PER_THREAD; ++cp_val_idx)
{
auto cp_idx = cp_val_idx * WARP_SIZE + lane_idx;
smem_correction[cp_idx] = corrected_values[cp_val_idx];
}
}
else
{
// all other warps pre-load the gathered_o elements
#pragma unroll
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx)
{
auto val = gathered_o_16b[cp_idx * gathered_16b_stride];
*reinterpret_cast<float4*>(vals[cp_idx]) = val;
}
#pragma unroll
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
{
final_sum[o_idx] = 0.F;
}
}
__syncthreads();
// warp 0 exits early
if (warp_idx == 0)
return;
// here we can trigger the dependent kernels to start
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
#pragma unroll
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx)
{
corr_vals[cp_idx] = smem_correction[cp_idx];
}
for (int cp_idx_base = NUM_PRE_LOAD; cp_idx_base < cp_size_aligned; cp_idx_base += NUM_PRE_LOAD)
{
#pragma unroll
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD; ++cp_idx)
{
#pragma unroll
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
{
final_sum[o_idx] += static_cast<float>(vals[cp_idx][o_idx]) * corr_vals[cp_idx];
}
}
#pragma unroll
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD; ++cp_idx)
{
*reinterpret_cast<float4*>(vals[cp_idx]) = cp_idx_base + cp_idx < cp_size
? gathered_o_16b[(cp_idx_base + cp_idx) * gathered_16b_stride]
: make_float4(0.F, 0.F, 0.F, 0.F);
corr_vals[cp_idx] = cp_idx_base + cp_idx < cp_size ? smem_correction[cp_idx_base + cp_idx] : 0.F;
}
}
#pragma unroll
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx)
{
#pragma unroll
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
{
final_sum[o_idx] += static_cast<float>(vals[cp_idx][o_idx]) * corr_vals[cp_idx];
}
}
#pragma unroll
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
{
output_typed[o_idx] = static_cast<T>(final_sum[o_idx]);
}
auto* output_off = output + tok_idx * num_heads * kv_lora_rank + head_idx * kv_lora_rank;
output_off += (threadIdx.x - WARP_SIZE) * NUM_O_PER_THREAD;
*reinterpret_cast<float4*>(output_off) = *reinterpret_cast<float4*>(output_typed);
}
} // anonymous namespace
template <typename T>
@ -394,21 +541,21 @@ INSTANTIATE_POST_PROC(__half);
INSTANTIATE_POST_PROC(__nv_bfloat16);
template <typename T>
void helixPostProcessNative(HelixPostProcParams<T> const& params, cudaStream_t stream)
void helixPostProcessNativeV1(HelixPostProcParams<T> const& params, cudaStream_t stream)
{
// Check that gathered_o is 16-byte aligned
// Check that gathered_o is 16-byte aligned.
TLLM_CHECK_WITH_INFO(reinterpret_cast<uintptr_t>(params.gathered_o) % 16 == 0,
"gathered_o must be 16-byte aligned for async memcpy");
// TODO: Figure why this constraint is specific to this implementation and not legacy one.
// TODO: Figure out why this constraint is specific to this implementation and not the legacy one.
TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) <= MAX_KV_LORA_BYTES,
"kv_lora_rank * sizeof(T) must be <= %zu bytes", MAX_KV_LORA_BYTES);
// Check that kv_lora_rank * sizeof(T) is a multiple of 16
// Check that kv_lora_rank * sizeof(T) is a multiple of 16.
TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) % 16 == 0,
"kv_lora_rank * sizeof(T) must be a multiple of 16 for async memcpy");
// Check that cp_size is not larger than the max fallback CP size
// Check that cp_size is not larger than the max fallback CP size.
TLLM_CHECK_WITH_INFO(params.cp_size <= MAX_CP, "cp_size > fallback max CP size");
auto kernel_instance = helix_postprocess_kernel_native<T>;
auto kernel_instance = helix_postprocess_kernel_native_v1<T>;
cudaLaunchConfig_t config;
config.gridDim = dim3(params.num_tokens, params.num_heads);
config.blockDim = WARP_SIZE + params.kv_lora_rank * sizeof(T) / 16;
@ -423,11 +570,47 @@ void helixPostProcessNative(HelixPostProcParams<T> const& params, cudaStream_t s
params.gathered_stats, params.cp_size, params.kv_lora_rank));
}
#define INSTANTIATE_POST_PROC_NATIVE(T) \
template void helixPostProcessNative<T>(HelixPostProcParams<T> const& params, cudaStream_t stream);
#define INSTANTIATE_POST_PROC_NATIVE_V1(T) \
template void helixPostProcessNativeV1<T>(HelixPostProcParams<T> const& params, cudaStream_t stream);
INSTANTIATE_POST_PROC_NATIVE(__half);
INSTANTIATE_POST_PROC_NATIVE(__nv_bfloat16);
INSTANTIATE_POST_PROC_NATIVE_V1(__half);
INSTANTIATE_POST_PROC_NATIVE_V1(__nv_bfloat16);
template <typename T>
void helixPostProcessNativeV2(HelixPostProcParams<T> const& params, cudaStream_t stream)
{
// Check that gathered_o is 16-byte aligned.
TLLM_CHECK_WITH_INFO(reinterpret_cast<uintptr_t>(params.gathered_o) % 16 == 0,
"gathered_o must be 16-byte aligned for async memcpy");
// TODO: Figure out why this constraint is specific to this implementation and not the legacy one.
TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) <= MAX_KV_LORA_BYTES,
"kv_lora_rank * sizeof(T) must be <= %zu bytes", MAX_KV_LORA_BYTES);
// Check that kv_lora_rank * sizeof(T) is a multiple of 16.
TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) % 16 == 0,
"kv_lora_rank * sizeof(T) must be a multiple of 16 for async memcpy");
// Check that cp_size is not larger than the max fallback CP size.
TLLM_CHECK_WITH_INFO(params.cp_size <= MAX_CP, "cp_size > fallback max CP size");
auto kernel_instance = helix_postprocess_kernel_native_v2<T>;
cudaLaunchConfig_t config;
config.gridDim = dim3(params.num_tokens, params.num_heads);
config.blockDim = WARP_SIZE + params.kv_lora_rank * sizeof(T) / 16;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_instance, params.output, params.gathered_o,
params.gathered_stats, params.cp_size, params.kv_lora_rank));
}
#define INSTANTIATE_POST_PROC_NATIVE_V2(T) \
template void helixPostProcessNativeV2<T>(HelixPostProcParams<T> const& params, cudaStream_t stream);
INSTANTIATE_POST_PROC_NATIVE_V2(__half);
INSTANTIATE_POST_PROC_NATIVE_V2(__nv_bfloat16);
} // namespace kernels

View File

@ -43,8 +43,17 @@ struct HelixPostProcParams
template <typename T>
void helixPostProcess(HelixPostProcParams<T> const& params, cudaStream_t stream);
// Version 1: cp_dim=2 layout.
// gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank].
// gathered_stats: [num_tokens, num_heads, cp_size, 2].
template <typename T>
void helixPostProcessNative(HelixPostProcParams<T> const& params, cudaStream_t stream);
void helixPostProcessNativeV1(HelixPostProcParams<T> const& params, cudaStream_t stream);
// Version 2: cp_dim=1 layout.
// gathered_o: [num_tokens, cp_size, num_heads, kv_lora_rank].
// gathered_stats: [num_tokens, cp_size, num_heads, 2].
template <typename T>
void helixPostProcessNativeV2(HelixPostProcParams<T> const& params, cudaStream_t stream);
} // namespace kernels

View File

@ -100,20 +100,14 @@ torch::Tensor helix_post_process(torch::Tensor const& gathered_o, torch::Tensor
}
template <typename T, typename Fn>
inline torch::Tensor helix_post_process_native_impl(
torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, int cp_dim, Fn fn)
inline torch::Tensor helix_post_process_native_impl_v1(
torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, Fn fn)
{
CHECK_TH_CUDA(gathered_o);
CHECK_CONTIGUOUS(gathered_o);
CHECK_TH_CUDA(gathered_stats);
CHECK_CONTIGUOUS(gathered_stats);
// Only cp_dim=2 is supported
TORCH_CHECK(cp_dim == 2,
"cp_dim must be 2. Expects tensor shapes to be: \n"
"gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank], \n"
"gathered_stats: [num_tokens, num_heads, cp_size, 2]");
// For cp_dim=2: tokens_dim=0, heads_dim=1
auto tokens_dim = 0;
auto heads_dim = 1;
@ -126,14 +120,14 @@ inline torch::Tensor helix_post_process_native_impl(
auto const cp_size = gathered_stats.sizes()[2];
auto const kv_lora_rank = gathered_o.sizes()[3];
// check remaining input tensor dimensions
// Check remaining input tensor dimensions.
TORCH_CHECK(gathered_o.sizes()[2] == cp_size, "gathered_o cp_size dim must match");
TORCH_CHECK(gathered_o.sizes()[tokens_dim] == num_tokens, "gathered_o tokens_dim must match num_tokens");
TORCH_CHECK(gathered_o.sizes()[heads_dim] == num_heads, "gathered_o heads_dim must match num_heads");
TORCH_CHECK(gathered_stats.sizes()[3] == 2, "gathered_stats last dimension must be 2");
// Check data types
// Check data types.
TORCH_CHECK(
gathered_o.scalar_type() == at::ScalarType::Half || gathered_o.scalar_type() == at::ScalarType::BFloat16,
"gathered_o must be half or bfloat16");
@ -143,15 +137,75 @@ inline torch::Tensor helix_post_process_native_impl(
// memcpy)
TORCH_CHECK(reinterpret_cast<uintptr_t>(gathered_o.data_ptr()) % 16 == 0, "gathered_o must be 16-byte aligned");
// Check that kv_lora_rank * sizeof(data_type) is a multiple of 16
// Check that kv_lora_rank * sizeof(data_type) is a multiple of 16.
size_t data_type_size = torch::elementSize(gathered_o.scalar_type());
TORCH_CHECK((kv_lora_rank * data_type_size) % 16 == 0, "kv_lora_rank * sizeof(data_type) must be a multiple of 16");
// Create output tensor
// Create output tensor.
std::vector<int64_t> output_shape = {num_tokens, num_heads * kv_lora_rank};
torch::Tensor output = torch::empty(output_shape, gathered_o.options());
// Get CUDA stream
// Get CUDA stream.
auto stream = at::cuda::getCurrentCUDAStream(gathered_o.get_device());
tensorrt_llm::kernels::HelixPostProcParams<T> params{reinterpret_cast<T*>(output.mutable_data_ptr()),
reinterpret_cast<T const*>(gathered_o.data_ptr()), reinterpret_cast<float2 const*>(gathered_stats.data_ptr()),
static_cast<int>(cp_size), static_cast<int>(num_tokens), static_cast<int>(num_heads),
static_cast<int>(kv_lora_rank)};
fn(params, stream);
if (scale != 1.0)
{
output *= scale;
}
return output;
}
template <typename T, typename Fn>
inline torch::Tensor helix_post_process_native_impl_v2(
torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, Fn fn)
{
CHECK_TH_CUDA(gathered_o);
CHECK_CONTIGUOUS(gathered_o);
CHECK_TH_CUDA(gathered_stats);
CHECK_CONTIGUOUS(gathered_stats);
// For cp_dim=1: gathered_o: [num_tokens, cp_size, num_heads, kv_lora_rank]
// gathered_stats: [num_tokens, cp_size, num_heads, 2]
TORCH_CHECK(gathered_o.dim() == 4, "gathered_o must be 4D tensor [num_tokens, cp_size, num_heads, kv_lora_rank]");
TORCH_CHECK(gathered_stats.dim() == 4, "gathered_stats must be 4D tensor [num_tokens, cp_size, num_heads, 2]");
auto const num_tokens = gathered_o.sizes()[0];
auto const cp_size = gathered_o.sizes()[1];
auto const num_heads = gathered_o.sizes()[2];
auto const kv_lora_rank = gathered_o.sizes()[3];
// Check remaining input tensor dimensions.
TORCH_CHECK(gathered_stats.sizes()[0] == num_tokens, "gathered_stats num_tokens dim must match");
TORCH_CHECK(gathered_stats.sizes()[1] == cp_size, "gathered_stats cp_size dim must match");
TORCH_CHECK(gathered_stats.sizes()[2] == num_heads, "gathered_stats num_heads dim must match");
TORCH_CHECK(gathered_stats.sizes()[3] == 2, "gathered_stats last dimension must be 2");
// Check data types.
TORCH_CHECK(
gathered_o.scalar_type() == at::ScalarType::Half || gathered_o.scalar_type() == at::ScalarType::BFloat16,
"gathered_o must be half or bfloat16");
TORCH_CHECK(gathered_stats.scalar_type() == at::ScalarType::Float, "gathered_stats must be float32");
// Check alignment requirements for gathered_o (16-byte aligned for async
// memcpy).
TORCH_CHECK(reinterpret_cast<uintptr_t>(gathered_o.data_ptr()) % 16 == 0, "gathered_o must be 16-byte aligned");
// Check that kv_lora_rank * sizeof(data_type) is a multiple of 16.
size_t data_type_size = torch::elementSize(gathered_o.scalar_type());
TORCH_CHECK((kv_lora_rank * data_type_size) % 16 == 0, "kv_lora_rank * sizeof(data_type) must be a multiple of 16");
// Create output tensor.
std::vector<int64_t> output_shape = {num_tokens, num_heads * kv_lora_rank};
torch::Tensor output = torch::empty(output_shape, gathered_o.options());
// Get CUDA stream.
auto stream = at::cuda::getCurrentCUDAStream(gathered_o.get_device());
tensorrt_llm::kernels::HelixPostProcParams<T> params{reinterpret_cast<T*>(output.mutable_data_ptr()),
@ -171,24 +225,54 @@ inline torch::Tensor helix_post_process_native_impl(
inline torch::Tensor helix_post_process_native(
torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, int64_t cp_dim)
{
TORCH_CHECK(cp_dim == 2, "cp_dim must be 2. Only cp_dim=2 layout is supported.");
if (gathered_o.scalar_type() == at::ScalarType::Half)
{
return helix_post_process_native_impl<__half>(
gathered_o, gathered_stats, scale, int(cp_dim), tensorrt_llm::kernels::helixPostProcessNative<__half>);
}
else if (gathered_o.scalar_type() == at::ScalarType::BFloat16)
TORCH_CHECK(cp_dim == 1 || cp_dim == 2,
"cp_dim must be 1 or 2. \n"
"cp_dim=1: gathered_o: [num_tokens, cp_size, num_heads, kv_lora_rank] \n"
"cp_dim=2: gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank]");
if (cp_dim == 1)
{
// Version 2 layout: [num_tokens, cp_size, num_heads, kv_lora_rank].
if (gathered_o.scalar_type() == at::ScalarType::Half)
{
return helix_post_process_native_impl_v2<__half>(
gathered_o, gathered_stats, scale, tensorrt_llm::kernels::helixPostProcessNativeV2<__half>);
}
else if (gathered_o.scalar_type() == at::ScalarType::BFloat16)
{
#ifdef ENABLE_BF16
return helix_post_process_native_impl<__nv_bfloat16>(gathered_o, gathered_stats, scale, int(cp_dim),
tensorrt_llm::kernels::helixPostProcessNative<__nv_bfloat16>);
return helix_post_process_native_impl_v2<__nv_bfloat16>(
gathered_o, gathered_stats, scale, tensorrt_llm::kernels::helixPostProcessNativeV2<__nv_bfloat16>);
#else
TLLM_THROW("BFloat16 must be enabled to use helix_post_process_native with bf16 tensors.");
TLLM_THROW("BFloat16 must be enabled to use helix_post_process_native with bf16 tensors.");
#endif
}
else
{
TLLM_THROW("helix_post_process_native only supports half and bfloat16 tensors.");
}
}
else
{
TLLM_THROW("helix_post_process_native only supports half and bfloat16 tensors.");
// Version 1 layout: [num_tokens, num_heads, cp_size, kv_lora_rank].
if (gathered_o.scalar_type() == at::ScalarType::Half)
{
return helix_post_process_native_impl_v1<__half>(
gathered_o, gathered_stats, scale, tensorrt_llm::kernels::helixPostProcessNativeV1<__half>);
}
else if (gathered_o.scalar_type() == at::ScalarType::BFloat16)
{
#ifdef ENABLE_BF16
return helix_post_process_native_impl_v1<__nv_bfloat16>(
gathered_o, gathered_stats, scale, tensorrt_llm::kernels::helixPostProcessNativeV1<__nv_bfloat16>);
#else
TLLM_THROW("BFloat16 must be enabled to use helix_post_process_native with bf16 tensors.");
#endif
}
else
{
TLLM_THROW("helix_post_process_native only supports half and bfloat16 tensors.");
}
}
}

View File

@ -1195,29 +1195,63 @@ class MLA(nn.Module):
num_tokens = partial_o.shape[0]
cp_size = self.mapping.cp_size
# Reshape for FIFO-based all-to-all.
# partial_o: [num_tokens, num_heads * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp, kv_lora_rank]
# softmax_stats: [num_tokens, num_heads, 2] -> [num_tokens, cp_size, num_heads_tp_cp, 2]
# Check which FIFO version to use (default: version 2 for better performance).
fifo_version = self.mapping.cp_config.get("fifo_version", 2)
partial_o = partial_o.view(
num_tokens, cp_size, self.num_heads_tp_cp,
kv_lora_rank).transpose(1, 2).contiguous()
softmax_stats = softmax_stats.view(num_tokens, cp_size,
self.num_heads_tp_cp,
2).transpose(1,
2).contiguous()
if fifo_version == 1:
# Version 1: Uses transpose+contiguous before alltoall, cp_dim=2.
# Reshape for FIFO-based all-to-all. Overlap the two .contiguous() calls on separate streams.
# partial_o: [num_tokens, num_heads * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp, kv_lora_rank]
# softmax_stats: [num_tokens, num_heads, 2] -> [num_tokens, cp_size, num_heads_tp_cp, 2]
partial_o, softmax_stats = maybe_execute_in_parallel(
lambda: partial_o.view(num_tokens, cp_size, self.
num_heads_tp_cp, kv_lora_rank).
transpose(1, 2).contiguous(),
lambda: softmax_stats.view(num_tokens, cp_size, self.
num_heads_tp_cp, 2).
transpose(1, 2).contiguous(),
self.ln_events[0],
self.ln_events[1],
self.aux_stream,
)
# Call FIFO-based helixAllToAll.
partial_o_out, softmax_stats_out = helix.alltoall_native(
partial_o, softmax_stats)
# Call FIFO-based helixAllToAll.
partial_o_out, softmax_stats_out = helix.alltoall_native(
partial_o, softmax_stats)
# partial_o_out: [num_tokens, num_heads_tp_cp, cp_size, kv_lora_rank]
# softmax_stats_out: [num_tokens, num_heads_tp_cp, cp_size, 2]
# cp_dim = 2 (the dimension where cp_size is located)
# partial_o_out: [num_tokens, num_heads_tp_cp, cp_size, kv_lora_rank]
# softmax_stats_out: [num_tokens, num_heads_tp_cp, cp_size, 2]
# cp_dim = 2 (the dimension where cp_size is located)
# Call helix_post_process_native with cp_dim=2.
return torch.ops.trtllm.helix_post_process_native(
partial_o_out, softmax_stats_out, 1.0, 2)
# Call helix_post_process_native V1 (cp_dim=2).
return torch.ops.trtllm.helix_post_process_native(
partial_o_out, softmax_stats_out, 1.0, 2)
else:
# Version 2: Uses simple view (no transpose+contiguous) for better performance.
# partial_o: [num_tokens, num_heads * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp * kv_lora_rank]
# softmax_stats: [num_tokens, num_heads, 2] -> [num_tokens, cp_size, num_heads_tp_cp * 2]
partial_o = partial_o.view(
num_tokens, cp_size,
self.num_heads_tp_cp * kv_lora_rank)
softmax_stats = softmax_stats.view(num_tokens, cp_size,
self.num_heads_tp_cp * 2)
# Call FIFO-based helixAllToAll.
partial_o_out, softmax_stats_out = helix.alltoall_native(
partial_o, softmax_stats)
# Reshape after alltoall for post-processing.
# partial_o_out: [num_tokens, cp_size, num_heads_tp_cp * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp, kv_lora_rank]
# softmax_stats_out: [num_tokens, cp_size, num_heads_tp_cp * 2] -> [num_tokens, cp_size, num_heads_tp_cp, 2]
gathered_o = partial_o_out.view(num_tokens, cp_size,
self.num_heads_tp_cp,
kv_lora_rank)
gathered_stats = softmax_stats_out.view(
num_tokens, cp_size, self.num_heads_tp_cp, 2)
# Call helix_post_process_native V2 (cp_dim=1).
return torch.ops.trtllm.helix_post_process_native(
gathered_o, gathered_stats, 1.0, 1)
else:
attn_output = attn_backend.forward(q, k, v, attn_metadata, **kwargs)
return attn_output

View File

@ -982,10 +982,21 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
"cudagraph:none", "cudagraph:without_padding",
"cudagraph:with_padding"
])
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
@pytest.mark.parametrize("comms_medium", ["fifo_v1", "fifo_v2", "nccl"])
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
gen_pp, gen_tp, gen_cp, enable_attention_dp):
use_nccl_for_alltoall = comms_medium == "nccl"
# Parse comms_medium to get use_nccl_for_alltoall and fifo_version.
if comms_medium == "nccl":
use_nccl_for_alltoall = True
fifo_version = 2 # Not used when NCCL is enabled.
elif comms_medium == "fifo_v1":
use_nccl_for_alltoall = False
fifo_version = 1
elif comms_medium == "fifo_v2":
use_nccl_for_alltoall = False
fifo_version = 2
else:
raise ValueError(f"Unknown comms_medium: {comms_medium}")
gen_ep = gen_tp * gen_cp
kv_cache_config = {
"free_gpu_memory_fraction": 0.5,
@ -1014,7 +1025,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
"cp_config": {
"cp_type": "HELIX",
"tokens_per_block": 32,
"use_nccl_for_alltoall": use_nccl_for_alltoall
"use_nccl_for_alltoall": use_nccl_for_alltoall,
"fifo_version": fifo_version,
},
"disable_overlap_scheduler": True,
"kv_cache_config": kv_cache_config,

View File

@ -293,13 +293,17 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp1tp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1tp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp1tp1cp4]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1tp1cp4]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1dp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp1dp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1dp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]

View File

@ -75,10 +75,10 @@ l0_dgx_b200:
backend: pytorch
orchestrator: mpi
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1dp2cp2] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60)
@ -108,8 +108,8 @@ l0_dgx_b200:
backend: pytorch
orchestrator: mpi
tests:
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1dp2cp2] TIMEOUT (60)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)

View File

@ -649,7 +649,8 @@ def _full_test_multi_gpu(
world_size: int,
scenario: Scenario,
gen_steps: int,
comms_medium: str = False,
use_nccl_for_alltoall: bool = False,
fifo_version: int = 2,
):
if scenario.rope_scaling:
rope_scaling = {
@ -825,7 +826,8 @@ def _full_test_multi_gpu(
cp_size=world_size,
cp_config={
"cp_type": CpType.HELIX,
"use_nccl_for_alltoall": comms_medium == "nccl",
"use_nccl_for_alltoall": use_nccl_for_alltoall,
"fifo_version": fifo_version,
},
)
# we use cp_allgather here because there is no broadcast op across CP group
@ -861,7 +863,7 @@ def _run_single_rank(func, *args, **kwargs):
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test")
@pytest.mark.parametrize("scenario", test_scenarios, ids=lambda x: f"scenario: {x}")
@pytest.mark.parametrize("comms_medium", ["nccl", "fifo"])
@pytest.mark.parametrize("comms_medium", ["nccl", "fifo_v1", "fifo_v2"])
def test_mla_helix_distributed(
scenario: Scenario,
comms_medium: str,
@ -872,11 +874,34 @@ def test_mla_helix_distributed(
world_size = 2
print(f"Testing with comms_medium={comms_medium}.")
gen_steps = scenario.ref_steps if gen_steps is None else gen_steps
# Parse comms_medium to get use_nccl_for_alltoall and fifo_version.
if comms_medium == "nccl":
use_nccl_for_alltoall = True
fifo_version = 2 # Not used when NCCL is enabled.
elif comms_medium == "fifo_v1":
use_nccl_for_alltoall = False
fifo_version = 1
elif comms_medium == "fifo_v2":
use_nccl_for_alltoall = False
fifo_version = 2
else:
raise ValueError(f"Unknown comms_medium: {comms_medium}")
with MPIPoolExecutor(max_workers=world_size) as executor:
results = executor.map(
_run_single_rank,
*zip(
*[(_full_test_multi_gpu, world_size, scenario, gen_steps, comms_medium == "nccl")]
*[
(
_full_test_multi_gpu,
world_size,
scenario,
gen_steps,
use_nccl_for_alltoall,
fifo_version,
)
]
* world_size
),
)
@ -888,7 +913,7 @@ def test_mla_helix_distributed(
if __name__ == "__main__":
for comms_medium in ["fifo", "nccl"]:
for comms_medium in ["fifo_v1", "fifo_v2", "nccl"]:
print(f"\n{'=' * 60}")
print(f"Testing with comms_medium={comms_medium}")
print(f"{'=' * 60}\n")

View File

@ -22,22 +22,40 @@ from parameterized import parameterized
import tensorrt_llm
def baseline(gathered_o, gathered_stats, kv_lora_rank, scale, native=False):
def baseline(gathered_o, gathered_stats, kv_lora_rank, scale, native_v1=False, native_v2=False):
"""Reference implementation (libtorch)
Args:
gathered_o: Input tensor
- native=False: [cp_size, num_tokens, num_heads * kv_lora_rank]
- native=True: [num_tokens, num_heads, cp_size, kv_lora_rank]
- native_v1=False, native_v2=False: [cp_size, num_tokens, num_heads * kv_lora_rank]
- native_v1=True (cp_dim=2): [num_tokens, num_heads, cp_size, kv_lora_rank]
- native_v2=True (cp_dim=1): [num_tokens, cp_size, num_heads, kv_lora_rank]
gathered_stats: Stats tensor
- native=False: [cp_size, num_tokens, num_heads, 2]
- native=True: [num_tokens, num_heads, cp_size, 2]
- native_v1=False, native_v2=False: [cp_size, num_tokens, num_heads, 2]
- native_v1=True (cp_dim=2): [num_tokens, num_heads, cp_size, 2]
- native_v2=True (cp_dim=1): [num_tokens, cp_size, num_heads, 2]
kv_lora_rank: KV LoRA rank
scale: Scale factor
native: Whether to use native layout (cp_dim=2)
native_v1: Whether to use native V1 layout (cp_dim=2).
native_v2: Whether to use native V2 layout (cp_dim=1).
"""
if native:
# Native layout: cp_dim=2
if native_v2:
# Native V2 layout: cp_dim=1.
# [num_tokens, cp_size, num_heads, 2] -> reduce over cp_size (dim=1).
global_max = gathered_stats[..., 0].max(dim=1, keepdim=True)[0]
corrected_max = gathered_stats[..., 0] - global_max
corrected_max_exp = torch.exp(corrected_max)
corrected_sum = gathered_stats[..., 1] * corrected_max_exp
global_sum = corrected_sum.sum(dim=1, keepdim=True)
correction = (gathered_stats[..., 1] * corrected_max_exp / global_sum).unsqueeze(-1)
gathered_o_fp32 = gathered_o.to(torch.float32)
corrected_o = gathered_o_fp32 * correction
# Sum over cp_size dimension (dim=1), result: [num_tokens, num_heads, kv_lora_rank]
corrected_o = corrected_o.sum(dim=1)
# Reshape to [num_tokens, num_heads * kv_lora_rank]
corrected_o = corrected_o.view(corrected_o.shape[0], -1)
elif native_v1:
# Native V1 layout: cp_dim=2.
# [num_tokens, num_heads, cp_size]
global_max = gathered_stats[..., 0].max(dim=-1, keepdim=True)[0]
corrected_max = gathered_stats[..., 0] - global_max
@ -75,27 +93,54 @@ class TestHelixPostProcess(unittest.TestCase):
torch.cuda.manual_seed(42)
def _test_helix_postprocess(
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native=False
self,
cp_size,
num_tokens,
num_heads,
kv_lora_rank,
scale,
dtype,
native_v1=False,
native_v2=False,
):
"""Test helix postprocessing with given parameters
"""Test helix postprocessing with given parameters.
Args:
cp_size: Context parallelism size
num_tokens: Number of tokens
num_heads: Number of attention heads
kv_lora_rank: KV LoRA rank
scale: Scale factor
dtype: Data type (float16 or bfloat16)
native: Whether to use native layout (cp_dim=2)
cp_size: Context parallelism size.
num_tokens: Number of tokens.
num_heads: Number of attention heads.
kv_lora_rank: KV LoRA rank.
scale: Scale factor.
dtype: Data type (float16 or bfloat16).
native_v1: Whether to use native V1 layout (cp_dim=2).
native_v2: Whether to use native V2 layout (cp_dim=1).
"""
device = torch.device("cuda")
if native:
# Native layout: [num_tokens, num_heads, cp_size, kv_lora_rank]
if native_v2:
# Native V2 layout: [num_tokens, cp_size, num_heads, kv_lora_rank].
gathered_o = torch.empty(
num_tokens, cp_size, num_heads, kv_lora_rank, dtype=dtype, device=device
).uniform_(-1, 1)
# gathered_stats: [num_tokens, cp_size, num_heads, 2].
gathered_stats = torch.empty(
num_tokens, cp_size, num_heads, 2, dtype=torch.float32, device=device
)
gathered_o_max = torch.max(gathered_o, dim=-1, keepdim=True)[0]
gathered_stats[..., 0] = gathered_o_max[..., 0]
gathered_o_sum = torch.sum(torch.exp(gathered_o - gathered_o_max), dim=-1)
gathered_stats[..., 1] = gathered_o_sum
# Call the custom operator with cp_dim=1 (V2).
output = torch.ops.trtllm.helix_post_process_native(
gathered_o, gathered_stats, scale, 1
)
elif native_v1:
# Native V1 layout: [num_tokens, num_heads, cp_size, kv_lora_rank].
gathered_o = torch.empty(
num_tokens, num_heads, cp_size, kv_lora_rank, dtype=dtype, device=device
).uniform_(-1, 1)
# gathered_stats: [num_tokens, num_heads, cp_size, 2]
# gathered_stats: [num_tokens, num_heads, cp_size, 2].
gathered_stats = torch.empty(
num_tokens, num_heads, cp_size, 2, dtype=torch.float32, device=device
)
@ -104,16 +149,16 @@ class TestHelixPostProcess(unittest.TestCase):
gathered_o_sum = torch.sum(torch.exp(gathered_o - gathered_o_max), dim=-1)
gathered_stats[..., 1] = gathered_o_sum
# Call the custom operator with cp_dim=2
# Call the custom operator with cp_dim=2 (V1).
output = torch.ops.trtllm.helix_post_process_native(
gathered_o, gathered_stats, scale, 2
)
else:
# Original layout: [cp_size, num_tokens, num_heads, kv_lora_rank]
# Original layout: [cp_size, num_tokens, num_heads, kv_lora_rank].
gathered_o_init = torch.empty(
cp_size, num_tokens, num_heads, kv_lora_rank, dtype=dtype, device=device
).uniform_(-1, 1)
# gathered_stats: [cp_size, num_tokens, num_heads, 2]
# gathered_stats: [cp_size, num_tokens, num_heads, 2].
gathered_stats = torch.empty(
cp_size, num_tokens, num_heads, 2, dtype=torch.float32, device=device
)
@ -124,80 +169,124 @@ class TestHelixPostProcess(unittest.TestCase):
gathered_o = gathered_o_init.view(cp_size, num_tokens, num_heads * kv_lora_rank)
# Call the custom operator
# Call the custom operator.
output = torch.ops.trtllm.helix_post_process(gathered_o, gathered_stats, scale)
# Compute baseline
expected_output = baseline(gathered_o, gathered_stats, kv_lora_rank, scale, native=native)
# Compute baseline.
expected_output = baseline(
gathered_o,
gathered_stats,
kv_lora_rank,
scale,
native_v1=native_v1,
native_v2=native_v2,
)
# Compare results
# Compare results.
torch.testing.assert_close(output, expected_output, atol=1e-3, rtol=1e-2)
@parameterized.expand(
[
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native)
(4, 8, 2, 64, 1.0, torch.float16, False),
(8, 16, 4, 128, 0.5, torch.float16, False),
(16, 32, 8, 256, 2.0, torch.float16, False),
(4, 8, 2, 64, 1.0, torch.bfloat16, False),
(8, 16, 4, 128, 0.5, torch.bfloat16, False),
(16, 32, 8, 256, 2.0, torch.bfloat16, False),
(4, 8, 2, 64, 1.0, torch.float16, True),
(8, 16, 4, 128, 0.5, torch.float16, True),
(16, 32, 8, 256, 2.0, torch.float16, True),
(4, 8, 2, 64, 1.0, torch.bfloat16, True),
(8, 16, 4, 128, 0.5, torch.bfloat16, True),
(16, 32, 8, 256, 2.0, torch.bfloat16, True),
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2)
# Original layout.
(4, 8, 2, 64, 1.0, torch.float16, False, False),
(8, 16, 4, 128, 0.5, torch.float16, False, False),
(16, 32, 8, 256, 2.0, torch.float16, False, False),
(4, 8, 2, 64, 1.0, torch.bfloat16, False, False),
(8, 16, 4, 128, 0.5, torch.bfloat16, False, False),
(16, 32, 8, 256, 2.0, torch.bfloat16, False, False),
# Native V1 layout (cp_dim=2).
(4, 8, 2, 64, 1.0, torch.float16, True, False),
(8, 16, 4, 128, 0.5, torch.float16, True, False),
(16, 32, 8, 256, 2.0, torch.float16, True, False),
(4, 8, 2, 64, 1.0, torch.bfloat16, True, False),
(8, 16, 4, 128, 0.5, torch.bfloat16, True, False),
(16, 32, 8, 256, 2.0, torch.bfloat16, True, False),
# Native V2 layout (cp_dim=1).
(4, 8, 2, 64, 1.0, torch.float16, False, True),
(8, 16, 4, 128, 0.5, torch.float16, False, True),
(16, 32, 8, 256, 2.0, torch.float16, False, True),
(4, 8, 2, 64, 1.0, torch.bfloat16, False, True),
(8, 16, 4, 128, 0.5, torch.bfloat16, False, True),
(16, 32, 8, 256, 2.0, torch.bfloat16, False, True),
]
)
def test_helix_postprocess_basic(
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2
):
"""Test basic helix postprocessing functionality"""
"""Test basic helix postprocessing functionality."""
self._test_helix_postprocess(
cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
cp_size,
num_tokens,
num_heads,
kv_lora_rank,
scale,
dtype,
native_v1=native_v1,
native_v2=native_v2,
)
@parameterized.expand(
[
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native)
# Edge cases for non-native layout
(1, 1, 1, 16, 1.0, torch.float16, False), # Minimal sizes
(256, 1, 1, 16, 1.0, torch.float16, False), # Max cp_size
(128, 1, 1, 16, 1.0, torch.float16, False), # Single token
(4, 8, 1, 16, 1.0, torch.float16, False), # Single head
(4, 8, 2, 2048, 1.0, torch.float16, False), # Large kv_lora_rank
# Edge cases for native layout
(1, 1, 1, 16, 1.0, torch.float16, True), # Minimal sizes
(256, 1, 1, 16, 1.0, torch.float16, True), # Max cp_size
(128, 1, 1, 16, 1.0, torch.float16, True), # Single token
(4, 8, 1, 16, 1.0, torch.float16, True), # Single head
# Note: Large kv_lora_rank (2048) exceeds MAX_KV_LORA_BYTES for native kernel
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2)
# Edge cases for original layout.
(1, 1, 1, 16, 1.0, torch.float16, False, False), # Minimal sizes.
(256, 1, 1, 16, 1.0, torch.float16, False, False), # Max cp_size.
(128, 1, 1, 16, 1.0, torch.float16, False, False), # Single token.
(4, 8, 1, 16, 1.0, torch.float16, False, False), # Single head.
(4, 8, 2, 2048, 1.0, torch.float16, False, False), # Large kv_lora_rank.
# Edge cases for native V1 layout.
(1, 1, 1, 16, 1.0, torch.float16, True, False), # Minimal sizes.
(256, 1, 1, 16, 1.0, torch.float16, True, False), # Max cp_size.
(128, 1, 1, 16, 1.0, torch.float16, True, False), # Single token.
(4, 8, 1, 16, 1.0, torch.float16, True, False), # Single head.
# Note: Large kv_lora_rank (2048) exceeds MAX_KV_LORA_BYTES for native kernel.
# Edge cases for native V2 layout.
(1, 1, 1, 16, 1.0, torch.float16, False, True), # Minimal sizes.
(256, 1, 1, 16, 1.0, torch.float16, False, True), # Max cp_size.
(128, 1, 1, 16, 1.0, torch.float16, False, True), # Single token.
(4, 8, 1, 16, 1.0, torch.float16, False, True), # Single head.
]
)
def test_helix_postprocess_edge_cases(
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2
):
"""Test edge cases with minimal dimensions"""
"""Test edge cases with minimal dimensions."""
self._test_helix_postprocess(
cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
cp_size,
num_tokens,
num_heads,
kv_lora_rank,
scale,
dtype,
native_v1=native_v1,
native_v2=native_v2,
)
@parameterized.expand(
[
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native)
(16, 16, 64, 512, 1.0, torch.float16, False),
(16, 16, 64, 512, 1.0, torch.bfloat16, False),
(16, 16, 64, 512, 1.0, torch.float16, True),
(16, 16, 64, 512, 1.0, torch.bfloat16, True),
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2)
(16, 16, 64, 512, 1.0, torch.float16, False, False),
(16, 16, 64, 512, 1.0, torch.bfloat16, False, False),
(16, 16, 64, 512, 1.0, torch.float16, True, False),
(16, 16, 64, 512, 1.0, torch.bfloat16, True, False),
(16, 16, 64, 512, 1.0, torch.float16, False, True),
(16, 16, 64, 512, 1.0, torch.bfloat16, False, True),
]
)
def test_helix_postprocess_large_inputs(
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2
):
"""Test with larger inputs to ensure performance and correctness"""
"""Test with larger inputs to ensure performance and correctness."""
self._test_helix_postprocess(
cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
cp_size,
num_tokens,
num_heads,
kv_lora_rank,
scale,
dtype,
native_v1=native_v1,
native_v2=native_v2,
)
def test_helix_postprocess_invalid_inputs(self):
@ -232,14 +321,14 @@ class TestHelixPostProcess(unittest.TestCase):
"""Test error handling for invalid inputs (native layout)"""
device = torch.device("cuda")
# Test with wrong cp_dim (only cp_dim=2 is supported)
# Test with wrong cp_dim (only cp_dim=1 and cp_dim=2 are supported).
gathered_o = torch.randn(8, 2, 4, 64, dtype=torch.float16, device=device)
gathered_stats = torch.randn(8, 2, 4, 2, dtype=torch.float32, device=device)
with pytest.raises(RuntimeError):
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 0)
with pytest.raises(RuntimeError):
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 1)
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 3)
# Test with wrong tensor dimensions (3D instead of 4D)
gathered_o = torch.randn(8, 2, 256, dtype=torch.float16, device=device)
@ -264,19 +353,20 @@ class TestHelixPostProcess(unittest.TestCase):
@parameterized.expand(
[
# (native,)
(False,),
(True,),
# (layout,) — "nccl", "fifo_v1", "fifo_v2".
("nccl",),
("fifo_v1",),
("fifo_v2",),
]
)
def test_helix_postprocess_alignment_requirements(self, native):
"""Test alignment requirements"""
def test_helix_postprocess_alignment_requirements(self, layout):
"""Test alignment requirements for all layouts."""
device = torch.device("cuda")
# For float16 (2 bytes), kv_lora_rank must be multiple of 8 for 16-byte alignment
if native:
# This should work (kv_lora_rank = 64 is multiple of 8)
# For float16 (2 bytes), kv_lora_rank must be multiple of 8 for 16-byte alignment.
if layout == "fifo_v1":
# V1 layout: [num_tokens, num_heads, cp_size, kv_lora_rank].
# This should work (kv_lora_rank = 64 is multiple of 8).
gathered_o = torch.randn(8, 2, 4, 64, dtype=torch.float16, device=device)
gathered_stats = torch.randn(8, 2, 4, 2, dtype=torch.float32, device=device)
@ -285,13 +375,30 @@ class TestHelixPostProcess(unittest.TestCase):
except RuntimeError as e:
pytest.fail(f"Should not raise error for valid alignment: {e}")
# Test with kv_lora_rank that doesn't satisfy alignment requirements
# Test with kv_lora_rank that doesn't satisfy alignment requirements.
gathered_o = torch.randn(8, 1, 4, 4, dtype=torch.float16, device=device)
gathered_stats = torch.randn(8, 1, 4, 2, dtype=torch.float32, device=device)
with pytest.raises(RuntimeError):
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 2)
elif layout == "fifo_v2":
# V2 layout: [num_tokens, cp_size, num_heads, kv_lora_rank].
# This should work (kv_lora_rank = 64 is multiple of 8).
gathered_o = torch.randn(8, 4, 2, 64, dtype=torch.float16, device=device)
gathered_stats = torch.randn(8, 4, 2, 2, dtype=torch.float32, device=device)
try:
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 1)
except RuntimeError as e:
pytest.fail(f"Should not raise error for valid alignment: {e}")
# Test with kv_lora_rank that doesn't satisfy alignment requirements.
gathered_o = torch.randn(8, 4, 1, 4, dtype=torch.float16, device=device)
gathered_stats = torch.randn(8, 4, 1, 2, dtype=torch.float32, device=device)
with pytest.raises(RuntimeError):
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 1)
else:
# This should work (kv_lora_rank = 64 is multiple of 8)
# NCCL layout.
# This should work (kv_lora_rank = 64 is multiple of 8).
gathered_o = torch.randn(4, 8, 2 * 64, dtype=torch.float16, device=device)
gathered_stats = torch.randn(4, 8, 2, 2, dtype=torch.float32, device=device)
@ -300,7 +407,7 @@ class TestHelixPostProcess(unittest.TestCase):
except RuntimeError as e:
pytest.fail(f"Should not raise error for valid alignment: {e}")
# Test with kv_lora_rank that doesn't satisfy alignment requirements
# Test with kv_lora_rank that doesn't satisfy alignment requirements.
gathered_o = torch.randn(4, 8, 4, dtype=torch.float16, device=device)
gathered_stats = torch.randn(4, 8, 1, 2, dtype=torch.float32, device=device)
with pytest.raises(RuntimeError):