mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#2333)
* Update TensorRT-LLM --------- Co-authored-by: Puneesh Khanna <puneesh.khanna@tii.ae> Co-authored-by: Ethan Zhang <26497102+ethnzhng@users.noreply.github.com>
This commit is contained in:
parent
8681b3a4c0
commit
75057cd036
11
README.md
11
README.md
@ -8,7 +8,7 @@ TensorRT-LLM
|
||||
[](https://www.python.org/downloads/release/python-31012/)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](./LICENSE)
|
||||
|
||||
[Architecture](./docs/source/architecture/overview.md) | [Results](./docs/source/performance/perf-overview.md) | [Examples](./examples/) | [Documentation](./docs/source/)
|
||||
@ -17,11 +17,14 @@ TensorRT-LLM
|
||||
<div align="left">
|
||||
|
||||
## Latest News
|
||||
* [2024/10/07] 🚀🚀🚀Optimizing Microsoft Bing Visual Search with NVIDIA Accelerated Libraries
|
||||
[➡️ link](https://developer.nvidia.com/blog/optimizing-microsoft-bing-visual-search-with-nvidia-accelerated-libraries/)
|
||||
<div align="center">
|
||||
<img src="docs/source/media/image-10-07-2024.png" width="50%">
|
||||
<div align="left">
|
||||
|
||||
* [2024/09/29] 🌟 AI at Meta PyTorch + TensorRT v2.4 🌟 ⚡TensorRT 10.1 ⚡PyTorch 2.4 ⚡CUDA 12.4 ⚡Python 3.12
|
||||
[➡️ link](https://github.com/pytorch/TensorRT/releases/tag/v2.4.0)
|
||||
<div align="center">
|
||||
<img src="docs/source/media/image-09-29-2024.png" width="50%">
|
||||
<div align="left">
|
||||
|
||||
* [2024/09/17] ✨ NVIDIA TensorRT-LLM Meetup
|
||||
[➡️ link](https://drive.google.com/file/d/1RR8GqC-QbuaKuHj82rZcXb3MS20SWo6F/view?usp=share_link)
|
||||
|
||||
@ -426,6 +426,7 @@ public:
|
||||
void initialize()
|
||||
{
|
||||
mStart = std::chrono::steady_clock::now();
|
||||
mRequestsQueueingLatencies.clear();
|
||||
}
|
||||
|
||||
void finalize()
|
||||
@ -433,6 +434,11 @@ public:
|
||||
mEnd = std::chrono::steady_clock::now();
|
||||
}
|
||||
|
||||
void recordQueueLatency(std::vector<float> const& latencies)
|
||||
{
|
||||
mRequestsQueueingLatencies.insert(mRequestsQueueingLatencies.end(), latencies.begin(), latencies.end());
|
||||
}
|
||||
|
||||
void recordStart(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
|
||||
{
|
||||
auto const inputLength = request->getInputIds()->getSize();
|
||||
@ -677,6 +683,16 @@ public:
|
||||
mMaxGenT2TLatency = genT2TLatencies.back();
|
||||
mMinGenT2TLatency = genT2TLatencies.front();
|
||||
}
|
||||
|
||||
mAvgReqQueueingLatency
|
||||
= std::accumulate(mRequestsQueueingLatencies.begin(), mRequestsQueueingLatencies.end(), 0.F)
|
||||
/ mRequestsQueueingLatencies.size();
|
||||
std::sort(mRequestsQueueingLatencies.begin(), mRequestsQueueingLatencies.end());
|
||||
mP99ReqQueueingLatency = calcPercentile(mRequestsQueueingLatencies, 99);
|
||||
mP90ReqQueueingLatency = calcPercentile(mRequestsQueueingLatencies, 90);
|
||||
mP50ReqQueueingLatency = calcPercentile(mRequestsQueueingLatencies, 50);
|
||||
mMaxReqQueueingLatency = mRequestsQueueingLatencies.back();
|
||||
mMinReqQueueingLatency = mRequestsQueueingLatencies.front();
|
||||
}
|
||||
}
|
||||
|
||||
@ -713,6 +729,13 @@ public:
|
||||
printf("[BENCHMARK] p99_inter_token_latency(ms) %.2f\n", mP99GenT2TLatency);
|
||||
printf("[BENCHMARK] p90_inter_token_latency(ms) %.2f\n", mP90GenT2TLatency);
|
||||
printf("[BENCHMARK] p50_inter_token_latency(ms) %.2f\n\n", mP50GenT2TLatency);
|
||||
|
||||
printf("[BENCHMARK] avg_request_queueing_latency(ms) %.2f\n", mAvgReqQueueingLatency);
|
||||
printf("[BENCHMARK] max_request_queueing_latency(ms) %.2f\n", mMaxReqQueueingLatency);
|
||||
printf("[BENCHMARK] min_request_queueing_latency(ms) %.2f\n", mMinReqQueueingLatency);
|
||||
printf("[BENCHMARK] p99_request_queueing_latency(ms) %.2f\n", mP99ReqQueueingLatency);
|
||||
printf("[BENCHMARK] p90_request_queueing_latency(ms) %.2f\n", mP90ReqQueueingLatency);
|
||||
printf("[BENCHMARK] p50_request_queueing_latency(ms) %.2f\n\n", mP50ReqQueueingLatency);
|
||||
}
|
||||
}
|
||||
|
||||
@ -820,6 +843,13 @@ private:
|
||||
float mP50GenT2TLatency{};
|
||||
float mMaxGenT2TLatency{};
|
||||
float mMinGenT2TLatency{};
|
||||
float mAvgReqQueueingLatency{};
|
||||
float mP99ReqQueueingLatency{};
|
||||
float mP90ReqQueueingLatency{};
|
||||
float mP50ReqQueueingLatency{};
|
||||
float mMaxReqQueueingLatency{};
|
||||
float mMinReqQueueingLatency{};
|
||||
std::vector<float> mRequestsQueueingLatencies{};
|
||||
|
||||
std::string mOpCsvFile;
|
||||
bool mStreaming;
|
||||
@ -846,6 +876,7 @@ public:
|
||||
, mActiveCount(0)
|
||||
, mNumFinished(0)
|
||||
, mShutdown(false)
|
||||
, mLogIterationData(logIterationData)
|
||||
{
|
||||
|
||||
texec::SchedulerConfig schedulerConfig(capacitySchedulerPolicy);
|
||||
@ -899,7 +930,9 @@ public:
|
||||
TLLM_LOG_ERROR("not a supported executor model type in executor server.");
|
||||
}
|
||||
|
||||
if (logIterationData)
|
||||
auto const& world = tensorrt_llm::mpi::MpiComm::world();
|
||||
auto worldRank = world.getRank();
|
||||
if (worldRank == 0)
|
||||
{
|
||||
mCollectStatsThread = std::thread(&ExecutorServer::collectStats, this);
|
||||
}
|
||||
@ -988,7 +1021,18 @@ public:
|
||||
auto iterStats = mExecutor->getLatestIterationStats();
|
||||
for (auto const& iterStat : iterStats)
|
||||
{
|
||||
TLLM_LOG_INFO(texec::JsonSerialization::toJsonStr(iterStat));
|
||||
SizeType32 numNewActiveRequests = iterStat.numNewActiveRequests;
|
||||
if (numNewActiveRequests > 0)
|
||||
{
|
||||
float avgQueueingTime
|
||||
= static_cast<float>(iterStat.newActiveRequestsQueueLatencyMS / numNewActiveRequests);
|
||||
std::vector<float> requestsQueueLatencyMS(numNewActiveRequests, avgQueueingTime);
|
||||
mRecorder->recordQueueLatency(requestsQueueLatencyMS);
|
||||
}
|
||||
if (mLogIterationData)
|
||||
{
|
||||
TLLM_LOG_INFO(texec::JsonSerialization::toJsonStr(iterStat));
|
||||
}
|
||||
}
|
||||
auto const waitSleep = std::chrono::milliseconds(50);
|
||||
std::this_thread::sleep_for(waitSleep);
|
||||
@ -1005,6 +1049,7 @@ private:
|
||||
std::atomic<uint64_t> mActiveCount;
|
||||
std::atomic<uint64_t> mNumFinished;
|
||||
std::atomic<bool> mShutdown;
|
||||
bool mLogIterationData;
|
||||
}; // class ExecutorServer
|
||||
|
||||
class GptServer
|
||||
|
||||
@ -201,6 +201,7 @@ public:
|
||||
, mDecodingIter(0)
|
||||
, mPriority(req.getPriority())
|
||||
, mFinishReasons(mSamplingConfig.beamWidth)
|
||||
, mEncoderInputFeatures(std::nullopt)
|
||||
, mEncoderOutputLength(req.getEncoderOutputLength())
|
||||
, mContextPhaseParams(req.getContextPhaseParams())
|
||||
, mInputTokenExtraIds(std::nullopt)
|
||||
@ -263,7 +264,8 @@ public:
|
||||
auto pTuningConfig = req.getPromptTuningConfig();
|
||||
if (pTuningConfig)
|
||||
{
|
||||
mPromptEmbeddingTable = executor::detail::toITensor(pTuningConfig.value().getEmbeddingTable());
|
||||
mPromptEmbeddingTable = tensorrt_llm::runtime::ITensor::view(
|
||||
executor::detail::toITensor(pTuningConfig.value().getEmbeddingTable()));
|
||||
TLLM_CHECK(mPromptEmbeddingTable.value()->getShape().nbDims == 2);
|
||||
mPromptVocabSize = mPromptEmbeddingTable.value()->getShape().d[0];
|
||||
mPromptEmbeddingTable.value()->unsqueeze(0);
|
||||
@ -1438,6 +1440,36 @@ public:
|
||||
0.0, std::chrono::duration<double, std::milli>(mKvCacheTransferEnd - mKvCacheTransferStart).count());
|
||||
}
|
||||
|
||||
void updateAllocTotalBlocksPerRequest(SizeType32 allocTotalBlocksPerRequest)
|
||||
{
|
||||
mAllocTotalBlocksPerRequest += allocTotalBlocksPerRequest;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getAllocTotalBlocksPerRequest() const
|
||||
{
|
||||
return mAllocTotalBlocksPerRequest;
|
||||
}
|
||||
|
||||
void updateAllocNewBlocksPerRequest(SizeType32 allocNewBlocksPerRequest)
|
||||
{
|
||||
mAllocNewBlocksPerRequest += allocNewBlocksPerRequest;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getAllocNewBlocksPerRequest() const
|
||||
{
|
||||
return mAllocNewBlocksPerRequest;
|
||||
}
|
||||
|
||||
void updateReusedBlocksPerRequest(SizeType32 reusedBlocksPerRequest)
|
||||
{
|
||||
mReusedBlocksPerRequest += reusedBlocksPerRequest;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getReusedBlocksPerRequest() const
|
||||
{
|
||||
return mReusedBlocksPerRequest;
|
||||
}
|
||||
|
||||
RequestIdType mRequestId;
|
||||
SizeType32 mPromptLen;
|
||||
SizeType32 mMaxNewTokens;
|
||||
@ -1545,6 +1577,10 @@ protected:
|
||||
std::chrono::time_point<std::chrono::steady_clock> mKvCacheTransferStart;
|
||||
std::chrono::time_point<std::chrono::steady_clock> mKvCacheTransferEnd;
|
||||
|
||||
SizeType32 mAllocTotalBlocksPerRequest{0};
|
||||
SizeType32 mAllocNewBlocksPerRequest{0};
|
||||
SizeType32 mReusedBlocksPerRequest{0};
|
||||
|
||||
private:
|
||||
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
|
||||
{
|
||||
|
||||
@ -297,6 +297,8 @@ struct IterationStats
|
||||
double iterLatencyMS;
|
||||
/// @brief The total time spent in queue by the requests that became active in this iteration (ms)
|
||||
double newActiveRequestsQueueLatencyMS;
|
||||
/// @brief Number of new fetched active requests
|
||||
SizeType32 numNewActiveRequests;
|
||||
/// @brief Number of active requests
|
||||
SizeType32 numActiveRequests;
|
||||
/// @brief Number of queued requests
|
||||
@ -364,6 +366,12 @@ struct RequestStats
|
||||
bool paused;
|
||||
/// @brief Stats specific to disaggregated serving
|
||||
std::optional<DisServingRequestStats> disServingStats;
|
||||
/// @brief Number of total allocated blocks per request
|
||||
SizeType32 allocTotalBlocksPerRequest;
|
||||
/// @brief Number of newly allocated blocks per request
|
||||
SizeType32 allocNewBlocksPerRequest;
|
||||
/// @brief Number of reused blocks per request
|
||||
SizeType32 reusedBlocksPerRequest;
|
||||
};
|
||||
|
||||
/// @brief Struct that holds the stats of all requests in an iteration
|
||||
|
||||
@ -115,7 +115,6 @@ public:
|
||||
std::optional<SizeType32> genMicroBatchSize = std::nullopt;
|
||||
std::optional<executor::DecodingMode> decodingMode = std::nullopt;
|
||||
bool normalizeLogProbs = true;
|
||||
std::optional<std::filesystem::path> enginePath;
|
||||
};
|
||||
|
||||
//! @brief Optional profiler class to profile the generation phase of an inference request
|
||||
|
||||
@ -127,6 +127,7 @@ public:
|
||||
, mContextFMHA(false)
|
||||
, mPagedContextFMHA(false)
|
||||
, mUseXQA{false}
|
||||
, mPpReduceScatter{false}
|
||||
, mUseLoraPlugin(false)
|
||||
, mMlpHiddenSize(0)
|
||||
, mUseCrossAttention(false)
|
||||
@ -468,6 +469,16 @@ public:
|
||||
return mUseXQA;
|
||||
}
|
||||
|
||||
void constexpr setPpReduceScatter(bool ppReduceScatter) noexcept
|
||||
{
|
||||
mPpReduceScatter = ppReduceScatter;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr getPpReduceScatter() const noexcept
|
||||
{
|
||||
return mPpReduceScatter;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr useLoraPlugin() const noexcept
|
||||
{
|
||||
return mUseLoraPlugin;
|
||||
@ -759,6 +770,7 @@ private:
|
||||
bool mContextFMHA;
|
||||
bool mPagedContextFMHA;
|
||||
bool mUseXQA;
|
||||
bool mPpReduceScatter;
|
||||
|
||||
bool mUseLoraPlugin;
|
||||
std::vector<LoraModule> mLoraModules;
|
||||
|
||||
@ -50,6 +50,11 @@ public:
|
||||
return SpeculativeDecodingMode{kExplicitDraftTokens};
|
||||
}
|
||||
|
||||
static auto constexpr Eagle()
|
||||
{
|
||||
return SpeculativeDecodingMode{kEagle};
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr isNone() const
|
||||
{
|
||||
return anyBitSet(kNone);
|
||||
@ -75,29 +80,34 @@ public:
|
||||
return anyBitSet(kExplicitDraftTokens);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr isEagle() const
|
||||
{
|
||||
return anyBitSet(kEagle);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr updatesPositionIds() const
|
||||
{
|
||||
return anyBitSet(kLookaheadDecoding | kExplicitDraftTokens);
|
||||
return anyBitSet(kLookaheadDecoding | kExplicitDraftTokens | kEagle);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr requiresAttentionMask() const
|
||||
{
|
||||
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens);
|
||||
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens | kEagle);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr predictsDraftTokens() const
|
||||
{
|
||||
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens);
|
||||
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens | kEagle);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr needsKVCacheRewind() const
|
||||
{
|
||||
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens);
|
||||
return anyBitSet(kLookaheadDecoding | kMedusa | kExplicitDraftTokens | kEagle);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr variableDraftLength() const
|
||||
{
|
||||
return anyBitSet(kDraftTokensExternal | kExplicitDraftTokens | kLookaheadDecoding);
|
||||
return anyBitSet(kDraftTokensExternal | kExplicitDraftTokens | kLookaheadDecoding | kEagle);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr hasDraftLogits() const
|
||||
@ -107,7 +117,7 @@ public:
|
||||
|
||||
[[nodiscard]] bool constexpr needsDecoderPrologue() const
|
||||
{
|
||||
return anyBitSet(kExplicitDraftTokens | kLookaheadDecoding);
|
||||
return anyBitSet(kExplicitDraftTokens | kLookaheadDecoding | kEagle);
|
||||
}
|
||||
|
||||
using UnderlyingType = std::uint8_t;
|
||||
@ -129,6 +139,7 @@ private:
|
||||
static UnderlyingType constexpr kMedusa{1U << 2U};
|
||||
static UnderlyingType constexpr kLookaheadDecoding{1U << 3U};
|
||||
static UnderlyingType constexpr kExplicitDraftTokens{1U << 4U};
|
||||
static UnderlyingType constexpr kEagle{1U << 5U};
|
||||
|
||||
[[nodiscard]] bool constexpr anyBitSet(UnderlyingType bits) const
|
||||
{
|
||||
@ -173,4 +184,11 @@ static_assert(!SpeculativeDecodingMode::ExplicitDraftTokens().isDraftTokensExter
|
||||
static_assert(!SpeculativeDecodingMode::ExplicitDraftTokens().isMedusa());
|
||||
static_assert(!SpeculativeDecodingMode::ExplicitDraftTokens().isLookaheadDecoding());
|
||||
|
||||
static_assert(SpeculativeDecodingMode::Eagle().isEagle());
|
||||
static_assert(!SpeculativeDecodingMode::Eagle().isNone());
|
||||
static_assert(!SpeculativeDecodingMode::Eagle().isDraftTokensExternal());
|
||||
static_assert(!SpeculativeDecodingMode::Eagle().isMedusa());
|
||||
static_assert(!SpeculativeDecodingMode::Eagle().isExplicitDraftTokens());
|
||||
static_assert(!SpeculativeDecodingMode::Eagle().isLookaheadDecoding());
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1a292517d802f2297c5d12d5d14ab597f47f46ebd31412fac044ceb9ca51a482
|
||||
size 5160586
|
||||
oid sha256:a55035628e0035141b4ea79b946f49ad77893d6e5d1ab47c402e1a9b95fbbb6c
|
||||
size 5160128
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8575fb58200701ae30feb4b8bd3f325f8018aac5505167fdba42e269adb3bd8c
|
||||
size 5271836
|
||||
oid sha256:ed219fad83caf000a40f0688fdb20cb8593a5fe8096316d645229ee160c42514
|
||||
size 5271480
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
954182e0c057f71f858a84f746201044 libtensorrt_llm_batch_manager_static.a
|
||||
dfe6ca360cf1d24a3dcae0a2bf8589c0 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
d7508bec7b6f112a2eac04cbeaf8b5da libtensorrt_llm_batch_manager_static.a
|
||||
d8969624b327af844d9ffba910084b93 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8fe84073b7ccff8dc361fdee64c3ef30bc523909e0bf9c16547f76a05a53fb5c
|
||||
size 5009886
|
||||
oid sha256:36479d1577d131e36ca03549467a6cfe4822868ca0f3dda3b5d254ee4680341f
|
||||
size 5009646
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6e565c2c3ce58656742772591d992aca91c7e46eb9fc711599d2d51928b88b48
|
||||
size 4970532
|
||||
oid sha256:b5caef410133f1552418978aa20cc1d3f7b6500b1dbc8b9f44232554b7cc8390
|
||||
size 4971234
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
61fd34e765788884d42f4ba27f085520 libtensorrt_llm_batch_manager_static.a
|
||||
e8a64dd19a234304483ef6756e67fd40 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
7029ee9cb0a921a3603e98815da18985 libtensorrt_llm_batch_manager_static.a
|
||||
0e7fe69b6621fe6dabcc0b372c3440f4 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:200a6721aa1d6e009c94866adab36ac686eb1beef02df267af7e18e31e11612b
|
||||
size 32436708
|
||||
oid sha256:b86e215e86c7b0f8b0c9618fb655e6e4f31cc731f778cf0ca12fde93c7afbcab
|
||||
size 32389592
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
9485cfa635b17378f23d1624b3acfbaf tensorrt_llm_batch_manager_static.lib
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
afac175cfda36b14d76e17517bad8b24 tensorrt_llm_batch_manager_static.lib
|
||||
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
|
||||
@ -92,7 +92,7 @@ template <
|
||||
typename Policy_,
|
||||
/// Number of stages,
|
||||
int Stages,
|
||||
/// Converter for B matrix applited immediately after the LDS
|
||||
/// Converter for B matrix applied immediately after the LDS
|
||||
typename TransformBAfterLDS_,
|
||||
/// The quantization operator being used
|
||||
WeightOnlyQuantOp QuantOp_,
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:809a1da76123ec4c640d63efc902209585223b66e23d887db9a198c5836986a2
|
||||
size 3349066
|
||||
oid sha256:414606be5b56f592fc7bd25f1e9fbf958c900dd2b01e01907029dfe19408ce59
|
||||
size 3349230
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6846ecefa017d03ab7d853908794c884ab4e92a500e223278b1d64eab59ed061
|
||||
size 3376088
|
||||
oid sha256:682cf952def054fce6116983a3b5686994b71744fcc85a65e3c9a6e44549c82d
|
||||
size 3377832
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
5a771664fdb75d99ba5fb90249ac26f0 libtensorrt_llm_executor_static.a
|
||||
3b433ea93b7d1d6fa471b457980f2680 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
dc9b4081af6357227886180a1b9a6d8d libtensorrt_llm_executor_static.a
|
||||
8291552cf3e8da9dc368c8c37cd35abe libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:479e86f410763445357f5d879cc666d210352dda9709ab5ab56e73591a9e8af8
|
||||
size 7851266
|
||||
oid sha256:88810c1dac205a1111fc833c0fe0d38486152b4b878fd972585eec2ac27d5160
|
||||
size 7857242
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6473c77d18929fa75342d63ffc591df39e8aeba1dda0b920b0187d4888710559
|
||||
size 7767384
|
||||
oid sha256:c023d6bad569fb3b3c528f3e003afa6a5f11a045bdccb06ca875607a6c781ade
|
||||
size 7769728
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
5424fb0f82076e03b5316f73aed04434 libtensorrt_llm_executor_static.a
|
||||
d0b1236baf61fc5c43383bbc1cd50fa8 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
fd9cb10c300350266f65957475404bff libtensorrt_llm_executor_static.a
|
||||
b8b0ae2861ef66853330441752ab1e32 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:dee57c9257a6678833e3c0d83e8df07aff25c185bc085db75938cec6652044c0
|
||||
size 24568210
|
||||
oid sha256:baf4dd1bacd75c4eae6d98fe411bbb5d478dc5905a298d4238db3db21121ebca
|
||||
size 24630026
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
305fac5d046a574ded2d46d968f746b0 tensorrt_llm_executor_static.lib
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
30d62c80211e4a2dc38bbe9dc5257839 tensorrt_llm_executor_static.lib
|
||||
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
|
||||
@ -74,7 +74,6 @@ void sm80_generic_fused_moe_gemm_kernelLauncher(ElementType_ const* A, CutlassWe
|
||||
int occupancy = std::min(2, fused_moe::fused_gemm_maximum_active_blocks<GemmType>());
|
||||
int const threadblock_count = multi_processor_count * occupancy;
|
||||
TLLM_CHECK_WITH_INFO(occupancy > 0, "GPU lacks the shared memory resources to run fused_moe kernel");
|
||||
GemmType gemm;
|
||||
using Arguments = typename GemmType::Arguments;
|
||||
Arguments args{{const_cast<ElementType_*>(A), const_cast<CutlassWeightType_*>(B), const_cast<ElementType_*>(biases),
|
||||
reinterpret_cast<ElementType_*>(C), total_tokens_including_expert, static_cast<int>(gemm_n),
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
88c30973b9b3452baa3f063d34d08169 libtensorrt_llm_nvrtc_wrapper.so
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
|
||||
@ -1,2 +1,2 @@
|
||||
95e9f87610383348e444d2d0b8396f2d libtensorrt_llm_nvrtc_wrapper.so
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:db512d533ab4e4a4abd0047a65d891dfd6e1522f2d34c90f29296c3239fd3cc1
|
||||
oid sha256:3bc495e1e677616db2756eb7d56d1161c34ae723896db34487883a955e2b3442
|
||||
size 1128448
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e207a8f57b944529163c7ed2ab30639a5f2779c5118602c6ebd50a623d16f845
|
||||
oid sha256:1a6c03470aaa69378d4989971ab9dd00ee427f7e14a85ba5e114ea0594c4de5e
|
||||
size 3488
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
b7e624ba775e9f5090ef4b67bcdbd7a2 tensorrt_llm_nvrtc_wrapper.lib
|
||||
d89a0a140d2d427af13c3794a4b21e2c tensorrt_llm_nvrtc_wrapper.dll
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
c5f36e093e875c8ea84523fb1566d986 tensorrt_llm_nvrtc_wrapper.lib
|
||||
de4b2f87f8eb1027f89c0f5cb05ca047 tensorrt_llm_nvrtc_wrapper.dll
|
||||
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0814af36fed752bbe70d953cefbb78dd306c42f3d9f6848b7043a865e48f9662
|
||||
oid sha256:80dbb6e3a34380bf4e375901ad9b71df24ec97cddcaa9f226bc0a278d11cbdd6
|
||||
size 25364090
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ee46f2d1c9162f4302a1031f778fcb7c7110c84110427f97af6532ed9bd342fd
|
||||
oid sha256:31e5cd6ef9e3599d55501ab0484b81f82ef1f22a79360a2699cd4a62c4928115
|
||||
size 25768990
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
90740ead1def66f350e14c133278463d libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
b0104227ffd1ce19fc1fdb45e349df36 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
1febd9d1bf244163deb269e2bebcd1e3 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
8fdb39f871225dedd32ca6651f1944ba libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4d9ba0f8b95cf64227cb0b17654fb7c9bc1741fe003889658b305750b388a4dc
|
||||
oid sha256:3431f91bcb2cadb8a2641c4ea54d1f8f90c5aa7648591510e3a27865c94169ea
|
||||
size 44173632
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4f848d5beebbd69792047a96b16f7145f8e1e3e311d2a19789ce639ad8149b0e
|
||||
oid sha256:1dedd4dd1df76a57576e749b4105a5d5f5070a6f7ee30d11944105742fea9b4b
|
||||
size 43561206
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
2aaf05cb84f52b024e89d4fa634d6900 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
f17ce186e9105c594e39d252777ce4c7 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
8683b15e77bf62ee9f57a2507e21e6a7 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
a065a7b6a11b079ee544664dddcf59a6 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c429687e335c75f08186bcd8f629b50467cb0f2e484d755834c5b1cdbb9ecaf3
|
||||
size 88140796
|
||||
oid sha256:c7afdf2c313685b0e31f4e5572e20cd11d94227177849784ce7405e15a3587f6
|
||||
size 88140804
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
4f663be2b768088805ccec6dc33545fc tensorrt_llm_internal_cutlass_kernels_static.lib
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
7eee845e969cfb8d589074d81288b700 tensorrt_llm_internal_cutlass_kernels_static.lib
|
||||
3eeadd9a4a9ca2558b3a2f2089419f8d285744e5 commit
|
||||
@ -48,7 +48,7 @@ __global__ void topKStage1(T const* __restrict logProbs, T const* const* __restr
|
||||
auto const tokenIdx = static_cast<SizeType32>(blockIdx.y);
|
||||
|
||||
auto const batchId = bid / BLOCKS_PER_BEAM_; // row id for logProbs
|
||||
auto const batchSlot = batchSlots[batchId];
|
||||
auto const batchSlot = batchSlots == nullptr ? batchId : batchSlots[batchId];
|
||||
if (tokensPerStep != nullptr && tokenIdx >= tokensPerStep[batchSlot])
|
||||
{
|
||||
return;
|
||||
@ -63,7 +63,6 @@ __global__ void topKStage1(T const* __restrict logProbs, T const* const* __restr
|
||||
auto const logBufIndex = batchId * maxTokensPerStep * vocabSize + tokenIdx * vocabSize;
|
||||
auto logProbsSlot
|
||||
= logProbsPtrs == nullptr ? logProbs + logBufIndex : logProbsPtrs[batchId * maxTokensPerStep + tokenIdx];
|
||||
|
||||
auto const blockLane = bid % BLOCKS_PER_BEAM_; // block id for a beam
|
||||
auto const k = (topKs != nullptr) ? topKs[batchSlot] : maxTopK; // batchId = batch index
|
||||
|
||||
@ -77,7 +76,7 @@ __global__ void topKStage1(T const* __restrict logProbs, T const* const* __restr
|
||||
|
||||
if (finished != nullptr && finishState.isFinished())
|
||||
{
|
||||
if (tid < k)
|
||||
if (tid < k && endIds != nullptr) // if returnAllSelectedToken, endIds would not be an input
|
||||
{
|
||||
auto const index = tmpTopKBufIndex + tid;
|
||||
if (blockLane == 0 && tid == 0)
|
||||
@ -134,7 +133,7 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
|
||||
float const* topPs, curandState_t* curandState, TokenIdType const* endIds, SizeType32 vocabSize,
|
||||
bool const* skipDecode, SizeType32 const* batchSlots, SizeType32 maxBatchSize, bool normalizeLogProbs,
|
||||
bool logitHasProbs, SizeType32 const* tokensPerStep, SizeType32 maxTokensPerStep, SizeType32 maxSeqLen,
|
||||
bool returnAllTopK)
|
||||
bool returnAllSelectedTokens)
|
||||
{
|
||||
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||
T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
@ -142,7 +141,7 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
|
||||
auto const tid = static_cast<SizeType32>(threadIdx.x);
|
||||
auto const batchIdx = static_cast<SizeType32>(blockIdx.x);
|
||||
auto const tokenIdx = static_cast<SizeType32>(blockIdx.y);
|
||||
auto const batchSlot = batchSlots[batchIdx];
|
||||
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
|
||||
FinishedState const finishState = finishedInput != nullptr ? finishedInput[batchSlot] : FinishedState::empty();
|
||||
if ((skipDecode != nullptr && skipDecode[batchSlot]) || (finishState.isSkipDecoding()))
|
||||
{
|
||||
@ -215,13 +214,16 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
|
||||
|
||||
if (tid == 0)
|
||||
{
|
||||
auto randNum = static_cast<float>(curand_uniform(curandState + batchSlot) * probThreshold * sSum);
|
||||
// if we want to return all top k indices, we should not do random sampling for probThreshold
|
||||
auto randNum = (returnAllSelectedTokens || curandState == nullptr)
|
||||
? static_cast<float>(probThreshold * sSum)
|
||||
: static_cast<float>(curand_uniform(curandState + batchSlot) * probThreshold * sSum);
|
||||
auto* outputIdsRequestPtr = idsPtrs == nullptr ? ids + batchSlot * maxSeqLen : idsPtrs[batchSlot];
|
||||
for (SizeType32 ki = 0; ki < k; ki++)
|
||||
{
|
||||
auto expLogit = sVal2[ki];
|
||||
randNum = randNum - expLogit;
|
||||
if (randNum <= 0.0f || ki == k - 1 || returnAllTopK)
|
||||
if (randNum <= 0.0f || ki == k - 1 || returnAllSelectedTokens)
|
||||
{
|
||||
auto idx = sId[ki];
|
||||
// If sId is -1 here we force output token to the last from vocabulary to get vivid indicator of smth
|
||||
@ -230,10 +232,10 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
|
||||
? topKTmpIdBuf[(batchIdx * maxTokensPerStep + tokenIdx) * stride + idx] % vocabSize
|
||||
: vocabSize - 1;
|
||||
auto const curSeqLen = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot];
|
||||
auto const outIdx = returnAllTopK ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx;
|
||||
auto const outIdx = returnAllSelectedTokens ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx;
|
||||
outputIdsRequestPtr[outIdx] = outputId;
|
||||
// cum log prob is not supported with returnAllTopK
|
||||
if (!returnAllTopK)
|
||||
// cum log prob is not supported with returnAllSelectedTokens
|
||||
if (!returnAllSelectedTokens)
|
||||
{
|
||||
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
|
||||
{
|
||||
@ -255,9 +257,17 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
|
||||
}
|
||||
break;
|
||||
}
|
||||
if (returnAllSelectedTokens && randNum <= 0.0f)
|
||||
{
|
||||
if (ki < k - 1)
|
||||
{ // not the last k, write a -1 to to log top p tokens boundary for external draft token masking
|
||||
outputIdsRequestPtr[outIdx + 1] = -1;
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (maxTokensPerStep == 1 && !returnAllTopK && sequenceLengths != nullptr && finishedOutput != nullptr
|
||||
if (maxTokensPerStep == 1 && !returnAllSelectedTokens && sequenceLengths != nullptr && finishedOutput != nullptr
|
||||
&& endIds != nullptr)
|
||||
{
|
||||
auto const seqLen = sequenceLengths[batchSlot];
|
||||
@ -297,7 +307,7 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
|
||||
params.maxTopK, params.topKs, params.maxTopP, params.topPs, params.curandState, params.endIds, \
|
||||
params.vocabSizePadded, params.skipDecode, params.batchSlots, params.maxBatchSize, \
|
||||
params.normalizeLogProbs, params.logitsHasProbs, params.tokensPerStep, params.maxTokensPerStep, \
|
||||
params.maxSeqLen, params.returnAllTopK); \
|
||||
params.maxSeqLen, params.returnAllSelectedTokens); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
@ -34,8 +34,8 @@ struct TopKSamplingKernelParams
|
||||
//! Log probabilities of each token in the vocab. If logitsHasProbs is true,
|
||||
//! logProbs must contain **just** probabilities instead of log probabilities.
|
||||
T const* logProbs{nullptr};
|
||||
//! input buffer [batchSize][vocabSizePadded] array of pointers to logits.
|
||||
//! If nullptr, logProbs is used. Only maxTokensPerStep == 1 is supported.
|
||||
//! input buffer [batchSize][tokensPerStep, vocabSizePadded] array of pointers to logits.
|
||||
//! If nullptr, logProbs is used.
|
||||
T const* const* logProbsPtrs{nullptr};
|
||||
|
||||
//! output buffer [maxBatchSize][maxSeqLen], optional. Contains pointers to rows
|
||||
@ -82,7 +82,8 @@ struct TopKSamplingKernelParams
|
||||
//! Ignored if nullptr.
|
||||
float* outputLogProbs{nullptr};
|
||||
|
||||
//! input buffer [maxBatchSize]. Initialized curand states
|
||||
//! input buffer [maxBatchSize], optional. Initialized curand states.
|
||||
//! If nullptr, 1 is always used.
|
||||
curandState_t* curandState{nullptr};
|
||||
//! input buffer [maxBatchSize]. K for topK sampling per request.
|
||||
//! Supported K is in range [1; 1024]. Where K=1 is greedy search.
|
||||
@ -106,8 +107,8 @@ struct TopKSamplingKernelParams
|
||||
bool normalizeLogProbs{false};
|
||||
//! flag to highlight that logProbs contains probabilities
|
||||
bool logitsHasProbs{false};
|
||||
//! flag to return all selectedTopK results
|
||||
bool returnAllTopK{false};
|
||||
//! flag to return all selected TopK results
|
||||
bool returnAllSelectedTokens{false};
|
||||
|
||||
void checkParams() const
|
||||
{
|
||||
@ -131,13 +132,12 @@ struct TopKSamplingKernelParams
|
||||
}
|
||||
|
||||
TLLM_CHECK(workspace);
|
||||
TLLM_CHECK(curandState);
|
||||
|
||||
TLLM_CHECK(maxTokensPerStep != 1 || returnAllTopK || sequenceLengths);
|
||||
TLLM_CHECK(maxTokensPerStep != 1 || returnAllTopK || endIds);
|
||||
TLLM_CHECK(maxTokensPerStep != 1 || returnAllSelectedTokens || sequenceLengths);
|
||||
TLLM_CHECK(maxTokensPerStep != 1 || returnAllSelectedTokens || endIds);
|
||||
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
|
||||
{
|
||||
TLLM_CHECK(maxTokensPerStep == 1 && !returnAllTopK);
|
||||
TLLM_CHECK(maxTokensPerStep == 1 && !returnAllSelectedTokens);
|
||||
}
|
||||
TLLM_CHECK(((finishedOutput == nullptr) ^ (endIds == nullptr)) == 0);
|
||||
|
||||
|
||||
@ -200,7 +200,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
|
||||
SizeType32* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
|
||||
float* outputLogProbs, SizeType32 const* beginOffsetBuf, SizeType32 const* offsetBuf, SizeType32 vocabSize,
|
||||
curandState_t* curandState, float const* topPs, TokenIdType const* endIds, SizeType32 maxBatchSize,
|
||||
bool const* skipDecode, SizeType32 const* batchSlots, bool returnAllTopP, SizeType32 maxSeqLen)
|
||||
bool const* skipDecode, SizeType32 const* batchSlots, bool returnAllSelectedTokens, SizeType32 maxSeqLen)
|
||||
{
|
||||
/**
|
||||
* Each block processes one request row sorted in descending order by probabilities.
|
||||
@ -244,7 +244,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
// if we want to return all top p indices, we should not do random sampling for probThreshold
|
||||
randNumS = returnAllTopP ? probThreshold : curand_uniform(curandState + blockIdx.x) * probThreshold;
|
||||
randNumS = returnAllSelectedTokens ? probThreshold : curand_uniform(curandState + blockIdx.x) * probThreshold;
|
||||
}
|
||||
|
||||
// if beginOffsetBuf and offsetBuf of sorting have same value,
|
||||
@ -255,7 +255,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
|
||||
if (tid == 0)
|
||||
{
|
||||
auto offset = batchId * vocabSize;
|
||||
if (returnAllTopP)
|
||||
if (returnAllSelectedTokens)
|
||||
{
|
||||
outputIdsRequestPtr[currentStep] = sortedIdVals[offset];
|
||||
}
|
||||
@ -294,7 +294,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
|
||||
}
|
||||
}
|
||||
|
||||
if (returnAllTopP)
|
||||
if (returnAllSelectedTokens)
|
||||
{
|
||||
__shared__ SizeType32 sharedSelectedTokenId;
|
||||
if (threadIdx.x == min(blockDim.x - count, blockDim.x - 1))
|
||||
@ -403,7 +403,7 @@ void invokeBatchTopPSampling(TopPSamplingKernelParams<T> const& params, cudaStre
|
||||
params.outputIds, params.outputIdsPtrs, params.sequenceLength, params.finishedInput, params.finishedOutput,
|
||||
params.cumLogProbs, params.outputLogProbs, beginOffsetBuf, offsetBuf + 1, params.vocabSizePadded,
|
||||
params.curandState, params.topPs, params.endIds, params.maxBatchSize, params.skipDecode, params.batchSlots,
|
||||
params.returnAllTopP, params.maxSeqLen);
|
||||
params.returnAllSelectedTokens, params.maxSeqLen);
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -80,7 +80,7 @@ struct TopPSamplingKernelParams
|
||||
runtime::SizeType32 vocabSizePadded{-1};
|
||||
runtime::SizeType32 maxSeqLen{-1};
|
||||
|
||||
bool returnAllTopP{false};
|
||||
bool returnAllSelectedTokens{false};
|
||||
|
||||
void checkParams() const
|
||||
{
|
||||
@ -91,7 +91,7 @@ struct TopPSamplingKernelParams
|
||||
TLLM_CHECK(probs);
|
||||
TLLM_CHECK(outputIds || outputIdsPtrs);
|
||||
TLLM_CHECK(workspace);
|
||||
TLLM_CHECK((sequenceLength != nullptr) || returnAllTopP);
|
||||
TLLM_CHECK((sequenceLength != nullptr) || returnAllSelectedTokens);
|
||||
TLLM_CHECK(curandState);
|
||||
TLLM_CHECK(topPs);
|
||||
|
||||
|
||||
@ -0,0 +1,136 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
|
||||
#include "tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h"
|
||||
#ifndef CUDART_VERSION
|
||||
#error CUDART_VERSION Undefined!
|
||||
#elif (CUDART_VERSION >= 11050)
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include "3rdparty/cub/cub.cuh"
|
||||
#endif
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tensorrt_llm::kernels::speculative_decoding
|
||||
{
|
||||
namespace
|
||||
{
|
||||
template <typename T, int BLOCK_SIZE>
|
||||
__global__ void assembleTargetLogitsOffsets(T const** logitsPtrs, SizeType32* decodingTokens, T const* logits,
|
||||
SizeType32 const* draftDecodingTokens, SizeType32 batchSize, SizeType32 maxDecodingTokens,
|
||||
SizeType32 vocabSizePadded)
|
||||
{
|
||||
typedef cub::BlockScan<SizeType32, BLOCK_SIZE> BlockScan;
|
||||
__shared__ typename BlockScan::TempStorage tempStorage;
|
||||
|
||||
auto const tix = static_cast<SizeType32>(threadIdx.x);
|
||||
|
||||
SizeType32 numDecodingTokens{0};
|
||||
if (tix < batchSize)
|
||||
{
|
||||
numDecodingTokens = draftDecodingTokens[tix] + 1;
|
||||
decodingTokens[tix] = numDecodingTokens;
|
||||
}
|
||||
|
||||
SizeType32 logitsOffset{0};
|
||||
BlockScan(tempStorage).ExclusiveSum(numDecodingTokens, logitsOffset);
|
||||
|
||||
if (tix < batchSize)
|
||||
{
|
||||
for (SizeType32 ti = 0; ti < numDecodingTokens; ++ti)
|
||||
{
|
||||
logitsPtrs[tix * maxDecodingTokens + ti] = logits + (logitsOffset + ti) * vocabSizePadded;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void invokeAssembleTargetLogitsOffsets(T const** logitsPtrs, SizeType32* decodingTokens, T const* logits,
|
||||
SizeType32 const* draftDecodingTokens, SizeType32 batchSize, SizeType32 maxDecodingTokens,
|
||||
SizeType32 vocabSizePadded, cudaStream_t stream)
|
||||
{
|
||||
SizeType32 constexpr BLOCK_SIZE = 512;
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
batchSize <= BLOCK_SIZE, "Batch size larger than %d is not supported for EAGLE yet", batchSize);
|
||||
assembleTargetLogitsOffsets<T, BLOCK_SIZE><<<1, BLOCK_SIZE, 0, stream>>>(
|
||||
logitsPtrs, decodingTokens, logits, draftDecodingTokens, batchSize, maxDecodingTokens, vocabSizePadded);
|
||||
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
template void invokeAssembleTargetLogitsOffsets(float const** logitsPtrs, SizeType32* decodingTokens,
|
||||
float const* logits, SizeType32 const* draftDecodingTokens, SizeType32 batchSize, SizeType32 maxDecodingTokens,
|
||||
SizeType32 vocabSizePadded, cudaStream_t stream);
|
||||
template void invokeAssembleTargetLogitsOffsets(__half const** logitsPtrs, SizeType32* decodingTokens,
|
||||
__half const* logits, SizeType32 const* draftDecodingTokens, SizeType32 batchSize, SizeType32 maxDecodingTokens,
|
||||
SizeType32 vocabSizePadded, cudaStream_t stream);
|
||||
|
||||
namespace
|
||||
{
|
||||
template <int BLOCK_SIZE>
|
||||
__global__ void selectLastAccTokenAndComputeIndicesCumSum(TokenIdType* lastAcceptedTokenIds,
|
||||
SizeType32* exclusiveSumLastAcceptedIndices, SizeType32 const* draftDecodingTokens,
|
||||
TokenIdType const* acceptedTokenIds, SizeType32 const* acceptedLengths, SizeType32 const* bestPathIds,
|
||||
SizeType32 const* paths, SizeType32 batchSize, SizeType32 maxDecodingTokens, SizeType32 maxPathLen)
|
||||
{
|
||||
typedef cub::BlockScan<SizeType32, BLOCK_SIZE> BlockScan;
|
||||
__shared__ typename BlockScan::TempStorage tempStorage;
|
||||
|
||||
auto const tix = static_cast<SizeType32>(threadIdx.x);
|
||||
SizeType32 decodingTokens{0};
|
||||
SizeType32 lastTokenId{0};
|
||||
if (tix < batchSize)
|
||||
{
|
||||
auto const acceptedLen = acceptedLengths[tix];
|
||||
lastAcceptedTokenIds[tix] = acceptedTokenIds[tix * maxPathLen + acceptedLen - 1];
|
||||
auto const bestPathId = bestPathIds[tix];
|
||||
auto const pathIdx = flat_index3(tix, bestPathId, acceptedLen - 1, maxDecodingTokens, maxPathLen);
|
||||
lastTokenId = paths[pathIdx];
|
||||
decodingTokens = draftDecodingTokens[tix] + 1;
|
||||
}
|
||||
|
||||
BlockScan(tempStorage).ExclusiveSum(decodingTokens, decodingTokens);
|
||||
|
||||
if (tix < batchSize)
|
||||
{
|
||||
exclusiveSumLastAcceptedIndices[tix] = decodingTokens + lastTokenId;
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void invokeSelectLastAccTokenAndComputeIndicesCumSum(TokenIdType* lastAcceptedTokenIds,
|
||||
SizeType32* exclusiveSumLastAcceptedIndices, SizeType32 const* draftDecodingTokens,
|
||||
TokenIdType const* acceptedTokenIds, SizeType32 const* acceptedLengths, SizeType32 const* bestPathIds,
|
||||
SizeType32 const* paths, SizeType32 batchSize, SizeType32 maxDecodingTokens, SizeType32 maxPathLen,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
SizeType32 constexpr BLOCK_SIZE = 512;
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
batchSize <= BLOCK_SIZE, "Batch size larger than %d is not supported for EAGLE yet", batchSize);
|
||||
selectLastAccTokenAndComputeIndicesCumSum<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, stream>>>(lastAcceptedTokenIds,
|
||||
exclusiveSumLastAcceptedIndices, draftDecodingTokens, acceptedTokenIds, acceptedLengths, bestPathIds, paths,
|
||||
batchSize, maxDecodingTokens, maxPathLen);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::kernels::speculative_decoding
|
||||
@ -0,0 +1,68 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include "tensorrt_llm/kernels/speculativeDecoding/common.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <curand_kernel.h>
|
||||
|
||||
namespace tensorrt_llm::kernels::speculative_decoding
|
||||
{
|
||||
|
||||
//! \brief Sets pointers to logits in logitsPtrs according to the draftDecodingTokens.
|
||||
//! \param logitsPtrs [batchSize][vocabSizePadded]
|
||||
//! \param decodingTokens [batchSize], on GPU. draftDecodingTokens + 1.
|
||||
//! \param logits [numTokens, vocabSizePadded], on GPU. Continuous logits in memory.
|
||||
//! \param draftDecodingTokens [batchSize], on GPU. 0 for context requests, and actual draft len for gen requests
|
||||
//! \param batchSize batch size. Only batch size <= 512 is supported at the moment
|
||||
//! \param maxDecodingTokens maximum number of decoding tokens per step per request
|
||||
//! \param vocabSizePadded vocab size of the logits
|
||||
//! \param stream cuda stream
|
||||
template <typename T>
|
||||
void invokeAssembleTargetLogitsOffsets(T const** logitsPtrs, runtime::SizeType32* decodingTokens, T const* logits,
|
||||
runtime::SizeType32 const* draftDecodingTokens, runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 vocabSizePadded, cudaStream_t stream);
|
||||
|
||||
//! \brief Sets last accepted token ids and computes inclusive sum of the indices of the last accepted tokens in
|
||||
//! flattened input_ids tensor.
|
||||
//! \param lastAcceptedTokenIds [batchSize], on GPU. Token ids of the last accepted tokens.
|
||||
//! \param exclusiveSumLastAcceptedIndices [batchSize], on GPU. Exclusive sum of the positions of the last accepted
|
||||
//! tokens in the original flattened draft sequence.
|
||||
//! \param draftDecodingTokens [batchSize], on GPU. 0 for context
|
||||
//! requests, and actual draft len for gen requests.
|
||||
//! \param acceptedTokenIds [batchSize, maxPathLen], on GPU. Ids of the
|
||||
//! accepted tokens per request.
|
||||
//! \param acceptedLengths [batchSize], on GPU. Lengths of the accepted draft sequences
|
||||
//! per request.
|
||||
//! \param bestPathIds [batchSize], on GPU. Selected path id per request
|
||||
//! \param paths [batchSize,
|
||||
//! maxDecodingTokens, maxPathLen], on GPU. Indices of the draft sequences
|
||||
//! \param batchSize batch size. Only batch size
|
||||
//! <= 512 is supported at the moment
|
||||
//! \param maxDecodingTokens maximum number of decoding tokens per step per request
|
||||
//! \param maxPathLen maximum path len of the draft sequence
|
||||
//! \param stream cuda stream
|
||||
void invokeSelectLastAccTokenAndComputeIndicesCumSum(runtime::TokenIdType* lastAcceptedTokenIds,
|
||||
runtime::SizeType32* exclusiveSumLastAcceptedIndices, runtime::SizeType32 const* draftDecodingTokens,
|
||||
runtime::TokenIdType const* acceptedTokenIds, runtime::SizeType32 const* acceptedLengths,
|
||||
runtime::SizeType32 const* bestPathIds, runtime::SizeType32 const* paths, runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 maxDecodingTokens, runtime::SizeType32 maxPathLen, cudaStream_t stream);
|
||||
|
||||
} // namespace tensorrt_llm::kernels::speculative_decoding
|
||||
@ -60,7 +60,7 @@ __global__ void maskTargetLogitsKernel(T* targetLogits, SizeType32 const* batchS
|
||||
auto* outputIdsAfterSamplingPtr = outputIdsAfterSampling + batchSlot * vocabSize;
|
||||
auto const useDraftLogits = batchUseDraftLogits[batchSlot];
|
||||
|
||||
if (finishedState.isSkipDecoding())
|
||||
if (finishedState.isSkipDecoding() || finishedState.isFinished())
|
||||
{
|
||||
return;
|
||||
}
|
||||
@ -75,8 +75,8 @@ __global__ void maskTargetLogitsKernel(T* targetLogits, SizeType32 const* batchS
|
||||
|
||||
for (SizeType32 vIdx = tid; vIdx < vocabSize; vIdx += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
if (tokensToMask == 0 && outputIdsAfterSamplingPtr[vIdx] == -1)
|
||||
{ // we need to find the -1 boundary from returnAllTopP outputIds if topK == 0
|
||||
if (outputIdsAfterSamplingPtr[vIdx] == -1)
|
||||
{ // we need to find the -1 boundary from returnAllTopP outputIds if topK == 0 or number of topP indices < topK
|
||||
tokensToMask = vIdx;
|
||||
}
|
||||
maskBuffer[vIdx] = false;
|
||||
@ -124,12 +124,21 @@ __global__ void acceptDraftTokensKernel(T const* draftProbs, T* targetProbs, Siz
|
||||
auto const numDraftTokens = numsDraftTokens[batchSlotBeamWidth];
|
||||
auto const useDraftLogits = batchUseDraftLogits[batchSlotBeamWidth];
|
||||
|
||||
if (draftTokenIdx > numDraftTokens || finishedInput[batchSlot].isSkipDecoding())
|
||||
if (draftTokenIdx > numDraftTokens || finishedInput[batchSlot].isSkipDecoding()
|
||||
|| finishedInput[batchSlot].isFinished())
|
||||
{
|
||||
if (tid == 0)
|
||||
{
|
||||
batchIsAccepted[batchSlot] = true;
|
||||
|
||||
// either finished or skip decode in previous step, this step don't need decoding
|
||||
finishedOutput[batchSlot].setSkipDecoding();
|
||||
|
||||
// if previous step is finished, write the state to next step too
|
||||
if (finishedInput[batchSlot].isFinished())
|
||||
{
|
||||
finishedOutput[batchSlot] = finishedInput[batchSlot];
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -214,7 +223,8 @@ __global__ void forwardAcceptedTokensKernel(SizeType32 batchSize, SizeType32 con
|
||||
for (SizeType32 bi = index; bi < batchSize; bi += static_cast<SizeType32>(gridDim.x * blockDim.x))
|
||||
{
|
||||
auto const batchSlot = batchSlots[bi];
|
||||
if (batchIsAccepted[batchSlot] && !finishedOutput[batchSlot].isSkipDecoding())
|
||||
if (batchIsAccepted[batchSlot] && !finishedOutput[batchSlot].isSkipDecoding()
|
||||
&& !finishedOutput[batchSlot].isFinished())
|
||||
{
|
||||
auto const curSeqLen = sequenceLengths[batchSlot];
|
||||
auto const draftTokenIdx = step;
|
||||
|
||||
@ -46,22 +46,22 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
|
||||
FinishedState* finishedFinal, SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds,
|
||||
T const** medusaLogits, T const** logitsPtrs, SizeType32* curTokensPerStep, SizeType32 const* targetTokensPerStep,
|
||||
SizeType32* bestPathIds, SizeType32 batchSize, SizeType32 vocabSize, SizeType32 maxBatchSize, SizeType32 maxSeqLen,
|
||||
SizeType32 maxNumHeads, SizeType32 maxDecodingTokens)
|
||||
SizeType32 maxDraftPathLen, SizeType32 maxDecodingTokens)
|
||||
{
|
||||
auto const batchIdx = static_cast<SizeType32>(blockIdx.x);
|
||||
auto const batchSlot = batchSlots[batchIdx];
|
||||
auto const inputLength = sequenceLengths[batchSlot];
|
||||
auto const endId = endIds[batchSlot];
|
||||
auto const numTokensPerStep = curTokensPerStep[batchSlot];
|
||||
auto const maxNumDraftTokens = maxNumHeads + 1;
|
||||
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
|
||||
auto const inputLength = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot];
|
||||
auto const endId = endIds == nullptr ? -1 : endIds[batchSlot];
|
||||
auto const numTokensPerStep = curTokensPerStep == nullptr ? maxDecodingTokens : curTokensPerStep[batchSlot];
|
||||
auto const maxPathLen = maxDraftPathLen + 1;
|
||||
|
||||
int4 partialMax{-1, -1, 0, 0};
|
||||
// Go over different paths and construct implicit sequences
|
||||
for (auto pathIdx = static_cast<SizeType32>(threadIdx.x); pathIdx < maxDecodingTokens;
|
||||
pathIdx += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
auto acceptedLength = maxNumDraftTokens;
|
||||
auto const pathOffset = flat_index3(batchSlot, pathIdx, 0, maxDecodingTokens, maxNumDraftTokens);
|
||||
auto acceptedLength = maxPathLen;
|
||||
auto const pathOffset = flat_index3(batchSlot, pathIdx, 0, maxDecodingTokens, maxPathLen);
|
||||
bool hasEnd = false;
|
||||
|
||||
auto const tokenId = paths[pathOffset];
|
||||
@ -75,13 +75,14 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
|
||||
auto nextIdx = tokenId;
|
||||
|
||||
// Go along the path
|
||||
for (SizeType32 ti = 1; ti < maxNumDraftTokens; ++ti)
|
||||
for (SizeType32 ti = 1; ti < maxPathLen; ++ti)
|
||||
{
|
||||
auto const tokenId = paths[pathOffset + ti];
|
||||
// Break if path terminates
|
||||
if (tokenId == -1)
|
||||
{
|
||||
hasEnd = targetToken == endId; // check if last token is EOS when path terminates.
|
||||
hasEnd = endIds == nullptr ? false
|
||||
: targetToken == endId; // check if last token is EOS when path terminates.
|
||||
acceptedLength = hasEnd ? ti - 1 : ti;
|
||||
break;
|
||||
}
|
||||
@ -91,7 +92,7 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
|
||||
auto const draftToken = tokenId >= numTokensPerStep ? -1 : draftIds[draftTokenIdx];
|
||||
// Check if draft tokens are the same as target tokens
|
||||
bool const accepted = draftToken == targetToken;
|
||||
hasEnd = targetToken == endId;
|
||||
hasEnd = endIds == nullptr ? false : targetToken == endId;
|
||||
if (!accepted || hasEnd)
|
||||
{
|
||||
acceptedLength = hasEnd ? ti - 1 : ti;
|
||||
@ -126,7 +127,7 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
|
||||
auto const acceptedLength = totalShared.x;
|
||||
auto const bestPathIdx = totalShared.y;
|
||||
auto const bestNextIdx = numTokensPerStep == 1 ? 0 : totalShared.w;
|
||||
auto const pathOffset = flat_index3(batchSlot, bestPathIdx, 0, maxDecodingTokens, maxNumDraftTokens);
|
||||
auto const pathOffset = flat_index3(batchSlot, bestPathIdx, 0, maxDecodingTokens, maxPathLen);
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < acceptedLength; ti += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
auto const tokenId = paths[pathOffset + ti];
|
||||
@ -142,15 +143,18 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
|
||||
{
|
||||
auto const hasEnd = totalShared.z;
|
||||
// Set end condition
|
||||
if (hasEnd)
|
||||
if (hasEnd && finishedFinal)
|
||||
{
|
||||
finishedFinal[batchSlot].setFinishedEOS();
|
||||
}
|
||||
// Make correction to the sequence length
|
||||
sequenceLengths[batchSlot] += acceptedLength;
|
||||
if (sequenceLengths)
|
||||
{
|
||||
sequenceLengths[batchSlot] += acceptedLength;
|
||||
}
|
||||
acceptedLengths[batchSlot] = acceptedLength;
|
||||
// In Medusa decoding step, number of draft tokens is 0 and must be updated for the next steps
|
||||
if (numTokensPerStep == 1)
|
||||
if (curTokensPerStep && targetTokensPerStep && numTokensPerStep == 1)
|
||||
{
|
||||
curTokensPerStep[batchSlot] = targetTokensPerStep[batchSlot];
|
||||
}
|
||||
@ -158,45 +162,33 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
|
||||
}
|
||||
|
||||
// Prepare logits pointers to respective logits from Medusa Heads for the all-top-K sampling kernel
|
||||
for (auto hi = static_cast<SizeType32>(threadIdx.x); hi < maxNumHeads; hi += static_cast<SizeType32>(blockDim.x))
|
||||
if (medusaLogits && logitsPtrs)
|
||||
{
|
||||
logitsPtrs[batchIdx * maxNumHeads + hi]
|
||||
= medusaLogits[batchSlot * maxNumHeads + hi] + flat_index2(bestNextIdx, 0, vocabSize);
|
||||
for (auto hi = static_cast<SizeType32>(threadIdx.x); hi < maxDraftPathLen;
|
||||
hi += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
logitsPtrs[batchIdx * maxDraftPathLen + hi]
|
||||
= medusaLogits[batchSlot * maxDraftPathLen + hi] + flat_index2(bestNextIdx, 0, vocabSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds, TokenIdType const* targetIds,
|
||||
SizeType32* sequenceLengths, SizeType32* acceptedLengths, FinishedState* finishedFinal,
|
||||
SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds, T const** medusaLogits,
|
||||
T const** logitsPtrs, SizeType32* curTokensPerStep, SizeType32 const* targetTokensPerStep, SizeType32* bestPathIds,
|
||||
SizeType32 batchSize, SizeType32 vocabSize, SizeType32 maxBatchSize, SizeType32 maxSeqLen, SizeType32 maxNumHeads,
|
||||
SizeType32 maxDecodingTokens, cudaStream_t stream)
|
||||
void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams<T> const& params)
|
||||
{
|
||||
constexpr SizeType32 BLOCK_SIZE = 256;
|
||||
dim3 block(BLOCK_SIZE);
|
||||
dim3 grid(batchSize);
|
||||
acceptDraftTokensByIdsWithPaths<T, BLOCK_SIZE><<<grid, block, 0, stream>>>(outputIds, draftIds, targetIds,
|
||||
sequenceLengths, acceptedLengths, finishedFinal, batchSlots, paths, endIds, medusaLogits, logitsPtrs,
|
||||
curTokensPerStep, targetTokensPerStep, bestPathIds, batchSize, vocabSize, maxBatchSize, maxSeqLen, maxNumHeads,
|
||||
maxDecodingTokens);
|
||||
dim3 grid(params.batchSize);
|
||||
acceptDraftTokensByIdsWithPaths<T, BLOCK_SIZE><<<grid, block, 0, params.stream>>>(params.outputIds, params.draftIds,
|
||||
params.targetIds, params.sequenceLengths, params.acceptedLengths, params.finishedFinal, params.batchSlots,
|
||||
params.paths, params.endIds, params.medusaLogits, params.logitsPtrs, params.curTokensPerStep,
|
||||
params.targetTokensPerStep, params.bestPathIds, params.batchSize, params.vocabSize, params.maxBatchSize,
|
||||
params.maxSeqLen, params.maxDraftPathLen, params.maxDecodingTokens);
|
||||
}
|
||||
|
||||
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
|
||||
TokenIdType const* targetIds, SizeType32* sequenceLengths, SizeType32* acceptedLengths,
|
||||
FinishedState* finishedFinal, SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds,
|
||||
float const** medusaLogits, float const** logitsPtrs, SizeType32* curTokensPerStep,
|
||||
SizeType32 const* targetTokensPerStep, SizeType32* bestPathIds, SizeType32 batchSize, SizeType32 vocabSize,
|
||||
SizeType32 maxBatchSize, SizeType32 maxSeqLen, SizeType32 maxNumHeads, SizeType32 maxDecodingTokens,
|
||||
cudaStream_t stream);
|
||||
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
|
||||
TokenIdType const* targetIds, SizeType32* sequenceLengths, SizeType32* acceptedLengths,
|
||||
FinishedState* finishedFinal, SizeType32 const* batchSlots, SizeType32 const* paths, TokenIdType const* endIds,
|
||||
half const** medusaLogits, half const** logitsPtrs, SizeType32* curTokensPerStep,
|
||||
SizeType32 const* targetTokensPerStep, SizeType32* bestPathIds, SizeType32 batchSize, SizeType32 vocabSize,
|
||||
SizeType32 maxBatchSize, SizeType32 maxSeqLen, SizeType32 maxNumHeads, SizeType32 maxDecodingTokens,
|
||||
cudaStream_t stream);
|
||||
template void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams<float> const& params);
|
||||
template void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams<__half> const& params);
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
@ -26,46 +26,87 @@
|
||||
namespace tensorrt_llm::kernels::speculative_decoding
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct AcceptDraftTokensByIdsWithPathsParams
|
||||
{
|
||||
//! output buffer [maxBatchSize, maxSeqLen], input tokens.
|
||||
runtime::TokenIdType* outputIds{nullptr};
|
||||
//! input buffer [maxBatchSize, maxDecodingTokens], draft tokens
|
||||
runtime::TokenIdType const* draftIds{nullptr};
|
||||
//! input buffer [maxBatchSize, maxDecodingTokens], tokens predicted from the target medusa head
|
||||
runtime::TokenIdType const* targetIds{nullptr};
|
||||
//! input/output buffer [maxBatchSize], optional.
|
||||
//! Length of the data in outputIds without draft tokens.
|
||||
//! If set, incrememnted according to the accepted length.
|
||||
runtime::SizeType32* sequenceLengths{nullptr};
|
||||
//! output buffer [maxBatchSize], length of the data accepted tokens
|
||||
runtime::SizeType32* acceptedLengths{nullptr};
|
||||
//! input buffer [maxBatchSize], optional. Finished states per request
|
||||
FinishedState* finishedFinal{nullptr};
|
||||
//! input buffer [batchSize], optional. Address map from local index
|
||||
//! to global index [0, batchSize] -> [0, maxBatchSize].
|
||||
//! If nullptr, batchIdx is used.
|
||||
runtime::SizeType32 const* batchSlots{nullptr};
|
||||
//! input buffer [maxBatchSize, maxDecodingTokens, maxDraftPathLen+1],
|
||||
//! paths to restore sequences from outputIds and targetIds. Should be filled with -1 for everything that is not
|
||||
//! path.
|
||||
runtime::SizeType32 const* paths{nullptr};
|
||||
//! input buffer [maxBatchSize], optional. EOS ids per request.
|
||||
//! No EOS checks if nullptr.
|
||||
runtime::TokenIdType const* endIds{nullptr};
|
||||
//! input buffer [maxDraftPathLen, maxBatchSize, maxDecodingTokens, vocabSize], optional.
|
||||
//! Pointer to the logits from medusa heads.
|
||||
T const** medusaLogits{nullptr};
|
||||
//! output buffer [batchSize, maxDraftPathLen], optional. Contains pointers to the
|
||||
//! respective rows of the medusaLogits for the next after the accepted token
|
||||
T const** logitsPtrs{nullptr};
|
||||
//! current tokens to compute per step will be updated to
|
||||
//! targetTokensPerStep if curTokensPerStep == 1
|
||||
runtime::SizeType32* curTokensPerStep{nullptr};
|
||||
//! target values of tokens to compute per step
|
||||
runtime::SizeType32 const* targetTokensPerStep{nullptr};
|
||||
//! output buffer [maxBatchSize], indices of the selected paths
|
||||
runtime::SizeType32* bestPathIds{nullptr};
|
||||
//! current batch size
|
||||
runtime::SizeType32 batchSize{0};
|
||||
//! maximum batch size
|
||||
runtime::SizeType32 maxBatchSize{0};
|
||||
//! vocab size
|
||||
runtime::SizeType32 vocabSize{0};
|
||||
//! maximum sequence length of output ids
|
||||
runtime::SizeType32 maxSeqLen{0};
|
||||
//! maximum number of medusa heads
|
||||
runtime::SizeType32 maxDraftPathLen{0};
|
||||
//! maximum number of tokens per step configured in the system
|
||||
runtime::SizeType32 maxDecodingTokens{0};
|
||||
//! stream
|
||||
cudaStream_t stream;
|
||||
|
||||
void checkParams() const
|
||||
{
|
||||
TLLM_CHECK(outputIds);
|
||||
TLLM_CHECK(draftIds);
|
||||
TLLM_CHECK(targetIds);
|
||||
TLLM_CHECK(acceptedLengths);
|
||||
TLLM_CHECK(paths);
|
||||
TLLM_CHECK(bestPathIds);
|
||||
TLLM_CHECK((curTokensPerStep == nullptr) ^ (targetTokensPerStep == nullptr) == 0);
|
||||
TLLM_CHECK((medusaLogits == nullptr) ^ (logitsPtrs == nullptr) == 0);
|
||||
|
||||
TLLM_CHECK(batchSize > 0);
|
||||
TLLM_CHECK(batchSize <= maxBatchSize);
|
||||
TLLM_CHECK(vocabSize > 0);
|
||||
TLLM_CHECK(maxSeqLen > 0);
|
||||
TLLM_CHECK(maxDraftPathLen > 0);
|
||||
TLLM_CHECK(maxDecodingTokens > 0);
|
||||
}
|
||||
};
|
||||
|
||||
//! \brief verifies draft medusa tokens given target tokens. Modifies outputIds tensor accordingly filling it with
|
||||
//! accepted tokens. Fills logitsPtrs tensor with the pointers to the respective medusa logits tensor according
|
||||
//! to the next after the last accepted token.
|
||||
//!
|
||||
//! \param outputIds output buffer [maxBatchSize, maxSeqLen], input tokens.
|
||||
//! \param draftIds input buffer [maxBatchSize, maxDecodingTokens], draft tokens
|
||||
//! \param targetIds input buffer [maxBatchSize, maxDecodingTokens], tokens predicted from the target medusa head
|
||||
//! \param sequenceLengths input/output buffer [maxBatchSize], length of the data in outputIds without draft tokens
|
||||
//! Incrememnted according to the accepted length
|
||||
//! \param acceptedLengths output buffer [maxBatchSize], length of the data accepted tokens
|
||||
//! \param finishedFinal input buffer [maxBatchSize], finished states per request
|
||||
//! \param batchSlots input buffer [batchSize], address map from local index
|
||||
//! to global index [0, batchSize] -> [0, maxBatchSize]
|
||||
//! \param paths input buffer [maxBatchSize, maxDecodingTokens, maxNumHeads+1],
|
||||
//! paths to restore sequences from outputIds and targetIds. Should be filled with -1 for everything that is not path.
|
||||
//! \param endIds input buffer [maxBatchSize], EOS ids per request
|
||||
//! \param medusaLogits input buffer [maxNumHeads, maxBatchSize, maxDecodingTokens, vocabSize], pointer
|
||||
//! to the logits from medusa heads
|
||||
//! \param logitsPtrs output buffer [batchSize, maxNumHeads], contains pointers to the
|
||||
//! respective rows of the medusaLogits for the next after the accepted token
|
||||
//! \param curTokensPerStep current tokens to compute per step will be updated to
|
||||
//! targetTokensPerStep if curTokensPerStep == 1
|
||||
//! \param targetTokensPerStep target values of tokens to compute per step
|
||||
//! \param bestPathIds output buffer [maxBatchSize], indices of the selected paths
|
||||
//! \param batchSize current batch size
|
||||
//! \param maxBatchSize maximum batch size
|
||||
//! \param vocabSize vocab size
|
||||
//! \param maxSeqLen maximum sequence length of output ids
|
||||
//! \param maxNumHeads maximum number of medusa heads
|
||||
//! \param maxDecodingTokens maximum number of tokens per step configured in the system
|
||||
//! \param stream stream
|
||||
template <typename T>
|
||||
void acceptDraftTokensByIdsWithPaths(runtime::TokenIdType* outputIds, runtime::TokenIdType const* draftIds,
|
||||
runtime::TokenIdType const* targetIds, runtime::SizeType32* sequenceLengths, runtime::SizeType32* acceptedLengths,
|
||||
FinishedState* finishedFinal, runtime::SizeType32 const* batchSlots, runtime::SizeType32 const* paths,
|
||||
runtime::TokenIdType const* endIds, T const** medusaLogits, T const** logitsPtrs,
|
||||
runtime::SizeType32* curTokensPerStep, runtime::SizeType32 const* targetTokensPerStep,
|
||||
runtime::SizeType32* bestPathIds, runtime::SizeType32 batchSize, runtime::SizeType32 maxBatchSize,
|
||||
runtime::SizeType32 vocabSize, runtime::SizeType32 maxSeqLen, runtime::SizeType32 maxNumHeads,
|
||||
runtime::SizeType32 maxDecodingTokens, cudaStream_t stream);
|
||||
void acceptDraftTokensByIdsWithPaths(AcceptDraftTokensByIdsWithPathsParams<T> const&);
|
||||
|
||||
//! \brief assembles draft tokens to treeDraftIds from sourceDraftIds using indices of treeIds
|
||||
//!
|
||||
|
||||
@ -507,15 +507,12 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
|
||||
VecType k_to_cache = params.position_shift_enabled ? k_wo_pos : k;
|
||||
|
||||
auto const dst_q_idx = static_cast<size_t>(global_token_idx) * params.q_hidden_size + hidden_idx;
|
||||
QuantizedEltType* quantized_q_ptr = STORE_QKV
|
||||
? reinterpret_cast<QuantizedEltType*>(params.QuantizedQKV) + src_q_idx
|
||||
: reinterpret_cast<QuantizedEltType*>(params.Q) + dst_q_idx;
|
||||
VecType* q_ptr = STORE_QKV ? reinterpret_ptr<T, VecType>(params.QKV, src_q_idx)
|
||||
: reinterpret_ptr<T, VecType>(params.Q, dst_q_idx);
|
||||
|
||||
// Cast float scale to dst data type.
|
||||
using TScale = typename mmha::kv_cache_scale_type_t<T, TCache>::Type;
|
||||
TScale scaleOrigQuant;
|
||||
[[maybe_unused]] TScale scaleOrigQuant;
|
||||
if constexpr (FP8_OUTPUT || ENABLE_8BITS_CACHE)
|
||||
{
|
||||
mmha::convert_from_float(
|
||||
@ -525,6 +522,9 @@ __global__ void applyBiasRopeUpdateKVCache(QKVPreprocessingParams<T, KVCacheBuff
|
||||
if constexpr (FP8_OUTPUT)
|
||||
{
|
||||
// Quant the vec to fp8 vec with the scale.
|
||||
QuantizedEltType* quantized_q_ptr = STORE_QKV
|
||||
? reinterpret_cast<QuantizedEltType*>(params.QuantizedQKV) + src_q_idx
|
||||
: reinterpret_cast<QuantizedEltType*>(params.Q) + dst_q_idx;
|
||||
mmha::store_8bits_vec(quantized_q_ptr, q, 0, scaleOrigQuant);
|
||||
}
|
||||
else
|
||||
@ -813,15 +813,12 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
|
||||
if (valid_token)
|
||||
{
|
||||
auto const dst_q_idx = static_cast<size_t>(global_token_idx) * params.q_hidden_size + hidden_idx;
|
||||
QuantizedEltType* quantized_q_ptr = STORE_QKV
|
||||
? reinterpret_cast<QuantizedEltType*>(params.QuantizedQKV) + src_q_idx
|
||||
: reinterpret_cast<QuantizedEltType*>(params.Q) + dst_q_idx;
|
||||
VecT* q_ptr = STORE_QKV ? reinterpret_ptr<T, VecT>(params.QKV, src_q_idx)
|
||||
: reinterpret_ptr<T, VecT>(params.Q, dst_q_idx);
|
||||
|
||||
// Cast float scale to dst data type.
|
||||
using TScale = typename mmha::kv_cache_scale_type_t<T, TCache>::Type;
|
||||
TScale scaleOrigQuant;
|
||||
[[maybe_unused]] TScale scaleOrigQuant;
|
||||
if constexpr (FP8_OUTPUT || ENABLE_8BITS_CACHE)
|
||||
{
|
||||
mmha::convert_from_float(&scaleOrigQuant, params.kvScaleOrigQuant ? params.kvScaleOrigQuant[0] : 1.0f);
|
||||
@ -830,6 +827,9 @@ __global__ void applyBiasRopeUpdateKVCacheV2(QKVPreprocessingParams<T, KVCacheBu
|
||||
if constexpr (FP8_OUTPUT)
|
||||
{
|
||||
// Quant the vec to fp8 vec with the scale.
|
||||
QuantizedEltType* quantized_q_ptr = STORE_QKV
|
||||
? reinterpret_cast<QuantizedEltType*>(params.QuantizedQKV) + src_q_idx
|
||||
: reinterpret_cast<QuantizedEltType*>(params.Q) + dst_q_idx;
|
||||
mmha::store_8bits_vec(quantized_q_ptr, q, 0, scaleOrigQuant);
|
||||
}
|
||||
else
|
||||
|
||||
@ -32,6 +32,8 @@ namespace weight_only
|
||||
{
|
||||
enum class KernelType
|
||||
{
|
||||
FP16Int8Groupwise,
|
||||
BF16Int8Groupwise,
|
||||
FP16Int4Groupwise,
|
||||
BF16Int4Groupwise,
|
||||
FP16Int8PerChannel,
|
||||
@ -49,6 +51,8 @@ struct kernel_type_traits;
|
||||
static constexpr bool isGroupwise = _isGroupwise; \
|
||||
static constexpr bool isInt4 = _isInt4; \
|
||||
};
|
||||
KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int8Groupwise, true, false);
|
||||
KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int8Groupwise, true, false);
|
||||
KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int4Groupwise, true, true);
|
||||
KERNEL_TYPE_TRAITS_REGISTRY(KernelType::BF16Int4Groupwise, true, true);
|
||||
KERNEL_TYPE_TRAITS_REGISTRY(KernelType::FP16Int8PerChannel, false, false);
|
||||
|
||||
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true, 64);
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajor, false, 64);
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/kernelDispatcher.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
INSTANTIATE_WEIGHT_ONLY_CUDA_DISPATCHERS(
|
||||
KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true, 64);
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -61,6 +61,8 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s)
|
||||
{
|
||||
EXEC_W4A8(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
|
||||
}
|
||||
EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true);
|
||||
EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true);
|
||||
EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
|
||||
EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajorInterleaved, true);
|
||||
EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajorInterleaved, true);
|
||||
@ -70,6 +72,8 @@ inline void kernel_launcher(int arch, Params& params, cudaStream_t s)
|
||||
}
|
||||
else if (arch >= 90)
|
||||
{
|
||||
EXEC(KernelType::FP16Int8Groupwise, FP16DetailsA, Int8DetailsW, ColumnMajor, false);
|
||||
EXEC(KernelType::BF16Int8Groupwise, BF16DetailsA, Int8DetailsW, ColumnMajor, false);
|
||||
EXEC(KernelType::FP16Int4Groupwise, FP16DetailsA, Int4DetailsW, ColumnMajor, false);
|
||||
EXEC(KernelType::BF16Int4Groupwise, BF16DetailsA, Int4DetailsW, ColumnMajor, false);
|
||||
EXEC(KernelType::FP16Int8PerChannel, FP16DetailsA, Int8DetailsW, ColumnMajor, false);
|
||||
@ -98,6 +102,8 @@ inline bool is_supported(int arch, KernelType kernel_type)
|
||||
}
|
||||
else if (arch >= 80 && arch < 90)
|
||||
{
|
||||
SUPPORT(KernelType::FP16Int8Groupwise);
|
||||
SUPPORT(KernelType::BF16Int8Groupwise);
|
||||
SUPPORT(KernelType::FP16Int4Groupwise);
|
||||
SUPPORT(KernelType::BF16Int4Groupwise);
|
||||
SUPPORT(KernelType::FP16Int8PerChannel);
|
||||
@ -107,6 +113,8 @@ inline bool is_supported(int arch, KernelType kernel_type)
|
||||
}
|
||||
else if (arch >= 90)
|
||||
{
|
||||
SUPPORT(KernelType::FP16Int8Groupwise);
|
||||
SUPPORT(KernelType::BF16Int8Groupwise);
|
||||
SUPPORT(KernelType::FP16Int4Groupwise);
|
||||
SUPPORT(KernelType::BF16Int4Groupwise);
|
||||
SUPPORT(KernelType::FP16Int8PerChannel);
|
||||
|
||||
@ -431,7 +431,7 @@ void ExternalDraftTokensLayer<T>::getAllTopKs(std::shared_ptr<BaseDecodingOutput
|
||||
params.maxBatchSize = mDecoderDomain.getBatchSize();
|
||||
params.maxTokensPerStep = 1;
|
||||
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
|
||||
params.returnAllTopK = true;
|
||||
params.returnAllSelectedTokens = true;
|
||||
params.maxSeqLen = mDecoderDomain.getVocabSizePadded(); // workaround for returning all topKs with outputIds
|
||||
params.logitsHasProbs = inputs->probsComputed;
|
||||
|
||||
@ -475,7 +475,7 @@ void ExternalDraftTokensLayer<T>::getAllTopPs(std::shared_ptr<BaseDecodingOutput
|
||||
params.batchSize = batchSize;
|
||||
params.maxBatchSize = mDecoderDomain.getBatchSize();
|
||||
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
|
||||
params.returnAllTopP = true;
|
||||
params.returnAllSelectedTokens = true;
|
||||
params.maxSeqLen = mDecoderDomain.getVocabSizePadded();
|
||||
|
||||
invokeBatchTopPSampling<T>(params, getStream());
|
||||
|
||||
@ -76,6 +76,8 @@ LookaheadDecodingLayer<T>::CpuAlgorithmResources::CpuAlgorithmResources(DecoderD
|
||||
ITensor::makeShape({maxTokensPerStep, maxBatchSize, beamWidth}), nvinfer1::DataType::kINT32);
|
||||
mPathsOffsets
|
||||
= BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxAcceptedDraftLen}), nvinfer1::DataType::kINT32);
|
||||
mPathsOffsetsBatch
|
||||
= BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxAcceptedDraftLen}), nvinfer1::DataType::kINT32);
|
||||
mNumNewTokens = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
mNumNewTokensCumSum = BufferManager::cpu(ITensor::makeShape({maxBatchSize + 1}), nvinfer1::DataType::kINT32);
|
||||
mNextDraftTokens = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32);
|
||||
@ -220,7 +222,7 @@ void LookaheadDecodingLayer<T>::forwardAsync(std::shared_ptr<BaseDecodingOutputs
|
||||
params.maxBatchSize = mDecoderDomain.getBatchSize();
|
||||
params.batchSize = batchSize;
|
||||
params.maxTopK = 1;
|
||||
params.returnAllTopK = true;
|
||||
params.returnAllSelectedTokens = true;
|
||||
params.maxTokensPerStep = mDecoderDomain.getMaxDecodingTokens();
|
||||
params.maxSeqLen = mDecoderDomain.getMaxDecodingTokens();
|
||||
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
|
||||
@ -321,6 +323,7 @@ void LookaheadDecodingLayer<T>::forwardSyncCPU(
|
||||
BufferRange<SizeType32> nextDraftLengthsRange(*mCpuAlgo->mNextDraftLengths);
|
||||
BufferRange<SizeType32> sequenceLengthsRange(*mCpuAlgo->mSequenceLengths);
|
||||
BufferLocation<SizeType32> pathsOffsetLocation(*mCpuAlgo->mPathsOffsets);
|
||||
BufferLocation<SizeType32> pathsOffsetBatchLocation(*mCpuAlgo->mPathsOffsetsBatch);
|
||||
BufferLocation<TokenIdType> outputIdsLocation(*mCpuAlgo->mOutputIds);
|
||||
|
||||
mBufferManager->setZero(*mCpuAlgo->mPathsOffsets);
|
||||
@ -394,20 +397,22 @@ void LookaheadDecodingLayer<T>::forwardSyncCPU(
|
||||
D(accepted).values().c_str(), D(draft).values().c_str());
|
||||
}
|
||||
|
||||
numNewTokensCumSumRange[0] = 0;
|
||||
SizeType32 pi = 0;
|
||||
for (SizeType32 bi = 0; bi < numNewTokensRange.size(); bi++)
|
||||
numNewTokensCumSumRange[0] = 0;
|
||||
for (SizeType32 bi = 0; bi < batchSize; bi++)
|
||||
{
|
||||
SizeType32 acceptedDraftLen = numNewTokensRange[bi] <= 1 ? 0 : (numNewTokensRange[bi] - 1);
|
||||
SizeType32 gbi = batchSlotsRange[bi];
|
||||
SizeType32 acceptedDraftLen = numNewTokensRange[gbi] <= 1 ? 0 : (numNewTokensRange[gbi] - 1);
|
||||
numNewTokensCumSumRange[bi + 1] = numNewTokensCumSumRange[bi] + acceptedDraftLen;
|
||||
for (SizeType32 tj = 0; tj < acceptedDraftLen; tj++)
|
||||
{
|
||||
pathsOffsetLocation[pi++] = pathsOffsetLocation.at(bi, tj);
|
||||
pathsOffsetBatchLocation[pi++] = pathsOffsetLocation.at(gbi, tj);
|
||||
}
|
||||
}
|
||||
for (; pi < pathsOffsetLocation.size(); pi++)
|
||||
|
||||
for (; pi < pathsOffsetBatchLocation.size(); pi++)
|
||||
{
|
||||
pathsOffsetLocation[pi++] = 0;
|
||||
pathsOffsetBatchLocation[pi++] = 0;
|
||||
}
|
||||
|
||||
TLLM_CHECK(outputs->numNewTokens);
|
||||
@ -415,8 +420,8 @@ void LookaheadDecodingLayer<T>::forwardSyncCPU(
|
||||
mBufferManager->copy(*mCpuAlgo->mSequenceLengths, *outputs->sequenceLength.value());
|
||||
mBufferManager->copy(*mCpuAlgo->mNewTokens, *outputs->newTokens);
|
||||
|
||||
mBufferManager->copy(*mCpuAlgo->mPathsOffsets, *outputs->pathsOffsets);
|
||||
mBufferManager->copy(*mCpuAlgo->mNumNewTokens, *outputs->numNewTokens.value());
|
||||
mBufferManager->copy(*mCpuAlgo->mPathsOffsetsBatch, *outputs->pathsOffsets);
|
||||
mBufferManager->copy(*mCpuAlgo->mNumNewTokensCumSum, *outputs->numNewTokensCumSum); //
|
||||
mBufferManager->copy(*mCpuAlgo->mNextDraftTokens, *outputs->nextDraftTokens);
|
||||
|
||||
|
||||
@ -70,6 +70,7 @@ private:
|
||||
|
||||
TensorPtr mOutputIds;
|
||||
TensorPtr mPathsOffsets;
|
||||
TensorPtr mPathsOffsetsBatch;
|
||||
TensorPtr mNumNewTokens;
|
||||
TensorPtr mNumNewTokensCumSum;
|
||||
TensorPtr mNewTokens;
|
||||
|
||||
@ -329,11 +329,33 @@ void MedusaDecodingLayer<T>::acceptDraftTokens(SpeculativeDecodingOutputs const&
|
||||
auto medusaInputLogitsPtrsPtr = reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaInputLogitsPtrs));
|
||||
auto medusaSelectedLogitsPtrsDevicePtr
|
||||
= const_cast<T const**>(bufferCastOrNull<T const*>(mMedusaSelectedLogitsPtrsDevice));
|
||||
acceptDraftTokensByIdsWithPaths(outputIds, draftIds, targetTokensDevicePtr, sequenceLengths, numNewTokens,
|
||||
finishedStatesPtr, workspace->getDeviceBatchSlotsPtr(), paths, endIds, medusaInputLogitsPtrsPtr,
|
||||
medusaSelectedLogitsPtrsDevicePtr, curTokensPerStepDevice, targetTokensPerStepDevice, bestPathIdsDevicePtr,
|
||||
batchSize, mDecoderDomain.getVocabSize(), mDecoderDomain.getBatchSize(), maxSeqLen, maxDraftPathLen,
|
||||
mDecoderDomain.getMaxDecodingTokens(), getStream());
|
||||
|
||||
AcceptDraftTokensByIdsWithPathsParams<T> params;
|
||||
params.outputIds = outputIds;
|
||||
params.draftIds = draftIds;
|
||||
params.targetIds = targetTokensDevicePtr;
|
||||
params.sequenceLengths = sequenceLengths;
|
||||
params.acceptedLengths = numNewTokens;
|
||||
params.finishedFinal = finishedStatesPtr;
|
||||
params.batchSlots = workspace->getDeviceBatchSlotsPtr();
|
||||
params.paths = paths;
|
||||
params.endIds = endIds;
|
||||
params.medusaLogits = medusaInputLogitsPtrsPtr;
|
||||
params.logitsPtrs = medusaSelectedLogitsPtrsDevicePtr;
|
||||
params.curTokensPerStep = curTokensPerStepDevice;
|
||||
params.targetTokensPerStep = targetTokensPerStepDevice;
|
||||
params.bestPathIds = bestPathIdsDevicePtr;
|
||||
params.batchSize = batchSize;
|
||||
params.maxBatchSize = mDecoderDomain.getBatchSize();
|
||||
params.vocabSize = mDecoderDomain.getVocabSize();
|
||||
params.maxSeqLen = maxSeqLen;
|
||||
params.maxDraftPathLen = maxDraftPathLen;
|
||||
params.maxDecodingTokens = mDecoderDomain.getMaxDecodingTokens();
|
||||
params.stream = getStream();
|
||||
|
||||
params.checkParams();
|
||||
|
||||
acceptDraftTokensByIdsWithPaths(params);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -390,7 +412,7 @@ void MedusaDecodingLayer<T>::sampleNewDraftTokens(SpeculativeDecodingOutputs con
|
||||
params.maxBatchSize = maxBatchSizeHeadNums;
|
||||
params.maxTokensPerStep = 1;
|
||||
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
|
||||
params.returnAllTopK = true;
|
||||
params.returnAllSelectedTokens = true;
|
||||
|
||||
invokeBatchTopKSampling(params, getStream());
|
||||
|
||||
|
||||
@ -54,7 +54,8 @@ set(PLUGIN_LISTS
|
||||
mambaConv1dPlugin
|
||||
lruPlugin
|
||||
cumsumLastDimPlugin
|
||||
lowLatencyGemmPlugin)
|
||||
lowLatencyGemmPlugin
|
||||
eaglePlugin)
|
||||
|
||||
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
|
||||
include_directories(${PLUGIN_ITER})
|
||||
|
||||
@ -39,6 +39,9 @@
|
||||
#include "tensorrt_llm/plugins/ncclPlugin/sendPlugin.h"
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
#include "tensorrt_llm/plugins/cumsumLastDimPlugin/cumsumLastDimPlugin.h"
|
||||
#include "tensorrt_llm/plugins/eaglePlugin/eagleDecodeDraftTokensPlugin.h"
|
||||
#include "tensorrt_llm/plugins/eaglePlugin/eaglePrepareDrafterInputsPlugin.h"
|
||||
#include "tensorrt_llm/plugins/eaglePlugin/eagleSampleAndAcceptDraftTokensPlugin.h"
|
||||
#include "tensorrt_llm/plugins/lowLatencyGemmPlugin/lowLatencyGemmPlugin.h"
|
||||
#include "tensorrt_llm/plugins/quantizePerTokenPlugin/quantizePerTokenPlugin.h"
|
||||
#include "tensorrt_llm/plugins/quantizeTensorPlugin/quantizeTensorPlugin.h"
|
||||
@ -201,6 +204,10 @@ extern "C"
|
||||
static tensorrt_llm::plugins::lruPluginCreator lruPluginCreator;
|
||||
static tensorrt_llm::plugins::CumsumLastDimPluginCreator cumsumLastDimPluginCreator;
|
||||
static tensorrt_llm::plugins::LowLatencyGemmPluginCreator lowLatencyGemmPluginCreator;
|
||||
static tensorrt_llm::plugins::EagleDecodeDraftTokensPluginCreator eagleDecodeDraftTokensPluginCreator;
|
||||
static tensorrt_llm::plugins::EagleSampleAndAcceptDraftTokensPluginCreator
|
||||
eagleSampleAndAcceptDraftTokensPluginCreator;
|
||||
static tensorrt_llm::plugins::EaglePrepareDrafterInputsPluginCreator eaglePrepareDrafterInputsPluginCreator;
|
||||
|
||||
static std::array pluginCreators
|
||||
= { creatorPtr(identityPluginCreator),
|
||||
@ -231,6 +238,9 @@ extern "C"
|
||||
creatorPtr(lruPluginCreator),
|
||||
creatorPtr(cumsumLastDimPluginCreator),
|
||||
creatorPtr(lowLatencyGemmPluginCreator),
|
||||
creatorPtr(eagleDecodeDraftTokensPluginCreator),
|
||||
creatorPtr(eagleSampleAndAcceptDraftTokensPluginCreator),
|
||||
creatorPtr(eaglePrepareDrafterInputsPluginCreator),
|
||||
};
|
||||
nbCreators = pluginCreators.size();
|
||||
return pluginCreators.data();
|
||||
|
||||
21
cpp/tensorrt_llm/plugins/eaglePlugin/CMakeLists.txt
Normal file
21
cpp/tensorrt_llm/plugins/eaglePlugin/CMakeLists.txt
Normal file
@ -0,0 +1,21 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
#
|
||||
file(GLOB SRCS *.cpp)
|
||||
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
|
||||
set(PLUGIN_SOURCES
|
||||
${PLUGIN_SOURCES}
|
||||
PARENT_SCOPE)
|
||||
@ -0,0 +1,228 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "eagleDecodeDraftTokensPlugin.h"
|
||||
|
||||
using namespace nvinfer1;
|
||||
using tensorrt_llm::plugins::EagleDecodeDraftTokensPluginCreator;
|
||||
using tensorrt_llm::plugins::EagleDecodeDraftTokensPlugin;
|
||||
|
||||
static char const* EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_VERSION{"1"};
|
||||
static char const* EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_NAME{"EagleDecodeDraftTokens"};
|
||||
PluginFieldCollection EagleDecodeDraftTokensPluginCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField> EagleDecodeDraftTokensPluginCreator::mPluginAttributes;
|
||||
|
||||
EagleDecodeDraftTokensPlugin::EagleDecodeDraftTokensPlugin(nvinfer1::DataType type, int32_t layerIdx)
|
||||
: mDtype(type)
|
||||
, mLayerIdx(layerIdx)
|
||||
{
|
||||
}
|
||||
|
||||
// Parameterized constructor
|
||||
EagleDecodeDraftTokensPlugin::EagleDecodeDraftTokensPlugin(void const* data, size_t length)
|
||||
{
|
||||
char const *d = reinterpret_cast<char const*>(data), *a = d;
|
||||
read(d, mDtype);
|
||||
read(d, mLayerIdx);
|
||||
TLLM_CHECK_WITH_INFO(d == a + length,
|
||||
"Expected length (%d) != real length (%d). This is often "
|
||||
"caused by using different TensorRT-LLM version to build "
|
||||
"engine and run engine.",
|
||||
static_cast<int>(length), static_cast<int>(d - a));
|
||||
}
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* EagleDecodeDraftTokensPlugin::clone() const noexcept
|
||||
{
|
||||
auto* plugin = new EagleDecodeDraftTokensPlugin(*this);
|
||||
plugin->setPluginNamespace(mNamespace.c_str());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs EagleDecodeDraftTokensPlugin::getOutputDimensions(
|
||||
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
|
||||
{
|
||||
TLLM_CHECK(outputIndex < 2);
|
||||
TLLM_CHECK(nbInputs == 5);
|
||||
return inputs[outputIndex + 1];
|
||||
}
|
||||
|
||||
bool EagleDecodeDraftTokensPlugin::supportsFormatCombination(
|
||||
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
|
||||
{
|
||||
if (pos == 0) // logits
|
||||
{
|
||||
return (inOut[pos].type == mDtype) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
else if (pos == 3) // rand_data_sample
|
||||
{
|
||||
return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
else // next_draft_tokens, next_draft_lens, paths, tree_indices
|
||||
{
|
||||
return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
}
|
||||
|
||||
void EagleDecodeDraftTokensPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
|
||||
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
|
||||
{
|
||||
}
|
||||
|
||||
size_t EagleDecodeDraftTokensPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
|
||||
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
int EagleDecodeDraftTokensPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) noexcept
|
||||
{
|
||||
// TODO fill me
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType EagleDecodeDraftTokensPlugin::getOutputDataType(
|
||||
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
|
||||
{
|
||||
TLLM_CHECK(index < 2);
|
||||
return inputTypes[index + 1];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
|
||||
char const* EagleDecodeDraftTokensPlugin::getPluginType() const noexcept
|
||||
{
|
||||
return EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
char const* EagleDecodeDraftTokensPlugin::getPluginVersion() const noexcept
|
||||
{
|
||||
return EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int EagleDecodeDraftTokensPlugin::getNbOutputs() const noexcept
|
||||
{
|
||||
return 2;
|
||||
}
|
||||
|
||||
int EagleDecodeDraftTokensPlugin::initialize() noexcept
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
void EagleDecodeDraftTokensPlugin::terminate() noexcept {}
|
||||
|
||||
size_t EagleDecodeDraftTokensPlugin::getSerializationSize() const noexcept
|
||||
{
|
||||
return sizeof(mDtype) + sizeof(mLayerIdx);
|
||||
}
|
||||
|
||||
void EagleDecodeDraftTokensPlugin::serialize(void* buffer) const noexcept
|
||||
{
|
||||
char *d = static_cast<char*>(buffer), *a = d;
|
||||
write(d, mLayerIdx);
|
||||
write(d, mDtype);
|
||||
assert(d == a + getSerializationSize());
|
||||
}
|
||||
|
||||
void EagleDecodeDraftTokensPlugin::destroy() noexcept
|
||||
{
|
||||
// This gets called when the network containing plugin is destroyed
|
||||
delete this;
|
||||
}
|
||||
|
||||
///////////////
|
||||
|
||||
EagleDecodeDraftTokensPluginCreator::EagleDecodeDraftTokensPluginCreator()
|
||||
{
|
||||
// Fill PluginFieldCollection with PluginField arguments metadata
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("layer_idx", nullptr, PluginFieldType::kINT32, 0));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
char const* EagleDecodeDraftTokensPluginCreator::getPluginName() const noexcept
|
||||
{
|
||||
return EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
char const* EagleDecodeDraftTokensPluginCreator::getPluginVersion() const noexcept
|
||||
{
|
||||
return EAGLE_DECODE_DRAFT_TOKENS_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
PluginFieldCollection const* EagleDecodeDraftTokensPluginCreator::getFieldNames() noexcept
|
||||
{
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
IPluginV2* EagleDecodeDraftTokensPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
|
||||
{
|
||||
PluginField const* fields = fc->fields;
|
||||
int32_t layerIdx;
|
||||
nvinfer1::DataType type;
|
||||
// Read configurations from each fields
|
||||
for (int i = 0; i < fc->nbFields; ++i)
|
||||
{
|
||||
char const* attrName = fields[i].name;
|
||||
if (!strcmp(attrName, "layer_idx"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
layerIdx = *static_cast<int32_t const*>(fields[i].data);
|
||||
}
|
||||
else if (!strcmp(attrName, "type_id"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
type = static_cast<nvinfer1::DataType>(*(static_cast<nvinfer1::DataType const*>(fields[i].data)));
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
auto* obj = new EagleDecodeDraftTokensPlugin(type, layerIdx);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
caughtError(e);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
IPluginV2* EagleDecodeDraftTokensPluginCreator::deserializePlugin(
|
||||
char const* name, void const* serialData, size_t serialLength) noexcept
|
||||
{
|
||||
// This object will be deleted when the network is destroyed, which will
|
||||
// call EagleDecodeDraftTokensPlugin::destroy()
|
||||
try
|
||||
{
|
||||
auto* obj = new EagleDecodeDraftTokensPlugin(serialData, serialLength);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
caughtError(e);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -0,0 +1,90 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
#include <cassert>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::plugins
|
||||
{
|
||||
|
||||
class EagleDecodeDraftTokensPlugin : public BasePlugin
|
||||
{
|
||||
public:
|
||||
EagleDecodeDraftTokensPlugin(nvinfer1::DataType type, int32_t layerIdx);
|
||||
|
||||
EagleDecodeDraftTokensPlugin(void const* data, size_t length);
|
||||
|
||||
~EagleDecodeDraftTokensPlugin() override = default;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
|
||||
bool supportsFormatCombination(
|
||||
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override;
|
||||
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
|
||||
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override;
|
||||
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
|
||||
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override;
|
||||
int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
|
||||
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(
|
||||
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
char const* getPluginType() const noexcept override;
|
||||
char const* getPluginVersion() const noexcept override;
|
||||
int getNbOutputs() const noexcept override;
|
||||
int initialize() noexcept override;
|
||||
void terminate() noexcept override;
|
||||
size_t getSerializationSize() const noexcept override;
|
||||
void serialize(void* buffer) const noexcept override;
|
||||
void destroy() noexcept override;
|
||||
|
||||
private:
|
||||
nvinfer1::DataType mDtype;
|
||||
int32_t mLayerIdx;
|
||||
};
|
||||
|
||||
class EagleDecodeDraftTokensPluginCreator : public BaseCreator
|
||||
{
|
||||
public:
|
||||
EagleDecodeDraftTokensPluginCreator();
|
||||
|
||||
char const* getPluginName() const noexcept override;
|
||||
|
||||
char const* getPluginVersion() const noexcept override;
|
||||
|
||||
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
|
||||
|
||||
nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
|
||||
|
||||
nvinfer1::IPluginV2* deserializePlugin(
|
||||
char const* name, void const* serialData, size_t serialLength) noexcept override;
|
||||
|
||||
private:
|
||||
static nvinfer1::PluginFieldCollection mFC;
|
||||
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::plugins
|
||||
@ -0,0 +1,272 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "eaglePrepareDrafterInputsPlugin.h"
|
||||
|
||||
using namespace nvinfer1;
|
||||
using tensorrt_llm::plugins::EaglePrepareDrafterInputsPluginCreator;
|
||||
using tensorrt_llm::plugins::EaglePrepareDrafterInputsPlugin;
|
||||
|
||||
static char const* EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_VERSION{"1"};
|
||||
static char const* EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_NAME{"EaglePrepareDrafterInputs"};
|
||||
PluginFieldCollection EaglePrepareDrafterInputsPluginCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField> EaglePrepareDrafterInputsPluginCreator::mPluginAttributes;
|
||||
|
||||
EaglePrepareDrafterInputsPlugin::EaglePrepareDrafterInputsPlugin(nvinfer1::DataType type, int32_t layerIdx)
|
||||
: mDtype(type)
|
||||
, mLayerIdx(layerIdx)
|
||||
{
|
||||
}
|
||||
|
||||
// Parameterized constructor
|
||||
EaglePrepareDrafterInputsPlugin::EaglePrepareDrafterInputsPlugin(void const* data, size_t length)
|
||||
{
|
||||
char const *d = reinterpret_cast<char const*>(data), *a = d;
|
||||
read(d, mDtype);
|
||||
read(d, mLayerIdx);
|
||||
TLLM_CHECK_WITH_INFO(d == a + length,
|
||||
"Expected length (%d) != real length (%d). This is often "
|
||||
"caused by using different TensorRT-LLM version to build "
|
||||
"engine and run engine.",
|
||||
static_cast<int>(length), static_cast<int>(d - a));
|
||||
}
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* EaglePrepareDrafterInputsPlugin::clone() const noexcept
|
||||
{
|
||||
auto* plugin = new EaglePrepareDrafterInputsPlugin(*this);
|
||||
plugin->setPluginNamespace(mNamespace.c_str());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs EaglePrepareDrafterInputsPlugin::getOutputDimensions(
|
||||
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
|
||||
{
|
||||
TLLM_CHECK(outputIndex < 10);
|
||||
TLLM_CHECK(nbInputs == 7);
|
||||
auto const batchSizeExpr = inputs[nbInputs - 2].d[0];
|
||||
auto const maxDraftLenExpr = inputs[nbInputs - 2].d[1];
|
||||
|
||||
nvinfer1::DimsExprs ret;
|
||||
switch (outputIndex)
|
||||
{
|
||||
case 0: // sequence_length
|
||||
case 1: // host_request_types
|
||||
case 2: // host_past_key_value_lengths
|
||||
ret = inputs[outputIndex];
|
||||
break;
|
||||
case 3: // spec_decoding_generation_lengths
|
||||
ret.nbDims = 1;
|
||||
ret.d[0] = batchSizeExpr;
|
||||
break;
|
||||
case 4: // spec_decoding_position_offsets
|
||||
case 5: // input_ids
|
||||
case 6: // position_ids
|
||||
// FIXME input_ids should have real value, not maxDraftLen
|
||||
ret.nbDims = 1;
|
||||
ret.d[0] = maxDraftLenExpr;
|
||||
break;
|
||||
case 7: // spec_decoding_packed_mask
|
||||
// FIXME
|
||||
ret.nbDims = 3;
|
||||
ret.d[0] = batchSizeExpr;
|
||||
ret.d[1] = maxDraftLenExpr;
|
||||
ret.d[2] = exprBuilder.operation(DimensionOperation::kCEIL_DIV, *maxDraftLenExpr, *exprBuilder.constant(32));
|
||||
break;
|
||||
case 8: // hidden_dim
|
||||
ret.nbDims = 2;
|
||||
// FIXME real dim instead of max draft len
|
||||
ret.d[0] = maxDraftLenExpr;
|
||||
ret.d[1] = inputs[4].d[1];
|
||||
break;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool EaglePrepareDrafterInputsPlugin::supportsFormatCombination(
|
||||
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
|
||||
{
|
||||
if (pos == nbInputs - 1 || pos == nbInputs + nbOutputs - 1) // hidden_states
|
||||
{
|
||||
return (inOut[pos].type == mDtype) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
else if (pos == 3) // kv cache pool pointers
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT64 && inOut[pos].format == TensorFormat::kLINEAR;
|
||||
}
|
||||
else // all other tensors
|
||||
{
|
||||
return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
}
|
||||
|
||||
void EaglePrepareDrafterInputsPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
|
||||
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
|
||||
{
|
||||
}
|
||||
|
||||
size_t EaglePrepareDrafterInputsPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
|
||||
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
int EaglePrepareDrafterInputsPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) noexcept
|
||||
{
|
||||
// TODO fill me
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType EaglePrepareDrafterInputsPlugin::getOutputDataType(
|
||||
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
|
||||
{
|
||||
TLLM_CHECK(index < 9);
|
||||
if (index < 8)
|
||||
{
|
||||
return inputTypes[0]; // type of sequence_length
|
||||
}
|
||||
else // hidden_states
|
||||
{
|
||||
return inputTypes[nbInputs - 1]; // type of hidden_states
|
||||
}
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
|
||||
char const* EaglePrepareDrafterInputsPlugin::getPluginType() const noexcept
|
||||
{
|
||||
return EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
char const* EaglePrepareDrafterInputsPlugin::getPluginVersion() const noexcept
|
||||
{
|
||||
return EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int EaglePrepareDrafterInputsPlugin::getNbOutputs() const noexcept
|
||||
{
|
||||
return 9;
|
||||
}
|
||||
|
||||
int EaglePrepareDrafterInputsPlugin::initialize() noexcept
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
void EaglePrepareDrafterInputsPlugin::terminate() noexcept {}
|
||||
|
||||
size_t EaglePrepareDrafterInputsPlugin::getSerializationSize() const noexcept
|
||||
{
|
||||
return sizeof(mDtype) + sizeof(mLayerIdx);
|
||||
}
|
||||
|
||||
void EaglePrepareDrafterInputsPlugin::serialize(void* buffer) const noexcept
|
||||
{
|
||||
char *d = static_cast<char*>(buffer), *a = d;
|
||||
write(d, mLayerIdx);
|
||||
write(d, mDtype);
|
||||
assert(d == a + getSerializationSize());
|
||||
}
|
||||
|
||||
void EaglePrepareDrafterInputsPlugin::destroy() noexcept
|
||||
{
|
||||
// This gets called when the network containing plugin is destroyed
|
||||
delete this;
|
||||
}
|
||||
|
||||
///////////////
|
||||
|
||||
EaglePrepareDrafterInputsPluginCreator::EaglePrepareDrafterInputsPluginCreator()
|
||||
{
|
||||
// Fill PluginFieldCollection with PluginField arguments metadata
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("layer_idx", nullptr, PluginFieldType::kINT32, 0));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
char const* EaglePrepareDrafterInputsPluginCreator::getPluginName() const noexcept
|
||||
{
|
||||
return EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
char const* EaglePrepareDrafterInputsPluginCreator::getPluginVersion() const noexcept
|
||||
{
|
||||
return EAGLE_PREPARE_DRAFTER_INPUTS_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
PluginFieldCollection const* EaglePrepareDrafterInputsPluginCreator::getFieldNames() noexcept
|
||||
{
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
IPluginV2* EaglePrepareDrafterInputsPluginCreator::createPlugin(
|
||||
char const* name, PluginFieldCollection const* fc) noexcept
|
||||
{
|
||||
PluginField const* fields = fc->fields;
|
||||
int32_t layerIdx;
|
||||
nvinfer1::DataType type;
|
||||
// Read configurations from each fields
|
||||
for (int i = 0; i < fc->nbFields; ++i)
|
||||
{
|
||||
char const* attrName = fields[i].name;
|
||||
if (!strcmp(attrName, "layer_idx"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
layerIdx = *static_cast<int32_t const*>(fields[i].data);
|
||||
}
|
||||
else if (!strcmp(attrName, "type_id"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
type = static_cast<nvinfer1::DataType>(*(static_cast<nvinfer1::DataType const*>(fields[i].data)));
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
auto* obj = new EaglePrepareDrafterInputsPlugin(type, layerIdx);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
caughtError(e);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
IPluginV2* EaglePrepareDrafterInputsPluginCreator::deserializePlugin(
|
||||
char const* name, void const* serialData, size_t serialLength) noexcept
|
||||
{
|
||||
// This object will be deleted when the network is destroyed, which will
|
||||
// call EaglePrepareDrafterInputsPlugin::destroy()
|
||||
try
|
||||
{
|
||||
auto* obj = new EaglePrepareDrafterInputsPlugin(serialData, serialLength);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
caughtError(e);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -0,0 +1,90 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
#include <cassert>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::plugins
|
||||
{
|
||||
|
||||
class EaglePrepareDrafterInputsPlugin : public BasePlugin
|
||||
{
|
||||
public:
|
||||
EaglePrepareDrafterInputsPlugin(nvinfer1::DataType type, int32_t layerIdx);
|
||||
|
||||
EaglePrepareDrafterInputsPlugin(void const* data, size_t length);
|
||||
|
||||
~EaglePrepareDrafterInputsPlugin() override = default;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
|
||||
bool supportsFormatCombination(
|
||||
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override;
|
||||
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
|
||||
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override;
|
||||
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
|
||||
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override;
|
||||
int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
|
||||
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(
|
||||
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
char const* getPluginType() const noexcept override;
|
||||
char const* getPluginVersion() const noexcept override;
|
||||
int getNbOutputs() const noexcept override;
|
||||
int initialize() noexcept override;
|
||||
void terminate() noexcept override;
|
||||
size_t getSerializationSize() const noexcept override;
|
||||
void serialize(void* buffer) const noexcept override;
|
||||
void destroy() noexcept override;
|
||||
|
||||
private:
|
||||
nvinfer1::DataType mDtype;
|
||||
int32_t mLayerIdx;
|
||||
};
|
||||
|
||||
class EaglePrepareDrafterInputsPluginCreator : public BaseCreator
|
||||
{
|
||||
public:
|
||||
EaglePrepareDrafterInputsPluginCreator();
|
||||
|
||||
char const* getPluginName() const noexcept override;
|
||||
|
||||
char const* getPluginVersion() const noexcept override;
|
||||
|
||||
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
|
||||
|
||||
nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
|
||||
|
||||
nvinfer1::IPluginV2* deserializePlugin(
|
||||
char const* name, void const* serialData, size_t serialLength) noexcept override;
|
||||
|
||||
private:
|
||||
static nvinfer1::PluginFieldCollection mFC;
|
||||
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::plugins
|
||||
@ -0,0 +1,515 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION &
|
||||
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "eagleSampleAndAcceptDraftTokensPlugin.h"
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/dataType.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
|
||||
#include "tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h"
|
||||
#include "tensorrt_llm/kernels/speculativeDecoding/medusaDecodingKernels.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
using namespace nvinfer1;
|
||||
using tensorrt_llm::plugins::EagleSampleAndAcceptDraftTokensPluginCreator;
|
||||
using tensorrt_llm::plugins::EagleSampleAndAcceptDraftTokensPlugin;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::kernels::speculative_decoding;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
static char const* EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_VERSION{"1"};
|
||||
static char const* EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_NAME{"EagleSampleAndAcceptDraftTokens"};
|
||||
PluginFieldCollection EagleSampleAndAcceptDraftTokensPluginCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField> EagleSampleAndAcceptDraftTokensPluginCreator::mPluginAttributes;
|
||||
|
||||
EagleSampleAndAcceptDraftTokensPlugin::EagleSampleAndAcceptDraftTokensPlugin(
|
||||
nvinfer1::DataType type, bool greedySampling)
|
||||
: mDtype(type)
|
||||
, mGreedySampling(greedySampling)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mGreedySampling, "Non-greedy sampling is not supported yet.");
|
||||
}
|
||||
|
||||
// Parameterized constructor
|
||||
EagleSampleAndAcceptDraftTokensPlugin::EagleSampleAndAcceptDraftTokensPlugin(void const* data, size_t length)
|
||||
{
|
||||
char const *d = reinterpret_cast<char const*>(data), *a = d;
|
||||
read(d, mDtype);
|
||||
read(d, mGreedySampling);
|
||||
TLLM_CHECK_WITH_INFO(d == a + length,
|
||||
"Expected length (%d) != real length (%d). This is often "
|
||||
"caused by using different TensorRT-LLM version to build "
|
||||
"engine and run engine.",
|
||||
(int) length, (int) (d - a));
|
||||
}
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* EagleSampleAndAcceptDraftTokensPlugin::clone() const noexcept
|
||||
{
|
||||
auto* plugin = new EagleSampleAndAcceptDraftTokensPlugin(*this);
|
||||
plugin->setPluginNamespace(mNamespace.c_str());
|
||||
return plugin;
|
||||
}
|
||||
|
||||
nvinfer1::DimsExprs EagleSampleAndAcceptDraftTokensPlugin::getOutputDimensions(
|
||||
int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs, nvinfer1::IExprBuilder& exprBuilder) noexcept
|
||||
{
|
||||
TLLM_CHECK(nbInputs == 6);
|
||||
TLLM_CHECK(outputIndex < 7);
|
||||
auto const batchSizeExpr = inputs[getIdx(InputIdxEntry::PATHS)].d[0];
|
||||
auto const maxDecodingDraftTokensExpr = inputs[getIdx(InputIdxEntry::DRAFT_TOKEN_IDS)].d[1];
|
||||
auto const maxPathLenExpr = inputs[getIdx(InputIdxEntry::PATHS)].d[2];
|
||||
|
||||
nvinfer1::DimsExprs ret;
|
||||
switch (outputIndex)
|
||||
{
|
||||
case 0: // accepted_tokens
|
||||
ret.nbDims = 2;
|
||||
ret.d[0] = batchSizeExpr;
|
||||
ret.d[1] = maxPathLenExpr;
|
||||
break;
|
||||
case 1: // num_accepted_tokens
|
||||
ret.nbDims = 1;
|
||||
ret.d[0] = batchSizeExpr;
|
||||
break;
|
||||
case 2: // accepted_paths
|
||||
ret.nbDims = 1;
|
||||
ret.d[0] = batchSizeExpr;
|
||||
break;
|
||||
case 3: // last_accepted_tokens
|
||||
ret.nbDims = 1;
|
||||
ret.d[0] = batchSizeExpr;
|
||||
break;
|
||||
case 4: // exclusive_sum_last_accepted_indices
|
||||
ret.nbDims = 1;
|
||||
ret.d[0] = batchSizeExpr;
|
||||
break;
|
||||
case 5: // next_draft_tokens
|
||||
ret.nbDims = 2;
|
||||
ret.d[0] = batchSizeExpr;
|
||||
ret.d[1] = maxDecodingDraftTokensExpr;
|
||||
break;
|
||||
case 6: // next_draft_lens
|
||||
ret.nbDims = 1;
|
||||
ret.d[0] = batchSizeExpr;
|
||||
break;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool EagleSampleAndAcceptDraftTokensPlugin::supportsFormatCombination(
|
||||
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept
|
||||
{
|
||||
if (pos == getIdx(InputIdxEntry::LOGITS)) // logits
|
||||
{
|
||||
return (inOut[pos].type == mDtype) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
else if (pos == getIdx(InputIdxEntry::TEMPERATURE)
|
||||
|| pos == getIdx(InputIdxEntry::RAND_VALIDATION)) // temperature, rand_validation
|
||||
{
|
||||
return (inOut[pos].type == nvinfer1::DataType::kFLOAT) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
else // everything else
|
||||
{
|
||||
return (inOut[pos].type == nvinfer1::DataType::kINT32) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
}
|
||||
|
||||
void EagleSampleAndAcceptDraftTokensPlugin::configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
|
||||
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept
|
||||
{
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t EagleSampleAndAcceptDraftTokensPlugin::getWorkspaceSizeType(nvinfer1::PluginTensorDesc const* inputs,
|
||||
int nbInputs, nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
|
||||
{
|
||||
size_t workspaceSize{0};
|
||||
|
||||
auto const vocabSizePadded = inputs[getIdx(InputIdxEntry::LOGITS)].dims.d[1];
|
||||
auto const batchSize = inputs[getIdx(InputIdxEntry::PATHS)].dims.d[0];
|
||||
auto const maxDecodingTokens = inputs[getIdx(InputIdxEntry::PATHS)].dims.d[1];
|
||||
|
||||
// Greedy sampling
|
||||
{
|
||||
// Top1 sampling workspace
|
||||
auto const primarySamplingWorkspaceSize
|
||||
= getTopKWorkspaceSize<T>(batchSize, maxDecodingTokens, /* maxTopK */ 1, vocabSizePadded);
|
||||
|
||||
// Target output ids
|
||||
auto const targetOutputIdsSize = batchSize * maxDecodingTokens * sizeof(TokenIdType);
|
||||
|
||||
// Logits ptrs
|
||||
auto const logitsPtrsSize = batchSize * maxDecodingTokens * sizeof(T*);
|
||||
SizeType32 constexpr NUM_BUFFERS{4};
|
||||
size_t workspaces[NUM_BUFFERS];
|
||||
workspaces[0] = primarySamplingWorkspaceSize;
|
||||
workspaces[1] = targetOutputIdsSize;
|
||||
workspaces[2] = logitsPtrsSize;
|
||||
workspaces[3] = batchSize * sizeof(SizeType32);
|
||||
workspaceSize = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
|
||||
}
|
||||
|
||||
return workspaceSize;
|
||||
}
|
||||
|
||||
size_t EagleSampleAndAcceptDraftTokensPlugin::getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
|
||||
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept
|
||||
{
|
||||
auto const logitsType = inputs[getIdx(InputIdxEntry::LOGITS)].type;
|
||||
if (logitsType == nvinfer1::DataType::kFLOAT)
|
||||
{
|
||||
return getWorkspaceSizeType<float>(inputs, nbInputs, outputs, nbOutputs);
|
||||
}
|
||||
else if (logitsType == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
return getWorkspaceSizeType<__half>(inputs, nbInputs, outputs, nbOutputs);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "Unsupported logits type");
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void EagleSampleAndAcceptDraftTokensPlugin::samplePrimeHeadTokens(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) noexcept
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const maxNumTokens = inputDesc[getIdx(InputIdxEntry::LOGITS)].dims.d[0];
|
||||
auto const vocabSizePadded = inputDesc[getIdx(InputIdxEntry::LOGITS)].dims.d[1];
|
||||
auto const batchSize = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[0];
|
||||
auto const maxDecodingTokens = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[1];
|
||||
|
||||
auto logits = static_cast<T const*>(inputs[getIdx(InputIdxEntry::LOGITS)]);
|
||||
auto prevDraftLens = reinterpret_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::DRAFT_LENS)]);
|
||||
|
||||
int8_t* workspaceBytePtr = reinterpret_cast<int8_t*>(workspace);
|
||||
size_t offset{0};
|
||||
|
||||
auto const samplingWorkspaceSize
|
||||
= getTopKWorkspaceSize<T>(batchSize, maxDecodingTokens, /* maxTopK */ 1, vocabSizePadded);
|
||||
|
||||
void* workspaceSampling
|
||||
= reinterpret_cast<void*>(tc::nextWorkspacePtr(workspaceBytePtr, offset, samplingWorkspaceSize));
|
||||
TokenIdType* outputIds = reinterpret_cast<TokenIdType*>(
|
||||
tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * maxDecodingTokens * sizeof(TokenIdType)));
|
||||
T const** logitsPtrs = reinterpret_cast<T const**>(
|
||||
tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * maxDecodingTokens * sizeof(T*)));
|
||||
SizeType32* decodingTokens
|
||||
= reinterpret_cast<SizeType32*>(tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * sizeof(SizeType32)));
|
||||
|
||||
// Assemble pointers to logits
|
||||
invokeAssembleTargetLogitsOffsets(
|
||||
logitsPtrs, decodingTokens, logits, prevDraftLens, batchSize, maxDecodingTokens, vocabSizePadded, stream);
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
TopKSamplingKernelParams<T> params;
|
||||
params.logProbsPtrs = logitsPtrs;
|
||||
params.outputIds = outputIds;
|
||||
params.workspace = workspaceSampling;
|
||||
params.maxTopK = 1;
|
||||
params.batchSize = batchSize;
|
||||
params.maxBatchSize = batchSize;
|
||||
params.tokensPerStep = decodingTokens;
|
||||
params.maxTokensPerStep = maxDecodingTokens;
|
||||
params.maxSeqLen = maxDecodingTokens;
|
||||
params.vocabSizePadded = vocabSizePadded;
|
||||
|
||||
invokeBatchTopKSampling(params, stream);
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void EagleSampleAndAcceptDraftTokensPlugin::acceptDraftTokens(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) noexcept
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const maxNumTokens = inputDesc[getIdx(InputIdxEntry::LOGITS)].dims.d[0];
|
||||
auto const vocabSizePadded = inputDesc[getIdx(InputIdxEntry::LOGITS)].dims.d[1];
|
||||
|
||||
auto const batchSize = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[0];
|
||||
auto const maxDecodingTokens = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[1];
|
||||
auto const maxPathLen = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[2];
|
||||
auto const maxDraftPathLen = maxPathLen - 1;
|
||||
|
||||
int8_t* workspaceBytePtr = reinterpret_cast<int8_t*>(workspace);
|
||||
size_t offset{0};
|
||||
|
||||
auto const samplingWorkspaceSize
|
||||
= getTopKWorkspaceSize<T>(batchSize, maxDecodingTokens, /* maxTopK */ 1, vocabSizePadded);
|
||||
|
||||
void* workspaceSampling
|
||||
= reinterpret_cast<void*>(tc::nextWorkspacePtr(workspaceBytePtr, offset, samplingWorkspaceSize));
|
||||
TokenIdType* outputIds = reinterpret_cast<TokenIdType*>(
|
||||
tc::nextWorkspacePtr(workspaceBytePtr, offset, batchSize * maxDecodingTokens * sizeof(TokenIdType)));
|
||||
|
||||
AcceptDraftTokensByIdsWithPathsParams<T> params;
|
||||
params.outputIds = reinterpret_cast<TokenIdType*>(outputs[getIdx(OutputIdxEntry::ACCEPTED_TOKENS)]);
|
||||
params.draftIds = reinterpret_cast<TokenIdType const*>(inputs[getIdx(InputIdxEntry::DRAFT_TOKEN_IDS)]);
|
||||
params.targetIds = outputIds;
|
||||
params.acceptedLengths = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::ACCEPTED_LEN)]);
|
||||
params.paths = reinterpret_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::PATHS)]);
|
||||
params.bestPathIds = reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::BEST_ACCEPTED_PATHS)]);
|
||||
params.batchSize = batchSize;
|
||||
params.maxBatchSize = batchSize;
|
||||
params.vocabSize = vocabSizePadded;
|
||||
params.maxSeqLen = maxPathLen;
|
||||
params.maxDraftPathLen = maxDraftPathLen;
|
||||
params.maxDecodingTokens = maxDecodingTokens;
|
||||
params.stream = stream;
|
||||
|
||||
params.checkParams();
|
||||
|
||||
acceptDraftTokensByIdsWithPaths(params);
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void EagleSampleAndAcceptDraftTokensPlugin::doGreedy(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) noexcept
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
// Sample all main head tokens with Top-1.
|
||||
samplePrimeHeadTokens<T>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
|
||||
// Greedy accept tokens based on token ids, write the best path and best token id.
|
||||
acceptDraftTokens<T>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void EagleSampleAndAcceptDraftTokensPlugin::selectLastAccTokenAndComputeIndicesCumSum(
|
||||
nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
|
||||
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[0];
|
||||
auto const maxDecodingTokens = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[1];
|
||||
auto const maxPathLen = inputDesc[getIdx(InputIdxEntry::PATHS)].dims.d[2];
|
||||
|
||||
auto lastAcceptedTokenIds
|
||||
= reinterpret_cast<TokenIdType*>(outputs[getIdx(OutputIdxEntry::LAST_ACCEPTED_TOKEN_IDS)]);
|
||||
auto exclusiveSumLastAcceptedIndices
|
||||
= reinterpret_cast<SizeType32*>(outputs[getIdx(OutputIdxEntry::EXCLUSIVE_SUM_LAST_TOKEN_INDICES)]);
|
||||
auto prevDraftLens = reinterpret_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::DRAFT_LENS)]);
|
||||
auto acceptedTokenIds = reinterpret_cast<TokenIdType const*>(outputs[getIdx(OutputIdxEntry::ACCEPTED_TOKENS)]);
|
||||
auto acceptedLengths = reinterpret_cast<SizeType32 const*>(outputs[getIdx(OutputIdxEntry::ACCEPTED_LEN)]);
|
||||
auto bestPathIds = reinterpret_cast<SizeType32 const*>(outputs[getIdx(OutputIdxEntry::BEST_ACCEPTED_PATHS)]);
|
||||
auto paths = reinterpret_cast<SizeType32 const*>(inputs[getIdx(InputIdxEntry::PATHS)]);
|
||||
|
||||
invokeSelectLastAccTokenAndComputeIndicesCumSum(lastAcceptedTokenIds, exclusiveSumLastAcceptedIndices,
|
||||
prevDraftLens, acceptedTokenIds, acceptedLengths, bestPathIds, paths, batchSize, maxDecodingTokens, maxPathLen,
|
||||
stream);
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void EagleSampleAndAcceptDraftTokensPlugin::enqueueType(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) noexcept
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
// TODO split batch into greedy and non-greedy and execute both paths
|
||||
if (mGreedySampling)
|
||||
{
|
||||
doGreedy<T>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
// TODO fill me
|
||||
TLLM_CHECK_WITH_INFO(false, "Non-greedy sampling is not supported yet");
|
||||
}
|
||||
|
||||
// Find last accepted tokens and do cumulative sum of accepted indices.
|
||||
selectLastAccTokenAndComputeIndicesCumSum(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
int EagleSampleAndAcceptDraftTokensPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) noexcept
|
||||
{
|
||||
auto const logitsType = inputDesc[getIdx(InputIdxEntry::LOGITS)].type;
|
||||
if (logitsType == nvinfer1::DataType::kFLOAT)
|
||||
{
|
||||
enqueueType<float>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
}
|
||||
else if (logitsType == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
enqueueType<__half>(inputDesc, outputDesc, inputs, outputs, workspace, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "Unsupported logits type");
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType EagleSampleAndAcceptDraftTokensPlugin::getOutputDataType(
|
||||
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept
|
||||
{
|
||||
TLLM_CHECK(index < 7);
|
||||
// input 1 is draft tokens now of int32 type. All outputs are int32_t as well.
|
||||
return inputTypes[getIdx(InputIdxEntry::DRAFT_TOKEN_IDS)];
|
||||
}
|
||||
|
||||
// IPluginV2 Methods
|
||||
|
||||
char const* EagleSampleAndAcceptDraftTokensPlugin::getPluginType() const noexcept
|
||||
{
|
||||
return EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
char const* EagleSampleAndAcceptDraftTokensPlugin::getPluginVersion() const noexcept
|
||||
{
|
||||
return EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
int EagleSampleAndAcceptDraftTokensPlugin::getNbOutputs() const noexcept
|
||||
{
|
||||
return 7;
|
||||
}
|
||||
|
||||
int EagleSampleAndAcceptDraftTokensPlugin::initialize() noexcept
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
void EagleSampleAndAcceptDraftTokensPlugin::terminate() noexcept {}
|
||||
|
||||
size_t EagleSampleAndAcceptDraftTokensPlugin::getSerializationSize() const noexcept
|
||||
{
|
||||
return sizeof(mDtype) + sizeof(mGreedySampling);
|
||||
}
|
||||
|
||||
void EagleSampleAndAcceptDraftTokensPlugin::serialize(void* buffer) const noexcept
|
||||
{
|
||||
char *d = static_cast<char*>(buffer), *a = d;
|
||||
write(d, mDtype);
|
||||
write(d, mGreedySampling);
|
||||
assert(d == a + getSerializationSize());
|
||||
}
|
||||
|
||||
void EagleSampleAndAcceptDraftTokensPlugin::destroy() noexcept
|
||||
{
|
||||
// This gets called when the network containing plugin is destroyed
|
||||
delete this;
|
||||
}
|
||||
|
||||
///////////////
|
||||
|
||||
EagleSampleAndAcceptDraftTokensPluginCreator::EagleSampleAndAcceptDraftTokensPluginCreator()
|
||||
{
|
||||
// Fill PluginFieldCollection with PluginField arguments metadata
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("greedy_sampling", nullptr, PluginFieldType::kINT32, 1));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
mFC.fields = mPluginAttributes.data();
|
||||
}
|
||||
|
||||
char const* EagleSampleAndAcceptDraftTokensPluginCreator::getPluginName() const noexcept
|
||||
{
|
||||
return EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_NAME;
|
||||
}
|
||||
|
||||
char const* EagleSampleAndAcceptDraftTokensPluginCreator::getPluginVersion() const noexcept
|
||||
{
|
||||
return EAGLE_SAMPLE_AND_ACCEPT_DRAFT_TOKENS_PLUGIN_VERSION;
|
||||
}
|
||||
|
||||
PluginFieldCollection const* EagleSampleAndAcceptDraftTokensPluginCreator::getFieldNames() noexcept
|
||||
{
|
||||
return &mFC;
|
||||
}
|
||||
|
||||
IPluginV2* EagleSampleAndAcceptDraftTokensPluginCreator::createPlugin(
|
||||
char const* name, PluginFieldCollection const* fc) noexcept
|
||||
{
|
||||
PluginField const* fields = fc->fields;
|
||||
nvinfer1::DataType type;
|
||||
bool greedySampling;
|
||||
// Read configurations from each fields
|
||||
for (int i = 0; i < fc->nbFields; ++i)
|
||||
{
|
||||
char const* attrName = fields[i].name;
|
||||
if (!strcmp(attrName, "type_id"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
type = static_cast<nvinfer1::DataType>(*(static_cast<nvinfer1::DataType const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "greedy_sampling"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
greedySampling = static_cast<bool>(*static_cast<int32_t const*>(fields[i].data));
|
||||
}
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
auto* obj = new EagleSampleAndAcceptDraftTokensPlugin(type, greedySampling);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
caughtError(e);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
IPluginV2* EagleSampleAndAcceptDraftTokensPluginCreator::deserializePlugin(
|
||||
char const* name, void const* serialData, size_t serialLength) noexcept
|
||||
{
|
||||
// This object will be deleted when the network is destroyed, which will
|
||||
// call EagleSampleAndAcceptDraftTokensPlugin::destroy()
|
||||
try
|
||||
{
|
||||
auto* obj = new EagleSampleAndAcceptDraftTokensPlugin(serialData, serialLength);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
caughtError(e);
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
@ -0,0 +1,163 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 1993-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/plugins/common/plugin.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::plugins
|
||||
{
|
||||
|
||||
class EagleSampleAndAcceptDraftTokensPlugin : public BasePlugin
|
||||
{
|
||||
public:
|
||||
EagleSampleAndAcceptDraftTokensPlugin(nvinfer1::DataType type, bool greedySampling);
|
||||
|
||||
EagleSampleAndAcceptDraftTokensPlugin(void const* data, size_t length);
|
||||
|
||||
~EagleSampleAndAcceptDraftTokensPlugin() override = default;
|
||||
|
||||
// IPluginV2DynamicExt Methods
|
||||
nvinfer1::IPluginV2DynamicExt* clone() const noexcept override;
|
||||
nvinfer1::DimsExprs getOutputDimensions(int outputIndex, nvinfer1::DimsExprs const* inputs, int nbInputs,
|
||||
nvinfer1::IExprBuilder& exprBuilder) noexcept override;
|
||||
bool supportsFormatCombination(
|
||||
int pos, nvinfer1::PluginTensorDesc const* inOut, int nbInputs, int nbOutputs) noexcept override;
|
||||
void configurePlugin(nvinfer1::DynamicPluginTensorDesc const* in, int nbInputs,
|
||||
nvinfer1::DynamicPluginTensorDesc const* out, int nbOutputs) noexcept override;
|
||||
size_t getWorkspaceSize(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
|
||||
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept override;
|
||||
int enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
|
||||
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept override;
|
||||
|
||||
// IPluginV2Ext Methods
|
||||
nvinfer1::DataType getOutputDataType(
|
||||
int index, nvinfer1::DataType const* inputTypes, int nbInputs) const noexcept override;
|
||||
|
||||
// IPluginV2 Methods
|
||||
char const* getPluginType() const noexcept override;
|
||||
char const* getPluginVersion() const noexcept override;
|
||||
int getNbOutputs() const noexcept override;
|
||||
int initialize() noexcept override;
|
||||
void terminate() noexcept override;
|
||||
size_t getSerializationSize() const noexcept override;
|
||||
void serialize(void* buffer) const noexcept override;
|
||||
void destroy() noexcept override;
|
||||
|
||||
private:
|
||||
enum class InputIdxEntry : int32_t
|
||||
{
|
||||
//! [num_tokens, vocab_size_padded]
|
||||
LOGITS = 0,
|
||||
//! [batch_size, max_decoding_draft_tokens]
|
||||
DRAFT_TOKEN_IDS,
|
||||
//! [batch_size]
|
||||
DRAFT_LENS,
|
||||
//! [batch_size]
|
||||
TEMPERATURE,
|
||||
//! []?
|
||||
RAND_VALIDATION,
|
||||
//! [batch_size, max_decoding_tokens, max_path_len]
|
||||
PATHS
|
||||
};
|
||||
|
||||
enum class OutputIdxEntry : int32_t
|
||||
{
|
||||
//! [batch_size, max_draft_path_len]
|
||||
ACCEPTED_TOKENS = 0,
|
||||
//! [batch_size]
|
||||
ACCEPTED_LEN,
|
||||
//! [batch_size]
|
||||
BEST_ACCEPTED_PATHS,
|
||||
//! [batch_size]
|
||||
LAST_ACCEPTED_TOKEN_IDS,
|
||||
//! [batch_size]
|
||||
EXCLUSIVE_SUM_LAST_TOKEN_INDICES,
|
||||
//! [batch_size, max_decoding_draft_tokens]
|
||||
NEXT_DRAFT_TOKEN_IDS,
|
||||
//! [batch_size]
|
||||
NEXT_DRAFT_LENS
|
||||
};
|
||||
|
||||
int32_t getIdx(InputIdxEntry idx) const
|
||||
{
|
||||
return static_cast<int32_t>(idx);
|
||||
}
|
||||
|
||||
int32_t getIdx(OutputIdxEntry idx) const
|
||||
{
|
||||
return static_cast<int32_t>(idx);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
size_t getWorkspaceSizeType(nvinfer1::PluginTensorDesc const* inputs, int nbInputs,
|
||||
nvinfer1::PluginTensorDesc const* outputs, int nbOutputs) const noexcept;
|
||||
|
||||
template <typename T>
|
||||
void samplePrimeHeadTokens(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) noexcept;
|
||||
|
||||
template <typename T>
|
||||
void acceptDraftTokens(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
|
||||
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept;
|
||||
|
||||
template <typename T>
|
||||
void doGreedy(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
|
||||
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept;
|
||||
|
||||
void selectLastAccTokenAndComputeIndicesCumSum(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
nvinfer1::PluginTensorDesc const* outputDesc, void const* const* inputs, void* const* outputs, void* workspace,
|
||||
cudaStream_t stream) noexcept;
|
||||
|
||||
template <typename T>
|
||||
void enqueueType(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::PluginTensorDesc const* outputDesc,
|
||||
void const* const* inputs, void* const* outputs, void* workspace, cudaStream_t stream) noexcept;
|
||||
|
||||
private:
|
||||
nvinfer1::DataType mDtype;
|
||||
bool mGreedySampling;
|
||||
};
|
||||
|
||||
class EagleSampleAndAcceptDraftTokensPluginCreator : public BaseCreator
|
||||
{
|
||||
public:
|
||||
EagleSampleAndAcceptDraftTokensPluginCreator();
|
||||
|
||||
char const* getPluginName() const noexcept override;
|
||||
|
||||
char const* getPluginVersion() const noexcept override;
|
||||
|
||||
nvinfer1::PluginFieldCollection const* getFieldNames() noexcept override;
|
||||
|
||||
nvinfer1::IPluginV2* createPlugin(char const* name, nvinfer1::PluginFieldCollection const* fc) noexcept override;
|
||||
|
||||
nvinfer1::IPluginV2* deserializePlugin(
|
||||
char const* name, void const* serialData, size_t serialLength) noexcept override;
|
||||
|
||||
private:
|
||||
static nvinfer1::PluginFieldCollection mFC;
|
||||
static std::vector<nvinfer1::PluginField> mPluginAttributes;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::plugins
|
||||
@ -24,14 +24,17 @@ using namespace tensorrt_llm::kernels::cutlass_kernels;
|
||||
using tensorrt_llm::plugins::WeightOnlyGroupwiseQuantMatmulPluginCreator;
|
||||
using tensorrt_llm::plugins::WeightOnlyGroupwiseQuantMatmulPlugin;
|
||||
using tensorrt_llm::plugins::WeightOnlyGroupwiseQuantGemmPluginProfiler;
|
||||
using tensorrt_llm::plugins::WeightOnlyGemmRunnerPtr;
|
||||
|
||||
// Flags for indicating whether the corresponding inputs are applied in mQuantAlgo
|
||||
// mQuantAlgo = pre_quant_scale * PRE_QUANT_SCALE + zero * ZERO + bias * BIAS
|
||||
// Here pre_quant_scale, zero and bias are boolean type
|
||||
// mQuantAlgo = int8_weight * INT8_WEIGHT + use_w4a8_awq * FP8_ALPHA + pre_quant_scale * PRE_QUANT_SCALE
|
||||
// + zero * ZERO + bias * BIAS
|
||||
// Here int8_weight, use_w4a8_awq, pre_quant_scale, zero and bias are boolean type
|
||||
static constexpr int BIAS = int(1) << 0;
|
||||
static constexpr int ZERO = int(1) << 1;
|
||||
static constexpr int PRE_QUANT_SCALE = int(1) << 2;
|
||||
static constexpr int FP8_ALPHA = int(1) << 3;
|
||||
static constexpr int INT8_WEIGHT = int(1) << 4;
|
||||
using tensorrt_llm::plugins::read;
|
||||
using tensorrt_llm::plugins::write;
|
||||
|
||||
@ -43,11 +46,10 @@ std::vector<nvinfer1::PluginField> WeightOnlyGroupwiseQuantMatmulPluginCreator::
|
||||
void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic(int m, int n, int k,
|
||||
WeightOnlyGroupwiseQuantGemmPluginProfiler::Config const& tactic, char* workspace, cudaStream_t const& stream)
|
||||
{
|
||||
// Quantized weights are packed in FP16 format (INT4*4 -> FP16)
|
||||
int const originalN = n * FP16_INT4_RATIO;
|
||||
// Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
|
||||
int const originalN = mQuantAlgo & INT8_WEIGHT ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO;
|
||||
half* actPtr = reinterpret_cast<half*>(workspace);
|
||||
cutlass::uint4b_t* weightPtr = reinterpret_cast<cutlass::uint4b_t*>(
|
||||
nextWorkspacePtr(reinterpret_cast<int8_t*>(actPtr), m * k * sizeof(half)));
|
||||
void* weightPtr = nextWorkspacePtr(reinterpret_cast<int8_t*>(actPtr), m * k * sizeof(half));
|
||||
half* inputScalesPtr
|
||||
= reinterpret_cast<half*>(nextWorkspacePtr(reinterpret_cast<int8_t*>(weightPtr), n * k * sizeof(float)));
|
||||
half* zerosPtr = reinterpret_cast<half*>(
|
||||
@ -69,15 +71,22 @@ void WeightOnlyGroupwiseQuantGemmPluginProfiler::runTactic(int m, int n, int k,
|
||||
}
|
||||
|
||||
int const wsSize = mRunner->getWorkspaceSize(m, originalN, k);
|
||||
|
||||
mRunner->gemm(actPtr, weightPtr, inputScalesPtr, zerosPtr, biasesPtr, outputPtr, m, originalN, k, mGroupSize,
|
||||
tactic, workspacePtr, wsSize, stream);
|
||||
if (mQuantAlgo & INT8_WEIGHT)
|
||||
{
|
||||
mRunner->gemm(actPtr, reinterpret_cast<int8_t*>(weightPtr), inputScalesPtr, zerosPtr, biasesPtr, outputPtr, m,
|
||||
originalN, k, mGroupSize, tactic, workspacePtr, wsSize, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
mRunner->gemm(actPtr, reinterpret_cast<cutlass::uint4b_t*>(weightPtr), inputScalesPtr, zerosPtr, biasesPtr,
|
||||
outputPtr, m, originalN, k, mGroupSize, tactic, workspacePtr, wsSize, stream);
|
||||
}
|
||||
}
|
||||
|
||||
void WeightOnlyGroupwiseQuantGemmPluginProfiler::computeTmpSize(size_t maxM, size_t n, size_t k)
|
||||
{
|
||||
// Quantized weights are packed in FP16 format (INT4*4 -> FP16)
|
||||
int const originalN = n * FP16_INT4_RATIO;
|
||||
// Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
|
||||
int const originalN = mQuantAlgo & INT8_WEIGHT ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO;
|
||||
std::vector<size_t> workspaces = {
|
||||
maxM * k * sizeof(half), // A
|
||||
k * n * sizeof(float), // B
|
||||
@ -129,6 +138,38 @@ WeightOnlyGroupwiseQuantMatmulPlugin::WeightOnlyGroupwiseQuantMatmulPlugin(
|
||||
(int) length, (int) (d - a));
|
||||
}
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename OutputType, cutlass::WeightOnlyQuantOp QuantOp>
|
||||
using GemmRunner = tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp,
|
||||
OutputType, OutputType, OutputType>;
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename OutputType>
|
||||
WeightOnlyGemmRunnerPtr selectGemmRunnerForZERO(int quant_algo)
|
||||
{
|
||||
if (quant_algo & ZERO)
|
||||
{
|
||||
return std::make_shared<GemmRunner<ActivationType, WeightType, OutputType,
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return std::make_shared<
|
||||
GemmRunner<ActivationType, WeightType, OutputType, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ActivationType>
|
||||
WeightOnlyGemmRunnerPtr selectGemmRunnerForWeightType(int quant_algo)
|
||||
{
|
||||
if (quant_algo & INT8_WEIGHT)
|
||||
{
|
||||
return selectGemmRunnerForZERO<ActivationType, uint8_t, ActivationType>(quant_algo);
|
||||
}
|
||||
else
|
||||
{
|
||||
return selectGemmRunnerForZERO<ActivationType, cutlass::uint4b_t, ActivationType>(quant_algo);
|
||||
}
|
||||
}
|
||||
|
||||
void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int quant_algo, int group_size)
|
||||
{
|
||||
mArch = tensorrt_llm::common::getSMVersion();
|
||||
@ -136,7 +177,7 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua
|
||||
mQuantAlgo = quant_algo;
|
||||
mGroupSize = group_size;
|
||||
|
||||
// quant_algo = fp8_alpha * 8 + pre_quant_scale * 4 + zero * 2 + bias
|
||||
// quant_algo = int8_weight * 16 + fp8_alpha * 8 + pre_quant_scale * 4 + zero * 2 + bias
|
||||
mPreQuantScaleInputIdx = (quant_algo & PRE_QUANT_SCALE) ? 1 : 0;
|
||||
mWeightInputIdx = mPreQuantScaleInputIdx + 1;
|
||||
mScalesInputIdx = mWeightInputIdx + 1;
|
||||
@ -146,6 +187,7 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua
|
||||
|
||||
if (mType == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
// CUTLASS kernel selection
|
||||
if (quant_algo & FP8_ALPHA)
|
||||
{
|
||||
// Ada & Hopper style kernels
|
||||
@ -153,45 +195,34 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua
|
||||
{
|
||||
TLLM_THROW("W4A(fp)8 kernel is unsupported on pre-Ada (sm<89) architectures!");
|
||||
}
|
||||
if (quant_algo & ZERO)
|
||||
{
|
||||
// has zeros
|
||||
m_weightOnlyGroupwiseGemmRunner = std::make_shared<
|
||||
tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, cutlass::uint4b_t,
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, half, half>>();
|
||||
}
|
||||
else
|
||||
{
|
||||
// no zeros
|
||||
m_weightOnlyGroupwiseGemmRunner
|
||||
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_fp8_e4m3,
|
||||
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, half, half, half>>();
|
||||
}
|
||||
assert(!(quant_algo & INT8_WEIGHT) && "W4A(fp)8 kernel requires INT4 weight!");
|
||||
m_weightOnlyGroupwiseGemmRunner
|
||||
= selectGemmRunnerForZERO<__nv_fp8_e4m3, cutlass::uint4b_t, half>(quant_algo);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (quant_algo & ZERO)
|
||||
{
|
||||
// has zeros
|
||||
m_weightOnlyGroupwiseGemmRunner
|
||||
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<half,
|
||||
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>>();
|
||||
}
|
||||
else
|
||||
{
|
||||
// no zeros
|
||||
m_weightOnlyGroupwiseGemmRunner
|
||||
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<half,
|
||||
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>>();
|
||||
}
|
||||
m_weightOnlyGroupwiseGemmRunner = selectGemmRunnerForWeightType<half>(quant_algo);
|
||||
}
|
||||
// CUDA kernel selection
|
||||
if (quant_algo & INT8_WEIGHT)
|
||||
{
|
||||
// INT8 weight
|
||||
mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
|
||||
mArch, tensorrt_llm::kernels::weight_only::KernelType::FP16Int8Groupwise);
|
||||
mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::FP16Int8Groupwise;
|
||||
}
|
||||
else
|
||||
{
|
||||
// INT4 weight
|
||||
mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
|
||||
mArch, tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise);
|
||||
mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise;
|
||||
}
|
||||
mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
|
||||
mArch, tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise);
|
||||
mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::FP16Int4Groupwise;
|
||||
}
|
||||
#if defined(ENABLE_BF16)
|
||||
else if (mType == nvinfer1::DataType::kBF16)
|
||||
{
|
||||
// CUTLASS kernel selection
|
||||
if (quant_algo & FP8_ALPHA)
|
||||
{
|
||||
// FP8 requires at least sm89 devices
|
||||
@ -203,24 +234,23 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::init(nvinfer1::DataType type, int qua
|
||||
}
|
||||
else
|
||||
{
|
||||
if (quant_algo & ZERO)
|
||||
{
|
||||
// has zeros
|
||||
m_weightOnlyGroupwiseGemmRunner
|
||||
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16,
|
||||
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS>>();
|
||||
}
|
||||
else
|
||||
{
|
||||
// no zeros
|
||||
m_weightOnlyGroupwiseGemmRunner
|
||||
= std::make_shared<tensorrt_llm::kernels::cutlass_kernels::CutlassFpAIntBGemmRunner<__nv_bfloat16,
|
||||
cutlass::uint4b_t, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>>();
|
||||
}
|
||||
m_weightOnlyGroupwiseGemmRunner = selectGemmRunnerForWeightType<__nv_bfloat16>(quant_algo);
|
||||
}
|
||||
// CUDA kernel selection
|
||||
if (quant_algo & INT8_WEIGHT)
|
||||
{
|
||||
// INT8 weight
|
||||
mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
|
||||
mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int8Groupwise);
|
||||
mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int8Groupwise;
|
||||
}
|
||||
else
|
||||
{
|
||||
// INT4 weight
|
||||
mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
|
||||
mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise);
|
||||
mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise;
|
||||
}
|
||||
mCudaKernelEnabled = tensorrt_llm::kernels::weight_only::is_supported(
|
||||
mArch, tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise);
|
||||
mCudaKernelType = tensorrt_llm::kernels::weight_only::KernelType::BF16Int4Groupwise;
|
||||
}
|
||||
#endif
|
||||
else
|
||||
@ -273,8 +303,9 @@ nvinfer1::DimsExprs WeightOnlyGroupwiseQuantMatmulPlugin::getOutputDimensions(
|
||||
ret.d[ii] = inputs[0].d[ii];
|
||||
}
|
||||
|
||||
// int4 weight only quant
|
||||
ret.d[nbDimsA - 1] = exprBuilder.constant(inputs[mWeightInputIdx].d[1]->getConstantValue() * FP16_INT4_RATIO);
|
||||
// int4/int8 weight only quant (INT4*4 -> FP16, INT8*2 -> FP16)
|
||||
int const weight_multiplier = mQuantAlgo & INT8_WEIGHT ? FP16_INT8_RATIO : FP16_INT4_RATIO;
|
||||
ret.d[nbDimsA - 1] = exprBuilder.constant(inputs[mWeightInputIdx].d[1]->getConstantValue() * weight_multiplier);
|
||||
|
||||
return ret;
|
||||
}
|
||||
@ -320,11 +351,12 @@ void WeightOnlyGroupwiseQuantMatmulPlugin::configurePlugin(nvinfer1::DynamicPlug
|
||||
|
||||
int const maxK = in[0].max.d[in[0].max.nbDims - 1];
|
||||
|
||||
// Quantized weights are packed in FP16 format (INT4*4 -> FP16)
|
||||
int const maxN = in[mWeightInputIdx].max.d[1] * FP16_INT4_RATIO;
|
||||
// Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
|
||||
int const weight_multiplier = mQuantAlgo & INT8_WEIGHT ? FP16_INT8_RATIO : FP16_INT4_RATIO;
|
||||
int const maxN = in[mWeightInputIdx].max.d[1] * weight_multiplier;
|
||||
|
||||
auto const K = maxK;
|
||||
auto const N = maxN / FP16_INT4_RATIO;
|
||||
auto const N = maxN / weight_multiplier;
|
||||
|
||||
if (!mDims.isInitialized())
|
||||
{
|
||||
@ -424,8 +456,9 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(nvinfer1::PluginTensorDesc con
|
||||
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF, "No valid weightOnlyGropwiseQuantMatmul configuration");
|
||||
#endif
|
||||
|
||||
// Quantized weights are packed in FP16 format (INT4*4 -> FP16)
|
||||
int real_n = n * FP16_INT4_RATIO;
|
||||
// Quantized weights are packed in FP16 format (INT4*4 -> FP16, INT8*2 -> FP16)
|
||||
int real_n = mQuantAlgo & INT8_WEIGHT ? n * FP16_INT8_RATIO : n * FP16_INT4_RATIO;
|
||||
|
||||
if (use_cuda_kernel)
|
||||
{
|
||||
void const* pre_quant_scale_ptr = nullptr;
|
||||
|
||||
@ -46,6 +46,7 @@ constexpr int32_t INT8_BITS = 8;
|
||||
constexpr int32_t INT4_BITS = 4;
|
||||
constexpr int32_t INT8_INT4_RATIO = INT8_BITS / INT4_BITS;
|
||||
constexpr int32_t FP16_INT4_RATIO = FP16_BITS / INT4_BITS;
|
||||
constexpr int32_t FP16_INT8_RATIO = FP16_BITS / INT8_BITS;
|
||||
|
||||
inline int32_t getWeightTypeMultiplier(WeightTypeId weightTypeId)
|
||||
{
|
||||
|
||||
@ -140,6 +140,7 @@ void InitBindings(pybind11::module_& m)
|
||||
.def_readwrite("iter", &tle::IterationStats::iter)
|
||||
.def_readwrite("iter_latency_ms", &tle::IterationStats::iterLatencyMS)
|
||||
.def_readwrite("new_active_requests_queue_latency_ms", &tle::IterationStats::newActiveRequestsQueueLatencyMS)
|
||||
.def_readwrite("num_new_active_requests", &tle::IterationStats::numNewActiveRequests)
|
||||
.def_readwrite("num_active_requests", &tle::IterationStats::numActiveRequests)
|
||||
.def_readwrite("num_queued_requests", &tle::IterationStats::numQueuedRequests)
|
||||
.def_readwrite("num_completed_requests", &tle::IterationStats::numCompletedRequests)
|
||||
@ -180,6 +181,9 @@ void InitBindings(pybind11::module_& m)
|
||||
.def_readwrite("scheduled", &tle::RequestStats::scheduled)
|
||||
.def_readwrite("paused", &tle::RequestStats::paused)
|
||||
.def_readwrite("dis_serving_stats", &tle::RequestStats::disServingStats)
|
||||
.def_readwrite("alloc_total_blocks_per_request", &tle::RequestStats::allocTotalBlocksPerRequest)
|
||||
.def_readwrite("alloc_new_blocks_per_request", &tle::RequestStats::allocNewBlocksPerRequest)
|
||||
.def_readwrite("reused_blocks_per_request", &tle::RequestStats::reusedBlocksPerRequest)
|
||||
.def("to_json_str",
|
||||
[](tle::RequestStats const& iterationStats) { return tle::JsonSerialization::toJsonStr(iterationStats); });
|
||||
|
||||
|
||||
@ -266,6 +266,7 @@ void parsePluginConfig(ModelConfig& modelConfig, Json const& pluginConfig)
|
||||
auto const manageWeightsType = parseJsonFieldOr<bool>(pluginConfig, "manage_weights", false)
|
||||
? ModelConfig::ManageWeightsType::kEnabled
|
||||
: ModelConfig::ManageWeightsType::kDisabled;
|
||||
auto const ppReduceScatter = parseJsonFieldOr<bool>(pluginConfig, "pp_reduce_scatter", false);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
!removeInputPadding || modelConfig.getMaxNumTokens(), "Padding removal requires max_num_tokens to be set.");
|
||||
@ -283,6 +284,7 @@ void parsePluginConfig(ModelConfig& modelConfig, Json const& pluginConfig)
|
||||
modelConfig.setPagedContextFMHA(pagedContextFMHA);
|
||||
modelConfig.useXQA(useXQA);
|
||||
modelConfig.setManageWeightsType(manageWeightsType);
|
||||
modelConfig.setPpReduceScatter(ppReduceScatter);
|
||||
}
|
||||
|
||||
void parseLora(ModelConfig& modelConfig, Json const& json, Json const& pluginConfig, bool engineVersionNone,
|
||||
|
||||
@ -72,7 +72,6 @@ auto const kProfileMbIdxs = populateMicrobatchIndexes();
|
||||
GptSession::Config setPath(GptSession::Config const& original, std::string const& path)
|
||||
{
|
||||
GptSession::Config config = original;
|
||||
config.enginePath = std::filesystem::path(path);
|
||||
return config;
|
||||
}
|
||||
|
||||
|
||||
@ -408,9 +408,7 @@ void TllmRuntime::loadManagedWeights(RawEngine const& rawEngine, int localRank)
|
||||
{
|
||||
TLLM_LOG_DEBUG("Loading managed weight: %s", name.c_str());
|
||||
auto iTensor = tensorrt_llm::executor::detail::toITensor(weight);
|
||||
auto weightsDevice = std::shared_ptr<ITensor>{
|
||||
manager.allocate(MemoryType::kGPU, iTensor->getShape(), iTensor->getDataType())};
|
||||
manager.copy(iTensor->data(), *weightsDevice, MemoryType::kCPU);
|
||||
auto weightsDevice = std::shared_ptr<ITensor>{manager.copyFrom(*iTensor, MemoryType::kGPU)};
|
||||
mManagedWeightsMap.insert(std::make_pair(name, weightsDevice));
|
||||
}
|
||||
}
|
||||
|
||||
@ -1326,16 +1326,34 @@ public:
|
||||
|
||||
void callAcceptByIdsWithPaths()
|
||||
{
|
||||
tksp::acceptDraftTokensByIdsWithPaths(bufferCast<SizeType32>(*mOutputTokens),
|
||||
bufferCast<SizeType32>(*mDraftTokens), bufferCast<SizeType32>(*mTargetTokens),
|
||||
bufferCast<SizeType32>(*mSequenceLengths), bufferCast<SizeType32>(*mAcceptedLengths),
|
||||
reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedFinal)),
|
||||
bufferCast<SizeType32>(*mBatchSlots), bufferCast<SizeType32>(*mPaths), bufferCast<SizeType32>(*mEndIds),
|
||||
reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaInputLogitsPtrs)),
|
||||
reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaLogitsPtrs)),
|
||||
bufferCast<SizeType32>(*mTokensPerStep), bufferCast<SizeType32>(*mTokensPerStep),
|
||||
bufferCast<SizeType32>(*mBestPaths), mBatchSize, mMaxBatchSize, mVocabSize, mMaxSeqLen, mMaxNumHeads,
|
||||
mMaxDraftSeqPerStep, mStream->get());
|
||||
tksp::AcceptDraftTokensByIdsWithPathsParams<T> params;
|
||||
|
||||
params.outputIds = bufferCast<SizeType32>(*mOutputTokens);
|
||||
params.draftIds = bufferCast<SizeType32>(*mDraftTokens);
|
||||
params.targetIds = bufferCast<SizeType32>(*mTargetTokens);
|
||||
params.sequenceLengths = bufferCast<SizeType32>(*mSequenceLengths);
|
||||
params.acceptedLengths = bufferCast<SizeType32>(*mAcceptedLengths);
|
||||
params.finishedFinal
|
||||
= reinterpret_cast<tk::FinishedState*>(bufferCast<tk::FinishedState::UnderlyingType>(*mFinishedFinal));
|
||||
params.batchSlots = bufferCast<SizeType32>(*mBatchSlots);
|
||||
params.paths = bufferCast<SizeType32>(*mPaths);
|
||||
params.endIds = bufferCast<SizeType32>(*mEndIds);
|
||||
params.medusaLogits = reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaInputLogitsPtrs));
|
||||
params.logitsPtrs = reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaLogitsPtrs));
|
||||
params.curTokensPerStep = bufferCast<SizeType32>(*mTokensPerStep);
|
||||
params.targetTokensPerStep = bufferCast<SizeType32>(*mTokensPerStep);
|
||||
params.bestPathIds = bufferCast<SizeType32>(*mBestPaths);
|
||||
params.batchSize = mBatchSize;
|
||||
params.maxBatchSize = mMaxBatchSize;
|
||||
params.vocabSize = mVocabSize;
|
||||
params.maxSeqLen = mMaxSeqLen;
|
||||
params.maxDraftPathLen = mMaxNumHeads;
|
||||
params.maxDecodingTokens = mMaxDraftSeqPerStep;
|
||||
params.stream = mStream->get();
|
||||
|
||||
params.checkParams();
|
||||
|
||||
tksp::acceptDraftTokensByIdsWithPaths(params);
|
||||
}
|
||||
|
||||
void callTestedKernel()
|
||||
|
||||
@ -91,54 +91,59 @@ TYPED_TEST_SUITE(AirTopPSamplingKernelTest, FloatAndHalfTypes);
|
||||
|
||||
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessSmallP)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.2f));
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.2f));
|
||||
};
|
||||
|
||||
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeP)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.9f));
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.9f));
|
||||
};
|
||||
|
||||
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessAncestral)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(1.0f));
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(1.0f));
|
||||
};
|
||||
|
||||
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeVocabSmallP)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.2f));
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.2f));
|
||||
};
|
||||
|
||||
TYPED_TEST(AirTopPSamplingKernelTest, NondeterministicCorrectnessLargeVocabLargeP)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.9f));
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.9f));
|
||||
};
|
||||
|
||||
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessSmallP)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.2f).setDeterministicTopP(true));
|
||||
this->runTest(
|
||||
SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.2f).setDeterministicTopP(true));
|
||||
};
|
||||
|
||||
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeP)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.9f).setDeterministicTopP(true));
|
||||
this->runTest(
|
||||
SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.9f).setDeterministicTopP(true));
|
||||
};
|
||||
|
||||
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessAncestral)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(1.0f).setDeterministicTopP(true));
|
||||
this->runTest(
|
||||
SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(1.0f).setDeterministicTopP(true));
|
||||
};
|
||||
|
||||
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeVocabSmallP)
|
||||
{
|
||||
this->runTest(
|
||||
SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.2f).setDeterministicTopP(true));
|
||||
SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.2f).setDeterministicTopP(
|
||||
true));
|
||||
};
|
||||
|
||||
TYPED_TEST(AirTopPSamplingKernelTest, DeterministicCorrectnessLargeVocabLargeP)
|
||||
{
|
||||
this->runTest(
|
||||
SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.9f).setDeterministicTopP(true));
|
||||
SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.9f).setDeterministicTopP(
|
||||
true));
|
||||
};
|
||||
|
||||
class AirTopPSamplingKernelUtilsTest : public SamplingKernelTest<float>
|
||||
|
||||
@ -110,6 +110,8 @@ void SamplingKernelTest<T>::setupBuffers(SamplingKernelTestParam const& param)
|
||||
|
||||
auto const topK = param.topK;
|
||||
auto const topP = param.topP;
|
||||
// TopK == 0 case (TopP kernel)
|
||||
auto const topKDistUpperBound = std::max(topK, static_cast<unsigned int>(1));
|
||||
|
||||
std::mt19937 gen(42);
|
||||
|
||||
@ -133,7 +135,7 @@ void SamplingKernelTest<T>::setupBuffers(SamplingKernelTestParam const& param)
|
||||
0, vocabSize - 1); // -1 because uniform_int_distribution generates closed interval
|
||||
std::uniform_real_distribution<> skipDecodeDist(0, 1);
|
||||
std::uniform_real_distribution<> topPDist(0, topP);
|
||||
std::uniform_int_distribution<> topKDist(1, topK);
|
||||
std::uniform_int_distribution<> topKDist(1, topKDistUpperBound);
|
||||
std::uniform_int_distribution<> tokensPerStepDist(1, maxTokensPerStep);
|
||||
std::uniform_int_distribution<> seqLenDist(0, mMaxSeqLen - maxTokensPerStep);
|
||||
std::uniform_real_distribution<> logProbDist(-3.f, 3.f);
|
||||
@ -158,7 +160,7 @@ void SamplingKernelTest<T>::setupBuffers(SamplingKernelTestParam const& param)
|
||||
endIdsHostPtr[bi] = endIdsDistr(gen);
|
||||
skipDecodeHostPtr[bi] = skipDecodeDist(gen) > 0.8;
|
||||
topPsHostPtr[bi] = topPDist(gen);
|
||||
topKsHostPtr[bi] = topKDist(gen);
|
||||
topKsHostPtr[bi] = topK == 0 ? 0 : topKDist(gen);
|
||||
tokensPerStepPtr[bi] = tokensPerStepDist(gen);
|
||||
finishedHostPtr[bi] = finishedDist(gen) > 0.8 ? tk::FinishedState::finished() : tk::FinishedState::empty();
|
||||
}
|
||||
@ -196,9 +198,9 @@ void SamplingKernelTest<T>::setupBuffers(SamplingKernelTestParam const& param)
|
||||
// Init logits randomly
|
||||
auto logitsHostPtr = bufferCast<T>(*mLogitsHost);
|
||||
initRandom(logitsHostPtr, batchSize * maxTokensPerStep * vocabSize, -3.0f, 3.0f);
|
||||
|
||||
// Only in greedy search we can guarantee the selected token and stop by condition
|
||||
if (topK == 1)
|
||||
// TopK == 1 for TopK kernel greedy, TopK == 0 for TopP kernels
|
||||
if (topK <= 1)
|
||||
{
|
||||
for (SizeType32 bi = 0; bi < batchSize; ++bi)
|
||||
{
|
||||
@ -231,13 +233,29 @@ std::vector<SizeType32> SamplingKernelTest<T>::computeTopKTopPVariants(
|
||||
auto topK = bufferCast<int32_t>(*mTopKsHost)[batchSlot];
|
||||
auto topP = bufferCast<float>(*mTopPsHost)[batchSlot];
|
||||
|
||||
allowedTokens.insert(allowedTokens.begin(), indices.begin(), indices.begin() + topK);
|
||||
if (topK > 0) // handling top K kernel, top P result based on topK tokens
|
||||
{
|
||||
float sSum = 0.f; // sSum as in samplingTopKKernels.cu
|
||||
for (auto ki = 0; ki < topK; ki++)
|
||||
{
|
||||
sSum += static_cast<float>(probsPtr[indices[ki]]);
|
||||
}
|
||||
topP *= sSum; // the adjusted topP in the selected topK distribution
|
||||
}
|
||||
|
||||
float totalProb = 0.f;
|
||||
SizeType32 idx = 0;
|
||||
while (totalProb < topP && idx < vocabSize)
|
||||
{
|
||||
allowedTokens.push_back(indices[idx]);
|
||||
totalProb += static_cast<float>(probsPtr[indices[idx++]]);
|
||||
// cuda may selected a different index with same probability in kernel reduce, in test we allow them
|
||||
while (idx < vocabSize
|
||||
&& static_cast<float>(probsPtr[indices[idx]]) == static_cast<float>(probsPtr[indices[idx - 1]]))
|
||||
{
|
||||
allowedTokens.push_back(indices[idx]);
|
||||
totalProb += static_cast<float>(probsPtr[indices[idx++]]);
|
||||
}
|
||||
}
|
||||
return allowedTokens;
|
||||
}
|
||||
@ -284,12 +302,15 @@ void SamplingKernelTest<T>::verifyResult(SamplingKernelTestParam const& param)
|
||||
auto const tokensPerStep = tokensPerStepPtr[batchSlot];
|
||||
for (SizeType32 ti = 0; ti < tokensPerStep; ++ti)
|
||||
{
|
||||
auto kResults = param.returnAllTopK ? bufferCast<int32_t>(*mTopKsHost)[batchSlot] : 1;
|
||||
|
||||
for (SizeType32 ki = 0; ki < kResults; ++ki)
|
||||
auto topK = bufferCast<int32_t>(*mTopKsHost)[batchSlot];
|
||||
auto kResults = param.returnAllSelectedTokens ? (topK == 0 ? vocabSize : topK) : 1;
|
||||
auto topKTopPVariants = computeTopKTopPVariants(bi, batchSlot, ti, maxTokensPerStep, vocabSize);
|
||||
SizeType32 ki;
|
||||
for (ki = 0; ki < kResults && ki < topKTopPVariants.size(); ++ki)
|
||||
{
|
||||
// Set reference finished state to true if we finished before or at current step
|
||||
auto const idsIdx = param.returnAllTopK ? ti * mMaxTopK + ki : seqLengthsOrigHostPtr[batchSlot] + ti;
|
||||
auto const idsIdx
|
||||
= param.returnAllSelectedTokens ? ti * mMaxTopK + ki : seqLengthsOrigHostPtr[batchSlot] + ti;
|
||||
auto const outputId = outputIdsHostPtr[batchSlot * mMaxSeqLen + idsIdx];
|
||||
// Check the range of the returned token ([0, vocabSize))
|
||||
EXPECT_TRUE((outputId >= 0) && (outputId < vocabSize));
|
||||
@ -299,7 +320,7 @@ void SamplingKernelTest<T>::verifyResult(SamplingKernelTestParam const& param)
|
||||
if (!skipDecodeHostPtr[batchSlot] && !finishedOrigHostPtr[batchSlot].isFinished()
|
||||
&& !finishedOrigHostPtr[batchSlot].isSkipDecoding())
|
||||
{
|
||||
if (maxTokensPerStep == 1 && !param.returnAllTopK)
|
||||
if (maxTokensPerStep == 1 && !param.returnAllSelectedTokens)
|
||||
{
|
||||
if (generatedEOS)
|
||||
{
|
||||
@ -314,8 +335,6 @@ void SamplingKernelTest<T>::verifyResult(SamplingKernelTestParam const& param)
|
||||
}
|
||||
}
|
||||
|
||||
auto topKTopPVariants = computeTopKTopPVariants(bi, batchSlot, ti, maxTokensPerStep, vocabSize);
|
||||
|
||||
bool found = false;
|
||||
for (auto const& var : topKTopPVariants)
|
||||
{
|
||||
@ -340,11 +359,24 @@ void SamplingKernelTest<T>::verifyResult(SamplingKernelTestParam const& param)
|
||||
EXPECT_EQ(finishedHostPtr[batchSlot].isFinished(), finishedOrigHostPtr[batchSlot].isFinished());
|
||||
}
|
||||
}
|
||||
|
||||
// a boundary check for returnAllSelectedTokens in topP kernel and when TopP selected indices < topK in topK
|
||||
// kernel.
|
||||
if (!skipDecodeHostPtr[batchSlot] && !finishedOrigHostPtr[batchSlot].isFinished()
|
||||
&& !finishedOrigHostPtr[batchSlot].isSkipDecoding())
|
||||
{
|
||||
if (param.returnAllSelectedTokens && (topK == 0 || ki != topK))
|
||||
{
|
||||
auto const idsIdx = ti * mMaxTopK + ki;
|
||||
auto const outputId = outputIdsHostPtr[batchSlot * mMaxSeqLen + idsIdx];
|
||||
EXPECT_EQ(outputId, -1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Cum log probs is not supported for multiple tokens per step or all top K return
|
||||
if (maxTokensPerStep == 1 && !param.returnAllTopK)
|
||||
if (maxTokensPerStep == 1 && !param.returnAllSelectedTokens)
|
||||
{
|
||||
for (int32_t bi = 0; bi < batchSize; ++bi)
|
||||
{
|
||||
|
||||
@ -194,7 +194,7 @@ struct SamplingKernelTestParam
|
||||
bool normalizeLogProbs{false};
|
||||
bool logitsHasProbs{true};
|
||||
int32_t maxTokensPerStep{1};
|
||||
bool returnAllTopK{false};
|
||||
bool returnAllSelectedTokens{false};
|
||||
bool useLogitsPtrs{false};
|
||||
bool isDeterministicTopP{false};
|
||||
|
||||
@ -228,9 +228,9 @@ struct SamplingKernelTestParam
|
||||
return *this;
|
||||
}
|
||||
|
||||
SamplingKernelTestParam& setReturnAllTopK()
|
||||
SamplingKernelTestParam& setReturnAllSelectedTokens()
|
||||
{
|
||||
returnAllTopK = true;
|
||||
returnAllSelectedTokens = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
|
||||
@ -70,10 +70,10 @@ protected:
|
||||
kernelParams.finishedOutput = reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
|
||||
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(*this->mFinishedDevice));
|
||||
kernelParams.skipDecode = bufferCast<bool>(*this->mSkipDecodeDevice);
|
||||
kernelParams.cumLogProbs = params.returnAllTopK || params.maxTokensPerStep > 1
|
||||
kernelParams.cumLogProbs = params.returnAllSelectedTokens || params.maxTokensPerStep > 1
|
||||
? nullptr
|
||||
: bufferCast<float>(*this->mCumLogProbsDevice);
|
||||
kernelParams.outputLogProbs = params.returnAllTopK || params.maxTokensPerStep > 1
|
||||
kernelParams.outputLogProbs = params.returnAllSelectedTokens || params.maxTokensPerStep > 1
|
||||
? nullptr
|
||||
: bufferCast<float>(*this->mOutputLogProbsDevice);
|
||||
kernelParams.curandState = reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*this->mCurandStatesDevice));
|
||||
@ -84,7 +84,7 @@ protected:
|
||||
kernelParams.vocabSizePadded = params.vocabSize;
|
||||
kernelParams.normalizeLogProbs = params.normalizeLogProbs;
|
||||
kernelParams.logitsHasProbs = params.logitsHasProbs;
|
||||
kernelParams.returnAllTopK = params.returnAllTopK;
|
||||
kernelParams.returnAllSelectedTokens = params.returnAllSelectedTokens;
|
||||
|
||||
// Perform batched TopK sampling
|
||||
tk::invokeBatchTopKSampling(kernelParams, this->mStream->get());
|
||||
@ -136,7 +136,7 @@ TYPED_TEST(TopKSamplingKernelTest, CorrectnessTopKMaxTokensPerStep)
|
||||
SamplingKernelTestParam().setBatchSize(16).setVocabSize(4000).setTopK(63).setTopP(1.0f).setMaxTokensPerStep(4));
|
||||
};
|
||||
|
||||
TYPED_TEST(TopKSamplingKernelTest, CorrectnessReturnAllTopK)
|
||||
TYPED_TEST(TopKSamplingKernelTest, CorrectnessReturnAllSelectedTokens)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam()
|
||||
.setBatchSize(16)
|
||||
@ -144,7 +144,18 @@ TYPED_TEST(TopKSamplingKernelTest, CorrectnessReturnAllTopK)
|
||||
.setTopK(10)
|
||||
.setTopP(1.0f)
|
||||
.setMaxTokensPerStep(4)
|
||||
.setReturnAllTopK());
|
||||
.setReturnAllSelectedTokens());
|
||||
};
|
||||
|
||||
TYPED_TEST(TopKSamplingKernelTest, CorrectnessReturnAllSelectedTokensSmallP)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam()
|
||||
.setBatchSize(16)
|
||||
.setVocabSize(50)
|
||||
.setTopK(20)
|
||||
.setTopP(0.3f)
|
||||
.setMaxTokensPerStep(4)
|
||||
.setReturnAllSelectedTokens());
|
||||
};
|
||||
|
||||
TYPED_TEST(TopKSamplingKernelTest, CorrectnessLogitsPtrs)
|
||||
|
||||
@ -64,12 +64,15 @@ private:
|
||||
kernelParams.finishedOutput = reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
|
||||
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(*this->mFinishedDevice));
|
||||
kernelParams.skipDecode = bufferCast<bool>(*this->mSkipDecodeDevice);
|
||||
kernelParams.cumLogProbs = bufferCast<float>(*this->mCumLogProbsDevice);
|
||||
kernelParams.outputLogProbs = bufferCast<float>(*this->mOutputLogProbsDevice);
|
||||
kernelParams.cumLogProbs
|
||||
= params.returnAllSelectedTokens ? nullptr : bufferCast<float>(*this->mCumLogProbsDevice);
|
||||
kernelParams.outputLogProbs
|
||||
= params.returnAllSelectedTokens ? nullptr : bufferCast<float>(*this->mOutputLogProbsDevice);
|
||||
kernelParams.curandState = reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*this->mCurandStatesDevice));
|
||||
kernelParams.batchSize = params.batchSize;
|
||||
kernelParams.maxBatchSize = maxBatchSize;
|
||||
kernelParams.vocabSizePadded = params.vocabSize;
|
||||
kernelParams.returnAllSelectedTokens = params.returnAllSelectedTokens;
|
||||
|
||||
// Perform batched TopP sampling
|
||||
tk::invokeBatchTopPSampling<T>(kernelParams, this->mStream->get());
|
||||
@ -80,26 +83,36 @@ TYPED_TEST_SUITE(TopPSamplingKernelTest, FloatAndHalfTypes);
|
||||
|
||||
TYPED_TEST(TopPSamplingKernelTest, CorrectnessSmallP)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.2f));
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.2f));
|
||||
};
|
||||
|
||||
TYPED_TEST(TopPSamplingKernelTest, CorrectnessLargeP)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(0.9f));
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(0.9f));
|
||||
};
|
||||
|
||||
TYPED_TEST(TopPSamplingKernelTest, CorrectnessAncestral)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopP(1.0f));
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(6).setVocabSize(4).setTopK(0).setTopP(1.0f));
|
||||
};
|
||||
|
||||
TYPED_TEST(TopPSamplingKernelTest, CorrectnessLargeVocabSmallP)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.2f));
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.2f));
|
||||
};
|
||||
|
||||
TYPED_TEST(TopPSamplingKernelTest, CorrectnessLargeVocabLargeP)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopP(0.9f));
|
||||
this->runTest(SamplingKernelTestParam().setBatchSize(32).setVocabSize(51200).setTopK(0).setTopP(0.9f));
|
||||
};
|
||||
|
||||
TYPED_TEST(TopPSamplingKernelTest, CorrectnessReturnAllSelectedTokens)
|
||||
{
|
||||
this->runTest(SamplingKernelTestParam()
|
||||
.setBatchSize(16)
|
||||
.setVocabSize(50)
|
||||
.setTopK(0)
|
||||
.setTopP(0.8f)
|
||||
.setReturnAllSelectedTokens());
|
||||
};
|
||||
} // end of namespace
|
||||
|
||||
@ -164,6 +164,10 @@ struct cutlassTypeMapper
|
||||
return ss.str(); \
|
||||
} \
|
||||
};
|
||||
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int8Groupwise, "FP16Int8Groupwise", half, uint8_t, 8,
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
|
||||
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int8Groupwise, "BF16Int8Groupwise", __nv_bfloat16, uint8_t, 8,
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
|
||||
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::FP16Int4Groupwise, "FP16Int4Groupwise", half, cutlass::uint4b_t, 4,
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS);
|
||||
CUTLASS_TYPE_MAPPER_REGISTRY(wo::KernelType::BF16Int4Groupwise, "BF16Int4Groupwise", __nv_bfloat16, cutlass::uint4b_t,
|
||||
@ -367,8 +371,8 @@ bool benchmark_and_verify(int m, int n, int k, int groupsize, int warmup, int it
|
||||
d_out.copy_to(h_out2.data());
|
||||
float quant_scale = 1.f / (1 << (WSizeInBits - 1));
|
||||
bool pass = compare<AType>(h_out1.data(), h_out2.data(), m * n, quant_scale);
|
||||
printf(
|
||||
"cuda kernel cost time %.6f, cutlass kernel cost time %.6f, cuda speedup %.3f\n", time1, time2, time2 / time1);
|
||||
printf("cuda kernel cost time %.6f, cutlass kernel cost time %.6f, cuda speedup %.3f\n\n", time1, time2,
|
||||
time2 / time1);
|
||||
return pass;
|
||||
}
|
||||
|
||||
@ -392,6 +396,10 @@ TEST(Kernel, WeightOnly)
|
||||
EXPECT_TRUE(pass);
|
||||
if (arch >= 75)
|
||||
{
|
||||
pass = benchmark_and_verify<wo::KernelType::FP16Int8Groupwise>(m, n, k, 64, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark_and_verify<wo::KernelType::FP16Int8Groupwise>(m, n, k, 128, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark_and_verify<wo::KernelType::FP16Int4Groupwise>(m, n, k, 64, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark_and_verify<wo::KernelType::FP16Int4Groupwise>(m, n, k, 128, warmup, iter);
|
||||
@ -399,6 +407,10 @@ TEST(Kernel, WeightOnly)
|
||||
#if defined(ENABLE_BF16)
|
||||
if (arch >= 80)
|
||||
{
|
||||
pass = benchmark_and_verify<wo::KernelType::BF16Int8Groupwise>(m, n, k, 64, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark_and_verify<wo::KernelType::BF16Int8Groupwise>(m, n, k, 128, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark_and_verify<wo::KernelType::BF16Int4Groupwise>(m, n, k, 64, warmup, iter);
|
||||
EXPECT_TRUE(pass);
|
||||
pass = benchmark_and_verify<wo::KernelType::BF16Int4Groupwise>(m, n, k, 128, warmup, iter);
|
||||
|
||||
@ -73,6 +73,8 @@ void BaseSamplingLayerTest<T>::setup(uint64_t seed, TestSamplingParams const& pa
|
||||
trk::invokeFill(*mCumLogProbsDevice, float{0.0f}, *mStream);
|
||||
trk::invokeFill(*mOutputLogProbsDevice, float{0.0f}, *mStream);
|
||||
trk::invokeFill(*mEndIdsDevice, int32_t{mEndId}, *mStream);
|
||||
tk::invokeCurandInitialize(reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStatesDevice)), nullptr,
|
||||
mMaxBatchSize, seed, mStream->get());
|
||||
|
||||
auto batchSlotsPtr = bufferCast<int32_t>(*mBatchSlots);
|
||||
for (SizeType32 bi = 0; bi < mBatchSize; ++bi)
|
||||
|
||||
@ -720,17 +720,17 @@ void LookaheadDecodingLayerTest::verifyDecode()
|
||||
BufferRange<SizeType32> cumSumRange(*mNumNewTokensCumSum);
|
||||
BufferRange<SizeType32> pathOffsetsRange(*mPathsOffsets);
|
||||
PRINT_VALUES(mNumNewTokensCumSum);
|
||||
for (SizeType32 gbi = 0; gbi < mTestParam.maxBatchSize; gbi++)
|
||||
for (SizeType32 bi = 0; bi < batchSize; bi++)
|
||||
{
|
||||
SizeType32 pathOffsetBegin = cumSumRange[gbi];
|
||||
SizeType32 pathOffsetEnd = cumSumRange[gbi + 1];
|
||||
auto gbi = BufferRange<SizeType32>(*mBatchSlots)[bi];
|
||||
SizeType32 pathOffsetBegin = cumSumRange[bi];
|
||||
SizeType32 pathOffsetEnd = cumSumRange[bi + 1];
|
||||
TensorPtr golden = ITensor::at(mGoldenSampledTokens, {gbi});
|
||||
auto sequenceLength = BufferLocation<SizeType32>(*mSequenceLengths).at(gbi);
|
||||
auto numNewTokens = BufferLocation<SizeType32>(*mNumNewTokens).at(gbi);
|
||||
TensorPtr newTokens = ITensor::slice(mOutputIds, {gbi, 0, sequenceLength - numNewTokens}, numNewTokens);
|
||||
BufferRange<SizeType32> goldenRange(*ITensor::at(mGoldenSampledTokens, {gbi}));
|
||||
BufferRange<TokenIdType> newTokensRange(
|
||||
*ITensor::slice(mOutputIds, {gbi, 0, sequenceLength - numNewTokens}, numNewTokens));
|
||||
BufferRange<TokenIdType> newTokensRange(*newTokens);
|
||||
|
||||
SizeType32 ni = 1;
|
||||
for (SizeType32 poi = pathOffsetBegin; poi < pathOffsetEnd; poi++)
|
||||
|
||||
@ -207,7 +207,7 @@ TEST(LookaheadRandomllm, gpuSampling)
|
||||
kernelParams.vocabSizePadded = vocabSize;
|
||||
kernelParams.normalizeLogProbs = false;
|
||||
kernelParams.logitsHasProbs = false;
|
||||
kernelParams.returnAllTopK = false;
|
||||
kernelParams.returnAllSelectedTokens = false;
|
||||
|
||||
PRINT_TOKENS(mEndIds);
|
||||
PRINT_VALUES(mTokensPerStep);
|
||||
|
||||
@ -101,6 +101,9 @@ def add_parallel_info(report, parallel):
|
||||
document.write(report, encoding="UTF-8", xml_declaration=True)
|
||||
|
||||
|
||||
default_test_parallel = 2
|
||||
|
||||
|
||||
def parallel_run_ctest(
|
||||
command: _tp.Sequence[str],
|
||||
cwd: _pl.Path,
|
||||
@ -108,7 +111,7 @@ def parallel_run_ctest(
|
||||
shell=False,
|
||||
env=None,
|
||||
timeout=None,
|
||||
parallel=2,
|
||||
parallel=default_test_parallel,
|
||||
) -> None:
|
||||
if parallel == 1:
|
||||
return run_command(command,
|
||||
@ -576,7 +579,16 @@ def run_unit_tests(build_dir: _pl.Path, timeout=1800):
|
||||
excluded_tests.append("Encoder")
|
||||
excluded_tests.append("EncDec")
|
||||
ctest.extend(["-E", "|".join(excluded_tests)])
|
||||
parallel_run_ctest(ctest, cwd=build_dir, env=cpp_env, timeout=timeout)
|
||||
|
||||
parallel = default_test_parallel
|
||||
if parallel_override := _os.environ.get("LLM_TEST_PARALLEL_OVERRIDE", None):
|
||||
parallel = int(parallel_override)
|
||||
|
||||
parallel_run_ctest(ctest,
|
||||
cwd=build_dir,
|
||||
env=cpp_env,
|
||||
timeout=timeout,
|
||||
parallel=parallel)
|
||||
|
||||
|
||||
def run_single_gpu_tests(build_dir: _pl.Path,
|
||||
@ -634,7 +646,17 @@ def run_single_gpu_tests(build_dir: _pl.Path,
|
||||
ctest.extend(["-R", "|".join(included_tests)])
|
||||
if excluded_tests:
|
||||
ctest.extend(["-E", "|".join(excluded_tests)])
|
||||
parallel_run_ctest(ctest, cwd=build_dir, env=cpp_env, timeout=timeout)
|
||||
|
||||
parallel = default_test_parallel
|
||||
if parallel_override := _os.environ.get("LLM_TEST_PARALLEL_OVERRIDE",
|
||||
None):
|
||||
parallel = int(parallel_override)
|
||||
|
||||
parallel_run_ctest(ctest,
|
||||
cwd=build_dir,
|
||||
env=cpp_env,
|
||||
timeout=timeout,
|
||||
parallel=parallel)
|
||||
if run_gpt:
|
||||
xml_output_file = build_dir / "results-single-gpu-disagg-executor_gpt.xml"
|
||||
trt_model_test = produce_mpirun_command(
|
||||
|
||||
@ -62,7 +62,7 @@ COPY benchmarks benchmarks
|
||||
COPY scripts scripts
|
||||
COPY tensorrt_llm tensorrt_llm
|
||||
COPY 3rdparty 3rdparty
|
||||
COPY setup.py requirements.txt requirements-dev.txt ./
|
||||
COPY .gitmodules setup.py requirements.txt requirements-dev.txt ./
|
||||
|
||||
# Create cache directories for pip and ccache
|
||||
RUN mkdir -p /root/.cache/pip /root/.cache/ccache
|
||||
|
||||
@ -6,6 +6,10 @@ set -ex
|
||||
# and closest to the version specified in
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-07.html#rel-24-07
|
||||
TORCH_VERSION="2.4.0"
|
||||
# Check the compatible torchvision from
|
||||
# https://github.com/pytorch/vision/tree/main?tab=readme-ov-file#installation
|
||||
# and also confirm with https://pypi.org/pypi/torchvision/0.19.0/json
|
||||
TORCHVISION_VERSION="0.19.0"
|
||||
SYSTEM_ID=$(grep -oP '(?<=^ID=).+' /etc/os-release | tr -d '"')
|
||||
|
||||
prepare_environment() {
|
||||
@ -35,29 +39,44 @@ restore_environment() {
|
||||
|
||||
install_from_source() {
|
||||
if [[ $SYSTEM_ID == *"centos"* ]]; then
|
||||
VERSION_ID=$(grep -oP '(?<=^VERSION_ID=).+' /etc/os-release | tr -d '"')
|
||||
if [[ $VERSION_ID == "7" ]]; then
|
||||
echo "Installation from PyTorch source codes cannot be supported..."
|
||||
exit 1
|
||||
fi
|
||||
VERSION_ID=$(grep -oP '(?<=^VERSION_ID=).+' /etc/os-release | tr -d '"')
|
||||
if [[ $VERSION_ID == "7" ]]; then
|
||||
echo "Installation from PyTorch source codes cannot be supported..."
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
prepare_environment $1
|
||||
export _GLIBCXX_USE_CXX11_ABI=$1
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
|
||||
|
||||
export _GLIBCXX_USE_CXX11_ABI=$1
|
||||
|
||||
export TORCH_CUDA_ARCH_LIST="8.0;9.0"
|
||||
export PYTORCH_BUILD_VERSION=${TORCH_VERSION}
|
||||
export PYTORCH_BUILD_NUMBER=0
|
||||
pip3 uninstall -y torch
|
||||
cd /tmp
|
||||
git clone --depth 1 --branch v$TORCH_VERSION https://github.com/pytorch/pytorch
|
||||
git clone --depth 1 --branch v${TORCH_VERSION} https://github.com/pytorch/pytorch
|
||||
cd pytorch
|
||||
git submodule sync && git submodule update --init --recursive
|
||||
pip3 install -r requirements.txt
|
||||
python3 setup.py install
|
||||
cd /tmp && rm -rf /tmp/pytorch
|
||||
|
||||
export PYTORCH_VERSION=${PYTORCH_BUILD_VERSION}
|
||||
export FORCE_CUDA=1
|
||||
export BUILD_VERSION=${TORCHVISION_VERSION}
|
||||
pip3 uninstall -y torchvision
|
||||
cd /tmp
|
||||
git clone --depth 1 --branch v${TORCHVISION_VERSION} https://github.com/pytorch/vision
|
||||
cd vision
|
||||
python3 setup.py install
|
||||
cd /tmp && rm -rf /tmp/vision
|
||||
|
||||
restore_environment $1
|
||||
}
|
||||
|
||||
install_from_pypi() {
|
||||
pip3 install torch==${TORCH_VERSION}
|
||||
pip3 uninstall -y torch torchvision
|
||||
pip3 install torch==${TORCH_VERSION} torchvision==${TORCHVISION_VERSION}
|
||||
}
|
||||
|
||||
case "$1" in
|
||||
|
||||
@ -15,37 +15,6 @@ The following sections provide an overview of the main classes defined in the Ex
|
||||
|
||||
The `Executor` class is responsible for receiving requests from the client, and providing responses for those requests. The executor is constructed by providing a path to a directory containing the TensorRT-LLM engine or buffers containing the engine and the model JSON configuration. The client can create requests and enqueue those requests for execution using the `enqueueRequest` or `enqueueRequests` methods of the `Executor` class. Enqueued requests will be scheduled for execution by the executor, and multiple independent requests can be batched together at every iteration of the main execution loop (a process often referred to as continuous batching or iteration-level batching). Responses for a particular request can be awaited for by calling the `awaitResponses` method, and by providing the request id. Alternatively, responses for any requests can be awaited for by omitting to provide the request id when calling `awaitResponses`. The `Executor` class also allows to cancel requests using the `cancelRequest` method and to obtain per-iteration and per-request statistics using the `getLatestIterationStats`.
|
||||
|
||||
#### Logits Post-Processor (optional)
|
||||
|
||||
Users can alter the logits produced by the network, by providing a map of named callbacks of the form:
|
||||
|
||||
```
|
||||
std::unordered_map<std::string, function<Tensor(IdType, Tensor&, BeamTokens const&, StreamPtr const&, std::optional<IdType>)>>
|
||||
```
|
||||
to an instance of `LogitsPostProcessorConfig`. The map key is the name associated with that logits post-processing callback. Each request can then specify the name of the logits post-processor to use for that particular request, if any.
|
||||
|
||||
The first argument to the callback is the request id, second is the logits tensor, third are the tokens produced by the request so far, fourth is the operation stream used by the logits tensor, and last one is an optional client id. The callback returns a modified tensor of logits.
|
||||
|
||||
Users *must* use the stream to access the logits tensor. For example, performing a addition with a bias tensor should be enqueued on that stream.
|
||||
Alternatively, users may call `stream->synchronize()`, however, that will slow down the entire execution pipeline.
|
||||
|
||||
Multiple requests can share same client id and callback can use different logic based on client id.
|
||||
|
||||
We also provide a batched version that allows altering logits of multiple requests in a batch. This allows further optimizations and reduces callback overheads.
|
||||
|
||||
```
|
||||
std::function<void(std::vector<IdType> const&, std::vector<Tensor>&, std::vector<std::reference_wrapper<BeamTokens const>> const&, StreamPtr const&, std::vector<std::optional<IdType>> const&)>
|
||||
```
|
||||
|
||||
A single batched callback can be specified in `LogitsPostProcessorConfig`. Each request can opt to apply this callback by specifying the name of the logits
|
||||
post-processor as `Request::kBatchedPostProcessorName`.
|
||||
|
||||
Note: Neither callback variant is supported with the `STATIC` batching type for the moment.
|
||||
|
||||
In a multi-GPU run, callback is invoked on all tensor parallel ranks (in last pipeline rank) by default.
|
||||
For correct execution, user should replicate client-side state accessed by callback on all tensor parallel ranks.
|
||||
If replication is expensive or infeasible, use `LogitsPostProcessorConfig::setReplicate(false)` to invoke callback only on first tensor parallel rank.
|
||||
|
||||
### The Request Class
|
||||
|
||||
The `Request` class is used to define properties of the request, such as the input token ids and the maximum number of tokens to generate. The `streaming` parameter can be used to indicate if the request should generate a response for each new generated tokens (`streaming = true`) or only after all tokens have been generated (`streaming = false`). Other mandatory parameters of the request include the sampling configuration (defined by the `SamplingConfig` class) which contains parameters controlling the decoding process and the output configuration (defined by the `OutputConfig` class) which controls what information should be included in the `Result` for a particular response.
|
||||
@ -83,6 +52,32 @@ The executor can process requests with different beam widths if the following co
|
||||
|
||||
The request queue of the executor must be empty to allow it to reconfigure itself for a new beam width. This reconfiguration will happen automatically when requests with a new beam width are enqueued. If requests with different beam widths are enqueued at the same time, the executor will encounter an error and terminate all requests prematurely.
|
||||
|
||||
### Controlling output with Logits Post-Processor
|
||||
|
||||
Optionally, you can alter the logits produced by the network by providing an instance of `Executor::LogitsPostProcessorConfig`. For instance, this feature can be used to generate JSON formatted output. {cpp:class}`Executor::LogitsPostProcessorConfig <tensorrt_llm::executor::LogitsPostProcessorConfig>` specifies a map of named callbacks in the following form
|
||||
|
||||
```cpp
|
||||
std::unordered_map<std::string, function<Tensor(IdType, Tensor&, BeamTokens const&, StreamPtr const&, std::optional<IdType>)>>
|
||||
```
|
||||
|
||||
The map key is the name associated with that logits post-processing callback. Each request can then specify the name of the logits post-processor to use for that particular request, if any.
|
||||
|
||||
The first argument to the callback is the request id, second is the logits tensor, third are the tokens produced by the request so far, fourth is the operation stream used by the logits tensor, and last one is an optional client id. The callback returns a modified tensor of logits. Multiple requests can share same client id and callback can use different logic based on client id.
|
||||
|
||||
You must use the stream to access the logits tensor. For example, to perform an addition with a bias tensor, the addition operation is enqueued on that stream. Alternatively, you can call `stream->synchronize()`, however, that will slow down the entire execution pipeline.
|
||||
|
||||
The executor also includes a {cpp:class}`LogitsPostProcessorBatched <tensorrt_llm::executor::LogitsPostProcessorBatched>` method that enables altering logits of multiple requests in a batch. The batched method allows further optimizations and reduces callback overheads.
|
||||
|
||||
```cpp
|
||||
std::function<void(std::vector<IdType> const&, std::vector<Tensor>&, std::vector<std::reference_wrapper<BeamTokens const>> const&, StreamPtr const&, std::vector<std::optional<IdType>> const&)>
|
||||
```
|
||||
|
||||
A single batched callback can be specified in `LogitsPostProcessorConfig`. Each request can opt to apply this callback by specifying the name of the logits post-processor as `Request::kBatchedPostProcessorName`.
|
||||
|
||||
Note: Neither callback variant is supported with the `STATIC` batching type for the moment.
|
||||
|
||||
In a multi-GPU run, the callback is invoked on all ranks in the first tensor-parallel group, by default. To ensure correct execution, replicate the client-side state that is accessed by the callback on these ranks. If replication is expensive or infeasible, use `LogitsPostProcessorConfig::setReplicate(false)` to invoke the callback only on rank 0. The executor broadcasts the sampled tokens internally to ensure correct execution.
|
||||
|
||||
## C++ Executor API Example
|
||||
|
||||
Two C++ examples are provided that shows how to use the Executor API and can be found in the [`examples/cpp/executor`](source:examples/cpp/executor/) folder.
|
||||
|
||||
@ -304,7 +304,7 @@ For guidance on constructing and executing Medusa with the Python runtime, consu
|
||||
|
||||
- TensorRT-LLM supports Medusa only for Vicuna (fine tuned LLaMA).
|
||||
However, similar to any new model, you can follow the same approach to define your own Medusa model and deploy with TensorRT-LLM.
|
||||
- We match only tokens during the validation phasem that is `medusa_temperature=0`.
|
||||
- We match only tokens during the validation phase that is `medusa_temperature=0`.
|
||||
- Beam search is **not** compatible with Medusa.
|
||||
|
||||
|
||||
|
||||
@ -93,13 +93,13 @@ def generate_llmapi():
|
||||
doc_dir.mkdir(exist_ok=True)
|
||||
doc_path = doc_dir / "index.rst"
|
||||
|
||||
hlapi_all_file = root_dir / "tensorrt_llm/hlapi/__init__.py"
|
||||
public_classes_names = extract_all_and_eval(hlapi_all_file)['__all__']
|
||||
llmapi_all_file = root_dir / "tensorrt_llm/llmapi/__init__.py"
|
||||
public_classes_names = extract_all_and_eval(llmapi_all_file)['__all__']
|
||||
|
||||
content = underline("API Reference", "-") + "\n\n"
|
||||
for cls_name in public_classes_names:
|
||||
cls_name = cls_name.strip()
|
||||
content += (f".. autoclass:: tensorrt_llm.hlapi.{cls_name}\n"
|
||||
content += (f".. autoclass:: tensorrt_llm.llmapi.{cls_name}\n"
|
||||
" :members:\n"
|
||||
" :undoc-members:\n"
|
||||
" :special-members: __init__\n"
|
||||
|
||||
@ -71,3 +71,7 @@ We recommend checking out the [v0.13.0 tag](https://github.com/NVIDIA/TensorRT-L
|
||||
This may be caused by an outdated Microsoft Visual C++ Redistributable Version. Please install
|
||||
[the latest MSVC](https://learn.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-170#latest-microsoft-visual-c-redistributable-version)
|
||||
and retry. Check the system path to make sure the latest version installed in `System32` is searched first. Check dependencies to make sure no other packages are using an outdated version (e.g. package `pyarrow` might contain an outdated MSCV DLL).
|
||||
|
||||
2. OSError: [WinError 126] The specified module could not be found. Error loading “...\Lib\site-packages\torch\lib\fbgemm.dll” or one of its dependencies.
|
||||
|
||||
Installing the latest [Build Tools for Visual Studio 2022] (https://visualstudio.microsoft.com/downloads/#build-tools-for-visual-studio-2022) will resolve the issue.
|
||||
|
||||
@ -5,7 +5,7 @@
|
||||
TensorRT-LLM can quantize the Hugging Face model automatically. By setting the appropriate flags in the `LLM` instance. For example, to perform an Int4 AWQ quantization, the following code triggers the model quantization. Please refer to complete list of [supported flags](https://nvidia.github.io/TensorRT-LLM/_modules/tensorrt_llm/quantization/mode.html#QuantAlgo) and acceptable values.
|
||||
|
||||
``` python
|
||||
from tensorrt_llm.hlapi import QuantConfig, QuantAlgo
|
||||
from tensorrt_llm.llmapi import QuantConfig, QuantAlgo
|
||||
|
||||
quant_config = QuantConfig(quant_algo=QuantAlgo.W4A16_AWQ)
|
||||
|
||||
@ -14,12 +14,12 @@ llm = LLM(<model-dir>, quant_config=quant_config)
|
||||
|
||||
## Sampling
|
||||
|
||||
SamplingParams can customize the sampling strategy to control LLM generated responses, such as beam search, temperature, and [others](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/hlapi/utils.py#L55-L76).
|
||||
SamplingParams can customize the sampling strategy to control LLM generated responses, such as beam search, temperature, and [others](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/llmapi/utils.py#L55-L76).
|
||||
|
||||
As an example, to enable beam search with a beam size of 4, set the `sampling_params` as follows:
|
||||
|
||||
```python
|
||||
from tensorrt_llm.hlapi import LLM, SamplingParams, BuildConfig
|
||||
from tensorrt_llm.llmapi import LLM, SamplingParams, BuildConfig
|
||||
|
||||
build_config = BuildConfig()
|
||||
build_config.max_beam_width = 4
|
||||
@ -38,7 +38,7 @@ for output in llm.generate(<prompt>, sampling_params=sampling_params):
|
||||
* [SamplingConfig](https://nvidia.github.io/TensorRT-LLM/_cpp_gen/runtime.html#_CPPv4N12tensorrt_llm7runtime14SamplingConfigE)
|
||||
* [OutputConfig](https://nvidia.github.io/TensorRT-LLM/_cpp_gen/executor.html#_CPPv4N12tensorrt_llm8executor12OutputConfigE)
|
||||
|
||||
Refer to the [class documentation](https://nvidia.github.io/TensorRT-LLM/llm-api/index.html#tensorrt_llm.hlapi.SamplingParams) for more details.
|
||||
Refer to the [class documentation](https://nvidia.github.io/TensorRT-LLM/llm-api/index.html#tensorrt_llm.llmapi.SamplingParams) for more details.
|
||||
|
||||
## Build Configuration
|
||||
|
||||
@ -55,11 +55,11 @@ Refer to the [buildconfig documentation](https://github.com/NVIDIA/TensorRT-LLM/
|
||||
|
||||
## Runtime Customization
|
||||
|
||||
Similar to `build_config`, you can also customize the runtime configuration with the `runtime_config`, `peft_cache_config` or other [arguments](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/hlapi/llm_utils.py#L186-L223) borrowed from the lower-level APIs. These runtime configuration options provide additional flexibility with respect to KV cache management, GPU memory allocation and so on. Refer to the following example:
|
||||
Similar to `build_config`, you can also customize the runtime configuration with the `runtime_config`, `peft_cache_config` or other [arguments](https://github.com/NVIDIA/TensorRT-LLM/blob/main/tensorrt_llm/llmapi/llm_utils.py#L186-L223) borrowed from the lower-level APIs. These runtime configuration options provide additional flexibility with respect to KV cache management, GPU memory allocation and so on. Refer to the following example:
|
||||
|
||||
|
||||
```python
|
||||
from tensorrt_llm.hlapi import LLM, KvCacheConfig
|
||||
from tensorrt_llm.llmapi import LLM, KvCacheConfig
|
||||
|
||||
llm = LLM(<llama_model_path>,
|
||||
kv_cache_config=KvCacheConfig(
|
||||
|
||||
@ -13,6 +13,7 @@ The LLM API can be used for both offline or online usage. See more examples of t
|
||||
* [LLM Generate Async Streaming](https://nvidia.github.io/TensorRT-LLM/llm-api-examples/llm_generate_async_streaming.html)
|
||||
* [LLM Quantization](https://nvidia.github.io/TensorRT-LLM/llm-api-examples/llm_quantization.html)
|
||||
* [LLM Auto Parallel](https://nvidia.github.io/TensorRT-LLM/llm-api-examples/llm_auto_parallel.html)
|
||||
* [LLM Logits Processor](https://nvidia.github.io/TensorRT-LLM/llm-api-examples/llm_logits_processor.html)
|
||||
|
||||
For more details on how to fully utilize this API, check out:
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 175 KiB |
BIN
docs/source/media/image-10-07-2024.png
Normal file
BIN
docs/source/media/image-10-07-2024.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 372 KiB |
@ -30,7 +30,8 @@ TensorRT-LLM consists of pre– and post-processing steps and multi-GPU multi-no
|
||||
|
||||
### Latest GPU Support
|
||||
|
||||
TensorRT-LLM supports GPUs based on the NVIDIA Hopper, NVIDIA Ada Lovelace, NVIDIA Ampere, NVIDIA Turing, and NVIDIA Volta architectures. Certain limitations may, however, apply. Refer to the {ref}`support-matrix` for more information.
|
||||
TensorRT-LLM supports GPUs based on the NVIDIA Hopper, NVIDIA Ada Lovelace, and NVIDIA Ampere architectures.
|
||||
Certain limitations might apply. Refer to the {ref}`support-matrix` for more information.
|
||||
|
||||
### Native Windows Support
|
||||
|
||||
|
||||
@ -34,83 +34,191 @@ and shows the throughput client-server scenario under maximum load.
|
||||
|
||||
The performance numbers below were collected using the steps described in this document.
|
||||
|
||||
**All data in the table below was generated using version 0.12.0 and presents token throughput in tokens/second.**
|
||||
**All data in the table below was generated using version 0.13.0 and presents token throughput in tokens/second.**
|
||||
|
||||
| | | | | | | | | |
|
||||
| ------------ | ------------------------ | ------------- | --------------- | ----------- | -------------- | -------------- | -------------- | ------- |
|
||||
| | | **GPU** | H200 141GB HBM3 | GH200 120GB | H100 80GB HBM3 | H100 80GB HBM3 | A100-SXM4-80GB | L40S |
|
||||
| | | **Precision** | FP8 | FP8 | FP8 | FP16 | FP16 | FP8 |
|
||||
| **Model** | **Input/Output Lengths** | **TP** | | | | | | |
|
||||
| GPTJ 6B | 128/128 | 1 | 24834.76 | 22454.79 | 24429.55 | 13085.91 | 5864.81 | 7647.24 |
|
||||
| | 128/2048 | 1 | 8348.93 | 6656.25 | 7831.38 | 3882.21 | 2194.57 | 1843.91 |
|
||||
| | 128/4096 | 1 | 5062.80 | 3678.91 | 3968.98 | 2046.53 | 1118.22 | 980.67 |
|
||||
| | 2048/128 | 1 | 2776.53 | 2491.03 | 2724.38 | 1488.56 | 657.01 | 741.06 |
|
||||
| | 2048/2048 | 1 | 3631.54 | 2994.81 | 3004.17 | 1280.54 | 854.37 | 754.16 |
|
||||
| LLaMA v2 7B | 128/128 | 1 | 19706.35 | 17803.58 | 19068.99 | 11393.48 | 5272.39 | 6345.72 |
|
||||
| | 128/2048 | 1 | 7651.12 | 5472.34 | 6610.03 | 2964.65 | 1785.79 | 1551.37 |
|
||||
| | 128/4096 | 1 | 4424.90 | 3271.61 | 3649.38 | 1596.87 | 957.12 | 817.24 |
|
||||
| | 2048/128 | 1 | 2385.54 | 2035.42 | 2271.63 | 1189.06 | 564.77 | 625.09 |
|
||||
| | 2048/2048 | 1 | 3191.34 | 2726.29 | 2802.41 | 1243.96 | 735.19 | 641.56 |
|
||||
| LLaMA v3 8B | 128/128 | 1 | 28288.75 | 25420.52 | 27399.75 | 15567.44 | 6586.88 | 8745.80 |
|
||||
| | 128/2048 | 1 | 23230.62 | 16426.68 | 19198.73 | 8817.39 | 4882.13 | 5084.49 |
|
||||
| | 128/4096 | 1 | 16144.44 | 9832.66 | 12084.97 | 5352.37 | 3079.90 | 2755.13 |
|
||||
| | 2048/128 | 1 | 3623.79 | 3290.22 | 3463.26 | 1852.48 | 781.63 | 980.86 |
|
||||
| | 2048/2048 | 1 | 11093.62 | 7573.35 | 8894.11 | 3986.83 | 2268.13 | 2051.79 |
|
||||
| Mistral 7B | 128/128 | 1 | 30223.01 | 27696.90 | 29788.46 | 16319.25 | 6807.02 | 9612.58 |
|
||||
| | 128/2048 | 1 | 24989.54 | 17942.29 | 20509.72 | 9982.01 | 5296.02 | 5444.89 |
|
||||
| | 128/4096 | 1 | 17036.14 | 10846.03 | 12807.80 | 5718.89 | 3241.33 | 2931.17 |
|
||||
| | 2048/128 | 1 | 3678.80 | 3294.02 | 3521.71 | 1887.75 | 786.43 | 1002.49 |
|
||||
| | 2048/2048 | 1 | 11510.54 | 8357.75 | 9214.61 | 4284.82 | 2363.25 | 2154.26 |
|
||||
| Mixtral 8x7B | 128/128 | 2 | 24895.03 | 8785.80 | 24394.71 | 15529.86 | 5921.41 | |
|
||||
| | | 4 | 42014.24 | 38828.53 | 40197.42 | 28132.17 | 11414.95 | 6820.26 |
|
||||
| | 128/2048 | 2 | 29389.21 | 5474.69 | 20873.02 | 7066.02 | 4306.98 | |
|
||||
| | | 4 | 52348.10 | 41573.66 | 40588.05 | 21285.72 | 10974.83 | 7467.15 |
|
||||
| | 128/4096 | 2 | 21480.27 | 2277.66 | 12838.28 | 3986.01 | 2400.11 | |
|
||||
| | | 4 | 39182.04 | 28626.55 | 28337.31 | 12447.13 | 7278.89 | 5233.43 |
|
||||
| | 2048/128 | 2 | 2934.44 | 1003.51 | 2898.27 | 1834.77 | 693.51 | |
|
||||
| | | 4 | 5152.40 | 4724.01 | 5028.61 | 3393.18 | 1362.93 | 805.49 |
|
||||
| | 2048/2048 | 2 | 14029.17 | 2671.88 | 10479.45 | 3531.31 | 1945.88 | |
|
||||
| | | 4 | 25436.05 | 20302.56 | 19971.72 | 9622.66 | 5221.74 | 3616.30 |
|
||||
| LLaMA v3 70B | 128/128 | 2 | 5386.88 | | | 2959.22 | 1301.14 | |
|
||||
| | | 4 | 8944.26 | 8587.01 | 8642.05 | 5966.47 | 2413.95 | |
|
||||
| | | 8 | 16125.20 | | 15397.47 | 10406.55 | 4548.32 | 1364.08 |
|
||||
| | 128/2048 | 2 | 7007.27 | | | 720.73 | 500.83 | |
|
||||
| | | 4 | 12906.75 | 10761.53 | 8978.95 | 4736.61 | 2380.02 | |
|
||||
| | | 8 | 19417.37 | | 14822.93 | 6672.14 | 3815.08 | 1809.40 |
|
||||
| | 128/4096 | 2 | 6183.85 | | | 369.29 | 251.24 | |
|
||||
| | | 4 | 8859.54 | 7270.77 | 6073.48 | 2969.99 | 1634.82 | |
|
||||
| | | 8 | 13969.95 | | 10094.57 | 4358.77 | 2847.54 | 1313.78 |
|
||||
| | 2048/128 | 2 | 696.59 | | | 301.46 | 140.88 | |
|
||||
| | | 4 | 1044.35 | 1000.55 | 1022.06 | 681.72 | 278.76 | |
|
||||
| | | 8 | 2018.47 | | 1933.15 | 1279.46 | 543.73 | 163.36 |
|
||||
| | 2048/2048 | 2 | 3525.18 | | | | 87.54 | |
|
||||
| | | 4 | 6550.76 | 4859.38 | 4870.26 | 2379.66 | 1209.69 | |
|
||||
| | | 8 | 9706.95 | | 7670.04 | 3692.41 | 2192.28 | 895.23 |
|
||||
| LLaMA v2 70B | 128/128 | 2 | 6355.16 | | | 2927.71 | 1374.05 | |
|
||||
| | | 4 | 10818.97 | 10819.19 | 10754.99 | 6603.10 | 2765.94 | |
|
||||
| | | 8 | 16667.25 | | 16074.84 | 11369.11 | 4796.89 | 1402.92 |
|
||||
| | 128/2048 | 2 | 6185.77 | | | 668.52 | 445.04 | |
|
||||
| | | 4 | 12884.76 | 11356.48 | 8870.71 | 5067.06 | 2710.53 | |
|
||||
| | | 8 | 19053.13 | | 17534.62 | 8805.16 | 5665.93 | 2203.33 |
|
||||
| | 128/4096 | 2 | 4873.24 | | | 334.10 | 215.70 | |
|
||||
| | | 4 | 8664.90 | 6311.85 | 7564.99 | 3354.02 | 1884.46 | |
|
||||
| | | 8 | 15110.32 | | 10584.03 | 5373.10 | 3672.80 | 1787.76 |
|
||||
| | 2048/128 | 2 | 732.09 | | | 302.49 | 141.70 | |
|
||||
| | | 4 | 1272.90 | 1269.58 | 1265.80 | 774.93 | 320.79 | |
|
||||
| | | 8 | 2015.77 | | 1943.96 | 1355.78 | 569.48 | 165.52 |
|
||||
| | 2048/2048 | 2 | 3508.50 | | | 321.95 | 212.97 | |
|
||||
| | | 4 | 6642.69 | 5545.83 | 4889.26 | 2439.10 | 1276.58 | |
|
||||
| | | 8 | 10178.71 | | 8071.77 | 4275.74 | 2589.60 | 1083.45 |
|
||||
| Falcon 180B | 128/128 | 4 | 5129.55 | | | | | |
|
||||
| | | 8 | 8370.98 | | 8268.72 | | | |
|
||||
| | 128/2048 | 4 | 7823.79 | | | | | |
|
||||
| | | 8 | 13278.59 | | 13107.48 | | | |
|
||||
| | 128/4096 | 4 | 6374.10 | | | | | |
|
||||
| | | 8 | 12660.89 | | 10493.79 | | | |
|
||||
| | 2048/128 | 4 | 601.67 | | | | | |
|
||||
| | | 8 | 1002.57 | | 991.22 | | | |
|
||||
| | 2048/2048 | 4 | 3869.76 | | | | | |
|
||||
| | | 8 | 7134.33 | | 6386.83 | | | |
|
||||
| | | | | | | | | |
|
||||
| --------------- | ------------------------ | ------------- | ------------------- | --------------- | ------------------ | ------------------ | ------------------ | -------- |
|
||||
| | | **GPU** | **H200 141GB HBM3** | **GH200 120GB** | **H100 80GB HBM3** | **H100 80GB HBM3** | **A100-SXM4-80GB** | **L40S** |
|
||||
| | | **Precision** | **FP8** | **FP8** | **FP8** | **FP16** | **FP16** | **FP8** |
|
||||
| **Model** | **Input/Output Lengths** | **TP** | | | | | | |
|
||||
| GPTJ 6B | 128/128 | 1 | 24,533.54 | 22,368.50 | 24,318.61 | 12,936.63 | 5,964.19 | 7,688.44 |
|
||||
| | 128/2048 | 1 | 8,375.67 | 6,588.73 | 7,829.91 | 3,931.61 | 2,215.88 | 1,842.82 |
|
||||
| | 128/4096 | 1 | 5,048.59 | 3,662.81 | 3,955.28 | 2,041.06 | 1,118.12 | 980.23 |
|
||||
| | 2048/128 | 1 | 2,770.27 | 2,520.37 | 2,698.08 | 1,479.48 | 650.09 | 746.54 |
|
||||
| | 5000/500 | 1 | 1,791.39 | 1,449.23 | 1,623.17 | 818.80 | 436.85 | 413.33 |
|
||||
| | 500/2000 | 1 | 6,770.60 | 5,565.62 | 6,149.65 | 3,030.03 | 1,673.05 | 1,538.45 |
|
||||
| | 1000/1000 | 1 | 6,465.73 | 5,580.37 | 6,078.80 | 2,797.48 | 1,673.45 | 1,531.57 |
|
||||
| | 2048/2048 | 1 | 3,637.42 | 2,998.01 | 3,060.80 | 1,285.08 | 845.83 | 753.55 |
|
||||
| LLaMA v3.1 8B | 128/128 | 1 | 28,125.59 | 26,045.60 | 27,147.22 | 15,647.83 | 6,687.04 | 8,548.90 |
|
||||
| | 128/2048 | 1 | 22,989.20 | 16,497.79 | 19,221.02 | 8,882.95 | 4,918.53 | 4,988.61 |
|
||||
| | 128/4096 | 1 | 16,077.62 | 9,637.91 | 11,856.11 | 5,462.96 | 3,054.46 | 2,768.91 |
|
||||
| | 2048/128 | 1 | 3,625.83 | 3,357.60 | 3,497.30 | 1,859.37 | 796.17 | 1,000.90 |
|
||||
| | 5000/500 | 1 | 3,823.76 | 3,217.40 | 3,276.69 | 1,687.74 | 788.66 | 872.14 |
|
||||
| | 500/2000 | 1 | 19,382.37 | 15,128.77 | 13,996.05 | 6,834.76 | 3,929.83 | 3,911.14 |
|
||||
| | 1000/1000 | 1 | 16,435.21 | 12,355.41 | 13,411.43 | 7,160.92 | 3,592.16 | 3,648.21 |
|
||||
| | 2048/2048 | 1 | 11,072.97 | 7,850.75 | 8,851.23 | 4,152.21 | 2,269.78 | 2,055.78 |
|
||||
| | 20000/2000 | 1 | 1,634.98 | 1,200.89 | 1,278.04 | 595.89 | 316.43 | 263.75 |
|
||||
| LLaMA v3 8B | 128/128 | 1 | 27,940.47 | 26,117.13 | 27,156.81 | 15,489.11 | 6,656.98 | 8,734.57 |
|
||||
| | 128/2048 | 1 | 23,228.98 | 16,417.04 | 19,209.17 | 8,901.43 | 4,967.37 | 5,004.93 |
|
||||
| | 128/4096 | 1 | 15,980.94 | 9,351.95 | 11,889.67 | 5,455.91 | 3,053.27 | 2,768.15 |
|
||||
| | 2048/128 | 1 | 3,631.45 | 3,339.90 | 3,476.37 | 1,918.56 | 796.28 | 1,050.68 |
|
||||
| | 5000/500 | 1 | 3,836.98 | 3,186.22 | 3,279.24 | 1,668.42 | 792.95 | 860.31 |
|
||||
| | 500/2000 | 1 | 19,725.45 | 15,241.74 | 14,218.30 | 6,816.62 | 3,899.64 | 3,990.73 |
|
||||
| | 1000/1000 | 1 | 16,201.60 | 12,049.81 | 13,371.60 | 7,041.47 | 3,617.10 | 3,679.10 |
|
||||
| | 2048/2048 | 1 | 11,097.69 | 7,255.55 | 8,852.87 | 4,251.45 | 2,269.68 | 2,048.94 |
|
||||
| LLaMA v2 7B | 128/128 | 1 | 19,549.13 | 17,823.45 | 19,298.99 | 11,436.31 | 5,238.68 | 6,396.62 |
|
||||
| | 128/2048 | 1 | 7,675.14 | 5,438.53 | 6,607.33 | 2,985.61 | 1,807.39 | 1,566.03 |
|
||||
| | 128/4096 | 1 | 4,397.83 | 3,310.09 | 3,628.46 | 1,575.35 | 957.24 | 821.83 |
|
||||
| | 2048/128 | 1 | 2,392.31 | 2,064.18 | 2,304.02 | 1,157.55 | 560.35 | 619.83 |
|
||||
| | 5000/500 | 1 | 1,570.37 | 1,250.11 | 1,419.09 | 624.75 | 366.39 | 347.03 |
|
||||
| | 500/2000 | 1 | 6,044.15 | 4,717.51 | 5,188.69 | 2,382.75 | 1,408.58 | 1,231.78 |
|
||||
| | 1000/1000 | 1 | 5,896.10 | 4,825.24 | 5,208.97 | 2,462.65 | 1,431.92 | 1,277.79 |
|
||||
| | 2048/2048 | 1 | 3,193.42 | 2,693.21 | 2,792.53 | 1,263.11 | 734.38 | 641.47 |
|
||||
| Mistral 7B | 128/128 | 1 | 30,152.19 | 27,738.08 | 29,672.75 | 16,711.12 | 6,863.59 | 9,676.88 |
|
||||
| | 128/2048 | 1 | 24,742.09 | 17,528.14 | 20,318.60 | 9,774.11 | 5,321.44 | 5,437.25 |
|
||||
| | 128/4096 | 1 | 16,905.49 | 10,671.38 | 12,715.46 | 5,740.41 | 3,257.23 | 2,941.08 |
|
||||
| | 2048/128 | 1 | 3,676.37 | 3,369.77 | 3,502.83 | 1,893.42 | 796.00 | 996.65 |
|
||||
| | 5000/500 | 1 | 3,890.07 | 3,401.45 | 3,358.65 | 1,740.69 | 807.07 | 904.45 |
|
||||
| | 500/2000 | 1 | 20,788.70 | 15,035.59 | 15,962.94 | 7,494.80 | 4,168.89 | 4,088.52 |
|
||||
| | 1000/1000 | 1 | 17,620.46 | 13,362.84 | 14,213.48 | 7,281.07 | 3,794.31 | 3,972.63 |
|
||||
| | 2048/2048 | 1 | 11,747.88 | 8,599.03 | 9,200.19 | 4,349.39 | 2,320.50 | 2,170.16 |
|
||||
| | 20000/2000 | 1 | 1,693.41 | 1,271.85 | 1,299.05 | 609.91 | 324.52 | 276.19 |
|
||||
| LLaMA v3.1 405B | 128/128 | 8 | 3,734.50 | | | | | |
|
||||
| | 128/2048 | 8 | 3,039.70 | | | | | |
|
||||
| | 128/4096 | 8 | 3,144.97 | | | | | |
|
||||
| | 2048/128 | 8 | 454.17 | | | | | |
|
||||
| | 5000/500 | 8 | 459.91 | | | | | |
|
||||
| | 500/2000 | 8 | 2,967.98 | | | | | |
|
||||
| | 1000/1000 | 8 | 2,259.32 | | | | | |
|
||||
| | 2048/2048 | 8 | 2,067.15 | | | | | |
|
||||
| | 20000/2000 | 8 | 447.67 | | | | | |
|
||||
| LLaMA v3.1 70B | 128/128 | 1 | 3,923.61 | 2,998.99 | 2,168.72 | | | |
|
||||
| | | 2 | 5,358.16 | 1,839.02 | 5,215.12 | 3,156.10 | 1,340.20 | |
|
||||
| | | 4 | 8,969.59 | 8,655.98 | 8,677.59 | 5,845.53 | 2,426.46 | 1,434.63 |
|
||||
| | | 8 | 16,449.68 | | 15,711.60 | 10,643.75 | 4,491.42 | 1,365.36 |
|
||||
| | 128/2048 | 1 | 3,503.59 | 1,343.53 | 344.22 | | | |
|
||||
| | | 2 | 7,068.42 | 1,146.08 | 5,654.43 | 801.82 | 498.44 | |
|
||||
| | | 4 | 12,890.95 | 10,358.10 | 9,377.87 | 4,791.11 | 2,460.91 | 1,748.87 |
|
||||
| | | 8 | 19,947.02 | | 15,168.97 | 6,892.18 | 4,148.33 | 1,890.62 |
|
||||
| | 128/4096 | 1 | 2,314.83 | | | | | |
|
||||
| | | 2 | 6,227.19 | 896.56 | 3,302.41 | 413.22 | 268.86 | |
|
||||
| | | 4 | 10,059.64 | 6,628.22 | 6,501.69 | 3,056.98 | 1,660.93 | 1,180.87 |
|
||||
| | | 8 | 14,393.28 | | 9,699.99 | 4,238.15 | 2,705.77 | 1,417.60 |
|
||||
| | 2048/128 | 1 | 459.73 | 372.44 | 211.51 | | | |
|
||||
| | | 2 | 689.30 | 280.61 | 690.05 | 323.66 | 143.39 | |
|
||||
| | | 4 | 1,047.96 | 1,015.14 | 1,016.24 | 672.37 | 278.87 | 167.87 |
|
||||
| | | 8 | 2,061.19 | | 1,964.49 | 1,273.97 | 539.57 | 163.91 |
|
||||
| | 5000/500 | 1 | 534.79 | 283.19 | 112.21 | | | |
|
||||
| | | 2 | 943.78 | 337.04 | 897.36 | 224.31 | 115.63 | |
|
||||
| | | 4 | 1,437.45 | 1,383.61 | 1,329.82 | 851.12 | 361.39 | 235.90 |
|
||||
| | | 8 | 2,795.95 | | 2,472.69 | 1,438.10 | 679.27 | 224.33 |
|
||||
| | 500/2000 | 1 | 2,758.24 | 1,083.48 | | | | |
|
||||
| | | 2 | 6,063.53 | 851.46 | 4,347.69 | 652.34 | 423.06 | |
|
||||
| | | 4 | 10,061.89 | 9,090.78 | 8,378.16 | 3,441.34 | 2,072.88 | 1,436.41 |
|
||||
| | | 8 | 16,139.49 | | 10,790.85 | 5,792.17 | 3,115.20 | 1,512.78 |
|
||||
| | 1000/1000 | 1 | 2,539.65 | 728.79 | | | | |
|
||||
| | | 2 | 4,572.03 | 1,223.92 | 3,880.41 | 737.40 | 451.82 | |
|
||||
| | | 4 | 7,612.56 | 6,705.02 | 6,553.00 | 3,655.64 | 1,731.86 | 1,113.18 |
|
||||
| | | 8 | 12,660.86 | | 11,121.10 | 5,599.45 | 3,013.95 | 1,120.73 |
|
||||
| | 2048/2048 | 1 | 1,753.58 | 611.08 | 161.60 | | | |
|
||||
| | | 2 | 3,407.26 | 626.26 | 2,432.55 | | 108.91 | |
|
||||
| | | 4 | 6,565.77 | 4,864.55 | 4,948.83 | 2,396.06 | 1,220.93 | 855.44 |
|
||||
| | | 8 | 9,948.56 | | 8,527.52 | 3,819.60 | 2,103.68 | 924.89 |
|
||||
| | 20000/2000 | 1 | 262.82 | 88.89 | | | | |
|
||||
| | | 2 | 598.19 | 177.04 | 414.17 | | | |
|
||||
| | | 4 | 1,047.27 | 958.88 | 856.31 | 375.85 | 187.42 | 140.73 |
|
||||
| | | 8 | 1,793.52 | | 1,359.27 | 650.78 | 344.41 | 122.04 |
|
||||
| LLaMA v3 70B | 128/128 | 1 | 3,924.02 | 3,161.73 | 2,177.84 | | | |
|
||||
| | | 2 | 5,388.22 | 1,551.84 | 5,205.80 | 3,186.61 | 1,321.55 | |
|
||||
| | | 4 | 8,958.95 | 8,618.55 | 8,678.68 | 5,857.16 | 2,424.68 | 1,432.46 |
|
||||
| | | 8 | 16,375.41 | | 15,703.26 | 10,627.36 | 4,490.19 | 1,333.09 |
|
||||
| | 128/2048 | 1 | 3,519.24 | 1,346.37 | 353.68 | | | |
|
||||
| | | 2 | 7,071.54 | 862.54 | 5,878.06 | 802.98 | 512.11 | |
|
||||
| | | 4 | 12,876.38 | 10,015.23 | 8,929.23 | 4,768.27 | 2,458.73 | 1,737.31 |
|
||||
| | | 8 | 20,013.92 | | 15,171.91 | 6,875.97 | 3,906.35 | 1,892.41 |
|
||||
| | 128/4096 | 1 | 2,310.85 | | | | | |
|
||||
| | | 2 | 6,199.95 | 602.98 | 3,311.05 | 413.29 | 269.02 | |
|
||||
| | | 4 | 9,633.49 | 7,370.19 | 6,489.95 | 3,053.89 | 1,677.51 | 1,199.71 |
|
||||
| | | 8 | 14,552.09 | | 9,632.02 | 4,259.39 | 2,697.61 | 1,358.34 |
|
||||
| | 2048/128 | 1 | 458.75 | 371.70 | 210.27 | | | |
|
||||
| | | 2 | 694.00 | 277.85 | 692.74 | 321.71 | 144.61 | |
|
||||
| | | 4 | 1,048.84 | 1,016.03 | 1,022.77 | 690.10 | 279.06 | 168.52 |
|
||||
| | | 8 | 2,072.33 | | 1,976.76 | 1,273.41 | 542.93 | 158.63 |
|
||||
| | 5000/500 | 1 | 533.37 | 303.33 | 112.68 | | | |
|
||||
| | | 2 | 936.82 | 379.62 | 899.29 | 224.65 | 115.00 | |
|
||||
| | | 4 | 1,442.76 | 1,384.62 | 1,326.95 | 853.73 | 361.06 | 235.19 |
|
||||
| | | 8 | 2,797.36 | | 2,483.56 | 1,437.15 | 678.70 | 225.15 |
|
||||
| | 500/2000 | 1 | 2,763.89 | 1,074.62 | 293.47 | | | |
|
||||
| | | 2 | 6,054.46 | 1,109.13 | 4,356.55 | 683.11 | 423.82 | |
|
||||
| | | 4 | 10,103.08 | 7,325.93 | 8,370.32 | 3,436.29 | 2,064.47 | 1,412.78 |
|
||||
| | | 8 | 16,857.45 | | 10,760.65 | 5,665.02 | 3,159.89 | 1,517.76 |
|
||||
| | 1000/1000 | 1 | 2,540.45 | 1,164.45 | | | | |
|
||||
| | | 2 | 4,590.38 | 1,040.64 | 3,879.25 | 768.53 | 453.73 | |
|
||||
| | | 4 | 7,606.92 | 6,655.61 | 6,547.23 | 3,655.19 | 1,732.86 | 1,117.53 |
|
||||
| | | 8 | 12,660.32 | | 11,155.47 | 5,617.24 | 2,894.58 | 1,126.50 |
|
||||
| | 2048/2048 | 1 | 1,746.77 | 610.87 | 162.10 | | | |
|
||||
| | | 2 | 3,405.72 | 738.51 | 2,548.70 | | 108.66 | |
|
||||
| | | 4 | 6,571.34 | 4,880.28 | 5,060.39 | 2,391.55 | 1,222.11 | 854.65 |
|
||||
| | | 8 | 9,923.96 | | 8,480.48 | 3,826.38 | 2,181.07 | 927.54 |
|
||||
| LLaMA v2 70B | 128/128 | 1 | 3,969.25 | 3,502.35 | 3,413.82 | | | |
|
||||
| | | 2 | 6,394.64 | 3,252.69 | 6,432.82 | 3,170.28 | 1,336.48 | |
|
||||
| | | 4 | 11,031.42 | 11,126.95 | 10,865.42 | 6,420.88 | 2,766.00 | 1,487.71 |
|
||||
| | | 8 | 17,060.04 | | 16,384.83 | 11,146.15 | 4,742.74 | 1,404.99 |
|
||||
| | 128/2048 | 1 | 3,742.99 | 1,660.81 | | | | |
|
||||
| | | 2 | 6,453.25 | 1,335.80 | 5,775.34 | 757.21 | 476.46 | |
|
||||
| | | 4 | 13,869.67 | 11,098.69 | 9,536.82 | 5,274.27 | 2,686.16 | 1,880.22 |
|
||||
| | | 8 | 19,220.48 | | 17,715.01 | 8,904.94 | 5,520.41 | 2,186.68 |
|
||||
| | 128/4096 | 1 | 2,459.63 | | 446.60 | | | |
|
||||
| | | 2 | 4,831.03 | 684.68 | 3,354.60 | 385.98 | 235.22 | |
|
||||
| | | 4 | 8,988.84 | 8,397.13 | 7,619.62 | 3,228.36 | 1,941.07 | 1,318.51 |
|
||||
| | | 8 | 15,115.41 | | 12,506.95 | 5,996.81 | 3,539.36 | 1,782.93 |
|
||||
| | 2048/128 | 1 | 458.88 | 400.31 | 328.90 | | | |
|
||||
| | | 2 | 745.71 | 457.57 | 742.17 | 308.02 | 138.81 | |
|
||||
| | | 4 | 1,297.10 | 1,330.90 | 1,270.78 | 755.30 | 321.72 | 171.67 |
|
||||
| | | 8 | 2,060.53 | | 2,009.57 | 1,348.71 | 561.71 | 160.37 |
|
||||
| | 5000/500 | 1 | 548.46 | 364.00 | 224.17 | | | |
|
||||
| | | 2 | 1,020.86 | 335.07 | 885.67 | 212.20 | 112.43 | |
|
||||
| | | 4 | 1,759.69 | 1,683.26 | 1,590.94 | 837.57 | 386.78 | 231.54 |
|
||||
| | | 8 | 2,839.69 | | 2,546.12 | 1,570.91 | 709.66 | 238.59 |
|
||||
| | 500/2000 | 1 | 3,019.28 | 1,364.66 | 716.54 | | | |
|
||||
| | | 2 | 6,402.94 | 1,292.24 | 4,462.98 | 629.21 | 387.61 | |
|
||||
| | | 4 | 12,429.18 | 8,951.07 | 8,753.09 | 4,012.41 | 2,158.17 | 1,517.53 |
|
||||
| | | 8 | 16,789.12 | | 15,260.29 | 7,384.79 | 4,104.80 | 1,739.28 |
|
||||
| | 1000/1000 | 1 | 2,706.04 | 1,449.83 | | | | |
|
||||
| | | 2 | 4,693.24 | 960.39 | 3,958.45 | 736.68 | 425.70 | |
|
||||
| | | 4 | 8,557.11 | 7,278.64 | 6,817.41 | 3,866.05 | 1,876.40 | 1,188.91 |
|
||||
| | | 8 | 13,483.04 | | 11,511.74 | 6,543.96 | 3,285.82 | 1,241.42 |
|
||||
| | 2048/2048 | 1 | 1,911.20 | 798.50 | 412.37 | | | |
|
||||
| | | 2 | 3,408.82 | 767.24 | 2,551.21 | 388.82 | 226.60 | |
|
||||
| | | 4 | 6,702.46 | 5,354.80 | 5,212.02 | 2,512.22 | 1,316.92 | 891.95 |
|
||||
| | | 8 | 10,348.65 | | 8,016.14 | 4,414.75 | 2,492.09 | 1,083.26 |
|
||||
| Mixtral 8x7B | 128/128 | 2 | 25,135.25 | 8,512.51 | 24,572.90 | 15,395.59 | 5,927.88 | |
|
||||
| | | 4 | 42,394.61 | 40,148.01 | 40,309.25 | 27,747.43 | 11,205.51 | 6,784.44 |
|
||||
| | | 8 | 54,648.80 | | 51,683.16 | 40,116.51 | 18,496.66 | 6,437.72 |
|
||||
| | 128/2048 | 2 | 29,412.17 | 3,271.02 | 20,938.80 | 7,391.51 | 4,278.79 | |
|
||||
| | | 4 | 52,603.13 | 43,071.34 | 40,580.94 | 21,332.15 | 10,946.58 | 7,475.05 |
|
||||
| | | 8 | 70,427.00 | | 64,161.64 | 41,101.18 | 21,235.99 | 9,955.21 |
|
||||
| | 128/4096 | 2 | 21,312.11 | 2,254.56 | | 3,896.02 | 2,388.14 | |
|
||||
| | | 4 | 39,353.01 | 30,065.77 | | | 7,108.03 | 5,232.44 |
|
||||
| | | 8 | 32,992.62 | | 47,860.65 | 27,261.67 | 15,943.70 | 8,081.21 |
|
||||
| | 2048/128 | 2 | 2,946.01 | 921.87 | 2,894.09 | 1,790.49 | 684.71 | |
|
||||
| | | 4 | 5,237.58 | 5,056.60 | 4,988.14 | 3,354.89 | 1,338.54 | 803.50 |
|
||||
| | | 8 | 7,053.32 | | 6,559.63 | 5,072.46 | 2,244.39 | 753.39 |
|
||||
| | 5000/500 | 2 | 3,848.10 | 997.06 | 3,630.24 | 1,656.04 | 739.84 | |
|
||||
| | | 4 | 6,877.65 | 6,466.39 | 6,237.22 | 3,607.46 | 1,619.49 | 1,048.60 |
|
||||
| | | 8 | 9,531.26 | | 8,709.34 | 6,237.96 | 2,927.13 | 1,109.25 |
|
||||
| | 500/2000 | 2 | 23,539.24 | 2,773.86 | 16,886.30 | 5,773.33 | 3,325.73 | |
|
||||
| | | 4 | 40,035.05 | 33,478.35 | 32,047.73 | 16,897.03 | 8,908.09 | 6,153.32 |
|
||||
| | | 8 | 60,572.77 | | 41,597.80 | 31,392.32 | 16,954.54 | 7,980.34 |
|
||||
| | 1000/1000 | 2 | 18,644.51 | 4,540.15 | 14,154.95 | 5,826.43 | 3,289.27 | |
|
||||
| | | 4 | 32,709.62 | 29,046.16 | 25,291.30 | 14,307.91 | 7,461.63 | 4,697.19 |
|
||||
| | | 8 | 44,072.88 | | 40,628.46 | 27,633.48 | 13,741.62 | 5,706.17 |
|
||||
| | 2048/2048 | 2 | 14,017.70 | 2,870.77 | 10,448.79 | 3,535.21 | 1,954.32 | |
|
||||
| | | 4 | 25,550.44 | 21,488.32 | 19,977.11 | 9,620.99 | 5,191.30 | 3,593.18 |
|
||||
| | | 8 | 24,999.94 | | 31,678.85 | 19,372.52 | 10,572.07 | 4,860.61 |
|
||||
| | 20000/2000 | 2 | 2,195.84 | 367.81 | 1,583.86 | 626.60 | 320.41 | |
|
||||
| | | 4 | 4,086.41 | 3,301.28 | 2,982.42 | 1,586.09 | 807.67 | 579.49 |
|
||||
| | | 8 | 5,797.57 | | 5,163.91 | 3,106.98 | 1,653.55 | 821.64 |
|
||||
*TP stands for Tensor Parallelism*
|
||||
|
||||
## Reproducing Benchmarked Results
|
||||
@ -169,7 +277,10 @@ remain in the system longer and therefore require less requests to achieve stead
|
||||
| 128 | 4096 | 4224 | 1500 |
|
||||
| 2048 | 128 | 2176 | 3000 |
|
||||
| 2048 | 2048 | 4096 | 1500 |
|
||||
|
||||
| 5000 | 500 | 5500 | 1500 |
|
||||
| 1000 | 1000 | 2000 | 3000 |
|
||||
| 500 | 2000 | 2500 | 3000 |
|
||||
| 20000 | 2000 | 22000 | 1000 |
|
||||
|
||||
## Engine Building
|
||||
|
||||
|
||||
@ -75,7 +75,8 @@ TensorRT-LLM optimizes the performance of a range of well-known models on NVIDIA
|
||||
|
||||
The following table shows the supported hardware for TensorRT-LLM.
|
||||
|
||||
If a GPU is not listed, it is important to note that TensorRT-LLM is expected to work on GPUs based on the Volta, Turing, Ampere, Hopper, and Ada Lovelace architectures. Certain limitations may, however, apply.
|
||||
If a GPU architecture is not listed, the TensorRT-LLM team does not develop or test the software on the architecture and support is limited to community support.
|
||||
In addition, older architectures can have limitations for newer software releases.
|
||||
|
||||
```{list-table}
|
||||
:header-rows: 1
|
||||
@ -90,8 +91,6 @@ If a GPU is not listed, it is important to note that TensorRT-LLM is expected to
|
||||
- [NVIDIA Hopper Architecture](https://www.nvidia.com/en-us/data-center/technologies/hopper-architecture/)
|
||||
- [NVIDIA Ada Lovelace Architecture](https://www.nvidia.com/en-us/technologies/ada-architecture/)
|
||||
- [NVIDIA Ampere Architecture](https://www.nvidia.com/en-us/data-center/ampere-architecture/)
|
||||
- [NVIDIA Turing Architecture](https://www.nvidia.com/en-us/geforce/turing/)
|
||||
- [NVIDIA Volta Architecture](https://www.nvidia.com/en-us/data-center/volta-gpu-architecture/) (experimental)
|
||||
```
|
||||
|
||||
(support-matrix-software)=
|
||||
@ -114,14 +113,8 @@ The following table shows the supported software for TensorRT-LLM.
|
||||
- Hopper (SM90) - FP32, FP16, BF16, FP8, INT8, INT4
|
||||
- Ada Lovelace (SM89) - FP32, FP16, BF16, FP8, INT8, INT4
|
||||
- Ampere (SM80, SM86) - FP32, FP16, BF16, INT8, INT4[^smgte89]
|
||||
- Turing (SM75) - FP32, FP16, INT8[^smooth], INT4
|
||||
- Volta (SM70) - FP32, FP16, INT8[^smooth], INT4[^smlt75]
|
||||
```
|
||||
|
||||
[^smooth]: INT8 SmoothQuant is not supported on SM70 and SM75.
|
||||
|
||||
[^smlt75]: INT4 AWQ and GPTQ are not supported on SM < 75.
|
||||
|
||||
[^smgte89]: INT4 AWQ and GPTQ with FP8 activations require SM >= 89.
|
||||
|
||||
[^encdec]: Encoder-Decoder provides general encoder-decoder functionality that supports many encoder-decoder models such as T5 family, BART family, Whisper family, NMT family, and so on.
|
||||
|
||||
@ -258,8 +258,13 @@ SLURM, depending upon the SLURM version you are using:
|
||||
Please configure as appropriate and try again.
|
||||
--------------------------------------------------------------------------
|
||||
```
|
||||
|
||||
You may experience other problems like hanging on the program startup.
|
||||
|
||||
As a rule of thumb, if you are running TensorRT-LLM interactively on a Slurm
|
||||
node, prefix your commands with `mpirun -n 1` to run TensorRT-LLM in a
|
||||
dedicated MPI environment, not the one provided by your Slurm allocation.
|
||||
|
||||
For example: `mpirun -n 1 python3 examples/gpt/build.py ...`
|
||||
|
||||
It's critical that it's always `-n 1` regardless of how many GPUs are being used. If you'd use `-n 2` for a 2 GPU program it will not work. `mpirun` here isn't being used to orchestrate multiple processes, but to invoke the right environment on SLURM. The internal MPI implementation deals with spawning the additional processes.
|
||||
|
||||
@ -179,13 +179,13 @@ All published functionality in the Release Notes has been fully tested and verif
|
||||
- Moved the most commonly used options in the explicit arg-list, and hidden the expert options in the kwargs.
|
||||
- Exposed `model` to accept either HuggingFace model name or local HuggingFace model/TensorRT-LLM checkpoint/TensorRT-LLM engine.
|
||||
- Support downloading model from HuggingFace model hub, currently only Llama variants are supported.
|
||||
- Support build cache to reuse the built TensorRT-LLM engines by setting environment variable `TLLM_HLAPI_BUILD_CACHE=1` or passing `enable_build_cache=True` to `LLM` class.
|
||||
- Support build cache to reuse the built TensorRT-LLM engines by setting environment variable `TLLM_LLMAPI_BUILD_CACHE=1` or passing `enable_build_cache=True` to `LLM` class.
|
||||
- Exposed low-level options including `BuildConfig`, `SchedulerConfig` and so on in the kwargs, ideally you should be able to configure details about the build and runtime phase.
|
||||
- Refactored `LLM.generate()` and `LLM.generate_async()` API.
|
||||
- Removed `SamplingConfig`.
|
||||
- Added `SamplingParams` with more extensive parameters, see `tensorrt_llm/hlapi/utils.py`.
|
||||
- Added `SamplingParams` with more extensive parameters, see `tensorrt_llm/llmapi/utils.py`.
|
||||
- The new `SamplingParams` contains and manages fields from Python bindings of `SamplingConfig`, `OutputConfig`, and so on.
|
||||
- Refactored `LLM.generate()` output as `RequestOutput`, see `tensorrt_llm/hlapi/llm.py`.
|
||||
- Refactored `LLM.generate()` output as `RequestOutput`, see `tensorrt_llm/llmapi/llm.py`.
|
||||
- Updated the `apps` examples, specially by rewriting both `chat.py` and `fastapi_server.py` using the `LLM` APIs, please refer to the `examples/apps/README.md` for details.
|
||||
- Updated the `chat.py` to support multi-turn conversation, allowing users to chat with a model in the terminal.
|
||||
- Fixed the `fastapi_server.py` and eliminate the need for `mpirun` in multi-GPU scenarios.
|
||||
@ -481,7 +481,7 @@ All published functionality in the Release Notes has been fully tested and verif
|
||||
Refer to the {ref}`support-matrix-software` section for a list of supported models.
|
||||
|
||||
* API
|
||||
- Add a set of High-level APIs for end-to-end generation tasks (see examples/high-level-api/README.md)
|
||||
- Add a set of LLM APIs for end-to-end generation tasks (see examples/llm-api/README.md)
|
||||
- **[BREAKING CHANGES]** Migrate models to the new build workflow, including LLaMA, Mistral, Mixtral, InternLM, ChatGLM, Falcon, GPT-J, GPT-NeoX, Medusa, MPT, Baichuan and Phi (see docs/source/new_workflow.md)
|
||||
- **[BREAKING CHANGES]** Deprecate `LayerNorm` and `RMSNorm` plugins and removed corresponding build parameters
|
||||
- **[BREAKING CHANGES]** Remove optional parameter `maxNumSequences` for GPT manager
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
# Apps examples with GenerationExecutor / High-level API
|
||||
# Apps examples with GenerationExecutor / LLM API
|
||||
## OpenAI API
|
||||
[openai_server.py](./openai_server.py) is an OpenAI compatible server which supports `v1/version`, `v1/completions` and `v1/chat/completions`. [openai_client.py](./openai_client.py) is a simple example using OpenAI client to query your model. To start the server, you can run
|
||||
```
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user