mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#2460)
This commit is contained in:
parent
c629546ce4
commit
535c9cc673
14
README.md
14
README.md
@ -8,7 +8,7 @@ TensorRT-LLM
|
||||
[](https://www.python.org/downloads/release/python-31012/)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](./LICENSE)
|
||||
|
||||
[Architecture](./docs/source/architecture/overview.md) | [Results](./docs/source/performance/perf-overview.md) | [Examples](./examples/) | [Documentation](./docs/source/) | [Roadmap](https://docs.google.com/presentation/d/1gycPmtdh7uUcH6laOvW65Dbp9F1McUkGDIcAyjicBZs/edit?usp=sharing)
|
||||
@ -18,12 +18,18 @@ TensorRT-LLM
|
||||
|
||||
## Latest News
|
||||
|
||||
* [2024/11/09] 🚀🚀🚀 3x Faster AllReduce with NVSwitch and TensorRT-LLM MultiShot
|
||||
[➡️ link](https://developer.nvidia.com/blog/3x-faster-allreduce-with-nvswitch-and-tensorrt-llm-multishot/)
|
||||
<div align="center">
|
||||
<img src="https://developer-blogs.nvidia.com/wp-content/uploads/2024/08/HGX-H200-tech-blog-1920x1080-1.jpg" width="50%">
|
||||
<div align="left">
|
||||
|
||||
* [2024/11/09] ✨ NVIDIA advances the AI ecosystem with the AI model of LG AI Research 🙌
|
||||
[➡️ link](https://blogs.nvidia.co.kr/blog/nvidia-lg-ai-research/)
|
||||
|
||||
* [2024/11/02] 🌟🌟🌟 NVIDIA and LlamaIndex Developer Contest
|
||||
🙌 Enter for a chance to win prizes including an NVIDIA® GeForce RTX™ 4080 SUPER GPU, DLI credits, and more🙌
|
||||
[➡️ link](https://developer.nvidia.com/llamaindex-developer-contest)
|
||||
<div align="center">
|
||||
<img src="docs/source/media/image-11-02-2024.png" width="50%">
|
||||
<div align="left">
|
||||
|
||||
* [2024/10/28] 🏎️🏎️🏎️ NVIDIA GH200 Superchip Accelerates Inference by 2x in Multiturn Interactions with Llama Models
|
||||
[➡️ link](https://developer.nvidia.com/blog/nvidia-gh200-superchip-accelerates-inference-by-2x-in-multiturn-interactions-with-llama-models/)
|
||||
|
||||
@ -664,7 +664,7 @@ class ExecutorServer
|
||||
{
|
||||
public:
|
||||
ExecutorServer(std::optional<std::filesystem::path> const& decoderTrtEnginePath,
|
||||
std::optional<std::filesystem::path> const& encoderTrtEnginePath, TrtGptModelType modelType,
|
||||
std::optional<std::filesystem::path> const& encoderTrtEnginePath, texec::BatchingType batchingType,
|
||||
int32_t maxBeamWidth, texec::CapacitySchedulerPolicy capacitySchedulerPolicy,
|
||||
BenchmarkParams const& benchmarkParams, std::shared_ptr<Recorder> recorder, std::chrono::milliseconds waitSleep,
|
||||
bool logIterationData, texec::ModelType executorModelType)
|
||||
@ -692,8 +692,7 @@ public:
|
||||
maxBeamWidth, schedulerConfig, kvCacheConfig, benchmarkParams.enableChunkedContext, true);
|
||||
executorConfig.setGpuWeightsPercent(benchmarkParams.gpuWeightsPercent);
|
||||
executorConfig.setPeftCacheConfig(peftCacheConfig);
|
||||
executorConfig.setBatchingType(
|
||||
modelType == TrtGptModelType::V1 ? texec::BatchingType::kSTATIC : texec::BatchingType::kINFLIGHT);
|
||||
executorConfig.setBatchingType(batchingType);
|
||||
if (benchmarkParams.maxBatchSize)
|
||||
{
|
||||
executorConfig.setMaxBatchSize(benchmarkParams.maxBatchSize.value());
|
||||
@ -947,6 +946,7 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
|
||||
std::nullopt, // embeddingBias
|
||||
std::nullopt, // speculativeDecoding
|
||||
std::nullopt, // pTuning
|
||||
std::nullopt, // mRopeConfig
|
||||
loraConfig, // loraConfig
|
||||
lookaheadConfig, // lookaheadConfig
|
||||
std::nullopt, // kvCacheRetentionConfig
|
||||
@ -955,7 +955,7 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
|
||||
}
|
||||
|
||||
void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngineDir,
|
||||
std::optional<std::filesystem::path> const& encoderEngineDir, TrtGptModelType modelType,
|
||||
std::optional<std::filesystem::path> const& encoderEngineDir, texec::BatchingType batchingType,
|
||||
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
|
||||
std::optional<int32_t> const& eosId, std::optional<int32_t> const& padId, BenchmarkParams const& benchmarkParams,
|
||||
texec::CapacitySchedulerPolicy capacitySchedulerPolicy, std::chrono::milliseconds waitSleep,
|
||||
@ -977,16 +977,17 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
decoderEngineDir.has_value(), "decoder models require a path to decoder engine in executor benchmark.");
|
||||
executorServer = std::make_shared<ExecutorServer>(decoderEngineDir.value(), std::nullopt, modelType, beamWidth,
|
||||
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
|
||||
executorServer
|
||||
= std::make_shared<ExecutorServer>(decoderEngineDir.value(), std::nullopt, batchingType, beamWidth,
|
||||
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
|
||||
}
|
||||
else if (executorModelType == texec::ModelType::kENCODER_DECODER)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(encoderEngineDir.has_value(),
|
||||
"encoder-decoder models require a path to encoder engine in executor benchmark.");
|
||||
executorServer
|
||||
= std::make_shared<ExecutorServer>(decoderEngineDir.value(), encoderEngineDir.value(), modelType, beamWidth,
|
||||
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
|
||||
executorServer = std::make_shared<ExecutorServer>(decoderEngineDir.value(), encoderEngineDir.value(),
|
||||
batchingType, beamWidth, capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData,
|
||||
executorModelType);
|
||||
try
|
||||
{
|
||||
std::ifstream decoderJsonConfigPath(decoderEngineDir.value() / "config.json");
|
||||
@ -1011,8 +1012,9 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
encoderEngineDir.has_value(), "encoder models require a path to encoder engine in executor benchmark.");
|
||||
executorServer = std::make_shared<ExecutorServer>(std::nullopt, encoderEngineDir.value(), modelType, beamWidth,
|
||||
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
|
||||
executorServer
|
||||
= std::make_shared<ExecutorServer>(std::nullopt, encoderEngineDir.value(), batchingType, beamWidth,
|
||||
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -1219,8 +1221,9 @@ int main(int argc, char* argv[])
|
||||
"encoder_engine_dir", "Directory that store the engines of the encoder models.", cxxopts::value<std::string>());
|
||||
options.add_options()(
|
||||
"api", "API type: gptManager or executor.", cxxopts::value<std::string>()->default_value("executor"));
|
||||
options.add_options()("type", "Batching type: IFB, UIFB (unfused IFB) or V1 (non-IFB) batching.",
|
||||
cxxopts::value<std::string>()->default_value("IFB"));
|
||||
options.add_options()("type",
|
||||
"Batching type: choose between inflight/static. (IFB/V1 options are going to be deprecated)",
|
||||
cxxopts::value<std::string>()->default_value("inflight"));
|
||||
options.add_options()("dataset", "Dataset that is used for benchmarking BatchManager.",
|
||||
cxxopts::value<std::string>()->default_value(""));
|
||||
options.add_options()(
|
||||
@ -1332,18 +1335,22 @@ int main(int argc, char* argv[])
|
||||
|
||||
// Argument: Batching Type
|
||||
auto const type = result["type"].as<std::string>();
|
||||
TrtGptModelType modelType{TrtGptModelType::V1};
|
||||
if (type == "V1")
|
||||
texec::BatchingType batchingType{texec::BatchingType::kINFLIGHT};
|
||||
if (type == "V1" || type == "static")
|
||||
{
|
||||
modelType = TrtGptModelType::V1;
|
||||
if (type == "V1")
|
||||
{
|
||||
TLLM_LOG_WARNING("type option \"V1\" is going to be renamed to \"static\".");
|
||||
}
|
||||
batchingType = texec::BatchingType::kSTATIC;
|
||||
}
|
||||
else if (type == "UIFB")
|
||||
else if (type == "IFB" || type == "inflight")
|
||||
{
|
||||
modelType = TrtGptModelType::InflightBatching;
|
||||
}
|
||||
else if (type == "IFB")
|
||||
{
|
||||
modelType = TrtGptModelType::InflightFusedBatching;
|
||||
if (type == "IFB")
|
||||
{
|
||||
TLLM_LOG_WARNING("type option \"IFB\" is going to be renamed to \"inflight\".");
|
||||
}
|
||||
batchingType = texec::BatchingType::kINFLIGHT;
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -1604,7 +1611,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(api == "executor", "encoder-decoder only support executor api.");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
modelType == TrtGptModelType::InflightFusedBatching, "encoder-decoder only support inflight batching.");
|
||||
batchingType == texec::BatchingType::kINFLIGHT, "encoder-decoder only support inflight batching.");
|
||||
executorModelType = texec::ModelType::kENCODER_DECODER;
|
||||
encoderEngineDir = result["encoder_engine_dir"].as<std::string>();
|
||||
decoderEngineDir = result["decoder_engine_dir"].as<std::string>();
|
||||
@ -1621,7 +1628,7 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
try
|
||||
{
|
||||
benchmarkExecutor(decoderEngineDir, encoderEngineDir, modelType, datasetPath, opCsvFile, maxNumSamples,
|
||||
benchmarkExecutor(decoderEngineDir, encoderEngineDir, batchingType, datasetPath, opCsvFile, maxNumSamples,
|
||||
beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, capacitySchedulerPolicy,
|
||||
waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, logIterationData,
|
||||
maxPromptLen, executorModelType);
|
||||
|
||||
@ -50,7 +50,7 @@ def print_dataset(input_ids, output_lens):
|
||||
for i, input_tokens in enumerate(input_ids):
|
||||
d = {
|
||||
"task_id": i,
|
||||
"logits": input_tokens,
|
||||
"input_ids": input_tokens,
|
||||
"output_tokens": output_lens[i]
|
||||
}
|
||||
print(json.dumps(d, separators=(',', ':'), ensure_ascii=False))
|
||||
|
||||
@ -85,7 +85,7 @@ public:
|
||||
std::optional<SizeType32> sinkTokenLength;
|
||||
std::optional<float> freeGpuMemoryFraction;
|
||||
bool enableBlockReuse;
|
||||
static constexpr auto kDefaultGpuMemFraction = 0.9f;
|
||||
static constexpr auto kDefaultGpuMemFraction = 0.9F;
|
||||
bool useUvm;
|
||||
std::optional<size_t> hostCacheSize;
|
||||
bool onboardBlocks;
|
||||
|
||||
@ -835,7 +835,7 @@ public:
|
||||
* 2 * modelConfig.getSizePerHead();
|
||||
}
|
||||
|
||||
[[nodiscard]] static std::tuple<SizeType32, SizeType32> const calculateMaxNumBlocks(KvCacheConfig const& config,
|
||||
[[nodiscard]] static std::tuple<SizeType32, SizeType32> calculateMaxNumBlocks(KvCacheConfig const& config,
|
||||
nvinfer1::DataType dtype, tensorrt_llm::runtime::ModelConfig const& modelConfig,
|
||||
tensorrt_llm::runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
|
||||
|
||||
|
||||
@ -92,6 +92,8 @@ public:
|
||||
std::optional<std::shared_ptr<std::vector<SizeType32>>> positionIds = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
std::optional<SizeType32> promptVocabSize = std::nullopt,
|
||||
std::optional<TensorPtr> mropeRotarySinCos = std::nullopt,
|
||||
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
|
||||
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
|
||||
std::optional<TensorPtr> loraConfig = std::nullopt,
|
||||
std::optional<executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
|
||||
@ -131,6 +133,8 @@ public:
|
||||
, mPositionIds(std::move(positionIds))
|
||||
, mPromptEmbeddingTable(std::move(promptEmbeddingTable))
|
||||
, mPromptVocabSize(promptVocabSize)
|
||||
, mMropeRotarySinCos(std::move(mropeRotarySinCos))
|
||||
, mMropePositionDeltas(std::move(mropePositionDeltas))
|
||||
, mLoraTaskId(loraTaskId)
|
||||
, mLoraWeights(std::move(loraWeights))
|
||||
, mLoraConfig(std::move(loraConfig))
|
||||
@ -188,6 +192,8 @@ public:
|
||||
, mPositionIds(std::nullopt)
|
||||
, mPromptEmbeddingTable(std::nullopt)
|
||||
, mPromptVocabSize(std::nullopt)
|
||||
, mMropeRotarySinCos(std::nullopt)
|
||||
, mMropePositionDeltas(std::nullopt)
|
||||
, mLoraTaskId(std::nullopt)
|
||||
, mLoraWeights(std::nullopt)
|
||||
, mLoraConfig(std::nullopt)
|
||||
@ -285,6 +291,12 @@ public:
|
||||
= std::make_shared<VecTokenExtraIds>(pTuningConfig->getInputTokenExtraIds().value());
|
||||
}
|
||||
}
|
||||
auto mRopeConfig = req.getMropeConfig();
|
||||
if (mRopeConfig)
|
||||
{
|
||||
mMropeRotarySinCos = executor::detail::toITensor(mRopeConfig.value().getMRopeRotarySinCos());
|
||||
mMropePositionDeltas = mRopeConfig.value().getMRopePositionDeltas();
|
||||
}
|
||||
|
||||
auto loraConfig = req.getLoraConfig();
|
||||
if (loraConfig)
|
||||
@ -447,16 +459,6 @@ public:
|
||||
mContextPhaseParams = std::move(contextPhaseParams);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isLayerWiseKvCacheEnabled() const
|
||||
{
|
||||
return isContextOnlyRequest() && mLayerWiseKvCacheEnabled;
|
||||
}
|
||||
|
||||
void setLayerWiseKvCacheEnabled(bool enabled)
|
||||
{
|
||||
mLayerWiseKvCacheEnabled = enabled;
|
||||
}
|
||||
|
||||
/// @brief Get the state params of the context
|
||||
/// @return The state params of the context
|
||||
[[nodiscard]] executor::DataTransceiverState const& getDataTransceiverState() const
|
||||
@ -798,6 +800,16 @@ public:
|
||||
return mPromptVocabSize;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<TensorPtr> getMropeRotarySinCos() const
|
||||
{
|
||||
return mMropeRotarySinCos;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<SizeType32> getMropePositionDeltas() const
|
||||
{
|
||||
return mMropePositionDeltas;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<LoraTaskIdType> getLoraTaskId() const
|
||||
{
|
||||
return mLoraTaskId;
|
||||
@ -1604,6 +1616,8 @@ protected:
|
||||
|
||||
std::optional<TensorPtr> mPromptEmbeddingTable;
|
||||
std::optional<SizeType32> mPromptVocabSize;
|
||||
std::optional<TensorPtr> mMropeRotarySinCos;
|
||||
std::optional<SizeType32> mMropePositionDeltas;
|
||||
|
||||
std::optional<LoraTaskIdType> mLoraTaskId;
|
||||
std::optional<TensorPtr> mLoraWeights;
|
||||
@ -1654,7 +1668,6 @@ protected:
|
||||
std::optional<TensorPtr> mCrossAttentionMask; // Input cross attention mask
|
||||
LlmRequestType mLlmRequestType;
|
||||
std::optional<executor::ContextPhaseParams> mContextPhaseParams;
|
||||
bool mLayerWiseKvCacheEnabled = false;
|
||||
|
||||
std::optional<std::shared_ptr<VecTokenExtraIds>> mInputTokenExtraIds;
|
||||
BeamUniqueTokens mUniqueTokens;
|
||||
@ -1819,6 +1832,8 @@ public:
|
||||
std::optional<std::shared_ptr<std::vector<SizeType32>>> positionIds = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
std::optional<SizeType32> promptVocabSize = std::nullopt,
|
||||
std::optional<TensorPtr> mropeRotarySinCos = std::nullopt,
|
||||
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
|
||||
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
|
||||
std::optional<TensorPtr> loraConfig = std::nullopt,
|
||||
std::optional<executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
|
||||
@ -1840,7 +1855,8 @@ public:
|
||||
std::optional<TensorPtr> skipCrossAttnBlocks = std::nullopt)
|
||||
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
|
||||
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
|
||||
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
|
||||
std::move(promptEmbeddingTable), promptVocabSize, std::move(mropeRotarySinCos),
|
||||
std::move(mropePositionDeltas), loraTaskId, std::move(loraWeights), std::move(loraConfig),
|
||||
std::move(lookaheadConfig), std::move(kvCacheRetentionConfig), returnLogProbs, returnContextLogits,
|
||||
returnGenerationLogits, std::move(draftTokens), std::move(draftLogits), excludeInputFromOutput,
|
||||
std::move(logitsPostProcessor), applyLogitsPostProcessorBatched, std::move(encoderInputTokens),
|
||||
@ -1857,6 +1873,8 @@ public:
|
||||
std::optional<std::vector<SizeType32>> positionIds = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
std::optional<SizeType32> promptVocabSize = std::nullopt,
|
||||
std::optional<TensorPtr> mropeRotarySinCos = std::nullopt,
|
||||
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
|
||||
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
|
||||
std::optional<TensorPtr> loraConfig = std::nullopt,
|
||||
std::optional<executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
|
||||
@ -1879,7 +1897,8 @@ public:
|
||||
std::move(stopWordsList),
|
||||
positionIds.has_value() ? std::make_shared<std::vector<SizeType32>>(std::move(positionIds.value()))
|
||||
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
|
||||
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
|
||||
std::move(promptEmbeddingTable), promptVocabSize, std::move(mropeRotarySinCos),
|
||||
std::move(mropePositionDeltas), loraTaskId, std::move(loraWeights), std::move(loraConfig),
|
||||
std::move(lookaheadConfig), std::move(kvCacheRetentionConfig), returnLogProbs, returnContextLogits,
|
||||
returnGenerationLogits,
|
||||
draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value()))
|
||||
|
||||
@ -35,11 +35,11 @@
|
||||
#include <mpi.h>
|
||||
#else
|
||||
// Dummy defines to avoid #if in wider places.
|
||||
typedef int MPI_Datatype;
|
||||
typedef int MPI_Comm;
|
||||
typedef int MPI_Request;
|
||||
typedef int MPI_Message;
|
||||
typedef int MPI_Op;
|
||||
typedef void* MPI_Datatype;
|
||||
typedef void* MPI_Comm;
|
||||
typedef void* MPI_Request;
|
||||
typedef void* MPI_Message;
|
||||
typedef void* MPI_Op;
|
||||
|
||||
typedef struct MPI_Status
|
||||
{
|
||||
|
||||
@ -250,6 +250,24 @@ private:
|
||||
std::optional<VecTokenExtraIds> mInputTokenExtraIds;
|
||||
};
|
||||
|
||||
/// @brief Configuration for mrope
|
||||
class MropeConfig
|
||||
{
|
||||
public:
|
||||
explicit MropeConfig(Tensor mropeRoratySinCos, SizeType32 mropePositionDeltas);
|
||||
|
||||
[[nodiscard]] Tensor getMRopeRotarySinCos() const;
|
||||
[[nodiscard]] SizeType32 getMRopePositionDeltas() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
/// @brief The mrope rotary sin and cos cache. Expected shape: [maxPositionEmbeddings*rotaryEmbeddingDim],Data type
|
||||
/// must float32
|
||||
Tensor mMRopeRotarySinCos;
|
||||
/// @brief The mrope position deltas
|
||||
SizeType32 mMRopePositionDeltas;
|
||||
};
|
||||
|
||||
/// @brief Configuration for LoRA
|
||||
class LoraConfig
|
||||
{
|
||||
@ -330,9 +348,10 @@ public:
|
||||
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state);
|
||||
|
||||
ContextPhaseParams(ContextPhaseParams const&);
|
||||
ContextPhaseParams(ContextPhaseParams&&);
|
||||
ContextPhaseParams(ContextPhaseParams&&) noexcept;
|
||||
ContextPhaseParams& operator=(ContextPhaseParams const&);
|
||||
ContextPhaseParams& operator=(ContextPhaseParams&&);
|
||||
ContextPhaseParams& operator=(ContextPhaseParams&&) noexcept;
|
||||
~ContextPhaseParams();
|
||||
|
||||
[[nodiscard]] bool operator==(ContextPhaseParams const&) const noexcept;
|
||||
|
||||
@ -511,7 +530,7 @@ public:
|
||||
std::optional<Tensor> embeddingBias = std::nullopt,
|
||||
std::optional<ExternalDraftTokensConfig> externalDraftTokensConfig = std::nullopt,
|
||||
std::optional<PromptTuningConfig> pTuningConfig = std::nullopt,
|
||||
std::optional<LoraConfig> loraConfig = std::nullopt,
|
||||
std::optional<MropeConfig> mRopeConfig = std::nullopt, std::optional<LoraConfig> loraConfig = std::nullopt,
|
||||
std::optional<LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
|
||||
std::optional<KvCacheRetentionConfig> kvCacheRetentionConfig = std::nullopt,
|
||||
std::optional<std::string> logitsPostProcessorName = std::nullopt,
|
||||
@ -548,6 +567,7 @@ public:
|
||||
[[nodiscard]] std::optional<Tensor> getEmbeddingBias() const;
|
||||
[[nodiscard]] std::optional<ExternalDraftTokensConfig> getExternalDraftTokensConfig() const;
|
||||
[[nodiscard]] std::optional<PromptTuningConfig> getPromptTuningConfig() const;
|
||||
[[nodiscard]] std::optional<MropeConfig> getMropeConfig() const;
|
||||
[[nodiscard]] std::optional<LoraConfig> getLoraConfig() const;
|
||||
[[nodiscard]] std::optional<LookaheadDecodingConfig> getLookaheadConfig() const;
|
||||
[[nodiscard]] std::optional<KvCacheRetentionConfig> getKvCacheRetentionConfig() const;
|
||||
@ -576,6 +596,7 @@ public:
|
||||
void setEmbeddingBias(Tensor const& embeddingBias);
|
||||
void setExternalDraftTokensConfig(ExternalDraftTokensConfig const& externalDraftTokensConfig);
|
||||
void setPromptTuningConfig(PromptTuningConfig const& pTuningConfig);
|
||||
void setMropeConfig(MropeConfig const& mRopeConfig);
|
||||
void setLoraConfig(LoraConfig const& loraConfig);
|
||||
void setLookaheadConfig(LookaheadDecodingConfig const& lookaheadConfig);
|
||||
void setKvCacheRetentionConfig(KvCacheRetentionConfig const& kvCacheRetentionConfig);
|
||||
@ -648,7 +669,10 @@ struct Result
|
||||
/// @brief The params of the context phase.
|
||||
std::optional<ContextPhaseParams> contextPhaseParams;
|
||||
|
||||
/// @brief The decoding iterations it takes.
|
||||
/// @brief The number of the decoding iterations used to generate the result.
|
||||
/// In autoregressive decoding, it is equal to the maximum length of the beam in outputTokenIds.
|
||||
/// In speculative decoding, might be less than maximum length of the beam in outputTokenIds as more than
|
||||
/// one token can be generated per iteration. Used for speculative decoding statistics.
|
||||
SizeType32 decodingIter{0};
|
||||
|
||||
/// @brief The index of the output sequence of this result where 0 <= sequenceIndex < numReturnSequences.
|
||||
|
||||
@ -55,6 +55,11 @@ public:
|
||||
static void serialize(PromptTuningConfig const& config, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(PromptTuningConfig const& config);
|
||||
|
||||
// MropeConfig
|
||||
[[nodiscard]] static MropeConfig deserializeMropeConfig(std::istream& is);
|
||||
static void serialize(MropeConfig const& config, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(MropeConfig const& config);
|
||||
|
||||
// LoraConfig
|
||||
[[nodiscard]] static LoraConfig deserializeLoraConfig(std::istream& is);
|
||||
static void serialize(LoraConfig const& config, std::ostream& os);
|
||||
@ -206,6 +211,33 @@ public:
|
||||
static void serialize(IterationStats const& iterStats, std::ostream& os);
|
||||
static std::vector<char> serialize(IterationStats const& iterStats);
|
||||
static size_t serializedSize(IterationStats const& iterStats);
|
||||
static std::vector<char> serialize(std::vector<IterationStats> const& iterStatsVec);
|
||||
static std::vector<IterationStats> deserializeIterationStatsVec(std::vector<char>& buffer);
|
||||
|
||||
// DisServingStats
|
||||
[[nodiscard]] static DisServingRequestStats deserializeDisServingRequestStats(std::istream& is);
|
||||
static void serialize(DisServingRequestStats const& state, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(DisServingRequestStats const& state);
|
||||
|
||||
// RequestStage
|
||||
[[nodiscard]] static RequestStage deserializeRequestStage(std::istream& is);
|
||||
static void serialize(RequestStage const& state, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(RequestStage const& state);
|
||||
|
||||
// RequestStats
|
||||
[[nodiscard]] static RequestStats deserializeRequestStats(std::istream& is);
|
||||
static void serialize(RequestStats const& state, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(RequestStats const& state);
|
||||
|
||||
// RequestStatsPerIteration
|
||||
[[nodiscard]] static RequestStatsPerIteration deserializeRequestStatsPerIteration(std::istream& is);
|
||||
[[nodiscard]] static RequestStatsPerIteration deserializeRequestStatsPerIteration(std::vector<char>& buffer);
|
||||
static void serialize(RequestStatsPerIteration const& state, std::ostream& os);
|
||||
[[nodiscard]] static std::vector<char> serialize(RequestStatsPerIteration const& state);
|
||||
[[nodiscard]] static size_t serializedSize(RequestStatsPerIteration const& state);
|
||||
[[nodiscard]] static std::vector<char> serialize(std::vector<RequestStatsPerIteration> const& requestStatsVec);
|
||||
[[nodiscard]] static std::vector<RequestStatsPerIteration> deserializeRequestStatsPerIterationVec(
|
||||
std::vector<char>& buffer);
|
||||
|
||||
// String
|
||||
static std::string deserializeString(std::istream& is);
|
||||
|
||||
@ -77,6 +77,8 @@ public:
|
||||
TensorPtr eagleNetGenContextLengthsHost;
|
||||
//! [maxBatchSize] or [numSequences]
|
||||
TensorPtr eagleNetGenPastKeyValueLengthsHost;
|
||||
//! [maxBatchSize * maxDecodingTokens] or [numSequences * maxDecodingTokens]
|
||||
TensorPtr inputGenTokensHost;
|
||||
|
||||
void create(SizeType32 maxNumSequences, runtime::TllmRuntime const& runtime,
|
||||
runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig);
|
||||
|
||||
@ -67,7 +67,7 @@ public:
|
||||
static std::unique_ptr<IGptDecoder> create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
|
||||
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
|
||||
BufferManager::CudaStreamPtr const& stream,
|
||||
std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModule = nullptr);
|
||||
std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule = nullptr);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -110,7 +110,7 @@ private:
|
||||
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
|
||||
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
|
||||
BufferManager::CudaStreamPtr const& stream,
|
||||
std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModule)
|
||||
std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule)
|
||||
{
|
||||
switch (dtype)
|
||||
{
|
||||
@ -128,10 +128,9 @@ inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode c
|
||||
|
||||
/// @brief Helper function to produce batch slots [0, 1, ..., batchSize - 1] for paths that do not explicitly provide
|
||||
/// batch slots to the decoder.
|
||||
inline runtime::ITensor::SharedConstPtr getDefaultBatchSlots(
|
||||
runtime::SizeType32 batchSize, runtime::BufferManager const& bufferManager)
|
||||
inline runtime::ITensor::SharedConstPtr getDefaultBatchSlots(runtime::SizeType32 batchSize)
|
||||
{
|
||||
auto defaultBatchSlots = bufferManager.pinnedPool(
|
||||
auto defaultBatchSlots = runtime::BufferManager::pinnedPool(
|
||||
runtime::ITensor::makeShape({batchSize}), runtime::TRTDataType<runtime::SizeType32>::value);
|
||||
auto range = runtime::BufferRange<runtime::SizeType32>(*defaultBatchSlots);
|
||||
std::iota(range.begin(), range.end(), 0);
|
||||
|
||||
@ -36,6 +36,7 @@ public:
|
||||
// users tune that, we can support that by writing and reading the
|
||||
// points in `config.json`.
|
||||
static constexpr std::array kOPT_PROFILES_SPLIT_POINTS{64, 128, 256, 512, 1024};
|
||||
static constexpr SizeType32 kDEFAULT_NUM_TOKENS_PER_BLOCK = 64;
|
||||
|
||||
enum class ModelVariant : std::int32_t
|
||||
{
|
||||
@ -80,18 +81,16 @@ public:
|
||||
{
|
||||
return KVCacheType::kCONTINUOUS;
|
||||
}
|
||||
else if (value == "PAGED")
|
||||
if (value == "PAGED")
|
||||
{
|
||||
return KVCacheType::kPAGED;
|
||||
}
|
||||
else if (value == "DISABLED")
|
||||
if (value == "DISABLED")
|
||||
{
|
||||
return KVCacheType::kDISABLED;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("Invalid KV cache type: " + value);
|
||||
}
|
||||
|
||||
throw std::invalid_argument("Invalid KV cache type: " + value);
|
||||
}
|
||||
|
||||
enum class ManageWeightsType : std::int32_t
|
||||
@ -113,7 +112,7 @@ public:
|
||||
, mUseGptAttentionPlugin(false)
|
||||
, mUseMambaConv1dPlugin(false)
|
||||
, mInputPacked{false}
|
||||
, mTokensPerBlock{64}
|
||||
, mTokensPerBlock{kDEFAULT_NUM_TOKENS_PER_BLOCK}
|
||||
, mQuantMode{common::QuantMode::none()}
|
||||
, mMaxBatchSize(0)
|
||||
, mMaxBeamWidth(0)
|
||||
@ -124,6 +123,9 @@ public:
|
||||
, mComputeGenerationLogits(false)
|
||||
, mModelVariant(ModelVariant::kGpt)
|
||||
, mMaxPromptEmbeddingTableSize(0)
|
||||
, mUseMrope{false}
|
||||
, mMaxPositionEmbeddings(0)
|
||||
, mRotaryEmbeddingDim(0)
|
||||
, mContextFMHA(false)
|
||||
, mPagedContextFMHA(false)
|
||||
, mUseXQA{false}
|
||||
@ -396,6 +398,36 @@ public:
|
||||
return mMaxPromptEmbeddingTableSize > 0;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr useMrope() const noexcept
|
||||
{
|
||||
return mUseMrope;
|
||||
}
|
||||
|
||||
void constexpr setUseMrope(bool useMrope) noexcept
|
||||
{
|
||||
mUseMrope = useMrope;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getMaxPositionEmbeddings() const noexcept
|
||||
{
|
||||
return mMaxPositionEmbeddings;
|
||||
}
|
||||
|
||||
void constexpr setMaxPositionEmbeddings(SizeType32 maxPositionEmbeddings) noexcept
|
||||
{
|
||||
mMaxPositionEmbeddings = maxPositionEmbeddings;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getRotaryEmbeddingDim() const noexcept
|
||||
{
|
||||
return mRotaryEmbeddingDim;
|
||||
}
|
||||
|
||||
void constexpr setRotaryEmbeddingDim(SizeType32 rotaryEmbeddingDim) noexcept
|
||||
{
|
||||
mRotaryEmbeddingDim = rotaryEmbeddingDim;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getMaxPromptEmbeddingTableSize() const noexcept
|
||||
{
|
||||
return mMaxPromptEmbeddingTableSize;
|
||||
@ -622,14 +654,12 @@ public:
|
||||
{
|
||||
return nvinfer1::DataType::kFP8;
|
||||
}
|
||||
else if (getQuantMode().hasInt8KvCache())
|
||||
if (getQuantMode().hasInt8KvCache())
|
||||
{
|
||||
return nvinfer1::DataType::kINT8;
|
||||
}
|
||||
else
|
||||
{
|
||||
return getDataType();
|
||||
}
|
||||
|
||||
return getDataType();
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr isTransformerBased() const noexcept
|
||||
@ -733,6 +763,7 @@ public:
|
||||
: mNumKvHeadsPerAttentionLayer.cbegin() + numPrevAttnLayers;
|
||||
auto const numLocalAttentionLayers
|
||||
= countLocalLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
|
||||
TLLM_LOG_TRACE("%s stop: %d", __PRETTY_FUNCTION__);
|
||||
return std::make_pair(firstLocalAttentionLayerIt, firstLocalAttentionLayerIt + numLocalAttentionLayers);
|
||||
}
|
||||
|
||||
@ -797,6 +828,9 @@ private:
|
||||
ModelVariant mModelVariant;
|
||||
|
||||
SizeType32 mMaxPromptEmbeddingTableSize;
|
||||
bool mUseMrope;
|
||||
SizeType32 mMaxPositionEmbeddings;
|
||||
SizeType32 mRotaryEmbeddingDim;
|
||||
|
||||
bool mContextFMHA;
|
||||
bool mPagedContextFMHA;
|
||||
|
||||
@ -26,4 +26,7 @@ bool tensorHasNan(ITensor const& tensor, BufferManager const& manager, std::stri
|
||||
bool tensorHasNan(
|
||||
size_t M, size_t K, nvinfer1::DataType type, void const* data, cudaStream_t stream, std::string const& infoStr);
|
||||
|
||||
int stallStream(
|
||||
char const* name, std::optional<cudaStream_t> stream = std::nullopt, std::optional<int> delay = std::nullopt);
|
||||
|
||||
} // namespace tensorrt_llm::runtime::utils
|
||||
|
||||
@ -35,6 +35,7 @@ struct TreeNode
|
||||
SizeType32 initTensorsFromChoices(SpeculativeDecodingModule const& speculativeDecodingModule,
|
||||
std::vector<std::vector<SizeType32>> const& choices, std::vector<SizeType32>& topKs,
|
||||
ITensor::SharedPtr generationInputLengths, ITensor::SharedPtr positionOffsets, ITensor::SharedPtr treeIds,
|
||||
ITensor::SharedPtr paths, ITensor::SharedPtr packedMask);
|
||||
ITensor::SharedPtr paths, ITensor::SharedPtr packedMask,
|
||||
std::optional<SizeType32> maxNonLeafNodesPerLayer = std::nullopt);
|
||||
|
||||
} // namespace tensorrt_llm::runtime::utils
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:748a53a5f70813f0ddb5bb54a56cd07a4b9146917c12ec34504dc4384b00610b
|
||||
size 5882210
|
||||
oid sha256:93114cc9b3f67d302800ef751a71a87f549ad1fb436d7983cedea7edaf3cdc34
|
||||
size 6001292
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2350b7f07b5f30179ebf24f6e103dc17d4a656c95c171eaca684529120ca245a
|
||||
size 6001974
|
||||
oid sha256:c65e18c28264cf19543f94e5e529e00e1cfda12e6d02c7a2141960f11e1020e8
|
||||
size 6121836
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8b28f05452036c1722a37ac625921cf4902cfb6c04fb01b9d958b9f40ff9be0b
|
||||
oid sha256:be60169ba0f4d8a526427d942ddb7e657a075f82b9bde186f339d92e5baefedd
|
||||
size 1958384
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
0066a5a67ec747f565158bbbc398cca9 libtensorrt_llm_ucx_wrapper.so
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
f0f55c8f4b75991abd3ff2fce878cbea libtensorrt_llm_ucx_wrapper.so
|
||||
0397a251a647dd5d25f0de5279170ba35d82c50d commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0132b1d4544101465ac37993ae20324c0c49ae978b0a3c8c95a03a08a17b5b36
|
||||
size 5692876
|
||||
oid sha256:7cb62faee8fdf912738a7e144f55fbbc8348dcee342224a6e07890b5ec3cac05
|
||||
size 5793796
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:15ff5d0aeae4d3e776fdf3bb68af0cc5896b14f435b66a11fecc2111668fd089
|
||||
size 5659602
|
||||
oid sha256:85ff3b9cefda7fc15bcd8a92fbcfb6a7200f13fbfa47e133caa103c9c6a77e4c
|
||||
size 5763878
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
1598761c1df1fd35b2180b599ad34f58 libtensorrt_llm_ucx_wrapper.so
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
0397a251a647dd5d25f0de5279170ba35d82c50d commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f975b781b240c8489a48243a94dfdf0be6bfe6b862cf6ec6cbeacd5c66fae7af
|
||||
size 36139148
|
||||
oid sha256:89b13b0625a17c038545a6a2c00e1e752d366c0afddcceb411c7903459e45911
|
||||
size 36266224
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
f9557afc965818430dcae14ae7542adf tensorrt_llm_batch_manager_static.lib
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
32ed5e6a9704a91e56cff30c1ebe7211 tensorrt_llm_batch_manager_static.lib
|
||||
0397a251a647dd5d25f0de5279170ba35d82c50d commit
|
||||
@ -38,6 +38,13 @@ static std::optional<int32_t> getIntEnv(char const* name)
|
||||
return {val};
|
||||
};
|
||||
|
||||
// Returns true if the env variable exists and is set to "1"
|
||||
static bool getBoolEnv(char const* name)
|
||||
{
|
||||
char const* env = std::getenv(name);
|
||||
return env && env[0] == '1' && env[1] == '\0';
|
||||
}
|
||||
|
||||
// XQA kernels (optimized kernels for generation phase).
|
||||
bool forceXQAKernels()
|
||||
{
|
||||
@ -143,15 +150,8 @@ bool getEnvEnablePDL()
|
||||
// PDL only available when arch >= 90
|
||||
if (getSMVersion() >= 90)
|
||||
{
|
||||
char const* enable_pdl = std::getenv("TRTLLM_ENABLE_PDL");
|
||||
if (enable_pdl)
|
||||
{
|
||||
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
|
||||
if (enable_pdl[0] == '1' && enable_pdl[1] == '\0')
|
||||
{
|
||||
enablePDL = true;
|
||||
}
|
||||
}
|
||||
// PDL will be enabled by setting the env variables `TRTLLM_ENABLE_PDL` to `1`
|
||||
enablePDL = getBoolEnv("TRTLLM_ENABLE_PDL");
|
||||
}
|
||||
}
|
||||
return enablePDL;
|
||||
@ -159,22 +159,7 @@ bool getEnvEnablePDL()
|
||||
|
||||
bool getEnvUseUCXKvCache()
|
||||
{
|
||||
static bool init = false;
|
||||
static bool useUCXKVCache = false;
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
{
|
||||
char const* use_ucx_kv_cache = std::getenv("TRTLLM_USE_UCX_KVCACHE");
|
||||
if (use_ucx_kv_cache)
|
||||
{
|
||||
if (use_ucx_kv_cache[0] == '1' && use_ucx_kv_cache[1] == '\0')
|
||||
{
|
||||
useUCXKVCache = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
static bool const useUCXKVCache = getBoolEnv("TRTLLM_USE_UCX_KVCACHE");
|
||||
return useUCXKVCache;
|
||||
}
|
||||
|
||||
@ -195,4 +180,11 @@ std::string getEnvUCXInterface()
|
||||
}
|
||||
return ucxInterface;
|
||||
}
|
||||
|
||||
bool getEnvDisaggLayerwise()
|
||||
{
|
||||
static bool const disaggLayerwise = getBoolEnv("TRTLLM_DISAGG_LAYERWISE");
|
||||
return disaggLayerwise;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
@ -45,4 +45,6 @@ bool getEnvUseUCXKvCache();
|
||||
|
||||
std::string getEnvUCXInterface();
|
||||
|
||||
bool getEnvDisaggLayerwise();
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <unordered_set>
|
||||
|
||||
@ -101,7 +100,7 @@ std::recursive_mutex mpiMutex;
|
||||
MpiComm initLocalSession()
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPI_Comm localComm;
|
||||
MPI_Comm localComm = nullptr;
|
||||
MPI_Comm_split_type(COMM_SESSION, OMPI_COMM_TYPE_HOST, COMM_SESSION.getRank(), MPI_INFO_NULL, &localComm);
|
||||
MpiComm localSession{localComm, false};
|
||||
#else
|
||||
@ -115,14 +114,16 @@ MpiComm initLocalSession()
|
||||
std::vector<int> getWorldRanks(MpiComm const& comm)
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPI_Group group, worldGroup;
|
||||
MPI_Group group = nullptr;
|
||||
MPI_Group worldGroup = nullptr;
|
||||
|
||||
MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
|
||||
MPICHECK(MPI_Comm_group(comm, &group));
|
||||
|
||||
int groupSize;
|
||||
int groupSize = 0;
|
||||
MPICHECK(MPI_Group_size(group, &groupSize));
|
||||
std::vector<int> ranks(groupSize), worldRanks(groupSize);
|
||||
std::vector<int> ranks(groupSize);
|
||||
std::vector<int> worldRanks(groupSize);
|
||||
std::iota(ranks.begin(), ranks.end(), 0);
|
||||
|
||||
MPICHECK(MPI_Group_translate_ranks(group, groupSize, ranks.data(), worldGroup, worldRanks.data()));
|
||||
@ -152,7 +153,7 @@ void initialize(MpiThreadSupport threadMode, bool forwardAbortToParent)
|
||||
if (!initialized)
|
||||
{
|
||||
TLLM_LOG_INFO("Initializing MPI with thread mode %d", threadMode);
|
||||
int providedMode;
|
||||
int providedMode = 0;
|
||||
auto requiredMode = static_cast<int>(threadMode);
|
||||
MPICHECK(MPI_Init_thread(nullptr, nullptr, requiredMode, &providedMode));
|
||||
TLLM_CHECK_WITH_INFO(providedMode >= requiredMode, "MPI_Init_thread failed");
|
||||
@ -287,7 +288,7 @@ MPI_Status MpiComm::recv(runtime::IBuffer& buf, int source, int tag) const
|
||||
|
||||
MpiComm MpiComm::split(int color, int key) const
|
||||
{
|
||||
MPI_Comm splitComm;
|
||||
MPI_Comm splitComm = nullptr;
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Comm_split(mComm, color, key, &splitComm));
|
||||
#else
|
||||
@ -431,11 +432,11 @@ void MpiComm::refreshLocalSession()
|
||||
}
|
||||
}
|
||||
|
||||
MPI_Group worldGroup;
|
||||
MPI_Group worldGroup = nullptr;
|
||||
MPICHECK(MPI_Comm_group(MPI_COMM_WORLD, &worldGroup));
|
||||
MPI_Group localGroup;
|
||||
MPI_Group localGroup = nullptr;
|
||||
MPICHECK(MPI_Group_incl(worldGroup, intersectionRanks.size(), intersectionRanks.data(), &localGroup));
|
||||
MPI_Comm localComm;
|
||||
MPI_Comm localComm = nullptr;
|
||||
MPICHECK(MPI_Comm_create_group(MPI_COMM_WORLD, localGroup, intersectionRanks.front(), &localComm));
|
||||
MpiComm::mutableLocalSession().mFreeComm = true;
|
||||
MpiComm::mutableLocalSession() = MpiComm{localComm, false};
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:33f66dba2f3024d979e38cf1aae4d10802c5a1fb0f4c801108c35824339eae5d
|
||||
size 2419566
|
||||
oid sha256:ee2f55f3882f75eec0e91ea8392899356b67feecf415945b5e8fee80045a1c97
|
||||
size 2493128
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d224780476ce5f398f30ffbfa0d61bbd0aae5cb1538c8d4c0a16cdf8945ba5d3
|
||||
size 2449532
|
||||
oid sha256:6ff7a82a0e772a5e9cfc9c02ca012bae35d0b98d00841df22df6745232c8bd96
|
||||
size 2523762
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
ee532edbf35321d4ac0aadf8a3c6a3a5 libtensorrt_llm_executor_static.a
|
||||
0bf468a19d4c353dcf421fc3e05a9d7d libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
e31839682c2e28f1ee65d0a4ea5bdbde libtensorrt_llm_executor_static.a
|
||||
544b31144b0cb90d4256e83b0c10bdfe libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
0397a251a647dd5d25f0de5279170ba35d82c50d commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9b21e2488bdb5c1e18e7aa129acb18087d031eea4f5b063910081ca09a3041a5
|
||||
size 3494984
|
||||
oid sha256:d07532b2d05dc3a69ad98cf8fdfab870a67a57863eb812721d74ff3fd4c740dc
|
||||
size 3563598
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:94964aa02020e38e869bf9ca18385ae379c8b9d1819ad02e10b23d8175cc9d82
|
||||
size 3412104
|
||||
oid sha256:5b7c5ab992090c7456adc23c38bbdd85ad0a5992bb2231763bda7fa2211ddadd
|
||||
size 3485776
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
ba01eba908f38eb582c22c1f822cfedf libtensorrt_llm_executor_static.a
|
||||
ffe68ec0af94d364ec8db50a24ae0e8c libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
a3216154712c98a922de02d68fea8456 libtensorrt_llm_executor_static.a
|
||||
9d0a93f148854b0fbd76aa36e7d54d8e libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
0397a251a647dd5d25f0de5279170ba35d82c50d commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:67f59341edab284c309d39f2a0ad39e91f8afe198c4cf6ba838ae7adb54ad01d
|
||||
size 23192460
|
||||
oid sha256:abfb1b75d8675abba0a8fc59b74d3f94e2ba3eaea4f28e0c4b2beea9cc182316
|
||||
size 23865270
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
e3cd49147c73b0066dcb759df9556191 tensorrt_llm_executor_static.lib
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
8e6a97333e79684aca515112982f2624 tensorrt_llm_executor_static.lib
|
||||
0397a251a647dd5d25f0de5279170ba35d82c50d commit
|
||||
@ -17,18 +17,18 @@
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <set>
|
||||
|
||||
@ -16,10 +16,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
// clang-format off
|
||||
#include "cutlass/cutlass.h"
|
||||
@ -29,9 +29,9 @@
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
// clang-format on
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
using namespace cute;
|
||||
|
||||
|
||||
@ -22,10 +22,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
@ -43,9 +43,9 @@
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
using namespace cute;
|
||||
|
||||
|
||||
@ -22,10 +22,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
@ -33,9 +33,9 @@
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "fp8_rowwise_gemm.h"
|
||||
#include "fp8_rowwise_gemm_kernel_template_sm89.h"
|
||||
|
||||
@ -87,6 +87,9 @@ protected:
|
||||
static constexpr int SPLIT_K_LIMIT = 7;
|
||||
static constexpr int MIN_M_TILE = 16;
|
||||
static constexpr int MIN_N_TILE = 64;
|
||||
|
||||
static constexpr int MAX_M_TILE_SM90 = 128;
|
||||
static constexpr int MAX_N_TILE_SM90 = 256;
|
||||
};
|
||||
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp,
|
||||
|
||||
@ -14,10 +14,10 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
@ -29,9 +29,9 @@
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
@ -568,6 +568,27 @@ CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, Bia
|
||||
int const m, int const n, int const k)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
// For Hopper, we have to allocate large memory size in case for stream-K
|
||||
if (sm_ == 90)
|
||||
{
|
||||
// https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L878-L892
|
||||
// The above lines says sk_tiles = output_tiles - (static_cast<uint32_t>(output_tiles / ctas_per_wave) - 1) *
|
||||
// ctas_per_wave This means sk_tiles is at most 2 * ctas_per_wave, which is 2 * multi_processor_count_
|
||||
int const max_sk_tiles = 2 * multi_processor_count_;
|
||||
|
||||
// https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L939
|
||||
// The above line says uint64_t sk_units = platform::min(ctas_per_sk_wave, min_sized_sk_units);
|
||||
// That means sk_units is at most ctas_per_sk_wave, which is multi_processor_count_
|
||||
int const max_sk_units = multi_processor_count_;
|
||||
|
||||
// https://github.com/NVIDIA/cutlass/blob/19b4c5e065e7e5bbc8082dfc7dbd792bdac850fc/include/cutlass/gemm/kernel/tile_scheduler_params.h#L505
|
||||
// The above lines scales sk_tiles by the factor of static_cast<uint32_t>(sk_units / sk_tiles + 2)
|
||||
// That means the final sk_tiles is at most 2 * max_sk_tiles + max_sk_units;
|
||||
int const max_sk_tiles_with_seperate_reduction = 2 * max_sk_tiles + max_sk_units;
|
||||
|
||||
return static_cast<size_t>(
|
||||
max_sk_tiles_with_seperate_reduction * MAX_M_TILE_SM90 * MAX_N_TILE_SM90 * sizeof(float));
|
||||
}
|
||||
// These are the min tile sizes for each config, which would launch the maximum number of blocks
|
||||
int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE);
|
||||
int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE);
|
||||
|
||||
@ -14,10 +14,10 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
@ -34,9 +34,9 @@
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
@ -161,8 +161,11 @@ void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using TileScheduler = cute::conditional_t<size<0>(CTAShape{}) == Int<64>{}, cutlass::gemm::PersistentScheduler,
|
||||
cutlass::gemm::StreamKScheduler>;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop, CollectiveEpilogue>;
|
||||
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
|
||||
|
||||
if (occupancy != nullptr)
|
||||
{
|
||||
|
||||
@ -16,10 +16,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
@ -38,9 +38,9 @@
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
using namespace cute;
|
||||
|
||||
|
||||
@ -16,10 +16,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
@ -27,9 +27,9 @@
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC
|
||||
|
||||
#include "fused_gated_gemm.h"
|
||||
#include "fused_gated_gemm_kernel_template_sm90.h"
|
||||
|
||||
@ -14,10 +14,10 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
// clang-format off
|
||||
#include <cutlass/gemm/device/default_gemm_configuration.h>
|
||||
@ -36,9 +36,9 @@
|
||||
#include "cutlass_extensions/gemm/kernel/default_int8_traits.h"
|
||||
#include "cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // #ifndef _WIN32
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
|
||||
@ -42,8 +42,6 @@
|
||||
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
|
||||
@ -15,8 +15,10 @@
|
||||
*/
|
||||
|
||||
// Ignore CUTLASS warnings about type punning
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
@ -44,7 +46,9 @@
|
||||
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
|
||||
#ifdef __GNUC__ // Restore GCC-specific diagnostics
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
@ -15,8 +15,10 @@
|
||||
*/
|
||||
|
||||
// Ignore CUTLASS warnings about type punning
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "cutlass/array.h"
|
||||
#include "cutlass/numeric_conversion.h"
|
||||
@ -44,7 +46,9 @@
|
||||
#include "cutlass_extensions/gemm/kernel/moe_cutlass_kernel.h"
|
||||
#include "cutlass_extensions/gemm/threadblock/default_mma.h"
|
||||
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // __GNUC__
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
@ -197,6 +197,7 @@ struct Multihead_attention_params_base
|
||||
int* block_counter = nullptr;
|
||||
|
||||
int const* memory_length_per_sample = nullptr;
|
||||
int32_t const* mrope_position_deltas = nullptr;
|
||||
};
|
||||
|
||||
template <typename T, bool USE_CROSS_ATTENTION = false>
|
||||
|
||||
@ -75,7 +75,8 @@ inline size_t smem_size_in_bytes(Multihead_attention_params<T, DO_CROSS_ATTENTIO
|
||||
|
||||
size_t transpose_rotary_size = 0;
|
||||
if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE)
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kROPE_M)
|
||||
{
|
||||
assert(params.rotary_embedding_dim > 0);
|
||||
transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(Tk);
|
||||
@ -416,7 +417,8 @@ void mmha_launch_kernel(KernelParamsType const& params, KVCacheBuffer const& kv_
|
||||
assert((params.rotary_embedding_dim != 0)
|
||||
== (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kROPE_GPTJ
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE));
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kROPE_M));
|
||||
if (params.beam_width == 1)
|
||||
{
|
||||
mmha_launch_kernel_dispatch<T, KVCacheBuffer, KernelParamsType, Dh, false, BLOCK_SPARSE_ATTN,
|
||||
|
||||
@ -1500,7 +1500,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
|
||||
int const beam0_context_length
|
||||
= HAS_BEAMS && tlength > cyclic_kv_cache_len ? 0 : params.input_lengths[batch_beam_idx];
|
||||
// The position of the current timestep, and it is used to apply the position embedding
|
||||
int const current_pos_idx = (!POS_SHIFT || DO_CROSS_ATTENTION) ? tlength : kv_loop_length;
|
||||
int current_pos_idx = (!POS_SHIFT || DO_CROSS_ATTENTION) ? tlength : kv_loop_length;
|
||||
|
||||
// The offset in the Q and K buffer also accounts for the batch.
|
||||
auto const qk_vec_idx = tidx * QK_VEC_SIZE;
|
||||
@ -1667,6 +1667,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
|
||||
break;
|
||||
}
|
||||
case PositionEmbeddingType::kLONG_ROPE:
|
||||
case PositionEmbeddingType::kROPE_M:
|
||||
case PositionEmbeddingType::kROPE_GPT_NEOX:
|
||||
{
|
||||
bool const do_rotary = is_valid_qk_vec && QK_VEC_SIZE * tidx < params.rotary_embedding_dim;
|
||||
@ -1680,6 +1681,10 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
|
||||
int const smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts
|
||||
|
||||
assert(half_rotary_dim % QK_VEC_SIZE == 0);
|
||||
if (params.position_embedding_type == PositionEmbeddingType::kROPE_M)
|
||||
{
|
||||
current_pos_idx += params.mrope_position_deltas[batch_idx];
|
||||
}
|
||||
|
||||
if (do_rotary)
|
||||
{
|
||||
|
||||
@ -221,8 +221,9 @@ void DecoderXQAImplJIT::runImpl(XQAParams const& xqaParams, KVCacheBuffer const&
|
||||
xqaParams.spec_decoding_generation_lengths, xqaParams.sequence_lengths, /* encoder_seqlens */ nullptr,
|
||||
xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr, /* cu_kv_seqlens */ nullptr,
|
||||
launchParams.rotary_inv_freq_buf, (float2 const*) nullptr, xqaParams.kv_scale_orig_quant,
|
||||
xqaParams.spec_decoding_position_offsets, int(batch_beam_size), xqaParams.generation_input_length,
|
||||
xqaParams.timestep, xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
|
||||
xqaParams.spec_decoding_position_offsets, xqaParams.mrope_rotary_sin_cos, xqaParams.mrope_position_deltas,
|
||||
int(batch_beam_size), xqaParams.generation_input_length, xqaParams.timestep,
|
||||
xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
|
||||
int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length),
|
||||
/*remove_padding*/ true, /*cross_attention*/ false, xqaParams.num_q_heads, xqaParams.num_kv_heads,
|
||||
xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size, xqaParams.rotary_embedding_dim,
|
||||
|
||||
@ -73,7 +73,7 @@ bool supportConfigCommon(XQAParams const& xqaParams, bool forConfigurePlugin)
|
||||
return false;
|
||||
}
|
||||
if (!contains({PositionEmbeddingType::kROPE_GPTJ, PositionEmbeddingType::kROPE_GPT_NEOX,
|
||||
PositionEmbeddingType::kLONG_ROPE},
|
||||
PositionEmbeddingType::kROPE_M, PositionEmbeddingType::kLONG_ROPE},
|
||||
xqaParams.position_embedding_type))
|
||||
{
|
||||
return false;
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
90df70c216d9aa2c85b8b097c853e4ba libtensorrt_llm_nvrtc_wrapper.so
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
0397a251a647dd5d25f0de5279170ba35d82c50d commit
|
||||
@ -1,2 +1,2 @@
|
||||
232f492424a31204a2be2e67be299aef libtensorrt_llm_nvrtc_wrapper.so
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
0397a251a647dd5d25f0de5279170ba35d82c50d commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bed2713947315cf941533dd12b5b98270a2aabd584cc33bc2092be6dbf879959
|
||||
oid sha256:ed682d2c566aa703cfd1276f1329d3db3b291aed7e325e521bbe6f3406d7cd84
|
||||
size 1128448
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
c5f36e093e875c8ea84523fb1566d986 tensorrt_llm_nvrtc_wrapper.lib
|
||||
aaa20992c207e46eab50dd90bcf3c405 tensorrt_llm_nvrtc_wrapper.dll
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
6ed8db839864ffdd8f22fff0d627ee9d tensorrt_llm_nvrtc_wrapper.dll
|
||||
0397a251a647dd5d25f0de5279170ba35d82c50d commit
|
||||
@ -204,9 +204,10 @@ public:
|
||||
xqaParams.sequence_lengths, /* encoder_seqlens */ nullptr,
|
||||
xqaParams.multi_query_tokens ? launchParams.cu_seq_lens : nullptr,
|
||||
/* cu_kv_seqlens */ nullptr, launchParams.rotary_inv_freq_buf, (float2 const*) nullptr,
|
||||
xqaParams.kv_scale_orig_quant, xqaParams.spec_decoding_position_offsets, int(batch_beam_size),
|
||||
xqaParams.generation_input_length, xqaParams.timestep, xqaParams.cyclic_attention_window_size,
|
||||
xqaParams.sink_token_length, int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length),
|
||||
xqaParams.kv_scale_orig_quant, xqaParams.spec_decoding_position_offsets, xqaParams.mrope_rotary_sin_cos,
|
||||
xqaParams.mrope_position_deltas, int(batch_beam_size), xqaParams.generation_input_length,
|
||||
xqaParams.timestep, xqaParams.cyclic_attention_window_size, xqaParams.sink_token_length,
|
||||
int(xqaParams.batch_size * beam_width * xqaParams.generation_input_length),
|
||||
/*remove_padding*/ true, /*cross_attention*/ false, xqaParams.num_q_heads, xqaParams.num_kv_heads,
|
||||
xqaParams.num_q_heads / xqaParams.num_kv_heads, xqaParams.head_size, xqaParams.rotary_embedding_dim,
|
||||
xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type, xqaParams.rotary_embedding_scale,
|
||||
|
||||
@ -54,6 +54,8 @@ struct XQAParams
|
||||
int const* spec_decoding_generation_lengths; // variable input lengths.
|
||||
bool spec_decoding_is_generation_length_variable; // whether the generation lengths actually vary
|
||||
int32_t spec_decoding_max_generation_length; // max possible input length
|
||||
float2 const* mrope_rotary_sin_cos = nullptr;
|
||||
int32_t const* mrope_position_deltas = nullptr;
|
||||
|
||||
// almost copy from GPTAttentionPluginCommon.
|
||||
// maybe use one struct for parameters in GPTAttentionPluginCommon and share the same here.
|
||||
|
||||
@ -59,6 +59,7 @@ enum class PositionEmbeddingType : int8_t
|
||||
kRELATIVE = 6,
|
||||
kCHATGLM = 7,
|
||||
kYARN = 8,
|
||||
kROPE_M = 9,
|
||||
};
|
||||
|
||||
enum class RotaryScalingType : int8_t
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
f1820d73fc5cac7fa324d71933e5412a libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
a8785db1cc11e3b571bc071d7abec1a8 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
0397a251a647dd5d25f0de5279170ba35d82c50d commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4b6917794ec6e67989fdcd0af3cc4d84713f3d8d4dcd822d2df2272117c66d6b
|
||||
oid sha256:14fa77c4e77a1a6c6955539f1721a773663785d628181e3bb30fedcccb676dfc
|
||||
size 36626184
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e6d2f3c25a8ce88917ba512eba804f14827703fab6f9ac8d63043e2d95b6b281
|
||||
oid sha256:da7b96feeabce735db2a0d0524c4d057407bc7e7201b92383d4f7693db4aba7f
|
||||
size 36080026
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
9e6ff6d826caeea1e6e19c71f5d0986b libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
2fca4e76d5f21089f00b1d39f624b80a libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
5b0663234aecc8eb59d51f1e0a9206d6 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
5da2c92d3a859d10cf65bf4edab96c62 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
0397a251a647dd5d25f0de5279170ba35d82c50d commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c2f34df6d47b7b2b6629358bb03b33eb193db067188e8b980598027b0ff85392
|
||||
size 2669968
|
||||
oid sha256:17d5d559c2a9748cf03b24d82270739430eee53c1d4dc41442ae05745724af84
|
||||
size 2669962
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
95c2f50347d4de94e2e09cbf0cf99582 tensorrt_llm_internal_cutlass_kernels_static.lib
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
133f764ad845f62a4641bffca9c436b2 tensorrt_llm_internal_cutlass_kernels_static.lib
|
||||
0397a251a647dd5d25f0de5279170ba35d82c50d commit
|
||||
@ -26,8 +26,10 @@
|
||||
#include <sstream>
|
||||
|
||||
// Ignore CUTLASS warnings about type punning
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/conv/convolution.h"
|
||||
@ -41,7 +43,9 @@
|
||||
|
||||
#include "cutlass_extensions/epilogue/thread/fused_activations.h"
|
||||
|
||||
#ifdef __GNUC__ // Check if the compiler is GCC or Clang
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/dataType.h"
|
||||
|
||||
@ -235,8 +235,23 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
|
||||
auto const curSeqLen = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot];
|
||||
auto const outIdx = returnAllSelectedTokens ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx;
|
||||
outputIdsRequestPtr[outIdx] = outputId;
|
||||
// cum log prob is not supported with returnAllSelectedTokens
|
||||
if (!returnAllSelectedTokens)
|
||||
|
||||
if (returnAllSelectedTokens)
|
||||
{
|
||||
// 'outputLogProbs' is the probability induced by the top-k sampling:
|
||||
// NOT normalized (same way as OpenAI does):
|
||||
// log_prob = log P(i | i is in vocab) = log(expLogit)
|
||||
// normalized:
|
||||
// log_prob = log P(i | i is in top-k) = log(expLogit / sum)
|
||||
if (outputLogProbs != nullptr)
|
||||
{
|
||||
// outputLogProbs shape: [maxBatchSize, maxTopK]
|
||||
auto logProb = logf(expLogit);
|
||||
auto const normalizedProb = normalizeLogProbs ? logProb - logf(sSum) : logProb;
|
||||
outputLogProbs[batchSlot * maxTopK + ki] = normalizedProb;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
|
||||
{
|
||||
@ -247,17 +262,14 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
|
||||
}
|
||||
if (outputLogProbs != nullptr)
|
||||
{
|
||||
// 'outputLogProbs' is the probability induced by the top-k sampling:
|
||||
// NOT normalized (same way as OpenAI does):
|
||||
// log_prob = log P(i | i is in vocab) = log(expLogit)
|
||||
// normalized:
|
||||
// log_prob = log P(i | i is in top-k) = log(expLogit / sum)
|
||||
outputLogProbs[curSeqLen * maxBatchSize + batchSlot]
|
||||
= normalizeLogProbs ? logProb - logf(sSum) : logProb;
|
||||
auto const normalizedProb = normalizeLogProbs ? logProb - logf(sSum) : logProb;
|
||||
// outputLogProbs shape: [maxSeqLen, maxBatchSize]
|
||||
outputLogProbs[curSeqLen * maxBatchSize + batchSlot] = normalizedProb;
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
||||
if (returnAllSelectedTokens && randNum <= 0.0f)
|
||||
{
|
||||
if (ki < k - 1)
|
||||
|
||||
@ -74,7 +74,9 @@ struct TopKSamplingKernelParams
|
||||
//! input/output buffer [maxBatchSize], optional.
|
||||
//! Cumulative log probability of selected tokens. Ignored if nullptr
|
||||
float* cumLogProbs{nullptr};
|
||||
//! output buffer [maxBatchSize]. Log probs is the probability induced by the top-k sampling.
|
||||
//! output buffer
|
||||
//! [maxBatchSize, maxTopK] when returnAllSelectedTokens, otherwise [maxSeqLen, maxBatchSize]
|
||||
//! Log probs is the probability induced by the top-k sampling.
|
||||
//! If normalizeLogProbs is true, we normalize the probability 'expLogit' of the selected token
|
||||
//! by the probability 's_sum' of a set of top-k tokens, meaning the logProb is the probability
|
||||
//! of the selected token, conditioned on the event that it is selected,
|
||||
@ -137,8 +139,13 @@ struct TopKSamplingKernelParams
|
||||
TLLM_CHECK(maxTokensPerStep != 1 || returnAllSelectedTokens || endIds);
|
||||
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
|
||||
{
|
||||
TLLM_CHECK(maxTokensPerStep == 1 && !returnAllSelectedTokens);
|
||||
TLLM_CHECK(maxTokensPerStep == 1);
|
||||
if (cumLogProbs != nullptr)
|
||||
{
|
||||
TLLM_CHECK(!returnAllSelectedTokens);
|
||||
}
|
||||
}
|
||||
|
||||
TLLM_CHECK(((finishedOutput == nullptr) ^ (endIds == nullptr)) == 0);
|
||||
|
||||
TLLM_CHECK(0 < maxTopP && maxTopP <= 1.f);
|
||||
|
||||
@ -98,11 +98,11 @@ namespace
|
||||
template <int BLOCK_SIZE>
|
||||
__global__ void prepareCtxEagleNetInputsKernel(SizeType32* eagleNetSequenceLengths, SizeType32* eagleNetContextLengths,
|
||||
TokenIdType* outputIds, SizeType32* positionIds, SizeType32* hiddenStatesIndices, SizeType32* lastTokenIndices,
|
||||
SizeType32* numOutputTokens, SizeType32* numLastTokenIndices, SizeType32* hiddenSizeBatchLevelStarts,
|
||||
TokenIdType const* inputIds, SizeType32 const* baseNetSequenceLengths, SizeType32 const* baseNetContextLengths,
|
||||
SizeType32* numLastTokenIndices, SizeType32* hiddenSizeBatchLevelStarts, TokenIdType const* inputIds,
|
||||
SizeType32 const* baseNetSequenceLengths, SizeType32 const* baseNetContextLengths,
|
||||
TokenIdType const* acceptedTokens, SizeType32 const* acceptedLens, SizeType32 const* prevDraftLens,
|
||||
SizeType32 const* prevPaths, SizeType32 const* bestPathIds, SizeType32 batchSize, SizeType32 maxPathLen,
|
||||
SizeType32 maxDecodingTokens)
|
||||
SizeType32 maxDecodingTokens, SizeType32 maxNonLeavesPerLayer)
|
||||
{
|
||||
typedef cub::BlockScan<SizeType32, BLOCK_SIZE> BlockScan;
|
||||
__shared__ typename BlockScan::TempStorage tempStorage;
|
||||
@ -135,6 +135,11 @@ __global__ void prepareCtxEagleNetInputsKernel(SizeType32* eagleNetSequenceLengt
|
||||
}
|
||||
}
|
||||
|
||||
for (SizeType32 ii = bid; ii < maxNonLeavesPerLayer * batchSize; ii += BLOCK_SIZE)
|
||||
{
|
||||
lastTokenIndices[ii] = 1;
|
||||
}
|
||||
|
||||
SizeType32 outputStartPos{0};
|
||||
SizeType32 inputIndexBase{0};
|
||||
SizeType32 lastTokenIndex{0};
|
||||
@ -215,9 +220,9 @@ __global__ void prepareCtxEagleNetInputsKernel(SizeType32* eagleNetSequenceLengt
|
||||
// The last thread writes number of flattened tokens.
|
||||
if (bid == BLOCK_SIZE - 1)
|
||||
{
|
||||
numOutputTokens[0] = outputStartPos + numDecodingTokens;
|
||||
// After the first EagleNet we predict exactly one set of logits per request.
|
||||
numLastTokenIndices[0] = batchSize;
|
||||
|
||||
// Set last hiddenSizeBatchLevelStarts.
|
||||
hiddenSizeBatchLevelStarts[batchSize] = batchSize;
|
||||
}
|
||||
@ -226,19 +231,20 @@ __global__ void prepareCtxEagleNetInputsKernel(SizeType32* eagleNetSequenceLengt
|
||||
|
||||
void invokePrepareCtxEagleNetInputs(SizeType32* eagleNetSequenceLengths, SizeType32* eagleNetContextLengths,
|
||||
TokenIdType* outputIds, SizeType32* positionIds, SizeType32* hiddenStatesIndices, SizeType32* lastTokenIndices,
|
||||
SizeType32* numOutputTokens, SizeType32* numLastTokenIndices, SizeType32* hiddenSizeBatchLevelStarts,
|
||||
TokenIdType const* inputIds, SizeType32 const* baseNetSequenceLengths, SizeType32 const* baseNetContextLengths,
|
||||
SizeType32* numLastTokenIndices, SizeType32* hiddenSizeBatchLevelStarts, TokenIdType const* inputIds,
|
||||
SizeType32 const* baseNetSequenceLengths, SizeType32 const* baseNetContextLengths,
|
||||
TokenIdType const* acceptedTokens, SizeType32 const* acceptedLens, SizeType32 const* prevDraftLens,
|
||||
SizeType32 const* prevPaths, SizeType32 const* bestPathIds, SizeType32 batchSize, SizeType32 maxPathLen,
|
||||
SizeType32 maxDecodingTokens, cudaStream_t stream)
|
||||
SizeType32 maxDecodingTokens, SizeType32 maxNonLeavesPerLayer, cudaStream_t stream)
|
||||
{
|
||||
SizeType32 constexpr BLOCK_SIZE = 512;
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
batchSize <= BLOCK_SIZE, "Batch size larger than %d is not supported for EAGLE yet", batchSize);
|
||||
prepareCtxEagleNetInputsKernel<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, stream>>>(eagleNetSequenceLengths,
|
||||
eagleNetContextLengths, outputIds, positionIds, hiddenStatesIndices, lastTokenIndices, numOutputTokens,
|
||||
numLastTokenIndices, hiddenSizeBatchLevelStarts, inputIds, baseNetSequenceLengths, baseNetContextLengths,
|
||||
acceptedTokens, acceptedLens, prevDraftLens, prevPaths, bestPathIds, batchSize, maxPathLen, maxDecodingTokens);
|
||||
eagleNetContextLengths, outputIds, positionIds, hiddenStatesIndices, lastTokenIndices, numLastTokenIndices,
|
||||
hiddenSizeBatchLevelStarts, inputIds, baseNetSequenceLengths, baseNetContextLengths, acceptedTokens,
|
||||
acceptedLens, prevDraftLens, prevPaths, bestPathIds, batchSize, maxPathLen, maxDecodingTokens,
|
||||
maxNonLeavesPerLayer);
|
||||
}
|
||||
|
||||
namespace
|
||||
@ -431,13 +437,13 @@ template <int BLOCK_SIZE>
|
||||
__global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths, SizeType32* nextContextLengths,
|
||||
TokenIdType* outputIds, SizeType32* positionIds, SizeType32* specDecodingGenLengths,
|
||||
SizeType32* specDecodingPositionOffsets, SizeType32* specDecodingPackedMasks, SizeType32* hiddenStatesIndices,
|
||||
SizeType32* lastTokenIndices, SizeType32* numOutputTokens, SizeType32* numLastTokenIndices,
|
||||
SizeType32* outputHiddenSizeBatchStartsPerLevel, SizeType32* cumSumGenerationLengths,
|
||||
SizeType32* maxGenerationLength, TokenIdType const* nextDraftIds, SizeType32 const* selectedDraftIndices,
|
||||
SizeType32 const* selectedDraftPosIds, SizeType32 const* numSelectedDraftIndices,
|
||||
SizeType32 const* eagleNet0SequenceLengths, SizeType32 const* prevContextLengths,
|
||||
SizeType32 const* inputHiddenSizeBatchStartsPerLevel, SizeType32 const* parentNonLeafInLevelOffset,
|
||||
SizeType32 levelIdx, SizeType32 batchSize, SizeType32 maxPathLen, SizeType32 maxDecodingTokens)
|
||||
SizeType32* lastTokenIndices, SizeType32* numLastTokenIndices, SizeType32* outputHiddenSizeBatchStartsPerLevel,
|
||||
SizeType32* cumSumGenerationLengths, SizeType32* maxGenerationLength, TokenIdType const* nextDraftIds,
|
||||
SizeType32 const* selectedDraftIndices, SizeType32 const* selectedDraftPosIds,
|
||||
SizeType32 const* numSelectedDraftIndices, SizeType32 const* eagleNet0SequenceLengths,
|
||||
SizeType32 const* prevContextLengths, SizeType32 const* inputHiddenSizeBatchStartsPerLevel,
|
||||
SizeType32 const* parentNonLeafInLevelOffset, SizeType32 levelIdx, SizeType32 batchSize, SizeType32 maxPathLen,
|
||||
SizeType32 maxDecodingTokens, SizeType32 maxNonLeavesPerLayer)
|
||||
{
|
||||
typedef cub::BlockScan<SizeType32, BLOCK_SIZE> BlockScan;
|
||||
typedef cub::BlockReduce<SizeType32, BLOCK_SIZE> BlockReduce;
|
||||
@ -551,7 +557,7 @@ __global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths,
|
||||
|
||||
auto const lastStart = inputHiddenSizeBatchStartsPerLevel[(levelIdx - 1) * batchSize + batchSize];
|
||||
// Set new layer idx.
|
||||
outputHiddenSizeBatchStartsPerLevel[levelIdx * batchSize + bid] = lastStart + outputLastIndicesBase;
|
||||
outputHiddenSizeBatchStartsPerLevel[levelIdx * batchSize + bid] = lastStart + bid * maxNonLeavesPerLayer;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@ -559,12 +565,11 @@ __global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths,
|
||||
// The last valid thread fills the number of tokens.
|
||||
if (bid == batchSize - 1)
|
||||
{
|
||||
numOutputTokens[0] = outputIndexBase + nextDraftLen;
|
||||
// Set the total number of logits needed after the next iteration.
|
||||
numLastTokenIndices[0] = lastIndices;
|
||||
// Set last outputHiddenSizeBatchStartsPerLevel.
|
||||
outputHiddenSizeBatchStartsPerLevel[levelIdx * batchSize + batchSize]
|
||||
= outputHiddenSizeBatchStartsPerLevel[levelIdx * batchSize + batchSize - 1] + numNextLogits;
|
||||
= outputHiddenSizeBatchStartsPerLevel[levelIdx * batchSize + batchSize - 1] + maxNonLeavesPerLayer;
|
||||
}
|
||||
}
|
||||
|
||||
@ -774,12 +779,12 @@ void invokePrepareGenEagleNetInputs(PrepareGenEagleNetInputsParams const& params
|
||||
prepareGenEagleNetInputsKernel<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, params.stream>>>(params.nextSequenceLengths,
|
||||
params.nextContextLengths, params.outputIds, params.positionIds, params.specDecodingGenLengths,
|
||||
params.specDecodingPositionOffsets, params.specDecodingPackedMasks, params.hiddenStatesIndices,
|
||||
params.lastTokenIndices, params.numOutputTokens, params.numLastTokenIndices,
|
||||
params.outputHiddenSizeBatchStartsPerLevel, params.cumSumGenerationLengths, params.maxGenerationLength,
|
||||
params.nextDraftIds, params.selectedDraftIndices, params.selectedDraftPosOffsets,
|
||||
params.numSelectedDraftIndices, params.eagleNet0SequenceLengths, params.prevContextLengths,
|
||||
params.inputHiddenSizeBatchStartsPerLevel, params.parentNonLeafInLevelOffset, params.levelIdx,
|
||||
params.batchSize, params.maxPathLen, params.maxDecodingTokens);
|
||||
params.lastTokenIndices, params.numLastTokenIndices, params.outputHiddenSizeBatchStartsPerLevel,
|
||||
params.cumSumGenerationLengths, params.maxGenerationLength, params.nextDraftIds,
|
||||
params.selectedDraftIndices, params.selectedDraftPosOffsets, params.numSelectedDraftIndices,
|
||||
params.eagleNet0SequenceLengths, params.prevContextLengths, params.inputHiddenSizeBatchStartsPerLevel,
|
||||
params.parentNonLeafInLevelOffset, params.levelIdx, params.batchSize, params.maxPathLen,
|
||||
params.maxDecodingTokens, params.maxNonLeavesPerLayer);
|
||||
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
@ -803,11 +808,13 @@ namespace
|
||||
|
||||
template <typename T>
|
||||
__global__ void assembleDraftLogitsOffsets(T const** logitsPtrs, T const* logits, TokenIdType** outputIdsPtrs,
|
||||
TokenIdType* outputIds, SizeType32 numInputLogits, SizeType32 maxDecodingDraftTokens, SizeType32 vocabSizePadded)
|
||||
TokenIdType* outputIds, bool* skipDecode, runtime::SizeType32 const* numValidLogits, SizeType32 batchSize,
|
||||
SizeType32 maxDecodingDraftTokens, SizeType32 vocabSizePadded)
|
||||
{
|
||||
auto const tix = static_cast<SizeType32>(blockIdx.x * blockDim.x + threadIdx.x);
|
||||
auto const isValid{tix < numValidLogits[0]};
|
||||
|
||||
if (tix < numInputLogits)
|
||||
if (isValid)
|
||||
{
|
||||
// logits: [numInputLogits, vocab_size]
|
||||
// logitsPtrs: [numInputLogits][1, vocab_size]
|
||||
@ -817,29 +824,34 @@ __global__ void assembleDraftLogitsOffsets(T const** logitsPtrs, T const* logits
|
||||
// outputIdsPtrs: [numInputLogits][maxDecodingDraftTokens]
|
||||
outputIdsPtrs[tix] = outputIds + tix * maxDecodingDraftTokens;
|
||||
}
|
||||
|
||||
skipDecode[tix] = !isValid;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void invokeAssembleDraftLogitsOffsets(T const** logitsPtrs, T const* logits, runtime::TokenIdType** outputIdsPtrs,
|
||||
runtime::TokenIdType* outputIds, runtime::SizeType32 numInputLogits, runtime::SizeType32 maxDecodingDraftTokens,
|
||||
runtime::TokenIdType* outputIds, bool* skipDecode, runtime::SizeType32 const* numValidLogits,
|
||||
runtime::SizeType32 numInputLogits, runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingDraftTokens,
|
||||
runtime::SizeType32 vocabSizePadded, cudaStream_t stream)
|
||||
{
|
||||
SizeType32 constexpr BLOCK_SIZE = 512;
|
||||
|
||||
assembleDraftLogitsOffsets<T><<<divUp(numInputLogits, BLOCK_SIZE), BLOCK_SIZE, 0, stream>>>(
|
||||
logitsPtrs, logits, outputIdsPtrs, outputIds, numInputLogits, maxDecodingDraftTokens, vocabSizePadded);
|
||||
assembleDraftLogitsOffsets<T><<<divUp(numInputLogits, BLOCK_SIZE), BLOCK_SIZE, 0, stream>>>(logitsPtrs, logits,
|
||||
outputIdsPtrs, outputIds, skipDecode, numValidLogits, batchSize, maxDecodingDraftTokens, vocabSizePadded);
|
||||
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
template void invokeAssembleDraftLogitsOffsets(float const** logitsPtrs, float const* logits,
|
||||
runtime::TokenIdType** outputIdsPtrs, runtime::TokenIdType* outputIds, runtime::SizeType32 numInputLogits,
|
||||
runtime::TokenIdType** outputIdsPtrs, runtime::TokenIdType* outputIds, bool* skipDecode,
|
||||
runtime::SizeType32 const* numValidLogits, runtime::SizeType32 numInputLogits, runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 maxDecodingDraftTokens, runtime::SizeType32 vocabSizePadded, cudaStream_t stream);
|
||||
|
||||
template void invokeAssembleDraftLogitsOffsets(__half const** logitsPtrs, __half const* logits,
|
||||
runtime::TokenIdType** outputIdsPtrs, runtime::TokenIdType* outputIds, runtime::SizeType32 numInputLogits,
|
||||
runtime::TokenIdType** outputIdsPtrs, runtime::TokenIdType* outputIds, bool* skipDecode,
|
||||
runtime::SizeType32 const* numValidLogits, runtime::SizeType32 numInputLogits, runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 maxDecodingDraftTokens, runtime::SizeType32 vocabSizePadded, cudaStream_t stream);
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -970,8 +982,8 @@ __global__ void extracTopKsFromSuccessorsArray(SizeType32* topKs, SizeType32* to
|
||||
// Extract topKs from paths and layerId
|
||||
void invokeExtractTopKsFromPath(runtime::SizeType32 const* paths, runtime::SizeType32* topKs,
|
||||
runtime::SizeType32* topKOffset, runtime::SizeType32* numSuccessorsForEachNode, runtime::SizeType32 layerId,
|
||||
runtime::SizeType32 batchSize, runtime::SizeType32 numInputLogits, runtime::SizeType32 maxDecodingTokens,
|
||||
runtime::SizeType32 maxPathLen, cudaStream_t stream)
|
||||
runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 maxPathLen,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
@ -999,8 +1011,8 @@ namespace
|
||||
{
|
||||
__global__ void copyOutputTokensIds(TokenIdType** tmpOutputIdsPtrs, SizeType32 const* topKs,
|
||||
SizeType32 const* topKOffset, TokenIdType const* pluginInputDraftIdsPtrs, SizeType32 const* pluginInputDraftLens,
|
||||
TokenIdType* pluginOutputDraftIdsPtrs, SizeType32* pluginOutputDraftLens, SizeType32 layerId, SizeType32 batchSize,
|
||||
SizeType32 numInputLogits, SizeType32 maxDecodingDraftTokens)
|
||||
SizeType32 const* numValidLogits, TokenIdType* pluginOutputDraftIdsPtrs, SizeType32* pluginOutputDraftLens,
|
||||
SizeType32 layerId, SizeType32 batchSize, SizeType32 maxDecodingDraftTokens)
|
||||
{
|
||||
// tmpOutputIdsPtrs: shape [numInputLogits][maxDecodingDraftTokens]
|
||||
// topKs: shape [numInputLogits]
|
||||
@ -1030,7 +1042,7 @@ __global__ void copyOutputTokensIds(TokenIdType** tmpOutputIdsPtrs, SizeType32 c
|
||||
|
||||
// Compute the topK offset
|
||||
SizeType32 startTopKOffset = topKOffset[tix];
|
||||
SizeType32 endTopkOffset = tix + 1 < batchSize ? topKOffset[tix + 1] : numInputLogits;
|
||||
SizeType32 endTopkOffset = tix + 1 < batchSize ? topKOffset[tix + 1] : numValidLogits[0];
|
||||
|
||||
for (SizeType32 ii = startTopKOffset; ii < endTopkOffset; ii++)
|
||||
{
|
||||
@ -1051,15 +1063,16 @@ __global__ void copyOutputTokensIds(TokenIdType** tmpOutputIdsPtrs, SizeType32 c
|
||||
// Copy output draft token ids from temporary buffer to plugin output buffer, also update the draft token length
|
||||
void invokeCopyOutputTokensIds(runtime::TokenIdType** tmpOutputIdsPtrs, runtime::SizeType32 const* topKs,
|
||||
runtime::SizeType32 const* topKOffset, runtime::TokenIdType const* pluginInputDraftIdsPtrs,
|
||||
runtime::SizeType32 const* pluginInputDraftLens, runtime::TokenIdType* pluginOutputDraftIdsPtrs,
|
||||
runtime::SizeType32* pluginOutputDraftLens, runtime::SizeType32 layerId, runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 numInputLogits, runtime::SizeType32 maxDecodingDraftTokens, cudaStream_t stream)
|
||||
runtime::SizeType32 const* pluginInputDraftLens, runtime::SizeType32 const* numValidLogits,
|
||||
runtime::TokenIdType* pluginOutputDraftIdsPtrs, runtime::SizeType32* pluginOutputDraftLens,
|
||||
runtime::SizeType32 layerId, runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingDraftTokens,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
SizeType32 constexpr BLOCK_SIZE = 512;
|
||||
|
||||
copyOutputTokensIds<<<divUp(batchSize, BLOCK_SIZE), BLOCK_SIZE, 0, stream>>>(tmpOutputIdsPtrs, topKs, topKOffset,
|
||||
pluginInputDraftIdsPtrs, pluginInputDraftLens, pluginOutputDraftIdsPtrs, pluginOutputDraftLens, layerId,
|
||||
batchSize, numInputLogits, maxDecodingDraftTokens);
|
||||
pluginInputDraftIdsPtrs, pluginInputDraftLens, numValidLogits, pluginOutputDraftIdsPtrs, pluginOutputDraftLens,
|
||||
layerId, batchSize, maxDecodingDraftTokens);
|
||||
}
|
||||
|
||||
namespace
|
||||
|
||||
@ -41,21 +41,25 @@ void invokeAssembleTargetLogitsOffsets(T const** logitsPtrs, runtime::SizeType32
|
||||
runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 vocabSizePadded, cudaStream_t stream);
|
||||
|
||||
//! FIXME: We may get rid of this kernel in future optimization
|
||||
//! \brief Set the logitsPtrs[numInputLogits][1, vocabSizePadded] from logits [numInputLogits * vocabSizePadded]
|
||||
//! and outputIdsPtrs[numInputLogits][maxDecodingDraftTokens] from outputIds[numInputLogits * maxDecodingDraftTokens]
|
||||
//! \brief Set the logitsPtrs[numValidLogits][1, vocabSizePadded] from logits [numInputLogits * vocabSizePadded]
|
||||
//! and outputIdsPtrs[numValidLogits][maxDecodingDraftTokens] from outputIds[numInputLogits * maxDecodingDraftTokens]
|
||||
//! Can be merged into other kernels
|
||||
//! \param logitsPtrs [numInputLogits][1, vocabSizePadded], on GPU. The logits pointer array that will be used in topK
|
||||
//! \param logitsPtrs [numValidLogits][1, vocabSizePadded], on GPU. The logits pointer array that will be used in topK
|
||||
//! sampling.
|
||||
//! \param logits [numInputLogits * vocabSizePadded], on GPU. Flatten logits, generated by the EagleNet.
|
||||
//! \param outputIdsPtrs [numInputLogits][maxDecodingDraftTokens], on GPU. The output buffer of the topK sampling.
|
||||
//! \param outputIdsPtrs [numValidLogits][maxDecodingDraftTokens], on GPU. The output buffer of the topK sampling.
|
||||
//! \param outputIds [numInputLogits * maxDecodingDraftTokens], on GPU. The flatten output buffer.
|
||||
//! \param skipDecode [batchSize * maxNonLeavesPerLayer], on GPU. Flag whether to skip decoding or not.
|
||||
//! First batchSize * sum(numValidLogitsPerRequest[:]) are set to true, the rest is false.
|
||||
//! \param numValidLogits [1], on GPU. Number of valid logits.
|
||||
//! \param numInputLogits SizeType32. Number of logits from all the requests.
|
||||
//! \param maxDecodingDraftTokens maximum number of decoding draft tokens per step per request
|
||||
//! \param vocabSizePadded vocab size of the logits
|
||||
//! \param stream cuda stream
|
||||
template <typename T>
|
||||
void invokeAssembleDraftLogitsOffsets(T const** logitsPtrs, T const* logits, runtime::TokenIdType** outputIdsPtrs,
|
||||
runtime::TokenIdType* outputIds, runtime::SizeType32 numInputLogits, runtime::SizeType32 maxDecodingDraftTokens,
|
||||
runtime::TokenIdType* outputIds, bool* skipDecode, runtime::SizeType32 const* numValidLogits,
|
||||
runtime::SizeType32 numInputLogits, runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingDraftTokens,
|
||||
runtime::SizeType32 vocabSizePadded, cudaStream_t stream);
|
||||
|
||||
//! \brief Prepares data for ctx stage EagleNet (EagleNet0).
|
||||
@ -74,12 +78,11 @@ void invokeAssembleDraftLogitsOffsets(T const** logitsPtrs, T const* logits, run
|
||||
//! E.g. With 3 requests where the first two are context requests with lengths 5 and 3 respectively and the 3rd
|
||||
//! is gen request with draftDecodingTokens=8 and acceptedLength=3 and the best path is [0, 2, 5].
|
||||
//! hiddenStatesIndices equals to [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 13].
|
||||
//! \param lastTokenIndices output buffer [numLastTokenIndices],
|
||||
//! \param numOutputTokens output buffer [1], single number equals to the total
|
||||
//! number of select tokens summed over all batches.
|
||||
//! \param numLastTokenIndices output buffer [1], single number equals to the total
|
||||
//! number of logits predicted by the next EagleNet iteration tokens summed over all batches.
|
||||
//! For EagleNet0 it is equal to the batchSize.
|
||||
//! \param lastTokenIndices output buffer [batchSize * maxNonLeavesPerLayer],
|
||||
//! Indices (starting with 1) of the logits of interest after the EagleNet prediction.
|
||||
//! Used for index_select of the hidden_states in the end of the EagleNet. Padded to maxNonLeavesPerLayer with 1s.
|
||||
//! \param numLastTokenIndices output buffer [1], number of logits predicted by the next EagleNet
|
||||
//! iteration. For EagleNet0 each value is batchSize.
|
||||
//! \param hiddenSizeBatchLevelStarts output buffer [batchSize * maxDraftPathLen + 1]
|
||||
//! Exclusive sum of the hidden states produced per batch per level.
|
||||
//! For EagleNet0 it is just cum sum of 1s for batchSize.
|
||||
@ -96,17 +99,18 @@ void invokeAssembleDraftLogitsOffsets(T const** logitsPtrs, T const* logits, run
|
||||
//! \param batchSize batch size
|
||||
//! \param maxPathLen Max number of accepted tokens per step
|
||||
//! \param maxDecodingTokens Max number of draft tokens + 1
|
||||
//! \param maxNonLeavesPerLayer Maximum number of non-leaf nodes per layer
|
||||
//! \param stream cuda stream.
|
||||
void invokePrepareCtxEagleNetInputs(runtime::SizeType32* eagleNetSequenceLengths,
|
||||
runtime::SizeType32* eagleNetContextLengths, runtime::TokenIdType* outputIds, runtime::SizeType32* positionIds,
|
||||
runtime::SizeType32* hiddenStatesIndices, runtime::SizeType32* lastTokenIndices,
|
||||
runtime::SizeType32* numOutputTokens, runtime::SizeType32* numLastTokenIndices,
|
||||
runtime::SizeType32* hiddenSizeBatchLevelStarts, runtime::TokenIdType const* inputIds,
|
||||
runtime::SizeType32 const* baseNetSequenceLengths, runtime::SizeType32 const* baseNetContextLengths,
|
||||
runtime::TokenIdType const* acceptedTokens, runtime::SizeType32 const* acceptedLens,
|
||||
runtime::SizeType32 const* prevDraftLens, runtime::SizeType32 const* prevPaths,
|
||||
runtime::SizeType32 const* bestPathIds, runtime::SizeType32 batchSize, runtime::SizeType32 maxPathLen,
|
||||
runtime::SizeType32 maxDecodingTokens, cudaStream_t stream);
|
||||
runtime::SizeType32* numLastTokenIndices, runtime::SizeType32* hiddenSizeBatchLevelStarts,
|
||||
runtime::TokenIdType const* inputIds, runtime::SizeType32 const* baseNetSequenceLengths,
|
||||
runtime::SizeType32 const* baseNetContextLengths, runtime::TokenIdType const* acceptedTokens,
|
||||
runtime::SizeType32 const* acceptedLens, runtime::SizeType32 const* prevDraftLens,
|
||||
runtime::SizeType32 const* prevPaths, runtime::SizeType32 const* bestPathIds, runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 maxPathLen, runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 maxNonLeavesPerLayer,
|
||||
cudaStream_t stream);
|
||||
|
||||
struct PrepareGenEagleNetInputsParams
|
||||
{
|
||||
@ -136,14 +140,11 @@ struct PrepareGenEagleNetInputsParams
|
||||
//! output buffer [numOutputTokens]
|
||||
//! Indices of the hidden states for selected tokens for the next EagleNet iteration.
|
||||
runtime::SizeType32* hiddenStatesIndices{nullptr};
|
||||
//! output buffer [numLastTokenIndices]
|
||||
//! output buffer [batchSize * maxNonLeavesPerLayer]
|
||||
//! Indices of the hidden states where to sample logits from after the next EagleNet iteration.
|
||||
runtime::SizeType32* lastTokenIndices{nullptr};
|
||||
//! output buffer [1]
|
||||
//! Single number equals to the total number of select tokens summed over all batches.
|
||||
runtime::SizeType32* numOutputTokens{nullptr};
|
||||
//! output buffer [1]
|
||||
//! Single number equals to the total number of logits to be predicted by the next EagleNet summed over all batches.
|
||||
//! Number of logits predicted by the next EagleNet iteration.
|
||||
runtime::SizeType32* numLastTokenIndices{nullptr};
|
||||
//! input buffer [(maxPathLen - 1) * batchSize + 1]
|
||||
//! Exclusive sum of the hidden states produced per batch per level.
|
||||
@ -205,6 +206,8 @@ struct PrepareGenEagleNetInputsParams
|
||||
runtime::SizeType32 maxPathLen{0};
|
||||
//! Max number of draft tokens + 1
|
||||
runtime::SizeType32 maxDecodingTokens{0};
|
||||
//! Maximum number of non-leaf nodes per layer
|
||||
runtime::SizeType32 maxNonLeavesPerLayer{0};
|
||||
cudaStream_t stream;
|
||||
|
||||
void checkParams()
|
||||
@ -218,7 +221,6 @@ struct PrepareGenEagleNetInputsParams
|
||||
TLLM_CHECK(specDecodingPackedMasks);
|
||||
TLLM_CHECK(hiddenStatesIndices);
|
||||
TLLM_CHECK(lastTokenIndices);
|
||||
TLLM_CHECK(numOutputTokens);
|
||||
TLLM_CHECK(numLastTokenIndices);
|
||||
TLLM_CHECK(outputHiddenSizeBatchStartsPerLevel);
|
||||
|
||||
@ -242,6 +244,7 @@ struct PrepareGenEagleNetInputsParams
|
||||
TLLM_CHECK(maxPathLen > 0);
|
||||
TLLM_CHECK(maxDecodingTokens > 0);
|
||||
TLLM_CHECK(0 < levelIdx && levelIdx < maxPathLen - 1);
|
||||
TLLM_CHECK(maxNonLeavesPerLayer > 0);
|
||||
}
|
||||
};
|
||||
|
||||
@ -511,8 +514,8 @@ void invokeGetPackedMaskFromPath(int32_t* specDecodingPackedMasks, runtime::Size
|
||||
//! \param stream cuda stream.
|
||||
void invokeExtractTopKsFromPath(runtime::SizeType32 const* paths, runtime::SizeType32* topKs,
|
||||
runtime::SizeType32* topKOffset, runtime::SizeType32* numSuccessorsForEachNode, runtime::SizeType32 layerId,
|
||||
runtime::SizeType32 batchSize, runtime::SizeType32 numInputLogits, runtime::SizeType32 maxDecodingTokens,
|
||||
runtime::SizeType32 maxPathLen, cudaStream_t stream);
|
||||
runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 maxPathLen,
|
||||
cudaStream_t stream);
|
||||
|
||||
//! \brief Copy the output draft token from input buffer (generated from previous EagleNets)
|
||||
//! and new draft tokens generated by this layers to the output buffer of this plugin
|
||||
@ -526,6 +529,7 @@ void invokeExtractTopKsFromPath(runtime::SizeType32 const* paths, runtime::SizeT
|
||||
//! which contains draft tokens generated by previous EagleNets.
|
||||
//! \param pluginInputDraftLens [batchSize], on GPU. The
|
||||
//! plugin's input buffer, which contains the draft length from previous EagleNets.
|
||||
//! \param numValidLogits [1], on GPU. The number of valid logits.
|
||||
//! \param pluginOutputDraftIdsPtrs [batchSize * maxDecodingDraftTokens], on GPU. The plugin's output buffer,
|
||||
//! which will contains all the draft tokens generated by this and previous EagleNets.
|
||||
//! \param pluginOutputDraftLens [batchSize], on GPU. The plugin's input buffer,
|
||||
@ -533,13 +537,13 @@ void invokeExtractTopKsFromPath(runtime::SizeType32 const* paths, runtime::SizeT
|
||||
//! \param layerId SizeType32. The layerId of the EagleNet. Will
|
||||
//! be used to traverse a specific level of the tree.
|
||||
//! \param batchSize SizeType32. Batch size.
|
||||
//! \param numInputLogits SizeType32. Number of logits from all the requests.
|
||||
//! \param maxDecodingDraftTokens maximum number of decoding draft tokens per step per request.
|
||||
//! \param stream cuda stream.
|
||||
void invokeCopyOutputTokensIds(runtime::TokenIdType** tmpOutputIdsPtrs, runtime::SizeType32 const* topKs,
|
||||
runtime::SizeType32 const* topKOffset, runtime::TokenIdType const* pluginInputDraftIdsPtrs,
|
||||
runtime::SizeType32 const* pluginInputDraftLens, runtime::TokenIdType* pluginOutputDraftIdsPtrs,
|
||||
runtime::SizeType32* pluginOutputDraftLens, runtime::SizeType32 layerId, runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 numInputLogits, runtime::SizeType32 maxDecodingDraftTokens, cudaStream_t stream);
|
||||
runtime::SizeType32 const* pluginInputDraftLens, runtime::SizeType32 const* numValidLogits,
|
||||
runtime::TokenIdType* pluginOutputDraftIdsPtrs, runtime::SizeType32* pluginOutputDraftLens,
|
||||
runtime::SizeType32 layerId, runtime::SizeType32 batchSize, runtime::SizeType32 maxDecodingDraftTokens,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace tensorrt_llm::kernels::speculative_decoding
|
||||
|
||||
@ -1411,6 +1411,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T* q_buf, T* k_buf, T* v_buf,
|
||||
break;
|
||||
}
|
||||
case PositionEmbeddingType::kLONG_ROPE:
|
||||
case PositionEmbeddingType::kROPE_M:
|
||||
case PositionEmbeddingType::kROPE_GPT_NEOX:
|
||||
{
|
||||
bool const do_rotary = !is_masked && vec_size * tidx < rotary_embedding_dim;
|
||||
@ -1531,6 +1532,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf, T* k_buf, T* v_buf, T* QKV, T cons
|
||||
dim3 grid(token_num, std::max(head_num, 1));
|
||||
size_t smem_size = (position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
|| position_embedding_type == PositionEmbeddingType::kLONG_ROPE
|
||||
|| position_embedding_type == PositionEmbeddingType::kROPE_M
|
||||
? 2 * rotary_embedding_dim * sizeof(T)
|
||||
: 0);
|
||||
// NOTE: add offset for rotary embedding
|
||||
@ -1923,6 +1925,7 @@ __global__ void shiftKCache(KVCacheBuffer kvCacheBuffer, KVLinearBuffer shiftKCa
|
||||
break;
|
||||
}
|
||||
case PositionEmbeddingType::kLONG_ROPE:
|
||||
case PositionEmbeddingType::kROPE_M:
|
||||
case PositionEmbeddingType::kROPE_GPT_NEOX:
|
||||
{
|
||||
bool const do_rotary = vec_size * tidx < rotary_embedding_dim;
|
||||
@ -1983,6 +1986,7 @@ void invokeShiftKCache(KVCacheBuffer const& kvCacheBuffer, KVLinearBuffer const&
|
||||
dim3 grid(token_num_in_k, kv_head_num, batch_beam);
|
||||
size_t smem_size = (position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
|| position_embedding_type == PositionEmbeddingType::kLONG_ROPE
|
||||
|| position_embedding_type == PositionEmbeddingType::kROPE_M
|
||||
? 2 * rotary_embedding_dim * sizeof(T)
|
||||
: 0);
|
||||
|
||||
|
||||
@ -106,6 +106,9 @@ struct QKVPreprocessingParams
|
||||
float const* kvScaleOrigQuant{nullptr};
|
||||
int const* spec_decoding_position_offsets{nullptr};
|
||||
|
||||
float2 const* mrope_rotary_sin_cos{nullptr};
|
||||
int32_t const* mrope_position_deltas{nullptr};
|
||||
|
||||
// Scalars.
|
||||
int batch_size{0};
|
||||
int max_input_seq_len{0};
|
||||
|
||||
@ -383,10 +383,12 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
|
||||
|
||||
// NOTE: only spec decoding needs the position offsets.
|
||||
// In the generation phase, we assume all sequences should have the same input length.
|
||||
int const rotary_position = params.spec_decoding_position_offsets != nullptr
|
||||
? (params.spec_decoding_position_offsets[local_token_idx + batch_idx * params.max_input_seq_len]
|
||||
+ cache_seq_len - actual_seq_len)
|
||||
: token_idx_in_seq;
|
||||
int const rotary_position
|
||||
= (params.spec_decoding_position_offsets != nullptr ? (
|
||||
params.spec_decoding_position_offsets[local_token_idx + batch_idx * params.max_input_seq_len]
|
||||
+ cache_seq_len - actual_seq_len)
|
||||
: token_idx_in_seq)
|
||||
+ (params.mrope_position_deltas != nullptr ? params.mrope_position_deltas[batch_idx] : 0);
|
||||
|
||||
if (!valid_token)
|
||||
{
|
||||
@ -753,8 +755,18 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
|
||||
}
|
||||
|
||||
// Cos/sin cache.
|
||||
[[maybe_unused]] float2 const* rotary_coef_cache_buffer
|
||||
= params.rotary_coef_cache_buffer + static_cast<size_t>(rotary_position) * params.half_rotary_dim;
|
||||
[[maybe_unused]] float2 const* rotary_coef_cache_buffer = nullptr;
|
||||
if (params.mrope_rotary_sin_cos != nullptr)
|
||||
{
|
||||
rotary_coef_cache_buffer = params.mrope_rotary_sin_cos + batch_idx * params.rotary_embedding_max_positions
|
||||
+ static_cast<size_t>(rotary_position) * params.half_rotary_dim;
|
||||
}
|
||||
else
|
||||
{
|
||||
rotary_coef_cache_buffer
|
||||
= params.rotary_coef_cache_buffer + static_cast<size_t>(rotary_position) * params.half_rotary_dim;
|
||||
}
|
||||
|
||||
if constexpr (ROTARY_TYPE == RotaryPositionEmbeddingType::GPT_NEOX)
|
||||
{
|
||||
rotary_coef_cache_buffer += gptneox_rotary_dim_idx;
|
||||
@ -892,7 +904,8 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
|
||||
grid.z = std::min(int(divUp(params.multi_processor_count * WARPS_PER_SM, grid.x * grid.y)), \
|
||||
int(divUp(params.batch_size, MIN_SEQUENCES_PER_WARP))); \
|
||||
if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX \
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE) \
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE \
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kROPE_M) \
|
||||
{ \
|
||||
applyBiasRopeUpdateKVCache<T, TCache, Dh_MAX, ADD_BIAS, STORE_QKV, KVCacheBuffer, \
|
||||
RotaryPositionEmbeddingType::GPT_NEOX, DYNAMIC_ROTARY_SCALING, FP8_OUTPUT> \
|
||||
@ -946,7 +959,8 @@ void kernelDispatchHeadSize(QKVPreprocessingParams<T, KVCacheBuffer> params, cud
|
||||
constexpr int VEC_SIZE = Rotary_vec_t<T, Dh_MAX>::size;
|
||||
// Make sure we have multiple of paired vectors so that the access is aligned.
|
||||
TLLM_CHECK_WITH_INFO((params.position_embedding_type != PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
&& params.position_embedding_type != PositionEmbeddingType::kLONG_ROPE)
|
||||
&& params.position_embedding_type != PositionEmbeddingType::kLONG_ROPE
|
||||
&& params.position_embedding_type == PositionEmbeddingType::kROPE_M)
|
||||
|| params.half_rotary_dim % VEC_SIZE == 0,
|
||||
"Rotary dim size is not supported.");
|
||||
|
||||
@ -1002,7 +1016,8 @@ void kernelV1Dispatch(QKVPreprocessingParams<T, KVCacheBuffer> params, cudaStrea
|
||||
dim3 block(BLOCK_SIZE); \
|
||||
dim3 grid(int(divUp(params.max_input_seq_len, tokens_per_cuda_block)), params.batch_size, params.head_num); \
|
||||
if (params.position_embedding_type == PositionEmbeddingType::kROPE_GPT_NEOX \
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE) \
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kLONG_ROPE \
|
||||
|| params.position_embedding_type == PositionEmbeddingType::kROPE_M) \
|
||||
{ \
|
||||
applyBiasRopeUpdateKVCacheV2<T, TCache, BLOCK_SIZE, Dh, ADD_BIAS, STORE_QKV, FP8_OUTPUT, KVCacheBuffer, \
|
||||
RotaryPositionEmbeddingType::GPT_NEOX><<<grid, block, 0, stream>>>(params); \
|
||||
@ -1279,7 +1294,8 @@ void invokeApplyBiasRopeUpdateKVCacheDispatch(QKVPreprocessingParams<T, KVCacheB
|
||||
bool const has_sink_tokens = params.sink_token_len > 0;
|
||||
// V2 implementation requires multiple of paired 16 bytes for gpt-neox rotation.
|
||||
bool const support_rotary_for_v2 = (params.position_embedding_type != PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
&& params.position_embedding_type != PositionEmbeddingType::kLONG_ROPE)
|
||||
&& params.position_embedding_type != PositionEmbeddingType::kLONG_ROPE
|
||||
&& params.position_embedding_type == PositionEmbeddingType::kROPE_M)
|
||||
|| params.rotary_embedding_dim % 16 == 0;
|
||||
|
||||
// Use v2 kernel for absolute_position_embedding.
|
||||
|
||||
@ -37,14 +37,20 @@ namespace tensorrt_llm::layers
|
||||
|
||||
template <typename T>
|
||||
ExternalDraftTokensLayer<T>::ExternalDraftTokensLayer(executor::DecodingMode const& mode,
|
||||
DecoderDomain const& decoderDomain, std::shared_ptr<BufferManager> bufferManager)
|
||||
DecoderDomain const& decoderDomain, std::shared_ptr<BufferManager> bufferManager, bool isDeterministic,
|
||||
bool isAirTopP)
|
||||
: BaseLayer(decoderDomain, bufferManager)
|
||||
, mDecodingMode(mode)
|
||||
, mIsDeterministic(isDeterministic)
|
||||
, mIsAirTopP(isAirTopP)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(!mDecodingMode.isBeamSearch(), "ExternalDraftTokensLayer does not support Beam search mode");
|
||||
|
||||
auto const deviceId = getDevice();
|
||||
TLLM_CUDA_CHECK(cudaGetDeviceProperties(&mDeviceProp, deviceId));
|
||||
|
||||
allocateBuffer(decoderDomain.getBatchSize());
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
@ -58,11 +64,15 @@ void ExternalDraftTokensLayer<T>::allocateBuffer(SizeType32 batchSize)
|
||||
// top k workspace size
|
||||
auto workspaceSize = getTopKWorkspaceSize<T>(batchSize, 1, TOP_K_MAX, mDecoderDomain.getVocabSizePadded());
|
||||
mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize);
|
||||
// top p workspace size
|
||||
workspaceSize = getTopPWorkspaceSize<T>(batchSize, mDecoderDomain.getVocabSizePadded());
|
||||
mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize);
|
||||
// multinomial (top p == 1) workspace size
|
||||
workspaceSize = getTopPWorkspaceSize<float>(batchSize, mDecoderDomain.getVocabSizePadded());
|
||||
// top p and multinomial (top p == 1) workspace size
|
||||
if (!mIsAirTopP)
|
||||
{
|
||||
workspaceSize = getTopPWorkspaceSize<T>(batchSize, mDecoderDomain.getVocabSizePadded());
|
||||
}
|
||||
else
|
||||
{
|
||||
workspaceSize = getAirTopPWorkspaceSize<T>(batchSize, mDecoderDomain.getVocabSizePadded(), mIsDeterministic);
|
||||
}
|
||||
mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize);
|
||||
|
||||
// batchsize here is maxBatchSize
|
||||
@ -169,6 +179,20 @@ void ExternalDraftTokensLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWid
|
||||
{topKsPtr, runtimeTopK.front(), bufferCast<SizeType32>(*mRuntimeTopKHost)}, {}, //
|
||||
skipDecodeTopKHostPtr, skipDecodeTopPHostPtr, batchSlotsHostPtr, false);
|
||||
|
||||
if (mIsAirTopP)
|
||||
{
|
||||
auto smCnt = mDeviceProp.multiProcessorCount;
|
||||
if (smCnt <= 0)
|
||||
{
|
||||
auto const deviceId = getDevice();
|
||||
cudaDeviceProp prop{};
|
||||
TLLM_CUDA_CHECK(cudaGetDeviceProperties(&prop, deviceId));
|
||||
smCnt = prop.multiProcessorCount;
|
||||
}
|
||||
mAirTopPBlockNum
|
||||
= calcAirTopPBlockNum<T>(batchSize, mDecoderDomain.getVocabSizePadded(), smCnt, mIsDeterministic);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -330,7 +354,17 @@ void ExternalDraftTokensLayer<T>::multinomialSampling(std::shared_ptr<BaseDecodi
|
||||
params.maxBatchSize = mDecoderDomain.getBatchSize();
|
||||
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
|
||||
|
||||
invokeBatchTopPSampling<T>(params, getStream());
|
||||
if (!mIsAirTopP)
|
||||
{
|
||||
invokeBatchTopPSampling<T>(params, getStream());
|
||||
}
|
||||
else
|
||||
{
|
||||
params.blockNum = mAirTopPBlockNum;
|
||||
params.isDeterministic = mIsDeterministic;
|
||||
invokeBatchAirTopPSampling<T>(params, getStream());
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -423,7 +457,17 @@ void ExternalDraftTokensLayer<T>::getAllTopPs(std::shared_ptr<BaseDecodingOutput
|
||||
params.returnAllSelectedTokens = true;
|
||||
params.maxSeqLen = mDecoderDomain.getVocabSizePadded();
|
||||
|
||||
invokeBatchTopPSampling<T>(params, getStream());
|
||||
if (!mIsAirTopP)
|
||||
{
|
||||
invokeBatchTopPSampling<T>(params, getStream());
|
||||
}
|
||||
else
|
||||
{
|
||||
params.blockNum = mAirTopPBlockNum;
|
||||
params.isDeterministic = mIsDeterministic;
|
||||
invokeBatchAirTopPSampling<T>(params, getStream());
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
|
||||
@ -35,7 +35,7 @@ public:
|
||||
using Base = BaseLayer;
|
||||
|
||||
ExternalDraftTokensLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain,
|
||||
std::shared_ptr<runtime::BufferManager> bufferManager);
|
||||
std::shared_ptr<runtime::BufferManager> bufferManager, bool isDeterministic = true, bool isAirTopP = true);
|
||||
|
||||
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, TensorConstPtr batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> const& setupParams,
|
||||
@ -74,6 +74,12 @@ private:
|
||||
|
||||
TensorPtr mTargetLogits;
|
||||
|
||||
// AirTopP
|
||||
cudaDeviceProp mDeviceProp;
|
||||
runtime::SizeType32 mAirTopPBlockNum{0};
|
||||
bool mIsDeterministic{true};
|
||||
bool mIsAirTopP{false};
|
||||
|
||||
private:
|
||||
void allocateBuffer(runtime::SizeType32 batchSize);
|
||||
void acceptDraftTokens(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
|
||||
@ -54,11 +54,12 @@ protected:
|
||||
|
||||
TensorPtr mSkipDecodeDevice;
|
||||
TensorPtr mSkipDecodeHost;
|
||||
runtime::SizeType32 mAirTopPBlockNum{0};
|
||||
size_t mWorkspaceSize{0};
|
||||
size_t mSetupWorkspaceSize{0};
|
||||
|
||||
// AirTopP
|
||||
cudaDeviceProp mDeviceProp;
|
||||
runtime::SizeType32 mAirTopPBlockNum{0};
|
||||
bool mIsDeterministic{true};
|
||||
bool mIsAirTopP{false};
|
||||
|
||||
|
||||
@ -58,7 +58,8 @@ set(PLUGIN_LISTS
|
||||
lowLatencyGemmPlugin
|
||||
eaglePlugin
|
||||
lowLatencyGemmSwigluPlugin
|
||||
qserveGemmPlugin)
|
||||
qserveGemmPlugin
|
||||
cudaStreamPlugin)
|
||||
|
||||
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
|
||||
include_directories(${PLUGIN_ITER})
|
||||
|
||||
@ -38,6 +38,7 @@
|
||||
#include "tensorrt_llm/plugins/ncclPlugin/reduceScatterPlugin.h"
|
||||
#include "tensorrt_llm/plugins/ncclPlugin/sendPlugin.h"
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
#include "tensorrt_llm/plugins/cudaStreamPlugin/cudaStreamPlugin.h"
|
||||
#include "tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.h"
|
||||
#include "tensorrt_llm/plugins/eaglePlugin/eagleDecodeDraftTokensPlugin.h"
|
||||
#include "tensorrt_llm/plugins/eaglePlugin/eaglePrepareDrafterInputsPlugin.h"
|
||||
@ -234,6 +235,7 @@ extern "C"
|
||||
static tensorrt_llm::plugins::EagleDecodeDraftTokensPluginCreator eagleDecodeDraftTokensPluginCreator;
|
||||
static tensorrt_llm::plugins::EagleSampleAndAcceptDraftTokensPluginCreator
|
||||
eagleSampleAndAcceptDraftTokensPluginCreator;
|
||||
static tensorrt_llm::plugins::CudaStreamPluginCreator cudaStreamPluginCreator;
|
||||
|
||||
static std::array pluginCreators
|
||||
= { creatorPtr(identityPluginCreator),
|
||||
@ -268,7 +270,9 @@ extern "C"
|
||||
creatorPtr(lowLatencyGemmPluginCreator),
|
||||
creatorPtr(eagleDecodeDraftTokensPluginCreator),
|
||||
creatorPtr(eagleSampleAndAcceptDraftTokensPluginCreator),
|
||||
creatorPtr(lowLatencyGemmSwigluPluginCreator) };
|
||||
creatorPtr(lowLatencyGemmSwigluPluginCreator),
|
||||
creatorPtr(cudaStreamPluginCreator),
|
||||
};
|
||||
nbCreators = pluginCreators.size();
|
||||
return pluginCreators.data();
|
||||
}
|
||||
|
||||
21
cpp/tensorrt_llm/plugins/cudaStreamPlugin/CMakeLists.txt
Normal file
21
cpp/tensorrt_llm/plugins/cudaStreamPlugin/CMakeLists.txt
Normal file
@ -0,0 +1,21 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
#
|
||||
file(GLOB SRCS *.cpp)
|
||||
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
|
||||
set(PLUGIN_SOURCES
|
||||
${PLUGIN_SOURCES}
|
||||
PARENT_SCOPE)
|
||||
294
cpp/tensorrt_llm/plugins/cudaStreamPlugin/cudaStreamPlugin.cpp
Normal file
294
cpp/tensorrt_llm/plugins/cudaStreamPlugin/cudaStreamPlugin.cpp
Normal file
@ -0,0 +1,294 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cudaStreamPlugin.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
using namespace nvinfer1;
|
||||
using tensorrt_llm::plugins::CudaStreamPluginCreator;
|
||||
using tensorrt_llm::plugins::CudaStreamPlugin;
|
||||
|
||||
static char const* CUDA_STREAM_PLUGIN_VERSION{"1"};
|
||||
static char const* CUDA_STREAM_PLUGIN_NAME{"CudaStream"};
|
||||
PluginFieldCollection CudaStreamPluginCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField> CudaStreamPluginCreator::mPluginAttributes;
|
||||
|
||||
CudaStreamPlugin::CudaStreamPlugin(int sideStreamId, int nbInputs, nvinfer1::DataType type)
|
||||
: mSideStreamId(sideStreamId)
|
||||
, mNbInputs(nbInputs)
|
||||
, mType(type)
|
||||
{
|
||||
init();
|
||||
}
|
||||
|
||||
CudaStreamPlugin::CudaStreamPlugin(void const* data, size_t length)
|
||||
{
|
||||
char const *d = reinterpret_cast<char const*>(data), *a = d;
|
||||
read(d, mSideStreamId);
|
||||
read(d, mNbInputs);
|
||||
read(d, mType);
|
||||
|
||||
init();
|
||||
|
||||
TLLM_CHECK_WITH_INFO(d == a + length,
|
||||
"Expected length (%d) != real length (%d). This is often "
|
||||
"caused by using different TensorRT-LLM version to build "
|
||||
"engine and run engine.",
|
||||
(int) length, (int) (d - a));
|
||||
}
|
||||
|
||||
CudaStreamPlugin::CudaStreamPlugin(CudaStreamPlugin const& other)
|
||||
: mSideStreamId(other.mSideStreamId)
|
||||
, mNbInputs(other.mNbInputs)
|
||||
, mType(other.mType)
|
||||
{
|
||||
init();
|
||||
}
|
||||
|
||||
void CudaStreamPlugin::init()
|
||||
{
|
||||
mSideStreamPtr = nullptr;
|
||||
}
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* CudaStreamPlugin::clone() const noexcept
|
||||
{
|
||||
auto* plugin = new CudaStreamPlugin(*this);
|
||||
plugin->setPluginNamespace(mNamespace.c_str());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs CudaStreamPlugin::getOutputDimensions(
|
||||
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
|
||||
{
|
||||
assert(outputIndex == 0);
|
||||
return inputs[outputIndex];
|
||||
}
|
||||
|
||||
bool CudaStreamPlugin::supportsFormatCombination(
|
||||
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(nbInputs == mNbInputs, "CudaStreamPlugin only accepts mNbInputs inputs");
|
||||
TLLM_CHECK_WITH_INFO(nbOutputs == 1, "CudaStreamPlugin only accepts 1 output");
|
||||
|
||||
auto const& desc = inOut[pos];
|
||||
if (desc.format != TensorFormat::kLINEAR)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (pos > 0 && pos < nbInputs)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
return desc.type == mType;
|
||||
}
|
||||
|
||||
void CudaStreamPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
|
||||
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
|
||||
{
|
||||
}
|
||||
|
||||
size_t CudaStreamPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
|
||||
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
int CudaStreamPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
|
||||
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
|
||||
{
|
||||
if (!mSideStreamPtr)
|
||||
{
|
||||
auto const resource_name = nvinfer1::pluginInternal::SideStream::getResourceKey(mSideStreamId);
|
||||
nvinfer1::pluginInternal::SideStream side_stream{};
|
||||
mSideStreamPtr = reinterpret_cast<nvinfer1::pluginInternal::SideStream*>(
|
||||
getPluginRegistry()->acquirePluginResource(resource_name, &side_stream));
|
||||
}
|
||||
mSideStreamPtr->waitSideStreamOnMainStream(stream);
|
||||
size_t count = 1;
|
||||
for (int i = 0; i < inputDesc[0].dims.nbDims; ++i)
|
||||
{
|
||||
count *= inputDesc[0].dims.d[i];
|
||||
}
|
||||
count *= tensorrt_llm::runtime::BufferDataType(inputDesc[0].type).getSize();
|
||||
TLLM_CUDA_CHECK(cudaMemcpyAsync(outputs[0], inputs[0], count, cudaMemcpyDeviceToDevice, stream));
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType CudaStreamPlugin::getOutputDataType(
|
||||
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
|
||||
{
|
||||
TLLM_CHECK(index == 0);
|
||||
return mType;
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
|
||||
char const* CudaStreamPlugin::getPluginType() const noexcept
|
||||
{
|
||||
return CUDA_STREAM_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
char const* CudaStreamPlugin::getPluginVersion() const noexcept
|
||||
{
|
||||
return CUDA_STREAM_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int CudaStreamPlugin::getNbOutputs() const noexcept
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
int CudaStreamPlugin::initialize() noexcept
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
void CudaStreamPlugin::terminate() noexcept
|
||||
{
|
||||
if (mSideStreamPtr)
|
||||
{
|
||||
auto const resource_name = nvinfer1::pluginInternal::SideStream::getResourceKey(mSideStreamId);
|
||||
getPluginRegistry()->releasePluginResource(resource_name);
|
||||
mSideStreamPtr = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
size_t CudaStreamPlugin::getSerializationSize() const noexcept
|
||||
{
|
||||
return sizeof(mSideStreamId) + sizeof(mNbInputs) + sizeof(mType);
|
||||
}
|
||||
|
||||
void CudaStreamPlugin::serialize(void* buffer) const noexcept
|
||||
{
|
||||
char *d = static_cast<char*>(buffer), *a = d;
|
||||
write(d, mSideStreamId);
|
||||
write(d, mNbInputs);
|
||||
write(d, mType);
|
||||
assert(d == a + getSerializationSize());
|
||||
}
|
||||
|
||||
void CudaStreamPlugin::destroy() noexcept
|
||||
{
|
||||
delete this;
|
||||
}
|
||||
|
||||
///////////////
|
||||
|
||||
CudaStreamPluginCreator::CudaStreamPluginCreator()
|
||||
{
|
||||
// Fill PluginFieldCollection with PluginField arguments metadata
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(PluginField("side_stream_id", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("num_inputs", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
char const* CudaStreamPluginCreator::getPluginName() const noexcept
|
||||
{
|
||||
return CUDA_STREAM_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
char const* CudaStreamPluginCreator::getPluginVersion() const noexcept
|
||||
{
|
||||
return CUDA_STREAM_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
PluginFieldCollection const* CudaStreamPluginCreator::getFieldNames() noexcept
|
||||
{
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
IPluginV2* CudaStreamPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
|
||||
{
|
||||
PluginField const* fields = fc->fields;
|
||||
int sideStreamId;
|
||||
int nbInputs;
|
||||
int type;
|
||||
|
||||
// Read configurations from each fields
|
||||
struct MapPair
|
||||
{
|
||||
char const* key;
|
||||
int& field;
|
||||
bool optional = false;
|
||||
bool set = false;
|
||||
};
|
||||
|
||||
std::array input_map{
|
||||
MapPair{"side_stream_id", std::ref(sideStreamId)},
|
||||
MapPair{"num_inputs", std::ref(nbInputs)},
|
||||
MapPair{"type_id", std::ref(type)},
|
||||
};
|
||||
bool typeSet = false;
|
||||
for (int i = 0; i < fc->nbFields; ++i)
|
||||
{
|
||||
char const* attrName = fields[i].name;
|
||||
for (auto& item : input_map)
|
||||
{
|
||||
if (!strcmp(item.key, attrName))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == nvinfer1::PluginFieldType::kINT32);
|
||||
TLLM_CHECK_WITH_INFO(!item.set, "Parameter %s was set twice", item.key);
|
||||
item.field = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
item.set = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& item : input_map)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(item.set || item.optional, "Parameter %s is required but not set", item.key);
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
auto* obj = new CudaStreamPlugin(sideStreamId, nbInputs, static_cast<nvinfer1::DataType>(type));
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
caughtError(e);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
IPluginV2* CudaStreamPluginCreator::deserializePlugin(
|
||||
char const* name, void const* serialData, size_t serialLength) noexcept
|
||||
{
|
||||
// This object will be deleted when the network is destroyed, which will
|
||||
// call CudaStreamPlugin::destroy()
|
||||
try
|
||||
{
|
||||
auto* obj = new CudaStreamPlugin(serialData, serialLength);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
caughtError(e);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
266
cpp/tensorrt_llm/plugins/cudaStreamPlugin/cudaStreamPlugin.h
Normal file
266
cpp/tensorrt_llm/plugins/cudaStreamPlugin/cudaStreamPlugin.h
Normal file
@ -0,0 +1,266 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "NvInferPlugin.h"
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
#include "tensorrt_llm/runtime/cudaMemPool.h"
|
||||
#include "tensorrt_llm/runtime/utils/debugUtils.h"
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace nvinfer1
|
||||
{
|
||||
namespace pluginInternal
|
||||
{
|
||||
class SideWorkspace
|
||||
{
|
||||
public:
|
||||
SideWorkspace(cudaStream_t stream)
|
||||
: mWorkspaceSize{0}
|
||||
, mWorkspacePtr{nullptr}
|
||||
, mStream{stream}
|
||||
{
|
||||
}
|
||||
|
||||
~SideWorkspace()
|
||||
{
|
||||
if (mWorkspacePtr)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaFreeAsync(mWorkspacePtr, mStream));
|
||||
}
|
||||
}
|
||||
|
||||
void* get(size_t workspaceSize)
|
||||
{
|
||||
if (mWorkspacePtr && mWorkspaceSize < workspaceSize)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaFreeAsync(mWorkspacePtr, mStream));
|
||||
mWorkspacePtr = nullptr;
|
||||
}
|
||||
if (!mWorkspacePtr)
|
||||
{
|
||||
mWorkspaceSize = workspaceSize;
|
||||
auto pool_ptr
|
||||
= tensorrt_llm::runtime::CudaMemPool::getPrimaryPoolForDevice(tensorrt_llm::common::getDevice());
|
||||
TLLM_CUDA_CHECK(cudaMallocFromPoolAsync(&mWorkspacePtr, mWorkspaceSize, pool_ptr->getPool(), mStream));
|
||||
}
|
||||
return mWorkspacePtr;
|
||||
}
|
||||
|
||||
private:
|
||||
size_t mWorkspaceSize;
|
||||
void* mWorkspacePtr;
|
||||
cudaStream_t mStream;
|
||||
};
|
||||
|
||||
class SideStream : public IPluginResource
|
||||
{
|
||||
public:
|
||||
SideStream(bool init = false)
|
||||
: mStream{}
|
||||
, mMainEvent{}
|
||||
, mSideEvent{}
|
||||
, mWorkspace{}
|
||||
, mInit{init}
|
||||
{
|
||||
// The object passed to acquirePluginResource should use the default value init=false
|
||||
if (init)
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaStreamCreate(&mStream));
|
||||
TLLM_CUDA_CHECK(cudaEventCreateWithFlags(&mMainEvent, cudaEventDisableTiming));
|
||||
TLLM_CUDA_CHECK(cudaEventCreateWithFlags(&mSideEvent, cudaEventDisableTiming));
|
||||
mWorkspace = std::make_shared<SideWorkspace>(mStream);
|
||||
}
|
||||
}
|
||||
|
||||
void free()
|
||||
{
|
||||
if (mInit)
|
||||
{
|
||||
mWorkspace = nullptr;
|
||||
TLLM_CUDA_CHECK(cudaStreamSynchronize(mStream));
|
||||
TLLM_CUDA_CHECK(cudaStreamDestroy(mStream));
|
||||
TLLM_CUDA_CHECK(cudaEventDestroy(mMainEvent));
|
||||
TLLM_CUDA_CHECK(cudaEventDestroy(mSideEvent));
|
||||
mInit = false;
|
||||
}
|
||||
}
|
||||
|
||||
int32_t release() noexcept override
|
||||
{
|
||||
try
|
||||
{
|
||||
free();
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
return -1;
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
IPluginResource* clone() noexcept override
|
||||
{
|
||||
// An object is cloned only when calling acquirePluginResource for the first time for each key
|
||||
std::unique_ptr<SideStream> cloned{};
|
||||
try
|
||||
{
|
||||
if (!mInit)
|
||||
{
|
||||
cloned = std::make_unique<SideStream>(/* init */ true);
|
||||
}
|
||||
else
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
return nullptr;
|
||||
}
|
||||
return cloned.release();
|
||||
}
|
||||
|
||||
~SideStream() override
|
||||
{
|
||||
free();
|
||||
}
|
||||
|
||||
void* getWorkspacePtr(size_t workspaceSize)
|
||||
{
|
||||
return mWorkspace->get(workspaceSize);
|
||||
}
|
||||
|
||||
cudaStream_t getStream() const
|
||||
{
|
||||
return mStream;
|
||||
}
|
||||
|
||||
void waitMainStreamOnSideStream(cudaStream_t const stream) const
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaEventRecord(mMainEvent, stream));
|
||||
TLLM_CUDA_CHECK(cudaStreamWaitEvent(mStream, mMainEvent));
|
||||
}
|
||||
|
||||
void waitSideStreamOnMainStream(cudaStream_t const stream) const
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaEventRecord(mSideEvent, mStream));
|
||||
TLLM_CUDA_CHECK(cudaStreamWaitEvent(stream, mSideEvent));
|
||||
}
|
||||
|
||||
void stallMainStream(char const* name, cudaStream_t const stream, std::optional<int> delay = std::nullopt) const
|
||||
{
|
||||
tensorrt_llm::runtime::utils::stallStream(name, stream, delay);
|
||||
}
|
||||
|
||||
void stallSideStream(char const* name, std::optional<int> delay = std::nullopt) const
|
||||
{
|
||||
tensorrt_llm::runtime::utils::stallStream(name, mStream, delay);
|
||||
}
|
||||
|
||||
static char const* getResourceKey(int const stream_id)
|
||||
{
|
||||
std::string keyString = "side_stream_" + std::to_string(stream_id);
|
||||
return keyString.c_str();
|
||||
}
|
||||
|
||||
private:
|
||||
cudaStream_t mStream;
|
||||
cudaEvent_t mMainEvent;
|
||||
cudaEvent_t mSideEvent;
|
||||
std::shared_ptr<SideWorkspace> mWorkspace;
|
||||
bool mInit;
|
||||
};
|
||||
|
||||
} // namespace pluginInternal
|
||||
} // namespace nvinfer1
|
||||
|
||||
namespace tensorrt_llm::plugins
|
||||
{
|
||||
|
||||
class CudaStreamPlugin : public BasePlugin
|
||||
{
|
||||
public:
|
||||
CudaStreamPlugin(int sideStreamId, int nbInputs, nvinfer1::DataType type);
|
||||
|
||||
CudaStreamPlugin(void const* data, size_t length);
|
||||
|
||||
CudaStreamPlugin(CudaStreamPlugin const&);
|
||||
|
||||
void init();
|
||||
|
||||
~CudaStreamPlugin() override = default;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
|
||||
bool supportsFormatCombination(
|
||||
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override;
|
||||
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
|
||||
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override;
|
||||
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
|
||||
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override;
|
||||
int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
|
||||
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(
|
||||
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
char const* getPluginType() const noexcept override;
|
||||
char const* getPluginVersion() const noexcept override;
|
||||
int getNbOutputs() const noexcept override;
|
||||
int initialize() noexcept override;
|
||||
void terminate() noexcept override;
|
||||
size_t getSerializationSize() const noexcept override;
|
||||
void serialize(void* buffer) const noexcept override;
|
||||
void destroy() noexcept override;
|
||||
|
||||
private:
|
||||
const std::string mLayerName;
|
||||
int mSideStreamId;
|
||||
int mNbInputs;
|
||||
nvinfer1::DataType mType;
|
||||
nvinfer1::pluginInternal::SideStream* mSideStreamPtr;
|
||||
};
|
||||
|
||||
class CudaStreamPluginCreator : public BaseCreator
|
||||
{
|
||||
public:
|
||||
CudaStreamPluginCreator();
|
||||
|
||||
char const* getPluginName() const noexcept override;
|
||||
|
||||
char const* getPluginVersion() const noexcept override;
|
||||
|
||||
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
|
||||
|
||||
nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
|
||||
|
||||
nvinfer1::IPluginV2* deserializePlugin(
|
||||
char const* name, void const* serialData, size_t serialLength) noexcept override;
|
||||
|
||||
private:
|
||||
static nvinfer1::PluginFieldCollection mFC;
|
||||
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::plugins
|
||||
@ -71,7 +71,7 @@ nvinfer1::DimsExprs EagleDecodeDraftTokensPlugin::getOutputDimensions(
|
||||
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
|
||||
{
|
||||
TLLM_CHECK(outputIndex < 2);
|
||||
TLLM_CHECK(nbInputs == 5);
|
||||
TLLM_CHECK(nbInputs == 6);
|
||||
auto const batchSizeExpr = inputs[getIdx(InputIdxEntry::PATHS)].d[0];
|
||||
auto const maxDecodingTokensExpr = inputs[getIdx(InputIdxEntry::PATHS)].d[1];
|
||||
auto const maxDecodingDraftTokensExpr
|
||||
@ -102,7 +102,7 @@ nvinfer1::DimsExprs EagleDecodeDraftTokensPlugin::getOutputDimensions(
|
||||
bool EagleDecodeDraftTokensPlugin::supportsFormatCombination(
|
||||
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
|
||||
{
|
||||
TLLM_CHECK(nbInputs == 5 && nbOutputs == 2);
|
||||
TLLM_CHECK(nbInputs == 6 && nbOutputs == 2);
|
||||
TLLM_CHECK(pos < nbInputs + nbOutputs);
|
||||
|
||||
if (pos == getIdx(InputIdxEntry::LOGITS)) // logits (input)
|
||||
@ -164,7 +164,12 @@ size_t EagleDecodeDraftTokensPlugin::getWorkspaceSizeType(nvinfer1::PluginTensor
|
||||
// [batchSize * maxDecodingTokens]
|
||||
auto const numSuccessorsForEachNodeSize = batchSize * maxDecodingTokens * sizeof(SizeType32);
|
||||
|
||||
SizeType32 constexpr NUM_BUFFERS{7};
|
||||
// 7. Flag whether to do decoding or not. SamplingTopK is done for numInputLogits tokens.
|
||||
// But only sum(numValidLogitsPerRequest[:]) of them are valid.
|
||||
// [batchSize * maxDecodingTokens]
|
||||
auto const skipDecodeSize = numInputLogits * sizeof(bool);
|
||||
|
||||
SizeType32 constexpr NUM_BUFFERS{8};
|
||||
size_t workspaces[NUM_BUFFERS];
|
||||
workspaces[0] = draftTokenSamplingWorkspaceSize;
|
||||
workspaces[1] = topKsSize;
|
||||
@ -173,6 +178,7 @@ size_t EagleDecodeDraftTokensPlugin::getWorkspaceSizeType(nvinfer1::PluginTensor
|
||||
workspaces[4] = outputIdsPtrsSize;
|
||||
workspaces[5] = outputIdsSize;
|
||||
workspaces[6] = numSuccessorsForEachNodeSize;
|
||||
workspaces[7] = skipDecodeSize;
|
||||
workspaceSize = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
|
||||
}
|
||||
else
|
||||
@ -221,6 +227,7 @@ void EagleDecodeDraftTokensPlugin::doTopKSampling(nvinfer1::PluginTensorDesc con
|
||||
|
||||
// Plugin inputs
|
||||
auto logits = static_cast<T const*>(inputs[getIdx(InputIdxEntry::LOGITS)]);
|
||||
auto numValidLogits = static_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::NUM_VALID_LOGITS)]);
|
||||
auto randSample = static_cast<float const*>(inputs[getIdx(InputIdxEntry::RAND_SAMPLE)]);
|
||||
auto paths = static_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::PATHS)]);
|
||||
auto pluginInputDraftTokenIdsPtrs
|
||||
@ -269,12 +276,16 @@ void EagleDecodeDraftTokensPlugin::doTopKSampling(nvinfer1::PluginTensorDesc con
|
||||
SizeType32* numSuccessorsForEachNode = reinterpret_cast<SizeType32*>(
|
||||
tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * maxDecodingTokens * sizeof(SizeType32)));
|
||||
|
||||
// Workspace 7: skip decoding mask [numInputLogits]
|
||||
bool* skipDecode
|
||||
= reinterpret_cast<bool*>(tc::nextWorkspacePtr(workspaceBytePtr, offset, numInputLogits * sizeof(bool)));
|
||||
|
||||
// Fill logitsPtrs from logits, fill outputIdsPtrs from outputIdsFlatten and fill decodingTokens
|
||||
invokeAssembleDraftLogitsOffsets(logitsPtrs, logits, outputIdsPtrs, outputIdsFlatten, numInputLogits,
|
||||
maxDecodingDraftTokens, vocabSizePadded, stream);
|
||||
invokeAssembleDraftLogitsOffsets(logitsPtrs, logits, outputIdsPtrs, outputIdsFlatten, skipDecode, numValidLogits,
|
||||
numInputLogits, batchSize, maxDecodingDraftTokens, vocabSizePadded, stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
invokeExtractTopKsFromPath(paths, topKs, topKOffset, numSuccessorsForEachNode, mLayerIdx, batchSize, numInputLogits,
|
||||
invokeExtractTopKsFromPath(paths, topKs, topKOffset, numSuccessorsForEachNode, mLayerIdx, batchSize,
|
||||
maxDecodingTokens, maxPathLen, stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
@ -289,13 +300,14 @@ void EagleDecodeDraftTokensPlugin::doTopKSampling(nvinfer1::PluginTensorDesc con
|
||||
params.maxTokensPerStep = 1;
|
||||
params.vocabSizePadded = vocabSizePadded;
|
||||
params.returnAllSelectedTokens = true;
|
||||
params.skipDecode = skipDecode;
|
||||
|
||||
invokeBatchTopKSampling(params, stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
// Copy output token id from outputIdsPtrs to the plugin output buffer
|
||||
invokeCopyOutputTokensIds(outputIdsPtrs, topKs, topKOffset, pluginInputDraftTokenIdsPtrs, pluginInputDraftLens,
|
||||
pluginOutputDraftTokenIdsPtrs, pluginOutputDraftLens, mLayerIdx, batchSize, numInputLogits,
|
||||
numValidLogits, pluginOutputDraftTokenIdsPtrs, pluginOutputDraftLens, mLayerIdx, batchSize,
|
||||
maxDecodingDraftTokens, stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
|
||||
@ -70,6 +70,8 @@ private:
|
||||
RAND_SAMPLE,
|
||||
// [batch_size, max_decoding_tokens, max_path_len]
|
||||
PATHS,
|
||||
// [1]
|
||||
NUM_VALID_LOGITS,
|
||||
|
||||
// [batch_size, max_decoding_draft_tokens]
|
||||
INPUT_DRAFT_TOKEN_IDS,
|
||||
|
||||
@ -36,8 +36,11 @@ static char const* EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_NAME{"EaglePrepareDrafter
|
||||
PluginFieldCollection EaglePrepareDrafterInputsPluginCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField> EaglePrepareDrafterInputsPluginCreator::mPluginAttributes;
|
||||
|
||||
EaglePrepareDrafterInputsPlugin::EaglePrepareDrafterInputsPlugin(int32_t layerIdx)
|
||||
EaglePrepareDrafterInputsPlugin::EaglePrepareDrafterInputsPlugin(
|
||||
int32_t layerIdx, int32_t numLayers, int32_t maxNonLeavesPerLayer)
|
||||
: mLayerIdx(layerIdx)
|
||||
, mNumLayers(numLayers)
|
||||
, mMaxNonLeavesPerLayer(maxNonLeavesPerLayer)
|
||||
{
|
||||
}
|
||||
|
||||
@ -45,6 +48,9 @@ void EaglePrepareDrafterInputsPlugin::initFieldsToSerialize()
|
||||
{
|
||||
mDataToSerialize.clear();
|
||||
mDataToSerialize.emplace_back(PluginField("layer_idx", &mLayerIdx, PluginFieldType::kINT32, 1));
|
||||
mDataToSerialize.emplace_back(PluginField("num_layers", &mNumLayers, PluginFieldType::kINT32, 1));
|
||||
mDataToSerialize.emplace_back(
|
||||
PluginField("max_non_leaves_per_layer", &mMaxNonLeavesPerLayer, PluginFieldType::kINT32, 1));
|
||||
mFCToSerialize.nbFields = mDataToSerialize.size();
|
||||
mFCToSerialize.fields = mDataToSerialize.data();
|
||||
}
|
||||
@ -99,7 +105,7 @@ char const* EaglePrepareDrafterInputsPlugin::getPluginNamespace() const noexcept
|
||||
// IPluginV3OneBuild methods
|
||||
int32_t EaglePrepareDrafterInputsPlugin::getNbOutputs() const noexcept
|
||||
{
|
||||
return 12;
|
||||
return 11;
|
||||
}
|
||||
|
||||
int32_t EaglePrepareDrafterInputsPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int32_t nbInputs,
|
||||
@ -136,11 +142,13 @@ int32_t EaglePrepareDrafterInputsPlugin::getOutputShapes(nvinfer1::DimsExprs con
|
||||
nvinfer1::DimsExprs const* shapeInputs, int32_t nbShapeInputs, nvinfer1::DimsExprs* outputs, int32_t nbOutputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) noexcept
|
||||
{
|
||||
TLLM_CHECK(nbOutputs == 12);
|
||||
TLLM_CHECK(nbInputs == 12);
|
||||
TLLM_CHECK(nbOutputs == 11);
|
||||
TLLM_CHECK(nbInputs == 14);
|
||||
TLLM_CHECK(nbShapeInputs == 0);
|
||||
auto const numTokens = inputs[getIdx(InputIdxEntry::INPUT_IDS)].d[0];
|
||||
auto const batchSizeExpr = inputs[getIdx(InputIdxEntry::PREV_DRAFT_PATHS)].d[0];
|
||||
auto const numGenRequestsExpr = inputs[getIdx(InputIdxEntry::SPEC_DECODING_GENERATION_LENGTHS)].d[0];
|
||||
auto const numInputGenTokensExpr = inputs[getIdx(InputIdxEntry::INPUT_GEN_TOKENS)].d[0];
|
||||
auto const maxDecodingLenExpr = inputs[getIdx(InputIdxEntry::PREV_DRAFT_PATHS)].d[1];
|
||||
auto const maxPathLenExpr = inputs[getIdx(InputIdxEntry::PREV_DRAFT_PATHS)].d[2];
|
||||
|
||||
@ -171,14 +179,29 @@ int32_t EaglePrepareDrafterInputsPlugin::getOutputShapes(nvinfer1::DimsExprs con
|
||||
|| outputIndex == getIdx(OutputIdxEntry::HIDDEN_STATES_INDICES)
|
||||
|| (mLayerIdx == 0 && outputIndex == getIdx(OutputIdxEntry::POSITION_IDS)))
|
||||
{
|
||||
auto optValue = exprBuilder.operation(DimensionOperation::kPROD, *maxDecodingLenExpr, *batchSizeExpr);
|
||||
auto upperBound = exprBuilder.operation(DimensionOperation::kSUM, *optValue, *numTokens);
|
||||
|
||||
SizeType32 outSizeIndex = getIdx(OutputIdxEntry::NUM_OUTPUT_TOKENS);
|
||||
auto outSizeTensor = exprBuilder.declareSizeTensor(outSizeIndex, *optValue, *upperBound);
|
||||
|
||||
outputs[outputIndex].nbDims = 1;
|
||||
outputs[outputIndex].d[0] = outSizeTensor;
|
||||
if (mLayerIdx == 0)
|
||||
{
|
||||
// We have at most numGenRequests * (mNumLayers + 1) accepted tokens per step for gen requests and
|
||||
// input_ids - numGenTokens tokens for context requests.
|
||||
auto numOutputGenTokensExpr = exprBuilder.operation(
|
||||
DimensionOperation::kPROD, *numGenRequestsExpr, *exprBuilder.constant(mNumLayers + 1));
|
||||
auto numInputCtxTokensExpr
|
||||
= exprBuilder.operation(DimensionOperation::kSUB, *numTokens, *numInputGenTokensExpr);
|
||||
outputs[outputIndex].nbDims = 1;
|
||||
outputs[outputIndex].d[0] = exprBuilder.operation(DimensionOperation::kMAX, *exprBuilder.constant(1),
|
||||
*exprBuilder.operation(DimensionOperation::kSUM, *numOutputGenTokensExpr, *numInputCtxTokensExpr));
|
||||
}
|
||||
else
|
||||
{
|
||||
// At most we have mMaxNonLeavesPerLayer non-leaves at this layer.
|
||||
// And in total we pass all non-leaves + all their preceding nodes.
|
||||
// batchSize * mMaxNonLeavesPerLayer * layerIdx
|
||||
outputs[outputIndex].nbDims = 1;
|
||||
outputs[outputIndex].d[0] = exprBuilder.operation(DimensionOperation::kPROD,
|
||||
*exprBuilder.operation(DimensionOperation::kPROD, *exprBuilder.constant(mLayerIdx),
|
||||
*exprBuilder.constant(mMaxNonLeavesPerLayer)),
|
||||
*batchSizeExpr);
|
||||
}
|
||||
}
|
||||
else if (mLayerIdx > 0 && outputIndex == getIdx(OutputIdxEntry::POSITION_IDS))
|
||||
{
|
||||
@ -187,20 +210,14 @@ int32_t EaglePrepareDrafterInputsPlugin::getOutputShapes(nvinfer1::DimsExprs con
|
||||
}
|
||||
else if (outputIndex == getIdx(OutputIdxEntry::LAST_TOKEN_INDICES))
|
||||
{
|
||||
auto upperBound = exprBuilder.operation(DimensionOperation::kPROD, *maxDecodingLenExpr, *batchSizeExpr);
|
||||
auto optValue = exprBuilder.operation(DimensionOperation::kCEIL_DIV, *upperBound, *exprBuilder.constant(2));
|
||||
|
||||
SizeType32 outSizeIndex = getIdx(OutputIdxEntry::NUM_LAST_TOKEN_INDICES);
|
||||
auto outSizeTensor = exprBuilder.declareSizeTensor(outSizeIndex, *optValue, *upperBound);
|
||||
|
||||
outputs[outputIndex].nbDims = 1;
|
||||
outputs[outputIndex].d[0] = outSizeTensor;
|
||||
outputs[outputIndex].d[0] = exprBuilder.operation(
|
||||
DimensionOperation::kPROD, *exprBuilder.constant(mMaxNonLeavesPerLayer), *batchSizeExpr);
|
||||
}
|
||||
else if (outputIndex == getIdx(OutputIdxEntry::NUM_OUTPUT_TOKENS)
|
||||
|| outputIndex == getIdx(OutputIdxEntry::NUM_LAST_TOKEN_INDICES))
|
||||
else if (outputIndex == getIdx(OutputIdxEntry::NUM_LAST_TOKEN_INDICES))
|
||||
{
|
||||
// size tensors must be declared as 0-D
|
||||
outputs[outputIndex].nbDims = 0;
|
||||
outputs[outputIndex].nbDims = 1;
|
||||
outputs[outputIndex].d[0] = exprBuilder.constant(1);
|
||||
}
|
||||
else if (outputIndex == getIdx(OutputIdxEntry::HIDDEN_SIZE_BATCH_LEVEL_STARTS))
|
||||
{
|
||||
@ -265,6 +282,11 @@ void EaglePrepareDrafterInputsPlugin::prepareCtxEagleNetData(nvinfer1::PluginTen
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputDesc[getIdx(InputIdxEntry::SEQUENCE_LENGTHS)].dims.d[0];
|
||||
|
||||
auto const numTokens = inputDesc[getIdx(InputIdxEntry::INPUT_IDS)].dims.d[0];
|
||||
auto const numGenRequests = inputDesc[getIdx(InputIdxEntry::SPEC_DECODING_GENERATION_LENGTHS)].dims.d[0];
|
||||
auto const numInputGenTokens = inputDesc[getIdx(InputIdxEntry::INPUT_GEN_TOKENS)].dims.d[0];
|
||||
|
||||
auto const maxPathLen = inputDesc[getIdx(InputIdxEntry::ACCEPTED_TOKENS)].dims.d[1];
|
||||
auto const maxDecodingTokens = inputDesc[getIdx(InputIdxEntry::NEXT_DRAFT_PATHS)].dims.d[1];
|
||||
|
||||
@ -274,7 +296,6 @@ void EaglePrepareDrafterInputsPlugin::prepareCtxEagleNetData(nvinfer1::PluginTen
|
||||
auto positionIds = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::POSITION_IDS)]);
|
||||
auto hiddenStatesIndices = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::HIDDEN_STATES_INDICES)]);
|
||||
auto lastTokenIndices = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::LAST_TOKEN_INDICES)]);
|
||||
auto numOutputTokens = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::NUM_OUTPUT_TOKENS)]);
|
||||
auto numLastTokenIndices = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::NUM_LAST_TOKEN_INDICES)]);
|
||||
auto hiddenSizeBatchLevelStarts
|
||||
= reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::HIDDEN_SIZE_BATCH_LEVEL_STARTS)]);
|
||||
@ -288,12 +309,14 @@ void EaglePrepareDrafterInputsPlugin::prepareCtxEagleNetData(nvinfer1::PluginTen
|
||||
auto prevPaths = reinterpret_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::PREV_DRAFT_PATHS)]);
|
||||
auto bestPathIds = reinterpret_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::ACCEPTED_PATHS)]);
|
||||
|
||||
invokePrepareCtxEagleNetInputs(eagleNetSequenceLengths, eagleNetContextLengths, outputIds, positionIds,
|
||||
hiddenStatesIndices, lastTokenIndices, numOutputTokens, numLastTokenIndices, hiddenSizeBatchLevelStarts,
|
||||
inputIds, baseNetSequenceLengths, baseNetContextLengths, acceptedTokens, acceptedLens, prevDraftLens, prevPaths,
|
||||
bestPathIds, batchSize, maxPathLen, maxDecodingTokens, stream);
|
||||
auto const numOutputTokens = (numTokens - numInputGenTokens) + (numGenRequests * (mNumLayers + 1));
|
||||
cudaMemsetAsync(positionIds, 0, numOutputTokens * sizeof(SizeType32), stream);
|
||||
cudaMemsetAsync(hiddenStatesIndices, 0, numOutputTokens * sizeof(SizeType32), stream);
|
||||
|
||||
auto const numTokens = inputDesc[getIdx(InputIdxEntry::INPUT_IDS)].dims.d[0];
|
||||
invokePrepareCtxEagleNetInputs(eagleNetSequenceLengths, eagleNetContextLengths, outputIds, positionIds,
|
||||
hiddenStatesIndices, lastTokenIndices, numLastTokenIndices, hiddenSizeBatchLevelStarts, inputIds,
|
||||
baseNetSequenceLengths, baseNetContextLengths, acceptedTokens, acceptedLens, prevDraftLens, prevPaths,
|
||||
bestPathIds, batchSize, maxPathLen, maxDecodingTokens, mMaxNonLeavesPerLayer, stream);
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
@ -322,7 +345,6 @@ void EaglePrepareDrafterInputsPlugin::prepareGenEagleNetData(nvinfer1::PluginTen
|
||||
= reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::SPEC_DECODING_PACKED_MASK)]);
|
||||
auto hiddenStatesIndices = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::HIDDEN_STATES_INDICES)]);
|
||||
auto lastTokenIndices = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::LAST_TOKEN_INDICES)]);
|
||||
auto numOutputTokens = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::NUM_OUTPUT_TOKENS)]);
|
||||
auto numLastTokenIndices = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::NUM_LAST_TOKEN_INDICES)]);
|
||||
auto outputHiddenSizeBatchStartsPerLevel
|
||||
= reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::HIDDEN_SIZE_BATCH_LEVEL_STARTS)]);
|
||||
@ -357,6 +379,7 @@ void EaglePrepareDrafterInputsPlugin::prepareGenEagleNetData(nvinfer1::PluginTen
|
||||
SizeType32* maxGenerationLength
|
||||
= reinterpret_cast<SizeType32*>(tc::nextWorkspacePtr(workspaceBytePtr, offset, 1 * sizeof(SizeType32)));
|
||||
|
||||
cudaMemsetAsync(hiddenStatesIndices, 0, batchSize * mMaxNonLeavesPerLayer * mLayerIdx * sizeof(SizeType32), stream);
|
||||
cudaMemsetAsync(selectedMasks, 0, batchSize * maxDecodingTokens * maxDecodingTokens * sizeof(int8_t), stream);
|
||||
// Prefill mask setting all to leaves.
|
||||
cudaMemsetAsync(isLeafMask, 1, batchSize * maxDecodingTokens * sizeof(int8_t), stream);
|
||||
@ -371,7 +394,6 @@ void EaglePrepareDrafterInputsPlugin::prepareGenEagleNetData(nvinfer1::PluginTen
|
||||
params.specDecodingPackedMasks = specDecodingPackedMasks;
|
||||
params.hiddenStatesIndices = hiddenStatesIndices;
|
||||
params.lastTokenIndices = lastTokenIndices;
|
||||
params.numOutputTokens = numOutputTokens;
|
||||
params.numLastTokenIndices = numLastTokenIndices;
|
||||
params.outputHiddenSizeBatchStartsPerLevel = outputHiddenSizeBatchStartsPerLevel;
|
||||
|
||||
@ -395,6 +417,7 @@ void EaglePrepareDrafterInputsPlugin::prepareGenEagleNetData(nvinfer1::PluginTen
|
||||
params.batchSize = batchSize;
|
||||
params.maxPathLen = maxPathLen;
|
||||
params.maxDecodingTokens = maxDecodingTokens;
|
||||
params.maxNonLeavesPerLayer = mMaxNonLeavesPerLayer;
|
||||
params.stream = stream;
|
||||
|
||||
params.checkParams();
|
||||
@ -459,7 +482,9 @@ EaglePrepareDrafterInputsPluginCreator::EaglePrepareDrafterInputsPluginCreator()
|
||||
{
|
||||
// Fill PluginFieldCollection with PluginField arguments metadata
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(PluginField("layer_idx", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("layer_idx", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("num_layers", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("max_non_leaves_per_layer", nullptr, PluginFieldType::kINT32, 1));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
@ -485,6 +510,8 @@ nvinfer1::IPluginV3* EaglePrepareDrafterInputsPluginCreator::createPlugin(
|
||||
try
|
||||
{
|
||||
int32_t layerIdx{0};
|
||||
int32_t numLayers{0};
|
||||
int32_t maxNonLeavesPerLayer{0};
|
||||
// Read configurations from each fields
|
||||
for (int i = 0; i < fc->nbFields; ++i)
|
||||
{
|
||||
@ -494,8 +521,18 @@ nvinfer1::IPluginV3* EaglePrepareDrafterInputsPluginCreator::createPlugin(
|
||||
TLLM_CHECK(fc->fields[i].type == PluginFieldType::kINT32);
|
||||
layerIdx = *static_cast<int32_t const*>(fc->fields[i].data);
|
||||
}
|
||||
else if (!strcmp(attrName, "num_layers"))
|
||||
{
|
||||
TLLM_CHECK(fc->fields[i].type == PluginFieldType::kINT32);
|
||||
numLayers = *static_cast<int32_t const*>(fc->fields[i].data);
|
||||
}
|
||||
else if (!strcmp(attrName, "max_non_leaves_per_layer"))
|
||||
{
|
||||
TLLM_CHECK(fc->fields[i].type == PluginFieldType::kINT32);
|
||||
maxNonLeavesPerLayer = *static_cast<int32_t const*>(fc->fields[i].data);
|
||||
}
|
||||
}
|
||||
return new EaglePrepareDrafterInputsPlugin(layerIdx);
|
||||
return new EaglePrepareDrafterInputsPlugin(layerIdx, numLayers, maxNonLeavesPerLayer);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
|
||||
@ -33,7 +33,7 @@ class EaglePrepareDrafterInputsPlugin : public nvinfer1::IPluginV3,
|
||||
public:
|
||||
EaglePrepareDrafterInputsPlugin(EaglePrepareDrafterInputsPlugin const& p) = default;
|
||||
|
||||
EaglePrepareDrafterInputsPlugin(int32_t layerIdx);
|
||||
EaglePrepareDrafterInputsPlugin(int32_t layerIdx, int32_t numLayers, int32_t maxNonLeavesPerLayer);
|
||||
|
||||
nvinfer1::IPluginV3* clone() noexcept override;
|
||||
|
||||
@ -98,6 +98,10 @@ private:
|
||||
PREV_DRAFT_PATHS,
|
||||
//! [(max_path_len - 1) * batch_size + 1]
|
||||
HIDDEN_SIZE_BATCH_LEVEL_STARTS,
|
||||
//! [num_gen_tokens]
|
||||
INPUT_GEN_TOKENS,
|
||||
//! [num_gen_requests]
|
||||
SPEC_DECODING_GENERATION_LENGTHS,
|
||||
};
|
||||
|
||||
enum class OutputIdxEntry : int32_t
|
||||
@ -112,17 +116,18 @@ private:
|
||||
SPEC_DECODING_POSITION_OFFSETS,
|
||||
//! [batchSize, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
|
||||
SPEC_DECODING_PACKED_MASK,
|
||||
//! [NUM_OUTPUT_TOKENS]
|
||||
//! [batchSize * mMaxNonLeavesPerLayer * layerIdx] for layerIdx > 0
|
||||
//! [num_tokens - numGenTokens + numGenRequests * (mNumLayers + 1)] for layerIdx == 0
|
||||
OUTPUT_IDS,
|
||||
//! [NUM_OUTPUT_TOKENS]
|
||||
//! [batchSize] for layerIdx > 0
|
||||
//! [num_tokens - numGenTokens + numGenRequests * (mNumLayers + 1)] for layerIdx == 0
|
||||
POSITION_IDS,
|
||||
//! [NUM_OUTPUT_TOKENS]
|
||||
//! [batchSize * mMaxNonLeavesPerLayer * layerIdx] for layerIdx > 0
|
||||
//! [num_tokens - numGenTokens + numGenRequests * (mNumLayers + 1)] for layerIdx == 0
|
||||
HIDDEN_STATES_INDICES,
|
||||
//! [NUM_LAST_TOKEN_INDICES]
|
||||
//! [batchSize * mMaxNonLeavesPerLayer]
|
||||
LAST_TOKEN_INDICES,
|
||||
//! [1]
|
||||
NUM_OUTPUT_TOKENS,
|
||||
//! [1]
|
||||
NUM_LAST_TOKEN_INDICES,
|
||||
//! [(max_path_len - 1) * batch_size + 1]
|
||||
HIDDEN_SIZE_BATCH_LEVEL_STARTS,
|
||||
@ -148,7 +153,9 @@ private:
|
||||
cudaStream_t stream) noexcept;
|
||||
|
||||
private:
|
||||
int32_t mLayerIdx;
|
||||
int32_t mLayerIdx{0};
|
||||
int32_t mNumLayers{0};
|
||||
int32_t mMaxNonLeavesPerLayer{0};
|
||||
std::vector<nvinfer1::PluginField> mDataToSerialize;
|
||||
nvinfer1::PluginFieldCollection mFCToSerialize;
|
||||
};
|
||||
|
||||
@ -116,6 +116,7 @@ struct FusedQKVMaskedAttentionDispatchParams
|
||||
int max_distance = 0;
|
||||
bool block_sparse_attention = false;
|
||||
BlockSparseParams block_sparse_params;
|
||||
int32_t const* mrope_position_deltas;
|
||||
};
|
||||
|
||||
template <typename T, typename KVCacheBuffer>
|
||||
@ -251,6 +252,9 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel
|
||||
= generationsParams.spec_decoding_is_generation_length_variable;
|
||||
xqaParams.spec_decoding_max_generation_length = generationsParams.spec_decoding_max_generation_length;
|
||||
|
||||
xqaParams.mrope_rotary_sin_cos = generationsParams.mrope_rotary_sin_cos;
|
||||
xqaParams.mrope_position_deltas = generationsParams.mrope_position_deltas;
|
||||
|
||||
xqaParams.total_num_input_tokens = generationsParams.total_num_input_tokens;
|
||||
xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr);
|
||||
return true;
|
||||
@ -376,6 +380,8 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
|
||||
|
||||
// cross attn
|
||||
params.memory_length_per_sample = input_params.memory_length_per_sample;
|
||||
|
||||
params.mrope_position_deltas = input_params.mrope_position_deltas;
|
||||
sync_check_cuda_error();
|
||||
|
||||
masked_multihead_attention(params, input_params.kv_block_array, input_params.shift_k_cache_buffer, stream);
|
||||
@ -1248,6 +1254,8 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T> const& para
|
||||
preprocessingParams.cu_kv_seq_lens = cu_kv_seqlens;
|
||||
preprocessingParams.rotary_embedding_inv_freq = rotary_inv_freq_buf;
|
||||
preprocessingParams.rotary_coef_cache_buffer = params.rotary_cos_sin;
|
||||
preprocessingParams.mrope_rotary_sin_cos = params.mrope_rotary_sin_cos;
|
||||
preprocessingParams.mrope_position_deltas = params.mrope_position_deltas;
|
||||
preprocessingParams.kvScaleOrigQuant = params.kv_scale_orig_quant;
|
||||
preprocessingParams.spec_decoding_position_offsets = nullptr;
|
||||
|
||||
@ -1861,6 +1869,7 @@ int GPTAttentionPluginCommon::enqueueGeneration(EnqueueGenerationParams<T> const
|
||||
dispatch_params.memory_length_per_sample = params.encoder_input_lengths;
|
||||
dispatch_params.block_sparse_attention = mMaskType == AttentionMaskType::BLOCKSPARSE;
|
||||
dispatch_params.block_sparse_params = mBlockSparseParams;
|
||||
dispatch_params.mrope_position_deltas = params.mrope_position_deltas;
|
||||
|
||||
using DataType = typename SATypeConverter<T>::Type;
|
||||
if (!isCrossAttention())
|
||||
|
||||
@ -135,6 +135,9 @@ protected:
|
||||
int32_t max_blocks_per_sequence;
|
||||
int32_t const* host_context_lengths;
|
||||
void* workspace;
|
||||
float2 const* mrope_rotary_sin_cos = nullptr;
|
||||
int32_t const* mrope_position_deltas = nullptr;
|
||||
|
||||
// optional when relative position
|
||||
T const* relative_attention_bias = nullptr;
|
||||
int relative_attention_bias_stride = 0;
|
||||
@ -237,6 +240,9 @@ protected:
|
||||
int32_t* semaphores;
|
||||
void* workspace;
|
||||
int32_t const* host_past_key_value_lengths;
|
||||
float2 const* mrope_rotary_sin_cos = nullptr;
|
||||
int32_t const* mrope_position_deltas = nullptr;
|
||||
|
||||
// optional when relative position
|
||||
T const* relative_attention_bias = nullptr;
|
||||
int relative_attention_bias_stride = 0;
|
||||
@ -293,7 +299,8 @@ protected:
|
||||
{
|
||||
return mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPTJ
|
||||
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_GPT_NEOX
|
||||
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kLONG_ROPE;
|
||||
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kLONG_ROPE
|
||||
|| mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_M;
|
||||
}
|
||||
|
||||
bool isLongRoPE() const
|
||||
@ -306,6 +313,11 @@ protected:
|
||||
return !mEnableContextFMHA && mCrossAttention;
|
||||
}
|
||||
|
||||
bool isMRoPE() const
|
||||
{
|
||||
return mPositionEmbeddingType == tensorrt_llm::kernels::PositionEmbeddingType::kROPE_M;
|
||||
}
|
||||
|
||||
bool isCrossAttention() const
|
||||
{
|
||||
return mCrossAttention;
|
||||
|
||||
@ -162,6 +162,8 @@ bool GPTAttentionPlugin::isEntryUsed(IdxEntry const& entry) const
|
||||
case IdxEntry::SPEC_DECODING_GENERATION_LENGTHS: return mIsSpecDecodingEnabled;
|
||||
case IdxEntry::SPEC_DECODING_PACKED_MASK: return mIsSpecDecodingEnabled;
|
||||
case IdxEntry::SPEC_DECODING_POSITION_OFFSETS: return mIsSpecDecodingEnabled;
|
||||
case IdxEntry::MROPE_ROTARY_SIN_COS: return isMRoPE();
|
||||
case IdxEntry::MROPE_POSITION_DELTAS: return isMRoPE();
|
||||
case IdxEntry::HOST_RUNTIME_PERF_KNOBS: return true;
|
||||
case IdxEntry::HOST_CONTEXT_PROGRESS: return true;
|
||||
case IdxEntry::MLA_FUSED_Q_PROJ_TENSOR: return mIsMLAEnabled;
|
||||
@ -266,6 +268,14 @@ bool GPTAttentionPlugin::supportsFormatCombination(
|
||||
posCaseLine = __LINE__;
|
||||
result = inOut[pos].type == nvinfer1::DataType::kINT32;
|
||||
}
|
||||
else if (isMRoPE() && (pos == getIdx(IdxEntry::MROPE_ROTARY_SIN_COS)))
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kFLOAT;
|
||||
}
|
||||
else if (isMRoPE() && (pos == getIdx(IdxEntry::MROPE_POSITION_DELTAS)))
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32;
|
||||
}
|
||||
else if (pos == getIdx(IdxEntry::HOST_RUNTIME_PERF_KNOBS) || pos == getIdx(IdxEntry::HOST_CONTEXT_PROGRESS))
|
||||
{
|
||||
posCaseLine = __LINE__;
|
||||
@ -411,13 +421,15 @@ void GPTAttentionPlugin::configurePluginImpl(nvinfer1::DynamicPluginTensorDesc c
|
||||
int const num_requests = 256;
|
||||
int const sink_token_length = 0;
|
||||
|
||||
EnqueueGenerationParams<T> enqueueParams{/*attention_input=*/nullptr,
|
||||
EnqueueGenerationParams<T> enqueueParams{
|
||||
/*attention_input=*/nullptr,
|
||||
/*qkv_bias=*/nullptr,
|
||||
/*attention_mask*/ nullptr,
|
||||
/*rotary_inv_freq*/ nullptr,
|
||||
/*input_seq_length=*/0,
|
||||
/*sequence_lengths=*/nullptr,
|
||||
/*past_kv_length=*/0, beamWidth,
|
||||
/*past_kv_length=*/0,
|
||||
beamWidth,
|
||||
/*context_lengths=*/nullptr,
|
||||
/*kv_scale_orig_quant=*/nullptr,
|
||||
/*kv_scale_quant_orig=*/nullptr,
|
||||
@ -428,12 +440,19 @@ void GPTAttentionPlugin::configurePluginImpl(nvinfer1::DynamicPluginTensorDesc c
|
||||
/*block_offsets=*/nullptr,
|
||||
/*host_primary_pool_pointer=*/nullptr,
|
||||
/*host_secondary_pool_pointer=*/nullptr,
|
||||
/*attention_mask_stride*/ 0, max_attention_window_size, cyclic_attention_window_size, sink_token_length,
|
||||
/*attention_mask_stride*/ 0,
|
||||
max_attention_window_size,
|
||||
cyclic_attention_window_size,
|
||||
sink_token_length,
|
||||
num_requests,
|
||||
/*max_blocks_per_sequence=*/0,
|
||||
/*cache_indir=*/nullptr,
|
||||
/*workspace=*/nullptr,
|
||||
/*max_context_kv_len_list=*/nullptr};
|
||||
/*max_context_kv_len_list=*/nullptr,
|
||||
/*mrope_rotary_sin_cos*/ nullptr,
|
||||
/*mrope_position_deltas*/ nullptr,
|
||||
|
||||
};
|
||||
|
||||
prepareEnqueueGeneration<T, KVCacheBuffer>(enqueueParams);
|
||||
|
||||
@ -699,6 +718,12 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
rotary_cos_sin = reinterpret_cast<float2 const*>(inputs[getIdx(IdxEntry::ROTARY_COS_SIN)]);
|
||||
}
|
||||
|
||||
auto const mrope_rotary_sin_cos
|
||||
= isMRoPE() ? reinterpret_cast<float2 const*>(inputs[getIdx(IdxEntry::MROPE_ROTARY_SIN_COS)]) : nullptr;
|
||||
|
||||
auto const mrope_position_deltas
|
||||
= isMRoPE() ? reinterpret_cast<int32_t const*>(inputs[getIdx(IdxEntry::MROPE_POSITION_DELTAS)]) : nullptr;
|
||||
|
||||
if (mUnfuseQkvGemm)
|
||||
{
|
||||
int const max_seqlen = inputDesc[getIdx(IdxEntry::QKV_TENSOR)].dims.d[mRemovePadding ? 0 : 1];
|
||||
@ -932,7 +957,9 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
cyclic_attention_window_size, sink_token_length, context_q_lengths, sequence_kv_length, kv_scale_orig_quant,
|
||||
kv_scale_quant_orig, attention_output_orig_quant, alibi_slopes, context_buf_, key_value_cache,
|
||||
block_offsets, host_block_offsets, host_primary_pool_pointer, host_secondary_pool_pointer, batch_size,
|
||||
localNbTokens, max_blocks_per_sequence, host_context_lengths, workspace};
|
||||
localNbTokens, max_blocks_per_sequence, host_context_lengths, workspace, mrope_rotary_sin_cos,
|
||||
mrope_position_deltas};
|
||||
|
||||
enqueue_params.runtime_perf_knobs = runtime_perf_knobs;
|
||||
if (isRelativePosition())
|
||||
{
|
||||
@ -1007,7 +1034,8 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
kv_scale_quant_orig, attention_output_orig_quant, alibi_slopes, context_buf_, key_value_cache,
|
||||
block_offsets, host_primary_pool_pointer, host_secondary_pool_pointer, attention_mask_stride,
|
||||
max_attention_window_size, cyclic_attention_window_size, sink_token_length, num_requests,
|
||||
max_blocks_per_sequence, cache_indir, mMultiBlockSemaphores.get(), workspace, max_context_kv_len_list};
|
||||
max_blocks_per_sequence, cache_indir, mMultiBlockSemaphores.get(), workspace, max_context_kv_len_list,
|
||||
mrope_rotary_sin_cos, mrope_position_deltas};
|
||||
enqueue_params.host_context_lengths = host_context_lengths;
|
||||
enqueue_params.runtime_perf_knobs = runtime_perf_knobs;
|
||||
if (isRelativePosition())
|
||||
|
||||
@ -220,6 +220,8 @@ private:
|
||||
SPEC_DECODING_GENERATION_LENGTHS,
|
||||
SPEC_DECODING_PACKED_MASK,
|
||||
SPEC_DECODING_POSITION_OFFSETS,
|
||||
MROPE_ROTARY_SIN_COS,
|
||||
MROPE_POSITION_DELTAS,
|
||||
HOST_RUNTIME_PERF_KNOBS,
|
||||
HOST_CONTEXT_PROGRESS,
|
||||
MLA_FUSED_Q_PROJ_TENSOR,
|
||||
|
||||
@ -18,6 +18,8 @@
|
||||
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
|
||||
#include "tensorrt_llm/common/dataType.h"
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/utils/debugUtils.h"
|
||||
#include <numeric>
|
||||
|
||||
using namespace nvinfer1;
|
||||
@ -42,8 +44,8 @@ MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(bool remove_input_padding, int nu
|
||||
nvinfer1::DataType type, nvinfer1::DataType weight_type, nvinfer1::DataType output_type, QuantMode quant_mode,
|
||||
bool use_finished, bool use_bias, int tp_size, int tp_rank, int ep_size, int ep_rank,
|
||||
MOEExpertScaleNormalizationMode normalization_mode, float sparse_mixer_epsilon, bool force_determinism,
|
||||
MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr, bool use_lora, nvinfer1::DataType lora_type,
|
||||
LoraPluginProfilerPtr lora_profiler, int max_low_rank)
|
||||
int side_stream_id, MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr, bool use_lora,
|
||||
nvinfer1::DataType lora_type, LoraPluginProfilerPtr lora_profiler, int max_low_rank)
|
||||
: mRemoveInputPadding(remove_input_padding)
|
||||
, mNumExperts(number_of_experts)
|
||||
, mK(top_k)
|
||||
@ -60,6 +62,7 @@ MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(bool remove_input_padding, int nu
|
||||
, mNormalizationMode(normalization_mode)
|
||||
, mSparseMixerEpsilon(sparse_mixer_epsilon)
|
||||
, mUseDeterministicKernels(force_determinism)
|
||||
, mSideStreamId(side_stream_id)
|
||||
, mGemmProfiler(std::move(gemm_profiler_ptr))
|
||||
, mUseLora(use_lora)
|
||||
, mLoraType(lora_type)
|
||||
@ -90,6 +93,7 @@ tensorrt_llm::plugins::MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(MixtureOfE
|
||||
, mGemmId2(other.mGemmId2)
|
||||
, mSparseMixerEpsilon(other.mSparseMixerEpsilon)
|
||||
, mUseDeterministicKernels(other.mUseDeterministicKernels)
|
||||
, mSideStreamId(other.mSideStreamId)
|
||||
, mGemmProfiler(other.mGemmProfiler)
|
||||
, mUseLora(other.mUseLora)
|
||||
, mLoraType(other.mLoraType)
|
||||
@ -111,8 +115,8 @@ size_t MixtureOfExpertsPlugin::getSerializationSize() const noexcept
|
||||
+ sizeof(mExpertInterSize) + sizeof(mActivationType) + sizeof(mType) + sizeof(mWeightType) + sizeof(mOutputType)
|
||||
+ sizeof(QuantMode::BaseType) + sizeof(mUseFinished) + sizeof(mUseBias) + sizeof(mParallelismConfig)
|
||||
+ sizeof(mNormalizationMode) + sizeof(mSparseMixerEpsilon) + sizeof(mDims) + sizeof(mUseDeterministicKernels)
|
||||
+ mGemmProfiler->getSerializationSize(mGemmId1) + mGemmProfiler->getSerializationSize(mGemmId2)
|
||||
+ sizeof(mUseLora) + sizeof(mLoraType) + sizeof(mMaxLowRank);
|
||||
+ sizeof(mSideStreamId) + mGemmProfiler->getSerializationSize(mGemmId1)
|
||||
+ mGemmProfiler->getSerializationSize(mGemmId2) + sizeof(mUseLora) + sizeof(mLoraType) + sizeof(mMaxLowRank);
|
||||
|
||||
if (hasLora())
|
||||
{
|
||||
@ -149,6 +153,7 @@ MixtureOfExpertsPlugin::MixtureOfExpertsPlugin(void const* data, size_t length,
|
||||
read(d, mSparseMixerEpsilon);
|
||||
read(d, mDims);
|
||||
read(d, mUseDeterministicKernels);
|
||||
read(d, mSideStreamId);
|
||||
read(d, mUseLora);
|
||||
read(d, mLoraType);
|
||||
read(d, mMaxLowRank);
|
||||
@ -193,6 +198,7 @@ void MixtureOfExpertsPlugin::serialize(void* buffer) const noexcept
|
||||
write(d, mSparseMixerEpsilon);
|
||||
write(d, mDims);
|
||||
write(d, mUseDeterministicKernels);
|
||||
write(d, mSideStreamId);
|
||||
write(d, mUseLora);
|
||||
write(d, mLoraType);
|
||||
write(d, mMaxLowRank);
|
||||
@ -308,6 +314,9 @@ void MixtureOfExpertsPlugin::init()
|
||||
|
||||
TLLM_CUDA_CHECK(cudaEventCreate(&mMemcpyEvent));
|
||||
}
|
||||
mSideStreamPtr = nullptr;
|
||||
mDebugStallMain = tensorrt_llm::runtime::utils::stallStream("TLLM_DEBUG_MOE_STALL_MAIN");
|
||||
mDebugStallSide = tensorrt_llm::runtime::utils::stallStream("TLLM_DEBUG_MOE_STALL_SIDE");
|
||||
}
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
@ -321,7 +330,7 @@ nvinfer1::IPluginV2DynamicExt* MixtureOfExpertsPlugin::clone() const noexcept
|
||||
nvinfer1::DimsExprs MixtureOfExpertsPlugin::getOutputDimensions(
|
||||
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
|
||||
{
|
||||
assert(outputIndex == getOutputTensorIndex());
|
||||
assert(outputIndex == getOutputTensorIndex() || outputIndex == getOutputDummyTensorIndex());
|
||||
return inputs[getInputTensorIndex()];
|
||||
}
|
||||
|
||||
@ -359,6 +368,10 @@ bool MixtureOfExpertsPlugin::supportsFormatCombination(
|
||||
{
|
||||
return inOut[pos].type == mOutputType;
|
||||
}
|
||||
else if (useSideStream() && pos == nbInputs + getOutputDummyTensorIndex())
|
||||
{
|
||||
return inOut[pos].type == mType;
|
||||
}
|
||||
else if (hasExpertFp8QuantScales() && getExpertFP8Dequant1Index() <= pos && pos <= getExpertFP8QuantFinalIndex())
|
||||
{
|
||||
return inOut[pos].type == DataType::kFLOAT;
|
||||
@ -514,6 +527,10 @@ size_t MixtureOfExpertsPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const
|
||||
TLLM_CHECK_WITH_INFO(nbInputs == getNbInputs(), "Required input to plugin is missing");
|
||||
TLLM_CHECK_WITH_INFO(nbOutputs == getNbOutputs(), "Required output to plugin is missing");
|
||||
|
||||
if (useSideStream())
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
int const num_tokens = getNumTokens(inputs);
|
||||
int const num_lora_reqs = getNumLoraRequests(inputs);
|
||||
return setupWorkspace(nullptr, num_tokens, num_lora_reqs).size;
|
||||
@ -661,6 +678,36 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
int64_t const num_reqs = getNumLoraRequests(inputDesc);
|
||||
int64_t const num_not_finished = num_tokens; // TODO Take this as an input
|
||||
|
||||
if (useSideStream())
|
||||
{
|
||||
// Prepare the side stream
|
||||
if (!mSideStreamPtr)
|
||||
{
|
||||
auto const resource_name = nvinfer1::pluginInternal::SideStream::getResourceKey(mSideStreamId);
|
||||
nvinfer1::pluginInternal::SideStream side_stream{};
|
||||
mSideStreamPtr = reinterpret_cast<nvinfer1::pluginInternal::SideStream*>(
|
||||
getPluginRegistry()->acquirePluginResource(resource_name, &side_stream));
|
||||
}
|
||||
// Debug the code with the main stream stalled (only executed when the environment variable
|
||||
// TLLM_DEBUG_MOE_STALL_MAIN is set and has a positive value)
|
||||
mSideStreamPtr->stallMainStream("TLLM_DEBUG_MOE_STALL_MAIN", stream, mDebugStallMain);
|
||||
// The side stream waits for the inputs managed by the main stream to be ready
|
||||
mSideStreamPtr->waitMainStreamOnSideStream(stream);
|
||||
// Provide data dependency for the shared experts running after this plugin by copying inputs on the main stream
|
||||
size_t count = 1;
|
||||
for (int i = 0; i < inputDesc[getInputTensorIndex()].dims.nbDims; ++i)
|
||||
{
|
||||
count *= inputDesc[getInputTensorIndex()].dims.d[i];
|
||||
}
|
||||
count *= tensorrt_llm::runtime::BufferDataType(inputDesc[getInputTensorIndex()].type).getSize();
|
||||
TLLM_CUDA_CHECK(cudaMemcpyAsync(outputs[getOutputDummyTensorIndex()], inputs[getInputTensorIndex()], count,
|
||||
cudaMemcpyDeviceToDevice, stream));
|
||||
// Switch from the main stream to the side stream
|
||||
stream = mSideStreamPtr->getStream();
|
||||
// The workspace is managed by the side stream (otherwise, the lifetime of workspace may be incorrect)
|
||||
auto const workspace_size = setupWorkspace(nullptr, num_tokens, num_reqs).size;
|
||||
workspace_ptr = mSideStreamPtr->getWorkspacePtr(workspace_size);
|
||||
}
|
||||
auto workspace = setupWorkspace(workspace_ptr, num_tokens, num_reqs);
|
||||
|
||||
auto w1_desc = inputDesc[getExpertWeights1Index()];
|
||||
@ -728,6 +775,13 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
static_cast<int*>(workspace.selected_experts), mSparseMixerEpsilon, mParallelismConfig, mNormalizationMode,
|
||||
hasLora(), lora_params, stream);
|
||||
|
||||
if (useSideStream())
|
||||
{
|
||||
// Debug the code with the side stream stalled (only executed when the environment variable
|
||||
// TLLM_DEBUG_MOE_STALL_SIDE is set and has a positive value)
|
||||
mSideStreamPtr->stallSideStream("TLLM_DEBUG_MOE_STALL_SIDE", mDebugStallSide);
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -735,8 +789,12 @@ int MixtureOfExpertsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
nvinfer1::DataType MixtureOfExpertsPlugin::getOutputDataType(
|
||||
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
|
||||
{
|
||||
TLLM_CHECK(index == getOutputTensorIndex());
|
||||
TLLM_CHECK(index == getOutputTensorIndex() || index == getOutputDummyTensorIndex());
|
||||
TLLM_CHECK(inputTypes[getInputTensorIndex()] == mType);
|
||||
if (useSideStream() && index == getOutputDummyTensorIndex())
|
||||
{
|
||||
return mType;
|
||||
}
|
||||
return mOutputType;
|
||||
}
|
||||
|
||||
@ -769,7 +827,15 @@ int MixtureOfExpertsPlugin::initialize() noexcept
|
||||
return 0;
|
||||
}
|
||||
|
||||
void MixtureOfExpertsPlugin::terminate() noexcept {}
|
||||
void MixtureOfExpertsPlugin::terminate() noexcept
|
||||
{
|
||||
if (mSideStreamPtr)
|
||||
{
|
||||
auto const resource_name = nvinfer1::pluginInternal::SideStream::getResourceKey(mSideStreamId);
|
||||
getPluginRegistry()->releasePluginResource(resource_name);
|
||||
mSideStreamPtr = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
void MixtureOfExpertsPlugin::destroy() noexcept
|
||||
{
|
||||
@ -836,6 +902,7 @@ MixtureOfExpertsPluginCreator::MixtureOfExpertsPluginCreator()
|
||||
static_cast<int>(MOEExpertScaleNormalizationMode::NONE)));
|
||||
mPluginAttributes.emplace_back(
|
||||
nvinfer1::PluginField("sparse_mixer_epsilon", nullptr, PluginFieldType::kFLOAT32, 0));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("side_stream_id", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("use_lora", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("lora_type_id", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(nvinfer1::PluginField("max_low_rank", nullptr, PluginFieldType::kINT32, 0));
|
||||
@ -865,6 +932,7 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin(
|
||||
int mEPRank{};
|
||||
int mNormalizationMode{};
|
||||
int mRequiresDeterminism{0};
|
||||
int mSideStreamId{0};
|
||||
int mUseLora{};
|
||||
int mLoraType{INT_MAX};
|
||||
int mMaxLowRank{0};
|
||||
@ -902,6 +970,7 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin(
|
||||
MapPair{"use_bias", std::ref(mUseBias), true},
|
||||
MapPair{"output_type_id", std::ref(mOutputType), true},
|
||||
MapPair{"force_determinism", std::ref(mRequiresDeterminism), true},
|
||||
MapPair{"side_stream_id", std::ref(mSideStreamId), true},
|
||||
MapPair{"lora_type_id", std::ref(mLoraType), true},
|
||||
MapPair{"max_low_rank", std::ref(mMaxLowRank), true},
|
||||
};
|
||||
@ -962,8 +1031,8 @@ IPluginV2* MixtureOfExpertsPluginCreator::createPlugin(
|
||||
static_cast<nvinfer1::DataType>(mWeightType), static_cast<nvinfer1::DataType>(mOutputType),
|
||||
QuantMode(mQuantMode), mUseFinished != 0, mUseBias != 0, mTPSize, mTPRank, mEPSize, mEPRank,
|
||||
static_cast<MOEExpertScaleNormalizationMode>(mNormalizationMode), mSparseMixerEpsilon,
|
||||
mRequiresDeterminism != 0, gemmProfiler, mUseLora != 0, static_cast<nvinfer1::DataType>(mLoraType),
|
||||
loraProfiler, mMaxLowRank);
|
||||
mRequiresDeterminism != 0, mSideStreamId, gemmProfiler, mUseLora != 0,
|
||||
static_cast<nvinfer1::DataType>(mLoraType), loraProfiler, mMaxLowRank);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
#include "tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.h"
|
||||
#include "tensorrt_llm/plugins/common/gemmPluginProfiler.h"
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
#include "tensorrt_llm/plugins/cudaStreamPlugin/cudaStreamPlugin.h"
|
||||
#include "tensorrt_llm/plugins/gemmPlugin/gemmPlugin.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include <cassert>
|
||||
@ -106,8 +107,8 @@ public:
|
||||
nvinfer1::DataType weight_type, nvinfer1::DataType output_type, tensorrt_llm::common::QuantMode quant_mode,
|
||||
bool use_finished, bool use_bias, int tp_size, int tp_rank, int ep_size, int ep_rank,
|
||||
MOEExpertScaleNormalizationMode normalization_mode, float sparse_mixer_epsilon, bool force_determinism,
|
||||
MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr, bool use_lora, nvinfer1::DataType lora_type,
|
||||
LoraPluginProfilerPtr lora_profiler, int max_low_rank);
|
||||
int side_stream_id, MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr, bool use_lora,
|
||||
nvinfer1::DataType lora_type, LoraPluginProfilerPtr lora_profiler, int max_low_rank);
|
||||
MixtureOfExpertsPlugin(void const* data, size_t length, MixtureOfExpertsPluginProfilerPtr gemm_profiler_ptr,
|
||||
LoraPluginProfilerPtr lora_profiler);
|
||||
MixtureOfExpertsPlugin(MixtureOfExpertsPlugin const&);
|
||||
@ -139,7 +140,7 @@ public:
|
||||
|
||||
int getNbOutputs() const noexcept override
|
||||
{
|
||||
return 1;
|
||||
return 1 + useSideStream();
|
||||
}
|
||||
|
||||
int initialize() noexcept override;
|
||||
@ -170,6 +171,10 @@ private:
|
||||
|
||||
GemmDims mDims{};
|
||||
bool mUseDeterministicKernels = false;
|
||||
int mSideStreamId = 0;
|
||||
|
||||
int mDebugStallMain = 0;
|
||||
int mDebugStallSide = 0;
|
||||
|
||||
GemmIDMoe mGemmId1{};
|
||||
GemmIDMoe mGemmId2{};
|
||||
@ -197,6 +202,7 @@ private:
|
||||
std::vector<int32_t> mLoraExpandGatedRanks{};
|
||||
|
||||
cudaEvent_t mMemcpyEvent;
|
||||
nvinfer1::pluginInternal::SideStream* mSideStreamPtr;
|
||||
|
||||
// The below are not serialised
|
||||
std::string const mLayerName{};
|
||||
@ -279,6 +285,11 @@ private:
|
||||
return hasExpertFp8QuantScales() && mOutputType == nvinfer1::DataType::kFP8;
|
||||
}
|
||||
|
||||
bool useSideStream() const
|
||||
{
|
||||
return mSideStreamId > 0;
|
||||
}
|
||||
|
||||
bool hasLora() const
|
||||
{
|
||||
return mUseLora;
|
||||
@ -390,6 +401,11 @@ private:
|
||||
return 0;
|
||||
}
|
||||
|
||||
IndexType getOutputDummyTensorIndex() const
|
||||
{
|
||||
return getOutputTensorIndex() + useSideStream();
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the index of the expert shape tuple that represents the inner dimension
|
||||
*/
|
||||
|
||||
@ -416,34 +416,33 @@ std::set<int> getLocalGroup(std::set<int> const& group)
|
||||
ranks.push_back(myRank);
|
||||
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
|
||||
{
|
||||
LOCAL_COMM_SESSION.recvValue(rank, *it, 0);
|
||||
COMM_SESSION.recvValue(rank, *it, 0);
|
||||
ranks.push_back(rank);
|
||||
}
|
||||
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
|
||||
{
|
||||
LOCAL_COMM_SESSION.send(ranks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *it, 0);
|
||||
COMM_SESSION.send(ranks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *it, 0);
|
||||
}
|
||||
|
||||
localRanks.clear();
|
||||
localRanks.push_back(myLocalRank);
|
||||
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
|
||||
{
|
||||
LOCAL_COMM_SESSION.recvValue(rank, *it, 0);
|
||||
COMM_SESSION.recvValue(rank, *it, 0);
|
||||
localRanks.push_back(rank);
|
||||
}
|
||||
for (auto it = std::next(std::begin(group), 1); it != group.end(); ++it)
|
||||
{
|
||||
LOCAL_COMM_SESSION.send(localRanks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *it, 0);
|
||||
COMM_SESSION.send(localRanks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *it, 0);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
LOCAL_COMM_SESSION.sendValue(myRank, *group.begin(), 0);
|
||||
LOCAL_COMM_SESSION.recv(ranks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *group.begin(), 0);
|
||||
COMM_SESSION.sendValue(myRank, *group.begin(), 0);
|
||||
COMM_SESSION.recv(ranks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *group.begin(), 0);
|
||||
|
||||
LOCAL_COMM_SESSION.sendValue(myLocalRank, *group.begin(), 0);
|
||||
LOCAL_COMM_SESSION.recv(
|
||||
localRanks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *group.begin(), 0);
|
||||
COMM_SESSION.sendValue(myLocalRank, *group.begin(), 0);
|
||||
COMM_SESSION.recv(localRanks.data(), localSize, tensorrt_llm::mpi::MpiType::kINT32, *group.begin(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -732,7 +731,6 @@ IPluginV2* AllreducePluginCreator::createPlugin(char const* name, PluginFieldCol
|
||||
bias = *static_cast<int8_t const*>(fields[i].data);
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
auto* obj = new AllreducePlugin(group, type, strategy, config, fusion_op, counter, eps, affine, bias);
|
||||
|
||||
@ -36,6 +36,8 @@ target_compile_definitions(
|
||||
if(NOT WIN32)
|
||||
set_target_properties(
|
||||
${TRTLLM_PYBIND_MODULE}
|
||||
PROPERTIES LINK_FLAGS
|
||||
"-Wl,-rpath,'$ORIGIN/libs' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}")
|
||||
PROPERTIES
|
||||
LINK_FLAGS
|
||||
"-Wl,-rpath,'$ORIGIN/libs' -Wl,-rpath,'$ORIGIN/../nvidia/nccl/lib' ${AS_NEEDED_FLAG} ${UNDEFINED_FLAG}"
|
||||
)
|
||||
endif()
|
||||
|
||||
@ -237,6 +237,8 @@ void initBindings(pybind11::module_& m)
|
||||
std::optional<std::vector<tb::LlmRequest::SizeType32>> position_ids,
|
||||
std::optional<at::Tensor> prompt_embedding_table,
|
||||
std::optional<tb::LlmRequest::SizeType32> prompt_vocab_size,
|
||||
std::optional<at::Tensor> mrope_rotary_sin_cos,
|
||||
std::optional<tb::LlmRequest::SizeType32> mrope_position_deltas,
|
||||
std::optional<LoraTaskIdType> lora_task_id, std::optional<at::Tensor> lora_weights,
|
||||
std::optional<at::Tensor> lora_config,
|
||||
std::optional<executor::LookaheadDecodingConfig> lookahead_config,
|
||||
@ -269,6 +271,7 @@ void initBindings(pybind11::module_& m)
|
||||
auto stop_words_list_tensor_ptr = makeOptionalTensor(stop_words_list);
|
||||
auto prompt_embedding_table_tensor_ptr = makeOptionalTensor(prompt_embedding_table);
|
||||
auto lora_weights_tensor_ptr = makeOptionalTensor(lora_weights);
|
||||
auto mrope_rotary_sin_cos_tensor_ptr = makeOptionalTensor(mrope_rotary_sin_cos);
|
||||
auto lora_config_tensor_ptr = makeOptionalTensor(lora_config);
|
||||
auto draft_logits_tensor_ptr = makeOptionalTensor(draft_logits);
|
||||
auto encoder_input_features_tensor_ptr = makeOptionalTensor(encoder_input_features);
|
||||
@ -277,19 +280,20 @@ void initBindings(pybind11::module_& m)
|
||||
return tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming,
|
||||
end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr,
|
||||
stop_words_list_tensor_ptr, position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size,
|
||||
lora_task_id, lora_weights_tensor_ptr, lora_config_tensor_ptr, lookahead_config,
|
||||
kv_cache_retention_config, return_log_probs, return_context_logits, return_generation_logits,
|
||||
draft_tokens, draft_logits_tensor_ptr, exclude_input_from_output, logits_post_processor,
|
||||
apply_logits_post_processor_batched, encoder_input_tokens, return_encoder_output, client_id,
|
||||
priority, encoder_input_features_tensor_ptr, encoder_output_length,
|
||||
cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids,
|
||||
num_return_sequences};
|
||||
mrope_rotary_sin_cos_tensor_ptr, mrope_position_deltas, lora_task_id, lora_weights_tensor_ptr,
|
||||
lora_config_tensor_ptr, lookahead_config, kv_cache_retention_config, return_log_probs,
|
||||
return_context_logits, return_generation_logits, draft_tokens, draft_logits_tensor_ptr,
|
||||
exclude_input_from_output, logits_post_processor, apply_logits_post_processor_batched,
|
||||
encoder_input_tokens, return_encoder_output, client_id, priority,
|
||||
encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr,
|
||||
llm_request_type, input_token_extra_ids, num_return_sequences};
|
||||
}),
|
||||
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
|
||||
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
|
||||
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
|
||||
py::arg("stop_words_list") = std::nullopt, py::arg("position_ids") = std::nullopt,
|
||||
py::arg("prompt_embedding_table") = std::nullopt, py::arg("prompt_vocab_size") = std::nullopt,
|
||||
py::arg("mrope_rotary_sin_cos") = std::nullopt, py::arg("mrope_position_deltas") = std::nullopt,
|
||||
py::arg("lora_task_id") = std::nullopt, py::arg("lora_weights") = std::nullopt,
|
||||
py::arg("lora_config") = std::nullopt, py::arg("lookahead_config") = std::nullopt,
|
||||
py::arg("kv_cache_retention_config") = std::nullopt, py::arg("return_log_probs") = false,
|
||||
|
||||
@ -73,6 +73,8 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
|
||||
auto badWordsList = from_torch(mBadWordsList);
|
||||
auto stopWordsList = from_torch(mStopWordsList);
|
||||
auto promptEmbeddingTable = from_torch(mPromptEmbeddingTable);
|
||||
auto mropeRotarySinCos = from_torch(mMropeRotarySinCos);
|
||||
|
||||
auto loraWeights = from_torch(mLoraWeights);
|
||||
auto loraConfig = from_torch(mLoraConfig);
|
||||
auto draftLogits = from_torch(mDraftLogits);
|
||||
@ -82,11 +84,11 @@ std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
|
||||
|
||||
return std::make_shared<tb::LlmRequest>(mRequestId, mMaxNewTokens,
|
||||
std::make_shared<std::vector<TokenIdType>>(mTokens.at(0)), mSamplingConfig, mIsStreaming, mEndId, mPadId,
|
||||
embeddingBias, badWordsList, stopWordsList, mPositionIds, promptEmbeddingTable, mPromptVocabSize, mLoraTaskId,
|
||||
loraWeights, loraConfig, mLookaheadConfig, mKvCacheRetentionConfig, returnLogProbs(), mReturnContextLogits,
|
||||
mReturnGenerationLogits, mDraftTokens, draftLogits, mExcludeInputFromOutput,
|
||||
callbackAdapter(mLogitsPostProcessor), mApplyLogitsPostProcessorBatched, mEncoderTokens, mReturnEncoderOutput,
|
||||
mClientId, mPriority, encoderInputFeatures, mEncoderOutputLength, crossAttentionMask,
|
||||
tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, mInputTokenExtraIds, mNumReturnSequences,
|
||||
std::nullopt, skipCrossAttnBlocks);
|
||||
embeddingBias, badWordsList, stopWordsList, mPositionIds, promptEmbeddingTable, mPromptVocabSize,
|
||||
mropeRotarySinCos, mMropePositionDeltas, mLoraTaskId, loraWeights, loraConfig, mLookaheadConfig,
|
||||
mKvCacheRetentionConfig, returnLogProbs(), mReturnContextLogits, mReturnGenerationLogits, mDraftTokens,
|
||||
draftLogits, mExcludeInputFromOutput, callbackAdapter(mLogitsPostProcessor), mApplyLogitsPostProcessorBatched,
|
||||
mEncoderTokens, mReturnEncoderOutput, mClientId, mPriority, encoderInputFeatures, mEncoderOutputLength,
|
||||
crossAttentionMask, tb::LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, mInputTokenExtraIds,
|
||||
mNumReturnSequences, std::nullopt, skipCrossAttnBlocks);
|
||||
}
|
||||
|
||||
@ -57,6 +57,8 @@ public:
|
||||
std::optional<std::vector<SizeType32>> positionIds = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
std::optional<SizeType32> promptVocabSize = std::nullopt,
|
||||
std::optional<TensorPtr> mropeRotarySinCos = std::nullopt,
|
||||
std::optional<SizeType32> mropePositionDeltas = std::nullopt,
|
||||
std::optional<LoraTaskIdType> loraTaskId = std::nullopt, std::optional<TensorPtr> loraWeights = std::nullopt,
|
||||
std::optional<TensorPtr> loraConfig = std::nullopt,
|
||||
std::optional<executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt,
|
||||
@ -76,8 +78,9 @@ public:
|
||||
samplingConfig, isStreaming, endId, padId, embeddingBias, badWordsList, stopWordsList,
|
||||
positionIds.has_value() ? std::make_shared<std::vector<SizeType32>>(std::move(positionIds.value()))
|
||||
: std::optional<std::shared_ptr<std::vector<SizeType32>>>(std::nullopt),
|
||||
promptEmbeddingTable, promptVocabSize, loraTaskId, loraWeights, loraConfig, lookaheadConfig,
|
||||
kvCacheRetentionConfig, returnLogProbs, returnContextLogits, returnGenerationLogits,
|
||||
promptEmbeddingTable, promptVocabSize, mropeRotarySinCos, mropePositionDeltas, loraTaskId, loraWeights,
|
||||
loraConfig, lookaheadConfig, kvCacheRetentionConfig, returnLogProbs, returnContextLogits,
|
||||
returnGenerationLogits,
|
||||
draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value()))
|
||||
: std::make_shared<VecTokens>(),
|
||||
draftLogits, excludeInputFromOutput, logitsPostProcessor, applyLogitsPostProcessorBatched,
|
||||
|
||||
@ -245,6 +245,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_property("max_prompt_embedding_table_size", &tr::ModelConfig::getMaxPromptEmbeddingTableSize,
|
||||
&tr::ModelConfig::setMaxPromptEmbeddingTableSize)
|
||||
.def_property_readonly("use_prompt_tuning", &tr::ModelConfig::usePromptTuning)
|
||||
.def_property_readonly("use_mrope", &tr::ModelConfig::useMrope)
|
||||
.def_property("use_lora_plugin", py::overload_cast<>(&tr::ModelConfig::useLoraPlugin, py::const_),
|
||||
py::overload_cast<bool>(&tr::ModelConfig::useLoraPlugin))
|
||||
.def_property("compute_context_logits", py::overload_cast<>(&tr::ModelConfig::computeContextLogits, py::const_),
|
||||
|
||||
@ -314,6 +314,11 @@ void InitBindings(pybind11::module_& m)
|
||||
.def_property_readonly("embedding_table", &tle::PromptTuningConfig::getEmbeddingTable)
|
||||
.def_property_readonly("input_token_extra_ids", &tle::PromptTuningConfig::getInputTokenExtraIds);
|
||||
|
||||
py::class_<tle::MropeConfig>(m, "MropeConfig")
|
||||
.def(py::init<Tensor, SizeType32>(), py::arg("mrope_rotary_sin_cos"), py::arg("mrope_position_deltas"))
|
||||
.def_property_readonly("mrope_rotary_sin_cos", &tle::MropeConfig::getMRopeRotarySinCos)
|
||||
.def_property_readonly("mrope_position_deltas", &tle::MropeConfig::getMRopePositionDeltas);
|
||||
|
||||
py::class_<tle::LoraConfig>(m, "LoraConfig")
|
||||
.def(py::init<uint64_t, std::optional<Tensor>, std::optional<Tensor>>(), py::arg("task_id"),
|
||||
py::arg("weights") = py::none(), py::arg("config") = py::none())
|
||||
@ -391,7 +396,8 @@ void InitBindings(pybind11::module_& m)
|
||||
std::optional<std::list<tle::VecTokens>> badWords,
|
||||
std::optional<std::list<tle::VecTokens>> stopWords, std::optional<tle::Tensor> embeddingBias,
|
||||
std::optional<tle::ExternalDraftTokensConfig> externalDraftTokensConfig,
|
||||
std::optional<tle::PromptTuningConfig> pTuningConfig, std::optional<tle::LoraConfig> loraConfig,
|
||||
std::optional<tle::PromptTuningConfig> pTuningConfig, std::optional<tle::MropeConfig> mRopeConfig,
|
||||
std::optional<tle::LoraConfig> loraConfig,
|
||||
std::optional<tle::LookaheadDecodingConfig> lookaheadConfig,
|
||||
std::optional<tle::KvCacheRetentionConfig> kvCacheRetentionConfig,
|
||||
std::optional<std::string> logitsPostProcessorName,
|
||||
@ -413,10 +419,10 @@ void InitBindings(pybind11::module_& m)
|
||||
TLLM_CHECK_WITH_INFO(maxTokens.has_value(), "missing required argument max_tokens");
|
||||
return std::make_unique<tle::Request>(inputTokenIds, maxTokens.value(), streaming, samplingConfig,
|
||||
outputConfig, endId, padId, positionIds, badWords, stopWords, embeddingBias,
|
||||
externalDraftTokensConfig, pTuningConfig, loraConfig, lookaheadConfig, kvCacheRetentionConfig,
|
||||
logitsPostProcessorName, encoderInputTokenIds, clientId, returnAllGeneratedTokens, priority,
|
||||
type, contextPhaseParams, encoderInputFeatures, encoderOutputLength, crossAttentionMask,
|
||||
numReturnSequences, eagleConfig, skipCrossAttnBlocks);
|
||||
externalDraftTokensConfig, pTuningConfig, mRopeConfig, loraConfig, lookaheadConfig,
|
||||
kvCacheRetentionConfig, logitsPostProcessorName, encoderInputTokenIds, clientId,
|
||||
returnAllGeneratedTokens, priority, type, contextPhaseParams, encoderInputFeatures,
|
||||
encoderOutputLength, crossAttentionMask, numReturnSequences, eagleConfig, skipCrossAttnBlocks);
|
||||
}),
|
||||
py::arg("input_token_ids"), py::kw_only(), py::arg("max_tokens") = py::none(),
|
||||
py::arg("max_new_tokens") = py::none(), py::arg("streaming") = false,
|
||||
@ -425,10 +431,11 @@ void InitBindings(pybind11::module_& m)
|
||||
py::arg("pad_id") = py::none(), py::arg("position_ids") = py::none(), py::arg("bad_words") = py::none(),
|
||||
py::arg("stop_words") = py::none(), py::arg("embedding_bias") = py::none(),
|
||||
py::arg("external_draft_tokens_config") = py::none(), py::arg("prompt_tuning_config") = py::none(),
|
||||
py::arg("lora_config") = py::none(), py::arg("lookahead_config") = py::none(),
|
||||
py::arg("kv_cache_retention_config") = py::none(), py::arg("logits_post_processor_name") = py::none(),
|
||||
py::arg("encoder_input_token_ids") = py::none(), py::arg("client_id") = py::none(),
|
||||
py::arg("return_all_generated_tokens") = false, py::arg("priority") = tle::Request::kDefaultPriority,
|
||||
py::arg("mrope_config") = py::none(), py::arg("lora_config") = py::none(),
|
||||
py::arg("lookahead_config") = py::none(), py::arg("kv_cache_retention_config") = py::none(),
|
||||
py::arg("logits_post_processor_name") = py::none(), py::arg("encoder_input_token_ids") = py::none(),
|
||||
py::arg("client_id") = py::none(), py::arg("return_all_generated_tokens") = false,
|
||||
py::arg("priority") = tle::Request::kDefaultPriority,
|
||||
py::arg_v("type", tle::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION,
|
||||
"RequestType.REQUEST_TYPE_CONTEXT_AND_GENERATION"),
|
||||
py::arg("context_phase_params") = py::none(), py::arg("encoder_input_features") = py::none(),
|
||||
@ -451,6 +458,7 @@ void InitBindings(pybind11::module_& m)
|
||||
&tle::Request::setExternalDraftTokensConfig)
|
||||
.def_property(
|
||||
"prompt_tuning_config", &tle::Request::getPromptTuningConfig, &tle::Request::setPromptTuningConfig)
|
||||
.def_property("mrope_config", &tle::Request::getMropeConfig, &tle::Request::setMropeConfig)
|
||||
.def_property("lora_config", &tle::Request::getLoraConfig, &tle::Request::setLoraConfig)
|
||||
.def_property("lookahead_config", &tle::Request::getLookaheadConfig, &tle::Request::setLookaheadConfig)
|
||||
.def_property("kv_cache_retention_config", &tle::Request::getKvCacheRetentionConfig,
|
||||
|
||||
@ -69,6 +69,8 @@ void EagleBuffers::Inputs::create(SizeType32 maxNumSequences, TllmRuntime const&
|
||||
= manager.pinnedPool(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32);
|
||||
eagleNetGenPastKeyValueLengthsHost
|
||||
= manager.pinnedPool(ITensor::makeShape({maxNumSequences}), nvinfer1::DataType::kINT32);
|
||||
inputGenTokensHost
|
||||
= manager.pinnedPool(ITensor::makeShape({maxNumSequences * maxDecodingTokens}), nvinfer1::DataType::kINT32);
|
||||
}
|
||||
|
||||
EagleBuffers::EagleBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, runtime::BufferManager const& manager,
|
||||
@ -119,6 +121,7 @@ EagleBuffers::EagleBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, run
|
||||
= manager.emptyTensor(runtime::MemoryType::kPINNEDPOOL, nvinfer1::DataType::kINT32);
|
||||
engineInputs.eagleNetGenPastKeyValueLengthsHost
|
||||
= manager.emptyTensor(runtime::MemoryType::kPINNEDPOOL, nvinfer1::DataType::kINT32);
|
||||
engineInputs.inputGenTokensHost = manager.emptyTensor(runtime::MemoryType::kPINNEDPOOL, nvinfer1::DataType::kINT32);
|
||||
|
||||
// output tensors
|
||||
engineOutputs.nextDraftTokens
|
||||
@ -187,6 +190,7 @@ void EagleBuffers::reshape(
|
||||
engineInputs.eagleNetGenRequestTypesHost->reshape(ITensor::makeShape({numSequences}));
|
||||
engineInputs.eagleNetGenContextLengthsHost->reshape(ITensor::makeShape({numSequences}));
|
||||
engineInputs.eagleNetGenPastKeyValueLengthsHost->reshape(ITensor::makeShape({numSequences}));
|
||||
engineInputs.inputGenTokensHost->reshape(ITensor::makeShape({numSequences * maxDecodingTokens}));
|
||||
|
||||
cumSumGenerationLengths->reshape(ITensor::makeShape({numSequences + 1}));
|
||||
|
||||
@ -260,6 +264,7 @@ void EagleBuffers::setFromInputs(SizeType32 numCtxSequences, SizeType32 numGenSe
|
||||
|
||||
// Pack host data.
|
||||
SizeType32 maxGenerationLengthHostValue{-1};
|
||||
SizeType32 numGenerationTokens{0};
|
||||
for (SizeType32 bi = 0; bi < params.batchSize; ++bi)
|
||||
{
|
||||
auto const batchSlot = params.batchSlots[bi];
|
||||
@ -276,8 +281,10 @@ void EagleBuffers::setFromInputs(SizeType32 numCtxSequences, SizeType32 numGenSe
|
||||
bufferCast<SizeType32>(*engineInputs.eagleNetGenPastKeyValueLengthsHost)[bi]
|
||||
= bufferCast<SizeType32>(*draftBuffers.eagleNetGenPastKeyValueLengthsHost)[batchSlot];
|
||||
|
||||
maxGenerationLengthHostValue = std::max(maxGenerationLengthHostValue,
|
||||
bufferCast<SizeType32>(*draftBuffers.specDecodingGenerationLengthsHost)[batchSlot]);
|
||||
auto const generationLength
|
||||
= bufferCast<SizeType32>(*draftBuffers.specDecodingGenerationLengthsHost)[batchSlot];
|
||||
maxGenerationLengthHostValue = std::max(maxGenerationLengthHostValue, generationLength);
|
||||
numGenerationTokens += generationLength;
|
||||
}
|
||||
|
||||
if (maxGenerationLengthHostValue <= 0)
|
||||
@ -289,6 +296,10 @@ void EagleBuffers::setFromInputs(SizeType32 numCtxSequences, SizeType32 numGenSe
|
||||
specDecodingPositionOffsetsShape.d[1] = maxGenerationLengthHostValue;
|
||||
engineInputs.specDecodingPositionOffsets->reshape(specDecodingPositionOffsetsShape);
|
||||
|
||||
auto inputGenTokensHostShape = engineInputs.inputGenTokensHost->getShape();
|
||||
inputGenTokensHostShape.d[0] = numGenerationTokens;
|
||||
engineInputs.inputGenTokensHost->reshape(inputGenTokensHostShape);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -350,6 +361,7 @@ void EagleBuffers::insertInputTensors(
|
||||
inputBuffers.insert_or_assign("host_gen_eagle_net_context_lengths", engineInputs.eagleNetGenContextLengthsHost);
|
||||
inputBuffers.insert_or_assign(
|
||||
"host_gen_eagle_net_past_key_value_lengths", engineInputs.eagleNetGenPastKeyValueLengthsHost);
|
||||
inputBuffers.insert_or_assign("input_gen_tokens", engineInputs.inputGenTokensHost);
|
||||
|
||||
// outputs
|
||||
outputBuffers.insert_or_assign("next_draft_tokens", engineOutputs.nextDraftTokens);
|
||||
|
||||
@ -28,15 +28,16 @@ class EagleModule : public SpeculativeDecodingModule
|
||||
public:
|
||||
// Number of paths is maxDecodingTokens = maxDecodingDraftTokens + 1 to account for very flat trees with
|
||||
// depth 1.
|
||||
explicit EagleModule(
|
||||
SizeType32 maxDraftPathLen, SizeType32 maxDecodingDraftTokens, SizeType32 numTransformersLayer) noexcept
|
||||
explicit EagleModule(SizeType32 maxDraftPathLen, SizeType32 maxDecodingDraftTokens, SizeType32 numTransformersLayer,
|
||||
SizeType32 maxNonLeafNodesPerLayer) noexcept
|
||||
: SpeculativeDecodingModule(maxDraftPathLen, maxDecodingDraftTokens, maxDecodingDraftTokens + 1)
|
||||
, mNumTransformersLayer(numTransformersLayer)
|
||||
, mMaxNonLeafNodesPerLayer(maxNonLeafNodesPerLayer)
|
||||
{
|
||||
}
|
||||
|
||||
explicit EagleModule() noexcept
|
||||
: EagleModule(0, 0, 0)
|
||||
: EagleModule(0, 0, 0, 0)
|
||||
{
|
||||
}
|
||||
|
||||
@ -50,8 +51,14 @@ public:
|
||||
return mNumTransformersLayer;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getMaxNonLeafNodesPerLayer() const noexcept
|
||||
{
|
||||
return mMaxNonLeafNodesPerLayer;
|
||||
}
|
||||
|
||||
private:
|
||||
SizeType32 mNumTransformersLayer;
|
||||
SizeType32 mMaxNonLeafNodesPerLayer;
|
||||
|
||||
// We use mc_sim_7b_63 from official Medusa implementation, i.e. one of the best trees with 63 nodes found for 7B
|
||||
// Vicuna model. We use it as default, if no other are trees are specified per request or on the server level.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user