open source 4dbf696ae9b74a26829d120b67ab8443d70c8e58 (#2297)

* Update TensorRT-LLM

---------

Co-authored-by: Bhuvanesh Sridharan <bhuvanesh.sridharan@sprinklr.com>
Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
This commit is contained in:
Kaiyu Xie 2024-10-08 18:19:19 +08:00 committed by GitHub
parent 48686bca3a
commit 8681b3a4c0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
205 changed files with 5589 additions and 1828 deletions

3
.gitmodules vendored
View File

@ -14,3 +14,6 @@
[submodule "3rdparty/ucxx"]
path = 3rdparty/ucxx
url = https://github.com/GuanLuo/ucxx.git
[submodule "3rdparty/pybind11"]
path = 3rdparty/pybind11
url = https://github.com/pybind/pybind11.git

View File

@ -46,5 +46,5 @@ repos:
args:
- --skip=".git,3rdparty"
- --exclude-file=examples/whisper/tokenizer.py
- --ignore-words-list=rouge,inout,atleast,strat,nd,subtile,thrid
- --ignore-words-list=rouge,inout,atleast,strat,nd,subtile,thrid,improbe
exclude: 'tests/llm-test-defs/turtle/test_input_files'

1
3rdparty/pybind11 vendored Submodule

@ -0,0 +1 @@
Subproject commit f99ffd7e03001810a3e722bf48ad1a9e08415d7d

View File

@ -103,7 +103,7 @@ Serverless TensorRT-LLM (LLaMA 3 8B) | Modal Docs [➡️ link](https://modal.co
## TensorRT-LLM Overview
TensorRT-LLM is a library for optimizing Large Language Model (LLM) inference.
It provides state-of-the-art optimziations, including custom attention kernels, inflight batching, paged KV caching, quantization (FP8, INT4 [AWQ](https://arxiv.org/abs/2306.00978), INT8 [SmoothQuant](https://arxiv.org/abs/2211.10438), ++) and much more, to perform inference efficiently on NVIDIA GPUs
It provides state-of-the-art optimizations, including custom attention kernels, inflight batching, paged KV caching, quantization (FP8, INT4 [AWQ](https://arxiv.org/abs/2306.00978), INT8 [SmoothQuant](https://arxiv.org/abs/2211.10438), ++) and much more, to perform inference efficiently on NVIDIA GPUs
TensorRT-LLM provides a Python API to build LLMs into optimized
[TensorRT](https://developer.nvidia.com/tensorrt) engines.

View File

@ -42,7 +42,7 @@ This section covers how to benchmark TensorRT-LLM using inflight batching.
### Quickstart
For this quick start guide, we will focus on running a short max throughput benchmark on
`meta-llama/Llama-2-7b-hf` on a syntehtic dataset with a uniform distribution of prompts with ISL:OSL
`meta-llama/Llama-2-7b-hf` on a synthetic dataset with a uniform distribution of prompts with ISL:OSL
of 128:128. In order to run the benchmark from start to finish simply run the following commands:
```shell
@ -101,12 +101,12 @@ The workflow for `trtllm-bench` is composed of the following steps:
The inflight benchmark utilizes a fixed JSON schema so that it is simple and
straightforward to specify requests. The schema is defined as follows:
| Key | Required | Type | Description |
| :- | :-: | :-: | :- |
| `task_id`| Y | String | Unique identifier for the request. |
| `prompt` | N* | String | Input text for a generation request. |
| `logits` | N* | List[Integer] | List of logits that make up the request prompt. |
| `output_tokens` | Y | Integer | Number of generated tokens for this request. |
| Key | Required | Type | Description |
| :-------------- | :------: | :-----------: | :---------------------------------------------- |
| `task_id` | Y | String | Unique identifier for the request. |
| `prompt` | N* | String | Input text for a generation request. |
| `logits` | N* | List[Integer] | List of logits that make up the request prompt. |
| `output_tokens` | Y | Integer | Number of generated tokens for this request. |
> [!NOTE] Prompt and logits are mutually exclusive*
> While having both `prompt` and `logits` is not required, at least one is required.

View File

@ -316,6 +316,8 @@ endif()
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
add_subdirectory(${3RDPARTY_DIR}/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/pybind11)
include_directories(
${CUDAToolkit_INCLUDE_DIRS}
${CUDNN_ROOT_DIR}/include
@ -323,7 +325,8 @@ include_directories(
${3RDPARTY_DIR}/cutlass/include
${3RDPARTY_DIR}/cutlass/tools/util/include
${3RDPARTY_DIR}/NVTX/include
${3RDPARTY_DIR}/json/include)
${3RDPARTY_DIR}/json/include
${3RDPARTY_DIR}/pybind11/include)
# TRT dependencies
set_ifndef(TRT_LIB_DIR ${CMAKE_BINARY_DIR})

View File

@ -0,0 +1,187 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "common.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/common/algorithm.h"
#include "tensorrt_llm/runtime/common.h"
#include <variant>
namespace tensorrt_llm::batch_manager
{
namespace kv_cache_manager
{
class KVCacheManager;
}
class BasePeftCacheManager;
} // namespace tensorrt_llm::batch_manager
namespace tensorrt_llm::batch_manager
{
using tensorrt_llm::runtime::SizeType32;
/// @brief This scheduler takes into account the given request capacity and the KV cache capacity.
/// Depending on the CapacitySchedulerPolicy it will schedule already started and new requests,
/// or even pause previously started requests.
class BaseCapacityScheduler
{
public:
explicit BaseCapacityScheduler(LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
: mNoScheduleUntilState(noScheduleUntilState)
, mNoScheduleAfterState(noScheduleAfterState)
{
}
[[nodiscard]] LlmRequestState constexpr getNoScheduleUntilState() const noexcept
{
return mNoScheduleUntilState;
}
[[nodiscard]] LlmRequestState constexpr getNoScheduleAfterState() const noexcept
{
return mNoScheduleAfterState;
}
private:
/// The state until/after which the scheduler should not schedule requests
LlmRequestState mNoScheduleUntilState;
LlmRequestState mNoScheduleAfterState;
};
/// @brief Schedule up to maxNumRequests requests
class MaxRequestsScheduler : public BaseCapacityScheduler
{
public:
explicit MaxRequestsScheduler(SizeType32 maxNumRequests,
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
/// @brief Takes as input a sorted list of requests and outputs a sorted lists of requests
/// to update for this current iteration, and a map of requests to pause
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
private:
SizeType32 mMaxNumRequests;
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr};
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr};
};
/// @brief Schedule requests using the MAX_UTILIZATION policy
/// @details Try reserving resources to advance requests by one step,
/// may pause previously started requests.
class MaxUtilizationScheduler : public BaseCapacityScheduler
{
public:
MaxUtilizationScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
std::shared_ptr<BasePeftCacheManager> peftCacheManager, bool manyMicroBatches,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
private:
/// @return {fitsKvCache, fitsPeft}
std::pair<bool, bool> trySchedulingRequestMaxUtilization(std::shared_ptr<LlmRequest> const& req,
RequestVector& scheduledRequests, SizeType32& numScheduledBlocks, SizeType32& numScheduledPeftPages,
std::unordered_set<uint64_t>& seenTaskIds) const;
SizeType32 mMaxNumRequests;
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr};
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr};
std::shared_ptr<BasePeftCacheManager> mPeftCacheManager{nullptr};
/// @brief Boolean that indicates if multiple micro batches might be in flight
bool mManyMicroBatches;
};
/// @brief Schedule requests using the GUARANTEED_NO_EVICT policy
class GuaranteedNoEvictScheduler : public BaseCapacityScheduler
{
public:
GuaranteedNoEvictScheduler(SizeType32 maxNumRequests,
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
protected:
[[nodiscard]] std::tuple<RequestVector, RequestVector> forwardImpl(
RequestList const& activeRequests, bool staticBatchScheduling) const;
private:
SizeType32 mMaxNumRequests;
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr};
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr};
std::shared_ptr<BasePeftCacheManager> mPeftCacheManager{nullptr};
};
/// @brief Schedule requests using the STATIC_BATCH policy
class StaticBatchScheduler : public GuaranteedNoEvictScheduler
{
public:
StaticBatchScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
};
class CapacityScheduler : public Algorithm
{
public:
constexpr static auto name{"CapacityScheduler"};
CapacityScheduler() = default;
CapacityScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool manyMicroBatches = false,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
static CapacityScheduler make(SizeType32 maxNumRequests,
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool manyMicroBatches = false,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE)
{
return CapacityScheduler{maxNumRequests, std::move(kvCacheManager), std::move(crossKvCacheManager),
std::move(peftCacheManager), capacitySchedulerPolicy, manyMicroBatches, noScheduleUntilState,
noScheduleAfterState};
}
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
private:
std::variant<std::monostate, MaxRequestsScheduler, MaxUtilizationScheduler, GuaranteedNoEvictScheduler,
StaticBatchScheduler>
mScheduler;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -0,0 +1,118 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/runtime/common.h"
#include <cstdint>
#include <list>
#include <memory>
#include <unordered_set>
#include <utility>
#include <vector>
namespace tensorrt_llm::executor
{
class RequestWithId;
}
namespace tensorrt_llm::batch_manager
{
class LlmRequest;
using RequestList = std::list<std::shared_ptr<LlmRequest>>;
using RequestIdType = std::uint64_t;
using RequestVector = std::vector<std::shared_ptr<LlmRequest>>;
using ReqIdsSet = std::unordered_set<RequestIdType>;
class ScheduledRequests
{
public:
/// @brief context phase requests (for decoder-only models) or encoder phase requests (for encoder-decoder models
/// and encoder-only models)
RequestVector contextRequests;
/// @brief generation phase requests (for decoder-only models) or empty for others
RequestVector generationRequests;
ScheduledRequests() = default;
explicit ScheduledRequests(RequestVector contextRequests, RequestVector generationRequests)
: contextRequests{std::move(contextRequests)}
, generationRequests{std::move(generationRequests)}
{
}
[[nodiscard]] bool empty() const
{
return contextRequests.empty() && generationRequests.empty();
}
[[nodiscard]] std::size_t size() const
{
return contextRequests.size() + generationRequests.size();
}
};
class BatchState
{
public:
BatchState() = default;
BatchState(runtime::SizeType32 numCtxRequests, runtime::SizeType32 numGenRequests, runtime::SizeType32 numTokens,
runtime::SizeType32 maxKvCacheLength)
: mNumCtxRequests{numCtxRequests}
, mNumGenRequests{numGenRequests}
, mNumTokens{numTokens}
, mMaxKvCacheLength{maxKvCacheLength}
{
}
bool isAnyContext() const
{
return mNumCtxRequests > 0;
}
bool operator==(BatchState const& other) const
{
return mNumCtxRequests == other.mNumCtxRequests && mNumGenRequests == other.mNumGenRequests
&& mNumTokens == other.mNumTokens && mMaxKvCacheLength == other.mMaxKvCacheLength;
}
size_t hash() const
{
size_t h1 = std::hash<runtime::SizeType32>{}(mNumCtxRequests);
size_t h2 = std::hash<runtime::SizeType32>{}(mNumGenRequests);
size_t h3 = std::hash<runtime::SizeType32>{}(mNumTokens);
size_t h4 = std::hash<runtime::SizeType32>{}(mMaxKvCacheLength);
return h1 ^ h2 ^ h3 ^ h4;
}
runtime::SizeType32 mNumCtxRequests;
runtime::SizeType32 mNumGenRequests;
runtime::SizeType32 mNumTokens;
runtime::SizeType32 mMaxKvCacheLength;
};
struct BatchStateHash
{
size_t operator()(BatchState const& bs) const
{
return bs.hash();
}
};
} // namespace tensorrt_llm::batch_manager

View File

@ -0,0 +1,74 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include <vector>
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
namespace tensorrt_llm::batch_manager::eviction_policy
{
class BaseEvictionPolicy
{
public:
virtual ~BaseEvictionPolicy() = default;
virtual void initialize(
std::vector<BlockPtr>& mAllBlocksById, SizeType32 numPrimaryBlocks, SizeType32 numSecondaryBlocks)
= 0;
// Get a free block from the primary memory pool
virtual BlockPtr getFreePrimaryBlock() = 0;
// Get a free block from the secondary memory pool
virtual BlockPtr getFreeSecondaryBlock() = 0;
// Release a block. Prioritize the block for eviction if toFront=true
virtual void releaseBlock(BlockPtr block, bool toFront = false) = 0;
// Get the amount of free blocks in the primary memory pool
virtual SizeType32 getNumFreePrimaryBlocks() = 0;
// Get the amount of free blocks in the secondary memory pool
virtual SizeType32 getNumFreeSecondaryBlocks() = 0;
// Claim a free block. Called when the cache manager allocates or reuses a new block
virtual void claimBlock(KVCacheBlock block) = 0;
};
class LRUEvictionPolicy : public BaseEvictionPolicy
{
public:
void initialize(
std::vector<BlockPtr>& mAllBlocksById, SizeType32 numPrimaryBlocks, SizeType32 numSecondaryBlocks) override;
BlockPtr getFreePrimaryBlock() override;
BlockPtr getFreeSecondaryBlock() override;
void releaseBlock(BlockPtr block, bool toFront = false) override;
SizeType32 getNumFreePrimaryBlocks() override;
SizeType32 getNumFreeSecondaryBlocks() override;
void claimBlock(KVCacheBlock block);
private:
FreeBlocksQueue mFreePrimaryBlocks;
FreeBlocksQueue mFreeSecondaryBlocks;
std::vector<std::optional<FreeBlocksQueue::iterator>> mFreeBlockIterators;
SizeType32 mFreePrimaryBlocksSize;
SizeType32 mFreeSecondaryBlocksSize;
};
} // namespace tensorrt_llm::batch_manager::eviction_policy

View File

@ -22,6 +22,7 @@
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
@ -29,13 +30,18 @@
#include <NvInferRuntime.h>
#include <cstdint>
#include <functional>
#include <limits>
#include <list>
#include <memory>
#include <optional>
#include <unordered_map>
#include <vector>
namespace tensorrt_llm::batch_manager::eviction_policy
{
class BaseEvictionPolicy;
}
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
@ -124,6 +130,8 @@ public:
[[nodiscard]] IdType getBlockId() const;
[[nodiscard]] NextBlockMap getNextBlocks() const;
[[nodiscard]] kernels::KVCacheIndex::UnderlyingType getMemoryPoolBlockIndex() const;
[[nodiscard]] bool isPrimary() const;
@ -144,22 +152,12 @@ public:
[[nodiscard]] VecUniqueTokens const& getUniqueTokens() const;
void setFreeBlockIterator(FreeBlocksQueue::iterator freeBlockIterator);
void resetFreeBlockIterator();
[[nodiscard]] std::optional<FreeBlocksQueue::iterator> const& getFreeBlockIterator() const;
void setPrevBlock(BlockPtr prevBlock);
void addNextBlock(BlockKey const& blockKey, BlockPtr block);
void removeNextBlock(BlockKey const& blockKey);
static std::shared_ptr<KVCacheBlock> findBestGPUBlockToFree(std::shared_ptr<KVCacheBlock> searchStart);
static std::shared_ptr<KVCacheBlock> findLeafBlock(std::shared_ptr<KVCacheBlock> searchStart);
[[nodiscard]] BlockPtr findMatchingBlock(BlockKey const& blockKey) const;
//! \brief Free block from previous block if present.
@ -203,14 +201,21 @@ class GenerationRequest
{
public:
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using SharedPtr = std::shared_ptr<GenerationRequest>;
explicit GenerationRequest(SizeType32 seqSlotIdx, SizeType32 numTokens, SizeType32 beamWidth)
: mSeqSlotIdx(seqSlotIdx)
explicit GenerationRequest(LlmRequest::RequestIdType requestId, SizeType32 numTokens, SizeType32 beamWidth,
SizeType32 maxBlocks, SizeType32 numPools = 1)
: mRequestId(requestId)
, mNumTokens(numTokens)
, mBeamWidth(beamWidth)
, mCacheBlockIds(beamWidth)
, mCacheBlockIndices{
runtime::BufferManager::cpu(runtime::ITensor::makeShape({numPools, beamWidth, 2, maxBlocks}),
runtime::TRTDataType<tensorrt_llm::kernels::KVCacheIndex>::value)}
{
auto cacheBlockIdsRange = runtime::BufferRange<tensorrt_llm::kernels::KVCacheIndex>(*mCacheBlockIndices);
std::fill(cacheBlockIdsRange.begin(), cacheBlockIdsRange.end(),
tensorrt_llm::kernels::KVCacheIndex{
std::numeric_limits<tensorrt_llm::kernels::KVCacheIndex::UnderlyingType>::max()});
}
void addNewTokens(SizeType32 n)
@ -225,9 +230,9 @@ public:
mNumTokens -= n;
}
[[nodiscard]] SizeType32 getSequenceSlotIdx() const
[[nodiscard]] LlmRequest::RequestIdType getRequestId() const
{
return mSeqSlotIdx;
return mRequestId;
}
[[nodiscard]] SizeType32 getNumTokens() const
@ -245,6 +250,16 @@ public:
return mCacheBlockIds;
}
[[nodiscard]] runtime::ITensor& getCacheBlockIndices()
{
return *mCacheBlockIndices;
}
[[nodiscard]] runtime::ITensor const& getCacheBlockIndices() const
{
return *mCacheBlockIndices;
}
void addCacheBlock(SizeType32 beamIdx, KVCacheBlock::IdType blockId)
{
mCacheBlockIds.at(beamIdx).push_back(blockId);
@ -272,14 +287,16 @@ public:
}
private:
// Slot id of the sequence
SizeType32 mSeqSlotIdx;
// Request id of the sequence
LlmRequest::RequestIdType mRequestId;
// Current number of generated tokens
SizeType32 mNumTokens;
// Number of beams
SizeType32 mBeamWidth;
// List of blocks allocated for each beam of the sequence
// List of block ids allocated for each beam of the sequence
std::vector<std::vector<KVCacheBlock::IdType>> mCacheBlockIds;
// Tensor of block indices allocated for each beam of the sequence
runtime::ITensor::SharedPtr mCacheBlockIndices;
};
// attach metadata to a pool pointer
@ -315,17 +332,19 @@ public:
// tokens_per_block, head_size]. The size per block and number of blocks are pre-determined and set in the constructor.
// BlockManager maintains a list of free blocks at any time.
// Alloc pops off the block at the front, and Free pushes it back to the vector.
// BlockManager maintains a vector of lists of seqSlotIdx to allocated blocks
// BlockManager maintains a vector of lists of request ids to allocated blocks
// per sequence. This can be used to Free all blocks belonging to a sequence.
class BlockManager
{
public:
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;
using BaseEvictionPolicy = tensorrt_llm::batch_manager::eviction_policy::BaseEvictionPolicy;
explicit BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead,
SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType = CacheType::kSELF);
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks,
CacheType cacheType = CacheType::kSELF);
~BlockManager();
@ -340,10 +359,6 @@ public:
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
void addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx);
//! \brief Release block, which puts it back onto free blocks queue.
//! \details Block appended by default, will be put at front if toFront is true.
void releaseBlock(std::shared_ptr<KVCacheBlock> block, bool toFront = false);
//! \brief Allocate new block for each beam of the sequence.
//! \details Might free cached blocks if no free blocks are available.
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams = false);
@ -359,10 +374,7 @@ public:
//! \brief Release last block in the sequence
void releaseLastBlock(GenerationRequest& sequence);
[[nodiscard]] SizeType32 getNumFreeBlocks() const noexcept
{
return mFreePrimaryBlocksSize;
}
[[nodiscard]] SizeType32 getNumFreeBlocks() const noexcept;
[[nodiscard]] SizeType32 getNumAllocTotalBlocks() const
{
@ -467,6 +479,11 @@ public:
BlockKey findNewContextBlock(
VecUniqueTokens const& uniqueTokens, std::shared_ptr<LlmRequest> const& llmRequest) const;
[[nodiscard]] runtime::BufferManager const& getBufferManager() const
{
return mBufferManager;
}
private:
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
@ -486,17 +503,9 @@ private:
SizeType32 loadOrAllocateBlocks(
std::list<BlockKey> const& blockKeys, SizeType32 numContextBlocks, GenerationRequest& sequence);
//! \brief Find best primary block to free.
//! \details The best primary block to free is the primary block that appears first in the queue and have no primary
//! block descendants
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBestGPUBlockToFree();
//! \brief Find block least likely to be reused, free it if necessary and return.
[[nodiscard]] BlockPtr getFreeBlock();
//! \brief Claim block if it is in free blocks list.
void claimBlock(KVCacheBlock& block);
//! \brief Free block from previous block and claim it from free blocks list.
void claimLeafBlock(KVCacheBlock& block);
@ -511,15 +520,9 @@ private:
// Number of blocks in pools
SizeType32 mNumPrimaryBlocks;
SizeType32 mNumSecondaryBlocks;
// List of free blocks. Blocks are either backed by fast primary memory or slow secondary memory.
// We maintain separate queues for these.
// We cache size of each queue instead of calling std::list::size, because size is O(N) function.
SizeType32 mFreePrimaryBlocksSize;
SizeType32 mFreeSecondaryBlocksSize;
FreeBlocksQueue mFreePrimaryBlocks;
FreeBlocksQueue mFreeSecondaryBlocks;
// List of allocated blocks for each sequences
std::vector<std::vector<BlockPtr>> mAllocatedBlocksPerSeq;
std::unordered_map<LlmRequest::RequestIdType, std::vector<BlockPtr>> mAllocatedBlocksPerSeq;
// Pool per unique numKvHeads in the model
std::vector<KVCacheBlockPool> mPools;
@ -547,6 +550,8 @@ private:
std::size_t mAllocTotalBlocks, mAllocNewBlocks, mReusedBlocks;
// KV cache type (self or cross)
CacheType mCacheType;
// Eviction Policy
std::shared_ptr<BaseEvictionPolicy> mEvictionPolicy;
private:
friend class KVCacheManager;
@ -555,8 +560,9 @@ private:
class KVCacheManager
{
public:
friend class KVCacheManagerBindings;
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using SequencesPtr = GenerationRequest::SharedPtr;
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;
@ -647,10 +653,10 @@ public:
/// @return The number of blocks
[[nodiscard]] SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req) const;
void addContextTokens(SizeType32 seqSlotIdx, SizeType32 numTokens);
void addContextTokens(LlmRequest::RequestIdType requestId, SizeType32 numTokens);
/// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed.
void addToken(SizeType32 seqSlotIdx);
/// @brief Increase size for request with requestId. Allocate new KV cache block(s) if needed.
void addToken(LlmRequest::RequestIdType requestId);
/// @brief Add new request to the KV cache manager.
/// @param inputLength Input length for which KV cache need to be allocated.
@ -658,23 +664,29 @@ public:
/// @param llmRequest Optional request to use for KV cache lookup.
/// @details If llmRequest is supplied and KV cache reuse is enabled, try to recover KV cache blocks for
/// inputLength - 1 tokens and populate prepopulatedPromptLen.
void addSequence(SizeType32 seqSlotIdx, SizeType32 inputLength, SizeType32 beamWidth,
void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
void removeSequence(SizeType32 seqSlotIdx, std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
void removeSequence(LlmRequest::RequestIdType requestId, std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
void schedulingRemoveSequence(SizeType32 seqSlotIdx);
void schedulingRemoveSequence(LlmRequest::RequestIdType requestId);
[[nodiscard]] runtime::ITensor::UniquePtr getBlockPoolPointers() const;
[[nodiscard]] runtime::ITensor::SharedPtr getBlockPoolPointers() const
{
return mBlockPoolPointers;
}
[[nodiscard]] runtime::ITensor::UniquePtr getLayerToPoolMapping() const;
[[nodiscard]] runtime::ITensor::SharedPtr getLayerToPoolMapping() const
{
return mLayerToPoolMapping;
}
void getBlockOffsetsOfBatch(
runtime::ITensor& output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, SizeType32 beamWidth) const;
//! @return maxBlockCount of all beams
SizeType32 copyBlockOffsets(
runtime::ITensor& output, SizeType32 outputSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const;
runtime::ITensor& output, SizeType32 outputSlotOffset, LlmRequest::RequestIdType requestId) const;
// Sum of numLayers * 2 * numKvHeads * sizePerHead for each pool
[[nodiscard]] static SizeType32 calculateCacheSizePerToken(
@ -697,10 +709,10 @@ public:
return mEnableBlockReuse;
}
void removeToken(SizeType32 seqSlotIdx);
void rewindKVCache(SizeType32 seqSlotIdx, SizeType32 rewindLengths);
void removeToken(LlmRequest::RequestIdType requestId);
void rewindKVCache(LlmRequest::RequestIdType requestId, SizeType32 rewindLengths);
[[nodiscard]] GenerationRequest const& getSequence(SizeType32 seqSlotIdx) const;
[[nodiscard]] GenerationRequest const& getSequence(LlmRequest::RequestIdType requestId) const;
[[nodiscard]] bool isCrossKv() const
{
@ -714,7 +726,7 @@ public:
//! \brief Store full context blocks contributed by llmRequest.
//! \details These blocks become reusable from next step.
void storeContextBlocks(SizeType32 seqSlotIdx, std::shared_ptr<LlmRequest> const& llmRequest);
void storeContextBlocks(std::shared_ptr<LlmRequest> const& llmRequest);
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
@ -722,14 +734,13 @@ public:
SizeType32 tokensPerBlock, SizeType32 maxBeamWidth, SizeType32 sinkTokenLen, bool useOneMoreBlock);
private:
void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 seqSlotIdx,
SizeType32 beamIdx, SizeType32 blockIdx, KVCacheBlock::IdType blockId) const;
void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
SizeType32 blockIdx, KVCacheBlock::IdType blockId) const;
void resetBlockOffsets(SizeType32 seqSlotIdx, SizeType32 beamWidth);
void cacheBlockOffsets(GenerationRequest const& seq, SizeType32 seqSlotIdx);
void cacheNewBlockOffsets(GenerationRequest const& seq, SizeType32 seqSlotIdx);
void updateNewBlockPointer(GenerationRequest const& seq, SizeType32 seqSlotIdx, SizeType32 blockIdx);
void updateToken(SizeType32 seqSlotIdx, bool addToken);
void cacheBlockOffsets(GenerationRequest& seq);
void cacheNewBlockOffsets(GenerationRequest& seq);
void updateNewBlockPointer(GenerationRequest& seq, SizeType32 blockIdx);
void updateToken(GenerationRequest& sequence, bool addToken);
private:
// Maximum number of sequences
@ -749,12 +760,13 @@ private:
SizeType32 mSinkBlockTokenLength;
// Block manager
BlockManager mBlockManager;
// List of all sequences
std::vector<SequencesPtr> mSequences;
// buffer for block indices for all managed sequences
runtime::ITensor::SharedPtr mSequenceBlockIndices;
// Map of all sequences
std::unordered_map<LlmRequest::RequestIdType, GenerationRequest> mSequences;
// Whether to cache KV pages for reuse
bool mEnableBlockReuse;
// buffers for static tensors, will be created after allocating pools
runtime::ITensor::SharedPtr mBlockPoolPointers;
runtime::ITensor::SharedPtr mLayerToPoolMapping;
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

View File

@ -65,6 +65,11 @@ public:
return ret;
}
operator runtime::ITensor::SharedPtr()
{
return mCurrent;
}
[[nodiscard]] bool operator==(BlockIterator const& other) const
{
return mIdx == other.mIdx && mPool.get() == other.mPool.get();

View File

@ -55,6 +55,7 @@ enum class LlmRequestState : int32_t
/// Waiting context-only request transmitting the kv cache
kDISAGG_CONTEXT_COMPLETE = 8, ///< Context-only request finished kv cache transmission.
kDISAGG_GENERATION_TRANS_IN_PROGRESS = 9, ///< For disaggregated serving only: transmitting the kv cache
kWAITING_TO_SEND_LOGITS = 10, ///< Generation phase completed, logits not sent yet
};
enum LlmRequestType
@ -132,8 +133,7 @@ public:
, mLoraWeights(std::move(loraWeights))
, mLoraConfig(std::move(loraConfig))
, mLookaheadConfig(std::move(lookaheadConfig))
, mContextChunkSize(std::nullopt)
, mContextCurrentPosition(0)
, mContextChunkSize{mPromptLen}
, mLogProbs(samplingConfig.beamWidth)
, mCumLogProbs(samplingConfig.beamWidth)
, mDraftTokens(draftTokens.value_or(std::make_shared<VecTokens>()))
@ -186,8 +186,7 @@ public:
, mLoraWeights(std::nullopt)
, mLoraConfig(std::nullopt)
, mLookaheadConfig(std::nullopt)
, mContextChunkSize(std::nullopt)
, mContextCurrentPosition(0)
, mContextChunkSize{mPromptLen}
, mLogProbs(mSamplingConfig.beamWidth)
, mCumLogProbs(mSamplingConfig.beamWidth)
, mDraftTokens(std::make_shared<VecTokens>())
@ -392,6 +391,15 @@ public:
mMaxNewTokens = maxNewTokens;
}
if (mNumReturnSequences > 1 && mSamplingConfig.beamWidth > 1)
{
TLLM_THROW(
"Using mNumReturnSequences (%d) > 1 with beam search is currently disabled, since TensorRT-LLM returns "
"a total of mNumReturnSequences x beamWidth beams, rather than limiting the number of returned beams "
"to mNumReturnSequences. This restriction will be removed once the issue is resolved.",
mNumReturnSequences);
}
TLLM_CHECK_WITH_INFO(mSamplingConfig.validate(), "Incorrect sampling config");
// validate extra ids when enabling kv cache reuse with prompt table
@ -722,7 +730,7 @@ public:
mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT
: LlmRequestState::kCONTEXT_INIT;
mContextCurrentPosition = 0;
mContextChunkSize = std::nullopt;
mContextChunkSize = mPromptLen;
mSeqSlot.reset();
}
@ -869,34 +877,33 @@ public:
return mPromptLen;
}
[[nodiscard]] SizeType32 getPrepopulatedPromptLen() const
{
return mPrepopulatedPromptLen;
}
void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock)
{
auto const promptLen = getPromptLen();
TLLM_CHECK(prepopulatedPromptLen < promptLen);
mPrepopulatedPromptLen = prepopulatedPromptLen;
if (prepopulatedPromptLen > 0)
{
// Currently, the runtime process is to apply for cache first and then determine prepopulation.
// Use the prepopulated length to advance the context position and decrease chunk size if necessary.
if (isFullContextRequest())
auto chunkSize = getContextChunkSize();
if (prepopulatedPromptLen + chunkSize < promptLen)
{
setContextCurrentPosition(prepopulatedPromptLen);
setContextChunkSize(promptLen);
}
else
{
auto chunkSize = getContextChunkSize();
if (prepopulatedPromptLen + chunkSize < promptLen)
{
// make sure to end at block boundary after current chunk
auto const flooredEndPosition
= (prepopulatedPromptLen + chunkSize) / kvTokensPerBlock * kvTokensPerBlock;
chunkSize = flooredEndPosition - prepopulatedPromptLen;
TLLM_CHECK(chunkSize <= getContextChunkSize());
}
setContextCurrentPosition(prepopulatedPromptLen);
setContextChunkSize(chunkSize);
// make sure to end at block boundary after current chunk
auto const flooredEndPosition
= (prepopulatedPromptLen + chunkSize) / kvTokensPerBlock * kvTokensPerBlock;
chunkSize = flooredEndPosition - prepopulatedPromptLen;
TLLM_CHECK(chunkSize <= getContextChunkSize());
}
setContextCurrentPosition(prepopulatedPromptLen);
setContextChunkSize(chunkSize);
if (!isLastContextChunk())
{
TLLM_CHECK_WITH_INFO((getContextCurrentPosition() + getContextChunkSize()) % kvTokensPerBlock == 0,
@ -1176,6 +1183,11 @@ public:
return mState == LlmRequestState::kDISAGG_CONTEXT_COMPLETE;
}
[[nodiscard]] bool isCompleteWaitingToSendLogits() const noexcept
{
return mState == LlmRequestState::kWAITING_TO_SEND_LOGITS;
}
/// To determine whether the context is unchunked. When a context is chunked into only a part, it
/// is still different from the unchunked state, which indicates the initial status.
[[nodiscard]] bool isFullContextRequest() const noexcept
@ -1211,12 +1223,11 @@ public:
return mPromptLen - getContextCurrentPosition();
}
/// To retrieve the context chunk size, throw an exception when the context is not chunked.
[[nodiscard]] SizeType32 getContextChunkSize() const
{
TLLM_CHECK_WITH_INFO(
isContextInitState() && mContextChunkSize, "The current request is not in context chunking state.");
return mContextChunkSize.value();
TLLM_CHECK_WITH_INFO(isContextInitState() || isDisaggGenerationInitState(),
"getContextChunkSize is only possible during the context phase.");
return mContextChunkSize;
}
/// To set the context chunk size, throw an exception when the chunk size is negative. If the chunk
@ -1224,45 +1235,34 @@ public:
/// remaining length.
void setContextChunkSize(SizeType32 size)
{
TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase.");
TLLM_CHECK_WITH_INFO(isContextInitState(), "setContextChunkSize is only possible during the context phase.");
TLLM_CHECK_WITH_INFO(size >= 0, "The chunk size of context (%d) can't be negative.", size);
mContextChunkSize = std::min(size, getContextRemainingLength());
}
/// Determines whether the current position is only one chunk away from the end of the context.
/// It will return true when the context is not chunked.
[[nodiscard]] bool isLastContextChunk() const noexcept
{
return isFullContextRequest()
|| (isContextInitState() && getContextCurrentPosition() + getContextChunkSize() == mPromptLen);
return isDisaggGenerationInitState() || getContextCurrentPosition() + getContextChunkSize() == mPromptLen;
}
/// Returns whether the position is at the beginning of the context. It will return true when the
/// context is not chunked.
/// Returns whether the position is at the beginning of the context.
[[nodiscard]] bool isFirstContextChunk() const noexcept
{
return isFullContextRequest() || getContextCurrentPosition() == 0;
}
[[nodiscard]] executor::PriorityType priority() const noexcept
{
return mPriority;
return getContextCurrentPosition() == 0;
}
/// Move the cursor forward one chunk. When not chunked, move forward to the end of the context.
void moveToNextContextChunk()
{
TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase.");
if (mContextChunkSize)
{
mContextCurrentPosition += getContextChunkSize();
setContextChunkSize(0);
}
else
{
TLLM_CHECK_WITH_INFO(mContextCurrentPosition == 0, "Full context out of bounds.");
mContextCurrentPosition = mPromptLen;
}
mContextCurrentPosition += getContextChunkSize();
setContextChunkSize(0);
}
[[nodiscard]] executor::PriorityType priority() const noexcept
{
return mPriority;
}
/// Increment the counter of decoding iterations.
@ -1282,20 +1282,24 @@ public:
return static_cast<float>(getMaxNumGeneratedTokens()) / mDecodingIter;
}
[[nodiscard]] bool isFinished() const noexcept
{
return isGenerationCompleteState() || isDisaggContextTransmissionState() || isCompleteWaitingToSendLogits();
}
/// @brief Create a Response from the current state of the request
/// @return An optional Response
std::optional<executor::Response> createResponse()
std::optional<executor::Response> createResponse(bool useFastLogits = false, int32_t mpiWorldRank = 0)
{
TLLM_CHECK(!isDisaggContextCompleteState());
if (isGenerationCompleteState() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS)
|| isDisaggContextTransmissionState())
if (isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS))
{
TLLM_LOG_DEBUG("Creating response for request %lu", mRequestId);
executor::Result result;
result.sequenceIndex = mSequenceIndex;
result.isSequenceFinal = isGenerationCompleteState() || isDisaggContextTransmissionState();
result.isSequenceFinal = isFinished();
mSequenceFinalVec->at(mSequenceIndex) = result.isSequenceFinal;
result.isFinal = std::all_of(mSequenceFinalVec->begin(), mSequenceFinalVec->end(),
@ -1333,8 +1337,7 @@ public:
auto const startTokenPos = maxNbTokens - maxNbTokensOut;
auto const shouldSendResponse = isGenerationCompleteState()
|| (mIsStreaming && maxNbTokens > getMaxSentTokenLen()) || isDisaggContextTransmissionState();
auto const shouldSendResponse = isFinished() || (mIsStreaming && maxNbTokens > getMaxSentTokenLen());
if (!shouldSendResponse)
{
@ -1374,6 +1377,11 @@ public:
= runtime::ITensor::slice(getGenerationLogitsHost(), startGenTokenPos, maxNbTokensOut);
result.generationLogits = executor::detail::ofITensor(generationLogitsHostCurrentStep);
}
else if (useFastLogits)
{
result.specDecFastLogitsInfo
= executor::SpeculativeDecodingFastLogitsInfo{mRequestId, mpiWorldRank};
}
else
{
result.generationLogits = executor::detail::ofITensor(getGenerationLogitsHost());
@ -1392,7 +1400,7 @@ public:
setMaxSentTokenLen(maxNbTokens);
auto requestId = isChild() ? mParentRequestId : mRequestId;
auto response = executor::Response(requestId, std::move(result));
auto response = executor::Response(requestId, std::move(result), mClientId);
return response;
}
@ -1483,8 +1491,8 @@ protected:
// To enable chunked context, the FHMA paged kv-cache also needs to be enabled. Except for the last one,
// the size of the context chunk needs to be an integer multiple of the kv-cache block size. The meaning
// of null value is that the context is not chunked.
std::optional<SizeType32> mContextChunkSize;
SizeType32 mContextCurrentPosition;
SizeType32 mContextChunkSize{0};
SizeType32 mContextCurrentPosition{0};
std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen]
VecLogProbs mCumLogProbs; // [beamSize]
@ -1636,6 +1644,8 @@ private:
class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
{
friend class LlmRequestBindings;
public:
using Base = GenericLlmRequest<runtime::ITensor::SharedPtr>;
using TensorPtr = Base::TensorPtr;

View File

@ -0,0 +1,108 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "common.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/common/algorithm.h"
#include "tensorrt_llm/runtime/common.h"
namespace tensorrt_llm::batch_manager
{
namespace batch_scheduler
{
struct ContextChunkingConfig
{
ContextChunkingConfig() = default;
executor::ContextChunkingPolicy chunkingPolicy;
/// The minimum size, also known as the chunk unit size. It generally
/// needs to be equal to the size of the kv cache block or its integer
/// multiples (except for the last context chunk) to avoid fragmentation.
/// When set to null, it indicates that the context chunk is disabled.
tensorrt_llm::runtime::SizeType32 chunkUnitSize;
};
} // namespace batch_scheduler
/// @brief This scheduler takes into account the desired batch size and limits of the TRT engine to schedule requests.
class MicroBatchScheduler : Algorithm
{
public:
constexpr static auto name{"MicroBatchScheduler"};
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using ContextChunkingPolicy = tensorrt_llm::executor::ContextChunkingPolicy;
MicroBatchScheduler() = default;
explicit MicroBatchScheduler(SizeType32 maxBatchSize, std::optional<SizeType32> maxNumTokens = std::nullopt,
std::optional<batch_scheduler::ContextChunkingConfig> ctxChunkConfig = std::nullopt,
std::optional<SizeType32> maxContextLength = std::nullopt,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
static MicroBatchScheduler make(SizeType32 maxBatchSize, std::optional<SizeType32> maxNumTokens = std::nullopt,
std::optional<batch_scheduler::ContextChunkingConfig> ctxChunkConfig = std::nullopt,
std::optional<SizeType32> maxContextLength = std::nullopt,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE)
{
return MicroBatchScheduler{
maxBatchSize, maxNumTokens, ctxChunkConfig, maxContextLength, noScheduleUntilState, noScheduleAfterState};
}
std::tuple<RequestVector, RequestVector> operator()(
RequestVector const& activeRequests, ReqIdsSet const& inflightReqIds);
static void setCtxRequestsChunkSize(RequestVector const& contextsToBeChunked, ContextChunkingPolicy ctxChunkPolicy,
std::optional<SizeType32> ctxTokensCapacity, SizeType32 chunkUnitSize,
std::optional<SizeType32> const& maxContextLength);
private:
template <ContextChunkingPolicy tPolicy>
static void setCtxRequestsChunkSize(RequestVector const& contextsToBeChunked,
std::optional<SizeType32> ctxTokensCapacity, SizeType32 chunkUnitSize,
std::optional<SizeType32> const& maxContextLength);
/// After the chunk sizes have been determined, this function will discard
/// any draft tokens that don't fit.
static void fitDraftTokens(RequestVector const& contextsToBeChunked, std::optional<SizeType32> ctxTokensCapacity,
SizeType32 chunkUnitSize, std::optional<SizeType32> const& maxContextLength);
/// The maximum number of requests returned by scheduleRequests
SizeType32 mMaxBatchSize;
/// The maximum number of tokens to include in a batch
std::optional<SizeType32> mMaxNumTokens;
/// The maximum length of the context. If the context exceeds this length,
/// it must be chunked, otherwise it cannot be processed. Therefore, it
/// needs to be set together with the chunk unit size to make sense.
/// When set to null, it indicates that context length is unlimited.
std::optional<SizeType32> mMaxContextLength;
std::optional<batch_scheduler::ContextChunkingConfig> mCtxChunkConfig;
/// The state until/after which the scheduler should not schedule requests
LlmRequestState mNoScheduleUntilState;
LlmRequestState mNoScheduleAfterState;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -51,6 +51,8 @@ public:
class BasePeftCacheManager
{
public:
friend class BasePeftCacheManagerBindings;
using LlmRequestPtr = std::shared_ptr<LlmRequest>;
using RequestVector = std::vector<LlmRequestPtr>;
using PeftTable = std::map<uint64_t, std::shared_ptr<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>>;

View File

@ -46,7 +46,9 @@ public:
executor::SchedulerConfig const& schedulerConfig = executor::SchedulerConfig{},
executor::ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig
= executor::ExtendedRuntimePerfKnobConfig{},
std::optional<executor::DebugConfig> debugConfig = std::nullopt, uint64_t maxSeqIdleMicroseconds = 180000000)
std::optional<executor::DebugConfig> debugConfig = std::nullopt, uint64_t maxSeqIdleMicroseconds = 180000000,
std::optional<executor::SpeculativeDecodingConfig> specDecConfig = std::nullopt,
bool isLeaderInOrchMode = false)
: kvCacheConfig{kvCacheConfig}
, enableTrtOverlap{enableTrtOverlap}
, deviceIds(deviceIds)
@ -62,10 +64,12 @@ public:
, extendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig)
, debugConfig{std::move(debugConfig)}
, maxSeqIdleMicroseconds{maxSeqIdleMicroseconds}
, speculativeDecodingConfig{std::move(specDecConfig)}
, isLeaderInOrchMode{isLeaderInOrchMode}
{
}
explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig)
explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig, bool isLeaderInOrchMode)
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), false,
executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(),
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(),
@ -74,7 +78,7 @@ public:
executorConfig.getGpuWeightsPercent(), executorConfig.getMaxBeamWidth(), executorConfig.getMaxBatchSize(),
executorConfig.getMaxNumTokens(), executorConfig.getSchedulerConfig(),
executorConfig.getExtendedRuntimePerfKnobConfig(), executorConfig.getDebugConfig(),
executorConfig.getMaxSeqIdleMicroseconds())
executorConfig.getMaxSeqIdleMicroseconds(), executorConfig.getSpecDecConfig(), isLeaderInOrchMode)
{
}
@ -94,6 +98,8 @@ public:
&& extendedRuntimePerfKnobConfig == other.extendedRuntimePerfKnobConfig //
&& debugConfig == other.debugConfig //
&& maxSeqIdleMicroseconds == other.maxSeqIdleMicroseconds //
&& speculativeDecodingConfig == other.speculativeDecodingConfig //
&& isLeaderInOrchMode == other.isLeaderInOrchMode //
;
}
@ -117,6 +123,9 @@ public:
std::optional<executor::DebugConfig> debugConfig;
// Sequence is considered idle if not updated for this amount of time.
uint64_t maxSeqIdleMicroseconds;
std::optional<executor::SpeculativeDecodingConfig> speculativeDecodingConfig;
// This rank is the leader worker in orchestrator mode
bool isLeaderInOrchMode;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -0,0 +1,32 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace tensorrt_llm
{
// Base class for algorithms
struct Algorithm
{
Algorithm() = default;
Algorithm(Algorithm&&) = default;
Algorithm& operator=(Algorithm&&) = default;
Algorithm(Algorithm const&) = delete;
Algorithm& operator=(Algorithm const&) = delete;
};
} // namespace tensorrt_llm

View File

@ -99,7 +99,6 @@ struct MpiTypeConverter<std::byte>
};
template <>
struct MpiTypeConverter<half>
{
@ -387,6 +386,7 @@ public:
void barrier() const;
void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
bool improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
//! \brief Returns if a message with the specified source and tag is available
bool iprobe(int source, int tag, MPI_Status* status) const;

View File

@ -186,11 +186,13 @@ class ExternalDraftTokensConfig
{
public:
explicit ExternalDraftTokensConfig(VecTokens tokens, std::optional<Tensor> logits = std::nullopt,
std::optional<FloatType> const& acceptanceThreshold = std::nullopt);
std::optional<FloatType> const& acceptanceThreshold = std::nullopt,
std::optional<bool> const& fastLogits = std::nullopt);
[[nodiscard]] VecTokens getTokens() const;
[[nodiscard]] std::optional<Tensor> getLogits() const;
[[nodiscard]] std::optional<FloatType> getAcceptanceThreshold() const;
[[nodiscard]] std::optional<bool> getFastLogits() const;
private:
friend class Serialization;
@ -200,6 +202,8 @@ private:
std::optional<Tensor> mLogits;
/// @brief The acceptance threshold. Must be > 0.f and <= 1.f
std::optional<FloatType> mAcceptanceThreshold;
/// @brief Use direct transfer for draft logits
std::optional<bool> mFastLogits;
};
/// @brief Configuration for prompt tuning
@ -318,6 +322,18 @@ private:
StatePtr mState{nullptr, deleter};
};
/// @brief Configuration for speculative decoding (both draft and target models)
class SpeculativeDecodingConfig
{
public:
explicit SpeculativeDecodingConfig(bool fastLogits);
bool operator==(SpeculativeDecodingConfig const& other) const;
/// @brief Send logits tensor directly from draft to target model.
bool fastLogits;
};
/// @brief A class that holds information about the request
class Request
{
@ -437,6 +453,16 @@ private:
std::unique_ptr<Impl> mImpl;
};
/// @brief Struct that holds the logits information when using direct transfer
struct SpeculativeDecodingFastLogitsInfo
{
/// @brief Draft request id
uint64_t draftRequestId;
/// @brief MPI world rank of the draft model leader
int32_t draftParticipantId;
};
/// @brief Struct that holds the generation result
struct Result
{
@ -455,11 +481,14 @@ struct Result
/// @brief The context logits. Size [promptLen, vocabSizePadded]
std::optional<Tensor> contextLogits;
/// @brief The context logits. Size [beamSize, maxNewTokens, vocabSizePadded] (non-streaming)
/// @brief The generation logits. Size [beamSize, maxNewTokens, vocabSizePadded] (non-streaming)
/// or [maxNewTokens, beamSize, vocabSizePadded] (streaming and allGeneratedTokens)
/// or [1, beamSize, vocabSizePadded] (streaming and non-allGeneratedTokens)
std::optional<Tensor> generationLogits;
/// @brief Logits information for direct transfer when using fast logits
std::optional<SpeculativeDecodingFastLogitsInfo> specDecFastLogitsInfo;
/// @brief The encoder output. Size [encoderLen, hiddenSize]
std::optional<Tensor> encoderOutput;
@ -484,8 +513,8 @@ struct Result
class Response
{
public:
Response(IdType requestId, std::string errorMsg);
Response(IdType requestId, Result Result);
Response(IdType requestId, std::string errorMsg, std::optional<IdType> clientId = std::nullopt);
Response(IdType requestId, Result Result, std::optional<IdType> clientId = std::nullopt);
~Response();
Response(Response const& other);
@ -496,6 +525,9 @@ public:
/// @brief Get the id of the request for which this response was generated
[[nodiscard]] IdType getRequestId() const;
/// @brief Get the client id of the request for which this response was generated
[[nodiscard]] std::optional<IdType> getClientId() const;
/// @brief Indicates if this response has an error or not
[[nodiscard]] bool hasError() const;
@ -873,7 +905,8 @@ public:
std::optional<SizeType32> maxQueueSize = std::nullopt,
ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig = ExtendedRuntimePerfKnobConfig(),
std::optional<DebugConfig> debugConfig = std::nullopt, SizeType32 recvPollPeriodMs = 0,
uint64_t maxSeqIdleMicroseconds = 180000000);
uint64_t maxSeqIdleMicroseconds = 180000000,
std::optional<SpeculativeDecodingConfig> specDecConfig = std::nullopt);
[[nodiscard]] SizeType32 getMaxBeamWidth() const;
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
@ -895,6 +928,7 @@ public:
[[nodiscard]] std::optional<DebugConfig> getDebugConfig() const;
[[nodiscard]] SizeType32 getRecvPollPeriodMs() const;
[[nodiscard]] uint64_t getMaxSeqIdleMicroseconds() const;
[[nodiscard]] std::optional<SpeculativeDecodingConfig> getSpecDecConfig() const;
void setMaxBeamWidth(SizeType32 maxBeamWidth);
void setMaxBatchSize(SizeType32 maxBatchSize);
@ -916,6 +950,7 @@ public:
void setDebugConfig(DebugConfig const& debugConfig);
void setRecvPollPeriodMs(SizeType32 const& recvPollPeriodMs);
void setMaxSeqIdleMicroseconds(uint64_t maxNumTokens);
void setSpecDecConfig(SpeculativeDecodingConfig const& specDecConfig);
private:
friend class Serialization;
@ -978,6 +1013,9 @@ private:
/// @brief The maximum time in microseconds a scheduled request can remain idle before getting terminated. Default
/// is 3 minutes.
uint64_t mMaxSeqIdleMicroseconds;
/// @brief The speculative decoding configuration
std::optional<SpeculativeDecodingConfig> mSpeculativeDecodingConfig;
};
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
@ -1080,6 +1118,9 @@ public:
/// @brief Indicates if the current process is allowed to enqueueRequests
[[nodiscard]] bool canEnqueueRequests() const;
/// @brief Indicates if the current process participates in this executor instance
[[nodiscard]] bool isParticipant() const;
private:
class Impl;
std::unique_ptr<Impl> mImpl;

View File

@ -95,6 +95,11 @@ public:
static void serialize(Tensor const& tensor, std::ostream& os);
[[nodiscard]] static size_t serializedSize(Tensor const& tensor);
// SpeculativeDecodingFastLogitsInfo
[[nodiscard]] static SpeculativeDecodingFastLogitsInfo deserializeSpecDecFastLogitsInfo(std::istream& is);
static void serialize(SpeculativeDecodingFastLogitsInfo const& info, std::ostream& os);
[[nodiscard]] static size_t serializedSize(SpeculativeDecodingFastLogitsInfo const& info);
// Result
[[nodiscard]] static Result deserializeResult(std::istream& is);
static void serialize(Result const& result, std::ostream& os);

View File

@ -446,6 +446,11 @@ public:
return DecodingMode{kExplicitDraftTokens | kStandardStopCriteria | kUseExplicitEosStop};
}
static auto constexpr ExternalDraftTokens()
{
return DecodingMode{kExternalDraftTokens | kUsePenalties | kUseBanTokens | kStandardStopCriteria};
}
auto constexpr useTemperature(bool useTemp)
{
mState = setBitTo(kUseTemperature, useTemp);
@ -563,6 +568,11 @@ public:
return anyBitSet(kExplicitDraftTokens);
}
[[nodiscard]] bool constexpr isExternalDraftTokens() const
{
return anyBitSet(kExternalDraftTokens);
}
[[nodiscard]] bool constexpr isUseTemperature() const
{
return anyBitSet(kUseTemperature);
@ -676,6 +686,7 @@ private:
static UnderlyingType constexpr kMedusa{1u << (kNumFlags + 4)};
static UnderlyingType constexpr kLookahead{1u << (kNumFlags + 5)};
static UnderlyingType constexpr kExplicitDraftTokens{1u << (kNumFlags + 6)};
static UnderlyingType constexpr kExternalDraftTokens{1u << (kNumFlags + 7)};
static UnderlyingType constexpr kTopKTopP{kTopK | kTopP};
[[nodiscard]] bool constexpr anyBitSet(UnderlyingType bits) const
@ -706,6 +717,7 @@ static_assert(!DecodingMode::Auto().isBeamSearch());
static_assert(!DecodingMode::Auto().isMedusa());
static_assert(!DecodingMode::Auto().isLookahead());
static_assert(!DecodingMode::Auto().isExplicitDraftTokens());
static_assert(!DecodingMode::Auto().isExternalDraftTokens());
static_assert(DecodingMode::TopK().isTopK());
static_assert(DecodingMode::TopK().isTopKorTopP());
@ -726,6 +738,7 @@ static_assert(!DecodingMode::TopK().isBeamSearch());
static_assert(!DecodingMode::TopK().isMedusa());
static_assert(!DecodingMode::TopK().isLookahead());
static_assert(!DecodingMode::TopK().isExplicitDraftTokens());
static_assert(!DecodingMode::TopK().isExternalDraftTokens());
static_assert(DecodingMode::TopP().isTopP());
static_assert(DecodingMode::TopP().isTopKorTopP());
@ -739,6 +752,7 @@ static_assert(!DecodingMode::TopP().isBeamSearch());
static_assert(!DecodingMode::TopP().isMedusa());
static_assert(!DecodingMode::TopP().isLookahead());
static_assert(!DecodingMode::TopP().isExplicitDraftTokens());
static_assert(!DecodingMode::TopP().isExternalDraftTokens());
static_assert(DecodingMode::TopKTopP().isTopK());
static_assert(DecodingMode::TopKTopP().isTopP());
@ -752,6 +766,7 @@ static_assert(!DecodingMode::TopKTopP().isBeamSearch());
static_assert(!DecodingMode::TopKTopP().isMedusa());
static_assert(!DecodingMode::TopKTopP().isLookahead());
static_assert(!DecodingMode::TopKTopP().isExplicitDraftTokens());
static_assert(!DecodingMode::TopKTopP().isExternalDraftTokens());
static_assert(DecodingMode::BeamSearch().isBeamSearch());
static_assert(DecodingMode::BeamSearch().isUseStopCriteria());
@ -760,6 +775,7 @@ static_assert(!DecodingMode::BeamSearch().isTopKorTopP());
static_assert(!DecodingMode::BeamSearch().isMedusa());
static_assert(!DecodingMode::BeamSearch().isLookahead());
static_assert(!DecodingMode::BeamSearch().isExplicitDraftTokens());
static_assert(!DecodingMode::BeamSearch().isExternalDraftTokens());
static_assert(!DecodingMode::Medusa().isAuto());
static_assert(!DecodingMode::Medusa().isTopK());
@ -775,6 +791,7 @@ static_assert(DecodingMode::Medusa().isUseStopCriteria());
static_assert(DecodingMode::Medusa().isUsePenalty());
static_assert(DecodingMode::Medusa().isUseMinLength());
static_assert(DecodingMode::Medusa().isMedusa());
static_assert(!DecodingMode::Medusa().isExternalDraftTokens());
static_assert(!DecodingMode::Lookahead().isAuto());
static_assert(!DecodingMode::Lookahead().isTopK());
@ -788,6 +805,7 @@ static_assert(DecodingMode::Lookahead().isUseStopCriteria());
static_assert(DecodingMode::Lookahead().isUseStopWords());
static_assert(DecodingMode::Lookahead().isUseExplicitEosStop());
static_assert(DecodingMode::Lookahead().isLookahead());
static_assert(!DecodingMode::Lookahead().isExternalDraftTokens());
static_assert(!DecodingMode::ExplicitDraftTokens().isAuto());
static_assert(!DecodingMode::ExplicitDraftTokens().isTopK());
@ -801,4 +819,19 @@ static_assert(!DecodingMode::ExplicitDraftTokens().isUsePenalty());
static_assert(DecodingMode::ExplicitDraftTokens().isUseStopCriteria());
static_assert(!DecodingMode::ExplicitDraftTokens().isUseBanWords());
static_assert(DecodingMode::ExplicitDraftTokens().isExplicitDraftTokens());
static_assert(!DecodingMode::ExplicitDraftTokens().isExternalDraftTokens());
static_assert(!DecodingMode::ExternalDraftTokens().isTopK());
static_assert(!DecodingMode::ExternalDraftTokens().isTopP());
static_assert(!DecodingMode::ExternalDraftTokens().isTopKorTopP());
static_assert(!DecodingMode::ExternalDraftTokens().isTopKandTopP());
static_assert(DecodingMode::ExternalDraftTokens().isUseBanWords());
static_assert(DecodingMode::ExternalDraftTokens().isUseOccurrencePenalty());
static_assert(DecodingMode::ExternalDraftTokens().isUseStopCriteria());
static_assert(!DecodingMode::ExternalDraftTokens().isAuto());
static_assert(!DecodingMode::ExternalDraftTokens().isBeamSearch());
static_assert(!DecodingMode::ExternalDraftTokens().isMedusa());
static_assert(!DecodingMode::ExternalDraftTokens().isLookahead());
static_assert(!DecodingMode::ExternalDraftTokens().isExplicitDraftTokens());
static_assert(DecodingMode::ExternalDraftTokens().isExternalDraftTokens());
} // namespace tensorrt_llm::executor

View File

@ -108,6 +108,20 @@ public:
TensorConstPtr medusaTargetTokensPerStep; //!< [batchSize], on gpu
};
class ExternalDraftTokensInputs
{
public:
TensorPtr draftLogits;
TensorPtr draftProbs;
TensorPtr targetProbs;
TensorPtr numDraftTokens;
TensorPtr draftTokenIds;
TensorPtr useDraftLogits;
SizeType32 step;
float constantThreshold;
bool useRandomAcceptanceThreshold;
};
class ExplicitDraftTokensInputs
{
public:
@ -138,6 +152,8 @@ public:
std::optional<ExplicitDraftTokensInputs> explicitDraftTokensInputs;
std::optional<LookaheadInputs> lookaheadInputs;
std::optional<ExternalDraftTokensInputs> externalDraftTokensInputs;
};
} // namespace tensorrt_llm::runtime

View File

@ -95,7 +95,7 @@ public:
// mandatory parameters for beam search
TensorPtr logProbs; // [BS, BM, MSL], must be float*
TensorPtr cumLogProbs; // [BS, BM], optional for sampling
TensorPtr parentIds; // [BS, BM, MSL]
TensorPtr parentIds; // [BS, BM, MSL] index of the beam where the previous token is
TensorPtr lengths; // [BS, BM], total sequence lengths including padding
TensorPtr cacheIndirection; // [BS, BM, MSL], k/v indirection for next generation step

View File

@ -64,16 +64,6 @@ public:
virtual SamplingConfig const& getSamplingConfig() = 0;
static void acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds,
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths,
ITensor const& finishedVec, ITensor& finishedFinal, ITensor& finishedSum, ITensor const& batchSlots,
BufferManager::CudaStreamPtr const& stream);
static void acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const& targetLogits, ITensor& draftProbs,
ITensor& targetProbs, ITensor const& numDraftTokens, ITensor& finished, ITensor const& batchSlots,
SizeType32 vocabSize, SizeType32 vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream);
static std::unique_ptr<IGptDecoder> create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
BufferManager::CudaStreamPtr const& stream,

View File

@ -245,7 +245,7 @@ private:
void newRequest(SizeType32 batchSlot, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
//! @brief Allocate buffers for speculative decoding.
void allocateSpeculativeDecodingBuffers();
void allocateSpeculativeDecodingBuffers(nvinfer1::DataType dtype);
//! @brief Setup buffers for speculative decoding.
void setupSpeculativeDecoding(ModelConfig const& modelConfig);
@ -300,10 +300,6 @@ private:
DecodingInputPtr mJointDecodingInput;
DecodingOutputPtr mJointDecodingOutput;
std::vector<bool> mAcceptByLogits;
TensorPtr mNumDraftTokens;
TensorPtr mCurandStates;
std::vector<SizeType32> mNbSteps;
std::vector<bool> mFinished;
TensorPtr mFinishedSum;
@ -313,18 +309,9 @@ private:
TensorPtr mFinishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states of type FinishedState
// for each generated token of maxTokensPerStep, on gpu
TensorPtr mDraftProbs; // [batchSize, maxTokensPerEngineStep, beamWidth, vocabPadded], temporary data for
// speculative decoding accept by logits kernel, on gpu
TensorPtr mTargetProbs; // [batchSize, maxTokensPerEngineStep, beamWidth, vocabPadded], temporary data for
// speculative decoding accept by logits kernel, on gpu
TensorPtr mDraftTokenIds; // [batchSize, maxTokensPerEngineStep], draft token indices, on gpu
TensorPtr mDraftLogits; // [batchSize, maxTokensPerEngineStep, vocabSizePadded], draft token logits, on gpu
TensorPtr mBatchSlotsSetup; // [maxBatchSize], int32_t, address map, pinned
TensorPtr mBatchSlotsDecoder; // [maxTokensPerEngineStep, maxBatchSize], int32_t, address map, pinned
TensorPtr mBatchSlotsAcceptTokens; // [maxTokensPerEngineStep, maxBatchSize], int32_t, address map, pinned
TensorPtr mBatchSlotsAcceptLogits; // [maxTokensPerEngineStep, maxBatchSize], int32_t, address map, pinned
TensorPtr mTargetLogitsPtrs; // [maxBatchSize], float*, pointers to target logits, pinned
SizeType32 mMaxSequenceLength{};
SizeType32 mMaxAttentionWindow{};
SizeType32 mSinkTokenLength{};

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:10b940475c5acd80a61674d8ce4e42cc4ef3d806bafb245bbed26751378274e3
size 4904726
oid sha256:1a292517d802f2297c5d12d5d14ab597f47f46ebd31412fac044ceb9ca51a482
size 5160586

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b2754f7887a1b5c37ba3d589320e16144039cfe5dc6a6c78ee71925861d7d511
size 5015842
oid sha256:8575fb58200701ae30feb4b8bd3f325f8018aac5505167fdba42e269adb3bd8c
size 5271836

View File

@ -1,3 +1,3 @@
ff71eabd0ac6ede5398b5b6ce4e26dcf libtensorrt_llm_batch_manager_static.a
846eb112a182973e7c3b0b193300b4b8 libtensorrt_llm_batch_manager_static.pre_cxx11.a
7f370deb0090d885d7518c2b146399ba3933c004 commit
954182e0c057f71f858a84f746201044 libtensorrt_llm_batch_manager_static.a
dfe6ca360cf1d24a3dcae0a2bf8589c0 libtensorrt_llm_batch_manager_static.pre_cxx11.a
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:13b8701dd767b414a5376a91905985979ad9d2b975465ac00835c04656ee6508
size 4766226
oid sha256:8fe84073b7ccff8dc361fdee64c3ef30bc523909e0bf9c16547f76a05a53fb5c
size 5009886

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:cd0b73a017fc5c663235dcd724eb104ecc49d12ff29b6e3744be6ea952d027db
size 4722522
oid sha256:6e565c2c3ce58656742772591d992aca91c7e46eb9fc711599d2d51928b88b48
size 4970532

View File

@ -1,3 +1,3 @@
1eb5c88f894f3361445d7254cbc29b03 libtensorrt_llm_batch_manager_static.a
4e73341b23e8fb20b732ba08e03a54a8 libtensorrt_llm_batch_manager_static.pre_cxx11.a
7f370deb0090d885d7518c2b146399ba3933c004 commit
61fd34e765788884d42f4ba27f085520 libtensorrt_llm_batch_manager_static.a
e8a64dd19a234304483ef6756e67fd40 libtensorrt_llm_batch_manager_static.pre_cxx11.a
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b4ac61c0b0816477c11bd6c66ec4c2f23f7b6e1400eacd8c07c333f79dec0bea
size 30794956
oid sha256:200a6721aa1d6e009c94866adab36ac686eb1beef02df267af7e18e31e11612b
size 32436708

View File

@ -1,2 +1,2 @@
eefe7310a60098897724f46cf4aa54f8 tensorrt_llm_batch_manager_static.lib
7f370deb0090d885d7518c2b146399ba3933c004 commit
9485cfa635b17378f23d1624b3acfbaf tensorrt_llm_batch_manager_static.lib
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit

View File

@ -21,7 +21,7 @@
namespace tensorrt_llm::utils::customAllReduceUtils
{
constexpr size_t NUM_POINTERS_PER_RANK = 4;
constexpr size_t NUM_POINTERS_PER_RANK = 7;
// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py
inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept

View File

@ -335,6 +335,18 @@ void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status)
#endif // ENABLE_MULTI_DEVICE
}
bool MpiComm::improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE
int flag{0};
MPICHECK(MPI_Improbe(source, tag, mComm, &flag, msg, status));
return flag != 0;
#else
TLLM_THROW("Multi device support is disabled.");
return false;
#endif
}
bool MpiComm::iprobe(int source, int tag, MPI_Status* status) const
{
#if ENABLE_MULTI_DEVICE

View File

@ -38,6 +38,12 @@ namespace common
template <int VPT>
struct BytesToType;
template <>
struct BytesToType<1>
{
using type = uint8_t;
};
template <>
struct BytesToType<2>
{

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ebab2cc2c62a826ddec02597178b8e0c9bc316726f37f8eef37c06795aebcf03
size 1784658
oid sha256:809a1da76123ec4c640d63efc902209585223b66e23d887db9a198c5836986a2
size 3349066

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4b630f89708614e63c67871e21b6e32bfde71acc51549b650c57048c0fa343e7
size 1812686
oid sha256:6846ecefa017d03ab7d853908794c884ab4e92a500e223278b1d64eab59ed061
size 3376088

View File

@ -1,3 +1,3 @@
136f1b9d2168cbb9011a341b267af9a2 libtensorrt_llm_executor_static.a
183bd079377d6cd698d46370168a5726 libtensorrt_llm_executor_static.pre_cxx11.a
7f370deb0090d885d7518c2b146399ba3933c004 commit
5a771664fdb75d99ba5fb90249ac26f0 libtensorrt_llm_executor_static.a
3b433ea93b7d1d6fa471b457980f2680 libtensorrt_llm_executor_static.pre_cxx11.a
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e04c76f6441a49db4d3996c62b4055395ae018384d8ee2f02ea5f0c4c0843902
size 1853180
oid sha256:479e86f410763445357f5d879cc666d210352dda9709ab5ab56e73591a9e8af8
size 7851266

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:95ba1a4b6bdcecbb592bbb42b4998bcb0eb1f45a318163635183bcde6950c4bf
size 1764982
oid sha256:6473c77d18929fa75342d63ffc591df39e8aeba1dda0b920b0187d4888710559
size 7767384

View File

@ -1,3 +1,3 @@
dfbd0d424c150253ff758aa5bd37a971 libtensorrt_llm_executor_static.a
e82866739fef1d6df8293541967924bf libtensorrt_llm_executor_static.pre_cxx11.a
7f370deb0090d885d7518c2b146399ba3933c004 commit
5424fb0f82076e03b5316f73aed04434 libtensorrt_llm_executor_static.a
d0b1236baf61fc5c43383bbc1cd50fa8 libtensorrt_llm_executor_static.pre_cxx11.a
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:aa8ba34fb98c5407e3d6944245086158c61b2c784b15c7b923fdd156b942224d
size 19670642
oid sha256:dee57c9257a6678833e3c0d83e8df07aff25c185bc085db75938cec6652044c0
size 24568210

View File

@ -1,2 +1,2 @@
784ad1fabd3d02466f95fbc463b64f5b tensorrt_llm_executor_static.lib
7f370deb0090d885d7518c2b146399ba3933c004 commit
305fac5d046a574ded2d46d968f746b0 tensorrt_llm_executor_static.lib
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit

View File

@ -630,7 +630,7 @@ void topKSoftMaxKernelLauncher(T const* logits, T const* bias, void* workspace,
// ┃ pTemp ┃ BS * PAD_K * VP * (2 * (PAD_K * 2) + 2) | | float |
// ┗━━━━━━━━━━┛ --------------------------------------------------------------------------------
// Stage1: gridDim(BS,BM,nVPart), blockDim(nBlockSize,1,1)
// beamStage1Kernel: gridDim(BS,BM,nVPart), blockDim(nBlockSize,1,1)
// Each ThreadBlock takes `nVocabChunk` contiguous elements in logits to do TopK and reduce_md,
// then writes output into pTemp.
// At end of this kernel, each ThreadBlock holds the indices and values of the top 2*BM elements,
@ -647,7 +647,7 @@ void topKSoftMaxKernelLauncher(T const* logits, T const* bias, void* workspace,
// ┃ md ┃ 2 | 2 | float |
// ┗━━━━━━━━━━┛ -----------------------------------------
// Stage2: gridDim(BS,BM,1), blockDim(32/64/128,1,1)
// beamStage2Kernel: gridDim(BS,BM,1), blockDim(32/64/128,1,1)
// Each TheadBlock takes `nVPart` contiguous Tiles in pTemp to do reduce_topk and reduce_md,
// writes output topk_id into in pTempId, writes topk_value + cumLogProbs into pTempVal.

View File

@ -165,7 +165,7 @@ void FusedMHARunnerV2::setupKernelParams(MHARunnerParams runnerParams)
// Use exp2f optimization for warp-specialized ws kernels on Hopper.
if (mLaunchParams.useBase2ExpTrick)
{
// The kernel adopts the log2f optimziation.
// The kernel adopts the log2f optimization.
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
set_alpha(mKernelParams.scale_bmm1, scale_bmm1 * float(kLog2e), DATA_TYPE_FP32);
}
@ -364,8 +364,8 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams)
void FusedMHARunnerV2::setPackedQkvTmaDescriptors(MHARunnerParams runnerParams)
{
// split D into multiple groups in order to match the TMA swizzle mode (128B)
const uint32_t d_in_bytes = get_size_in_bytes(mLaunchParams.padded_d, mFixedParams.dataType);
const uint32_t d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
uint32_t const d_in_bytes = get_size_in_bytes(mLaunchParams.padded_d, mFixedParams.dataType);
uint32_t const d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
// separate q, k, v and o tma descriptors
Multiple_tma_descriptor<4> qkv_tma_descriptor;
@ -421,8 +421,8 @@ void FusedMHARunnerV2::setPackedQkvTmaDescriptors(MHARunnerParams runnerParams)
uint32_t fp32_to_tf32 = 0;
// gmma descriptor mode
const uint32_t d_bytes_per_group = d_in_bytes / d_groups;
const cudaTmaDescSwizzle swizzle_mode = (d_bytes_per_group > 64
uint32_t const d_bytes_per_group = d_in_bytes / d_groups;
cudaTmaDescSwizzle const swizzle_mode = (d_bytes_per_group > 64
? cudaTmaDescSwizzle::SWIZZLE_128B
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
@ -474,8 +474,8 @@ void FusedMHARunnerV2::setPackedQkvTmaDescriptors(MHARunnerParams runnerParams)
void FusedMHARunnerV2::setSeparateQKvTmaDescriptors(MHARunnerParams runnerParams)
{
// split D into multiple groups in order to match the TMA swizzle mode (128B)
const uint32_t d_in_bytes = get_size_in_bytes(mLaunchParams.padded_d, mFixedParams.dataType);
const uint32_t d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
uint32_t const d_in_bytes = get_size_in_bytes(mLaunchParams.padded_d, mFixedParams.dataType);
uint32_t const d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
uint32_t q_step = 0, kv_step = 0;
xmmaKernel->getStepSize(q_step, kv_step, mKernelParams, mLaunchParams);
@ -518,7 +518,7 @@ void FusedMHARunnerV2::setSeparateQKvTmaDescriptors(MHARunnerParams runnerParams
= (get_size_in_bytes(mFixedParams.dataType) == 1) ? cudaTmaDescFormat::U8 : cudaTmaDescFormat::F16_RN;
// gmma descriptor mode
const uint32_t d_bytes_per_group = d_in_bytes / d_groups;
uint32_t const d_bytes_per_group = d_in_bytes / d_groups;
cudaTmaDescSwizzle const swizzle_mode = (d_bytes_per_group > 64
? cudaTmaDescSwizzle::SWIZZLE_128B
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));

View File

@ -17,8 +17,11 @@
#include "customAllReduceKernels.h"
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/customAllReduceUtils.h"
#include "tensorrt_llm/common/dataType.h"
#include "tensorrt_llm/common/envUtils.h"
#include <cooperative_groups.h>
#include <tuple>
#include <type_traits>
@ -174,12 +177,6 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
namespace reduce_fusion
{
namespace details
{
static constexpr int kBytesPerAccess = 16;
static constexpr int kWarpSize = 32;
static constexpr int kMaxCtaSize = 1024;
}; // namespace details
inline __device__ float warp_reduce_sum(float val)
{
@ -318,7 +315,7 @@ __global__ void rms_norm_kernel(AllReduceParams params)
}
template <typename T, bool Bias = false, bool Residual = false, bool Affine = false>
void rms_norm_kernel_launcher(AllReduceParams params, cudaStream_t stream)
void rms_norm_kernel_launcher(AllReduceParams& params, cudaStream_t stream)
{
static constexpr int kPackedSize = details::kBytesPerAccess / sizeof(T);
TLLM_CHECK(params.fusion_params.hidden_size % kPackedSize == 0);
@ -387,6 +384,395 @@ void rms_norm_kernel_launcher(AllReduceParams params, cudaStream_t stream)
}
}
template <typename T>
struct NegZero128b
{
static constexpr int v = static_cast<int>(0x80008000);
static constexpr int4 value = {v, v, v, v};
};
template <>
struct NegZero128b<float>
{
static constexpr int v = static_cast<int>(0x80000000);
static constexpr int4 value = {v, v, v, v};
};
template <typename T>
__device__ static constexpr int4 NegZero128b_v = NegZero128b<T>::value;
template <typename T>
__device__ __forceinline__ bool is_neg_zero(T& v);
template <>
__device__ __forceinline__ bool is_neg_zero<float>(float& v)
{
uint32_t bits = *reinterpret_cast<uint32_t*>(&v);
return bits == 0x80000000;
}
template <>
__device__ __forceinline__ bool is_neg_zero<half>(half& v)
{
uint16_t bits = *reinterpret_cast<uint16_t*>(&v);
return bits == 0x8000;
}
template <>
__device__ __forceinline__ bool is_neg_zero<__nv_bfloat16>(__nv_bfloat16& v)
{
uint16_t bits = *reinterpret_cast<uint16_t*>(&v);
return bits == 0x8000;
}
template <typename ValType, typename VecType>
__device__ __forceinline__ VecType remove_neg_zero(VecType const& vec)
{
static constexpr int kIter = sizeof(VecType) / sizeof(ValType);
using ReadOnlyValType = std::add_const_t<ValType>;
VecType ret;
#pragma unroll
for (int i = 0; i < kIter; ++i)
{
auto val = reinterpret_cast<ReadOnlyValType*>(&vec)[i];
reinterpret_cast<ValType*>(&ret)[i] = is_neg_zero(val) ? static_cast<ValType>(0.f) : val;
}
return ret;
}
template <typename ValType, typename VecType>
__device__ __forceinline__ bool has_neg_zero(VecType const& vec)
{
static constexpr int kIter = sizeof(VecType) / sizeof(ValType);
using ReadOnlyValType = std::add_const_t<ValType>;
#pragma unroll
for (int i = 0; i < kIter; ++i)
{
auto val = reinterpret_cast<ReadOnlyValType*>(&vec)[i];
if (is_neg_zero(val))
{
return true;
}
}
return false;
}
template <typename ValType, typename VecType>
__device__ __forceinline__ bool all_neg_zero(VecType const& vec)
{
static constexpr int kIter = sizeof(VecType) / sizeof(ValType);
using ReadOnlyValType = std::add_const_t<ValType>;
#pragma unroll
for (int i = 0; i < kIter; ++i)
{
auto val = reinterpret_cast<ReadOnlyValType*>(&vec)[i];
if (!is_neg_zero(val))
{
return false;
}
}
return true;
}
__device__ __forceinline__ void st_global_release(int4 const& val, int4* addr)
{
asm volatile("st.release.global.sys.v4.b32 [%4], {%0, %1, %2, %3};" ::"r"(val.x), "r"(val.y), "r"(val.z),
"r"(val.w), "l"(addr));
}
__device__ __forceinline__ int4 ld_global_acquire(int4* addr)
{
int4 val;
asm volatile("ld.acquire.global.sys.v4.b32 {%0, %1, %2, %3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(addr));
return val;
}
__device__ __forceinline__ void st_global_volatile(int4 const& val, int4* addr)
{
asm volatile("st.volatile.global.v4.b32 [%4], {%0, %1, %2, %3};" ::"r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w),
"l"(addr));
}
__device__ __forceinline__ int4 ld_global_volatile(int4* addr)
{
int4 val;
asm volatile("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
: "l"(addr));
return val;
}
template <typename ValType>
__device__ __forceinline__ void set_neg_zero(int4* addr)
{
st_global_volatile(NegZero128b_v<ValType>, addr);
}
template <typename T, int RanksPerNode, bool PushMode>
struct Reducer;
template <typename T, int RanksPerNode>
struct Reducer<T, RanksPerNode, true>
{
static __device__ __forceinline__ int4 allreduce(AllReduceParams& params, int global_offset)
{
using PackedStruct = typename PackedOn16Bytes<T>::Type;
int ping = params.barrier_flag % 3;
int pong = (params.barrier_flag + 2) % 3;
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
T* local_shared_buffer = reinterpret_cast<T*>(
params.fusion_params.lamport_peer_comm_buffer_ptrs[params.local_rank + ping * MAX_RANKS_PER_NODE]);
T* local_clean_buffer = reinterpret_cast<T*>(
params.fusion_params.lamport_peer_comm_buffer_ptrs[params.local_rank + pong * MAX_RANKS_PER_NODE]);
local_input_buffer += global_offset;
local_shared_buffer += global_offset;
local_clean_buffer += global_offset;
T* buffers[RanksPerNode];
#pragma unroll
for (int ii = 0; ii < RanksPerNode; ++ii)
{
int rank = (params.local_rank + ii) % RanksPerNode;
buffers[ii] = reinterpret_cast<T*>(
params.fusion_params.lamport_peer_comm_buffer_ptrs[rank + ping * MAX_RANKS_PER_NODE])
+ global_offset + params.local_rank * params.elts_total;
}
PackedStruct sum_vec, val;
val.packed = remove_neg_zero<T>(*reinterpret_cast<int4 const*>(local_input_buffer));
#pragma unroll
for (int ii = 1; ii < RanksPerNode; ++ii)
{
st_global_volatile(val.packed, reinterpret_cast<int4*>(buffers[ii]));
}
sum_vec.packed = val.packed;
#pragma unroll
for (int ii = 1; ii < RanksPerNode; ++ii)
{
int rank = (params.local_rank + ii) % RanksPerNode;
set_neg_zero<T>(reinterpret_cast<int4*>(local_clean_buffer + rank * params.elts_total));
}
PackedStruct vals[RanksPerNode - 1];
bool done = false;
while (!done)
{
done = true;
#pragma unroll
for (int ii = 1; ii < RanksPerNode; ++ii)
{
int rank = (params.local_rank + ii) % RanksPerNode;
vals[ii - 1].packed
= ld_global_volatile(reinterpret_cast<int4*>(local_shared_buffer + rank * params.elts_total));
}
#pragma unroll
for (int ii = 0; ii < RanksPerNode - 1; ii++)
{
done &= !has_neg_zero<T>(vals[ii].packed);
}
}
#pragma unroll
for (int ii = 1; ii < RanksPerNode; ++ii)
{
sum_vec.packed = add128b(sum_vec, vals[ii - 1]);
}
return sum_vec.packed;
}
};
template <typename T, int RanksPerNode>
struct Reducer<T, RanksPerNode, false>
{
static __device__ __forceinline__ int4 allreduce(AllReduceParams& params, int global_offset)
{
using PackedStruct = typename PackedOn16Bytes<T>::Type;
int ping = params.barrier_flag % 3;
int pong = (params.barrier_flag + 2) % 3;
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
T* local_shared_buffer = reinterpret_cast<T*>(
params.fusion_params.lamport_peer_comm_buffer_ptrs[params.local_rank + ping * MAX_RANKS_PER_NODE]);
T* local_clean_buffer = reinterpret_cast<T*>(
params.fusion_params.lamport_peer_comm_buffer_ptrs[params.local_rank + pong * MAX_RANKS_PER_NODE]);
local_input_buffer += global_offset;
local_shared_buffer += global_offset;
local_clean_buffer += global_offset;
T* buffers[RanksPerNode];
#pragma unroll
for (int ii = 0; ii < RanksPerNode; ++ii)
{
int rank = (params.local_rank + ii) % RanksPerNode;
buffers[ii] = reinterpret_cast<T*>(
params.fusion_params.lamport_peer_comm_buffer_ptrs[rank + ping * MAX_RANKS_PER_NODE])
+ global_offset;
}
PackedStruct sum_vec, val;
val.packed = remove_neg_zero<T>(*reinterpret_cast<int4 const*>(local_input_buffer));
st_global_volatile(val.packed, reinterpret_cast<int4*>(local_shared_buffer));
sum_vec.packed = val.packed;
#pragma unroll
for (int ii = 1; ii < RanksPerNode; ++ii)
{
do
{
val.packed = ld_global_volatile(reinterpret_cast<int4*>(buffers[ii]));
} while (has_neg_zero<T>(val.packed));
sum_vec.packed = add128b(sum_vec, val);
}
set_neg_zero<T>(reinterpret_cast<int4*>(local_clean_buffer));
return sum_vec.packed;
}
};
template <int ClusterSize, typename T, int RanksPerNode, bool Bias = false, bool Affine = false, bool PushMode = true>
static __global__ void lamport_style_one_shot_all_reduce_norm_kernel(AllReduceParams params)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
namespace cg = cooperative_groups;
static_assert(RanksPerNode <= 8);
static constexpr int kPackedSize = details::kBytesPerAccess / sizeof(T);
using PackedStruct = typename PackedOn16Bytes<T>::Type;
cg::cluster_group cluster = cg::this_cluster();
__shared__ float cluster_acc;
int bid = blockIdx.x, tid = threadIdx.x;
int cluster_id = bid / ClusterSize, cluster_block_rank = bid % ClusterSize;
int token_id = cluster_id;
int cluster_offset = token_id * params.fusion_params.hidden_size;
int block_offset = cluster_block_rank * params.fusion_params.hidden_size / ClusterSize;
int thread_offset = tid * kPackedSize;
int inner_token_offset = block_offset + thread_offset;
int global_offset = cluster_offset + inner_token_offset;
T const* bias_buffer = reinterpret_cast<T const*>(params.fusion_params.bias_buffer);
T const* residual_buffer = reinterpret_cast<T const*>(params.fusion_params.residual_buffer);
T const* weight_buffer = reinterpret_cast<T const*>(params.fusion_params.weight_buffer);
T* local_final_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);
T* intermediate_buffer = reinterpret_cast<T*>(params.fusion_params.intermediate_buffer);
local_final_output_buffer += global_offset;
intermediate_buffer += global_offset;
residual_buffer += global_offset;
bias_buffer += inner_token_offset;
weight_buffer += inner_token_offset;
PackedStruct weight_vec, bias_vec, residual_vec;
residual_vec.packed = *reinterpret_cast<int4 const*>(residual_buffer);
if constexpr (Bias)
{
bias_vec.packed = *reinterpret_cast<int4 const*>(bias_buffer);
}
if constexpr (Affine)
{
weight_vec.packed = *reinterpret_cast<int4 const*>(weight_buffer);
}
cudaGridDependencySynchronize();
float acc = 0.f;
PackedStruct sum_vec;
sum_vec.packed = Reducer<T, RanksPerNode, PushMode>::allreduce(params, global_offset);
if constexpr (Bias)
{
sum_vec.packed = add128b(sum_vec, bias_vec);
}
sum_vec.packed = add128b(sum_vec, residual_vec);
*reinterpret_cast<int4*>(intermediate_buffer) = sum_vec.packed;
acc = accumulate<T>(acc, sum_vec);
acc = block_reduce_sum(acc);
if (ClusterSize > 1)
{
if (threadIdx.x == 0)
{
cluster_acc = acc;
}
cluster.sync();
acc = 0.f;
#pragma unroll
for (int ii = 0; ii < ClusterSize; ++ii)
{
acc += *cluster.map_shared_rank(&cluster_acc, ii);
}
}
float denom = __fsqrt_rn(__fdividef(acc, params.fusion_params.hidden_size) + params.fusion_params.eps);
sum_vec.packed = rms_norm<T, Affine>(denom, sum_vec, weight_vec);
*reinterpret_cast<int4*>(local_final_output_buffer) = sum_vec.packed;
cudaTriggerProgrammaticLaunchCompletion();
#endif
}
int heuristic_min_warp_number(int tp_size, int hidden_size)
{
if (hidden_size >= 4096)
{
return 4;
}
if (tp_size == 2)
{
return 32;
}
else
{
return 16;
}
}
template <typename T, int RanksPerNode, bool Bias, bool Affine>
void lamport_style_one_shot_all_reduce_norm_kernel_launcher(AllReduceParams params, cudaStream_t stream)
{
static constexpr int kPackedSize = details::kBytesPerAccess / sizeof(T);
TLLM_CHECK(params.fusion_params.hidden_size % kPackedSize == 0);
int threads_per_token = params.fusion_params.hidden_size / kPackedSize;
int warps_per_token = (threads_per_token + details::kWarpSize - 1) / details::kWarpSize;
int token_num = params.elts_total / params.fusion_params.hidden_size;
int warp_min_number = heuristic_min_warp_number(RanksPerNode, params.fusion_params.hidden_size);
int cluster_size = std::min(((warps_per_token + warp_min_number - 1) / warp_min_number), details::kClusterMaxSize);
int cta_size = warps_per_token / cluster_size * details::kWarpSize;
TLLM_CHECK(cta_size <= details::kMaxCtaSize);
int cta_num = token_num * cluster_size;
cudaLaunchConfig_t kernel_config = {0};
kernel_config.gridDim = cta_num;
kernel_config.blockDim = cta_size;
kernel_config.dynamicSmemBytes = 0;
kernel_config.stream = stream;
cudaLaunchAttribute attribute[2];
attribute[0].id = cudaLaunchAttributeClusterDimension;
attribute[0].val.clusterDim.x = cluster_size;
attribute[0].val.clusterDim.y = 1;
attribute[0].val.clusterDim.z = 1;
kernel_config.attrs = attribute;
kernel_config.numAttrs = 1;
if (tensorrt_llm::common::getEnvEnablePDL())
{
attribute[1].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[1].val.programmaticStreamSerializationAllowed = 1;
kernel_config.numAttrs++;
}
#define LAUNCH_LAMPORT_KERNEL(CLUSTER_SIZE) \
if (cluster_size == CLUSTER_SIZE) \
{ \
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&kernel_config, \
lamport_style_one_shot_all_reduce_norm_kernel<CLUSTER_SIZE, T, RanksPerNode, Bias, Affine>, params)); \
return; \
}
LAUNCH_LAMPORT_KERNEL(1);
LAUNCH_LAMPORT_KERNEL(2);
LAUNCH_LAMPORT_KERNEL(3);
LAUNCH_LAMPORT_KERNEL(4);
LAUNCH_LAMPORT_KERNEL(5);
LAUNCH_LAMPORT_KERNEL(6);
LAUNCH_LAMPORT_KERNEL(7);
LAUNCH_LAMPORT_KERNEL(8);
#undef LAUNCH_LAMPORT_KERNEL
}
template <typename T, int RanksPerNode, bool Bias = false, bool Affine = false, bool UseSmem = false>
static __global__ void __launch_bounds__(1024, 1) one_shot_all_reduce_norm_kernel(AllReduceParams params)
{
@ -495,79 +881,144 @@ static __global__ void __launch_bounds__(1024, 1) one_shot_all_reduce_norm_kerne
#endif
}
template <typename T>
bool is_lamport_supported(int token_num)
{
static char* disableLamportReduceNormFusionChar = std::getenv("DISABLE_LAMPORT_REDUCE_NORM_FUSION");
bool disableLamportReduceNormFusion = (disableLamportReduceNormFusionChar != nullptr);
if (disableLamportReduceNormFusion)
return false;
static int sm = tensorrt_llm::common::getSMVersion();
if (sm < 90)
{
return false;
}
if (!std::is_same_v<T, half> && !std::is_same_v<T, __nv_bfloat16>)
{
return false;
}
if (token_num > details::kLamportTokenNumThreshold)
{
return false;
}
return true;
}
bool is_lamport_supported(nvinfer1::DataType dataType, int token_num)
{
switch (dataType)
{
case nvinfer1::DataType::kFLOAT: return is_lamport_supported<float>(token_num);
case nvinfer1::DataType::kHALF: return is_lamport_supported<half>(token_num);
#ifdef ENABLE_BF16
case nvinfer1::DataType::kBF16: return is_lamport_supported<__nv_bfloat16>(token_num);
#endif
default: return false;
}
}
template <typename T, int RanksPerNode, bool Bias, bool Affine>
void one_shot_all_reduce_norm_kernel_launcher(AllReduceParams params, cudaStream_t stream)
void one_shot_all_reduce_norm_kernel_launcher(AllReduceParams& params, cudaStream_t stream)
{
int token_num = params.elts_total / params.fusion_params.hidden_size;
if (is_lamport_supported<T>(token_num))
{
lamport_style_one_shot_all_reduce_norm_kernel_launcher<T, RanksPerNode, Bias, Affine>(params, stream);
}
else
{
static constexpr int kPackedSize = details::kBytesPerAccess / sizeof(T);
TLLM_CHECK(params.fusion_params.hidden_size % kPackedSize == 0);
int need_threads = params.fusion_params.hidden_size / kPackedSize;
int cta_size;
if (need_threads <= details::kMaxCtaSize)
{
cta_size = (need_threads + details::kWarpSize - 1) / details::kWarpSize * details::kWarpSize;
}
else
{
cta_size = details::kMaxCtaSize;
}
int norm_num = params.elts_total / params.fusion_params.hidden_size;
int cta_num = std::min(norm_num, static_cast<int>(MAX_ALL_REDUCE_BLOCKS));
int smem_size = 0;
if (cta_size * kPackedSize < params.fusion_params.hidden_size)
{
smem_size = params.fusion_params.hidden_size * sizeof(T);
if (tensorrt_llm::common::getEnvEnablePDL())
{
TLLM_LOG_DEBUG("Enable PDL in one_shot_all_reduce_norm_kernel");
cudaLaunchConfig_t kernelConfig = {0};
kernelConfig.gridDim = cta_num;
kernelConfig.blockDim = cta_size;
kernelConfig.dynamicSmemBytes = smem_size;
kernelConfig.stream = stream;
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
kernelConfig.attrs = attribute;
kernelConfig.numAttrs = 1;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
&kernelConfig, one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, true>, params));
}
else
{
one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, true>
<<<cta_num, cta_size, smem_size, stream>>>(params);
}
}
else
{
if (tensorrt_llm::common::getEnvEnablePDL())
{
cudaLaunchConfig_t kernelConfig = {0};
kernelConfig.gridDim = cta_num;
kernelConfig.blockDim = cta_size;
kernelConfig.dynamicSmemBytes = smem_size;
kernelConfig.stream = stream;
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
kernelConfig.attrs = attribute;
kernelConfig.numAttrs = 1;
TLLM_LOG_DEBUG("Enable PDL in one_shot_all_reduce_norm_kernel");
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
&kernelConfig, one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, false>, params));
}
else
{
one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, false>
<<<cta_num, cta_size, smem_size, stream>>>(params);
}
}
}
}
template <typename T>
__global__ void lamport_initialize_kernel(T* buffer, size_t size)
{
static constexpr int kPackedSize = details::kBytesPerAccess / sizeof(T);
TLLM_CHECK(params.fusion_params.hidden_size % kPackedSize == 0);
int need_threads = params.fusion_params.hidden_size / kPackedSize;
int cta_size;
if (need_threads <= details::kMaxCtaSize)
using PackedStruct = typename PackedOn16Bytes<T>::Type;
for (size_t offset = (blockIdx.x * blockDim.x + threadIdx.x) * kPackedSize; offset < size;
offset += gridDim.x * blockDim.x * kPackedSize)
{
cta_size = (need_threads + details::kWarpSize - 1) / details::kWarpSize * details::kWarpSize;
set_neg_zero<T>(reinterpret_cast<int4*>(&buffer[offset]));
}
else
{
cta_size = details::kMaxCtaSize;
}
int norm_num = params.elts_total / params.fusion_params.hidden_size;
int cta_num = std::min(norm_num, static_cast<int>(MAX_ALL_REDUCE_BLOCKS));
int smem_size = 0;
}
if (cta_size * kPackedSize < params.fusion_params.hidden_size)
{
smem_size = params.fusion_params.hidden_size * sizeof(T);
if (tensorrt_llm::common::getEnvEnablePDL())
{
TLLM_LOG_DEBUG("Enable PDL in one_shot_all_reduce_norm_kernel");
cudaLaunchConfig_t kernelConfig = {0};
kernelConfig.gridDim = cta_num;
kernelConfig.blockDim = cta_size;
kernelConfig.dynamicSmemBytes = smem_size;
kernelConfig.stream = stream;
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
kernelConfig.attrs = attribute;
kernelConfig.numAttrs = 1;
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
&kernelConfig, one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, true>, params));
}
else
{
one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, true>
<<<cta_num, cta_size, smem_size, stream>>>(params);
}
}
else
{
if (tensorrt_llm::common::getEnvEnablePDL())
{
cudaLaunchConfig_t kernelConfig = {0};
kernelConfig.gridDim = cta_num;
kernelConfig.blockDim = cta_size;
kernelConfig.dynamicSmemBytes = smem_size;
kernelConfig.stream = stream;
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
kernelConfig.attrs = attribute;
kernelConfig.numAttrs = 1;
TLLM_LOG_DEBUG("Enable PDL in one_shot_all_reduce_norm_kernel");
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
&kernelConfig, one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, false>, params));
}
else
{
one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, false>
<<<cta_num, cta_size, smem_size, stream>>>(params);
}
}
template <typename T>
void lamport_initialize_kernel_launcher(void* buffer, size_t size, cudaStream_t stream)
{
static constexpr int kPackedSize = details::kBytesPerAccess / sizeof(T);
int block_size = 1024;
int grid_size = (size + 1024 * kPackedSize - 1) / (1024 * kPackedSize);
lamport_initialize_kernel<T><<<grid_size, block_size, 0, stream>>>(reinterpret_cast<T*>(buffer), size);
}
}; // namespace reduce_fusion
@ -1117,13 +1568,24 @@ void AllReduceDispatchType(AllReduceParams& params, AllReduceStrategyType strat,
}
}
AllReduceParams AllReduceParams::deserialize(int64_t* buffer, size_t tpSize, size_t tpRank)
AllReduceParams AllReduceParams::deserialize(
int64_t* buffer, size_t tpSize, size_t tpRank, nvinfer1::DataType dataType, int token_num, AllReduceFusionOp op)
{
void* const* buffer_ptrs = reinterpret_cast<void* const*>(buffer);
auto const flag_ptr = &buffer[4 * tpSize];
int flag_offset;
if (op == AllReduceFusionOp::RESIDUAL_RMS_NORM && reduce_fusion::is_lamport_supported(dataType, token_num))
{
flag_offset = 0;
}
else
{
flag_offset = 1;
}
auto const flag_ptr
= &buffer[tensorrt_llm::utils::customAllReduceUtils::NUM_POINTERS_PER_RANK * tpSize + flag_offset];
// cannot use 0 since 0 represents released state for barrier
*flag_ptr += 1;
TLLM_LOG_TRACE("AllReduceParams's flag value is %d", *flag_ptr);
TLLM_LOG_TRACE("AllReduceParams's flag value is %d, flag offset %d", *flag_ptr, flag_offset);
uint32_t flag_value = *flag_ptr;
AllReduceParams params;
// Even plugins use ping buffers, odd plugins use pong.
@ -1208,4 +1670,25 @@ void residualRmsNorm(kernels::AllReduceParams& params, nvinfer1::DataType dataTy
sync_check_cuda_error();
}
void lamportInitialize(void* buffer, size_t size, nvinfer1::DataType dataType, cudaStream_t stream)
{
sync_check_cuda_error();
switch (dataType)
{
case nvinfer1::DataType::kFLOAT:
reduce_fusion::lamport_initialize_kernel_launcher<float>(buffer, size, stream);
break;
case nvinfer1::DataType::kHALF:
reduce_fusion::lamport_initialize_kernel_launcher<half>(buffer, size, stream);
break;
#ifdef ENABLE_BF16
case nvinfer1::DataType::kBF16:
reduce_fusion::lamport_initialize_kernel_launcher<__nv_bfloat16>(buffer, size, stream);
break;
#endif
default: TLLM_THROW("Unsupported dataType for customAllReduce");
}
sync_check_cuda_error();
}
} // namespace tensorrt_llm::kernels

View File

@ -31,6 +31,15 @@ constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24;
constexpr size_t MAX_RANKS_PER_NODE = 8;
constexpr size_t DEFAULT_BLOCK_SIZE = 512;
namespace reduce_fusion::details
{
static constexpr int kBytesPerAccess = 16;
static constexpr int kWarpSize = 32;
static constexpr int kMaxCtaSize = 1024;
static constexpr int kClusterMaxSize = 8;
static constexpr int kLamportTokenNumThreshold = 16;
}; // namespace reduce_fusion::details
// Warning: python definition is in tensorrt_llm/functional.py
// they must be kept in sync
enum class AllReduceStrategyType : int8_t
@ -73,6 +82,7 @@ struct AllReduceFusionParams
float eps;
// new residual
void* intermediate_buffer;
void* lamport_peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE * 3];
};
struct AllReduceParams
@ -81,7 +91,8 @@ struct AllReduceParams
size_t elts_per_rank;
size_t elts_per_block;
size_t rank_offset;
size_t ranks_per_node, local_rank;
size_t ranks_per_node;
size_t local_rank;
uint32_t barrier_flag;
uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
@ -91,7 +102,8 @@ struct AllReduceParams
AllReduceFusionParams fusion_params;
static AllReduceParams deserialize(int64_t* buffer, size_t tpSize, size_t tpRank);
static AllReduceParams deserialize(int64_t* buffer, size_t tpSize, size_t tpRank, nvinfer1::DataType dataType,
int token_num, AllReduceFusionOp op);
};
bool configurationSupported(AllReduceStrategyType algo, size_t msg_size, size_t n_ranks, nvinfer1::DataType type);
@ -101,4 +113,6 @@ void customAllReduce(kernels::AllReduceParams& params, nvinfer1::DataType dataTy
void residualRmsNorm(kernels::AllReduceParams& params, nvinfer1::DataType dataType, cudaStream_t stream);
void lamportInitialize(void* buffer, size_t size, nvinfer1::DataType dataType, cudaStream_t stream);
} // namespace tensorrt_llm::kernels

View File

@ -1,2 +1,2 @@
88c30973b9b3452baa3f063d34d08169 libtensorrt_llm_nvrtc_wrapper.so
7f370deb0090d885d7518c2b146399ba3933c004 commit
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit

View File

@ -1,2 +1,2 @@
95e9f87610383348e444d2d0b8396f2d libtensorrt_llm_nvrtc_wrapper.so
7f370deb0090d885d7518c2b146399ba3933c004 commit
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1471e322bb44cd65b98ee30e0befa32ae4c86e828f0b4fd4f02d4af4e710d08f
oid sha256:db512d533ab4e4a4abd0047a65d891dfd6e1522f2d34c90f29296c3239fd3cc1
size 1128448

View File

@ -1,3 +1,3 @@
b7e624ba775e9f5090ef4b67bcdbd7a2 tensorrt_llm_nvrtc_wrapper.lib
f9b1cc37a27dd0574bb41a2763a97be7 tensorrt_llm_nvrtc_wrapper.dll
7f370deb0090d885d7518c2b146399ba3933c004 commit
d89a0a140d2d427af13c3794a4b21e2c tensorrt_llm_nvrtc_wrapper.dll
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit

View File

@ -121,6 +121,17 @@ void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_ti
namespace runtime::kernels
{
//! \brief Inserts the running beams into the finished beams stored in the CBA buffers. (beams where the most likely
//! continuation is the end token get stored separately, and another candidate next token is stored). Then sorts the
//! beams according to their cumulative log probs. Note: the kernels in gatherTree modify the buffers inplace. When
//! streaming, we use tmp buffers since beam search kernels expect ungathered data.
//!
//! \param decodingOutput contains a slice of the output buffers to gather. Also contains the
//! DecodingOutput::BeamHypotheses object with the finished beams.
//! \param decodingInput used for endIds and input lengths.
//! \param manager the usual buffer manager.
//! \param samplingConfig the usual buffer samplingConfig.
void gatherTree(DecodingOutput const& decodingOutput, DecodingInput const& decodingInput, BufferManager const& manager,
SamplingConfig const& samplingConfig);
} // namespace runtime::kernels

View File

@ -228,7 +228,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets
}
}
// Perpare values for fmha.
// Prepare values for fmha.
if (threadIdx.x == 0 && blockIdx.x == 0)
{
// Reset fmha tile counter to 0 before launching fmha kernels.

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9117f7cf5eef0ed452c0d0bc79242b84def103e7038c9d3df6e366690801ca92
oid sha256:0814af36fed752bbe70d953cefbb78dd306c42f3d9f6848b7043a865e48f9662
size 25364090

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2b04913f9e9029a5ce5a222d5cc7492ff53323a548079d2fb32d5b2aeb0c2268
oid sha256:ee46f2d1c9162f4302a1031f778fcb7c7110c84110427f97af6532ed9bd342fd
size 25768990

View File

@ -1,3 +1,3 @@
d54fb93f256601f4c4ad7f1c8e6e9919 libtensorrt_llm_internal_cutlass_kernels_static.a
71028d801074f11138e890391e48591d libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
7f370deb0090d885d7518c2b146399ba3933c004 commit
90740ead1def66f350e14c133278463d libtensorrt_llm_internal_cutlass_kernels_static.a
b0104227ffd1ce19fc1fdb45e349df36 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d8c685f8ea2f84838dfdbf448eab41c76fe88fe29db0d4a511d6d6d241ad1832
oid sha256:4d9ba0f8b95cf64227cb0b17654fb7c9bc1741fe003889658b305750b388a4dc
size 44173632

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b9d75392ba3b59853c43072b4f9949b32cb6724813a39048e4585e9a8fb3e136
oid sha256:4f848d5beebbd69792047a96b16f7145f8e1e3e311d2a19789ce639ad8149b0e
size 43561206

View File

@ -1,3 +1,3 @@
4fc3e1fb0db6a121f88a9141605d9285 libtensorrt_llm_internal_cutlass_kernels_static.a
253731af750407020dbe6f2fbe50fa2b libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
7f370deb0090d885d7518c2b146399ba3933c004 commit
2aaf05cb84f52b024e89d4fa634d6900 libtensorrt_llm_internal_cutlass_kernels_static.a
f17ce186e9105c594e39d252777ce4c7 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:62af58f5e09d1cf5e347b02ef3bd3a186469162fc9645d038fb2cba23b597722
size 88140804
oid sha256:c429687e335c75f08186bcd8f629b50467cb0f2e484d755834c5b1cdbb9ecaf3
size 88140796

View File

@ -1,2 +1,2 @@
eb7fc4a105eb6e6f52ba865f2b055233 tensorrt_llm_internal_cutlass_kernels_static.lib
7f370deb0090d885d7518c2b146399ba3933c004 commit
4f663be2b768088805ccec6dc33545fc tensorrt_llm_internal_cutlass_kernels_static.lib
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit

View File

@ -1458,7 +1458,7 @@ std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType,
size_t const softmax_out_size = num_softmax_outs * sizeof(float);
size_t const permuted_scales_size = mayHaveFinalizeFused() ? num_moe_inputs * sizeof(float) : 0;
size_t const glu_inter_size = glu_inter_elems * gemm_output_dtype; // May be an intermediate type for quantization
size_t const fc1_result_size = interbuf_elems * sizeof(T); // Acitvation quantizes so back to sizeof(T)
size_t const fc1_result_size = interbuf_elems * sizeof(T); // Activation quantizes so back to sizeof(T)
size_t const sorter_size = CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts);
size_t const fc2_result_size = permuted_elems * gemm_output_dtype; // May be an intermediate type for quantization

View File

@ -1427,7 +1427,7 @@ void invokeAirTopPSamplingWithDeterministicPara(TopPSamplingKernelParams<T> cons
kernel = airTopPSampling<T, IdxT, AccT, HisT, BitsPerPass, SAMPLING_BLOCK_SIZE, true, isDeterministic>;
}
kernel<<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(counters, histograms, countHistograms, params.outputIds,
kernel<<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(counters, histograms, countHistograms, params.outputIdsPtrs,
params.sequenceLength, params.finishedInput, params.finishedOutput, params.cumLogProbs,
params.outputLogProbs, params.endIds, params.maxBatchSize, params.skipDecode, pass, buf1, idxBuf1, buf2,
idxBuf2, params.batchSlots);

View File

@ -196,11 +196,11 @@ __device__ void epilogue(SizeType32 batchId, SizeType32 currentStep, SizeType32
}
template <typename T, int blockSize>
__global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenIdType** ids, SizeType32* sequenceLength,
FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
SizeType32 const* beginOffsetBuf, SizeType32 const* offsetBuf, SizeType32 vocabSize, curandState_t* curandState,
float const* topPs, TokenIdType const* endIds, SizeType32 maxBatchSize, bool const* skipDecode,
SizeType32 const* batchSlots)
__global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenIdType* ids, TokenIdType** idsPtrs,
SizeType32* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, SizeType32 const* beginOffsetBuf, SizeType32 const* offsetBuf, SizeType32 vocabSize,
curandState_t* curandState, float const* topPs, TokenIdType const* endIds, SizeType32 maxBatchSize,
bool const* skipDecode, SizeType32 const* batchSlots, bool returnAllTopP, SizeType32 maxSeqLen)
{
/**
* Each block processes one request row sorted in descending order by probabilities.
@ -235,14 +235,16 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
}
auto const probThreshold = topPs[batchSlot];
auto const currentStep = sequenceLength[batchSlot];
auto const currentStep = sequenceLength == nullptr ? 0 : sequenceLength[batchSlot];
auto* outputIdsRequestPtr = idsPtrs == nullptr ? ids + batchSlot * maxSeqLen : idsPtrs[batchSlot];
// With P in (0.0; 1.0] we draw a random number P' in range (0.0; P]
// We will sum all probs moving from the largest probability to the smallest and
// will choose the token which probability makes cumulative probability sum to exceed P'
if (threadIdx.x == 0)
{
randNumS = curand_uniform(curandState + blockIdx.x) * probThreshold;
// if we want to return all top p indices, we should not do random sampling for probThreshold
randNumS = returnAllTopP ? probThreshold : curand_uniform(curandState + blockIdx.x) * probThreshold;
}
// if beginOffsetBuf and offsetBuf of sorting have same value,
@ -253,8 +255,15 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
if (tid == 0)
{
auto offset = batchId * vocabSize;
epilogue(batchSlot, currentStep, offset, ids, sortedIdVals, sortedProbs, cumLogProbs, outputLogProbs,
endIds, sequenceLength, finishedOutput, maxBatchSize);
if (returnAllTopP)
{
outputIdsRequestPtr[currentStep] = sortedIdVals[offset];
}
else
{
epilogue(batchSlot, currentStep, offset, idsPtrs, sortedIdVals, sortedProbs, cumLogProbs,
outputLogProbs, endIds, sequenceLength, finishedOutput, maxBatchSize);
}
}
return;
}
@ -267,7 +276,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
__syncthreads();
auto offset = batchId * vocabSize;
ids[batchSlot][currentStep] = sortedIdVals[offset];
outputIdsRequestPtr[currentStep] = sortedIdVals[offset];
auto end = ((vocabSize + blockSize - 1) / blockSize) * blockSize;
SizeType32 selectedTokenId = 0;
// Cumulative sum
@ -285,11 +294,31 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
}
}
// select first thread exceeded the prob threshold or the last thread in case of P=1.0f
if (threadIdx.x == min(blockDim.x - count, blockDim.x - 1))
if (returnAllTopP)
{
epilogue(batchSlot, currentStep, offset + selectedTokenId, ids, sortedIdVals, sortedProbs, cumLogProbs,
outputLogProbs, endIds, sequenceLength, finishedOutput, maxBatchSize);
__shared__ SizeType32 sharedSelectedTokenId;
if (threadIdx.x == min(blockDim.x - count, blockDim.x - 1))
{
sharedSelectedTokenId = selectedTokenId;
}
__syncthreads();
for (int vi = tid; vi <= sharedSelectedTokenId; vi += blockSize)
{
outputIdsRequestPtr[vi] = sortedIdVals[offset + vi];
}
if (tid == 0 && sharedSelectedTokenId != end - 1)
{
outputIdsRequestPtr[sharedSelectedTokenId + 1] = -1; // a boundary to record the end of all selected top Ps.
}
}
else
{
// select first thread exceeded the prob threshold or the last thread in case of P=1.0f
if (threadIdx.x == min(blockDim.x - count, blockDim.x - 1))
{
epilogue(batchSlot, currentStep, offset + selectedTokenId, idsPtrs, sortedIdVals, sortedProbs, cumLogProbs,
outputLogProbs, endIds, sequenceLength, finishedOutput, maxBatchSize);
}
}
}
@ -371,9 +400,10 @@ void invokeBatchTopPSampling(TopPSamplingKernelParams<T> const& params, cudaStre
dim3 grid(params.batchSize);
// Sample with Top P given sorted tokens
topPSsampling<T, SAMPLING_BLOCK_SIZE><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(sortedProbs, sortedIdVals,
params.outputIds, params.sequenceLength, params.finishedInput, params.finishedOutput, params.cumLogProbs,
params.outputLogProbs, beginOffsetBuf, offsetBuf + 1, params.vocabSizePadded, params.curandState, params.topPs,
params.endIds, params.maxBatchSize, params.skipDecode, params.batchSlots);
params.outputIds, params.outputIdsPtrs, params.sequenceLength, params.finishedInput, params.finishedOutput,
params.cumLogProbs, params.outputLogProbs, beginOffsetBuf, offsetBuf + 1, params.vocabSizePadded,
params.curandState, params.topPs, params.endIds, params.maxBatchSize, params.skipDecode, params.batchSlots,
params.returnAllTopP, params.maxSeqLen);
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);

View File

@ -28,8 +28,13 @@ struct TopPSamplingKernelParams
//! input buffer [batchSize, vocabSizePadded], required. Probabilities of each token in the vocab.
T const* probs{nullptr};
//! output buffer [maxBatchSize][maxSeqLen], required. Contains pointers to rows with output tokens per request.
runtime::TokenIdType** outputIds{nullptr};
//! output buffer [maxBatchSize][maxSeqLen]. Contains pointers to rows with output tokens per request.
//! If nullptr, outputIds must be provided.
runtime::TokenIdType** outputIdsPtrs{nullptr};
//! output buffer [maxBatchSize, maxSeqLen], optional. Tensor to store output tokens.
//! Not used if outputIdsPtrs != nullptr
runtime::TokenIdType* outputIds{nullptr};
//! pointer to the workspace. Has to be pre-allocated by caller.
//! Function does not take ownership of the buffer.
@ -73,6 +78,9 @@ struct TopPSamplingKernelParams
runtime::SizeType32 batchSize{-1};
runtime::SizeType32 maxBatchSize{-1};
runtime::SizeType32 vocabSizePadded{-1};
runtime::SizeType32 maxSeqLen{-1};
bool returnAllTopP{false};
void checkParams() const
{
@ -81,12 +89,17 @@ struct TopPSamplingKernelParams
TLLM_CHECK(maxBatchSize >= batchSize);
TLLM_CHECK(vocabSizePadded > 0);
TLLM_CHECK(probs);
TLLM_CHECK(outputIds);
TLLM_CHECK(outputIds || outputIdsPtrs);
TLLM_CHECK(workspace);
TLLM_CHECK(sequenceLength);
TLLM_CHECK((sequenceLength != nullptr) || returnAllTopP);
TLLM_CHECK(curandState);
TLLM_CHECK(topPs);
if (outputIds)
{
TLLM_CHECK(maxSeqLen > 0);
}
TLLM_CHECK(((finishedOutput == nullptr) ^ (endIds == nullptr)) == 0);
}
};

View File

@ -35,230 +35,281 @@ namespace tensorrt_llm::kernels::speculative_decoding
{
namespace
{
__global__ void acceptDraftTokensByIds(TokenIdType const* draftIds, TokenIdType const* targetIds,
SizeType32 const* contextLengths, SizeType32 const* numsDraftTokens, SizeType32* sequenceLengths,
FinishedState const* finished, FinishedState* finishedFinal, SizeType32* finishedSum, SizeType32 const* batchSlots,
SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 maxSeqLen, SizeType32 maxDraftTokens)
template <typename T>
__global__ void maskTargetLogitsKernel(T* targetLogits, SizeType32 const* batchSlots, SizeType32 beamWidth,
SizeType32 vocabSize, FinishedState const* finishedInput, SizeType32 maxBatchSize, bool const* batchUseDraftLogits,
SizeType32* outputIdsAfterSampling, SizeType32* targetOutputIds, SizeType32* runtimeTopKDevicePtr, bool* maskBuffer)
{
for (auto batchIdx = static_cast<SizeType32>(threadIdx.x); batchIdx < batchSize; batchIdx += blockDim.x)
/**
* @brief Masking the selected token to -inf as was done in Huggingface TopK/TopP Logits Warper
* https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/generation/logits_process.py#L533
*/
auto const bid = blockIdx.x;
auto const batchIdx = bid / beamWidth;
auto const tid = static_cast<SizeType32>(threadIdx.x);
auto const batchSlot = batchSlots[batchIdx];
constexpr bool IS_HALF = std::is_same<T, half>::value;
T const MAX_T_VAL = (IS_HALF) ? HALF_FLT_MAX : FLT_MAX;
auto targetLogitsBatch = targetLogits + batchIdx * vocabSize;
auto& finishedState = finishedInput[batchSlot];
auto* outputIdsAfterSamplingPtr = outputIdsAfterSampling + batchSlot * vocabSize;
auto const useDraftLogits = batchUseDraftLogits[batchSlot];
if (finishedState.isSkipDecoding())
{
auto const batchSlot = batchSlots[batchIdx];
auto const numDraftTokens = numsDraftTokens[batchSlot];
return;
}
auto const contextLength = contextLengths[batchSlot];
auto& sequenceLength = sequenceLengths[batchSlot];
SizeType32 finishedDraftIdx = 0;
for (auto ti = contextLength; ti < min(sequenceLength, contextLength + numDraftTokens);
++ti, ++finishedDraftIdx)
{
auto const draftIdx = ti - contextLength;
auto const targetTokenIdx = batchSlot * maxSeqLen + ti;
auto const draftTokenIdx = batchSlot * maxDraftTokens + draftIdx;
// Check if draft tokens are the same as target tokens
bool const accepted = draftIds[draftTokenIdx] == targetIds[targetTokenIdx];
if (!accepted)
{
// Set sequence length to the numAcceptedTokens + 1
sequenceLength = min(ti + 1, maxSeqLen);
// FIXME(nkorobov): do we need to set endIds here?
break;
}
__shared__ SizeType32 tokensToMask;
if (tid == 0)
{
tokensToMask = runtimeTopKDevicePtr[batchSlot];
}
__syncthreads();
for (SizeType32 vIdx = tid; vIdx < vocabSize; vIdx += static_cast<SizeType32>(blockDim.x))
{
if (tokensToMask == 0 && outputIdsAfterSamplingPtr[vIdx] == -1)
{ // we need to find the -1 boundary from returnAllTopP outputIds if topK == 0
tokensToMask = vIdx;
}
FinishedState finishState = finished[finishedDraftIdx * maxBatchSize + batchSlot];
finishedFinal[batchSlot] = finishState;
maskBuffer[vIdx] = false;
}
if (finishedSum)
__syncthreads();
if (!useDraftLogits && tid == 0)
{
targetOutputIds[batchSlot] = outputIdsAfterSamplingPtr[tokensToMask - 1];
}
for (SizeType32 vIdx = tid; vIdx < tokensToMask; vIdx += static_cast<SizeType32>(blockDim.x))
{
auto tokenToMask = outputIdsAfterSamplingPtr[vIdx];
maskBuffer[tokenToMask] = true;
}
__syncthreads();
for (SizeType32 vIdx = tid; vIdx < vocabSize; vIdx += static_cast<SizeType32>(blockDim.x))
{
if (!maskBuffer[vIdx])
{
finishedSum[batchSlot] = static_cast<int>(finishState.isFinished());
targetLogitsBatch[vIdx] = -MAX_T_VAL;
}
}
}
} // namespace
void invokeAcceptDraftTokensByIds(TokenIdType const* draftIds, TokenIdType const* targetIds,
SizeType32 const* contextLengths, SizeType32 const* numsDraftTokens, SizeType32* sequenceLengths,
FinishedState const* finished, FinishedState* finishedFinal, SizeType32* finishedSum, SizeType32 const* batchSlots,
SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth, SizeType32 maxSeqLen,
SizeType32 maxDraftTokens, cudaStream_t stream)
{
TLLM_CHECK(beamWidth == 1);
dim3 block(min(1024, batchSize));
dim3 grid(1);
acceptDraftTokensByIds<<<grid, block, 0, stream>>>(draftIds, targetIds, contextLengths, numsDraftTokens,
sequenceLengths, finished, finishedFinal, finishedSum, batchSlots, batchSize, maxBatchSize, maxSeqLen,
maxDraftTokens);
}
namespace
{
template <typename T>
__global__ void acceptDraftTokensByLogitsKernel(T const* draftProbs, T* targetProbs, SizeType32 const* numsDraftTokens,
FinishedState* finished, curandState_t* curandState, SizeType32 const* batchSlots, SizeType32 batchSize,
SizeType32 maxBatchSize, SizeType32 maxDraftTokens, SizeType32 beamWidth, SizeType32 vocabSize,
bool randomThreshold, float constantThreshold)
__global__ void acceptDraftTokensKernel(T const* draftProbs, T* targetProbs, SizeType32 const* numsDraftTokens,
bool const* batchUseDraftLogits, TokenIdType const* draftIds, FinishedState const* finishedInput,
FinishedState* finishedOutput, curandState_t* curandState, SizeType32 const* batchSlots, SizeType32 maxDraftTokens,
SizeType32 beamWidth, SizeType32 vocabSize, bool randomThreshold, float constantThreshold, SizeType32 step,
bool* batchIsAccepted, SizeType32* targetOutputIds)
{
auto const bid = blockIdx.x;
auto const draftTokenIdx = blockIdx.y;
auto const draftTokenIdx = step;
auto const batchIdx = bid / beamWidth;
auto const beamIdx = bid % beamWidth;
auto const batchSlot = batchSlots[batchIdx];
auto const batchSlotBeamWidth = batchSlot * beamWidth + beamIdx;
auto const tid = static_cast<SizeType32>(threadIdx.x);
auto const numDraftTokens = numsDraftTokens[batchSlotBeamWidth];
auto const useDraftLogits = batchUseDraftLogits[batchSlotBeamWidth];
if (draftTokenIdx >= numDraftTokens)
if (draftTokenIdx > numDraftTokens || finishedInput[batchSlot].isSkipDecoding())
{
if (tid == 0)
{
batchIsAccepted[batchSlot] = true;
finishedOutput[batchSlot].setSkipDecoding();
}
return;
}
auto const logitsOffset = (batchSlot * maxDraftTokens + draftTokenIdx) * beamWidth * vocabSize;
auto const draftProbsBatch = draftProbs + logitsOffset;
auto const targetProbsBatch = targetProbs + logitsOffset;
auto const vocabSizePadded = static_cast<SizeType32>((vocabSize + blockDim.x - 1) / blockDim.x) * blockDim.x;
auto const targetProbsBatch = targetProbs + (batchIdx * beamWidth * vocabSize);
struct Candidate candidate;
__shared__ float threshold;
if (threadIdx.x == 0)
__shared__ bool isAccepted;
__shared__ T sSumVal;
if (tid == 0)
{
threshold = randomThreshold ? curand_uniform(curandState + batchSlot) : constantThreshold;
}
__syncthreads();
for (auto vIdx = static_cast<SizeType32>(threadIdx.x); vIdx < vocabSizePadded;
vIdx += static_cast<SizeType32>(blockDim.x))
{
bool const pred = vIdx < vocabSize;
auto const targetProb = pred ? static_cast<float>(targetProbsBatch[vIdx]) : 1.f;
auto const draftProb = pred ? static_cast<float>(draftProbsBatch[vIdx]) : 0.f;
if (draftProb > candidate.maxProb)
if (draftTokenIdx < numDraftTokens)
{
candidate.maxProb = draftProb;
candidate.rateQP = pred ? targetProb / draftProb : 0.f;
}
}
__syncthreads();
typedef cub::BlockReduce<Candidate, 1024> BlockReduce;
__shared__ typename BlockReduce::TempStorage reduce_buffer;
Candidate candidate_global = BlockReduce(reduce_buffer).Reduce(candidate, reduce_op);
__syncthreads();
if (threadIdx.x == 0)
{
finished[draftTokenIdx * maxBatchSize * beamWidth + batchSlotBeamWidth]
= candidate_global.rateQP < threshold ? FinishedState::skipDecoding() : FinishedState::empty();
}
}
template <typename T>
__global__ void correctAcceptedStatesAndLogits(T const* draftProbs, T* targetProbs, T** targetLogits,
SizeType32 const* numsDraftTokens, FinishedState* finished, SizeType32 const* batchSlots, SizeType32 batchSize,
SizeType32 maxBatchSize, SizeType32 maxDraftTokens, SizeType32 beamWidth, SizeType32 vocabSize)
{
auto const bid = blockIdx.x;
auto const batchIdx = bid / beamWidth;
auto const beamIdx = bid % beamWidth;
auto const batchSlot = batchSlots[batchIdx];
auto const batchSlotBeamWidth = batchSlot * beamWidth + beamIdx;
auto const numDraftTokens = numsDraftTokens[batchSlotBeamWidth];
__shared__ SizeType32 numAcceptedTokens;
if (threadIdx.x == 0)
{
numAcceptedTokens = numDraftTokens;
bool cummulativeSkipDecoding = false;
for (SizeType32 ti = 0; ti < numDraftTokens + 1; ++ti)
{
auto& finishedState = finished[ti * maxBatchSize * beamWidth + batchSlotBeamWidth];
bool localSkipDecoding = finishedState.isSkipDecoding();
if (cummulativeSkipDecoding == false && localSkipDecoding == true)
auto const draftOutputTokenId = draftIds[batchSlot * maxDraftTokens + draftTokenIdx];
if (useDraftLogits)
{
numAcceptedTokens = ti;
float threshold = randomThreshold ? curand_uniform(curandState + batchSlot) : constantThreshold;
auto const targetProb = static_cast<float>(targetProbsBatch[draftOutputTokenId]);
auto const draftProb = static_cast<float>(draftProbsBatch[draftOutputTokenId]);
auto rateQP = targetProb / draftProb;
if (rateQP < threshold)
{
isAccepted = false;
finishedOutput[batchSlot].setSkipDecoding();
}
else
{
isAccepted = true;
}
}
else
{
// Check if draft tokens are the same as target tokens
isAccepted = targetOutputIds[batchSlot] == draftOutputTokenId;
if (!isAccepted)
{
finishedOutput[batchSlot].setSkipDecoding();
}
}
finishedState = cummulativeSkipDecoding ? FinishedState::skipDecoding() : FinishedState::empty();
cummulativeSkipDecoding |= localSkipDecoding;
}
else
{
isAccepted = false;
finishedOutput[batchSlot].setSkipDecoding();
}
batchIsAccepted[batchSlot] = isAccepted;
}
__syncthreads();
if (numAcceptedTokens < numDraftTokens)
if (!isAccepted)
{
auto const logitsIdx = (batchSlot * maxDraftTokens + numAcceptedTokens) * beamWidth * vocabSize;
auto const draftProbBatch = draftProbs + logitsIdx;
auto targetProbBatch = targetProbs + logitsIdx;
auto targetLogitsBatch = targetLogits[bid] + numAcceptedTokens * beamWidth * vocabSize;
float sumProbs = 0.f;
for (SizeType32 vIdx = static_cast<SizeType32>(threadIdx.x); vIdx < vocabSize;
vIdx += static_cast<SizeType32>(blockDim.x))
T const zeroVal = static_cast<T>(0.0f);
T sumVal = zeroVal;
for (SizeType32 vIdx = tid; vIdx < vocabSize; vIdx += static_cast<SizeType32>(blockDim.x))
{
auto const correctedProb = max(static_cast<float>(targetProbBatch[vIdx] - draftProbBatch[vIdx]), 0.f);
sumProbs += correctedProb;
targetProbBatch[vIdx] = correctedProb;
targetProbsBatch[vIdx]
-= (draftTokenIdx < numDraftTokens && useDraftLogits) ? draftProbsBatch[vIdx] : zeroVal;
targetProbsBatch[vIdx] = targetProbsBatch[vIdx] >= zeroVal ? targetProbsBatch[vIdx] : zeroVal;
sumVal += targetProbsBatch[vIdx];
}
__shared__ float sumProbsShared;
sumProbs = blockReduceSum<float>((float) sumProbs);
if (threadIdx.x == 0)
sumVal = blockReduceSum<T>(sumVal);
if (tid == 0)
{
sumProbsShared = max(sumProbs, 1e-6f);
sSumVal = sumVal;
}
__syncthreads();
for (SizeType32 vIdx = static_cast<SizeType32>(threadIdx.x); vIdx < vocabSize;
vIdx += static_cast<SizeType32>(blockDim.x))
for (SizeType32 vIdx = tid; vIdx < vocabSize; vIdx += static_cast<SizeType32>(blockDim.x))
{
auto const correctedNormProb = static_cast<float>(targetProbBatch[vIdx]) / sumProbsShared;
targetLogitsBatch[vIdx] = __logf(correctedNormProb / (1.f - correctedNormProb));
targetProbsBatch[vIdx] /= sSumVal;
}
}
}
__global__ void forwardAcceptedTokensKernel(SizeType32 batchSize, SizeType32 const* batchSlots, bool* batchIsAccepted,
SizeType32* sequenceLengths, TokenIdType const* draftIds, TokenIdType** idsPtrs, SizeType32 step,
SizeType32 maxDraftTokens, TokenIdType const* endIds, FinishedState* finishedOutput)
{
auto index = static_cast<SizeType32>(blockIdx.x * blockDim.x + threadIdx.x);
for (SizeType32 bi = index; bi < batchSize; bi += static_cast<SizeType32>(gridDim.x * blockDim.x))
{
auto const batchSlot = batchSlots[bi];
if (batchIsAccepted[batchSlot] && !finishedOutput[batchSlot].isSkipDecoding())
{
auto const curSeqLen = sequenceLengths[batchSlot];
auto const draftTokenIdx = step;
auto const draftOutputTokenId = draftIds[batchSlot * maxDraftTokens + draftTokenIdx];
auto* outputIdsRequestPtr = idsPtrs[batchSlot];
auto const outIdx = curSeqLen;
outputIdsRequestPtr[outIdx] = draftOutputTokenId;
if (outputIdsRequestPtr[outIdx] == endIds[batchSlot])
{
finishedOutput[batchSlot].setFinishedEOS();
// Do not increase seq len when EOS is generated. Seq len should always contain only tokens to be
// outputted
}
else
{
// We don't need to set output finished state as it is assumed to be in non finished state
sequenceLengths[batchSlot] += 1;
}
}
}
} // namespace
} // namespace
template <typename T>
void acceptDraftTokensByLogits(T* draftLogits, T** targetLogits, T* draftProbs, T* targetProbs,
SizeType32 const* numsDraftTokens, FinishedState* finished, curandState_t* curandState,
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth,
SizeType32 vocabSize, SizeType32 vocabSizePadded, SizeType32 maxDraftTokens, bool randomThreshold,
float constantThreshold, cudaStream_t stream)
void invokeMaskTargetLogits(SizeType32 batchSize, T* targetLogits, SizeType32 const* batchSlots, SizeType32 beamWidth,
SizeType32 vocabSizePadded, FinishedState const* finishedInput, SizeType32 maxBatchSize,
bool const* batchUseDraftLogits, SizeType32* outputIdsAfterSampling, SizeType32* targetOutputIds,
SizeType32* runtimeTopKDevicePtr, bool* maskBuffer, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(beamWidth == 1);
{
invokeAddBiasSoftMax(draftLogits, static_cast<T**>(nullptr), draftProbs, static_cast<T*>(nullptr), nullptr,
finished, batchSlots, batchSize, maxBatchSize, beamWidth * maxDraftTokens, vocabSize, vocabSizePadded,
/* skip softmax */ false,
/* batchSlotLogits */ true, stream);
invokeAddBiasSoftMax(static_cast<T*>(nullptr), targetLogits, targetProbs, static_cast<T*>(nullptr), nullptr,
finished, batchSlots, batchSize, maxBatchSize, beamWidth * maxDraftTokens, vocabSize, vocabSizePadded,
/* skip softmax */ false,
/* batchSlotLogits */ true, stream);
}
{
dim3 block(1024);
dim3 grid(batchSize * beamWidth, maxDraftTokens);
acceptDraftTokensByLogitsKernel<<<grid, block, 0, stream>>>(draftProbs, targetProbs, numsDraftTokens, finished,
curandState, batchSlots, batchSize, maxBatchSize, maxDraftTokens, beamWidth, vocabSizePadded,
randomThreshold, constantThreshold);
}
{
dim3 block(1024);
dim3 grid(batchSize * beamWidth);
correctAcceptedStatesAndLogits<<<grid, block, 0, stream>>>(draftProbs, targetProbs, targetLogits,
numsDraftTokens, finished, batchSlots, batchSize, maxBatchSize, maxDraftTokens, beamWidth, vocabSizePadded);
maskTargetLogitsKernel<<<grid, block, 0, stream>>>(targetLogits, batchSlots, beamWidth, vocabSizePadded,
finishedInput, maxBatchSize, batchUseDraftLogits, outputIdsAfterSampling, targetOutputIds,
runtimeTopKDevicePtr, maskBuffer);
}
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template void acceptDraftTokensByLogits(float* draftLogits, float** targetLogits, float* draftProbs, float* targetProbs,
SizeType32 const* numsDraftTokens, FinishedState* finished, curandState_t* curandState,
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth,
SizeType32 vocabSize, SizeType32 vocabSizePadded, SizeType32 maxDraftTokens, bool randomThreshold,
float constantThreshold, cudaStream_t stream);
template void acceptDraftTokensByLogits(half* draftLogits, half** targetLogits, half* draftProbs, half* targetProbs,
SizeType32 const* numsDraftTokens, FinishedState* finished, curandState_t* curandState,
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth,
SizeType32 vocabSize, SizeType32 vocabSizePadded, SizeType32 maxDraftTokens, bool randomThreshold,
float constantThreshold, cudaStream_t stream);
template <typename T>
void invokeAcceptDraftTokens(SizeType32 batchSize, T* draftProbs, T* targetProbs, SizeType32 const* numsDraftTokens,
bool const* batchUseDraftLogits, TokenIdType const* draftIds, FinishedState const* finishedInput,
FinishedState* finishedOutput, curandState_t* curandState, SizeType32 const* batchSlots, SizeType32 maxDraftTokens,
SizeType32 beamWidth, SizeType32 vocabSizePadded, bool randomThreshold, float constantThreshold, SizeType32 step,
bool* batchIsAccepted, SizeType32* targetOutputIds, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(beamWidth == 1);
{
dim3 block(1024);
dim3 grid(batchSize * beamWidth);
acceptDraftTokensKernel<<<grid, block, 0, stream>>>(draftProbs, targetProbs, numsDraftTokens,
batchUseDraftLogits, draftIds, finishedInput, finishedOutput, curandState, batchSlots, maxDraftTokens,
beamWidth, vocabSizePadded, randomThreshold, constantThreshold, step, batchIsAccepted, targetOutputIds);
}
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template void invokeMaskTargetLogits(SizeType32 batchSize, float* targetLogits, SizeType32 const* batchSlots,
SizeType32 beamWidth, SizeType32 vocabSizePadded, FinishedState const* finishedInput, SizeType32 maxBatchSize,
bool const* batchUseDraftLogits, SizeType32* outputIdsAfterSampling, SizeType32* targetOutputIds,
SizeType32* runtimeTopKDevicePtr, bool* maskBuffer, cudaStream_t stream);
template void invokeMaskTargetLogits(SizeType32 batchSize, half* targetLogits, SizeType32 const* batchSlots,
SizeType32 beamWidth, SizeType32 vocabSizePadded, FinishedState const* finishedInput, SizeType32 maxBatchSize,
bool const* batchUseDraftLogits, SizeType32* outputIdsAfterSampling, SizeType32* targetOutputIds,
SizeType32* runtimeTopKDevicePtr, bool* maskBuffer, cudaStream_t stream);
template void invokeAcceptDraftTokens(SizeType32 batchSize, float* draftProbs, float* targetProbs,
SizeType32 const* numsDraftTokens, bool const* batchUseDraftLogits, TokenIdType const* draftIds,
FinishedState const* finishedInput, FinishedState* finishedOutput, curandState_t* curandState,
SizeType32 const* batchSlots, SizeType32 maxDraftTokens, SizeType32 beamWidth, SizeType32 vocabSizePadded,
bool randomThreshold, float constantThreshold, SizeType32 step, bool* batchIsAccepted, SizeType32* targetOutputIds,
cudaStream_t stream);
template void invokeAcceptDraftTokens(SizeType32 batchSize, half* draftProbs, half* targetProbs,
SizeType32 const* numsDraftTokens, bool const* batchUseDraftLogits, TokenIdType const* draftIds,
FinishedState const* finishedInput, FinishedState* finishedOutput, curandState_t* curandState,
SizeType32 const* batchSlots, SizeType32 maxDraftTokens, SizeType32 beamWidth, SizeType32 vocabSizePadded,
bool randomThreshold, float constantThreshold, SizeType32 step, bool* batchIsAccepted, SizeType32* targetOutputIds,
cudaStream_t stream);
void invokeForwardAcceptedTokens(SizeType32 batchSize, SizeType32 const* batchSlots, bool* batchIsAccepted,
SizeType32* outputSequenceLengths, TokenIdType const* draftIds, TokenIdType** idsPtrs, SizeType32 step,
SizeType32 maxDraftTokens, TokenIdType const* endIds, FinishedState* finishedOutput, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
dim3 block(std::min(static_cast<uint32_t>(batchSize), 256u));
dim3 grid(divUp(static_cast<uint32_t>(batchSize), block.x));
forwardAcceptedTokensKernel<<<grid, block, 0, stream>>>(batchSize, batchSlots, batchIsAccepted,
outputSequenceLengths, draftIds, idsPtrs, step, maxDraftTokens, endIds, finishedOutput);
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
} // namespace tensorrt_llm::kernels::speculative_decoding

View File

@ -26,84 +26,77 @@
namespace tensorrt_llm::kernels::speculative_decoding
{
//! \brief Accepts or rejects draft tokens based on the equality of draft and target tokens
//! for speculative decoding. Target token is accepted if targetToken == draftToken.
//! If number of accepted tokens N < maxDraftTokens, then function accepts N + 1 tokens of target model.
//! sequenceLengths, finishedSum and finishedFinal are modified accordingly.
//!
//! \param draftIds input buffer [batchSize, maxDraftTokens].
//! Indices of the draft tokens.
//! \param targetIds input buffer [batchSize, maxSeqLen]. Indices of the tokens decoded by the target model
//! \param contextLengths input buffer [batchSize]. Context lengths of the requests without draft tokens
//! \param numsDraftTokens input buffer [batchSize]. Number of draft tokens per request
//! \param sequenceLengths input/output buffer [batchSize] sequence lengths of the requests in batch
//! Modified in-place according to the accepted/rejected tokens
//! \param finished input buffer [maxDraftTokens + 1, batchSize] finished states at each decoding iteration
//! \param finishedFinal output buffer [batchSize] finished states after accepting/rejecting tokens
//! \param finishedSum output buffer [1] total number of requests in batch that finished the execution
//! \param batchSlots input buffer [batchSize], address map from local index
//! to global index [0, batchSize] -> [0, maxBatchSize]
//! \param batchSize current batch size
//! \param maxBatchSize maximum batch size
//! \param beamWidth beam width
//! \param maxSeqLen maximum sequence length
//! \param maxDraftTokens maximum number of draft tokens
//! \param stream stream
void invokeAcceptDraftTokensByIds(runtime::TokenIdType const* draftIds, runtime::TokenIdType const* targetIds,
runtime::SizeType32 const* contextLengths, runtime::SizeType32 const* numsDraftTokens,
runtime::SizeType32* sequenceLengths, FinishedState const* finished, FinishedState* finishedFinal,
runtime::SizeType32* finishedSum, runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize,
runtime::SizeType32 maxBatchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 maxSeqLen,
runtime::SizeType32 maxDraftTokens, cudaStream_t stream);
//! \brief Performs probabilistic acceptance of draft tokens based on their probability distributions.
//! Corrects targetLogits for the next to the last accepted token
//! \brief Accepts or rejects draft tokens based on their probability distributions or the equality of draft and target
//! tokens. Corrects targetLogits for the last accepted token
//! according to https://openreview.net/pdf?id=C9NEblP8vS
//!
//! \param draftLogits input/output buffer [draftTokens, batchSize, beamWidth, vocabSize].
//! Initially contains token logits of the draft model.
//! \param targetLogits input/output buffer [batchSize][draftTokens+1, beamWidth, vocabSize].
//! Vector of pointers to the logits.
//! Initially contains token logits of the target model.
//! It is modified in-place for next to the last accepted token such as
//! P'(x) = norm(max(0, P_{n+1}(x) - Q_{n+1}(x))), where N < maxDraftTokens is number of accepted tokens.
//! \param batchSize current batch size
//! \param draftProbs output buffer [maxDraftTokens, batchSize, beamWidth, vocabSize].
//! Workspace buffer for token probabilities of the draft model.
//! \param targetProbs output buffer [maxDraftTokens+1, batchSize, beamWidth, vocabSize].
//! Workspace buffer for token probabilities of the target model.
//! \param numsDraftTokens input buffer [batchSize]. Number of draft tokens per request
//! \param finished output buffer [draftTokens, batchSize, beamWidth].
//! At each step sets to NOT_FINISHED if token is accepted or SKIP_DECODING if token is not accepted
//! \param curandState input buffer [batchSize]. Curand states properly
//! initialized using invokeCurandInitialize per request.
//! \param batchSlots input buffer [batchSize], address map from local index
//! to global index [0, batchSize] -> [0, maxBatchSize]
//! \param batchSize current batch size
//! \param maxBatchSize maximum batch size
//! \param beamWidth beam width
//! \param vocabSize unpadded vocab size
//! \param vocabSizePadded padded vocab size
//! \param batchUseDraftLogits input buffer [batchSize]. Acceptance logic using draft logits or not, per request
//! \param draftIds input buffer [batchSize, draftTokens]. Pointer to draft token ids.
//! \param finishedInput input buffer [batchSize, beamWidth].
//! \param finishedOutput output buffer [batchSize, beamWidth]. At each step sets SKIP_DECODING if token is not
//! accepted.
//! \param curandState input buffer [batchSize]. Curand states properly initialized using invokeCurandInitialize
//! per request.
//! \param batchSlots input buffer [batchSize], address map from local index to global index [0, batchSize] ->
//! [0, maxBatchSize].
//! \param maxDraftTokens maximum number of draft tokens
//! \param beamWidth beam width (only beamWidth == 1 supported)
//! \param vocabSizePadded padded vocab size
//! \param randomThreshold True if use uniformly sampled threshold for token acceptance
//! \param constantThreshold threshold used to accept tokens if randomThreshold is false
//! \param step The current step of decoding (draft token id index)
//! \param batchIsAccepted output buffer [batchSize]. Stores acceptance result for multinomial sampling later or
//! forwarding next step.
//! \param targetOutputIds input/output buffer [batchSize]. Stores target sampling output ids for acceptById
//! logics.
//! \param stream stream
template <typename T>
void acceptDraftTokensByLogits(T* draftLogits, T** targetLogits, T* draftProbs, T* targetProbs,
runtime::SizeType32 const* numsDraftTokens, FinishedState* finished, curandState_t* curandState,
runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 maxBatchSize,
runtime::SizeType32 beamWidth, runtime::SizeType32 vocabSize, runtime::SizeType32 vocabSizePadded,
runtime::SizeType32 maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream);
void invokeAcceptDraftTokens(runtime::SizeType32 batchSize, T* draftProbs, T* targetProbs,
runtime::SizeType32 const* numsDraftTokens, bool const* batchUseDraftLogits, runtime::TokenIdType const* draftIds,
FinishedState const* finishedInput, FinishedState* finishedOutput, curandState_t* curandState,
runtime::SizeType32 const* batchSlots, runtime::SizeType32 maxDraftTokens, runtime::SizeType32 beamWidth,
runtime::SizeType32 vocabSizePadded, bool randomThreshold, float constantThreshold, runtime::SizeType32 step,
bool* batchIsAccepted, runtime::SizeType32* targetOutputIds, cudaStream_t stream);
struct Candidate // Hold probability maximum and rate of target / dfraft, used in `acceptDraftTokensByLogits`
{
float maxProb{0.f};
float rateQP{0.f};
};
//! \brief Mask the target logits with -inf for unselected topK/topP token ids.
//! according to
//! https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/generation/utils.py#L4064
//!
//! \param batchSize current batch size
//! \param targetLogits input/output buffer [batchSize][draftTokens+1, beamWidth, vocabSize].
//! Vector of pointers to the logits. (beamWidth == 1)
//! Initially contains token logits of the target model.
//! \param batchSlots input buffer [batchSize], address map from local index to global index [0, batchSize] ->
//! [0, maxBatchSize].
//! \param beamWidth beam width (only beamWidth == 1 supported)
//! \param vocabSizePadded padded vocab size
//! \param finishedInput input buffer [batchSize, beamWidth].
//! \param maxBatchSize maximum batch size
//! \param batchUseDraftLogits input buffer [batchSize]. Acceptance logic using draft logits or not, per request
//! \param outputIdsAfterSampling input buffer [batchSize, vocabSize]. Stores all selected IDs from sampling for
//! masking.
//! \param targetOutputIds input/output buffer [batchSize]. Stores target sampling output ids for acceptById
//! logics.
//! \param numsDraftTokens input buffer [batchSize]. Number of draft tokens per request
//! \param runtimeTopKDevicePtr input buffer [batchSize] the topks in sampling step, for porting topK ids out.
//! \param maskBuffer input buffer [batchSize, vocabSize] for masking calculation (index value to position).
//! \param stream stream
template <typename T>
void invokeMaskTargetLogits(runtime::SizeType32 batchSize, T* targetLogits, runtime::SizeType32 const* batchSlots,
runtime::SizeType32 beamWidth, runtime::SizeType32 vocabSizePadded, FinishedState const* finishedInput,
runtime::SizeType32 maxBatchSize, bool const* batchUseDraftLogits, runtime::SizeType32* outputIdsAfterSampling,
runtime::SizeType32* targetOutputIds, runtime::SizeType32* runtimeTopKDevicePtr, bool* maskBuffer,
cudaStream_t stream);
__device__ __forceinline__ Candidate reduce_op(Candidate const& a, Candidate const& b)
{
// Max-reduce operator of Candidate
return (a.maxProb > b.maxProb) ? a : b;
}
void invokeForwardAcceptedTokens(runtime::SizeType32 batchSize, runtime::SizeType32 const* batchSlots,
bool* batchIsAccepted, runtime::SizeType32* outputSequenceLengths, runtime::TokenIdType const* draftIds,
runtime::TokenIdType** idsPtrs, runtime::SizeType32 step, runtime::SizeType32 maxDraftTokens,
runtime::TokenIdType const* endIds, FinishedState* finishedOutput, cudaStream_t stream);
} // namespace tensorrt_llm::kernels::speculative_decoding

View File

@ -73,7 +73,7 @@ void invokeLengthCriterion(FinishedState* finished, runtime::SizeType32* finishe
runtime::SizeType32* numNewTokens, runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize,
runtime::SizeType32 beamWidth, cudaStream_t stream);
//! \brief Sets finished states based on the endIds and ajusts sequence length to length before the first EOS token.
//! \brief Sets finished states based on the endIds and adjusts sequence length to length before the first EOS token.
//! Does not support beamWidth > 1 for now.
//!
//! \param outputIds input buffer [maxBatchSize][beamWidth, maxSeqLen].

View File

@ -19,6 +19,7 @@
#include "tensorrt_llm/layers/beamSearchLayer.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/layers/explicitDraftTokensLayer.h"
#include "tensorrt_llm/layers/externalDraftTokensLayer.h"
#include "tensorrt_llm/layers/layerUtils.h"
#include "tensorrt_llm/layers/lookaheadDecodingLayer.h"
#include "tensorrt_llm/layers/medusaDecodingLayer.h"
@ -96,6 +97,10 @@ DecodingLayer<T>::DecodingLayer(executor::DecodingMode const& mode, DecoderDomai
{
mDecodingLayer = std::make_unique<ExplicitDraftTokensLayer<T>>(decoderDomain, mBufferManager);
}
else if (mDecodingMode.isExternalDraftTokens())
{
mDecodingLayer = std::make_unique<ExternalDraftTokensLayer<T>>(mDecodingMode, decoderDomain, mBufferManager);
}
else
{
TLLM_CHECK_WITH_INFO(false,
@ -144,6 +149,12 @@ void DecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorC
beamWidth == 1, "Decoding mode is ExplicitDraftTokens, but beamWidth != 1 (%d != 1)", beamWidth);
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, setupParams->decodingParams, workspace);
}
else if (mDecodingMode.isExternalDraftTokens())
{
TLLM_CHECK_WITH_INFO(
beamWidth == 1, "Decoding mode is external draft tokens, but beamWidth != 1 (%d != 1)", beamWidth);
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, setupParams->decodingParams, workspace);
}
else
{
TLLM_CHECK_WITH_INFO(false,
@ -249,6 +260,45 @@ std::tuple<std::shared_ptr<BaseDecodingOutputs>, std::shared_ptr<BaseDecodingInp
preparedInputs = baseInputs;
preparedOutputs = baseOutputs;
}
else if (mDecodingMode.isExternalDraftTokens())
{
auto externalDraftTokenParams = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
auto const ite = externalDraftTokenParams->ite;
auto const step = externalDraftTokenParams->step;
auto const localBatchSize = static_cast<int64_t>(externalDraftTokenParams->localBatchSize);
TLLM_CHECK_WITH_INFO(localDecoderDomain.getBeamWidth() == 1,
"Decoding mode is TopK and/or TopP, but beamWidth != 1 (%d != 1)", localDecoderDomain.getBeamWidth());
// In sampling, we have supported batch sampling. So, we always compute all
// sentences once.
TensorConstPtr logitsSlice = ITensor::slice(*externalDraftTokenParams->logits, 0, localBatchSize);
TensorConstPtr endIdSlice = ITensor::slice(endIds, 0, localBatchSize);
auto decodeInputs = std::make_shared<ExternalDraftTokensInputs>(
endIdSlice, externalDraftTokenParams->batchSlots, step, ite, localBatchSize);
decodeInputs->finished = externalDraftTokenParams->finished;
decodeInputs->logits = logitsSlice;
if (externalDraftTokenParams->inputLengths)
{
auto& inputLengths = externalDraftTokenParams->inputLengths.value();
decodeInputs->inputLengths = ITensor::slice(inputLengths, 0, localBatchSize);
}
decodeInputs->draftLogits = externalDraftTokenParams->draftLogits;
decodeInputs->draftProbs = externalDraftTokenParams->draftProbs;
decodeInputs->targetProbs = externalDraftTokenParams->targetProbs;
decodeInputs->numDraftTokens = externalDraftTokenParams->numDraftTokens;
decodeInputs->draftTokenIds = externalDraftTokenParams->draftTokenIds;
decodeInputs->constantThreshold = externalDraftTokenParams->constantThreshold;
decodeInputs->useRandomAcceptanceThreshold = externalDraftTokenParams->useRandomAcceptanceThreshold;
decodeInputs->step = externalDraftTokenParams->step;
decodeInputs->useDraftLogits = externalDraftTokenParams->useDraftLogits;
preparedInputs = decodeInputs;
preparedOutputs = baseOutputs;
}
else
{
TLLM_CHECK_WITH_INFO(false,

View File

@ -45,7 +45,7 @@ public:
std::shared_ptr<BaseDecodingInputs> const& inputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace) override;
//! \brief Calls forwardSync of configired decoding layer.
//! \brief Calls forwardSync of configured decoding layer.
void forwardSync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& inputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace) override;

View File

@ -212,6 +212,13 @@ struct LookaheadSetupParams : public DecodingSetupParams
TensorPtr attentionPackedMasks;
};
class ExternalDraftTokensSetupParams : public DecodingSetupParams
{
public:
std::optional<std::vector<runtime::SizeType32>> runtimeTopK; // [1] or [setupBatchSize] on cpu
std::optional<std::vector<float>> runtimeTopP; // [1] or [setupBatchSize] on cpu
};
class BaseDecodingInputs
{
public:
@ -331,6 +338,33 @@ public:
bool probsComputed{};
};
class ExternalDraftTokensInputs : public DecodingInputs
{
public:
explicit ExternalDraftTokensInputs(TensorConstPtr endIds, TensorConstPtr batchSlots, runtime::SizeType32 step,
runtime::SizeType32 ite, runtime::SizeType32 localBatchSize)
: DecodingInputs{std::move(endIds), std::move(batchSlots), step, ite, localBatchSize}
{
}
TensorPtr draftLogits;
TensorPtr draftProbs;
TensorPtr targetProbs;
TensorPtr numDraftTokens;
TensorPtr draftTokenIds;
TensorPtr useDraftLogits;
runtime::SizeType32 step;
float constantThreshold;
bool useRandomAcceptanceThreshold;
//! optional parameters
//! [localBatchSize]
curandState_t* curandStates{};
//! Flag to mark that logits tensor contains probabilities
bool probsComputed{};
};
// Medusa inputs
class MedusaDecodingInputs : public DecodingInputs
{
@ -477,7 +511,7 @@ public:
//! {c'} is always accepted and {x', z'} is supposed to be accepted.
//! The accepted tokens [c', x', z'] is saved in `outputIds` in-place, starting from `sequenceLength`.
//! The `acceptedLength` is 3, and the accepted draft tokens length is 2.
//! `sequenceLength` is also increaded by `acceptedLength` in-place.
//! `sequenceLength` is also increased by `acceptedLength` in-place.
//! The pathsOffset is {0, 1, 3} for {c', x', z'}.
//! [] for accepted, <> for draft, {} for input/output.
//!

View File

@ -0,0 +1,514 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "externalDraftTokensLayer.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
#include "tensorrt_llm/kernels/samplingTopPKernels.h"
#include "tensorrt_llm/kernels/speculativeDecoding/externalDraftTokensKernels.h"
#include "tensorrt_llm/layers/defaultDecodingParams.h"
#include "tensorrt_llm/layers/layerUtils.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include <algorithm>
namespace tksd = tensorrt_llm::kernels::speculative_decoding;
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm::layers
{
template <typename T>
ExternalDraftTokensLayer<T>::ExternalDraftTokensLayer(executor::DecodingMode const& mode,
DecoderDomain const& decoderDomain, std::shared_ptr<BufferManager> bufferManager)
: BaseLayer(decoderDomain, bufferManager)
, mDecodingMode(mode)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK_WITH_INFO(!mDecodingMode.isBeamSearch(), "ExternalDraftTokensLayer does not support Beam search mode");
allocateBuffer(decoderDomain.getBatchSize());
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void ExternalDraftTokensLayer<T>::allocateBuffer(SizeType32 batchSize)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// top k workspace size
auto workspaceSize = getTopKWorkspaceSize<T>(batchSize, 1, TOP_K_MAX, mDecoderDomain.getVocabSizePadded());
mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize);
// top p workspace size
workspaceSize = getTopPWorkspaceSize<T>(batchSize, mDecoderDomain.getVocabSizePadded());
mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize);
// multinomial (top p == 1) workspace size
workspaceSize = getTopPWorkspaceSize<float>(batchSize, mDecoderDomain.getVocabSizePadded());
mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize);
// batchsize here is maxBatchSize
auto const batchSizeShape = ITensor::makeShape({batchSize});
mCurandStatesDevice
= mBufferManager->gpu(ITensor::makeShape({batchSize, sizeof(curandState_t)}), TRTDataType<int8_t>::value);
mBatchIsAccepted = mBufferManager->gpu(batchSizeShape, TRTDataType<bool>::value);
mRuntimeMultinomialDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<float>::value);
// host buffers.
mSkipTopKDecodeDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<bool>::value);
mSkipTopKDecodeHost = mBufferManager->pinnedPool(batchSizeShape, TRTDataType<bool>::value);
mSkipTopPDecodeDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<bool>::value);
mSkipTopPDecodeHost = mBufferManager->pinnedPool(batchSizeShape, TRTDataType<bool>::value);
auto skipTopPDecodeHostRange = BufferRange<bool>(*mSkipTopPDecodeHost);
std::fill(skipTopPDecodeHostRange.begin(), skipTopPDecodeHostRange.end(), true);
mOutputIdsAfterSampling = mBufferManager->gpu(
ITensor::makeShape({batchSize, mDecoderDomain.getVocabSizePadded()}), TRTDataType<TokenIdType>::value);
mTargetOutputIds = mBufferManager->gpu(ITensor::makeShape({batchSize}), TRTDataType<TokenIdType>::value);
mRuntimeTopKDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<SizeType32>::value);
mRuntimeTopPForTopKDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<float>::value);
mRuntimeTopPDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<float>::value);
mInitialTopPDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<float>::value);
mMaskBuffer = mBufferManager->gpu(
ITensor::makeShape({batchSize, mDecoderDomain.getVocabSizePadded()}), TRTDataType<bool>::value);
mSetupWorkspaceSize = std::max({mBatchIsAccepted->getSizeInBytes(), mRuntimeMultinomialDevice->getSizeInBytes(),
mSkipTopKDecodeDevice->getSizeInBytes(), mSkipTopPDecodeDevice->getSizeInBytes(),
mOutputIdsAfterSampling->getSizeInBytes(), mTargetOutputIds->getSizeInBytes(),
mRuntimeTopKDevice->getSizeInBytes(), mRuntimeTopPForTopKDevice->getSizeInBytes(),
mRuntimeTopPDevice->getSizeInBytes(), mInitialTopPDevice->getSizeInBytes(), mMaskBuffer->getSizeInBytes()});
mTargetLogits = mBufferManager->gpu(
ITensor::makeShape({batchSize, mDecoderDomain.getVocabSizePadded()}), TRTDataType<T>::value);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void ExternalDraftTokensLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorConstPtr batchSlots,
std::shared_ptr<BaseSetupParams> const& baseSetupParams,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto setupParams = std::dynamic_pointer_cast<ExternalDraftTokensSetupParams>(baseSetupParams);
workspace->initializeDeviceCurandStates(
setupParams->randomSeed, batchSize, workspace->getDeviceBatchSlots(), mCurandStatesDevice);
auto const* batchSlotsDevicePtr = workspace->getDeviceBatchSlotsPtr();
auto& runtimeMultinomialDeviceTensor = const_cast<ITensor&>(*mRuntimeMultinomialDevice);
tensorrt_llm::runtime::kernels::invokeFill(runtimeMultinomialDeviceTensor, 1.0f, mBufferManager->getStream());
auto* runtimeTopKDevicePtr = bufferCastOrNull<SizeType32>(mRuntimeTopKDevice);
// Prepare runtime top K
auto constexpr defaultTopK = 1u;
auto runtimeTopK = setupParams->runtimeTopK.value_or(std::vector<SizeType32>(batchSize, defaultTopK));
auto const runtimeTopKSize = runtimeTopK.size();
for (auto& topK : runtimeTopK)
{
if (topK < 0 || topK > TOP_K_MAX)
{
TLLM_LOG_WARNING(
"TopK (%d) is larger than max supported number (%d). Clip to max supported number.", topK, TOP_K_MAX);
topK = std::clamp(topK, 0, static_cast<SizeType32>(TOP_K_MAX));
}
}
if (runtimeTopKSize > 1)
{
TLLM_CHECK_WITH_INFO(runtimeTopK.size() == batchSize,
fmtstr("runtimeTopK.size() (%lu) == batchSize (%d) is not satisfied!", runtimeTopK.size(), batchSize));
DecodingLayerWorkspace::copyToWorkspace<SizeType32>(
*this->mBufferManager, runtimeTopK, workspace->getWorkspaceDeviceBuffer());
auto* setupWorkspaceDevicePtr = workspace->getWorkspaceDevicePtrAs<SizeType32>();
// fill top ks into runtimeTopKDevice
invokeScatterDecodingParams(
setupWorkspaceDevicePtr, runtimeTopKDevicePtr, batchSlotsDevicePtr, batchSize, getStream());
}
// FIXME(nkorobov): monotonically growing
auto const curMaxTopK = *std::max_element(std::begin(runtimeTopK), std::end(runtimeTopK));
mRuntimeMaxTopK = std::max(mRuntimeMaxTopK, curMaxTopK);
auto runtimeTopP = setupParams->runtimeTopP.value_or(std::vector<float>{});
auto const runtimeTopPSize = runtimeTopP.size();
auto* runtimeTopPForTopKDevicePtr = bufferCastOrNull<float>(mRuntimeTopPForTopKDevice);
auto* runtimeTopPDevicePtr = bufferCastOrNull<float>(mRuntimeTopPDevice);
auto* skipTopPDecodeHostPtr = bufferCastOrNull<bool>(mSkipTopPDecodeHost);
// if no top P, fill topP skip decode to true
if (runtimeTopPSize == 0)
{
auto const* batchSlotsPtr = bufferCast<SizeType32>(*batchSlots);
for (SizeType32 bi = 0; bi < batchSize; ++bi)
{
auto const bid = batchSlotsPtr[bi];
skipTopPDecodeHostPtr[bid] = true;
}
auto skipTopPDecodeHostSlice = IBuffer::slice(mSkipTopPDecodeHost, 0, mDecoderDomain.getBatchSize());
mBufferManager->copy(*skipTopPDecodeHostSlice, *mSkipTopPDecodeDevice);
}
else
{
for (auto& topP : runtimeTopP)
{
if (topP < 0.f || topP > 1.0f)
{
TLLM_LOG_WARNING("TopP (%f) is out of range ([0.0, 1.0f]). Clip to closest number.", topP);
topP = std::clamp(topP, 0.f, 1.f);
}
}
if (runtimeTopPSize > 1)
{
TLLM_CHECK_WITH_INFO(runtimeTopP.size() == batchSize,
fmtstr("runtimeTopP.size() (%lu) == batchSize (%d) is not satisfied!", runtimeTopP.size(), batchSize));
DecodingLayerWorkspace::copyToWorkspace<float>(
*this->mBufferManager, runtimeTopP, workspace->getWorkspaceDeviceBuffer());
auto* setupWorkspaceDevicePtr = workspace->getWorkspaceDevicePtrAs<float>();
// fill runtime top p device for top k kernel
invokeScatterDecodingParams(
setupWorkspaceDevicePtr, runtimeTopPForTopKDevicePtr, batchSlotsDevicePtr, batchSize, getStream());
// fill runtime top p device for top p kernel
invokeScatterDecodingParams(
setupWorkspaceDevicePtr, runtimeTopPDevicePtr, batchSlotsDevicePtr, batchSize, getStream());
}
}
// if no topP, default topP is 0.0f, but in invokeSetupTopKRuntimeArgs, it gets set to 1.0f if k > 0
auto const topP = (runtimeTopPSize == 0) ? DefaultDecodingParams::getTopP() : runtimeTopP.front();
auto* skipTopKDecodeDevicePtr = bufferCastOrNull<bool>(mSkipTopKDecodeDevice);
{
dim3 block(std::min(static_cast<uint32_t>(batchSize), 256u));
dim3 grid(divUp(static_cast<uint32_t>(batchSize), block.x));
// support topK up to TOP_K_MAX.
invokeSetupTopKRuntimeArgs(batchSize, curMaxTopK, runtimeTopKDevicePtr, runtimeTopKSize, topP,
runtimeTopPForTopKDevicePtr, runtimeTopPSize, skipTopKDecodeDevicePtr, batchSlotsDevicePtr, getStream());
}
auto const skipTopKHostDecodeDeviceSlice = ITensor::slice(mSkipTopKDecodeDevice, 0, mDecoderDomain.getBatchSize());
auto skipTopKDecodeHostSlice = ITensor::slice(mSkipTopKDecodeHost, 0, mDecoderDomain.getBatchSize());
mBufferManager->copy(*skipTopKHostDecodeDeviceSlice, *skipTopKDecodeHostSlice);
auto* skipTopPDecodeDevicePtr = bufferCast<bool>(*mSkipTopPDecodeDevice);
{
auto* initialTopPDevicePtr = bufferCast<float>(*mInitialTopPDevice);
invokeSetTopPRuntimeArgs(batchSize, curMaxTopK, runtimeTopKDevicePtr, runtimeTopKSize, topP,
runtimeTopPDevicePtr, runtimeTopPSize, skipTopPDecodeDevicePtr, batchSlotsDevicePtr, initialTopPDevicePtr,
getStream());
}
auto const skipTopPHostDecodeDeviceSlice = ITensor::slice(mSkipTopPDecodeDevice, 0, mDecoderDomain.getBatchSize());
auto skipTopPDecodeHostSlice = ITensor::slice(mSkipTopPDecodeHost, 0, mDecoderDomain.getBatchSize());
mBufferManager->copy(*skipTopPHostDecodeDeviceSlice, *skipTopPDecodeHostSlice);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void ExternalDraftTokensLayer<T>::forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputs = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
auto const batchSize = inputs->logits.value()->getDimension<0>();
auto const* endIds = bufferCast<TokenIdType>(*inputs->endIds);
FinishedState const* finishedInput = (inputs->finished)
? reinterpret_cast<FinishedState const*>(bufferCast<FinishedState::UnderlyingType>(*inputs->finished.value()))
: nullptr;
inputs->curandStates = reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStatesDevice));
inputs->probsComputed = true;
auto runtimeLogitsPtr = bufferCast<T>(*workspace->getDeviceRuntimeLogits());
auto logitsPtrsPtr = static_cast<T**>(nullptr);
auto biasPtr = static_cast<T*>(nullptr);
auto const* batchSlotsPtr = workspace->getDeviceBatchSlotsPtr();
mBufferManager->copy(runtimeLogitsPtr, *mTargetLogits);
invokeAddBiasSoftMax(runtimeLogitsPtr, logitsPtrsPtr, runtimeLogitsPtr, biasPtr, endIds, finishedInput,
batchSlotsPtr, batchSize, mDecoderDomain.getBatchSize(), /* bw */ 1, mDecoderDomain.getVocabSize(),
mDecoderDomain.getVocabSizePadded(), /*skipSoftMax*/ false, /* batchSlotLogits */ false, getStream());
auto const targetTokenIdsShape = (*outputs->outputIds).getShape();
// Fill the buffer for selected ids from sampling with zero. -1 will be set as a boundary if topP kernel is required
auto& outputIdsAfterSamplingTensor = const_cast<ITensor&>(*mOutputIdsAfterSampling);
tensorrt_llm::runtime::kernels::invokeFill(outputIdsAfterSamplingTensor, 0, mBufferManager->getStream());
// The logits from target engine should go through samplings first.
// gptDecoderBatched.cpp is calling dynamic decoder step by step, in this step, dynamic Decoder already forwarded
// PenaltyLayer, BanWordsLayer. For (TopK > 0) && (TopK == 0 && TopP == 0), we invoke TopK sampling kernel. The same
// logic is implemented in SamplingLayer.cpp
getAllTopKs(outputs, baseInputs, workspace);
// Only for (TopK == 0 && TopP > 0), we invoke TopP sampling
getAllTopPs(outputs, baseInputs, workspace);
// After all selected tokens are filled in mOutputIdsAfterSampling by topK, topP kernels, token acceptance logics
// starts. First we mask the logits of unselected token id to -inf as HF's TopK, TopP implementation. We compute the
// logit probs of draft and target and go through acceptance logics.
acceptDraftTokens(outputs, baseInputs, workspace);
// If the token of the sequence is not accepted, a multinomial sampling is required for the bonus token.
// Multinomial sampling is achieved through TopP kernel with TopP = 1 and already weighted-sum target logits.
// The acceptance result of each batch is used as skipDecode in topP kernel. If is accepted, no sampling is needed
// (early exit). Forwarding for the next step is also set in this kernel.
multinomialSampling(outputs, baseInputs, workspace);
// For the sequence with accepted tokens, we simply forward a step.
forwardAcceptedTokens(outputs, baseInputs, workspace);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
size_t ExternalDraftTokensLayer<T>::getWorkspaceSize() const noexcept
{
return std::max(mWorkspaceSize, mSetupWorkspaceSize);
}
template <typename T>
void ExternalDraftTokensLayer<T>::acceptDraftTokens(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputs = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
auto const draftLogitsShape = (*inputs->draftLogits).getShape();
auto const maxBatchSize = mDecoderDomain.getBatchSize();
auto const maxTokensPerStep = draftLogitsShape.d[1]; // 1
auto const batchSize = inputs->logits.value()->getDimension<0>();
auto constexpr beamWidth = 1;
FinishedState const* finishedInput = (inputs->finished)
? reinterpret_cast<FinishedState const*>(bufferCastOrNull<FinishedState::UnderlyingType>(inputs->finished))
: nullptr;
FinishedState* finishedOutput = (outputs->finished)
? reinterpret_cast<FinishedState*>(bufferCastOrNull<FinishedState::UnderlyingType>(outputs->finished))
: nullptr;
tksd::invokeMaskTargetLogits(batchSize, bufferCast<T>(*mTargetLogits), workspace->getDeviceBatchSlotsPtr(),
beamWidth, mDecoderDomain.getVocabSizePadded(), finishedInput, maxBatchSize,
bufferCast<bool>(*inputs->useDraftLogits), bufferCast<SizeType32>(*mOutputIdsAfterSampling),
bufferCast<SizeType32>(*mTargetOutputIds), bufferCastOrNull<SizeType32>(mRuntimeTopKDevice),
bufferCast<bool>(*mMaskBuffer), getStream());
if (inputs->step == 0)
{
invokeAddBiasSoftMax(bufferCast<T>(*inputs->draftLogits), static_cast<T**>(nullptr),
bufferCast<T>(*inputs->draftProbs), static_cast<T*>(nullptr), nullptr, finishedInput,
workspace->getDeviceBatchSlotsPtr(), batchSize, maxBatchSize, beamWidth * maxTokensPerStep,
mDecoderDomain.getVocabSize(), mDecoderDomain.getVocabSizePadded(),
/* skip softmax */ false,
/* batchSlotLogits */ true, getStream());
}
invokeAddBiasSoftMax(bufferCast<T>(*mTargetLogits), static_cast<T**>(nullptr), bufferCast<T>(*inputs->targetProbs),
static_cast<T*>(nullptr), nullptr, finishedInput, workspace->getDeviceBatchSlotsPtr(), batchSize, maxBatchSize,
beamWidth /* 1 */, mDecoderDomain.getVocabSize(), mDecoderDomain.getVocabSizePadded(),
/* skip softmax */ false,
/* batchSlotLogits */ false, getStream());
sync_check_cuda_error();
tksd::invokeAcceptDraftTokens(batchSize, bufferCast<T>(*inputs->draftProbs), bufferCast<T>(*inputs->targetProbs),
bufferCast<SizeType32>(*inputs->numDraftTokens), bufferCast<bool>(*inputs->useDraftLogits),
bufferCast<TokenIdType>(*inputs->draftTokenIds), finishedInput, finishedOutput, inputs->curandStates,
workspace->getDeviceBatchSlotsPtr(), maxTokensPerStep, beamWidth, mDecoderDomain.getVocabSizePadded(),
inputs->useRandomAcceptanceThreshold, inputs->constantThreshold, inputs->step,
bufferCast<bool>(*mBatchIsAccepted), bufferCast<SizeType32>(*mTargetOutputIds), getStream());
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void ExternalDraftTokensLayer<T>::multinomialSampling(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputs = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
auto const batchSize = inputs->logits.value()->getDimension<0>();
auto probs = bufferCastOrNull<T>(inputs->targetProbs);
auto* sequenceLength = bufferCastOrNull<SizeType32>(outputs->sequenceLength);
auto const* endIds = bufferCastOrNull<TokenIdType>(inputs->endIds);
FinishedState* finishedOutput = (outputs->finished)
? reinterpret_cast<FinishedState*>(bufferCastOrNull<FinishedState::UnderlyingType>(outputs->finished))
: nullptr;
TopPSamplingKernelParams<T> params{};
params.probs = probs;
params.outputIdsPtrs = bufferCastOrNull<TokenIdType*>(outputs->outputIdsPtr);
params.workspace = workspace->getRawWorkspaceDevicePtr();
params.topPs = bufferCastOrNull<float>(mRuntimeMultinomialDevice);
params.sequenceLength = sequenceLength;
params.endIds = endIds;
params.batchSlots = workspace->getDeviceBatchSlotsPtr();
params.finishedInput = nullptr;
params.finishedOutput = finishedOutput;
params.skipDecode = bufferCastOrNull<bool>(mBatchIsAccepted);
params.cumLogProbs = nullptr;
params.outputLogProbs = nullptr;
params.curandState = inputs->curandStates;
params.batchSize = batchSize;
params.maxBatchSize = mDecoderDomain.getBatchSize();
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
invokeBatchTopPSampling<T>(params, getStream());
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void ExternalDraftTokensLayer<T>::getAllTopKs(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputs = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
auto logits = bufferCastOrNull<T>(inputs->logits);
auto const batchSize = inputs->logits.value()->getDimension<0>();
auto const* batchSlotsHost = bufferCast<SizeType32>(*inputs->batchSlots);
auto* skipDecodeHostPtr = bufferCastOrNull<bool>(mSkipTopKDecodeHost);
auto const skip = allOfBatchSlots(batchSlotsHost, skipDecodeHostPtr, batchSize, true);
if (skip)
{
return;
}
FinishedState const* finishedInput = (inputs->finished)
? reinterpret_cast<FinishedState const*>(bufferCastOrNull<FinishedState::UnderlyingType>(inputs->finished))
: nullptr;
TopKSamplingKernelParams<T> params{};
params.logProbs = logits;
params.outputIds = bufferCastOrNull<TokenIdType>(mOutputIdsAfterSampling);
params.workspace = workspace->getRawWorkspaceDevicePtr();
params.maxTopP = 1.0f;
params.topPs = bufferCastOrNull<float>(mRuntimeTopPForTopKDevice);
params.maxTopK = mRuntimeMaxTopK;
params.topKs = bufferCastOrNull<SizeType32>(mRuntimeTopKDevice);
params.batchSlots = workspace->getDeviceBatchSlotsPtr();
params.finishedInput = finishedInput;
params.skipDecode = bufferCastOrNull<bool>(mSkipTopKDecodeDevice);
params.curandState = inputs->curandStates;
params.batchSize = batchSize;
params.maxBatchSize = mDecoderDomain.getBatchSize();
params.maxTokensPerStep = 1;
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
params.returnAllTopK = true;
params.maxSeqLen = mDecoderDomain.getVocabSizePadded(); // workaround for returning all topKs with outputIds
params.logitsHasProbs = inputs->probsComputed;
invokeBatchTopKSampling(params, getStream());
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void ExternalDraftTokensLayer<T>::getAllTopPs(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputs = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
auto logits = bufferCastOrNull<T>(inputs->logits);
auto const batchSize = inputs->logits.value()->getDimension<0>();
auto const* batchSlotsHost = bufferCast<SizeType32>(*inputs->batchSlots);
auto* skipDecodeHostPtr = bufferCastOrNull<bool>(mSkipTopPDecodeHost);
auto const skip = allOfBatchSlots(batchSlotsHost, skipDecodeHostPtr, batchSize, true);
if (skip)
{
return;
}
FinishedState const* finishedInput = (inputs->finished)
? reinterpret_cast<FinishedState const*>(bufferCastOrNull<FinishedState::UnderlyingType>(inputs->finished))
: nullptr;
TopPSamplingKernelParams<T> params{};
params.probs = logits;
params.outputIds = bufferCastOrNull<TokenIdType>(mOutputIdsAfterSampling);
params.workspace = workspace->getRawWorkspaceDevicePtr();
params.topPs = bufferCastOrNull<float>(mRuntimeTopPDevice);
params.batchSlots = workspace->getDeviceBatchSlotsPtr();
params.finishedInput = finishedInput;
params.skipDecode = bufferCastOrNull<bool>(mSkipTopPDecodeDevice);
params.curandState = inputs->curandStates;
params.batchSize = batchSize;
params.maxBatchSize = mDecoderDomain.getBatchSize();
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
params.returnAllTopP = true;
params.maxSeqLen = mDecoderDomain.getVocabSizePadded();
invokeBatchTopPSampling<T>(params, getStream());
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void ExternalDraftTokensLayer<T>::forwardAcceptedTokens(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputs = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
auto const batchSize = inputs->logits.value()->getDimension<0>();
auto const draftLogitsShape = (*inputs->draftLogits).getShape();
auto const maxTokensPerStep = draftLogitsShape.d[1]; // 1
FinishedState* finishedOutput = (outputs->finished)
? reinterpret_cast<FinishedState*>(bufferCastOrNull<FinishedState::UnderlyingType>(outputs->finished))
: nullptr;
tksd::invokeForwardAcceptedTokens(batchSize, workspace->getDeviceBatchSlotsPtr(),
bufferCast<bool>(*mBatchIsAccepted), bufferCastOrNull<SizeType32>(outputs->sequenceLength),
bufferCast<TokenIdType>(*inputs->draftTokenIds), bufferCastOrNull<TokenIdType*>(outputs->outputIdsPtr),
inputs->step, maxTokensPerStep, bufferCastOrNull<TokenIdType>(inputs->endIds), finishedOutput, getStream());
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template class ExternalDraftTokensLayer<float>;
template class ExternalDraftTokensLayer<half>;
} // namespace tensorrt_llm::layers

View File

@ -0,0 +1,100 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/layers/baseLayer.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/runtime/common.h"
#include <curand_kernel.h>
namespace tensorrt_llm::layers
{
//! \brief Top class for sampling layers.
//! It sets up and executes TopKSamplingLayer and TopPSamplingLayer samplings
template <typename T>
class ExternalDraftTokensLayer : public BaseLayer
{
public:
using Base = BaseLayer;
ExternalDraftTokensLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain,
std::shared_ptr<runtime::BufferManager> bufferManager);
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, TensorConstPtr batchSlots,
std::shared_ptr<BaseSetupParams> const& setupParams,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace) override;
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& inputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace) override;
//! @returns workspace needed for this layer in bytes
[[nodiscard]] size_t getWorkspaceSize() const noexcept override;
protected:
runtime::SizeType32 mRuntimeMaxTopK{0};
private:
using Base::mDecoderDomain;
executor::DecodingMode mDecodingMode;
size_t mWorkspaceSize{0};
size_t mSetupWorkspaceSize{0};
TensorPtr mCurandStatesDevice;
TensorPtr mSkipTopKDecodeDevice;
TensorPtr mSkipTopKDecodeHost;
TensorPtr mSkipTopPDecodeDevice;
TensorPtr mSkipTopPDecodeHost;
TensorPtr mBatchIsAccepted;
TensorPtr mRuntimeMultinomialDevice;
TensorPtr mOutputIdsAfterSampling;
TensorPtr mTargetOutputIds;
TensorPtr mRuntimeTopKDevice;
TensorPtr mRuntimeTopPForTopKDevice;
TensorPtr mRuntimeTopPDevice;
TensorPtr mInitialTopPDevice;
TensorPtr mMaskBuffer;
TensorPtr mTargetLogits;
private:
void allocateBuffer(runtime::SizeType32 batchSize);
void acceptDraftTokens(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace);
void multinomialSampling(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace);
void getAllTopKs(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace);
void getAllTopPs(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace);
void forwardAcceptedTokens(std::shared_ptr<BaseDecodingOutputs> const& outputs,
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace);
};
} // namespace tensorrt_llm::layers

View File

@ -267,7 +267,7 @@ void TopPSamplingLayer<T>::forwardAsync(std::shared_ptr<BaseDecodingOutputs> con
TopPSamplingKernelParams<T> params{};
params.probs = probs;
params.outputIds = bufferCastOrNull<TokenIdType*>(outputs->outputIdsPtr);
params.outputIdsPtrs = bufferCastOrNull<TokenIdType*>(outputs->outputIdsPtr);
params.workspace = workspace->getRawWorkspaceDevicePtr();
params.topPs = bufferCastOrNull<float>(mRuntimeTopPDevice);
params.sequenceLength = sequenceLength;

View File

@ -259,7 +259,7 @@ int LoraPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P
int idx = 0;
for (int reqId = 0; reqId < numReqs; reqId++)
{
const RequestType reqType = static_cast<RequestType const>(reqTypes[reqId]);
RequestType const reqType = static_cast<RequestType const>(reqTypes[reqId]);
if (reqType == RequestType::kGENERATION)
{
mExpandLoraWeightPtrs.push_back(reinterpret_cast<void const*>(loraWeightModulePtrs[reqId * 2]));
@ -284,7 +284,7 @@ int LoraPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P
fmtstr("LoraParams and input dims don't match, lora tokens %d input tokens %d", idx, numTokens));
}
// only used for unifed gemm
// only used for unified gemm
auto bestTactic = mPluginProfiler->getBestConfig(numTokens, mGemmId);
mLoraImpl->setBestTactic(bestTactic);
mLoraImpl->run(numTokens, numReqs, input, mExpandLoraRanks.data(), mExpandLoraWeightPtrs.data(), mWeightIndex,

View File

@ -305,14 +305,17 @@ int AllreducePlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfe
++tpRank;
}
int token_num = size / inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1];
auto params = tensorrt_llm::kernels::AllReduceParams::deserialize(
reinterpret_cast<int64_t*>(const_cast<void*>(inputs[1])), tpSize, tpRank);
reinterpret_cast<int64_t*>(const_cast<void*>(inputs[1])), tpSize, tpRank, mType, token_num, mOp);
params.local_output_buffer_ptr = outputs[0];
params.local_input_buffer_ptr = inputs[0];
params.elts_total = size;
if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM)
{
int fusion_ptr_idx = 2;
params.fusion_params.bias_buffer = mBias ? inputs[fusion_ptr_idx++] : nullptr;
params.fusion_params.residual_buffer = inputs[fusion_ptr_idx++];
@ -320,6 +323,15 @@ int AllreducePlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfe
params.fusion_params.hidden_size = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1];
params.fusion_params.eps = mEps;
params.fusion_params.intermediate_buffer = outputs[1];
for (int i = 0; i < tpSize; ++i)
{
params.fusion_params.lamport_peer_comm_buffer_ptrs[i]
= reinterpret_cast<void**>(const_cast<void*>(inputs[1]))[tpSize * 4 + i];
params.fusion_params.lamport_peer_comm_buffer_ptrs[i + tensorrt_llm::kernels::MAX_RANKS_PER_NODE]
= reinterpret_cast<void**>(const_cast<void*>(inputs[1]))[tpSize * 5 + i];
params.fusion_params.lamport_peer_comm_buffer_ptrs[i + tensorrt_llm::kernels::MAX_RANKS_PER_NODE * 2]
= reinterpret_cast<void**>(const_cast<void*>(inputs[1]))[tpSize * 6 + i];
}
}
tensorrt_llm::kernels::customAllReduce(params, mType, runtimeStrategy, mConfig, mOp, stream);
}

View File

@ -3,35 +3,19 @@ set(TRTLLM_PYBIND_MODULE
${TRTLLM_PYBIND_MODULE}
PARENT_SCOPE)
if(NOT BUILD_PYT)
message(
FATAL_ERROR
"Python bindings for C++ runtime require PyTorch. Please enable BUILD_PYT"
)
endif()
execute_process(
COMMAND ${Python3_EXECUTABLE} "-c"
"import pybind11 as pb11; print(pb11.get_cmake_dir(),end='');"
RESULT_VARIABLE PYBIND_CMAKE_DIR_RET
OUTPUT_VARIABLE PYBIND_CMAKE_DIR)
if(PYBIND_CMAKE_DIR_RET MATCHES 0)
list(APPEND CMAKE_PREFIX_PATH "${PYBIND_CMAKE_DIR}")
else()
message(ERROR "pybind11 CMake directory not found.")
endif()
find_package(pybind11 REQUIRED)
set(SRCS
bindings.cpp
batch_manager/algorithms.cpp
batch_manager/bindings.cpp
batch_manager/gptManager.cpp
batch_manager/llmRequest.cpp
batch_manager/inferenceRequest.cpp
batch_manager/kvCacheManager.cpp
batch_manager/llmRequest.cpp
batch_manager/namedTensor.cpp
executor/bindings.cpp
executor/executor.cpp)
executor/executor.cpp
bindings.cpp)
include_directories(${PROJECT_SOURCE_DIR}/include)
pybind11_add_module(${TRTLLM_PYBIND_MODULE} ${SRCS})
@ -41,15 +25,12 @@ set_property(TARGET ${TRTLLM_PYBIND_MODULE} PROPERTY POSITION_INDEPENDENT_CODE
target_link_directories(${TRTLLM_PYBIND_MODULE} PUBLIC
"${TORCH_INSTALL_PREFIX}/lib")
target_link_libraries(
${TRTLLM_PYBIND_MODULE} PUBLIC ${SHARED_TARGET} ${UNDEFINED_FLAG}
${NO_AS_NEEDED_FLAG})
target_link_libraries(
${TRTLLM_PYBIND_MODULE} PUBLIC ${Python3_LIBRARIES} ${TORCH_LIBRARIES}
torch_python ${UNDEFINED_FLAG})
target_compile_definitions(${TRTLLM_PYBIND_MODULE}
PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE})
target_compile_definitions(${TRTLLM_PYBIND_MODULE}
PUBLIC PYBIND11_DETAILED_ERROR_MESSAGES=1)
${TRTLLM_PYBIND_MODULE}
PUBLIC ${SHARED_TARGET} ${UNDEFINED_FLAG} ${NO_AS_NEEDED_FLAG}
${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python)
target_compile_definitions(
${TRTLLM_PYBIND_MODULE} PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE}
PYBIND11_DETAILED_ERROR_MESSAGES=1)
if(NOT WIN32)
set_target_properties(

View File

@ -0,0 +1,55 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "algorithms.h"
#include "tensorrt_llm/batch_manager/capacityScheduler.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
#include "tensorrt_llm/pybind/common/algorithmBindings.h"
namespace py = pybind11;
using namespace tensorrt_llm::batch_manager;
using namespace PybindUtils;
void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::module_& m)
{
// Algorithms with custom bindings
py::class_<CapacityScheduler>(m, CapacityScheduler::name)
.def_static("make", &CapacityScheduler::make, py::arg("max_num_requests"), py::arg("kv_cache_manager"),
py::arg("cross_kv_cache_manager"), py::arg("peft_cache_manager"), py::arg("capacity_scheduler_policy"),
py::arg("many_micro_batches") = false,
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE,
"LlmRequestState.GENERATION_COMPLETE"))
.def(py::init())
.def("__call__", &CapacityScheduler::operator())
.def("name", [](CapacityScheduler const&) { return CapacityScheduler::name; });
py::class_<MicroBatchScheduler>(m, MicroBatchScheduler::name)
.def_static("make", &MicroBatchScheduler::make, py::arg("max_batch_size"),
py::arg_v("max_num_tokens", std::nullopt, "None"), py::arg_v("ctx_chunk_config", std::nullopt, "None"),
py::arg_v("max_context_length", std::nullopt, "None"),
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE,
"LlmRequestState.GENERATION_COMPLETE"))
.def(py::init())
.def("__call__", &MicroBatchScheduler::operator())
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });
}

View File

@ -0,0 +1,28 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::batch_manager::algorithms
{
void initBindings(pybind11::module_& m);
}

View File

@ -0,0 +1,41 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "bindings.h"
#include "tensorrt_llm/batch_manager/common.h"
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
#include "tensorrt_llm/pybind/utils/bindTypes.h"
namespace py = pybind11;
namespace tb = tensorrt_llm::batch_manager;
namespace tle = tensorrt_llm::executor;
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm::pybind::batch_manager
{
void initBindings(pybind11::module_& m)
{
py::class_<tb::batch_scheduler::ContextChunkingConfig>(m, "ContextChunkingConfig")
.def(py::init<tle::ContextChunkingPolicy, tensorrt_llm::runtime::SizeType32>(), py::arg("chunking_policy"),
py::arg("chunk_unit_size"))
.def_readwrite("chunking_policy", &tb::batch_scheduler::ContextChunkingConfig::chunkingPolicy)
.def_readwrite("chunk_unit_size", &tb::batch_scheduler::ContextChunkingConfig::chunkUnitSize);
}
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -0,0 +1,28 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::batch_manager
{
void initBindings(pybind11::module_& m);
}

View File

@ -21,6 +21,7 @@
#include "namedTensor.h"
#include "tensorrt_llm/batch_manager/GptManager.h"
#include "tensorrt_llm/batch_manager/callbacks.h"
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
#include <ATen/ops/tensor.h>
#include <functional>

View File

@ -20,6 +20,7 @@
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
#include "tensorrt_llm/pybind/batch_manager/namedTensor.h"
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
#include <ATen/ATen.h>
#include <pybind11/pybind11.h>

View File

@ -0,0 +1,29 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#include "kvCacheManager.h"
#include "tensorrt_llm/pybind/utils/bindTypes.h"
namespace tb = tensorrt_llm::batch_manager;
namespace py = pybind11;
void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
{
// TODO: Provide proper bindings
py::classh<tb::kv_cache_manager::KVCacheManager>(m, "KVCacheManager");
}
void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)
{
// TODO: Provide proper bindings
py::classh<tb::BasePeftCacheManager>(m, "BasePeftCacheManager");
}

View File

@ -0,0 +1,36 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class KVCacheManagerBindings
{
public:
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
namespace tensorrt_llm::batch_manager
{
class BasePeftCacheManagerBindings
{
public:
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::batch_manager

View File

@ -17,22 +17,29 @@
#include "llmRequest.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/pybind/utils/bindTypes.h"
#include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <ATen/ATen.h>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <torch/extension.h>
#include <memory>
namespace tb = tensorrt_llm::batch_manager;
namespace tr = tensorrt_llm::runtime;
namespace tle = tensorrt_llm::executor;
using namespace tensorrt_llm::pybind::batch_manager;
using LlmRequestPtr = std::shared_ptr<tb::LlmRequest>;
using RequestList = std::list<LlmRequestPtr>;
namespace
{
@ -166,7 +173,6 @@ void LlmRequest::initBindings(py::module_& m)
.def_property_readonly("orig_prompt_len", &LlmRequest::getOrigPromptLen)
.def("has_draft_tokens", &LlmRequest::hasDraftTokens)
.def("move_to_next_context_chunk", &LlmRequest::moveToNextContextChunk)
.def("is_full_context_request", py::overload_cast<>(&LlmRequest::isFullContextRequest, py::const_))
.def("is_last_context_chunk", py::overload_cast<>(&LlmRequest::isLastContextChunk, py::const_))
.def("is_first_context_chunk", py::overload_cast<>(&LlmRequest::isFirstContextChunk, py::const_))
.def("get_context_remaining_length", py::overload_cast<>(&LlmRequest::getContextRemainingLength, py::const_))
@ -180,3 +186,140 @@ void LlmRequest::initBindings(py::module_& m)
{ self.setDraftLogits(std::make_optional<LlmRequest::TensorPtr>(logits)); })
.def_property("num_return_sequences", &LlmRequest::getNumReturnSequences, &LlmRequest::setNumReturnSequences);
}
void tb::LlmRequestBindings::initBindings(py::module_& m)
{
py::classh<tb::LlmRequest>(m, "PyLlmRequest")
.def("get_num_tokens", &tb::LlmRequest::getNumTokens, py::arg("beam"))
.def_property_readonly("max_beam_num_tokens", &tb::LlmRequest::getMaxBeamNumTokens)
.def("get_token", &tb::LlmRequest::getToken, py::arg("beam"), py::arg("pos"))
.def("get_tokens", py::overload_cast<tb::LlmRequest::SizeType32>(&tb::LlmRequest::getTokens, py::const_),
py::arg("beam"))
.def("get_tokens", py::overload_cast<>(&tb::LlmRequest::getTokens, py::const_))
.def_property_readonly("max_num_generated_tokens", &tb::LlmRequest::getMaxNumGeneratedTokens)
.def("add_new_token", &tb::LlmRequest::addNewToken, py::arg("token"), py::arg("beam"))
.def("add_new_tokens", &tb::LlmRequest::addNewTokens, py::arg("beam_tokens"))
.def("set_generated_tokens", &tb::LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens"))
.def("pause", &tb::LlmRequest::pause, py::arg("max_input_len"))
.def_property("max_sent_token_len", &tb::LlmRequest::getMaxSentTokenLen, &tb::LlmRequest::setMaxSentTokenLen)
.def("prompt_embedding_table",
[](tb::LlmRequest& self)
{
std::optional<at::Tensor> value{std::nullopt};
auto tensor = self.getPromptEmbeddingTable();
if (tensor)
{
value = tr::Torch::tensor(*tensor);
}
return value;
})
.def("bad_words_list",
[](tb::LlmRequest& self)
{
std::optional<at::Tensor> value{std::nullopt};
auto tensor = self.getBadWordsList();
if (tensor)
{
value = tr::Torch::tensor(*tensor);
}
return value;
})
.def_property(
"draft_logits",
[](tb::LlmRequest& self)
{
std::optional<at::Tensor> value{std::nullopt};
auto tensor = self.getDraftLogits();
if (tensor)
{
value = tr::Torch::tensor(*tensor);
}
return value;
},
[](tb::LlmRequest& self, at::Tensor& logits)
{ self.setDraftLogits(std::make_optional<tb::LlmRequest::TensorPtr>(tr::TorchView::of(logits))); })
.def("embedding_bias",
[](tb::LlmRequest& self)
{
std::optional<at::Tensor> value{std::nullopt};
auto tensor = self.getEmbeddingBias();
if (tensor)
{
value = tr::Torch::tensor(*tensor);
}
return value;
})
.def("lora_config",
[](tb::LlmRequest& self)
{
std::optional<at::Tensor> value{std::nullopt};
auto tensor = self.getLoraConfig();
if (tensor)
{
value = tr::Torch::tensor(*tensor);
}
return value;
})
.def("lora_weights",
[](tb::LlmRequest& self)
{
std::optional<at::Tensor> value{std::nullopt};
auto tensor = self.getLoraWeights();
if (tensor)
{
value = tr::Torch::tensor(*tensor);
}
return value;
})
.def("stop_words_list",
[](tb::LlmRequest& self)
{
std::optional<at::Tensor> value{std::nullopt};
auto tensor = self.getStopWordsList();
if (tensor)
{
value = tr::Torch::tensor(*tensor);
}
return value;
})
.def_property_readonly("prompt_vocab_size", &tb::LlmRequest::getPromptVocabSize)
.def_property_readonly("lora_task_id", &tb::LlmRequest::getLoraTaskId)
.def_property_readonly("lookahead_config", &tb::LlmRequest::getLookaheadConfig)
.def_property_readonly(
"context_current_position", py::overload_cast<>(&tb::LlmRequest::getContextCurrentPosition, py::const_))
.def_property("context_chunk_size", &tb::LlmRequest::getContextChunkSize, &tb::LlmRequest::setContextChunkSize)
.def_readwrite("request_id", &tb::LlmRequest::mRequestId)
.def_readwrite("prompt_len", &tb::LlmRequest::mPromptLen)
.def_readwrite("max_new_tokens", &tb::LlmRequest::mMaxNewTokens)
.def_readwrite("sampling_config", &tb::LlmRequest::mSamplingConfig)
.def_readwrite("state", &tb::LlmRequest::mState)
.def_readwrite("is_streaming", &tb::LlmRequest::mIsStreaming)
.def_readwrite("end_id", &tb::LlmRequest::mEndId)
.def_readwrite("pad_id", &tb::LlmRequest::mPadId)
.def_readwrite("seq_slot", &tb::LlmRequest::mSeqSlot)
.def_property_readonly("return_log_probs", &tb::LlmRequest::returnLogProbs)
.def_property_readonly("return_context_logits", &tb::LlmRequest::setReturnContextLogits)
.def_property_readonly("return_generation_logits", &tb::LlmRequest::setReturnGenerationLogits)
.def_property_readonly("log_probs", py::overload_cast<>(&tb::LlmRequest::getLogProbs, py::const_))
.def("get_log_probs", py::overload_cast<tb::LlmRequest::SizeType32>(&tb::LlmRequest::getLogProbs, py::const_))
.def("set_log_probs", &tb::LlmRequest::setLogProbs, py::arg("log_probs"), py::arg("beam"))
.def("set_return_encoder_output", &tb::LlmRequest::setReturnEncoderOutput, py::arg("return_encoder_output"))
.def("get_return_encoder_output", &tb::LlmRequest::getReturnEncoderOutput)
.def("priority", py::overload_cast<>(&tb::LlmRequest::priority, py::const_))
.def("set_priority", py::overload_cast<tle::PriorityType>(&tb::LlmRequest::setPriority))
.def_property_readonly("cum_log_probs", &tb::LlmRequest::getCumLogProbs)
.def("set_cum_log_prob", &tb::LlmRequest::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam"))
.def_property_readonly("orig_prompt_len", &tb::LlmRequest::getOrigPromptLen)
.def("has_draft_tokens", &tb::LlmRequest::hasDraftTokens)
.def("move_to_next_context_chunk", &tb::LlmRequest::moveToNextContextChunk)
.def("is_last_context_chunk", py::overload_cast<>(&tb::LlmRequest::isLastContextChunk, py::const_))
.def("is_first_context_chunk", py::overload_cast<>(&tb::LlmRequest::isFirstContextChunk, py::const_))
.def(
"get_context_remaining_length", py::overload_cast<>(&tb::LlmRequest::getContextRemainingLength, py::const_))
.def_property(
"draft_tokens", [](tb::LlmRequest& self) { return *self.getDraftTokens(); },
[](tb::LlmRequest& self, tb::LlmRequest::VecTokens& draftTokens)
{ self.setDraftTokens(std::make_shared<tb::LlmRequest::VecTokens>(std::move(draftTokens))); });
py::bind_vector<tb::RequestVector>(m, "RequestVector");
}

View File

@ -18,6 +18,7 @@
#pragma once
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
#include <ATen/ATen.h>
#include <ATen/ops/tensor.h>
@ -25,6 +26,15 @@
#include <optional>
#include <pybind11/pybind11.h>
namespace tensorrt_llm::batch_manager
{
class LlmRequestBindings
{
public:
static void initBindings(pybind11::module_& m);
};
} // namespace tensorrt_llm::batch_manager
namespace tensorrt_llm::pybind::batch_manager
{
@ -91,6 +101,7 @@ public:
std::optional<LlmRequest::LogitsPostProcessor> callback);
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> toTrtLlm() const;
static void initBindings(pybind11::module_& m);
};

View File

@ -18,6 +18,7 @@
#pragma once
#include "tensorrt_llm/batch_manager/namedTensor.h"
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
#include <ATen/ATen.h>

View File

@ -23,18 +23,20 @@
#include <torch/extension.h>
#include <vector>
#include "tensorrt_llm/pybind/batch_manager/gptManager.h"
#include "tensorrt_llm/pybind/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
#include "tensorrt_llm/pybind/batch_manager/namedTensor.h"
#include "tensorrt_llm/pybind/executor/bindings.h"
#include "tensorrt_llm/pybind/utils/pathCaster.h"
#include "tensorrt_llm/batch_manager/BatchManager.h"
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/pybind/batch_manager/algorithms.h"
#include "tensorrt_llm/pybind/batch_manager/bindings.h"
#include "tensorrt_llm/pybind/batch_manager/gptManager.h"
#include "tensorrt_llm/pybind/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/pybind/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
#include "tensorrt_llm/pybind/batch_manager/namedTensor.h"
#include "tensorrt_llm/pybind/executor/bindings.h"
#include "tensorrt_llm/pybind/utils/pathCaster.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/gptJsonConfig.h"
#include "tensorrt_llm/runtime/memoryCounters.h"
@ -333,6 +335,10 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
tpb::NamedTensor::initBindings(m);
tpb::LlmRequest::initBindings(m);
tb::kv_cache_manager::KVCacheManagerBindings::initBindings(m);
tb::BasePeftCacheManagerBindings::initBindings(m);
tb::LlmRequestBindings::initBindings(m);
auto tensorNames = m.def_submodule("tensor_names");
// Input tensor names
@ -412,8 +418,6 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def(py::pickle(gptModelParamsGetState, gptModelParamsSetState))
.def("__eq__", &tb::TrtGptModelOptionalParams::operator==);
tpb::GptManager::initBindings(m);
py::class_<tr::MemoryCounters>(m, "MemoryCounters")
.def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference)
.def_property_readonly("gpu", &tr::MemoryCounters::getGpu)
@ -447,4 +451,11 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
auto& world = tensorrt_llm::mpi::MpiComm::world();
tensorrt_llm::mpi::MpiComm::setSession(world.split(color, rank));
});
auto mInternal = m.def_submodule("internal", "Internal submodule of TRTLLM runtime");
tensorrt_llm::pybind::batch_manager::initBindings(mInternal);
tensorrt_llm::pybind::batch_manager::algorithms::initBindings(mInternal);
tpb::GptManager::initBindings(m);
}

View File

@ -0,0 +1,39 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include "opaqueBindings.h"
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <string>
namespace py = pybind11;
namespace PybindUtils
{
template <typename T>
void makeAlgorithmBindings(py::module_& m)
{
py::class_<T>(m, T::name).def(py::init()).def("forward", &T::forward).def("name", [](T const&) { return T::name; });
}
template <typename T>
void instantiatePybindAlgorithm(py::module_& m);
} // namespace PybindUtils
#define INSTANTIATE_ALGORITHM(TYPE) \
template <> \
void PybindUtils::instantiatePybindAlgorithm<TYPE>(py::module_ & m) \
{ \
makeAlgorithmBindings<TYPE>(m); \
};

View File

@ -0,0 +1,18 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include "tensorrt_llm/batch_manager/common.h"
#include <pybind11/stl_bind.h>
PYBIND11_MAKE_OPAQUE(tensorrt_llm::batch_manager::RequestVector)

View File

@ -408,9 +408,12 @@ void InitBindings(pybind11::module_& m)
.def_readwrite("is_sequence_final", &tle::Result::isSequenceFinal);
py::class_<tle::Response>(m, "Response")
.def(py::init<IdType, std::string>(), py::arg("request_id"), py::arg("error_msg"))
.def(py::init<IdType, tle::Result>(), py::arg("request_id"), py::arg("result"))
.def(py::init<IdType, std::string, std::optional<IdType>>(), py::arg("request_id"), py::arg("error_msg"),
py::arg("client_id") = std::nullopt)
.def(py::init<IdType, tle::Result, std::optional<IdType>>(), py::arg("request_id"), py::arg("result"),
py::arg("client_id") = std::nullopt)
.def_property_readonly("request_id", &tle::Response::getRequestId)
.def_property_readonly("client_id", &tle::Response::getClientId)
.def("has_error", &tle::Response::hasError)
.def_property_readonly("error_msg", &tle::Response::getErrorMsg)
.def_property_readonly("result", &tle::Response::getResult);

View File

@ -16,6 +16,8 @@
*/
#pragma once
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
#include <pybind11/pybind11.h>
namespace tensorrt_llm::pybind::executor

View File

@ -16,8 +16,10 @@
*/
#pragma once
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
#include <pybind11/pybind11.h>
namespace tle = tensorrt_llm::executor;

View File

@ -17,10 +17,10 @@
#pragma once
#include <pybind11/pybind11.h>
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include <pybind11/pybind11.h>
namespace PYBIND11_NAMESPACE
{

View File

@ -17,11 +17,11 @@
#pragma once
#include <pybind11/pybind11.h>
#include "tensorrt_llm/executor/tensor.h"
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
#include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <pybind11/pybind11.h>
#include <torch/extension.h>
namespace PYBIND11_NAMESPACE

View File

@ -0,0 +1,69 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
*
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
* property and proprietary rights in and to this material, related
* documentation and any modifications thereto. Any use, reproduction,
* disclosure or distribution of this material and related documentation
* without an express license agreement from NVIDIA CORPORATION or
* its affiliates is strictly prohibited.
*/
#pragma once
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
#include <pybind11/pybind11.h>
namespace PybindUtils
{
namespace py = pybind11;
template <typename T>
void bindList(py::module& m, std::string const& name)
{
py::class_<T>(m, name.c_str())
.def(py::init())
.def("push_back", [](T& lst, const typename T::value_type& value) { lst.push_back(value); })
.def("pop_back", [](T& lst) { lst.pop_back(); })
.def("push_front", [](T& lst, const typename T::value_type& value) { lst.push_front(value); })
.def("pop_front", [](T& lst) { lst.pop_front(); })
.def("__len__", [](T const& lst) { return lst.size(); })
.def(
"__iter__", [](T& lst) { return py::make_iterator(lst.begin(), lst.end()); }, py::keep_alive<0, 1>())
.def("__getitem__",
[](T const& lst, size_t index)
{
if (index >= lst.size())
throw py::index_error();
auto it = lst.begin();
std::advance(it, index);
return *it;
})
.def("__setitem__",
[](T& lst, size_t index, const typename T::value_type& value)
{
if (index >= lst.size())
throw py::index_error();
auto it = lst.begin();
std::advance(it, index);
*it = value;
});
}
template <typename T>
void bindSet(py::module& m, std::string const& name)
{
py::class_<T>(m, name.c_str())
.def(py::init())
.def("clear", &T::clear)
.def("size", &T::size)
// .def("insert", py::overload_cast<const typename T::value_type&>(&T::insert))
.def("erase", py::overload_cast<typename T::value_type const&>(&T::erase))
.def("__contains__", [](T const& s, typename T::value_type x) { return s.find(x) != s.end(); })
.def(
"__iter__", [](T& s) { return py::make_iterator(s.begin(), s.end()); }, py::keep_alive<0, 1>());
}
} // namespace PybindUtils

View File

@ -22,6 +22,7 @@
#include "pybind11/detail/descr.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
#include <filesystem>
namespace PYBIND11_NAMESPACE

View File

@ -161,6 +161,19 @@ void GptDecoder<T>::setup(SamplingConfig const& samplingConfig, size_t batchSize
lookaheadParams->attentionPackedMasks = output->lookaheadOutputs->packedMasks;
setupParams->decodingParams = std::move(lookaheadParams);
}
else if (mDecodingMode.isExternalDraftTokens())
{
auto externalDraftTokensParams = std::make_shared<tl::ExternalDraftTokensSetupParams>();
// signed to unsigned
if (mSamplingConfig.topK)
{
auto const& topK = mSamplingConfig.topK.value();
externalDraftTokensParams->runtimeTopK = std::vector<SizeType32>(std::begin(topK), std::end(topK));
}
externalDraftTokensParams->runtimeTopP = mSamplingConfig.topP;
setupParams->decodingParams = std::move(externalDraftTokensParams);
}
setupParams->decodingParams->randomSeed = mSamplingConfig.randomSeed;
mDecodingLayerWorkspace->setDeviceBatchSlots(batchSlots);
@ -244,6 +257,27 @@ void prepareMedusaInputs(
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void prepareExternalDraftTokensInputs(
DecodingInput const& inputs, size_t maxBatchSize, std::shared_ptr<tl::DecodingInputs>& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto inputParams = std::dynamic_pointer_cast<tl::ExternalDraftTokensInputs>(baseInputs);
auto const& externalDraftTokensInputs = inputs.externalDraftTokensInputs.value();
inputParams->draftLogits = externalDraftTokensInputs.draftLogits;
inputParams->draftProbs = externalDraftTokensInputs.draftProbs;
inputParams->targetProbs = externalDraftTokensInputs.targetProbs;
inputParams->numDraftTokens = externalDraftTokensInputs.numDraftTokens;
inputParams->draftTokenIds = externalDraftTokensInputs.draftTokenIds;
inputParams->constantThreshold = externalDraftTokensInputs.constantThreshold;
inputParams->useRandomAcceptanceThreshold = externalDraftTokensInputs.useRandomAcceptanceThreshold;
inputParams->step = externalDraftTokensInputs.step;
inputParams->useDraftLogits = externalDraftTokensInputs.useDraftLogits;
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void prepareExplicitDraftTokensInput(DecodingInput const& inputs, std::shared_ptr<tl::DecodingInputs>& baseInputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
@ -316,6 +350,11 @@ std::shared_ptr<tl::BaseDecodingInputs> prepareInputs(
forwardParams
= std::make_shared<tl::ExplicitDraftTokensInputs>(input.endIds, input.batchSlots, input.batchSize);
}
else if (decodingMode.isExternalDraftTokens())
{
forwardParams = std::make_shared<tl::ExternalDraftTokensInputs>(
input.endIds, input.batchSlots, input.step, ite, input.batchSize);
}
// No logits for explicit draft tokens
if (!decodingMode.isExplicitDraftTokens())
@ -379,6 +418,11 @@ std::shared_ptr<tl::BaseDecodingInputs> prepareInputs(
forwardParams->localBatchSize = input.batchSize;
}
if (decodingMode.isExternalDraftTokens())
{
prepareExternalDraftTokensInputs(input, maxBatchSize, forwardParams);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return forwardParams;
@ -593,105 +637,3 @@ namespace tensorrt_llm::runtime
template class GptDecoder<float>;
template class GptDecoder<half>;
} // namespace tensorrt_llm::runtime
void IGptDecoder::acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds,
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths, ITensor const& finishedVec,
ITensor& finishedFinal, ITensor& finishedSum, ITensor const& batchSlots, BufferManager::CudaStreamPtr const& stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const finishedVecShape = finishedVec.getShape();
auto const maxBatchSize = finishedVecShape.d[1];
auto const batchSlotsShape = batchSlots.getShape();
auto const batchSize = batchSlotsShape.d[0];
auto const targetTokenIdsShape = targetTokenIds.getShape();
auto const beamWidth = targetTokenIdsShape.d[1];
auto const maxSeqLength = targetTokenIdsShape.d[2];
auto const maxDraftTokens = draftTokenIds.getDimension<1>();
TLLM_CHECK_WITH_INFO(beamWidth == 1,
common::fmtstr("Beam width (" FMT_DIM ") > 1 is not supported for the speculative decoding", beamWidth));
TLLM_CHECK_WITH_INFO(batchSize <= maxBatchSize,
common::fmtstr("Batch size (" FMT_DIM ") is not smaller or equal to max batch size (" FMT_DIM ")", batchSize,
maxBatchSize));
TLLM_CHECK_WITH_INFO(draftTokenIds.getDimension<0>() == maxBatchSize,
common::fmtstr("Draft tokens batch size (" FMT_DIM ") is not equal to target batch size (" FMT_DIM ")",
draftTokenIds.getDimension<0>(), maxBatchSize));
TLLM_CHECK_WITH_INFO(contextLengths.getDimension<0>() == maxBatchSize,
common::fmtstr("Context length batch size (" FMT_DIM ") is not equal to batch size (" FMT_DIM ")",
contextLengths.getDimension<0>(), maxBatchSize));
TLLM_CHECK_WITH_INFO(numDraftTokens.getDimension<0>() == maxBatchSize,
common::fmtstr("Num draft tokens batch size (" FMT_DIM ") is not equal to batch size (" FMT_DIM ")",
numDraftTokens.getDimension<0>(), maxBatchSize));
TLLM_CHECK_WITH_INFO(sequenceLengths.getDimension<0>() == maxBatchSize,
common::fmtstr("Sequence length batch size (" FMT_DIM ") is not equal to batch size (" FMT_DIM ")",
sequenceLengths.getDimension<0>(), maxBatchSize));
tksd::invokeAcceptDraftTokensByIds(bufferCast<TokenIdType>(draftTokenIds), bufferCast<TokenIdType>(targetTokenIds),
bufferCast<SizeType32>(contextLengths), bufferCast<SizeType32>(numDraftTokens),
bufferCast<SizeType32>(sequenceLengths),
reinterpret_cast<tensorrt_llm::kernels::FinishedState const*>(
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(finishedVec)),
reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(finishedFinal)),
bufferCast<int>(finishedSum), bufferCast<SizeType32>(batchSlots), batchSize, maxBatchSize, beamWidth,
maxSeqLength, maxDraftTokens, stream->get());
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const& targetLogits, ITensor& draftProbs,
ITensor& targetProbs, ITensor const& numDraftTokens, ITensor& finished, ITensor const& batchSlots,
SizeType32 vocabSize, SizeType32 vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const draftLogitsShape = draftLogits.getShape();
auto const maxBatchSize = draftLogitsShape.d[0];
auto const maxTokensPerStep = draftLogitsShape.d[1];
auto const batchSlotsShape = batchSlots.getShape();
auto const batchSize = batchSlotsShape.d[0];
auto constexpr beamWidth = 1;
TLLM_CHECK_WITH_INFO(
beamWidth == 1, common::fmtstr("Beam width (%d) > 1 is not supported for the speculative decoding", beamWidth));
TLLM_CHECK(draftLogitsShape.d[2] == vocabSize);
if (draftLogits.getDataType() == nvinfer1::DataType::kFLOAT)
{
tksd::acceptDraftTokensByLogits(bufferCast<float>(draftLogits),
const_cast<float**>(reinterpret_cast<float const* const*>(bufferCast<int64_t>(targetLogits))),
bufferCast<float>(draftProbs), bufferCast<float>(targetProbs), bufferCast<SizeType32>(numDraftTokens),
reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(finished)),
curandState, bufferCast<SizeType32>(batchSlots), batchSize, maxBatchSize, beamWidth, vocabSize,
vocabSizePadded, maxTokensPerStep, useRandomAcceptThreshold, randomAcceptThreshold, stream->get());
}
else if (draftLogits.getDataType() == nvinfer1::DataType::kHALF)
{
tksd::acceptDraftTokensByLogits(bufferCast<half>(draftLogits),
const_cast<half**>(reinterpret_cast<half const* const*>(bufferCast<int64_t>(targetLogits))),
bufferCast<half>(draftProbs), bufferCast<half>(targetProbs), bufferCast<SizeType32>(numDraftTokens),
reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(finished)),
curandState, bufferCast<SizeType32>(batchSlots), batchSize, maxBatchSize, beamWidth, vocabSize,
vocabSizePadded, maxTokensPerStep, useRandomAcceptThreshold, randomAcceptThreshold, stream->get());
}
else
{
TLLM_THROW("Incorrect logits dtype. Only float32 and float16 are supported");
}
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

View File

@ -93,30 +93,28 @@ GptDecoderBatched::GptDecoderBatched(std::size_t vocabSize, std::size_t vocabSiz
auto constexpr nvFloatType = TRTDataType<float>::value;
auto& dInput = mJointDecodingInput;
auto dummyLogits = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
auto endIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
auto batchSlots = mBufferManager.emptyTensor(MemoryType::kPINNED, nvSizeType);
dInput
= std::make_unique<DecodingInput>(0, 0, 0, 0, std::move(dummyLogits), std::move(endIds), std::move(batchSlots));
{ // prevent reusing these vars after std::move
auto dummyLogits = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
auto endIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
auto batchSlots = mBufferManager.emptyTensor(MemoryType::kPINNED, nvSizeType);
dInput = std::make_unique<DecodingInput>(
0, 0, 0, 0, std::move(dummyLogits), std::move(endIds), std::move(batchSlots));
}
dInput->sequenceLimitLength = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
dInput->lengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
auto& dOutput = mJointDecodingOutput;
auto outputIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
auto gatheredOutputIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
dOutput = std::make_unique<DecodingOutput>(std::move(outputIds), std::move(gatheredOutputIds));
{ // prevent reusing these vars after std::move
auto outputIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
auto gatheredOutputIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
dOutput = std::make_unique<DecodingOutput>(std::move(outputIds), std::move(gatheredOutputIds));
}
dOutput->newTokensSteps = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
dOutput->parentIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
dOutput->parentIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
mFinishedSteps
= mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<tk::FinishedState::UnderlyingType>::value);
mDraftProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
mTargetProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
mBatchSlotsSetup = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<SizeType32>::value);
mBatchSlotsDecoder = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<SizeType32>::value);
mBatchSlotsAcceptTokens = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<SizeType32>::value);
mBatchSlotsAcceptLogits = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<SizeType32>::value);
mBatchSlotsSetup = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType);
mBatchSlotsDecoder = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType);
// use batchSize many entries instead of the usual 1
dOutput->finishedSum = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType);
mFinishedSum = BufferManager::pinned(ITensor::makeShape({1}), nvSizeType);
@ -129,16 +127,10 @@ GptDecoderBatched::GptDecoderBatched(std::size_t vocabSize, std::size_t vocabSiz
dOutput->logProbsTiled = mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<float>::value);
mNumDraftTokens = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
mCurandStates = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT8);
mDraftTokenIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
mDraftLogits = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
mTargetLogitsPtrs = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<float*>::value);
dInput->stopWordsPtrs = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<int32_t*>::value);
dInput->stopWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<SizeType32>::value);
dInput->stopWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType);
dInput->badWordsPtrs = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<int32_t*>::value);
dInput->badWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<SizeType32>::value);
dInput->badWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType);
dInput->embeddingBias = mBufferManager.emptyTensor(MemoryType::kGPU, dtype);
int device;
@ -149,13 +141,13 @@ GptDecoderBatched::GptDecoderBatched(std::size_t vocabSize, std::size_t vocabSiz
if (!mSpeculativeDecodingMode.isNone())
{
allocateSpeculativeDecodingBuffers();
allocateSpeculativeDecodingBuffers(dtype);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatched::allocateSpeculativeDecodingBuffers()
void GptDecoderBatched::allocateSpeculativeDecodingBuffers(nvinfer1::DataType dtype)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto constexpr nvSizeType = TRTDataType<SizeType32>::value;
@ -201,6 +193,22 @@ void GptDecoderBatched::allocateSpeculativeDecodingBuffers()
}
dOutput->speculativeDecodingOutputs = speculativeDecodingOutputs;
if (mSpeculativeDecodingMode.isDraftTokensExternal())
{
DecodingInput::ExternalDraftTokensInputs externalDraftTokensInputs;
externalDraftTokensInputs.draftLogits = mBufferManager.emptyTensor(MemoryType::kGPU, dtype);
externalDraftTokensInputs.draftProbs = mBufferManager.emptyTensor(MemoryType::kGPU, dtype);
externalDraftTokensInputs.targetProbs = mBufferManager.emptyTensor(MemoryType::kGPU, dtype);
externalDraftTokensInputs.numDraftTokens = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
externalDraftTokensInputs.useDraftLogits
= mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<bool>::value);
externalDraftTokensInputs.draftTokenIds
= mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
dInput->externalDraftTokensInputs = externalDraftTokensInputs;
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -251,6 +259,7 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max
auto const maxTokensPerStepXmaxBatchSizeXmaxBeamWidth
= ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize, maxBeamWidth});
auto const maxBatchSizeXmaxTokensPerStep = ITensor::makeShape({maxBatchSize, maxTokensPerEngineStep});
auto const jointOutputIdsShape = ITensor::makeShape({maxBatchSize, maxBeamWidth, maxSequenceLength});
auto& dInput = *mJointDecodingInput;
dInput.maxLength = mMaxSequenceLength;
@ -268,8 +277,6 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max
inputLengths.reshape(maxBatchSizeXmaxBeamWidth);
mBufferManager.setZero(inputLengths);
auto const jointOutputIdsShape = ITensor::makeShape({maxBatchSize, maxBeamWidth, maxSequenceLength});
auto& dOutput = *mJointDecodingOutput;
dOutput.ids->reshape(jointOutputIdsShape);
@ -296,15 +303,18 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max
mBatchSlotsSetup->reshape(ITensor::makeShape({maxBatchSize}));
mBatchSlotsDecoder->reshape(ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize}));
mBatchSlotsAcceptTokens->reshape(ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize}));
mBatchSlotsAcceptLogits->reshape(ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize}));
if (mSpeculativeDecodingMode.isDraftTokensExternal())
{
mDraftProbs->reshape(ITensor::makeShape(
dInput.externalDraftTokensInputs->draftProbs->reshape(ITensor::makeShape(
{maxBatchSize, maxTokensPerEngineStep, maxBeamWidth, static_cast<SizeType32>(mVocabSizePadded)}));
mTargetProbs->reshape(ITensor::makeShape(
dInput.externalDraftTokensInputs->targetProbs->reshape(ITensor::makeShape(
{maxBatchSize, maxTokensPerEngineStep, maxBeamWidth, static_cast<SizeType32>(mVocabSizePadded)}));
dInput.externalDraftTokensInputs->draftLogits->reshape(
ITensor::makeShape({maxBatchSize, maxTokensPerEngineStep, static_cast<SizeType32>(mVocabSizePadded)}));
dInput.externalDraftTokensInputs->draftTokenIds->reshape(maxBatchSizeXmaxTokensPerStep);
dInput.externalDraftTokensInputs->numDraftTokens->reshape(ITensor::makeShape({maxBatchSize, 1}));
dInput.externalDraftTokensInputs->useDraftLogits->reshape(ITensor::makeShape({maxBatchSize, 1}));
}
dOutput.parentIds->reshape(jointOutputIdsShape);
@ -317,7 +327,7 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max
dOutput.cumLogProbs->reshape(maxBatchSizeXmaxBeamWidth);
mBufferManager.setZero(*dOutput.cumLogProbs);
dOutput.logProbs->reshape(ITensor::makeShape({maxBatchSize, maxBeamWidth, mMaxSequenceLength}));
dOutput.logProbs->reshape(jointOutputIdsShape);
mBufferManager.setZero(*dOutput.logProbs);
if (maxBeamWidth > 1)
@ -328,15 +338,6 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max
dOutput.logProbsTiled->reshape(ITensor::makeShape({maxSequenceLength, maxBatchSize, maxBeamWidth}));
mBufferManager.setZero(*dOutput.logProbsTiled);
// speculative decoding only works for beam width == 1
mDraftTokenIds->reshape(maxBatchSizeXmaxTokensPerStep);
mDraftLogits->reshape(
ITensor::makeShape({maxBatchSize, maxTokensPerEngineStep, static_cast<SizeType32>(mVocabSizePadded)}));
mAcceptByLogits.resize(maxBatchSize);
mNumDraftTokens->reshape(ITensor::makeShape({maxBatchSize, 1}));
mCurandStates->reshape(ITensor::makeShape({maxBatchSize, sizeof(curandState_t)}));
mTargetLogitsPtrs->reshape(ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize}));
const_cast<ITensor&>(*dInput.embeddingBias)
.reshape(ITensor::makeShape({maxBatchSize, static_cast<SizeType32>(mVocabSizePadded)}));
const_cast<ITensor&>(*dInput.badWordsPtrs).reshape(ITensor::makeShape({maxBatchSize}));
@ -591,7 +592,6 @@ void GptDecoderBatched::newRequestSpeculativeDecoding(
SizeType32 batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
mAcceptByLogits[batchIdx] = false;
if (mSpeculativeDecodingMode.predictsDraftTokens())
{
@ -639,40 +639,41 @@ void GptDecoderBatched::newRequestDraftTokensExternal(
auto const& stream = mDecoderStream;
BufferManager manager{stream};
auto constexpr localBatchSize = 1;
auto& dJointInput = *mJointDecodingInput;
auto useDraftLogits = false;
auto const numDraftTokens = request.generatedTokensPerEngineStep - 1;
if (request.draftLogits.has_value())
{
TensorPtr draftLogitsView = ITensor::view(request.draftLogits.value());
mAcceptByLogits[batchIdx] = true;
useDraftLogits = true;
TensorPtr draftLogitsReqBatchSlice = ITensor::slice(mDraftLogits, batchIdx, 1);
TensorPtr draftLogitsReqBatchSlice
= ITensor::slice(dJointInput.externalDraftTokensInputs->draftLogits, batchIdx, 1);
draftLogitsReqBatchSlice->squeeze(0);
TensorPtr draftLogitsReqTokensSlice = ITensor::slice(draftLogitsReqBatchSlice, 0, numDraftTokens);
manager.copy(*draftLogitsView, *draftLogitsReqTokensSlice);
}
TensorPtr draftTokensReqBatchSlice = ITensor::slice(mDraftTokenIds, batchIdx, 1);
auto useDraftLogitsView = ITensor::slice(dJointInput.externalDraftTokensInputs->useDraftLogits, batchIdx, 1);
kernels::invokeFill(*useDraftLogitsView, useDraftLogits, *stream);
TensorPtr draftTokensReqBatchSlice
= ITensor::slice(dJointInput.externalDraftTokensInputs->draftTokenIds, batchIdx, 1);
draftTokensReqBatchSlice->squeeze(0);
TensorPtr draftTokensReqTokensSlice = ITensor::slice(draftTokensReqBatchSlice, 0, numDraftTokens);
TensorPtr draftTokensView = ITensor::view(request.draftTokens, ITensor::makeShape({numDraftTokens}));
manager.copy(*draftTokensView, *draftTokensReqTokensSlice);
auto const curandStatesView = ITensor::slice(mCurandStates, batchIdx, 1);
auto curandState = reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*curandStatesView));
auto batchSlotsPtr = bufferCast<SizeType32>(*ITensor::slice(mBatchSlotsSetup, 0, localBatchSize));
if (samplingConfig.randomSeed.has_value())
{
tk::invokeCurandInitialize(
curandState, batchSlotsPtr, localBatchSize, samplingConfig.randomSeed.value()[0], stream->get());
}
else
{
tk::invokeCurandInitialize(curandState, batchSlotsPtr, localBatchSize, 0, stream->get());
}
auto numDraftTokensView = ITensor::slice(mNumDraftTokens, batchIdx, 1);
auto numDraftTokensView = ITensor::slice(dJointInput.externalDraftTokensInputs->numDraftTokens, batchIdx, 1);
kernels::invokeFill(*numDraftTokensView, numDraftTokens, *stream);
bool const useRandomAcceptanceThreshold = !samplingConfig.draftAcceptanceThreshold.has_value();
float const constantThreshold
= useRandomAcceptanceThreshold ? 0 : samplingConfig.draftAcceptanceThreshold.value()[0];
dJointInput.externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold;
dJointInput.externalDraftTokensInputs->constantThreshold = constantThreshold;
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -838,8 +839,6 @@ void GptDecoderBatched::forwardDecoder(
auto batchSlotsDecoderPtr = maxBeamWidth > 1 && input.seqSlots ? bufferCast<SizeType32>(*input.seqSlots)
: bufferCast<SizeType32>(*mBatchSlotsDecoder);
auto batchSlotsAcceptTokensPtr = bufferCast<SizeType32>(*mBatchSlotsAcceptTokens);
auto batchSlotsAcceptLogitsPtr = bufferCast<SizeType32>(*mBatchSlotsAcceptLogits);
auto& dInput = *mJointDecodingInput;
auto& dOutput = *mJointDecodingOutput;
auto& decoder = *mDecoder;
@ -864,26 +863,12 @@ void GptDecoderBatched::forwardDecoder(
}
SizeType32 localBatchDecoderIdx = 0;
SizeType32 localBatchAcceptTokensIdx = 0;
SizeType32 localBatchAcceptLogitsIdx = 0;
for (SizeType32 bi = 0; bi < mActualBatchSize; ++bi)
{
if (mFinished[bi] || !input.active.at(bi) || step >= mNumDecodingEngineTokens[bi])
{
continue;
}
if (!mAcceptByLogits[bi] && mMaxDecodingDecoderTokens == 1 && mNumDecodingEngineTokens[bi] > 1
&& step == mNumDecodingEngineTokens[bi] - 1)
{
batchSlotsAcceptTokensPtr[step * mActualBatchSize + localBatchAcceptTokensIdx] = bi;
localBatchAcceptTokensIdx++;
}
else if (mAcceptByLogits[bi] && mMaxDecodingDecoderTokens == 1 && mNumDecodingEngineTokens[bi] > 1 && step == 0)
{
batchSlotsAcceptLogitsPtr[step * mActualBatchSize + localBatchAcceptLogitsIdx] = bi;
localBatchAcceptLogitsIdx++;
}
batchSlotsDecoderPtr[step * mActualBatchSize + localBatchDecoderIdx] = bi;
localBatchDecoderIdx++;
}
@ -892,9 +877,6 @@ void GptDecoderBatched::forwardDecoder(
= *std::max_element(std::begin(mNumDecodingEngineTokens), std::end(mNumDecodingEngineTokens));
std::vector<SharedConstPtr> logitsVec;
auto targetLogitsPtrsSlice = ITensor::slice(mTargetLogitsPtrs, step, 1);
auto targetLogitsPtrsSlicePtr = reinterpret_cast<void const**>(bufferCast<int64_t>(*targetLogitsPtrsSlice));
SizeType32 targetLogitsIdx = 0;
for (SizeType32 bi = 0; bi < mActualBatchSize; ++bi)
{
if (mFinished[bi] || !input.active.at(bi) || step >= mNumDecodingEngineTokens[bi])
@ -904,32 +886,6 @@ void GptDecoderBatched::forwardDecoder(
auto const& targetLogits = allTargetLogits[bi];
TensorPtr logitsSlice = ITensor::slice(targetLogits, step, singleRequest);
logitsVec.push_back(logitsSlice);
targetLogitsPtrsSlicePtr[targetLogitsIdx++] = logitsSlice->data();
}
if (async && localBatchAcceptLogitsIdx > 0)
{
// These params are only used for testing. Thus, can be per batch instead of per request
auto const& samplingConfig = decoder.getSamplingConfig();
bool const useRandomAcceptanceThreshold = !samplingConfig.draftAcceptanceThreshold.has_value();
float const randomAcceptanceThreshold
= useRandomAcceptanceThreshold ? 0 : samplingConfig.draftAcceptanceThreshold.value()[0];
TensorPtr batchSlotsAcceptLogitsStepSlice = ITensor::slice(mBatchSlotsAcceptLogits, step, 1);
batchSlotsAcceptLogitsStepSlice->squeeze(0);
TensorPtr batchSlotsAcceptLogitsSlice
= ITensor::slice(batchSlotsAcceptLogitsStepSlice, 0, localBatchAcceptLogitsIdx);
IGptDecoder::acceptDraftTokensByLogits(
/* [maxBatchSize, maxDecodingTokens, vocabPadded] */ *mDraftLogits,
/* [maxBatchSize][maxDecodingTokens, vocabPadded] */ *targetLogitsPtrsSlice,
/* [maxBatchSize, maxDecodingTokens, vocabPadded] */ *mDraftProbs,
/* [maxBatchSize, maxDecodingTokens, vocabPadded] */ *mTargetProbs,
/* [maxBatchSize] */ *mNumDraftTokens,
/* [maxDecodingTokens, maxBatchSize] */ *mFinishedSteps,
/* [bs] */ *batchSlotsAcceptLogitsSlice, static_cast<SizeType32>(mVocabSize),
static_cast<SizeType32>(mVocabSizePadded), useRandomAcceptanceThreshold, randomAcceptanceThreshold,
reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStates)), stream);
}
TensorPtr finishedStepsInput = ITensor::slice(mFinishedSteps, step, 1);
@ -958,6 +914,11 @@ void GptDecoderBatched::forwardDecoder(
dInput.medusaInputs->medusaLogits = input.predictedDraftLogits;
}
if (mSpeculativeDecodingMode.isDraftTokensExternal())
{
dInput.externalDraftTokensInputs->step = step;
}
dOutput.newTokens = newTokensStepView;
dOutput.finishReasons = finishedStepsOutput;
dOutput.lengths = sequenceLengths;
@ -987,26 +948,6 @@ void GptDecoderBatched::forwardDecoder(
mNbSteps[bi] += 1;
mFinished[bi] = mNbSteps[bi] >= mMaxNewTokens[bi];
}
if (async && localBatchAcceptTokensIdx > 0)
{
TensorPtr batchSlotsAcceptTokensStepSlice = ITensor::slice(mBatchSlotsAcceptTokens, step, 1);
batchSlotsAcceptTokensStepSlice->squeeze(0);
auto batchSlotsAcceptTokensSlice
= ITensor::slice(batchSlotsAcceptTokensStepSlice, 0, localBatchAcceptTokensIdx);
// Update finished state for 0th step
auto finishedFinal = ITensor::slice(mFinishedSteps, step, 1);
IGptDecoder::acceptDraftTokensByIds(
/* [maxBatchSize, maxBeamWidth, maxSeqLen] */ *dOutput.ids,
/* [maxBatchSize, maxDecodingDraftTokens] */ *mDraftTokenIds,
/* [maxBatchSize] */ *dInput.lengths,
/* [maxBatchSize] */ *mNumDraftTokens,
/* [maxBatchSize] */ *dOutput.lengths,
/* [maxDecodingTokens, maxBatchSize] */ *mFinishedSteps,
/* [maxBatchSize] */ *finishedFinal,
/* [maxBatchSize] */ *dOutput.finishedSum,
/* [bs] */ *batchSlotsAcceptTokensSlice, stream);
}
// If last iteration
if (async && step == maxDecodingEngineTokens - mMaxDecodingDecoderTokens)

Some files were not shown because too many files have changed in this diff Show More