mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
open source 7f370deb0090d885d7518c2b146399ba3933c004 (#2273)
* Update TensorRT-LLM --------- Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
This commit is contained in:
parent
40274aac39
commit
48686bca3a
10
README.md
10
README.md
@ -7,8 +7,8 @@ TensorRT-LLM
|
||||
[](https://nvidia.github.io/TensorRT-LLM/)
|
||||
[](https://www.python.org/downloads/release/python-31012/)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](./LICENSE)
|
||||
|
||||
[Architecture](./docs/source/architecture/overview.md) | [Results](./docs/source/performance/perf-overview.md) | [Examples](./examples/) | [Documentation](./docs/source/)
|
||||
@ -17,6 +17,12 @@ TensorRT-LLM
|
||||
<div align="left">
|
||||
|
||||
## Latest News
|
||||
* [2024/09/29] 🌟 AI at Meta PyTorch + TensorRT v2.4 🌟 ⚡TensorRT 10.1 ⚡PyTorch 2.4 ⚡CUDA 12.4 ⚡Python 3.12
|
||||
[➡️ link](https://github.com/pytorch/TensorRT/releases/tag/v2.4.0)
|
||||
<div align="center">
|
||||
<img src="docs/source/media/image-09-29-2024.png" width="50%">
|
||||
<div align="left">
|
||||
|
||||
* [2024/09/17] ✨ NVIDIA TensorRT-LLM Meetup
|
||||
[➡️ link](https://drive.google.com/file/d/1RR8GqC-QbuaKuHj82rZcXb3MS20SWo6F/view?usp=share_link)
|
||||
|
||||
|
||||
@ -159,6 +159,8 @@ struct BenchmarkParams
|
||||
std::optional<int> sinkTokenLength{std::nullopt};
|
||||
bool multiBlockMode{true};
|
||||
bool enableContextFMHAFP32Acc{false};
|
||||
bool cudaGraphMode{false};
|
||||
SizeType32 cudaGraphCacheSize{0};
|
||||
|
||||
// lora / peft params
|
||||
std::optional<std::string> loraDir{std::nullopt};
|
||||
@ -470,7 +472,38 @@ public:
|
||||
mRequestBenchInfos[requestId].firstTokenSeen = true;
|
||||
}
|
||||
|
||||
mRequestBenchInfos[requestId].outputLength += 1;
|
||||
mRequestBenchInfos[requestId].decodingIter += 1;
|
||||
}
|
||||
|
||||
void recordToken(uint64_t requestId, std::list<NamedTensor> const& responseTensors)
|
||||
{
|
||||
int32_t outputLength = 1;
|
||||
for (auto& tensor : responseTensors)
|
||||
{
|
||||
if (tensor.name == inference_request::kSequenceLengthTensorName)
|
||||
{
|
||||
// Tensor of shape nBeams, and we only need the first one
|
||||
outputLength = *(bufferCast<int32_t>(*(tensor.tensor)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
mRequestBenchInfos[requestId].outputLength += outputLength;
|
||||
this->recordToken(requestId);
|
||||
}
|
||||
|
||||
void recordToken(uint64_t requestId, texec::Response const& response)
|
||||
{
|
||||
auto outputTokenIds = response.getResult().outputTokenIds;
|
||||
|
||||
int32_t outputLength = 1;
|
||||
for (auto const& beam : outputTokenIds)
|
||||
{
|
||||
outputLength = std::max(static_cast<int32_t>(beam.size()), outputLength);
|
||||
}
|
||||
|
||||
mRequestBenchInfos[requestId].outputLength += outputLength;
|
||||
this->recordToken(requestId);
|
||||
}
|
||||
|
||||
void recordEnd(uint64_t requestId, std::list<NamedTensor> const& responseTensors, bool hasError)
|
||||
@ -500,7 +533,7 @@ public:
|
||||
}
|
||||
else
|
||||
{
|
||||
this->recordToken(requestId);
|
||||
this->recordToken(requestId, responseTensors);
|
||||
}
|
||||
}
|
||||
|
||||
@ -532,7 +565,7 @@ public:
|
||||
}
|
||||
else
|
||||
{
|
||||
this->recordToken(requestId);
|
||||
this->recordToken(requestId, response);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -821,8 +854,9 @@ public:
|
||||
benchmarkParams.freeGpuMemoryFraction, benchmarkParams.kvHostCacheSize, benchmarkParams.kvOnboardBlocks);
|
||||
texec::PeftCacheConfig peftCacheConfig(0, benchmarkParams.loraDeviceNumModLayers, 8, 64, 4, 4, 4, 24, 8,
|
||||
std::nullopt, benchmarkParams.loraHostCacheSize);
|
||||
texec::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig(
|
||||
benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc);
|
||||
texec::ExtendedRuntimePerfKnobConfig extendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode,
|
||||
benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode,
|
||||
benchmarkParams.cudaGraphCacheSize);
|
||||
texec::ExecutorConfig executorConfig(
|
||||
maxBeamWidth, schedulerConfig, kvCacheConfig, benchmarkParams.enableChunkedContext, true);
|
||||
executorConfig.setGpuWeightsPercent(benchmarkParams.gpuWeightsPercent);
|
||||
@ -940,7 +974,7 @@ public:
|
||||
{
|
||||
if (!warmup && !response.hasError())
|
||||
{
|
||||
mRecorder->recordToken(reqId);
|
||||
mRecorder->recordToken(reqId, response);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1228,7 +1262,7 @@ public:
|
||||
{
|
||||
if (errMsg.empty())
|
||||
{
|
||||
mRecorder->recordToken(requestId);
|
||||
mRecorder->recordToken(requestId, response_tensors);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1458,8 +1492,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
: benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead()
|
||||
: texec::DecodingMode::Auto(),
|
||||
benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices);
|
||||
optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(
|
||||
benchmarkParams.multiBlockMode, benchmarkParams.enableContextFMHAFP32Acc);
|
||||
optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode,
|
||||
benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode, benchmarkParams.cudaGraphCacheSize);
|
||||
|
||||
auto const jsonConfig = GptJsonConfig::parse(engineDir / "config.json");
|
||||
auto const worldConfig = WorldConfig::mpi(jsonConfig.getGpusPerNode(), jsonConfig.getTensorParallelism(),
|
||||
@ -1895,7 +1929,8 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("return_generation_logits", "Whether to return generation logits.",
|
||||
cxxopts::value<bool>()->default_value("false"));
|
||||
|
||||
options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.",
|
||||
options.add_options()("scheduler_policy",
|
||||
"Choose scheduler policy between max_utilization/guaranteed_no_evict/static_batch.",
|
||||
cxxopts::value<std::string>()->default_value("guaranteed_no_evict"));
|
||||
|
||||
options.add_options()("first_batch_delay",
|
||||
@ -1946,6 +1981,12 @@ int main(int argc, char* argv[])
|
||||
cxxopts::value<bool>()->default_value("true"));
|
||||
options.add_options()(
|
||||
"encoder_engine_dir", "Directory that store the engines of the encoder models.", cxxopts::value<std::string>());
|
||||
options.add_options()("cuda_graph_mode", "When enabled, inference is executed with cuda graph.",
|
||||
cxxopts::value<bool>()->default_value("false"));
|
||||
options.add_options()("cuda_graph_cache_size",
|
||||
"Specify how many cuda graphs are cached in the runtime. Larger cache gives better perf, but consumes more GPU "
|
||||
"memory.",
|
||||
cxxopts::value<SizeType32>()->default_value("0"));
|
||||
|
||||
options.add_options()("enable_context_fmha_fp32_acc", "Enable FMHA runner FP32 accumulation",
|
||||
cxxopts::value<bool>()->default_value("false"));
|
||||
@ -2131,6 +2172,12 @@ int main(int argc, char* argv[])
|
||||
// Argument: enable_context_fmha_fp32_acc
|
||||
benchmarkParams.enableContextFMHAFP32Acc = result["enable_context_fmha_fp32_acc"].as<bool>();
|
||||
|
||||
// Argument: cuda_graph_mode
|
||||
benchmarkParams.cudaGraphMode = result["cuda_graph_mode"].as<bool>();
|
||||
|
||||
// Argument: cuda_graph_mode
|
||||
benchmarkParams.cudaGraphCacheSize = result["cuda_graph_cache_size"].as<SizeType32>();
|
||||
|
||||
std::optional<TokenIdType> padId;
|
||||
// Argument: Padding token id
|
||||
if (result.count("pad_id"))
|
||||
@ -2168,6 +2215,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
capacitySchedulerPolicy = texec::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT;
|
||||
}
|
||||
else if (capacitySchedulerPolicyArg == "static_batch")
|
||||
{
|
||||
capacitySchedulerPolicy = texec::CapacitySchedulerPolicy::kSTATIC_BATCH;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_ERROR("Unexpected scheduler policy: " + capacitySchedulerPolicyArg);
|
||||
|
||||
@ -80,7 +80,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
|
||||
kv_cache_type = KVCacheType.CONTINUOUS
|
||||
if hasattr(self, 'kv_cache_type'):
|
||||
kv_cache_type = self.kv_cache_type
|
||||
kv_cache_type = KVCacheType(self.kv_cache_type)
|
||||
else:
|
||||
if hasattr(self, 'paged_kv_cache'):
|
||||
kv_cache_type = KVCacheType.PAGED if self.paged_kv_cache == True else KVCacheType.CONTINUOUS
|
||||
|
||||
@ -282,14 +282,37 @@ private:
|
||||
std::vector<std::vector<KVCacheBlock::IdType>> mCacheBlockIds;
|
||||
};
|
||||
|
||||
// BlockManager manages overall metadata of KVCacheBlocks in a layer of the
|
||||
// network. Layers are expected to be symmetric, so the metadata can be
|
||||
// reused for all layers of the network.
|
||||
// The array of cache blocks for a layer is called a pool.
|
||||
// Each pool has shape [max_blocks, 2, num_heads, tokens_per_block, head_size].
|
||||
// Size per block and number of blocks per pool are pre-determined and set in
|
||||
// constructor. These should not be changed after.
|
||||
// Block shape is [2, num_heads, tokens_per_block, head_size].
|
||||
// attach metadata to a pool pointer
|
||||
class KVCacheBlockPool
|
||||
{
|
||||
public:
|
||||
SizeType32 numKvHeads;
|
||||
SizeType32 numLayers;
|
||||
SizeType32 blockSize;
|
||||
|
||||
// Memory pools. Primary is fast memory, secondary is slower memory used for offloading.
|
||||
runtime::ITensor::SharedPtr primaryPtr;
|
||||
runtime::ITensor::SharedPtr secondaryPtr;
|
||||
|
||||
KVCacheBlockPool(SizeType32 numKvHeads, SizeType32 numLayers, SizeType32 blockSize,
|
||||
runtime::ITensor::SharedPtr primaryPtr = nullptr, runtime::ITensor::SharedPtr secondaryPtr = nullptr)
|
||||
: numKvHeads(numKvHeads)
|
||||
, numLayers(numLayers)
|
||||
, blockSize(blockSize)
|
||||
, primaryPtr(std::move(primaryPtr))
|
||||
, secondaryPtr(std::move(secondaryPtr))
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
// The BlockManager manages the metadata of KVCacheBlocks.
|
||||
// It manages multiple arrays of cache blocks called pools.
|
||||
// Layers with the same number of kv heads are grouped under the same pool.
|
||||
// Each pool has shape [max_blocks, num_layers, 2, num_kv_heads, tokens_pre_block, head_size], where num_layers refers
|
||||
// to the number of layers with the same num_kv_heads that share that pool.
|
||||
// The metadata of KVCacheBlocks is shared between layers, so each block spans all of the managed pool - an allocated
|
||||
// block matches some chunk of memory in each pool. The shape of the chunk in every pool is [2, num_kv_heads,
|
||||
// tokens_per_block, head_size]. The size per block and number of blocks are pre-determined and set in the constructor.
|
||||
// BlockManager maintains a list of free blocks at any time.
|
||||
// Alloc pops off the block at the front, and Free pushes it back to the vector.
|
||||
// BlockManager maintains a vector of lists of seqSlotIdx to allocated blocks
|
||||
@ -300,7 +323,7 @@ public:
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;
|
||||
|
||||
explicit BlockManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead,
|
||||
explicit BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead,
|
||||
SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
|
||||
std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType = CacheType::kSELF);
|
||||
|
||||
@ -338,7 +361,7 @@ public:
|
||||
|
||||
[[nodiscard]] SizeType32 getNumFreeBlocks() const noexcept
|
||||
{
|
||||
return mFreePrimaryBlocks.size();
|
||||
return mFreePrimaryBlocksSize;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumAllocTotalBlocks() const
|
||||
@ -381,21 +404,26 @@ public:
|
||||
return mTokensPerBlock;
|
||||
}
|
||||
|
||||
//! \brief Get size of one K/V cache block in one layer.
|
||||
//! @details Volume of [numKvHeads, tokensPerBlock, sizePerHead]
|
||||
[[nodiscard]] SizeType32 getBlockSize() const
|
||||
//! \brief Get size of one K/V cache block in one layer for the specified pool.
|
||||
//! @details Volume of [numKvHeads, tokensPerBlock, sizePerHead] in the specified pool.
|
||||
[[nodiscard]] SizeType32 getBlockSize(SizeType32 poolIdx) const
|
||||
{
|
||||
return mBlockSize;
|
||||
return mPools.at(poolIdx).blockSize;
|
||||
}
|
||||
|
||||
[[nodiscard]] runtime::ITensor::SharedPtr getPrimaryPool() const noexcept
|
||||
[[nodiscard]] SizeType32 getNumPools() const noexcept
|
||||
{
|
||||
return mPrimaryPool;
|
||||
return mPools.size();
|
||||
}
|
||||
|
||||
[[nodiscard]] runtime::ITensor::SharedPtr getSecondaryPool() const noexcept
|
||||
[[nodiscard]] runtime::ITensor::SharedPtr getPrimaryPool(SizeType32 poolIdx) const
|
||||
{
|
||||
return mSecondaryPool;
|
||||
return mPools.at(poolIdx).primaryPtr;
|
||||
}
|
||||
|
||||
[[nodiscard]] runtime::ITensor::SharedPtr getSecondaryPool(SizeType32 poolIdx) const
|
||||
{
|
||||
return mPools.at(poolIdx).secondaryPtr;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumLayers() const
|
||||
@ -403,10 +431,32 @@ public:
|
||||
return mNumLayers;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumPrimaryBlocks() const
|
||||
{
|
||||
return mNumPrimaryBlocks;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumSecondaryBlocks() const
|
||||
{
|
||||
return mNumSecondaryBlocks;
|
||||
}
|
||||
|
||||
[[nodiscard]] CacheType getCacheType() const
|
||||
{
|
||||
return mCacheType;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getLayerPoolIdx(SizeType32 layerIdx) const
|
||||
{
|
||||
return mLayerToPool.at(layerIdx);
|
||||
}
|
||||
|
||||
//! \brief Get index in pool to K or V block.
|
||||
//! \param blockId the blockId as returned by getBlockId()
|
||||
//! \param fieldIdx either 0 (K) or 1 (V),
|
||||
[[nodiscard]] kernels::KVCacheIndex getKOrVBlockIndex(KVCacheBlock::IdType blockId, SizeType32 fieldIdx) const;
|
||||
//! \param poolIdx the index of the pool for which the index is calculated (each pool has different strides)
|
||||
[[nodiscard]] kernels::KVCacheIndex getKOrVBlockIndex(
|
||||
KVCacheBlock::IdType blockId, SizeType32 fieldIdx, SizeType32 poolIdx) const;
|
||||
|
||||
//! \brief Bring offloaded block from secondary to primary memory.
|
||||
//! \details Does nothing of block is already in primary memory.
|
||||
@ -451,7 +501,8 @@ private:
|
||||
void claimLeafBlock(KVCacheBlock& block);
|
||||
|
||||
//! \brief Compute pointer to raw KV block (K & V, all layers).
|
||||
[[nodiscard]] runtime::ITensor::SharedPtr computeBlockPointer(std::shared_ptr<KVCacheBlock> block) const;
|
||||
[[nodiscard]] runtime::ITensor::SharedPtr computeBlockPointer(
|
||||
std::shared_ptr<KVCacheBlock> block, SizeType32 poolIdx) const;
|
||||
|
||||
//! \brief Copy content of src block to dst.
|
||||
void copyBlock(BlockPtr src, BlockPtr dst);
|
||||
@ -460,23 +511,30 @@ private:
|
||||
// Number of blocks in pools
|
||||
SizeType32 mNumPrimaryBlocks;
|
||||
SizeType32 mNumSecondaryBlocks;
|
||||
// List of free blocks. Blocks are either backed by fast primary memory or slow secondary memory,
|
||||
// we maintain separate queues for these.
|
||||
// List of free blocks. Blocks are either backed by fast primary memory or slow secondary memory.
|
||||
// We maintain separate queues for these.
|
||||
// We cache size of each queue instead of calling std::list::size, because size is O(N) function.
|
||||
SizeType32 mFreePrimaryBlocksSize;
|
||||
SizeType32 mFreeSecondaryBlocksSize;
|
||||
FreeBlocksQueue mFreePrimaryBlocks;
|
||||
FreeBlocksQueue mFreeSecondaryBlocks;
|
||||
// List of allocated blocks for each sequences
|
||||
std::vector<std::vector<BlockPtr>> mAllocatedBlocksPerSeq;
|
||||
// Memory pools. Primary is fast memory, secondary is slower memory used for offloading.
|
||||
runtime::ITensor::SharedPtr mPrimaryPool;
|
||||
runtime::ITensor::SharedPtr mSecondaryPool;
|
||||
|
||||
// Pool per unique numKvHeads in the model
|
||||
std::vector<KVCacheBlockPool> mPools;
|
||||
// Matching of model layers to their pools
|
||||
std::vector<SizeType32> mLayerToPool;
|
||||
|
||||
// Whether offloaded blocks should be onboarded before reuse.
|
||||
bool mOnboardBlocks;
|
||||
// Buffer manager
|
||||
runtime::BufferManager mBufferManager;
|
||||
|
||||
// Size of a single KV heads
|
||||
SizeType32 mSizePerHead;
|
||||
// Number of layers
|
||||
SizeType32 mNumLayers;
|
||||
// Volume of [numKvHeads, tokensPerBlock, sizePerHead]
|
||||
SizeType32 mBlockSize;
|
||||
// Used to keep track of number of free blocks during scheduling
|
||||
SizeType32 mSchedulingNumFreeBlocks;
|
||||
// Number of tokens per one block
|
||||
@ -502,12 +560,18 @@ public:
|
||||
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
|
||||
using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;
|
||||
|
||||
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
|
||||
KVCacheManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
|
||||
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
|
||||
SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, bool useOneMoreBlock,
|
||||
CudaStreamPtr stream, bool enableBlockReuse = false, bool onboardBlocks = true,
|
||||
CacheType cacheType = CacheType::kSELF);
|
||||
|
||||
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
|
||||
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
|
||||
SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, bool useOneMoreBlock,
|
||||
CudaStreamPtr stream, bool enableBlockReuse = true, bool onboardBlocks = true,
|
||||
CacheType cacheType = CacheType::kSELF);
|
||||
|
||||
void allocatePools(nvinfer1::DataType dtype, bool useUvm = false);
|
||||
|
||||
void startScheduling();
|
||||
@ -577,11 +641,11 @@ public:
|
||||
/// @return The number of blocks
|
||||
[[nodiscard]] SizeType32 getNeededBlocksOneStep(LlmRequest const& req, bool twoStepsLookAhead) const;
|
||||
|
||||
/// @brief Function that computes the number of KV cache blocks needed to advance a request to completion (i.e. for
|
||||
/// maxNewTokens)
|
||||
/// @brief Function that computes the number of KV cache blocks remaining to advance a request to completion (i.e.
|
||||
/// for maxNewTokens); the allocated blocks are excluded
|
||||
/// @param req The request for which we need to calculate the number of needed KV cache blocks
|
||||
/// @return The number of blocks
|
||||
[[nodiscard]] SizeType32 getNeededBlocksToCompletion(LlmRequest const& req) const;
|
||||
[[nodiscard]] SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req) const;
|
||||
|
||||
void addContextTokens(SizeType32 seqSlotIdx, SizeType32 numTokens);
|
||||
|
||||
@ -603,6 +667,8 @@ public:
|
||||
|
||||
[[nodiscard]] runtime::ITensor::UniquePtr getBlockPoolPointers() const;
|
||||
|
||||
[[nodiscard]] runtime::ITensor::UniquePtr getLayerToPoolMapping() const;
|
||||
|
||||
void getBlockOffsetsOfBatch(
|
||||
runtime::ITensor& output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, SizeType32 beamWidth) const;
|
||||
|
||||
@ -610,18 +676,16 @@ public:
|
||||
SizeType32 copyBlockOffsets(
|
||||
runtime::ITensor& output, SizeType32 outputSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const;
|
||||
|
||||
// Volume of [2, numKvHeads, tokensPerBlock, sizePerHead]
|
||||
[[nodiscard]] static SizeType32 constexpr calculatePageSize(tensorrt_llm::runtime::ModelConfig const& modelConfig)
|
||||
{
|
||||
return 2 * modelConfig.getNbKvHeads() * modelConfig.getTokensPerBlock() * modelConfig.getSizePerHead();
|
||||
}
|
||||
|
||||
// numLayers * 2 * numKvHeads * sizePerHead
|
||||
[[nodiscard]] static SizeType32 constexpr calculateCacheSizePerToken(
|
||||
// Sum of numLayers * 2 * numKvHeads * sizePerHead for each pool
|
||||
[[nodiscard]] static SizeType32 calculateCacheSizePerToken(
|
||||
tensorrt_llm::runtime::ModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig)
|
||||
{
|
||||
return modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism()) * 2 * modelConfig.getNbKvHeads()
|
||||
* modelConfig.getSizePerHead();
|
||||
// NOTE: We expect the initialization of modelConfig to have already taken the tp size into account and do not
|
||||
// address it here
|
||||
// consider only local layers for the calculation
|
||||
return modelConfig.getSumLocalKvHeads(
|
||||
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank())
|
||||
* 2 * modelConfig.getSizePerHead();
|
||||
}
|
||||
|
||||
[[nodiscard]] static std::tuple<SizeType32, SizeType32> const calculateMaxNumBlocks(KvCacheConfig const& config,
|
||||
@ -640,7 +704,7 @@ public:
|
||||
|
||||
[[nodiscard]] bool isCrossKv() const
|
||||
{
|
||||
return mCacheType == CacheType::kCROSS;
|
||||
return mBlockManager.getCacheType() == CacheType::kCROSS;
|
||||
}
|
||||
|
||||
//! \brief Find first new block that must be allocated for context phase and return it's concatenated token vector.
|
||||
@ -691,8 +755,6 @@ private:
|
||||
runtime::ITensor::SharedPtr mSequenceBlockIndices;
|
||||
// Whether to cache KV pages for reuse
|
||||
bool mEnableBlockReuse;
|
||||
// KV cache type (self or cross)
|
||||
CacheType mCacheType;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
@ -91,9 +91,9 @@ private:
|
||||
};
|
||||
|
||||
[[nodiscard]] BlockIterator getBlockBeginIt(
|
||||
KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam);
|
||||
KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam, SizeType32 poolIdx);
|
||||
|
||||
[[nodiscard]] BlockIterator getBlockEndIt(
|
||||
KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam);
|
||||
KVCacheManager const& cacheManager, LlmRequest const& request, SizeType32 beam, SizeType32 poolIdx);
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
@ -400,7 +400,8 @@ public:
|
||||
TLLM_CHECK_WITH_INFO(mInputTokenExtraIds.has_value() && mInputTokenExtraIds.value(),
|
||||
"Input token extra ids must be provided when enabling kv cache reuse with prompt table");
|
||||
TLLM_CHECK_WITH_INFO(mInputTokenExtraIds.value()->size() == static_cast<size_t>(mOrigPromptLen),
|
||||
"inputTokenExtraIds vector size must be the same as input token vector size.");
|
||||
"inputTokenExtraIds vector size (%lu) must be the same as input token vector size (%lu).",
|
||||
mInputTokenExtraIds.value()->size(), static_cast<size_t>(mOrigPromptLen));
|
||||
}
|
||||
}
|
||||
|
||||
@ -411,7 +412,7 @@ public:
|
||||
|
||||
/// @brief Get the params of the context
|
||||
/// @return The params of the context
|
||||
std::optional<executor::ContextPhaseParams> const& getContextPhaseParams() const noexcept
|
||||
[[nodiscard]] std::optional<executor::ContextPhaseParams> const& getContextPhaseParams() const noexcept
|
||||
{
|
||||
return mContextPhaseParams;
|
||||
}
|
||||
@ -423,10 +424,10 @@ public:
|
||||
|
||||
/// @brief Get the state params of the context
|
||||
/// @return The state params of the context
|
||||
executor::ContextPhaseState const& getContextPhaseState() const
|
||||
[[nodiscard]] executor::DataTransceiverState const& getDataTransceiverState() const
|
||||
{
|
||||
TLLM_CHECK(mContextPhaseParams.has_value());
|
||||
return *static_cast<executor::ContextPhaseState const*>(mContextPhaseParams.value().getState());
|
||||
return *static_cast<executor::DataTransceiverState const*>(mContextPhaseParams.value().getState());
|
||||
}
|
||||
|
||||
/// @brief Get total number of tokens for this req (prompt + generated)
|
||||
@ -659,6 +660,11 @@ public:
|
||||
return mSequenceIndex > 0;
|
||||
}
|
||||
|
||||
[[nodiscard]] RequestIdType getParentRequestId() const
|
||||
{
|
||||
return mParentRequestId;
|
||||
}
|
||||
|
||||
/// @brief Return a vector of the last-generated tokens of shape [num_beams]
|
||||
[[nodiscard]] VecTokens const& getLastTokens()
|
||||
{
|
||||
@ -858,14 +864,46 @@ public:
|
||||
return mOrigPromptLen;
|
||||
}
|
||||
|
||||
void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen)
|
||||
[[nodiscard]] SizeType32 getPromptLen() const
|
||||
{
|
||||
mPrepopulatedPromptLen = prepopulatedPromptLen;
|
||||
return mPromptLen;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getPrepopulatedPromptLen() const
|
||||
void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock)
|
||||
{
|
||||
return mPrepopulatedPromptLen;
|
||||
auto const promptLen = getPromptLen();
|
||||
TLLM_CHECK(prepopulatedPromptLen < promptLen);
|
||||
|
||||
if (prepopulatedPromptLen > 0)
|
||||
{
|
||||
// Currently, the runtime process is to apply for cache first and then determine prepopulation.
|
||||
// Use the prepopulated length to advance the context position and decrease chunk size if necessary.
|
||||
if (isFullContextRequest())
|
||||
{
|
||||
setContextCurrentPosition(prepopulatedPromptLen);
|
||||
setContextChunkSize(promptLen);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto chunkSize = getContextChunkSize();
|
||||
if (prepopulatedPromptLen + chunkSize < promptLen)
|
||||
{
|
||||
// make sure to end at block boundary after current chunk
|
||||
auto const flooredEndPosition
|
||||
= (prepopulatedPromptLen + chunkSize) / kvTokensPerBlock * kvTokensPerBlock;
|
||||
chunkSize = flooredEndPosition - prepopulatedPromptLen;
|
||||
TLLM_CHECK(chunkSize <= getContextChunkSize());
|
||||
}
|
||||
setContextCurrentPosition(prepopulatedPromptLen);
|
||||
setContextChunkSize(chunkSize);
|
||||
}
|
||||
if (!isLastContextChunk())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO((getContextCurrentPosition() + getContextChunkSize()) % kvTokensPerBlock == 0,
|
||||
"To prevent cache fragmentation, the context position after current chunk should be divisible "
|
||||
"by the number of tokens per block, except for the last chunk.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void setDraftTokens(std::shared_ptr<VecTokens> const& draftTokens)
|
||||
@ -1276,7 +1314,7 @@ public:
|
||||
}
|
||||
// TODO: fill the rank ids
|
||||
result.contextPhaseParams = executor::ContextPhaseParams{
|
||||
std::move(firstGenTokens), mContextPhaseParams.value().releaseState()};
|
||||
std::move(firstGenTokens), mRequestId, mContextPhaseParams.value().releaseState()};
|
||||
}
|
||||
|
||||
auto const calculateNbTokensOut = [this](SizeType32 maxNbTokens)
|
||||
@ -1513,8 +1551,8 @@ private:
|
||||
{
|
||||
if (mInputTokenExtraIds.value()->size() != inputTokens.size())
|
||||
{
|
||||
std::string errStr = "inputTokenExtraIds vector size must be the same as input token vector size.";
|
||||
TLLM_THROW(errStr);
|
||||
TLLM_THROW("inputTokenExtraIds vector size (%lu) must be the same as input token vector size (%lu).",
|
||||
mInputTokenExtraIds.value()->size(), inputTokens.size());
|
||||
}
|
||||
VecTokenExtraIds tokenExtraIds = *mInputTokenExtraIds.value();
|
||||
for (std::size_t i = 0; i < inputTokens.size(); ++i)
|
||||
|
||||
@ -161,7 +161,7 @@ inline std::optional<bool> isCudaLaunchBlocking()
|
||||
return result;
|
||||
}
|
||||
|
||||
inline void syncAndCheck(char const* const file, int const line)
|
||||
inline bool doCheckError()
|
||||
{
|
||||
auto const cudaLaunchBlocking = isCudaLaunchBlocking();
|
||||
#ifndef NDEBUG
|
||||
@ -170,10 +170,15 @@ inline void syncAndCheck(char const* const file, int const line)
|
||||
bool const checkError = cudaLaunchBlocking.value_or(false);
|
||||
#endif
|
||||
|
||||
if (checkError)
|
||||
return checkError;
|
||||
}
|
||||
|
||||
inline void syncAndCheck(char const* const file, int const line)
|
||||
{
|
||||
if (doCheckError())
|
||||
{
|
||||
cudaError_t result = cudaDeviceSynchronize();
|
||||
check(result, "cudaDeviceSynchronize", file, line);
|
||||
check(cudaGetLastError(), "cudaGetLastError", file, line);
|
||||
check(cudaDeviceSynchronize(), "cudaDeviceSynchronize", file, line);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -380,6 +380,10 @@ public:
|
||||
|
||||
void allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const;
|
||||
void allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const;
|
||||
|
||||
void allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf,
|
||||
std::vector<int> const& recvcounts, std::vector<int> const& displs, MpiType recvtype) const;
|
||||
|
||||
void barrier() const;
|
||||
|
||||
void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
|
||||
|
||||
@ -43,7 +43,7 @@ char const* version() noexcept;
|
||||
|
||||
class Model;
|
||||
class Serialization;
|
||||
class ContextPhaseState;
|
||||
class DataTransceiverState;
|
||||
|
||||
/// @brief Sampling configuration
|
||||
class SamplingConfig
|
||||
@ -283,8 +283,10 @@ private:
|
||||
class ContextPhaseParams
|
||||
{
|
||||
public:
|
||||
explicit ContextPhaseParams(VecTokens firstGenTokens);
|
||||
ContextPhaseParams(VecTokens firstGenTokens, void* state);
|
||||
using RequestIdType = std::uint64_t;
|
||||
|
||||
explicit ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId);
|
||||
ContextPhaseParams(VecTokens firstGenTokens, RequestIdType reqId, void* state);
|
||||
|
||||
ContextPhaseParams(ContextPhaseParams const&);
|
||||
ContextPhaseParams(ContextPhaseParams&&);
|
||||
@ -295,6 +297,8 @@ public:
|
||||
|
||||
[[nodiscard]] VecTokens const& getFirstGenTokens() const& noexcept;
|
||||
[[nodiscard]] VecTokens popFirstGenTokens() && noexcept;
|
||||
[[nodiscard]] RequestIdType getReqId() const noexcept;
|
||||
|
||||
[[nodiscard]] void const* getState() const noexcept;
|
||||
[[nodiscard]] void* getState() noexcept;
|
||||
[[nodiscard]] void* releaseState() noexcept;
|
||||
@ -304,6 +308,9 @@ private:
|
||||
static void deleter(void const* data);
|
||||
using StatePtr = std::unique_ptr<void, decltype(&deleter)>;
|
||||
|
||||
/// @brief This request corresponds to the request ID in the context phase.
|
||||
RequestIdType mReqId{0};
|
||||
|
||||
/// @brief The first tokens generated by context executor
|
||||
VecTokens mFirstGenTokens;
|
||||
|
||||
@ -593,18 +600,24 @@ private:
|
||||
class ExtendedRuntimePerfKnobConfig
|
||||
{
|
||||
public:
|
||||
explicit ExtendedRuntimePerfKnobConfig(bool multiBlockMode = true, bool enableContextFMHAFP32Acc = false);
|
||||
explicit ExtendedRuntimePerfKnobConfig(bool multiBlockMode = true, bool enableContextFMHAFP32Acc = false,
|
||||
bool cudaGraphMode = false, SizeType32 cudaGraphCacheSize = 0);
|
||||
|
||||
bool operator==(ExtendedRuntimePerfKnobConfig const& other) const
|
||||
{
|
||||
return mMultiBlockMode == other.mMultiBlockMode && mEnableContextFMHAFP32Acc == other.mEnableContextFMHAFP32Acc;
|
||||
return mMultiBlockMode == other.mMultiBlockMode && mEnableContextFMHAFP32Acc == other.mEnableContextFMHAFP32Acc
|
||||
&& mCudaGraphMode == other.mCudaGraphMode && mCudaGraphCacheSize == other.mCudaGraphCacheSize;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool getMultiBlockMode() const;
|
||||
[[nodiscard]] bool getEnableContextFMHAFP32Acc() const;
|
||||
[[nodiscard]] bool getCudaGraphMode() const;
|
||||
[[nodiscard]] SizeType32 getCudaGraphCacheSize() const;
|
||||
|
||||
void setMultiBlockMode(bool multiBlockMode);
|
||||
void setEnableContextFMHAFP32Acc(bool enableContextFMHAFP32Acc);
|
||||
void setCudaGraphMode(bool cudaGraphMode);
|
||||
void setCudaGraphCacheSize(SizeType32 cacheSize);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
@ -614,6 +627,13 @@ private:
|
||||
|
||||
/// @brief If enable FMHA runner FP32 accumulation.
|
||||
bool mEnableContextFMHAFP32Acc;
|
||||
|
||||
/// @brief Control if enable cuda graph.
|
||||
bool mCudaGraphMode;
|
||||
|
||||
/// @brief Number of cuda graphs to be cached in the runtime.
|
||||
/// The larger the cache, the better the perf, but more GPU memory is consumed.
|
||||
SizeType32 mCudaGraphCacheSize;
|
||||
};
|
||||
|
||||
/// @brief Configuration class for debugging output
|
||||
|
||||
@ -75,10 +75,10 @@ public:
|
||||
static void serialize(kv_cache::CacheState const& state, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(kv_cache::CacheState const& state);
|
||||
|
||||
// ContextPhaseState
|
||||
[[nodiscard]] static ContextPhaseState deserializeContextPhaseState(std::istream& is);
|
||||
static void serialize(ContextPhaseState const& contextPhaseState, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(ContextPhaseState const& contextPhaseState);
|
||||
// DataTransceiverState
|
||||
[[nodiscard]] static DataTransceiverState deserializeDataTransceiverState(std::istream& is);
|
||||
static void serialize(DataTransceiverState const& dataTransceiverState, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(DataTransceiverState const& dataTransceiverState);
|
||||
|
||||
// ContextPhaseParams
|
||||
[[nodiscard]] static ContextPhaseParams deserializeContextPhaseParams(std::istream& is);
|
||||
|
||||
@ -198,6 +198,10 @@ enum class CapacitySchedulerPolicy
|
||||
/// @brief GUARANTEED_NO_EVICT uses KV cache more conservatively guaranteeing that a request, once started, will run
|
||||
/// to completion without eviction.
|
||||
kGUARANTEED_NO_EVICT = 1,
|
||||
|
||||
/// @brief kSTATIC_BATCH does not schedule new requests until all requests in current batch are completed.
|
||||
/// Similar to kGUARANTEED_NO_EVICT, requests will run to completion without eviction.
|
||||
kSTATIC_BATCH = 2
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& os, CapacitySchedulerPolicy policy);
|
||||
|
||||
@ -62,12 +62,12 @@ public:
|
||||
void newRequests(std::vector<SizeType32> const& seqSlots, std::vector<decoder_batch::Request> const& requests,
|
||||
std::vector<SamplingConfig> const& samplingConfigs) override;
|
||||
|
||||
TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override;
|
||||
DecoderFinishedEventPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override;
|
||||
|
||||
void forwardSync(decoder_batch::Token const& token) override;
|
||||
void forwardSync(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent) override;
|
||||
|
||||
void forwardSync(
|
||||
decoder_batch::Token const& token, decoder_batch::Output& output, decoder_batch::Input const& input) override;
|
||||
void forwardSync(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent, decoder_batch::Output& output,
|
||||
decoder_batch::Input const& input) override;
|
||||
|
||||
void forwardAsync(decoder::Output& output, decoder::Input const& input) override;
|
||||
|
||||
@ -271,7 +271,7 @@ private:
|
||||
void newRequestExplicitDraftTokens(SizeType32 batchIdx, decoder_batch::Request const& request);
|
||||
|
||||
//! @brief Updates finished state on host for all active requests
|
||||
void updateFinished(decoder_batch::Token const& token);
|
||||
void updateFinished(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent);
|
||||
|
||||
//! @brief Sets inputs for explicit draft tokens.
|
||||
void setExplicitDraftTokensInputs(decoder_batch::Input const& input);
|
||||
@ -289,7 +289,7 @@ private:
|
||||
CudaStreamPtr mRuntimeStream;
|
||||
CudaStreamPtr mDecoderStream;
|
||||
BufferManager mBufferManager;
|
||||
TokenPtr mForwardToken;
|
||||
DecoderFinishedEventPtr mDecoderFinishEvent;
|
||||
CudaEvent mForwardEvent;
|
||||
|
||||
using GptDecoderPtr = std::unique_ptr<IGptDecoder>;
|
||||
|
||||
@ -75,11 +75,11 @@ public:
|
||||
|
||||
using Output = decoder::Output;
|
||||
|
||||
// TODO: is this a bad name to mix up with token concept in LLM? Would 'Event' be better? And should move to common.h
|
||||
class Token
|
||||
// used just as a container for easy returning / passing to function
|
||||
class DecoderFinishedEvent
|
||||
{
|
||||
public:
|
||||
explicit Token(CudaEvent&& event, std::vector<bool> const& active)
|
||||
explicit DecoderFinishedEvent(CudaEvent&& event, std::vector<bool> const& active)
|
||||
: event(std::move(event))
|
||||
, active(active)
|
||||
{
|
||||
@ -96,7 +96,7 @@ class IGptDecoderBatched : public virtual IStatefulGptDecoder
|
||||
public:
|
||||
using CudaStreamPtr = std::shared_ptr<CudaStream>;
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
using TokenPtr = std::unique_ptr<decoder_batch::Token const>;
|
||||
using DecoderFinishedEventPtr = std::unique_ptr<decoder_batch::DecoderFinishedEvent const>;
|
||||
|
||||
//! @brief Setup buffers for ExplicitDraftTokens decoding.
|
||||
virtual void setupExplicitDraftTokens(ExplicitDraftTokensBuffers::Inputs explicitDraftTokensBuffers) = 0;
|
||||
@ -105,15 +105,15 @@ public:
|
||||
virtual void setupLookahead(LookaheadDecodingBuffers lookaheadDecodingBuffers) = 0;
|
||||
|
||||
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
|
||||
virtual TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
|
||||
virtual DecoderFinishedEventPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
|
||||
|
||||
//! @brief Call decoder forwardSync and wait for the call to `forwardAsync` associated with a token to complete.
|
||||
virtual void forwardSync(
|
||||
decoder_batch::Token const& token, decoder_batch::Output& output, decoder_batch::Input const& input)
|
||||
virtual void forwardSync(decoder_batch::DecoderFinishedEvent const& token, decoder_batch::Output& output,
|
||||
decoder_batch::Input const& input)
|
||||
= 0;
|
||||
|
||||
//! @brief Wait for the call to `forwardAsync` associated with a token to complete.
|
||||
virtual void forwardSync(decoder_batch::Token const& token) = 0;
|
||||
virtual void forwardSync(decoder_batch::DecoderFinishedEvent const& token) = 0;
|
||||
|
||||
//! @brief Run one step for all requests and wait for completion on the host.
|
||||
virtual void forward(decoder_batch::Output& output, decoder_batch::Input const& input)
|
||||
|
||||
@ -62,6 +62,7 @@ public:
|
||||
TensorMap& inputBuffers, TensorMap& outputBuffers, runtime::WorldConfig const& worldConfig) const;
|
||||
|
||||
public:
|
||||
TensorPtr cumSumLength; // [1] the cumulative sum of generation length, on pinned
|
||||
TensorPtr packedMasksDevice; // [forwardBatchSize, tokensPerStep, numPackedMasks], on gpu
|
||||
TensorPtr generationLengthsDevice; // [forwardBatchSize], on gpu
|
||||
TensorPtr positionOffsetsDevice; // [forwardBatchSize, tokensPerStep], on gpu
|
||||
|
||||
@ -179,7 +179,7 @@ public:
|
||||
|
||||
static std::vector<LoraModule> createLoraModules(std::vector<std::string> const& loraModuleNames,
|
||||
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
|
||||
SizeType32 attentionHeadSize, SizeType32 tpSize);
|
||||
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts);
|
||||
|
||||
static ModuleType constexpr toModuleType(std::string_view const& name)
|
||||
{
|
||||
|
||||
@ -60,6 +60,9 @@ public:
|
||||
{
|
||||
kATTENTION,
|
||||
kRECURRENT,
|
||||
// NOTE: Linear and noop are attention alternatives introduced in Nemotron-NAS. They do not use the KV cache.
|
||||
kLINEAR,
|
||||
kNOOP,
|
||||
};
|
||||
|
||||
enum class KVCacheType : std::int32_t
|
||||
@ -97,13 +100,13 @@ public:
|
||||
kEnabled,
|
||||
};
|
||||
|
||||
explicit ModelConfig(SizeType32 vocabSize, SizeType32 nbAttentionLayers, SizeType32 nbRnnLayers, SizeType32 nbHeads,
|
||||
SizeType32 hiddenSize, nvinfer1::DataType dtype)
|
||||
explicit ModelConfig(SizeType32 vocabSize, SizeType32 nbLayers, SizeType32 nbAttentionLayers,
|
||||
SizeType32 nbRnnLayers, SizeType32 nbHeads, SizeType32 hiddenSize, nvinfer1::DataType dtype)
|
||||
: mVocabSize(vocabSize)
|
||||
, mNbLayers(nbLayers)
|
||||
, mNbAttentionLayers(nbAttentionLayers)
|
||||
, mNbRnnLayers(nbRnnLayers)
|
||||
, mNbHeads(nbHeads)
|
||||
, mNbKvHeads(nbHeads)
|
||||
, mHiddenSize(hiddenSize)
|
||||
, mSizePerHead(mHiddenSize / mNbHeads)
|
||||
, mDataType(dtype)
|
||||
@ -134,6 +137,10 @@ public:
|
||||
, mUseShapeInference(true)
|
||||
, mManageWeightsType(ManageWeightsType::kDisabled)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mNbLayers >= mNbAttentionLayers + mNbRnnLayers,
|
||||
"Number of layers (%d) expected to be >= number of attention (%d) + number of rnn layers (%d)", mNbLayers,
|
||||
mNbAttentionLayers, mNbRnnLayers);
|
||||
setNbKvHeads(mNbHeads);
|
||||
}
|
||||
|
||||
[[nodiscard]] static std::vector<SizeType32> getOptProfilesSplitPoints() noexcept
|
||||
@ -151,14 +158,55 @@ public:
|
||||
return (mVocabSize + worldSize - 1) / worldSize * worldSize;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getNbAttentionLayers(SizeType32 pipelineParallelism = 1) const
|
||||
[[nodiscard]] SizeType32 countLocalLayers(
|
||||
LayerType layerType, SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
|
||||
{
|
||||
return mNbAttentionLayers / pipelineParallelism;
|
||||
TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism);
|
||||
auto const numLocalLayers = mNbLayers / pipelineParallelism; // WARNING: assume no remainder
|
||||
auto const firstLocalLayerIt = mLayerTypes.cbegin() + (numLocalLayers * pipelineParallelismRank);
|
||||
return std::count(firstLocalLayerIt, firstLocalLayerIt + numLocalLayers, layerType);
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getNbRnnLayers(SizeType32 pipelineParallelism = 1) const
|
||||
[[nodiscard]] SizeType32 countLowerRankLayers(
|
||||
LayerType layerType, SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
|
||||
{
|
||||
return mNbRnnLayers / pipelineParallelism;
|
||||
auto const numLocalLayers = mNbLayers / pipelineParallelism; // WARNING: assume no remainder
|
||||
auto const firstLocalLayer = numLocalLayers * pipelineParallelismRank;
|
||||
// count number of previous non-local attention layers
|
||||
return std::count(mLayerTypes.cbegin(), mLayerTypes.cbegin() + firstLocalLayer, layerType);
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNbLayers(SizeType32 pipelineParallelism = 1) const
|
||||
{
|
||||
return mNbLayers / pipelineParallelism; // WARNING: assume no remainder
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNbAttentionLayers(
|
||||
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
|
||||
{
|
||||
// TODO(oargov): get rid of this invalid state
|
||||
if (mLayerTypes.empty())
|
||||
{
|
||||
// this assumption might be wrong in a few cases, for example:
|
||||
// layer types: [attention, recurrent, recurrent], pp=2 ==> first rank has 1 attention layer, not 0
|
||||
TLLM_LOG_DEBUG("Assuming uniform distribution of attention layers between ranks");
|
||||
return mNbAttentionLayers / pipelineParallelism;
|
||||
}
|
||||
return countLocalLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNbRnnLayers(
|
||||
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
|
||||
{
|
||||
// TODO(oargov): get rid of this invalid state
|
||||
if (mLayerTypes.empty())
|
||||
{
|
||||
// this assumption might be wrong in a few cases, for example:
|
||||
// layer types: [attention, attention, recurrent], pp=2 ==> second rank has 1 rnn layer, not 0
|
||||
TLLM_LOG_DEBUG("Assuming uniform distribution of recurrent layers between ranks");
|
||||
return mNbRnnLayers / pipelineParallelism;
|
||||
}
|
||||
return countLocalLayers(LayerType::kRECURRENT, pipelineParallelism, pipelineParallelismRank);
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getNbHeads() const noexcept
|
||||
@ -166,14 +214,16 @@ public:
|
||||
return mNbHeads;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getNbKvHeads() const noexcept
|
||||
[[nodiscard]] SizeType32 getNbKvHeads(SizeType32 layerIdx) const
|
||||
{
|
||||
return mNbKvHeads;
|
||||
TLLM_CHECK_WITH_INFO(layerIdx < mNbAttentionLayers, "Layer index %d is out of bounds", layerIdx);
|
||||
return mNumKvHeadsPerAttentionLayer[layerIdx];
|
||||
}
|
||||
|
||||
void constexpr setNbKvHeads(SizeType32 nbKvHeads) noexcept
|
||||
// set the number of kv heads for all layers
|
||||
void setNbKvHeads(SizeType32 nbKvHeads)
|
||||
{
|
||||
mNbKvHeads = nbKvHeads;
|
||||
mNumKvHeadsPerAttentionLayer = std::vector<SizeType32>(mNbAttentionLayers, nbKvHeads);
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 constexpr getHiddenSize() const noexcept
|
||||
@ -645,12 +695,46 @@ public:
|
||||
mModelName = modelName;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::vector<SizeType32> const& getNumKvHeadsPerLayer() const
|
||||
{
|
||||
return mNumKvHeadsPerAttentionLayer;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::pair<std::vector<SizeType32>::const_iterator, std::vector<SizeType32>::const_iterator>
|
||||
getNumKvHeadsPerLayerLocalRange(SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(pipelineParallelism > 0, "Invalid pipelineParallelism: %d", pipelineParallelism);
|
||||
// count number of previous non-local attention layers
|
||||
auto const numPrevAttnLayers
|
||||
= countLowerRankLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
|
||||
auto const firstLocalAttentionLayerIt = mNumKvHeadsPerAttentionLayer.cbegin() + numPrevAttnLayers;
|
||||
auto const numLocalAttentionLayers
|
||||
= countLocalLayers(LayerType::kATTENTION, pipelineParallelism, pipelineParallelismRank);
|
||||
return std::make_pair(firstLocalAttentionLayerIt, firstLocalAttentionLayerIt + numLocalAttentionLayers);
|
||||
}
|
||||
|
||||
void setNumKvHeadsPerLayer(std::vector<SizeType32> const& headsPerLayer)
|
||||
{
|
||||
auto const numElems = static_cast<SizeType32>(headsPerLayer.size());
|
||||
TLLM_CHECK_WITH_INFO(numElems == mNbAttentionLayers,
|
||||
"Length of head_per_layer (%d) must match number of attention layers (%d)", numElems, mNbAttentionLayers);
|
||||
mNumKvHeadsPerAttentionLayer = headsPerLayer;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getSumLocalKvHeads(
|
||||
SizeType32 pipelineParallelism = 1, SizeType32 pipelineParallelismRank = 0) const
|
||||
{
|
||||
auto [cbegin, cend] = getNumKvHeadsPerLayerLocalRange(pipelineParallelism, pipelineParallelismRank);
|
||||
auto const sumLocalHeads = std::reduce(cbegin, cend);
|
||||
return sumLocalHeads;
|
||||
}
|
||||
|
||||
private:
|
||||
SizeType32 mVocabSize;
|
||||
SizeType32 mNbLayers;
|
||||
SizeType32 mNbAttentionLayers;
|
||||
SizeType32 mNbRnnLayers;
|
||||
SizeType32 mNbHeads;
|
||||
SizeType32 mNbKvHeads;
|
||||
SizeType32 mHiddenSize;
|
||||
SizeType32 mSizePerHead;
|
||||
nvinfer1::DataType mDataType;
|
||||
@ -703,6 +787,7 @@ private:
|
||||
bool mUseShapeInference;
|
||||
ManageWeightsType mManageWeightsType;
|
||||
std::string mModelName;
|
||||
std::vector<SizeType32> mNumKvHeadsPerAttentionLayer;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -97,8 +97,7 @@ public:
|
||||
|
||||
[[nodiscard]] bool constexpr variableDraftLength() const
|
||||
{
|
||||
// Add Lookahead, when lookahead supports it.
|
||||
return anyBitSet(kDraftTokensExternal | kExplicitDraftTokens);
|
||||
return anyBitSet(kDraftTokensExternal | kExplicitDraftTokens | kLookaheadDecoding);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr hasDraftLogits() const
|
||||
|
||||
@ -348,9 +348,11 @@ endif()
|
||||
if(NOT WIN32) # Unix-like compilers
|
||||
set(UNDEFINED_FLAG "-Wl,--no-undefined")
|
||||
set(AS_NEEDED_FLAG "-Wl,--as-needed")
|
||||
set(NO_AS_NEEDED_FLAG "-Wl,--no-as-needed")
|
||||
else() # Windows
|
||||
set(UNDEFINED_FLAG "")
|
||||
set(AS_NEEDED_FLAG "")
|
||||
set(NO_AS_NEEDED_FLAG "")
|
||||
endif()
|
||||
|
||||
set(CMAKE_WINDOWS_EXPORT_ALL_SYMBOLS ON)
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e08b60b89bb4934490ee61383c55c22d831fa1cfcccedea5735400e3574aadbc
|
||||
size 4671466
|
||||
oid sha256:10b940475c5acd80a61674d8ce4e42cc4ef3d806bafb245bbed26751378274e3
|
||||
size 4904726
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2b6b3bf449c4b4d67f0bb9879af6b8eda6f46f272eaa5b7305582a2cc8c73e17
|
||||
size 4775694
|
||||
oid sha256:b2754f7887a1b5c37ba3d589320e16144039cfe5dc6a6c78ee71925861d7d511
|
||||
size 5015842
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
f229593e4699180b52e38f99c8ac31dc libtensorrt_llm_batch_manager_static.a
|
||||
440b3ae47982d88fc8517c5f01f67b3c libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
7adf157833793b3215570b0a95b9c4b2998a620c commit
|
||||
ff71eabd0ac6ede5398b5b6ce4e26dcf libtensorrt_llm_batch_manager_static.a
|
||||
846eb112a182973e7c3b0b193300b4b8 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cb9d3d05ef4b08df0fc02f39c053a4435b58f9431d1ce269439b2c1f0a055b21
|
||||
size 4523116
|
||||
oid sha256:13b8701dd767b414a5376a91905985979ad9d2b975465ac00835c04656ee6508
|
||||
size 4766226
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6597b35faffe93244d89595dc369ece59729c984871ad5aab531d714d39c8e49
|
||||
size 4487214
|
||||
oid sha256:cd0b73a017fc5c663235dcd724eb104ecc49d12ff29b6e3744be6ea952d027db
|
||||
size 4722522
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
c015186a31c789891a27e44f5a9ab9ec libtensorrt_llm_batch_manager_static.a
|
||||
cac21708838abf82b18e1846c40b5c79 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
7adf157833793b3215570b0a95b9c4b2998a620c commit
|
||||
1eb5c88f894f3361445d7254cbc29b03 libtensorrt_llm_batch_manager_static.a
|
||||
4e73341b23e8fb20b732ba08e03a54a8 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:85f18d38a66cd8b15c7e447be16171b9db854f2b2fe9dc49daa4f93fae9bc125
|
||||
size 30145896
|
||||
oid sha256:b4ac61c0b0816477c11bd6c66ec4c2f23f7b6e1400eacd8c07c333f79dec0bea
|
||||
size 30794956
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
8ebb7d383e97bcd738cc24b00d58a2d0 tensorrt_llm_batch_manager_static.lib
|
||||
7adf157833793b3215570b0a95b9c4b2998a620c commit
|
||||
eefe7310a60098897724f46cf4aa54f8 tensorrt_llm_batch_manager_static.lib
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
@ -314,6 +314,18 @@ void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType d
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
void MpiComm::allgatherv(void const* sendbuf, int sendcount, MpiType sendtype, void* recvbuf,
|
||||
std::vector<int> const& recvcounts, std::vector<int> const& displs, MpiType recvtype) const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
MPICHECK(MPI_Allgatherv(sendbuf, sendcount, getMpiDtype(sendtype), recvbuf, recvcounts.data(), displs.data(),
|
||||
getMpiDtype(recvtype), mComm));
|
||||
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:41eae80923f8634a635f2fce84fdbe33101ee6cf86c0a98ed4ce30a7f4cea350
|
||||
size 1782460
|
||||
oid sha256:ebab2cc2c62a826ddec02597178b8e0c9bc316726f37f8eef37c06795aebcf03
|
||||
size 1784658
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:dd4fb660b1c3e664a012e26cea949dcded855e133aa6aadd01157a15df3e0d44
|
||||
size 1808956
|
||||
oid sha256:4b630f89708614e63c67871e21b6e32bfde71acc51549b650c57048c0fa343e7
|
||||
size 1812686
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
2744c4784cfd34c7311148c7f7614757 libtensorrt_llm_executor_static.a
|
||||
d56af9e74a9d49e32860d89dcca024d0 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
7adf157833793b3215570b0a95b9c4b2998a620c commit
|
||||
136f1b9d2168cbb9011a341b267af9a2 libtensorrt_llm_executor_static.a
|
||||
183bd079377d6cd698d46370168a5726 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e9a5c64c347400297bc8b6907b4feaa890305aaf5c1b45ce57aca8fcae3e881f
|
||||
size 1846898
|
||||
oid sha256:e04c76f6441a49db4d3996c62b4055395ae018384d8ee2f02ea5f0c4c0843902
|
||||
size 1853180
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e4b5912ce9a1c13554f4b16d29deb6f2ad51477c56810b758ba488212f8e5dc9
|
||||
size 1757522
|
||||
oid sha256:95ba1a4b6bdcecbb592bbb42b4998bcb0eb1f45a318163635183bcde6950c4bf
|
||||
size 1764982
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
69dc85fe48625b6f8f684487f2048458 libtensorrt_llm_executor_static.a
|
||||
d46f0be3543e24c4df51ae287086ca52 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
7adf157833793b3215570b0a95b9c4b2998a620c commit
|
||||
dfbd0d424c150253ff758aa5bd37a971 libtensorrt_llm_executor_static.a
|
||||
e82866739fef1d6df8293541967924bf libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fc94de246db88ca06e5a7c20d04f78057cc3c2928b3fa3e79f49e8b9d90b76da
|
||||
size 19683718
|
||||
oid sha256:aa8ba34fb98c5407e3d6944245086158c61b2c784b15c7b923fdd156b942224d
|
||||
size 19670642
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
0e361ba639fa897f489f6d0f48cfe13f tensorrt_llm_executor_static.lib
|
||||
7adf157833793b3215570b0a95b9c4b2998a620c commit
|
||||
784ad1fabd3d02466f95fbc463b64f5b tensorrt_llm_executor_static.lib
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
@ -115,6 +115,11 @@ template <typename T>
|
||||
void invokeCumsumLastDim(SizeType32 batchSize, SizeType32 inputLength, void const* __restrict__ input,
|
||||
void* __restrict__ output, void* deviceTempStorage, size_t tempStorageBytes, cudaStream_t stream)
|
||||
{
|
||||
// For empty tensor support
|
||||
if (batchSize == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (deviceTempStorage != nullptr) // we need to use DeviceScan
|
||||
{
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
88c30973b9b3452baa3f063d34d08169 libtensorrt_llm_nvrtc_wrapper.so
|
||||
7adf157833793b3215570b0a95b9c4b2998a620c commit
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
@ -1,2 +1,2 @@
|
||||
95e9f87610383348e444d2d0b8396f2d libtensorrt_llm_nvrtc_wrapper.so
|
||||
7adf157833793b3215570b0a95b9c4b2998a620c commit
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6fc6f35c712d83404e40a7840a0c9d1f5157df61df91a7207c4e4131783f4676
|
||||
oid sha256:1471e322bb44cd65b98ee30e0befa32ae4c86e828f0b4fd4f02d4af4e710d08f
|
||||
size 1128448
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e74ab8e65851dfc44e015714fe166f521649b781c85bd0215d42b488218e9ca5
|
||||
oid sha256:e207a8f57b944529163c7ed2ab30639a5f2779c5118602c6ebd50a623d16f845
|
||||
size 3488
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
700fc148d9a0f939e0088bf69e899360 tensorrt_llm_nvrtc_wrapper.lib
|
||||
6ea6ac6dff8793afbd79dd5768daae85 tensorrt_llm_nvrtc_wrapper.dll
|
||||
7adf157833793b3215570b0a95b9c4b2998a620c commit
|
||||
b7e624ba775e9f5090ef4b67bcdbd7a2 tensorrt_llm_nvrtc_wrapper.lib
|
||||
f9b1cc37a27dd0574bb41a2763a97be7 tensorrt_llm_nvrtc_wrapper.dll
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:f4cc1e6f0b6d1e7bc875284275b591d34c707471e636019b4c2904f30798dbc9
|
||||
oid sha256:9117f7cf5eef0ed452c0d0bc79242b84def103e7038c9d3df6e366690801ca92
|
||||
size 25364090
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:658a3f0bc5b9877e5ad447437287908dc9b7df87ae0e86f5338aaf81e26f723e
|
||||
oid sha256:2b04913f9e9029a5ce5a222d5cc7492ff53323a548079d2fb32d5b2aeb0c2268
|
||||
size 25768990
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
979f3165fbc7a68528df6e343cc54e3f libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
68e84c294a658734a8b26d7270540e1d libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
7adf157833793b3215570b0a95b9c4b2998a620c commit
|
||||
d54fb93f256601f4c4ad7f1c8e6e9919 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
71028d801074f11138e890391e48591d libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:24358d8eb15a5e802cbee6d2d27735033eedb33091e9355d199229c3ba7b6447
|
||||
oid sha256:d8c685f8ea2f84838dfdbf448eab41c76fe88fe29db0d4a511d6d6d241ad1832
|
||||
size 44173632
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4b54efb78e98bf6580f73a3b6b689823d7c2eb851c0bab5f36906f4ebbfc44fc
|
||||
size 43561142
|
||||
oid sha256:b9d75392ba3b59853c43072b4f9949b32cb6724813a39048e4585e9a8fb3e136
|
||||
size 43561206
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
0b49d88b8b5e83c8c6997c725a37f373 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
7a12fc880d2a13ee5c7cf2b1e169cb19 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
7adf157833793b3215570b0a95b9c4b2998a620c commit
|
||||
4fc3e1fb0db6a121f88a9141605d9285 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
253731af750407020dbe6f2fbe50fa2b libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5ba6f8610db3b967f3de4beeff6394cdd3e56d15916f39110ed932c3c3a65417
|
||||
size 88141376
|
||||
oid sha256:62af58f5e09d1cf5e347b02ef3bd3a186469162fc9645d038fb2cba23b597722
|
||||
size 88140804
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
5c26d1347bb8b47288d598b6d7444900 tensorrt_llm_internal_cutlass_kernels_static.lib
|
||||
7adf157833793b3215570b0a95b9c4b2998a620c commit
|
||||
eb7fc4a105eb6e6f52ba865f2b055233 tensorrt_llm_internal_cutlass_kernels_static.lib
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
@ -210,8 +210,6 @@ struct LookaheadSetupParams : public DecodingSetupParams
|
||||
TensorPtr positionOffsets;
|
||||
//! see LookaheadDecodingOutputs::attentionPackedMasks
|
||||
TensorPtr attentionPackedMasks;
|
||||
//! see LookaheadDecodingOutputs::actualGenerationLengths
|
||||
TensorPtr actualGenerationLengths;
|
||||
};
|
||||
|
||||
class BaseDecodingInputs
|
||||
@ -551,8 +549,6 @@ public:
|
||||
TensorPtr positionOffsets;
|
||||
//! [maxBatchSize, maxDecodingTokens]
|
||||
TensorPtr positionIds;
|
||||
//! The actual decoding tokens length, for debug and for future.
|
||||
TensorPtr actualGenerationLengths;
|
||||
};
|
||||
|
||||
class ExplicitDraftTokensOutputs : public SpeculativeDecodingOutputs
|
||||
|
||||
@ -18,8 +18,12 @@
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/lookaheadModule.h"
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
@ -27,6 +31,36 @@ namespace tensorrt_llm::layers
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
LookaheadAlgorithm::LookaheadAlgorithm(
|
||||
runtime::SizeType32 maxW, runtime::SizeType32 maxN, runtime::SizeType32 maxG, runtime::SizeType32 id)
|
||||
: mMaxW(maxW)
|
||||
, mMaxN(maxN)
|
||||
, mMaxG(maxG)
|
||||
, mFilling(0)
|
||||
, mPoolManager(maxG)
|
||||
, mId(id)
|
||||
, mGoldenTokensMax(
|
||||
runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxN * 2 - 1}), nvinfer1::DataType::kINT32))
|
||||
, mPrefillsMax(runtime::BufferManager::cpu(
|
||||
runtime::ITensor::makeShape({(maxN <= 1 ? 0 : maxN - 2)}), nvinfer1::DataType::kINT32))
|
||||
, mKeyTokensMax(runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW}), nvinfer1::DataType::kINT32))
|
||||
, mPastTokensMax(
|
||||
runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW * (maxN - 1)}), nvinfer1::DataType::kINT32))
|
||||
, mGuessTokensMax(
|
||||
runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxG * (maxN - 1)}), nvinfer1::DataType::kINT32))
|
||||
{
|
||||
runtime::SizeType32 maxGeneratedLen, maxDraftLen;
|
||||
std::tie(maxGeneratedLen, std::ignore, maxDraftLen, std::ignore)
|
||||
= executor::LookaheadDecodingConfig(maxW, maxN, maxG).calculateSpeculativeResource();
|
||||
mAttentionMask = runtime::BufferManager::cpu(
|
||||
runtime::ITensor::makeShape({maxDraftLen, maxDraftLen}), nvinfer1::DataType::kBOOL);
|
||||
mDraftTokensMax
|
||||
= runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxDraftLen}), nvinfer1::DataType::kINT32);
|
||||
mSampledTokensMax
|
||||
= runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxGeneratedLen}), nvinfer1::DataType::kINT32);
|
||||
mEncodeMapMax = runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxDraftLen}), nvinfer1::DataType::kINT32);
|
||||
}
|
||||
|
||||
void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeType32 n, SizeType32 g)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
@ -36,7 +70,7 @@ void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeT
|
||||
mW = w;
|
||||
mN = n;
|
||||
mG = g;
|
||||
std::tie(std::ignore, std::ignore, mRuntimeMaxDraftLen, std::ignore)
|
||||
std::tie(std::ignore, std::ignore, mRuntimeMaxDraftLen, mRuntimeMaxDraftPathLen)
|
||||
= executor::LookaheadDecodingConfig(mW, mN, mG).calculateSpeculativeResource();
|
||||
|
||||
mPoolManager.setup(mG);
|
||||
@ -81,8 +115,8 @@ void LookaheadAlgorithm::accept(TensorConstPtr const& generatedTokens)
|
||||
}
|
||||
|
||||
//! lookahead has two phase, prefill the past tokens matrix and maintain past tokens matrix.
|
||||
runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens, TensorPtr const& positionIds,
|
||||
TensorPtr const& samplingMask, runtime::SizeType32 offset)
|
||||
runtime::SizeType32 LookaheadAlgorithm::lookahead(
|
||||
TensorPtr const& draftTokens, TensorPtr const& positionIds, runtime::SizeType32 startPosId)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -90,7 +124,6 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens,
|
||||
SizeType32 len = prefill + mFilling * mW;
|
||||
TLLM_CHECK(len <= ITensor::volume(draftTokens->getShape()));
|
||||
TLLM_CHECK(len <= ITensor::volume(positionIds->getShape()));
|
||||
TLLM_CHECK(len <= ITensor::volume(samplingMask->getShape()));
|
||||
BufferRange<TokenIdType> prefillRange(*mPrefills);
|
||||
BufferRange<TokenIdType> pastRange(*mPastTokens);
|
||||
BufferRange<TokenIdType> draftRange(*draftTokens);
|
||||
@ -112,11 +145,6 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens,
|
||||
}
|
||||
|
||||
BufferRange<TokenIdType> positionIdsRange(*positionIds);
|
||||
BufferRange<bool> samplingMaskRange(*samplingMask);
|
||||
for (auto& v : samplingMaskRange)
|
||||
{
|
||||
v = 0;
|
||||
}
|
||||
SizeType32 idx = 0, wj = 0;
|
||||
auto fillPosition = [&positionIdsRange, &idx](SizeType32 start, SizeType32 len)
|
||||
{
|
||||
@ -127,20 +155,18 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens,
|
||||
};
|
||||
if (prefill >= 0)
|
||||
{
|
||||
fillPosition(offset, prefill);
|
||||
fillPosition(startPosId, prefill);
|
||||
for (wj = 0; wj < mW; wj++)
|
||||
{
|
||||
fillPosition(offset + prefill + wj, mFilling);
|
||||
samplingMaskRange[prefill + wj * mFilling + mFilling - 1] = true;
|
||||
fillPosition(startPosId + prefill + wj, mFilling);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
fillPosition(offset, mFilling - 1);
|
||||
fillPosition(startPosId, mFilling - 1);
|
||||
for (wj = 1; wj < mW; wj++)
|
||||
{
|
||||
fillPosition(offset - 1 + wj, mFilling);
|
||||
samplingMaskRange[wj * mFilling + mFilling - 1 - 1] = true;
|
||||
fillPosition(startPosId - 1 + wj, mFilling);
|
||||
}
|
||||
}
|
||||
PRINT_VALUES(positionIds);
|
||||
@ -150,7 +176,7 @@ runtime::SizeType32 LookaheadAlgorithm::lookahead(TensorPtr const& draftTokens,
|
||||
}
|
||||
|
||||
runtime::SizeType32 LookaheadAlgorithm::guess(TensorPtr const& guessTokens, TensorPtr const& guessIds,
|
||||
TensorPtr const& samplingMask, runtime::SizeType32 offset, runtime::TokenIdType lastToken)
|
||||
runtime::SizeType32 startPosId, runtime::TokenIdType lastToken)
|
||||
{
|
||||
auto guesses = mPoolManager.guess(lastToken, mW);
|
||||
|
||||
@ -158,67 +184,227 @@ runtime::SizeType32 LookaheadAlgorithm::guess(TensorPtr const& guessTokens, Tens
|
||||
std::for_each(guesses.begin(), guesses.end(), [&len](auto& a) { len += ITensor::volume(a->getShape()); });
|
||||
TLLM_CHECK(len <= ITensor::volume(guessTokens->getShape()));
|
||||
TLLM_CHECK(len <= ITensor::volume(guessIds->getShape()));
|
||||
TLLM_CHECK(len <= ITensor::volume(samplingMask->getShape()));
|
||||
BufferRange<TokenIdType> guessTokensRange(*guessTokens);
|
||||
BufferRange<SizeType32> guessIdsRange(*guessIds);
|
||||
BufferRange<bool> samplingMaskRange(*samplingMask);
|
||||
|
||||
SizeType32 cur = 0;
|
||||
for (auto guess : guesses)
|
||||
{
|
||||
BufferRange<TokenIdType const> guessRange(*guess);
|
||||
std::copy(guessRange.begin(), guessRange.end(), guessTokensRange.begin() + cur);
|
||||
SizeType32 tmp = offset;
|
||||
SizeType32 tmp = startPosId;
|
||||
std::for_each(
|
||||
guessIdsRange.begin() + cur, guessIdsRange.begin() + cur + mN - 1, [&tmp](auto& v) { v = tmp++; });
|
||||
cur += ITensor::volume(guess->getShape());
|
||||
}
|
||||
|
||||
std::for_each(samplingMaskRange.begin(), samplingMaskRange.begin() + len, [](auto& a) { a = true; });
|
||||
|
||||
return len;
|
||||
}
|
||||
|
||||
void LookaheadAlgorithm::posIdsToMask(TensorPtr const& mask, TensorConstPtr const& posIds)
|
||||
{
|
||||
auto len = ITensor::volume(posIds->getShape());
|
||||
TLLM_CHECK(mask->getDimension<0>() >= len);
|
||||
TLLM_CHECK(mask->getDimension<1>() >= len);
|
||||
auto posIdsRange = BufferRange<SizeType32 const>(*posIds);
|
||||
auto maskLocation = BufferLocation<bool>(*mask);
|
||||
|
||||
for (auto& item : maskLocation)
|
||||
{
|
||||
item = false;
|
||||
}
|
||||
|
||||
if (len > 0)
|
||||
{
|
||||
std::vector<std::pair<SizeType32, SizeType32>> stack;
|
||||
for (auto i = 0; i < len; i++)
|
||||
{
|
||||
auto cur = posIdsRange[i];
|
||||
while (stack.size() > 0 && cur <= stack.back().second)
|
||||
{
|
||||
stack.pop_back();
|
||||
}
|
||||
TLLM_CHECK(stack.size() > 0 ? cur == stack.back().second + 1 : true);
|
||||
stack.push_back(std::make_pair(i, cur));
|
||||
for (auto prev : stack)
|
||||
{
|
||||
maskLocation.at(i, prev.first) = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct TreeValue;
|
||||
using TreeMap = std::unordered_map<TokenIdType, TreeValue>;
|
||||
|
||||
struct TreeValue
|
||||
{
|
||||
TreeValue()
|
||||
: nexts(std::make_shared<TreeMap>())
|
||||
{
|
||||
}
|
||||
|
||||
using Nexts = std::shared_ptr<TreeMap>;
|
||||
Nexts nexts{nullptr};
|
||||
std::list<SizeType32> sources;
|
||||
};
|
||||
|
||||
using TreeNode = TreeMap::value_type;
|
||||
|
||||
template <typename BF, typename AF>
|
||||
void treeDFS(TreeNode& node, BF const& visitBefore, AF const& visitAfter)
|
||||
{
|
||||
visitBefore(node);
|
||||
for (auto& next : *(node.second.nexts))
|
||||
{
|
||||
treeDFS(next, visitBefore, visitAfter);
|
||||
}
|
||||
visitAfter(node);
|
||||
}
|
||||
|
||||
SizeType32 LookaheadAlgorithm::treeEncode(
|
||||
TensorPtr const& tokens, TensorPtr const& posIds, TensorPtr const& mask, TensorPtr const& encodeMap)
|
||||
{
|
||||
TLLM_CHECK(ITensor::volume(tokens->getShape()) == ITensor::volume(posIds->getShape()));
|
||||
auto len = ITensor::volume(tokens->getShape());
|
||||
|
||||
BufferRange<TokenIdType> tokensRange(*tokens);
|
||||
BufferRange<SizeType32> posIdsRange(*posIds);
|
||||
BufferLocation<bool> maskLocation(*mask);
|
||||
BufferRange<SizeType32> mapRange(*encodeMap);
|
||||
|
||||
auto branches = std::make_shared<TreeMap>();
|
||||
|
||||
for (auto i = 0; i < len; i++)
|
||||
{
|
||||
auto nexts = branches;
|
||||
for (auto j = 0; j <= i; j++)
|
||||
{
|
||||
if (maskLocation.at(i, j))
|
||||
{
|
||||
auto pos = posIdsRange[j];
|
||||
auto tok = tokensRange[j];
|
||||
auto found = nexts->find(tok);
|
||||
if (found != nexts->end())
|
||||
{
|
||||
found->second.sources.push_back(j);
|
||||
nexts = found->second.nexts;
|
||||
}
|
||||
else
|
||||
{
|
||||
auto [inserted, ok] = nexts->insert({tok, TreeValue()});
|
||||
inserted->second.sources.push_back(j);
|
||||
nexts = inserted->second.nexts;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (auto& item : maskLocation)
|
||||
{
|
||||
item = 0;
|
||||
}
|
||||
std::vector<std::pair<SizeType32, TokenIdType>> stack;
|
||||
SizeType32 offset = 0;
|
||||
SizeType32 posId = posIdsRange.size() ? posIdsRange[0] : 0;
|
||||
|
||||
auto visitBefore
|
||||
= [&stack, &maskLocation, &tokensRange, &posIdsRange, &posId, &offset, &mapRange](TreeNode const& node)
|
||||
{
|
||||
stack.push_back(std::make_pair(offset, node.first));
|
||||
for (auto const& source : node.second.sources)
|
||||
{
|
||||
mapRange[source] = offset;
|
||||
}
|
||||
for (auto const& prev : stack)
|
||||
{
|
||||
maskLocation.at(offset, prev.first) = true;
|
||||
}
|
||||
tokensRange[offset] = node.first;
|
||||
posIdsRange[offset] = posId;
|
||||
offset++;
|
||||
posId++;
|
||||
};
|
||||
auto visitAfter = [&stack, &posId](TreeNode const& node)
|
||||
{
|
||||
stack.pop_back();
|
||||
posId--;
|
||||
};
|
||||
|
||||
for (auto& next : *branches)
|
||||
{
|
||||
treeDFS(next, visitBefore, visitAfter);
|
||||
}
|
||||
|
||||
for (SizeType32 i = offset; i < len; i++)
|
||||
{
|
||||
tokensRange[i] = 0;
|
||||
posIdsRange[i] = 0;
|
||||
}
|
||||
for (SizeType32 i = 0; i < len; i++)
|
||||
{
|
||||
for (SizeType32 j = i < offset ? offset : 0; j < len; j++)
|
||||
{
|
||||
maskLocation.at(i, j) = false;
|
||||
}
|
||||
}
|
||||
|
||||
return offset;
|
||||
}
|
||||
|
||||
void LookaheadAlgorithm::prepare(TensorPtr const& draftTokens, TensorPtr const& positionIds,
|
||||
TensorPtr const& samplingMask, TensorPtr const& length, TensorConstPtr const& offsetPtr,
|
||||
TensorConstPtr const& lastTokenPtr)
|
||||
TensorPtr const& draftLengthPtr, TensorPtr const& attentionMask, SizeType32 attentionMaskOffset,
|
||||
TensorConstPtr const& lastPositionIdPtr, TensorConstPtr const& lastTokenPtr)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
if (mRuntimeMaxDraftLen == 0)
|
||||
{
|
||||
(BufferRange<SizeType32>(*length))[0] = 0;
|
||||
mDraftTokens = ITensor::slice(mDraftTokensMax, 0, 0);
|
||||
mEncodeMap = ITensor::slice(mEncodeMapMax, 0, 0);
|
||||
(BufferRange<SizeType32>(*draftLengthPtr))[0] = 0;
|
||||
return;
|
||||
}
|
||||
|
||||
auto lastToken = BufferRange<TokenIdType const>(*lastTokenPtr)[0];
|
||||
auto offset = BufferRange<SizeType32 const>(*offsetPtr)[0];
|
||||
auto offset = BufferRange<SizeType32 const>(*lastPositionIdPtr)[0];
|
||||
|
||||
SizeType32 inputLen = ITensor::volume(draftTokens->getShape());
|
||||
TLLM_CHECK(inputLen >= mRuntimeMaxDraftLen);
|
||||
|
||||
BufferRange<TokenIdType> draftRange(*draftTokens);
|
||||
BufferRange<TokenIdType> positionRange(*positionIds);
|
||||
BufferRange<bool> samplingRange(*samplingMask);
|
||||
|
||||
SizeType32 filledLen = 0;
|
||||
|
||||
filledLen += lookahead(ITensor::slice(draftTokens, filledLen, mRuntimeMaxDraftLen - filledLen),
|
||||
ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen),
|
||||
ITensor::slice(samplingMask, filledLen, mRuntimeMaxDraftLen - filledLen), offset);
|
||||
ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen), offset);
|
||||
|
||||
auto guessStart = filledLen;
|
||||
filledLen += guess(ITensor::slice(draftTokens, filledLen, mRuntimeMaxDraftLen - filledLen),
|
||||
ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen),
|
||||
ITensor::slice(samplingMask, filledLen, mRuntimeMaxDraftLen - filledLen), offset, lastToken);
|
||||
ITensor::slice(positionIds, filledLen, mRuntimeMaxDraftLen - filledLen), offset, lastToken);
|
||||
auto guessEnd = filledLen;
|
||||
|
||||
std::copy(draftRange.begin() + guessStart, draftRange.begin() + guessEnd,
|
||||
BufferRange<TokenIdType>(*mGuessTokensMax).begin());
|
||||
mGuessTokens = ITensor::slice(mGuessTokensMax, 0, guessEnd - guessStart);
|
||||
|
||||
std::copy(draftRange.begin() + guessStart, draftRange.begin() + guessEnd,
|
||||
BufferRange<TokenIdType>(*mGuessTokens).begin());
|
||||
posIdsToMask(mAttentionMask, ITensor::slice(positionIds, 0, filledLen));
|
||||
|
||||
(BufferRange<SizeType32>(*length))[0] = filledLen;
|
||||
auto draftLen = treeEncode(ITensor::slice(draftTokens, 0, filledLen), ITensor::slice(positionIds, 0, filledLen),
|
||||
mAttentionMask, mEncodeMapMax);
|
||||
|
||||
for (SizeType32 i = 0; i < draftLen; i++)
|
||||
{
|
||||
BufferRange<bool> srcRange(*ITensor::at(mAttentionMask, {i}));
|
||||
BufferRange<bool> dstRange(*ITensor::slice(attentionMask, {i + attentionMaskOffset, attentionMaskOffset}));
|
||||
std::copy(srcRange.begin(), srcRange.end(), dstRange.begin());
|
||||
}
|
||||
|
||||
std::copy(draftRange.begin(), draftRange.begin() + draftLen, BufferRange<TokenIdType>(*mDraftTokensMax).begin());
|
||||
mDraftTokens = ITensor::slice(mDraftTokensMax, 0, draftLen);
|
||||
(BufferRange<SizeType32>(*draftLengthPtr))[0] = draftLen;
|
||||
mEncodeMap = ITensor::slice(mEncodeMapMax, 0, filledLen);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -229,29 +415,31 @@ void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acce
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK(ITensor::volume(goldenTokens->getShape()) == ITensor::volume(mGuessTokens->getShape()));
|
||||
TLLM_CHECK(ITensor::volume(goldenTokens->getShape()) == ITensor::volume(mDraftTokens->getShape()));
|
||||
BufferRange<TokenIdType const> goldRange(*goldenTokens);
|
||||
BufferRange<TokenIdType> guessTokensRange(*mGuessTokens);
|
||||
auto guessSize = ITensor::volume(mGuessTokens->getShape());
|
||||
BufferRange<TokenIdType> draftRange(*mDraftTokens);
|
||||
BufferLocation<bool const> maskLocation(*mAttentionMask);
|
||||
auto draftSize = ITensor::volume(mDraftTokens->getShape());
|
||||
auto end = *BufferRange<TokenIdType const>(*endToken).begin();
|
||||
|
||||
SizeType32 guesses = (mN - 1 > 0) ? (guessSize / (mN - 1)) : 0;
|
||||
SizeType32 hit = 0, maxHit = 0, hitIdx = 0;
|
||||
for (SizeType32 i = 0; i < guesses; i++)
|
||||
SizeType32 maxHit = 0, hitIdx = 0;
|
||||
for (SizeType32 i = 0; i < draftSize; i++)
|
||||
{
|
||||
SizeType32 hit = 0;
|
||||
for (SizeType32 j = 0; j < mN - 1; j++)
|
||||
TokenIdType cur = newLastToken;
|
||||
for (SizeType32 j = 0; j < draftSize; j++)
|
||||
{
|
||||
auto idx = i * (mN - 1) + j;
|
||||
bool ok
|
||||
= (j == 0) ? (newLastToken == guessTokensRange[idx]) : (goldRange[idx - 1] == guessTokensRange[idx]);
|
||||
bool finish = guessTokensRange[idx] == *BufferRange<TokenIdType const>(*endToken).begin();
|
||||
if (ok && !finish)
|
||||
if (maskLocation.at(i, j))
|
||||
{
|
||||
hit++;
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
if (draftRange[j] == cur && draftRange[j] != end)
|
||||
{
|
||||
hit++;
|
||||
cur = goldRange[j];
|
||||
}
|
||||
else
|
||||
{
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (hit > maxHit)
|
||||
@ -261,17 +449,19 @@ void LookaheadAlgorithm::verify(TensorPtr const& accepted, TensorPtr const& acce
|
||||
}
|
||||
}
|
||||
|
||||
BufferRange<TokenIdType> acceptedRange(*accepted);
|
||||
acceptedRange[0] = newLastToken;
|
||||
std::copy(goldRange.begin() + hitIdx * (mN - 1), goldRange.begin() + hitIdx * (mN - 1) + maxHit,
|
||||
acceptedRange.begin() + 1);
|
||||
maxHit = maxHit > mRuntimeMaxDraftPathLen ? mRuntimeMaxDraftPathLen : maxHit;
|
||||
|
||||
SizeType32 acceptedIdx = 0;
|
||||
BufferRange<TokenIdType> acceptedRange(*accepted);
|
||||
BufferRange<SizeType32> acceptedOffsetsRange(*acceptedOffsets);
|
||||
auto lookSize = 1 + mN - 2 - mFilling + mFilling * mW;
|
||||
// acceptedOffsetsRange[0] = 0;
|
||||
for (SizeType32 i = 0; i < maxHit; i++)
|
||||
acceptedRange[acceptedIdx] = newLastToken;
|
||||
for (SizeType32 j = 0; j < draftSize; j++)
|
||||
{
|
||||
acceptedOffsetsRange[i] = lookSize + hitIdx * (mN - 1) + i - 1;
|
||||
if (maskLocation.at(hitIdx, j) && acceptedIdx < maxHit)
|
||||
{
|
||||
acceptedOffsetsRange[acceptedIdx++] = j;
|
||||
acceptedRange[acceptedIdx] = goldRange[j];
|
||||
}
|
||||
}
|
||||
|
||||
*BufferRange<SizeType32>(*acceptedLength).begin() = maxHit + 1;
|
||||
@ -325,7 +515,19 @@ void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK(ITensor::volume(acceptedTokens->getShape()) >= mN);
|
||||
BufferRange<TokenIdType const> sampledRange(*sampledTokens);
|
||||
BufferRange<TokenIdType const> zippedTokensRange(*sampledTokens);
|
||||
BufferRange<TokenIdType const> sampledRange(*mSampledTokensMax);
|
||||
|
||||
BufferRange<SizeType32 const> mapRange(*mEncodeMap);
|
||||
BufferRange<TokenIdType> unzipRange(*mSampledTokensMax);
|
||||
mSampledTokens = ITensor::slice(mSampledTokensMax, 0, mEncodeMap->getShape().d[0] + 1);
|
||||
|
||||
unzipRange[0] = zippedTokensRange[0];
|
||||
for (SizeType32 i = 0; i < mapRange.size(); i++)
|
||||
{
|
||||
unzipRange[i + 1] = zippedTokensRange[mapRange[i] + 1];
|
||||
}
|
||||
|
||||
BufferRange<TokenIdType> keyRange(*mKeyTokens);
|
||||
BufferRange<TokenIdType> pastRange(*mPastTokens);
|
||||
|
||||
@ -359,13 +561,15 @@ void LookaheadAlgorithm::update(TensorPtr const& acceptedTokens, TensorPtr const
|
||||
}
|
||||
|
||||
auto guessSize = ITensor::volume(mGuessTokens->getShape());
|
||||
auto outputSize = ITensor::volume(sampledTokens->getShape());
|
||||
auto outputSize = ITensor::volume(mSampledTokens->getShape());
|
||||
auto lookSize = 1 + (mN > 1 ? mN - 2 : 0) - mFilling + mFilling * mW;
|
||||
TLLM_CHECK(guessSize + lookSize == outputSize);
|
||||
|
||||
TensorConstPtr goldenTokens = ITensor::slice(sampledTokens, lookSize, guessSize);
|
||||
TensorConstPtr goldenTokens = ITensor::slice(mSampledTokens, lookSize, guessSize);
|
||||
|
||||
verify(acceptedTokens, acceptedOffsets, acceptedLength, newLastToken, goldenTokens, endToken);
|
||||
auto& acptLen = *BufferRange<SizeType32>(*acceptedLength).begin();
|
||||
|
||||
verify(acceptedTokens, acceptedOffsets, acceptedLength, newLastToken, ITensor::slice(sampledTokens, 1), endToken);
|
||||
|
||||
accept(ITensor::slice(acceptedTokens, 0, *BufferRange<SizeType32>(*acceptedLength).begin()));
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include <curand_kernel.h>
|
||||
#include <tuple>
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
@ -35,24 +36,7 @@ public:
|
||||
//! @brief Currently the resource management is to be aligned with batch manager.
|
||||
//! @param w, n, g is the Jacobi window, n-gram level and guess set size respectively.
|
||||
LookaheadAlgorithm(
|
||||
runtime::SizeType32 maxW, runtime::SizeType32 maxN, runtime::SizeType32 maxG, runtime::SizeType32 id = 0)
|
||||
: mMaxW(maxW)
|
||||
, mMaxN(maxN)
|
||||
, mMaxG(maxG)
|
||||
, mFilling(0)
|
||||
, mPoolManager(maxG)
|
||||
, mId(id)
|
||||
, mGoldenTokensMax(
|
||||
runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxN * 2 - 1}), nvinfer1::DataType::kINT32))
|
||||
, mPrefillsMax(runtime::BufferManager::cpu(
|
||||
runtime::ITensor::makeShape({(maxN <= 1 ? 0 : maxN - 2)}), nvinfer1::DataType::kINT32))
|
||||
, mKeyTokensMax(runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW}), nvinfer1::DataType::kINT32))
|
||||
, mPastTokensMax(
|
||||
runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxW * (maxN - 1)}), nvinfer1::DataType::kINT32))
|
||||
, mGuessTokensMax(
|
||||
runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxG * (maxN - 1)}), nvinfer1::DataType::kINT32))
|
||||
{
|
||||
}
|
||||
runtime::SizeType32 maxW, runtime::SizeType32 maxN, runtime::SizeType32 maxG, runtime::SizeType32 id = 0);
|
||||
|
||||
//! @brief setup per request, fill internal states from @param prompt.
|
||||
void setup(TensorConstPtr const& prompt, runtime::SizeType32 w, runtime::SizeType32 n, runtime::SizeType32 g);
|
||||
@ -62,43 +46,55 @@ public:
|
||||
void accept(TensorConstPtr const& generatedTokens);
|
||||
|
||||
//! @brief combine lookahead and guess to prepare the tensors.
|
||||
//! input @param offsetPtr is position id of the last golden token, in a TensorPtr.
|
||||
//! input @param lastPositionIdPtr is position id of the last golden token, in a TensorPtr.
|
||||
//! input @param lastTokenPtr the last golden token for searching in the pool, in a TensorPtr.
|
||||
//! output @param draftTokens, positionIds, samplingMask; including the golden token, the lookahead
|
||||
//! and the verification branch information. @param length holds the draft tokens length.
|
||||
void prepare(TensorPtr const& draftTokens, TensorPtr const& positionIds, TensorPtr const& samplingMask,
|
||||
TensorPtr const& length, TensorConstPtr const& offsetPtr, TensorConstPtr const& lastTokenPtr);
|
||||
//! output @param draftTokens, positionIds includes the lookahead and the verification branch information.
|
||||
//! output @param draftLengthPtr holds the draft tokens length.
|
||||
//! output @param attentionMask holds the draft tokens dependency mask, and attentionMaskOffset is the index offset
|
||||
//! in attentionMask.
|
||||
void prepare(TensorPtr const& draftTokens, TensorPtr const& positionIds, TensorPtr const& draftLengthPtr,
|
||||
TensorPtr const& attentionMask, runtime::SizeType32 attentionMaskOffset,
|
||||
TensorConstPtr const& lastPositionIdPtr, TensorConstPtr const& lastTokenPtr);
|
||||
|
||||
//! @brief update the internal states and generate accepted tokens from @param outputTokens.
|
||||
//! input @param sampledTokens is the all the tokens from the language model. The position at samplingMask=1 is
|
||||
//! valid. input @param endToken is the end token for `verify` early quit.
|
||||
//! output @param acceptedTokens, acceptedOffsets ind @param acceptedLength.
|
||||
//! input @param sampledTokens is the all the tokens from the language model.
|
||||
//! input @param endToken is the end token for `verify` early quit.
|
||||
//! output @param acceptedTokens, acceptedOffsets in @param acceptedLength.
|
||||
void update(TensorPtr const& acceptedTokens, TensorPtr const& acceptedOffsets, TensorPtr const& acceptedLength,
|
||||
TensorConstPtr const& sampledTokens, TensorConstPtr const& endToken);
|
||||
|
||||
//! generate attention @param mask from @param posIds.
|
||||
static void posIdsToMask(TensorPtr const& mask, TensorConstPtr const& posIds);
|
||||
|
||||
//! inplace encode the @param tokens and @param posIds according to attention @param masks, and record the offsets
|
||||
//! in @param encodeMap.
|
||||
static runtime::SizeType32 treeEncode(
|
||||
TensorPtr const& tokens, TensorPtr const& posIds, TensorPtr const& masks, TensorPtr const& encodeMap);
|
||||
|
||||
private:
|
||||
//! @brief generate lookahead branch information.
|
||||
//! input @param offset the position id of the last golden token.
|
||||
//! output @param draftTokens, positionIds, samplingMask of the lookahead branch.
|
||||
//! input @param startPosId is the first position id of the draftTokens.
|
||||
//! output @param draftTokens, positionIds of the lookahead branch.
|
||||
//! @return the actual filled lookahead length.
|
||||
runtime::SizeType32 lookahead(TensorPtr const& draftTokens, TensorPtr const& positionIds,
|
||||
TensorPtr const& samplingMask, runtime::SizeType32 offset);
|
||||
runtime::SizeType32 lookahead(
|
||||
TensorPtr const& draftTokens, TensorPtr const& positionIds, runtime::SizeType32 startPosId);
|
||||
|
||||
//! @brief generate verification branch information. Also save the guessed tokens for future verification.
|
||||
//! input @param offset the position id of the last golden token.
|
||||
//! input @param startPosId the first position id.
|
||||
//! input @param lastToken the last golden token for searching in the pool.
|
||||
//! output @param guessTokens, guessIds, samplingMask of the verification branch.
|
||||
//! output @param guessTokens, guessIds of the verification branch.
|
||||
//! @return the actual filled guess length.
|
||||
runtime::SizeType32 guess(TensorPtr const& guessTokens, TensorPtr const& guessIds, TensorPtr const& samplingMask,
|
||||
runtime::SizeType32 offset, runtime::TokenIdType lastToken);
|
||||
runtime::SizeType32 guess(TensorPtr const& guessTokens, TensorPtr const& guessIds, runtime::SizeType32 startPosId,
|
||||
runtime::TokenIdType lastToken);
|
||||
|
||||
//! @brief verify the guessed tokens results and generate the longest accepted tokens.
|
||||
//! input @param newLastToken is the new-generated last golden token.
|
||||
//! input @param goldenTokens is the guessed token results from the language model.
|
||||
//! input @param sampledTokens is the generated token results from the language model.
|
||||
//! input @param endToken is the end token for early quit detection.
|
||||
//! output @param accepted, acceptedOffsets in @param acceptedLength, .
|
||||
//! output @param accepted in @param acceptedLength, including the first golden one.
|
||||
//! output @param acceptedOffsets is the offsets of draft tokens, excluding the first golden one.
|
||||
void verify(TensorPtr const& accepted, TensorPtr const& acceptedOffsets, TensorPtr const& acceptedLength,
|
||||
runtime::TokenIdType newLastToken, TensorConstPtr const& goldenTokens, TensorConstPtr const& endToken);
|
||||
runtime::TokenIdType newLastToken, TensorConstPtr const& sampledTokens, TensorConstPtr const& endToken);
|
||||
|
||||
private:
|
||||
LookaheadPoolManager mPoolManager;
|
||||
@ -117,6 +113,13 @@ private:
|
||||
//! the same guess tokens from `guess` and used in `verify`
|
||||
TensorPtr mGuessTokensMax; // shape [mMaxG*(mMaxN-1)]
|
||||
TensorPtr mGuessTokens; // shape [mG*(mN-1)]
|
||||
TensorPtr mDraftTokensMax;
|
||||
TensorPtr mDraftTokens;
|
||||
TensorPtr mAttentionMask;
|
||||
TensorPtr mEncodeMapMax;
|
||||
TensorPtr mEncodeMap;
|
||||
TensorPtr mSampledTokensMax;
|
||||
TensorPtr mSampledTokens;
|
||||
|
||||
//! look ahead algorithm parameters, Window size, Level and Guess set size.
|
||||
//! max for reserving resources and current for current request.
|
||||
@ -127,6 +130,7 @@ private:
|
||||
runtime::SizeType32 mN{0};
|
||||
runtime::SizeType32 mG{0};
|
||||
runtime::SizeType32 mRuntimeMaxDraftLen{0};
|
||||
runtime::SizeType32 mRuntimeMaxDraftPathLen{0};
|
||||
//! in prefilling mode when mFilling < mN-1.
|
||||
runtime::SizeType32 mFilling;
|
||||
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
#include "tensorrt_llm/layers/lookaheadAlgorithm.h"
|
||||
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/lookaheadModule.h"
|
||||
@ -80,14 +81,14 @@ LookaheadDecodingLayer<T>::CpuAlgorithmResources::CpuAlgorithmResources(DecoderD
|
||||
mNextDraftTokens = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32);
|
||||
mNextDraftPosIds = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32);
|
||||
mGenerationLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
mGenerationLengthsMax = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
mPositionOffsets
|
||||
= BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep}), nvinfer1::DataType::kINT32);
|
||||
mPositionIds = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep}), nvinfer1::DataType::kINT32);
|
||||
mAttentionMask
|
||||
= BufferManager::cpu(ITensor::makeShape({maxTokensPerStep, maxTokensPerStep}), nvinfer1::DataType::kBOOL);
|
||||
mPackedMask = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxTokensPerStep,
|
||||
static_cast<ITensor::DimType64>(divUp(maxTokensPerStep, 32))}),
|
||||
nvinfer1::DataType::kINT32);
|
||||
mSamplingMask = BufferManager::cpu(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kBOOL);
|
||||
mNextDraftLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
mSequenceLengths = BufferManager::cpu(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
}
|
||||
@ -113,7 +114,6 @@ LookaheadDecodingLayer<T>::LookaheadDecodingLayer(
|
||||
|
||||
mWorkspaceSize = getTopKWorkspaceSize<T>(maxBatchSize, maxTokensPerStep, maxTopK, vocabSizePadded);
|
||||
mTargetTokensDevice = mBufferManager->gpu(maxBatchShape2D, nvinfer1::DataType::kINT32);
|
||||
mSamplingMaskDevice = mBufferManager->gpu(maxBatchShape2D, nvinfer1::DataType::kBOOL);
|
||||
mCurandStatesDevice
|
||||
= mBufferManager->gpu(ITensor::makeShape({maxBatchSize, sizeof(curandState_t)}), nvinfer1::DataType::kINT8);
|
||||
|
||||
@ -168,6 +168,7 @@ void LookaheadDecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth
|
||||
{
|
||||
SizeType32 gbi = batchSlotsRange[bi];
|
||||
(BufferRange<SizeType32>(*mCpuAlgo->mGenerationLengths))[gbi] = 1;
|
||||
(BufferRange<SizeType32>(*mCpuAlgo->mNextDraftLengths))[gbi] = 0;
|
||||
BufferLocation<SizeType32>(*mCpuAlgo->mPositionOffsets).at(gbi, 0) = 0;
|
||||
BufferRange<SizeType32> packedMaskRange(*ITensor::at(mCpuAlgo->mPackedMask, {gbi}));
|
||||
for (auto& mask : packedMaskRange)
|
||||
@ -184,11 +185,6 @@ void LookaheadDecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth
|
||||
PRINT_SHAPE(setupParams->attentionPackedMasks);
|
||||
mBufferManager->copy(
|
||||
*ITensor::at(mCpuAlgo->mGenerationLengths, {gbi}), *ITensor::at(setupParams->generationLengths, {gbi}));
|
||||
if (setupParams->actualGenerationLengths)
|
||||
{
|
||||
mBufferManager->copy(*ITensor::at(mCpuAlgo->mGenerationLengths, {gbi}),
|
||||
*ITensor::at(setupParams->actualGenerationLengths, {gbi}));
|
||||
}
|
||||
mBufferManager->copy(
|
||||
*ITensor::at(mCpuAlgo->mPositionOffsets, {gbi}), *ITensor::at(setupParams->positionOffsets, {gbi}));
|
||||
mBufferManager->copy(
|
||||
@ -261,39 +257,32 @@ size_t LookaheadDecodingLayer<T>::getWorkspaceSize() const noexcept
|
||||
return std::max(mWorkspaceSize, mSetupWorkspaceSize);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LookaheadDecodingLayer<T>::posIdsToMask(TensorPtr mask, TensorConstPtr posIds)
|
||||
inline void initAttentionMask(TensorPtr const& mask, std::shared_ptr<runtime::BufferManager>& bufferManager)
|
||||
{
|
||||
auto len = ITensor::volume(posIds->getShape());
|
||||
TLLM_CHECK(mask->getDimension<0>() > len);
|
||||
TLLM_CHECK(mask->getDimension<1>() * 32 > len);
|
||||
auto posIdsRange = BufferRange<SizeType32 const>(*posIds);
|
||||
auto maskLocation = BufferLocation<SizeType32>(*mask);
|
||||
|
||||
for (auto i = 0; i < maskLocation.size(); i++)
|
||||
bufferManager->setZero(*mask);
|
||||
BufferLocation<bool> maskLocation(*mask);
|
||||
auto maskShape = mask->getShape();
|
||||
for (SizeType32 i = 0; i < maskShape.d[0]; i++)
|
||||
{
|
||||
maskLocation[i] = 0;
|
||||
maskLocation.at(i, 0) = true;
|
||||
}
|
||||
maskLocation.at(0, 0) = 1;
|
||||
}
|
||||
|
||||
auto setBit = [](SizeType32& x, SizeType32 idx) { x |= (1 << idx); };
|
||||
if (len > 0)
|
||||
inline void convertBoolToInt32(TensorPtr const& dst, TensorConstPtr const& src)
|
||||
{
|
||||
auto dstShape = dst->getShape();
|
||||
auto srcShape = src->getShape();
|
||||
TLLM_CHECK(dstShape.d[0] == srcShape.d[0]);
|
||||
TLLM_CHECK(dstShape.d[1] * 32 >= srcShape.d[1]);
|
||||
BufferLocation<SizeType32> dstLocation(*dst);
|
||||
BufferLocation<bool const> srcLocation(*src);
|
||||
|
||||
auto setBit = [](SizeType32& x, SizeType32 idx, bool value) { x |= (value << idx); };
|
||||
for (auto i = 0; i < srcShape.d[0]; i++)
|
||||
{
|
||||
std::vector<std::pair<SizeType32, SizeType32>> stack;
|
||||
stack.emplace_back(0, posIdsRange[0] - 1);
|
||||
for (auto i = 1; i < len + 1; i++)
|
||||
for (auto j = 0; j < srcShape.d[1]; j++)
|
||||
{
|
||||
auto cur = posIdsRange[i - 1];
|
||||
while (stack.size() > 0 && cur <= stack.back().second)
|
||||
{
|
||||
stack.pop_back();
|
||||
}
|
||||
TLLM_CHECK(stack.size() > 0 ? cur == stack.back().second + 1 : true);
|
||||
stack.emplace_back(i, cur);
|
||||
for (auto prev : stack)
|
||||
{
|
||||
setBit(maskLocation.at(i, prev.first / 32), prev.first % 32);
|
||||
}
|
||||
setBit(dstLocation.at(i, j / 32), j % 32, srcLocation.at(i, j));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -307,12 +296,16 @@ void LookaheadDecodingLayer<T>::forwardSyncCPU(
|
||||
mCpuAlgo->mBatchSlots->reshape(inputs->batchSlots->getShape());
|
||||
mBufferManager->copy(*inputs->batchSlots, *mCpuAlgo->mBatchSlots);
|
||||
mBufferManager->copy(*inputs->curTokensPerStep.value(), *mCpuAlgo->mTokensPerStep);
|
||||
mBufferManager->copy(*inputs->curTokensPerStep.value(), *mCpuAlgo->mTokensPerStep);
|
||||
mBufferManager->copy(*inputs->endIds, *mCpuAlgo->mEndIds);
|
||||
mBufferManager->copy(*outputs->sequenceLength.value(), *mCpuAlgo->mSequenceLengths);
|
||||
|
||||
mBufferManager->copy(*mTargetTokensDevice, *mCpuAlgo->mTargetTokens);
|
||||
|
||||
if (outputs->prevDraftLengths)
|
||||
{
|
||||
mBufferManager->copy(*mCpuAlgo->mNextDraftLengths, *outputs->prevDraftLengths);
|
||||
}
|
||||
|
||||
mBufferManager->getStream().synchronize();
|
||||
|
||||
auto const batchSize = inputs->localBatchSize;
|
||||
@ -325,7 +318,6 @@ void LookaheadDecodingLayer<T>::forwardSyncCPU(
|
||||
BufferRange<SizeType32> numNewTokensCumSumRange(*mCpuAlgo->mNumNewTokensCumSum);
|
||||
BufferRange<SizeType32> batchSlotsRange(*mCpuAlgo->mBatchSlots);
|
||||
BufferRange<SizeType32> generationLengthsRange(*mCpuAlgo->mGenerationLengths);
|
||||
BufferRange<SizeType32> generationLengthsMaxRange(*mCpuAlgo->mGenerationLengthsMax);
|
||||
BufferRange<SizeType32> nextDraftLengthsRange(*mCpuAlgo->mNextDraftLengths);
|
||||
BufferRange<SizeType32> sequenceLengthsRange(*mCpuAlgo->mSequenceLengths);
|
||||
BufferLocation<SizeType32> pathsOffsetLocation(*mCpuAlgo->mPathsOffsets);
|
||||
@ -334,6 +326,7 @@ void LookaheadDecodingLayer<T>::forwardSyncCPU(
|
||||
mBufferManager->setZero(*mCpuAlgo->mPathsOffsets);
|
||||
mBufferManager->setZero(*mCpuAlgo->mNumNewTokens);
|
||||
mBufferManager->setZero(*mCpuAlgo->mNumNewTokensCumSum);
|
||||
mBufferManager->setZero(*mCpuAlgo->mPackedMask);
|
||||
|
||||
for (SizeType32 bi = 0; bi < batchSize; bi++)
|
||||
{
|
||||
@ -342,7 +335,6 @@ void LookaheadDecodingLayer<T>::forwardSyncCPU(
|
||||
|
||||
SizeType32 const tokensPerStep = generationLengthsRange[gbi];
|
||||
TensorPtr sampledTokens = ITensor::slice(mCpuAlgo->mTargetTokens, {gbi, 0}, tokensPerStep);
|
||||
PRINT_VALUES(sampledTokens);
|
||||
|
||||
if (tokensPerStep == 1)
|
||||
{
|
||||
@ -369,14 +361,18 @@ void LookaheadDecodingLayer<T>::forwardSyncCPU(
|
||||
|
||||
sequenceLengthsRange[gbi] += numNewTokensRange[gbi];
|
||||
|
||||
initAttentionMask(mCpuAlgo->mAttentionMask, mBufferManager);
|
||||
|
||||
theAlgo.prepare( //
|
||||
ITensor::at(mCpuAlgo->mNextDraftTokens, {gbi}), //
|
||||
ITensor::at(mCpuAlgo->mNextDraftPosIds, {gbi}), //
|
||||
ITensor::at(mCpuAlgo->mSamplingMask, {gbi}), //
|
||||
ITensor::at(mCpuAlgo->mNextDraftLengths, {gbi}), //
|
||||
mCpuAlgo->mAttentionMask, 1, //
|
||||
ITensor::at(mCpuAlgo->mSequenceLengths, {gbi}), //
|
||||
ITensor::at(mCpuAlgo->mOutputIds, {gbi, numNewTokensRange[gbi] - 1}));
|
||||
|
||||
convertBoolToInt32(ITensor::at(mCpuAlgo->mPackedMask, {gbi}), mCpuAlgo->mAttentionMask);
|
||||
|
||||
BufferLocation<SizeType32> posIdsLocation(*ITensor::at(mCpuAlgo->mPositionIds, {gbi}));
|
||||
for (auto& posid : posIdsLocation)
|
||||
{
|
||||
@ -385,20 +381,14 @@ void LookaheadDecodingLayer<T>::forwardSyncCPU(
|
||||
mBufferManager->copy(*ITensor::slice(mCpuAlgo->mNextDraftPosIds, {gbi, 0}, nextDraftLengthsRange[gbi]),
|
||||
*ITensor::slice(mCpuAlgo->mPositionIds, {gbi, 1}, nextDraftLengthsRange[gbi]));
|
||||
|
||||
posIdsToMask( //
|
||||
ITensor::at(mCpuAlgo->mPackedMask, {gbi}), //
|
||||
ITensor::slice(mCpuAlgo->mNextDraftPosIds, {gbi, 0}, nextDraftLengthsRange[gbi]));
|
||||
|
||||
BufferRange<SizeType32> offsetRange(*ITensor::at(mCpuAlgo->mPositionOffsets, {gbi}));
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
posIdsLocation.size() == offsetRange.size(), "%ld, %ld", posIdsLocation.size(), offsetRange.size());
|
||||
for (auto i = 0; i < posIdsLocation.size(); i++)
|
||||
{
|
||||
offsetRange[i] = posIdsLocation[i] - posIdsLocation[0];
|
||||
}
|
||||
|
||||
TensorPtr accepted = ITensor::slice(mCpuAlgo->mOutputIds, {gbi, 0}, numNewTokensRange[gbi]);
|
||||
TensorPtr draft = ITensor::slice(mCpuAlgo->mNextDraftTokens, {gbi, 0}, nextDraftLengthsRange[gbi]);
|
||||
|
||||
TLLM_LOG_DEBUG("CPU ALGO [ %d ] forward, %s", gbi, D(sampledTokens).values().c_str());
|
||||
TLLM_LOG_DEBUG("[%d][%d] CPU ALGO [ %d ] forward, %s, %s", mGlobalSteps, batchSize, gbi,
|
||||
D(accepted).values().c_str(), D(draft).values().c_str());
|
||||
@ -430,29 +420,23 @@ void LookaheadDecodingLayer<T>::forwardSyncCPU(
|
||||
mBufferManager->copy(*mCpuAlgo->mNumNewTokensCumSum, *outputs->numNewTokensCumSum); //
|
||||
mBufferManager->copy(*mCpuAlgo->mNextDraftTokens, *outputs->nextDraftTokens);
|
||||
|
||||
mBufferManager->copy(*mCpuAlgo->mPackedMask, *outputs->packedMasks);
|
||||
for (SizeType32 bi = 0; bi < batchSize; bi++)
|
||||
{
|
||||
SizeType32 gbi = batchSlotsRange[bi];
|
||||
// nextDraftLengthsRange[gbi] = mDecoderDomain.getMaxDecodingTokens() - 1;
|
||||
generationLengthsRange[gbi] = nextDraftLengthsRange[gbi] + 1;
|
||||
}
|
||||
|
||||
if (outputs->nextDraftLengths)
|
||||
{
|
||||
mBufferManager->copy(*mCpuAlgo->mNextDraftLengths, *outputs->nextDraftLengths);
|
||||
}
|
||||
|
||||
for (SizeType32 bi = 0; bi < batchSize; bi++)
|
||||
{
|
||||
SizeType32 gbi = batchSlotsRange[bi];
|
||||
generationLengthsRange[gbi] = nextDraftLengthsRange[gbi] + 1;
|
||||
generationLengthsMaxRange[gbi] = mDecoderDomain.getMaxDecodingTokens();
|
||||
}
|
||||
mBufferManager->copy(*mCpuAlgo->mPackedMask, *outputs->packedMasks);
|
||||
mBufferManager->copy(*mCpuAlgo->mGenerationLengthsMax, *outputs->generationLengths);
|
||||
mBufferManager->copy(*mCpuAlgo->mGenerationLengths, *outputs->generationLengths);
|
||||
mBufferManager->copy(*mCpuAlgo->mPositionOffsets, *outputs->positionOffsets);
|
||||
mBufferManager->copy(*mCpuAlgo->mPositionIds, *outputs->positionIds);
|
||||
|
||||
if (outputs->actualGenerationLengths)
|
||||
{
|
||||
mBufferManager->copy(*mCpuAlgo->mGenerationLengths, *outputs->actualGenerationLengths);
|
||||
}
|
||||
|
||||
mBufferManager->getStream().synchronize();
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -48,7 +48,6 @@ public:
|
||||
private:
|
||||
void forwardSyncCPU(std::shared_ptr<LookaheadDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<LookaheadDecodingInputs> const& inputs);
|
||||
void posIdsToMask(TensorPtr mask, TensorConstPtr posIds);
|
||||
|
||||
private:
|
||||
using Base::mDecoderDomain;
|
||||
@ -57,7 +56,6 @@ private:
|
||||
size_t mSetupWorkspaceSize{};
|
||||
TensorPtr mCurandStatesDevice;
|
||||
TensorPtr mTargetTokensDevice;
|
||||
TensorPtr mSamplingMaskDevice;
|
||||
|
||||
struct CpuAlgorithmResources
|
||||
{
|
||||
@ -78,11 +76,10 @@ private:
|
||||
|
||||
TensorPtr mNextDraftTokens;
|
||||
TensorPtr mNextDraftPosIds;
|
||||
TensorPtr mSamplingMask;
|
||||
TensorPtr mNextDraftLengths;
|
||||
TensorPtr mSequenceLengths;
|
||||
TensorPtr mGenerationLengths;
|
||||
TensorPtr mGenerationLengthsMax;
|
||||
TensorPtr mAttentionMask;
|
||||
TensorPtr mPackedMask;
|
||||
TensorPtr mPositionOffsets;
|
||||
TensorPtr mPositionIds;
|
||||
|
||||
@ -423,6 +423,7 @@ int BertAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
||||
request_seq_len, mNumHeads, mHeadSize, padding_offset, (float*) nullptr, 0, stream);
|
||||
}
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -143,6 +143,7 @@ int CumsumLastDimPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
||||
invokeCumsumLastDim<T>(
|
||||
batchSize, inputLength, inputs[getInputTensorIdx()], outputs[0], wp, mTempStorageBytes, stream);
|
||||
|
||||
sync_check_cuda_error();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -164,7 +164,7 @@ bool GPTAttentionPluginCommon::convertMMHAParamsToXQAParams(tensorrt_llm::kernel
|
||||
memset(&xqaParams, 0, sizeof(XQAParams));
|
||||
xqaParams.data_type = ConvertMMHAToXQAParamsHelper<T, KVCacheBuffer>::data_type;
|
||||
|
||||
xqaParams.layer_idx = mLayerIdx;
|
||||
xqaParams.layer_idx = mLayerIdxInCachePool;
|
||||
xqaParams.num_q_heads = mNumHeads;
|
||||
xqaParams.num_kv_heads = mNumKVHeads;
|
||||
xqaParams.head_size = mHeadSize;
|
||||
@ -376,13 +376,13 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
|
||||
|
||||
#define INSTANTIATE_MMHA_DISPATCH(T_MMHA, T) \
|
||||
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, false>&, \
|
||||
const FusedQKVMaskedAttentionDispatchParams<T, KVLinearBuffer>&, cudaStream_t stream); \
|
||||
FusedQKVMaskedAttentionDispatchParams<T, KVLinearBuffer> const&, cudaStream_t stream); \
|
||||
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, true>&, \
|
||||
const FusedQKVMaskedAttentionDispatchParams<T, KVLinearBuffer>&, cudaStream_t stream); \
|
||||
FusedQKVMaskedAttentionDispatchParams<T, KVLinearBuffer> const&, cudaStream_t stream); \
|
||||
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, false>&, \
|
||||
const FusedQKVMaskedAttentionDispatchParams<T, KVBlockArray>&, cudaStream_t stream); \
|
||||
FusedQKVMaskedAttentionDispatchParams<T, KVBlockArray> const&, cudaStream_t stream); \
|
||||
template void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, true>&, \
|
||||
const FusedQKVMaskedAttentionDispatchParams<T, KVBlockArray>&, cudaStream_t stream);
|
||||
FusedQKVMaskedAttentionDispatchParams<T, KVBlockArray> const&, cudaStream_t stream);
|
||||
INSTANTIATE_MMHA_DISPATCH(float, float)
|
||||
INSTANTIATE_MMHA_DISPATCH(uint16_t, half)
|
||||
#ifdef ENABLE_BF16
|
||||
@ -391,8 +391,8 @@ INSTANTIATE_MMHA_DISPATCH(__nv_bfloat16, __nv_bfloat16)
|
||||
#undef INSTANTIATE_MMHA_DISPATCH
|
||||
|
||||
GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length,
|
||||
int num_kv_heads, int head_size, int unidirectional, float q_scaling, float qk_tanh_scale,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int num_kv_heads, int layer_idx_in_cache_pool, int head_size, int unidirectional, float q_scaling,
|
||||
float qk_tanh_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, float rotary_embedding_short_m_scale, float rotary_embedding_long_m_scale,
|
||||
@ -411,6 +411,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
|
||||
, mVisionStart(vision_start)
|
||||
, mVisionLength(vision_length)
|
||||
, mNumKVHeads(num_kv_heads)
|
||||
, mLayerIdxInCachePool(layer_idx_in_cache_pool)
|
||||
, mHeadSize(head_size)
|
||||
, mUnidirectional(unidirectional)
|
||||
, mQScaling(q_scaling)
|
||||
@ -525,6 +526,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng
|
||||
read(d, mVisionStart);
|
||||
read(d, mVisionLength);
|
||||
read(d, mNumKVHeads);
|
||||
read(d, mLayerIdxInCachePool);
|
||||
read(d, mHeadSize);
|
||||
read(d, mUnidirectional);
|
||||
read(d, mQScaling);
|
||||
@ -721,7 +723,7 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
|
||||
|
||||
KVCacheBuffer kv_cache_buffer;
|
||||
auto const elemSize = mKVCacheQuantMode.hasKvCacheQuant() ? sizeof(int8_t) : sizeof(T);
|
||||
auto const sizePerToken = num_kv_heads * head_size * elemSize;
|
||||
auto sizePerToken = num_kv_heads * head_size * elemSize;
|
||||
KVBlockArray::DataType* hostKvCacheBlockOffsets = nullptr;
|
||||
if (useKVCache())
|
||||
{
|
||||
@ -1751,13 +1753,13 @@ void GPTAttentionPluginCommon::destroy() noexcept
|
||||
size_t GPTAttentionPluginCommon::getCommonSerializationSize() const noexcept
|
||||
{
|
||||
return sizeof(mLayerIdx) + sizeof(mNumHeads) + +sizeof(mVisionStart) + sizeof(mVisionLength) + sizeof(mNumKVHeads)
|
||||
+ sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling) + sizeof(mQKTanhScale)
|
||||
+ sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim) + sizeof(mRotaryEmbeddingBase)
|
||||
+ sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale) + sizeof(mRotaryEmbeddingShortMscale)
|
||||
+ sizeof(mRotaryEmbeddingLongMscale) + sizeof(mRotaryEmbeddingMaxPositions)
|
||||
+ sizeof(mRotaryEmbeddingOriginalMaxPositions) + sizeof(mTpSize) + sizeof(mTpRank) + sizeof(mEnableContextFMHA)
|
||||
+ sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode) + sizeof(mEnableXQA)
|
||||
+ sizeof(unsigned int) // mKVCacheQuantMode
|
||||
+ sizeof(mLayerIdxInCachePool) + sizeof(mHeadSize) + sizeof(mUnidirectional) + sizeof(mQScaling)
|
||||
+ sizeof(mQKTanhScale) + sizeof(mPositionEmbeddingType) + sizeof(mRotaryEmbeddingDim)
|
||||
+ sizeof(mRotaryEmbeddingBase) + sizeof(mRotaryEmbeddingScaleType) + sizeof(mRotaryEmbeddingScale)
|
||||
+ sizeof(mRotaryEmbeddingShortMscale) + sizeof(mRotaryEmbeddingLongMscale)
|
||||
+ sizeof(mRotaryEmbeddingMaxPositions) + sizeof(mRotaryEmbeddingOriginalMaxPositions) + sizeof(mTpSize)
|
||||
+ sizeof(mTpRank) + sizeof(mEnableContextFMHA) + sizeof(mFMHAForceFP32Acc) + sizeof(mMultiBlockMode)
|
||||
+ sizeof(mEnableXQA) + sizeof(unsigned int) // mKVCacheQuantMode
|
||||
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mBlockSparseParams) + sizeof(mPagedKVCache)
|
||||
+ sizeof(mTokensPerBlock) + sizeof(mType) + sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled)
|
||||
+ sizeof(mCrossAttention) + sizeof(mMaxDistance) + sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA)
|
||||
@ -1776,6 +1778,7 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
||||
write(d, mVisionStart);
|
||||
write(d, mVisionLength);
|
||||
write(d, mNumKVHeads);
|
||||
write(d, mLayerIdxInCachePool);
|
||||
write(d, mHeadSize);
|
||||
write(d, mUnidirectional);
|
||||
write(d, mQScaling);
|
||||
|
||||
@ -38,7 +38,7 @@ public:
|
||||
GPTAttentionPluginCommon() = delete;
|
||||
|
||||
GPTAttentionPluginCommon(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
|
||||
int head_size, int unidirectional, float q_scaling, float qk_tanh_scale,
|
||||
int layer_idx_in_cache_pool, int head_size, int unidirectional, float q_scaling, float qk_tanh_scale,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. Use 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
@ -307,6 +307,7 @@ protected:
|
||||
int mVisionStart;
|
||||
int mVisionLength;
|
||||
int mNumKVHeads;
|
||||
int mLayerIdxInCachePool;
|
||||
int mHeadSize;
|
||||
int mUnidirectional;
|
||||
float mQScaling;
|
||||
@ -389,6 +390,7 @@ protected:
|
||||
ss << "gptAttentionCommon members ====================" << std::endl;
|
||||
ss << "mNumHeads: " << mNumHeads << std::endl;
|
||||
ss << "mNumKVHeads: " << mNumKVHeads << std::endl;
|
||||
ss << "mLayerIdxInCachePool " << mLayerIdxInCachePool << std::endl;
|
||||
ss << "mHeadSize: " << mHeadSize << std::endl;
|
||||
ss << "mUnidirectional: " << mUnidirectional << std::endl;
|
||||
ss << "mQScaling: " << mQScaling << std::endl;
|
||||
|
||||
@ -15,6 +15,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "gptAttentionPlugin.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/kernels/decoderMaskedMultiheadAttention.h"
|
||||
#include "tensorrt_llm/kernels/gptKernels.h"
|
||||
#include "tensorrt_llm/kernels/unfusedAttentionKernels.h"
|
||||
@ -26,6 +27,7 @@
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/utils/debugUtils.h"
|
||||
#include <NvInferRuntimeBase.h>
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
@ -41,8 +43,8 @@ static char const* GPT_ATTENTION_PLUGIN_VERSION{"1"};
|
||||
static char const* GPT_ATTENTION_PLUGIN_NAME{"GPTAttention"};
|
||||
|
||||
GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length,
|
||||
int num_kv_heads, int head_size, int unidirectional, float q_scaling, float qk_tanh_scale,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int num_kv_heads, int layer_idx_in_cache_pool, int head_size, int unidirectional, float q_scaling,
|
||||
float qk_tanh_scale, tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
float rotary_embedding_scale, float rotary_embedding_short_m_scale,
|
||||
@ -57,9 +59,9 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int vision_
|
||||
bool pos_shift_enabled, bool dense_context_fmha, bool use_paged_context_fmha, bool use_fp8_context_fmha,
|
||||
bool use_cache, bool is_spec_decoding_enabled, bool spec_decoding_is_generation_length_variable,
|
||||
int spec_decoding_max_generation_length)
|
||||
: GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, head_size,
|
||||
unidirectional, q_scaling, qk_tanh_scale, position_embedding_type, rotary_embedding_dim, rotary_embedding_base,
|
||||
rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale,
|
||||
: GPTAttentionPluginCommon(layer_idx, num_heads, vision_start, vision_length, num_kv_heads, layer_idx_in_cache_pool,
|
||||
head_size, unidirectional, q_scaling, qk_tanh_scale, position_embedding_type, rotary_embedding_dim,
|
||||
rotary_embedding_base, rotary_embedding_scale_type, rotary_embedding_scale, rotary_embedding_short_m_scale,
|
||||
rotary_embedding_long_m_scale, rotary_embedding_max_positions, rotary_embedding_original_max_positions, tp_size,
|
||||
tp_rank, unfuse_qkv_gemm, context_fmha_type, enable_xqa, kv_cache_quant_mode, remove_input_padding, mask_type,
|
||||
block_sparse_params, paged_kv_cache, tokens_per_block, type, max_context_length, qkv_bias_enabled,
|
||||
@ -94,6 +96,7 @@ bool GPTAttentionPlugin::isEntryUsed(IdxEntry const& entry) const
|
||||
case IdxEntry::KV_CACHE_BLOCK_OFFSETS: return useKVCache() && mPagedKVCache;
|
||||
case IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS: return useKVCache() && mPagedKVCache;
|
||||
case IdxEntry::HOST_KV_CACHE_POOL_POINTERS: return useKVCache() && mPagedKVCache;
|
||||
case IdxEntry::HOST_KV_CACHE_POOL_MAPPING: return useKVCache() && mPagedKVCache;
|
||||
case IdxEntry::PAST_KEY_VALUE: return useKVCache() && !mPagedKVCache;
|
||||
case IdxEntry::KV_CACHE_QUANTIZATION_SCALE: return useKVCache() && mKVCacheQuantMode.hasKvCacheQuant();
|
||||
case IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE: return useKVCache() && mKVCacheQuantMode.hasKvCacheQuant();
|
||||
@ -244,6 +247,11 @@ bool GPTAttentionPlugin::supportsFormatCombination(
|
||||
// kv cache pool pointers
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT64 && inOut[pos].format == TensorFormat::kLINEAR;
|
||||
}
|
||||
else if (useKVCache() && mPagedKVCache && (pos == getIdx(IdxEntry::HOST_KV_CACHE_POOL_MAPPING)))
|
||||
{
|
||||
// kv cache pool mapping
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32 && inOut[pos].format == TensorFormat::kLINEAR;
|
||||
}
|
||||
else if (useKVCache() && mKVCacheQuantMode.hasInt8KvCache()
|
||||
&& (!mPagedKVCache && (pos == getIdx(IdxEntry::PAST_KEY_VALUE) || pos == nbInputs + 1)))
|
||||
{
|
||||
@ -478,6 +486,7 @@ int GPTAttentionPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
outputDesc, inputs, outputs, workspace, stream);
|
||||
}
|
||||
|
||||
sync_check_cuda_error();
|
||||
TLLM_LOG_TRACE("Attention plugin stop at layer %d", mLayerIdx);
|
||||
|
||||
return 0;
|
||||
@ -624,27 +633,36 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
auto const& kvCacheBlockOffsets = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)];
|
||||
auto const& kvCacheBlockOffsetsShape = inputDesc[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)].dims;
|
||||
max_blocks_per_sequence = kvCacheBlockOffsetsShape.d[kvCacheBlockOffsetsShape.nbDims - 1];
|
||||
auto const seqStride = getStride(kvCacheBlockOffsetsShape, 0);
|
||||
|
||||
std::int32_t const* host_pool_mapping
|
||||
= static_cast<std::int32_t const*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_POOL_MAPPING)]);
|
||||
|
||||
const int32_t layerToPool = host_pool_mapping[mLayerIdx];
|
||||
auto const seqStride = getStride(kvCacheBlockOffsetsShape, 1);
|
||||
auto const poolStride = getStride(kvCacheBlockOffsetsShape, 0);
|
||||
auto const seqOffset = seqIdxBeg * seqStride;
|
||||
auto const poolOffset = layerToPool * poolStride;
|
||||
|
||||
block_offsets
|
||||
= reinterpret_cast<kernels::KVBlockArray::DataType*>(inputs[getIdx(IdxEntry::KV_CACHE_BLOCK_OFFSETS)])
|
||||
+ seqOffset;
|
||||
+ poolOffset + seqOffset;
|
||||
|
||||
host_block_offsets
|
||||
= reinterpret_cast<kernels::KVBlockArray::DataType*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_OFFSETS)])
|
||||
+ seqOffset;
|
||||
+ poolOffset + seqOffset;
|
||||
|
||||
auto const* const typed_host_pool_pointers
|
||||
= static_cast<char* const*>(inputs[getIdx(IdxEntry::HOST_KV_CACHE_POOL_POINTERS)]);
|
||||
|
||||
auto const cacheElemSize = (mKVCacheQuantMode.hasKvCacheQuant() ? 1 : sizeof(T));
|
||||
|
||||
auto const blockSize = mTokensPerBlock * mNumKVHeads * mHeadSize;
|
||||
auto const bytesPerBlock = blockSize * cacheElemSize;
|
||||
auto const layerOffset = mLayerIdx * 2 * bytesPerBlock;
|
||||
auto const layerOffset = mLayerIdxInCachePool * 2 * bytesPerBlock;
|
||||
|
||||
host_primary_pool_pointer = reinterpret_cast<void*>(typed_host_pool_pointers[0] + layerOffset);
|
||||
host_secondary_pool_pointer = reinterpret_cast<void*>(typed_host_pool_pointers[1] + layerOffset);
|
||||
host_primary_pool_pointer = reinterpret_cast<void*>(typed_host_pool_pointers[layerToPool * 2] + layerOffset);
|
||||
host_secondary_pool_pointer
|
||||
= reinterpret_cast<void*>(typed_host_pool_pointers[layerToPool * 2 + 1] + layerOffset);
|
||||
}
|
||||
|
||||
AttentionOutT* context_buf_ = static_cast<AttentionOutT*>(outputs[0])
|
||||
@ -962,8 +980,9 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField
|
||||
auto* obj = new GPTAttentionPlugin(p.getScalar<int32_t>("layer_idx").value(),
|
||||
p.getScalar<int32_t>("num_heads").value(), p.getScalar<int32_t>("vision_start").value(),
|
||||
p.getScalar<int32_t>("vision_length").value(), p.getScalar<int32_t>("num_kv_heads").value(),
|
||||
p.getScalar<int32_t>("head_size").value(), p.getScalar<int32_t>("unidirectional").value(),
|
||||
p.getScalar<float>("q_scaling").value(), p.getScalar<float>("qk_tanh_scale").value(),
|
||||
p.getScalar<int32_t>("layer_idx_in_cache_pool").value(), p.getScalar<int32_t>("head_size").value(),
|
||||
p.getScalar<int32_t>("unidirectional").value(), p.getScalar<float>("q_scaling").value(),
|
||||
p.getScalar<float>("qk_tanh_scale").value(),
|
||||
static_cast<PositionEmbeddingType>(p.getScalar<int8_t>("position_embedding_type").value()),
|
||||
p.getScalar<int32_t>("rotary_embedding_dim").value(), p.getScalar<float>("rotary_embedding_base").value(),
|
||||
static_cast<RotaryScalingType>(p.getScalar<int8_t>("rotary_embedding_scale_type").value()),
|
||||
|
||||
@ -85,7 +85,7 @@ class GPTAttentionPlugin : public GPTAttentionPluginCommon
|
||||
{
|
||||
public:
|
||||
GPTAttentionPlugin(int layer_idx, int num_heads, int vision_start, int vision_length, int num_kv_heads,
|
||||
int head_size, int unidirectional, float q_scaling, float qk_tanh_scale,
|
||||
int layer_idx_in_cache_pool, int head_size, int unidirectional, float q_scaling, float qk_tanh_scale,
|
||||
tensorrt_llm::kernels::PositionEmbeddingType position_embedding_type,
|
||||
int rotary_embedding_dim, // for RoPE. 0 for non-RoPE
|
||||
float rotary_embedding_base, tensorrt_llm::kernels::RotaryScalingType rotary_embedding_scale_type,
|
||||
@ -182,6 +182,7 @@ private:
|
||||
KV_CACHE_BLOCK_OFFSETS,
|
||||
HOST_KV_CACHE_BLOCK_OFFSETS,
|
||||
HOST_KV_CACHE_POOL_POINTERS,
|
||||
HOST_KV_CACHE_POOL_MAPPING,
|
||||
PAST_KEY_VALUE,
|
||||
KV_CACHE_QUANTIZATION_SCALE,
|
||||
KV_CACHE_DEQUANTIZATION_SCALE,
|
||||
|
||||
@ -90,6 +90,7 @@ int IdentityPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer
|
||||
|
||||
cudaMemcpyAsync(outputs[0], inputs[0], count, cudaMemcpyDeviceToDevice, stream);
|
||||
|
||||
sync_check_cuda_error();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -177,7 +177,7 @@ int LayernormQuantizationPlugin::enqueue(nvinfer1::PluginTensorDesc const* input
|
||||
scale, dynamic_scale, output);
|
||||
}
|
||||
#endif
|
||||
|
||||
sync_check_cuda_error();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -214,6 +214,7 @@ int lruPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1
|
||||
{
|
||||
invokeRGLRUUpdate<T>(lru_params, stream);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -200,6 +200,7 @@ int MambaConv1dPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
{
|
||||
invokeMambaConv1dGeneration<T>(mambaConv1dParams, stream);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -203,7 +203,7 @@ int QuantizePerTokenPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
sync_check_cuda_error();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -134,7 +134,7 @@ int QuantizeTensorPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc,
|
||||
stream, mProp.maxGridSize[0]);
|
||||
}
|
||||
#endif
|
||||
|
||||
sync_check_cuda_error();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -236,7 +236,7 @@ int RmsnormQuantizationPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDe
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
sync_check_cuda_error();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
@ -347,6 +347,7 @@ int SelectiveScanPlugin::enqueueImpl(nvinfer1::PluginTensorDesc const* inputDesc
|
||||
{
|
||||
invokeSelectiveScanUpdate<T, float>(ssm_params, stream);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
||||
8
cpp/tensorrt_llm/pybind/CMakeLists.txt
Normal file → Executable file
8
cpp/tensorrt_llm/pybind/CMakeLists.txt
Normal file → Executable file
@ -41,9 +41,11 @@ set_property(TARGET ${TRTLLM_PYBIND_MODULE} PROPERTY POSITION_INDEPENDENT_CODE
|
||||
target_link_directories(${TRTLLM_PYBIND_MODULE} PUBLIC
|
||||
"${TORCH_INSTALL_PREFIX}/lib")
|
||||
target_link_libraries(
|
||||
${TRTLLM_PYBIND_MODULE}
|
||||
PUBLIC ${SHARED_TARGET} ${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python
|
||||
${UNDEFINED_FLAG})
|
||||
${TRTLLM_PYBIND_MODULE} PUBLIC ${SHARED_TARGET} ${UNDEFINED_FLAG}
|
||||
${NO_AS_NEEDED_FLAG})
|
||||
target_link_libraries(
|
||||
${TRTLLM_PYBIND_MODULE} PUBLIC ${Python3_LIBRARIES} ${TORCH_LIBRARIES}
|
||||
torch_python ${UNDEFINED_FLAG})
|
||||
target_compile_definitions(${TRTLLM_PYBIND_MODULE}
|
||||
PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE})
|
||||
target_compile_definitions(${TRTLLM_PYBIND_MODULE}
|
||||
|
||||
@ -178,19 +178,25 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def(py::self != py::self);
|
||||
|
||||
py::class_<tr::ModelConfig>(m, "ModelConfig")
|
||||
.def(py::init<SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, nvinfer1::DataType>(),
|
||||
py::arg("vocab_size"), py::arg("num_attention_layers"), py::arg("num_rnn_layers"), py::arg("num_heads"),
|
||||
py::arg("hidden_size"), py::arg("data_type"))
|
||||
.def(py::init<SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, SizeType32, nvinfer1::DataType>(),
|
||||
py::arg("vocab_size"), py::arg("num_layers"), py::arg("num_attention_layers"), py::arg("num_rnn_layers"),
|
||||
py::arg("num_heads"), py::arg("hidden_size"), py::arg("data_type"))
|
||||
.def_property_readonly("vocab_size", &tr::ModelConfig::getVocabSize)
|
||||
.def("vocab_size_padded", &tr::ModelConfig::getVocabSizePadded, py::arg("world_size"))
|
||||
.def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, py::arg("pipeline_parallelism") = 1)
|
||||
.def("num_rnn_layers", &tr::ModelConfig::getNbRnnLayers, py::arg("pipeline_parallelism") = 1)
|
||||
.def("num_layers", &tr::ModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1)
|
||||
.def("num_attention_layers", &tr::ModelConfig::getNbAttentionLayers, py::arg("pipeline_parallelism") = 1,
|
||||
py::arg("pipeline_parallelism_rank") = 0)
|
||||
.def("num_rnn_layers", &tr::ModelConfig::getNbRnnLayers, py::arg("pipeline_parallelism") = 1,
|
||||
py::arg("pipeline_parallelism_rank") = 0)
|
||||
.def("num_kv_heads", &tr::ModelConfig::getNbKvHeads, py::arg("layer_idx"))
|
||||
.def("set_num_kv_heads", &tr::ModelConfig::setNbKvHeads, py::arg("num_kv_heads"))
|
||||
.def_property_readonly("num_heads", &tr::ModelConfig::getNbHeads)
|
||||
.def_property_readonly("hidden_size", &tr::ModelConfig::getHiddenSize)
|
||||
.def_property_readonly("size_per_head", &tr::ModelConfig::getSizePerHead)
|
||||
.def_property_readonly("data_type", &tr::ModelConfig::getDataType)
|
||||
.def_property("num_kv_heads", &tr::ModelConfig::getNbKvHeads, &tr::ModelConfig::setNbKvHeads)
|
||||
.def_property("head_size", &tr::ModelConfig::getSizePerHead, &tr::ModelConfig::setSizePerHead)
|
||||
.def_property(
|
||||
"num_kv_heads_per_layer", &tr::ModelConfig::getNumKvHeadsPerLayer, &tr::ModelConfig::setNumKvHeadsPerLayer)
|
||||
.def_property("use_gpt_attention_plugin",
|
||||
py::overload_cast<>(&tr::ModelConfig::useGptAttentionPlugin, py::const_),
|
||||
py::overload_cast<bool>(&tr::ModelConfig::useGptAttentionPlugin))
|
||||
|
||||
@ -93,7 +93,8 @@ void InitBindings(pybind11::module_& m)
|
||||
|
||||
py::enum_<tle::CapacitySchedulerPolicy>(m, "CapacitySchedulerPolicy")
|
||||
.value("MAX_UTILIZATION", tle::CapacitySchedulerPolicy::kMAX_UTILIZATION)
|
||||
.value("GUARANTEED_NO_EVICT", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT);
|
||||
.value("GUARANTEED_NO_EVICT", tle::CapacitySchedulerPolicy::kGUARANTEED_NO_EVICT)
|
||||
.value("STATIC_BATCH", tle::CapacitySchedulerPolicy::kSTATIC_BATCH);
|
||||
|
||||
py::enum_<tle::ContextChunkingPolicy>(m, "ContextChunkingPolicy")
|
||||
.value("EQUAL_PROGRESS", tle::ContextChunkingPolicy::kEQUAL_PROGRESS)
|
||||
@ -299,7 +300,8 @@ void InitBindings(pybind11::module_& m)
|
||||
.def_property_readonly("max_verification_set_size", &tle::LookaheadDecodingConfig::getVerificationSetSize);
|
||||
|
||||
py::class_<tle::ContextPhaseParams>(m, "ContextPhaseParams")
|
||||
.def(py::init<VecTokens>(), py::arg("first_gen_tokens"));
|
||||
.def(py::init<VecTokens, tle::ContextPhaseParams::RequestIdType>(), py::arg("first_gen_tokens"),
|
||||
py::arg("req_id"));
|
||||
|
||||
py::class_<tle::Request> request(m, "Request");
|
||||
request
|
||||
@ -631,14 +633,18 @@ void InitBindings(pybind11::module_& m)
|
||||
|
||||
auto extendedRuntimePerfKnobConfigSetstate = [](py::tuple state)
|
||||
{
|
||||
if (state.size() != 2)
|
||||
if (state.size() != 4)
|
||||
{
|
||||
throw std::runtime_error("Invalid extendedRuntimePerfKnobConfig state!");
|
||||
}
|
||||
return tle::ExtendedRuntimePerfKnobConfig(state[0].cast<bool>(), state[1].cast<bool>());
|
||||
return tle::ExtendedRuntimePerfKnobConfig(
|
||||
state[0].cast<bool>(), state[1].cast<bool>(), state[2].cast<bool>(), state[2].cast<SizeType32>());
|
||||
};
|
||||
auto extendedRuntimePerfKnobConfigGetstate = [](tle::ExtendedRuntimePerfKnobConfig const& self)
|
||||
{ return py::make_tuple(self.getMultiBlockMode(), self.getEnableContextFMHAFP32Acc()); };
|
||||
{
|
||||
return py::make_tuple(self.getMultiBlockMode(), self.getEnableContextFMHAFP32Acc(), self.getCudaGraphMode(),
|
||||
self.getCudaGraphCacheSize());
|
||||
};
|
||||
py::class_<tle::ExtendedRuntimePerfKnobConfig>(m, "ExtendedRuntimePerfKnobConfig")
|
||||
.def(
|
||||
py::init<bool, bool>(), py::arg("multi_block_mode") = true, py::arg("enable_context_fmha_fp32_acc") = false)
|
||||
@ -646,6 +652,10 @@ void InitBindings(pybind11::module_& m)
|
||||
&tle::ExtendedRuntimePerfKnobConfig::setMultiBlockMode)
|
||||
.def_property("enable_context_fmha_fp32_acc", &tle::ExtendedRuntimePerfKnobConfig::getEnableContextFMHAFP32Acc,
|
||||
&tle::ExtendedRuntimePerfKnobConfig::setEnableContextFMHAFP32Acc)
|
||||
.def_property("cuda_graph_mode", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphMode,
|
||||
&tle::ExtendedRuntimePerfKnobConfig::setCudaGraphMode)
|
||||
.def_property("cuda_graph_cache_size", &tle::ExtendedRuntimePerfKnobConfig::getCudaGraphCacheSize,
|
||||
&tle::ExtendedRuntimePerfKnobConfig::setCudaGraphCacheSize)
|
||||
.def(py::pickle(extendedRuntimePerfKnobConfigGetstate, extendedRuntimePerfKnobConfigSetstate));
|
||||
|
||||
auto executorConfigGetState = [](tle::ExecutorConfig const& self)
|
||||
|
||||
@ -803,7 +803,7 @@ void GptDecoderBatched::forwardDispatch(
|
||||
}
|
||||
}
|
||||
|
||||
GptDecoderBatched::TokenPtr GptDecoderBatched::forwardAsync(
|
||||
GptDecoderBatched::DecoderFinishedEventPtr GptDecoderBatched::forwardAsync(
|
||||
decoder_batch::Output& output, decoder_batch::Input const& input)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
@ -813,7 +813,7 @@ GptDecoderBatched::TokenPtr GptDecoderBatched::forwardAsync(
|
||||
CudaEvent eventStop{};
|
||||
mRuntimeStream->record(eventStop);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
return std::make_unique<decoder_batch::Token>(std::move(eventStop), input.active);
|
||||
return std::make_unique<decoder_batch::DecoderFinishedEvent>(std::move(eventStop), input.active);
|
||||
}
|
||||
|
||||
void GptDecoderBatched::forwardDecoder(
|
||||
@ -1019,12 +1019,12 @@ void GptDecoderBatched::forwardDecoder(
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatched::updateFinished(decoder_batch::Token const& token)
|
||||
void GptDecoderBatched::updateFinished(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
for (std::int32_t i = 0; i < mActualBatchSize; ++i)
|
||||
{
|
||||
if (token.active[i] && !mFinished[i])
|
||||
if (decoderFinishEvent.active[i] && !mFinished[i])
|
||||
{
|
||||
auto finishedSum = ITensor::slice(mJointDecodingOutput->finishedSum, i, 1);
|
||||
mFinished[i] = mFinished[i]
|
||||
@ -1035,25 +1035,25 @@ void GptDecoderBatched::updateFinished(decoder_batch::Token const& token)
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatched::forwardSync(decoder_batch::Token const& token)
|
||||
void GptDecoderBatched::forwardSync(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
token.event.synchronize();
|
||||
decoderFinishEvent.event.synchronize();
|
||||
|
||||
updateFinished(token);
|
||||
updateFinished(decoderFinishEvent);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatched::forwardSync(
|
||||
decoder_batch::Token const& token, decoder_batch::Output& output, decoder_batch::Input const& input)
|
||||
void GptDecoderBatched::forwardSync(decoder_batch::DecoderFinishedEvent const& decoderFinishEvent,
|
||||
decoder_batch::Output& output, decoder_batch::Input const& input)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
token.event.synchronize();
|
||||
decoderFinishEvent.event.synchronize();
|
||||
|
||||
forwardDispatch(output, input, ForwardType::kSYNC);
|
||||
|
||||
updateFinished(token);
|
||||
updateFinished(decoderFinishEvent);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -1231,7 +1231,7 @@ void GptDecoderBatched::forwardAsync(decoder::Output& output, decoder::Input con
|
||||
batchOutput.cacheIndirection = output.cacheIndirection;
|
||||
batchOutput.sequenceLengths = output.sequenceLengths;
|
||||
|
||||
mForwardToken = forwardAsync(batchOutput, batchInput);
|
||||
mDecoderFinishEvent = forwardAsync(batchOutput, batchInput);
|
||||
mBufferManager.setZero(*mFinishedSum);
|
||||
kernels::reduce(
|
||||
*mFinishedSum, *ITensor::slice(mJointDecodingOutput->finishedSum, 0, mActualBatchSize), *mRuntimeStream);
|
||||
@ -1243,7 +1243,7 @@ void GptDecoderBatched::forwardAsync(decoder::Output& output, decoder::Input con
|
||||
void GptDecoderBatched::forwardSync()
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
forwardSync(*mForwardToken);
|
||||
forwardSync(*mDecoderFinishEvent);
|
||||
// wait for mFinishedSum to be updated
|
||||
mForwardEvent.synchronize();
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -85,6 +85,8 @@ std::vector<ModelConfig::LayerType> buildLayerTypes(
|
||||
|
||||
auto constexpr layerNameAttention = "attention";
|
||||
auto constexpr layerNameRecurrent = "recurrent";
|
||||
auto constexpr layerNameLinear = "linear";
|
||||
auto constexpr layerNameNoop = "no_op";
|
||||
|
||||
// The json field specifies a "group" of layers, which gets repeated multiple times
|
||||
// Note that the total number of layers does not need to be a multiple of a layer
|
||||
@ -102,9 +104,17 @@ std::vector<ModelConfig::LayerType> buildLayerTypes(
|
||||
{
|
||||
result[i] = ModelConfig::LayerType::kRECURRENT;
|
||||
}
|
||||
else if (layerStringTypes[i % groupSize] == layerNameLinear)
|
||||
{
|
||||
result[i] = ModelConfig::LayerType::kLINEAR;
|
||||
}
|
||||
else if (layerStringTypes[i % groupSize] == layerNameNoop)
|
||||
{
|
||||
result[i] = ModelConfig::LayerType::kNOOP;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_ERROR("Unknown layer type: %s", layerStringTypes[i % groupSize].c_str());
|
||||
TLLM_LOG_WARNING("Unknown layer type: %s, assuming attention", layerStringTypes[i % groupSize].c_str());
|
||||
}
|
||||
}
|
||||
|
||||
@ -147,9 +157,25 @@ ModelConfig createModelConfig(
|
||||
|
||||
auto const mlpHiddenSize = parseJsonFieldOptional<SizeType32>(config, mlpHiddenSizeField);
|
||||
|
||||
auto modelConfig = ModelConfig{vocabSize, numAttentionLayers, numRnnLayers, numHeads, hiddenSize, dataType};
|
||||
auto numKvHeadsPerAttentionLayer
|
||||
= parseJsonFieldOr<std::vector<SizeType32>>(config, "num_kv_heads_per_layer", std::vector<SizeType32>());
|
||||
|
||||
auto modelConfig
|
||||
= ModelConfig{vocabSize, numLayers, numAttentionLayers, numRnnLayers, numHeads, hiddenSize, dataType};
|
||||
|
||||
if (!numKvHeadsPerAttentionLayer.empty())
|
||||
{
|
||||
std::transform(numKvHeadsPerAttentionLayer.cbegin(), numKvHeadsPerAttentionLayer.cend(),
|
||||
numKvHeadsPerAttentionLayer.begin(),
|
||||
[tensorParallelism](SizeType32 const numKvHeads) { return std::max(numKvHeads / tensorParallelism, 1); });
|
||||
modelConfig.setNumKvHeadsPerLayer(numKvHeadsPerAttentionLayer);
|
||||
}
|
||||
else
|
||||
{
|
||||
modelConfig.setNbKvHeads(numKvHeads);
|
||||
}
|
||||
|
||||
modelConfig.setSizePerHead(sizePerHead);
|
||||
modelConfig.setNbKvHeads(numKvHeads);
|
||||
modelConfig.setLayerTypes(layerTypes);
|
||||
|
||||
// Set logits datatype
|
||||
@ -269,9 +295,24 @@ void parseLora(ModelConfig& modelConfig, Json const& json, Json const& pluginCon
|
||||
|
||||
if (loraTargetModules.has_value())
|
||||
{
|
||||
auto const& loraModuleNames = loraTargetModules.value();
|
||||
auto const& numKvHeadsPerLayer = modelConfig.getNumKvHeadsPerLayer();
|
||||
if (!loraModuleNames.empty())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(std::all_of(numKvHeadsPerLayer.cbegin(), numKvHeadsPerLayer.cend(),
|
||||
[firstNumKvHeads = numKvHeadsPerLayer[0]](SizeType32 numKvHeads)
|
||||
{ return numKvHeads == firstNumKvHeads; }),
|
||||
"LORA with a VGQA model is not supported");
|
||||
}
|
||||
// TODO(oargov): don't assume all layers have the same num_kv_heads to support VGQA
|
||||
auto const numKvHeads = numKvHeadsPerLayer.empty() ? modelConfig.getNbHeads() : numKvHeadsPerLayer[0];
|
||||
bool hasMoE = !engineVersionNone && json.at("pretrained_config").contains("moe");
|
||||
auto const numExperts = hasMoE
|
||||
? json.at("pretrained_config").at("moe").at("num_experts").template get<SizeType32>()
|
||||
: SizeType32{0};
|
||||
modelConfig.setLoraModules(LoraModule::createLoraModules(loraTargetModules.value(), modelConfig.getHiddenSize(),
|
||||
modelConfig.getMlpHiddenSize(), modelConfig.getNbHeads(), modelConfig.getNbKvHeads(),
|
||||
modelConfig.getSizePerHead(), tensorParallelism));
|
||||
modelConfig.getMlpHiddenSize(), modelConfig.getNbHeads(), numKvHeads, modelConfig.getSizePerHead(),
|
||||
tensorParallelism, numExperts));
|
||||
}
|
||||
|
||||
modelConfig.setMaxLoraRank(loraMaxRank);
|
||||
|
||||
@ -219,8 +219,13 @@ void GptSession::createKvCacheManager(SizeType32 maxBatchSize, SizeType32 maxBea
|
||||
// tokens, when enabling cyclic kv cache.
|
||||
auto const useOneMoreBlock = maxBeamWidth > 1 && maxSequenceLength > maxAttentionWindow;
|
||||
|
||||
auto const localNbLayers = mModelConfig.getNbAttentionLayers(mWorldConfig.getPipelineParallelism());
|
||||
auto const nbKvHeads = mModelConfig.getNbKvHeads();
|
||||
auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = mModelConfig.getNumKvHeadsPerLayerLocalRange(
|
||||
mWorldConfig.getPipelineParallelism(), mWorldConfig.getPipelineParallelRank());
|
||||
TLLM_CHECK_WITH_INFO(std::all_of(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd,
|
||||
[firstNumKvHeads = *numKvHeadsPerLayerBegin](SizeType32 numKvHeads)
|
||||
{ return numKvHeads == firstNumKvHeads; }),
|
||||
"Deprecated session API does not support multiple cache pools, use the newer executor API instead");
|
||||
|
||||
auto const sizePerHead = mModelConfig.getSizePerHead();
|
||||
bool constexpr enableBlockReuse{false};
|
||||
bool enableDiffMaxAttenWin = false;
|
||||
@ -235,7 +240,8 @@ void GptSession::createKvCacheManager(SizeType32 maxBatchSize, SizeType32 maxBea
|
||||
TLLM_CHECK_WITH_INFO(maxBeamWidth == 1 || !enableDiffMaxAttenWin,
|
||||
"Can't support layer-wise max_attention_window with beam search. Please use a unified max_attention_window for "
|
||||
"all layers.");
|
||||
mKvCacheManager = std::make_shared<bmkv::KVCacheManager>(localNbLayers, nbKvHeads, sizePerHead, tokensPerBlock,
|
||||
mKvCacheManager = std::make_shared<bmkv::KVCacheManager>(
|
||||
std::vector<SizeType32>(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd), sizePerHead, tokensPerBlock,
|
||||
blocksInPrimaryPool, blocksInSecondaryPool, maxBatchSize, maxBeamWidth, maxAttentionWindow, sinkTokenLength,
|
||||
useOneMoreBlock, mRuntime->getStreamPtr(), enableBlockReuse, kvCacheConfig.onboardBlocks);
|
||||
|
||||
@ -253,6 +259,7 @@ void GptSession::createKvCacheManager(SizeType32 maxBatchSize, SizeType32 maxBea
|
||||
for (auto& buffers : mBuffers)
|
||||
{
|
||||
buffers->transformerBuffers->setKvPoolPointers(mKvCacheManager.get());
|
||||
buffers->transformerBuffers->setKvPoolMapping(mKvCacheManager.get());
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -11,7 +11,9 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/runtime/lookaheadBuffers.h"
|
||||
#include "iTensor.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
@ -28,8 +30,6 @@ LookaheadDecodingBuffers::LookaheadDecodingBuffers(
|
||||
, positionIds(
|
||||
bufferManager.gpu(ITensor::makeShape({maxNumSequences, maxTokensPerStep}), nvinfer1::DataType::kINT32))
|
||||
{
|
||||
TLLM_LOG_DEBUG(
|
||||
"LookaheadDecodingBuffers, maxNumSequences = %d, maxTokensPerStep = %d", maxNumSequences, maxTokensPerStep);
|
||||
}
|
||||
|
||||
LookaheadRuntimeBuffers::LookaheadRuntimeBuffers(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
|
||||
@ -40,11 +40,11 @@ LookaheadRuntimeBuffers::LookaheadRuntimeBuffers(SizeType32 maxBatchSize, SizeTy
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK_WITH_INFO(maxBeamWidth == 1, "Lookahead decoding does not support beam search");
|
||||
|
||||
// auto const tokensPerStep = modelConfig.getMaxTokensPerStep();
|
||||
auto const tokensPerStep = modelConfig.getMaxDecodingTokens();
|
||||
auto const numPackedMasks = static_cast<ITensor::DimType64>(tensorrt_llm::common::divUp(tokensPerStep, 32));
|
||||
|
||||
// Copy buffers to device
|
||||
cumSumLength = manager.pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
|
||||
|
||||
packedMasksDevice
|
||||
= manager.gpu(ITensor::makeShape({maxBatchSize * tokensPerStep, numPackedMasks}), nvinfer1::DataType::kINT32);
|
||||
positionOffsetsDevice = manager.gpu(ITensor::makeShape({maxBatchSize, tokensPerStep}), nvinfer1::DataType::kINT32);
|
||||
@ -76,24 +76,59 @@ void LookaheadRuntimeBuffers::setFromInputs(SizeType32 numCtxSequences, SizeType
|
||||
|
||||
auto const tokensPerStep = modelConfig.getMaxDecodingTokens();
|
||||
|
||||
manager.copy(seqSlots, *batchSlotsHostCopy);
|
||||
manager.copy(*decoderLookaheadBuffers.generationLengths, *generationLengthsHostCopy);
|
||||
manager.copy(*decoderLookaheadBuffers.positionOffsets, *positionOffsetsHostCopy);
|
||||
manager.copy(*decoderLookaheadBuffers.packedMasks, *packedMaskHostCopy);
|
||||
manager.copy(*decoderLookaheadBuffers.positionIds, *positionIdsHostCopy);
|
||||
manager.copy(seqSlots, *batchSlotsHostCopy);
|
||||
manager.copy(*decoderLookaheadBuffers.generationLengths, *generationLengthsHostCopy);
|
||||
|
||||
manager.getStream().synchronize();
|
||||
|
||||
BufferRange<SizeType32 const> batchSlotsRange(*batchSlotsHostCopy);
|
||||
BufferRange<SizeType32> cumSumLengthRange(*cumSumLength);
|
||||
|
||||
SizeType32 maxGenerationLength = 0;
|
||||
for (SizeType32 bi = 0; bi < numGenSequences; bi++)
|
||||
{
|
||||
SizeType32 gbi = batchSlotsRange[bi + numCtxSequences];
|
||||
manager.copy(*ITensor::at(generationLengthsHostCopy, {gbi}), *ITensor::at(generationLengthsHost, {bi}));
|
||||
manager.copy(*ITensor::at(positionOffsetsHostCopy, {gbi}), *ITensor::at(positionOffsetsHost, {bi}));
|
||||
manager.copy(*ITensor::slice(packedMaskHostCopy, gbi * tokensPerStep, tokensPerStep),
|
||||
*ITensor::slice(packedMaskHost, bi * tokensPerStep, tokensPerStep));
|
||||
manager.copy(*ITensor::at(positionIdsHostCopy, {gbi}), *ITensor::at(positionIdsHost, {bi}));
|
||||
SizeType32 theLength = BufferRange<SizeType32>(*generationLengthsHostCopy)[gbi];
|
||||
maxGenerationLength = std::max(maxGenerationLength, theLength);
|
||||
}
|
||||
|
||||
auto positionOffsetShape = positionOffsetsHost->getShape();
|
||||
positionOffsetShape.d[1] = maxGenerationLength;
|
||||
positionOffsetsHost->reshape(positionOffsetShape);
|
||||
positionOffsetsDevice->reshape(positionOffsetShape);
|
||||
|
||||
auto positionIdsShape = positionIdsHostCopy->getShape();
|
||||
auto positionIdsShape1D = ITensor::makeShape({ITensor::volume(positionIdsShape)});
|
||||
positionIdsHostCopy->reshape(positionIdsShape1D);
|
||||
positionIdsHost->reshape(positionIdsShape1D);
|
||||
|
||||
cumSumLengthRange[0] = 0;
|
||||
for (SizeType32 bi = 0; bi < numGenSequences; bi++)
|
||||
{
|
||||
SizeType32 gbi = batchSlotsRange[bi + numCtxSequences];
|
||||
SizeType32 theLength = BufferRange<SizeType32>(*generationLengthsHostCopy)[gbi];
|
||||
|
||||
manager.copy(*ITensor::at(generationLengthsHostCopy, {gbi}), *ITensor::at(generationLengthsHost, {bi}));
|
||||
|
||||
manager.copy(*ITensor::slice(positionOffsetsHostCopy, {gbi, 0}, theLength),
|
||||
*ITensor::slice(positionOffsetsHost, {bi, 0}, theLength));
|
||||
|
||||
manager.copy(*ITensor::slice(packedMaskHostCopy, gbi * tokensPerStep, theLength),
|
||||
*ITensor::slice(packedMaskHost, cumSumLengthRange[0], theLength));
|
||||
|
||||
manager.copy(*ITensor::slice(positionIdsHostCopy, gbi * tokensPerStep, theLength),
|
||||
*ITensor::slice(positionIdsHost, cumSumLengthRange[0], theLength));
|
||||
|
||||
cumSumLengthRange[0] += theLength;
|
||||
}
|
||||
|
||||
positionIdsHostCopy->reshape(positionIdsShape);
|
||||
positionIdsHost->reshape(positionIdsShape);
|
||||
positionIdsDevice->reshape(positionIdsShape);
|
||||
|
||||
manager.copy(*ITensor::slice(generationLengthsHost, 0, numGenSequences),
|
||||
*ITensor::slice(generationLengthsDevice, 0, numGenSequences));
|
||||
manager.copy(*ITensor::slice(positionOffsetsHost, 0, numGenSequences),
|
||||
@ -102,6 +137,7 @@ void LookaheadRuntimeBuffers::setFromInputs(SizeType32 numCtxSequences, SizeType
|
||||
*ITensor::slice(packedMasksDevice, 0, numGenSequences * tokensPerStep));
|
||||
manager.copy(
|
||||
*ITensor::slice(positionIdsHost, 0, numGenSequences), *ITensor::slice(positionIdsDevice, 0, numGenSequences));
|
||||
positionIdsDevice->reshape(ITensor::makeShape({cumSumLengthRange[0]}));
|
||||
|
||||
manager.getStream().synchronize();
|
||||
|
||||
|
||||
@ -21,7 +21,7 @@ namespace tensorrt_llm::runtime
|
||||
|
||||
std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> const& loraModuleNames,
|
||||
SizeType32 hiddenSize, SizeType32 mlpHiddenSize, SizeType32 numAttentionHeads, SizeType32 numKvAttentionHeads,
|
||||
SizeType32 attentionHeadSize, SizeType32 tpSize)
|
||||
SizeType32 attentionHeadSize, SizeType32 tpSize, SizeType32 numExperts)
|
||||
{
|
||||
auto const hidden = hiddenSize * tpSize;
|
||||
auto const mlpHidden = mlpHiddenSize * tpSize;
|
||||
@ -55,10 +55,10 @@ std::vector<LoraModule> LoraModule::createLoraModules(std::vector<std::string> c
|
||||
case ModuleType::kMLP_4H_TO_H: modules.emplace_back(t, mlpHiddenSize, hidden, false, true, 1, -1); break;
|
||||
// TODO(TRTLLM-379): Support MOE LoRA weights
|
||||
case ModuleType::kMOE_H_TO_4H:
|
||||
case ModuleType::kMOE_GATE:
|
||||
case ModuleType::kMOE_4H_TO_H:
|
||||
case ModuleType::kMOE_ROUTER:
|
||||
case ModuleType::kMLP_ROUTER:
|
||||
case ModuleType::kMOE_GATE: modules.emplace_back(t, hidden, mlpHidden, false, true, -1, 0); break;
|
||||
case ModuleType::kMOE_4H_TO_H: modules.emplace_back(t, mlpHiddenSize, hidden, false, true, 1, -1); break;
|
||||
case ModuleType::kMOE_ROUTER: modules.emplace_back(t, hidden, numExperts, false, true, -1, -1); break;
|
||||
case ModuleType::kMLP_ROUTER: modules.emplace_back(t, hidden, 1, false, true, -1, -1); break;
|
||||
case ModuleType::kINVALID: throw std::runtime_error("Invalid LoRA module " + moduleName);
|
||||
}
|
||||
}
|
||||
|
||||
@ -15,11 +15,11 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/runtime/rnnStateBuffers.h"
|
||||
#include "iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/runtimeBuffers.h"
|
||||
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
RnnStateBuffers::RnnStateBuffers()
|
||||
{
|
||||
@ -92,8 +92,8 @@ RnnStateBuffers::RnnStateBuffers(
|
||||
auto statePtrsShape = ITensor::makeShape({localNbLayers});
|
||||
slotMappingDevice = bufferManager.gpu(slotMappingShape, nvinfer1::DataType::kINT32);
|
||||
slotMappingHost = BufferManager::cpu(slotMappingShape, nvinfer1::DataType::kINT32);
|
||||
rnnStatePtrs = BufferManager::cpu(statePtrsShape, nvinfer1::DataType::kINT64);
|
||||
convStatePtrs = BufferManager::cpu(statePtrsShape, nvinfer1::DataType::kINT64);
|
||||
rnnStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
|
||||
convStatePtrs = BufferManager::cpu(statePtrsShape, TRTDataType<void*>::value);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -179,8 +179,8 @@ void RnnStateBuffers::fillStatePtrs()
|
||||
rnnStatePtr.resize(mLocalNbLayers);
|
||||
convStatePtr.resize(mLocalNbLayers);
|
||||
|
||||
void** rnnStatePtrArray = static_cast<void**>(rnnStatePtrs->data());
|
||||
void** convStatePtrArray = static_cast<void**>(convStatePtrs->data());
|
||||
auto* rnnStatePtrArray = bufferCast<void*>(*rnnStatePtrs);
|
||||
auto* convStatePtrArray = bufferCast<void*>(*convStatePtrs);
|
||||
|
||||
for (int i = 0; i < mLocalNbLayers; i++)
|
||||
{
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include "tensorrt_llm/common/nvtxUtils.h"
|
||||
#include "tensorrt_llm/common/safetensors.h"
|
||||
#include "tensorrt_llm/executor/tensor.h"
|
||||
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
|
||||
#include "tllmLogger.h"
|
||||
|
||||
#include <limits>
|
||||
@ -182,7 +183,9 @@ bool TllmRuntime::executeContext(SizeType32 contextIndex) const
|
||||
{
|
||||
NVTX3_FUNC_RANGE();
|
||||
auto& context = getContext(contextIndex);
|
||||
return context.enqueueV3(mStream->get());
|
||||
auto res = context.enqueueV3(mStream->get());
|
||||
sync_check_cuda_error();
|
||||
return res;
|
||||
}
|
||||
|
||||
void TllmRuntime::setInputTensors(SizeType32 contextIndex, TensorMap const& tensorMap)
|
||||
|
||||
@ -15,12 +15,15 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/runtime/transformerBuffers.h"
|
||||
#include "iTensor.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/stlUtils.h"
|
||||
#include "tensorrt_llm/runtime/runtimeBuffers.h"
|
||||
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
||||
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
|
||||
#include <cstdlib> // std::getenv
|
||||
#include <vector>
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
@ -34,6 +37,7 @@ TransformerBuffers::TransformerBuffers()
|
||||
presentKeysVals.clear();
|
||||
presentKeysValsAlt.clear();
|
||||
kvCacheBlockPoolPointers = nullptr;
|
||||
kvCacheBlockPoolMapping = nullptr;
|
||||
kvCacheBlockOffsetsHost = nullptr;
|
||||
kvCacheBlockOffsetsDevice = nullptr;
|
||||
}
|
||||
@ -101,15 +105,16 @@ void TransformerBuffers::reshape(
|
||||
auto const maxAttentionWindow = generationConfig.maxAttentionWindow;
|
||||
|
||||
auto const kvCacheReserve = ITensor::makeShape(
|
||||
{batchSize, 2, modelConfig.getNbKvHeads(), maxAttentionWindow, modelConfig.getSizePerHead()});
|
||||
{batchSize, 2, modelConfig.getNbKvHeads(0), maxAttentionWindow, modelConfig.getSizePerHead()});
|
||||
auto const kvCacheShape
|
||||
= ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(), maxInputLength, modelConfig.getSizePerHead()});
|
||||
= ITensor::makeShape({batchSize, 2, modelConfig.getNbKvHeads(0), maxInputLength, modelConfig.getSizePerHead()});
|
||||
|
||||
if (modelConfig.isPagedKVCache())
|
||||
{
|
||||
auto cacheBlockOffsetsShape = kvCacheBlockOffsetsHost->getShape();
|
||||
if (cacheBlockOffsetsShape.nbDims > 0)
|
||||
{
|
||||
cacheBlockOffsetsShape.d[0] = batchSize;
|
||||
cacheBlockOffsetsShape.d[1] = batchSize;
|
||||
kvCacheBlockOffsetsHost->reshape(cacheBlockOffsetsShape);
|
||||
kvCacheBlockOffsetsDevice->reshape(cacheBlockOffsetsShape);
|
||||
}
|
||||
@ -123,7 +128,8 @@ void TransformerBuffers::reshape(
|
||||
utils::reshapeBufferVector(presentKeysVals, kvCacheReserve);
|
||||
}
|
||||
|
||||
auto const localNbLayers = modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism());
|
||||
auto const localNbLayers
|
||||
= modelConfig.getNbAttentionLayers(worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
|
||||
|
||||
if (modelConfig.useGptAttentionPlugin())
|
||||
{
|
||||
@ -147,7 +153,7 @@ void TransformerBuffers::reshapeKvTensors(
|
||||
{
|
||||
auto const& manager = runtime.getBufferManager();
|
||||
|
||||
auto const cacheBlockOffsetsShape = ITensor::makeShape({maxBatchSize * maxBeamWidth, 2, maxBlocksPerSeq});
|
||||
auto const cacheBlockOffsetsShape = ITensor::makeShape({1, maxBatchSize * maxBeamWidth, 2, maxBlocksPerSeq});
|
||||
|
||||
kvCacheBlockOffsetsHost->reshape(cacheBlockOffsetsShape);
|
||||
manager.setZero(*kvCacheBlockOffsetsHost);
|
||||
@ -161,6 +167,11 @@ void TransformerBuffers::setKvPoolPointers(KvCacheManager const* kvCacheManager)
|
||||
kvCacheBlockPoolPointers = kvCacheManager->getBlockPoolPointers();
|
||||
}
|
||||
|
||||
void TransformerBuffers::setKvPoolMapping(KvCacheManager const* kvCacheManager)
|
||||
{
|
||||
kvCacheBlockPoolMapping = kvCacheManager->getLayerToPoolMapping();
|
||||
}
|
||||
|
||||
TransformerBuffers TransformerBuffers::sliceTo(
|
||||
GenerationConfig const& generationConfig, ModelConfig const& modelConfig, SizeType32 offset, SizeType32 batchSize)
|
||||
{
|
||||
@ -169,8 +180,15 @@ TransformerBuffers TransformerBuffers::sliceTo(
|
||||
auto const generationBatchSize = generationConfig.batchSize;
|
||||
if (modelConfig.isPagedKVCache())
|
||||
{
|
||||
|
||||
auto const& realCacheBlockOffsetsShape = kvCacheBlockOffsetsHost->getShape();
|
||||
auto const maxBlocksPerSeq = realCacheBlockOffsetsShape.d[2];
|
||||
auto const numPools = realCacheBlockOffsetsShape.d[0];
|
||||
// (oargov) with multiple pools, slicing the tensor along the batch*beam dimension would require us to support
|
||||
// non-contiguous tensors. with a single pool, we can just ignore the pools dimension when slicing and restore
|
||||
// it later. this is part of the deprecated GPTSession API, so not supporting VGQA here should be ok.
|
||||
TLLM_CHECK_WITH_INFO(numPools == 1,
|
||||
"Deprecated transformerBuffers API does not support multiple cache pools, use the newer API instead");
|
||||
auto const maxBlocksPerSeq = realCacheBlockOffsetsShape.d[3];
|
||||
|
||||
// enable slicing by moving generationBatchSize to first dim
|
||||
auto const fakeCacheBlockOffsetsShape = ITensor::makeShape({generationBatchSize, 2, maxBlocksPerSeq});
|
||||
@ -178,13 +196,14 @@ TransformerBuffers TransformerBuffers::sliceTo(
|
||||
TensorPtr kvCacheBlockOffsetsDeviceView{ITensor::view(kvCacheBlockOffsetsDevice, fakeCacheBlockOffsetsShape)};
|
||||
|
||||
// slice and reshape to correct shape
|
||||
auto const cacheBlockOffsetsShape = ITensor::makeShape({batchSize, 2, maxBlocksPerSeq});
|
||||
auto const cacheBlockOffsetsShape = ITensor::makeShape({numPools, batchSize, 2, maxBlocksPerSeq});
|
||||
buffers.kvCacheBlockOffsetsHost = ITensor::slice(kvCacheBlockOffsetsHostView, offset, batchSize);
|
||||
buffers.kvCacheBlockOffsetsHost->reshape(cacheBlockOffsetsShape);
|
||||
buffers.kvCacheBlockOffsetsDevice = ITensor::slice(kvCacheBlockOffsetsDeviceView, offset, batchSize);
|
||||
buffers.kvCacheBlockOffsetsDevice->reshape(cacheBlockOffsetsShape);
|
||||
|
||||
buffers.kvCacheBlockPoolPointers = kvCacheBlockPoolPointers;
|
||||
buffers.kvCacheBlockPoolMapping = kvCacheBlockPoolMapping;
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -529,7 +548,7 @@ void TransformerBuffers::postContextStep(RuntimeBuffers* runtimeBuffers,
|
||||
if (modelConfig.useGptAttentionPlugin() && modelConfig.isPagedKVCache())
|
||||
{
|
||||
auto cacheBlockOffsetsShape = kvCacheBlockOffsetsHost->getShape();
|
||||
cacheBlockOffsetsShape.d[0] = batchSize * beamWidth;
|
||||
cacheBlockOffsetsShape.d[1] = batchSize * beamWidth;
|
||||
kvCacheBlockOffsetsHost->reshape(cacheBlockOffsetsShape);
|
||||
kvCacheBlockOffsetsDevice->reshape(cacheBlockOffsetsShape);
|
||||
}
|
||||
@ -720,6 +739,7 @@ void TransformerBuffers::getRuntimeBuffers(RuntimeBuffers const* runtimeBuffers,
|
||||
inputBuffers.insert_or_assign("kv_cache_block_offsets", kvCacheBlockOffsetsDevice);
|
||||
inputBuffers.insert_or_assign("host_kv_cache_block_offsets", kvCacheBlockOffsetsHost);
|
||||
inputBuffers.insert_or_assign("host_kv_cache_pool_pointers", kvCacheBlockPoolPointers);
|
||||
inputBuffers.insert_or_assign("host_kv_cache_pool_mapping", kvCacheBlockPoolMapping);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@ -53,6 +53,7 @@ public:
|
||||
runtime::TllmRuntime const& runtime);
|
||||
|
||||
void setKvPoolPointers(KvCacheManager const* kvCacheManager);
|
||||
void setKvPoolMapping(KvCacheManager const* kvCacheManager);
|
||||
|
||||
void reset(BufferManager& manager){};
|
||||
|
||||
@ -92,9 +93,10 @@ public:
|
||||
TensorPtr maxAttentionWindows; // with attention plugin, host tensor
|
||||
TensorPtr sinkTokenLengths; // with attention plugin, host tensor
|
||||
TensorPtr kvCacheBlockPoolPointers;
|
||||
TensorPtr kvCacheBlockOffsetsHost; // [batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
TensorPtr kvCacheBlockOffsetsDevice; // [batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
TensorPtr runtimePerfKnobsHost; // can hold max 16 perf knobs
|
||||
TensorPtr kvCacheBlockPoolMapping;
|
||||
TensorPtr kvCacheBlockOffsetsHost; // [numPools, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
TensorPtr kvCacheBlockOffsetsDevice; // [numPools, batchSize * beamWidth, 2, maxBlocksPerSeq * 2]
|
||||
TensorPtr runtimePerfKnobsHost; // can hold max 16 perf knobs
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
#include <cstddef>
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
@ -89,6 +90,16 @@ void reshapeBufferVector(std::vector<ITensor::SharedPtr>& vector, nvinfer1::Dims
|
||||
}
|
||||
}
|
||||
|
||||
void assertNoVGQA(ModelConfig const& modelConfig, WorldConfig const& worldConfig)
|
||||
{
|
||||
auto [numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd] = modelConfig.getNumKvHeadsPerLayerLocalRange(
|
||||
worldConfig.getPipelineParallelism(), worldConfig.getPipelineParallelRank());
|
||||
TLLM_CHECK_WITH_INFO(std::all_of(numKvHeadsPerLayerBegin, numKvHeadsPerLayerEnd,
|
||||
[firstNumKvHeads = *numKvHeadsPerLayerBegin](SizeType32 numKvHeads)
|
||||
{ return numKvHeads == firstNumKvHeads; }),
|
||||
"Deprecated session API does not support multiple cache pools, use the newer executor API instead");
|
||||
}
|
||||
|
||||
std::vector<ITensor::SharedPtr> sliceBufferVector(
|
||||
std::vector<ITensor::SharedPtr> const& vector, SizeType32 const offset, SizeType32 const size)
|
||||
{
|
||||
|
||||
@ -56,6 +56,8 @@ std::vector<ITensor::SharedPtr> createBufferVector(
|
||||
|
||||
void reshapeBufferVector(std::vector<ITensor::SharedPtr>& vector, nvinfer1::Dims const& shape);
|
||||
|
||||
void assertNoVGQA(ModelConfig const& modelConfig, WorldConfig const& worldConfig);
|
||||
|
||||
std::vector<ITensor::SharedPtr> sliceBufferVector(
|
||||
std::vector<ITensor::SharedPtr> const& vector, SizeType32 offset, SizeType32 size);
|
||||
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "tensorrt_llm/layers/lookaheadAlgorithm.h"
|
||||
#include "tensorrt_llm/layers/lookaheadDecodingUtils.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/lookaheadModule.h"
|
||||
#include "tests/layers/randomLlm.h"
|
||||
|
||||
@ -84,9 +85,10 @@ TEST_P(LookaheadAlgorithmTest, predict)
|
||||
std::tie(std::ignore, std::ignore, maxDraftLenRuntime, std::ignore)
|
||||
= executor::LookaheadDecodingConfig(w, n, g).calculateSpeculativeResource();
|
||||
auto shape = ITensor::makeShape({maxTokensPerStep});
|
||||
auto shape2d = ITensor::makeShape({maxTokensPerStep, maxTokensPerStep});
|
||||
auto shapeSingle = ITensor::makeShape({1});
|
||||
TensorPtr posidMax = BufferManager::cpu(shape, nvinfer1::DataType::kINT32);
|
||||
TensorPtr smaskMax = BufferManager::cpu(shape, nvinfer1::DataType::kBOOL);
|
||||
TensorPtr attentionMaskMax = BufferManager::cpu(shape2d, nvinfer1::DataType::kBOOL);
|
||||
TensorPtr inputLengthPtr = BufferManager::cpu(shapeSingle, nvinfer1::DataType::kINT32);
|
||||
auto& inputLength(*BufferRange<SizeType32>(*inputLengthPtr).begin());
|
||||
|
||||
@ -123,26 +125,34 @@ TEST_P(LookaheadAlgorithmTest, predict)
|
||||
{
|
||||
TLLM_LOG_DEBUG("\noracle[%d] = '%c'", sequenceLength - 1, static_cast<char>(sequenceRange[sequenceLength - 1]));
|
||||
bufferCast<SizeType32>(*posidMax)[0] = sequenceLength - 1;
|
||||
bufferCast<bool>(*smaskMax)[0] = true;
|
||||
BufferLocation<bool> amaskLocation(*attentionMaskMax);
|
||||
for (auto& item : amaskLocation)
|
||||
{
|
||||
item = false;
|
||||
}
|
||||
for (SizeType32 i = 0; i < maxTokensPerStep; i++)
|
||||
{
|
||||
amaskLocation.at(i, 0) = true;
|
||||
}
|
||||
|
||||
algo.prepare( //
|
||||
ITensor::slice(sequence, sequenceLength, maxDraftLenRuntime), //
|
||||
ITensor::slice(posidMax, 1, maxDraftLenRuntime), //
|
||||
ITensor::slice(smaskMax, 1, maxDraftLenRuntime), //
|
||||
inputLengthPtr, //
|
||||
attentionMaskMax, 1, //
|
||||
sequenceLengthPtr, //
|
||||
ITensor::slice(sequence, sequenceLength - 1, 1));
|
||||
|
||||
TensorPtr input = ITensor::slice(sequence, sequenceLength - 1, inputLength + 1);
|
||||
TensorPtr posid = ITensor::slice(posidMax, 0, inputLength + 1);
|
||||
TensorPtr smask = ITensor::slice(smaskMax, 0, inputLength + 1);
|
||||
TensorPtr amask = ITensor::slice(attentionMaskMax, 0, inputLength + 1);
|
||||
|
||||
PRINT_TOKENS(input);
|
||||
PRINT_VALUES(posid);
|
||||
PRINT_VALUES(smask);
|
||||
PRINT_VALUES(amask);
|
||||
|
||||
TensorPtr output = ITensor::slice(outputMax, 0, inputLength + 1);
|
||||
llm.foretell(output, input, posid);
|
||||
llm.sampleByMask(output, smask);
|
||||
llm.foretell(output, input, posid, amask);
|
||||
PRINT_TOKENS(output);
|
||||
|
||||
// algo.update(acceptedMax, acceptedOffsetsMax, acceptedLengthPtr, output, endIdPtr);
|
||||
@ -207,4 +217,46 @@ INSTANTIATE_TEST_CASE_P(CombineLookaheadAlgorithmTestSmall_222, LookaheadAlgorit
|
||||
testing::Combine(testing::Values(std::make_tuple(2, 2)), testing::Values(std::make_tuple(2, 2)),
|
||||
testing::Values(std::make_tuple(2, 2))));
|
||||
|
||||
TEST(LookaheadAlgorithmTest, treeEncodeTest)
|
||||
{
|
||||
auto testWithData = [](TensorPtr inputTokens, TensorPtr inputPosIds, SizeType32 lastPosId, SizeType32 gold_len)
|
||||
{
|
||||
auto shape = inputTokens->getShape();
|
||||
auto shape2d = ITensor::makeShape({shape.d[0], shape.d[0]});
|
||||
|
||||
TensorPtr inputMasks = BufferManager::cpu(shape2d, nvinfer1::DataType::kBOOL);
|
||||
LookaheadAlgorithm::posIdsToMask(inputMasks, inputPosIds);
|
||||
|
||||
TensorPtr outputTokens = BufferManager::cpu(shape, nvinfer1::DataType::kINT32);
|
||||
TensorPtr outputPosIds = BufferManager::cpu(shape, nvinfer1::DataType::kINT32);
|
||||
TensorPtr encodeMap = BufferManager::cpu(shape, nvinfer1::DataType::kINT32);
|
||||
TensorPtr outputMasks = BufferManager::cpu(shape2d, nvinfer1::DataType::kBOOL);
|
||||
|
||||
// auto len = LookaheadAlgorithm::treeEncode(outputTokens, outputPosIds, outputMasks, inputTokens, inputPosIds,
|
||||
// inputMasks, '$', 9);
|
||||
auto len = LookaheadAlgorithm::treeEncode(inputTokens, inputPosIds, inputMasks, encodeMap);
|
||||
TLLM_LOG_DEBUG("len = %d", len);
|
||||
|
||||
EXPECT_EQ(len, gold_len);
|
||||
};
|
||||
|
||||
testWithData( //
|
||||
initTensor(std::string("01234512345")), //
|
||||
initTensor({10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15}), //
|
||||
9, 6);
|
||||
|
||||
testWithData( //
|
||||
initTensor(std::string("01234512abc")), //
|
||||
initTensor({10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15}), //
|
||||
9, 9);
|
||||
|
||||
testWithData( //
|
||||
initTensor(std::string("01234512abc2aBCD")), //
|
||||
initTensor({10, 11, 12, 13, 14, 15, 11, 12, 13, 14, 15, 12, 13, 14, 15, 16}), //
|
||||
9, 12);
|
||||
|
||||
testWithData(initTensor(std::string("wmplhi folxamp")),
|
||||
initTensor({21, 22, 23, 24, 25, 26, 27, 21, 22, 23, 24, 21, 22, 23, 24}), 20, 15);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::tests::layers
|
||||
|
||||
@ -230,11 +230,11 @@ protected:
|
||||
TensorPtr mNumNewTokensCumSum;
|
||||
TensorPtr mPathsOffsets;
|
||||
TensorPtr mDraftLengths;
|
||||
TensorPtr mPrevDraftLengths;
|
||||
TensorPtr mDraftTokens;
|
||||
TensorPtr mPackedMasks;
|
||||
TensorPtr mPackedMasksBool;
|
||||
TensorPtr mGenerationLengths;
|
||||
TensorPtr mGenerationLengthsMax;
|
||||
TensorPtr mPositionOffsets;
|
||||
TensorPtr mPositionIds;
|
||||
TensorPtr mAttentionPackedMask;
|
||||
@ -371,6 +371,7 @@ void LookaheadDecodingLayerTest::allocateBuffers()
|
||||
ITensor::makeShape({mMaxTokensPerStep, maxBatchSize, 1}), nvinfer1::DataType::kINT32);
|
||||
mNumNewTokens = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
mDraftLengths = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
mPrevDraftLengths = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
mDraftTokens
|
||||
= BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize, maxDraftLen}), nvinfer1::DataType::kINT32);
|
||||
auto packedMaskShape = ITensor::makeShape(
|
||||
@ -382,7 +383,6 @@ void LookaheadDecodingLayerTest::allocateBuffers()
|
||||
mPathsOffsets = BufferManager::pinnedPool(
|
||||
ITensor::makeShape({maxBatchSize, maxAcceptedDraftLen}), nvinfer1::DataType::kINT32);
|
||||
mGenerationLengths = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
mGenerationLengthsMax = BufferManager::pinnedPool(maxBatchShape1D, nvinfer1::DataType::kINT32);
|
||||
mPositionOffsets
|
||||
= BufferManager::pinnedPool(ITensor::makeShape({maxBatchSize, mMaxTokensPerStep}), nvinfer1::DataType::kINT32);
|
||||
mPositionIds
|
||||
@ -462,10 +462,8 @@ void LookaheadDecodingLayerTest::newRequests(std::vector<SizeType32> requestIds)
|
||||
setupParams->prompt.emplace_back(mPrompt[gbi]);
|
||||
setupParams->algoConfigs.emplace_back(mTestParam.w, mTestParam.n, mTestParam.g);
|
||||
PRINT_TOKENS(setupParams->prompt[bi]);
|
||||
setupParams->generationLengths = mGenerationLengthsMax;
|
||||
setupParams->actualGenerationLengths = mGenerationLengths;
|
||||
setupParams->generationLengths = mGenerationLengths;
|
||||
setupParams->positionOffsets = mPositionOffsets;
|
||||
// setupParams->outputs.positionIds = mPositionIds;
|
||||
setupParams->attentionPackedMasks = mPackedMasks;
|
||||
}
|
||||
std::vector<uint64_t> seed(requestIds.begin(), requestIds.end());
|
||||
@ -669,14 +667,14 @@ void LookaheadDecodingLayerTest::decodeForward()
|
||||
PRINT_VALUES(mSequenceLengths);
|
||||
outputParams->sequenceLength = mSequenceLengths;
|
||||
outputParams->nextDraftLengths = mDraftLengths;
|
||||
outputParams->prevDraftLengths = mPrevDraftLengths;
|
||||
outputParams->nextDraftTokens = mDraftTokens;
|
||||
outputParams->packedMasks = mPackedMasks;
|
||||
outputParams->numNewTokens = mNumNewTokens;
|
||||
outputParams->newTokens = mNewTokens;
|
||||
outputParams->numNewTokensCumSum = mNumNewTokensCumSum;
|
||||
outputParams->pathsOffsets = mPathsOffsets;
|
||||
outputParams->generationLengths = mGenerationLengthsMax;
|
||||
outputParams->actualGenerationLengths = mGenerationLengths;
|
||||
outputParams->generationLengths = mGenerationLengths;
|
||||
outputParams->positionOffsets = mPositionOffsets;
|
||||
outputParams->positionIds = mPositionIds;
|
||||
outputParams->packedMasks = mPackedMasks;
|
||||
|
||||
@ -276,8 +276,8 @@ void LookaheadRandomLlm::foretell(TensorPtr const& output, TensorConstPtr const&
|
||||
{
|
||||
right &= maskLocation.at(i, j) ? oracleRange[positionRange[j]] == inputRange[j] : true;
|
||||
}
|
||||
if (i < verifyStart)
|
||||
{ // lookahead might be right
|
||||
if (i < verifyStart && false)
|
||||
{ // lookahead might be right. Since we verify lookahead branch, then must be right.
|
||||
outputRange[i] = ((right || rand() % 5) && legal) ? oracleRange[positionRange[i] + 1] : invalid;
|
||||
}
|
||||
else
|
||||
|
||||
@ -90,7 +90,7 @@ def build_engines(model_cache: str, only_multi_gpu: bool):
|
||||
|
||||
tp_pp_sizes = [(1, 1)]
|
||||
if only_multi_gpu:
|
||||
tp_pp_sizes = [(1, 4), (4, 1), (1, 2), (2, 2)]
|
||||
tp_pp_sizes = [(1, 4), (4, 1), (1, 2), (2, 2), (2, 1)]
|
||||
for tp_size, pp_size in tp_pp_sizes:
|
||||
tp_pp_dir = f"tp{tp_size}-pp{pp_size}-gpu"
|
||||
print(f"\nBuilding fp16 tp{tp_size} pp{pp_size} engine")
|
||||
|
||||
@ -72,7 +72,7 @@ def generate_outputs(num_beams, only_multi_gpu=False):
|
||||
elif COMM_WORLD.size == 4:
|
||||
tp_pp_sizes = [(4, 1), (2, 2), (1, 4)]
|
||||
elif COMM_WORLD.size == 2:
|
||||
tp_pp_sizes = [(1, 2)]
|
||||
tp_pp_sizes = [(1, 2), (2, 1)]
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"The world size of MPI {COMM_WORLD.size} is not equal to 1, 2, or 4."
|
||||
|
||||
@ -664,6 +664,17 @@ def run_single_gpu_tests(build_dir: _pl.Path,
|
||||
if excluded_tests:
|
||||
ctest.extend(["-E", "|".join(excluded_tests)])
|
||||
parallel_run_ctest(ctest, cwd=build_dir, env=cpp_env, timeout=timeout)
|
||||
if run_gpt:
|
||||
xml_output_file = build_dir / "results-single-gpu-disagg-executor_gpt.xml"
|
||||
trt_model_test = produce_mpirun_command(
|
||||
global_commands=["mpirun", "--allow-run-as-root"],
|
||||
nranks=2,
|
||||
local_commands=[
|
||||
"tests/executor/executorTest",
|
||||
"--gtest_filter=*GptSingleDeviceDisaggExecutorTest*"
|
||||
],
|
||||
leader_commands=[f"--gtest_output=xml:{xml_output_file}"])
|
||||
run_command(trt_model_test, cwd=build_dir, env=cpp_env, timeout=timeout)
|
||||
|
||||
|
||||
def produce_mpirun_command(*, global_commands, nranks, local_commands,
|
||||
@ -777,25 +788,37 @@ def run_multi_gpu_tests(build_dir: _pl.Path, timeout=1500):
|
||||
run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500)
|
||||
|
||||
new_env = copy.copy(cpp_env)
|
||||
xml_output_file = build_dir / "results-multi-gpu-dist-executor_gpt.xml"
|
||||
xml_output_file = build_dir / "results-multi-gpu-disagg-executor-2-process.xml"
|
||||
trt_model_test = produce_mpirun_command(
|
||||
global_commands=["mpirun", "--allow-run-as-root"],
|
||||
nranks=2,
|
||||
local_commands=[
|
||||
"executor/executorTest",
|
||||
"--gtest_filter=DistExecutorTest.GPTTokenComparison"
|
||||
"executor/executorTest", "--gtest_filter=*DisaggExecutorTest*"
|
||||
],
|
||||
leader_commands=[f"--gtest_output=xml:{xml_output_file}"])
|
||||
run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500)
|
||||
|
||||
new_env = copy.copy(cpp_env)
|
||||
xml_output_file = build_dir / "results-multi-gpu-dist-executor_chatglm.xml"
|
||||
new_env["RUN_LLAMA_MULTI_GPU"] = "true"
|
||||
xml_output_file = build_dir / "results-multi-gpu-disagg-executor-4-process.xml"
|
||||
trt_model_test = produce_mpirun_command(
|
||||
global_commands=["mpirun", "--allow-run-as-root"],
|
||||
nranks=2,
|
||||
nranks=4,
|
||||
local_commands=[
|
||||
"executor/executorTest", "--gtest_filter=*DisaggExecutorTest*"
|
||||
],
|
||||
leader_commands=[f"--gtest_output=xml:{xml_output_file}"])
|
||||
run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500)
|
||||
|
||||
new_env = copy.copy(cpp_env)
|
||||
new_env["RUN_LLAMA_MULTI_GPU"] = "true"
|
||||
xml_output_file = build_dir / "results-multi-gpu-disagg-executor-8-process.xml"
|
||||
trt_model_test = produce_mpirun_command(
|
||||
global_commands=["mpirun", "--allow-run-as-root"],
|
||||
nranks=8,
|
||||
local_commands=[
|
||||
"executor/executorTest",
|
||||
"--gtest_filter=DistExecutorTest.ChatGLMTokenComparison"
|
||||
"--gtest_filter=*LlamaTP2PP2DisaggExecutorTest*"
|
||||
],
|
||||
leader_commands=[f"--gtest_output=xml:{xml_output_file}"])
|
||||
run_command(trt_model_test, cwd=tests_dir, env=new_env, timeout=1500)
|
||||
|
||||
@ -195,7 +195,8 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector<SamplingConfig>& sa
|
||||
SizeType32 constexpr nbRnnLayers{0};
|
||||
SizeType32 constexpr nbHeads{16};
|
||||
SizeType32 constexpr hiddenSize{1024};
|
||||
ModelConfig modelConfig{vocabSize, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, dtype};
|
||||
ModelConfig modelConfig{
|
||||
vocabSize, nbAttentionLayers + nbRnnLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, dtype};
|
||||
modelConfig.useGptAttentionPlugin(false);
|
||||
|
||||
auto streamPtr = std::make_shared<CudaStream>();
|
||||
@ -315,7 +316,8 @@ void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector<SamplingCo
|
||||
SizeType32 constexpr nbRnnLayers{0};
|
||||
SizeType32 constexpr nbHeads{16};
|
||||
SizeType32 constexpr hiddenSize{1024};
|
||||
ModelConfig modelConfig{vocabSize, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, dtype};
|
||||
ModelConfig modelConfig{
|
||||
vocabSize, nbAttentionLayers + nbRnnLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, dtype};
|
||||
modelConfig.useGptAttentionPlugin(false);
|
||||
|
||||
auto streamPtr = std::make_shared<CudaStream>();
|
||||
@ -440,7 +442,8 @@ void testDecoderDraft(nvinfer1::DataType const dtype, std::vector<SamplingConfig
|
||||
SizeType32 constexpr nbRnnLayers{0};
|
||||
SizeType32 constexpr nbHeads{16};
|
||||
SizeType32 constexpr hiddenSize{1024};
|
||||
ModelConfig modelConfig{vocabSize, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, dtype};
|
||||
ModelConfig modelConfig{
|
||||
vocabSize, nbAttentionLayers + nbRnnLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, dtype};
|
||||
modelConfig.useGptAttentionPlugin(false);
|
||||
modelConfig.setSpeculativeDecodingMode(SpeculativeDecodingMode::DraftTokensExternal());
|
||||
|
||||
|
||||
@ -81,7 +81,7 @@ void testDecoder(nvinfer1::DataType const dtype, SamplingConfig const& samplingC
|
||||
SizeType32 constexpr nbHeads{16};
|
||||
SizeType32 constexpr hiddenSize{1024};
|
||||
SizeType32 constexpr batchSize{4};
|
||||
ModelConfig modelConfig{vocabSize, nbLayers, nbRnnLayers, nbHeads, hiddenSize, dtype};
|
||||
ModelConfig modelConfig{vocabSize, nbLayers + nbRnnLayers, nbLayers, nbRnnLayers, nbHeads, hiddenSize, dtype};
|
||||
modelConfig.useGptAttentionPlugin(false);
|
||||
|
||||
SizeType32 constexpr maxInputLength{8};
|
||||
|
||||
@ -62,7 +62,7 @@ protected:
|
||||
|
||||
void SetUp() override
|
||||
{
|
||||
mModelConfig = std::make_unique<ModelConfig>(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
mModelConfig = std::make_unique<ModelConfig>(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
mModelConfig->setMlpHiddenSize(32);
|
||||
mWorldConfig = std::make_unique<WorldConfig>(2, 1, 0);
|
||||
std::vector<LoraModule> modules{
|
||||
@ -166,8 +166,8 @@ TEST_F(LoraCacheTest, LoraCachePageManagerTest)
|
||||
|
||||
TEST_F(LoraCacheTest, determineNumPages)
|
||||
{
|
||||
ModelConfig modelConfig(0, 2, 0, 1, 4, nvinfer1::DataType::kFLOAT);
|
||||
modelConfig.setLoraModules(LoraModule::createLoraModules({"attn_dense", "attn_qkv"}, 4, 4, 1, 1, 2, 2));
|
||||
ModelConfig modelConfig(0, 2, 2, 0, 1, 4, nvinfer1::DataType::kFLOAT);
|
||||
modelConfig.setLoraModules(LoraModule::createLoraModules({"attn_dense", "attn_qkv"}, 4, 4, 1, 1, 2, 2, 0));
|
||||
WorldConfig worldConfig(1, 1, 0);
|
||||
|
||||
LoraCachePageManagerConfig pageConfig(MemoryType::kCPU, nvinfer1::DataType::kFLOAT, 12393, 40, 80, 16, 1);
|
||||
@ -358,7 +358,7 @@ TEST_F(LoraCacheTest, basicPutGet)
|
||||
|
||||
TEST_F(LoraCacheTest, splitTransposeCpu)
|
||||
{
|
||||
auto modelConfig = ModelConfig(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
auto worldConfig = WorldConfig(2, 1, 0);
|
||||
|
||||
SizeType32 const split{2};
|
||||
@ -421,7 +421,7 @@ TEST_F(LoraCacheTest, splitTransposeCpu)
|
||||
|
||||
TEST_F(LoraCacheTest, copyToPages_tp1)
|
||||
{
|
||||
auto modelConfig = ModelConfig(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
modelConfig.setMlpHiddenSize(32);
|
||||
auto worldConfig = WorldConfig(1, 1, 0);
|
||||
std::vector<LoraModule> modules{
|
||||
@ -479,7 +479,7 @@ TEST_F(LoraCacheTest, copyToPages_tp1)
|
||||
|
||||
TEST_F(LoraCacheTest, copyToPages_tp2_rank0)
|
||||
{
|
||||
auto modelConfig = ModelConfig(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
modelConfig.setMlpHiddenSize(32);
|
||||
auto worldConfig = WorldConfig(2, 1, 0);
|
||||
std::vector<LoraModule> modules{
|
||||
@ -536,7 +536,7 @@ TEST_F(LoraCacheTest, copyToPages_tp2_rank0)
|
||||
|
||||
TEST_F(LoraCacheTest, copyToPages_tp2_rank1)
|
||||
{
|
||||
auto modelConfig = ModelConfig(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
modelConfig.setMlpHiddenSize(32);
|
||||
auto worldConfig = WorldConfig(2, 1, 1);
|
||||
std::vector<LoraModule> modules{
|
||||
|
||||
@ -59,7 +59,7 @@ class LoraManagerTest : public ::testing::Test // NOLINT(cppcoreguidelines-pro-t
|
||||
{
|
||||
protected:
|
||||
LoraManagerTest()
|
||||
: mModelConfig(1, 2, 0, 1, 4, nvinfer1::DataType::kFLOAT)
|
||||
: mModelConfig(1, 2, 2, 0, 1, 4, nvinfer1::DataType::kFLOAT)
|
||||
{
|
||||
}
|
||||
|
||||
@ -70,7 +70,7 @@ protected:
|
||||
|
||||
mWorldConfig = WorldConfig(2);
|
||||
|
||||
mModelConfig.setLoraModules(LoraModule::createLoraModules({"attn_dense", "attn_qkv"}, 4, 4, 1, 1, 2, 2));
|
||||
mModelConfig.setLoraModules(LoraModule::createLoraModules({"attn_dense", "attn_qkv"}, 4, 4, 1, 1, 2, 2, 0));
|
||||
}
|
||||
|
||||
std::unique_ptr<BufferManager> mManager;
|
||||
@ -80,7 +80,7 @@ protected:
|
||||
|
||||
PeftTable getPeftTable(SizeType32 tpRank = 0)
|
||||
{
|
||||
auto modelConfig = ModelConfig(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
modelConfig.setMlpHiddenSize(32);
|
||||
auto worldConfig = WorldConfig(2, 2, 3);
|
||||
std::vector<LoraModule> modules{
|
||||
@ -292,7 +292,7 @@ static std::tuple<std::vector<int32_t>, std::vector<int64_t>, PeftTable> createF
|
||||
TEST_F(LoraManagerTest, fillInputTensors)
|
||||
{
|
||||
LoraManager loraManager;
|
||||
auto modelConfig = ModelConfig(0, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 16, nvinfer1::DataType::kFLOAT);
|
||||
modelConfig.setMlpHiddenSize(32);
|
||||
auto worldConfig = WorldConfig(1, 1, 0);
|
||||
std::vector<LoraModule> modules{
|
||||
|
||||
@ -86,7 +86,7 @@ TEST_F(LoraUtilsTest, dims_mem_type)
|
||||
|
||||
TEST_F(LoraUtilsTest, loraValidateRequestTensors)
|
||||
{
|
||||
auto modelConfig = ModelConfig(0, 2, 0, 1, 4, nvinfer1::DataType::kFLOAT);
|
||||
auto modelConfig = ModelConfig(0, 2, 2, 0, 1, 4, nvinfer1::DataType::kFLOAT);
|
||||
auto worldConfig = WorldConfig();
|
||||
|
||||
std::optional<TensorPtr> optReqLoraWeights
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
set -ex
|
||||
|
||||
ARCH=$(uname -m)
|
||||
CMAKE_VERSION="3.24.4"
|
||||
CMAKE_VERSION="3.30.2"
|
||||
|
||||
PARSED_CMAKE_VERSION=$(echo $CMAKE_VERSION | sed 's/\.[0-9]*$//')
|
||||
CMAKE_FILE_NAME="cmake-${CMAKE_VERSION}-linux-${ARCH}"
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
|
||||
set -ex
|
||||
|
||||
TRT_VER="10.3.0.26"
|
||||
TRT_VER="10.4.0.26"
|
||||
# Align with the pre-installed cuDNN / cuBLAS / NCCL versions from
|
||||
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes/rel-24-07.html#rel-24-07
|
||||
CUDA_VER="12.5" # 12.5.1
|
||||
@ -14,6 +14,7 @@ CUBLAS_VER="12.5.3.2-1"
|
||||
# Align with the pre-installed CUDA / NVCC / NVRTC versions from
|
||||
# https://docs.nvidia.com/cuda/archive/12.5.1/cuda-toolkit-release-notes/index.html
|
||||
NVRTC_VER="12.5.82-1"
|
||||
CUDA_RUNTIME="12.5.82-1"
|
||||
|
||||
for i in "$@"; do
|
||||
case $i in
|
||||
@ -71,12 +72,14 @@ install_centos_requirements() {
|
||||
yum -y install epel-release
|
||||
wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/libnccl-${NCCL_VER}.x86_64.rpm
|
||||
wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/libnccl-devel-${NCCL_VER}.x86_64.rpm
|
||||
yum remove -y libnccl* && yum -y localinstall libnccl-${NCCL_VER}.x86_64.rpm libnccl-devel-${NCCL_VER}.x86_64.rpm
|
||||
wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/cuda-toolkit-${CUBLAS_CUDA_VERSION}-config-common-${NVRTC_VER}.noarch.rpm
|
||||
yum remove -y cuda-toolkit* && yum -y localinstall cuda-toolkit-${CUBLAS_CUDA_VERSION}-config-common-${NVRTC_VER}.noarch.rpm
|
||||
yum remove -y "libnccl*" && yum -y localinstall libnccl-${NCCL_VER}.x86_64.rpm libnccl-devel-${NCCL_VER}.x86_64.rpm
|
||||
wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/cuda-toolkit-${CUBLAS_CUDA_VERSION}-config-common-${CUDA_RUNTIME}.noarch.rpm
|
||||
wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/cuda-toolkit-12-config-common-${CUDA_RUNTIME}.noarch.rpm
|
||||
wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/cuda-toolkit-config-common-${CUDA_RUNTIME}.noarch.rpm
|
||||
yum remove -y "cuda-toolkit*" && yum -y localinstall cuda-toolkit-${CUBLAS_CUDA_VERSION}-config-common-${CUDA_RUNTIME}.noarch.rpm cuda-toolkit-12-config-common-${CUDA_RUNTIME}.noarch.rpm cuda-toolkit-config-common-${CUDA_RUNTIME}.noarch.rpm
|
||||
wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.x86_64.rpm
|
||||
wget -q https://developer.download.nvidia.cn/compute/cuda/repos/rhel8/x86_64/libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.x86_64.rpm
|
||||
yum remove -y libcublas* && yum -y localinstall libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.x86_64.rpm libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.x86_64.rpm
|
||||
yum remove -y "libcublas*" && yum -y localinstall libcublas-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.x86_64.rpm libcublas-devel-${CUBLAS_CUDA_VERSION}-${CUBLAS_VER}.x86_64.rpm
|
||||
yum clean all
|
||||
nvcc --version
|
||||
}
|
||||
@ -84,7 +87,7 @@ install_centos_requirements() {
|
||||
install_tensorrt() {
|
||||
PY_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[0:2])))')
|
||||
PARSED_PY_VERSION=$(echo "${PY_VERSION//./}")
|
||||
TRT_CUDA_VERSION="12.5"
|
||||
TRT_CUDA_VERSION="12.6"
|
||||
|
||||
if [ -z "$RELEASE_URL_TRT" ];then
|
||||
ARCH=${TRT_TARGETARCH}
|
||||
@ -92,8 +95,8 @@ install_tensorrt() {
|
||||
if [ "$ARCH" = "arm64" ];then ARCH="aarch64";fi
|
||||
if [ "$ARCH" = "amd64" ];then ARCH="x86_64";fi
|
||||
if [ "$ARCH" = "x86_64" ];then DIR_NAME="x64-agnostic"; else DIR_NAME=${ARCH};fi
|
||||
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu22_04" && OS2="Ubuntu-22.04" && OS="ubuntu-22.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
|
||||
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
|
||||
if [ "$ARCH" = "aarch64" ];then OS1="Ubuntu24_04" && OS2="Ubuntu-24.04" && OS="ubuntu-24.04"; else OS1="Linux" && OS2="Linux" && OS="linux";fi
|
||||
RELEASE_URL_TRT=https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/tars/TensorRT-${TRT_VER}.${OS2}.${ARCH}-gnu.cuda-${TRT_CUDA_VERSION}.tar.gz
|
||||
fi
|
||||
wget --no-verbose ${RELEASE_URL_TRT} -O /tmp/TensorRT.tar
|
||||
tar -xf /tmp/TensorRT.tar -C /usr/local/
|
||||
|
||||
@ -147,6 +147,7 @@ Note: this feature isn't supported with the `V1` batching scheme for the moment.
|
||||
* `capacitySchedulerPolicy`, policy used to select the subset available requests in each iteration of the InflightBatching generation loop.
|
||||
- `MAX_UTILIZATION` packs as many requests as the underlying TRT engine can support in any iteration of the InflightBatching generation loop. While this is expected to maximize GPU throughput, it might require that some requests be paused and restarted depending on peak KV cache memory availability.
|
||||
- `GUARANTEED_NO_EVICT` uses KV cache more conservatively guaranteeing that a request, once started, will run to completion without eviction.
|
||||
- `STATIC_BATCH` similarly to `GUARANTEED_NO_EVICT` schedules the maximum possible batch size without eviction. New requests are scheduled only after all requests in the previous batch have finished.
|
||||
|
||||
### Optional GptManager parameters
|
||||
* `TrtGptModelOptionalParams` class encapsulates the following fields:
|
||||
@ -227,6 +228,9 @@ It can also adopt a more conservative approach and schedule requests only when i
|
||||
knows that the memory allocation will be sufficient to process all active requests
|
||||
even in the worst case of KV cache consumption. That mode corresponds to a
|
||||
`SchedulerConfig::capacitySchedulerPolicy` set to `kGUARANTEED_NO_EVICT`.
|
||||
Another traditional batching scheme with a batch of requests running in lockstep
|
||||
until generation for all of them is completed corresponds to
|
||||
`SchedulerConfig::capacitySchedulerPolicy` set to `kSTATIC_BATCH`.
|
||||
|
||||
The `GptManager`'s worker thread terminates when the `GptManager` destructor is
|
||||
called and there are no more active requests.
|
||||
|
||||
@ -50,7 +50,7 @@ If replication is expensive or infeasible, use `LogitsPostProcessorConfig::setRe
|
||||
|
||||
The `Request` class is used to define properties of the request, such as the input token ids and the maximum number of tokens to generate. The `streaming` parameter can be used to indicate if the request should generate a response for each new generated tokens (`streaming = true`) or only after all tokens have been generated (`streaming = false`). Other mandatory parameters of the request include the sampling configuration (defined by the `SamplingConfig` class) which contains parameters controlling the decoding process and the output configuration (defined by the `OutputConfig` class) which controls what information should be included in the `Result` for a particular response.
|
||||
|
||||
Optional parameters can also be provided when constructing a request such as a list of bad words, a list of stop words, a client id, or configurations objects for prompt tuning, LoRA, or speculative decoding for example.
|
||||
Optional parameters can also be provided when constructing a request such as a list of bad words, a list of stop words, a client id, or configurations objects for prompt tuning, LoRA, or speculative decoding, or a number of sequences to generate for example.
|
||||
|
||||
### The Response Class
|
||||
|
||||
@ -58,7 +58,19 @@ The `awaitResponses` method of the `Executor` class returns a vector of response
|
||||
|
||||
### The Result Class
|
||||
|
||||
The `Result` class holds the result for a given request. It contains a Boolean parameter called `isFinal` that indicates if this is the last `Result` that will be returned for the given request id. It also contains the generated tokens. If the request is configured with `streaming = false`, the `isFinal` Boolean will be set to `true` and all generated tokens will be included in the `outputTokenIds`. If `streaming = true` is used, a `Result` will only include 1 token and the `isFinal` flag will be set to `true` for the last result associated with this request.
|
||||
The `Result` class holds the result for a given request. It contains a Boolean parameter called `isFinal` that indicates if this is the last `Result` that will be returned for the given request id. It also contains the generated tokens. If the request is configured with `streaming = false` and `numReturnSequences = 1`, a single response will be returned, the `isFinal` Boolean will be set to `true` and all generated tokens will be included in the `outputTokenIds`. If `streaming = true` and `numReturnSequences = 1` is used, a `Result` will include one or more tokens (depending on the request `returnAllGeneratedTokens` parameter) except the last result and the `isFinal` flag will be set to `true` for the last result associated with this request.
|
||||
|
||||
The request `numReturnSequences` parameter controls the number of output sequences to generate for each prompt. When this option is used, the Executor will return at least `numReturnSequences` responses for each request, each containing one Result. The `sequenceIndex` attribute of the `Result` class indicates the index of the generated sequence in the result (`0 <= sequenceIndex < numReturnSequences`). It contains a Boolean parameter called `isSequenceFinal` that indicates if this is the last result for the sequence and also contains a Boolean parameter `isFinal` that indicates when all sequences for the request have been generated. When `numReturnSequences = 1`, `isFinal` is identical to `isSequenceFinal`.
|
||||
|
||||
Here is an example that shows how a subset of 3 responses might look like for `numReturnSequences = 3`:
|
||||
|
||||
```
|
||||
Response 1: requestId = 1, Result with sequenceIndex = 0, isSequenceFinal = false, isFinal = false
|
||||
Response 2: requestId = 1, Result with sequenceIndex = 1, isSequenceFinal = true, isFinal = false
|
||||
Response 3: requestId = 1, Result with sequenceIndex = 2, isSequenceFinal = false, isFinal = false
|
||||
```
|
||||
|
||||
In this example, each response contains one result for different sequences. The `isSequenceFinal` flag of the second Result is set to true, indicating that it is the last result for `sequenceIndex = 1`, however, the isFinal flag of each Response is set to false because sequences 0 and 2 are not completed.
|
||||
|
||||
### Sending Requests with Different Beam Widths
|
||||
|
||||
|
||||
@ -195,6 +195,7 @@ loader = ModelWeightsLoader(external_checkpoint_dir, llava_dict)
|
||||
loader.generate_tllm_weights(trtllm_model)
|
||||
```
|
||||
Users need to specify the different part from the default `tllm_to_externel_key_dict`. The loader still have support across different precisions.
|
||||
The support for LLaVA and Exaone is in `LLaMAForCausalLM.from_hugging_face()` of [model.py](../../../tensorrt_llm/models/llama/model.py), and can also be taken as examples.
|
||||
|
||||
### Models with customized weight layout
|
||||
For models with different weight layout, users can write the conversion loop explicitly and do customized operations.
|
||||
@ -225,9 +226,10 @@ for tllm_key, _ in tqdm(trtllm_model.named_parameters()):
|
||||
tllm_weights.update(loader.load(tllm_key, preprocess=customized_preprocess))
|
||||
else:
|
||||
tllm_weights.update(loader.load(tllm_key))
|
||||
loader.check(tllm_weights)
|
||||
loader.fill(tllm_weights)
|
||||
```
|
||||
This will apply `preprocess` after `load_tensor()` and before `postprocess`, and demonstrates how to convert the loaded shard into default HF layout. The loader still have support for precisions quantized from FP16/BF16 (e.g. INT8-wo/INT4-wo), the other precisions may require special operations, and can be addressed inside the `preprocess` function.
|
||||
The support for Qwen-1 is in `QWenForCausalLM.from_hugging_face()` of [model.py](../../../tensorrt_llm/models/qwen/model.py), and can also be taken as example.
|
||||
|
||||
### Fully customized
|
||||
If the model weights loader cannot satisfy the requirements, users can write the conversion loop totally on their own.
|
||||
|
||||
@ -11,7 +11,7 @@ This section is for advanced users. Skip this section if you plan to use the pre
|
||||
1. Install prerequisites listed in our [Installing on Windows](https://nvidia.github.io/TensorRT-LLM/installation/windows.html) document.
|
||||
2. Install [CMake](https://cmake.org/download/), version 3.27.7 is recommended, and select the option to add it to the system path.
|
||||
3. Download and install [Visual Studio 2022](https://visualstudio.microsoft.com/).
|
||||
4. Download and unzip [TensorRT 10.3.0.26](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/zip/TensorRT-10.3.0.26.Windows.win10.cuda-12.5.zip).
|
||||
4. Download and unzip [TensorRT 10.4.0.26](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/zip/TensorRT-10.4.0.26.Windows.win10.cuda-12.6.zip).
|
||||
|
||||
## Building a TensorRT-LLM Docker Image
|
||||
|
||||
@ -65,7 +65,7 @@ git submodule update --init --recursive
|
||||
2. Build TensorRT-LLM. This command generates `build\tensorrt_llm-*.whl`.
|
||||
|
||||
```bash
|
||||
python .\scripts\build_wheel.py -a "89-real" --trt_root C:\workspace\TensorRT-10.3.0.26\
|
||||
python .\scripts\build_wheel.py -a "89-real" --trt_root C:\workspace\TensorRT-10.4.0.26\
|
||||
```
|
||||
|
||||
3. Copy or move `build\tensorrt_llm-*.whl` into your mounted folder so it can be accessed on your host machine. If you intend to use the C++ runtime, you'll also need to gather various DLLs from the build into your mounted folder. For more information, refer to [C++ Runtime Usage](#c-runtime-usage).
|
||||
@ -103,7 +103,7 @@ python .\scripts\build_wheel.py -a "89-real" --trt_root C:\workspace\TensorRT-10
|
||||
|
||||
1. Install [CMake](https://cmake.org/download/), version 3.27.7 is recommended, and select the option to add it to the system path.
|
||||
2. Download and install [Visual Studio 2022](https://visualstudio.microsoft.com/). When prompted to select more Workloads, check **Desktop development with C++**.
|
||||
3. Download and unzip [TensorRT 10.3.0.26](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.3.0/zip/TensorRT-10.3.0.26.Windows.win10.cuda-12.5.zip). Move the folder to a location you can reference later, such as `%USERPROFILE%\inference\TensorRT`.
|
||||
3. Download and unzip [TensorRT 10.4.0.26](https://developer.nvidia.com/downloads/compute/machine-learning/tensorrt/10.4.0/zip/TensorRT-10.4.0.26.Windows.win10.cuda-12.6.zip). Move the folder to a location you can reference later, such as `%USERPROFILE%\inference\TensorRT`.
|
||||
|
||||
1. Add the libraries for TensorRT to your system's `Path` environment variable. Your `Path` should include a line like this:
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@
|
||||
|
||||
```{note}
|
||||
The Windows release of TensorRT-LLM is currently in beta.
|
||||
We recommend checking out the [v0.12.0 tag](https://github.com/NVIDIA/TensorRT-LLM/releases/tag/v0.12.0) for the most stable experience.
|
||||
We recommend checking out the [v0.13.0 tag](https://github.com/NVIDIA/TensorRT-LLM/releases/tag/v0.13.0) for the most stable experience.
|
||||
```
|
||||
|
||||
**Prerequisites**
|
||||
@ -15,7 +15,7 @@ We recommend checking out the [v0.12.0 tag](https://github.com/NVIDIA/TensorRT-L
|
||||
|
||||
1. Install all dependencies together.
|
||||
|
||||
1. Run the provided PowerShell script `setup_env.ps1` located under the `/windows/` folder which installs Python and CUDA 12.4.1 automatically with default settings. Run PowerShell as Administrator to use the script.
|
||||
1. Run the provided PowerShell script `setup_env.ps1` located under the `/windows/` folder which installs Python and CUDA 12.5.1 automatically with default settings. Run PowerShell as Administrator to use the script.
|
||||
|
||||
```bash
|
||||
./setup_env.ps1 [-skipCUDA] [-skipPython]
|
||||
@ -52,7 +52,7 @@ We recommend checking out the [v0.12.0 tag](https://github.com/NVIDIA/TensorRT-L
|
||||
before installing TensorRT-LLM with the following command.
|
||||
|
||||
```bash
|
||||
pip install tensorrt_llm==0.12.0 --extra-index-url https://pypi.nvidia.com --extra-index-url https://download.pytorch.org/whl/cu121/
|
||||
pip install tensorrt_llm==0.13.0 --extra-index-url https://pypi.nvidia.com --extra-index-url https://download.pytorch.org/whl/
|
||||
```
|
||||
|
||||
Run the following command to verify that your TensorRT-LLM installation is working properly.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user