[TRTLLM-9493][feat] Add helixPostProcessNative kernel for cp_dim=2 (#9924)

Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
Balaram Buddharaju 2025-12-12 16:49:25 -08:00 committed by GitHub
parent 6147452158
commit 461446045e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 508 additions and 72 deletions

View File

@ -34,6 +34,9 @@ TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
namespace
{
static constexpr int WARP_SIZE = 32;
// Utility: warp-level corrected sum
@ -207,6 +210,156 @@ __global__ void helix_postprocess_kernel(
}
}
static constexpr int MAX_THREADS = 256;
static constexpr int MAX_KV_LORA_BYTES = (MAX_THREADS - WARP_SIZE) * BYTES_O_PER_THREAD;
// Kernel: fused helix post-processing
// output: [num_tokens, num_heads * kv_lora_rank] (half)
// gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank] (half)
// gathered_stats: [num_tokens, num_heads, cp_size, 2] (fp32)
// note: we explicitly avoid using restrict here, to avoid getting ld.global.nc
// which may have longer latency
template <typename T>
__global__ void __launch_bounds__(MAX_THREADS) helix_postprocess_kernel_native(
T* output, T const* gathered_o, float2 const* gathered_stats, int cp_size, int kv_lora_rank)
{
// Each block processes one (token, head)
// gridDim.x: num_tokens, gridDim.y: num_heads
// there are two separate types of warps:
// warp 0 calculates the correction values (one per cp_size)
// all other warps pre-load the gathered_o elements for the current token/head
// and once warp 0 is done, all other warps can start accumulating the output
static constexpr int NUM_O_PER_THREAD = BYTES_O_PER_THREAD / sizeof(T);
int tok_idx = blockIdx.x;
int head_idx = blockIdx.y;
int num_tokens = gridDim.x;
int num_heads = gridDim.y;
int const cp_size_aligned = ((cp_size + NUM_PRE_LOAD - 1) / NUM_PRE_LOAD) * NUM_PRE_LOAD;
__shared__ float smem_correction[MAX_CP];
int lane_idx = threadIdx.x % WARP_SIZE;
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0);
// all warps except first pre-load the gathered_o elements for the current
// token/head
T const* gathered_o_off;
gathered_o_off = gathered_o + tok_idx * num_heads * cp_size * kv_lora_rank + head_idx * cp_size * kv_lora_rank;
// we subtract WARP_SIZE because first warp is not participating in pre-load
gathered_o_off += (threadIdx.x - WARP_SIZE) * NUM_O_PER_THREAD;
float4 const* gathered_o_16b = reinterpret_cast<float4 const*>(gathered_o_off);
int gathered_16b_stride = (kv_lora_rank) / NUM_O_PER_THREAD;
int stats_offset = tok_idx * num_heads * cp_size + head_idx * cp_size;
int stats_stride = 1;
// here we have to wait for memory operations of the previous kernel to
// complete
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaGridDependencySynchronize();
#endif
float max_values[MAX_CP_VAL_PER_THREAD];
float sum_values[MAX_CP_VAL_PER_THREAD];
T vals[NUM_PRE_LOAD][NUM_O_PER_THREAD];
float final_sum[NUM_O_PER_THREAD];
float corr_vals[NUM_PRE_LOAD];
T output_typed[NUM_O_PER_THREAD];
if (warp_idx == 0)
{
// the warp collectively calculates the correction values
#pragma unroll
for (int cp_val_idx = 0; cp_val_idx < MAX_CP_VAL_PER_THREAD; ++cp_val_idx)
{
auto cp_idx = cp_val_idx * WARP_SIZE + lane_idx;
auto stats_idx = stats_offset + cp_idx * stats_stride;
float2 stats = cp_idx < cp_size ? gathered_stats[stats_idx] : make_float2(-INFINITY, 0.F);
max_values[cp_val_idx] = stats.x;
sum_values[cp_val_idx] = stats.y;
}
float corrected_values[MAX_CP_VAL_PER_THREAD];
warpReduceCorrectedSum(corrected_values, max_values, sum_values);
#pragma unroll
for (int cp_val_idx = 0; cp_val_idx < MAX_CP_VAL_PER_THREAD; ++cp_val_idx)
{
auto cp_idx = cp_val_idx * WARP_SIZE + lane_idx;
smem_correction[cp_idx] = corrected_values[cp_val_idx];
}
}
else
{
// all other warps pre-load the gathered_o elements
#pragma unroll
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx)
{
auto val = gathered_o_16b[cp_idx * gathered_16b_stride];
*reinterpret_cast<float4*>(vals[cp_idx]) = val;
}
#pragma unroll
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
{
final_sum[o_idx] = 0.F;
}
}
__syncthreads();
// warp 0 exits early
if (warp_idx == 0)
return;
// here we can trigger the dependent kernels to start
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
cudaTriggerProgrammaticLaunchCompletion();
#endif
#pragma unroll
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx)
{
corr_vals[cp_idx] = smem_correction[cp_idx];
}
for (int cp_idx_base = NUM_PRE_LOAD; cp_idx_base < cp_size_aligned; cp_idx_base += NUM_PRE_LOAD)
{
#pragma unroll
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD; ++cp_idx)
{
#pragma unroll
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
{
final_sum[o_idx] += static_cast<float>(vals[cp_idx][o_idx]) * corr_vals[cp_idx];
}
}
#pragma unroll
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD; ++cp_idx)
{
*reinterpret_cast<float4*>(vals[cp_idx]) = cp_idx_base + cp_idx < cp_size
? gathered_o_16b[(cp_idx_base + cp_idx) * gathered_16b_stride]
: make_float4(0.F, 0.F, 0.F, 0.F);
corr_vals[cp_idx] = cp_idx_base + cp_idx < cp_size ? smem_correction[cp_idx_base + cp_idx] : 0.F;
}
}
#pragma unroll
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx)
{
#pragma unroll
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
{
final_sum[o_idx] += static_cast<float>(vals[cp_idx][o_idx]) * corr_vals[cp_idx];
}
}
#pragma unroll
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
{
output_typed[o_idx] = static_cast<T>(final_sum[o_idx]);
}
auto* output_off = output + tok_idx * num_heads * kv_lora_rank + head_idx * kv_lora_rank;
output_off += (threadIdx.x - WARP_SIZE) * NUM_O_PER_THREAD;
*reinterpret_cast<float4*>(output_off) = *reinterpret_cast<float4*>(output_typed);
}
} // anonymous namespace
template <typename T>
void helixPostProcess(HelixPostProcParams<T> const& params, cudaStream_t stream)
{
@ -240,6 +393,42 @@ void helixPostProcess(HelixPostProcParams<T> const& params, cudaStream_t stream)
INSTANTIATE_POST_PROC(__half);
INSTANTIATE_POST_PROC(__nv_bfloat16);
template <typename T>
void helixPostProcessNative(HelixPostProcParams<T> const& params, cudaStream_t stream)
{
// Check that gathered_o is 16-byte aligned
TLLM_CHECK_WITH_INFO(reinterpret_cast<uintptr_t>(params.gathered_o) % 16 == 0,
"gathered_o must be 16-byte aligned for async memcpy");
// TODO: Figure why this constraint is specific to this implementation and not legacy one.
TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) <= MAX_KV_LORA_BYTES,
"kv_lora_rank * sizeof(T) must be <= %zu bytes", MAX_KV_LORA_BYTES);
// Check that kv_lora_rank * sizeof(T) is a multiple of 16
TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) % 16 == 0,
"kv_lora_rank * sizeof(T) must be a multiple of 16 for async memcpy");
// Check that cp_size is not larger than the max fallback CP size
TLLM_CHECK_WITH_INFO(params.cp_size <= MAX_CP, "cp_size > fallback max CP size");
auto kernel_instance = helix_postprocess_kernel_native<T>;
cudaLaunchConfig_t config;
config.gridDim = dim3(params.num_tokens, params.num_heads);
config.blockDim = WARP_SIZE + params.kv_lora_rank * sizeof(T) / 16;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = common::getEnvEnablePDL();
config.numAttrs = 1;
config.attrs = attrs;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_instance, params.output, params.gathered_o,
params.gathered_stats, params.cp_size, params.kv_lora_rank));
}
#define INSTANTIATE_POST_PROC_NATIVE(T) \
template void helixPostProcessNative<T>(HelixPostProcParams<T> const& params, cudaStream_t stream);
INSTANTIATE_POST_PROC_NATIVE(__half);
INSTANTIATE_POST_PROC_NATIVE(__nv_bfloat16);
} // namespace kernels
TRTLLM_NAMESPACE_END

View File

@ -43,6 +43,9 @@ struct HelixPostProcParams
template <typename T>
void helixPostProcess(HelixPostProcParams<T> const& params, cudaStream_t stream);
template <typename T>
void helixPostProcessNative(HelixPostProcParams<T> const& params, cudaStream_t stream);
} // namespace kernels
TRTLLM_NAMESPACE_END

View File

@ -99,14 +99,111 @@ torch::Tensor helix_post_process(torch::Tensor const& gathered_o, torch::Tensor
return output;
}
template <typename T, typename Fn>
inline torch::Tensor helix_post_process_native_impl(
torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, int cp_dim, Fn fn)
{
CHECK_TH_CUDA(gathered_o);
CHECK_CONTIGUOUS(gathered_o);
CHECK_TH_CUDA(gathered_stats);
CHECK_CONTIGUOUS(gathered_stats);
// Only cp_dim=2 is supported
TORCH_CHECK(cp_dim == 2,
"cp_dim must be 2. Expects tensor shapes to be: \n"
"gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank], \n"
"gathered_stats: [num_tokens, num_heads, cp_size, 2]");
// For cp_dim=2: tokens_dim=0, heads_dim=1
auto tokens_dim = 0;
auto heads_dim = 1;
TORCH_CHECK(gathered_o.dim() == 4, "gathered_o must be 4D tensor [num_tokens, num_heads, cp_size, kv_lora_rank]");
TORCH_CHECK(gathered_stats.dim() == 4, "gathered_stats must be 4D tensor [num_tokens, num_heads, cp_size, 2]");
auto const num_tokens = gathered_stats.sizes()[tokens_dim];
auto const num_heads = gathered_stats.sizes()[heads_dim];
auto const cp_size = gathered_stats.sizes()[2];
auto const kv_lora_rank = gathered_o.sizes()[3];
// check remaining input tensor dimensions
TORCH_CHECK(gathered_o.sizes()[2] == cp_size, "gathered_o cp_size dim must match");
TORCH_CHECK(gathered_o.sizes()[tokens_dim] == num_tokens, "gathered_o tokens_dim must match num_tokens");
TORCH_CHECK(gathered_o.sizes()[heads_dim] == num_heads, "gathered_o heads_dim must match num_heads");
TORCH_CHECK(gathered_stats.sizes()[3] == 2, "gathered_stats last dimension must be 2");
// Check data types
TORCH_CHECK(
gathered_o.scalar_type() == at::ScalarType::Half || gathered_o.scalar_type() == at::ScalarType::BFloat16,
"gathered_o must be half or bfloat16");
TORCH_CHECK(gathered_stats.scalar_type() == at::ScalarType::Float, "gathered_stats must be float32");
// Check alignment requirements for gathered_o (16-byte aligned for async
// memcpy)
TORCH_CHECK(reinterpret_cast<uintptr_t>(gathered_o.data_ptr()) % 16 == 0, "gathered_o must be 16-byte aligned");
// Check that kv_lora_rank * sizeof(data_type) is a multiple of 16
size_t data_type_size = torch::elementSize(gathered_o.scalar_type());
TORCH_CHECK((kv_lora_rank * data_type_size) % 16 == 0, "kv_lora_rank * sizeof(data_type) must be a multiple of 16");
// Create output tensor
std::vector<int64_t> output_shape = {num_tokens, num_heads * kv_lora_rank};
torch::Tensor output = torch::empty(output_shape, gathered_o.options());
// Get CUDA stream
auto stream = at::cuda::getCurrentCUDAStream(gathered_o.get_device());
tensorrt_llm::kernels::HelixPostProcParams<T> params{reinterpret_cast<T*>(output.mutable_data_ptr()),
reinterpret_cast<T const*>(gathered_o.data_ptr()), reinterpret_cast<float2 const*>(gathered_stats.data_ptr()),
static_cast<int>(cp_size), static_cast<int>(num_tokens), static_cast<int>(num_heads),
static_cast<int>(kv_lora_rank)};
fn(params, stream);
if (scale != 1.0)
{
output *= scale;
}
return output;
}
inline torch::Tensor helix_post_process_native(
torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, int64_t cp_dim)
{
TORCH_CHECK(cp_dim == 2, "cp_dim must be 2. Only cp_dim=2 layout is supported.");
if (gathered_o.scalar_type() == at::ScalarType::Half)
{
return helix_post_process_native_impl<__half>(
gathered_o, gathered_stats, scale, int(cp_dim), tensorrt_llm::kernels::helixPostProcessNative<__half>);
}
else if (gathered_o.scalar_type() == at::ScalarType::BFloat16)
{
#ifdef ENABLE_BF16
return helix_post_process_native_impl<__nv_bfloat16>(gathered_o, gathered_stats, scale, int(cp_dim),
tensorrt_llm::kernels::helixPostProcessNative<__nv_bfloat16>);
#else
TLLM_THROW("BFloat16 must be enabled to use helix_post_process_native with bf16 tensors.");
#endif
}
else
{
TLLM_THROW("helix_post_process_native only supports half and bfloat16 tensors.");
}
}
TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def("helix_post_process(Tensor gathered_o, Tensor gathered_stats, float scale) -> Tensor");
m.def(
"helix_post_process_native(Tensor gathered_o, Tensor gathered_stats, float "
"scale, int cp_dim) -> Tensor");
}
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
{
m.impl("helix_post_process", helix_post_process);
m.impl("helix_post_process_native", &helix_post_process_native);
}
} // namespace torch_ext

View File

@ -756,6 +756,13 @@ def _register_fake():
def _(gathered_o, gathered_stats, scale):
return gathered_o.new_empty(*gathered_o.shape[1:])
@torch.library.register_fake("trtllm::helix_post_process_native")
def _(gathered_o, gathered_stats, scale, cp_dim):
# Remove the dimension at cp_dim (context parallelism dimension)
out_shape = list(gathered_o.shape)
del out_shape[cp_dim]
return gathered_o.new_empty(*out_shape)
@torch.library.register_fake("trtllm::tinygemm2")
def _(input: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor):
# input [M, K], weight [N, K], bias [N]

View File

@ -22,21 +22,49 @@ from parameterized import parameterized
import tensorrt_llm
def baseline(gathered_o, gathered_stats, kv_lora_rank, scale):
"""Reference implementation (libtorch)"""
# [cp_size, num_tokens, num_heads]
global_max = gathered_stats[..., 0].max(dim=0, keepdim=True)[0]
# [cp_size, num_tokens, num_heads]
corrected_max = gathered_stats[..., 0] - global_max
corrected_max_exp = torch.exp(corrected_max)
corrected_sum = gathered_stats[..., 1] * corrected_max_exp
global_sum = corrected_sum.sum(dim=0, keepdim=True)
correction = (gathered_stats[..., 1] * corrected_max_exp / global_sum).unsqueeze(-1)
# Cast gathered_o to float32 for computation, then cast output to bf16 at the end
gathered_o_fp32 = gathered_o.to(torch.float32).view(*correction.shape[:-1], kv_lora_rank)
corrected_o = gathered_o_fp32 * correction
# [num_tokens, num_heads * kv_lora_rank] (bf16)
corrected_o = corrected_o.view(*gathered_o.shape[:-1], -1).sum(dim=0)
def baseline(gathered_o, gathered_stats, kv_lora_rank, scale, native=False):
"""Reference implementation (libtorch)
Args:
gathered_o: Input tensor
- native=False: [cp_size, num_tokens, num_heads * kv_lora_rank]
- native=True: [num_tokens, num_heads, cp_size, kv_lora_rank]
gathered_stats: Stats tensor
- native=False: [cp_size, num_tokens, num_heads, 2]
- native=True: [num_tokens, num_heads, cp_size, 2]
kv_lora_rank: KV LoRA rank
scale: Scale factor
native: Whether to use native layout (cp_dim=2)
"""
if native:
# Native layout: cp_dim=2
# [num_tokens, num_heads, cp_size]
global_max = gathered_stats[..., 0].max(dim=-1, keepdim=True)[0]
corrected_max = gathered_stats[..., 0] - global_max
corrected_max_exp = torch.exp(corrected_max)
corrected_sum = gathered_stats[..., 1] * corrected_max_exp
global_sum = corrected_sum.sum(dim=-1, keepdim=True)
correction = (gathered_stats[..., 1] * corrected_max_exp / global_sum).unsqueeze(-1)
gathered_o_fp32 = gathered_o.to(torch.float32)
corrected_o = gathered_o_fp32 * correction
# Sum over cp_size dimension (dim=2), result: [num_tokens, num_heads, kv_lora_rank]
corrected_o = corrected_o.sum(dim=2)
# Reshape to [num_tokens, num_heads * kv_lora_rank]
corrected_o = corrected_o.view(corrected_o.shape[0], -1)
else:
# Original layout: cp_dim=0
# [cp_size, num_tokens, num_heads]
global_max = gathered_stats[..., 0].max(dim=0, keepdim=True)[0]
corrected_max = gathered_stats[..., 0] - global_max
corrected_max_exp = torch.exp(corrected_max)
corrected_sum = gathered_stats[..., 1] * corrected_max_exp
global_sum = corrected_sum.sum(dim=0, keepdim=True)
correction = (gathered_stats[..., 1] * corrected_max_exp / global_sum).unsqueeze(-1)
gathered_o_fp32 = gathered_o.to(torch.float32).view(*correction.shape[:-1], kv_lora_rank)
corrected_o = gathered_o_fp32 * correction
# [num_tokens, num_heads * kv_lora_rank]
corrected_o = corrected_o.view(*gathered_o.shape[:-1], -1).sum(dim=0)
return corrected_o.to(gathered_o.dtype) * scale
@ -46,71 +74,134 @@ class TestHelixPostProcess(unittest.TestCase):
torch.manual_seed(42)
torch.cuda.manual_seed(42)
def _test_helix_postprocess(self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype):
"""Test helix postprocessing with given parameters"""
def _test_helix_postprocess(
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native=False
):
"""Test helix postprocessing with given parameters
Args:
cp_size: Context parallelism size
num_tokens: Number of tokens
num_heads: Number of attention heads
kv_lora_rank: KV LoRA rank
scale: Scale factor
dtype: Data type (float16 or bfloat16)
native: Whether to use native layout (cp_dim=2)
"""
device = torch.device("cuda")
# Create test tensors
# gathered_o_init: [cp_size, num_tokens, num_heads, kv_lora_rank]
gathered_o_init = torch.empty(
cp_size, num_tokens, num_heads, kv_lora_rank, dtype=dtype, device=device
).uniform_(-1, 1)
if native:
# Native layout: [num_tokens, num_heads, cp_size, kv_lora_rank]
gathered_o = torch.empty(
num_tokens, num_heads, cp_size, kv_lora_rank, dtype=dtype, device=device
).uniform_(-1, 1)
# gathered_stats: [num_tokens, num_heads, cp_size, 2]
gathered_stats = torch.empty(
num_tokens, num_heads, cp_size, 2, dtype=torch.float32, device=device
)
gathered_o_max = torch.max(gathered_o, dim=-1, keepdim=True)[0]
gathered_stats[..., 0] = gathered_o_max[..., 0]
gathered_o_sum = torch.sum(torch.exp(gathered_o - gathered_o_max), dim=-1)
gathered_stats[..., 1] = gathered_o_sum
# gathered_stats: [cp_size, num_tokens, num_heads, 2]
gathered_stats = torch.empty(
cp_size, num_tokens, num_heads, 2, dtype=torch.float32, device=device
)
gathered_o_max = torch.max(gathered_o_init, dim=-1, keepdim=True)[0]
gathered_stats[..., 0] = gathered_o_max[..., 0]
gathered_o_sum = torch.sum(torch.exp(gathered_o_init - gathered_o_max), dim=-1)
gathered_stats[..., 1] = gathered_o_sum
# Call the custom operator with cp_dim=2
output = torch.ops.trtllm.helix_post_process_native(
gathered_o, gathered_stats, scale, 2
)
else:
# Original layout: [cp_size, num_tokens, num_heads, kv_lora_rank]
gathered_o_init = torch.empty(
cp_size, num_tokens, num_heads, kv_lora_rank, dtype=dtype, device=device
).uniform_(-1, 1)
# gathered_stats: [cp_size, num_tokens, num_heads, 2]
gathered_stats = torch.empty(
cp_size, num_tokens, num_heads, 2, dtype=torch.float32, device=device
)
gathered_o_max = torch.max(gathered_o_init, dim=-1, keepdim=True)[0]
gathered_stats[..., 0] = gathered_o_max[..., 0]
gathered_o_sum = torch.sum(torch.exp(gathered_o_init - gathered_o_max), dim=-1)
gathered_stats[..., 1] = gathered_o_sum
gathered_o = gathered_o_init.view(cp_size, num_tokens, num_heads * kv_lora_rank)
gathered_o = gathered_o_init.view(cp_size, num_tokens, num_heads * kv_lora_rank)
# Call the custom operator
output = torch.ops.trtllm.helix_post_process(gathered_o, gathered_stats, scale)
# Call the custom operator
output = torch.ops.trtllm.helix_post_process(gathered_o, gathered_stats, scale)
# Compute baseline
expected_output = baseline(gathered_o, gathered_stats, kv_lora_rank, scale)
expected_output = baseline(gathered_o, gathered_stats, kv_lora_rank, scale, native=native)
# Compare results
torch.testing.assert_close(output, expected_output, atol=1e-3, rtol=1e-2)
@parameterized.expand(
[
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype)
(4, 8, 2, 64, 1.0, torch.float16),
(8, 16, 4, 128, 0.5, torch.float16),
(16, 32, 8, 256, 2.0, torch.float16),
(4, 8, 2, 64, 1.0, torch.bfloat16),
(8, 16, 4, 128, 0.5, torch.bfloat16),
(16, 32, 8, 256, 2.0, torch.bfloat16),
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native)
(4, 8, 2, 64, 1.0, torch.float16, False),
(8, 16, 4, 128, 0.5, torch.float16, False),
(16, 32, 8, 256, 2.0, torch.float16, False),
(4, 8, 2, 64, 1.0, torch.bfloat16, False),
(8, 16, 4, 128, 0.5, torch.bfloat16, False),
(16, 32, 8, 256, 2.0, torch.bfloat16, False),
(4, 8, 2, 64, 1.0, torch.float16, True),
(8, 16, 4, 128, 0.5, torch.float16, True),
(16, 32, 8, 256, 2.0, torch.float16, True),
(4, 8, 2, 64, 1.0, torch.bfloat16, True),
(8, 16, 4, 128, 0.5, torch.bfloat16, True),
(16, 32, 8, 256, 2.0, torch.bfloat16, True),
]
)
def test_helix_postprocess_basic(
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
):
"""Test basic helix postprocessing functionality"""
self._test_helix_postprocess(cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype)
self._test_helix_postprocess(
cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
)
@parameterized.expand(
[
# Test edge cases
(1, 1, 1, 16, 1.0, torch.float16), # Minimal sizes
(256, 1, 1, 16, 1.0, torch.float16), # Max cp_size
(128, 1, 1, 16, 1.0, torch.float16), # Single token
(4, 8, 1, 16, 1.0, torch.float16), # Single head
(4, 8, 2, 2048, 1.0, torch.float16), # Large kv_lora_rank
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native)
# Edge cases for non-native layout
(1, 1, 1, 16, 1.0, torch.float16, False), # Minimal sizes
(256, 1, 1, 16, 1.0, torch.float16, False), # Max cp_size
(128, 1, 1, 16, 1.0, torch.float16, False), # Single token
(4, 8, 1, 16, 1.0, torch.float16, False), # Single head
(4, 8, 2, 2048, 1.0, torch.float16, False), # Large kv_lora_rank
# Edge cases for native layout
(1, 1, 1, 16, 1.0, torch.float16, True), # Minimal sizes
(256, 1, 1, 16, 1.0, torch.float16, True), # Max cp_size
(128, 1, 1, 16, 1.0, torch.float16, True), # Single token
(4, 8, 1, 16, 1.0, torch.float16, True), # Single head
# Note: Large kv_lora_rank (2048) exceeds MAX_KV_LORA_BYTES for native kernel
]
)
def test_helix_postprocess_edge_cases(
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
):
"""Test edge cases with minimal dimensions"""
self._test_helix_postprocess(cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype)
self._test_helix_postprocess(
cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
)
@parameterized.expand(
[
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native)
(16, 16, 64, 512, 1.0, torch.float16, False),
(16, 16, 64, 512, 1.0, torch.bfloat16, False),
(16, 16, 64, 512, 1.0, torch.float16, True),
(16, 16, 64, 512, 1.0, torch.bfloat16, True),
]
)
def test_helix_postprocess_large_inputs(
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
):
"""Test with larger inputs to ensure performance and correctness"""
self._test_helix_postprocess(
cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
)
def test_helix_postprocess_invalid_inputs(self):
"""Test error handling for invalid inputs"""
"""Test error handling for invalid inputs (non-native)"""
device = torch.device("cuda")
# Test with wrong tensor dimensions
@ -137,34 +228,83 @@ class TestHelixPostProcess(unittest.TestCase):
with pytest.raises(RuntimeError):
torch.ops.trtllm.helix_post_process(gathered_o, gathered_stats, 1.0)
def test_helix_postprocess_alignment_requirements(self):
def test_helix_postprocess_native_invalid_inputs(self):
"""Test error handling for invalid inputs (native layout)"""
device = torch.device("cuda")
# Test with wrong cp_dim (only cp_dim=2 is supported)
gathered_o = torch.randn(8, 2, 4, 64, dtype=torch.float16, device=device)
gathered_stats = torch.randn(8, 2, 4, 2, dtype=torch.float32, device=device)
with pytest.raises(RuntimeError):
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 0)
with pytest.raises(RuntimeError):
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 1)
# Test with wrong tensor dimensions (3D instead of 4D)
gathered_o = torch.randn(8, 2, 256, dtype=torch.float16, device=device)
gathered_stats = torch.randn(8, 2, 4, 2, dtype=torch.float32, device=device)
with pytest.raises(RuntimeError):
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 2)
# Test with wrong data types
gathered_o = torch.randn(8, 2, 4, 64, dtype=torch.float32, device=device)
gathered_stats = torch.randn(8, 2, 4, 2, dtype=torch.float32, device=device)
with pytest.raises(RuntimeError):
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 2)
# Test with non-contiguous tensors
gathered_o = torch.randn(8, 2, 4, 64, dtype=torch.float16, device=device).transpose(0, 1)
gathered_stats = torch.randn(8, 2, 4, 2, dtype=torch.float32, device=device)
with pytest.raises(RuntimeError):
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 2)
@parameterized.expand(
[
# (native,)
(False,),
(True,),
]
)
def test_helix_postprocess_alignment_requirements(self, native):
"""Test alignment requirements"""
device = torch.device("cuda")
# Test with kv_lora_rank that doesn't satisfy alignment requirements
# For float16 (2 bytes), kv_lora_rank must be multiple of 8 for 16-byte alignment
# For bfloat16 (2 bytes), kv_lora_rank must be multiple of 8 for 16-byte alignment
# This should work (kv_lora_rank = 64 is multiple of 8)
gathered_o = torch.randn(4, 8, 2 * 64, dtype=torch.float16, device=device)
gathered_stats = torch.randn(4, 8, 2, 2, dtype=torch.float32, device=device)
if native:
# This should work (kv_lora_rank = 64 is multiple of 8)
gathered_o = torch.randn(8, 2, 4, 64, dtype=torch.float16, device=device)
gathered_stats = torch.randn(8, 2, 4, 2, dtype=torch.float32, device=device)
try:
torch.ops.trtllm.helix_post_process(gathered_o, gathered_stats, 1.0)
# Should not raise an error
except RuntimeError as e:
pytest.fail(f"Should not raise error for valid alignment: {e}")
try:
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 2)
except RuntimeError as e:
pytest.fail(f"Should not raise error for valid alignment: {e}")
# Test with kv_lora_rank that doesn't satisfy alignment requirements
gathered_o = torch.randn(4, 8, 4, dtype=torch.float16, device=device)
gathered_stats = torch.randn(4, 8, 1, 2, dtype=torch.float32, device=device)
with pytest.raises(RuntimeError):
torch.ops.trtllm.helix_post_process(gathered_o, gathered_stats, 1.0)
# Test with kv_lora_rank that doesn't satisfy alignment requirements
gathered_o = torch.randn(8, 1, 4, 4, dtype=torch.float16, device=device)
gathered_stats = torch.randn(8, 1, 4, 2, dtype=torch.float32, device=device)
with pytest.raises(RuntimeError):
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 2)
else:
# This should work (kv_lora_rank = 64 is multiple of 8)
gathered_o = torch.randn(4, 8, 2 * 64, dtype=torch.float16, device=device)
gathered_stats = torch.randn(4, 8, 2, 2, dtype=torch.float32, device=device)
def test_helix_postprocess_large_inputs(self):
"""Test with larger inputs to ensure performance and correctness"""
self._test_helix_postprocess(16, 16, 64, 512, 1.0, torch.float16)
self._test_helix_postprocess(16, 16, 64, 512, 1.0, torch.bfloat16)
try:
torch.ops.trtllm.helix_post_process(gathered_o, gathered_stats, 1.0)
except RuntimeError as e:
pytest.fail(f"Should not raise error for valid alignment: {e}")
# Test with kv_lora_rank that doesn't satisfy alignment requirements
gathered_o = torch.randn(4, 8, 4, dtype=torch.float16, device=device)
gathered_stats = torch.randn(4, 8, 1, 2, dtype=torch.float32, device=device)
with pytest.raises(RuntimeError):
torch.ops.trtllm.helix_post_process(gathered_o, gathered_stats, 1.0)
if __name__ == "__main__":