Update TensorRT-LLM (#2333)

* Update TensorRT-LLM

---------

Co-authored-by: Puneesh Khanna <puneesh.khanna@tii.ae>
Co-authored-by: Ethan Zhang <26497102+ethnzhng@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2024-10-15 15:28:40 +08:00 committed by GitHub
parent 8681b3a4c0
commit 75057cd036
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
251 changed files with 8119 additions and 1528 deletions

View File

@ -8,7 +8,7 @@ TensorRT-LLM
[![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-12.5.1-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-10.4.0-green)](https://developer.nvidia.com/tensorrt)
[![version](https://img.shields.io/badge/release-0.14.0.dev-green)](./tensorrt_llm/version.py)
[![version](https://img.shields.io/badge/release-0.15.0.dev-green)](./tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)
[Architecture](./docs/source/architecture/overview.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Results](./docs/source/performance/perf-overview.md)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Examples](./examples/)&nbsp;&nbsp;&nbsp;|&nbsp;&nbsp;&nbsp;[Documentation](./docs/source/)
@ -17,11 +17,14 @@ TensorRT-LLM
<div align="left">
## Latest News
* [2024/10/07] 🚀🚀🚀Optimizing Microsoft Bing Visual Search with NVIDIA Accelerated Libraries
[➡️ link](https://developer.nvidia.com/blog/optimizing-microsoft-bing-visual-search-with-nvidia-accelerated-libraries/)
<div align="center">
<img src="docs/source/media/image-10-07-2024.png" width="50%">
<div align="left">
* [2024/09/29] 🌟 AI at Meta PyTorch + TensorRT v2.4 🌟 ⚡TensorRT 10.1 ⚡PyTorch 2.4 ⚡CUDA 12.4 ⚡Python 3.12
[➡️ link](https://github.com/pytorch/TensorRT/releases/tag/v2.4.0)
<div align="center">
<img src="docs/source/media/image-09-29-2024.png" width="50%">
<div align="left">
* [2024/09/17] ✨ NVIDIA TensorRT-LLM Meetup
[➡️ link](https://drive.google.com/file/d/1RR8GqC-QbuaKuHj82rZcXb3MS20SWo6F/view?usp=share_link)

View File

@ -426,6 +426,7 @@ public:
void initialize()
{
mStart = std::chrono::steady_clock::now();
mRequestsQueueingLatencies.clear();
}
void finalize()
@ -433,6 +434,11 @@ public:
mEnd = std::chrono::steady_clock::now();
}
void recordQueueLatency(std::vector<float> const& latencies)
{
mRequestsQueueingLatencies.insert(mRequestsQueueingLatencies.end(), latencies.begin(), latencies.end());
}
void recordStart(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
{
auto const inputLength = request->getInputIds()->getSize();
@ -677,6 +683,16 @@ public:
mMaxGenT2TLatency = genT2TLatencies.back();
mMinGenT2TLatency = genT2TLatencies.front();
}
mAvgReqQueueingLatency
= std::accumulate(mRequestsQueueingLatencies.begin(), mRequestsQueueingLatencies.end(), 0.F)
/ mRequestsQueueingLatencies.size();
std::sort(mRequestsQueueingLatencies.begin(), mRequestsQueueingLatencies.end());
mP99ReqQueueingLatency = calcPercentile(mRequestsQueueingLatencies, 99);
mP90ReqQueueingLatency = calcPercentile(mRequestsQueueingLatencies, 90);
mP50ReqQueueingLatency = calcPercentile(mRequestsQueueingLatencies, 50);
mMaxReqQueueingLatency = mRequestsQueueingLatencies.back();
mMinReqQueueingLatency = mRequestsQueueingLatencies.front();
}
}
@ -713,6 +729,13 @@ public:
printf("[BENCHMARK] p99_inter_token_latency(ms) %.2f\n", mP99GenT2TLatency);
printf("[BENCHMARK] p90_inter_token_latency(ms) %.2f\n", mP90GenT2TLatency);
printf("[BENCHMARK] p50_inter_token_latency(ms) %.2f\n\n", mP50GenT2TLatency);
printf("[BENCHMARK] avg_request_queueing_latency(ms) %.2f\n", mAvgReqQueueingLatency);
printf("[BENCHMARK] max_request_queueing_latency(ms) %.2f\n", mMaxReqQueueingLatency);
printf("[BENCHMARK] min_request_queueing_latency(ms) %.2f\n", mMinReqQueueingLatency);
printf("[BENCHMARK] p99_request_queueing_latency(ms) %.2f\n", mP99ReqQueueingLatency);
printf("[BENCHMARK] p90_request_queueing_latency(ms) %.2f\n", mP90ReqQueueingLatency);
printf("[BENCHMARK] p50_request_queueing_latency(ms) %.2f\n\n", mP50ReqQueueingLatency);
}
}
@ -820,6 +843,13 @@ private:
float mP50GenT2TLatency{};
float mMaxGenT2TLatency{};
float mMinGenT2TLatency{};
float mAvgReqQueueingLatency{};
float mP99ReqQueueingLatency{};
float mP90ReqQueueingLatency{};
float mP50ReqQueueingLatency{};
float mMaxReqQueueingLatency{};
float mMinReqQueueingLatency{};
std::vector<float> mRequestsQueueingLatencies{};
std::string mOpCsvFile;
bool mStreaming;
@ -846,6 +876,7 @@ public:
, mActiveCount(0)
, mNumFinished(0)
, mShutdown(false)
, mLogIterationData(logIterationData)
{
texec::SchedulerConfig schedulerConfig(capacitySchedulerPolicy);
@ -899,7 +930,9 @@ public:
TLLM_LOG_ERROR("not a supported executor model type in executor server.");
}
if (logIterationData)
auto const& world = tensorrt_llm::mpi::MpiComm::world();
auto worldRank = world.getRank();
if (worldRank == 0)
{
mCollectStatsThread = std::thread(&ExecutorServer::collectStats, this);
}
@ -988,7 +1021,18 @@ public:
auto iterStats = mExecutor->getLatestIterationStats();
for (auto const& iterStat : iterStats)
{
TLLM_LOG_INFO(texec::JsonSerialization::toJsonStr(iterStat));
SizeType32 numNewActiveRequests = iterStat.numNewActiveRequests;
if (numNewActiveRequests > 0)
{
float avgQueueingTime
= static_cast<float>(iterStat.newActiveRequestsQueueLatencyMS / numNewActiveRequests);
std::vector<float> requestsQueueLatencyMS(numNewActiveRequests, avgQueueingTime);
mRecorder->recordQueueLatency(requestsQueueLatencyMS);
}
if (mLogIterationData)
{
TLLM_LOG_INFO(texec::JsonSerialization::toJsonStr(iterStat));
}
}
auto const waitSleep = std::chrono::milliseconds(50);
std::this_thread::sleep_for(waitSleep);
@ -1005,6 +1049,7 @@ private:
std::atomic<uint64_t> mActiveCount;
std::atomic<uint64_t> mNumFinished;
std::atomic<bool> mShutdown;
bool mLogIterationData;
}; // class ExecutorServer
class GptServer

View File

@ -201,6 +201,7 @@ public:
, mDecodingIter(0)
, mPriority(req.getPriority())
, mFinishReasons(mSamplingConfig.beamWidth)
, mEncoderInputFeatures(std::nullopt)
, mEncoderOutputLength(req.getEncoderOutputLength())
, mContextPhaseParams(req.getContextPhaseParams())
, mInputTokenExtraIds(std::nullopt)
@ -263,7 +264,8 @@ public:
auto pTuningConfig = req.getPromptTuningConfig();
if (pTuningConfig)
{
mPromptEmbeddingTable = executor::detail::toITensor(pTuningConfig.value().getEmbeddingTable());
mPromptEmbeddingTable = tensorrt_llm::runtime::ITensor::view(
executor::detail::toITensor(pTuningConfig.value().getEmbeddingTable()));
TLLM_CHECK(mPromptEmbeddingTable.value()->getShape().nbDims == 2);
mPromptVocabSize = mPromptEmbeddingTable.value()->getShape().d[0];
mPromptEmbeddingTable.value()->unsqueeze(0);
@ -1438,6 +1440,36 @@ public:
0.0, std::chrono::duration<double, std::milli>(mKvCacheTransferEnd - mKvCacheTransferStart).count());
}
void updateAllocTotalBlocksPerRequest(SizeType32 allocTotalBlocksPerRequest)
{
mAllocTotalBlocksPerRequest += allocTotalBlocksPerRequest;
}
[[nodiscard]] SizeType32 getAllocTotalBlocksPerRequest() const
{
return mAllocTotalBlocksPerRequest;
}
void updateAllocNewBlocksPerRequest(SizeType32 allocNewBlocksPerRequest)
{
mAllocNewBlocksPerRequest += allocNewBlocksPerRequest;
}
[[nodiscard]] SizeType32 getAllocNewBlocksPerRequest() const
{
return mAllocNewBlocksPerRequest;
}
void updateReusedBlocksPerRequest(SizeType32 reusedBlocksPerRequest)
{
mReusedBlocksPerRequest += reusedBlocksPerRequest;
}
[[nodiscard]] SizeType32 getReusedBlocksPerRequest() const
{
return mReusedBlocksPerRequest;
}
RequestIdType mRequestId;
SizeType32 mPromptLen;
SizeType32 mMaxNewTokens;
@ -1545,6 +1577,10 @@ protected:
std::chrono::time_point<std::chrono::steady_clock> mKvCacheTransferStart;
std::chrono::time_point<std::chrono::steady_clock> mKvCacheTransferEnd;
SizeType32 mAllocTotalBlocksPerRequest{0};
SizeType32 mAllocNewBlocksPerRequest{0};
SizeType32 mReusedBlocksPerRequest{0};
private:
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
{

View File

@ -297,6 +297,8 @@ struct IterationStats
double iterLatencyMS;
/// @brief The total time spent in queue by the requests that became active in this iteration (ms)
double newActiveRequestsQueueLatencyMS;
/// @brief Number of new fetched active requests
SizeType32 numNewActiveRequests;
/// @brief Number of active requests
SizeType32 numActiveRequests;
/// @brief Number of queued requests
@ -364,6 +366,12 @@ struct RequestStats
bool paused;
/// @brief Stats specific to disaggregated serving
std::optional<DisServingRequestStats> disServingStats;
/// @brief Number of total allocated blocks per request
SizeType32 allocTotalBlocksPerRequest;
/// @brief Number of newly allocated blocks per request
SizeType32 allocNewBlocksPerRequest;
/// @brief Number of reused blocks per request
SizeType32 reusedBlocksPerRequest;
};
/// @brief Struct that holds the stats of all requests in an iteration

View File

@ -115,7 +115,6 @@ public:
std::optional<SizeType32> genMicroBatchSize = std::nullopt;
std::optional<executor::DecodingMode> decodingMode = std::nullopt;
bool normalizeLogProbs = true;
std::optional<std::filesystem::path> enginePath;
};
//! @brief Optional profiler class to profile the generation phase of an inference request

View File

@ -127,6 +127,7 @@ public:
, mContextFMHA(false)
, mPagedContextFMHA(false)
, mUseXQA{false}
, mPpReduceScatter{false}
, mUseLoraPlugin(false)
, mMlpHiddenSize(0)
, mUseCrossAttention(false)
@ -468,6 +469,16 @@ public:
return mUseXQA;
}
void constexpr setPpReduceScatter(bool ppReduceScatter) noexcept
{
mPpReduceScatter = ppReduceScatter;
}
[[nodiscard]] bool constexpr getPpReduceScatter() const noexcept
{
return mPpReduceScatter;
}
[[nodiscard]] bool constexpr useLoraPlugin() const noexcept
{
return mUseLoraPlugin;
@ -759,6 +770,7 @@ private:
bool mContextFMHA;
bool mPagedContextFMHA;
bool mUseXQA;
bool mPpReduceScatter;
bool mUseLoraPlugin;
std::vector<LoraModule> mLoraModules;

View File

@ -50,6 +50,11 @@ public:
return SpeculativeDecodingMode{kExplicitDraftTokens};
}
static auto constexpr Eagle()
{
return SpeculativeDecodingMode{kEagle};
}
[[nodiscard]] bool constexpr isNone() const
{
return anyBitSet(kNone);
@ -75,29 +80,34 @@ public:
return anyBitSet(kExplicitDraftTokens);
}
[[nodiscard]] bool constexpr isEagle() const
{
return anyBitSet(kEagle);
}
[[nodiscard]] bool constexpr updatesPositionIds() const
{
return anyBitSet(kLookaheadDecoding | kExplicitDraftTokens);
return anyBitSet(kLookaheadDecoding | kExplicitDraftTokens | kEagle);
}
[[nodiscard]] bool constexpr requiresAttentionMask() const
{
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens);
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens | kEagle);
}
[[nodiscard]] bool constexpr predictsDraftTokens() const
{
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens);
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens | kEagle);
}
[[nodiscard]] bool constexpr needsKVCacheRewind() const
{
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens);
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens | kEagle);
}
[[nodiscard]] bool constexpr variableDraftLength() const
{
return anyBitSet(kDraftTokensExternal | kExplicitDraftTokens | kLookaheadDecoding);
return anyBitSet(kDraftTokensExternal | kExplicitDraftTokens | kLookaheadDecoding | kEagle);
}
[[nodiscard]] bool constexpr hasDraftLogits() const
@ -107,7 +117,7 @@ public:
[[nodiscard]] bool constexpr needsDecoderPrologue() const
{
return anyBitSet(kExplicitDraftTokens | kLookaheadDecoding);
return anyBitSet(kExplicitDraftTokens | kLookaheadDecoding | kEagle);
}
using UnderlyingType = std::uint8_t;
@ -129,6 +139,7 @@ private:
static UnderlyingType constexpr kMedusa{1U << 2U};
static UnderlyingType constexpr kLookaheadDecoding{1U << 3U};
static UnderlyingType constexpr kExplicitDraftTokens{1U << 4U};
static UnderlyingType constexpr kEagle{1U << 5U};
[[nodiscard]] bool constexpr anyBitSet(UnderlyingType bits) const
{
@ -173,4 +184,11 @@ static_assert(!SpeculativeDecodingMode::ExplicitDraftTokens().isDraftTokensExter
static_assert(!SpeculativeDecodingMode::ExplicitDraftTokens().isMedusa());
static_assert(!SpeculativeDecodingMode::ExplicitDraftTokens().isLookaheadDecoding());
static_assert(SpeculativeDecodingMode::Eagle().isEagle());
static_assert(!SpeculativeDecodingMode::Eagle().isNone());
static_assert(!SpeculativeDecodingMode::Eagle().isDraftTokensExternal());
static_assert(!SpeculativeDecodingMode::Eagle().isMedusa());
static_assert(!SpeculativeDecodingMode::Eagle().isExplicitDraftTokens());
static_assert(!SpeculativeDecodingMode::Eagle().isLookaheadDecoding());
} // namespace tensorrt_llm::runtime

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1a292517d802f2297c5d12d5d14ab597f47f46ebd31412fac044ceb9ca51a482
size 5160586
oid sha256:a55035628e0035141b4ea79b946f49ad77893d6e5d1ab47c402e1a9b95fbbb6c
size 5160128

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8575fb58200701ae30feb4b8bd3f325f8018aac5505167fdba42e269adb3bd8c
size 5271836
oid sha256:ed219fad83caf000a40f0688fdb20cb8593a5fe8096316d645229ee160c42514
size 5271480

View File

@ -1,3 +1,3 @@
954182e0c057f71f858a84f746201044 libtensorrt_llm_batch_manager_static.a
dfe6ca360cf1d24a3dcae0a2bf8589c0 libtensorrt_llm_batch_manager_static.pre_cxx11.a
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
d7508bec7b6f112a2eac04cbeaf8b5da libtensorrt_llm_batch_manager_static.a
d8969624b327af844d9ffba910084b93 libtensorrt_llm_batch_manager_static.pre_cxx11.a
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8fe84073b7ccff8dc361fdee64c3ef30bc523909e0bf9c16547f76a05a53fb5c
size 5009886
oid sha256:36479d1577d131e36ca03549467a6cfe4822868ca0f3dda3b5d254ee4680341f
size 5009646

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6e565c2c3ce58656742772591d992aca91c7e46eb9fc711599d2d51928b88b48
size 4970532
oid sha256:b5caef410133f1552418978aa20cc1d3f7b6500b1dbc8b9f44232554b7cc8390
size 4971234

View File

@ -1,3 +1,3 @@
61fd34e765788884d42f4ba27f085520 libtensorrt_llm_batch_manager_static.a
e8a64dd19a234304483ef6756e67fd40 libtensorrt_llm_batch_manager_static.pre_cxx11.a
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
7029ee9cb0a921a3603e98815da18985 libtensorrt_llm_batch_manager_static.a
0e7fe69b6621fe6dabcc0b372c3440f4 libtensorrt_llm_batch_manager_static.pre_cxx11.a
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:200a6721aa1d6e009c94866adab36ac686eb1beef02df267af7e18e31e11612b
size 32436708
oid sha256:b86e215e86c7b0f8b0c9618fb655e6e4f31cc731f778cf0ca12fde93c7afbcab
size 32389592

View File

@ -1,2 +1,2 @@
9485cfa635b17378f23d1624b3acfbaf tensorrt_llm_batch_manager_static.lib
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
afac175cfda36b14d76e17517bad8b24 tensorrt_llm_batch_manager_static.lib
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit

View File

@ -92,7 +92,7 @@ template <
typename Policy_,
/// Number of stages,
int Stages,
/// Converter for B matrix applited immediately after the LDS
/// Converter for B matrix applied immediately after the LDS
typename TransformBAfterLDS_,
/// The quantization operator being used
WeightOnlyQuantOp QuantOp_,

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:809a1da76123ec4c640d63efc902209585223b66e23d887db9a198c5836986a2
size 3349066
oid sha256:414606be5b56f592fc7bd25f1e9fbf958c900dd2b01e01907029dfe19408ce59
size 3349230

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6846ecefa017d03ab7d853908794c884ab4e92a500e223278b1d64eab59ed061
size 3376088
oid sha256:682cf952def054fce6116983a3b5686994b71744fcc85a65e3c9a6e44549c82d
size 3377832

View File

@ -1,3 +1,3 @@
5a771664fdb75d99ba5fb90249ac26f0 libtensorrt_llm_executor_static.a
3b433ea93b7d1d6fa471b457980f2680 libtensorrt_llm_executor_static.pre_cxx11.a
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
dc9b4081af6357227886180a1b9a6d8d libtensorrt_llm_executor_static.a
8291552cf3e8da9dc368c8c37cd35abe libtensorrt_llm_executor_static.pre_cxx11.a
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:479e86f410763445357f5d879cc666d210352dda9709ab5ab56e73591a9e8af8
size 7851266
oid sha256:88810c1dac205a1111fc833c0fe0d38486152b4b878fd972585eec2ac27d5160
size 7857242

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6473c77d18929fa75342d63ffc591df39e8aeba1dda0b920b0187d4888710559
size 7767384
oid sha256:c023d6bad569fb3b3c528f3e003afa6a5f11a045bdccb06ca875607a6c781ade
size 7769728

View File

@ -1,3 +1,3 @@
5424fb0f82076e03b5316f73aed04434 libtensorrt_llm_executor_static.a
d0b1236baf61fc5c43383bbc1cd50fa8 libtensorrt_llm_executor_static.pre_cxx11.a
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
fd9cb10c300350266f65957475404bff libtensorrt_llm_executor_static.a
b8b0ae2861ef66853330441752ab1e32 libtensorrt_llm_executor_static.pre_cxx11.a
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:dee57c9257a6678833e3c0d83e8df07aff25c185bc085db75938cec6652044c0
size 24568210
oid sha256:baf4dd1bacd75c4eae6d98fe411bbb5d478dc5905a298d4238db3db21121ebca
size 24630026

View File

@ -1,2 +1,2 @@
305fac5d046a574ded2d46d968f746b0 tensorrt_llm_executor_static.lib
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
30d62c80211e4a2dc38bbe9dc5257839 tensorrt_llm_executor_static.lib
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit

View File

@ -74,7 +74,6 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe
int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks<GemmType>());
int const threadblock_count = multi_processor_count * occupancy;
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel");
GemmType gemm;
using Arguments = typename GemmType::Arguments;
Arguments args{{const_cast<ElementType_*>(A), const_cast<CutlassWeightType_*>(B), const_cast<ElementType_*>(biases),
reinterpret_cast<ElementType_*>(C), total_tokens_including_expert, static_cast<int>(gemm_n),

View File

@ -1,2 +1,2 @@
88c30973b9b3452baa3f063d34d08169 libtensorrt_llm_nvrtc_wrapper.so
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit

View File

@ -1,2 +1,2 @@
95e9f87610383348e444d2d0b8396f2d libtensorrt_llm_nvrtc_wrapper.so
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:db512d533ab4e4a4abd0047a65d891dfd6e1522f2d34c90f29296c3239fd3cc1
oid sha256:3bc495e1e677616db2756eb7d56d1161c34ae723896db34487883a955e2b3442
size 1128448

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e207a8f57b944529163c7ed2ab30639a5f2779c5118602c6ebd50a623d16f845
oid sha256:1a6c03470aaa69378d4989971ab9dd00ee427f7e14a85ba5e114ea0594c4de5e
size 3488

View File

@ -1,3 +1,3 @@
b7e624ba775e9f5090ef4b67bcdbd7a2 tensorrt_llm_nvrtc_wrapper.lib
d89a0a140d2d427af13c3794a4b21e2c tensorrt_llm_nvrtc_wrapper.dll
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
c5f36e093e875c8ea84523fb1566d986 tensorrt_llm_nvrtc_wrapper.lib
de4b2f87f8eb1027f89c0f5cb05ca047 tensorrt_llm_nvrtc_wrapper.dll
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0814af36fed752bbe70d953cefbb78dd306c42f3d9f6848b7043a865e48f9662
oid sha256:80dbb6e3a34380bf4e375901ad9b71df24ec97cddcaa9f226bc0a278d11cbdd6
size 25364090

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ee46f2d1c9162f4302a1031f778fcb7c7110c84110427f97af6532ed9bd342fd
oid sha256:31e5cd6ef9e3599d55501ab0484b81f82ef1f22a79360a2699cd4a62c4928115
size 25768990

View File

@ -1,3 +1,3 @@
90740ead1def66f350e14c133278463d libtensorrt_llm_internal_cutlass_kernels_static.a
b0104227ffd1ce19fc1fdb45e349df36 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
1febd9d1bf244163deb269e2bebcd1e3 libtensorrt_llm_internal_cutlass_kernels_static.a
8fdb39f871225dedd32ca6651f1944ba libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4d9ba0f8b95cf64227cb0b17654fb7c9bc1741fe003889658b305750b388a4dc
oid sha256:3431f91bcb2cadb8a2641c4ea54d1f8f90c5aa7648591510e3a27865c94169ea
size 44173632

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4f848d5beebbd69792047a96b16f7145f8e1e3e311d2a19789ce639ad8149b0e
oid sha256:1dedd4dd1df76a57576e749b4105a5d5f5070a6f7ee30d11944105742fea9b4b
size 43561206

View File

@ -1,3 +1,3 @@
2aaf05cb84f52b024e89d4fa634d6900 libtensorrt_llm_internal_cutlass_kernels_static.a
f17ce186e9105c594e39d252777ce4c7 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
8683b15e77bf62ee9f57a2507e21e6a7 libtensorrt_llm_internal_cutlass_kernels_static.a
a065a7b6a11b079ee544664dddcf59a6 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c429687e335c75f08186bcd8f629b50467cb0f2e484d755834c5b1cdbb9ecaf3
size 88140796
oid sha256:c7afdf2c313685b0e31f4e5572e20cd11d94227177849784ce7405e15a3587f6
size 88140804

View File

@ -1,2 +1,2 @@
4f663be2b768088805ccec6dc33545fc tensorrt_llm_internal_cutlass_kernels_static.lib
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
7eee845e969cfb8d589074d81288b700 tensorrt_llm_internal_cutlass_kernels_static.lib
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit

View File

@ -48,7 +48,7 @@ __global__ void topKStage1(T const* __restrict logProbs, T const* const* __restr
auto const tokenIdx = static_cast<SizeType32>(blockIdx.y);
auto const batchId = bid / BLOCKS_PER_BEAM_; // row id for logProbs
auto const batchSlot = batchSlots[batchId];
auto const batchSlot = batchSlots == nullptr ? batchId : batchSlots[batchId];
if (tokensPerStep != nullptr && tokenIdx >= tokensPerStep[batchSlot])
{
return;
@ -63,7 +63,6 @@ __global__ void topKStage1(T const* __restrict logProbs, T const* const* __restr
auto const logBufIndex = batchId * maxTokensPerStep * vocabSize + tokenIdx * vocabSize;
auto logProbsSlot
= logProbsPtrs == nullptr ? logProbs + logBufIndex : logProbsPtrs[batchId * maxTokensPerStep + tokenIdx];
auto const blockLane = bid % BLOCKS_PER_BEAM_; // block id for a beam
auto const k = (topKs != nullptr) ? topKs[batchSlot] : maxTopK; // batchId = batch index
@ -77,7 +76,7 @@ __global__ void topKStage1(T const* __restrict logProbs, T const* const* __restr
if (finished != nullptr && finishState.isFinished())
{
if (tid < k)
if (tid < k && endIds != nullptr) // if returnAllSelectedToken, endIds would not be an input
{
auto const index = tmpTopKBufIndex + tid;
if (blockLane == 0 && tid == 0)
@ -134,7 +133,7 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
float const* topPs, curandState_t* curandState, TokenIdType const* endIds, SizeType32 vocabSize,
bool const* skipDecode, SizeType32 const* batchSlots, SizeType32 maxBatchSize, bool normalizeLogProbs,
bool logitHasProbs, SizeType32 const* tokensPerStep, SizeType32 maxTokensPerStep, SizeType32 maxSeqLen,
bool returnAllTopK)
bool returnAllSelectedTokens)
{
bool const IS_FP16 = std::is_same<T, half>::value;
T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
@ -142,7 +141,7 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
auto const tid = static_cast<SizeType32>(threadIdx.x);
auto const batchIdx = static_cast<SizeType32>(blockIdx.x);
auto const tokenIdx = static_cast<SizeType32>(blockIdx.y);
auto const batchSlot = batchSlots[batchIdx];
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
FinishedState const finishState = finishedInput != nullptr ? finishedInput[batchSlot] : FinishedState::empty();
if ((skipDecode != nullptr && skipDecode[batchSlot]) || (finishState.isSkipDecoding()))
{
@ -215,13 +214,16 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
if (tid == 0)
{
auto randNum = static_cast<float>(curand_uniform(curandState + batchSlot) * probThreshold * sSum);
// if we want to return all top k indices, we should not do random sampling for probThreshold
auto randNum = (returnAllSelectedTokens || curandState == nullptr)
? static_cast<float>(probThreshold * sSum)
: static_cast<float>(curand_uniform(curandState + batchSlot) * probThreshold * sSum);
auto* outputIdsRequestPtr = idsPtrs == nullptr ? ids + batchSlot * maxSeqLen : idsPtrs[batchSlot];
for (SizeType32 ki = 0; ki < k; ki++)
{
auto expLogit = sVal2[ki];
randNum = randNum - expLogit;
if (randNum <= 0.0f || ki == k - 1 || returnAllTopK)
if (randNum <= 0.0f || ki == k - 1 || returnAllSelectedTokens)
{
auto idx = sId[ki];
// If sId is -1 here we force output token to the last from vocabulary to get vivid indicator of smth
@ -230,10 +232,10 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
? topKTmpIdBuf[(batchIdx * maxTokensPerStep + tokenIdx) * stride + idx] % vocabSize
: vocabSize - 1;
auto const curSeqLen = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot];
auto const outIdx = returnAllTopK ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx;
auto const outIdx = returnAllSelectedTokens ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx;
outputIdsRequestPtr[outIdx] = outputId;
// cum log prob is not supported with returnAllTopK
if (!returnAllTopK)
// cum log prob is not supported with returnAllSelectedTokens
if (!returnAllSelectedTokens)
{
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
{
@ -255,9 +257,17 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
}
break;
}
if (returnAllSelectedTokens && randNum <= 0.0f)
{
if (ki < k - 1)
{ // not the last k, write a -1 to to log top p tokens boundary for external draft token masking
outputIdsRequestPtr[outIdx + 1] = -1;
}
break;
}
}
}
if (maxTokensPerStep == 1 && !returnAllTopK && sequenceLengths != nullptr && finishedOutput != nullptr
if (maxTokensPerStep == 1 && !returnAllSelectedTokens && sequenceLengths != nullptr && finishedOutput != nullptr
&& endIds != nullptr)
{
auto const seqLen = sequenceLengths[batchSlot];
@ -297,7 +307,7 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
params.maxTopK, params.topKs, params.maxTopP, params.topPs, params.curandState, params.endIds, \
params.vocabSizePadded, params.skipDecode, params.batchSlots, params.maxBatchSize, \
params.normalizeLogProbs, params.logitsHasProbs, params.tokensPerStep, params.maxTokensPerStep, \
params.maxSeqLen, params.returnAllTopK); \
params.maxSeqLen, params.returnAllSelectedTokens); \
} \
} while (0)

View File

@ -34,8 +34,8 @@ struct TopKSamplingKernelParams
//! Log probabilities of each token in the vocab. If logitsHasProbs is true,
//! logProbs must contain **just** probabilities instead of log probabilities.
T const* logProbs{nullptr};
//! input buffer [batchSize][vocabSizePadded] array of pointers to logits.
//! If nullptr, logProbs is used. Only maxTokensPerStep == 1 is supported.
//! input buffer [batchSize][tokensPerStep, vocabSizePadded] array of pointers to logits.
//! If nullptr, logProbs is used.
T const* const* logProbsPtrs{nullptr};
//! output buffer [maxBatchSize][maxSeqLen], optional. Contains pointers to rows
@ -82,7 +82,8 @@ struct TopKSamplingKernelParams
//! Ignored if nullptr.
float* outputLogProbs{nullptr};
//! input buffer [maxBatchSize]. Initialized curand states
//! input buffer [maxBatchSize], optional. Initialized curand states.
//! If nullptr, 1 is always used.
curandState_t* curandState{nullptr};
//! input buffer [maxBatchSize]. K for topK sampling per request.
//! Supported K is in range [1; 1024]. Where K=1 is greedy search.
@ -106,8 +107,8 @@ struct TopKSamplingKernelParams
bool normalizeLogProbs{false};
//! flag to highlight that logProbs contains probabilities
bool logitsHasProbs{false};
//! flag to return all selectedTopK results
bool returnAllTopK{false};
//! flag to return all selected TopK results
bool returnAllSelectedTokens{false};
void checkParams() const
{
@ -131,13 +132,12 @@ struct TopKSamplingKernelParams
}
TLLM_CHECK(workspace);
TLLM_CHECK(curandState);
TLLM_CHECK(maxTokensPerStep != 1 || returnAllTopK || sequenceLengths);
TLLM_CHECK(maxTokensPerStep != 1 || returnAllTopK || endIds);
TLLM_CHECK(maxTokensPerStep != 1 || returnAllSelectedTokens || sequenceLengths);
TLLM_CHECK(maxTokensPerStep != 1 || returnAllSelectedTokens || endIds);
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
{
TLLM_CHECK(maxTokensPerStep == 1 && !returnAllTopK);
TLLM_CHECK(maxTokensPerStep == 1 && !returnAllSelectedTokens);
}
TLLM_CHECK(((finishedOutput == nullptr) ^ (endIds == nullptr)) == 0);

View File

@ -200,7 +200,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
SizeType32* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, SizeType32 const* beginOffsetBuf, SizeType32 const* offsetBuf, SizeType32 vocabSize,
curandState_t* curandState, float const* topPs, TokenIdType const* endIds, SizeType32 maxBatchSize,
bool const* skipDecode, SizeType32 const* batchSlots, bool returnAllTopP, SizeType32 maxSeqLen)
bool const* skipDecode, SizeType32 const* batchSlots, bool returnAllSelectedTokens, SizeType32 maxSeqLen)
{
/**
* Each block processes one request row sorted in descending order by probabilities.
@ -244,7 +244,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
if (threadIdx.x == 0)
{
// if we want to return all top p indices, we should not do random sampling for probThreshold
randNumS = returnAllTopP ? probThreshold : curand_uniform(curandState + blockIdx.x) * probThreshold;
randNumS = returnAllSelectedTokens ? probThreshold : curand_uniform(curandState + blockIdx.x) * probThreshold;
}
// if beginOffsetBuf and offsetBuf of sorting have same value,
@ -255,7 +255,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
if (tid == 0)
{
auto offset = batchId * vocabSize;
if (returnAllTopP)
if (returnAllSelectedTokens)
{
outputIdsRequestPtr[currentStep] = sortedIdVals[offset];
}
@ -294,7 +294,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
}
}
if (returnAllTopP)
if (returnAllSelectedTokens)
{
__shared__ SizeType32 sharedSelectedTokenId;
if (threadIdx.x == min(blockDim.x - count, blockDim.x - 1))
@ -403,7 +403,7 @@ void invokeBatchTopPSampling(TopPSamplingKernelParams<T> const& params, cudaStre
params.outputIds, params.outputIdsPtrs, params.sequenceLength, params.finishedInput, params.finishedOutput,
params.cumLogProbs, params.outputLogProbs, beginOffsetBuf, offsetBuf + 1, params.vocabSizePadded,
params.curandState, params.topPs, params.endIds, params.maxBatchSize, params.skipDecode, params.batchSlots,
params.returnAllTopP, params.maxSeqLen);
params.returnAllSelectedTokens, params.maxSeqLen);
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);

View File

@ -80,7 +80,7 @@ struct TopPSamplingKernelParams
runtime::SizeType32 vocabSizePadded{-1};
runtime::SizeType32 maxSeqLen{-1};
bool returnAllTopP{false};
bool returnAllSelectedTokens{false};
void checkParams() const
{
@ -91,7 +91,7 @@ struct TopPSamplingKernelParams
TLLM_CHECK(probs);
TLLM_CHECK(outputIds || outputIdsPtrs);
TLLM_CHECK(workspace);
TLLM_CHECK((sequenceLength != nullptr) || returnAllTopP);
TLLM_CHECK((sequenceLength != nullptr) || returnAllSelectedTokens);
TLLM_CHECK(curandState);
TLLM_CHECK(topPs);

View File

@ -0,0 +1,136 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
#include "tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h"
#ifndef CUDART_VERSION
#error CUDART_VERSION Undefined!
#elif (CUDART_VERSION >= 11050)
#include <cub/cub.cuh>
#else
#include "3rdparty/cub/cub.cuh"
#endif
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm::kernels::speculative_decoding
{
namespace
{
template <typename T, int BLOCK_SIZE>
__global__ void assembleTargetLogitsOffsets(T const** logitsPtrs, SizeType32* decodingTokens, T const* logits,
SizeType32 const* draftDecodingTokens, SizeType32 batchSize, SizeType32 maxDecodingTokens,
SizeType32 vocabSizePadded)
{
typedef cub::BlockScan<SizeType32, BLOCK_SIZE> BlockScan;
__shared__ typename BlockScan::TempStorage tempStorage;
auto const tix = static_cast<SizeType32>(threadIdx.x);
SizeType32 numDecodingTokens{0};
if (tix < batchSize)
{
numDecodingTokens = draftDecodingTokens[tix] + 1;
decodingTokens[tix] = numDecodingTokens;
}
SizeType32 logitsOffset{0};
BlockScan(tempStorage).ExclusiveSum(numDecodingTokens, logitsOffset);
if (tix < batchSize)
{
for (SizeType32 ti = 0; ti < numDecodingTokens; ++ti)
{
logitsPtrs[tix * maxDecodingTokens + ti] = logits + (logitsOffset + ti) * vocabSizePadded;
}
}
}
} // namespace
template <typename T>
void invokeAssembleTargetLogitsOffsets(T const** logitsPtrs, SizeType32* decodingTokens, T const* logits,
SizeType32 const* draftDecodingTokens, SizeType32 batchSize, SizeType32 maxDecodingTokens,
SizeType32 vocabSizePadded, cudaStream_t stream)
{
SizeType32 constexpr BLOCK_SIZE = 512;
TLLM_CHECK_WITH_INFO(
batchSize <= BLOCK_SIZE, "Batch size larger than %d is not supported for EAGLE yet", batchSize);
assembleTargetLogitsOffsets<T, BLOCK_SIZE><<<1, BLOCK_SIZE, 0, stream>>>(
logitsPtrs, decodingTokens, logits, draftDecodingTokens, batchSize, maxDecodingTokens, vocabSizePadded);
sync_check_cuda_error();
}
template void invokeAssembleTargetLogitsOffsets(float const** logitsPtrs, SizeType32* decodingTokens,
float const* logits, SizeType32 const* draftDecodingTokens, SizeType32 batchSize, SizeType32 maxDecodingTokens,
SizeType32 vocabSizePadded, cudaStream_t stream);
template void invokeAssembleTargetLogitsOffsets(__half const** logitsPtrs, SizeType32* decodingTokens,
__half const* logits, SizeType32 const* draftDecodingTokens, SizeType32 batchSize, SizeType32 maxDecodingTokens,
SizeType32 vocabSizePadded, cudaStream_t stream);
namespace
{
template <int BLOCK_SIZE>
__global__ void selectLastAccTokenAndComputeIndicesCumSum(TokenIdType* lastAcceptedTokenIds,
SizeType32* exclusiveSumLastAcceptedIndices, SizeType32 const* draftDecodingTokens,
TokenIdType const* acceptedTokenIds, SizeType32 const* acceptedLengths, SizeType32 const* bestPathIds,
SizeType32 const* paths, SizeType32 batchSize, SizeType32 maxDecodingTokens, SizeType32 maxPathLen)
{
typedef cub::BlockScan<SizeType32, BLOCK_SIZE> BlockScan;
__shared__ typename BlockScan::TempStorage tempStorage;
auto const tix = static_cast<SizeType32>(threadIdx.x);
SizeType32 decodingTokens{0};
SizeType32 lastTokenId{0};
if (tix < batchSize)
{
auto const acceptedLen = acceptedLengths[tix];
lastAcceptedTokenIds[tix] = acceptedTokenIds[tix * maxPathLen + acceptedLen - 1];
auto const bestPathId = bestPathIds[tix];
auto const pathIdx = flat_index3(tix, bestPathId, acceptedLen - 1, maxDecodingTokens, maxPathLen);
lastTokenId = paths[pathIdx];
decodingTokens = draftDecodingTokens[tix] + 1;
}
BlockScan(tempStorage).ExclusiveSum(decodingTokens, decodingTokens);
if (tix < batchSize)
{
exclusiveSumLastAcceptedIndices[tix] = decodingTokens + lastTokenId;
}
}
} // namespace
void invokeSelectLastAccTokenAndComputeIndicesCumSum(TokenIdType* lastAcceptedTokenIds,
SizeType32* exclusiveSumLastAcceptedIndices, SizeType32 const* draftDecodingTokens,
TokenIdType const* acceptedTokenIds, SizeType32 const* acceptedLengths, SizeType32 const* bestPathIds,
SizeType32 const* paths, SizeType32 batchSize, SizeType32 maxDecodingTokens, SizeType32 maxPathLen,
cudaStream_t stream)
{
SizeType32 constexpr BLOCK_SIZE = 512;
TLLM_CHECK_WITH_INFO(
batchSize <= BLOCK_SIZE, "Batch size larger than %d is not supported for EAGLE yet", batchSize);
selectLastAccTokenAndComputeIndicesCumSum<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, stream>>>(lastAcceptedTokenIds,
exclusiveSumLastAcceptedIndices, draftDecodingTokens, acceptedTokenIds, acceptedLengths, bestPathIds, paths,
batchSize, maxDecodingTokens, maxPathLen);
}
} // namespace tensorrt_llm::kernels::speculative_decoding

View File

@ -0,0 +1,68 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/speculativeDecoding/common.h"
#include "tensorrt_llm/runtime/common.h"
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <curand_kernel.h>
namespace tensorrt_llm::kernels::speculative_decoding
{
//! \brief Sets pointers to logits in logitsPtrs according to the draftDecodingTokens.
//! \param logitsPtrs [batchSize][vocabSizePadded]
//! \param decodingTokens [batchSize], on GPU. draftDecodingTokens + 1.
//! \param logits [numTokens, vocabSizePadded], on GPU. Continuous logits in memory.
//! \param draftDecodingTokens [batchSize], on GPU. 0 for context requests, and actual draft len for gen requests
//! \param batchSize batch size. Only batch size <= 512 is supported at the moment
//! \param maxDecodingTokens maximum number of decoding tokens per step per request
//! \param vocabSizePadded vocab size of the logits
//! \param stream cuda stream
template <typename T>
void invokeAssembleTargetLogitsOffsets(T const** logitsPtrs, runtime::SizeType32* decodingTokens, T const* logits,
runtime::SizeType32 const* draftDecodingTokens, runtime::SizeType32 batchSize,
runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 vocabSizePadded, cudaStream_t stream);
//! \brief Sets last accepted token ids and computes inclusive sum of the indices of the last accepted tokens in
//! flattened input_ids tensor.
//! \param lastAcceptedTokenIds [batchSize], on GPU. Token ids of the last accepted tokens.
//! \param exclusiveSumLastAcceptedIndices [batchSize], on GPU. Exclusive sum of the positions of the last accepted
//! tokens in the original flattened draft sequence.
//! \param draftDecodingTokens [batchSize], on GPU. 0 for context
//! requests, and actual draft len for gen requests.
//! \param acceptedTokenIds [batchSize, maxPathLen], on GPU. Ids of the
//! accepted tokens per request.
//! \param acceptedLengths [batchSize], on GPU. Lengths of the accepted draft sequences
//! per request.
//! \param bestPathIds [batchSize], on GPU. Selected path id per request
//! \param paths [batchSize,
//! maxDecodingTokens, maxPathLen], on GPU. Indices of the draft sequences
//! \param batchSize batch size. Only batch size
//! <= 512 is supported at the moment
//! \param maxDecodingTokens maximum number of decoding tokens per step per request
//! \param maxPathLen maximum path len of the draft sequence
//! \param stream cuda stream
void invokeSelectLastAccTokenAndComputeIndicesCumSum(runtime::TokenIdType* lastAcceptedTokenIds,
runtime::SizeType32* exclusiveSumLastAcceptedIndices, runtime::SizeType32 const* draftDecodingTokens,
runtime::TokenIdType const* acceptedTokenIds, runtime::SizeType32 const* acceptedLengths,
runtime::SizeType32 const* bestPathIds, runtime::SizeType32 const* paths, runtime::SizeType32 batchSize,
runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 maxPathLen, cudaStream_t stream);
} // namespace tensorrt_llm::kernels::speculative_decoding

View File

@ -60,7 +60,7 @@ __global__ void maskTargetLogitsKernel(T* targetLogits, SizeType32 const* batchS
auto* outputIdsAfterSamplingPtr = outputIdsAfterSampling + batchSlot * vocabSize;
auto const useDraftLogits = batchUseDraftLogits[batchSlot];
if (finishedState.isSkipDecoding())
if (finishedState.isSkipDecoding() || finishedState.isFinished())
{
return;
}
@ -75,8 +75,8 @@ __global__ void maskTargetLogitsKernel(T* targetLogits, SizeType32 const* batchS
for (SizeType32 vIdx = tid; vIdx < vocabSize; vIdx += static_cast<SizeType32>(blockDim.x))
{
if (tokensToMask == 0 && outputIdsAfterSamplingPtr[vIdx] == -1)
{ // we need to find the -1 boundary from returnAllTopP outputIds if topK == 0
if (outputIdsAfterSamplingPtr[vIdx] == -1)
{ // we need to find the -1 boundary from returnAllTopP outputIds if topK == 0 or number of topP indices < topK
tokensToMask = vIdx;
}
maskBuffer[vIdx] = false;
@ -124,12 +124,21 @@ __global__ void acceptDraftTokensKernel(T const* draftProbs, T* targetProbs, Siz
auto const numDraftTokens = numsDraftTokens[batchSlotBeamWidth];
auto const useDraftLogits = batchUseDraftLogits[batchSlotBeamWidth];
if (draftTokenIdx > numDraftTokens || finishedInput[batchSlot].isSkipDecoding())
if (draftTokenIdx > numDraftTokens || finishedInput[batchSlot].isSkipDecoding()
|| finishedInput[batchSlot].isFinished())
{
if (tid == 0)
{
batchIsAccepted[batchSlot] = true;
// either finished or skip decode in previous step, this step don't need decoding
finishedOutput[batchSlot].setSkipDecoding();
// if previous step is finished, write the state to next step too
if (finishedInput[batchSlot].isFinished())
{
finishedOutput[batchSlot] = finishedInput[batchSlot];
}
}
return;
}
@ -214,7 +223,8 @@ __global__ void forwardAcceptedTokensKernel(SizeType32 batchSize, SizeType32 con
for (SizeType32 bi = index; bi < batchSize; bi += static_cast<SizeType32>(gridDim.x * blockDim.x))
{
auto const batchSlot = batchSlots[bi];
if (batchIsAccepted[batchSlot] && !finishedOutput[batchSlot].isSkipDecoding())
if (batchIsAccepted[batchSlot] && !finishedOutput[batchSlot].isSkipDecoding()
&& !finishedOutput[batchSlot].isFinished())
{
auto const curSeqLen = sequenceLengths[batchSlot];
auto const draftTokenIdx = step;

View File

@ -46,22 +46,22 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
FinishedState* finishedFinal, SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds,
T const** medusaLogits, T const** logitsPtrs, SizeType32* curTokensPerStep, SizeType32 const* targetTokensPerStep,
SizeType32* bestPathIds, SizeType32 batchSize, SizeType32 vocabSize, SizeType32 maxBatchSize, SizeType32 maxSeqLen,
SizeType32 maxNumHeads, SizeType32 maxDecodingTokens)
SizeType32 maxDraftPathLen, SizeType32 maxDecodingTokens)
{
auto const batchIdx = static_cast<SizeType32>(blockIdx.x);
auto const batchSlot = batchSlots[batchIdx];
auto const inputLength = sequenceLengths[batchSlot];
auto const endId = endIds[batchSlot];
auto const numTokensPerStep = curTokensPerStep[batchSlot];
auto const maxNumDraftTokens = maxNumHeads + 1;
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
auto const inputLength = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot];
auto const endId = endIds == nullptr ? -1 : endIds[batchSlot];
auto const numTokensPerStep = curTokensPerStep == nullptr ? maxDecodingTokens : curTokensPerStep[batchSlot];
auto const maxPathLen = maxDraftPathLen + 1;
int4 partialMax{-1, -1, 0, 0};
// Go over different paths and construct implicit sequences
for (auto pathIdx = static_cast<SizeType32>(threadIdx.x); pathIdx < maxDecodingTokens;
pathIdx += static_cast<SizeType32>(blockDim.x))
{
auto acceptedLength = maxNumDraftTokens;
auto const pathOffset = flat_index3(batchSlot, pathIdx, 0, maxDecodingTokens, maxNumDraftTokens);
auto acceptedLength = maxPathLen;
auto const pathOffset = flat_index3(batchSlot, pathIdx, 0, maxDecodingTokens, maxPathLen);
bool hasEnd = false;
auto const tokenId = paths[pathOffset];
@ -75,13 +75,14 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
auto nextIdx = tokenId;
// Go along the path
for (SizeType32 ti = 1; ti < maxNumDraftTokens; ++ti)
for (SizeType32 ti = 1; ti < maxPathLen; ++ti)
{
auto const tokenId = paths[pathOffset + ti];
// Break if path terminates
if (tokenId == -1)
{
hasEnd = targetToken == endId; // check if last token is EOS when path terminates.
hasEnd = endIds == nullptr ? false
: targetToken == endId; // check if last token is EOS when path terminates.
acceptedLength = hasEnd ? ti - 1 : ti;
break;
}
@ -91,7 +92,7 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
auto const draftToken = tokenId >= numTokensPerStep ? -1 : draftIds[draftTokenIdx];
// Check if draft tokens are the same as target tokens
bool const accepted = draftToken == targetToken;
hasEnd = targetToken == endId;
hasEnd = endIds == nullptr ? false : targetToken == endId;
if (!accepted || hasEnd)
{
acceptedLength = hasEnd ? ti - 1 : ti;
@ -126,7 +127,7 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
auto const acceptedLength = totalShared.x;
auto const bestPathIdx = totalShared.y;
auto const bestNextIdx = numTokensPerStep == 1 ? 0 : totalShared.w;
auto const pathOffset = flat_index3(batchSlot, bestPathIdx, 0, maxDecodingTokens, maxNumDraftTokens);
auto const pathOffset = flat_index3(batchSlot, bestPathIdx, 0, maxDecodingTokens, maxPathLen);
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < acceptedLength; ti += static_cast<SizeType32>(blockDim.x))
{
auto const tokenId = paths[pathOffset + ti];
@ -142,15 +143,18 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
{
auto const hasEnd = totalShared.z;
// Set end condition
if (hasEnd)
if (hasEnd && finishedFinal)
{
finishedFinal[batchSlot].setFinishedEOS();
}
// Make correction to the sequence length
sequenceLengths[batchSlot] += acceptedLength;
if (sequenceLengths)
{
sequenceLengths[batchSlot] += acceptedLength;
}
acceptedLengths[batchSlot] = acceptedLength;
// In Medusa decoding step, number of draft tokens is 0 and must be updated for the next steps
if (numTokensPerStep == 1)
if (curTokensPerStep && targetTokensPerStep && numTokensPerStep == 1)
{
curTokensPerStep[batchSlot] = targetTokensPerStep[batchSlot];
}
@ -158,45 +162,33 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
}
// Prepare logits pointers to respective logits from Medusa Heads for the all-top-K sampling kernel
for (auto hi = static_cast<SizeType32>(threadIdx.x); hi < maxNumHeads; hi += static_cast<SizeType32>(blockDim.x))
if (medusaLogits && logitsPtrs)
{
logitsPtrs[batchIdx * maxNumHeads + hi]
= medusaLogits[batchSlot * maxNumHeads + hi] + flat_index2(bestNextIdx, 0, vocabSize);
for (auto hi = static_cast<SizeType32>(threadIdx.x); hi < maxDraftPathLen;
hi += static_cast<SizeType32>(blockDim.x))
{
logitsPtrs[batchIdx * maxDraftPathLen + hi]
= medusaLogits[batchSlot * maxDraftPathLen + hi] + flat_index2(bestNextIdx, 0, vocabSize);
}
}
}
} // namespace
template <typename T>
void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds, TokenIdType const* targetIds,
SizeType32* sequenceLengths, SizeType32* acceptedLengths, FinishedState* finishedFinal,
SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds, T const** medusaLogits,
T const** logitsPtrs, SizeType32* curTokensPerStep, SizeType32 const* targetTokensPerStep, SizeType32* bestPathIds,
SizeType32 batchSize, SizeType32 vocabSize, SizeType32 maxBatchSize, SizeType32 maxSeqLen, SizeType32 maxNumHeads,
SizeType32 maxDecodingTokens, cudaStream_t stream)
void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams<T> const& params)
{
constexpr SizeType32 BLOCK_SIZE = 256;
dim3 block(BLOCK_SIZE);
dim3 grid(batchSize);
acceptDraftTokensByIdsWithPaths<T, BLOCK_SIZE><<<grid, block, 0, stream>>>(outputIds, draftIds, targetIds,
sequenceLengths, acceptedLengths, finishedFinal, batchSlots, paths, endIds, medusaLogits, logitsPtrs,
curTokensPerStep, targetTokensPerStep, bestPathIds, batchSize, vocabSize, maxBatchSize, maxSeqLen, maxNumHeads,
maxDecodingTokens);
dim3 grid(params.batchSize);
acceptDraftTokensByIdsWithPaths<T, BLOCK_SIZE><<<grid, block, 0, params.stream>>>(params.outputIds, params.draftIds,
params.targetIds, params.sequenceLengths, params.acceptedLengths, params.finishedFinal, params.batchSlots,
params.paths, params.endIds, params.medusaLogits, params.logitsPtrs, params.curTokensPerStep,
params.targetTokensPerStep, params.bestPathIds, params.batchSize, params.vocabSize, params.maxBatchSize,
params.maxSeqLen, params.maxDraftPathLen, params.maxDecodingTokens);
}
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
TokenIdType const* targetIds, SizeType32* sequenceLengths, SizeType32* acceptedLengths,
FinishedState* finishedFinal, SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds,
float const** medusaLogits, float const** logitsPtrs, SizeType32* curTokensPerStep,
SizeType32 const* targetTokensPerStep, SizeType32* bestPathIds, SizeType32 batchSize, SizeType32 vocabSize,
SizeType32 maxBatchSize, SizeType32 maxSeqLen, SizeType32 maxNumHeads, SizeType32 maxDecodingTokens,
cudaStream_t stream);
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
TokenIdType const* targetIds, SizeType32* sequenceLengths, SizeType32* acceptedLengths,
FinishedState* finishedFinal, SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds,
half const** medusaLogits, half const** logitsPtrs, SizeType32* curTokensPerStep,
SizeType32 const* targetTokensPerStep, SizeType32* bestPathIds, SizeType32 batchSize, SizeType32 vocabSize,
SizeType32 maxBatchSize, SizeType32 maxSeqLen, SizeType32 maxNumHeads, SizeType32 maxDecodingTokens,
cudaStream_t stream);
template void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams<float> const& params);
template void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams<__half> const& params);
namespace
{

View File

@ -26,46 +26,87 @@
namespace tensorrt_llm::kernels::speculative_decoding
{
template <typename T>
struct AcceptDraftTokensByIdsWithPathsParams
{
//! output buffer [maxBatchSize, maxSeqLen], input tokens.
runtime::TokenIdType* outputIds{nullptr};
//! input buffer [maxBatchSize, maxDecodingTokens], draft tokens
runtime::TokenIdType const* draftIds{nullptr};
//! input buffer [maxBatchSize, maxDecodingTokens], tokens predicted from the target medusa head
runtime::TokenIdType const* targetIds{nullptr};
//! input/output buffer [maxBatchSize], optional.
//! Length of the data in outputIds without draft tokens.
//! If set, incrememnted according to the accepted length.
runtime::SizeType32* sequenceLengths{nullptr};
//! output buffer [maxBatchSize], length of the data accepted tokens
runtime::SizeType32* acceptedLengths{nullptr};
//! input buffer [maxBatchSize], optional. Finished states per request
FinishedState* finishedFinal{nullptr};
//! input buffer [batchSize], optional. Address map from local index
//! to global index [0, batchSize] -> [0, maxBatchSize].
//! If nullptr, batchIdx is used.
runtime::SizeType32 const* batchSlots{nullptr};
//! input buffer [maxBatchSize, maxDecodingTokens, maxDraftPathLen+1],
//! paths to restore sequences from outputIds and targetIds. Should be filled with -1 for everything that is not
//! path.
runtime::SizeType32 const* paths{nullptr};
//! input buffer [maxBatchSize], optional. EOS ids per request.
//! No EOS checks if nullptr.
runtime::TokenIdType const* endIds{nullptr};
//! input buffer [maxDraftPathLen, maxBatchSize, maxDecodingTokens, vocabSize], optional.
//! Pointer to the logits from medusa heads.
T const** medusaLogits{nullptr};
//! output buffer [batchSize, maxDraftPathLen], optional. Contains pointers to the
//! respective rows of the medusaLogits for the next after the accepted token
T const** logitsPtrs{nullptr};
//! current tokens to compute per step will be updated to
//! targetTokensPerStep if curTokensPerStep == 1
runtime::SizeType32* curTokensPerStep{nullptr};
//! target values of tokens to compute per step
runtime::SizeType32 const* targetTokensPerStep{nullptr};
//! output buffer [maxBatchSize], indices of the selected paths
runtime::SizeType32* bestPathIds{nullptr};
//! current batch size
runtime::SizeType32 batchSize{0};
//! maximum batch size
runtime::SizeType32 maxBatchSize{0};
//! vocab size
runtime::SizeType32 vocabSize{0};
//! maximum sequence length of output ids
runtime::SizeType32 maxSeqLen{0};
//! maximum number of medusa heads
runtime::SizeType32 maxDraftPathLen{0};
//! maximum number of tokens per step configured in the system
runtime::SizeType32 maxDecodingTokens{0};
//! stream
cudaStream_t stream;
void checkParams() const
{
TLLM_CHECK(outputIds);
TLLM_CHECK(draftIds);
TLLM_CHECK(targetIds);
TLLM_CHECK(acceptedLengths);
TLLM_CHECK(paths);
TLLM_CHECK(bestPathIds);
TLLM_CHECK((curTokensPerStep == nullptr) ^ (targetTokensPerStep == nullptr) == 0);
TLLM_CHECK((medusaLogits == nullptr) ^ (logitsPtrs == nullptr) == 0);
TLLM_CHECK(batchSize > 0);
TLLM_CHECK(batchSize <= maxBatchSize);
TLLM_CHECK(vocabSize > 0);
TLLM_CHECK(maxSeqLen > 0);
TLLM_CHECK(maxDraftPathLen > 0);
TLLM_CHECK(maxDecodingTokens > 0);
}
};
//! \brief verifies draft medusa tokens given target tokens. Modifies outputIds tensor accordingly filling it with
//! accepted tokens. Fills logitsPtrs tensor with the pointers to the respective medusa logits tensor according
//! to the next after the last accepted token.
//!
//! \param outputIds output buffer [maxBatchSize, maxSeqLen], input tokens.
//! \param draftIds input buffer [maxBatchSize, maxDecodingTokens], draft tokens
//! \param targetIds input buffer [maxBatchSize, maxDecodingTokens], tokens predicted from the target medusa head
//! \param sequenceLengths input/output buffer [maxBatchSize], length of the data in outputIds without draft tokens
//! Incrememnted according to the accepted length
//! \param acceptedLengths output buffer [maxBatchSize], length of the data accepted tokens
//! \param finishedFinal input buffer [maxBatchSize], finished states per request
//! \param batchSlots input buffer [batchSize], address map from local index
//! to global index [0, batchSize] -> [0, maxBatchSize]
//! \param paths input buffer [maxBatchSize, maxDecodingTokens, maxNumHeads+1],
//! paths to restore sequences from outputIds and targetIds. Should be filled with -1 for everything that is not path.
//! \param endIds input buffer [maxBatchSize], EOS ids per request
//! \param medusaLogits input buffer [maxNumHeads, maxBatchSize, maxDecodingTokens, vocabSize], pointer
//! to the logits from medusa heads
//! \param logitsPtrs output buffer [batchSize, maxNumHeads], contains pointers to the
//! respective rows of the medusaLogits for the next after the accepted token
//! \param curTokensPerStep current tokens to compute per step will be updated to
//! targetTokensPerStep if curTokensPerStep == 1
//! \param targetTokensPerStep target values of tokens to compute per step
//! \param bestPathIds output buffer [maxBatchSize], indices of the selected paths
//! \param batchSize current batch size
//! \param maxBatchSize maximum batch size
//! \param vocabSize vocab size
//! \param maxSeqLen maximum sequence length of output ids
//! \param maxNumHeads maximum number of medusa heads
//! \param maxDecodingTokens maximum number of tokens per step configured in the system
//! \param stream stream
template <typename T>
void acceptDraftTokensByIdsWithPaths(runtime::TokenIdType* outputIds, runtime::TokenIdType const* draftIds,
runtime::TokenIdType const* targetIds, runtime::SizeType32* sequenceLengths, runtime::SizeType32* acceptedLengths,
FinishedState* finishedFinal, runtime::SizeType32 const* batchSlots, runtime::SizeType32 const* paths,
runtime::TokenIdType const* endIds, T const** medusaLogits, T const** logitsPtrs,
runtime::SizeType32* curTokensPerStep, runtime::SizeType32 const* targetTokensPerStep,
runtime::SizeType32* bestPathIds, runtime::SizeType32 batchSize, runtime::SizeType32 maxBatchSize,
runtime::SizeType32 vocabSize, runtime::SizeType32 maxSeqLen, runtime::SizeType32 maxNumHeads,
runtime::SizeType32 maxDecodingTokens, cudaStream_t stream);
void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams<T> const&);
//! \brief assembles draft tokens to treeDraftIds from sourceDraftIds using indices of treeIds
//!

View File

@ -507,15 +507,12 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
VecType k_to_cache = params.position_shift_enabled ? k_wo_pos : k;
auto const dst_q_idx = static_cast<size_t>(global_token_idx) * params.q_hidden_size + hidden_idx;
QuantizedEltType* quantized_q_ptr = STORE_QKV
? reinterpret_cast<QuantizedEltType*>(params.QuantizedQKV) + src_q_idx
: reinterpret_cast<QuantizedEltType*>(params.Q) + dst_q_idx;
VecType* q_ptr = STORE_QKV ? reinterpret_ptr<T, VecType>(params.QKV, src_q_idx)
: reinterpret_ptr<T, VecType>(params.Q, dst_q_idx);
// Cast float scale to dst data type.
using TScale = typename mmha::kv_cache_scale_type_t<T, TCache>::Type;
TScale scaleOrigQuant;
[[maybe_unused]] TScale scaleOrigQuant;
if constexpr (FP8_OUTPUT || ENABLE_8BITS_CACHE)
{
mmha::convert_from_float(
@ -525,6 +522,9 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
if constexpr (FP8_OUTPUT)
{
// Quant the vec to fp8 vec with the scale.
QuantizedEltType* quantized_q_ptr = STORE_QKV
? reinterpret_cast<QuantizedEltType*>(params.QuantizedQKV) + src_q_idx
: reinterpret_cast<QuantizedEltType*>(params.Q) + dst_q_idx;
mmha::store_8bits_vec(quantized_q_ptr, q, 0, scaleOrigQuant);
}
else
@ -813,15 +813,12 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
if (valid_token)
{
auto const dst_q_idx = static_cast<size_t>(global_token_idx) * params.q_hidden_size + hidden_idx;
QuantizedEltType* quantized_q_ptr = STORE_QKV
? reinterpret_cast<QuantizedEltType*>(params.QuantizedQKV) + src_q_idx
: reinterpret_cast<QuantizedEltType*>(params.Q) + dst_q_idx;
VecT* q_ptr = STORE_QKV ? reinterpret_ptr<T, VecT>(params.QKV, src_q_idx)
: reinterpret_ptr<T, VecT>(params.Q, dst_q_idx);
// Cast float scale to dst data type.
using TScale = typename mmha::kv_cache_scale_type_t<T, TCache>::Type;
TScale scaleOrigQuant;
[[maybe_unused]] TScale scaleOrigQuant;
if constexpr (FP8_OUTPUT || ENABLE_8BITS_CACHE)
{
mmha::convert_from_float(&scaleOrigQuant, params.kvScaleOrigQuant ? params.kvScaleOrigQuant[0] : 1.0f);
@ -830,6 +827,9 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
if constexpr (FP8_OUTPUT)
{
// Quant the vec to fp8 vec with the scale.
QuantizedEltType* quantized_q_ptr = STORE_QKV
? reinterpret_cast<QuantizedEltType*>(params.QuantizedQKV) + src_q_idx
: reinterpret_cast<QuantizedEltType*>(params.Q) + dst_q_idx;
mmha::store_8bits_vec(quantized_q_ptr, q, 0, scaleOrigQuant);
}
else

View File

@ -32,6 +32,8 @@ namespace weight_only
{
enum class KernelType
{
FP16Int8Groupwise,
BF16Int8Groupwise,
FP16Int4Groupwise,
BF16Int4Groupwise,
FP16Int8PerChannel,
@ -49,6 +51,8 @@ struct kernel_type_traits;
static constexpr bool isGroupwise = _isGroupwise; \
static constexpr bool isInt4 = _isInt4; \
};
KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int8Groupwise, true, false);
KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int8Groupwise, true, false);
KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int4Groupwise, true, true);
KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int4Groupwise, true, true);
KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int8PerChannel, false, false);

View File

@ -0,0 +1,29 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,29 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true, 64);
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,29 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,29 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true, 64);
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -61,6 +61,8 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s)
{
EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
}
EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true);
EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true);
EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true);
@ -70,6 +72,8 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s)
}
else if (arch >= 90)
{
EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajor, false);
EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajor, false);
EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false);
EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false);
EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajor, false);
@ -98,6 +102,8 @@ inline bool is_supported(int arch, KernelType kernel_type)
}
else if (arch >= 80 && arch < 90)
{
SUPPORT(KernelType::FP16Int8Groupwise);
SUPPORT(KernelType::BF16Int8Groupwise);
SUPPORT(KernelType::FP16Int4Groupwise);
SUPPORT(KernelType::BF16Int4Groupwise);
SUPPORT(KernelType::FP16Int8PerChannel);
@ -107,6 +113,8 @@ inline bool is_supported(int arch, KernelType kernel_type)
}
else if (arch >= 90)
{
SUPPORT(KernelType::FP16Int8Groupwise);
SUPPORT(KernelType::BF16Int8Groupwise);
SUPPORT(KernelType::FP16Int4Groupwise);
SUPPORT(KernelType::BF16Int4Groupwise);
SUPPORT(KernelType::FP16Int8PerChannel);

View File

@ -431,7 +431,7 @@ void ExternalDraftTokensLayer<T>::getAllTopKs(std::shared_ptr<BaseDecodingOutput
params.maxBatchSize = mDecoderDomain.getBatchSize();
params.maxTokensPerStep = 1;
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
params.returnAllTopK = true;
params.returnAllSelectedTokens = true;
params.maxSeqLen = mDecoderDomain.getVocabSizePadded(); // workaround for returning all topKs with outputIds
params.logitsHasProbs = inputs->probsComputed;
@ -475,7 +475,7 @@ void ExternalDraftTokensLayer<T>::getAllTopPs(std::shared_ptr<BaseDecodingOutput
params.batchSize = batchSize;
params.maxBatchSize = mDecoderDomain.getBatchSize();
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
params.returnAllTopP = true;
params.returnAllSelectedTokens = true;
params.maxSeqLen = mDecoderDomain.getVocabSizePadded();
invokeBatchTopPSampling<T>(params, getStream());

View File

@ -76,6 +76,8 @@ LookaheadDecodingLayer<T>::CpuAlgorithmResources::CpuAlgorithmResources(DecoderD
ITensor::makeShape({maxTokensPerStep, maxBatchSize, beamWidth}), nvinfer1::DataType::kINT32);
mPathsOffsets
= BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxAcceptedDraftLen}), nvinfer1::DataType::kINT32);
mPathsOffsetsBatch
= BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxAcceptedDraftLen}), nvinfer1::DataType::kINT32);
mNumNewTokens = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
mNumNewTokensCumSum = BufferManager::cpu(ITensor::makeShape({maxBatchSize + 1}), nvinfer1::DataType::kINT32);
mNextDraftTokens = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32);
@ -220,7 +222,7 @@ void LookaheadDecodingLayer<T>::forwardAsync(std::shared_ptr<BaseDecodingOutputs
params.maxBatchSize = mDecoderDomain.getBatchSize();
params.batchSize = batchSize;
params.maxTopK = 1;
params.returnAllTopK = true;
params.returnAllSelectedTokens = true;
params.maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens();
params.maxSeqLen = mDecoderDomain.getMaxDecodingTokens();
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
@ -321,6 +323,7 @@ void LookaheadDecodingLayer<T>::forwardSyncCPU(
BufferRange<SizeType32> nextDraftLengthsRange(*mCpuAlgo->mNextDraftLengths);
BufferRange<SizeType32> sequenceLengthsRange(*mCpuAlgo->mSequenceLengths);
BufferLocation<SizeType32> pathsOffsetLocation(*mCpuAlgo->mPathsOffsets);
BufferLocation<SizeType32> pathsOffsetBatchLocation(*mCpuAlgo->mPathsOffsetsBatch);
BufferLocation<TokenIdType> outputIdsLocation(*mCpuAlgo->mOutputIds);
mBufferManager->setZero(*mCpuAlgo->mPathsOffsets);
@ -394,20 +397,22 @@ void LookaheadDecodingLayer<T>::forwardSyncCPU(
D(accepted).values().c_str(), D(draft).values().c_str());
}
numNewTokensCumSumRange[0] = 0;
SizeType32 pi = 0;
for (SizeType32 bi = 0; bi < numNewTokensRange.size(); bi++)
numNewTokensCumSumRange[0] = 0;
for (SizeType32 bi = 0; bi < batchSize; bi++)
{
SizeType32 acceptedDraftLen = numNewTokensRange[bi] <= 1 ? 0 : (numNewTokensRange[bi] - 1);
SizeType32 gbi = batchSlotsRange[bi];
SizeType32 acceptedDraftLen = numNewTokensRange[gbi] <= 1 ? 0 : (numNewTokensRange[gbi] - 1);
numNewTokensCumSumRange[bi + 1] = numNewTokensCumSumRange[bi] + acceptedDraftLen;
for (SizeType32 tj = 0; tj < acceptedDraftLen; tj++)
{
pathsOffsetLocation[pi++] = pathsOffsetLocation.at(bi, tj);
pathsOffsetBatchLocation[pi++] = pathsOffsetLocation.at(gbi, tj);
}
}
for (; pi < pathsOffsetLocation.size(); pi++)
for (; pi < pathsOffsetBatchLocation.size(); pi++)
{
pathsOffsetLocation[pi++] = 0;
pathsOffsetBatchLocation[pi++] = 0;
}
TLLM_CHECK(outputs->numNewTokens);
@ -415,8 +420,8 @@ void LookaheadDecodingLayer<T>::forwardSyncCPU(
mBufferManager->copy(*mCpuAlgo->mSequenceLengths, *outputs->sequenceLength.value());
mBufferManager->copy(*mCpuAlgo->mNewTokens, *outputs->newTokens);
mBufferManager->copy(*mCpuAlgo->mPathsOffsets, *outputs->pathsOffsets);
mBufferManager->copy(*mCpuAlgo->mNumNewTokens, *outputs->numNewTokens.value());
mBufferManager->copy(*mCpuAlgo->mPathsOffsetsBatch, *outputs->pathsOffsets);
mBufferManager->copy(*mCpuAlgo->mNumNewTokensCumSum, *outputs->numNewTokensCumSum); //
mBufferManager->copy(*mCpuAlgo->mNextDraftTokens, *outputs->nextDraftTokens);

View File

@ -70,6 +70,7 @@ private:
TensorPtr mOutputIds;
TensorPtr mPathsOffsets;
TensorPtr mPathsOffsetsBatch;
TensorPtr mNumNewTokens;
TensorPtr mNumNewTokensCumSum;
TensorPtr mNewTokens;

View File

@ -329,11 +329,33 @@ void MedusaDecodingLayer<T>::acceptDraftTokens(SpeculativeDecodingOutputs const&
auto medusaInputLogitsPtrsPtr = reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaInputLogitsPtrs));
auto medusaSelectedLogitsPtrsDevicePtr
= const_cast<T const**>(bufferCastOrNull<T const*>(mMedusaSelectedLogitsPtrsDevice));
acceptDraftTokensByIdsWithPaths(outputIds, draftIds, targetTokensDevicePtr, sequenceLengths, numNewTokens,
finishedStatesPtr, workspace->getDeviceBatchSlotsPtr(), paths, endIds, medusaInputLogitsPtrsPtr,
medusaSelectedLogitsPtrsDevicePtr, curTokensPerStepDevice, targetTokensPerStepDevice, bestPathIdsDevicePtr,
batchSize, mDecoderDomain.getVocabSize(), mDecoderDomain.getBatchSize(), maxSeqLen, maxDraftPathLen,
mDecoderDomain.getMaxDecodingTokens(), getStream());
AcceptDraftTokensByIdsWithPathsParams<T> params;
params.outputIds = outputIds;
params.draftIds = draftIds;
params.targetIds = targetTokensDevicePtr;
params.sequenceLengths = sequenceLengths;
params.acceptedLengths = numNewTokens;
params.finishedFinal = finishedStatesPtr;
params.batchSlots = workspace->getDeviceBatchSlotsPtr();
params.paths = paths;
params.endIds = endIds;
params.medusaLogits = medusaInputLogitsPtrsPtr;
params.logitsPtrs = medusaSelectedLogitsPtrsDevicePtr;
params.curTokensPerStep = curTokensPerStepDevice;
params.targetTokensPerStep = targetTokensPerStepDevice;
params.bestPathIds = bestPathIdsDevicePtr;
params.batchSize = batchSize;
params.maxBatchSize = mDecoderDomain.getBatchSize();
params.vocabSize = mDecoderDomain.getVocabSize();
params.maxSeqLen = maxSeqLen;
params.maxDraftPathLen = maxDraftPathLen;
params.maxDecodingTokens = mDecoderDomain.getMaxDecodingTokens();
params.stream = getStream();
params.checkParams();
acceptDraftTokensByIdsWithPaths(params);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -390,7 +412,7 @@ void MedusaDecodingLayer<T>::sampleNewDraftTokens(SpeculativeDecodingOutputs con
params.maxBatchSize = maxBatchSizeHeadNums;
params.maxTokensPerStep = 1;
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
params.returnAllTopK = true;
params.returnAllSelectedTokens = true;
invokeBatchTopKSampling(params, getStream());

View File

@ -54,7 +54,8 @@ set(PLUGIN_LISTS
mambaConv1dPlugin
lruPlugin
cumsumLastDimPlugin
lowLatencyGemmPlugin)
lowLatencyGemmPlugin
eaglePlugin)
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
include_directories(${PLUGIN_ITER})

View File

@ -39,6 +39,9 @@
#include "tensorrt_llm/plugins/ncclPlugin/sendPlugin.h"
#endif // ENABLE_MULTI_DEVICE
#include "tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.h"
#include "tensorrt_llm/plugins/eaglePlugin/eagleDecodeDraftTokensPlugin.h"
#include "tensorrt_llm/plugins/eaglePlugin/eaglePrepareDrafterInputsPlugin.h"
#include "tensorrt_llm/plugins/eaglePlugin/eagleSampleAndAcceptDraftTokensPlugin.h"
#include "tensorrt_llm/plugins/lowLatencyGemmPlugin/lowLatencyGemmPlugin.h"
#include "tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.h"
#include "tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.h"
@ -201,6 +204,10 @@ extern "C"
static tensorrt_llm::plugins::lruPluginCreator lruPluginCreator;
static tensorrt_llm::plugins::CumsumLastDimPluginCreator cumsumLastDimPluginCreator;
static tensorrt_llm::plugins::LowLatencyGemmPluginCreator lowLatencyGemmPluginCreator;
static tensorrt_llm::plugins::EagleDecodeDraftTokensPluginCreator eagleDecodeDraftTokensPluginCreator;
static tensorrt_llm::plugins::EagleSampleAndAcceptDraftTokensPluginCreator
eagleSampleAndAcceptDraftTokensPluginCreator;
static tensorrt_llm::plugins::EaglePrepareDrafterInputsPluginCreator eaglePrepareDrafterInputsPluginCreator;
static std::array pluginCreators
= { creatorPtr(identityPluginCreator),
@ -231,6 +238,9 @@ extern "C"
creatorPtr(lruPluginCreator),
creatorPtr(cumsumLastDimPluginCreator),
creatorPtr(lowLatencyGemmPluginCreator),
creatorPtr(eagleDecodeDraftTokensPluginCreator),
creatorPtr(eagleSampleAndAcceptDraftTokensPluginCreator),
creatorPtr(eaglePrepareDrafterInputsPluginCreator),
};
nbCreators = pluginCreators.size();
return pluginCreators.data();

View File

@ -0,0 +1,21 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
#
file(GLOB SRCS *.cpp)
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
set(PLUGIN_SOURCES
${PLUGIN_SOURCES}
PARENT_SCOPE)

View File

@ -0,0 +1,228 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "eagleDecodeDraftTokensPlugin.h"
using namespace nvinfer1;
using tensorrt_llm::plugins::EagleDecodeDraftTokensPluginCreator;
using tensorrt_llm::plugins::EagleDecodeDraftTokensPlugin;
static char const* EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_VERSION{"1"};
static char const* EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_NAME{"EagleDecodeDraftTokens"};
PluginFieldCollection EagleDecodeDraftTokensPluginCreator::mFC{};
std::vector<nvinfer1::PluginField> EagleDecodeDraftTokensPluginCreator::mPluginAttributes;
EagleDecodeDraftTokensPlugin::EagleDecodeDraftTokensPlugin(nvinfer1::DataType type, int32_t layerIdx)
: mDtype(type)
, mLayerIdx(layerIdx)
{
}
// Parameterized constructor
EagleDecodeDraftTokensPlugin::EagleDecodeDraftTokensPlugin(void const* data, size_t length)
{
char const *d = reinterpret_cast<char const*>(data), *a = d;
read(d, mDtype);
read(d, mLayerIdx);
TLLM_CHECK_WITH_INFO(d == a + length,
"Expected length (%d) != real length (%d). This is often "
"caused by using different TensorRT-LLM version to build "
"engine and run engine.",
static_cast<int>(length), static_cast<int>(d - a));
}
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* EagleDecodeDraftTokensPlugin::clone() const noexcept
{
auto* plugin = new EagleDecodeDraftTokensPlugin(*this);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
nvinfer1::DimsExprs EagleDecodeDraftTokensPlugin::getOutputDimensions(
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
TLLM_CHECK(outputIndex < 2);
TLLM_CHECK(nbInputs == 5);
return inputs[outputIndex + 1];
}
bool EagleDecodeDraftTokensPlugin::supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
{
if (pos == 0) // logits
{
return (inOut[pos].type == mDtype) && (inOut[pos].format == TensorFormat::kLINEAR);
}
else if (pos == 3) // rand_data_sample
{
return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR);
}
else // next_draft_tokens, next_draft_lens, paths, tree_indices
{
return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == TensorFormat::kLINEAR);
}
}
void EagleDecodeDraftTokensPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
{
}
size_t EagleDecodeDraftTokensPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
{
return 0;
}
int EagleDecodeDraftTokensPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
// TODO fill me
return 0;
}
// IPluginV2Ext Methods
nvinfer1::DataType EagleDecodeDraftTokensPlugin::getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
{
TLLM_CHECK(index < 2);
return inputTypes[index + 1];
}
// IPluginV2 Methods
char const* EagleDecodeDraftTokensPlugin::getPluginType() const noexcept
{
return EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_NAME;
}
char const* EagleDecodeDraftTokensPlugin::getPluginVersion() const noexcept
{
return EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_VERSION;
}
int EagleDecodeDraftTokensPlugin::getNbOutputs() const noexcept
{
return 2;
}
int EagleDecodeDraftTokensPlugin::initialize() noexcept
{
return 0;
}
void EagleDecodeDraftTokensPlugin::terminate() noexcept {}
size_t EagleDecodeDraftTokensPlugin::getSerializationSize() const noexcept
{
return sizeof(mDtype) + sizeof(mLayerIdx);
}
void EagleDecodeDraftTokensPlugin::serialize(void* buffer) const noexcept
{
char *d = static_cast<char*>(buffer), *a = d;
write(d, mLayerIdx);
write(d, mDtype);
assert(d == a + getSerializationSize());
}
void EagleDecodeDraftTokensPlugin::destroy() noexcept
{
// This gets called when the network containing plugin is destroyed
delete this;
}
///////////////
EagleDecodeDraftTokensPluginCreator::EagleDecodeDraftTokensPluginCreator()
{
// Fill PluginFieldCollection with PluginField arguments metadata
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("layer_idx", nullptr, PluginFieldType::kINT32, 0));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
char const* EagleDecodeDraftTokensPluginCreator::getPluginName() const noexcept
{
return EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_NAME;
}
char const* EagleDecodeDraftTokensPluginCreator::getPluginVersion() const noexcept
{
return EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_VERSION;
}
PluginFieldCollection const* EagleDecodeDraftTokensPluginCreator::getFieldNames() noexcept
{
return &mFC;
}
IPluginV2* EagleDecodeDraftTokensPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
{
PluginField const* fields = fc->fields;
int32_t layerIdx;
nvinfer1::DataType type;
// Read configurations from each fields
for (int i = 0; i < fc->nbFields; ++i)
{
char const* attrName = fields[i].name;
if (!strcmp(attrName, "layer_idx"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
layerIdx = *static_cast<int32_t const*>(fields[i].data);
}
else if (!strcmp(attrName, "type_id"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
type = static_cast<nvinfer1::DataType>(*(static_cast<nvinfer1::DataType const*>(fields[i].data)));
}
}
try
{
auto* obj = new EagleDecodeDraftTokensPlugin(type, layerIdx);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
IPluginV2* EagleDecodeDraftTokensPluginCreator::deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept
{
// This object will be deleted when the network is destroyed, which will
// call EagleDecodeDraftTokensPlugin::destroy()
try
{
auto* obj = new EagleDecodeDraftTokensPlugin(serialData, serialLength);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}

View File

@ -0,0 +1,90 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/plugins/common/plugin.h"
#include <cassert>
#include <set>
#include <string>
#include <vector>
namespace tensorrt_llm::plugins
{
class EagleDecodeDraftTokensPlugin : public BasePlugin
{
public:
EagleDecodeDraftTokensPlugin(nvinfer1::DataType type, int32_t layerIdx);
EagleDecodeDraftTokensPlugin(void const* data, size_t length);
~EagleDecodeDraftTokensPlugin() override = default;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
bool supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override;
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override;
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override;
int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override;
// IPluginV2 Methods
char const* getPluginType() const noexcept override;
char const* getPluginVersion() const noexcept override;
int getNbOutputs() const noexcept override;
int initialize() noexcept override;
void terminate() noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
void destroy() noexcept override;
private:
nvinfer1::DataType mDtype;
int32_t mLayerIdx;
};
class EagleDecodeDraftTokensPluginCreator : public BaseCreator
{
public:
EagleDecodeDraftTokensPluginCreator();
char const* getPluginName() const noexcept override;
char const* getPluginVersion() const noexcept override;
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
nvinfer1::IPluginV2* deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept override;
private:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
};
} // namespace tensorrt_llm::plugins

View File

@ -0,0 +1,272 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "eaglePrepareDrafterInputsPlugin.h"
using namespace nvinfer1;
using tensorrt_llm::plugins::EaglePrepareDrafterInputsPluginCreator;
using tensorrt_llm::plugins::EaglePrepareDrafterInputsPlugin;
static char const* EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_VERSION{"1"};
static char const* EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_NAME{"EaglePrepareDrafterInputs"};
PluginFieldCollection EaglePrepareDrafterInputsPluginCreator::mFC{};
std::vector<nvinfer1::PluginField> EaglePrepareDrafterInputsPluginCreator::mPluginAttributes;
EaglePrepareDrafterInputsPlugin::EaglePrepareDrafterInputsPlugin(nvinfer1::DataType type, int32_t layerIdx)
: mDtype(type)
, mLayerIdx(layerIdx)
{
}
// Parameterized constructor
EaglePrepareDrafterInputsPlugin::EaglePrepareDrafterInputsPlugin(void const* data, size_t length)
{
char const *d = reinterpret_cast<char const*>(data), *a = d;
read(d, mDtype);
read(d, mLayerIdx);
TLLM_CHECK_WITH_INFO(d == a + length,
"Expected length (%d) != real length (%d). This is often "
"caused by using different TensorRT-LLM version to build "
"engine and run engine.",
static_cast<int>(length), static_cast<int>(d - a));
}
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* EaglePrepareDrafterInputsPlugin::clone() const noexcept
{
auto* plugin = new EaglePrepareDrafterInputsPlugin(*this);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
nvinfer1::DimsExprs EaglePrepareDrafterInputsPlugin::getOutputDimensions(
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
TLLM_CHECK(outputIndex < 10);
TLLM_CHECK(nbInputs == 7);
auto const batchSizeExpr = inputs[nbInputs - 2].d[0];
auto const maxDraftLenExpr = inputs[nbInputs - 2].d[1];
nvinfer1::DimsExprs ret;
switch (outputIndex)
{
case 0: // sequence_length
case 1: // host_request_types
case 2: // host_past_key_value_lengths
ret = inputs[outputIndex];
break;
case 3: // spec_decoding_generation_lengths
ret.nbDims = 1;
ret.d[0] = batchSizeExpr;
break;
case 4: // spec_decoding_position_offsets
case 5: // input_ids
case 6: // position_ids
// FIXME input_ids should have real value, not maxDraftLen
ret.nbDims = 1;
ret.d[0] = maxDraftLenExpr;
break;
case 7: // spec_decoding_packed_mask
// FIXME
ret.nbDims = 3;
ret.d[0] = batchSizeExpr;
ret.d[1] = maxDraftLenExpr;
ret.d[2] = exprBuilder.operation(DimensionOperation::kCEIL_DIV, *maxDraftLenExpr, *exprBuilder.constant(32));
break;
case 8: // hidden_dim
ret.nbDims = 2;
// FIXME real dim instead of max draft len
ret.d[0] = maxDraftLenExpr;
ret.d[1] = inputs[4].d[1];
break;
}
return ret;
}
bool EaglePrepareDrafterInputsPlugin::supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
{
if (pos == nbInputs - 1 || pos == nbInputs + nbOutputs - 1) // hidden_states
{
return (inOut[pos].type == mDtype) && (inOut[pos].format == TensorFormat::kLINEAR);
}
else if (pos == 3) // kv cache pool pointers
{
return inOut[pos].type == nvinfer1::DataType::kINT64 && inOut[pos].format == TensorFormat::kLINEAR;
}
else // all other tensors
{
return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == TensorFormat::kLINEAR);
}
}
void EaglePrepareDrafterInputsPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
{
}
size_t EaglePrepareDrafterInputsPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
{
return 0;
}
int EaglePrepareDrafterInputsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
// TODO fill me
return 0;
}
// IPluginV2Ext Methods
nvinfer1::DataType EaglePrepareDrafterInputsPlugin::getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
{
TLLM_CHECK(index < 9);
if (index < 8)
{
return inputTypes[0]; // type of sequence_length
}
else // hidden_states
{
return inputTypes[nbInputs - 1]; // type of hidden_states
}
}
// IPluginV2 Methods
char const* EaglePrepareDrafterInputsPlugin::getPluginType() const noexcept
{
return EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_NAME;
}
char const* EaglePrepareDrafterInputsPlugin::getPluginVersion() const noexcept
{
return EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_VERSION;
}
int EaglePrepareDrafterInputsPlugin::getNbOutputs() const noexcept
{
return 9;
}
int EaglePrepareDrafterInputsPlugin::initialize() noexcept
{
return 0;
}
void EaglePrepareDrafterInputsPlugin::terminate() noexcept {}
size_t EaglePrepareDrafterInputsPlugin::getSerializationSize() const noexcept
{
return sizeof(mDtype) + sizeof(mLayerIdx);
}
void EaglePrepareDrafterInputsPlugin::serialize(void* buffer) const noexcept
{
char *d = static_cast<char*>(buffer), *a = d;
write(d, mLayerIdx);
write(d, mDtype);
assert(d == a + getSerializationSize());
}
void EaglePrepareDrafterInputsPlugin::destroy() noexcept
{
// This gets called when the network containing plugin is destroyed
delete this;
}
///////////////
EaglePrepareDrafterInputsPluginCreator::EaglePrepareDrafterInputsPluginCreator()
{
// Fill PluginFieldCollection with PluginField arguments metadata
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("layer_idx", nullptr, PluginFieldType::kINT32, 0));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
char const* EaglePrepareDrafterInputsPluginCreator::getPluginName() const noexcept
{
return EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_NAME;
}
char const* EaglePrepareDrafterInputsPluginCreator::getPluginVersion() const noexcept
{
return EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_VERSION;
}
PluginFieldCollection const* EaglePrepareDrafterInputsPluginCreator::getFieldNames() noexcept
{
return &mFC;
}
IPluginV2* EaglePrepareDrafterInputsPluginCreator::createPlugin(
char const* name, PluginFieldCollection const* fc) noexcept
{
PluginField const* fields = fc->fields;
int32_t layerIdx;
nvinfer1::DataType type;
// Read configurations from each fields
for (int i = 0; i < fc->nbFields; ++i)
{
char const* attrName = fields[i].name;
if (!strcmp(attrName, "layer_idx"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
layerIdx = *static_cast<int32_t const*>(fields[i].data);
}
else if (!strcmp(attrName, "type_id"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
type = static_cast<nvinfer1::DataType>(*(static_cast<nvinfer1::DataType const*>(fields[i].data)));
}
}
try
{
auto* obj = new EaglePrepareDrafterInputsPlugin(type, layerIdx);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
IPluginV2* EaglePrepareDrafterInputsPluginCreator::deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept
{
// This object will be deleted when the network is destroyed, which will
// call EaglePrepareDrafterInputsPlugin::destroy()
try
{
auto* obj = new EaglePrepareDrafterInputsPlugin(serialData, serialLength);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}

View File

@ -0,0 +1,90 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/plugins/common/plugin.h"
#include <cassert>
#include <set>
#include <string>
#include <vector>
namespace tensorrt_llm::plugins
{
class EaglePrepareDrafterInputsPlugin : public BasePlugin
{
public:
EaglePrepareDrafterInputsPlugin(nvinfer1::DataType type, int32_t layerIdx);
EaglePrepareDrafterInputsPlugin(void const* data, size_t length);
~EaglePrepareDrafterInputsPlugin() override = default;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
bool supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override;
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override;
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override;
int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override;
// IPluginV2 Methods
char const* getPluginType() const noexcept override;
char const* getPluginVersion() const noexcept override;
int getNbOutputs() const noexcept override;
int initialize() noexcept override;
void terminate() noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
void destroy() noexcept override;
private:
nvinfer1::DataType mDtype;
int32_t mLayerIdx;
};
class EaglePrepareDrafterInputsPluginCreator : public BaseCreator
{
public:
EaglePrepareDrafterInputsPluginCreator();
char const* getPluginName() const noexcept override;
char const* getPluginVersion() const noexcept override;
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
nvinfer1::IPluginV2* deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept override;
private:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
};
} // namespace tensorrt_llm::plugins

View File

@ -0,0 +1,515 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "eagleSampleAndAcceptDraftTokensPlugin.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
#include "tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h"
#include "tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.h"
#include "tensorrt_llm/runtime/common.h"
using namespace nvinfer1;
using tensorrt_llm::plugins::EagleSampleAndAcceptDraftTokensPluginCreator;
using tensorrt_llm::plugins::EagleSampleAndAcceptDraftTokensPlugin;
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::kernels::speculative_decoding;
using namespace tensorrt_llm::runtime;
namespace tc = tensorrt_llm::common;
static char const* EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_VERSION{"1"};
static char const* EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_NAME{"EagleSampleAndAcceptDraftTokens"};
PluginFieldCollection EagleSampleAndAcceptDraftTokensPluginCreator::mFC{};
std::vector<nvinfer1::PluginField> EagleSampleAndAcceptDraftTokensPluginCreator::mPluginAttributes;
EagleSampleAndAcceptDraftTokensPlugin::EagleSampleAndAcceptDraftTokensPlugin(
nvinfer1::DataType type, bool greedySampling)
: mDtype(type)
, mGreedySampling(greedySampling)
{
TLLM_CHECK_WITH_INFO(mGreedySampling, "Non-greedy sampling is not supported yet.");
}
// Parameterized constructor
EagleSampleAndAcceptDraftTokensPlugin::EagleSampleAndAcceptDraftTokensPlugin(void const* data, size_t length)
{
char const *d = reinterpret_cast<char const*>(data), *a = d;
read(d, mDtype);
read(d, mGreedySampling);
TLLM_CHECK_WITH_INFO(d == a + length,
"Expected length (%d) != real length (%d). This is often "
"caused by using different TensorRT-LLM version to build "
"engine and run engine.",
(int) length, (int) (d - a));
}
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* EagleSampleAndAcceptDraftTokensPlugin::clone() const noexcept
{
auto* plugin = new EagleSampleAndAcceptDraftTokensPlugin(*this);
plugin->setPluginNamespace(mNamespace.c_str());
return plugin;
}
nvinfer1::DimsExprs EagleSampleAndAcceptDraftTokensPlugin::getOutputDimensions(
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
{
TLLM_CHECK(nbInputs == 6);
TLLM_CHECK(outputIndex < 7);
auto const batchSizeExpr = inputs[getIdx(InputIdxEntry::PATHS)].d[0];
auto const maxDecodingDraftTokensExpr = inputs[getIdx(InputIdxEntry::DRAFT_TOKEN_IDS)].d[1];
auto const maxPathLenExpr = inputs[getIdx(InputIdxEntry::PATHS)].d[2];
nvinfer1::DimsExprs ret;
switch (outputIndex)
{
case 0: // accepted_tokens
ret.nbDims = 2;
ret.d[0] = batchSizeExpr;
ret.d[1] = maxPathLenExpr;
break;
case 1: // num_accepted_tokens
ret.nbDims = 1;
ret.d[0] = batchSizeExpr;
break;
case 2: // accepted_paths
ret.nbDims = 1;
ret.d[0] = batchSizeExpr;
break;
case 3: // last_accepted_tokens
ret.nbDims = 1;
ret.d[0] = batchSizeExpr;
break;
case 4: // exclusive_sum_last_accepted_indices
ret.nbDims = 1;
ret.d[0] = batchSizeExpr;
break;
case 5: // next_draft_tokens
ret.nbDims = 2;
ret.d[0] = batchSizeExpr;
ret.d[1] = maxDecodingDraftTokensExpr;
break;
case 6: // next_draft_lens
ret.nbDims = 1;
ret.d[0] = batchSizeExpr;
break;
}
return ret;
}
bool EagleSampleAndAcceptDraftTokensPlugin::supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
{
if (pos == getIdx(InputIdxEntry::LOGITS)) // logits
{
return (inOut[pos].type == mDtype) && (inOut[pos].format == TensorFormat::kLINEAR);
}
else if (pos == getIdx(InputIdxEntry::TEMPERATURE)
|| pos == getIdx(InputIdxEntry::RAND_VALIDATION)) // temperature, rand_validation
{
return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR);
}
else // everything else
{
return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == TensorFormat::kLINEAR);
}
}
void EagleSampleAndAcceptDraftTokensPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
{
}
template <typename T>
size_t EagleSampleAndAcceptDraftTokensPlugin::getWorkspaceSizeType(nvinfer1::PluginTensorDesc const* inputs,
int nbInputs, nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
{
size_t workspaceSize{0};
auto const vocabSizePadded = inputs[getIdx(InputIdxEntry::LOGITS)].dims.d[1];
auto const batchSize = inputs[getIdx(InputIdxEntry::PATHS)].dims.d[0];
auto const maxDecodingTokens = inputs[getIdx(InputIdxEntry::PATHS)].dims.d[1];
// Greedy sampling
{
// Top1 sampling workspace
auto const primarySamplingWorkspaceSize
= getTopKWorkspaceSize<T>(batchSize, maxDecodingTokens, /* maxTopK */ 1, vocabSizePadded);
// Target output ids
auto const targetOutputIdsSize = batchSize * maxDecodingTokens * sizeof(TokenIdType);
// Logits ptrs
auto const logitsPtrsSize = batchSize * maxDecodingTokens * sizeof(T*);
SizeType32 constexpr NUM_BUFFERS{4};
size_t workspaces[NUM_BUFFERS];
workspaces[0] = primarySamplingWorkspaceSize;
workspaces[1] = targetOutputIdsSize;
workspaces[2] = logitsPtrsSize;
workspaces[3] = batchSize * sizeof(SizeType32);
workspaceSize = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
}
return workspaceSize;
}
size_t EagleSampleAndAcceptDraftTokensPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
{
auto const logitsType = inputs[getIdx(InputIdxEntry::LOGITS)].type;
if (logitsType == nvinfer1::DataType::kFLOAT)
{
return getWorkspaceSizeType<float>(inputs, nbInputs, outputs, nbOutputs);
}
else if (logitsType == nvinfer1::DataType::kHALF)
{
return getWorkspaceSizeType<__half>(inputs, nbInputs, outputs, nbOutputs);
}
else
{
TLLM_CHECK_WITH_INFO(false, "Unsupported logits type");
}
return 0;
}
template <typename T>
void EagleSampleAndAcceptDraftTokensPlugin::samplePrimeHeadTokens(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const maxNumTokens = inputDesc[getIdx(InputIdxEntry::LOGITS)].dims.d[0];
auto const vocabSizePadded = inputDesc[getIdx(InputIdxEntry::LOGITS)].dims.d[1];
auto const batchSize = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[0];
auto const maxDecodingTokens = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[1];
auto logits = static_cast<T const*>(inputs[getIdx(InputIdxEntry::LOGITS)]);
auto prevDraftLens = reinterpret_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::DRAFT_LENS)]);
int8_t* workspaceBytePtr = reinterpret_cast<int8_t*>(workspace);
size_t offset{0};
auto const samplingWorkspaceSize
= getTopKWorkspaceSize<T>(batchSize, maxDecodingTokens, /* maxTopK */ 1, vocabSizePadded);
void* workspaceSampling
= reinterpret_cast<void*>(tc::nextWorkspacePtr(workspaceBytePtr, offset, samplingWorkspaceSize));
TokenIdType* outputIds = reinterpret_cast<TokenIdType*>(
tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * maxDecodingTokens * sizeof(TokenIdType)));
T const** logitsPtrs = reinterpret_cast<T const**>(
tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * maxDecodingTokens * sizeof(T*)));
SizeType32* decodingTokens
= reinterpret_cast<SizeType32*>(tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * sizeof(SizeType32)));
// Assemble pointers to logits
invokeAssembleTargetLogitsOffsets(
logitsPtrs, decodingTokens, logits, prevDraftLens, batchSize, maxDecodingTokens, vocabSizePadded, stream);
sync_check_cuda_error();
TopKSamplingKernelParams<T> params;
params.logProbsPtrs = logitsPtrs;
params.outputIds = outputIds;
params.workspace = workspaceSampling;
params.maxTopK = 1;
params.batchSize = batchSize;
params.maxBatchSize = batchSize;
params.tokensPerStep = decodingTokens;
params.maxTokensPerStep = maxDecodingTokens;
params.maxSeqLen = maxDecodingTokens;
params.vocabSizePadded = vocabSizePadded;
invokeBatchTopKSampling(params, stream);
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void EagleSampleAndAcceptDraftTokensPlugin::acceptDraftTokens(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const maxNumTokens = inputDesc[getIdx(InputIdxEntry::LOGITS)].dims.d[0];
auto const vocabSizePadded = inputDesc[getIdx(InputIdxEntry::LOGITS)].dims.d[1];
auto const batchSize = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[0];
auto const maxDecodingTokens = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[1];
auto const maxPathLen = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[2];
auto const maxDraftPathLen = maxPathLen - 1;
int8_t* workspaceBytePtr = reinterpret_cast<int8_t*>(workspace);
size_t offset{0};
auto const samplingWorkspaceSize
= getTopKWorkspaceSize<T>(batchSize, maxDecodingTokens, /* maxTopK */ 1, vocabSizePadded);
void* workspaceSampling
= reinterpret_cast<void*>(tc::nextWorkspacePtr(workspaceBytePtr, offset, samplingWorkspaceSize));
TokenIdType* outputIds = reinterpret_cast<TokenIdType*>(
tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * maxDecodingTokens * sizeof(TokenIdType)));
AcceptDraftTokensByIdsWithPathsParams<T> params;
params.outputIds = reinterpret_cast<TokenIdType*>(outputs[getIdx(OutputIdxEntry::ACCEPTED_TOKENS)]);
params.draftIds = reinterpret_cast<TokenIdType const*>(inputs[getIdx(InputIdxEntry::DRAFT_TOKEN_IDS)]);
params.targetIds = outputIds;
params.acceptedLengths = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::ACCEPTED_LEN)]);
params.paths = reinterpret_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::PATHS)]);
params.bestPathIds = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::BEST_ACCEPTED_PATHS)]);
params.batchSize = batchSize;
params.maxBatchSize = batchSize;
params.vocabSize = vocabSizePadded;
params.maxSeqLen = maxPathLen;
params.maxDraftPathLen = maxDraftPathLen;
params.maxDecodingTokens = maxDecodingTokens;
params.stream = stream;
params.checkParams();
acceptDraftTokensByIdsWithPaths(params);
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void EagleSampleAndAcceptDraftTokensPlugin::doGreedy(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// Sample all main head tokens with Top-1.
samplePrimeHeadTokens<T>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
// Greedy accept tokens based on token ids, write the best path and best token id.
acceptDraftTokens<T>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void EagleSampleAndAcceptDraftTokensPlugin::selectLastAccTokenAndComputeIndicesCumSum(
nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[0];
auto const maxDecodingTokens = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[1];
auto const maxPathLen = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[2];
auto lastAcceptedTokenIds
= reinterpret_cast<TokenIdType*>(outputs[getIdx(OutputIdxEntry::LAST_ACCEPTED_TOKEN_IDS)]);
auto exclusiveSumLastAcceptedIndices
= reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::EXCLUSIVE_SUM_LAST_TOKEN_INDICES)]);
auto prevDraftLens = reinterpret_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::DRAFT_LENS)]);
auto acceptedTokenIds = reinterpret_cast<TokenIdType const*>(outputs[getIdx(OutputIdxEntry::ACCEPTED_TOKENS)]);
auto acceptedLengths = reinterpret_cast<SizeType32 const*>(outputs[getIdx(OutputIdxEntry::ACCEPTED_LEN)]);
auto bestPathIds = reinterpret_cast<SizeType32 const*>(outputs[getIdx(OutputIdxEntry::BEST_ACCEPTED_PATHS)]);
auto paths = reinterpret_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::PATHS)]);
invokeSelectLastAccTokenAndComputeIndicesCumSum(lastAcceptedTokenIds, exclusiveSumLastAcceptedIndices,
prevDraftLens, acceptedTokenIds, acceptedLengths, bestPathIds, paths, batchSize, maxDecodingTokens, maxPathLen,
stream);
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void EagleSampleAndAcceptDraftTokensPlugin::enqueueType(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// TODO split batch into greedy and non-greedy and execute both paths
if (mGreedySampling)
{
doGreedy<T>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
}
else
{
// TODO fill me
TLLM_CHECK_WITH_INFO(false, "Non-greedy sampling is not supported yet");
}
// Find last accepted tokens and do cumulative sum of accepted indices.
selectLastAccTokenAndComputeIndicesCumSum(inputDesc, outputDesc, inputs, outputs, workspace, stream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
int EagleSampleAndAcceptDraftTokensPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept
{
auto const logitsType = inputDesc[getIdx(InputIdxEntry::LOGITS)].type;
if (logitsType == nvinfer1::DataType::kFLOAT)
{
enqueueType<float>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
}
else if (logitsType == nvinfer1::DataType::kHALF)
{
enqueueType<__half>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
}
else
{
TLLM_CHECK_WITH_INFO(false, "Unsupported logits type");
}
return 0;
}
// IPluginV2Ext Methods
nvinfer1::DataType EagleSampleAndAcceptDraftTokensPlugin::getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
{
TLLM_CHECK(index < 7);
// input 1 is draft tokens now of int32 type. All outputs are int32_t as well.
return inputTypes[getIdx(InputIdxEntry::DRAFT_TOKEN_IDS)];
}
// IPluginV2 Methods
char const* EagleSampleAndAcceptDraftTokensPlugin::getPluginType() const noexcept
{
return EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_NAME;
}
char const* EagleSampleAndAcceptDraftTokensPlugin::getPluginVersion() const noexcept
{
return EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_VERSION;
}
int EagleSampleAndAcceptDraftTokensPlugin::getNbOutputs() const noexcept
{
return 7;
}
int EagleSampleAndAcceptDraftTokensPlugin::initialize() noexcept
{
return 0;
}
void EagleSampleAndAcceptDraftTokensPlugin::terminate() noexcept {}
size_t EagleSampleAndAcceptDraftTokensPlugin::getSerializationSize() const noexcept
{
return sizeof(mDtype) + sizeof(mGreedySampling);
}
void EagleSampleAndAcceptDraftTokensPlugin::serialize(void* buffer) const noexcept
{
char *d = static_cast<char*>(buffer), *a = d;
write(d, mDtype);
write(d, mGreedySampling);
assert(d == a + getSerializationSize());
}
void EagleSampleAndAcceptDraftTokensPlugin::destroy() noexcept
{
// This gets called when the network containing plugin is destroyed
delete this;
}
///////////////
EagleSampleAndAcceptDraftTokensPluginCreator::EagleSampleAndAcceptDraftTokensPluginCreator()
{
// Fill PluginFieldCollection with PluginField arguments metadata
mPluginAttributes.clear();
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
mPluginAttributes.emplace_back(PluginField("greedy_sampling", nullptr, PluginFieldType::kINT32, 1));
mFC.nbFields = mPluginAttributes.size();
mFC.fields = mPluginAttributes.data();
}
char const* EagleSampleAndAcceptDraftTokensPluginCreator::getPluginName() const noexcept
{
return EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_NAME;
}
char const* EagleSampleAndAcceptDraftTokensPluginCreator::getPluginVersion() const noexcept
{
return EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_VERSION;
}
PluginFieldCollection const* EagleSampleAndAcceptDraftTokensPluginCreator::getFieldNames() noexcept
{
return &mFC;
}
IPluginV2* EagleSampleAndAcceptDraftTokensPluginCreator::createPlugin(
char const* name, PluginFieldCollection const* fc) noexcept
{
PluginField const* fields = fc->fields;
nvinfer1::DataType type;
bool greedySampling;
// Read configurations from each fields
for (int i = 0; i < fc->nbFields; ++i)
{
char const* attrName = fields[i].name;
if (!strcmp(attrName, "type_id"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
type = static_cast<nvinfer1::DataType>(*(static_cast<nvinfer1::DataType const*>(fields[i].data)));
}
else if (!strcmp(attrName, "greedy_sampling"))
{
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
greedySampling = static_cast<bool>(*static_cast<int32_t const*>(fields[i].data));
}
}
try
{
auto* obj = new EagleSampleAndAcceptDraftTokensPlugin(type, greedySampling);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}
IPluginV2* EagleSampleAndAcceptDraftTokensPluginCreator::deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept
{
// This object will be deleted when the network is destroyed, which will
// call EagleSampleAndAcceptDraftTokensPlugin::destroy()
try
{
auto* obj = new EagleSampleAndAcceptDraftTokensPlugin(serialData, serialLength);
obj->setPluginNamespace(mNamespace.c_str());
return obj;
}
catch (std::exception const& e)
{
caughtError(e);
}
return nullptr;
}

View File

@ -0,0 +1,163 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/plugins/common/plugin.h"
#include <cassert>
#include <memory>
#include <set>
#include <string>
#include <vector>
namespace tensorrt_llm::plugins
{
class EagleSampleAndAcceptDraftTokensPlugin : public BasePlugin
{
public:
EagleSampleAndAcceptDraftTokensPlugin(nvinfer1::DataType type, bool greedySampling);
EagleSampleAndAcceptDraftTokensPlugin(void const* data, size_t length);
~EagleSampleAndAcceptDraftTokensPlugin() override = default;
// IPluginV2DynamicExt Methods
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs,
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
bool supportsFormatCombination(
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override;
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override;
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override;
int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
// IPluginV2Ext Methods
nvinfer1::DataType getOutputDataType(
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override;
// IPluginV2 Methods
char const* getPluginType() const noexcept override;
char const* getPluginVersion() const noexcept override;
int getNbOutputs() const noexcept override;
int initialize() noexcept override;
void terminate() noexcept override;
size_t getSerializationSize() const noexcept override;
void serialize(void* buffer) const noexcept override;
void destroy() noexcept override;
private:
enum class InputIdxEntry : int32_t
{
//! [num_tokens, vocab_size_padded]
LOGITS = 0,
//! [batch_size, max_decoding_draft_tokens]
DRAFT_TOKEN_IDS,
//! [batch_size]
DRAFT_LENS,
//! [batch_size]
TEMPERATURE,
//! []?
RAND_VALIDATION,
//! [batch_size, max_decoding_tokens, max_path_len]
PATHS
};
enum class OutputIdxEntry : int32_t
{
//! [batch_size, max_draft_path_len]
ACCEPTED_TOKENS = 0,
//! [batch_size]
ACCEPTED_LEN,
//! [batch_size]
BEST_ACCEPTED_PATHS,
//! [batch_size]
LAST_ACCEPTED_TOKEN_IDS,
//! [batch_size]
EXCLUSIVE_SUM_LAST_TOKEN_INDICES,
//! [batch_size, max_decoding_draft_tokens]
NEXT_DRAFT_TOKEN_IDS,
//! [batch_size]
NEXT_DRAFT_LENS
};
int32_t getIdx(InputIdxEntry idx) const
{
return static_cast<int32_t>(idx);
}
int32_t getIdx(OutputIdxEntry idx) const
{
return static_cast<int32_t>(idx);
}
private:
template <typename T>
size_t getWorkspaceSizeType(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept;
template <typename T>
void samplePrimeHeadTokens(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept;
template <typename T>
void acceptDraftTokens(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept;
template <typename T>
void doGreedy(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept;
void selectLastAccTokenAndComputeIndicesCumSum(nvinfer1::PluginTensorDesc const* inputDesc,
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
cudaStream_t stream) noexcept;
template <typename T>
void enqueueType(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept;
private:
nvinfer1::DataType mDtype;
bool mGreedySampling;
};
class EagleSampleAndAcceptDraftTokensPluginCreator : public BaseCreator
{
public:
EagleSampleAndAcceptDraftTokensPluginCreator();
char const* getPluginName() const noexcept override;
char const* getPluginVersion() const noexcept override;
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
nvinfer1::IPluginV2* deserializePlugin(
char const* name, void const* serialData, size_t serialLength) noexcept override;
private:
static nvinfer1::PluginFieldCollection mFC;
static std::vector<nvinfer1::PluginField> mPluginAttributes;
};
} // namespace tensorrt_llm::plugins

View File

@ -24,14 +24,17 @@ using namespace tensorrt_llm::kernels::cutlass_kernels;
using tensorrt_llm::plugins::WeightOnlyGroupwiseQuantMatmulPluginCreator;
using tensorrt_llm::plugins::WeightOnlyGroupwiseQuantMatmulPlugin;
using tensorrt_llm::plugins::WeightOnlyGroupwiseQuantGemmPluginProfiler;
using tensorrt_llm::plugins::WeightOnlyGemmRunnerPtr;
// Flags for indicating whether the corresponding inputs are applied in mQuantAlgo
// mQuantAlgo = pre_quant_scale * PRE_QUANT_SCALE + zero * ZERO + bias * BIAS
// Here pre_quant_scale, zero and bias are boolean type
// mQuantAlgo = int8_weight * INT8_WEIGHT + use_w4a8_awq * FP8_ALPHA + pre_quant_scale * PRE_QUANT_SCALE
// + zero * ZERO + bias * BIAS
// Here int8_weight, use_w4a8_awq, pre_quant_scale, zero and bias are boolean type
static constexpr int BIAS = int(1) << 0;
static constexpr int ZERO = int(1) << 1;
static constexpr int PRE_QUANT_SCALE = int(1) << 2;
static constexpr int FP8_ALPHA = int(1) << 3;
static constexpr int INT8_WEIGHT = int(1) << 4;
using tensorrt_llm::plugins::read;
using tensorrt_llm::plugins::write;
@ -43,11 +46,10 @@ std::vector<nvinfer1::PluginField> WeightOnlyGroupwiseQuantMatmulPluginCreator::
void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic(int m, int n, int k,
WeightOnlyGroupwiseQuantGemmPluginProfiler::Config const& tactic, char* workspace, cudaStream_t const& stream)
{
// Quantized weights are packed in FP16 format (INT4*4 -> FP16)
int const originalN = n * FP16_INT4_RATIO;
// Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
int const originalN = mQuantAlgo & INT8_WEIGHT ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO;
half* actPtr = reinterpret_cast<half*>(workspace);
cutlass::uint4b_t* weightPtr = reinterpret_cast<cutlass::uint4b_t*>(
nextWorkspacePtr(reinterpret_cast<int8_t*>(actPtr), m * k * sizeof(half)));
void* weightPtr = nextWorkspacePtr(reinterpret_cast<int8_t*>(actPtr), m * k * sizeof(half));
half* inputScalesPtr
= reinterpret_cast<half*>(nextWorkspacePtr(reinterpret_cast<int8_t*>(weightPtr), n * k * sizeof(float)));
half* zerosPtr = reinterpret_cast<half*>(
@ -69,15 +71,22 @@ void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic(int m, int n, int k,
}
int const wsSize = mRunner->getWorkspaceSize(m, originalN, k);
mRunner->gemm(actPtr, weightPtr, inputScalesPtr, zerosPtr, biasesPtr, outputPtr, m, originalN, k, mGroupSize,
tactic, workspacePtr, wsSize, stream);
if (mQuantAlgo & INT8_WEIGHT)
{
mRunner->gemm(actPtr, reinterpret_cast<int8_t*>(weightPtr), inputScalesPtr, zerosPtr, biasesPtr, outputPtr, m,
originalN, k, mGroupSize, tactic, workspacePtr, wsSize, stream);
}
else
{
mRunner->gemm(actPtr, reinterpret_cast<cutlass::uint4b_t*>(weightPtr), inputScalesPtr, zerosPtr, biasesPtr,
outputPtr, m, originalN, k, mGroupSize, tactic, workspacePtr, wsSize, stream);
}
}
void WeightOnlyGroupwiseQuantGemmPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k)
{
// Quantized weights are packed in FP16 format (INT4*4 -> FP16)
int const originalN = n * FP16_INT4_RATIO;
// Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
int const originalN = mQuantAlgo & INT8_WEIGHT ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO;
std::vector<size_t> workspaces = {
maxM * k * sizeof(half), // A
k * n * sizeof(float), // B
@ -129,6 +138,38 @@ WeightOnlyGroupwiseQuantMatmulPlugin::WeightOnlyGroupwiseQuantMatmulPlugin(
(int) length, (int) (d - a));
}
template <typename ActivationType, typename WeightType, typename OutputType, cutlass::WeightOnlyQuantOp QuantOp>
using GemmRunner = tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp,
OutputType, OutputType, OutputType>;
template <typename ActivationType, typename WeightType, typename OutputType>
WeightOnlyGemmRunnerPtr selectGemmRunnerForZERO(int quant_algo)
{
if (quant_algo & ZERO)
{
return std::make_shared<GemmRunner<ActivationType, WeightType, OutputType,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>>();
}
else
{
return std::make_shared<
GemmRunner<ActivationType, WeightType, OutputType, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>>();
}
}
template <typename ActivationType>
WeightOnlyGemmRunnerPtr selectGemmRunnerForWeightType(int quant_algo)
{
if (quant_algo & INT8_WEIGHT)
{
return selectGemmRunnerForZERO<ActivationType, uint8_t, ActivationType>(quant_algo);
}
else
{
return selectGemmRunnerForZERO<ActivationType, cutlass::uint4b_t, ActivationType>(quant_algo);
}
}
void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int quant_algo, int group_size)
{
mArch = tensorrt_llm::common::getSMVersion();
@ -136,7 +177,7 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua
mQuantAlgo = quant_algo;
mGroupSize = group_size;
// quant_algo = fp8_alpha * 8 + pre_quant_scale * 4 + zero * 2 + bias
// quant_algo = int8_weight * 16 + fp8_alpha * 8 + pre_quant_scale * 4 + zero * 2 + bias
mPreQuantScaleInputIdx = (quant_algo & PRE_QUANT_SCALE) ? 1 : 0;
mWeightInputIdx = mPreQuantScaleInputIdx + 1;
mScalesInputIdx = mWeightInputIdx + 1;
@ -146,6 +187,7 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua
if (mType == nvinfer1::DataType::kHALF)
{
// CUTLASS kernel selection
if (quant_algo & FP8_ALPHA)
{
// Ada & Hopper style kernels
@ -153,45 +195,34 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua
{
TLLM_THROW("W4A(fp)8 kernel is unsupported on pre-Ada (sm<89) architectures!");
}
if (quant_algo & ZERO)
{
// has zeros
m_weightOnlyGroupwiseGemmRunner = std::make_shared<
tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, cutlass::uint4b_t,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, half, half>>();
}
else
{
// no zeros
m_weightOnlyGroupwiseGemmRunner
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_fp8_e4m3,
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, half, half, half>>();
}
assert(!(quant_algo & INT8_WEIGHT) && "W4A(fp)8 kernel requires INT4 weight!");
m_weightOnlyGroupwiseGemmRunner
= selectGemmRunnerForZERO<__nv_fp8_e4m3, cutlass::uint4b_t, half>(quant_algo);
}
else
{
if (quant_algo & ZERO)
{
// has zeros
m_weightOnlyGroupwiseGemmRunner
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<half,
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>>();
}
else
{
// no zeros
m_weightOnlyGroupwiseGemmRunner
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<half,
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>>();
}
m_weightOnlyGroupwiseGemmRunner = selectGemmRunnerForWeightType<half>(quant_algo);
}
// CUDA kernel selection
if (quant_algo & INT8_WEIGHT)
{
// INT8 weight
mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
mArch, tensorrt_llm::kernels::weight_only::KernelType::FP16Int8Groupwise);
mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::FP16Int8Groupwise;
}
else
{
// INT4 weight
mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
mArch, tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise);
mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise;
}
mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
mArch, tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise);
mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise;
}
#if defined(ENABLE_BF16)
else if (mType == nvinfer1::DataType::kBF16)
{
// CUTLASS kernel selection
if (quant_algo & FP8_ALPHA)
{
// FP8 requires at least sm89 devices
@ -203,24 +234,23 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua
}
else
{
if (quant_algo & ZERO)
{
// has zeros
m_weightOnlyGroupwiseGemmRunner
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16,
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>>();
}
else
{
// no zeros
m_weightOnlyGroupwiseGemmRunner
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16,
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>>();
}
m_weightOnlyGroupwiseGemmRunner = selectGemmRunnerForWeightType<__nv_bfloat16>(quant_algo);
}
// CUDA kernel selection
if (quant_algo & INT8_WEIGHT)
{
// INT8 weight
mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int8Groupwise);
mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int8Groupwise;
}
else
{
// INT4 weight
mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise);
mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise;
}
mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise);
mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise;
}
#endif
else
@ -273,8 +303,9 @@ nvinfer1::DimsExprs WeightOnlyGroupwiseQuantMatmulPlugin::getOutputDimensions(
ret.d[ii] = inputs[0].d[ii];
}
// int4 weight only quant
ret.d[nbDimsA - 1] = exprBuilder.constant(inputs[mWeightInputIdx].d[1]->getConstantValue() * FP16_INT4_RATIO);
// int4/int8 weight only quant (INT4*4 -> FP16, INT8*2 -> FP16)
int const weight_multiplier = mQuantAlgo & INT8_WEIGHT ? FP16_INT8_RATIO : FP16_INT4_RATIO;
ret.d[nbDimsA - 1] = exprBuilder.constant(inputs[mWeightInputIdx].d[1]->getConstantValue() * weight_multiplier);
return ret;
}
@ -320,11 +351,12 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::configurePlugin(nvinfer1::DynamicPlug
int const maxK = in[0].max.d[in[0].max.nbDims - 1];
// Quantized weights are packed in FP16 format (INT4*4 -> FP16)
int const maxN = in[mWeightInputIdx].max.d[1] * FP16_INT4_RATIO;
// Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
int const weight_multiplier = mQuantAlgo & INT8_WEIGHT ? FP16_INT8_RATIO : FP16_INT4_RATIO;
int const maxN = in[mWeightInputIdx].max.d[1] * weight_multiplier;
auto const K = maxK;
auto const N = maxN / FP16_INT4_RATIO;
auto const N = maxN / weight_multiplier;
if (!mDims.isInitialized())
{
@ -424,8 +456,9 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(nvinfer1::PluginTensorDesc con
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF, "No valid weightOnlyGropwiseQuantMatmul configuration");
#endif
// Quantized weights are packed in FP16 format (INT4*4 -> FP16)
int real_n = n * FP16_INT4_RATIO;
// Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
int real_n = mQuantAlgo & INT8_WEIGHT ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO;
if (use_cuda_kernel)
{
void const* pre_quant_scale_ptr = nullptr;

View File

@ -46,6 +46,7 @@ constexpr int32_t INT8_BITS = 8;
constexpr int32_t INT4_BITS = 4;
constexpr int32_t INT8_INT4_RATIO = INT8_BITS / INT4_BITS;
constexpr int32_t FP16_INT4_RATIO = FP16_BITS / INT4_BITS;
constexpr int32_t FP16_INT8_RATIO = FP16_BITS / INT8_BITS;
inline int32_t getWeightTypeMultiplier(WeightTypeId weightTypeId)
{

View File

@ -140,6 +140,7 @@ void InitBindings(pybind11::module_& m)
.def_readwrite("iter", &tle::IterationStats::iter)
.def_readwrite("iter_latency_ms", &tle::IterationStats::iterLatencyMS)
.def_readwrite("new_active_requests_queue_latency_ms", &tle::IterationStats::newActiveRequestsQueueLatencyMS)
.def_readwrite("num_new_active_requests", &tle::IterationStats::numNewActiveRequests)
.def_readwrite("num_active_requests", &tle::IterationStats::numActiveRequests)
.def_readwrite("num_queued_requests", &tle::IterationStats::numQueuedRequests)
.def_readwrite("num_completed_requests", &tle::IterationStats::numCompletedRequests)
@ -180,6 +181,9 @@ void InitBindings(pybind11::module_& m)
.def_readwrite("scheduled", &tle::RequestStats::scheduled)
.def_readwrite("paused", &tle::RequestStats::paused)
.def_readwrite("dis_serving_stats", &tle::RequestStats::disServingStats)
.def_readwrite("alloc_total_blocks_per_request", &tle::RequestStats::allocTotalBlocksPerRequest)
.def_readwrite("alloc_new_blocks_per_request", &tle::RequestStats::allocNewBlocksPerRequest)
.def_readwrite("reused_blocks_per_request", &tle::RequestStats::reusedBlocksPerRequest)
.def("to_json_str",
[](tle::RequestStats const& iterationStats) { return tle::JsonSerialization::toJsonStr(iterationStats); });

View File

@ -266,6 +266,7 @@ void parsePluginConfig(ModelConfig& modelConfig, Json const& pluginConfig)
auto const manageWeightsType = parseJsonFieldOr<bool>(pluginConfig, "manage_weights", false)
? ModelConfig::ManageWeightsType::kEnabled
: ModelConfig::ManageWeightsType::kDisabled;
auto const ppReduceScatter = parseJsonFieldOr<bool>(pluginConfig, "pp_reduce_scatter", false);
TLLM_CHECK_WITH_INFO(
!removeInputPadding || modelConfig.getMaxNumTokens(), "Padding removal requires max_num_tokens to be set.");
@ -283,6 +284,7 @@ void parsePluginConfig(ModelConfig& modelConfig, Json const& pluginConfig)
modelConfig.setPagedContextFMHA(pagedContextFMHA);
modelConfig.useXQA(useXQA);
modelConfig.setManageWeightsType(manageWeightsType);
modelConfig.setPpReduceScatter(ppReduceScatter);
}
void parseLora(ModelConfig& modelConfig, Json const& json, Json const& pluginConfig, bool engineVersionNone,

View File

@ -72,7 +72,6 @@ auto const kProfileMbIdxs = populateMicrobatchIndexes();
GptSession::Config setPath(GptSession::Config const& original, std::string const& path)
{
GptSession::Config config = original;
config.enginePath = std::filesystem::path(path);
return config;
}

View File

@ -408,9 +408,7 @@ void TllmRuntime::loadManagedWeights(RawEngine const& rawEngine, int localRank)
{
TLLM_LOG_DEBUG("Loading managed weight: %s", name.c_str());
auto iTensor = tensorrt_llm::executor::detail::toITensor(weight);
auto weightsDevice = std::shared_ptr<ITensor>{
manager.allocate(MemoryType::kGPU, iTensor->getShape(), iTensor->getDataType())};
manager.copy(iTensor->data(), *weightsDevice, MemoryType::kCPU);
auto weightsDevice = std::shared_ptr<ITensor>{manager.copyFrom(*iTensor, MemoryType::kGPU)};
mManagedWeightsMap.insert(std::make_pair(name, weightsDevice));
}
}

View File

@ -1326,16 +1326,34 @@ public:
void callAcceptByIdsWithPaths()
{
tksp::acceptDraftTokensByIdsWithPaths(bufferCast<SizeType32>(*mOutputTokens),
bufferCast<SizeType32>(*mDraftTokens), bufferCast<SizeType32>(*mTargetTokens),
bufferCast<SizeType32>(*mSequenceLengths), bufferCast<SizeType32>(*mAcceptedLengths),
reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedFinal)),
bufferCast<SizeType32>(*mBatchSlots), bufferCast<SizeType32>(*mPaths), bufferCast<SizeType32>(*mEndIds),
reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaInputLogitsPtrs)),
reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaLogitsPtrs)),
bufferCast<SizeType32>(*mTokensPerStep), bufferCast<SizeType32>(*mTokensPerStep),
bufferCast<SizeType32>(*mBestPaths), mBatchSize, mMaxBatchSize, mVocabSize, mMaxSeqLen, mMaxNumHeads,
mMaxDraftSeqPerStep, mStream->get());
tksp::AcceptDraftTokensByIdsWithPathsParams<T> params;
params.outputIds = bufferCast<SizeType32>(*mOutputTokens);
params.draftIds = bufferCast<SizeType32>(*mDraftTokens);
params.targetIds = bufferCast<SizeType32>(*mTargetTokens);
params.sequenceLengths = bufferCast<SizeType32>(*mSequenceLengths);
params.acceptedLengths = bufferCast<SizeType32>(*mAcceptedLengths);
params.finishedFinal
= reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedFinal));
params.batchSlots = bufferCast<SizeType32>(*mBatchSlots);
params.paths = bufferCast<SizeType32>(*mPaths);
params.endIds = bufferCast<SizeType32>(*mEndIds);
params.medusaLogits = reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaInputLogitsPtrs));
params.logitsPtrs = reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaLogitsPtrs));
params.curTokensPerStep = bufferCast<SizeType32>(*mTokensPerStep);
params.targetTokensPerStep = bufferCast<SizeType32>(*mTokensPerStep);
params.bestPathIds = bufferCast<SizeType32>(*mBestPaths);
params.batchSize = mBatchSize;
params.maxBatchSize = mMaxBatchSize;
params.vocabSize = mVocabSize;
params.maxSeqLen = mMaxSeqLen;
params.maxDraftPathLen = mMaxNumHeads;
params.maxDecodingTokens = mMaxDraftSeqPerStep;
params.stream = mStream->get();
params.checkParams();
tksp::acceptDraftTokensByIdsWithPaths(params);
}
void callTestedKernel()

View File

@ -91,54 +91,59 @@ TYPED_TEST_SUITE(AirTopPSamplingKernelTest, FloatAndHalfTypes);
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessSmallP)
{
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.2f));
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.2f));
};
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeP)
{
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.9f));
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.9f));
};
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessAncestral)
{
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(1.0f));
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(1.0f));
};
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeVocabSmallP)
{
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.2f));
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.2f));
};
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeVocabLargeP)
{
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.9f));
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.9f));
};
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessSmallP)
{
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.2f).setDeterministicTopP(true));
this->runTest(
SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.2f).setDeterministicTopP(true));
};
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeP)
{
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.9f).setDeterministicTopP(true));
this->runTest(
SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.9f).setDeterministicTopP(true));
};
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessAncestral)
{
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(1.0f).setDeterministicTopP(true));
this->runTest(
SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(1.0f).setDeterministicTopP(true));
};
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeVocabSmallP)
{
this->runTest(
SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.2f).setDeterministicTopP(true));
SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.2f).setDeterministicTopP(
true));
};
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeVocabLargeP)
{
this->runTest(
SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.9f).setDeterministicTopP(true));
SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.9f).setDeterministicTopP(
true));
};
class AirTopPSamplingKernelUtilsTest : public SamplingKernelTest<float>

View File

@ -110,6 +110,8 @@ void SamplingKernelTest<T>::setupBuffers(SamplingKernelTestParam const& param)
auto const topK = param.topK;
auto const topP = param.topP;
// TopK == 0 case (TopP kernel)
auto const topKDistUpperBound = std::max(topK, static_cast<unsigned int>(1));
std::mt19937 gen(42);
@ -133,7 +135,7 @@ void SamplingKernelTest<T>::setupBuffers(SamplingKernelTestParam const& param)
0, vocabSize - 1); // -1 because uniform_int_distribution generates closed interval
std::uniform_real_distribution<> skipDecodeDist(0, 1);
std::uniform_real_distribution<> topPDist(0, topP);
std::uniform_int_distribution<> topKDist(1, topK);
std::uniform_int_distribution<> topKDist(1, topKDistUpperBound);
std::uniform_int_distribution<> tokensPerStepDist(1, maxTokensPerStep);
std::uniform_int_distribution<> seqLenDist(0, mMaxSeqLen - maxTokensPerStep);
std::uniform_real_distribution<> logProbDist(-3.f, 3.f);
@ -158,7 +160,7 @@ void SamplingKernelTest<T>::setupBuffers(SamplingKernelTestParam const& param)
endIdsHostPtr[bi] = endIdsDistr(gen);
skipDecodeHostPtr[bi] = skipDecodeDist(gen) > 0.8;
topPsHostPtr[bi] = topPDist(gen);
topKsHostPtr[bi] = topKDist(gen);
topKsHostPtr[bi] = topK == 0 ? 0 : topKDist(gen);
tokensPerStepPtr[bi] = tokensPerStepDist(gen);
finishedHostPtr[bi] = finishedDist(gen) > 0.8 ? tk::FinishedState::finished() : tk::FinishedState::empty();
}
@ -196,9 +198,9 @@ void SamplingKernelTest<T>::setupBuffers(SamplingKernelTestParam const& param)
// Init logits randomly
auto logitsHostPtr = bufferCast<T>(*mLogitsHost);
initRandom(logitsHostPtr, batchSize * maxTokensPerStep * vocabSize, -3.0f, 3.0f);
// Only in greedy search we can guarantee the selected token and stop by condition
if (topK == 1)
// TopK == 1 for TopK kernel greedy, TopK == 0 for TopP kernels
if (topK <= 1)
{
for (SizeType32 bi = 0; bi < batchSize; ++bi)
{
@ -231,13 +233,29 @@ std::vector<SizeType32> SamplingKernelTest<T>::computeTopKTopPVariants(
auto topK = bufferCast<int32_t>(*mTopKsHost)[batchSlot];
auto topP = bufferCast<float>(*mTopPsHost)[batchSlot];
allowedTokens.insert(allowedTokens.begin(), indices.begin(), indices.begin() + topK);
if (topK > 0) // handling top K kernel, top P result based on topK tokens
{
float sSum = 0.f; // sSum as in samplingTopKKernels.cu
for (auto ki = 0; ki < topK; ki++)
{
sSum += static_cast<float>(probsPtr[indices[ki]]);
}
topP *= sSum; // the adjusted topP in the selected topK distribution
}
float totalProb = 0.f;
SizeType32 idx = 0;
while (totalProb < topP && idx < vocabSize)
{
allowedTokens.push_back(indices[idx]);
totalProb += static_cast<float>(probsPtr[indices[idx++]]);
// cuda may selected a different index with same probability in kernel reduce, in test we allow them
while (idx < vocabSize
&& static_cast<float>(probsPtr[indices[idx]]) == static_cast<float>(probsPtr[indices[idx - 1]]))
{
allowedTokens.push_back(indices[idx]);
totalProb += static_cast<float>(probsPtr[indices[idx++]]);
}
}
return allowedTokens;
}
@ -284,12 +302,15 @@ void SamplingKernelTest<T>::verifyResult(SamplingKernelTestParam const& param)
auto const tokensPerStep = tokensPerStepPtr[batchSlot];
for (SizeType32 ti = 0; ti < tokensPerStep; ++ti)
{
auto kResults = param.returnAllTopK ? bufferCast<int32_t>(*mTopKsHost)[batchSlot] : 1;
for (SizeType32 ki = 0; ki < kResults; ++ki)
auto topK = bufferCast<int32_t>(*mTopKsHost)[batchSlot];
auto kResults = param.returnAllSelectedTokens ? (topK == 0 ? vocabSize : topK) : 1;
auto topKTopPVariants = computeTopKTopPVariants(bi, batchSlot, ti, maxTokensPerStep, vocabSize);
SizeType32 ki;
for (ki = 0; ki < kResults && ki < topKTopPVariants.size(); ++ki)
{
// Set reference finished state to true if we finished before or at current step
auto const idsIdx = param.returnAllTopK ? ti * mMaxTopK + ki : seqLengthsOrigHostPtr[batchSlot] + ti;
auto const idsIdx
= param.returnAllSelectedTokens ? ti * mMaxTopK + ki : seqLengthsOrigHostPtr[batchSlot] + ti;
auto const outputId = outputIdsHostPtr[batchSlot * mMaxSeqLen + idsIdx];
// Check the range of the returned token ([0, vocabSize))
EXPECT_TRUE((outputId >= 0) && (outputId < vocabSize));
@ -299,7 +320,7 @@ void SamplingKernelTest<T>::verifyResult(SamplingKernelTestParam const& param)
if (!skipDecodeHostPtr[batchSlot] && !finishedOrigHostPtr[batchSlot].isFinished()
&& !finishedOrigHostPtr[batchSlot].isSkipDecoding())
{
if (maxTokensPerStep == 1 && !param.returnAllTopK)
if (maxTokensPerStep == 1 && !param.returnAllSelectedTokens)
{
if (generatedEOS)
{
@ -314,8 +335,6 @@ void SamplingKernelTest<T>::verifyResult(SamplingKernelTestParam const& param)
}
}
auto topKTopPVariants = computeTopKTopPVariants(bi, batchSlot, ti, maxTokensPerStep, vocabSize);
bool found = false;
for (auto const& var : topKTopPVariants)
{
@ -340,11 +359,24 @@ void SamplingKernelTest<T>::verifyResult(SamplingKernelTestParam const& param)
EXPECT_EQ(finishedHostPtr[batchSlot].isFinished(), finishedOrigHostPtr[batchSlot].isFinished());
}
}
// a boundary check for returnAllSelectedTokens in topP kernel and when TopP selected indices < topK in topK
// kernel.
if (!skipDecodeHostPtr[batchSlot] && !finishedOrigHostPtr[batchSlot].isFinished()
&& !finishedOrigHostPtr[batchSlot].isSkipDecoding())
{
if (param.returnAllSelectedTokens && (topK == 0 || ki != topK))
{
auto const idsIdx = ti * mMaxTopK + ki;
auto const outputId = outputIdsHostPtr[batchSlot * mMaxSeqLen + idsIdx];
EXPECT_EQ(outputId, -1);
}
}
}
}
// Cum log probs is not supported for multiple tokens per step or all top K return
if (maxTokensPerStep == 1 && !param.returnAllTopK)
if (maxTokensPerStep == 1 && !param.returnAllSelectedTokens)
{
for (int32_t bi = 0; bi < batchSize; ++bi)
{

View File

@ -194,7 +194,7 @@ struct SamplingKernelTestParam
bool normalizeLogProbs{false};
bool logitsHasProbs{true};
int32_t maxTokensPerStep{1};
bool returnAllTopK{false};
bool returnAllSelectedTokens{false};
bool useLogitsPtrs{false};
bool isDeterministicTopP{false};
@ -228,9 +228,9 @@ struct SamplingKernelTestParam
return *this;
}
SamplingKernelTestParam& setReturnAllTopK()
SamplingKernelTestParam& setReturnAllSelectedTokens()
{
returnAllTopK = true;
returnAllSelectedTokens = true;
return *this;
}

View File

@ -70,10 +70,10 @@ protected:
kernelParams.finishedOutput = reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(*this->mFinishedDevice));
kernelParams.skipDecode = bufferCast<bool>(*this->mSkipDecodeDevice);
kernelParams.cumLogProbs = params.returnAllTopK || params.maxTokensPerStep > 1
kernelParams.cumLogProbs = params.returnAllSelectedTokens || params.maxTokensPerStep > 1
? nullptr
: bufferCast<float>(*this->mCumLogProbsDevice);
kernelParams.outputLogProbs = params.returnAllTopK || params.maxTokensPerStep > 1
kernelParams.outputLogProbs = params.returnAllSelectedTokens || params.maxTokensPerStep > 1
? nullptr
: bufferCast<float>(*this->mOutputLogProbsDevice);
kernelParams.curandState = reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*this->mCurandStatesDevice));
@ -84,7 +84,7 @@ protected:
kernelParams.vocabSizePadded = params.vocabSize;
kernelParams.normalizeLogProbs = params.normalizeLogProbs;
kernelParams.logitsHasProbs = params.logitsHasProbs;
kernelParams.returnAllTopK = params.returnAllTopK;
kernelParams.returnAllSelectedTokens = params.returnAllSelectedTokens;
// Perform batched TopK sampling
tk::invokeBatchTopKSampling(kernelParams, this->mStream->get());
@ -136,7 +136,7 @@ TYPED_TEST(TopKSamplingKernelTest, CorrectnessTopKMaxTokensPerStep)
SamplingKernelTestParam().setBatchSize(16).setVocabSize(4000).setTopK(63).setTopP(1.0f).setMaxTokensPerStep(4));
};
TYPED_TEST(TopKSamplingKernelTest, CorrectnessReturnAllTopK)
TYPED_TEST(TopKSamplingKernelTest, CorrectnessReturnAllSelectedTokens)
{
this->runTest(SamplingKernelTestParam()
.setBatchSize(16)
@ -144,7 +144,18 @@ TYPED_TEST(TopKSamplingKernelTest, CorrectnessReturnAllTopK)
.setTopK(10)
.setTopP(1.0f)
.setMaxTokensPerStep(4)
.setReturnAllTopK());
.setReturnAllSelectedTokens());
};
TYPED_TEST(TopKSamplingKernelTest, CorrectnessReturnAllSelectedTokensSmallP)
{
this->runTest(SamplingKernelTestParam()
.setBatchSize(16)
.setVocabSize(50)
.setTopK(20)
.setTopP(0.3f)
.setMaxTokensPerStep(4)
.setReturnAllSelectedTokens());
};
TYPED_TEST(TopKSamplingKernelTest, CorrectnessLogitsPtrs)

View File

@ -64,12 +64,15 @@ private:
kernelParams.finishedOutput = reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(*this->mFinishedDevice));
kernelParams.skipDecode = bufferCast<bool>(*this->mSkipDecodeDevice);
kernelParams.cumLogProbs = bufferCast<float>(*this->mCumLogProbsDevice);
kernelParams.outputLogProbs = bufferCast<float>(*this->mOutputLogProbsDevice);
kernelParams.cumLogProbs
= params.returnAllSelectedTokens ? nullptr : bufferCast<float>(*this->mCumLogProbsDevice);
kernelParams.outputLogProbs
= params.returnAllSelectedTokens ? nullptr : bufferCast<float>(*this->mOutputLogProbsDevice);
kernelParams.curandState = reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*this->mCurandStatesDevice));
kernelParams.batchSize = params.batchSize;
kernelParams.maxBatchSize = maxBatchSize;
kernelParams.vocabSizePadded = params.vocabSize;
kernelParams.returnAllSelectedTokens = params.returnAllSelectedTokens;
// Perform batched TopP sampling
tk::invokeBatchTopPSampling<T>(kernelParams, this->mStream->get());
@ -80,26 +83,36 @@ TYPED_TEST_SUITE(TopPSamplingKernelTest, FloatAndHalfTypes);
TYPED_TEST(TopPSamplingKernelTest, CorrectnessSmallP)
{
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.2f));
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.2f));
};
TYPED_TEST(TopPSamplingKernelTest, CorrectnessLargeP)
{
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.9f));
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.9f));
};
TYPED_TEST(TopPSamplingKernelTest, CorrectnessAncestral)
{
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(1.0f));
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(1.0f));
};
TYPED_TEST(TopPSamplingKernelTest, CorrectnessLargeVocabSmallP)
{
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.2f));
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.2f));
};
TYPED_TEST(TopPSamplingKernelTest, CorrectnessLargeVocabLargeP)
{
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.9f));
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.9f));
};
TYPED_TEST(TopPSamplingKernelTest, CorrectnessReturnAllSelectedTokens)
{
this->runTest(SamplingKernelTestParam()
.setBatchSize(16)
.setVocabSize(50)
.setTopK(0)
.setTopP(0.8f)
.setReturnAllSelectedTokens());
};
} // end of namespace

View File

@ -164,6 +164,10 @@ struct cutlassTypeMapper
return ss.str(); \
} \
};
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int8Groupwise, "FP16Int8Groupwise", half, uint8_t, 8,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int8Groupwise, "BF16Int8Groupwise", __nv_bfloat16, uint8_t, 8,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int4Groupwise, "FP16Int4Groupwise", half, cutlass::uint4b_t, 4,
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int4Groupwise, "BF16Int4Groupwise", __nv_bfloat16, cutlass::uint4b_t,
@ -367,8 +371,8 @@ bool benchmark_and_verify(int m, int n, int k, int groupsize, int warmup, int it
d_out.copy_to(h_out2.data());
float quant_scale = 1.f / (1 << (WSizeInBits - 1));
bool pass = compare<AType>(h_out1.data(), h_out2.data(), m * n, quant_scale);
printf(
"cuda kernel cost time %.6f, cutlass kernel cost time %.6f, cuda speedup %.3f\n", time1, time2, time2 / time1);
printf("cuda kernel cost time %.6f, cutlass kernel cost time %.6f, cuda speedup %.3f\n\n", time1, time2,
time2 / time1);
return pass;
}
@ -392,6 +396,10 @@ TEST(Kernel, WeightOnly)
EXPECT_TRUE(pass);
if (arch >= 75)
{
pass = benchmark_and_verify<wo::KernelType::FP16Int8Groupwise>(m, n, k, 64, warmup, iter);
EXPECT_TRUE(pass);
pass = benchmark_and_verify<wo::KernelType::FP16Int8Groupwise>(m, n, k, 128, warmup, iter);
EXPECT_TRUE(pass);
pass = benchmark_and_verify<wo::KernelType::FP16Int4Groupwise>(m, n, k, 64, warmup, iter);
EXPECT_TRUE(pass);
pass = benchmark_and_verify<wo::KernelType::FP16Int4Groupwise>(m, n, k, 128, warmup, iter);
@ -399,6 +407,10 @@ TEST(Kernel, WeightOnly)
#if defined(ENABLE_BF16)
if (arch >= 80)
{
pass = benchmark_and_verify<wo::KernelType::BF16Int8Groupwise>(m, n, k, 64, warmup, iter);
EXPECT_TRUE(pass);
pass = benchmark_and_verify<wo::KernelType::BF16Int8Groupwise>(m, n, k, 128, warmup, iter);
EXPECT_TRUE(pass);
pass = benchmark_and_verify<wo::KernelType::BF16Int4Groupwise>(m, n, k, 64, warmup, iter);
EXPECT_TRUE(pass);
pass = benchmark_and_verify<wo::KernelType::BF16Int4Groupwise>(m, n, k, 128, warmup, iter);

View File

@ -73,6 +73,8 @@ void BaseSamplingLayerTest<T>::setup(uint64_t seed, TestSamplingParams const& pa
trk::invokeFill(*mCumLogProbsDevice, float{0.0f}, *mStream);
trk::invokeFill(*mOutputLogProbsDevice, float{0.0f}, *mStream);
trk::invokeFill(*mEndIdsDevice, int32_t{mEndId}, *mStream);
tk::invokeCurandInitialize(reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStatesDevice)), nullptr,
mMaxBatchSize, seed, mStream->get());
auto batchSlotsPtr = bufferCast<int32_t>(*mBatchSlots);
for (SizeType32 bi = 0; bi < mBatchSize; ++bi)

View File

@ -720,17 +720,17 @@ void LookaheadDecodingLayerTest::verifyDecode()
BufferRange<SizeType32> cumSumRange(*mNumNewTokensCumSum);
BufferRange<SizeType32> pathOffsetsRange(*mPathsOffsets);
PRINT_VALUES(mNumNewTokensCumSum);
for (SizeType32 gbi = 0; gbi < mTestParam.maxBatchSize; gbi++)
for (SizeType32 bi = 0; bi < batchSize; bi++)
{
SizeType32 pathOffsetBegin = cumSumRange[gbi];
SizeType32 pathOffsetEnd = cumSumRange[gbi + 1];
auto gbi = BufferRange<SizeType32>(*mBatchSlots)[bi];
SizeType32 pathOffsetBegin = cumSumRange[bi];
SizeType32 pathOffsetEnd = cumSumRange[bi + 1];
TensorPtr golden = ITensor::at(mGoldenSampledTokens, {gbi});
auto sequenceLength = BufferLocation<SizeType32>(*mSequenceLengths).at(gbi);
auto numNewTokens = BufferLocation<SizeType32>(*mNumNewTokens).at(gbi);
TensorPtr newTokens = ITensor::slice(mOutputIds, {gbi, 0, sequenceLength - numNewTokens}, numNewTokens);
BufferRange<SizeType32> goldenRange(*ITensor::at(mGoldenSampledTokens, {gbi}));
BufferRange<TokenIdType> newTokensRange(
*ITensor::slice(mOutputIds, {gbi, 0, sequenceLength - numNewTokens}, numNewTokens));
BufferRange<TokenIdType> newTokensRange(*newTokens);
SizeType32 ni = 1;
for (SizeType32 poi = pathOffsetBegin; poi < pathOffsetEnd; poi++)

View File

@ -207,7 +207,7 @@ TEST(LookaheadRandomllm, gpuSampling)
kernelParams.vocabSizePadded = vocabSize;
kernelParams.normalizeLogProbs = false;
kernelParams.logitsHasProbs = false;
kernelParams.returnAllTopK = false;
kernelParams.returnAllSelectedTokens = false;
PRINT_TOKENS(mEndIds);
PRINT_VALUES(mTokensPerStep);

View File

@ -101,6 +101,9 @@ def add_parallel_info(report, parallel):
document.write(report, encoding="UTF-8", xml_declaration=True)
default_test_parallel = 2
def parallel_run_ctest(
command: _tp.Sequence[str],
cwd: _pl.Path,
@ -108,7 +111,7 @@ def parallel_run_ctest(
shell=False,
env=None,
timeout=None,
parallel=2,
parallel=default_test_parallel,
) -> None:
if parallel == 1:
return run_command(command,
@ -576,7 +579,16 @@ def run_unit_tests(build_dir: _pl.Path, timeout=1800):
excluded_tests.append("Encoder")
excluded_tests.append("EncDec")
ctest.extend(["-E", "|".join(excluded_tests)])
parallel_run_ctest(ctest, cwd=build_dir, env=cpp_env, timeout=timeout)
parallel = default_test_parallel
if parallel_override := _os.environ.get("LLM_TEST_PARALLEL_OVERRIDE", None):
parallel = int(parallel_override)
parallel_run_ctest(ctest,
cwd=build_dir,
env=cpp_env,
timeout=timeout,
parallel=parallel)
def run_single_gpu_tests(build_dir: _pl.Path,
@ -634,7 +646,17 @@ def run_single_gpu_tests(build_dir: _pl.Path,
ctest.extend(["-R", "|".join(included_tests)])
if excluded_tests:
ctest.extend(["-E", "|".join(excluded_tests)])
parallel_run_ctest(ctest, cwd=build_dir, env=cpp_env, timeout=timeout)
parallel = default_test_parallel
if parallel_override := _os.environ.get("LLM_TEST_PARALLEL_OVERRIDE",
None):
parallel = int(parallel_override)
parallel_run_ctest(ctest,
cwd=build_dir,
env=cpp_env,
timeout=timeout,
parallel=parallel)
if run_gpt:
xml_output_file = build_dir / "results-single-gpu-disagg-executor_gpt.xml"
trt_model_test = produce_mpirun_command(

View File

@ -62,7 +62,7 @@ COPY benchmarks benchmarks
COPY scripts scripts
COPY tensorrt_llm tensorrt_llm
COPY 3rdparty 3rdparty
COPY setup.py requirements.txt requirements-dev.txt ./
COPY .gitmodules setup.py requirements.txt requirements-dev.txt ./
# Create cache directories for pip and ccache
RUN mkdir -p /root/.cache/pip /root/.cache/ccache

View File

@ -6,6 +6,10 @@ set -ex
# and closest to the version specified in
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-07.html#rel-24-07
TORCH_VERSION="2.4.0"
# Check the compatible torchvision from
# https://github.com/pytorch/vision/tree/main?tab=readme-ov-file#installation
# and also confirm with https://pypi.org/pypi/torchvision/0.19.0/json
TORCHVISION_VERSION="0.19.0"
SYSTEM_ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
prepare_environment() {
@ -35,29 +39,44 @@ restore_environment() {
install_from_source() {
if [[ $SYSTEM_ID == *"centos"* ]]; then
VERSION_ID=$(grep -oP '(?<=^VERSION_ID=).+' /etc/os-release | tr -d '"')
if [[ $VERSION_ID == "7" ]]; then
echo "Installation from PyTorch source codes cannot be supported..."
exit 1
fi
VERSION_ID=$(grep -oP '(?<=^VERSION_ID=).+' /etc/os-release | tr -d '"')
if [[ $VERSION_ID == "7" ]]; then
echo "Installation from PyTorch source codes cannot be supported..."
exit 1
fi
fi
prepare_environment $1
export _GLIBCXX_USE_CXX11_ABI=$1
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
export _GLIBCXX_USE_CXX11_ABI=$1
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
export PYTORCH_BUILD_VERSION=${TORCH_VERSION}
export PYTORCH_BUILD_NUMBER=0
pip3 uninstall -y torch
cd /tmp
git clone --depth 1 --branch v$TORCH_VERSION https://github.com/pytorch/pytorch
git clone --depth 1 --branch v${TORCH_VERSION} https://github.com/pytorch/pytorch
cd pytorch
git submodule sync && git submodule update --init --recursive
pip3 install -r requirements.txt
python3 setup.py install
cd /tmp && rm -rf /tmp/pytorch
export PYTORCH_VERSION=${PYTORCH_BUILD_VERSION}
export FORCE_CUDA=1
export BUILD_VERSION=${TORCHVISION_VERSION}
pip3 uninstall -y torchvision
cd /tmp
git clone --depth 1 --branch v${TORCHVISION_VERSION} https://github.com/pytorch/vision
cd vision
python3 setup.py install
cd /tmp && rm -rf /tmp/vision
restore_environment $1
}
install_from_pypi() {
pip3 install torch==${TORCH_VERSION}
pip3 uninstall -y torch torchvision
pip3 install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION}
}
case "$1" in

View File

@ -15,37 +15,6 @@ The following sections provide an overview of the main classes defined in the Ex
The `Executor` class is responsible for receiving requests from the client, and providing responses for those requests. The executor is constructed by providing a path to a directory containing the TensorRT-LLM engine or buffers containing the engine and the model JSON configuration. The client can create requests and enqueue those requests for execution using the `enqueueRequest` or `enqueueRequests` methods of the `Executor` class. Enqueued requests will be scheduled for execution by the executor, and multiple independent requests can be batched together at every iteration of the main execution loop (a process often referred to as continuous batching or iteration-level batching). Responses for a particular request can be awaited for by calling the `awaitResponses` method, and by providing the request id. Alternatively, responses for any requests can be awaited for by omitting to provide the request id when calling `awaitResponses`. The `Executor` class also allows to cancel requests using the `cancelRequest` method and to obtain per-iteration and per-request statistics using the `getLatestIterationStats`.
#### Logits Post-Processor (optional)
Users can alter the logits produced by the network, by providing a map of named callbacks of the form:
```
std::unordered_map<std::string, function<Tensor(IdType, Tensor&, BeamTokens const&, StreamPtr const&, std::optional<IdType>)>>
```
to an instance of `LogitsPostProcessorConfig`. The map key is the name associated with that logits post-processing callback. Each request can then specify the name of the logits post-processor to use for that particular request, if any.
The first argument to the callback is the request id, second is the logits tensor, third are the tokens produced by the request so far, fourth is the operation stream used by the logits tensor, and last one is an optional client id. The callback returns a modified tensor of logits.
Users *must* use the stream to access the logits tensor. For example, performing a addition with a bias tensor should be enqueued on that stream.
Alternatively, users may call `stream->synchronize()`, however, that will slow down the entire execution pipeline.
Multiple requests can share same client id and callback can use different logic based on client id.
We also provide a batched version that allows altering logits of multiple requests in a batch. This allows further optimizations and reduces callback overheads.
```
std::function<void(std::vector<IdType> const&, std::vector<Tensor>&, std::vector<std::reference_wrapper<BeamTokens const>> const&, StreamPtr const&, std::vector<std::optional<IdType>> const&)>
```
A single batched callback can be specified in `LogitsPostProcessorConfig`. Each request can opt to apply this callback by specifying the name of the logits
post-processor as `Request::kBatchedPostProcessorName`.
Note: Neither callback variant is supported with the `STATIC` batching type for the moment.
In a multi-GPU run, callback is invoked on all tensor parallel ranks (in last pipeline rank) by default.
For correct execution, user should replicate client-side state accessed by callback on all tensor parallel ranks.
If replication is expensive or infeasible, use `LogitsPostProcessorConfig::setReplicate(false)` to invoke callback only on first tensor parallel rank.
### The Request Class
The `Request` class is used to define properties of the request, such as the input token ids and the maximum number of tokens to generate. The `streaming` parameter can be used to indicate if the request should generate a response for each new generated tokens (`streaming = true`) or only after all tokens have been generated (`streaming = false`). Other mandatory parameters of the request include the sampling configuration (defined by the `SamplingConfig` class) which contains parameters controlling the decoding process and the output configuration (defined by the `OutputConfig` class) which controls what information should be included in the `Result` for a particular response.
@ -83,6 +52,32 @@ The executor can process requests with different beam widths if the following co
The request queue of the executor must be empty to allow it to reconfigure itself for a new beam width. This reconfiguration will happen automatically when requests with a new beam width are enqueued. If requests with different beam widths are enqueued at the same time, the executor will encounter an error and terminate all requests prematurely.
### Controlling output with Logits Post-Processor
Optionally, you can alter the logits produced by the network by providing an instance of `Executor::LogitsPostProcessorConfig`. For instance, this feature can be used to generate JSON formatted output. {cpp:class}`Executor::LogitsPostProcessorConfig <tensorrt_llm::executor::LogitsPostProcessorConfig>` specifies a map of named callbacks in the following form
```cpp
std::unordered_map<std::string, function<Tensor(IdType, Tensor&, BeamTokens const&, StreamPtr const&, std::optional<IdType>)>>
```
The map key is the name associated with that logits post-processing callback. Each request can then specify the name of the logits post-processor to use for that particular request, if any.
The first argument to the callback is the request id, second is the logits tensor, third are the tokens produced by the request so far, fourth is the operation stream used by the logits tensor, and last one is an optional client id. The callback returns a modified tensor of logits. Multiple requests can share same client id and callback can use different logic based on client id.
You must use the stream to access the logits tensor. For example, to perform an addition with a bias tensor, the addition operation is enqueued on that stream. Alternatively, you can call `stream->synchronize()`, however, that will slow down the entire execution pipeline.
The executor also includes a {cpp:class}`LogitsPostProcessorBatched <tensorrt_llm::executor::LogitsPostProcessorBatched>` method that enables altering logits of multiple requests in a batch. The batched method allows further optimizations and reduces callback overheads.
```cpp
std::function<void(std::vector<IdType> const&, std::vector<Tensor>&, std::vector<std::reference_wrapper<BeamTokens const>> const&, StreamPtr const&, std::vector<std::optional<IdType>> const&)>
```
A single batched callback can be specified in `LogitsPostProcessorConfig`. Each request can opt to apply this callback by specifying the name of the logits post-processor as `Request::kBatchedPostProcessorName`.
Note: Neither callback variant is supported with the `STATIC` batching type for the moment.
In a multi-GPU run, the callback is invoked on all ranks in the first tensor-parallel group, by default. To ensure correct execution, replicate the client-side state that is accessed by the callback on these ranks. If replication is expensive or infeasible, use `LogitsPostProcessorConfig::setReplicate(false)` to invoke the callback only on rank 0. The executor broadcasts the sampled tokens internally to ensure correct execution.
## C++ Executor API Example
Two C++ examples are provided that shows how to use the Executor API and can be found in the [`examples/cpp/executor`](source:examples/cpp/executor/) folder.

View File

@ -304,7 +304,7 @@ For guidance on constructing and executing Medusa with the Python runtime, consu
- TensorRT-LLM supports Medusa only for Vicuna (fine tuned LLaMA).
However, similar to any new model, you can follow the same approach to define your own Medusa model and deploy with TensorRT-LLM.
- We match only tokens during the validation phasem that is `medusa_temperature=0`.
- We match only tokens during the validation phase that is `medusa_temperature=0`.
- Beam search is **not** compatible with Medusa.

View File

@ -93,13 +93,13 @@ def generate_llmapi():
doc_dir.mkdir(exist_ok=True)
doc_path = doc_dir / "index.rst"
hlapi_all_file = root_dir / "tensorrt_llm/hlapi/__init__.py"
public_classes_names = extract_all_and_eval(hlapi_all_file)['__all__']
llmapi_all_file = root_dir / "tensorrt_llm/llmapi/__init__.py"
public_classes_names = extract_all_and_eval(llmapi_all_file)['__all__']
content = underline("API Reference", "-") + "\n\n"
for cls_name in public_classes_names:
cls_name = cls_name.strip()
content += (f".. autoclass:: tensorrt_llm.hlapi.{cls_name}\n"
content += (f".. autoclass:: tensorrt_llm.llmapi.{cls_name}\n"
" :members:\n"
" :undoc-members:\n"
" :special-members: __init__\n"

View File

@ -71,3 +71,7 @@ We recommend checking out the [v0.13.0 tag](https://github.com/NVIDIA/TensorRT-L
This may be caused by an outdated Microsoft Visual C++ Redistributable Version. Please install
[the latest MSVC](https://learn.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-170#latest-microsoft-visual-c-redistributable-version)
and retry. Check the system path to make sure the latest version installed in `System32` is searched first. Check dependencies to make sure no other packages are using an outdated version (e.g. package `pyarrow` might contain an outdated MSCV DLL).
2. OSError: [WinError 126] The specified module could not be found. Error loading “...\Lib\site-packages\torch\lib\fbgemm.dll” or one of its dependencies.
Installing the latest [Build Tools for Visual Studio 2022] (https://visualstudio.microsoft.com/downloads/#build-tools-for-visual-studio-2022) will resolve the issue.

View File

@ -5,7 +5,7 @@
TensorRT-LLM can quantize the Hugging Face model automatically. By setting the appropriate flags in the `LLM` instance. For example, to perform an Int4 AWQ quantization, the following code triggers the model quantization. Please refer to complete list of [supported flags](https://nvidia.github.io/TensorRT-LLM/_modules/tensorrt_llm/quantization/mode.html#QuantAlgo) and acceptable values.
``` python
from tensorrt_llm.hlapi import QuantConfig, QuantAlgo
from tensorrt_llm.llmapi import QuantConfig, QuantAlgo
quant_config = QuantConfig(quant_algo=QuantAlgo.W4A16_AWQ)
@ -14,12 +14,12 @@ llm = LLM(<model-dir>, quant_config=quant_config)
## Sampling
SamplingParams can customize the sampling strategy to control LLM generated responses, such as beam search, temperature, and [others](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/hlapi/utils.py#L55-L76).
SamplingParams can customize the sampling strategy to control LLM generated responses, such as beam search, temperature, and [others](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/llmapi/utils.py#L55-L76).
As an example, to enable beam search with a beam size of 4, set the `sampling_params` as follows:
```python
from tensorrt_llm.hlapi import LLM, SamplingParams, BuildConfig
from tensorrt_llm.llmapi import LLM, SamplingParams, BuildConfig
build_config = BuildConfig()
build_config.max_beam_width = 4
@ -38,7 +38,7 @@ for output in llm.generate(<prompt>, sampling_params=sampling_params):
* [SamplingConfig](https://nvidia.github.io/TensorRT-LLM/_cpp_gen/runtime.html#_CPPv4N12tensorrt_llm7runtime14SamplingConfigE)
* [OutputConfig](https://nvidia.github.io/TensorRT-LLM/_cpp_gen/executor.html#_CPPv4N12tensorrt_llm8executor12OutputConfigE)
Refer to the [class documentation](https://nvidia.github.io/TensorRT-LLM/llm-api/index.html#tensorrt_llm.hlapi.SamplingParams) for more details.
Refer to the [class documentation](https://nvidia.github.io/TensorRT-LLM/llm-api/index.html#tensorrt_llm.llmapi.SamplingParams) for more details.
## Build Configuration
@ -55,11 +55,11 @@ Refer to the [buildconfig documentation](https://github.com/NVIDIA/TensorRT-LLM/
## Runtime Customization
Similar to `build_config`, you can also customize the runtime configuration with the `runtime_config`, `peft_cache_config` or other [arguments](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/hlapi/llm_utils.py#L186-L223) borrowed from the lower-level APIs. These runtime configuration options provide additional flexibility with respect to KV cache management, GPU memory allocation and so on. Refer to the following example:
Similar to `build_config`, you can also customize the runtime configuration with the `runtime_config`, `peft_cache_config` or other [arguments](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/llmapi/llm_utils.py#L186-L223) borrowed from the lower-level APIs. These runtime configuration options provide additional flexibility with respect to KV cache management, GPU memory allocation and so on. Refer to the following example:
```python
from tensorrt_llm.hlapi import LLM, KvCacheConfig
from tensorrt_llm.llmapi import LLM, KvCacheConfig
llm = LLM(<llama_model_path>,
kv_cache_config=KvCacheConfig(

View File

@ -13,6 +13,7 @@ The LLM API can be used for both offline or online usage. See more examples of t
* [LLM Generate Async Streaming](https://nvidia.github.io/TensorRT-LLM/llm-api-examples/llm_generate_async_streaming.html)
* [LLM Quantization](https://nvidia.github.io/TensorRT-LLM/llm-api-examples/llm_quantization.html)
* [LLM Auto Parallel](https://nvidia.github.io/TensorRT-LLM/llm-api-examples/llm_auto_parallel.html)
* [LLM Logits Processor](https://nvidia.github.io/TensorRT-LLM/llm-api-examples/llm_logits_processor.html)
For more details on how to fully utilize this API, check out:

Binary file not shown.

Before

Width:  |  Height:  |  Size: 175 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 372 KiB

View File

@ -30,7 +30,8 @@ TensorRT-LLM consists of pre and post-processing steps and multi-GPU multi-no
### Latest GPU Support
TensorRT-LLM supports GPUs based on the NVIDIA Hopper, NVIDIA Ada Lovelace, NVIDIA Ampere, NVIDIA Turing, and NVIDIA Volta architectures. Certain limitations may, however, apply. Refer to the {ref}`support-matrix` for more information.
TensorRT-LLM supports GPUs based on the NVIDIA Hopper, NVIDIA Ada Lovelace, and NVIDIA Ampere architectures.
Certain limitations might apply. Refer to the {ref}`support-matrix` for more information.
### Native Windows Support

View File

@ -34,83 +34,191 @@ and shows the throughput client-server scenario under maximum load.
The performance numbers below were collected using the steps described in this document.
**All data in the table below was generated using version 0.12.0 and presents token throughput in tokens/second.**
**All data in the table below was generated using version 0.13.0 and presents token throughput in tokens/second.**
| | | | | | | | | |
| ------------ | ------------------------ | ------------- | --------------- | ----------- | -------------- | -------------- | -------------- | ------- |
| | | **GPU** | H200 141GB HBM3 | GH200 120GB | H100 80GB HBM3 | H100 80GB HBM3 | A100-SXM4-80GB | L40S |
| | | **Precision** | FP8 | FP8 | FP8 | FP16 | FP16 | FP8 |
| **Model** | **Input/Output Lengths** | **TP** | | | | | | |
| GPTJ 6B | 128/128 | 1 | 24834.76 | 22454.79 | 24429.55 | 13085.91 | 5864.81 | 7647.24 |
| | 128/2048 | 1 | 8348.93 | 6656.25 | 7831.38 | 3882.21 | 2194.57 | 1843.91 |
| | 128/4096 | 1 | 5062.80 | 3678.91 | 3968.98 | 2046.53 | 1118.22 | 980.67 |
| | 2048/128 | 1 | 2776.53 | 2491.03 | 2724.38 | 1488.56 | 657.01 | 741.06 |
| | 2048/2048 | 1 | 3631.54 | 2994.81 | 3004.17 | 1280.54 | 854.37 | 754.16 |
| LLaMA v2 7B | 128/128 | 1 | 19706.35 | 17803.58 | 19068.99 | 11393.48 | 5272.39 | 6345.72 |
| | 128/2048 | 1 | 7651.12 | 5472.34 | 6610.03 | 2964.65 | 1785.79 | 1551.37 |
| | 128/4096 | 1 | 4424.90 | 3271.61 | 3649.38 | 1596.87 | 957.12 | 817.24 |
| | 2048/128 | 1 | 2385.54 | 2035.42 | 2271.63 | 1189.06 | 564.77 | 625.09 |
| | 2048/2048 | 1 | 3191.34 | 2726.29 | 2802.41 | 1243.96 | 735.19 | 641.56 |
| LLaMA v3 8B | 128/128 | 1 | 28288.75 | 25420.52 | 27399.75 | 15567.44 | 6586.88 | 8745.80 |
| | 128/2048 | 1 | 23230.62 | 16426.68 | 19198.73 | 8817.39 | 4882.13 | 5084.49 |
| | 128/4096 | 1 | 16144.44 | 9832.66 | 12084.97 | 5352.37 | 3079.90 | 2755.13 |
| | 2048/128 | 1 | 3623.79 | 3290.22 | 3463.26 | 1852.48 | 781.63 | 980.86 |
| | 2048/2048 | 1 | 11093.62 | 7573.35 | 8894.11 | 3986.83 | 2268.13 | 2051.79 |
| Mistral 7B | 128/128 | 1 | 30223.01 | 27696.90 | 29788.46 | 16319.25 | 6807.02 | 9612.58 |
| | 128/2048 | 1 | 24989.54 | 17942.29 | 20509.72 | 9982.01 | 5296.02 | 5444.89 |
| | 128/4096 | 1 | 17036.14 | 10846.03 | 12807.80 | 5718.89 | 3241.33 | 2931.17 |
| | 2048/128 | 1 | 3678.80 | 3294.02 | 3521.71 | 1887.75 | 786.43 | 1002.49 |
| | 2048/2048 | 1 | 11510.54 | 8357.75 | 9214.61 | 4284.82 | 2363.25 | 2154.26 |
| Mixtral 8x7B | 128/128 | 2 | 24895.03 | 8785.80 | 24394.71 | 15529.86 | 5921.41 | |
| | | 4 | 42014.24 | 38828.53 | 40197.42 | 28132.17 | 11414.95 | 6820.26 |
| | 128/2048 | 2 | 29389.21 | 5474.69 | 20873.02 | 7066.02 | 4306.98 | |
| | | 4 | 52348.10 | 41573.66 | 40588.05 | 21285.72 | 10974.83 | 7467.15 |
| | 128/4096 | 2 | 21480.27 | 2277.66 | 12838.28 | 3986.01 | 2400.11 | |
| | | 4 | 39182.04 | 28626.55 | 28337.31 | 12447.13 | 7278.89 | 5233.43 |
| | 2048/128 | 2 | 2934.44 | 1003.51 | 2898.27 | 1834.77 | 693.51 | |
| | | 4 | 5152.40 | 4724.01 | 5028.61 | 3393.18 | 1362.93 | 805.49 |
| | 2048/2048 | 2 | 14029.17 | 2671.88 | 10479.45 | 3531.31 | 1945.88 | |
| | | 4 | 25436.05 | 20302.56 | 19971.72 | 9622.66 | 5221.74 | 3616.30 |
| LLaMA v3 70B | 128/128 | 2 | 5386.88 | | | 2959.22 | 1301.14 | |
| | | 4 | 8944.26 | 8587.01 | 8642.05 | 5966.47 | 2413.95 | |
| | | 8 | 16125.20 | | 15397.47 | 10406.55 | 4548.32 | 1364.08 |
| | 128/2048 | 2 | 7007.27 | | | 720.73 | 500.83 | |
| | | 4 | 12906.75 | 10761.53 | 8978.95 | 4736.61 | 2380.02 | |
| | | 8 | 19417.37 | | 14822.93 | 6672.14 | 3815.08 | 1809.40 |
| | 128/4096 | 2 | 6183.85 | | | 369.29 | 251.24 | |
| | | 4 | 8859.54 | 7270.77 | 6073.48 | 2969.99 | 1634.82 | |
| | | 8 | 13969.95 | | 10094.57 | 4358.77 | 2847.54 | 1313.78 |
| | 2048/128 | 2 | 696.59 | | | 301.46 | 140.88 | |
| | | 4 | 1044.35 | 1000.55 | 1022.06 | 681.72 | 278.76 | |
| | | 8 | 2018.47 | | 1933.15 | 1279.46 | 543.73 | 163.36 |
| | 2048/2048 | 2 | 3525.18 | | | | 87.54 | |
| | | 4 | 6550.76 | 4859.38 | 4870.26 | 2379.66 | 1209.69 | |
| | | 8 | 9706.95 | | 7670.04 | 3692.41 | 2192.28 | 895.23 |
| LLaMA v2 70B | 128/128 | 2 | 6355.16 | | | 2927.71 | 1374.05 | |
| | | 4 | 10818.97 | 10819.19 | 10754.99 | 6603.10 | 2765.94 | |
| | | 8 | 16667.25 | | 16074.84 | 11369.11 | 4796.89 | 1402.92 |
| | 128/2048 | 2 | 6185.77 | | | 668.52 | 445.04 | |
| | | 4 | 12884.76 | 11356.48 | 8870.71 | 5067.06 | 2710.53 | |
| | | 8 | 19053.13 | | 17534.62 | 8805.16 | 5665.93 | 2203.33 |
| | 128/4096 | 2 | 4873.24 | | | 334.10 | 215.70 | |
| | | 4 | 8664.90 | 6311.85 | 7564.99 | 3354.02 | 1884.46 | |
| | | 8 | 15110.32 | | 10584.03 | 5373.10 | 3672.80 | 1787.76 |
| | 2048/128 | 2 | 732.09 | | | 302.49 | 141.70 | |
| | | 4 | 1272.90 | 1269.58 | 1265.80 | 774.93 | 320.79 | |
| | | 8 | 2015.77 | | 1943.96 | 1355.78 | 569.48 | 165.52 |
| | 2048/2048 | 2 | 3508.50 | | | 321.95 | 212.97 | |
| | | 4 | 6642.69 | 5545.83 | 4889.26 | 2439.10 | 1276.58 | |
| | | 8 | 10178.71 | | 8071.77 | 4275.74 | 2589.60 | 1083.45 |
| Falcon 180B | 128/128 | 4 | 5129.55 | | | | | |
| | | 8 | 8370.98 | | 8268.72 | | | |
| | 128/2048 | 4 | 7823.79 | | | | | |
| | | 8 | 13278.59 | | 13107.48 | | | |
| | 128/4096 | 4 | 6374.10 | | | | | |
| | | 8 | 12660.89 | | 10493.79 | | | |
| | 2048/128 | 4 | 601.67 | | | | | |
| | | 8 | 1002.57 | | 991.22 | | | |
| | 2048/2048 | 4 | 3869.76 | | | | | |
| | | 8 | 7134.33 | | 6386.83 | | | |
| | | | | | | | | |
| --------------- | ------------------------ | ------------- | ------------------- | --------------- | ------------------ | ------------------ | ------------------ | -------- |
| | | **GPU** | **H200 141GB HBM3** | **GH200 120GB** | **H100 80GB HBM3** | **H100 80GB HBM3** | **A100-SXM4-80GB** | **L40S** |
| | | **Precision** | **FP8** | **FP8** | **FP8** | **FP16** | **FP16** | **FP8** |
| **Model** | **Input/Output Lengths** | **TP** | | | | | | |
| GPTJ 6B | 128/128 | 1 | 24,533.54 | 22,368.50 | 24,318.61 | 12,936.63 | 5,964.19 | 7,688.44 |
| | 128/2048 | 1 | 8,375.67 | 6,588.73 | 7,829.91 | 3,931.61 | 2,215.88 | 1,842.82 |
| | 128/4096 | 1 | 5,048.59 | 3,662.81 | 3,955.28 | 2,041.06 | 1,118.12 | 980.23 |
| | 2048/128 | 1 | 2,770.27 | 2,520.37 | 2,698.08 | 1,479.48 | 650.09 | 746.54 |
| | 5000/500 | 1 | 1,791.39 | 1,449.23 | 1,623.17 | 818.80 | 436.85 | 413.33 |
| | 500/2000 | 1 | 6,770.60 | 5,565.62 | 6,149.65 | 3,030.03 | 1,673.05 | 1,538.45 |
| | 1000/1000 | 1 | 6,465.73 | 5,580.37 | 6,078.80 | 2,797.48 | 1,673.45 | 1,531.57 |
| | 2048/2048 | 1 | 3,637.42 | 2,998.01 | 3,060.80 | 1,285.08 | 845.83 | 753.55 |
| LLaMA v3.1 8B | 128/128 | 1 | 28,125.59 | 26,045.60 | 27,147.22 | 15,647.83 | 6,687.04 | 8,548.90 |
| | 128/2048 | 1 | 22,989.20 | 16,497.79 | 19,221.02 | 8,882.95 | 4,918.53 | 4,988.61 |
| | 128/4096 | 1 | 16,077.62 | 9,637.91 | 11,856.11 | 5,462.96 | 3,054.46 | 2,768.91 |
| | 2048/128 | 1 | 3,625.83 | 3,357.60 | 3,497.30 | 1,859.37 | 796.17 | 1,000.90 |
| | 5000/500 | 1 | 3,823.76 | 3,217.40 | 3,276.69 | 1,687.74 | 788.66 | 872.14 |
| | 500/2000 | 1 | 19,382.37 | 15,128.77 | 13,996.05 | 6,834.76 | 3,929.83 | 3,911.14 |
| | 1000/1000 | 1 | 16,435.21 | 12,355.41 | 13,411.43 | 7,160.92 | 3,592.16 | 3,648.21 |
| | 2048/2048 | 1 | 11,072.97 | 7,850.75 | 8,851.23 | 4,152.21 | 2,269.78 | 2,055.78 |
| | 20000/2000 | 1 | 1,634.98 | 1,200.89 | 1,278.04 | 595.89 | 316.43 | 263.75 |
| LLaMA v3 8B | 128/128 | 1 | 27,940.47 | 26,117.13 | 27,156.81 | 15,489.11 | 6,656.98 | 8,734.57 |
| | 128/2048 | 1 | 23,228.98 | 16,417.04 | 19,209.17 | 8,901.43 | 4,967.37 | 5,004.93 |
| | 128/4096 | 1 | 15,980.94 | 9,351.95 | 11,889.67 | 5,455.91 | 3,053.27 | 2,768.15 |
| | 2048/128 | 1 | 3,631.45 | 3,339.90 | 3,476.37 | 1,918.56 | 796.28 | 1,050.68 |
| | 5000/500 | 1 | 3,836.98 | 3,186.22 | 3,279.24 | 1,668.42 | 792.95 | 860.31 |
| | 500/2000 | 1 | 19,725.45 | 15,241.74 | 14,218.30 | 6,816.62 | 3,899.64 | 3,990.73 |
| | 1000/1000 | 1 | 16,201.60 | 12,049.81 | 13,371.60 | 7,041.47 | 3,617.10 | 3,679.10 |
| | 2048/2048 | 1 | 11,097.69 | 7,255.55 | 8,852.87 | 4,251.45 | 2,269.68 | 2,048.94 |
| LLaMA v2 7B | 128/128 | 1 | 19,549.13 | 17,823.45 | 19,298.99 | 11,436.31 | 5,238.68 | 6,396.62 |
| | 128/2048 | 1 | 7,675.14 | 5,438.53 | 6,607.33 | 2,985.61 | 1,807.39 | 1,566.03 |
| | 128/4096 | 1 | 4,397.83 | 3,310.09 | 3,628.46 | 1,575.35 | 957.24 | 821.83 |
| | 2048/128 | 1 | 2,392.31 | 2,064.18 | 2,304.02 | 1,157.55 | 560.35 | 619.83 |
| | 5000/500 | 1 | 1,570.37 | 1,250.11 | 1,419.09 | 624.75 | 366.39 | 347.03 |
| | 500/2000 | 1 | 6,044.15 | 4,717.51 | 5,188.69 | 2,382.75 | 1,408.58 | 1,231.78 |
| | 1000/1000 | 1 | 5,896.10 | 4,825.24 | 5,208.97 | 2,462.65 | 1,431.92 | 1,277.79 |
| | 2048/2048 | 1 | 3,193.42 | 2,693.21 | 2,792.53 | 1,263.11 | 734.38 | 641.47 |
| Mistral 7B | 128/128 | 1 | 30,152.19 | 27,738.08 | 29,672.75 | 16,711.12 | 6,863.59 | 9,676.88 |
| | 128/2048 | 1 | 24,742.09 | 17,528.14 | 20,318.60 | 9,774.11 | 5,321.44 | 5,437.25 |
| | 128/4096 | 1 | 16,905.49 | 10,671.38 | 12,715.46 | 5,740.41 | 3,257.23 | 2,941.08 |
| | 2048/128 | 1 | 3,676.37 | 3,369.77 | 3,502.83 | 1,893.42 | 796.00 | 996.65 |
| | 5000/500 | 1 | 3,890.07 | 3,401.45 | 3,358.65 | 1,740.69 | 807.07 | 904.45 |
| | 500/2000 | 1 | 20,788.70 | 15,035.59 | 15,962.94 | 7,494.80 | 4,168.89 | 4,088.52 |
| | 1000/1000 | 1 | 17,620.46 | 13,362.84 | 14,213.48 | 7,281.07 | 3,794.31 | 3,972.63 |
| | 2048/2048 | 1 | 11,747.88 | 8,599.03 | 9,200.19 | 4,349.39 | 2,320.50 | 2,170.16 |
| | 20000/2000 | 1 | 1,693.41 | 1,271.85 | 1,299.05 | 609.91 | 324.52 | 276.19 |
| LLaMA v3.1 405B | 128/128 | 8 | 3,734.50 | | | | | |
| | 128/2048 | 8 | 3,039.70 | | | | | |
| | 128/4096 | 8 | 3,144.97 | | | | | |
| | 2048/128 | 8 | 454.17 | | | | | |
| | 5000/500 | 8 | 459.91 | | | | | |
| | 500/2000 | 8 | 2,967.98 | | | | | |
| | 1000/1000 | 8 | 2,259.32 | | | | | |
| | 2048/2048 | 8 | 2,067.15 | | | | | |
| | 20000/2000 | 8 | 447.67 | | | | | |
| LLaMA v3.1 70B | 128/128 | 1 | 3,923.61 | 2,998.99 | 2,168.72 | | | |
| | | 2 | 5,358.16 | 1,839.02 | 5,215.12 | 3,156.10 | 1,340.20 | |
| | | 4 | 8,969.59 | 8,655.98 | 8,677.59 | 5,845.53 | 2,426.46 | 1,434.63 |
| | | 8 | 16,449.68 | | 15,711.60 | 10,643.75 | 4,491.42 | 1,365.36 |
| | 128/2048 | 1 | 3,503.59 | 1,343.53 | 344.22 | | | |
| | | 2 | 7,068.42 | 1,146.08 | 5,654.43 | 801.82 | 498.44 | |
| | | 4 | 12,890.95 | 10,358.10 | 9,377.87 | 4,791.11 | 2,460.91 | 1,748.87 |
| | | 8 | 19,947.02 | | 15,168.97 | 6,892.18 | 4,148.33 | 1,890.62 |
| | 128/4096 | 1 | 2,314.83 | | | | | |
| | | 2 | 6,227.19 | 896.56 | 3,302.41 | 413.22 | 268.86 | |
| | | 4 | 10,059.64 | 6,628.22 | 6,501.69 | 3,056.98 | 1,660.93 | 1,180.87 |
| | | 8 | 14,393.28 | | 9,699.99 | 4,238.15 | 2,705.77 | 1,417.60 |
| | 2048/128 | 1 | 459.73 | 372.44 | 211.51 | | | |
| | | 2 | 689.30 | 280.61 | 690.05 | 323.66 | 143.39 | |
| | | 4 | 1,047.96 | 1,015.14 | 1,016.24 | 672.37 | 278.87 | 167.87 |
| | | 8 | 2,061.19 | | 1,964.49 | 1,273.97 | 539.57 | 163.91 |
| | 5000/500 | 1 | 534.79 | 283.19 | 112.21 | | | |
| | | 2 | 943.78 | 337.04 | 897.36 | 224.31 | 115.63 | |
| | | 4 | 1,437.45 | 1,383.61 | 1,329.82 | 851.12 | 361.39 | 235.90 |
| | | 8 | 2,795.95 | | 2,472.69 | 1,438.10 | 679.27 | 224.33 |
| | 500/2000 | 1 | 2,758.24 | 1,083.48 | | | | |
| | | 2 | 6,063.53 | 851.46 | 4,347.69 | 652.34 | 423.06 | |
| | | 4 | 10,061.89 | 9,090.78 | 8,378.16 | 3,441.34 | 2,072.88 | 1,436.41 |
| | | 8 | 16,139.49 | | 10,790.85 | 5,792.17 | 3,115.20 | 1,512.78 |
| | 1000/1000 | 1 | 2,539.65 | 728.79 | | | | |
| | | 2 | 4,572.03 | 1,223.92 | 3,880.41 | 737.40 | 451.82 | |
| | | 4 | 7,612.56 | 6,705.02 | 6,553.00 | 3,655.64 | 1,731.86 | 1,113.18 |
| | | 8 | 12,660.86 | | 11,121.10 | 5,599.45 | 3,013.95 | 1,120.73 |
| | 2048/2048 | 1 | 1,753.58 | 611.08 | 161.60 | | | |
| | | 2 | 3,407.26 | 626.26 | 2,432.55 | | 108.91 | |
| | | 4 | 6,565.77 | 4,864.55 | 4,948.83 | 2,396.06 | 1,220.93 | 855.44 |
| | | 8 | 9,948.56 | | 8,527.52 | 3,819.60 | 2,103.68 | 924.89 |
| | 20000/2000 | 1 | 262.82 | 88.89 | | | | |
| | | 2 | 598.19 | 177.04 | 414.17 | | | |
| | | 4 | 1,047.27 | 958.88 | 856.31 | 375.85 | 187.42 | 140.73 |
| | | 8 | 1,793.52 | | 1,359.27 | 650.78 | 344.41 | 122.04 |
| LLaMA v3 70B | 128/128 | 1 | 3,924.02 | 3,161.73 | 2,177.84 | | | |
| | | 2 | 5,388.22 | 1,551.84 | 5,205.80 | 3,186.61 | 1,321.55 | |
| | | 4 | 8,958.95 | 8,618.55 | 8,678.68 | 5,857.16 | 2,424.68 | 1,432.46 |
| | | 8 | 16,375.41 | | 15,703.26 | 10,627.36 | 4,490.19 | 1,333.09 |
| | 128/2048 | 1 | 3,519.24 | 1,346.37 | 353.68 | | | |
| | | 2 | 7,071.54 | 862.54 | 5,878.06 | 802.98 | 512.11 | |
| | | 4 | 12,876.38 | 10,015.23 | 8,929.23 | 4,768.27 | 2,458.73 | 1,737.31 |
| | | 8 | 20,013.92 | | 15,171.91 | 6,875.97 | 3,906.35 | 1,892.41 |
| | 128/4096 | 1 | 2,310.85 | | | | | |
| | | 2 | 6,199.95 | 602.98 | 3,311.05 | 413.29 | 269.02 | |
| | | 4 | 9,633.49 | 7,370.19 | 6,489.95 | 3,053.89 | 1,677.51 | 1,199.71 |
| | | 8 | 14,552.09 | | 9,632.02 | 4,259.39 | 2,697.61 | 1,358.34 |
| | 2048/128 | 1 | 458.75 | 371.70 | 210.27 | | | |
| | | 2 | 694.00 | 277.85 | 692.74 | 321.71 | 144.61 | |
| | | 4 | 1,048.84 | 1,016.03 | 1,022.77 | 690.10 | 279.06 | 168.52 |
| | | 8 | 2,072.33 | | 1,976.76 | 1,273.41 | 542.93 | 158.63 |
| | 5000/500 | 1 | 533.37 | 303.33 | 112.68 | | | |
| | | 2 | 936.82 | 379.62 | 899.29 | 224.65 | 115.00 | |
| | | 4 | 1,442.76 | 1,384.62 | 1,326.95 | 853.73 | 361.06 | 235.19 |
| | | 8 | 2,797.36 | | 2,483.56 | 1,437.15 | 678.70 | 225.15 |
| | 500/2000 | 1 | 2,763.89 | 1,074.62 | 293.47 | | | |
| | | 2 | 6,054.46 | 1,109.13 | 4,356.55 | 683.11 | 423.82 | |
| | | 4 | 10,103.08 | 7,325.93 | 8,370.32 | 3,436.29 | 2,064.47 | 1,412.78 |
| | | 8 | 16,857.45 | | 10,760.65 | 5,665.02 | 3,159.89 | 1,517.76 |
| | 1000/1000 | 1 | 2,540.45 | 1,164.45 | | | | |
| | | 2 | 4,590.38 | 1,040.64 | 3,879.25 | 768.53 | 453.73 | |
| | | 4 | 7,606.92 | 6,655.61 | 6,547.23 | 3,655.19 | 1,732.86 | 1,117.53 |
| | | 8 | 12,660.32 | | 11,155.47 | 5,617.24 | 2,894.58 | 1,126.50 |
| | 2048/2048 | 1 | 1,746.77 | 610.87 | 162.10 | | | |
| | | 2 | 3,405.72 | 738.51 | 2,548.70 | | 108.66 | |
| | | 4 | 6,571.34 | 4,880.28 | 5,060.39 | 2,391.55 | 1,222.11 | 854.65 |
| | | 8 | 9,923.96 | | 8,480.48 | 3,826.38 | 2,181.07 | 927.54 |
| LLaMA v2 70B | 128/128 | 1 | 3,969.25 | 3,502.35 | 3,413.82 | | | |
| | | 2 | 6,394.64 | 3,252.69 | 6,432.82 | 3,170.28 | 1,336.48 | |
| | | 4 | 11,031.42 | 11,126.95 | 10,865.42 | 6,420.88 | 2,766.00 | 1,487.71 |
| | | 8 | 17,060.04 | | 16,384.83 | 11,146.15 | 4,742.74 | 1,404.99 |
| | 128/2048 | 1 | 3,742.99 | 1,660.81 | | | | |
| | | 2 | 6,453.25 | 1,335.80 | 5,775.34 | 757.21 | 476.46 | |
| | | 4 | 13,869.67 | 11,098.69 | 9,536.82 | 5,274.27 | 2,686.16 | 1,880.22 |
| | | 8 | 19,220.48 | | 17,715.01 | 8,904.94 | 5,520.41 | 2,186.68 |
| | 128/4096 | 1 | 2,459.63 | | 446.60 | | | |
| | | 2 | 4,831.03 | 684.68 | 3,354.60 | 385.98 | 235.22 | |
| | | 4 | 8,988.84 | 8,397.13 | 7,619.62 | 3,228.36 | 1,941.07 | 1,318.51 |
| | | 8 | 15,115.41 | | 12,506.95 | 5,996.81 | 3,539.36 | 1,782.93 |
| | 2048/128 | 1 | 458.88 | 400.31 | 328.90 | | | |
| | | 2 | 745.71 | 457.57 | 742.17 | 308.02 | 138.81 | |
| | | 4 | 1,297.10 | 1,330.90 | 1,270.78 | 755.30 | 321.72 | 171.67 |
| | | 8 | 2,060.53 | | 2,009.57 | 1,348.71 | 561.71 | 160.37 |
| | 5000/500 | 1 | 548.46 | 364.00 | 224.17 | | | |
| | | 2 | 1,020.86 | 335.07 | 885.67 | 212.20 | 112.43 | |
| | | 4 | 1,759.69 | 1,683.26 | 1,590.94 | 837.57 | 386.78 | 231.54 |
| | | 8 | 2,839.69 | | 2,546.12 | 1,570.91 | 709.66 | 238.59 |
| | 500/2000 | 1 | 3,019.28 | 1,364.66 | 716.54 | | | |
| | | 2 | 6,402.94 | 1,292.24 | 4,462.98 | 629.21 | 387.61 | |
| | | 4 | 12,429.18 | 8,951.07 | 8,753.09 | 4,012.41 | 2,158.17 | 1,517.53 |
| | | 8 | 16,789.12 | | 15,260.29 | 7,384.79 | 4,104.80 | 1,739.28 |
| | 1000/1000 | 1 | 2,706.04 | 1,449.83 | | | | |
| | | 2 | 4,693.24 | 960.39 | 3,958.45 | 736.68 | 425.70 | |
| | | 4 | 8,557.11 | 7,278.64 | 6,817.41 | 3,866.05 | 1,876.40 | 1,188.91 |
| | | 8 | 13,483.04 | | 11,511.74 | 6,543.96 | 3,285.82 | 1,241.42 |
| | 2048/2048 | 1 | 1,911.20 | 798.50 | 412.37 | | | |
| | | 2 | 3,408.82 | 767.24 | 2,551.21 | 388.82 | 226.60 | |
| | | 4 | 6,702.46 | 5,354.80 | 5,212.02 | 2,512.22 | 1,316.92 | 891.95 |
| | | 8 | 10,348.65 | | 8,016.14 | 4,414.75 | 2,492.09 | 1,083.26 |
| Mixtral 8x7B | 128/128 | 2 | 25,135.25 | 8,512.51 | 24,572.90 | 15,395.59 | 5,927.88 | |
| | | 4 | 42,394.61 | 40,148.01 | 40,309.25 | 27,747.43 | 11,205.51 | 6,784.44 |
| | | 8 | 54,648.80 | | 51,683.16 | 40,116.51 | 18,496.66 | 6,437.72 |
| | 128/2048 | 2 | 29,412.17 | 3,271.02 | 20,938.80 | 7,391.51 | 4,278.79 | |
| | | 4 | 52,603.13 | 43,071.34 | 40,580.94 | 21,332.15 | 10,946.58 | 7,475.05 |
| | | 8 | 70,427.00 | | 64,161.64 | 41,101.18 | 21,235.99 | 9,955.21 |
| | 128/4096 | 2 | 21,312.11 | 2,254.56 | | 3,896.02 | 2,388.14 | |
| | | 4 | 39,353.01 | 30,065.77 | | | 7,108.03 | 5,232.44 |
| | | 8 | 32,992.62 | | 47,860.65 | 27,261.67 | 15,943.70 | 8,081.21 |
| | 2048/128 | 2 | 2,946.01 | 921.87 | 2,894.09 | 1,790.49 | 684.71 | |
| | | 4 | 5,237.58 | 5,056.60 | 4,988.14 | 3,354.89 | 1,338.54 | 803.50 |
| | | 8 | 7,053.32 | | 6,559.63 | 5,072.46 | 2,244.39 | 753.39 |
| | 5000/500 | 2 | 3,848.10 | 997.06 | 3,630.24 | 1,656.04 | 739.84 | |
| | | 4 | 6,877.65 | 6,466.39 | 6,237.22 | 3,607.46 | 1,619.49 | 1,048.60 |
| | | 8 | 9,531.26 | | 8,709.34 | 6,237.96 | 2,927.13 | 1,109.25 |
| | 500/2000 | 2 | 23,539.24 | 2,773.86 | 16,886.30 | 5,773.33 | 3,325.73 | |
| | | 4 | 40,035.05 | 33,478.35 | 32,047.73 | 16,897.03 | 8,908.09 | 6,153.32 |
| | | 8 | 60,572.77 | | 41,597.80 | 31,392.32 | 16,954.54 | 7,980.34 |
| | 1000/1000 | 2 | 18,644.51 | 4,540.15 | 14,154.95 | 5,826.43 | 3,289.27 | |
| | | 4 | 32,709.62 | 29,046.16 | 25,291.30 | 14,307.91 | 7,461.63 | 4,697.19 |
| | | 8 | 44,072.88 | | 40,628.46 | 27,633.48 | 13,741.62 | 5,706.17 |
| | 2048/2048 | 2 | 14,017.70 | 2,870.77 | 10,448.79 | 3,535.21 | 1,954.32 | |
| | | 4 | 25,550.44 | 21,488.32 | 19,977.11 | 9,620.99 | 5,191.30 | 3,593.18 |
| | | 8 | 24,999.94 | | 31,678.85 | 19,372.52 | 10,572.07 | 4,860.61 |
| | 20000/2000 | 2 | 2,195.84 | 367.81 | 1,583.86 | 626.60 | 320.41 | |
| | | 4 | 4,086.41 | 3,301.28 | 2,982.42 | 1,586.09 | 807.67 | 579.49 |
| | | 8 | 5,797.57 | | 5,163.91 | 3,106.98 | 1,653.55 | 821.64 |
*TP stands for Tensor Parallelism*
## Reproducing Benchmarked Results
@ -169,7 +277,10 @@ remain in the system longer and therefore require less requests to achieve stead
| 128 | 4096 | 4224 | 1500 |
| 2048 | 128 | 2176 | 3000 |
| 2048 | 2048 | 4096 | 1500 |
| 5000 | 500 | 5500 | 1500 |
| 1000 | 1000 | 2000 | 3000 |
| 500 | 2000 | 2500 | 3000 |
| 20000 | 2000 | 22000 | 1000 |
## Engine Building

View File

@ -75,7 +75,8 @@ TensorRT-LLM optimizes the performance of a range of well-known models on NVIDIA
The following table shows the supported hardware for TensorRT-LLM.
If a GPU is not listed, it is important to note that TensorRT-LLM is expected to work on GPUs based on the Volta, Turing, Ampere, Hopper, and Ada Lovelace architectures. Certain limitations may, however, apply.
If a GPU architecture is not listed, the TensorRT-LLM team does not develop or test the software on the architecture and support is limited to community support.
In addition, older architectures can have limitations for newer software releases.
```{list-table}
:header-rows: 1
@ -90,8 +91,6 @@ If a GPU is not listed, it is important to note that TensorRT-LLM is expected to
- [NVIDIA Hopper Architecture](https://www.nvidia.com/en-us/data-center/technologies/hopper-architecture/)
- [NVIDIA Ada Lovelace Architecture](https://www.nvidia.com/en-us/technologies/ada-architecture/)
- [NVIDIA Ampere Architecture](https://www.nvidia.com/en-us/data-center/ampere-architecture/)
- [NVIDIA Turing Architecture](https://www.nvidia.com/en-us/geforce/turing/)
- [NVIDIA Volta Architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) (experimental)
```
(support-matrix-software)=
@ -114,14 +113,8 @@ The following table shows the supported software for TensorRT-LLM.
- Hopper (SM90) - FP32, FP16, BF16, FP8, INT8, INT4
- Ada Lovelace (SM89) - FP32, FP16, BF16, FP8, INT8, INT4
- Ampere (SM80, SM86) - FP32, FP16, BF16, INT8, INT4[^smgte89]
- Turing (SM75) - FP32, FP16, INT8[^smooth], INT4
- Volta (SM70) - FP32, FP16, INT8[^smooth], INT4[^smlt75]
```
[^smooth]: INT8 SmoothQuant is not supported on SM70 and SM75.
[^smlt75]: INT4 AWQ and GPTQ are not supported on SM < 75.
[^smgte89]: INT4 AWQ and GPTQ with FP8 activations require SM >= 89.
[^encdec]: Encoder-Decoder provides general encoder-decoder functionality that supports many encoder-decoder models such as T5 family, BART family, Whisper family, NMT family, and so on.

View File

@ -258,8 +258,13 @@ SLURM, depending upon the SLURM version you are using:
Please configure as appropriate and try again.
--------------------------------------------------------------------------
```
You may experience other problems like hanging on the program startup.
As a rule of thumb, if you are running TensorRT-LLM interactively on a Slurm
node, prefix your commands with `mpirun -n 1` to run TensorRT-LLM in a
dedicated MPI environment, not the one provided by your Slurm allocation.
For example: `mpirun -n 1 python3 examples/gpt/build.py ...`
It's critical that it's always `-n 1` regardless of how many GPUs are being used. If you'd use `-n 2` for a 2 GPU program it will not work. `mpirun` here isn't being used to orchestrate multiple processes, but to invoke the right environment on SLURM. The internal MPI implementation deals with spawning the additional processes.

View File

@ -179,13 +179,13 @@ All published functionality in the Release Notes has been fully tested and verif
- Moved the most commonly used options in the explicit arg-list, and hidden the expert options in the kwargs.
- Exposed `model` to accept either HuggingFace model name or local HuggingFace model/TensorRT-LLM checkpoint/TensorRT-LLM engine.
- Support downloading model from HuggingFace model hub, currently only Llama variants are supported.
- Support build cache to reuse the built TensorRT-LLM engines by setting environment variable `TLLM_HLAPI_BUILD_CACHE=1` or passing `enable_build_cache=True` to `LLM` class.
- Support build cache to reuse the built TensorRT-LLM engines by setting environment variable `TLLM_LLMAPI_BUILD_CACHE=1` or passing `enable_build_cache=True` to `LLM` class.
- Exposed low-level options including `BuildConfig`, `SchedulerConfig` and so on in the kwargs, ideally you should be able to configure details about the build and runtime phase.
- Refactored `LLM.generate()` and `LLM.generate_async()` API.
- Removed `SamplingConfig`.
- Added `SamplingParams` with more extensive parameters, see `tensorrt_llm/hlapi/utils.py`.
- Added `SamplingParams` with more extensive parameters, see `tensorrt_llm/llmapi/utils.py`.
- The new `SamplingParams` contains and manages fields from Python bindings of `SamplingConfig`, `OutputConfig`, and so on.
- Refactored `LLM.generate()` output as `RequestOutput`, see `tensorrt_llm/hlapi/llm.py`.
- Refactored `LLM.generate()` output as `RequestOutput`, see `tensorrt_llm/llmapi/llm.py`.
- Updated the `apps` examples, specially by rewriting both `chat.py` and `fastapi_server.py` using the `LLM` APIs, please refer to the `examples/apps/README.md` for details.
- Updated the `chat.py` to support multi-turn conversation, allowing users to chat with a model in the terminal.
- Fixed the `fastapi_server.py` and eliminate the need for `mpirun` in multi-GPU scenarios.
@ -481,7 +481,7 @@ All published functionality in the Release Notes has been fully tested and verif
Refer to the {ref}`support-matrix-software` section for a list of supported models.
* API
- Add a set of High-level APIs for end-to-end generation tasks (see examples/high-level-api/README.md)
- Add a set of LLM APIs for end-to-end generation tasks (see examples/llm-api/README.md)
- **[BREAKING CHANGES]** Migrate models to the new build workflow, including LLaMA, Mistral, Mixtral, InternLM, ChatGLM, Falcon, GPT-J, GPT-NeoX, Medusa, MPT, Baichuan and Phi (see docs/source/new_workflow.md)
- **[BREAKING CHANGES]** Deprecate `LayerNorm` and `RMSNorm` plugins and removed corresponding build parameters
- **[BREAKING CHANGES]** Remove optional parameter `maxNumSequences` for GPT manager

View File

@ -1,4 +1,4 @@
# Apps examples with GenerationExecutor / High-level API
# Apps examples with GenerationExecutor / LLM API
## OpenAI API
[openai_server.py](./openai_server.py) is an OpenAI compatible server which supports `v1/version`, `v1/completions` and `v1/chat/completions`. [openai_client.py](./openai_client.py) is a simple example using OpenAI client to query your model. To start the server, you can run
```

Some files were not shown because too many files have changed in this diff Show More