mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-02-16 07:53:55 +08:00
[None][feat] Add new helix kernels for MNNVL-based codepath (#11433)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
This commit is contained in:
parent
4debf153d8
commit
2989bf5b39
@ -80,8 +80,8 @@ static constexpr int NUM_PRE_LOAD = 8;
|
||||
// output: [num_tokens, num_heads * kv_lora_rank] (half)
|
||||
// gathered_o: [cp_size, num_tokens, num_heads * kv_lora_rank] (half)
|
||||
// gathered_stats: [cp_size, num_tokens, num_heads, 2] (fp32)
|
||||
// note: we explicitly avoid using restrict here, to avoid getting ld.global.nc
|
||||
// which may have longer latency
|
||||
// Note: we explicitly avoid using restrict here, to avoid getting ld.global.nc,
|
||||
// which may have longer latency.
|
||||
template <typename T>
|
||||
__global__ void helix_postprocess_kernel(
|
||||
T* output, T const* gathered_o, float2 const* gathered_stats, int cp_size, int kv_lora_rank)
|
||||
@ -217,10 +217,10 @@ static constexpr int MAX_KV_LORA_BYTES = (MAX_THREADS - WARP_SIZE) * BYTES_O_PER
|
||||
// output: [num_tokens, num_heads * kv_lora_rank] (half)
|
||||
// gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank] (half)
|
||||
// gathered_stats: [num_tokens, num_heads, cp_size, 2] (fp32)
|
||||
// note: we explicitly avoid using restrict here, to avoid getting ld.global.nc
|
||||
// which may have longer latency
|
||||
// Note: we explicitly avoid using restrict here, to avoid getting ld.global.nc,
|
||||
// which may have longer latency.
|
||||
template <typename T>
|
||||
__global__ void __launch_bounds__(MAX_THREADS) helix_postprocess_kernel_native(
|
||||
__global__ void __launch_bounds__(MAX_THREADS) helix_postprocess_kernel_native_v1(
|
||||
T* output, T const* gathered_o, float2 const* gathered_stats, int cp_size, int kv_lora_rank)
|
||||
{
|
||||
// Each block processes one (token, head)
|
||||
@ -358,6 +358,153 @@ __global__ void __launch_bounds__(MAX_THREADS) helix_postprocess_kernel_native(
|
||||
*reinterpret_cast<float4*>(output_off) = *reinterpret_cast<float4*>(output_typed);
|
||||
}
|
||||
|
||||
// Kernel: fused helix post-processing for cp_dim=1 layout (Version 2)
|
||||
// output: [num_tokens, num_heads * kv_lora_rank] (half)
|
||||
// gathered_o: [num_tokens, cp_size, num_heads, kv_lora_rank] (half)
|
||||
// gathered_stats: [num_tokens, cp_size, num_heads, 2] (fp32)
|
||||
// Note: we explicitly avoid using restrict here, to avoid getting ld.global.nc,
|
||||
// which may have longer latency.
|
||||
template <typename T>
|
||||
__global__ void __launch_bounds__(MAX_THREADS) helix_postprocess_kernel_native_v2(
|
||||
T* output, T const* gathered_o, float2 const* gathered_stats, int cp_size, int kv_lora_rank)
|
||||
{
|
||||
// Each block processes one (token, head)
|
||||
// gridDim.x: num_tokens, gridDim.y: num_heads
|
||||
// there are two separate types of warps:
|
||||
// warp 0 calculates the correction values (one per cp_size)
|
||||
// all other warps pre-load the gathered_o elements for the current token/head
|
||||
// and once warp 0 is done, all other warps can start accumulating the output
|
||||
static constexpr int NUM_O_PER_THREAD = BYTES_O_PER_THREAD / sizeof(T);
|
||||
|
||||
int tok_idx = blockIdx.x;
|
||||
int head_idx = blockIdx.y;
|
||||
int num_tokens = gridDim.x;
|
||||
int num_heads = gridDim.y;
|
||||
|
||||
int const cp_size_aligned = ((cp_size + NUM_PRE_LOAD - 1) / NUM_PRE_LOAD) * NUM_PRE_LOAD;
|
||||
__shared__ float smem_correction[MAX_CP];
|
||||
|
||||
int lane_idx = threadIdx.x % WARP_SIZE;
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / WARP_SIZE, 0);
|
||||
|
||||
// For cp_dim=1 layout: [num_tokens, cp_size, num_heads, kv_lora_rank]
|
||||
// Pointer offsets differ from cp_dim=2
|
||||
T const* gathered_o_off;
|
||||
gathered_o_off = gathered_o + tok_idx * cp_size * num_heads * kv_lora_rank + head_idx * kv_lora_rank;
|
||||
// we subtract WARP_SIZE because first warp is not participating in pre-load
|
||||
gathered_o_off += (threadIdx.x - WARP_SIZE) * NUM_O_PER_THREAD;
|
||||
float4 const* gathered_o_16b = reinterpret_cast<float4 const*>(gathered_o_off);
|
||||
// For cp_dim=1: stride between cp entries is num_heads * kv_lora_rank
|
||||
int gathered_16b_stride = (num_heads * kv_lora_rank) / NUM_O_PER_THREAD;
|
||||
// For cp_dim=1: stats layout is [num_tokens, cp_size, num_heads, 2]
|
||||
int stats_offset = tok_idx * cp_size * num_heads + head_idx;
|
||||
int stats_stride = num_heads;
|
||||
|
||||
// here we have to wait for memory operations of the previous kernel to
|
||||
// complete
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaGridDependencySynchronize();
|
||||
#endif
|
||||
|
||||
float max_values[MAX_CP_VAL_PER_THREAD];
|
||||
float sum_values[MAX_CP_VAL_PER_THREAD];
|
||||
T vals[NUM_PRE_LOAD][NUM_O_PER_THREAD];
|
||||
float final_sum[NUM_O_PER_THREAD];
|
||||
float corr_vals[NUM_PRE_LOAD];
|
||||
T output_typed[NUM_O_PER_THREAD];
|
||||
|
||||
if (warp_idx == 0)
|
||||
{
|
||||
// the warp collectively calculates the correction values
|
||||
#pragma unroll
|
||||
for (int cp_val_idx = 0; cp_val_idx < MAX_CP_VAL_PER_THREAD; ++cp_val_idx)
|
||||
{
|
||||
auto cp_idx = cp_val_idx * WARP_SIZE + lane_idx;
|
||||
auto stats_idx = stats_offset + cp_idx * stats_stride;
|
||||
float2 stats = cp_idx < cp_size ? gathered_stats[stats_idx] : make_float2(-INFINITY, 0.F);
|
||||
max_values[cp_val_idx] = stats.x;
|
||||
sum_values[cp_val_idx] = stats.y;
|
||||
}
|
||||
float corrected_values[MAX_CP_VAL_PER_THREAD];
|
||||
warpReduceCorrectedSum(corrected_values, max_values, sum_values);
|
||||
#pragma unroll
|
||||
for (int cp_val_idx = 0; cp_val_idx < MAX_CP_VAL_PER_THREAD; ++cp_val_idx)
|
||||
{
|
||||
auto cp_idx = cp_val_idx * WARP_SIZE + lane_idx;
|
||||
smem_correction[cp_idx] = corrected_values[cp_val_idx];
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// all other warps pre-load the gathered_o elements
|
||||
#pragma unroll
|
||||
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx)
|
||||
{
|
||||
auto val = gathered_o_16b[cp_idx * gathered_16b_stride];
|
||||
*reinterpret_cast<float4*>(vals[cp_idx]) = val;
|
||||
}
|
||||
#pragma unroll
|
||||
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
|
||||
{
|
||||
final_sum[o_idx] = 0.F;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// warp 0 exits early
|
||||
if (warp_idx == 0)
|
||||
return;
|
||||
|
||||
// here we can trigger the dependent kernels to start
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
|
||||
#pragma unroll
|
||||
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx)
|
||||
{
|
||||
corr_vals[cp_idx] = smem_correction[cp_idx];
|
||||
}
|
||||
|
||||
for (int cp_idx_base = NUM_PRE_LOAD; cp_idx_base < cp_size_aligned; cp_idx_base += NUM_PRE_LOAD)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD; ++cp_idx)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
|
||||
{
|
||||
final_sum[o_idx] += static_cast<float>(vals[cp_idx][o_idx]) * corr_vals[cp_idx];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD; ++cp_idx)
|
||||
{
|
||||
*reinterpret_cast<float4*>(vals[cp_idx]) = cp_idx_base + cp_idx < cp_size
|
||||
? gathered_o_16b[(cp_idx_base + cp_idx) * gathered_16b_stride]
|
||||
: make_float4(0.F, 0.F, 0.F, 0.F);
|
||||
corr_vals[cp_idx] = cp_idx_base + cp_idx < cp_size ? smem_correction[cp_idx_base + cp_idx] : 0.F;
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int cp_idx = 0; cp_idx < NUM_PRE_LOAD && cp_idx < cp_size; ++cp_idx)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
|
||||
{
|
||||
final_sum[o_idx] += static_cast<float>(vals[cp_idx][o_idx]) * corr_vals[cp_idx];
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int o_idx = 0; o_idx < NUM_O_PER_THREAD; ++o_idx)
|
||||
{
|
||||
output_typed[o_idx] = static_cast<T>(final_sum[o_idx]);
|
||||
}
|
||||
auto* output_off = output + tok_idx * num_heads * kv_lora_rank + head_idx * kv_lora_rank;
|
||||
output_off += (threadIdx.x - WARP_SIZE) * NUM_O_PER_THREAD;
|
||||
*reinterpret_cast<float4*>(output_off) = *reinterpret_cast<float4*>(output_typed);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename T>
|
||||
@ -394,21 +541,21 @@ INSTANTIATE_POST_PROC(__half);
|
||||
INSTANTIATE_POST_PROC(__nv_bfloat16);
|
||||
|
||||
template <typename T>
|
||||
void helixPostProcessNative(HelixPostProcParams<T> const& params, cudaStream_t stream)
|
||||
void helixPostProcessNativeV1(HelixPostProcParams<T> const& params, cudaStream_t stream)
|
||||
{
|
||||
// Check that gathered_o is 16-byte aligned
|
||||
// Check that gathered_o is 16-byte aligned.
|
||||
TLLM_CHECK_WITH_INFO(reinterpret_cast<uintptr_t>(params.gathered_o) % 16 == 0,
|
||||
"gathered_o must be 16-byte aligned for async memcpy");
|
||||
// TODO: Figure why this constraint is specific to this implementation and not legacy one.
|
||||
// TODO: Figure out why this constraint is specific to this implementation and not the legacy one.
|
||||
TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) <= MAX_KV_LORA_BYTES,
|
||||
"kv_lora_rank * sizeof(T) must be <= %zu bytes", MAX_KV_LORA_BYTES);
|
||||
// Check that kv_lora_rank * sizeof(T) is a multiple of 16
|
||||
// Check that kv_lora_rank * sizeof(T) is a multiple of 16.
|
||||
TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) % 16 == 0,
|
||||
"kv_lora_rank * sizeof(T) must be a multiple of 16 for async memcpy");
|
||||
// Check that cp_size is not larger than the max fallback CP size
|
||||
// Check that cp_size is not larger than the max fallback CP size.
|
||||
TLLM_CHECK_WITH_INFO(params.cp_size <= MAX_CP, "cp_size > fallback max CP size");
|
||||
|
||||
auto kernel_instance = helix_postprocess_kernel_native<T>;
|
||||
auto kernel_instance = helix_postprocess_kernel_native_v1<T>;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = dim3(params.num_tokens, params.num_heads);
|
||||
config.blockDim = WARP_SIZE + params.kv_lora_rank * sizeof(T) / 16;
|
||||
@ -423,11 +570,47 @@ void helixPostProcessNative(HelixPostProcParams<T> const& params, cudaStream_t s
|
||||
params.gathered_stats, params.cp_size, params.kv_lora_rank));
|
||||
}
|
||||
|
||||
#define INSTANTIATE_POST_PROC_NATIVE(T) \
|
||||
template void helixPostProcessNative<T>(HelixPostProcParams<T> const& params, cudaStream_t stream);
|
||||
#define INSTANTIATE_POST_PROC_NATIVE_V1(T) \
|
||||
template void helixPostProcessNativeV1<T>(HelixPostProcParams<T> const& params, cudaStream_t stream);
|
||||
|
||||
INSTANTIATE_POST_PROC_NATIVE(__half);
|
||||
INSTANTIATE_POST_PROC_NATIVE(__nv_bfloat16);
|
||||
INSTANTIATE_POST_PROC_NATIVE_V1(__half);
|
||||
INSTANTIATE_POST_PROC_NATIVE_V1(__nv_bfloat16);
|
||||
|
||||
template <typename T>
|
||||
void helixPostProcessNativeV2(HelixPostProcParams<T> const& params, cudaStream_t stream)
|
||||
{
|
||||
// Check that gathered_o is 16-byte aligned.
|
||||
TLLM_CHECK_WITH_INFO(reinterpret_cast<uintptr_t>(params.gathered_o) % 16 == 0,
|
||||
"gathered_o must be 16-byte aligned for async memcpy");
|
||||
// TODO: Figure out why this constraint is specific to this implementation and not the legacy one.
|
||||
TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) <= MAX_KV_LORA_BYTES,
|
||||
"kv_lora_rank * sizeof(T) must be <= %zu bytes", MAX_KV_LORA_BYTES);
|
||||
// Check that kv_lora_rank * sizeof(T) is a multiple of 16.
|
||||
TLLM_CHECK_WITH_INFO((params.kv_lora_rank * sizeof(T)) % 16 == 0,
|
||||
"kv_lora_rank * sizeof(T) must be a multiple of 16 for async memcpy");
|
||||
// Check that cp_size is not larger than the max fallback CP size.
|
||||
TLLM_CHECK_WITH_INFO(params.cp_size <= MAX_CP, "cp_size > fallback max CP size");
|
||||
|
||||
auto kernel_instance = helix_postprocess_kernel_native_v2<T>;
|
||||
cudaLaunchConfig_t config;
|
||||
config.gridDim = dim3(params.num_tokens, params.num_heads);
|
||||
config.blockDim = WARP_SIZE + params.kv_lora_rank * sizeof(T) / 16;
|
||||
config.dynamicSmemBytes = 0;
|
||||
config.stream = stream;
|
||||
cudaLaunchAttribute attrs[1];
|
||||
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attrs[0].val.programmaticStreamSerializationAllowed = common::getEnvEnablePDL();
|
||||
config.numAttrs = 1;
|
||||
config.attrs = attrs;
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, kernel_instance, params.output, params.gathered_o,
|
||||
params.gathered_stats, params.cp_size, params.kv_lora_rank));
|
||||
}
|
||||
|
||||
#define INSTANTIATE_POST_PROC_NATIVE_V2(T) \
|
||||
template void helixPostProcessNativeV2<T>(HelixPostProcParams<T> const& params, cudaStream_t stream);
|
||||
|
||||
INSTANTIATE_POST_PROC_NATIVE_V2(__half);
|
||||
INSTANTIATE_POST_PROC_NATIVE_V2(__nv_bfloat16);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
|
||||
@ -43,8 +43,17 @@ struct HelixPostProcParams
|
||||
template <typename T>
|
||||
void helixPostProcess(HelixPostProcParams<T> const& params, cudaStream_t stream);
|
||||
|
||||
// Version 1: cp_dim=2 layout.
|
||||
// gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank].
|
||||
// gathered_stats: [num_tokens, num_heads, cp_size, 2].
|
||||
template <typename T>
|
||||
void helixPostProcessNative(HelixPostProcParams<T> const& params, cudaStream_t stream);
|
||||
void helixPostProcessNativeV1(HelixPostProcParams<T> const& params, cudaStream_t stream);
|
||||
|
||||
// Version 2: cp_dim=1 layout.
|
||||
// gathered_o: [num_tokens, cp_size, num_heads, kv_lora_rank].
|
||||
// gathered_stats: [num_tokens, cp_size, num_heads, 2].
|
||||
template <typename T>
|
||||
void helixPostProcessNativeV2(HelixPostProcParams<T> const& params, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
|
||||
@ -100,20 +100,14 @@ torch::Tensor helix_post_process(torch::Tensor const& gathered_o, torch::Tensor
|
||||
}
|
||||
|
||||
template <typename T, typename Fn>
|
||||
inline torch::Tensor helix_post_process_native_impl(
|
||||
torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, int cp_dim, Fn fn)
|
||||
inline torch::Tensor helix_post_process_native_impl_v1(
|
||||
torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, Fn fn)
|
||||
{
|
||||
CHECK_TH_CUDA(gathered_o);
|
||||
CHECK_CONTIGUOUS(gathered_o);
|
||||
CHECK_TH_CUDA(gathered_stats);
|
||||
CHECK_CONTIGUOUS(gathered_stats);
|
||||
|
||||
// Only cp_dim=2 is supported
|
||||
TORCH_CHECK(cp_dim == 2,
|
||||
"cp_dim must be 2. Expects tensor shapes to be: \n"
|
||||
"gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank], \n"
|
||||
"gathered_stats: [num_tokens, num_heads, cp_size, 2]");
|
||||
|
||||
// For cp_dim=2: tokens_dim=0, heads_dim=1
|
||||
auto tokens_dim = 0;
|
||||
auto heads_dim = 1;
|
||||
@ -126,14 +120,14 @@ inline torch::Tensor helix_post_process_native_impl(
|
||||
auto const cp_size = gathered_stats.sizes()[2];
|
||||
auto const kv_lora_rank = gathered_o.sizes()[3];
|
||||
|
||||
// check remaining input tensor dimensions
|
||||
// Check remaining input tensor dimensions.
|
||||
TORCH_CHECK(gathered_o.sizes()[2] == cp_size, "gathered_o cp_size dim must match");
|
||||
TORCH_CHECK(gathered_o.sizes()[tokens_dim] == num_tokens, "gathered_o tokens_dim must match num_tokens");
|
||||
TORCH_CHECK(gathered_o.sizes()[heads_dim] == num_heads, "gathered_o heads_dim must match num_heads");
|
||||
|
||||
TORCH_CHECK(gathered_stats.sizes()[3] == 2, "gathered_stats last dimension must be 2");
|
||||
|
||||
// Check data types
|
||||
// Check data types.
|
||||
TORCH_CHECK(
|
||||
gathered_o.scalar_type() == at::ScalarType::Half || gathered_o.scalar_type() == at::ScalarType::BFloat16,
|
||||
"gathered_o must be half or bfloat16");
|
||||
@ -143,15 +137,75 @@ inline torch::Tensor helix_post_process_native_impl(
|
||||
// memcpy)
|
||||
TORCH_CHECK(reinterpret_cast<uintptr_t>(gathered_o.data_ptr()) % 16 == 0, "gathered_o must be 16-byte aligned");
|
||||
|
||||
// Check that kv_lora_rank * sizeof(data_type) is a multiple of 16
|
||||
// Check that kv_lora_rank * sizeof(data_type) is a multiple of 16.
|
||||
size_t data_type_size = torch::elementSize(gathered_o.scalar_type());
|
||||
TORCH_CHECK((kv_lora_rank * data_type_size) % 16 == 0, "kv_lora_rank * sizeof(data_type) must be a multiple of 16");
|
||||
|
||||
// Create output tensor
|
||||
// Create output tensor.
|
||||
std::vector<int64_t> output_shape = {num_tokens, num_heads * kv_lora_rank};
|
||||
torch::Tensor output = torch::empty(output_shape, gathered_o.options());
|
||||
|
||||
// Get CUDA stream
|
||||
// Get CUDA stream.
|
||||
auto stream = at::cuda::getCurrentCUDAStream(gathered_o.get_device());
|
||||
|
||||
tensorrt_llm::kernels::HelixPostProcParams<T> params{reinterpret_cast<T*>(output.mutable_data_ptr()),
|
||||
reinterpret_cast<T const*>(gathered_o.data_ptr()), reinterpret_cast<float2 const*>(gathered_stats.data_ptr()),
|
||||
static_cast<int>(cp_size), static_cast<int>(num_tokens), static_cast<int>(num_heads),
|
||||
static_cast<int>(kv_lora_rank)};
|
||||
fn(params, stream);
|
||||
|
||||
if (scale != 1.0)
|
||||
{
|
||||
output *= scale;
|
||||
}
|
||||
|
||||
return output;
|
||||
}
|
||||
|
||||
template <typename T, typename Fn>
|
||||
inline torch::Tensor helix_post_process_native_impl_v2(
|
||||
torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, Fn fn)
|
||||
{
|
||||
CHECK_TH_CUDA(gathered_o);
|
||||
CHECK_CONTIGUOUS(gathered_o);
|
||||
CHECK_TH_CUDA(gathered_stats);
|
||||
CHECK_CONTIGUOUS(gathered_stats);
|
||||
|
||||
// For cp_dim=1: gathered_o: [num_tokens, cp_size, num_heads, kv_lora_rank]
|
||||
// gathered_stats: [num_tokens, cp_size, num_heads, 2]
|
||||
TORCH_CHECK(gathered_o.dim() == 4, "gathered_o must be 4D tensor [num_tokens, cp_size, num_heads, kv_lora_rank]");
|
||||
TORCH_CHECK(gathered_stats.dim() == 4, "gathered_stats must be 4D tensor [num_tokens, cp_size, num_heads, 2]");
|
||||
|
||||
auto const num_tokens = gathered_o.sizes()[0];
|
||||
auto const cp_size = gathered_o.sizes()[1];
|
||||
auto const num_heads = gathered_o.sizes()[2];
|
||||
auto const kv_lora_rank = gathered_o.sizes()[3];
|
||||
|
||||
// Check remaining input tensor dimensions.
|
||||
TORCH_CHECK(gathered_stats.sizes()[0] == num_tokens, "gathered_stats num_tokens dim must match");
|
||||
TORCH_CHECK(gathered_stats.sizes()[1] == cp_size, "gathered_stats cp_size dim must match");
|
||||
TORCH_CHECK(gathered_stats.sizes()[2] == num_heads, "gathered_stats num_heads dim must match");
|
||||
TORCH_CHECK(gathered_stats.sizes()[3] == 2, "gathered_stats last dimension must be 2");
|
||||
|
||||
// Check data types.
|
||||
TORCH_CHECK(
|
||||
gathered_o.scalar_type() == at::ScalarType::Half || gathered_o.scalar_type() == at::ScalarType::BFloat16,
|
||||
"gathered_o must be half or bfloat16");
|
||||
TORCH_CHECK(gathered_stats.scalar_type() == at::ScalarType::Float, "gathered_stats must be float32");
|
||||
|
||||
// Check alignment requirements for gathered_o (16-byte aligned for async
|
||||
// memcpy).
|
||||
TORCH_CHECK(reinterpret_cast<uintptr_t>(gathered_o.data_ptr()) % 16 == 0, "gathered_o must be 16-byte aligned");
|
||||
|
||||
// Check that kv_lora_rank * sizeof(data_type) is a multiple of 16.
|
||||
size_t data_type_size = torch::elementSize(gathered_o.scalar_type());
|
||||
TORCH_CHECK((kv_lora_rank * data_type_size) % 16 == 0, "kv_lora_rank * sizeof(data_type) must be a multiple of 16");
|
||||
|
||||
// Create output tensor.
|
||||
std::vector<int64_t> output_shape = {num_tokens, num_heads * kv_lora_rank};
|
||||
torch::Tensor output = torch::empty(output_shape, gathered_o.options());
|
||||
|
||||
// Get CUDA stream.
|
||||
auto stream = at::cuda::getCurrentCUDAStream(gathered_o.get_device());
|
||||
|
||||
tensorrt_llm::kernels::HelixPostProcParams<T> params{reinterpret_cast<T*>(output.mutable_data_ptr()),
|
||||
@ -171,24 +225,54 @@ inline torch::Tensor helix_post_process_native_impl(
|
||||
inline torch::Tensor helix_post_process_native(
|
||||
torch::Tensor const& gathered_o, torch::Tensor const& gathered_stats, double scale, int64_t cp_dim)
|
||||
{
|
||||
TORCH_CHECK(cp_dim == 2, "cp_dim must be 2. Only cp_dim=2 layout is supported.");
|
||||
if (gathered_o.scalar_type() == at::ScalarType::Half)
|
||||
{
|
||||
return helix_post_process_native_impl<__half>(
|
||||
gathered_o, gathered_stats, scale, int(cp_dim), tensorrt_llm::kernels::helixPostProcessNative<__half>);
|
||||
}
|
||||
else if (gathered_o.scalar_type() == at::ScalarType::BFloat16)
|
||||
TORCH_CHECK(cp_dim == 1 || cp_dim == 2,
|
||||
"cp_dim must be 1 or 2. \n"
|
||||
"cp_dim=1: gathered_o: [num_tokens, cp_size, num_heads, kv_lora_rank] \n"
|
||||
"cp_dim=2: gathered_o: [num_tokens, num_heads, cp_size, kv_lora_rank]");
|
||||
|
||||
if (cp_dim == 1)
|
||||
{
|
||||
// Version 2 layout: [num_tokens, cp_size, num_heads, kv_lora_rank].
|
||||
if (gathered_o.scalar_type() == at::ScalarType::Half)
|
||||
{
|
||||
return helix_post_process_native_impl_v2<__half>(
|
||||
gathered_o, gathered_stats, scale, tensorrt_llm::kernels::helixPostProcessNativeV2<__half>);
|
||||
}
|
||||
else if (gathered_o.scalar_type() == at::ScalarType::BFloat16)
|
||||
{
|
||||
#ifdef ENABLE_BF16
|
||||
return helix_post_process_native_impl<__nv_bfloat16>(gathered_o, gathered_stats, scale, int(cp_dim),
|
||||
tensorrt_llm::kernels::helixPostProcessNative<__nv_bfloat16>);
|
||||
return helix_post_process_native_impl_v2<__nv_bfloat16>(
|
||||
gathered_o, gathered_stats, scale, tensorrt_llm::kernels::helixPostProcessNativeV2<__nv_bfloat16>);
|
||||
#else
|
||||
TLLM_THROW("BFloat16 must be enabled to use helix_post_process_native with bf16 tensors.");
|
||||
TLLM_THROW("BFloat16 must be enabled to use helix_post_process_native with bf16 tensors.");
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("helix_post_process_native only supports half and bfloat16 tensors.");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("helix_post_process_native only supports half and bfloat16 tensors.");
|
||||
// Version 1 layout: [num_tokens, num_heads, cp_size, kv_lora_rank].
|
||||
if (gathered_o.scalar_type() == at::ScalarType::Half)
|
||||
{
|
||||
return helix_post_process_native_impl_v1<__half>(
|
||||
gathered_o, gathered_stats, scale, tensorrt_llm::kernels::helixPostProcessNativeV1<__half>);
|
||||
}
|
||||
else if (gathered_o.scalar_type() == at::ScalarType::BFloat16)
|
||||
{
|
||||
#ifdef ENABLE_BF16
|
||||
return helix_post_process_native_impl_v1<__nv_bfloat16>(
|
||||
gathered_o, gathered_stats, scale, tensorrt_llm::kernels::helixPostProcessNativeV1<__nv_bfloat16>);
|
||||
#else
|
||||
TLLM_THROW("BFloat16 must be enabled to use helix_post_process_native with bf16 tensors.");
|
||||
#endif
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("helix_post_process_native only supports half and bfloat16 tensors.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1195,29 +1195,63 @@ class MLA(nn.Module):
|
||||
num_tokens = partial_o.shape[0]
|
||||
cp_size = self.mapping.cp_size
|
||||
|
||||
# Reshape for FIFO-based all-to-all.
|
||||
# partial_o: [num_tokens, num_heads * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp, kv_lora_rank]
|
||||
# softmax_stats: [num_tokens, num_heads, 2] -> [num_tokens, cp_size, num_heads_tp_cp, 2]
|
||||
# Check which FIFO version to use (default: version 2 for better performance).
|
||||
fifo_version = self.mapping.cp_config.get("fifo_version", 2)
|
||||
|
||||
partial_o = partial_o.view(
|
||||
num_tokens, cp_size, self.num_heads_tp_cp,
|
||||
kv_lora_rank).transpose(1, 2).contiguous()
|
||||
softmax_stats = softmax_stats.view(num_tokens, cp_size,
|
||||
self.num_heads_tp_cp,
|
||||
2).transpose(1,
|
||||
2).contiguous()
|
||||
if fifo_version == 1:
|
||||
# Version 1: Uses transpose+contiguous before alltoall, cp_dim=2.
|
||||
# Reshape for FIFO-based all-to-all. Overlap the two .contiguous() calls on separate streams.
|
||||
# partial_o: [num_tokens, num_heads * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp, kv_lora_rank]
|
||||
# softmax_stats: [num_tokens, num_heads, 2] -> [num_tokens, cp_size, num_heads_tp_cp, 2]
|
||||
partial_o, softmax_stats = maybe_execute_in_parallel(
|
||||
lambda: partial_o.view(num_tokens, cp_size, self.
|
||||
num_heads_tp_cp, kv_lora_rank).
|
||||
transpose(1, 2).contiguous(),
|
||||
lambda: softmax_stats.view(num_tokens, cp_size, self.
|
||||
num_heads_tp_cp, 2).
|
||||
transpose(1, 2).contiguous(),
|
||||
self.ln_events[0],
|
||||
self.ln_events[1],
|
||||
self.aux_stream,
|
||||
)
|
||||
|
||||
# Call FIFO-based helixAllToAll.
|
||||
partial_o_out, softmax_stats_out = helix.alltoall_native(
|
||||
partial_o, softmax_stats)
|
||||
# Call FIFO-based helixAllToAll.
|
||||
partial_o_out, softmax_stats_out = helix.alltoall_native(
|
||||
partial_o, softmax_stats)
|
||||
|
||||
# partial_o_out: [num_tokens, num_heads_tp_cp, cp_size, kv_lora_rank]
|
||||
# softmax_stats_out: [num_tokens, num_heads_tp_cp, cp_size, 2]
|
||||
# cp_dim = 2 (the dimension where cp_size is located)
|
||||
# partial_o_out: [num_tokens, num_heads_tp_cp, cp_size, kv_lora_rank]
|
||||
# softmax_stats_out: [num_tokens, num_heads_tp_cp, cp_size, 2]
|
||||
# cp_dim = 2 (the dimension where cp_size is located)
|
||||
|
||||
# Call helix_post_process_native with cp_dim=2.
|
||||
return torch.ops.trtllm.helix_post_process_native(
|
||||
partial_o_out, softmax_stats_out, 1.0, 2)
|
||||
# Call helix_post_process_native V1 (cp_dim=2).
|
||||
return torch.ops.trtllm.helix_post_process_native(
|
||||
partial_o_out, softmax_stats_out, 1.0, 2)
|
||||
else:
|
||||
# Version 2: Uses simple view (no transpose+contiguous) for better performance.
|
||||
# partial_o: [num_tokens, num_heads * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp * kv_lora_rank]
|
||||
# softmax_stats: [num_tokens, num_heads, 2] -> [num_tokens, cp_size, num_heads_tp_cp * 2]
|
||||
partial_o = partial_o.view(
|
||||
num_tokens, cp_size,
|
||||
self.num_heads_tp_cp * kv_lora_rank)
|
||||
softmax_stats = softmax_stats.view(num_tokens, cp_size,
|
||||
self.num_heads_tp_cp * 2)
|
||||
|
||||
# Call FIFO-based helixAllToAll.
|
||||
partial_o_out, softmax_stats_out = helix.alltoall_native(
|
||||
partial_o, softmax_stats)
|
||||
|
||||
# Reshape after alltoall for post-processing.
|
||||
# partial_o_out: [num_tokens, cp_size, num_heads_tp_cp * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp, kv_lora_rank]
|
||||
# softmax_stats_out: [num_tokens, cp_size, num_heads_tp_cp * 2] -> [num_tokens, cp_size, num_heads_tp_cp, 2]
|
||||
gathered_o = partial_o_out.view(num_tokens, cp_size,
|
||||
self.num_heads_tp_cp,
|
||||
kv_lora_rank)
|
||||
gathered_stats = softmax_stats_out.view(
|
||||
num_tokens, cp_size, self.num_heads_tp_cp, 2)
|
||||
|
||||
# Call helix_post_process_native V2 (cp_dim=1).
|
||||
return torch.ops.trtllm.helix_post_process_native(
|
||||
gathered_o, gathered_stats, 1.0, 1)
|
||||
else:
|
||||
attn_output = attn_backend.forward(q, k, v, attn_metadata, **kwargs)
|
||||
return attn_output
|
||||
|
||||
@ -982,10 +982,21 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
"cudagraph:none", "cudagraph:without_padding",
|
||||
"cudagraph:with_padding"
|
||||
])
|
||||
@pytest.mark.parametrize("comms_medium", ["fifo", "nccl"])
|
||||
@pytest.mark.parametrize("comms_medium", ["fifo_v1", "fifo_v2", "nccl"])
|
||||
def test_auto_dtype_with_helix(self, comms_medium, cuda_graph_config,
|
||||
gen_pp, gen_tp, gen_cp, enable_attention_dp):
|
||||
use_nccl_for_alltoall = comms_medium == "nccl"
|
||||
# Parse comms_medium to get use_nccl_for_alltoall and fifo_version.
|
||||
if comms_medium == "nccl":
|
||||
use_nccl_for_alltoall = True
|
||||
fifo_version = 2 # Not used when NCCL is enabled.
|
||||
elif comms_medium == "fifo_v1":
|
||||
use_nccl_for_alltoall = False
|
||||
fifo_version = 1
|
||||
elif comms_medium == "fifo_v2":
|
||||
use_nccl_for_alltoall = False
|
||||
fifo_version = 2
|
||||
else:
|
||||
raise ValueError(f"Unknown comms_medium: {comms_medium}")
|
||||
gen_ep = gen_tp * gen_cp
|
||||
kv_cache_config = {
|
||||
"free_gpu_memory_fraction": 0.5,
|
||||
@ -1014,7 +1025,8 @@ class TestDeepSeekV3Lite(LlmapiAccuracyTestHarness):
|
||||
"cp_config": {
|
||||
"cp_type": "HELIX",
|
||||
"tokens_per_block": 32,
|
||||
"use_nccl_for_alltoall": use_nccl_for_alltoall
|
||||
"use_nccl_for_alltoall": use_nccl_for_alltoall,
|
||||
"fifo_version": fifo_version,
|
||||
},
|
||||
"disable_overlap_scheduler": True,
|
||||
"kv_cache_config": kv_cache_config,
|
||||
|
||||
@ -293,13 +293,17 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp1tp2cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1tp2cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp1tp1cp4]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1tp1cp4]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp2tp1cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1dp2cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp1dp2cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1dp2cp2]
|
||||
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
|
||||
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
|
||||
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
|
||||
|
||||
@ -75,10 +75,10 @@ l0_dgx_b200:
|
||||
backend: pytorch
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v2-cudagraph:with_padding-pp1dp2cp2] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_mtp] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_bs8_mtp] TIMEOUT (60)
|
||||
@ -108,8 +108,8 @@ l0_dgx_b200:
|
||||
backend: pytorch
|
||||
orchestrator: mpi
|
||||
tests:
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp2tp1cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[fifo_v1-cudagraph:with_padding-pp1tp2cp2] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1tp1cp4] TIMEOUT (60)
|
||||
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix[nccl-cudagraph:with_padding-pp1dp2cp2] TIMEOUT (60)
|
||||
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus_corner_case TIMEOUT (60)
|
||||
|
||||
@ -649,7 +649,8 @@ def _full_test_multi_gpu(
|
||||
world_size: int,
|
||||
scenario: Scenario,
|
||||
gen_steps: int,
|
||||
comms_medium: str = False,
|
||||
use_nccl_for_alltoall: bool = False,
|
||||
fifo_version: int = 2,
|
||||
):
|
||||
if scenario.rope_scaling:
|
||||
rope_scaling = {
|
||||
@ -825,7 +826,8 @@ def _full_test_multi_gpu(
|
||||
cp_size=world_size,
|
||||
cp_config={
|
||||
"cp_type": CpType.HELIX,
|
||||
"use_nccl_for_alltoall": comms_medium == "nccl",
|
||||
"use_nccl_for_alltoall": use_nccl_for_alltoall,
|
||||
"fifo_version": fifo_version,
|
||||
},
|
||||
)
|
||||
# we use cp_allgather here because there is no broadcast op across CP group
|
||||
@ -861,7 +863,7 @@ def _run_single_rank(func, *args, **kwargs):
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="needs 2 GPUs to run this test")
|
||||
@pytest.mark.parametrize("scenario", test_scenarios, ids=lambda x: f"scenario: {x}")
|
||||
@pytest.mark.parametrize("comms_medium", ["nccl", "fifo"])
|
||||
@pytest.mark.parametrize("comms_medium", ["nccl", "fifo_v1", "fifo_v2"])
|
||||
def test_mla_helix_distributed(
|
||||
scenario: Scenario,
|
||||
comms_medium: str,
|
||||
@ -872,11 +874,34 @@ def test_mla_helix_distributed(
|
||||
world_size = 2
|
||||
print(f"Testing with comms_medium={comms_medium}.")
|
||||
gen_steps = scenario.ref_steps if gen_steps is None else gen_steps
|
||||
|
||||
# Parse comms_medium to get use_nccl_for_alltoall and fifo_version.
|
||||
if comms_medium == "nccl":
|
||||
use_nccl_for_alltoall = True
|
||||
fifo_version = 2 # Not used when NCCL is enabled.
|
||||
elif comms_medium == "fifo_v1":
|
||||
use_nccl_for_alltoall = False
|
||||
fifo_version = 1
|
||||
elif comms_medium == "fifo_v2":
|
||||
use_nccl_for_alltoall = False
|
||||
fifo_version = 2
|
||||
else:
|
||||
raise ValueError(f"Unknown comms_medium: {comms_medium}")
|
||||
|
||||
with MPIPoolExecutor(max_workers=world_size) as executor:
|
||||
results = executor.map(
|
||||
_run_single_rank,
|
||||
*zip(
|
||||
*[(_full_test_multi_gpu, world_size, scenario, gen_steps, comms_medium == "nccl")]
|
||||
*[
|
||||
(
|
||||
_full_test_multi_gpu,
|
||||
world_size,
|
||||
scenario,
|
||||
gen_steps,
|
||||
use_nccl_for_alltoall,
|
||||
fifo_version,
|
||||
)
|
||||
]
|
||||
* world_size
|
||||
),
|
||||
)
|
||||
@ -888,7 +913,7 @@ def test_mla_helix_distributed(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
for comms_medium in ["fifo", "nccl"]:
|
||||
for comms_medium in ["fifo_v1", "fifo_v2", "nccl"]:
|
||||
print(f"\n{'=' * 60}")
|
||||
print(f"Testing with comms_medium={comms_medium}")
|
||||
print(f"{'=' * 60}\n")
|
||||
|
||||
@ -22,22 +22,40 @@ from parameterized import parameterized
|
||||
import tensorrt_llm
|
||||
|
||||
|
||||
def baseline(gathered_o, gathered_stats, kv_lora_rank, scale, native=False):
|
||||
def baseline(gathered_o, gathered_stats, kv_lora_rank, scale, native_v1=False, native_v2=False):
|
||||
"""Reference implementation (libtorch)
|
||||
|
||||
Args:
|
||||
gathered_o: Input tensor
|
||||
- native=False: [cp_size, num_tokens, num_heads * kv_lora_rank]
|
||||
- native=True: [num_tokens, num_heads, cp_size, kv_lora_rank]
|
||||
- native_v1=False, native_v2=False: [cp_size, num_tokens, num_heads * kv_lora_rank]
|
||||
- native_v1=True (cp_dim=2): [num_tokens, num_heads, cp_size, kv_lora_rank]
|
||||
- native_v2=True (cp_dim=1): [num_tokens, cp_size, num_heads, kv_lora_rank]
|
||||
gathered_stats: Stats tensor
|
||||
- native=False: [cp_size, num_tokens, num_heads, 2]
|
||||
- native=True: [num_tokens, num_heads, cp_size, 2]
|
||||
- native_v1=False, native_v2=False: [cp_size, num_tokens, num_heads, 2]
|
||||
- native_v1=True (cp_dim=2): [num_tokens, num_heads, cp_size, 2]
|
||||
- native_v2=True (cp_dim=1): [num_tokens, cp_size, num_heads, 2]
|
||||
kv_lora_rank: KV LoRA rank
|
||||
scale: Scale factor
|
||||
native: Whether to use native layout (cp_dim=2)
|
||||
native_v1: Whether to use native V1 layout (cp_dim=2).
|
||||
native_v2: Whether to use native V2 layout (cp_dim=1).
|
||||
"""
|
||||
if native:
|
||||
# Native layout: cp_dim=2
|
||||
if native_v2:
|
||||
# Native V2 layout: cp_dim=1.
|
||||
# [num_tokens, cp_size, num_heads, 2] -> reduce over cp_size (dim=1).
|
||||
global_max = gathered_stats[..., 0].max(dim=1, keepdim=True)[0]
|
||||
corrected_max = gathered_stats[..., 0] - global_max
|
||||
corrected_max_exp = torch.exp(corrected_max)
|
||||
corrected_sum = gathered_stats[..., 1] * corrected_max_exp
|
||||
global_sum = corrected_sum.sum(dim=1, keepdim=True)
|
||||
correction = (gathered_stats[..., 1] * corrected_max_exp / global_sum).unsqueeze(-1)
|
||||
gathered_o_fp32 = gathered_o.to(torch.float32)
|
||||
corrected_o = gathered_o_fp32 * correction
|
||||
# Sum over cp_size dimension (dim=1), result: [num_tokens, num_heads, kv_lora_rank]
|
||||
corrected_o = corrected_o.sum(dim=1)
|
||||
# Reshape to [num_tokens, num_heads * kv_lora_rank]
|
||||
corrected_o = corrected_o.view(corrected_o.shape[0], -1)
|
||||
elif native_v1:
|
||||
# Native V1 layout: cp_dim=2.
|
||||
# [num_tokens, num_heads, cp_size]
|
||||
global_max = gathered_stats[..., 0].max(dim=-1, keepdim=True)[0]
|
||||
corrected_max = gathered_stats[..., 0] - global_max
|
||||
@ -75,27 +93,54 @@ class TestHelixPostProcess(unittest.TestCase):
|
||||
torch.cuda.manual_seed(42)
|
||||
|
||||
def _test_helix_postprocess(
|
||||
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native=False
|
||||
self,
|
||||
cp_size,
|
||||
num_tokens,
|
||||
num_heads,
|
||||
kv_lora_rank,
|
||||
scale,
|
||||
dtype,
|
||||
native_v1=False,
|
||||
native_v2=False,
|
||||
):
|
||||
"""Test helix postprocessing with given parameters
|
||||
"""Test helix postprocessing with given parameters.
|
||||
|
||||
Args:
|
||||
cp_size: Context parallelism size
|
||||
num_tokens: Number of tokens
|
||||
num_heads: Number of attention heads
|
||||
kv_lora_rank: KV LoRA rank
|
||||
scale: Scale factor
|
||||
dtype: Data type (float16 or bfloat16)
|
||||
native: Whether to use native layout (cp_dim=2)
|
||||
cp_size: Context parallelism size.
|
||||
num_tokens: Number of tokens.
|
||||
num_heads: Number of attention heads.
|
||||
kv_lora_rank: KV LoRA rank.
|
||||
scale: Scale factor.
|
||||
dtype: Data type (float16 or bfloat16).
|
||||
native_v1: Whether to use native V1 layout (cp_dim=2).
|
||||
native_v2: Whether to use native V2 layout (cp_dim=1).
|
||||
"""
|
||||
device = torch.device("cuda")
|
||||
|
||||
if native:
|
||||
# Native layout: [num_tokens, num_heads, cp_size, kv_lora_rank]
|
||||
if native_v2:
|
||||
# Native V2 layout: [num_tokens, cp_size, num_heads, kv_lora_rank].
|
||||
gathered_o = torch.empty(
|
||||
num_tokens, cp_size, num_heads, kv_lora_rank, dtype=dtype, device=device
|
||||
).uniform_(-1, 1)
|
||||
# gathered_stats: [num_tokens, cp_size, num_heads, 2].
|
||||
gathered_stats = torch.empty(
|
||||
num_tokens, cp_size, num_heads, 2, dtype=torch.float32, device=device
|
||||
)
|
||||
gathered_o_max = torch.max(gathered_o, dim=-1, keepdim=True)[0]
|
||||
gathered_stats[..., 0] = gathered_o_max[..., 0]
|
||||
gathered_o_sum = torch.sum(torch.exp(gathered_o - gathered_o_max), dim=-1)
|
||||
gathered_stats[..., 1] = gathered_o_sum
|
||||
|
||||
# Call the custom operator with cp_dim=1 (V2).
|
||||
output = torch.ops.trtllm.helix_post_process_native(
|
||||
gathered_o, gathered_stats, scale, 1
|
||||
)
|
||||
elif native_v1:
|
||||
# Native V1 layout: [num_tokens, num_heads, cp_size, kv_lora_rank].
|
||||
gathered_o = torch.empty(
|
||||
num_tokens, num_heads, cp_size, kv_lora_rank, dtype=dtype, device=device
|
||||
).uniform_(-1, 1)
|
||||
# gathered_stats: [num_tokens, num_heads, cp_size, 2]
|
||||
# gathered_stats: [num_tokens, num_heads, cp_size, 2].
|
||||
gathered_stats = torch.empty(
|
||||
num_tokens, num_heads, cp_size, 2, dtype=torch.float32, device=device
|
||||
)
|
||||
@ -104,16 +149,16 @@ class TestHelixPostProcess(unittest.TestCase):
|
||||
gathered_o_sum = torch.sum(torch.exp(gathered_o - gathered_o_max), dim=-1)
|
||||
gathered_stats[..., 1] = gathered_o_sum
|
||||
|
||||
# Call the custom operator with cp_dim=2
|
||||
# Call the custom operator with cp_dim=2 (V1).
|
||||
output = torch.ops.trtllm.helix_post_process_native(
|
||||
gathered_o, gathered_stats, scale, 2
|
||||
)
|
||||
else:
|
||||
# Original layout: [cp_size, num_tokens, num_heads, kv_lora_rank]
|
||||
# Original layout: [cp_size, num_tokens, num_heads, kv_lora_rank].
|
||||
gathered_o_init = torch.empty(
|
||||
cp_size, num_tokens, num_heads, kv_lora_rank, dtype=dtype, device=device
|
||||
).uniform_(-1, 1)
|
||||
# gathered_stats: [cp_size, num_tokens, num_heads, 2]
|
||||
# gathered_stats: [cp_size, num_tokens, num_heads, 2].
|
||||
gathered_stats = torch.empty(
|
||||
cp_size, num_tokens, num_heads, 2, dtype=torch.float32, device=device
|
||||
)
|
||||
@ -124,80 +169,124 @@ class TestHelixPostProcess(unittest.TestCase):
|
||||
|
||||
gathered_o = gathered_o_init.view(cp_size, num_tokens, num_heads * kv_lora_rank)
|
||||
|
||||
# Call the custom operator
|
||||
# Call the custom operator.
|
||||
output = torch.ops.trtllm.helix_post_process(gathered_o, gathered_stats, scale)
|
||||
|
||||
# Compute baseline
|
||||
expected_output = baseline(gathered_o, gathered_stats, kv_lora_rank, scale, native=native)
|
||||
# Compute baseline.
|
||||
expected_output = baseline(
|
||||
gathered_o,
|
||||
gathered_stats,
|
||||
kv_lora_rank,
|
||||
scale,
|
||||
native_v1=native_v1,
|
||||
native_v2=native_v2,
|
||||
)
|
||||
|
||||
# Compare results
|
||||
# Compare results.
|
||||
torch.testing.assert_close(output, expected_output, atol=1e-3, rtol=1e-2)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native)
|
||||
(4, 8, 2, 64, 1.0, torch.float16, False),
|
||||
(8, 16, 4, 128, 0.5, torch.float16, False),
|
||||
(16, 32, 8, 256, 2.0, torch.float16, False),
|
||||
(4, 8, 2, 64, 1.0, torch.bfloat16, False),
|
||||
(8, 16, 4, 128, 0.5, torch.bfloat16, False),
|
||||
(16, 32, 8, 256, 2.0, torch.bfloat16, False),
|
||||
(4, 8, 2, 64, 1.0, torch.float16, True),
|
||||
(8, 16, 4, 128, 0.5, torch.float16, True),
|
||||
(16, 32, 8, 256, 2.0, torch.float16, True),
|
||||
(4, 8, 2, 64, 1.0, torch.bfloat16, True),
|
||||
(8, 16, 4, 128, 0.5, torch.bfloat16, True),
|
||||
(16, 32, 8, 256, 2.0, torch.bfloat16, True),
|
||||
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2)
|
||||
# Original layout.
|
||||
(4, 8, 2, 64, 1.0, torch.float16, False, False),
|
||||
(8, 16, 4, 128, 0.5, torch.float16, False, False),
|
||||
(16, 32, 8, 256, 2.0, torch.float16, False, False),
|
||||
(4, 8, 2, 64, 1.0, torch.bfloat16, False, False),
|
||||
(8, 16, 4, 128, 0.5, torch.bfloat16, False, False),
|
||||
(16, 32, 8, 256, 2.0, torch.bfloat16, False, False),
|
||||
# Native V1 layout (cp_dim=2).
|
||||
(4, 8, 2, 64, 1.0, torch.float16, True, False),
|
||||
(8, 16, 4, 128, 0.5, torch.float16, True, False),
|
||||
(16, 32, 8, 256, 2.0, torch.float16, True, False),
|
||||
(4, 8, 2, 64, 1.0, torch.bfloat16, True, False),
|
||||
(8, 16, 4, 128, 0.5, torch.bfloat16, True, False),
|
||||
(16, 32, 8, 256, 2.0, torch.bfloat16, True, False),
|
||||
# Native V2 layout (cp_dim=1).
|
||||
(4, 8, 2, 64, 1.0, torch.float16, False, True),
|
||||
(8, 16, 4, 128, 0.5, torch.float16, False, True),
|
||||
(16, 32, 8, 256, 2.0, torch.float16, False, True),
|
||||
(4, 8, 2, 64, 1.0, torch.bfloat16, False, True),
|
||||
(8, 16, 4, 128, 0.5, torch.bfloat16, False, True),
|
||||
(16, 32, 8, 256, 2.0, torch.bfloat16, False, True),
|
||||
]
|
||||
)
|
||||
def test_helix_postprocess_basic(
|
||||
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
|
||||
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2
|
||||
):
|
||||
"""Test basic helix postprocessing functionality"""
|
||||
"""Test basic helix postprocessing functionality."""
|
||||
self._test_helix_postprocess(
|
||||
cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
|
||||
cp_size,
|
||||
num_tokens,
|
||||
num_heads,
|
||||
kv_lora_rank,
|
||||
scale,
|
||||
dtype,
|
||||
native_v1=native_v1,
|
||||
native_v2=native_v2,
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native)
|
||||
# Edge cases for non-native layout
|
||||
(1, 1, 1, 16, 1.0, torch.float16, False), # Minimal sizes
|
||||
(256, 1, 1, 16, 1.0, torch.float16, False), # Max cp_size
|
||||
(128, 1, 1, 16, 1.0, torch.float16, False), # Single token
|
||||
(4, 8, 1, 16, 1.0, torch.float16, False), # Single head
|
||||
(4, 8, 2, 2048, 1.0, torch.float16, False), # Large kv_lora_rank
|
||||
# Edge cases for native layout
|
||||
(1, 1, 1, 16, 1.0, torch.float16, True), # Minimal sizes
|
||||
(256, 1, 1, 16, 1.0, torch.float16, True), # Max cp_size
|
||||
(128, 1, 1, 16, 1.0, torch.float16, True), # Single token
|
||||
(4, 8, 1, 16, 1.0, torch.float16, True), # Single head
|
||||
# Note: Large kv_lora_rank (2048) exceeds MAX_KV_LORA_BYTES for native kernel
|
||||
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2)
|
||||
# Edge cases for original layout.
|
||||
(1, 1, 1, 16, 1.0, torch.float16, False, False), # Minimal sizes.
|
||||
(256, 1, 1, 16, 1.0, torch.float16, False, False), # Max cp_size.
|
||||
(128, 1, 1, 16, 1.0, torch.float16, False, False), # Single token.
|
||||
(4, 8, 1, 16, 1.0, torch.float16, False, False), # Single head.
|
||||
(4, 8, 2, 2048, 1.0, torch.float16, False, False), # Large kv_lora_rank.
|
||||
# Edge cases for native V1 layout.
|
||||
(1, 1, 1, 16, 1.0, torch.float16, True, False), # Minimal sizes.
|
||||
(256, 1, 1, 16, 1.0, torch.float16, True, False), # Max cp_size.
|
||||
(128, 1, 1, 16, 1.0, torch.float16, True, False), # Single token.
|
||||
(4, 8, 1, 16, 1.0, torch.float16, True, False), # Single head.
|
||||
# Note: Large kv_lora_rank (2048) exceeds MAX_KV_LORA_BYTES for native kernel.
|
||||
# Edge cases for native V2 layout.
|
||||
(1, 1, 1, 16, 1.0, torch.float16, False, True), # Minimal sizes.
|
||||
(256, 1, 1, 16, 1.0, torch.float16, False, True), # Max cp_size.
|
||||
(128, 1, 1, 16, 1.0, torch.float16, False, True), # Single token.
|
||||
(4, 8, 1, 16, 1.0, torch.float16, False, True), # Single head.
|
||||
]
|
||||
)
|
||||
def test_helix_postprocess_edge_cases(
|
||||
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
|
||||
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2
|
||||
):
|
||||
"""Test edge cases with minimal dimensions"""
|
||||
"""Test edge cases with minimal dimensions."""
|
||||
self._test_helix_postprocess(
|
||||
cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
|
||||
cp_size,
|
||||
num_tokens,
|
||||
num_heads,
|
||||
kv_lora_rank,
|
||||
scale,
|
||||
dtype,
|
||||
native_v1=native_v1,
|
||||
native_v2=native_v2,
|
||||
)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native)
|
||||
(16, 16, 64, 512, 1.0, torch.float16, False),
|
||||
(16, 16, 64, 512, 1.0, torch.bfloat16, False),
|
||||
(16, 16, 64, 512, 1.0, torch.float16, True),
|
||||
(16, 16, 64, 512, 1.0, torch.bfloat16, True),
|
||||
# (cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2)
|
||||
(16, 16, 64, 512, 1.0, torch.float16, False, False),
|
||||
(16, 16, 64, 512, 1.0, torch.bfloat16, False, False),
|
||||
(16, 16, 64, 512, 1.0, torch.float16, True, False),
|
||||
(16, 16, 64, 512, 1.0, torch.bfloat16, True, False),
|
||||
(16, 16, 64, 512, 1.0, torch.float16, False, True),
|
||||
(16, 16, 64, 512, 1.0, torch.bfloat16, False, True),
|
||||
]
|
||||
)
|
||||
def test_helix_postprocess_large_inputs(
|
||||
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
|
||||
self, cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native_v1, native_v2
|
||||
):
|
||||
"""Test with larger inputs to ensure performance and correctness"""
|
||||
"""Test with larger inputs to ensure performance and correctness."""
|
||||
self._test_helix_postprocess(
|
||||
cp_size, num_tokens, num_heads, kv_lora_rank, scale, dtype, native
|
||||
cp_size,
|
||||
num_tokens,
|
||||
num_heads,
|
||||
kv_lora_rank,
|
||||
scale,
|
||||
dtype,
|
||||
native_v1=native_v1,
|
||||
native_v2=native_v2,
|
||||
)
|
||||
|
||||
def test_helix_postprocess_invalid_inputs(self):
|
||||
@ -232,14 +321,14 @@ class TestHelixPostProcess(unittest.TestCase):
|
||||
"""Test error handling for invalid inputs (native layout)"""
|
||||
device = torch.device("cuda")
|
||||
|
||||
# Test with wrong cp_dim (only cp_dim=2 is supported)
|
||||
# Test with wrong cp_dim (only cp_dim=1 and cp_dim=2 are supported).
|
||||
gathered_o = torch.randn(8, 2, 4, 64, dtype=torch.float16, device=device)
|
||||
gathered_stats = torch.randn(8, 2, 4, 2, dtype=torch.float32, device=device)
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 0)
|
||||
with pytest.raises(RuntimeError):
|
||||
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 1)
|
||||
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 3)
|
||||
|
||||
# Test with wrong tensor dimensions (3D instead of 4D)
|
||||
gathered_o = torch.randn(8, 2, 256, dtype=torch.float16, device=device)
|
||||
@ -264,19 +353,20 @@ class TestHelixPostProcess(unittest.TestCase):
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# (native,)
|
||||
(False,),
|
||||
(True,),
|
||||
# (layout,) — "nccl", "fifo_v1", "fifo_v2".
|
||||
("nccl",),
|
||||
("fifo_v1",),
|
||||
("fifo_v2",),
|
||||
]
|
||||
)
|
||||
def test_helix_postprocess_alignment_requirements(self, native):
|
||||
"""Test alignment requirements"""
|
||||
def test_helix_postprocess_alignment_requirements(self, layout):
|
||||
"""Test alignment requirements for all layouts."""
|
||||
device = torch.device("cuda")
|
||||
|
||||
# For float16 (2 bytes), kv_lora_rank must be multiple of 8 for 16-byte alignment
|
||||
|
||||
if native:
|
||||
# This should work (kv_lora_rank = 64 is multiple of 8)
|
||||
# For float16 (2 bytes), kv_lora_rank must be multiple of 8 for 16-byte alignment.
|
||||
if layout == "fifo_v1":
|
||||
# V1 layout: [num_tokens, num_heads, cp_size, kv_lora_rank].
|
||||
# This should work (kv_lora_rank = 64 is multiple of 8).
|
||||
gathered_o = torch.randn(8, 2, 4, 64, dtype=torch.float16, device=device)
|
||||
gathered_stats = torch.randn(8, 2, 4, 2, dtype=torch.float32, device=device)
|
||||
|
||||
@ -285,13 +375,30 @@ class TestHelixPostProcess(unittest.TestCase):
|
||||
except RuntimeError as e:
|
||||
pytest.fail(f"Should not raise error for valid alignment: {e}")
|
||||
|
||||
# Test with kv_lora_rank that doesn't satisfy alignment requirements
|
||||
# Test with kv_lora_rank that doesn't satisfy alignment requirements.
|
||||
gathered_o = torch.randn(8, 1, 4, 4, dtype=torch.float16, device=device)
|
||||
gathered_stats = torch.randn(8, 1, 4, 2, dtype=torch.float32, device=device)
|
||||
with pytest.raises(RuntimeError):
|
||||
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 2)
|
||||
elif layout == "fifo_v2":
|
||||
# V2 layout: [num_tokens, cp_size, num_heads, kv_lora_rank].
|
||||
# This should work (kv_lora_rank = 64 is multiple of 8).
|
||||
gathered_o = torch.randn(8, 4, 2, 64, dtype=torch.float16, device=device)
|
||||
gathered_stats = torch.randn(8, 4, 2, 2, dtype=torch.float32, device=device)
|
||||
|
||||
try:
|
||||
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 1)
|
||||
except RuntimeError as e:
|
||||
pytest.fail(f"Should not raise error for valid alignment: {e}")
|
||||
|
||||
# Test with kv_lora_rank that doesn't satisfy alignment requirements.
|
||||
gathered_o = torch.randn(8, 4, 1, 4, dtype=torch.float16, device=device)
|
||||
gathered_stats = torch.randn(8, 4, 1, 2, dtype=torch.float32, device=device)
|
||||
with pytest.raises(RuntimeError):
|
||||
torch.ops.trtllm.helix_post_process_native(gathered_o, gathered_stats, 1.0, 1)
|
||||
else:
|
||||
# This should work (kv_lora_rank = 64 is multiple of 8)
|
||||
# NCCL layout.
|
||||
# This should work (kv_lora_rank = 64 is multiple of 8).
|
||||
gathered_o = torch.randn(4, 8, 2 * 64, dtype=torch.float16, device=device)
|
||||
gathered_stats = torch.randn(4, 8, 2, 2, dtype=torch.float32, device=device)
|
||||
|
||||
@ -300,7 +407,7 @@ class TestHelixPostProcess(unittest.TestCase):
|
||||
except RuntimeError as e:
|
||||
pytest.fail(f"Should not raise error for valid alignment: {e}")
|
||||
|
||||
# Test with kv_lora_rank that doesn't satisfy alignment requirements
|
||||
# Test with kv_lora_rank that doesn't satisfy alignment requirements.
|
||||
gathered_o = torch.randn(4, 8, 4, dtype=torch.float16, device=device)
|
||||
gathered_stats = torch.randn(4, 8, 1, 2, dtype=torch.float32, device=device)
|
||||
with pytest.raises(RuntimeError):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user