mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
fix: Improve chunking test and skip empty kernel calls (#5710)
Signed-off-by: Robin Kobus <19427718+Funatiq@users.noreply.github.com>
This commit is contained in:
parent
b8fef809ae
commit
07f9cf1519
@ -136,8 +136,11 @@ CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, ru
|
||||
std::copy_if(contextRequests.begin(), contextRequests.end(), std::back_inserter(finishedContextRequests),
|
||||
[](auto const& llmReq) { return llmReq->isLastContextChunk(); });
|
||||
|
||||
copySequenceLengths(finishedContextRequests, inputBuffers, *decoderState.getSequenceLengths(), beamWidth,
|
||||
bufferManager, runtimeStream);
|
||||
if (!finishedContextRequests.empty())
|
||||
{
|
||||
copySequenceLengths(finishedContextRequests, inputBuffers, *decoderState.getSequenceLengths(), beamWidth,
|
||||
bufferManager, runtimeStream);
|
||||
}
|
||||
|
||||
auto [lookaheadPrompt, lookaheadAlgoConfigs] = createDecoderRequests(finishedContextRequests,
|
||||
inputBuffers.inputsIds, decodingConfig, decoderState, bufferManager, logitsType, modelConfig, worldConfig,
|
||||
|
||||
@ -815,10 +815,13 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request
|
||||
auto contextInputsIds = ITensor::slice(inputsIds, 0, numContextTokens);
|
||||
manager.copy(inputHost.data(), *contextInputsIds);
|
||||
|
||||
auto generationInputsIds = ITensor::slice(inputsIds, numContextTokens);
|
||||
auto seqSlotsDeviceSlice = ITensor::slice(seqSlotsDevice, numContextRequests);
|
||||
runtime::kernels::invokeGatherBatch(
|
||||
*generationInputsIds, *newOutputTokens, *seqSlotsDeviceSlice, maxBeamWidth, stream);
|
||||
if (!genRequests.empty())
|
||||
{
|
||||
auto generationInputsIds = ITensor::slice(inputsIds, numContextTokens);
|
||||
auto seqSlotsDeviceSlice = ITensor::slice(seqSlotsDevice, numContextRequests);
|
||||
runtime::kernels::invokeGatherBatch(
|
||||
*generationInputsIds, *newOutputTokens, *seqSlotsDeviceSlice, maxBeamWidth, stream);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -885,6 +888,8 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request
|
||||
}
|
||||
}
|
||||
|
||||
sync_check_cuda_error(stream.get());
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
|
||||
@ -1024,6 +1024,7 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
|
||||
{
|
||||
prepareDistGenBufferAndDecoder(currRequests.generationRequests);
|
||||
}
|
||||
sync_check_cuda_error(mRuntime->getStream().get());
|
||||
|
||||
executeBatch(currRequests);
|
||||
if (mWorldConfig.isLastPipelineParallelRank() && mGuidedDecoder)
|
||||
@ -1063,6 +1064,8 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
|
||||
? std::make_optional(decoderStepAsync(currRequests))
|
||||
: std::nullopt;
|
||||
|
||||
sync_check_cuda_error(mRuntime->getStream().get());
|
||||
|
||||
mLastIterationStatsIFB = fillIterationStats(currRequests, requestsToPause);
|
||||
for (auto const& requests : {currRequests.contextRequests, currRequests.generationRequests})
|
||||
{
|
||||
|
||||
@ -66,7 +66,6 @@ void initBindings(py::module_& m)
|
||||
.def("use_draft_tokens_external_decoding", &ModelSpec::useDraftTokensExternalDecoding)
|
||||
.def("use_logits", &ModelSpec::useLogits)
|
||||
.def("use_multiple_profiles", &ModelSpec::useMultipleProfiles)
|
||||
.def("set_batch_sizes", &ModelSpec::setBatchSizes)
|
||||
.def("set_max_input_length", &ModelSpec::setMaxInputLength)
|
||||
.def("set_max_output_length", &ModelSpec::setMaxOutputLength)
|
||||
.def("set_quant_method", &ModelSpec::setQuantMethod)
|
||||
|
||||
@ -206,12 +206,6 @@ public:
|
||||
return mEnableContextFMHAFp32Acc;
|
||||
}
|
||||
|
||||
ModelSpec& setBatchSizes(std::vector<SizeType32> batchSizes)
|
||||
{
|
||||
mBatchSizes = std::move(batchSizes);
|
||||
return *this;
|
||||
}
|
||||
|
||||
ModelSpec& setMaxInputLength(SizeType32 maxInputLength)
|
||||
{
|
||||
mMaxInputLength = maxInputLength;
|
||||
@ -338,7 +332,6 @@ public:
|
||||
QuantMethod mQuantMethod{QuantMethod::kNONE};
|
||||
|
||||
SpeculativeDecodingMode mSpecDecodingMode{SpeculativeDecodingMode::None()};
|
||||
std::vector<SizeType32> mBatchSizes{1, 2, 8};
|
||||
|
||||
std::optional<tensorrt_llm::executor::CapacitySchedulerPolicy> mCapacitySchedulerPolicy{std::nullopt};
|
||||
|
||||
|
||||
@ -708,14 +708,16 @@ struct BeamConfig
|
||||
} // namespace
|
||||
|
||||
using ParamType = std::tuple<ModelParams, ModelSpec, TrtGptModelType, TrtGptModelIfbTestType, BeamConfig, // id: 0-4
|
||||
std::optional<int32_t>, // 5. maxTokensInPagedKvCache
|
||||
std::optional<float>, // 6. freeGpuMemoryFraction
|
||||
bool, // 7. enableTrtOverlap
|
||||
bool, // 8. enableChunkedContext
|
||||
bool, // 9. enableStreamingMode
|
||||
bool, // 10. enableCudaGraphMode
|
||||
std::optional<size_t>, // 11. hostCacheSize
|
||||
bool // 12. useRandomEndId
|
||||
std::optional<int32_t>, // 5. maxTokensInPagedKvCache
|
||||
std::optional<float>, // 6. freeGpuMemoryFraction
|
||||
bool, // 7. enableTrtOverlap
|
||||
bool, // 8. enableChunkedContext
|
||||
bool, // 9. enableStreamingMode
|
||||
bool, // 10. enableCudaGraphMode
|
||||
std::optional<size_t>, // 11. hostCacheSize
|
||||
bool, // 12. useRandomEndId
|
||||
std::vector<SizeType32>, // 13. batchSizes
|
||||
std::optional<SizeType32> // 14. maxNumTokens
|
||||
>;
|
||||
|
||||
std::string generateTestName(testing::TestParamInfo<ParamType> const& info)
|
||||
@ -866,7 +868,7 @@ TEST_P(ParamTest, Test)
|
||||
|
||||
auto const useRandomEndId = std::get<12>(GetParam());
|
||||
|
||||
std::vector<int32_t> batchSizes = modelSpec.mBatchSizes;
|
||||
auto const batchSizes = std::get<13>(GetParam());
|
||||
|
||||
std::ostringstream gpuSizePath;
|
||||
gpuSizePath << "tp" << modelSpec.mTPSize << "-pp" << modelSpec.mPPSize << "-cp" << modelSpec.mCPSize;
|
||||
@ -935,6 +937,11 @@ TEST_P(ParamTest, Test)
|
||||
|
||||
executorConfig.setEnableTrtOverlap(std::get<7>(GetParam()));
|
||||
executorConfig.setEnableChunkedContext(std::get<8>(GetParam()));
|
||||
auto const maxNumTokens = std::get<14>(GetParam());
|
||||
if (maxNumTokens.has_value())
|
||||
{
|
||||
executorConfig.setMaxNumTokens(maxNumTokens.value());
|
||||
}
|
||||
executorConfig.setNormalizeLogProbs(false);
|
||||
executorConfig.setMaxBeamWidth(beamConfig.maxBeamWidth);
|
||||
executorConfig.setGatherGenerationLogits(modelSpec.mCollectGenerationLogits);
|
||||
@ -1047,16 +1054,18 @@ INSTANTIATE_TEST_SUITE_P(GptTests, ParamTest,
|
||||
TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM),
|
||||
testing::Values(
|
||||
// TODO: enable more tests when mixed beam width is supported
|
||||
BeamConfig{1, {1}}, BeamConfig{2, {2}} // , BeamConfig{2, {1, 2}}
|
||||
BeamConfig{1, {1}}, BeamConfig{2, {2}} // , BeamConfig{2, {1, 2}}
|
||||
),
|
||||
testing::Values(std::nullopt, 1280), // maxTokensInPagedKvCache
|
||||
testing::Values(std::nullopt, 0.4), // freeGpuMemoryFraction
|
||||
testing::Values(true), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(std::nullopt, 1280), // maxTokensInPagedKvCache
|
||||
testing::Values(std::nullopt, 0.4), // freeGpuMemoryFraction
|
||||
testing::Values(true), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1073,16 +1082,18 @@ INSTANTIATE_TEST_SUITE_P(GptRandomEndIdTests, ParamTest,
|
||||
TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM),
|
||||
testing::Values(
|
||||
// TODO: enable more tests when mixed beam width is supported
|
||||
BeamConfig{1, {1}}, BeamConfig{2, {2}} // , BeamConfig{2, {1, 2}}
|
||||
BeamConfig{1, {1}}, BeamConfig{2, {2}} // , BeamConfig{2, {1, 2}}
|
||||
),
|
||||
testing::Values(std::nullopt, 1280), // maxTokensInPagedKvCache
|
||||
testing::Values(std::nullopt, 0.4), // freeGpuMemoryFraction
|
||||
testing::Values(true), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(true) // useRandomEndId
|
||||
testing::Values(std::nullopt, 1280), // maxTokensInPagedKvCache
|
||||
testing::Values(std::nullopt, 0.4), // freeGpuMemoryFraction
|
||||
testing::Values(true), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(true), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1099,14 +1110,16 @@ INSTANTIATE_TEST_SUITE_P(GptKVOffloadingTest, ParamTest,
|
||||
testing::Values(
|
||||
TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM),
|
||||
testing::Values(BeamConfig{1, {1}}),
|
||||
testing::Values(256), // maxTokensInPagedKvCache
|
||||
testing::Values(std::nullopt, 0.4), // freeGpuMemoryFraction
|
||||
testing::Values(true), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(100000000), // hostCacheSize
|
||||
testing::Values(false, true) // useRandomEndId
|
||||
testing::Values(256), // maxTokensInPagedKvCache
|
||||
testing::Values(std::nullopt, 0.4), // freeGpuMemoryFraction
|
||||
testing::Values(true), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(100000000), // hostCacheSize
|
||||
testing::Values(false, true), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1129,16 +1142,18 @@ INSTANTIATE_TEST_SUITE_P(GptCudaGraphTests, ParamTest,
|
||||
TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM),
|
||||
testing::Values(
|
||||
// TODO: enable more tests when mixed beam width is supported
|
||||
BeamConfig{1, {1}}, BeamConfig{2, {2}} // , BeamConfig{2, {1, 2}}
|
||||
BeamConfig{1, {1}}, BeamConfig{2, {2}} // , BeamConfig{2, {1, 2}}
|
||||
),
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(true), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(true), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1149,23 +1164,24 @@ INSTANTIATE_TEST_SUITE_P(GptSwitchBwTests, ParamTest,
|
||||
ModelSpec{INPUT_FILE, nvinfer1::DataType::kHALF}
|
||||
.useGptAttentionPlugin()
|
||||
.setKVCacheType(KVCacheType::kPAGED)
|
||||
.usePackedInput()
|
||||
.setBatchSizes({4})),
|
||||
.usePackedInput()),
|
||||
testing::Values(TrtGptModelType::InflightFusedBatching),
|
||||
testing::Values(
|
||||
TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM),
|
||||
testing::Values(
|
||||
// TODO: enable more tests when mixed beam width is supported
|
||||
BeamConfig{2, {1}} // , BeamConfig{2, {1, 2}}
|
||||
BeamConfig{2, {1}} // , BeamConfig{2, {1, 2}}
|
||||
),
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{4}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1179,16 +1195,18 @@ INSTANTIATE_TEST_SUITE_P(GptNProfilesTests, ParamTest,
|
||||
testing::Values(TrtGptModelType::InflightFusedBatching), testing::Values(TrtGptModelIfbTestType::BULK),
|
||||
testing::Values(
|
||||
// TODO: enable more tests when mixed beam width is supported
|
||||
BeamConfig{1, {1}}, BeamConfig{2, {2}} // , BeamConfig{2, {1, 2}}
|
||||
BeamConfig{1, {1}}, BeamConfig{2, {2}} // , BeamConfig{2, {1, 2}}
|
||||
),
|
||||
testing::Values(std::nullopt, 1280), // maxTokensInPagedKvCache
|
||||
testing::Values(std::nullopt, 0.4), // freeGpuMemoryFraction
|
||||
testing::Values(true), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(true) // useRandomEndId
|
||||
testing::Values(std::nullopt, 1280), // maxTokensInPagedKvCache
|
||||
testing::Values(std::nullopt, 0.4), // freeGpuMemoryFraction
|
||||
testing::Values(true), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(true), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1219,16 +1237,18 @@ INSTANTIATE_TEST_SUITE_P(GptSqTests, ParamTest,
|
||||
testing::Values(
|
||||
// TODO: enable more tests when mixed beam width is supported
|
||||
// FIXME: disabled flaky beam search tests (https://nvbugspro.nvidia.com/bug/4646234)
|
||||
BeamConfig{1, {1}} //, BeamConfig{2, {2}}
|
||||
BeamConfig{1, {1}} //, BeamConfig{2, {2}}
|
||||
),
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1243,16 +1263,18 @@ INSTANTIATE_TEST_SUITE_P(DISABLED_GptChunkedContextTests, ParamTest,
|
||||
.setKVCacheType(KVCacheType::kPAGED)
|
||||
.setMaxInputLength(128)),
|
||||
testing::Values(TrtGptModelType::InflightFusedBatching),
|
||||
testing::Values(TrtGptModelIfbTestType::BULK), // TrtGptModelIfbTestType
|
||||
testing::Values(BeamConfig{1, {1}}), // beam config
|
||||
testing::Values(257), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(TrtGptModelIfbTestType::BULK), // TrtGptModelIfbTestType
|
||||
testing::Values(BeamConfig{1, {1}}), // beam config
|
||||
testing::Values(257), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1273,16 +1295,18 @@ INSTANTIATE_TEST_SUITE_P(GptChunkedLongContextTests, ParamTest,
|
||||
.setDraftTokens(5)),
|
||||
testing::Values(TrtGptModelType::InflightFusedBatching),
|
||||
testing::Values(TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT,
|
||||
TrtGptModelIfbTestType::RANDOM), // TrtGptModelIfbTestType
|
||||
testing::Values(BeamConfig{1, {1}}), // beam config
|
||||
testing::Values(std::nullopt, 1024), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(true), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
TrtGptModelIfbTestType::RANDOM), // TrtGptModelIfbTestType
|
||||
testing::Values(BeamConfig{1, {1}}), // beam config
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(true), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(64) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1312,15 +1336,17 @@ INSTANTIATE_TEST_SUITE_P(GptDraftTests, ParamTest,
|
||||
testing::Values(TrtGptModelType::InflightFusedBatching),
|
||||
testing::Values(
|
||||
TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM),
|
||||
testing::Values(BeamConfig{1, {1}}), // beamConfig
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false, true) // useRandomEndId
|
||||
testing::Values(BeamConfig{1, {1}}), // beamConfig
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false, true), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1339,14 +1365,16 @@ INSTANTIATE_TEST_SUITE_P(GptLogitsTests, ParamTest,
|
||||
testing::Values(TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT,
|
||||
TrtGptModelIfbTestType::RANDOM), // testType
|
||||
testing::Values(BeamConfig{1, {1}}), // beamConfig
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false, true), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(true) // useRandomEndId
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false, true), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(true), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1372,7 +1400,9 @@ INSTANTIATE_TEST_SUITE_P(GptLogProbsTests, ParamTest,
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1396,16 +1426,18 @@ INSTANTIATE_TEST_SUITE_P(GptjTests, ParamTest,
|
||||
/* , TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM */),
|
||||
testing::Values(
|
||||
// TODO: enable more tests when mixed beam width is supported
|
||||
BeamConfig{1, {1}}, BeamConfig{2, {2}} // , BeamConfig{2, {1, 2}}
|
||||
BeamConfig{1, {1}}, BeamConfig{2, {2}} // , BeamConfig{2, {1, 2}}
|
||||
),
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1427,14 +1459,16 @@ INSTANTIATE_TEST_SUITE_P(MambaTests, ParamTest,
|
||||
testing::Values(
|
||||
TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM),
|
||||
testing::Values(BeamConfig{1, {1}}),
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1450,14 +1484,16 @@ INSTANTIATE_TEST_SUITE_P(RecurrentGemmaTests, ParamTest,
|
||||
testing::Values(
|
||||
TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM),
|
||||
testing::Values(BeamConfig{1, {1}}),
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1492,16 +1528,18 @@ INSTANTIATE_TEST_SUITE_P(LlamaTests, ParamTest,
|
||||
TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM),
|
||||
testing::Values(
|
||||
// TODO: enable more tests when mixed beam width is supported
|
||||
BeamConfig{1, {1}}, BeamConfig{2, {2}} // , BeamConfig{2, {1, 2}}
|
||||
BeamConfig{1, {1}}, BeamConfig{2, {2}} // , BeamConfig{2, {1, 2}}
|
||||
),
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1517,14 +1555,16 @@ INSTANTIATE_TEST_SUITE_P(ChatGlmTests, ParamTest,
|
||||
testing::Values(
|
||||
TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM),
|
||||
testing::Values(BeamConfig{1, {1}}),
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false, true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false, true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1541,14 +1581,16 @@ INSTANTIATE_TEST_SUITE_P(ChatGlm0Tests, ParamTest,
|
||||
testing::Values(
|
||||
TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM),
|
||||
testing::Values(BeamConfig{1, {1}}),
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1562,18 +1604,19 @@ INSTANTIATE_TEST_SUITE_P(MedusaTests, ParamTest,
|
||||
.useGptAttentionPlugin()
|
||||
.usePackedInput()
|
||||
.setKVCacheType(KVCacheType::kPAGED)
|
||||
.useMedusa()
|
||||
.setBatchSizes({8})),
|
||||
.useMedusa()),
|
||||
testing::Values(TrtGptModelType::InflightFusedBatching), testing::Values(TrtGptModelIfbTestType::BULK),
|
||||
testing::Values(BeamConfig{1, {1}}),
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(true, false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(true, false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1585,19 +1628,20 @@ INSTANTIATE_TEST_SUITE_P(EagleTests, ParamTest,
|
||||
.useGptAttentionPlugin()
|
||||
.usePackedInput()
|
||||
.setKVCacheType(KVCacheType::kPAGED)
|
||||
.useEagle()
|
||||
.setBatchSizes({8})),
|
||||
.useEagle()),
|
||||
testing::Values(TrtGptModelType::InflightFusedBatching),
|
||||
testing::Values(TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT),
|
||||
testing::Values(BeamConfig{1, {1}}),
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(true, false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(true, false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
@ -1609,20 +1653,21 @@ INSTANTIATE_TEST_SUITE_P(LlamaLookaheadDecodingTests, ParamTest,
|
||||
.useGptAttentionPlugin()
|
||||
.usePackedInput()
|
||||
.setKVCacheType(KVCacheType::kPAGED)
|
||||
.useLookaheadDecoding()
|
||||
.setBatchSizes({1, 16})),
|
||||
.useLookaheadDecoding()),
|
||||
testing::Values(TrtGptModelType::InflightFusedBatching),
|
||||
testing::Values(
|
||||
TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM),
|
||||
testing::Values(BeamConfig{1, {1}}), // beamConfig
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(true) // useRandomEndId
|
||||
testing::Values(BeamConfig{1, {1}}), // beamConfig
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(false), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(true), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 16}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
|
||||
generateTestName);
|
||||
@ -1636,18 +1681,19 @@ INSTANTIATE_TEST_SUITE_P(ExplicitDraftTokensDecodingTests, ParamTest,
|
||||
.usePackedInput()
|
||||
.setKVCacheType(KVCacheType::kPAGED)
|
||||
.useExplicitDraftTokensDecoding()
|
||||
.setMaxOutputLength(128)
|
||||
.setBatchSizes({8})),
|
||||
.setMaxOutputLength(128)),
|
||||
testing::Values(TrtGptModelType::InflightFusedBatching), testing::Values(TrtGptModelIfbTestType::BULK),
|
||||
testing::Values(BeamConfig{1, {1}}), // beamConfig
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(BeamConfig{1, {1}}), // beamConfig
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
|
||||
generateTestName);
|
||||
@ -1669,16 +1715,18 @@ INSTANTIATE_TEST_SUITE_P(GptjFP8Tests, ParamTest,
|
||||
TrtGptModelIfbTestType::BULK, TrtGptModelIfbTestType::WAVEFRONT, TrtGptModelIfbTestType::RANDOM),
|
||||
testing::Values(
|
||||
// TODO: enable more tests when supported
|
||||
BeamConfig{1, {1}} // , BeamConfig{2, {2}}, BeamConfig{2, {1, 2}}
|
||||
BeamConfig{1, {1}} // , BeamConfig{2, {2}}, BeamConfig{2, {1, 2}}
|
||||
),
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false) // useRandomEndId
|
||||
testing::Values(std::nullopt), // maxTokensInPagedKvCache
|
||||
testing::Values(0.4), // freeGpuMemoryFraction
|
||||
testing::Values(false), // enableTrtOverlap
|
||||
testing::Values(true), // enableChunkedContext
|
||||
testing::Values(false), // enableStreamingMode
|
||||
testing::Values(false), // enableCudaGraphMode
|
||||
testing::Values(std::nullopt), // hostCacheSize
|
||||
testing::Values(false), // useRandomEndId
|
||||
testing::Values(std::vector<SizeType32>{1, 2, 8}), // batchSizes
|
||||
testing::Values(std::nullopt) // maxNumTokens
|
||||
),
|
||||
generateTestName);
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user