/* * SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef TOP_LEVEL_DIR #error "Define TOP_LEVEL_DIR" #endif #include #include #include #include "tensorrt_llm/batch_manager/common.h" #include "tensorrt_llm/batch_manager/decoderBuffers.h" #include "tensorrt_llm/batch_manager/guidedDecoder.h" #include "tensorrt_llm/batch_manager/llmRequest.h" #include "tensorrt_llm/executor/executor.h" using namespace tensorrt_llm::runtime; using namespace tensorrt_llm::batch_manager; namespace texec = tensorrt_llm::executor; namespace { auto const TEST_RESOURCE_PATH = std::filesystem::path{TOP_LEVEL_DIR} / "cpp/tests/resources"; auto const DATA_PATH = TEST_RESOURCE_PATH / "data"; auto const GPT_XGRAMMAR_TOKENIZER_INFO_PATH = DATA_PATH / "gpt2" / "xgrammar_tokenizer_info.json"; auto const LLAMA_XGRAMMAR_TOKENIZER_INFO_PATH = DATA_PATH / "Llama-3.2-1B" / "xgrammar_tokenizer_info.json"; } // namespace class GuidedDecoderTest : public ::testing::Test { public: using TensorPtr = ITensor::SharedPtr; using VecTokens = std::vector; using RequestIdType = std::uint64_t; using RequestVector = std::vector>; void SetUp() override { mStream = std::make_shared(); mRuntimeBufferManager = std::make_shared(mStream); } void TearDown() override {} void initData(std::filesystem::path tokenizerInfoPath, SizeType32 vocabSizePadded, VecTokens outputIds, std::vector expectedNumRejected) { mLogitsDtype = nvinfer1::DataType::kFLOAT; mMaxNumRequests = 16; mVocabSizePadded = vocabSizePadded; auto const tokenizerInfo = nlohmann::json::parse(std::ifstream{tokenizerInfoPath}); auto const encodedVocab = tokenizerInfo["encoded_vocab"].template get>(); auto const tokenizerStr = tokenizerInfo["tokenizer_str"].template get(); auto const stopTokenIds = tokenizerInfo["stop_token_ids"].template get>(); texec::GuidedDecodingConfig guidedDecodingConfig( texec::GuidedDecodingConfig::GuidedDecodingBackend::kXGRAMMAR, encodedVocab, tokenizerStr, stopTokenIds); mGuidedDecoder = std::make_shared( guidedDecodingConfig, mMaxNumRequests, mVocabSizePadded, mLogitsDtype, *mRuntimeBufferManager); mLogits.resize(mMaxNumRequests); mLogitsHost.resize(mMaxNumRequests); for (int i = 0; i < mMaxNumRequests; i++) { mLogits[i] = mRuntimeBufferManager->gpu(ITensor::makeShape({mVocabSizePadded}), mLogitsDtype); mLogitsHost[i] = BufferManager::pinned(ITensor::makeShape({mVocabSizePadded}), mLogitsDtype); } mOutputIds = outputIds; mExpectedNumRejected = expectedNumRejected; } void resetLogits() { for (int i = 0; i < mMaxNumRequests; i++) { auto logitsHostData = bufferCast(*mLogitsHost[i]); for (int j = 0; j < mVocabSizePadded; j++) { logitsHostData[j] = 0.0f; } mRuntimeBufferManager->copy(*(mLogitsHost[i]), *(mLogits[i])); } } void syncLogitsToHost() { for (int i = 0; i < mMaxNumRequests; i++) { mRuntimeBufferManager->copy(*(mLogits[i]), *(mLogitsHost[i])); } } int32_t countRejected(int i) { int32_t numRejected = 0; for (int j = 0; j < mVocabSizePadded; j++) { auto logitsHostData = bufferCast(*mLogitsHost[i]); if (logitsHostData[j] < -1e6) { numRejected++; } } return numRejected; } void runTest() { auto llmReq1 = std::make_shared(1, 100, std::make_shared(10), SamplingConfig(), false); texec::GuidedDecodingParams guidedDecodingParams(texec::GuidedDecodingParams::GuideType::kJSON); llmReq1->setGuidedDecodingParams(guidedDecodingParams); llmReq1->mSeqSlot = 1; auto llmReq2 = std::make_shared(1, 100, std::make_shared(10), SamplingConfig(), false); llmReq2->mSeqSlot = 2; RequestVector contextRequests{llmReq1, llmReq2}; RequestVector generationRequests{}; ScheduledRequests scheduledRequests{contextRequests, generationRequests}; DecoderInputBuffers decoderInputBuffers(mMaxNumRequests, 1, *mRuntimeBufferManager); for (auto const& requests : {scheduledRequests.contextRequests, scheduledRequests.generationRequests}) { for (auto const& llmReq : requests) { decoderInputBuffers.decoderRequests.push_back(llmReq); } } decoderInputBuffers.decoderLogits = mLogits; // Context phase resetLogits(); mGuidedDecoder->build(scheduledRequests); mGuidedDecoder->execute(decoderInputBuffers, *mRuntimeBufferManager); syncLogitsToHost(); mRuntimeBufferManager->getStream().synchronize(); // Move request to generation phase contextRequests.pop_back(); contextRequests.pop_back(); llmReq1->setState(LlmRequestState::kGENERATION_IN_PROGRESS); generationRequests.push_back(llmReq1); llmReq2->setState(LlmRequestState::kGENERATION_IN_PROGRESS); generationRequests.push_back(llmReq2); decoderInputBuffers.decoderRequests.clear(); for (auto const& requests : {scheduledRequests.contextRequests, scheduledRequests.generationRequests}) { for (auto const& llmReq : requests) { decoderInputBuffers.decoderRequests.push_back(llmReq); } } EXPECT_EQ(countRejected(0), mExpectedNumRejected[0]); EXPECT_EQ(countRejected(1), 0); // Generation phase for (int i = 0; i < mOutputIds.size(); i++) { llmReq1->addNewToken(mOutputIds[i], 0); llmReq2->addNewToken(mOutputIds[i], 0); resetLogits(); mGuidedDecoder->build(scheduledRequests); mGuidedDecoder->execute(decoderInputBuffers, *mRuntimeBufferManager); syncLogitsToHost(); mRuntimeBufferManager->getStream().synchronize(); EXPECT_EQ(countRejected(0), mExpectedNumRejected[i + 1]); EXPECT_EQ(countRejected(1), 0); } } private: SizeType32 mMaxNumRequests; SizeType32 mVocabSizePadded; nvinfer1::DataType mLogitsDtype; std::vector mLogits; // [mBatchSize, mVocabSizePadded] std::vector mLogitsHost; // [mBatchSize, mVocabSizePadded] std::shared_ptr mRuntimeBufferManager; std::shared_ptr mStream; std::shared_ptr mGuidedDecoder; VecTokens mOutputIds; std::vector mExpectedNumRejected; }; TEST_F(GuidedDecoderTest, GptTokenizer) { VecTokens outputIds{4895, 824, 312, 1298, 366, 27743, 7934, 49793, 1600, 366, 12961, 19703, 4668, 1298, 366, 54, 4537, 17, 12, 17469, 7919, 1600, 366, 3903, 10394, 1298, 366, 1485, 405, 41022, 20662}; std::vector expectedNumRejected{50251, 219, 219, 219, 48558, 219, 219, 219, 219, 50191, 219, 219, 219, 219, 48558, 219, 219, 219, 219, 219, 219, 219, 50191, 219, 219, 219, 48558, 219, 219, 219, 219, 50256}; initData(GPT_XGRAMMAR_TOKENIZER_INFO_PATH, 50257, outputIds, expectedNumRejected); runTest(); } TEST_F(GuidedDecoderTest, LlamaTokenizer) { VecTokens outputIds{6377, 893, 333, 1115, 376, 27247, 6779, 7898, 545, 613, 376, 8926, 17830, 1115, 376, 29956, 7228, 29906, 29899, 10399, 7734, 613, 376, 4980, 2103, 1115, 376, 29896, 29941, 29900, 29900, 341, 29890, 567, 9092}; std::vector expectedNumRejected{128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235, 128235}; initData(LLAMA_XGRAMMAR_TOKENIZER_INFO_PATH, 128256, outputIds, expectedNumRejected); runTest(); }