/* * Copyright (c) 2022-2025, NVIDIA CORPORATION. All rights reserved. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "tests/unit_tests/kernels/routing/routingTest.h" namespace tk = tensorrt_llm::kernels; namespace btg = batchedGemm::trtllm::gen; using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::tests::kernels::routing; namespace { template class RoutingLlama4KernelTest : public RoutingKernelTest { protected: using RoutingKernelTest::mSeed; using RoutingKernelTest::mStream; using RoutingKernelTest::mBufferManager; using typename RoutingKernelTest::PackedType; private: // private methods void computeTopKExperts(RoutingKernelTestParam const& param) override { for (int it = 0; it < param.numTokens; ++it) { PackedFloat expWeightsIdx[param.numExperts]; PackedFloat expIdx[param.topK]; float sum = float{0.0f}; float maxScore = -std::numeric_limits::infinity(); for (int ie = 0; ie < param.numExperts; ++ie) { float score; int16_t newIdx = static_cast(ie); score = static_cast(bufferCast(*this->mPtrScoresHost)[it * param.numExperts + ie]); PackedFloat si{static_cast(score), newIdx}; expWeightsIdx[ie] = si; } // Calculate the top-k scores and indices std::partial_sort_copy(expWeightsIdx, expWeightsIdx + param.numExperts, expIdx, expIdx + param.topK, [](PackedFloat const& a, PackedFloat const& b) { return ( (a.score > b.score) || (a.score == b.score && a.idx < b.idx)); //@TODO: check if this is correct }); // Apply sigmoid to the top-k scores for (int ie = 0; ie < param.topK; ++ie) { auto finalScore = 1.F / (1.F + std::exp(-expIdx[ie].score)); expIdx[ie].score = static_cast(finalScore); } // convert back to io_dtype and store the topk expert results in hostData.mPtrExpertIdx for (int ie = 0; ie < param.topK; ++ie) { PackedType si{static_cast(expIdx[ie].score), expIdx[ie].idx}; reinterpret_cast(bufferCast(*this->mPtrExpertIdxHost))[it * param.topK + ie] = si; if (param.getExpWeights) { bufferCast(*this->mPtrExpertWeightsHost)[it * param.topK + ie] = static_cast(expIdx[ie].score); } } } } void allocateBuffers(RoutingKernelTestParam const& param) override { RoutingKernelTest::allocateBuffers(param); int64_t scoresSize = param.numTokens * param.numExperts; this->mPtrScoresHost = mBufferManager->pinned(ITensor::makeShape({scoresSize}), TRTDataType::value); this->mPtrScoresDevice = mBufferManager->gpu(ITensor::makeShape({scoresSize}), TRTDataType::value); } template void setParams(RoutingKernelTestParam const& param, RoutingData& routingData) { RoutingKernelTest::setCommonParams(param, routingData); routingData.mDtypeExpW = btg::Dtype::Bfloat16; routingData.mPtrScores = bufferCast(*this->mPtrScoresDevice); routingData.mPtrExpertIdx = reinterpret_cast(bufferCast(*this->mPtrExpertIdxDevice)); } void callTestedFunction( RoutingKernelTestParam const& param, tensorrt_llm::runtime::ITensor::SharedPtr& workspaceDevice) override { moe::dev::routing::routingLlama4::Data routingData; setParams(param, routingData); moe::dev::routing::routingLlama4::run(routingData, mStream->get()); } }; TYPED_TEST_SUITE(RoutingLlama4KernelTest, Bf16Types); TYPED_TEST(RoutingLlama4KernelTest, WarpLevelParallelization) { RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/3, /*numExperts=*/128, /*topK=*/1, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, /*usePdl=*/true, /*getExpWeights=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 0.0f, /*requiredComputeCapability*/ 8); this->runTest(param); }; TYPED_TEST(RoutingLlama4KernelTest, ClusterLevelParallelization) { RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/10, /*numExperts=*/128, /*topK=*/1, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, /*usePdl=*/true, /*getExpWeights=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 9); this->runTest(param); }; TYPED_TEST(RoutingLlama4KernelTest, DeviceLevelParallelization) { RoutingKernelTestParam param(RoutingMethodType::Llama4, /*numTokens=*/300, /*numExperts=*/128, /*topK=*/1, /*expertParallelization=*/1, /*expertParallelizationId=*/0, /*paddingLog2=*/3, /*localExpertsStrideLog2=*/0, /*usePdl=*/true, /*getExpWeights=*/true, /*nGroup*/ 0, /*topkGroup*/ 0, /*routedScalingFactor*/ 1.0f, /*requiredComputeCapability*/ 8); this->runTest(param); }; } // namespace