TensorRT-LLMs/cpp/tensorrt_llm/kernels/communicationKernels/moeAlltoAllKernels.h
Guoming Zhang 93ac0bc1dc
[TRTLLM-10126][feat] Increase topk upper limit to 22 for NVLinkOneSid… (#10229)
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
2025-12-27 22:48:10 +08:00

185 lines
8.3 KiB
C++

/*
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/config.h"
#include <NvInferRuntime.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
TRTLLM_NAMESPACE_BEGIN
namespace kernels::moe_comm
{
// Configuration constants
static constexpr int kMaxTopK = 22; // Maximum top-k experts per token
static constexpr int kMaxPayloads = 4; // Maximum number of different payload types
static constexpr int kMaxRanks = 64; // Maximum supported EP size
// Describes a single payload type to be communicated
struct PayloadDescriptor
{
void const* src_data; // Source data pointer [local_num_tokens, elements_per_token]
int element_size; // Size of each element in bytes
int elements_per_token; // Number of elements per token (e.g., hidden_size, top_k)
};
// Kernel pointers packed into a struct for device access
// Dispatch kernel pointers - const source data
struct DispatchKernelPointers
{
// Payload pointers
void const* src_data_ptrs[kMaxPayloads]; // Array of source data pointers
void* recv_buffers[kMaxRanks][kMaxPayloads]; // 2D array of receive buffer pointers
int payload_bytes_per_token[kMaxPayloads]; // Bytes per token for each payload
// Completion flags for synchronization
uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
// rank has signaled the target rank
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
// Local aux data pointers
int* send_counters; // [ep_size] How many tokens have been sent to each target rank
int* recv_counters[kMaxRanks]; // How many tokens have been received from each source rank. Each rank has [ep_size]
// counters
int* local_token_counter; // Atomic counter for completed tokens
// Top-K compact routing info per local token (size: [local_num_tokens, top_k])
int* topk_target_ranks; // target rank per k, -1 for duplicates
int* topk_send_indices; // dst index per k, -1 for duplicates
};
// Combine kernel pointers - non-const output in src_data_ptrs[0], const recv buffers
struct CombineKernelPointers
{
// Payload pointers
void* src_data_ptrs[kMaxPayloads]; // src_data_ptrs[0] is output
void const* recv_buffers[kMaxRanks][kMaxPayloads]; // 2D array of receive buffer pointers (const)
// Completion flags for synchronization
uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
// rank has signaled the target rank
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
// Top-K compact routing info per local token (size: [local_num_tokens, top_k])
int const* topk_target_ranks; // target rank per k, -1 for duplicates
int const* topk_send_indices; // dst index per k, -1 for duplicates
};
// Dispatch phase parameters
struct MoeA2ADispatchParams
{
bool one_block_per_token; // True: one block per token, False: one warp per token
// Threading policy
// EP configuration
int ep_size; // Number of EP ranks
int ep_rank; // Current EP rank
int num_experts_per_rank; // Number of experts per rank (num_experts / ep_size)
// Token configuration
int local_num_tokens; // Number of tokens on this rank
int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to runtime_max_tokens_per_rank
int top_k; // Number of experts per token
// Expert routing information
int32_t const* token_selected_experts; // [local_num_tokens, top_k]
// Generic payloads
int num_payloads; // Number of different payload types
PayloadDescriptor payloads[kMaxPayloads]; // Array of payload descriptors
// Local aux data
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
int* local_token_counter; // Atomic counter for completed tokens on this rank
int* send_counters; // [ep_size] atomic counters - tracks tokens sent to each target rank
int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank
// per k, -1 for duplicates
int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), dst index
// per k, -1 for duplicates
// Distributed aux data and recv buffers
int* recv_counters[kMaxRanks]; // tracks tokens received from each source rank. Each rank has [ep_size] counters
uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
// rank has signaled the target rank
void* recv_buffers[kMaxRanks][kMaxPayloads]; // Per-rank receive buffers for each payload
// CUDA stream
cudaStream_t stream;
};
// Dispatch kernels
void moe_a2a_dispatch_launch(MoeA2ADispatchParams const& params);
// Prepare for dispatch: zero send_counters, local_token_counter and increment flag_val
void moe_a2a_prepare_dispatch_launch(MoeA2ADispatchParams const& params);
// Combine phase parameters
struct MoeA2ACombineParams
{
bool one_block_per_token; // True: one block per token, False: one warp per token
// EP configuration
int ep_size; // Number of EP ranks
int ep_rank; // Current EP rank
// Token configuration
int local_num_tokens; // Number of tokens on this rank
int max_tokens_per_rank; // Maximum tokens per rank for pre-allocation TODO: Rename to runtime_max_tokens_per_rank
int top_k; // Number of experts per token
// Prepare-only field: original payload tensor pointer used to stage into workspace
void const* prepare_payload;
// Output tensor
void* output_data; // Output buffer [local_num_tokens, elements_per_token]
// Payload information
int elements_per_token; // Number of elements per token
nvinfer1::DataType dtype; // Data type for proper summation
// Local aux data
uint32_t* flag_val; // The value of the flag for this round (stored on the local rank)
int* topk_target_ranks; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), target rank
// per k, -1 for duplicates
int* topk_send_indices; // Top-K compact routing info per local token (size: [local_num_tokens, top_k]), dst index
// per k, -1 for duplicates
int const* recv_counters; // [ep_size] number of valid tokens per source rank for this target
// Distributed aux data and recv buffers
uint32_t* completion_flags[kMaxRanks]; // If completion_flags[target_rank][source_rank] == *flag_val, then source
// rank has signaled the target rank
void const* recv_buffers[kMaxRanks]; // Per-rank receive buffers (only for single payload)
// CUDA stream
cudaStream_t stream;
};
// Combine kernels
void moe_a2a_combine_launch(MoeA2ACombineParams const& params);
void moe_a2a_prepare_combine_launch(MoeA2ACombineParams const& params);
// Sanitize expert IDs for invalid tokens
// expert_ids: [ep_size, max_tokens_per_rank, top_k] (int32)
// recv_counters: [ep_size] (int32), number of valid tokens per source
// invalid_id: value to fill for invalid tokens' expert ids
void moe_a2a_sanitize_expert_ids_launch(int32_t* expert_ids, int32_t const* recv_counters, int32_t invalid_id,
int ep_size, int max_tokens_per_rank, int top_k, cudaStream_t stream);
} // namespace kernels::moe_comm
TRTLLM_NAMESPACE_END