TensorRT-LLMs/cpp/tensorrt_llm/kernels/helixAllToAll.h
Balaram Buddharaju 8c1cfc872b
[TRTLLM-9493][feat] Custom AllToAll for helix parallelism (#9986)
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
2025-12-23 18:14:30 -08:00

95 lines
2.8 KiB
C++

/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/config.h"
#include <cuda_runtime.h>
#include <cstddef>
#include <cstdint>
TRTLLM_NAMESPACE_BEGIN
namespace kernels
{
struct HelixFieldInfo
{
uint8_t* dataPtr;
int elementCount; // Number of elements (e.g., kv_lora_rank for field 0, 1 for
// field 1)
int elementSize; // Size of each element in bytes (2 for half, 8 for float2)
int stride; // Stride between rows in bytes
};
struct HelixAllToAllParams
{
HelixFieldInfo sendFields[2];
HelixFieldInfo recvFields[2];
int entryCount; // Number of entries per peer rank to process
uint64_t* workspace;
int workspaceStrideInU64;
int cpRank;
int cpSize;
int channelCount; // use 0 to auto-compute
int maxChannelCount;
};
// ============================================================================
// Workspace Management Functions
// ============================================================================
/**
* Compute number of channels for communication based on cpSize.
*
* @param cpSize Number of context parallel ranks
* @param smCount Number of SMs available (0 = auto-detect)
* @return Number of channels to use
*/
int computeHelixMaxChannelCount(int cpSize, int smCount = 0);
/**
* Compute the workspace size required per rank for the all-to-all operation.
*
* @param cpSize Number of context parallel ranks
* @return Size in bytes
*/
size_t computeHelixWorkspaceSizePerRank(int cpSize);
/**
* Initialize workspace memory for a given rank.
* Should be called once during setup.
*
* @param workspace Pointer to workspace memory (per-rank view)
* @param cpSize Number of context parallel ranks
* @param stream CUDA stream for asynchronous operations
*/
void initializeHelixWorkspace(uint64_t* workspace, int cpSize, cudaStream_t stream);
/**
* Launch the helix all-to-all kernel.
*
* @param params Kernel parameters including field info and workspace
* @param allowVariableField1 Whether to allow variable field 1
* @param stream CUDA stream for kernel launch
*/
void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream);
} // namespace kernels
TRTLLM_NAMESPACE_END