mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
[None][fix] Rename: slot_count -> invalid_expert_id (#8783)
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
This commit is contained in:
parent
89e0117097
commit
4c5a8f4ec6
@ -280,7 +280,7 @@ __global__ void computeCumsumDevice(int* sendCountsCumsum, int* recvCountsCumsum
|
||||
}
|
||||
|
||||
__global__ void memsetExpertIdsDevice(
|
||||
int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount, int rankCount)
|
||||
int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int invalidExpertId, int rankCount)
|
||||
{
|
||||
int maxTokenCount = maxTokenCountPerRank * rankCount;
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
@ -291,7 +291,7 @@ __global__ void memsetExpertIdsDevice(
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i + totalRecvTokenCount * topK < maxTokenCount * topK;
|
||||
i += gridDim.x * blockDim.x)
|
||||
{
|
||||
*(expertIds + i + totalRecvTokenCount * topK) = slotCount;
|
||||
*(expertIds + i + totalRecvTokenCount * topK) = invalidExpertId;
|
||||
}
|
||||
}
|
||||
|
||||
@ -355,7 +355,7 @@ void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, i
|
||||
maxTokenCountPerRank);
|
||||
}
|
||||
|
||||
void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount,
|
||||
void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int invalidExpertId,
|
||||
int rankCount, cudaStream_t stream)
|
||||
{
|
||||
int smCount = tensorrt_llm::common::getMultiProcessorCount();
|
||||
@ -364,7 +364,7 @@ void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPer
|
||||
dim3 grid(smCount);
|
||||
|
||||
launchWithPdlWhenEnabled("memsetExpertIds", memsetExpertIdsDevice, grid, block, 0, stream, expertIds,
|
||||
recvCountsCumsum, maxTokenCountPerRank, topK, slotCount, rankCount);
|
||||
recvCountsCumsum, maxTokenCountPerRank, topK, invalidExpertId, rankCount);
|
||||
}
|
||||
|
||||
size_t getMoePrepareWorkspaceSize(int epSize)
|
||||
|
||||
@ -80,7 +80,7 @@ void moveIndice(int* sendCountsCumsum, int* recvCountsCumsum, int* sendIndice, i
|
||||
int* backwardIndice, int* gatherBackwardIndice, int* recvIndice, int* gatherRecvIndice, int rankId, int rankCount,
|
||||
int maxTokenCountPerRank, cudaStream_t stream);
|
||||
|
||||
void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int slotCount,
|
||||
void memsetExpertIds(int* expertIds, int* recvCountsCumsum, int maxTokenCountPerRank, int topK, int invalidExpertId,
|
||||
int epSize, cudaStream_t stream);
|
||||
|
||||
size_t getMoePrepareWorkspaceSize(int epSize);
|
||||
|
||||
@ -228,7 +228,7 @@ moePrepareOp(torch::Tensor expertsIds, c10::optional<torch::Tensor> expertsStati
|
||||
}
|
||||
|
||||
void memsetExpertIds(torch::Tensor expertsIds, torch::Tensor recvRankCountCumSum, int64_t maxTokenCountPerRank,
|
||||
int64_t topK, int64_t slotCount, int64_t epSize)
|
||||
int64_t topK, int64_t invalidExpertId, int64_t epSize)
|
||||
{
|
||||
CHECK_INPUT(expertsIds, torch::kInt32);
|
||||
TORCH_CHECK(expertsIds.dim() == 2, "expertsIds must be a 1D tensor");
|
||||
@ -243,7 +243,7 @@ void memsetExpertIds(torch::Tensor expertsIds, torch::Tensor recvRankCountCumSum
|
||||
auto stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
tensorrt_llm::kernels::moe_prepare::memsetExpertIds(expertsIds.data_ptr<int>(), recvRankCountCumSum.data_ptr<int>(),
|
||||
static_cast<int>(maxTokenCountPerRank), static_cast<int>(topK), static_cast<int>(slotCount),
|
||||
static_cast<int>(maxTokenCountPerRank), static_cast<int>(topK), static_cast<int>(invalidExpertId),
|
||||
static_cast<int>(epSize), stream);
|
||||
}
|
||||
|
||||
@ -310,7 +310,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
||||
m.def(
|
||||
"memset_expert_ids(Tensor(a!) experts_ids, Tensor recv_rank_count_cumsum, int max_token_count_per_rank, int "
|
||||
"top_k, "
|
||||
"int slot_count, int ep_size) -> ()");
|
||||
"int invalid_expert_id, int ep_size) -> ()");
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
||||
|
||||
@ -283,7 +283,7 @@ def _register_fake():
|
||||
|
||||
@torch.library.register_fake("trtllm::memset_expert_ids")
|
||||
def _(experts_ids: torch.Tensor, recv_rank_count_cumsum: torch.Tensor,
|
||||
max_token_count_per_rank: int, top_k: int, slot_count: int,
|
||||
max_token_count_per_rank: int, top_k: int, invalid_expert_id: int,
|
||||
ep_size: int):
|
||||
pass
|
||||
|
||||
|
||||
@ -370,7 +370,7 @@ class TRTLLMGenFusedMoE(MoE):
|
||||
alltoall_info.recv_rank_count_cumsum,
|
||||
max_num_token,
|
||||
top_k,
|
||||
self.num_slots,
|
||||
-1, # Trtllm Gen uses -1 as invalid expert id
|
||||
self.ep_size,
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user