/* * Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include "tensorrt_llm/common/config.h" #include #include #include TRTLLM_NAMESPACE_BEGIN namespace kernels { struct HelixFieldInfo { uint8_t* dataPtr; int elementCount; // Number of elements (e.g., kv_lora_rank for field 0, 1 for // field 1) int elementSize; // Size of each element in bytes (2 for half, 8 for float2) int stride; // Stride between rows in bytes }; struct HelixAllToAllParams { HelixFieldInfo sendFields[2]; HelixFieldInfo recvFields[2]; int entryCount; // Number of entries per peer rank to process uint64_t* workspace; int workspaceStrideInU64; int cpRank; int cpSize; int channelCount; // use 0 to auto-compute int maxChannelCount; }; // ============================================================================ // Workspace Management Functions // ============================================================================ /** * Compute number of channels for communication based on cpSize. * * @param cpSize Number of context parallel ranks * @param smCount Number of SMs available (0 = auto-detect) * @return Number of channels to use */ int computeHelixMaxChannelCount(int cpSize, int smCount = 0); /** * Compute the workspace size required per rank for the all-to-all operation. * * @param cpSize Number of context parallel ranks * @return Size in bytes */ size_t computeHelixWorkspaceSizePerRank(int cpSize); /** * Initialize workspace memory for a given rank. * Should be called once during setup. * * @param workspace Pointer to workspace memory (per-rank view) * @param cpSize Number of context parallel ranks * @param stream CUDA stream for asynchronous operations */ void initializeHelixWorkspace(uint64_t* workspace, int cpSize, cudaStream_t stream); /** * Launch the helix all-to-all kernel. * * @param params Kernel parameters including field info and workspace * @param allowVariableField1 Whether to allow variable field 1 * @param stream CUDA stream for kernel launch */ void launchHelixAllToAll(HelixAllToAllParams const& params, bool allowVariableField1, cudaStream_t stream); } // namespace kernels TRTLLM_NAMESPACE_END