/* * SPDX-FileCopyrightText: Copyright (c) 2023-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include #include #include "tensorrt_llm/batch_manager/common.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/batch_manager/microBatchScheduler.h" #include using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::batch_manager; using namespace tensorrt_llm::batch_manager::batch_scheduler; using namespace tensorrt_llm::executor; using RequestTable = std::map>; using CudaStreamPtr = std::shared_ptr; class MicroBatchSchedulerTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-type-member-init) { protected: MicroBatchSchedulerTest() {} void SetUp() override {} void TearDown() override {} static std::shared_ptr createRequest(int32_t promptLen, int32_t maxNewTokens, std::optional optionalReqId, SizeType32 beamWidth = 1, int32_t draftTokensLen = 0) { tensorrt_llm::runtime::SamplingConfig samplingConfig; samplingConfig.beamWidth = beamWidth; uint64_t reqId = optionalReqId.value_or((rand() % INT64_MAX) + 1); auto inputTokens = std::make_shared>(promptLen, 1); std::optional>> draftTokens = std::nullopt; std::optional draftLogits = std::nullopt; if (draftTokensLen > 0) { draftTokens = std::make_shared>(draftTokensLen, 2); draftLogits = BufferManager::cpu( ITensor::makeShape({draftTokensLen, /* vocabSizePadded*/ 42}), nvinfer1::DataType::kFLOAT); } return std::make_shared(reqId, maxNewTokens, inputTokens, samplingConfig, /*isStreaming=*/false, /*endId=*/std::nullopt, /*padId=*/std::nullopt, /*embeddingBias=*/std::nullopt, /*badWordsList=*/std::nullopt, /*stopWordsList=*/std::nullopt, /*positionIds=*/std::nullopt, /*promptEmbeddingTable=*/std::nullopt, /*promptVocabSize=*/std::nullopt, /*multimodalHashes=*/std::nullopt, /*multimodalPos=*/std::nullopt, /*multimodalLength=*/std::nullopt, /*multimodalEmbedding=*/std::nullopt, /*mropeRotaryCosSin=*/std::nullopt, /*mropePositionDeltas*/ std::nullopt, /*loraTaskId=*/std::nullopt, /*loraWeights=*/std::nullopt, /*loraConfig=*/std::nullopt, /*lookaheadConfig=*/std::nullopt, /*kvCacheRetentionConfig=*/std::nullopt, /*returnLogProbs=*/false, /*returnContextLogits=*/false, /*returnGenerationLogits=*/false, draftTokens, draftLogits); } RequestTable forward( RequestVector& activeRequests, SizeType32 maxBatchSizeRuntime, std::optional maxNumTokensRuntime) { for (auto const& [reqId, req] : mContextRequests.at(mRuntimeContextId)) { mInflightReqIds.erase(reqId); } auto const [contextRequests, genRequests] = (*mMicroBatchScheduler)(activeRequests, mInflightReqIds, maxBatchSizeRuntime, maxNumTokensRuntime); for (auto const& requests : {contextRequests, genRequests}) { for (auto const& llmReq : requests) { mInflightReqIds.insert(llmReq->mRequestId); } } // ---------------------------------------- // Mock the behavior of TrtModelInflightBatching // ---------------------------------------- // Append a token for all exec requests for (auto const& llmReq : contextRequests) { llmReq->moveToNextContextChunk(); if (!llmReq->getContextRemainingLength()) { llmReq->setState(LlmRequestState::kGENERATION_IN_PROGRESS); llmReq->addNewTokens({mItCount}); } if (llmReq->getNumTokens(0) == llmReq->mPromptLen + llmReq->mMaxNewTokens) { llmReq->setState(LlmRequestState::kGENERATION_COMPLETE); } } for (auto const& llmReq : genRequests) { llmReq->addNewTokens({mItCount}); if (llmReq->getNumTokens(0) == llmReq->mPromptLen + llmReq->mMaxNewTokens) { llmReq->setState(LlmRequestState::kGENERATION_COMPLETE); } } // Remove completed requests auto const activeEnd = std::remove_if(activeRequests.begin(), activeRequests.end(), [](auto const& llmReq) { return llmReq->isGenerationCompleteState(); }); activeRequests.erase(activeEnd, activeRequests.end()); auto& currContextRequests = mContextRequests.at(mRuntimeContextId); currContextRequests.clear(); for (auto const& requests : {contextRequests, genRequests}) { for (auto const& llmReq : requests) { currContextRequests.emplace(llmReq->mRequestId, llmReq); } } // Update the context ID mRuntimeContextId = (mRuntimeContextId + 1) % mNumContexts; mItCount++; return currContextRequests; } ReqIdsSet mInflightReqIds; SizeType32 mItCount{0}; SizeType32 mRuntimeContextId{0}; SizeType32 mMaxInputLen = 1000; SizeType32 mNumContexts; std::vector mContextRequests; std::shared_ptr mMicroBatchScheduler; }; TEST_F(MicroBatchSchedulerTest, SimpleNoOverlap) { constexpr SizeType32 maxBatchSize = 2; constexpr uint64_t maxSeqIdleMicroseconds = 60 * 1000 * 1000; mNumContexts = 1; mContextRequests.resize(mNumContexts); mMicroBatchScheduler = std::make_shared(); constexpr int32_t maxNewTokens = 80; constexpr int32_t promptLen = 10; RequestVector activeRequests; activeRequests.push_back(createRequest(promptLen, maxNewTokens, 0)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 1)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 2)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 3)); for (int it = 0; it < 170; ++it) { RequestTable newRequests = forward(activeRequests, maxBatchSize, std::nullopt); if (it < 80) { EXPECT_EQ(newRequests.size(), 2) << " in iteration " << it; EXPECT_NE(newRequests.find(0), newRequests.end()) << " in iteration " << it; EXPECT_EQ(newRequests.at(0)->getMaxNumGeneratedTokens(), it + 1) << " in iteration " << it; EXPECT_NE(newRequests.find(1), newRequests.end()) << " in iteration " << it; EXPECT_EQ(newRequests.at(1)->getMaxNumGeneratedTokens(), it + 1) << " in iteration " << it; } else if (it < 160) { EXPECT_EQ(newRequests.size(), 2) << " in iteration " << it; EXPECT_NE(newRequests.find(2), newRequests.end()) << " in iteration " << it; EXPECT_EQ(newRequests.at(2)->getMaxNumGeneratedTokens(), (it - 80 + 1)) << " in iteration " << it; EXPECT_NE(newRequests.find(3), newRequests.end()) << " in iteration " << it; EXPECT_EQ(newRequests.at(3)->getMaxNumGeneratedTokens(), (it - 80 + 1)) << " in iteration " << it; } else { EXPECT_EQ(newRequests.size(), 0); } } } TEST_F(MicroBatchSchedulerTest, SimpleNoOverlapMaxNumTokens) { constexpr SizeType32 maxBatchSize = 2; constexpr uint64_t maxSeqIdleMicroseconds = 60 * 1000 * 1000; constexpr SizeType32 maxNumTokens = 7; constexpr SizeType32 chunkUnitSize = 5; constexpr ContextChunkingPolicy ctxChunkPolicy{ContextChunkingPolicy::kEQUAL_PROGRESS}; mNumContexts = 1; mContextRequests.resize(mNumContexts); mMicroBatchScheduler = std::make_shared(ContextChunkingConfig{ctxChunkPolicy, chunkUnitSize}, std::nullopt); constexpr int32_t maxNewTokens = 5; // Use numbers to represent context tokens and letters to represent generated tokens. // Req 0: (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, A, B, C, D, E) // Req 1: (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, A, B, C, D, E) RequestVector activeRequests; constexpr int32_t promptLen0 = 12; constexpr int32_t promptLen1 = 12; activeRequests.push_back(createRequest(promptLen0, maxNewTokens, 0)); activeRequests.push_back(createRequest(promptLen1, maxNewTokens, 1)); for (int it = 0; it < 9; ++it) { RequestTable newRequests = forward(activeRequests, maxBatchSize, maxNumTokens); if (it == 0) { // When it equals 0: // Req 0: (0, 1, 2, 3, 4) // Req 1: () EXPECT_EQ(newRequests.size(), 1); EXPECT_EQ(newRequests.at(0)->getMaxNumGeneratedTokens(), 0); EXPECT_EQ(newRequests.at(0)->getContextCurrentPosition(), 5); } else if (it == 1) { // When it equals 1: // Req 0: The last context chunk can be larger than the chunk unit size // and it also satisfies the total token count limit. // Req 0: 0, 1, 2, 3, 4; (5, 6, 7, 8, 9, 10, 11, A) // Req 1: () EXPECT_EQ(newRequests.size(), 1); EXPECT_EQ(newRequests.at(0)->getMaxNumGeneratedTokens(), 1); EXPECT_EQ(newRequests.at(0)->getContextCurrentPosition(), promptLen0); } else if (it == 2) { // When it equals 2: // Req 0: 0, 1, 2, 3, 4; 5, 6, 7, 8, 9, 10, 11, A; (B) // Req 1: (0, 1, 2, 3, 4) EXPECT_EQ(newRequests.size(), 2); EXPECT_EQ(newRequests.at(0)->getMaxNumGeneratedTokens(), 2); EXPECT_EQ(newRequests.at(0)->getContextCurrentPosition(), promptLen0); EXPECT_EQ(newRequests.at(1)->getContextCurrentPosition(), chunkUnitSize); } else if (it == 3) { // When it equals 3: // Req 1: Although the last chunk can be larger than the chunk unit size, it // does not meet the total number of tokens, so it is chunked again. // Req 0: 0, 1, 2, 3, 4; 5, 6, 7, 8, 9, 10, 11, A; B; (C) // Req 1: 0, 1, 2, 3, 4; (5, 6, 7, 8, 9) EXPECT_EQ(newRequests.size(), 2); EXPECT_EQ(newRequests.at(0)->getMaxNumGeneratedTokens(), 3); EXPECT_EQ(newRequests.at(1)->getContextCurrentPosition(), chunkUnitSize * 2); } else if (it <= 5) { // When it equals 4: // Req 0: 0, 1, 2, 3, 4; 5, 6, 7, 8, 9, A; B; C; D; // Req 1: 0, 1, 2, 3, 4; 5, 6, 7, 8, 9; (10, 11, A) EXPECT_EQ(newRequests.size(), 2); EXPECT_EQ(newRequests.at(0)->getMaxNumGeneratedTokens(), it); EXPECT_EQ(newRequests.at(1)->getMaxNumGeneratedTokens(), it - 3); EXPECT_EQ(newRequests.at(1)->getContextCurrentPosition(), promptLen1); } else if (it <= 8) { // When it equals 6: // Req 0: 0, 1, 2, 3, 4; 5, 6, 7, 8, 9, A; B; C; D; E; // Req 1: 0, 1, 2, 3, 4; 5, 6, 7, 8, 9, 10, 11, A; B; C; (D) EXPECT_EQ(newRequests.size(), 1); EXPECT_EQ(newRequests.at(1)->getMaxNumGeneratedTokens(), it - 3); } } } TEST_F(MicroBatchSchedulerTest, SimpleNoOverlapMaxContextLength) { constexpr SizeType32 maxBatchSize = 2; constexpr uint64_t maxSeqIdleMicroseconds = 60 * 1000 * 1000; constexpr SizeType32 chunkUnitSize = 5; constexpr SizeType32 maxContextLength = 12; ContextChunkingPolicy ctxChunkPolicy{ContextChunkingPolicy::kEQUAL_PROGRESS}; mNumContexts = 1; mContextRequests.resize(mNumContexts); mMicroBatchScheduler = std::make_shared(ContextChunkingConfig{ctxChunkPolicy, chunkUnitSize}, maxContextLength); constexpr int32_t maxNewTokens = 5; // Use numbers to represent context tokens and letters to represent generated tokens. // Req 0, 1, 2: (0, 1, ..., 9, A, B, C, D, E) // Req 3: (0, 1, ..., 16, A, B, C, D, E) RequestVector activeRequests; activeRequests.push_back(createRequest(10, maxNewTokens, 0)); activeRequests.push_back(createRequest(10, maxNewTokens, 1)); activeRequests.push_back(createRequest(10, maxNewTokens, 2)); activeRequests.push_back(createRequest(17, maxNewTokens, 3)); RequestTable newRequests; for (int it = 0; it < 12; ++it) { RequestTable newRequests = forward(activeRequests, maxBatchSize, std::nullopt); if (it == 0) { // The context for requests 0 and 1 can be processed in one batch. // Req 0, 1: (0, 1, ..., 9, A) // Req 2, 3: () EXPECT_EQ(newRequests.size(), 2); EXPECT_EQ(newRequests.at(0)->getContextCurrentPosition(), 10); EXPECT_EQ(newRequests.at(0)->getMaxNumGeneratedTokens(), 1); EXPECT_EQ(newRequests.at(1)->getContextCurrentPosition(), 10); EXPECT_EQ(newRequests.at(1)->getMaxNumGeneratedTokens(), 1); } else if (it < 5) { EXPECT_EQ(newRequests.size(), 2); EXPECT_NE(newRequests.find(0), newRequests.end()); EXPECT_EQ(newRequests.at(0)->getMaxNumGeneratedTokens(), it + 1); EXPECT_NE(newRequests.find(1), newRequests.end()); EXPECT_EQ(newRequests.at(1)->getMaxNumGeneratedTokens(), it + 1); } else if (it == 5) { // Limited by `maxContextLength`, continued chunking is required for request 3. // Req 2: (0, 1, ..., 9, A) // Req 3: (0, 1, ..., 9) EXPECT_EQ(newRequests.size(), 2); EXPECT_EQ(newRequests.at(2)->getContextCurrentPosition(), 10); EXPECT_EQ(newRequests.at(2)->getMaxNumGeneratedTokens(), 1); EXPECT_EQ(newRequests.at(3)->getContextCurrentPosition(), 10); EXPECT_EQ(newRequests.at(3)->getMaxNumGeneratedTokens(), 0); } else if (it == 6) { // Req 2: 0, 1, ..., 9, A; (B) // Req 3: 0, 1, ..., 9; (10, 11, ..., 16, A) EXPECT_EQ(newRequests.size(), 2); EXPECT_EQ(newRequests.at(2)->getContextCurrentPosition(), 10); EXPECT_EQ(newRequests.at(2)->getMaxNumGeneratedTokens(), 2); EXPECT_EQ(newRequests.at(3)->getContextCurrentPosition(), 17); EXPECT_EQ(newRequests.at(3)->getMaxNumGeneratedTokens(), 1); } else if (it <= 9) { EXPECT_EQ(newRequests.size(), 2); EXPECT_NE(newRequests.find(2), newRequests.end()); EXPECT_EQ(newRequests.at(2)->getMaxNumGeneratedTokens(), (it - 4)); EXPECT_NE(newRequests.find(3), newRequests.end()); EXPECT_EQ(newRequests.at(3)->getMaxNumGeneratedTokens(), (it - 5)); } else if (it == 10) { // Req 3: 0, 1, ..., 9; 10, 11, ..., 16; A; B; C; D; (E) EXPECT_EQ(newRequests.size(), 1); EXPECT_NE(newRequests.find(3), newRequests.end()); EXPECT_EQ(newRequests.at(3)->getMaxNumGeneratedTokens(), 5); } else { EXPECT_EQ(newRequests.size(), 0); } } } TEST_F(MicroBatchSchedulerTest, SimpleWithOverlap) { constexpr SizeType32 maxBatchSize = 2; constexpr uint64_t maxSeqIdleMicroseconds = 60 * 1000 * 1000; mNumContexts = 2; mContextRequests.resize(mNumContexts); mMicroBatchScheduler = std::make_shared(); constexpr int32_t maxNewTokens = 80; constexpr int32_t promptLen = 10; RequestVector activeRequests; activeRequests.push_back(createRequest(promptLen, maxNewTokens, 0)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 1)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 2)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 3)); for (int it = 0; it < 170; ++it) { RequestTable newRequests = forward(activeRequests, maxBatchSize, std::nullopt); if (it < 160) { if (it % 2 == 1) { // new: 2,3 EXPECT_EQ(newRequests.size(), 2); EXPECT_NE(newRequests.find(2), newRequests.end()); // EXPECT_EQ(newRequests.at(2)->getMaxNumGeneratedTokens(), (it+1) / 2); EXPECT_EQ(newRequests.at(2)->getMaxNumGeneratedTokens(), (it / 2) + 1); EXPECT_NE(newRequests.find(3), newRequests.end()); // EXPECT_EQ(newRequests.at(3)->getMaxNumGeneratedTokens(), (it+1) / 2); EXPECT_EQ(newRequests.at(3)->getMaxNumGeneratedTokens(), (it / 2) + 1); } else { // new: 0,1 EXPECT_EQ(newRequests.size(), 2); EXPECT_NE(newRequests.find(0), newRequests.end()); EXPECT_EQ(newRequests.at(0)->getMaxNumGeneratedTokens(), (it / 2) + 1); EXPECT_NE(newRequests.find(1), newRequests.end()); EXPECT_EQ(newRequests.at(1)->getMaxNumGeneratedTokens(), (it / 2) + 1); } } else { EXPECT_EQ(newRequests.size(), 0); } } } TEST_F(MicroBatchSchedulerTest, SimpleMaxNumTokensBW1) { constexpr SizeType32 maxNumTokens = 12; constexpr SizeType32 maxBatchSize = 4; constexpr uint64_t maxSeqIdleMicroseconds = 60 * 1000 * 1000; mNumContexts = 1; mContextRequests.resize(mNumContexts); mMicroBatchScheduler = std::make_shared(); constexpr int32_t maxNewTokens = 10; constexpr int32_t promptLen = 10; RequestVector activeRequests; activeRequests.push_back(createRequest(promptLen, maxNewTokens, 0, 1)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 1, 1)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 2, 1)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 3, 1)); for (int it = 0; it < 21; ++it) { RequestTable newRequests = forward(activeRequests, maxBatchSize, maxNumTokens); if (it < 3) { EXPECT_EQ(newRequests.size(), it + 1); } else if (it < 10) { // Due to limit of 12 tokens, we can only have 1 context and 2 gen, or 3 gen // we can't have 1 context + 3 gen EXPECT_EQ(newRequests.size(), 3); EXPECT_NE(newRequests.find(0), newRequests.end()); EXPECT_NE(newRequests.find(1), newRequests.end()); EXPECT_NE(newRequests.find(2), newRequests.end()); } // At iteration 10, request 0 is done, 3 finally enters else if (it < 11) { EXPECT_EQ(newRequests.size(), 3); EXPECT_EQ(newRequests.find(0), newRequests.end()); EXPECT_NE(newRequests.find(1), newRequests.end()); EXPECT_NE(newRequests.find(2), newRequests.end()); EXPECT_NE(newRequests.find(3), newRequests.end()); } // At iteration 11, request 1 is done else if (it < 12) { EXPECT_EQ(newRequests.size(), 2); EXPECT_NE(newRequests.find(2), newRequests.end()); EXPECT_NE(newRequests.find(3), newRequests.end()); } // By iteration 20, all requests are done else if (it == 20) { EXPECT_EQ(newRequests.size(), 0); } } } TEST_F(MicroBatchSchedulerTest, SimpleMaxNumTokensBW4) { constexpr SizeType32 maxNumTokens = 15; constexpr SizeType32 maxBatchSize = 4; constexpr uint64_t maxSeqIdleMicroseconds = 60 * 1000 * 1000; mNumContexts = 1; mContextRequests.resize(mNumContexts); mMicroBatchScheduler = std::make_shared(); constexpr int32_t maxNewTokens = 10; constexpr int32_t promptLen = 10; RequestVector activeRequests; activeRequests.push_back(createRequest(promptLen, maxNewTokens, 0, 4)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 1, 4)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 2, 4)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 3, 4)); // for (int it = 0; it < 22; ++it) { RequestTable newRequests = forward(activeRequests, maxBatchSize, maxNumTokens); if (it < 2) { EXPECT_EQ(newRequests.size(), it + 1); } // At iteration 2, we should be limited to only 2 requests (since we have 2 gen, each needing 4 input ids, so // adding one more gen would violate constraint) else if (it < 10) { EXPECT_EQ(newRequests.size(), 2); EXPECT_NE(newRequests.find(0), newRequests.end()); EXPECT_NE(newRequests.find(1), newRequests.end()); } // At iteration 10, request 0 is done, 2 enters else if (it < 11) { EXPECT_EQ(newRequests.size(), 2); EXPECT_NE(newRequests.find(1), newRequests.end()); EXPECT_NE(newRequests.find(2), newRequests.end()); } // At iteration 11, request 1 is done, 3 enters else if (it < 20) { EXPECT_EQ(newRequests.size(), 2); EXPECT_NE(newRequests.find(2), newRequests.end()); EXPECT_NE(newRequests.find(3), newRequests.end()); } else if (it == 20) { EXPECT_EQ(newRequests.size(), 1); EXPECT_NE(newRequests.find(3), newRequests.end()); } else if (it == 21) { EXPECT_EQ(newRequests.size(), 0); } } } TEST_F(MicroBatchSchedulerTest, DraftTokensMaxNumTokens) { // This test checks that draft tokens will not cause maxNumTokens to be exceeded. constexpr SizeType32 maxNumTokens = 4096; constexpr SizeType32 maxBatchSize = 64; constexpr uint64_t maxSeqIdleMicroseconds = 60 * 1000 * 1000; mNumContexts = 1; mContextRequests.resize(mNumContexts); mMicroBatchScheduler = std::make_shared(); constexpr int32_t maxNewTokens = 10; constexpr int32_t promptLen = 1024; constexpr int32_t draftLen = 17; RequestVector activeRequests; activeRequests.push_back(createRequest(promptLen, maxNewTokens, 0, 1, draftLen)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 1, 1, draftLen)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 2, 1, draftLen)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 3, 1, draftLen)); for (int it = 0; it < 12; ++it) { RequestTable newRequests = forward(activeRequests, maxBatchSize, maxNumTokens); if (it == 0) { // Due to draft tokens, only 3 requests can fit. EXPECT_EQ(newRequests.size(), 3); EXPECT_NE(newRequests.find(0), newRequests.end()); EXPECT_NE(newRequests.find(1), newRequests.end()); EXPECT_NE(newRequests.find(2), newRequests.end()); } else if (it < 10) { EXPECT_EQ(newRequests.size(), 4); EXPECT_NE(newRequests.find(0), newRequests.end()); EXPECT_NE(newRequests.find(1), newRequests.end()); EXPECT_NE(newRequests.find(2), newRequests.end()); EXPECT_NE(newRequests.find(3), newRequests.end()); } else if (it == 10) { EXPECT_EQ(newRequests.size(), 1); EXPECT_NE(newRequests.find(3), newRequests.end()); } else if (it == 11) { EXPECT_EQ(newRequests.size(), 0); } } } TEST_F(MicroBatchSchedulerTest, GenDraftTokensMaxNumTokens) { // This test checks that draft tokens will not cause maxNumTokens to be exceeded. constexpr SizeType32 maxNumTokens = 128; constexpr SizeType32 maxBatchSize = 64; constexpr uint64_t maxSeqIdleMicroseconds = 60 * 1000 * 1000; mNumContexts = 1; mContextRequests.resize(mNumContexts); mMicroBatchScheduler = std::make_shared(); constexpr int32_t maxNewTokens = 2; constexpr int32_t promptLen = 2; constexpr int32_t genDraftLen = 63; RequestVector activeRequests; // No ctx draft tokens. activeRequests.push_back(createRequest(promptLen, maxNewTokens, 0, 1)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 1, 1)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 2, 1)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 3, 1)); for (int it = 0; it < 4; ++it) { RequestTable newRequests = forward(activeRequests, maxBatchSize, maxNumTokens); if (it == 0) { EXPECT_EQ(newRequests.size(), 4); EXPECT_NE(newRequests.find(0), newRequests.end()); EXPECT_NE(newRequests.find(1), newRequests.end()); EXPECT_NE(newRequests.find(2), newRequests.end()); EXPECT_NE(newRequests.find(3), newRequests.end()); // Set gen draft tokens. newRequests.find(0)->second->setDraftTokens(std::make_shared>(genDraftLen)); newRequests.find(1)->second->setDraftTokens(std::make_shared>(genDraftLen)); newRequests.find(2)->second->setDraftTokens(std::make_shared>(genDraftLen)); newRequests.find(3)->second->setDraftTokens(std::make_shared>(genDraftLen)); } if (it == 1) { // Due to draft tokens, only 2 gen requests can fit EXPECT_EQ(newRequests.size(), 2); EXPECT_NE(newRequests.find(0), newRequests.end()); EXPECT_NE(newRequests.find(1), newRequests.end()); } if (it == 2) { EXPECT_EQ(newRequests.size(), 2); EXPECT_NE(newRequests.find(2), newRequests.end()); EXPECT_NE(newRequests.find(3), newRequests.end()); } else if (it == 3) { EXPECT_EQ(newRequests.size(), 0); } } } TEST_F(MicroBatchSchedulerTest, ChunkedContextDraftTokensMaxNumTokens) { // This test checks that draft tokens will not cause maxNumTokens to be exceeded. constexpr SizeType32 maxNumTokens = 8192; constexpr SizeType32 maxBatchSize = 64; constexpr uint64_t maxSeqIdleMicroseconds = 60 * 1000 * 1000; mNumContexts = 1; mContextRequests.resize(mNumContexts); auto ctxChunkConfig = batch_scheduler::ContextChunkingConfig{ContextChunkingPolicy::kFIRST_COME_FIRST_SERVED, 64}; mMicroBatchScheduler = std::make_shared(ctxChunkConfig); constexpr int32_t maxNewTokens = 9; constexpr int32_t promptLen = 2041; constexpr int32_t draftLen = 8; RequestVector activeRequests; activeRequests.push_back(createRequest(promptLen, maxNewTokens, 0, 1, draftLen)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 1, 1, draftLen)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 2, 1, draftLen)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 3, 1, draftLen)); for (int it = 0; it < 10; ++it) { RequestTable newRequests = forward(activeRequests, maxBatchSize, maxNumTokens); if (it < 9) { // Some draft tokens are discarded so that all requests can fit. EXPECT_EQ(newRequests.size(), 4); EXPECT_NE(newRequests.find(0), newRequests.end()); EXPECT_NE(newRequests.find(1), newRequests.end()); EXPECT_NE(newRequests.find(2), newRequests.end()); EXPECT_NE(newRequests.find(3), newRequests.end()); EXPECT_EQ(newRequests.find(0)->second->getNumDraftTokens(), 7); EXPECT_EQ(newRequests.find(1)->second->getNumDraftTokens(), 7); EXPECT_EQ(newRequests.find(2)->second->getNumDraftTokens(), 7); EXPECT_EQ(newRequests.find(3)->second->getNumDraftTokens(), 7); } else if (it == 9) { EXPECT_EQ(newRequests.size(), 0); } } } TEST_F(MicroBatchSchedulerTest, ChunkedContextDraftTokensMaxContextLength) { // This test checks that draft tokens will not cause maxContextLength to be exceeded. constexpr SizeType32 maxContextLength = 10; constexpr SizeType32 maxBatchSize = 64; constexpr uint64_t maxSeqIdleMicroseconds = 60 * 1000 * 1000; constexpr SizeType32 maxNumTokens = 8192; mNumContexts = 1; mContextRequests.resize(mNumContexts); auto ctxChunkConfig = batch_scheduler::ContextChunkingConfig{ContextChunkingPolicy::kFIRST_COME_FIRST_SERVED, 64}; mMicroBatchScheduler = std::make_shared(ctxChunkConfig, maxContextLength); constexpr int32_t maxNewTokens = 6; constexpr int32_t promptLen = 6; constexpr int32_t draftLen = 5; RequestVector activeRequests; activeRequests.push_back(createRequest(promptLen, maxNewTokens, 0, 1, draftLen)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 1, 1, draftLen)); for (int it = 0; it < 7; ++it) { RequestTable newRequests = forward(activeRequests, maxBatchSize, maxNumTokens); if (it < 6) { // Some draft tokens are discarded so that all requests can fit. EXPECT_EQ(newRequests.size(), 2); EXPECT_NE(newRequests.find(0), newRequests.end()); EXPECT_NE(newRequests.find(1), newRequests.end()); EXPECT_EQ(newRequests.find(0)->second->getNumDraftTokens(), 4); EXPECT_EQ(newRequests.find(1)->second->getNumDraftTokens(), 4); } else if (it == 6) { EXPECT_EQ(newRequests.size(), 0); } } } TEST_F(MicroBatchSchedulerTest, DraftTokensGreaterThanChunkSize) { // This test checks that draft tokens are dropped to fit the request within // a chunk unit. // TODO(tmorris): This behavior may not be desired and might be changed soon. constexpr SizeType32 maxNumTokens = 40; constexpr SizeType32 maxBatchSize = 64; constexpr SizeType32 chunkUnitSize = 16; constexpr uint64_t maxSeqIdleMicroseconds = 60 * 1000 * 1000; mNumContexts = 1; mContextRequests.resize(mNumContexts); auto ctxChunkConfig = batch_scheduler::ContextChunkingConfig{ContextChunkingPolicy::kFIRST_COME_FIRST_SERVED, chunkUnitSize}; mMicroBatchScheduler = std::make_shared(ctxChunkConfig, maxBatchSize); constexpr int32_t maxNewTokens = 21; constexpr int32_t promptLen = 3; constexpr int32_t draftLen = 17; RequestVector activeRequests; activeRequests.push_back(createRequest(promptLen, maxNewTokens, 0, 1, draftLen)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 1, 1, draftLen)); activeRequests.push_back(createRequest(promptLen, maxNewTokens, 2, 1, draftLen)); for (int it = 0; it < 22; ++it) { RequestTable newRequests = forward(activeRequests, maxBatchSize, maxNumTokens); if (it < 21) { EXPECT_EQ(newRequests.size(), 3); EXPECT_NE(newRequests.find(0), newRequests.end()); EXPECT_NE(newRequests.find(1), newRequests.end()); EXPECT_NE(newRequests.find(2), newRequests.end()); EXPECT_EQ(newRequests.find(0)->second->getNumDraftTokens(), 13); EXPECT_EQ(newRequests.find(1)->second->getNumDraftTokens(), 13); EXPECT_EQ(newRequests.find(2)->second->getNumDraftTokens(), 5); } else if (it == 21) { EXPECT_EQ(newRequests.size(), 0); } } } class ContextChunkingTest : public MicroBatchSchedulerTest { protected: using Policy = ContextChunkingPolicy; ~ContextChunkingTest() { for (auto ctxChunkPolicy : {Policy::kEQUAL_PROGRESS, Policy::kFIRST_COME_FIRST_SERVED}) { auto reqs = initContextLengths(mLengths, mDraftLengths); auto const& statesMap = mExpectedStates.at(ctxChunkPolicy); EXPECT_GT(statesMap.size(), 0); SizeType32 const endIter = statesMap.rbegin()->first + 1; for (SizeType32 i = 0; i < endIter; ++i) { forward(ctxChunkPolicy, reqs, i, statesMap); } // Check final draft lengths. if (mExpectedDraftLengthsMap.count(ctxChunkPolicy)) { auto const& expectedDraftLengths = mExpectedDraftLengthsMap.at(ctxChunkPolicy); EXPECT_EQ(reqs.size(), expectedDraftLengths.size()); for (size_t i = 0; i < reqs.size(); i++) { EXPECT_EQ(reqs[i]->getNumDraftTokens(), expectedDraftLengths[i]) << "policy = " << ctxChunkPolicy; } } } } struct ChunkState { SizeType32 mContextCurrentPosition; LlmRequest::RequestIdType mRequestId; friend std::ostream& operator<<(std::ostream& os, ChunkState const& obj) { os << "pos = " << obj.mContextCurrentPosition << ", id = " << obj.mRequestId; return os; } }; static RequestVector initContextLengths( std::vector const& lengths, std::vector const& draftLengths) { RequestVector reqs; constexpr SizeType32 maxNewTokens = 1; for (size_t i = 0; i < lengths.size(); ++i) { auto draftLen = draftLengths.size() > 0 ? draftLengths[i] : 0; reqs.push_back(createRequest(lengths[i], maxNewTokens, i, /*beamWidth=*/1, draftLen)); } return reqs; } static std::string debugStr(LlmRequest const& val) { std::ostringstream os; os << "pos = " << val.getContextCurrentPosition() << ", id = " << val.mRequestId; return os.str(); } static bool isEqual(LlmRequest const& source, ChunkState const& target) { bool ret = true; ret = ret && (source.mRequestId == target.mRequestId); ret = ret && (source.getContextCurrentPosition() == target.mContextCurrentPosition); return ret; } template void setExpectedPositions(std::vector> const& positions) { std::vector stateVec; for (size_t iter = 0; iter < positions.size(); ++iter) { for (size_t i = 0; i < positions[iter].size(); ++i) { stateVec.emplace_back(ChunkState{positions[iter][i], i}); } mExpectedStates[tPolicy].insert({iter, std::move(stateVec)}); } } void forward(Policy ctxChunkPolicy, RequestVector const& reqs, SizeType32 itCount, std::map> const& statesMap) { // Don't process already completed requests RequestVector activeRequests; std::copy_if(reqs.begin(), reqs.end(), std::back_inserter(activeRequests), [](auto const& llmReq) { return llmReq->getContextRemainingLength() > 0; }); MicroBatchScheduler::setCtxRequestsChunkSize( activeRequests, ctxChunkPolicy, mCtxTokensCapacity, mChunkUnitSize, mMaxContextLength); for (auto const& req : activeRequests) { req->moveToNextContextChunk(); } auto stateIt = statesMap.find(itCount); if (stateIt != statesMap.end()) { std::vector const& states = stateIt->second; EXPECT_EQ(reqs.size(), states.size()); auto reqIt = reqs.begin(); SizeType32 i = 0; while (reqIt != reqs.end()) { EXPECT_TRUE(isEqual(**reqIt, states[i])) << "policy = " << ctxChunkPolicy << "; mItCount = " << mItCount << "; actual: " << debugStr(**reqIt) << "; expect: " << states[i]; ++reqIt; ++i; } } } void setCtxTokenCapacity(SizeType32 ctxTokensCapacity) { mCtxTokensCapacity = ctxTokensCapacity; } void setChunkUnitSize(SizeType32 chunkUnitSize) { mChunkUnitSize = chunkUnitSize; } void setMaxContextLength(SizeType32 maxContextLength) { mMaxContextLength = maxContextLength; } void setContextLengths(std::vector lengths) { mLengths = std::move(lengths); } void setDraftLengths(std::vector draftLengths) { mDraftLengths = std::move(draftLengths); } template void setExpectedFinalDraftLengths(std::vector expectedDraftLengths) { mExpectedDraftLengthsMap[tPolicy] = std::move(expectedDraftLengths); } private: std::vector mLengths; std::vector mDraftLengths; std::map> mExpectedDraftLengthsMap; std::map>> mExpectedStates; std::optional mCtxTokensCapacity; SizeType32 mChunkUnitSize{0}; std::optional mMaxContextLength; }; TEST_F(ContextChunkingTest, NoLimit) { setContextLengths({25, 25}); setChunkUnitSize(20); setExpectedPositions({{25, 25}}); setExpectedPositions({{25, 25}}); } TEST_F(ContextChunkingTest, ContextLengthNeverSatisfied) { setContextLengths({25, 25}); setMaxContextLength(20); setChunkUnitSize(100); setExpectedPositions({{0, 0}}); setExpectedPositions({{0, 0}}); } TEST_F(ContextChunkingTest, ChunkLongerThanContext) { setContextLengths({25, 25}); setMaxContextLength(25); setChunkUnitSize(30); setExpectedPositions({{25, 25}}); setExpectedPositions({{25, 25}}); } TEST_F(ContextChunkingTest, ContextLengthSatisfied) { setContextLengths({10, 25}); setMaxContextLength(20); setChunkUnitSize(10); setExpectedPositions({{10, 20}, {10, 25}}); setExpectedPositions({{10, 20}, {10, 25}}); } TEST_F(ContextChunkingTest, TokenCapacitySmallerThanContext) { setContextLengths({25, 25}); setChunkUnitSize(20); setCtxTokenCapacity(20); setExpectedPositions({{20, 0}, {25, 0}, {25, 20}, {25, 25}}); setExpectedPositions({{20, 0}, {25, 0}, {25, 20}, {25, 25}}); } TEST_F(ContextChunkingTest, TokenCapacitySmallerThanChunkUnit) { setContextLengths({25, 25}); setChunkUnitSize(20); setCtxTokenCapacity(10); setExpectedPositions({{0, 0}}); setExpectedPositions({{0, 0}}); } TEST_F(ContextChunkingTest, SchedulingOrder) { setContextLengths({25, 25}); setChunkUnitSize(5); setCtxTokenCapacity(30); setExpectedPositions({{15, 15}, {25, 25}}); setExpectedPositions({{25, 5}, {25, 25}}); } TEST_F(ContextChunkingTest, CompletionOrder) { setContextLengths({25, 15}); setChunkUnitSize(5); setCtxTokenCapacity(30); setExpectedPositions({{15, 15}, {25, 15}}); setExpectedPositions({{25, 5}, {25, 15}}); } TEST_F(ContextChunkingTest, LongFirstShortLater) { setContextLengths({25, 15}); setChunkUnitSize(5); setCtxTokenCapacity(30); setMaxContextLength(10); setExpectedPositions({{10, 10}, {20, 15}, {25, 15}}); setExpectedPositions({{10, 10}, {20, 15}, {25, 15}}); } TEST_F(ContextChunkingTest, FrontPriority) { setContextLengths({25, 25}); setChunkUnitSize(5); setCtxTokenCapacity(15); setExpectedPositions({{10, 5}, {20, 10}, {25, 20}, {25, 25}}); setExpectedPositions({{15, 0}, {25, 5}, {25, 20}, {25, 25}}); } TEST_F(ContextChunkingTest, DraftTokensDiscard) { setContextLengths({27, 27}); setDraftLengths({5, 5}); setChunkUnitSize(5); setCtxTokenCapacity(30); setExpectedPositions({{15, 15}, {27, 27}}); setExpectedPositions({{27, 0}, {27, 27}}); setExpectedFinalDraftLengths({3, 3}); setExpectedFinalDraftLengths({3, 3}); } TEST_F(ContextChunkingTest, DraftTokensDiscard2) { setContextLengths({17, 17}); setDraftLengths({5, 5}); setChunkUnitSize(5); setCtxTokenCapacity(30); setExpectedPositions({{15, 15}, {17, 17}}); setExpectedFinalDraftLengths({3, 3}); setExpectedPositions({{17, 10}, {17, 17}}); setExpectedFinalDraftLengths({3, 3}); } TEST_F(ContextChunkingTest, DraftTokensDiscard3) { setContextLengths({27, 27}); setDraftLengths({5, 5}); setChunkUnitSize(5); setCtxTokenCapacity(20); setExpectedPositions({{10, 10}, {20, 20}, {27, 27}}); setExpectedPositions({{20, 0}, {27, 10}, {27, 27}}); setExpectedFinalDraftLengths({3, 3}); setExpectedFinalDraftLengths({3, 3}); } TEST_F(ContextChunkingTest, DraftTokensDiscardDueToTokenCapacity) { setContextLengths({23, 17}); setDraftLengths({5, 5}); setChunkUnitSize(5); setCtxTokenCapacity(20); setExpectedPositions({{10, 10}, {23, 17}}); setExpectedFinalDraftLengths({0, 0}); setExpectedPositions({{20, 0}, {23, 17}}); setExpectedFinalDraftLengths({0, 0}); } TEST_F(ContextChunkingTest, DraftTokensDiscardDueToMaxContextLength) { setContextLengths({6, 6}); setDraftLengths({5, 5}); setChunkUnitSize(5); setCtxTokenCapacity(30); setMaxContextLength(10); setExpectedPositions({{6, 6}}); setExpectedPositions({{6, 6}}); setExpectedFinalDraftLengths({4, 4}); setExpectedFinalDraftLengths({4, 4}); } TEST_F(ContextChunkingTest, DraftTokensDiscardAll) { setContextLengths({25, 25}); setDraftLengths({5, 5}); setChunkUnitSize(5); setCtxTokenCapacity(50); setExpectedPositions({{25, 25}}); setExpectedPositions({{25, 25}}); setExpectedFinalDraftLengths({0, 0}); setExpectedFinalDraftLengths({0, 0}); } TEST_F(ContextChunkingTest, DraftTokensDiscardAll2) { setContextLengths({25, 25}); setDraftLengths({5, 5}); setChunkUnitSize(5); setCtxTokenCapacity(25); setExpectedPositions({{15, 10}, {25, 25}}); setExpectedPositions({{25, 0}, {25, 25}}); setExpectedFinalDraftLengths({0, 0}); setExpectedFinalDraftLengths({0, 0}); } TEST_F(ContextChunkingTest, DraftTokensNoDiscard) { setContextLengths({25, 25}); setDraftLengths({5, 5}); setChunkUnitSize(10); setCtxTokenCapacity(30); setExpectedPositions({{20, 10}, {25, 25}}); setExpectedPositions({{25, 0}, {25, 25}}); setExpectedFinalDraftLengths({5, 5}); setExpectedFinalDraftLengths({5, 5}); } TEST_F(ContextChunkingTest, DraftTokensNoChunkingDiscardAll) { setContextLengths({4128}); setDraftLengths({3}); setChunkUnitSize(64); setMaxContextLength(4128); setExpectedPositions({{4128}}); setExpectedPositions({{4128}}); setExpectedFinalDraftLengths({0}); setExpectedFinalDraftLengths({0}); } TEST_F(ContextChunkingTest, DraftTokensNoChunkingDiscardSome) { setContextLengths({4127}); setDraftLengths({3}); setChunkUnitSize(64); setMaxContextLength(4128); setExpectedPositions({{4127}}); setExpectedPositions({{4127}}); setExpectedFinalDraftLengths({1}); setExpectedFinalDraftLengths({1}); } TEST_F(ContextChunkingTest, DraftTokensNoChunkingDiscardNone) { setContextLengths({4125}); setDraftLengths({3}); setChunkUnitSize(64); setMaxContextLength(4128); setExpectedPositions({{4125}}); setExpectedPositions({{4125}}); setExpectedFinalDraftLengths({3}); setExpectedFinalDraftLengths({3}); }