mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
95 lines
2.8 KiB
C++
95 lines
2.8 KiB
C++
/*
|
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
#pragma once
|
|
|
|
#include "tensorrt_llm/common/config.h"
|
|
|
|
#include <cuda_runtime.h>
|
|
|
|
#include <cstddef>
|
|
#include <cstdint>
|
|
|
|
TRTLLM_NAMESPACE_BEGIN
|
|
|
|
namespace kernels
|
|
{
|
|
|
|
struct HelixFieldInfo
|
|
{
|
|
uint8_t* dataPtr;
|
|
int elementCount; // Number of elements (e.g., kv_lora_rank for field 0, 1 for
|
|
// field 1)
|
|
int elementSize; // Size of each element in bytes (2 for half, 8 for float2)
|
|
int stride; // Stride between rows in bytes
|
|
};
|
|
|
|
struct HelixAllToAllParams
|
|
{
|
|
HelixFieldInfo sendFields[2];
|
|
HelixFieldInfo recvFields[2];
|
|
int entryCount; // Number of entries per peer rank to process
|
|
uint64_t* workspace;
|
|
int workspaceStrideInU64;
|
|
int cpRank;
|
|
int cpSize;
|
|
int channelCount; // use 0 to auto-compute
|
|
int maxChannelCount;
|
|
};
|
|
|
|
// ============================================================================
|
|
// Workspace Management Functions
|
|
// ============================================================================
|
|
|
|
/**
|
|
* Compute number of channels for communication based on cpSize.
|
|
*
|
|
* @param cpSize Number of context parallel ranks
|
|
* @param smCount Number of SMs available (0 = auto-detect)
|
|
* @return Number of channels to use
|
|
*/
|
|
int computeHelixMaxChannelCount(int cpSize, int smCount = 0);
|
|
|
|
/**
|
|
* Compute the workspace size required per rank for the all-to-all operation.
|
|
*
|
|
* @param cpSize Number of context parallel ranks
|
|
* @return Size in bytes
|
|
*/
|
|
size_t computeHelixWorkspaceSizePerRank(int cpSize);
|
|
|
|
/**
|
|
* Initialize workspace memory for a given rank.
|
|
* Should be called once during setup.
|
|
*
|
|
* @param workspace Pointer to workspace memory (per-rank view)
|
|
* @param cpSize Number of context parallel ranks
|
|
* @param stream CUDA stream for asynchronous operations
|
|
*/
|
|
void initializeHelixWorkspace(uint64_t* workspace, int cpSize, cudaStream_t stream);
|
|
|
|
/**
|
|
* Launch the helix all-to-all kernel.
|
|
*
|
|
* @param params Kernel parameters including field info and workspace
|
|
* @param allowVariableField1 Whether to allow variable field 1
|
|
* @param stream CUDA stream for kernel launch
|
|
*/
|
|
void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream);
|
|
|
|
} // namespace kernels
|
|
|
|
TRTLLM_NAMESPACE_END
|