mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
open source 4dbf696ae9b74a26829d120b67ab8443d70c8e58 (#2297)
* Update TensorRT-LLM --------- Co-authored-by: Bhuvanesh Sridharan <bhuvanesh.sridharan@sprinklr.com> Co-authored-by: Qingquan Song <ustcsqq@gmail.com>
This commit is contained in:
parent
48686bca3a
commit
8681b3a4c0
3
.gitmodules
vendored
3
.gitmodules
vendored
@ -14,3 +14,6 @@
|
||||
[submodule "3rdparty/ucxx"]
|
||||
path = 3rdparty/ucxx
|
||||
url = https://github.com/GuanLuo/ucxx.git
|
||||
[submodule "3rdparty/pybind11"]
|
||||
path = 3rdparty/pybind11
|
||||
url = https://github.com/pybind/pybind11.git
|
||||
|
||||
@ -46,5 +46,5 @@ repos:
|
||||
args:
|
||||
- --skip=".git,3rdparty"
|
||||
- --exclude-file=examples/whisper/tokenizer.py
|
||||
- --ignore-words-list=rouge,inout,atleast,strat,nd,subtile,thrid
|
||||
- --ignore-words-list=rouge,inout,atleast,strat,nd,subtile,thrid,improbe
|
||||
exclude: 'tests/llm-test-defs/turtle/test_input_files'
|
||||
|
||||
1
3rdparty/pybind11
vendored
Submodule
1
3rdparty/pybind11
vendored
Submodule
@ -0,0 +1 @@
|
||||
Subproject commit f99ffd7e03001810a3e722bf48ad1a9e08415d7d
|
||||
@ -103,7 +103,7 @@ Serverless TensorRT-LLM (LLaMA 3 8B) | Modal Docs [➡️ link](https://modal.co
|
||||
## TensorRT-LLM Overview
|
||||
|
||||
TensorRT-LLM is a library for optimizing Large Language Model (LLM) inference.
|
||||
It provides state-of-the-art optimziations, including custom attention kernels, inflight batching, paged KV caching, quantization (FP8, INT4 [AWQ](https://arxiv.org/abs/2306.00978), INT8 [SmoothQuant](https://arxiv.org/abs/2211.10438), ++) and much more, to perform inference efficiently on NVIDIA GPUs
|
||||
It provides state-of-the-art optimizations, including custom attention kernels, inflight batching, paged KV caching, quantization (FP8, INT4 [AWQ](https://arxiv.org/abs/2306.00978), INT8 [SmoothQuant](https://arxiv.org/abs/2211.10438), ++) and much more, to perform inference efficiently on NVIDIA GPUs
|
||||
|
||||
TensorRT-LLM provides a Python API to build LLMs into optimized
|
||||
[TensorRT](https://developer.nvidia.com/tensorrt) engines.
|
||||
|
||||
@ -42,7 +42,7 @@ This section covers how to benchmark TensorRT-LLM using inflight batching.
|
||||
### Quickstart
|
||||
|
||||
For this quick start guide, we will focus on running a short max throughput benchmark on
|
||||
`meta-llama/Llama-2-7b-hf` on a syntehtic dataset with a uniform distribution of prompts with ISL:OSL
|
||||
`meta-llama/Llama-2-7b-hf` on a synthetic dataset with a uniform distribution of prompts with ISL:OSL
|
||||
of 128:128. In order to run the benchmark from start to finish simply run the following commands:
|
||||
|
||||
```shell
|
||||
@ -101,12 +101,12 @@ The workflow for `trtllm-bench` is composed of the following steps:
|
||||
The inflight benchmark utilizes a fixed JSON schema so that it is simple and
|
||||
straightforward to specify requests. The schema is defined as follows:
|
||||
|
||||
| Key | Required | Type | Description |
|
||||
| :- | :-: | :-: | :- |
|
||||
| `task_id`| Y | String | Unique identifier for the request. |
|
||||
| `prompt` | N* | String | Input text for a generation request. |
|
||||
| `logits` | N* | List[Integer] | List of logits that make up the request prompt. |
|
||||
| `output_tokens` | Y | Integer | Number of generated tokens for this request. |
|
||||
| Key | Required | Type | Description |
|
||||
| :-------------- | :------: | :-----------: | :---------------------------------------------- |
|
||||
| `task_id` | Y | String | Unique identifier for the request. |
|
||||
| `prompt` | N* | String | Input text for a generation request. |
|
||||
| `logits` | N* | List[Integer] | List of logits that make up the request prompt. |
|
||||
| `output_tokens` | Y | Integer | Number of generated tokens for this request. |
|
||||
|
||||
> [!NOTE] Prompt and logits are mutually exclusive*
|
||||
> While having both `prompt` and `logits` is not required, at least one is required.
|
||||
|
||||
@ -316,6 +316,8 @@ endif()
|
||||
get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_CURRENT_SOURCE_DIR} PATH)
|
||||
|
||||
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
|
||||
add_subdirectory(${3RDPARTY_DIR}/pybind11 ${CMAKE_CURRENT_BINARY_DIR}/pybind11)
|
||||
|
||||
include_directories(
|
||||
${CUDAToolkit_INCLUDE_DIRS}
|
||||
${CUDNN_ROOT_DIR}/include
|
||||
@ -323,7 +325,8 @@ include_directories(
|
||||
${3RDPARTY_DIR}/cutlass/include
|
||||
${3RDPARTY_DIR}/cutlass/tools/util/include
|
||||
${3RDPARTY_DIR}/NVTX/include
|
||||
${3RDPARTY_DIR}/json/include)
|
||||
${3RDPARTY_DIR}/json/include
|
||||
${3RDPARTY_DIR}/pybind11/include)
|
||||
|
||||
# TRT dependencies
|
||||
set_ifndef(TRT_LIB_DIR ${CMAKE_BINARY_DIR})
|
||||
|
||||
187
cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h
Normal file
187
cpp/include/tensorrt_llm/batch_manager/capacityScheduler.h
Normal file
@ -0,0 +1,187 @@
|
||||
/*
|
||||
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/common/algorithm.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include <variant>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
namespace kv_cache_manager
|
||||
{
|
||||
class KVCacheManager;
|
||||
}
|
||||
class BasePeftCacheManager;
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
using tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
/// @brief This scheduler takes into account the given request capacity and the KV cache capacity.
|
||||
/// Depending on the CapacitySchedulerPolicy it will schedule already started and new requests,
|
||||
/// or even pause previously started requests.
|
||||
class BaseCapacityScheduler
|
||||
{
|
||||
public:
|
||||
explicit BaseCapacityScheduler(LlmRequestState noScheduleUntilState, LlmRequestState noScheduleAfterState)
|
||||
: mNoScheduleUntilState(noScheduleUntilState)
|
||||
, mNoScheduleAfterState(noScheduleAfterState)
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] LlmRequestState constexpr getNoScheduleUntilState() const noexcept
|
||||
{
|
||||
return mNoScheduleUntilState;
|
||||
}
|
||||
|
||||
[[nodiscard]] LlmRequestState constexpr getNoScheduleAfterState() const noexcept
|
||||
{
|
||||
return mNoScheduleAfterState;
|
||||
}
|
||||
|
||||
private:
|
||||
/// The state until/after which the scheduler should not schedule requests
|
||||
LlmRequestState mNoScheduleUntilState;
|
||||
LlmRequestState mNoScheduleAfterState;
|
||||
};
|
||||
|
||||
/// @brief Schedule up to maxNumRequests requests
|
||||
class MaxRequestsScheduler : public BaseCapacityScheduler
|
||||
{
|
||||
public:
|
||||
explicit MaxRequestsScheduler(SizeType32 maxNumRequests,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
|
||||
|
||||
/// @brief Takes as input a sorted list of requests and outputs a sorted lists of requests
|
||||
/// to update for this current iteration, and a map of requests to pause
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
|
||||
|
||||
private:
|
||||
SizeType32 mMaxNumRequests;
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr};
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr};
|
||||
};
|
||||
|
||||
/// @brief Schedule requests using the MAX_UTILIZATION policy
|
||||
/// @details Try reserving resources to advance requests by one step,
|
||||
/// may pause previously started requests.
|
||||
class MaxUtilizationScheduler : public BaseCapacityScheduler
|
||||
{
|
||||
public:
|
||||
MaxUtilizationScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
|
||||
std::shared_ptr<BasePeftCacheManager> peftCacheManager, bool manyMicroBatches,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
|
||||
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
|
||||
|
||||
private:
|
||||
/// @return {fitsKvCache, fitsPeft}
|
||||
std::pair<bool, bool> trySchedulingRequestMaxUtilization(std::shared_ptr<LlmRequest> const& req,
|
||||
RequestVector& scheduledRequests, SizeType32& numScheduledBlocks, SizeType32& numScheduledPeftPages,
|
||||
std::unordered_set<uint64_t>& seenTaskIds) const;
|
||||
|
||||
SizeType32 mMaxNumRequests;
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr};
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr};
|
||||
std::shared_ptr<BasePeftCacheManager> mPeftCacheManager{nullptr};
|
||||
/// @brief Boolean that indicates if multiple micro batches might be in flight
|
||||
bool mManyMicroBatches;
|
||||
};
|
||||
|
||||
/// @brief Schedule requests using the GUARANTEED_NO_EVICT policy
|
||||
class GuaranteedNoEvictScheduler : public BaseCapacityScheduler
|
||||
{
|
||||
public:
|
||||
GuaranteedNoEvictScheduler(SizeType32 maxNumRequests,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
|
||||
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
|
||||
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
|
||||
|
||||
protected:
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> forwardImpl(
|
||||
RequestList const& activeRequests, bool staticBatchScheduling) const;
|
||||
|
||||
private:
|
||||
SizeType32 mMaxNumRequests;
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr};
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr};
|
||||
std::shared_ptr<BasePeftCacheManager> mPeftCacheManager{nullptr};
|
||||
};
|
||||
|
||||
/// @brief Schedule requests using the STATIC_BATCH policy
|
||||
class StaticBatchScheduler : public GuaranteedNoEvictScheduler
|
||||
{
|
||||
public:
|
||||
StaticBatchScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
|
||||
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
|
||||
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
|
||||
};
|
||||
|
||||
class CapacityScheduler : public Algorithm
|
||||
{
|
||||
public:
|
||||
constexpr static auto name{"CapacityScheduler"};
|
||||
|
||||
CapacityScheduler() = default;
|
||||
|
||||
CapacityScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
|
||||
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
|
||||
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool manyMicroBatches = false,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
|
||||
|
||||
static CapacityScheduler make(SizeType32 maxNumRequests,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
|
||||
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
|
||||
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool manyMicroBatches = false,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE)
|
||||
{
|
||||
return CapacityScheduler{maxNumRequests, std::move(kvCacheManager), std::move(crossKvCacheManager),
|
||||
std::move(peftCacheManager), capacitySchedulerPolicy, manyMicroBatches, noScheduleUntilState,
|
||||
noScheduleAfterState};
|
||||
}
|
||||
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
|
||||
|
||||
private:
|
||||
std::variant<std::monostate, MaxRequestsScheduler, MaxUtilizationScheduler, GuaranteedNoEvictScheduler,
|
||||
StaticBatchScheduler>
|
||||
mScheduler;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
118
cpp/include/tensorrt_llm/batch_manager/common.h
Normal file
118
cpp/include/tensorrt_llm/batch_manager/common.h
Normal file
@ -0,0 +1,118 @@
|
||||
/*
|
||||
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include <cstdint>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <unordered_set>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::executor
|
||||
{
|
||||
class RequestWithId;
|
||||
}
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
class LlmRequest;
|
||||
|
||||
using RequestList = std::list<std::shared_ptr<LlmRequest>>;
|
||||
using RequestIdType = std::uint64_t;
|
||||
using RequestVector = std::vector<std::shared_ptr<LlmRequest>>;
|
||||
using ReqIdsSet = std::unordered_set<RequestIdType>;
|
||||
|
||||
class ScheduledRequests
|
||||
{
|
||||
public:
|
||||
/// @brief context phase requests (for decoder-only models) or encoder phase requests (for encoder-decoder models
|
||||
/// and encoder-only models)
|
||||
RequestVector contextRequests;
|
||||
|
||||
/// @brief generation phase requests (for decoder-only models) or empty for others
|
||||
RequestVector generationRequests;
|
||||
|
||||
ScheduledRequests() = default;
|
||||
|
||||
explicit ScheduledRequests(RequestVector contextRequests, RequestVector generationRequests)
|
||||
: contextRequests{std::move(contextRequests)}
|
||||
, generationRequests{std::move(generationRequests)}
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] bool empty() const
|
||||
{
|
||||
return contextRequests.empty() && generationRequests.empty();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::size_t size() const
|
||||
{
|
||||
return contextRequests.size() + generationRequests.size();
|
||||
}
|
||||
};
|
||||
|
||||
class BatchState
|
||||
{
|
||||
public:
|
||||
BatchState() = default;
|
||||
|
||||
BatchState(runtime::SizeType32 numCtxRequests, runtime::SizeType32 numGenRequests, runtime::SizeType32 numTokens,
|
||||
runtime::SizeType32 maxKvCacheLength)
|
||||
: mNumCtxRequests{numCtxRequests}
|
||||
, mNumGenRequests{numGenRequests}
|
||||
, mNumTokens{numTokens}
|
||||
, mMaxKvCacheLength{maxKvCacheLength}
|
||||
{
|
||||
}
|
||||
|
||||
bool isAnyContext() const
|
||||
{
|
||||
return mNumCtxRequests > 0;
|
||||
}
|
||||
|
||||
bool operator==(BatchState const& other) const
|
||||
{
|
||||
return mNumCtxRequests == other.mNumCtxRequests && mNumGenRequests == other.mNumGenRequests
|
||||
&& mNumTokens == other.mNumTokens && mMaxKvCacheLength == other.mMaxKvCacheLength;
|
||||
}
|
||||
|
||||
size_t hash() const
|
||||
{
|
||||
size_t h1 = std::hash<runtime::SizeType32>{}(mNumCtxRequests);
|
||||
size_t h2 = std::hash<runtime::SizeType32>{}(mNumGenRequests);
|
||||
size_t h3 = std::hash<runtime::SizeType32>{}(mNumTokens);
|
||||
size_t h4 = std::hash<runtime::SizeType32>{}(mMaxKvCacheLength);
|
||||
return h1 ^ h2 ^ h3 ^ h4;
|
||||
}
|
||||
|
||||
runtime::SizeType32 mNumCtxRequests;
|
||||
runtime::SizeType32 mNumGenRequests;
|
||||
runtime::SizeType32 mNumTokens;
|
||||
runtime::SizeType32 mMaxKvCacheLength;
|
||||
};
|
||||
|
||||
struct BatchStateHash
|
||||
{
|
||||
size_t operator()(BatchState const& bs) const
|
||||
{
|
||||
return bs.hash();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
74
cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h
Normal file
74
cpp/include/tensorrt_llm/batch_manager/evictionPolicy.h
Normal file
@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
|
||||
#include <vector>
|
||||
|
||||
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
|
||||
|
||||
namespace tensorrt_llm::batch_manager::eviction_policy
|
||||
{
|
||||
|
||||
class BaseEvictionPolicy
|
||||
{
|
||||
public:
|
||||
virtual ~BaseEvictionPolicy() = default;
|
||||
|
||||
virtual void initialize(
|
||||
std::vector<BlockPtr>& mAllBlocksById, SizeType32 numPrimaryBlocks, SizeType32 numSecondaryBlocks)
|
||||
= 0;
|
||||
|
||||
// Get a free block from the primary memory pool
|
||||
virtual BlockPtr getFreePrimaryBlock() = 0;
|
||||
// Get a free block from the secondary memory pool
|
||||
virtual BlockPtr getFreeSecondaryBlock() = 0;
|
||||
// Release a block. Prioritize the block for eviction if toFront=true
|
||||
virtual void releaseBlock(BlockPtr block, bool toFront = false) = 0;
|
||||
// Get the amount of free blocks in the primary memory pool
|
||||
virtual SizeType32 getNumFreePrimaryBlocks() = 0;
|
||||
// Get the amount of free blocks in the secondary memory pool
|
||||
virtual SizeType32 getNumFreeSecondaryBlocks() = 0;
|
||||
// Claim a free block. Called when the cache manager allocates or reuses a new block
|
||||
virtual void claimBlock(KVCacheBlock block) = 0;
|
||||
};
|
||||
|
||||
class LRUEvictionPolicy : public BaseEvictionPolicy
|
||||
{
|
||||
public:
|
||||
void initialize(
|
||||
std::vector<BlockPtr>& mAllBlocksById, SizeType32 numPrimaryBlocks, SizeType32 numSecondaryBlocks) override;
|
||||
BlockPtr getFreePrimaryBlock() override;
|
||||
BlockPtr getFreeSecondaryBlock() override;
|
||||
void releaseBlock(BlockPtr block, bool toFront = false) override;
|
||||
SizeType32 getNumFreePrimaryBlocks() override;
|
||||
SizeType32 getNumFreeSecondaryBlocks() override;
|
||||
|
||||
void claimBlock(KVCacheBlock block);
|
||||
|
||||
private:
|
||||
FreeBlocksQueue mFreePrimaryBlocks;
|
||||
FreeBlocksQueue mFreeSecondaryBlocks;
|
||||
|
||||
std::vector<std::optional<FreeBlocksQueue::iterator>> mFreeBlockIterators;
|
||||
|
||||
SizeType32 mFreePrimaryBlocksSize;
|
||||
SizeType32 mFreeSecondaryBlocksSize;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::eviction_policy
|
||||
@ -22,6 +22,7 @@
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
@ -29,13 +30,18 @@
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <list>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::batch_manager::eviction_policy
|
||||
{
|
||||
class BaseEvictionPolicy;
|
||||
}
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
|
||||
@ -124,6 +130,8 @@ public:
|
||||
|
||||
[[nodiscard]] IdType getBlockId() const;
|
||||
|
||||
[[nodiscard]] NextBlockMap getNextBlocks() const;
|
||||
|
||||
[[nodiscard]] kernels::KVCacheIndex::UnderlyingType getMemoryPoolBlockIndex() const;
|
||||
|
||||
[[nodiscard]] bool isPrimary() const;
|
||||
@ -144,22 +152,12 @@ public:
|
||||
|
||||
[[nodiscard]] VecUniqueTokens const& getUniqueTokens() const;
|
||||
|
||||
void setFreeBlockIterator(FreeBlocksQueue::iterator freeBlockIterator);
|
||||
|
||||
void resetFreeBlockIterator();
|
||||
|
||||
[[nodiscard]] std::optional<FreeBlocksQueue::iterator> const& getFreeBlockIterator() const;
|
||||
|
||||
void setPrevBlock(BlockPtr prevBlock);
|
||||
|
||||
void addNextBlock(BlockKey const& blockKey, BlockPtr block);
|
||||
|
||||
void removeNextBlock(BlockKey const& blockKey);
|
||||
|
||||
static std::shared_ptr<KVCacheBlock> findBestGPUBlockToFree(std::shared_ptr<KVCacheBlock> searchStart);
|
||||
|
||||
static std::shared_ptr<KVCacheBlock> findLeafBlock(std::shared_ptr<KVCacheBlock> searchStart);
|
||||
|
||||
[[nodiscard]] BlockPtr findMatchingBlock(BlockKey const& blockKey) const;
|
||||
|
||||
//! \brief Free block from previous block if present.
|
||||
@ -203,14 +201,21 @@ class GenerationRequest
|
||||
{
|
||||
public:
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
using SharedPtr = std::shared_ptr<GenerationRequest>;
|
||||
|
||||
explicit GenerationRequest(SizeType32 seqSlotIdx, SizeType32 numTokens, SizeType32 beamWidth)
|
||||
: mSeqSlotIdx(seqSlotIdx)
|
||||
explicit GenerationRequest(LlmRequest::RequestIdType requestId, SizeType32 numTokens, SizeType32 beamWidth,
|
||||
SizeType32 maxBlocks, SizeType32 numPools = 1)
|
||||
: mRequestId(requestId)
|
||||
, mNumTokens(numTokens)
|
||||
, mBeamWidth(beamWidth)
|
||||
, mCacheBlockIds(beamWidth)
|
||||
, mCacheBlockIndices{
|
||||
runtime::BufferManager::cpu(runtime::ITensor::makeShape({numPools, beamWidth, 2, maxBlocks}),
|
||||
runtime::TRTDataType<tensorrt_llm::kernels::KVCacheIndex>::value)}
|
||||
{
|
||||
auto cacheBlockIdsRange = runtime::BufferRange<tensorrt_llm::kernels::KVCacheIndex>(*mCacheBlockIndices);
|
||||
std::fill(cacheBlockIdsRange.begin(), cacheBlockIdsRange.end(),
|
||||
tensorrt_llm::kernels::KVCacheIndex{
|
||||
std::numeric_limits<tensorrt_llm::kernels::KVCacheIndex::UnderlyingType>::max()});
|
||||
}
|
||||
|
||||
void addNewTokens(SizeType32 n)
|
||||
@ -225,9 +230,9 @@ public:
|
||||
mNumTokens -= n;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getSequenceSlotIdx() const
|
||||
[[nodiscard]] LlmRequest::RequestIdType getRequestId() const
|
||||
{
|
||||
return mSeqSlotIdx;
|
||||
return mRequestId;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getNumTokens() const
|
||||
@ -245,6 +250,16 @@ public:
|
||||
return mCacheBlockIds;
|
||||
}
|
||||
|
||||
[[nodiscard]] runtime::ITensor& getCacheBlockIndices()
|
||||
{
|
||||
return *mCacheBlockIndices;
|
||||
}
|
||||
|
||||
[[nodiscard]] runtime::ITensor const& getCacheBlockIndices() const
|
||||
{
|
||||
return *mCacheBlockIndices;
|
||||
}
|
||||
|
||||
void addCacheBlock(SizeType32 beamIdx, KVCacheBlock::IdType blockId)
|
||||
{
|
||||
mCacheBlockIds.at(beamIdx).push_back(blockId);
|
||||
@ -272,14 +287,16 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
// Slot id of the sequence
|
||||
SizeType32 mSeqSlotIdx;
|
||||
// Request id of the sequence
|
||||
LlmRequest::RequestIdType mRequestId;
|
||||
// Current number of generated tokens
|
||||
SizeType32 mNumTokens;
|
||||
// Number of beams
|
||||
SizeType32 mBeamWidth;
|
||||
// List of blocks allocated for each beam of the sequence
|
||||
// List of block ids allocated for each beam of the sequence
|
||||
std::vector<std::vector<KVCacheBlock::IdType>> mCacheBlockIds;
|
||||
// Tensor of block indices allocated for each beam of the sequence
|
||||
runtime::ITensor::SharedPtr mCacheBlockIndices;
|
||||
};
|
||||
|
||||
// attach metadata to a pool pointer
|
||||
@ -315,17 +332,19 @@ public:
|
||||
// tokens_per_block, head_size]. The size per block and number of blocks are pre-determined and set in the constructor.
|
||||
// BlockManager maintains a list of free blocks at any time.
|
||||
// Alloc pops off the block at the front, and Free pushes it back to the vector.
|
||||
// BlockManager maintains a vector of lists of seqSlotIdx to allocated blocks
|
||||
// BlockManager maintains a vector of lists of request ids to allocated blocks
|
||||
// per sequence. This can be used to Free all blocks belonging to a sequence.
|
||||
class BlockManager
|
||||
{
|
||||
public:
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;
|
||||
using BaseEvictionPolicy = tensorrt_llm::batch_manager::eviction_policy::BaseEvictionPolicy;
|
||||
|
||||
explicit BlockManager(std::vector<SizeType32> const& numKvHeadsPerLayer, SizeType32 sizePerHead,
|
||||
SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
|
||||
std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks, CacheType cacheType = CacheType::kSELF);
|
||||
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks,
|
||||
CacheType cacheType = CacheType::kSELF);
|
||||
|
||||
~BlockManager();
|
||||
|
||||
@ -340,10 +359,6 @@ public:
|
||||
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
|
||||
void addSequence(GenerationRequest& sequence, SizeType32 numBlocks, SizeType32 unsharedBlockIdx);
|
||||
|
||||
//! \brief Release block, which puts it back onto free blocks queue.
|
||||
//! \details Block appended by default, will be put at front if toFront is true.
|
||||
void releaseBlock(std::shared_ptr<KVCacheBlock> block, bool toFront = false);
|
||||
|
||||
//! \brief Allocate new block for each beam of the sequence.
|
||||
//! \details Might free cached blocks if no free blocks are available.
|
||||
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams = false);
|
||||
@ -359,10 +374,7 @@ public:
|
||||
//! \brief Release last block in the sequence
|
||||
void releaseLastBlock(GenerationRequest& sequence);
|
||||
|
||||
[[nodiscard]] SizeType32 getNumFreeBlocks() const noexcept
|
||||
{
|
||||
return mFreePrimaryBlocksSize;
|
||||
}
|
||||
[[nodiscard]] SizeType32 getNumFreeBlocks() const noexcept;
|
||||
|
||||
[[nodiscard]] SizeType32 getNumAllocTotalBlocks() const
|
||||
{
|
||||
@ -467,6 +479,11 @@ public:
|
||||
BlockKey findNewContextBlock(
|
||||
VecUniqueTokens const& uniqueTokens, std::shared_ptr<LlmRequest> const& llmRequest) const;
|
||||
|
||||
[[nodiscard]] runtime::BufferManager const& getBufferManager() const
|
||||
{
|
||||
return mBufferManager;
|
||||
}
|
||||
|
||||
private:
|
||||
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
|
||||
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
|
||||
@ -486,17 +503,9 @@ private:
|
||||
SizeType32 loadOrAllocateBlocks(
|
||||
std::list<BlockKey> const& blockKeys, SizeType32 numContextBlocks, GenerationRequest& sequence);
|
||||
|
||||
//! \brief Find best primary block to free.
|
||||
//! \details The best primary block to free is the primary block that appears first in the queue and have no primary
|
||||
//! block descendants
|
||||
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBestGPUBlockToFree();
|
||||
|
||||
//! \brief Find block least likely to be reused, free it if necessary and return.
|
||||
[[nodiscard]] BlockPtr getFreeBlock();
|
||||
|
||||
//! \brief Claim block if it is in free blocks list.
|
||||
void claimBlock(KVCacheBlock& block);
|
||||
|
||||
//! \brief Free block from previous block and claim it from free blocks list.
|
||||
void claimLeafBlock(KVCacheBlock& block);
|
||||
|
||||
@ -511,15 +520,9 @@ private:
|
||||
// Number of blocks in pools
|
||||
SizeType32 mNumPrimaryBlocks;
|
||||
SizeType32 mNumSecondaryBlocks;
|
||||
// List of free blocks. Blocks are either backed by fast primary memory or slow secondary memory.
|
||||
// We maintain separate queues for these.
|
||||
// We cache size of each queue instead of calling std::list::size, because size is O(N) function.
|
||||
SizeType32 mFreePrimaryBlocksSize;
|
||||
SizeType32 mFreeSecondaryBlocksSize;
|
||||
FreeBlocksQueue mFreePrimaryBlocks;
|
||||
FreeBlocksQueue mFreeSecondaryBlocks;
|
||||
|
||||
// List of allocated blocks for each sequences
|
||||
std::vector<std::vector<BlockPtr>> mAllocatedBlocksPerSeq;
|
||||
std::unordered_map<LlmRequest::RequestIdType, std::vector<BlockPtr>> mAllocatedBlocksPerSeq;
|
||||
|
||||
// Pool per unique numKvHeads in the model
|
||||
std::vector<KVCacheBlockPool> mPools;
|
||||
@ -547,6 +550,8 @@ private:
|
||||
std::size_t mAllocTotalBlocks, mAllocNewBlocks, mReusedBlocks;
|
||||
// KV cache type (self or cross)
|
||||
CacheType mCacheType;
|
||||
// Eviction Policy
|
||||
std::shared_ptr<BaseEvictionPolicy> mEvictionPolicy;
|
||||
|
||||
private:
|
||||
friend class KVCacheManager;
|
||||
@ -555,8 +560,9 @@ private:
|
||||
class KVCacheManager
|
||||
{
|
||||
public:
|
||||
friend class KVCacheManagerBindings;
|
||||
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
using SequencesPtr = GenerationRequest::SharedPtr;
|
||||
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
|
||||
using CacheType = tensorrt_llm::batch_manager::kv_cache_manager::CacheType;
|
||||
|
||||
@ -647,10 +653,10 @@ public:
|
||||
/// @return The number of blocks
|
||||
[[nodiscard]] SizeType32 getRemainingBlocksToCompletion(LlmRequest const& req) const;
|
||||
|
||||
void addContextTokens(SizeType32 seqSlotIdx, SizeType32 numTokens);
|
||||
void addContextTokens(LlmRequest::RequestIdType requestId, SizeType32 numTokens);
|
||||
|
||||
/// @brief Increase size for request at seqSlotIdx. Allocate new KV cache block(s) if needed.
|
||||
void addToken(SizeType32 seqSlotIdx);
|
||||
/// @brief Increase size for request with requestId. Allocate new KV cache block(s) if needed.
|
||||
void addToken(LlmRequest::RequestIdType requestId);
|
||||
|
||||
/// @brief Add new request to the KV cache manager.
|
||||
/// @param inputLength Input length for which KV cache need to be allocated.
|
||||
@ -658,23 +664,29 @@ public:
|
||||
/// @param llmRequest Optional request to use for KV cache lookup.
|
||||
/// @details If llmRequest is supplied and KV cache reuse is enabled, try to recover KV cache blocks for
|
||||
/// inputLength - 1 tokens and populate prepopulatedPromptLen.
|
||||
void addSequence(SizeType32 seqSlotIdx, SizeType32 inputLength, SizeType32 beamWidth,
|
||||
void addSequence(LlmRequest::RequestIdType requestId, SizeType32 inputLength, SizeType32 beamWidth,
|
||||
std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
|
||||
|
||||
void removeSequence(SizeType32 seqSlotIdx, std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
|
||||
void removeSequence(LlmRequest::RequestIdType requestId, std::shared_ptr<LlmRequest> const& llmRequest = nullptr);
|
||||
|
||||
void schedulingRemoveSequence(SizeType32 seqSlotIdx);
|
||||
void schedulingRemoveSequence(LlmRequest::RequestIdType requestId);
|
||||
|
||||
[[nodiscard]] runtime::ITensor::UniquePtr getBlockPoolPointers() const;
|
||||
[[nodiscard]] runtime::ITensor::SharedPtr getBlockPoolPointers() const
|
||||
{
|
||||
return mBlockPoolPointers;
|
||||
}
|
||||
|
||||
[[nodiscard]] runtime::ITensor::UniquePtr getLayerToPoolMapping() const;
|
||||
[[nodiscard]] runtime::ITensor::SharedPtr getLayerToPoolMapping() const
|
||||
{
|
||||
return mLayerToPoolMapping;
|
||||
}
|
||||
|
||||
void getBlockOffsetsOfBatch(
|
||||
runtime::ITensor& output, SizeType32 firstBatchSlotIdx, SizeType32 batchSize, SizeType32 beamWidth) const;
|
||||
|
||||
//! @return maxBlockCount of all beams
|
||||
SizeType32 copyBlockOffsets(
|
||||
runtime::ITensor& output, SizeType32 outputSlotOffset, SizeType32 seqSlotIdx, SizeType32 beamWidth) const;
|
||||
runtime::ITensor& output, SizeType32 outputSlotOffset, LlmRequest::RequestIdType requestId) const;
|
||||
|
||||
// Sum of numLayers * 2 * numKvHeads * sizePerHead for each pool
|
||||
[[nodiscard]] static SizeType32 calculateCacheSizePerToken(
|
||||
@ -697,10 +709,10 @@ public:
|
||||
return mEnableBlockReuse;
|
||||
}
|
||||
|
||||
void removeToken(SizeType32 seqSlotIdx);
|
||||
void rewindKVCache(SizeType32 seqSlotIdx, SizeType32 rewindLengths);
|
||||
void removeToken(LlmRequest::RequestIdType requestId);
|
||||
void rewindKVCache(LlmRequest::RequestIdType requestId, SizeType32 rewindLengths);
|
||||
|
||||
[[nodiscard]] GenerationRequest const& getSequence(SizeType32 seqSlotIdx) const;
|
||||
[[nodiscard]] GenerationRequest const& getSequence(LlmRequest::RequestIdType requestId) const;
|
||||
|
||||
[[nodiscard]] bool isCrossKv() const
|
||||
{
|
||||
@ -714,7 +726,7 @@ public:
|
||||
|
||||
//! \brief Store full context blocks contributed by llmRequest.
|
||||
//! \details These blocks become reusable from next step.
|
||||
void storeContextBlocks(SizeType32 seqSlotIdx, std::shared_ptr<LlmRequest> const& llmRequest);
|
||||
void storeContextBlocks(std::shared_ptr<LlmRequest> const& llmRequest);
|
||||
|
||||
[[nodiscard]] static SizeType32 getSinkBubbleLength(SizeType32 sinkTokenLen, SizeType32 tokensPerBlock);
|
||||
|
||||
@ -722,14 +734,13 @@ public:
|
||||
SizeType32 tokensPerBlock, SizeType32 maxBeamWidth, SizeType32 sinkTokenLen, bool useOneMoreBlock);
|
||||
|
||||
private:
|
||||
void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 seqSlotIdx,
|
||||
SizeType32 beamIdx, SizeType32 blockIdx, KVCacheBlock::IdType blockId) const;
|
||||
void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
|
||||
SizeType32 blockIdx, KVCacheBlock::IdType blockId) const;
|
||||
|
||||
void resetBlockOffsets(SizeType32 seqSlotIdx, SizeType32 beamWidth);
|
||||
void cacheBlockOffsets(GenerationRequest const& seq, SizeType32 seqSlotIdx);
|
||||
void cacheNewBlockOffsets(GenerationRequest const& seq, SizeType32 seqSlotIdx);
|
||||
void updateNewBlockPointer(GenerationRequest const& seq, SizeType32 seqSlotIdx, SizeType32 blockIdx);
|
||||
void updateToken(SizeType32 seqSlotIdx, bool addToken);
|
||||
void cacheBlockOffsets(GenerationRequest& seq);
|
||||
void cacheNewBlockOffsets(GenerationRequest& seq);
|
||||
void updateNewBlockPointer(GenerationRequest& seq, SizeType32 blockIdx);
|
||||
void updateToken(GenerationRequest& sequence, bool addToken);
|
||||
|
||||
private:
|
||||
// Maximum number of sequences
|
||||
@ -749,12 +760,13 @@ private:
|
||||
SizeType32 mSinkBlockTokenLength;
|
||||
// Block manager
|
||||
BlockManager mBlockManager;
|
||||
// List of all sequences
|
||||
std::vector<SequencesPtr> mSequences;
|
||||
// buffer for block indices for all managed sequences
|
||||
runtime::ITensor::SharedPtr mSequenceBlockIndices;
|
||||
// Map of all sequences
|
||||
std::unordered_map<LlmRequest::RequestIdType, GenerationRequest> mSequences;
|
||||
// Whether to cache KV pages for reuse
|
||||
bool mEnableBlockReuse;
|
||||
// buffers for static tensors, will be created after allocating pools
|
||||
runtime::ITensor::SharedPtr mBlockPoolPointers;
|
||||
runtime::ITensor::SharedPtr mLayerToPoolMapping;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
@ -65,6 +65,11 @@ public:
|
||||
return ret;
|
||||
}
|
||||
|
||||
operator runtime::ITensor::SharedPtr()
|
||||
{
|
||||
return mCurrent;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool operator==(BlockIterator const& other) const
|
||||
{
|
||||
return mIdx == other.mIdx && mPool.get() == other.mPool.get();
|
||||
|
||||
@ -55,6 +55,7 @@ enum class LlmRequestState : int32_t
|
||||
/// Waiting context-only request transmitting the kv cache
|
||||
kDISAGG_CONTEXT_COMPLETE = 8, ///< Context-only request finished kv cache transmission.
|
||||
kDISAGG_GENERATION_TRANS_IN_PROGRESS = 9, ///< For disaggregated serving only: transmitting the kv cache
|
||||
kWAITING_TO_SEND_LOGITS = 10, ///< Generation phase completed, logits not sent yet
|
||||
};
|
||||
|
||||
enum LlmRequestType
|
||||
@ -132,8 +133,7 @@ public:
|
||||
, mLoraWeights(std::move(loraWeights))
|
||||
, mLoraConfig(std::move(loraConfig))
|
||||
, mLookaheadConfig(std::move(lookaheadConfig))
|
||||
, mContextChunkSize(std::nullopt)
|
||||
, mContextCurrentPosition(0)
|
||||
, mContextChunkSize{mPromptLen}
|
||||
, mLogProbs(samplingConfig.beamWidth)
|
||||
, mCumLogProbs(samplingConfig.beamWidth)
|
||||
, mDraftTokens(draftTokens.value_or(std::make_shared<VecTokens>()))
|
||||
@ -186,8 +186,7 @@ public:
|
||||
, mLoraWeights(std::nullopt)
|
||||
, mLoraConfig(std::nullopt)
|
||||
, mLookaheadConfig(std::nullopt)
|
||||
, mContextChunkSize(std::nullopt)
|
||||
, mContextCurrentPosition(0)
|
||||
, mContextChunkSize{mPromptLen}
|
||||
, mLogProbs(mSamplingConfig.beamWidth)
|
||||
, mCumLogProbs(mSamplingConfig.beamWidth)
|
||||
, mDraftTokens(std::make_shared<VecTokens>())
|
||||
@ -392,6 +391,15 @@ public:
|
||||
mMaxNewTokens = maxNewTokens;
|
||||
}
|
||||
|
||||
if (mNumReturnSequences > 1 && mSamplingConfig.beamWidth > 1)
|
||||
{
|
||||
TLLM_THROW(
|
||||
"Using mNumReturnSequences (%d) > 1 with beam search is currently disabled, since TensorRT-LLM returns "
|
||||
"a total of mNumReturnSequences x beamWidth beams, rather than limiting the number of returned beams "
|
||||
"to mNumReturnSequences. This restriction will be removed once the issue is resolved.",
|
||||
mNumReturnSequences);
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO(mSamplingConfig.validate(), "Incorrect sampling config");
|
||||
|
||||
// validate extra ids when enabling kv cache reuse with prompt table
|
||||
@ -722,7 +730,7 @@ public:
|
||||
mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT
|
||||
: LlmRequestState::kCONTEXT_INIT;
|
||||
mContextCurrentPosition = 0;
|
||||
mContextChunkSize = std::nullopt;
|
||||
mContextChunkSize = mPromptLen;
|
||||
mSeqSlot.reset();
|
||||
}
|
||||
|
||||
@ -869,34 +877,33 @@ public:
|
||||
return mPromptLen;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType32 getPrepopulatedPromptLen() const
|
||||
{
|
||||
return mPrepopulatedPromptLen;
|
||||
}
|
||||
|
||||
void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock)
|
||||
{
|
||||
auto const promptLen = getPromptLen();
|
||||
TLLM_CHECK(prepopulatedPromptLen < promptLen);
|
||||
mPrepopulatedPromptLen = prepopulatedPromptLen;
|
||||
|
||||
if (prepopulatedPromptLen > 0)
|
||||
{
|
||||
// Currently, the runtime process is to apply for cache first and then determine prepopulation.
|
||||
// Use the prepopulated length to advance the context position and decrease chunk size if necessary.
|
||||
if (isFullContextRequest())
|
||||
auto chunkSize = getContextChunkSize();
|
||||
if (prepopulatedPromptLen + chunkSize < promptLen)
|
||||
{
|
||||
setContextCurrentPosition(prepopulatedPromptLen);
|
||||
setContextChunkSize(promptLen);
|
||||
}
|
||||
else
|
||||
{
|
||||
auto chunkSize = getContextChunkSize();
|
||||
if (prepopulatedPromptLen + chunkSize < promptLen)
|
||||
{
|
||||
// make sure to end at block boundary after current chunk
|
||||
auto const flooredEndPosition
|
||||
= (prepopulatedPromptLen + chunkSize) / kvTokensPerBlock * kvTokensPerBlock;
|
||||
chunkSize = flooredEndPosition - prepopulatedPromptLen;
|
||||
TLLM_CHECK(chunkSize <= getContextChunkSize());
|
||||
}
|
||||
setContextCurrentPosition(prepopulatedPromptLen);
|
||||
setContextChunkSize(chunkSize);
|
||||
// make sure to end at block boundary after current chunk
|
||||
auto const flooredEndPosition
|
||||
= (prepopulatedPromptLen + chunkSize) / kvTokensPerBlock * kvTokensPerBlock;
|
||||
chunkSize = flooredEndPosition - prepopulatedPromptLen;
|
||||
TLLM_CHECK(chunkSize <= getContextChunkSize());
|
||||
}
|
||||
setContextCurrentPosition(prepopulatedPromptLen);
|
||||
setContextChunkSize(chunkSize);
|
||||
|
||||
if (!isLastContextChunk())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO((getContextCurrentPosition() + getContextChunkSize()) % kvTokensPerBlock == 0,
|
||||
@ -1176,6 +1183,11 @@ public:
|
||||
return mState == LlmRequestState::kDISAGG_CONTEXT_COMPLETE;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isCompleteWaitingToSendLogits() const noexcept
|
||||
{
|
||||
return mState == LlmRequestState::kWAITING_TO_SEND_LOGITS;
|
||||
}
|
||||
|
||||
/// To determine whether the context is unchunked. When a context is chunked into only a part, it
|
||||
/// is still different from the unchunked state, which indicates the initial status.
|
||||
[[nodiscard]] bool isFullContextRequest() const noexcept
|
||||
@ -1211,12 +1223,11 @@ public:
|
||||
return mPromptLen - getContextCurrentPosition();
|
||||
}
|
||||
|
||||
/// To retrieve the context chunk size, throw an exception when the context is not chunked.
|
||||
[[nodiscard]] SizeType32 getContextChunkSize() const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
isContextInitState() && mContextChunkSize, "The current request is not in context chunking state.");
|
||||
return mContextChunkSize.value();
|
||||
TLLM_CHECK_WITH_INFO(isContextInitState() || isDisaggGenerationInitState(),
|
||||
"getContextChunkSize is only possible during the context phase.");
|
||||
return mContextChunkSize;
|
||||
}
|
||||
|
||||
/// To set the context chunk size, throw an exception when the chunk size is negative. If the chunk
|
||||
@ -1224,45 +1235,34 @@ public:
|
||||
/// remaining length.
|
||||
void setContextChunkSize(SizeType32 size)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase.");
|
||||
TLLM_CHECK_WITH_INFO(isContextInitState(), "setContextChunkSize is only possible during the context phase.");
|
||||
TLLM_CHECK_WITH_INFO(size >= 0, "The chunk size of context (%d) can't be negative.", size);
|
||||
mContextChunkSize = std::min(size, getContextRemainingLength());
|
||||
}
|
||||
|
||||
/// Determines whether the current position is only one chunk away from the end of the context.
|
||||
/// It will return true when the context is not chunked.
|
||||
[[nodiscard]] bool isLastContextChunk() const noexcept
|
||||
{
|
||||
return isFullContextRequest()
|
||||
|| (isContextInitState() && getContextCurrentPosition() + getContextChunkSize() == mPromptLen);
|
||||
return isDisaggGenerationInitState() || getContextCurrentPosition() + getContextChunkSize() == mPromptLen;
|
||||
}
|
||||
|
||||
/// Returns whether the position is at the beginning of the context. It will return true when the
|
||||
/// context is not chunked.
|
||||
/// Returns whether the position is at the beginning of the context.
|
||||
[[nodiscard]] bool isFirstContextChunk() const noexcept
|
||||
{
|
||||
return isFullContextRequest() || getContextCurrentPosition() == 0;
|
||||
}
|
||||
|
||||
[[nodiscard]] executor::PriorityType priority() const noexcept
|
||||
{
|
||||
return mPriority;
|
||||
return getContextCurrentPosition() == 0;
|
||||
}
|
||||
|
||||
/// Move the cursor forward one chunk. When not chunked, move forward to the end of the context.
|
||||
void moveToNextContextChunk()
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase.");
|
||||
if (mContextChunkSize)
|
||||
{
|
||||
mContextCurrentPosition += getContextChunkSize();
|
||||
setContextChunkSize(0);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mContextCurrentPosition == 0, "Full context out of bounds.");
|
||||
mContextCurrentPosition = mPromptLen;
|
||||
}
|
||||
mContextCurrentPosition += getContextChunkSize();
|
||||
setContextChunkSize(0);
|
||||
}
|
||||
|
||||
[[nodiscard]] executor::PriorityType priority() const noexcept
|
||||
{
|
||||
return mPriority;
|
||||
}
|
||||
|
||||
/// Increment the counter of decoding iterations.
|
||||
@ -1282,20 +1282,24 @@ public:
|
||||
return static_cast<float>(getMaxNumGeneratedTokens()) / mDecodingIter;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isFinished() const noexcept
|
||||
{
|
||||
return isGenerationCompleteState() || isDisaggContextTransmissionState() || isCompleteWaitingToSendLogits();
|
||||
}
|
||||
|
||||
/// @brief Create a Response from the current state of the request
|
||||
/// @return An optional Response
|
||||
std::optional<executor::Response> createResponse()
|
||||
std::optional<executor::Response> createResponse(bool useFastLogits = false, int32_t mpiWorldRank = 0)
|
||||
{
|
||||
TLLM_CHECK(!isDisaggContextCompleteState());
|
||||
if (isGenerationCompleteState() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS)
|
||||
|| isDisaggContextTransmissionState())
|
||||
if (isFinished() || (mIsStreaming && mState == LlmRequestState::kGENERATION_IN_PROGRESS))
|
||||
{
|
||||
TLLM_LOG_DEBUG("Creating response for request %lu", mRequestId);
|
||||
|
||||
executor::Result result;
|
||||
result.sequenceIndex = mSequenceIndex;
|
||||
|
||||
result.isSequenceFinal = isGenerationCompleteState() || isDisaggContextTransmissionState();
|
||||
result.isSequenceFinal = isFinished();
|
||||
mSequenceFinalVec->at(mSequenceIndex) = result.isSequenceFinal;
|
||||
|
||||
result.isFinal = std::all_of(mSequenceFinalVec->begin(), mSequenceFinalVec->end(),
|
||||
@ -1333,8 +1337,7 @@ public:
|
||||
|
||||
auto const startTokenPos = maxNbTokens - maxNbTokensOut;
|
||||
|
||||
auto const shouldSendResponse = isGenerationCompleteState()
|
||||
|| (mIsStreaming && maxNbTokens > getMaxSentTokenLen()) || isDisaggContextTransmissionState();
|
||||
auto const shouldSendResponse = isFinished() || (mIsStreaming && maxNbTokens > getMaxSentTokenLen());
|
||||
|
||||
if (!shouldSendResponse)
|
||||
{
|
||||
@ -1374,6 +1377,11 @@ public:
|
||||
= runtime::ITensor::slice(getGenerationLogitsHost(), startGenTokenPos, maxNbTokensOut);
|
||||
result.generationLogits = executor::detail::ofITensor(generationLogitsHostCurrentStep);
|
||||
}
|
||||
else if (useFastLogits)
|
||||
{
|
||||
result.specDecFastLogitsInfo
|
||||
= executor::SpeculativeDecodingFastLogitsInfo{mRequestId, mpiWorldRank};
|
||||
}
|
||||
else
|
||||
{
|
||||
result.generationLogits = executor::detail::ofITensor(getGenerationLogitsHost());
|
||||
@ -1392,7 +1400,7 @@ public:
|
||||
setMaxSentTokenLen(maxNbTokens);
|
||||
|
||||
auto requestId = isChild() ? mParentRequestId : mRequestId;
|
||||
auto response = executor::Response(requestId, std::move(result));
|
||||
auto response = executor::Response(requestId, std::move(result), mClientId);
|
||||
|
||||
return response;
|
||||
}
|
||||
@ -1483,8 +1491,8 @@ protected:
|
||||
// To enable chunked context, the FHMA paged kv-cache also needs to be enabled. Except for the last one,
|
||||
// the size of the context chunk needs to be an integer multiple of the kv-cache block size. The meaning
|
||||
// of null value is that the context is not chunked.
|
||||
std::optional<SizeType32> mContextChunkSize;
|
||||
SizeType32 mContextCurrentPosition;
|
||||
SizeType32 mContextChunkSize{0};
|
||||
SizeType32 mContextCurrentPosition{0};
|
||||
|
||||
std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen]
|
||||
VecLogProbs mCumLogProbs; // [beamSize]
|
||||
@ -1636,6 +1644,8 @@ private:
|
||||
|
||||
class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
|
||||
{
|
||||
friend class LlmRequestBindings;
|
||||
|
||||
public:
|
||||
using Base = GenericLlmRequest<runtime::ITensor::SharedPtr>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
|
||||
108
cpp/include/tensorrt_llm/batch_manager/microBatchScheduler.h
Normal file
108
cpp/include/tensorrt_llm/batch_manager/microBatchScheduler.h
Normal file
@ -0,0 +1,108 @@
|
||||
/*
|
||||
* Copyright (c) 2023-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/common/algorithm.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
namespace batch_scheduler
|
||||
{
|
||||
|
||||
struct ContextChunkingConfig
|
||||
{
|
||||
ContextChunkingConfig() = default;
|
||||
|
||||
executor::ContextChunkingPolicy chunkingPolicy;
|
||||
/// The minimum size, also known as the chunk unit size. It generally
|
||||
/// needs to be equal to the size of the kv cache block or its integer
|
||||
/// multiples (except for the last context chunk) to avoid fragmentation.
|
||||
/// When set to null, it indicates that the context chunk is disabled.
|
||||
tensorrt_llm::runtime::SizeType32 chunkUnitSize;
|
||||
};
|
||||
|
||||
} // namespace batch_scheduler
|
||||
|
||||
/// @brief This scheduler takes into account the desired batch size and limits of the TRT engine to schedule requests.
|
||||
class MicroBatchScheduler : Algorithm
|
||||
{
|
||||
public:
|
||||
constexpr static auto name{"MicroBatchScheduler"};
|
||||
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
using ContextChunkingPolicy = tensorrt_llm::executor::ContextChunkingPolicy;
|
||||
|
||||
MicroBatchScheduler() = default;
|
||||
|
||||
explicit MicroBatchScheduler(SizeType32 maxBatchSize, std::optional<SizeType32> maxNumTokens = std::nullopt,
|
||||
std::optional<batch_scheduler::ContextChunkingConfig> ctxChunkConfig = std::nullopt,
|
||||
std::optional<SizeType32> maxContextLength = std::nullopt,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
|
||||
|
||||
static MicroBatchScheduler make(SizeType32 maxBatchSize, std::optional<SizeType32> maxNumTokens = std::nullopt,
|
||||
std::optional<batch_scheduler::ContextChunkingConfig> ctxChunkConfig = std::nullopt,
|
||||
std::optional<SizeType32> maxContextLength = std::nullopt,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE)
|
||||
{
|
||||
return MicroBatchScheduler{
|
||||
maxBatchSize, maxNumTokens, ctxChunkConfig, maxContextLength, noScheduleUntilState, noScheduleAfterState};
|
||||
}
|
||||
|
||||
std::tuple<RequestVector, RequestVector> operator()(
|
||||
RequestVector const& activeRequests, ReqIdsSet const& inflightReqIds);
|
||||
|
||||
static void setCtxRequestsChunkSize(RequestVector const& contextsToBeChunked, ContextChunkingPolicy ctxChunkPolicy,
|
||||
std::optional<SizeType32> ctxTokensCapacity, SizeType32 chunkUnitSize,
|
||||
std::optional<SizeType32> const& maxContextLength);
|
||||
|
||||
private:
|
||||
template <ContextChunkingPolicy tPolicy>
|
||||
static void setCtxRequestsChunkSize(RequestVector const& contextsToBeChunked,
|
||||
std::optional<SizeType32> ctxTokensCapacity, SizeType32 chunkUnitSize,
|
||||
std::optional<SizeType32> const& maxContextLength);
|
||||
|
||||
/// After the chunk sizes have been determined, this function will discard
|
||||
/// any draft tokens that don't fit.
|
||||
static void fitDraftTokens(RequestVector const& contextsToBeChunked, std::optional<SizeType32> ctxTokensCapacity,
|
||||
SizeType32 chunkUnitSize, std::optional<SizeType32> const& maxContextLength);
|
||||
|
||||
/// The maximum number of requests returned by scheduleRequests
|
||||
SizeType32 mMaxBatchSize;
|
||||
|
||||
/// The maximum number of tokens to include in a batch
|
||||
std::optional<SizeType32> mMaxNumTokens;
|
||||
|
||||
/// The maximum length of the context. If the context exceeds this length,
|
||||
/// it must be chunked, otherwise it cannot be processed. Therefore, it
|
||||
/// needs to be set together with the chunk unit size to make sense.
|
||||
/// When set to null, it indicates that context length is unlimited.
|
||||
std::optional<SizeType32> mMaxContextLength;
|
||||
|
||||
std::optional<batch_scheduler::ContextChunkingConfig> mCtxChunkConfig;
|
||||
|
||||
/// The state until/after which the scheduler should not schedule requests
|
||||
LlmRequestState mNoScheduleUntilState;
|
||||
LlmRequestState mNoScheduleAfterState;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
@ -51,6 +51,8 @@ public:
|
||||
class BasePeftCacheManager
|
||||
{
|
||||
public:
|
||||
friend class BasePeftCacheManagerBindings;
|
||||
|
||||
using LlmRequestPtr = std::shared_ptr<LlmRequest>;
|
||||
using RequestVector = std::vector<LlmRequestPtr>;
|
||||
using PeftTable = std::map<uint64_t, std::shared_ptr<std::vector<runtime::LoraCache::TaskLayerModuleConfig>>>;
|
||||
|
||||
@ -46,7 +46,9 @@ public:
|
||||
executor::SchedulerConfig const& schedulerConfig = executor::SchedulerConfig{},
|
||||
executor::ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig
|
||||
= executor::ExtendedRuntimePerfKnobConfig{},
|
||||
std::optional<executor::DebugConfig> debugConfig = std::nullopt, uint64_t maxSeqIdleMicroseconds = 180000000)
|
||||
std::optional<executor::DebugConfig> debugConfig = std::nullopt, uint64_t maxSeqIdleMicroseconds = 180000000,
|
||||
std::optional<executor::SpeculativeDecodingConfig> specDecConfig = std::nullopt,
|
||||
bool isLeaderInOrchMode = false)
|
||||
: kvCacheConfig{kvCacheConfig}
|
||||
, enableTrtOverlap{enableTrtOverlap}
|
||||
, deviceIds(deviceIds)
|
||||
@ -62,10 +64,12 @@ public:
|
||||
, extendedRuntimePerfKnobConfig(extendedRuntimePerfKnobConfig)
|
||||
, debugConfig{std::move(debugConfig)}
|
||||
, maxSeqIdleMicroseconds{maxSeqIdleMicroseconds}
|
||||
, speculativeDecodingConfig{std::move(specDecConfig)}
|
||||
, isLeaderInOrchMode{isLeaderInOrchMode}
|
||||
{
|
||||
}
|
||||
|
||||
explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig)
|
||||
explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig, bool isLeaderInOrchMode)
|
||||
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), false,
|
||||
executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(),
|
||||
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(),
|
||||
@ -74,7 +78,7 @@ public:
|
||||
executorConfig.getGpuWeightsPercent(), executorConfig.getMaxBeamWidth(), executorConfig.getMaxBatchSize(),
|
||||
executorConfig.getMaxNumTokens(), executorConfig.getSchedulerConfig(),
|
||||
executorConfig.getExtendedRuntimePerfKnobConfig(), executorConfig.getDebugConfig(),
|
||||
executorConfig.getMaxSeqIdleMicroseconds())
|
||||
executorConfig.getMaxSeqIdleMicroseconds(), executorConfig.getSpecDecConfig(), isLeaderInOrchMode)
|
||||
{
|
||||
}
|
||||
|
||||
@ -94,6 +98,8 @@ public:
|
||||
&& extendedRuntimePerfKnobConfig == other.extendedRuntimePerfKnobConfig //
|
||||
&& debugConfig == other.debugConfig //
|
||||
&& maxSeqIdleMicroseconds == other.maxSeqIdleMicroseconds //
|
||||
&& speculativeDecodingConfig == other.speculativeDecodingConfig //
|
||||
&& isLeaderInOrchMode == other.isLeaderInOrchMode //
|
||||
;
|
||||
}
|
||||
|
||||
@ -117,6 +123,9 @@ public:
|
||||
std::optional<executor::DebugConfig> debugConfig;
|
||||
// Sequence is considered idle if not updated for this amount of time.
|
||||
uint64_t maxSeqIdleMicroseconds;
|
||||
std::optional<executor::SpeculativeDecodingConfig> speculativeDecodingConfig;
|
||||
// This rank is the leader worker in orchestrator mode
|
||||
bool isLeaderInOrchMode;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
32
cpp/include/tensorrt_llm/common/algorithm.h
Normal file
32
cpp/include/tensorrt_llm/common/algorithm.h
Normal file
@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
|
||||
// Base class for algorithms
|
||||
struct Algorithm
|
||||
{
|
||||
Algorithm() = default;
|
||||
Algorithm(Algorithm&&) = default;
|
||||
Algorithm& operator=(Algorithm&&) = default;
|
||||
Algorithm(Algorithm const&) = delete;
|
||||
Algorithm& operator=(Algorithm const&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm
|
||||
@ -99,7 +99,6 @@ struct MpiTypeConverter<std::byte>
|
||||
};
|
||||
|
||||
template <>
|
||||
|
||||
struct MpiTypeConverter<half>
|
||||
|
||||
{
|
||||
@ -387,6 +386,7 @@ public:
|
||||
void barrier() const;
|
||||
|
||||
void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
|
||||
bool improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
|
||||
|
||||
//! \brief Returns if a message with the specified source and tag is available
|
||||
bool iprobe(int source, int tag, MPI_Status* status) const;
|
||||
|
||||
@ -186,11 +186,13 @@ class ExternalDraftTokensConfig
|
||||
{
|
||||
public:
|
||||
explicit ExternalDraftTokensConfig(VecTokens tokens, std::optional<Tensor> logits = std::nullopt,
|
||||
std::optional<FloatType> const& acceptanceThreshold = std::nullopt);
|
||||
std::optional<FloatType> const& acceptanceThreshold = std::nullopt,
|
||||
std::optional<bool> const& fastLogits = std::nullopt);
|
||||
|
||||
[[nodiscard]] VecTokens getTokens() const;
|
||||
[[nodiscard]] std::optional<Tensor> getLogits() const;
|
||||
[[nodiscard]] std::optional<FloatType> getAcceptanceThreshold() const;
|
||||
[[nodiscard]] std::optional<bool> getFastLogits() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
@ -200,6 +202,8 @@ private:
|
||||
std::optional<Tensor> mLogits;
|
||||
/// @brief The acceptance threshold. Must be > 0.f and <= 1.f
|
||||
std::optional<FloatType> mAcceptanceThreshold;
|
||||
/// @brief Use direct transfer for draft logits
|
||||
std::optional<bool> mFastLogits;
|
||||
};
|
||||
|
||||
/// @brief Configuration for prompt tuning
|
||||
@ -318,6 +322,18 @@ private:
|
||||
StatePtr mState{nullptr, deleter};
|
||||
};
|
||||
|
||||
/// @brief Configuration for speculative decoding (both draft and target models)
|
||||
class SpeculativeDecodingConfig
|
||||
{
|
||||
public:
|
||||
explicit SpeculativeDecodingConfig(bool fastLogits);
|
||||
|
||||
bool operator==(SpeculativeDecodingConfig const& other) const;
|
||||
|
||||
/// @brief Send logits tensor directly from draft to target model.
|
||||
bool fastLogits;
|
||||
};
|
||||
|
||||
/// @brief A class that holds information about the request
|
||||
class Request
|
||||
{
|
||||
@ -437,6 +453,16 @@ private:
|
||||
std::unique_ptr<Impl> mImpl;
|
||||
};
|
||||
|
||||
/// @brief Struct that holds the logits information when using direct transfer
|
||||
struct SpeculativeDecodingFastLogitsInfo
|
||||
{
|
||||
/// @brief Draft request id
|
||||
uint64_t draftRequestId;
|
||||
|
||||
/// @brief MPI world rank of the draft model leader
|
||||
int32_t draftParticipantId;
|
||||
};
|
||||
|
||||
/// @brief Struct that holds the generation result
|
||||
struct Result
|
||||
{
|
||||
@ -455,11 +481,14 @@ struct Result
|
||||
/// @brief The context logits. Size [promptLen, vocabSizePadded]
|
||||
std::optional<Tensor> contextLogits;
|
||||
|
||||
/// @brief The context logits. Size [beamSize, maxNewTokens, vocabSizePadded] (non-streaming)
|
||||
/// @brief The generation logits. Size [beamSize, maxNewTokens, vocabSizePadded] (non-streaming)
|
||||
/// or [maxNewTokens, beamSize, vocabSizePadded] (streaming and allGeneratedTokens)
|
||||
/// or [1, beamSize, vocabSizePadded] (streaming and non-allGeneratedTokens)
|
||||
std::optional<Tensor> generationLogits;
|
||||
|
||||
/// @brief Logits information for direct transfer when using fast logits
|
||||
std::optional<SpeculativeDecodingFastLogitsInfo> specDecFastLogitsInfo;
|
||||
|
||||
/// @brief The encoder output. Size [encoderLen, hiddenSize]
|
||||
std::optional<Tensor> encoderOutput;
|
||||
|
||||
@ -484,8 +513,8 @@ struct Result
|
||||
class Response
|
||||
{
|
||||
public:
|
||||
Response(IdType requestId, std::string errorMsg);
|
||||
Response(IdType requestId, Result Result);
|
||||
Response(IdType requestId, std::string errorMsg, std::optional<IdType> clientId = std::nullopt);
|
||||
Response(IdType requestId, Result Result, std::optional<IdType> clientId = std::nullopt);
|
||||
|
||||
~Response();
|
||||
Response(Response const& other);
|
||||
@ -496,6 +525,9 @@ public:
|
||||
/// @brief Get the id of the request for which this response was generated
|
||||
[[nodiscard]] IdType getRequestId() const;
|
||||
|
||||
/// @brief Get the client id of the request for which this response was generated
|
||||
[[nodiscard]] std::optional<IdType> getClientId() const;
|
||||
|
||||
/// @brief Indicates if this response has an error or not
|
||||
[[nodiscard]] bool hasError() const;
|
||||
|
||||
@ -873,7 +905,8 @@ public:
|
||||
std::optional<SizeType32> maxQueueSize = std::nullopt,
|
||||
ExtendedRuntimePerfKnobConfig const& extendedRuntimePerfKnobConfig = ExtendedRuntimePerfKnobConfig(),
|
||||
std::optional<DebugConfig> debugConfig = std::nullopt, SizeType32 recvPollPeriodMs = 0,
|
||||
uint64_t maxSeqIdleMicroseconds = 180000000);
|
||||
uint64_t maxSeqIdleMicroseconds = 180000000,
|
||||
std::optional<SpeculativeDecodingConfig> specDecConfig = std::nullopt);
|
||||
|
||||
[[nodiscard]] SizeType32 getMaxBeamWidth() const;
|
||||
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
|
||||
@ -895,6 +928,7 @@ public:
|
||||
[[nodiscard]] std::optional<DebugConfig> getDebugConfig() const;
|
||||
[[nodiscard]] SizeType32 getRecvPollPeriodMs() const;
|
||||
[[nodiscard]] uint64_t getMaxSeqIdleMicroseconds() const;
|
||||
[[nodiscard]] std::optional<SpeculativeDecodingConfig> getSpecDecConfig() const;
|
||||
|
||||
void setMaxBeamWidth(SizeType32 maxBeamWidth);
|
||||
void setMaxBatchSize(SizeType32 maxBatchSize);
|
||||
@ -916,6 +950,7 @@ public:
|
||||
void setDebugConfig(DebugConfig const& debugConfig);
|
||||
void setRecvPollPeriodMs(SizeType32 const& recvPollPeriodMs);
|
||||
void setMaxSeqIdleMicroseconds(uint64_t maxNumTokens);
|
||||
void setSpecDecConfig(SpeculativeDecodingConfig const& specDecConfig);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
@ -978,6 +1013,9 @@ private:
|
||||
/// @brief The maximum time in microseconds a scheduled request can remain idle before getting terminated. Default
|
||||
/// is 3 minutes.
|
||||
uint64_t mMaxSeqIdleMicroseconds;
|
||||
|
||||
/// @brief The speculative decoding configuration
|
||||
std::optional<SpeculativeDecodingConfig> mSpeculativeDecodingConfig;
|
||||
};
|
||||
|
||||
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
|
||||
@ -1080,6 +1118,9 @@ public:
|
||||
/// @brief Indicates if the current process is allowed to enqueueRequests
|
||||
[[nodiscard]] bool canEnqueueRequests() const;
|
||||
|
||||
/// @brief Indicates if the current process participates in this executor instance
|
||||
[[nodiscard]] bool isParticipant() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> mImpl;
|
||||
|
||||
@ -95,6 +95,11 @@ public:
|
||||
static void serialize(Tensor const& tensor, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(Tensor const& tensor);
|
||||
|
||||
// SpeculativeDecodingFastLogitsInfo
|
||||
[[nodiscard]] static SpeculativeDecodingFastLogitsInfo deserializeSpecDecFastLogitsInfo(std::istream& is);
|
||||
static void serialize(SpeculativeDecodingFastLogitsInfo const& info, std::ostream& os);
|
||||
[[nodiscard]] static size_t serializedSize(SpeculativeDecodingFastLogitsInfo const& info);
|
||||
|
||||
// Result
|
||||
[[nodiscard]] static Result deserializeResult(std::istream& is);
|
||||
static void serialize(Result const& result, std::ostream& os);
|
||||
|
||||
@ -446,6 +446,11 @@ public:
|
||||
return DecodingMode{kExplicitDraftTokens | kStandardStopCriteria | kUseExplicitEosStop};
|
||||
}
|
||||
|
||||
static auto constexpr ExternalDraftTokens()
|
||||
{
|
||||
return DecodingMode{kExternalDraftTokens | kUsePenalties | kUseBanTokens | kStandardStopCriteria};
|
||||
}
|
||||
|
||||
auto constexpr useTemperature(bool useTemp)
|
||||
{
|
||||
mState = setBitTo(kUseTemperature, useTemp);
|
||||
@ -563,6 +568,11 @@ public:
|
||||
return anyBitSet(kExplicitDraftTokens);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr isExternalDraftTokens() const
|
||||
{
|
||||
return anyBitSet(kExternalDraftTokens);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr isUseTemperature() const
|
||||
{
|
||||
return anyBitSet(kUseTemperature);
|
||||
@ -676,6 +686,7 @@ private:
|
||||
static UnderlyingType constexpr kMedusa{1u << (kNumFlags + 4)};
|
||||
static UnderlyingType constexpr kLookahead{1u << (kNumFlags + 5)};
|
||||
static UnderlyingType constexpr kExplicitDraftTokens{1u << (kNumFlags + 6)};
|
||||
static UnderlyingType constexpr kExternalDraftTokens{1u << (kNumFlags + 7)};
|
||||
static UnderlyingType constexpr kTopKTopP{kTopK | kTopP};
|
||||
|
||||
[[nodiscard]] bool constexpr anyBitSet(UnderlyingType bits) const
|
||||
@ -706,6 +717,7 @@ static_assert(!DecodingMode::Auto().isBeamSearch());
|
||||
static_assert(!DecodingMode::Auto().isMedusa());
|
||||
static_assert(!DecodingMode::Auto().isLookahead());
|
||||
static_assert(!DecodingMode::Auto().isExplicitDraftTokens());
|
||||
static_assert(!DecodingMode::Auto().isExternalDraftTokens());
|
||||
|
||||
static_assert(DecodingMode::TopK().isTopK());
|
||||
static_assert(DecodingMode::TopK().isTopKorTopP());
|
||||
@ -726,6 +738,7 @@ static_assert(!DecodingMode::TopK().isBeamSearch());
|
||||
static_assert(!DecodingMode::TopK().isMedusa());
|
||||
static_assert(!DecodingMode::TopK().isLookahead());
|
||||
static_assert(!DecodingMode::TopK().isExplicitDraftTokens());
|
||||
static_assert(!DecodingMode::TopK().isExternalDraftTokens());
|
||||
|
||||
static_assert(DecodingMode::TopP().isTopP());
|
||||
static_assert(DecodingMode::TopP().isTopKorTopP());
|
||||
@ -739,6 +752,7 @@ static_assert(!DecodingMode::TopP().isBeamSearch());
|
||||
static_assert(!DecodingMode::TopP().isMedusa());
|
||||
static_assert(!DecodingMode::TopP().isLookahead());
|
||||
static_assert(!DecodingMode::TopP().isExplicitDraftTokens());
|
||||
static_assert(!DecodingMode::TopP().isExternalDraftTokens());
|
||||
|
||||
static_assert(DecodingMode::TopKTopP().isTopK());
|
||||
static_assert(DecodingMode::TopKTopP().isTopP());
|
||||
@ -752,6 +766,7 @@ static_assert(!DecodingMode::TopKTopP().isBeamSearch());
|
||||
static_assert(!DecodingMode::TopKTopP().isMedusa());
|
||||
static_assert(!DecodingMode::TopKTopP().isLookahead());
|
||||
static_assert(!DecodingMode::TopKTopP().isExplicitDraftTokens());
|
||||
static_assert(!DecodingMode::TopKTopP().isExternalDraftTokens());
|
||||
|
||||
static_assert(DecodingMode::BeamSearch().isBeamSearch());
|
||||
static_assert(DecodingMode::BeamSearch().isUseStopCriteria());
|
||||
@ -760,6 +775,7 @@ static_assert(!DecodingMode::BeamSearch().isTopKorTopP());
|
||||
static_assert(!DecodingMode::BeamSearch().isMedusa());
|
||||
static_assert(!DecodingMode::BeamSearch().isLookahead());
|
||||
static_assert(!DecodingMode::BeamSearch().isExplicitDraftTokens());
|
||||
static_assert(!DecodingMode::BeamSearch().isExternalDraftTokens());
|
||||
|
||||
static_assert(!DecodingMode::Medusa().isAuto());
|
||||
static_assert(!DecodingMode::Medusa().isTopK());
|
||||
@ -775,6 +791,7 @@ static_assert(DecodingMode::Medusa().isUseStopCriteria());
|
||||
static_assert(DecodingMode::Medusa().isUsePenalty());
|
||||
static_assert(DecodingMode::Medusa().isUseMinLength());
|
||||
static_assert(DecodingMode::Medusa().isMedusa());
|
||||
static_assert(!DecodingMode::Medusa().isExternalDraftTokens());
|
||||
|
||||
static_assert(!DecodingMode::Lookahead().isAuto());
|
||||
static_assert(!DecodingMode::Lookahead().isTopK());
|
||||
@ -788,6 +805,7 @@ static_assert(DecodingMode::Lookahead().isUseStopCriteria());
|
||||
static_assert(DecodingMode::Lookahead().isUseStopWords());
|
||||
static_assert(DecodingMode::Lookahead().isUseExplicitEosStop());
|
||||
static_assert(DecodingMode::Lookahead().isLookahead());
|
||||
static_assert(!DecodingMode::Lookahead().isExternalDraftTokens());
|
||||
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isAuto());
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isTopK());
|
||||
@ -801,4 +819,19 @@ static_assert(!DecodingMode::ExplicitDraftTokens().isUsePenalty());
|
||||
static_assert(DecodingMode::ExplicitDraftTokens().isUseStopCriteria());
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isUseBanWords());
|
||||
static_assert(DecodingMode::ExplicitDraftTokens().isExplicitDraftTokens());
|
||||
static_assert(!DecodingMode::ExplicitDraftTokens().isExternalDraftTokens());
|
||||
|
||||
static_assert(!DecodingMode::ExternalDraftTokens().isTopK());
|
||||
static_assert(!DecodingMode::ExternalDraftTokens().isTopP());
|
||||
static_assert(!DecodingMode::ExternalDraftTokens().isTopKorTopP());
|
||||
static_assert(!DecodingMode::ExternalDraftTokens().isTopKandTopP());
|
||||
static_assert(DecodingMode::ExternalDraftTokens().isUseBanWords());
|
||||
static_assert(DecodingMode::ExternalDraftTokens().isUseOccurrencePenalty());
|
||||
static_assert(DecodingMode::ExternalDraftTokens().isUseStopCriteria());
|
||||
static_assert(!DecodingMode::ExternalDraftTokens().isAuto());
|
||||
static_assert(!DecodingMode::ExternalDraftTokens().isBeamSearch());
|
||||
static_assert(!DecodingMode::ExternalDraftTokens().isMedusa());
|
||||
static_assert(!DecodingMode::ExternalDraftTokens().isLookahead());
|
||||
static_assert(!DecodingMode::ExternalDraftTokens().isExplicitDraftTokens());
|
||||
static_assert(DecodingMode::ExternalDraftTokens().isExternalDraftTokens());
|
||||
} // namespace tensorrt_llm::executor
|
||||
|
||||
@ -108,6 +108,20 @@ public:
|
||||
TensorConstPtr medusaTargetTokensPerStep; //!< [batchSize], on gpu
|
||||
};
|
||||
|
||||
class ExternalDraftTokensInputs
|
||||
{
|
||||
public:
|
||||
TensorPtr draftLogits;
|
||||
TensorPtr draftProbs;
|
||||
TensorPtr targetProbs;
|
||||
TensorPtr numDraftTokens;
|
||||
TensorPtr draftTokenIds;
|
||||
TensorPtr useDraftLogits;
|
||||
SizeType32 step;
|
||||
float constantThreshold;
|
||||
bool useRandomAcceptanceThreshold;
|
||||
};
|
||||
|
||||
class ExplicitDraftTokensInputs
|
||||
{
|
||||
public:
|
||||
@ -138,6 +152,8 @@ public:
|
||||
std::optional<ExplicitDraftTokensInputs> explicitDraftTokensInputs;
|
||||
|
||||
std::optional<LookaheadInputs> lookaheadInputs;
|
||||
|
||||
std::optional<ExternalDraftTokensInputs> externalDraftTokensInputs;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -95,7 +95,7 @@ public:
|
||||
// mandatory parameters for beam search
|
||||
TensorPtr logProbs; // [BS, BM, MSL], must be float*
|
||||
TensorPtr cumLogProbs; // [BS, BM], optional for sampling
|
||||
TensorPtr parentIds; // [BS, BM, MSL]
|
||||
TensorPtr parentIds; // [BS, BM, MSL] index of the beam where the previous token is
|
||||
TensorPtr lengths; // [BS, BM], total sequence lengths including padding
|
||||
TensorPtr cacheIndirection; // [BS, BM, MSL], k/v indirection for next generation step
|
||||
|
||||
|
||||
@ -64,16 +64,6 @@ public:
|
||||
|
||||
virtual SamplingConfig const& getSamplingConfig() = 0;
|
||||
|
||||
static void acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds,
|
||||
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths,
|
||||
ITensor const& finishedVec, ITensor& finishedFinal, ITensor& finishedSum, ITensor const& batchSlots,
|
||||
BufferManager::CudaStreamPtr const& stream);
|
||||
|
||||
static void acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const& targetLogits, ITensor& draftProbs,
|
||||
ITensor& targetProbs, ITensor const& numDraftTokens, ITensor& finished, ITensor const& batchSlots,
|
||||
SizeType32 vocabSize, SizeType32 vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
|
||||
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream);
|
||||
|
||||
static std::unique_ptr<IGptDecoder> create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
|
||||
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
|
||||
BufferManager::CudaStreamPtr const& stream,
|
||||
|
||||
@ -245,7 +245,7 @@ private:
|
||||
void newRequest(SizeType32 batchSlot, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
|
||||
|
||||
//! @brief Allocate buffers for speculative decoding.
|
||||
void allocateSpeculativeDecodingBuffers();
|
||||
void allocateSpeculativeDecodingBuffers(nvinfer1::DataType dtype);
|
||||
|
||||
//! @brief Setup buffers for speculative decoding.
|
||||
void setupSpeculativeDecoding(ModelConfig const& modelConfig);
|
||||
@ -300,10 +300,6 @@ private:
|
||||
DecodingInputPtr mJointDecodingInput;
|
||||
DecodingOutputPtr mJointDecodingOutput;
|
||||
|
||||
std::vector<bool> mAcceptByLogits;
|
||||
TensorPtr mNumDraftTokens;
|
||||
TensorPtr mCurandStates;
|
||||
|
||||
std::vector<SizeType32> mNbSteps;
|
||||
std::vector<bool> mFinished;
|
||||
TensorPtr mFinishedSum;
|
||||
@ -313,18 +309,9 @@ private:
|
||||
|
||||
TensorPtr mFinishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states of type FinishedState
|
||||
// for each generated token of maxTokensPerStep, on gpu
|
||||
TensorPtr mDraftProbs; // [batchSize, maxTokensPerEngineStep, beamWidth, vocabPadded], temporary data for
|
||||
// speculative decoding accept by logits kernel, on gpu
|
||||
TensorPtr mTargetProbs; // [batchSize, maxTokensPerEngineStep, beamWidth, vocabPadded], temporary data for
|
||||
// speculative decoding accept by logits kernel, on gpu
|
||||
TensorPtr mDraftTokenIds; // [batchSize, maxTokensPerEngineStep], draft token indices, on gpu
|
||||
TensorPtr mDraftLogits; // [batchSize, maxTokensPerEngineStep, vocabSizePadded], draft token logits, on gpu
|
||||
|
||||
TensorPtr mBatchSlotsSetup; // [maxBatchSize], int32_t, address map, pinned
|
||||
TensorPtr mBatchSlotsDecoder; // [maxTokensPerEngineStep, maxBatchSize], int32_t, address map, pinned
|
||||
TensorPtr mBatchSlotsAcceptTokens; // [maxTokensPerEngineStep, maxBatchSize], int32_t, address map, pinned
|
||||
TensorPtr mBatchSlotsAcceptLogits; // [maxTokensPerEngineStep, maxBatchSize], int32_t, address map, pinned
|
||||
TensorPtr mTargetLogitsPtrs; // [maxBatchSize], float*, pointers to target logits, pinned
|
||||
SizeType32 mMaxSequenceLength{};
|
||||
SizeType32 mMaxAttentionWindow{};
|
||||
SizeType32 mSinkTokenLength{};
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:10b940475c5acd80a61674d8ce4e42cc4ef3d806bafb245bbed26751378274e3
|
||||
size 4904726
|
||||
oid sha256:1a292517d802f2297c5d12d5d14ab597f47f46ebd31412fac044ceb9ca51a482
|
||||
size 5160586
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b2754f7887a1b5c37ba3d589320e16144039cfe5dc6a6c78ee71925861d7d511
|
||||
size 5015842
|
||||
oid sha256:8575fb58200701ae30feb4b8bd3f325f8018aac5505167fdba42e269adb3bd8c
|
||||
size 5271836
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
ff71eabd0ac6ede5398b5b6ce4e26dcf libtensorrt_llm_batch_manager_static.a
|
||||
846eb112a182973e7c3b0b193300b4b8 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
954182e0c057f71f858a84f746201044 libtensorrt_llm_batch_manager_static.a
|
||||
dfe6ca360cf1d24a3dcae0a2bf8589c0 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:13b8701dd767b414a5376a91905985979ad9d2b975465ac00835c04656ee6508
|
||||
size 4766226
|
||||
oid sha256:8fe84073b7ccff8dc361fdee64c3ef30bc523909e0bf9c16547f76a05a53fb5c
|
||||
size 5009886
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cd0b73a017fc5c663235dcd724eb104ecc49d12ff29b6e3744be6ea952d027db
|
||||
size 4722522
|
||||
oid sha256:6e565c2c3ce58656742772591d992aca91c7e46eb9fc711599d2d51928b88b48
|
||||
size 4970532
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
1eb5c88f894f3361445d7254cbc29b03 libtensorrt_llm_batch_manager_static.a
|
||||
4e73341b23e8fb20b732ba08e03a54a8 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
61fd34e765788884d42f4ba27f085520 libtensorrt_llm_batch_manager_static.a
|
||||
e8a64dd19a234304483ef6756e67fd40 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b4ac61c0b0816477c11bd6c66ec4c2f23f7b6e1400eacd8c07c333f79dec0bea
|
||||
size 30794956
|
||||
oid sha256:200a6721aa1d6e009c94866adab36ac686eb1beef02df267af7e18e31e11612b
|
||||
size 32436708
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
eefe7310a60098897724f46cf4aa54f8 tensorrt_llm_batch_manager_static.lib
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
9485cfa635b17378f23d1624b3acfbaf tensorrt_llm_batch_manager_static.lib
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
@ -21,7 +21,7 @@
|
||||
namespace tensorrt_llm::utils::customAllReduceUtils
|
||||
{
|
||||
|
||||
constexpr size_t NUM_POINTERS_PER_RANK = 4;
|
||||
constexpr size_t NUM_POINTERS_PER_RANK = 7;
|
||||
|
||||
// WARNING: MUST BE KEPT IN SYNC with tensorrt_llm/plugin/plugin.py
|
||||
inline size_t getMaxRequiredWorkspaceSize(int worldSize) noexcept
|
||||
|
||||
@ -335,6 +335,18 @@ void MpiComm::mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status)
|
||||
#endif // ENABLE_MULTI_DEVICE
|
||||
}
|
||||
|
||||
bool MpiComm::improbe(int source, int tag, MPI_Message* msg, MPI_Status* status) const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
int flag{0};
|
||||
MPICHECK(MPI_Improbe(source, tag, mComm, &flag, msg, status));
|
||||
return flag != 0;
|
||||
#else
|
||||
TLLM_THROW("Multi device support is disabled.");
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool MpiComm::iprobe(int source, int tag, MPI_Status* status) const
|
||||
{
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
|
||||
@ -38,6 +38,12 @@ namespace common
|
||||
template <int VPT>
|
||||
struct BytesToType;
|
||||
|
||||
template <>
|
||||
struct BytesToType<1>
|
||||
{
|
||||
using type = uint8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BytesToType<2>
|
||||
{
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ebab2cc2c62a826ddec02597178b8e0c9bc316726f37f8eef37c06795aebcf03
|
||||
size 1784658
|
||||
oid sha256:809a1da76123ec4c640d63efc902209585223b66e23d887db9a198c5836986a2
|
||||
size 3349066
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4b630f89708614e63c67871e21b6e32bfde71acc51549b650c57048c0fa343e7
|
||||
size 1812686
|
||||
oid sha256:6846ecefa017d03ab7d853908794c884ab4e92a500e223278b1d64eab59ed061
|
||||
size 3376088
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
136f1b9d2168cbb9011a341b267af9a2 libtensorrt_llm_executor_static.a
|
||||
183bd079377d6cd698d46370168a5726 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
5a771664fdb75d99ba5fb90249ac26f0 libtensorrt_llm_executor_static.a
|
||||
3b433ea93b7d1d6fa471b457980f2680 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e04c76f6441a49db4d3996c62b4055395ae018384d8ee2f02ea5f0c4c0843902
|
||||
size 1853180
|
||||
oid sha256:479e86f410763445357f5d879cc666d210352dda9709ab5ab56e73591a9e8af8
|
||||
size 7851266
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:95ba1a4b6bdcecbb592bbb42b4998bcb0eb1f45a318163635183bcde6950c4bf
|
||||
size 1764982
|
||||
oid sha256:6473c77d18929fa75342d63ffc591df39e8aeba1dda0b920b0187d4888710559
|
||||
size 7767384
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
dfbd0d424c150253ff758aa5bd37a971 libtensorrt_llm_executor_static.a
|
||||
e82866739fef1d6df8293541967924bf libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
5424fb0f82076e03b5316f73aed04434 libtensorrt_llm_executor_static.a
|
||||
d0b1236baf61fc5c43383bbc1cd50fa8 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:aa8ba34fb98c5407e3d6944245086158c61b2c784b15c7b923fdd156b942224d
|
||||
size 19670642
|
||||
oid sha256:dee57c9257a6678833e3c0d83e8df07aff25c185bc085db75938cec6652044c0
|
||||
size 24568210
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
784ad1fabd3d02466f95fbc463b64f5b tensorrt_llm_executor_static.lib
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
305fac5d046a574ded2d46d968f746b0 tensorrt_llm_executor_static.lib
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
@ -630,7 +630,7 @@ void topKSoftMaxKernelLauncher(T const* logits, T const* bias, void* workspace,
|
||||
// ┃ pTemp ┃ BS * PAD_K * VP * (2 * (PAD_K * 2) + 2) | | float |
|
||||
// ┗━━━━━━━━━━┛ --------------------------------------------------------------------------------
|
||||
|
||||
// Stage1: gridDim(BS,BM,nVPart), blockDim(nBlockSize,1,1)
|
||||
// beamStage1Kernel: gridDim(BS,BM,nVPart), blockDim(nBlockSize,1,1)
|
||||
// Each ThreadBlock takes `nVocabChunk` contiguous elements in logits to do TopK and reduce_md,
|
||||
// then writes output into pTemp.
|
||||
// At end of this kernel, each ThreadBlock holds the indices and values of the top 2*BM elements,
|
||||
@ -647,7 +647,7 @@ void topKSoftMaxKernelLauncher(T const* logits, T const* bias, void* workspace,
|
||||
// ┃ md ┃ 2 | 2 | float |
|
||||
// ┗━━━━━━━━━━┛ -----------------------------------------
|
||||
|
||||
// Stage2: gridDim(BS,BM,1), blockDim(32/64/128,1,1)
|
||||
// beamStage2Kernel: gridDim(BS,BM,1), blockDim(32/64/128,1,1)
|
||||
// Each TheadBlock takes `nVPart` contiguous Tiles in pTemp to do reduce_topk and reduce_md,
|
||||
// writes output topk_id into in pTempId, writes topk_value + cumLogProbs into pTempVal.
|
||||
|
||||
|
||||
@ -165,7 +165,7 @@ void FusedMHARunnerV2::setupKernelParams(MHARunnerParams runnerParams)
|
||||
// Use exp2f optimization for warp-specialized ws kernels on Hopper.
|
||||
if (mLaunchParams.useBase2ExpTrick)
|
||||
{
|
||||
// The kernel adopts the log2f optimziation.
|
||||
// The kernel adopts the log2f optimization.
|
||||
constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E
|
||||
set_alpha(mKernelParams.scale_bmm1, scale_bmm1 * float(kLog2e), DATA_TYPE_FP32);
|
||||
}
|
||||
@ -364,8 +364,8 @@ void FusedMHARunnerV2::setupLaunchParams(MHARunnerParams runnerParams)
|
||||
void FusedMHARunnerV2::setPackedQkvTmaDescriptors(MHARunnerParams runnerParams)
|
||||
{
|
||||
// split D into multiple groups in order to match the TMA swizzle mode (128B)
|
||||
const uint32_t d_in_bytes = get_size_in_bytes(mLaunchParams.padded_d, mFixedParams.dataType);
|
||||
const uint32_t d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
|
||||
uint32_t const d_in_bytes = get_size_in_bytes(mLaunchParams.padded_d, mFixedParams.dataType);
|
||||
uint32_t const d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
|
||||
|
||||
// separate q, k, v and o tma descriptors
|
||||
Multiple_tma_descriptor<4> qkv_tma_descriptor;
|
||||
@ -421,8 +421,8 @@ void FusedMHARunnerV2::setPackedQkvTmaDescriptors(MHARunnerParams runnerParams)
|
||||
uint32_t fp32_to_tf32 = 0;
|
||||
|
||||
// gmma descriptor mode
|
||||
const uint32_t d_bytes_per_group = d_in_bytes / d_groups;
|
||||
const cudaTmaDescSwizzle swizzle_mode = (d_bytes_per_group > 64
|
||||
uint32_t const d_bytes_per_group = d_in_bytes / d_groups;
|
||||
cudaTmaDescSwizzle const swizzle_mode = (d_bytes_per_group > 64
|
||||
? cudaTmaDescSwizzle::SWIZZLE_128B
|
||||
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
|
||||
|
||||
@ -474,8 +474,8 @@ void FusedMHARunnerV2::setPackedQkvTmaDescriptors(MHARunnerParams runnerParams)
|
||||
void FusedMHARunnerV2::setSeparateQKvTmaDescriptors(MHARunnerParams runnerParams)
|
||||
{
|
||||
// split D into multiple groups in order to match the TMA swizzle mode (128B)
|
||||
const uint32_t d_in_bytes = get_size_in_bytes(mLaunchParams.padded_d, mFixedParams.dataType);
|
||||
const uint32_t d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
|
||||
uint32_t const d_in_bytes = get_size_in_bytes(mLaunchParams.padded_d, mFixedParams.dataType);
|
||||
uint32_t const d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
|
||||
|
||||
uint32_t q_step = 0, kv_step = 0;
|
||||
xmmaKernel->getStepSize(q_step, kv_step, mKernelParams, mLaunchParams);
|
||||
@ -518,7 +518,7 @@ void FusedMHARunnerV2::setSeparateQKvTmaDescriptors(MHARunnerParams runnerParams
|
||||
= (get_size_in_bytes(mFixedParams.dataType) == 1) ? cudaTmaDescFormat::U8 : cudaTmaDescFormat::F16_RN;
|
||||
|
||||
// gmma descriptor mode
|
||||
const uint32_t d_bytes_per_group = d_in_bytes / d_groups;
|
||||
uint32_t const d_bytes_per_group = d_in_bytes / d_groups;
|
||||
cudaTmaDescSwizzle const swizzle_mode = (d_bytes_per_group > 64
|
||||
? cudaTmaDescSwizzle::SWIZZLE_128B
|
||||
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
|
||||
|
||||
@ -17,8 +17,11 @@
|
||||
#include "customAllReduceKernels.h"
|
||||
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/customAllReduceUtils.h"
|
||||
#include "tensorrt_llm/common/dataType.h"
|
||||
#include "tensorrt_llm/common/envUtils.h"
|
||||
#include <cooperative_groups.h>
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
@ -174,12 +177,6 @@ __inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag
|
||||
|
||||
namespace reduce_fusion
|
||||
{
|
||||
namespace details
|
||||
{
|
||||
static constexpr int kBytesPerAccess = 16;
|
||||
static constexpr int kWarpSize = 32;
|
||||
static constexpr int kMaxCtaSize = 1024;
|
||||
}; // namespace details
|
||||
|
||||
inline __device__ float warp_reduce_sum(float val)
|
||||
{
|
||||
@ -318,7 +315,7 @@ __global__ void rms_norm_kernel(AllReduceParams params)
|
||||
}
|
||||
|
||||
template <typename T, bool Bias = false, bool Residual = false, bool Affine = false>
|
||||
void rms_norm_kernel_launcher(AllReduceParams params, cudaStream_t stream)
|
||||
void rms_norm_kernel_launcher(AllReduceParams& params, cudaStream_t stream)
|
||||
{
|
||||
static constexpr int kPackedSize = details::kBytesPerAccess / sizeof(T);
|
||||
TLLM_CHECK(params.fusion_params.hidden_size % kPackedSize == 0);
|
||||
@ -387,6 +384,395 @@ void rms_norm_kernel_launcher(AllReduceParams params, cudaStream_t stream)
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct NegZero128b
|
||||
{
|
||||
static constexpr int v = static_cast<int>(0x80008000);
|
||||
static constexpr int4 value = {v, v, v, v};
|
||||
};
|
||||
|
||||
template <>
|
||||
struct NegZero128b<float>
|
||||
{
|
||||
static constexpr int v = static_cast<int>(0x80000000);
|
||||
static constexpr int4 value = {v, v, v, v};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ static constexpr int4 NegZero128b_v = NegZero128b<T>::value;
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ bool is_neg_zero(T& v);
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ bool is_neg_zero<float>(float& v)
|
||||
{
|
||||
uint32_t bits = *reinterpret_cast<uint32_t*>(&v);
|
||||
return bits == 0x80000000;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ bool is_neg_zero<half>(half& v)
|
||||
{
|
||||
uint16_t bits = *reinterpret_cast<uint16_t*>(&v);
|
||||
return bits == 0x8000;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ bool is_neg_zero<__nv_bfloat16>(__nv_bfloat16& v)
|
||||
{
|
||||
uint16_t bits = *reinterpret_cast<uint16_t*>(&v);
|
||||
return bits == 0x8000;
|
||||
}
|
||||
|
||||
template <typename ValType, typename VecType>
|
||||
__device__ __forceinline__ VecType remove_neg_zero(VecType const& vec)
|
||||
{
|
||||
static constexpr int kIter = sizeof(VecType) / sizeof(ValType);
|
||||
using ReadOnlyValType = std::add_const_t<ValType>;
|
||||
VecType ret;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kIter; ++i)
|
||||
{
|
||||
auto val = reinterpret_cast<ReadOnlyValType*>(&vec)[i];
|
||||
reinterpret_cast<ValType*>(&ret)[i] = is_neg_zero(val) ? static_cast<ValType>(0.f) : val;
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
|
||||
template <typename ValType, typename VecType>
|
||||
__device__ __forceinline__ bool has_neg_zero(VecType const& vec)
|
||||
{
|
||||
static constexpr int kIter = sizeof(VecType) / sizeof(ValType);
|
||||
using ReadOnlyValType = std::add_const_t<ValType>;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kIter; ++i)
|
||||
{
|
||||
auto val = reinterpret_cast<ReadOnlyValType*>(&vec)[i];
|
||||
if (is_neg_zero(val))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
template <typename ValType, typename VecType>
|
||||
__device__ __forceinline__ bool all_neg_zero(VecType const& vec)
|
||||
{
|
||||
static constexpr int kIter = sizeof(VecType) / sizeof(ValType);
|
||||
using ReadOnlyValType = std::add_const_t<ValType>;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kIter; ++i)
|
||||
{
|
||||
auto val = reinterpret_cast<ReadOnlyValType*>(&vec)[i];
|
||||
if (!is_neg_zero(val))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_global_release(int4 const& val, int4* addr)
|
||||
{
|
||||
asm volatile("st.release.global.sys.v4.b32 [%4], {%0, %1, %2, %3};" ::"r"(val.x), "r"(val.y), "r"(val.z),
|
||||
"r"(val.w), "l"(addr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int4 ld_global_acquire(int4* addr)
|
||||
{
|
||||
int4 val;
|
||||
asm volatile("ld.acquire.global.sys.v4.b32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
|
||||
: "l"(addr));
|
||||
return val;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void st_global_volatile(int4 const& val, int4* addr)
|
||||
{
|
||||
asm volatile("st.volatile.global.v4.b32 [%4], {%0, %1, %2, %3};" ::"r"(val.x), "r"(val.y), "r"(val.z), "r"(val.w),
|
||||
"l"(addr));
|
||||
}
|
||||
|
||||
__device__ __forceinline__ int4 ld_global_volatile(int4* addr)
|
||||
{
|
||||
int4 val;
|
||||
asm volatile("ld.volatile.global.v4.b32 {%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(val.x), "=r"(val.y), "=r"(val.z), "=r"(val.w)
|
||||
: "l"(addr));
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename ValType>
|
||||
__device__ __forceinline__ void set_neg_zero(int4* addr)
|
||||
{
|
||||
st_global_volatile(NegZero128b_v<ValType>, addr);
|
||||
}
|
||||
|
||||
template <typename T, int RanksPerNode, bool PushMode>
|
||||
struct Reducer;
|
||||
|
||||
template <typename T, int RanksPerNode>
|
||||
struct Reducer<T, RanksPerNode, true>
|
||||
{
|
||||
static __device__ __forceinline__ int4 allreduce(AllReduceParams& params, int global_offset)
|
||||
{
|
||||
using PackedStruct = typename PackedOn16Bytes<T>::Type;
|
||||
int ping = params.barrier_flag % 3;
|
||||
int pong = (params.barrier_flag + 2) % 3;
|
||||
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
|
||||
T* local_shared_buffer = reinterpret_cast<T*>(
|
||||
params.fusion_params.lamport_peer_comm_buffer_ptrs[params.local_rank + ping * MAX_RANKS_PER_NODE]);
|
||||
T* local_clean_buffer = reinterpret_cast<T*>(
|
||||
params.fusion_params.lamport_peer_comm_buffer_ptrs[params.local_rank + pong * MAX_RANKS_PER_NODE]);
|
||||
local_input_buffer += global_offset;
|
||||
local_shared_buffer += global_offset;
|
||||
local_clean_buffer += global_offset;
|
||||
T* buffers[RanksPerNode];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RanksPerNode; ++ii)
|
||||
{
|
||||
int rank = (params.local_rank + ii) % RanksPerNode;
|
||||
buffers[ii] = reinterpret_cast<T*>(
|
||||
params.fusion_params.lamport_peer_comm_buffer_ptrs[rank + ping * MAX_RANKS_PER_NODE])
|
||||
+ global_offset + params.local_rank * params.elts_total;
|
||||
}
|
||||
PackedStruct sum_vec, val;
|
||||
val.packed = remove_neg_zero<T>(*reinterpret_cast<int4 const*>(local_input_buffer));
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < RanksPerNode; ++ii)
|
||||
{
|
||||
st_global_volatile(val.packed, reinterpret_cast<int4*>(buffers[ii]));
|
||||
}
|
||||
sum_vec.packed = val.packed;
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < RanksPerNode; ++ii)
|
||||
{
|
||||
int rank = (params.local_rank + ii) % RanksPerNode;
|
||||
set_neg_zero<T>(reinterpret_cast<int4*>(local_clean_buffer + rank * params.elts_total));
|
||||
}
|
||||
PackedStruct vals[RanksPerNode - 1];
|
||||
bool done = false;
|
||||
while (!done)
|
||||
{
|
||||
done = true;
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < RanksPerNode; ++ii)
|
||||
{
|
||||
int rank = (params.local_rank + ii) % RanksPerNode;
|
||||
vals[ii - 1].packed
|
||||
= ld_global_volatile(reinterpret_cast<int4*>(local_shared_buffer + rank * params.elts_total));
|
||||
}
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RanksPerNode - 1; ii++)
|
||||
{
|
||||
done &= !has_neg_zero<T>(vals[ii].packed);
|
||||
}
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < RanksPerNode; ++ii)
|
||||
{
|
||||
sum_vec.packed = add128b(sum_vec, vals[ii - 1]);
|
||||
}
|
||||
return sum_vec.packed;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int RanksPerNode>
|
||||
struct Reducer<T, RanksPerNode, false>
|
||||
{
|
||||
static __device__ __forceinline__ int4 allreduce(AllReduceParams& params, int global_offset)
|
||||
{
|
||||
using PackedStruct = typename PackedOn16Bytes<T>::Type;
|
||||
int ping = params.barrier_flag % 3;
|
||||
int pong = (params.barrier_flag + 2) % 3;
|
||||
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
|
||||
T* local_shared_buffer = reinterpret_cast<T*>(
|
||||
params.fusion_params.lamport_peer_comm_buffer_ptrs[params.local_rank + ping * MAX_RANKS_PER_NODE]);
|
||||
T* local_clean_buffer = reinterpret_cast<T*>(
|
||||
params.fusion_params.lamport_peer_comm_buffer_ptrs[params.local_rank + pong * MAX_RANKS_PER_NODE]);
|
||||
local_input_buffer += global_offset;
|
||||
local_shared_buffer += global_offset;
|
||||
local_clean_buffer += global_offset;
|
||||
T* buffers[RanksPerNode];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RanksPerNode; ++ii)
|
||||
{
|
||||
int rank = (params.local_rank + ii) % RanksPerNode;
|
||||
buffers[ii] = reinterpret_cast<T*>(
|
||||
params.fusion_params.lamport_peer_comm_buffer_ptrs[rank + ping * MAX_RANKS_PER_NODE])
|
||||
+ global_offset;
|
||||
}
|
||||
PackedStruct sum_vec, val;
|
||||
val.packed = remove_neg_zero<T>(*reinterpret_cast<int4 const*>(local_input_buffer));
|
||||
st_global_volatile(val.packed, reinterpret_cast<int4*>(local_shared_buffer));
|
||||
sum_vec.packed = val.packed;
|
||||
#pragma unroll
|
||||
for (int ii = 1; ii < RanksPerNode; ++ii)
|
||||
{
|
||||
do
|
||||
{
|
||||
val.packed = ld_global_volatile(reinterpret_cast<int4*>(buffers[ii]));
|
||||
} while (has_neg_zero<T>(val.packed));
|
||||
sum_vec.packed = add128b(sum_vec, val);
|
||||
}
|
||||
set_neg_zero<T>(reinterpret_cast<int4*>(local_clean_buffer));
|
||||
return sum_vec.packed;
|
||||
}
|
||||
};
|
||||
|
||||
template <int ClusterSize, typename T, int RanksPerNode, bool Bias = false, bool Affine = false, bool PushMode = true>
|
||||
static __global__ void lamport_style_one_shot_all_reduce_norm_kernel(AllReduceParams params)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
|
||||
namespace cg = cooperative_groups;
|
||||
static_assert(RanksPerNode <= 8);
|
||||
static constexpr int kPackedSize = details::kBytesPerAccess / sizeof(T);
|
||||
using PackedStruct = typename PackedOn16Bytes<T>::Type;
|
||||
|
||||
cg::cluster_group cluster = cg::this_cluster();
|
||||
|
||||
__shared__ float cluster_acc;
|
||||
|
||||
int bid = blockIdx.x, tid = threadIdx.x;
|
||||
int cluster_id = bid / ClusterSize, cluster_block_rank = bid % ClusterSize;
|
||||
|
||||
int token_id = cluster_id;
|
||||
int cluster_offset = token_id * params.fusion_params.hidden_size;
|
||||
int block_offset = cluster_block_rank * params.fusion_params.hidden_size / ClusterSize;
|
||||
int thread_offset = tid * kPackedSize;
|
||||
|
||||
int inner_token_offset = block_offset + thread_offset;
|
||||
int global_offset = cluster_offset + inner_token_offset;
|
||||
|
||||
T const* bias_buffer = reinterpret_cast<T const*>(params.fusion_params.bias_buffer);
|
||||
T const* residual_buffer = reinterpret_cast<T const*>(params.fusion_params.residual_buffer);
|
||||
T const* weight_buffer = reinterpret_cast<T const*>(params.fusion_params.weight_buffer);
|
||||
T* local_final_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);
|
||||
T* intermediate_buffer = reinterpret_cast<T*>(params.fusion_params.intermediate_buffer);
|
||||
|
||||
local_final_output_buffer += global_offset;
|
||||
intermediate_buffer += global_offset;
|
||||
residual_buffer += global_offset;
|
||||
bias_buffer += inner_token_offset;
|
||||
weight_buffer += inner_token_offset;
|
||||
|
||||
PackedStruct weight_vec, bias_vec, residual_vec;
|
||||
residual_vec.packed = *reinterpret_cast<int4 const*>(residual_buffer);
|
||||
if constexpr (Bias)
|
||||
{
|
||||
bias_vec.packed = *reinterpret_cast<int4 const*>(bias_buffer);
|
||||
}
|
||||
if constexpr (Affine)
|
||||
{
|
||||
weight_vec.packed = *reinterpret_cast<int4 const*>(weight_buffer);
|
||||
}
|
||||
|
||||
cudaGridDependencySynchronize();
|
||||
|
||||
float acc = 0.f;
|
||||
PackedStruct sum_vec;
|
||||
sum_vec.packed = Reducer<T, RanksPerNode, PushMode>::allreduce(params, global_offset);
|
||||
|
||||
if constexpr (Bias)
|
||||
{
|
||||
sum_vec.packed = add128b(sum_vec, bias_vec);
|
||||
}
|
||||
sum_vec.packed = add128b(sum_vec, residual_vec);
|
||||
*reinterpret_cast<int4*>(intermediate_buffer) = sum_vec.packed;
|
||||
acc = accumulate<T>(acc, sum_vec);
|
||||
acc = block_reduce_sum(acc);
|
||||
if (ClusterSize > 1)
|
||||
{
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
cluster_acc = acc;
|
||||
}
|
||||
cluster.sync();
|
||||
acc = 0.f;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < ClusterSize; ++ii)
|
||||
{
|
||||
acc += *cluster.map_shared_rank(&cluster_acc, ii);
|
||||
}
|
||||
}
|
||||
|
||||
float denom = __fsqrt_rn(__fdividef(acc, params.fusion_params.hidden_size) + params.fusion_params.eps);
|
||||
sum_vec.packed = rms_norm<T, Affine>(denom, sum_vec, weight_vec);
|
||||
*reinterpret_cast<int4*>(local_final_output_buffer) = sum_vec.packed;
|
||||
|
||||
cudaTriggerProgrammaticLaunchCompletion();
|
||||
#endif
|
||||
}
|
||||
|
||||
int heuristic_min_warp_number(int tp_size, int hidden_size)
|
||||
{
|
||||
if (hidden_size >= 4096)
|
||||
{
|
||||
return 4;
|
||||
}
|
||||
if (tp_size == 2)
|
||||
{
|
||||
return 32;
|
||||
}
|
||||
else
|
||||
{
|
||||
return 16;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int RanksPerNode, bool Bias, bool Affine>
|
||||
void lamport_style_one_shot_all_reduce_norm_kernel_launcher(AllReduceParams params, cudaStream_t stream)
|
||||
{
|
||||
static constexpr int kPackedSize = details::kBytesPerAccess / sizeof(T);
|
||||
TLLM_CHECK(params.fusion_params.hidden_size % kPackedSize == 0);
|
||||
int threads_per_token = params.fusion_params.hidden_size / kPackedSize;
|
||||
int warps_per_token = (threads_per_token + details::kWarpSize - 1) / details::kWarpSize;
|
||||
int token_num = params.elts_total / params.fusion_params.hidden_size;
|
||||
int warp_min_number = heuristic_min_warp_number(RanksPerNode, params.fusion_params.hidden_size);
|
||||
int cluster_size = std::min(((warps_per_token + warp_min_number - 1) / warp_min_number), details::kClusterMaxSize);
|
||||
int cta_size = warps_per_token / cluster_size * details::kWarpSize;
|
||||
TLLM_CHECK(cta_size <= details::kMaxCtaSize);
|
||||
int cta_num = token_num * cluster_size;
|
||||
cudaLaunchConfig_t kernel_config = {0};
|
||||
kernel_config.gridDim = cta_num;
|
||||
kernel_config.blockDim = cta_size;
|
||||
kernel_config.dynamicSmemBytes = 0;
|
||||
kernel_config.stream = stream;
|
||||
|
||||
cudaLaunchAttribute attribute[2];
|
||||
attribute[0].id = cudaLaunchAttributeClusterDimension;
|
||||
attribute[0].val.clusterDim.x = cluster_size;
|
||||
attribute[0].val.clusterDim.y = 1;
|
||||
attribute[0].val.clusterDim.z = 1;
|
||||
kernel_config.attrs = attribute;
|
||||
kernel_config.numAttrs = 1;
|
||||
if (tensorrt_llm::common::getEnvEnablePDL())
|
||||
{
|
||||
attribute[1].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attribute[1].val.programmaticStreamSerializationAllowed = 1;
|
||||
kernel_config.numAttrs++;
|
||||
}
|
||||
#define LAUNCH_LAMPORT_KERNEL(CLUSTER_SIZE) \
|
||||
if (cluster_size == CLUSTER_SIZE) \
|
||||
{ \
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(&kernel_config, \
|
||||
lamport_style_one_shot_all_reduce_norm_kernel<CLUSTER_SIZE, T, RanksPerNode, Bias, Affine>, params)); \
|
||||
return; \
|
||||
}
|
||||
LAUNCH_LAMPORT_KERNEL(1);
|
||||
LAUNCH_LAMPORT_KERNEL(2);
|
||||
LAUNCH_LAMPORT_KERNEL(3);
|
||||
LAUNCH_LAMPORT_KERNEL(4);
|
||||
LAUNCH_LAMPORT_KERNEL(5);
|
||||
LAUNCH_LAMPORT_KERNEL(6);
|
||||
LAUNCH_LAMPORT_KERNEL(7);
|
||||
LAUNCH_LAMPORT_KERNEL(8);
|
||||
#undef LAUNCH_LAMPORT_KERNEL
|
||||
}
|
||||
|
||||
template <typename T, int RanksPerNode, bool Bias = false, bool Affine = false, bool UseSmem = false>
|
||||
static __global__ void __launch_bounds__(1024, 1) one_shot_all_reduce_norm_kernel(AllReduceParams params)
|
||||
{
|
||||
@ -495,79 +881,144 @@ static __global__ void __launch_bounds__(1024, 1) one_shot_all_reduce_norm_kerne
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool is_lamport_supported(int token_num)
|
||||
{
|
||||
static char* disableLamportReduceNormFusionChar = std::getenv("DISABLE_LAMPORT_REDUCE_NORM_FUSION");
|
||||
bool disableLamportReduceNormFusion = (disableLamportReduceNormFusionChar != nullptr);
|
||||
if (disableLamportReduceNormFusion)
|
||||
return false;
|
||||
static int sm = tensorrt_llm::common::getSMVersion();
|
||||
if (sm < 90)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (!std::is_same_v<T, half> && !std::is_same_v<T, __nv_bfloat16>)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (token_num > details::kLamportTokenNumThreshold)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_lamport_supported(nvinfer1::DataType dataType, int token_num)
|
||||
{
|
||||
switch (dataType)
|
||||
{
|
||||
case nvinfer1::DataType::kFLOAT: return is_lamport_supported<float>(token_num);
|
||||
case nvinfer1::DataType::kHALF: return is_lamport_supported<half>(token_num);
|
||||
#ifdef ENABLE_BF16
|
||||
case nvinfer1::DataType::kBF16: return is_lamport_supported<__nv_bfloat16>(token_num);
|
||||
#endif
|
||||
default: return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int RanksPerNode, bool Bias, bool Affine>
|
||||
void one_shot_all_reduce_norm_kernel_launcher(AllReduceParams params, cudaStream_t stream)
|
||||
void one_shot_all_reduce_norm_kernel_launcher(AllReduceParams& params, cudaStream_t stream)
|
||||
{
|
||||
int token_num = params.elts_total / params.fusion_params.hidden_size;
|
||||
if (is_lamport_supported<T>(token_num))
|
||||
{
|
||||
lamport_style_one_shot_all_reduce_norm_kernel_launcher<T, RanksPerNode, Bias, Affine>(params, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
static constexpr int kPackedSize = details::kBytesPerAccess / sizeof(T);
|
||||
TLLM_CHECK(params.fusion_params.hidden_size % kPackedSize == 0);
|
||||
int need_threads = params.fusion_params.hidden_size / kPackedSize;
|
||||
int cta_size;
|
||||
if (need_threads <= details::kMaxCtaSize)
|
||||
{
|
||||
cta_size = (need_threads + details::kWarpSize - 1) / details::kWarpSize * details::kWarpSize;
|
||||
}
|
||||
else
|
||||
{
|
||||
cta_size = details::kMaxCtaSize;
|
||||
}
|
||||
int norm_num = params.elts_total / params.fusion_params.hidden_size;
|
||||
int cta_num = std::min(norm_num, static_cast<int>(MAX_ALL_REDUCE_BLOCKS));
|
||||
int smem_size = 0;
|
||||
|
||||
if (cta_size * kPackedSize < params.fusion_params.hidden_size)
|
||||
{
|
||||
smem_size = params.fusion_params.hidden_size * sizeof(T);
|
||||
if (tensorrt_llm::common::getEnvEnablePDL())
|
||||
{
|
||||
TLLM_LOG_DEBUG("Enable PDL in one_shot_all_reduce_norm_kernel");
|
||||
|
||||
cudaLaunchConfig_t kernelConfig = {0};
|
||||
kernelConfig.gridDim = cta_num;
|
||||
kernelConfig.blockDim = cta_size;
|
||||
kernelConfig.dynamicSmemBytes = smem_size;
|
||||
kernelConfig.stream = stream;
|
||||
|
||||
cudaLaunchAttribute attribute[1];
|
||||
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attribute[0].val.programmaticStreamSerializationAllowed = 1;
|
||||
kernelConfig.attrs = attribute;
|
||||
kernelConfig.numAttrs = 1;
|
||||
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
|
||||
&kernelConfig, one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, true>, params));
|
||||
}
|
||||
else
|
||||
{
|
||||
one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, true>
|
||||
<<<cta_num, cta_size, smem_size, stream>>>(params);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (tensorrt_llm::common::getEnvEnablePDL())
|
||||
{
|
||||
cudaLaunchConfig_t kernelConfig = {0};
|
||||
kernelConfig.gridDim = cta_num;
|
||||
kernelConfig.blockDim = cta_size;
|
||||
kernelConfig.dynamicSmemBytes = smem_size;
|
||||
kernelConfig.stream = stream;
|
||||
|
||||
cudaLaunchAttribute attribute[1];
|
||||
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attribute[0].val.programmaticStreamSerializationAllowed = 1;
|
||||
kernelConfig.attrs = attribute;
|
||||
kernelConfig.numAttrs = 1;
|
||||
|
||||
TLLM_LOG_DEBUG("Enable PDL in one_shot_all_reduce_norm_kernel");
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
|
||||
&kernelConfig, one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, false>, params));
|
||||
}
|
||||
else
|
||||
{
|
||||
one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, false>
|
||||
<<<cta_num, cta_size, smem_size, stream>>>(params);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void lamport_initialize_kernel(T* buffer, size_t size)
|
||||
{
|
||||
static constexpr int kPackedSize = details::kBytesPerAccess / sizeof(T);
|
||||
TLLM_CHECK(params.fusion_params.hidden_size % kPackedSize == 0);
|
||||
int need_threads = params.fusion_params.hidden_size / kPackedSize;
|
||||
int cta_size;
|
||||
if (need_threads <= details::kMaxCtaSize)
|
||||
using PackedStruct = typename PackedOn16Bytes<T>::Type;
|
||||
for (size_t offset = (blockIdx.x * blockDim.x + threadIdx.x) * kPackedSize; offset < size;
|
||||
offset += gridDim.x * blockDim.x * kPackedSize)
|
||||
{
|
||||
cta_size = (need_threads + details::kWarpSize - 1) / details::kWarpSize * details::kWarpSize;
|
||||
set_neg_zero<T>(reinterpret_cast<int4*>(&buffer[offset]));
|
||||
}
|
||||
else
|
||||
{
|
||||
cta_size = details::kMaxCtaSize;
|
||||
}
|
||||
int norm_num = params.elts_total / params.fusion_params.hidden_size;
|
||||
int cta_num = std::min(norm_num, static_cast<int>(MAX_ALL_REDUCE_BLOCKS));
|
||||
int smem_size = 0;
|
||||
}
|
||||
|
||||
if (cta_size * kPackedSize < params.fusion_params.hidden_size)
|
||||
{
|
||||
smem_size = params.fusion_params.hidden_size * sizeof(T);
|
||||
if (tensorrt_llm::common::getEnvEnablePDL())
|
||||
{
|
||||
TLLM_LOG_DEBUG("Enable PDL in one_shot_all_reduce_norm_kernel");
|
||||
|
||||
cudaLaunchConfig_t kernelConfig = {0};
|
||||
kernelConfig.gridDim = cta_num;
|
||||
kernelConfig.blockDim = cta_size;
|
||||
kernelConfig.dynamicSmemBytes = smem_size;
|
||||
kernelConfig.stream = stream;
|
||||
|
||||
cudaLaunchAttribute attribute[1];
|
||||
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attribute[0].val.programmaticStreamSerializationAllowed = 1;
|
||||
kernelConfig.attrs = attribute;
|
||||
kernelConfig.numAttrs = 1;
|
||||
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
|
||||
&kernelConfig, one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, true>, params));
|
||||
}
|
||||
else
|
||||
{
|
||||
one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, true>
|
||||
<<<cta_num, cta_size, smem_size, stream>>>(params);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (tensorrt_llm::common::getEnvEnablePDL())
|
||||
{
|
||||
cudaLaunchConfig_t kernelConfig = {0};
|
||||
kernelConfig.gridDim = cta_num;
|
||||
kernelConfig.blockDim = cta_size;
|
||||
kernelConfig.dynamicSmemBytes = smem_size;
|
||||
kernelConfig.stream = stream;
|
||||
|
||||
cudaLaunchAttribute attribute[1];
|
||||
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
|
||||
attribute[0].val.programmaticStreamSerializationAllowed = 1;
|
||||
kernelConfig.attrs = attribute;
|
||||
kernelConfig.numAttrs = 1;
|
||||
|
||||
TLLM_LOG_DEBUG("Enable PDL in one_shot_all_reduce_norm_kernel");
|
||||
TLLM_CUDA_CHECK(cudaLaunchKernelEx(
|
||||
&kernelConfig, one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, false>, params));
|
||||
}
|
||||
else
|
||||
{
|
||||
one_shot_all_reduce_norm_kernel<T, RanksPerNode, Bias, Affine, false>
|
||||
<<<cta_num, cta_size, smem_size, stream>>>(params);
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
void lamport_initialize_kernel_launcher(void* buffer, size_t size, cudaStream_t stream)
|
||||
{
|
||||
static constexpr int kPackedSize = details::kBytesPerAccess / sizeof(T);
|
||||
int block_size = 1024;
|
||||
int grid_size = (size + 1024 * kPackedSize - 1) / (1024 * kPackedSize);
|
||||
lamport_initialize_kernel<T><<<grid_size, block_size, 0, stream>>>(reinterpret_cast<T*>(buffer), size);
|
||||
}
|
||||
}; // namespace reduce_fusion
|
||||
|
||||
@ -1117,13 +1568,24 @@ void AllReduceDispatchType(AllReduceParams& params, AllReduceStrategyType strat,
|
||||
}
|
||||
}
|
||||
|
||||
AllReduceParams AllReduceParams::deserialize(int64_t* buffer, size_t tpSize, size_t tpRank)
|
||||
AllReduceParams AllReduceParams::deserialize(
|
||||
int64_t* buffer, size_t tpSize, size_t tpRank, nvinfer1::DataType dataType, int token_num, AllReduceFusionOp op)
|
||||
{
|
||||
void* const* buffer_ptrs = reinterpret_cast<void* const*>(buffer);
|
||||
auto const flag_ptr = &buffer[4 * tpSize];
|
||||
int flag_offset;
|
||||
if (op == AllReduceFusionOp::RESIDUAL_RMS_NORM && reduce_fusion::is_lamport_supported(dataType, token_num))
|
||||
{
|
||||
flag_offset = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
flag_offset = 1;
|
||||
}
|
||||
auto const flag_ptr
|
||||
= &buffer[tensorrt_llm::utils::customAllReduceUtils::NUM_POINTERS_PER_RANK * tpSize + flag_offset];
|
||||
// cannot use 0 since 0 represents released state for barrier
|
||||
*flag_ptr += 1;
|
||||
TLLM_LOG_TRACE("AllReduceParams's flag value is %d", *flag_ptr);
|
||||
TLLM_LOG_TRACE("AllReduceParams's flag value is %d, flag offset %d", *flag_ptr, flag_offset);
|
||||
uint32_t flag_value = *flag_ptr;
|
||||
AllReduceParams params;
|
||||
// Even plugins use ping buffers, odd plugins use pong.
|
||||
@ -1208,4 +1670,25 @@ void residualRmsNorm(kernels::AllReduceParams& params, nvinfer1::DataType dataTy
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
void lamportInitialize(void* buffer, size_t size, nvinfer1::DataType dataType, cudaStream_t stream)
|
||||
{
|
||||
sync_check_cuda_error();
|
||||
switch (dataType)
|
||||
{
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
reduce_fusion::lamport_initialize_kernel_launcher<float>(buffer, size, stream);
|
||||
break;
|
||||
case nvinfer1::DataType::kHALF:
|
||||
reduce_fusion::lamport_initialize_kernel_launcher<half>(buffer, size, stream);
|
||||
break;
|
||||
#ifdef ENABLE_BF16
|
||||
case nvinfer1::DataType::kBF16:
|
||||
reduce_fusion::lamport_initialize_kernel_launcher<__nv_bfloat16>(buffer, size, stream);
|
||||
break;
|
||||
#endif
|
||||
default: TLLM_THROW("Unsupported dataType for customAllReduce");
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::kernels
|
||||
|
||||
@ -31,6 +31,15 @@ constexpr size_t MAX_ALL_REDUCE_BLOCKS = 24;
|
||||
constexpr size_t MAX_RANKS_PER_NODE = 8;
|
||||
constexpr size_t DEFAULT_BLOCK_SIZE = 512;
|
||||
|
||||
namespace reduce_fusion::details
|
||||
{
|
||||
static constexpr int kBytesPerAccess = 16;
|
||||
static constexpr int kWarpSize = 32;
|
||||
static constexpr int kMaxCtaSize = 1024;
|
||||
static constexpr int kClusterMaxSize = 8;
|
||||
static constexpr int kLamportTokenNumThreshold = 16;
|
||||
}; // namespace reduce_fusion::details
|
||||
|
||||
// Warning: python definition is in tensorrt_llm/functional.py
|
||||
// they must be kept in sync
|
||||
enum class AllReduceStrategyType : int8_t
|
||||
@ -73,6 +82,7 @@ struct AllReduceFusionParams
|
||||
float eps;
|
||||
// new residual
|
||||
void* intermediate_buffer;
|
||||
void* lamport_peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE * 3];
|
||||
};
|
||||
|
||||
struct AllReduceParams
|
||||
@ -81,7 +91,8 @@ struct AllReduceParams
|
||||
size_t elts_per_rank;
|
||||
size_t elts_per_block;
|
||||
size_t rank_offset;
|
||||
size_t ranks_per_node, local_rank;
|
||||
size_t ranks_per_node;
|
||||
size_t local_rank;
|
||||
uint32_t barrier_flag;
|
||||
uint32_t* peer_barrier_ptrs_in[MAX_RANKS_PER_NODE];
|
||||
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
|
||||
@ -91,7 +102,8 @@ struct AllReduceParams
|
||||
|
||||
AllReduceFusionParams fusion_params;
|
||||
|
||||
static AllReduceParams deserialize(int64_t* buffer, size_t tpSize, size_t tpRank);
|
||||
static AllReduceParams deserialize(int64_t* buffer, size_t tpSize, size_t tpRank, nvinfer1::DataType dataType,
|
||||
int token_num, AllReduceFusionOp op);
|
||||
};
|
||||
|
||||
bool configurationSupported(AllReduceStrategyType algo, size_t msg_size, size_t n_ranks, nvinfer1::DataType type);
|
||||
@ -101,4 +113,6 @@ void customAllReduce(kernels::AllReduceParams& params, nvinfer1::DataType dataTy
|
||||
|
||||
void residualRmsNorm(kernels::AllReduceParams& params, nvinfer1::DataType dataType, cudaStream_t stream);
|
||||
|
||||
void lamportInitialize(void* buffer, size_t size, nvinfer1::DataType dataType, cudaStream_t stream);
|
||||
|
||||
} // namespace tensorrt_llm::kernels
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
88c30973b9b3452baa3f063d34d08169 libtensorrt_llm_nvrtc_wrapper.so
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
@ -1,2 +1,2 @@
|
||||
95e9f87610383348e444d2d0b8396f2d libtensorrt_llm_nvrtc_wrapper.so
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1471e322bb44cd65b98ee30e0befa32ae4c86e828f0b4fd4f02d4af4e710d08f
|
||||
oid sha256:db512d533ab4e4a4abd0047a65d891dfd6e1522f2d34c90f29296c3239fd3cc1
|
||||
size 1128448
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
b7e624ba775e9f5090ef4b67bcdbd7a2 tensorrt_llm_nvrtc_wrapper.lib
|
||||
f9b1cc37a27dd0574bb41a2763a97be7 tensorrt_llm_nvrtc_wrapper.dll
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
d89a0a140d2d427af13c3794a4b21e2c tensorrt_llm_nvrtc_wrapper.dll
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
@ -121,6 +121,17 @@ void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_ti
|
||||
|
||||
namespace runtime::kernels
|
||||
{
|
||||
//! \brief Inserts the running beams into the finished beams stored in the CBA buffers. (beams where the most likely
|
||||
//! continuation is the end token get stored separately, and another candidate next token is stored). Then sorts the
|
||||
//! beams according to their cumulative log probs. Note: the kernels in gatherTree modify the buffers inplace. When
|
||||
//! streaming, we use tmp buffers since beam search kernels expect ungathered data.
|
||||
//!
|
||||
//! \param decodingOutput contains a slice of the output buffers to gather. Also contains the
|
||||
//! DecodingOutput::BeamHypotheses object with the finished beams.
|
||||
//! \param decodingInput used for endIds and input lengths.
|
||||
//! \param manager the usual buffer manager.
|
||||
//! \param samplingConfig the usual buffer samplingConfig.
|
||||
|
||||
void gatherTree(DecodingOutput const& decodingOutput, DecodingInput const& decodingInput, BufferManager const& manager,
|
||||
SamplingConfig const& samplingConfig);
|
||||
} // namespace runtime::kernels
|
||||
|
||||
@ -228,7 +228,7 @@ __global__ __launch_bounds__(THREADS_PER_BLOCK) void computeSeqAndPaddingOffsets
|
||||
}
|
||||
}
|
||||
|
||||
// Perpare values for fmha.
|
||||
// Prepare values for fmha.
|
||||
if (threadIdx.x == 0 && blockIdx.x == 0)
|
||||
{
|
||||
// Reset fmha tile counter to 0 before launching fmha kernels.
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9117f7cf5eef0ed452c0d0bc79242b84def103e7038c9d3df6e366690801ca92
|
||||
oid sha256:0814af36fed752bbe70d953cefbb78dd306c42f3d9f6848b7043a865e48f9662
|
||||
size 25364090
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2b04913f9e9029a5ce5a222d5cc7492ff53323a548079d2fb32d5b2aeb0c2268
|
||||
oid sha256:ee46f2d1c9162f4302a1031f778fcb7c7110c84110427f97af6532ed9bd342fd
|
||||
size 25768990
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
d54fb93f256601f4c4ad7f1c8e6e9919 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
71028d801074f11138e890391e48591d libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
90740ead1def66f350e14c133278463d libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
b0104227ffd1ce19fc1fdb45e349df36 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d8c685f8ea2f84838dfdbf448eab41c76fe88fe29db0d4a511d6d6d241ad1832
|
||||
oid sha256:4d9ba0f8b95cf64227cb0b17654fb7c9bc1741fe003889658b305750b388a4dc
|
||||
size 44173632
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b9d75392ba3b59853c43072b4f9949b32cb6724813a39048e4585e9a8fb3e136
|
||||
oid sha256:4f848d5beebbd69792047a96b16f7145f8e1e3e311d2a19789ce639ad8149b0e
|
||||
size 43561206
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
4fc3e1fb0db6a121f88a9141605d9285 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
253731af750407020dbe6f2fbe50fa2b libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
2aaf05cb84f52b024e89d4fa634d6900 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
f17ce186e9105c594e39d252777ce4c7 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:62af58f5e09d1cf5e347b02ef3bd3a186469162fc9645d038fb2cba23b597722
|
||||
size 88140804
|
||||
oid sha256:c429687e335c75f08186bcd8f629b50467cb0f2e484d755834c5b1cdbb9ecaf3
|
||||
size 88140796
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
eb7fc4a105eb6e6f52ba865f2b055233 tensorrt_llm_internal_cutlass_kernels_static.lib
|
||||
7f370deb0090d885d7518c2b146399ba3933c004 commit
|
||||
4f663be2b768088805ccec6dc33545fc tensorrt_llm_internal_cutlass_kernels_static.lib
|
||||
4dbf696ae9b74a26829d120b67ab8443d70c8e58 commit
|
||||
@ -1458,7 +1458,7 @@ std::vector<size_t> CutlassMoeFCRunner<T, WeightType, OutputType, ScaleBiasType,
|
||||
size_t const softmax_out_size = num_softmax_outs * sizeof(float);
|
||||
size_t const permuted_scales_size = mayHaveFinalizeFused() ? num_moe_inputs * sizeof(float) : 0;
|
||||
size_t const glu_inter_size = glu_inter_elems * gemm_output_dtype; // May be an intermediate type for quantization
|
||||
size_t const fc1_result_size = interbuf_elems * sizeof(T); // Acitvation quantizes so back to sizeof(T)
|
||||
size_t const fc1_result_size = interbuf_elems * sizeof(T); // Activation quantizes so back to sizeof(T)
|
||||
size_t const sorter_size = CubKeyValueSorter::getWorkspaceSize(num_rows, num_experts);
|
||||
size_t const fc2_result_size = permuted_elems * gemm_output_dtype; // May be an intermediate type for quantization
|
||||
|
||||
|
||||
@ -1427,7 +1427,7 @@ void invokeAirTopPSamplingWithDeterministicPara(TopPSamplingKernelParams<T> cons
|
||||
kernel = airTopPSampling<T, IdxT, AccT, HisT, BitsPerPass, SAMPLING_BLOCK_SIZE, true, isDeterministic>;
|
||||
}
|
||||
|
||||
kernel<<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(counters, histograms, countHistograms, params.outputIds,
|
||||
kernel<<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(counters, histograms, countHistograms, params.outputIdsPtrs,
|
||||
params.sequenceLength, params.finishedInput, params.finishedOutput, params.cumLogProbs,
|
||||
params.outputLogProbs, params.endIds, params.maxBatchSize, params.skipDecode, pass, buf1, idxBuf1, buf2,
|
||||
idxBuf2, params.batchSlots);
|
||||
|
||||
@ -196,11 +196,11 @@ __device__ void epilogue(SizeType32 batchId, SizeType32 currentStep, SizeType32
|
||||
}
|
||||
|
||||
template <typename T, int blockSize>
|
||||
__global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenIdType** ids, SizeType32* sequenceLength,
|
||||
FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
SizeType32 const* beginOffsetBuf, SizeType32 const* offsetBuf, SizeType32 vocabSize, curandState_t* curandState,
|
||||
float const* topPs, TokenIdType const* endIds, SizeType32 maxBatchSize, bool const* skipDecode,
|
||||
SizeType32 const* batchSlots)
|
||||
__global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenIdType* ids, TokenIdType** idsPtrs,
|
||||
SizeType32* sequenceLength, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
|
||||
float* outputLogProbs, SizeType32 const* beginOffsetBuf, SizeType32 const* offsetBuf, SizeType32 vocabSize,
|
||||
curandState_t* curandState, float const* topPs, TokenIdType const* endIds, SizeType32 maxBatchSize,
|
||||
bool const* skipDecode, SizeType32 const* batchSlots, bool returnAllTopP, SizeType32 maxSeqLen)
|
||||
{
|
||||
/**
|
||||
* Each block processes one request row sorted in descending order by probabilities.
|
||||
@ -235,14 +235,16 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
|
||||
}
|
||||
|
||||
auto const probThreshold = topPs[batchSlot];
|
||||
auto const currentStep = sequenceLength[batchSlot];
|
||||
auto const currentStep = sequenceLength == nullptr ? 0 : sequenceLength[batchSlot];
|
||||
auto* outputIdsRequestPtr = idsPtrs == nullptr ? ids + batchSlot * maxSeqLen : idsPtrs[batchSlot];
|
||||
|
||||
// With P in (0.0; 1.0] we draw a random number P' in range (0.0; P]
|
||||
// We will sum all probs moving from the largest probability to the smallest and
|
||||
// will choose the token which probability makes cumulative probability sum to exceed P'
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
randNumS = curand_uniform(curandState + blockIdx.x) * probThreshold;
|
||||
// if we want to return all top p indices, we should not do random sampling for probThreshold
|
||||
randNumS = returnAllTopP ? probThreshold : curand_uniform(curandState + blockIdx.x) * probThreshold;
|
||||
}
|
||||
|
||||
// if beginOffsetBuf and offsetBuf of sorting have same value,
|
||||
@ -253,8 +255,15 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
|
||||
if (tid == 0)
|
||||
{
|
||||
auto offset = batchId * vocabSize;
|
||||
epilogue(batchSlot, currentStep, offset, ids, sortedIdVals, sortedProbs, cumLogProbs, outputLogProbs,
|
||||
endIds, sequenceLength, finishedOutput, maxBatchSize);
|
||||
if (returnAllTopP)
|
||||
{
|
||||
outputIdsRequestPtr[currentStep] = sortedIdVals[offset];
|
||||
}
|
||||
else
|
||||
{
|
||||
epilogue(batchSlot, currentStep, offset, idsPtrs, sortedIdVals, sortedProbs, cumLogProbs,
|
||||
outputLogProbs, endIds, sequenceLength, finishedOutput, maxBatchSize);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -267,7 +276,7 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
|
||||
__syncthreads();
|
||||
|
||||
auto offset = batchId * vocabSize;
|
||||
ids[batchSlot][currentStep] = sortedIdVals[offset];
|
||||
outputIdsRequestPtr[currentStep] = sortedIdVals[offset];
|
||||
auto end = ((vocabSize + blockSize - 1) / blockSize) * blockSize;
|
||||
SizeType32 selectedTokenId = 0;
|
||||
// Cumulative sum
|
||||
@ -285,11 +294,31 @@ __global__ void topPSsampling(T* sortedProbs, TokenIdType* sortedIdVals, TokenId
|
||||
}
|
||||
}
|
||||
|
||||
// select first thread exceeded the prob threshold or the last thread in case of P=1.0f
|
||||
if (threadIdx.x == min(blockDim.x - count, blockDim.x - 1))
|
||||
if (returnAllTopP)
|
||||
{
|
||||
epilogue(batchSlot, currentStep, offset + selectedTokenId, ids, sortedIdVals, sortedProbs, cumLogProbs,
|
||||
outputLogProbs, endIds, sequenceLength, finishedOutput, maxBatchSize);
|
||||
__shared__ SizeType32 sharedSelectedTokenId;
|
||||
if (threadIdx.x == min(blockDim.x - count, blockDim.x - 1))
|
||||
{
|
||||
sharedSelectedTokenId = selectedTokenId;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int vi = tid; vi <= sharedSelectedTokenId; vi += blockSize)
|
||||
{
|
||||
outputIdsRequestPtr[vi] = sortedIdVals[offset + vi];
|
||||
}
|
||||
if (tid == 0 && sharedSelectedTokenId != end - 1)
|
||||
{
|
||||
outputIdsRequestPtr[sharedSelectedTokenId + 1] = -1; // a boundary to record the end of all selected top Ps.
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// select first thread exceeded the prob threshold or the last thread in case of P=1.0f
|
||||
if (threadIdx.x == min(blockDim.x - count, blockDim.x - 1))
|
||||
{
|
||||
epilogue(batchSlot, currentStep, offset + selectedTokenId, idsPtrs, sortedIdVals, sortedProbs, cumLogProbs,
|
||||
outputLogProbs, endIds, sequenceLength, finishedOutput, maxBatchSize);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -371,9 +400,10 @@ void invokeBatchTopPSampling(TopPSamplingKernelParams<T> const& params, cudaStre
|
||||
dim3 grid(params.batchSize);
|
||||
// Sample with Top P given sorted tokens
|
||||
topPSsampling<T, SAMPLING_BLOCK_SIZE><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(sortedProbs, sortedIdVals,
|
||||
params.outputIds, params.sequenceLength, params.finishedInput, params.finishedOutput, params.cumLogProbs,
|
||||
params.outputLogProbs, beginOffsetBuf, offsetBuf + 1, params.vocabSizePadded, params.curandState, params.topPs,
|
||||
params.endIds, params.maxBatchSize, params.skipDecode, params.batchSlots);
|
||||
params.outputIds, params.outputIdsPtrs, params.sequenceLength, params.finishedInput, params.finishedOutput,
|
||||
params.cumLogProbs, params.outputLogProbs, beginOffsetBuf, offsetBuf + 1, params.vocabSizePadded,
|
||||
params.curandState, params.topPs, params.endIds, params.maxBatchSize, params.skipDecode, params.batchSlots,
|
||||
params.returnAllTopP, params.maxSeqLen);
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -28,8 +28,13 @@ struct TopPSamplingKernelParams
|
||||
//! input buffer [batchSize, vocabSizePadded], required. Probabilities of each token in the vocab.
|
||||
T const* probs{nullptr};
|
||||
|
||||
//! output buffer [maxBatchSize][maxSeqLen], required. Contains pointers to rows with output tokens per request.
|
||||
runtime::TokenIdType** outputIds{nullptr};
|
||||
//! output buffer [maxBatchSize][maxSeqLen]. Contains pointers to rows with output tokens per request.
|
||||
//! If nullptr, outputIds must be provided.
|
||||
runtime::TokenIdType** outputIdsPtrs{nullptr};
|
||||
|
||||
//! output buffer [maxBatchSize, maxSeqLen], optional. Tensor to store output tokens.
|
||||
//! Not used if outputIdsPtrs != nullptr
|
||||
runtime::TokenIdType* outputIds{nullptr};
|
||||
|
||||
//! pointer to the workspace. Has to be pre-allocated by caller.
|
||||
//! Function does not take ownership of the buffer.
|
||||
@ -73,6 +78,9 @@ struct TopPSamplingKernelParams
|
||||
runtime::SizeType32 batchSize{-1};
|
||||
runtime::SizeType32 maxBatchSize{-1};
|
||||
runtime::SizeType32 vocabSizePadded{-1};
|
||||
runtime::SizeType32 maxSeqLen{-1};
|
||||
|
||||
bool returnAllTopP{false};
|
||||
|
||||
void checkParams() const
|
||||
{
|
||||
@ -81,12 +89,17 @@ struct TopPSamplingKernelParams
|
||||
TLLM_CHECK(maxBatchSize >= batchSize);
|
||||
TLLM_CHECK(vocabSizePadded > 0);
|
||||
TLLM_CHECK(probs);
|
||||
TLLM_CHECK(outputIds);
|
||||
TLLM_CHECK(outputIds || outputIdsPtrs);
|
||||
TLLM_CHECK(workspace);
|
||||
TLLM_CHECK(sequenceLength);
|
||||
TLLM_CHECK((sequenceLength != nullptr) || returnAllTopP);
|
||||
TLLM_CHECK(curandState);
|
||||
TLLM_CHECK(topPs);
|
||||
|
||||
if (outputIds)
|
||||
{
|
||||
TLLM_CHECK(maxSeqLen > 0);
|
||||
}
|
||||
|
||||
TLLM_CHECK(((finishedOutput == nullptr) ^ (endIds == nullptr)) == 0);
|
||||
}
|
||||
};
|
||||
|
||||
@ -35,230 +35,281 @@ namespace tensorrt_llm::kernels::speculative_decoding
|
||||
{
|
||||
namespace
|
||||
{
|
||||
__global__ void acceptDraftTokensByIds(TokenIdType const* draftIds, TokenIdType const* targetIds,
|
||||
SizeType32 const* contextLengths, SizeType32 const* numsDraftTokens, SizeType32* sequenceLengths,
|
||||
FinishedState const* finished, FinishedState* finishedFinal, SizeType32* finishedSum, SizeType32 const* batchSlots,
|
||||
SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 maxSeqLen, SizeType32 maxDraftTokens)
|
||||
|
||||
template <typename T>
|
||||
__global__ void maskTargetLogitsKernel(T* targetLogits, SizeType32 const* batchSlots, SizeType32 beamWidth,
|
||||
SizeType32 vocabSize, FinishedState const* finishedInput, SizeType32 maxBatchSize, bool const* batchUseDraftLogits,
|
||||
SizeType32* outputIdsAfterSampling, SizeType32* targetOutputIds, SizeType32* runtimeTopKDevicePtr, bool* maskBuffer)
|
||||
{
|
||||
for (auto batchIdx = static_cast<SizeType32>(threadIdx.x); batchIdx < batchSize; batchIdx += blockDim.x)
|
||||
/**
|
||||
* @brief Masking the selected token to -inf as was done in Huggingface TopK/TopP Logits Warper
|
||||
* https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/generation/logits_process.py#L533
|
||||
*/
|
||||
|
||||
auto const bid = blockIdx.x;
|
||||
auto const batchIdx = bid / beamWidth;
|
||||
auto const tid = static_cast<SizeType32>(threadIdx.x);
|
||||
auto const batchSlot = batchSlots[batchIdx];
|
||||
|
||||
constexpr bool IS_HALF = std::is_same<T, half>::value;
|
||||
T const MAX_T_VAL = (IS_HALF) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
auto targetLogitsBatch = targetLogits + batchIdx * vocabSize;
|
||||
auto& finishedState = finishedInput[batchSlot];
|
||||
|
||||
auto* outputIdsAfterSamplingPtr = outputIdsAfterSampling + batchSlot * vocabSize;
|
||||
auto const useDraftLogits = batchUseDraftLogits[batchSlot];
|
||||
|
||||
if (finishedState.isSkipDecoding())
|
||||
{
|
||||
auto const batchSlot = batchSlots[batchIdx];
|
||||
auto const numDraftTokens = numsDraftTokens[batchSlot];
|
||||
return;
|
||||
}
|
||||
|
||||
auto const contextLength = contextLengths[batchSlot];
|
||||
auto& sequenceLength = sequenceLengths[batchSlot];
|
||||
SizeType32 finishedDraftIdx = 0;
|
||||
for (auto ti = contextLength; ti < min(sequenceLength, contextLength + numDraftTokens);
|
||||
++ti, ++finishedDraftIdx)
|
||||
{
|
||||
auto const draftIdx = ti - contextLength;
|
||||
auto const targetTokenIdx = batchSlot * maxSeqLen + ti;
|
||||
auto const draftTokenIdx = batchSlot * maxDraftTokens + draftIdx;
|
||||
// Check if draft tokens are the same as target tokens
|
||||
bool const accepted = draftIds[draftTokenIdx] == targetIds[targetTokenIdx];
|
||||
if (!accepted)
|
||||
{
|
||||
// Set sequence length to the numAcceptedTokens + 1
|
||||
sequenceLength = min(ti + 1, maxSeqLen);
|
||||
// FIXME(nkorobov): do we need to set endIds here?
|
||||
break;
|
||||
}
|
||||
__shared__ SizeType32 tokensToMask;
|
||||
|
||||
if (tid == 0)
|
||||
{
|
||||
tokensToMask = runtimeTopKDevicePtr[batchSlot];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (SizeType32 vIdx = tid; vIdx < vocabSize; vIdx += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
if (tokensToMask == 0 && outputIdsAfterSamplingPtr[vIdx] == -1)
|
||||
{ // we need to find the -1 boundary from returnAllTopP outputIds if topK == 0
|
||||
tokensToMask = vIdx;
|
||||
}
|
||||
FinishedState finishState = finished[finishedDraftIdx * maxBatchSize + batchSlot];
|
||||
finishedFinal[batchSlot] = finishState;
|
||||
maskBuffer[vIdx] = false;
|
||||
}
|
||||
|
||||
if (finishedSum)
|
||||
__syncthreads();
|
||||
|
||||
if (!useDraftLogits && tid == 0)
|
||||
{
|
||||
targetOutputIds[batchSlot] = outputIdsAfterSamplingPtr[tokensToMask - 1];
|
||||
}
|
||||
|
||||
for (SizeType32 vIdx = tid; vIdx < tokensToMask; vIdx += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
auto tokenToMask = outputIdsAfterSamplingPtr[vIdx];
|
||||
maskBuffer[tokenToMask] = true;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (SizeType32 vIdx = tid; vIdx < vocabSize; vIdx += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
if (!maskBuffer[vIdx])
|
||||
{
|
||||
finishedSum[batchSlot] = static_cast<int>(finishState.isFinished());
|
||||
targetLogitsBatch[vIdx] = -MAX_T_VAL;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void invokeAcceptDraftTokensByIds(TokenIdType const* draftIds, TokenIdType const* targetIds,
|
||||
SizeType32 const* contextLengths, SizeType32 const* numsDraftTokens, SizeType32* sequenceLengths,
|
||||
FinishedState const* finished, FinishedState* finishedFinal, SizeType32* finishedSum, SizeType32 const* batchSlots,
|
||||
SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth, SizeType32 maxSeqLen,
|
||||
SizeType32 maxDraftTokens, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK(beamWidth == 1);
|
||||
dim3 block(min(1024, batchSize));
|
||||
dim3 grid(1);
|
||||
acceptDraftTokensByIds<<<grid, block, 0, stream>>>(draftIds, targetIds, contextLengths, numsDraftTokens,
|
||||
sequenceLengths, finished, finishedFinal, finishedSum, batchSlots, batchSize, maxBatchSize, maxSeqLen,
|
||||
maxDraftTokens);
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
template <typename T>
|
||||
__global__ void acceptDraftTokensByLogitsKernel(T const* draftProbs, T* targetProbs, SizeType32 const* numsDraftTokens,
|
||||
FinishedState* finished, curandState_t* curandState, SizeType32 const* batchSlots, SizeType32 batchSize,
|
||||
SizeType32 maxBatchSize, SizeType32 maxDraftTokens, SizeType32 beamWidth, SizeType32 vocabSize,
|
||||
bool randomThreshold, float constantThreshold)
|
||||
__global__ void acceptDraftTokensKernel(T const* draftProbs, T* targetProbs, SizeType32 const* numsDraftTokens,
|
||||
bool const* batchUseDraftLogits, TokenIdType const* draftIds, FinishedState const* finishedInput,
|
||||
FinishedState* finishedOutput, curandState_t* curandState, SizeType32 const* batchSlots, SizeType32 maxDraftTokens,
|
||||
SizeType32 beamWidth, SizeType32 vocabSize, bool randomThreshold, float constantThreshold, SizeType32 step,
|
||||
bool* batchIsAccepted, SizeType32* targetOutputIds)
|
||||
{
|
||||
auto const bid = blockIdx.x;
|
||||
auto const draftTokenIdx = blockIdx.y;
|
||||
auto const draftTokenIdx = step;
|
||||
auto const batchIdx = bid / beamWidth;
|
||||
auto const beamIdx = bid % beamWidth;
|
||||
auto const batchSlot = batchSlots[batchIdx];
|
||||
auto const batchSlotBeamWidth = batchSlot * beamWidth + beamIdx;
|
||||
auto const tid = static_cast<SizeType32>(threadIdx.x);
|
||||
|
||||
auto const numDraftTokens = numsDraftTokens[batchSlotBeamWidth];
|
||||
auto const useDraftLogits = batchUseDraftLogits[batchSlotBeamWidth];
|
||||
|
||||
if (draftTokenIdx >= numDraftTokens)
|
||||
if (draftTokenIdx > numDraftTokens || finishedInput[batchSlot].isSkipDecoding())
|
||||
{
|
||||
if (tid == 0)
|
||||
{
|
||||
batchIsAccepted[batchSlot] = true;
|
||||
finishedOutput[batchSlot].setSkipDecoding();
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto const logitsOffset = (batchSlot * maxDraftTokens + draftTokenIdx) * beamWidth * vocabSize;
|
||||
auto const draftProbsBatch = draftProbs + logitsOffset;
|
||||
auto const targetProbsBatch = targetProbs + logitsOffset;
|
||||
auto const vocabSizePadded = static_cast<SizeType32>((vocabSize + blockDim.x - 1) / blockDim.x) * blockDim.x;
|
||||
auto const targetProbsBatch = targetProbs + (batchIdx * beamWidth * vocabSize);
|
||||
|
||||
struct Candidate candidate;
|
||||
__shared__ float threshold;
|
||||
if (threadIdx.x == 0)
|
||||
__shared__ bool isAccepted;
|
||||
__shared__ T sSumVal;
|
||||
if (tid == 0)
|
||||
{
|
||||
threshold = randomThreshold ? curand_uniform(curandState + batchSlot) : constantThreshold;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (auto vIdx = static_cast<SizeType32>(threadIdx.x); vIdx < vocabSizePadded;
|
||||
vIdx += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
bool const pred = vIdx < vocabSize;
|
||||
auto const targetProb = pred ? static_cast<float>(targetProbsBatch[vIdx]) : 1.f;
|
||||
auto const draftProb = pred ? static_cast<float>(draftProbsBatch[vIdx]) : 0.f;
|
||||
|
||||
if (draftProb > candidate.maxProb)
|
||||
if (draftTokenIdx < numDraftTokens)
|
||||
{
|
||||
candidate.maxProb = draftProb;
|
||||
candidate.rateQP = pred ? targetProb / draftProb : 0.f;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
typedef cub::BlockReduce<Candidate, 1024> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage reduce_buffer;
|
||||
Candidate candidate_global = BlockReduce(reduce_buffer).Reduce(candidate, reduce_op);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
finished[draftTokenIdx * maxBatchSize * beamWidth + batchSlotBeamWidth]
|
||||
= candidate_global.rateQP < threshold ? FinishedState::skipDecoding() : FinishedState::empty();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void correctAcceptedStatesAndLogits(T const* draftProbs, T* targetProbs, T** targetLogits,
|
||||
SizeType32 const* numsDraftTokens, FinishedState* finished, SizeType32 const* batchSlots, SizeType32 batchSize,
|
||||
SizeType32 maxBatchSize, SizeType32 maxDraftTokens, SizeType32 beamWidth, SizeType32 vocabSize)
|
||||
{
|
||||
auto const bid = blockIdx.x;
|
||||
auto const batchIdx = bid / beamWidth;
|
||||
auto const beamIdx = bid % beamWidth;
|
||||
auto const batchSlot = batchSlots[batchIdx];
|
||||
auto const batchSlotBeamWidth = batchSlot * beamWidth + beamIdx;
|
||||
auto const numDraftTokens = numsDraftTokens[batchSlotBeamWidth];
|
||||
|
||||
__shared__ SizeType32 numAcceptedTokens;
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
numAcceptedTokens = numDraftTokens;
|
||||
bool cummulativeSkipDecoding = false;
|
||||
for (SizeType32 ti = 0; ti < numDraftTokens + 1; ++ti)
|
||||
{
|
||||
auto& finishedState = finished[ti * maxBatchSize * beamWidth + batchSlotBeamWidth];
|
||||
bool localSkipDecoding = finishedState.isSkipDecoding();
|
||||
if (cummulativeSkipDecoding == false && localSkipDecoding == true)
|
||||
auto const draftOutputTokenId = draftIds[batchSlot * maxDraftTokens + draftTokenIdx];
|
||||
if (useDraftLogits)
|
||||
{
|
||||
numAcceptedTokens = ti;
|
||||
float threshold = randomThreshold ? curand_uniform(curandState + batchSlot) : constantThreshold;
|
||||
auto const targetProb = static_cast<float>(targetProbsBatch[draftOutputTokenId]);
|
||||
auto const draftProb = static_cast<float>(draftProbsBatch[draftOutputTokenId]);
|
||||
auto rateQP = targetProb / draftProb;
|
||||
if (rateQP < threshold)
|
||||
{
|
||||
isAccepted = false;
|
||||
finishedOutput[batchSlot].setSkipDecoding();
|
||||
}
|
||||
else
|
||||
{
|
||||
isAccepted = true;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Check if draft tokens are the same as target tokens
|
||||
isAccepted = targetOutputIds[batchSlot] == draftOutputTokenId;
|
||||
if (!isAccepted)
|
||||
{
|
||||
finishedOutput[batchSlot].setSkipDecoding();
|
||||
}
|
||||
}
|
||||
|
||||
finishedState = cummulativeSkipDecoding ? FinishedState::skipDecoding() : FinishedState::empty();
|
||||
cummulativeSkipDecoding |= localSkipDecoding;
|
||||
}
|
||||
else
|
||||
{
|
||||
isAccepted = false;
|
||||
finishedOutput[batchSlot].setSkipDecoding();
|
||||
}
|
||||
batchIsAccepted[batchSlot] = isAccepted;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (numAcceptedTokens < numDraftTokens)
|
||||
if (!isAccepted)
|
||||
{
|
||||
auto const logitsIdx = (batchSlot * maxDraftTokens + numAcceptedTokens) * beamWidth * vocabSize;
|
||||
auto const draftProbBatch = draftProbs + logitsIdx;
|
||||
auto targetProbBatch = targetProbs + logitsIdx;
|
||||
auto targetLogitsBatch = targetLogits[bid] + numAcceptedTokens * beamWidth * vocabSize;
|
||||
|
||||
float sumProbs = 0.f;
|
||||
for (SizeType32 vIdx = static_cast<SizeType32>(threadIdx.x); vIdx < vocabSize;
|
||||
vIdx += static_cast<SizeType32>(blockDim.x))
|
||||
T const zeroVal = static_cast<T>(0.0f);
|
||||
T sumVal = zeroVal;
|
||||
for (SizeType32 vIdx = tid; vIdx < vocabSize; vIdx += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
auto const correctedProb = max(static_cast<float>(targetProbBatch[vIdx] - draftProbBatch[vIdx]), 0.f);
|
||||
sumProbs += correctedProb;
|
||||
targetProbBatch[vIdx] = correctedProb;
|
||||
targetProbsBatch[vIdx]
|
||||
-= (draftTokenIdx < numDraftTokens && useDraftLogits) ? draftProbsBatch[vIdx] : zeroVal;
|
||||
targetProbsBatch[vIdx] = targetProbsBatch[vIdx] >= zeroVal ? targetProbsBatch[vIdx] : zeroVal;
|
||||
sumVal += targetProbsBatch[vIdx];
|
||||
}
|
||||
|
||||
__shared__ float sumProbsShared;
|
||||
sumProbs = blockReduceSum<float>((float) sumProbs);
|
||||
if (threadIdx.x == 0)
|
||||
sumVal = blockReduceSum<T>(sumVal);
|
||||
if (tid == 0)
|
||||
{
|
||||
sumProbsShared = max(sumProbs, 1e-6f);
|
||||
sSumVal = sumVal;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (SizeType32 vIdx = static_cast<SizeType32>(threadIdx.x); vIdx < vocabSize;
|
||||
vIdx += static_cast<SizeType32>(blockDim.x))
|
||||
for (SizeType32 vIdx = tid; vIdx < vocabSize; vIdx += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
auto const correctedNormProb = static_cast<float>(targetProbBatch[vIdx]) / sumProbsShared;
|
||||
targetLogitsBatch[vIdx] = __logf(correctedNormProb / (1.f - correctedNormProb));
|
||||
targetProbsBatch[vIdx] /= sSumVal;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void forwardAcceptedTokensKernel(SizeType32 batchSize, SizeType32 const* batchSlots, bool* batchIsAccepted,
|
||||
SizeType32* sequenceLengths, TokenIdType const* draftIds, TokenIdType** idsPtrs, SizeType32 step,
|
||||
SizeType32 maxDraftTokens, TokenIdType const* endIds, FinishedState* finishedOutput)
|
||||
{
|
||||
auto index = static_cast<SizeType32>(blockIdx.x * blockDim.x + threadIdx.x);
|
||||
for (SizeType32 bi = index; bi < batchSize; bi += static_cast<SizeType32>(gridDim.x * blockDim.x))
|
||||
{
|
||||
auto const batchSlot = batchSlots[bi];
|
||||
if (batchIsAccepted[batchSlot] && !finishedOutput[batchSlot].isSkipDecoding())
|
||||
{
|
||||
auto const curSeqLen = sequenceLengths[batchSlot];
|
||||
auto const draftTokenIdx = step;
|
||||
auto const draftOutputTokenId = draftIds[batchSlot * maxDraftTokens + draftTokenIdx];
|
||||
auto* outputIdsRequestPtr = idsPtrs[batchSlot];
|
||||
auto const outIdx = curSeqLen;
|
||||
outputIdsRequestPtr[outIdx] = draftOutputTokenId;
|
||||
if (outputIdsRequestPtr[outIdx] == endIds[batchSlot])
|
||||
{
|
||||
finishedOutput[batchSlot].setFinishedEOS();
|
||||
// Do not increase seq len when EOS is generated. Seq len should always contain only tokens to be
|
||||
// outputted
|
||||
}
|
||||
else
|
||||
{
|
||||
// We don't need to set output finished state as it is assumed to be in non finished state
|
||||
sequenceLengths[batchSlot] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void acceptDraftTokensByLogits(T* draftLogits, T** targetLogits, T* draftProbs, T* targetProbs,
|
||||
SizeType32 const* numsDraftTokens, FinishedState* finished, curandState_t* curandState,
|
||||
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth,
|
||||
SizeType32 vocabSize, SizeType32 vocabSizePadded, SizeType32 maxDraftTokens, bool randomThreshold,
|
||||
float constantThreshold, cudaStream_t stream)
|
||||
void invokeMaskTargetLogits(SizeType32 batchSize, T* targetLogits, SizeType32 const* batchSlots, SizeType32 beamWidth,
|
||||
SizeType32 vocabSizePadded, FinishedState const* finishedInput, SizeType32 maxBatchSize,
|
||||
bool const* batchUseDraftLogits, SizeType32* outputIdsAfterSampling, SizeType32* targetOutputIds,
|
||||
SizeType32* runtimeTopKDevicePtr, bool* maskBuffer, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(beamWidth == 1);
|
||||
{
|
||||
invokeAddBiasSoftMax(draftLogits, static_cast<T**>(nullptr), draftProbs, static_cast<T*>(nullptr), nullptr,
|
||||
finished, batchSlots, batchSize, maxBatchSize, beamWidth * maxDraftTokens, vocabSize, vocabSizePadded,
|
||||
/* skip softmax */ false,
|
||||
/* batchSlotLogits */ true, stream);
|
||||
invokeAddBiasSoftMax(static_cast<T*>(nullptr), targetLogits, targetProbs, static_cast<T*>(nullptr), nullptr,
|
||||
finished, batchSlots, batchSize, maxBatchSize, beamWidth * maxDraftTokens, vocabSize, vocabSizePadded,
|
||||
/* skip softmax */ false,
|
||||
/* batchSlotLogits */ true, stream);
|
||||
}
|
||||
{
|
||||
dim3 block(1024);
|
||||
dim3 grid(batchSize * beamWidth, maxDraftTokens);
|
||||
acceptDraftTokensByLogitsKernel<<<grid, block, 0, stream>>>(draftProbs, targetProbs, numsDraftTokens, finished,
|
||||
curandState, batchSlots, batchSize, maxBatchSize, maxDraftTokens, beamWidth, vocabSizePadded,
|
||||
randomThreshold, constantThreshold);
|
||||
}
|
||||
{
|
||||
dim3 block(1024);
|
||||
dim3 grid(batchSize * beamWidth);
|
||||
correctAcceptedStatesAndLogits<<<grid, block, 0, stream>>>(draftProbs, targetProbs, targetLogits,
|
||||
numsDraftTokens, finished, batchSlots, batchSize, maxBatchSize, maxDraftTokens, beamWidth, vocabSizePadded);
|
||||
maskTargetLogitsKernel<<<grid, block, 0, stream>>>(targetLogits, batchSlots, beamWidth, vocabSizePadded,
|
||||
finishedInput, maxBatchSize, batchUseDraftLogits, outputIdsAfterSampling, targetOutputIds,
|
||||
runtimeTopKDevicePtr, maskBuffer);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template void acceptDraftTokensByLogits(float* draftLogits, float** targetLogits, float* draftProbs, float* targetProbs,
|
||||
SizeType32 const* numsDraftTokens, FinishedState* finished, curandState_t* curandState,
|
||||
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth,
|
||||
SizeType32 vocabSize, SizeType32 vocabSizePadded, SizeType32 maxDraftTokens, bool randomThreshold,
|
||||
float constantThreshold, cudaStream_t stream);
|
||||
template void acceptDraftTokensByLogits(half* draftLogits, half** targetLogits, half* draftProbs, half* targetProbs,
|
||||
SizeType32 const* numsDraftTokens, FinishedState* finished, curandState_t* curandState,
|
||||
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 maxBatchSize, SizeType32 beamWidth,
|
||||
SizeType32 vocabSize, SizeType32 vocabSizePadded, SizeType32 maxDraftTokens, bool randomThreshold,
|
||||
float constantThreshold, cudaStream_t stream);
|
||||
template <typename T>
|
||||
void invokeAcceptDraftTokens(SizeType32 batchSize, T* draftProbs, T* targetProbs, SizeType32 const* numsDraftTokens,
|
||||
bool const* batchUseDraftLogits, TokenIdType const* draftIds, FinishedState const* finishedInput,
|
||||
FinishedState* finishedOutput, curandState_t* curandState, SizeType32 const* batchSlots, SizeType32 maxDraftTokens,
|
||||
SizeType32 beamWidth, SizeType32 vocabSizePadded, bool randomThreshold, float constantThreshold, SizeType32 step,
|
||||
bool* batchIsAccepted, SizeType32* targetOutputIds, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(beamWidth == 1);
|
||||
{
|
||||
dim3 block(1024);
|
||||
dim3 grid(batchSize * beamWidth);
|
||||
acceptDraftTokensKernel<<<grid, block, 0, stream>>>(draftProbs, targetProbs, numsDraftTokens,
|
||||
batchUseDraftLogits, draftIds, finishedInput, finishedOutput, curandState, batchSlots, maxDraftTokens,
|
||||
beamWidth, vocabSizePadded, randomThreshold, constantThreshold, step, batchIsAccepted, targetOutputIds);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template void invokeMaskTargetLogits(SizeType32 batchSize, float* targetLogits, SizeType32 const* batchSlots,
|
||||
SizeType32 beamWidth, SizeType32 vocabSizePadded, FinishedState const* finishedInput, SizeType32 maxBatchSize,
|
||||
bool const* batchUseDraftLogits, SizeType32* outputIdsAfterSampling, SizeType32* targetOutputIds,
|
||||
SizeType32* runtimeTopKDevicePtr, bool* maskBuffer, cudaStream_t stream);
|
||||
template void invokeMaskTargetLogits(SizeType32 batchSize, half* targetLogits, SizeType32 const* batchSlots,
|
||||
SizeType32 beamWidth, SizeType32 vocabSizePadded, FinishedState const* finishedInput, SizeType32 maxBatchSize,
|
||||
bool const* batchUseDraftLogits, SizeType32* outputIdsAfterSampling, SizeType32* targetOutputIds,
|
||||
SizeType32* runtimeTopKDevicePtr, bool* maskBuffer, cudaStream_t stream);
|
||||
|
||||
template void invokeAcceptDraftTokens(SizeType32 batchSize, float* draftProbs, float* targetProbs,
|
||||
SizeType32 const* numsDraftTokens, bool const* batchUseDraftLogits, TokenIdType const* draftIds,
|
||||
FinishedState const* finishedInput, FinishedState* finishedOutput, curandState_t* curandState,
|
||||
SizeType32 const* batchSlots, SizeType32 maxDraftTokens, SizeType32 beamWidth, SizeType32 vocabSizePadded,
|
||||
bool randomThreshold, float constantThreshold, SizeType32 step, bool* batchIsAccepted, SizeType32* targetOutputIds,
|
||||
cudaStream_t stream);
|
||||
template void invokeAcceptDraftTokens(SizeType32 batchSize, half* draftProbs, half* targetProbs,
|
||||
SizeType32 const* numsDraftTokens, bool const* batchUseDraftLogits, TokenIdType const* draftIds,
|
||||
FinishedState const* finishedInput, FinishedState* finishedOutput, curandState_t* curandState,
|
||||
SizeType32 const* batchSlots, SizeType32 maxDraftTokens, SizeType32 beamWidth, SizeType32 vocabSizePadded,
|
||||
bool randomThreshold, float constantThreshold, SizeType32 step, bool* batchIsAccepted, SizeType32* targetOutputIds,
|
||||
cudaStream_t stream);
|
||||
|
||||
void invokeForwardAcceptedTokens(SizeType32 batchSize, SizeType32 const* batchSlots, bool* batchIsAccepted,
|
||||
SizeType32* outputSequenceLengths, TokenIdType const* draftIds, TokenIdType** idsPtrs, SizeType32 step,
|
||||
SizeType32 maxDraftTokens, TokenIdType const* endIds, FinishedState* finishedOutput, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
dim3 block(std::min(static_cast<uint32_t>(batchSize), 256u));
|
||||
dim3 grid(divUp(static_cast<uint32_t>(batchSize), block.x));
|
||||
forwardAcceptedTokensKernel<<<grid, block, 0, stream>>>(batchSize, batchSlots, batchIsAccepted,
|
||||
outputSequenceLengths, draftIds, idsPtrs, step, maxDraftTokens, endIds, finishedOutput);
|
||||
sync_check_cuda_error();
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
} // namespace tensorrt_llm::kernels::speculative_decoding
|
||||
|
||||
@ -26,84 +26,77 @@
|
||||
namespace tensorrt_llm::kernels::speculative_decoding
|
||||
{
|
||||
|
||||
//! \brief Accepts or rejects draft tokens based on the equality of draft and target tokens
|
||||
//! for speculative decoding. Target token is accepted if targetToken == draftToken.
|
||||
//! If number of accepted tokens N < maxDraftTokens, then function accepts N + 1 tokens of target model.
|
||||
//! sequenceLengths, finishedSum and finishedFinal are modified accordingly.
|
||||
//!
|
||||
//! \param draftIds input buffer [batchSize, maxDraftTokens].
|
||||
//! Indices of the draft tokens.
|
||||
//! \param targetIds input buffer [batchSize, maxSeqLen]. Indices of the tokens decoded by the target model
|
||||
//! \param contextLengths input buffer [batchSize]. Context lengths of the requests without draft tokens
|
||||
//! \param numsDraftTokens input buffer [batchSize]. Number of draft tokens per request
|
||||
//! \param sequenceLengths input/output buffer [batchSize] sequence lengths of the requests in batch
|
||||
//! Modified in-place according to the accepted/rejected tokens
|
||||
//! \param finished input buffer [maxDraftTokens + 1, batchSize] finished states at each decoding iteration
|
||||
//! \param finishedFinal output buffer [batchSize] finished states after accepting/rejecting tokens
|
||||
//! \param finishedSum output buffer [1] total number of requests in batch that finished the execution
|
||||
//! \param batchSlots input buffer [batchSize], address map from local index
|
||||
//! to global index [0, batchSize] -> [0, maxBatchSize]
|
||||
//! \param batchSize current batch size
|
||||
//! \param maxBatchSize maximum batch size
|
||||
//! \param beamWidth beam width
|
||||
//! \param maxSeqLen maximum sequence length
|
||||
//! \param maxDraftTokens maximum number of draft tokens
|
||||
//! \param stream stream
|
||||
void invokeAcceptDraftTokensByIds(runtime::TokenIdType const* draftIds, runtime::TokenIdType const* targetIds,
|
||||
runtime::SizeType32 const* contextLengths, runtime::SizeType32 const* numsDraftTokens,
|
||||
runtime::SizeType32* sequenceLengths, FinishedState const* finished, FinishedState* finishedFinal,
|
||||
runtime::SizeType32* finishedSum, runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 maxBatchSize, runtime::SizeType32 beamWidth, runtime::SizeType32 maxSeqLen,
|
||||
runtime::SizeType32 maxDraftTokens, cudaStream_t stream);
|
||||
|
||||
//! \brief Performs probabilistic acceptance of draft tokens based on their probability distributions.
|
||||
//! Corrects targetLogits for the next to the last accepted token
|
||||
//! \brief Accepts or rejects draft tokens based on their probability distributions or the equality of draft and target
|
||||
//! tokens. Corrects targetLogits for the last accepted token
|
||||
//! according to https://openreview.net/pdf?id=C9NEblP8vS
|
||||
//!
|
||||
//! \param draftLogits input/output buffer [draftTokens, batchSize, beamWidth, vocabSize].
|
||||
//! Initially contains token logits of the draft model.
|
||||
//! \param targetLogits input/output buffer [batchSize][draftTokens+1, beamWidth, vocabSize].
|
||||
//! Vector of pointers to the logits.
|
||||
//! Initially contains token logits of the target model.
|
||||
//! It is modified in-place for next to the last accepted token such as
|
||||
//! P'(x) = norm(max(0, P_{n+1}(x) - Q_{n+1}(x))), where N < maxDraftTokens is number of accepted tokens.
|
||||
//! \param batchSize current batch size
|
||||
//! \param draftProbs output buffer [maxDraftTokens, batchSize, beamWidth, vocabSize].
|
||||
//! Workspace buffer for token probabilities of the draft model.
|
||||
//! \param targetProbs output buffer [maxDraftTokens+1, batchSize, beamWidth, vocabSize].
|
||||
//! Workspace buffer for token probabilities of the target model.
|
||||
//! \param numsDraftTokens input buffer [batchSize]. Number of draft tokens per request
|
||||
//! \param finished output buffer [draftTokens, batchSize, beamWidth].
|
||||
//! At each step sets to NOT_FINISHED if token is accepted or SKIP_DECODING if token is not accepted
|
||||
//! \param curandState input buffer [batchSize]. Curand states properly
|
||||
//! initialized using invokeCurandInitialize per request.
|
||||
//! \param batchSlots input buffer [batchSize], address map from local index
|
||||
//! to global index [0, batchSize] -> [0, maxBatchSize]
|
||||
//! \param batchSize current batch size
|
||||
//! \param maxBatchSize maximum batch size
|
||||
//! \param beamWidth beam width
|
||||
//! \param vocabSize unpadded vocab size
|
||||
//! \param vocabSizePadded padded vocab size
|
||||
//! \param batchUseDraftLogits input buffer [batchSize]. Acceptance logic using draft logits or not, per request
|
||||
//! \param draftIds input buffer [batchSize, draftTokens]. Pointer to draft token ids.
|
||||
//! \param finishedInput input buffer [batchSize, beamWidth].
|
||||
//! \param finishedOutput output buffer [batchSize, beamWidth]. At each step sets SKIP_DECODING if token is not
|
||||
//! accepted.
|
||||
//! \param curandState input buffer [batchSize]. Curand states properly initialized using invokeCurandInitialize
|
||||
//! per request.
|
||||
//! \param batchSlots input buffer [batchSize], address map from local index to global index [0, batchSize] ->
|
||||
//! [0, maxBatchSize].
|
||||
//! \param maxDraftTokens maximum number of draft tokens
|
||||
//! \param beamWidth beam width (only beamWidth == 1 supported)
|
||||
//! \param vocabSizePadded padded vocab size
|
||||
//! \param randomThreshold True if use uniformly sampled threshold for token acceptance
|
||||
//! \param constantThreshold threshold used to accept tokens if randomThreshold is false
|
||||
//! \param step The current step of decoding (draft token id index)
|
||||
//! \param batchIsAccepted output buffer [batchSize]. Stores acceptance result for multinomial sampling later or
|
||||
//! forwarding next step.
|
||||
//! \param targetOutputIds input/output buffer [batchSize]. Stores target sampling output ids for acceptById
|
||||
//! logics.
|
||||
//! \param stream stream
|
||||
template <typename T>
|
||||
void acceptDraftTokensByLogits(T* draftLogits, T** targetLogits, T* draftProbs, T* targetProbs,
|
||||
runtime::SizeType32 const* numsDraftTokens, FinishedState* finished, curandState_t* curandState,
|
||||
runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize, runtime::SizeType32 maxBatchSize,
|
||||
runtime::SizeType32 beamWidth, runtime::SizeType32 vocabSize, runtime::SizeType32 vocabSizePadded,
|
||||
runtime::SizeType32 maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream);
|
||||
void invokeAcceptDraftTokens(runtime::SizeType32 batchSize, T* draftProbs, T* targetProbs,
|
||||
runtime::SizeType32 const* numsDraftTokens, bool const* batchUseDraftLogits, runtime::TokenIdType const* draftIds,
|
||||
FinishedState const* finishedInput, FinishedState* finishedOutput, curandState_t* curandState,
|
||||
runtime::SizeType32 const* batchSlots, runtime::SizeType32 maxDraftTokens, runtime::SizeType32 beamWidth,
|
||||
runtime::SizeType32 vocabSizePadded, bool randomThreshold, float constantThreshold, runtime::SizeType32 step,
|
||||
bool* batchIsAccepted, runtime::SizeType32* targetOutputIds, cudaStream_t stream);
|
||||
|
||||
struct Candidate // Hold probability maximum and rate of target / dfraft, used in `acceptDraftTokensByLogits`
|
||||
{
|
||||
float maxProb{0.f};
|
||||
float rateQP{0.f};
|
||||
};
|
||||
//! \brief Mask the target logits with -inf for unselected topK/topP token ids.
|
||||
//! according to
|
||||
//! https://github.com/huggingface/transformers/blob/2e24ee4dfa39cc0bc264b89edbccc373c8337086/src/transformers/generation/utils.py#L4064
|
||||
//!
|
||||
//! \param batchSize current batch size
|
||||
//! \param targetLogits input/output buffer [batchSize][draftTokens+1, beamWidth, vocabSize].
|
||||
//! Vector of pointers to the logits. (beamWidth == 1)
|
||||
//! Initially contains token logits of the target model.
|
||||
//! \param batchSlots input buffer [batchSize], address map from local index to global index [0, batchSize] ->
|
||||
//! [0, maxBatchSize].
|
||||
//! \param beamWidth beam width (only beamWidth == 1 supported)
|
||||
//! \param vocabSizePadded padded vocab size
|
||||
//! \param finishedInput input buffer [batchSize, beamWidth].
|
||||
//! \param maxBatchSize maximum batch size
|
||||
//! \param batchUseDraftLogits input buffer [batchSize]. Acceptance logic using draft logits or not, per request
|
||||
//! \param outputIdsAfterSampling input buffer [batchSize, vocabSize]. Stores all selected IDs from sampling for
|
||||
//! masking.
|
||||
//! \param targetOutputIds input/output buffer [batchSize]. Stores target sampling output ids for acceptById
|
||||
//! logics.
|
||||
//! \param numsDraftTokens input buffer [batchSize]. Number of draft tokens per request
|
||||
//! \param runtimeTopKDevicePtr input buffer [batchSize] the topks in sampling step, for porting topK ids out.
|
||||
//! \param maskBuffer input buffer [batchSize, vocabSize] for masking calculation (index value to position).
|
||||
//! \param stream stream
|
||||
template <typename T>
|
||||
void invokeMaskTargetLogits(runtime::SizeType32 batchSize, T* targetLogits, runtime::SizeType32 const* batchSlots,
|
||||
runtime::SizeType32 beamWidth, runtime::SizeType32 vocabSizePadded, FinishedState const* finishedInput,
|
||||
runtime::SizeType32 maxBatchSize, bool const* batchUseDraftLogits, runtime::SizeType32* outputIdsAfterSampling,
|
||||
runtime::SizeType32* targetOutputIds, runtime::SizeType32* runtimeTopKDevicePtr, bool* maskBuffer,
|
||||
cudaStream_t stream);
|
||||
|
||||
__device__ __forceinline__ Candidate reduce_op(Candidate const& a, Candidate const& b)
|
||||
{
|
||||
// Max-reduce operator of Candidate
|
||||
return (a.maxProb > b.maxProb) ? a : b;
|
||||
}
|
||||
void invokeForwardAcceptedTokens(runtime::SizeType32 batchSize, runtime::SizeType32 const* batchSlots,
|
||||
bool* batchIsAccepted, runtime::SizeType32* outputSequenceLengths, runtime::TokenIdType const* draftIds,
|
||||
runtime::TokenIdType** idsPtrs, runtime::SizeType32 step, runtime::SizeType32 maxDraftTokens,
|
||||
runtime::TokenIdType const* endIds, FinishedState* finishedOutput, cudaStream_t stream);
|
||||
|
||||
} // namespace tensorrt_llm::kernels::speculative_decoding
|
||||
|
||||
@ -73,7 +73,7 @@ void invokeLengthCriterion(FinishedState* finished, runtime::SizeType32* finishe
|
||||
runtime::SizeType32* numNewTokens, runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize,
|
||||
runtime::SizeType32 beamWidth, cudaStream_t stream);
|
||||
|
||||
//! \brief Sets finished states based on the endIds and ajusts sequence length to length before the first EOS token.
|
||||
//! \brief Sets finished states based on the endIds and adjusts sequence length to length before the first EOS token.
|
||||
//! Does not support beamWidth > 1 for now.
|
||||
//!
|
||||
//! \param outputIds input buffer [maxBatchSize][beamWidth, maxSeqLen].
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "tensorrt_llm/layers/beamSearchLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/layers/explicitDraftTokensLayer.h"
|
||||
#include "tensorrt_llm/layers/externalDraftTokensLayer.h"
|
||||
#include "tensorrt_llm/layers/layerUtils.h"
|
||||
#include "tensorrt_llm/layers/lookaheadDecodingLayer.h"
|
||||
#include "tensorrt_llm/layers/medusaDecodingLayer.h"
|
||||
@ -96,6 +97,10 @@ DecodingLayer<T>::DecodingLayer(executor::DecodingMode const& mode, DecoderDomai
|
||||
{
|
||||
mDecodingLayer = std::make_unique<ExplicitDraftTokensLayer<T>>(decoderDomain, mBufferManager);
|
||||
}
|
||||
else if (mDecodingMode.isExternalDraftTokens())
|
||||
{
|
||||
mDecodingLayer = std::make_unique<ExternalDraftTokensLayer<T>>(mDecodingMode, decoderDomain, mBufferManager);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false,
|
||||
@ -144,6 +149,12 @@ void DecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorC
|
||||
beamWidth == 1, "Decoding mode is ExplicitDraftTokens, but beamWidth != 1 (%d != 1)", beamWidth);
|
||||
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, setupParams->decodingParams, workspace);
|
||||
}
|
||||
else if (mDecodingMode.isExternalDraftTokens())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
beamWidth == 1, "Decoding mode is external draft tokens, but beamWidth != 1 (%d != 1)", beamWidth);
|
||||
mDecodingLayer->setup(batchSize, beamWidth, batchSlots, setupParams->decodingParams, workspace);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false,
|
||||
@ -249,6 +260,45 @@ std::tuple<std::shared_ptr<BaseDecodingOutputs>, std::shared_ptr<BaseDecodingInp
|
||||
preparedInputs = baseInputs;
|
||||
preparedOutputs = baseOutputs;
|
||||
}
|
||||
else if (mDecodingMode.isExternalDraftTokens())
|
||||
{
|
||||
auto externalDraftTokenParams = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
|
||||
auto const ite = externalDraftTokenParams->ite;
|
||||
auto const step = externalDraftTokenParams->step;
|
||||
auto const localBatchSize = static_cast<int64_t>(externalDraftTokenParams->localBatchSize);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(localDecoderDomain.getBeamWidth() == 1,
|
||||
"Decoding mode is TopK and/or TopP, but beamWidth != 1 (%d != 1)", localDecoderDomain.getBeamWidth());
|
||||
|
||||
// In sampling, we have supported batch sampling. So, we always compute all
|
||||
// sentences once.
|
||||
TensorConstPtr logitsSlice = ITensor::slice(*externalDraftTokenParams->logits, 0, localBatchSize);
|
||||
TensorConstPtr endIdSlice = ITensor::slice(endIds, 0, localBatchSize);
|
||||
auto decodeInputs = std::make_shared<ExternalDraftTokensInputs>(
|
||||
endIdSlice, externalDraftTokenParams->batchSlots, step, ite, localBatchSize);
|
||||
|
||||
decodeInputs->finished = externalDraftTokenParams->finished;
|
||||
|
||||
decodeInputs->logits = logitsSlice;
|
||||
|
||||
if (externalDraftTokenParams->inputLengths)
|
||||
{
|
||||
auto& inputLengths = externalDraftTokenParams->inputLengths.value();
|
||||
decodeInputs->inputLengths = ITensor::slice(inputLengths, 0, localBatchSize);
|
||||
}
|
||||
decodeInputs->draftLogits = externalDraftTokenParams->draftLogits;
|
||||
decodeInputs->draftProbs = externalDraftTokenParams->draftProbs;
|
||||
decodeInputs->targetProbs = externalDraftTokenParams->targetProbs;
|
||||
decodeInputs->numDraftTokens = externalDraftTokenParams->numDraftTokens;
|
||||
decodeInputs->draftTokenIds = externalDraftTokenParams->draftTokenIds;
|
||||
decodeInputs->constantThreshold = externalDraftTokenParams->constantThreshold;
|
||||
decodeInputs->useRandomAcceptanceThreshold = externalDraftTokenParams->useRandomAcceptanceThreshold;
|
||||
decodeInputs->step = externalDraftTokenParams->step;
|
||||
decodeInputs->useDraftLogits = externalDraftTokenParams->useDraftLogits;
|
||||
|
||||
preparedInputs = decodeInputs;
|
||||
preparedOutputs = baseOutputs;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false,
|
||||
|
||||
@ -45,7 +45,7 @@ public:
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace) override;
|
||||
|
||||
//! \brief Calls forwardSync of configired decoding layer.
|
||||
//! \brief Calls forwardSync of configured decoding layer.
|
||||
void forwardSync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace) override;
|
||||
|
||||
@ -212,6 +212,13 @@ struct LookaheadSetupParams : public DecodingSetupParams
|
||||
TensorPtr attentionPackedMasks;
|
||||
};
|
||||
|
||||
class ExternalDraftTokensSetupParams : public DecodingSetupParams
|
||||
{
|
||||
public:
|
||||
std::optional<std::vector<runtime::SizeType32>> runtimeTopK; // [1] or [setupBatchSize] on cpu
|
||||
std::optional<std::vector<float>> runtimeTopP; // [1] or [setupBatchSize] on cpu
|
||||
};
|
||||
|
||||
class BaseDecodingInputs
|
||||
{
|
||||
public:
|
||||
@ -331,6 +338,33 @@ public:
|
||||
bool probsComputed{};
|
||||
};
|
||||
|
||||
class ExternalDraftTokensInputs : public DecodingInputs
|
||||
{
|
||||
public:
|
||||
explicit ExternalDraftTokensInputs(TensorConstPtr endIds, TensorConstPtr batchSlots, runtime::SizeType32 step,
|
||||
runtime::SizeType32 ite, runtime::SizeType32 localBatchSize)
|
||||
: DecodingInputs{std::move(endIds), std::move(batchSlots), step, ite, localBatchSize}
|
||||
{
|
||||
}
|
||||
|
||||
TensorPtr draftLogits;
|
||||
TensorPtr draftProbs;
|
||||
TensorPtr targetProbs;
|
||||
TensorPtr numDraftTokens;
|
||||
TensorPtr draftTokenIds;
|
||||
TensorPtr useDraftLogits;
|
||||
runtime::SizeType32 step;
|
||||
float constantThreshold;
|
||||
bool useRandomAcceptanceThreshold;
|
||||
|
||||
//! optional parameters
|
||||
//! [localBatchSize]
|
||||
curandState_t* curandStates{};
|
||||
|
||||
//! Flag to mark that logits tensor contains probabilities
|
||||
bool probsComputed{};
|
||||
};
|
||||
|
||||
// Medusa inputs
|
||||
class MedusaDecodingInputs : public DecodingInputs
|
||||
{
|
||||
@ -477,7 +511,7 @@ public:
|
||||
//! {c'} is always accepted and {x', z'} is supposed to be accepted.
|
||||
//! The accepted tokens [c', x', z'] is saved in `outputIds` in-place, starting from `sequenceLength`.
|
||||
//! The `acceptedLength` is 3, and the accepted draft tokens length is 2.
|
||||
//! `sequenceLength` is also increaded by `acceptedLength` in-place.
|
||||
//! `sequenceLength` is also increased by `acceptedLength` in-place.
|
||||
//! The pathsOffset is {0, 1, 3} for {c', x', z'}.
|
||||
//! [] for accepted, <> for draft, {} for input/output.
|
||||
//!
|
||||
|
||||
514
cpp/tensorrt_llm/layers/externalDraftTokensLayer.cpp
Normal file
514
cpp/tensorrt_llm/layers/externalDraftTokensLayer.cpp
Normal file
@ -0,0 +1,514 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "externalDraftTokensLayer.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
|
||||
#include "tensorrt_llm/kernels/samplingTopPKernels.h"
|
||||
#include "tensorrt_llm/kernels/speculativeDecoding/externalDraftTokensKernels.h"
|
||||
#include "tensorrt_llm/layers/defaultDecodingParams.h"
|
||||
#include "tensorrt_llm/layers/layerUtils.h"
|
||||
#include "tensorrt_llm/runtime/runtimeKernels.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
namespace tksd = tensorrt_llm::kernels::speculative_decoding;
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
ExternalDraftTokensLayer<T>::ExternalDraftTokensLayer(executor::DecodingMode const& mode,
|
||||
DecoderDomain const& decoderDomain, std::shared_ptr<BufferManager> bufferManager)
|
||||
: BaseLayer(decoderDomain, bufferManager)
|
||||
, mDecodingMode(mode)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(!mDecodingMode.isBeamSearch(), "ExternalDraftTokensLayer does not support Beam search mode");
|
||||
|
||||
allocateBuffer(decoderDomain.getBatchSize());
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExternalDraftTokensLayer<T>::allocateBuffer(SizeType32 batchSize)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
// top k workspace size
|
||||
auto workspaceSize = getTopKWorkspaceSize<T>(batchSize, 1, TOP_K_MAX, mDecoderDomain.getVocabSizePadded());
|
||||
mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize);
|
||||
// top p workspace size
|
||||
workspaceSize = getTopPWorkspaceSize<T>(batchSize, mDecoderDomain.getVocabSizePadded());
|
||||
mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize);
|
||||
// multinomial (top p == 1) workspace size
|
||||
workspaceSize = getTopPWorkspaceSize<float>(batchSize, mDecoderDomain.getVocabSizePadded());
|
||||
mWorkspaceSize = std::max(workspaceSize, mWorkspaceSize);
|
||||
|
||||
// batchsize here is maxBatchSize
|
||||
auto const batchSizeShape = ITensor::makeShape({batchSize});
|
||||
|
||||
mCurandStatesDevice
|
||||
= mBufferManager->gpu(ITensor::makeShape({batchSize, sizeof(curandState_t)}), TRTDataType<int8_t>::value);
|
||||
mBatchIsAccepted = mBufferManager->gpu(batchSizeShape, TRTDataType<bool>::value);
|
||||
mRuntimeMultinomialDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<float>::value);
|
||||
|
||||
// host buffers.
|
||||
mSkipTopKDecodeDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<bool>::value);
|
||||
mSkipTopKDecodeHost = mBufferManager->pinnedPool(batchSizeShape, TRTDataType<bool>::value);
|
||||
mSkipTopPDecodeDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<bool>::value);
|
||||
mSkipTopPDecodeHost = mBufferManager->pinnedPool(batchSizeShape, TRTDataType<bool>::value);
|
||||
auto skipTopPDecodeHostRange = BufferRange<bool>(*mSkipTopPDecodeHost);
|
||||
std::fill(skipTopPDecodeHostRange.begin(), skipTopPDecodeHostRange.end(), true);
|
||||
|
||||
mOutputIdsAfterSampling = mBufferManager->gpu(
|
||||
ITensor::makeShape({batchSize, mDecoderDomain.getVocabSizePadded()}), TRTDataType<TokenIdType>::value);
|
||||
mTargetOutputIds = mBufferManager->gpu(ITensor::makeShape({batchSize}), TRTDataType<TokenIdType>::value);
|
||||
|
||||
mRuntimeTopKDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<SizeType32>::value);
|
||||
|
||||
mRuntimeTopPForTopKDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<float>::value);
|
||||
|
||||
mRuntimeTopPDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<float>::value);
|
||||
mInitialTopPDevice = mBufferManager->gpu(batchSizeShape, TRTDataType<float>::value);
|
||||
|
||||
mMaskBuffer = mBufferManager->gpu(
|
||||
ITensor::makeShape({batchSize, mDecoderDomain.getVocabSizePadded()}), TRTDataType<bool>::value);
|
||||
|
||||
mSetupWorkspaceSize = std::max({mBatchIsAccepted->getSizeInBytes(), mRuntimeMultinomialDevice->getSizeInBytes(),
|
||||
mSkipTopKDecodeDevice->getSizeInBytes(), mSkipTopPDecodeDevice->getSizeInBytes(),
|
||||
mOutputIdsAfterSampling->getSizeInBytes(), mTargetOutputIds->getSizeInBytes(),
|
||||
mRuntimeTopKDevice->getSizeInBytes(), mRuntimeTopPForTopKDevice->getSizeInBytes(),
|
||||
mRuntimeTopPDevice->getSizeInBytes(), mInitialTopPDevice->getSizeInBytes(), mMaskBuffer->getSizeInBytes()});
|
||||
|
||||
mTargetLogits = mBufferManager->gpu(
|
||||
ITensor::makeShape({batchSize, mDecoderDomain.getVocabSizePadded()}), TRTDataType<T>::value);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExternalDraftTokensLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth, TensorConstPtr batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> const& baseSetupParams,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto setupParams = std::dynamic_pointer_cast<ExternalDraftTokensSetupParams>(baseSetupParams);
|
||||
|
||||
workspace->initializeDeviceCurandStates(
|
||||
setupParams->randomSeed, batchSize, workspace->getDeviceBatchSlots(), mCurandStatesDevice);
|
||||
|
||||
auto const* batchSlotsDevicePtr = workspace->getDeviceBatchSlotsPtr();
|
||||
auto& runtimeMultinomialDeviceTensor = const_cast<ITensor&>(*mRuntimeMultinomialDevice);
|
||||
tensorrt_llm::runtime::kernels::invokeFill(runtimeMultinomialDeviceTensor, 1.0f, mBufferManager->getStream());
|
||||
|
||||
auto* runtimeTopKDevicePtr = bufferCastOrNull<SizeType32>(mRuntimeTopKDevice);
|
||||
|
||||
// Prepare runtime top K
|
||||
auto constexpr defaultTopK = 1u;
|
||||
auto runtimeTopK = setupParams->runtimeTopK.value_or(std::vector<SizeType32>(batchSize, defaultTopK));
|
||||
auto const runtimeTopKSize = runtimeTopK.size();
|
||||
for (auto& topK : runtimeTopK)
|
||||
{
|
||||
if (topK < 0 || topK > TOP_K_MAX)
|
||||
{
|
||||
TLLM_LOG_WARNING(
|
||||
"TopK (%d) is larger than max supported number (%d). Clip to max supported number.", topK, TOP_K_MAX);
|
||||
topK = std::clamp(topK, 0, static_cast<SizeType32>(TOP_K_MAX));
|
||||
}
|
||||
}
|
||||
|
||||
if (runtimeTopKSize > 1)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(runtimeTopK.size() == batchSize,
|
||||
fmtstr("runtimeTopK.size() (%lu) == batchSize (%d) is not satisfied!", runtimeTopK.size(), batchSize));
|
||||
DecodingLayerWorkspace::copyToWorkspace<SizeType32>(
|
||||
*this->mBufferManager, runtimeTopK, workspace->getWorkspaceDeviceBuffer());
|
||||
auto* setupWorkspaceDevicePtr = workspace->getWorkspaceDevicePtrAs<SizeType32>();
|
||||
// fill top ks into runtimeTopKDevice
|
||||
invokeScatterDecodingParams(
|
||||
setupWorkspaceDevicePtr, runtimeTopKDevicePtr, batchSlotsDevicePtr, batchSize, getStream());
|
||||
}
|
||||
|
||||
// FIXME(nkorobov): monotonically growing
|
||||
auto const curMaxTopK = *std::max_element(std::begin(runtimeTopK), std::end(runtimeTopK));
|
||||
mRuntimeMaxTopK = std::max(mRuntimeMaxTopK, curMaxTopK);
|
||||
|
||||
auto runtimeTopP = setupParams->runtimeTopP.value_or(std::vector<float>{});
|
||||
auto const runtimeTopPSize = runtimeTopP.size();
|
||||
auto* runtimeTopPForTopKDevicePtr = bufferCastOrNull<float>(mRuntimeTopPForTopKDevice);
|
||||
auto* runtimeTopPDevicePtr = bufferCastOrNull<float>(mRuntimeTopPDevice);
|
||||
auto* skipTopPDecodeHostPtr = bufferCastOrNull<bool>(mSkipTopPDecodeHost);
|
||||
|
||||
// if no top P, fill topP skip decode to true
|
||||
if (runtimeTopPSize == 0)
|
||||
{
|
||||
auto const* batchSlotsPtr = bufferCast<SizeType32>(*batchSlots);
|
||||
for (SizeType32 bi = 0; bi < batchSize; ++bi)
|
||||
{
|
||||
auto const bid = batchSlotsPtr[bi];
|
||||
skipTopPDecodeHostPtr[bid] = true;
|
||||
}
|
||||
auto skipTopPDecodeHostSlice = IBuffer::slice(mSkipTopPDecodeHost, 0, mDecoderDomain.getBatchSize());
|
||||
mBufferManager->copy(*skipTopPDecodeHostSlice, *mSkipTopPDecodeDevice);
|
||||
}
|
||||
else
|
||||
{
|
||||
for (auto& topP : runtimeTopP)
|
||||
{
|
||||
if (topP < 0.f || topP > 1.0f)
|
||||
{
|
||||
TLLM_LOG_WARNING("TopP (%f) is out of range ([0.0, 1.0f]). Clip to closest number.", topP);
|
||||
topP = std::clamp(topP, 0.f, 1.f);
|
||||
}
|
||||
}
|
||||
if (runtimeTopPSize > 1)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(runtimeTopP.size() == batchSize,
|
||||
fmtstr("runtimeTopP.size() (%lu) == batchSize (%d) is not satisfied!", runtimeTopP.size(), batchSize));
|
||||
DecodingLayerWorkspace::copyToWorkspace<float>(
|
||||
*this->mBufferManager, runtimeTopP, workspace->getWorkspaceDeviceBuffer());
|
||||
auto* setupWorkspaceDevicePtr = workspace->getWorkspaceDevicePtrAs<float>();
|
||||
// fill runtime top p device for top k kernel
|
||||
invokeScatterDecodingParams(
|
||||
setupWorkspaceDevicePtr, runtimeTopPForTopKDevicePtr, batchSlotsDevicePtr, batchSize, getStream());
|
||||
// fill runtime top p device for top p kernel
|
||||
invokeScatterDecodingParams(
|
||||
setupWorkspaceDevicePtr, runtimeTopPDevicePtr, batchSlotsDevicePtr, batchSize, getStream());
|
||||
}
|
||||
}
|
||||
// if no topP, default topP is 0.0f, but in invokeSetupTopKRuntimeArgs, it gets set to 1.0f if k > 0
|
||||
auto const topP = (runtimeTopPSize == 0) ? DefaultDecodingParams::getTopP() : runtimeTopP.front();
|
||||
|
||||
auto* skipTopKDecodeDevicePtr = bufferCastOrNull<bool>(mSkipTopKDecodeDevice);
|
||||
{
|
||||
dim3 block(std::min(static_cast<uint32_t>(batchSize), 256u));
|
||||
dim3 grid(divUp(static_cast<uint32_t>(batchSize), block.x));
|
||||
// support topK up to TOP_K_MAX.
|
||||
invokeSetupTopKRuntimeArgs(batchSize, curMaxTopK, runtimeTopKDevicePtr, runtimeTopKSize, topP,
|
||||
runtimeTopPForTopKDevicePtr, runtimeTopPSize, skipTopKDecodeDevicePtr, batchSlotsDevicePtr, getStream());
|
||||
}
|
||||
auto const skipTopKHostDecodeDeviceSlice = ITensor::slice(mSkipTopKDecodeDevice, 0, mDecoderDomain.getBatchSize());
|
||||
auto skipTopKDecodeHostSlice = ITensor::slice(mSkipTopKDecodeHost, 0, mDecoderDomain.getBatchSize());
|
||||
mBufferManager->copy(*skipTopKHostDecodeDeviceSlice, *skipTopKDecodeHostSlice);
|
||||
|
||||
auto* skipTopPDecodeDevicePtr = bufferCast<bool>(*mSkipTopPDecodeDevice);
|
||||
{
|
||||
auto* initialTopPDevicePtr = bufferCast<float>(*mInitialTopPDevice);
|
||||
invokeSetTopPRuntimeArgs(batchSize, curMaxTopK, runtimeTopKDevicePtr, runtimeTopKSize, topP,
|
||||
runtimeTopPDevicePtr, runtimeTopPSize, skipTopPDecodeDevicePtr, batchSlotsDevicePtr, initialTopPDevicePtr,
|
||||
getStream());
|
||||
}
|
||||
auto const skipTopPHostDecodeDeviceSlice = ITensor::slice(mSkipTopPDecodeDevice, 0, mDecoderDomain.getBatchSize());
|
||||
auto skipTopPDecodeHostSlice = ITensor::slice(mSkipTopPDecodeHost, 0, mDecoderDomain.getBatchSize());
|
||||
mBufferManager->copy(*skipTopPHostDecodeDeviceSlice, *skipTopPDecodeHostSlice);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExternalDraftTokensLayer<T>::forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto inputs = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
|
||||
|
||||
auto const batchSize = inputs->logits.value()->getDimension<0>();
|
||||
|
||||
auto const* endIds = bufferCast<TokenIdType>(*inputs->endIds);
|
||||
|
||||
FinishedState const* finishedInput = (inputs->finished)
|
||||
? reinterpret_cast<FinishedState const*>(bufferCast<FinishedState::UnderlyingType>(*inputs->finished.value()))
|
||||
: nullptr;
|
||||
|
||||
inputs->curandStates = reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStatesDevice));
|
||||
inputs->probsComputed = true;
|
||||
|
||||
auto runtimeLogitsPtr = bufferCast<T>(*workspace->getDeviceRuntimeLogits());
|
||||
auto logitsPtrsPtr = static_cast<T**>(nullptr);
|
||||
auto biasPtr = static_cast<T*>(nullptr);
|
||||
auto const* batchSlotsPtr = workspace->getDeviceBatchSlotsPtr();
|
||||
mBufferManager->copy(runtimeLogitsPtr, *mTargetLogits);
|
||||
invokeAddBiasSoftMax(runtimeLogitsPtr, logitsPtrsPtr, runtimeLogitsPtr, biasPtr, endIds, finishedInput,
|
||||
batchSlotsPtr, batchSize, mDecoderDomain.getBatchSize(), /* bw */ 1, mDecoderDomain.getVocabSize(),
|
||||
mDecoderDomain.getVocabSizePadded(), /*skipSoftMax*/ false, /* batchSlotLogits */ false, getStream());
|
||||
|
||||
auto const targetTokenIdsShape = (*outputs->outputIds).getShape();
|
||||
|
||||
// Fill the buffer for selected ids from sampling with zero. -1 will be set as a boundary if topP kernel is required
|
||||
auto& outputIdsAfterSamplingTensor = const_cast<ITensor&>(*mOutputIdsAfterSampling);
|
||||
tensorrt_llm::runtime::kernels::invokeFill(outputIdsAfterSamplingTensor, 0, mBufferManager->getStream());
|
||||
|
||||
// The logits from target engine should go through samplings first.
|
||||
// gptDecoderBatched.cpp is calling dynamic decoder step by step, in this step, dynamic Decoder already forwarded
|
||||
// PenaltyLayer, BanWordsLayer. For (TopK > 0) && (TopK == 0 && TopP == 0), we invoke TopK sampling kernel. The same
|
||||
// logic is implemented in SamplingLayer.cpp
|
||||
getAllTopKs(outputs, baseInputs, workspace);
|
||||
|
||||
// Only for (TopK == 0 && TopP > 0), we invoke TopP sampling
|
||||
getAllTopPs(outputs, baseInputs, workspace);
|
||||
|
||||
// After all selected tokens are filled in mOutputIdsAfterSampling by topK, topP kernels, token acceptance logics
|
||||
// starts. First we mask the logits of unselected token id to -inf as HF's TopK, TopP implementation. We compute the
|
||||
// logit probs of draft and target and go through acceptance logics.
|
||||
acceptDraftTokens(outputs, baseInputs, workspace);
|
||||
|
||||
// If the token of the sequence is not accepted, a multinomial sampling is required for the bonus token.
|
||||
// Multinomial sampling is achieved through TopP kernel with TopP = 1 and already weighted-sum target logits.
|
||||
// The acceptance result of each batch is used as skipDecode in topP kernel. If is accepted, no sampling is needed
|
||||
// (early exit). Forwarding for the next step is also set in this kernel.
|
||||
multinomialSampling(outputs, baseInputs, workspace);
|
||||
|
||||
// For the sequence with accepted tokens, we simply forward a step.
|
||||
forwardAcceptedTokens(outputs, baseInputs, workspace);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
size_t ExternalDraftTokensLayer<T>::getWorkspaceSize() const noexcept
|
||||
{
|
||||
return std::max(mWorkspaceSize, mSetupWorkspaceSize);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExternalDraftTokensLayer<T>::acceptDraftTokens(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto inputs = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
|
||||
|
||||
auto const draftLogitsShape = (*inputs->draftLogits).getShape();
|
||||
auto const maxBatchSize = mDecoderDomain.getBatchSize();
|
||||
auto const maxTokensPerStep = draftLogitsShape.d[1]; // 1
|
||||
auto const batchSize = inputs->logits.value()->getDimension<0>();
|
||||
auto constexpr beamWidth = 1;
|
||||
|
||||
FinishedState const* finishedInput = (inputs->finished)
|
||||
? reinterpret_cast<FinishedState const*>(bufferCastOrNull<FinishedState::UnderlyingType>(inputs->finished))
|
||||
: nullptr;
|
||||
|
||||
FinishedState* finishedOutput = (outputs->finished)
|
||||
? reinterpret_cast<FinishedState*>(bufferCastOrNull<FinishedState::UnderlyingType>(outputs->finished))
|
||||
: nullptr;
|
||||
|
||||
tksd::invokeMaskTargetLogits(batchSize, bufferCast<T>(*mTargetLogits), workspace->getDeviceBatchSlotsPtr(),
|
||||
beamWidth, mDecoderDomain.getVocabSizePadded(), finishedInput, maxBatchSize,
|
||||
bufferCast<bool>(*inputs->useDraftLogits), bufferCast<SizeType32>(*mOutputIdsAfterSampling),
|
||||
bufferCast<SizeType32>(*mTargetOutputIds), bufferCastOrNull<SizeType32>(mRuntimeTopKDevice),
|
||||
bufferCast<bool>(*mMaskBuffer), getStream());
|
||||
|
||||
if (inputs->step == 0)
|
||||
{
|
||||
invokeAddBiasSoftMax(bufferCast<T>(*inputs->draftLogits), static_cast<T**>(nullptr),
|
||||
bufferCast<T>(*inputs->draftProbs), static_cast<T*>(nullptr), nullptr, finishedInput,
|
||||
workspace->getDeviceBatchSlotsPtr(), batchSize, maxBatchSize, beamWidth * maxTokensPerStep,
|
||||
mDecoderDomain.getVocabSize(), mDecoderDomain.getVocabSizePadded(),
|
||||
/* skip softmax */ false,
|
||||
/* batchSlotLogits */ true, getStream());
|
||||
}
|
||||
|
||||
invokeAddBiasSoftMax(bufferCast<T>(*mTargetLogits), static_cast<T**>(nullptr), bufferCast<T>(*inputs->targetProbs),
|
||||
static_cast<T*>(nullptr), nullptr, finishedInput, workspace->getDeviceBatchSlotsPtr(), batchSize, maxBatchSize,
|
||||
beamWidth /* 1 */, mDecoderDomain.getVocabSize(), mDecoderDomain.getVocabSizePadded(),
|
||||
/* skip softmax */ false,
|
||||
/* batchSlotLogits */ false, getStream());
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
tksd::invokeAcceptDraftTokens(batchSize, bufferCast<T>(*inputs->draftProbs), bufferCast<T>(*inputs->targetProbs),
|
||||
bufferCast<SizeType32>(*inputs->numDraftTokens), bufferCast<bool>(*inputs->useDraftLogits),
|
||||
bufferCast<TokenIdType>(*inputs->draftTokenIds), finishedInput, finishedOutput, inputs->curandStates,
|
||||
workspace->getDeviceBatchSlotsPtr(), maxTokensPerStep, beamWidth, mDecoderDomain.getVocabSizePadded(),
|
||||
inputs->useRandomAcceptanceThreshold, inputs->constantThreshold, inputs->step,
|
||||
bufferCast<bool>(*mBatchIsAccepted), bufferCast<SizeType32>(*mTargetOutputIds), getStream());
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExternalDraftTokensLayer<T>::multinomialSampling(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto inputs = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
|
||||
|
||||
auto const batchSize = inputs->logits.value()->getDimension<0>();
|
||||
auto probs = bufferCastOrNull<T>(inputs->targetProbs);
|
||||
auto* sequenceLength = bufferCastOrNull<SizeType32>(outputs->sequenceLength);
|
||||
auto const* endIds = bufferCastOrNull<TokenIdType>(inputs->endIds);
|
||||
|
||||
FinishedState* finishedOutput = (outputs->finished)
|
||||
? reinterpret_cast<FinishedState*>(bufferCastOrNull<FinishedState::UnderlyingType>(outputs->finished))
|
||||
: nullptr;
|
||||
TopPSamplingKernelParams<T> params{};
|
||||
|
||||
params.probs = probs;
|
||||
params.outputIdsPtrs = bufferCastOrNull<TokenIdType*>(outputs->outputIdsPtr);
|
||||
params.workspace = workspace->getRawWorkspaceDevicePtr();
|
||||
params.topPs = bufferCastOrNull<float>(mRuntimeMultinomialDevice);
|
||||
params.sequenceLength = sequenceLength;
|
||||
params.endIds = endIds;
|
||||
params.batchSlots = workspace->getDeviceBatchSlotsPtr();
|
||||
params.finishedInput = nullptr;
|
||||
params.finishedOutput = finishedOutput;
|
||||
params.skipDecode = bufferCastOrNull<bool>(mBatchIsAccepted);
|
||||
params.cumLogProbs = nullptr;
|
||||
params.outputLogProbs = nullptr;
|
||||
params.curandState = inputs->curandStates;
|
||||
params.batchSize = batchSize;
|
||||
params.maxBatchSize = mDecoderDomain.getBatchSize();
|
||||
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
|
||||
|
||||
invokeBatchTopPSampling<T>(params, getStream());
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExternalDraftTokensLayer<T>::getAllTopKs(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto inputs = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
|
||||
|
||||
auto logits = bufferCastOrNull<T>(inputs->logits);
|
||||
|
||||
auto const batchSize = inputs->logits.value()->getDimension<0>();
|
||||
|
||||
auto const* batchSlotsHost = bufferCast<SizeType32>(*inputs->batchSlots);
|
||||
auto* skipDecodeHostPtr = bufferCastOrNull<bool>(mSkipTopKDecodeHost);
|
||||
auto const skip = allOfBatchSlots(batchSlotsHost, skipDecodeHostPtr, batchSize, true);
|
||||
if (skip)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
FinishedState const* finishedInput = (inputs->finished)
|
||||
? reinterpret_cast<FinishedState const*>(bufferCastOrNull<FinishedState::UnderlyingType>(inputs->finished))
|
||||
: nullptr;
|
||||
|
||||
TopKSamplingKernelParams<T> params{};
|
||||
params.logProbs = logits;
|
||||
params.outputIds = bufferCastOrNull<TokenIdType>(mOutputIdsAfterSampling);
|
||||
params.workspace = workspace->getRawWorkspaceDevicePtr();
|
||||
params.maxTopP = 1.0f;
|
||||
params.topPs = bufferCastOrNull<float>(mRuntimeTopPForTopKDevice);
|
||||
params.maxTopK = mRuntimeMaxTopK;
|
||||
params.topKs = bufferCastOrNull<SizeType32>(mRuntimeTopKDevice);
|
||||
params.batchSlots = workspace->getDeviceBatchSlotsPtr();
|
||||
params.finishedInput = finishedInput;
|
||||
params.skipDecode = bufferCastOrNull<bool>(mSkipTopKDecodeDevice);
|
||||
params.curandState = inputs->curandStates;
|
||||
params.batchSize = batchSize;
|
||||
params.maxBatchSize = mDecoderDomain.getBatchSize();
|
||||
params.maxTokensPerStep = 1;
|
||||
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
|
||||
params.returnAllTopK = true;
|
||||
params.maxSeqLen = mDecoderDomain.getVocabSizePadded(); // workaround for returning all topKs with outputIds
|
||||
params.logitsHasProbs = inputs->probsComputed;
|
||||
|
||||
invokeBatchTopKSampling(params, getStream());
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExternalDraftTokensLayer<T>::getAllTopPs(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto inputs = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
|
||||
|
||||
auto logits = bufferCastOrNull<T>(inputs->logits);
|
||||
|
||||
auto const batchSize = inputs->logits.value()->getDimension<0>();
|
||||
|
||||
auto const* batchSlotsHost = bufferCast<SizeType32>(*inputs->batchSlots);
|
||||
auto* skipDecodeHostPtr = bufferCastOrNull<bool>(mSkipTopPDecodeHost);
|
||||
auto const skip = allOfBatchSlots(batchSlotsHost, skipDecodeHostPtr, batchSize, true);
|
||||
if (skip)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
FinishedState const* finishedInput = (inputs->finished)
|
||||
? reinterpret_cast<FinishedState const*>(bufferCastOrNull<FinishedState::UnderlyingType>(inputs->finished))
|
||||
: nullptr;
|
||||
|
||||
TopPSamplingKernelParams<T> params{};
|
||||
params.probs = logits;
|
||||
params.outputIds = bufferCastOrNull<TokenIdType>(mOutputIdsAfterSampling);
|
||||
params.workspace = workspace->getRawWorkspaceDevicePtr();
|
||||
params.topPs = bufferCastOrNull<float>(mRuntimeTopPDevice);
|
||||
params.batchSlots = workspace->getDeviceBatchSlotsPtr();
|
||||
params.finishedInput = finishedInput;
|
||||
params.skipDecode = bufferCastOrNull<bool>(mSkipTopPDecodeDevice);
|
||||
params.curandState = inputs->curandStates;
|
||||
params.batchSize = batchSize;
|
||||
params.maxBatchSize = mDecoderDomain.getBatchSize();
|
||||
params.vocabSizePadded = mDecoderDomain.getVocabSizePadded();
|
||||
params.returnAllTopP = true;
|
||||
params.maxSeqLen = mDecoderDomain.getVocabSizePadded();
|
||||
|
||||
invokeBatchTopPSampling<T>(params, getStream());
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void ExternalDraftTokensLayer<T>::forwardAcceptedTokens(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto inputs = std::dynamic_pointer_cast<ExternalDraftTokensInputs>(baseInputs);
|
||||
auto const batchSize = inputs->logits.value()->getDimension<0>();
|
||||
|
||||
auto const draftLogitsShape = (*inputs->draftLogits).getShape();
|
||||
auto const maxTokensPerStep = draftLogitsShape.d[1]; // 1
|
||||
|
||||
FinishedState* finishedOutput = (outputs->finished)
|
||||
? reinterpret_cast<FinishedState*>(bufferCastOrNull<FinishedState::UnderlyingType>(outputs->finished))
|
||||
: nullptr;
|
||||
|
||||
tksd::invokeForwardAcceptedTokens(batchSize, workspace->getDeviceBatchSlotsPtr(),
|
||||
bufferCast<bool>(*mBatchIsAccepted), bufferCastOrNull<SizeType32>(outputs->sequenceLength),
|
||||
bufferCast<TokenIdType>(*inputs->draftTokenIds), bufferCastOrNull<TokenIdType*>(outputs->outputIdsPtr),
|
||||
inputs->step, maxTokensPerStep, bufferCastOrNull<TokenIdType>(inputs->endIds), finishedOutput, getStream());
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template class ExternalDraftTokensLayer<float>;
|
||||
template class ExternalDraftTokensLayer<half>;
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
100
cpp/tensorrt_llm/layers/externalDraftTokensLayer.h
Normal file
100
cpp/tensorrt_llm/layers/externalDraftTokensLayer.h
Normal file
@ -0,0 +1,100 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/layers/baseLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
#include <curand_kernel.h>
|
||||
|
||||
namespace tensorrt_llm::layers
|
||||
{
|
||||
|
||||
//! \brief Top class for sampling layers.
|
||||
//! It sets up and executes TopKSamplingLayer and TopPSamplingLayer samplings
|
||||
template <typename T>
|
||||
class ExternalDraftTokensLayer : public BaseLayer
|
||||
{
|
||||
public:
|
||||
using Base = BaseLayer;
|
||||
|
||||
ExternalDraftTokensLayer(executor::DecodingMode const& mode, DecoderDomain const& decoderDomain,
|
||||
std::shared_ptr<runtime::BufferManager> bufferManager);
|
||||
|
||||
void setup(runtime::SizeType32 batchSize, runtime::SizeType32 beamWidth, TensorConstPtr batchSlots,
|
||||
std::shared_ptr<BaseSetupParams> const& setupParams,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace) override;
|
||||
|
||||
void forwardAsync(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& inputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace) override;
|
||||
|
||||
//! @returns workspace needed for this layer in bytes
|
||||
[[nodiscard]] size_t getWorkspaceSize() const noexcept override;
|
||||
|
||||
protected:
|
||||
runtime::SizeType32 mRuntimeMaxTopK{0};
|
||||
|
||||
private:
|
||||
using Base::mDecoderDomain;
|
||||
|
||||
executor::DecodingMode mDecodingMode;
|
||||
|
||||
size_t mWorkspaceSize{0};
|
||||
size_t mSetupWorkspaceSize{0};
|
||||
|
||||
TensorPtr mCurandStatesDevice;
|
||||
TensorPtr mSkipTopKDecodeDevice;
|
||||
TensorPtr mSkipTopKDecodeHost;
|
||||
TensorPtr mSkipTopPDecodeDevice;
|
||||
TensorPtr mSkipTopPDecodeHost;
|
||||
|
||||
TensorPtr mBatchIsAccepted;
|
||||
TensorPtr mRuntimeMultinomialDevice;
|
||||
|
||||
TensorPtr mOutputIdsAfterSampling;
|
||||
TensorPtr mTargetOutputIds;
|
||||
TensorPtr mRuntimeTopKDevice;
|
||||
TensorPtr mRuntimeTopPForTopKDevice;
|
||||
TensorPtr mRuntimeTopPDevice;
|
||||
TensorPtr mInitialTopPDevice;
|
||||
TensorPtr mMaskBuffer;
|
||||
|
||||
TensorPtr mTargetLogits;
|
||||
|
||||
private:
|
||||
void allocateBuffer(runtime::SizeType32 batchSize);
|
||||
void acceptDraftTokens(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace);
|
||||
void multinomialSampling(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace);
|
||||
void getAllTopKs(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace);
|
||||
void getAllTopPs(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace);
|
||||
void forwardAcceptedTokens(std::shared_ptr<BaseDecodingOutputs> const& outputs,
|
||||
std::shared_ptr<BaseDecodingInputs> const& baseInputs,
|
||||
std::shared_ptr<runtime::DecodingLayerWorkspace> const& workspace);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
@ -267,7 +267,7 @@ void TopPSamplingLayer<T>::forwardAsync(std::shared_ptr<BaseDecodingOutputs> con
|
||||
|
||||
TopPSamplingKernelParams<T> params{};
|
||||
params.probs = probs;
|
||||
params.outputIds = bufferCastOrNull<TokenIdType*>(outputs->outputIdsPtr);
|
||||
params.outputIdsPtrs = bufferCastOrNull<TokenIdType*>(outputs->outputIdsPtr);
|
||||
params.workspace = workspace->getRawWorkspaceDevicePtr();
|
||||
params.topPs = bufferCastOrNull<float>(mRuntimeTopPDevice);
|
||||
params.sequenceLength = sequenceLength;
|
||||
|
||||
@ -259,7 +259,7 @@ int LoraPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P
|
||||
int idx = 0;
|
||||
for (int reqId = 0; reqId < numReqs; reqId++)
|
||||
{
|
||||
const RequestType reqType = static_cast<RequestType const>(reqTypes[reqId]);
|
||||
RequestType const reqType = static_cast<RequestType const>(reqTypes[reqId]);
|
||||
if (reqType == RequestType::kGENERATION)
|
||||
{
|
||||
mExpandLoraWeightPtrs.push_back(reinterpret_cast<void const*>(loraWeightModulePtrs[reqId * 2]));
|
||||
@ -284,7 +284,7 @@ int LoraPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P
|
||||
fmtstr("LoraParams and input dims don't match, lora tokens %d input tokens %d", idx, numTokens));
|
||||
}
|
||||
|
||||
// only used for unifed gemm
|
||||
// only used for unified gemm
|
||||
auto bestTactic = mPluginProfiler->getBestConfig(numTokens, mGemmId);
|
||||
mLoraImpl->setBestTactic(bestTactic);
|
||||
mLoraImpl->run(numTokens, numReqs, input, mExpandLoraRanks.data(), mExpandLoraWeightPtrs.data(), mWeightIndex,
|
||||
|
||||
@ -305,14 +305,17 @@ int AllreducePlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfe
|
||||
++tpRank;
|
||||
}
|
||||
|
||||
int token_num = size / inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1];
|
||||
|
||||
auto params = tensorrt_llm::kernels::AllReduceParams::deserialize(
|
||||
reinterpret_cast<int64_t*>(const_cast<void*>(inputs[1])), tpSize, tpRank);
|
||||
reinterpret_cast<int64_t*>(const_cast<void*>(inputs[1])), tpSize, tpRank, mType, token_num, mOp);
|
||||
|
||||
params.local_output_buffer_ptr = outputs[0];
|
||||
params.local_input_buffer_ptr = inputs[0];
|
||||
params.elts_total = size;
|
||||
if (mOp == AllReduceFusionOp::RESIDUAL_RMS_NORM)
|
||||
{
|
||||
|
||||
int fusion_ptr_idx = 2;
|
||||
params.fusion_params.bias_buffer = mBias ? inputs[fusion_ptr_idx++] : nullptr;
|
||||
params.fusion_params.residual_buffer = inputs[fusion_ptr_idx++];
|
||||
@ -320,6 +323,15 @@ int AllreducePlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfe
|
||||
params.fusion_params.hidden_size = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1];
|
||||
params.fusion_params.eps = mEps;
|
||||
params.fusion_params.intermediate_buffer = outputs[1];
|
||||
for (int i = 0; i < tpSize; ++i)
|
||||
{
|
||||
params.fusion_params.lamport_peer_comm_buffer_ptrs[i]
|
||||
= reinterpret_cast<void**>(const_cast<void*>(inputs[1]))[tpSize * 4 + i];
|
||||
params.fusion_params.lamport_peer_comm_buffer_ptrs[i + tensorrt_llm::kernels::MAX_RANKS_PER_NODE]
|
||||
= reinterpret_cast<void**>(const_cast<void*>(inputs[1]))[tpSize * 5 + i];
|
||||
params.fusion_params.lamport_peer_comm_buffer_ptrs[i + tensorrt_llm::kernels::MAX_RANKS_PER_NODE * 2]
|
||||
= reinterpret_cast<void**>(const_cast<void*>(inputs[1]))[tpSize * 6 + i];
|
||||
}
|
||||
}
|
||||
tensorrt_llm::kernels::customAllReduce(params, mType, runtimeStrategy, mConfig, mOp, stream);
|
||||
}
|
||||
|
||||
@ -3,35 +3,19 @@ set(TRTLLM_PYBIND_MODULE
|
||||
${TRTLLM_PYBIND_MODULE}
|
||||
PARENT_SCOPE)
|
||||
|
||||
if(NOT BUILD_PYT)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"Python bindings for C++ runtime require PyTorch. Please enable BUILD_PYT"
|
||||
)
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} "-c"
|
||||
"import pybind11 as pb11; print(pb11.get_cmake_dir(),end='');"
|
||||
RESULT_VARIABLE PYBIND_CMAKE_DIR_RET
|
||||
OUTPUT_VARIABLE PYBIND_CMAKE_DIR)
|
||||
|
||||
if(PYBIND_CMAKE_DIR_RET MATCHES 0)
|
||||
list(APPEND CMAKE_PREFIX_PATH "${PYBIND_CMAKE_DIR}")
|
||||
else()
|
||||
message(ERROR "pybind11 CMake directory not found.")
|
||||
endif()
|
||||
|
||||
find_package(pybind11 REQUIRED)
|
||||
|
||||
set(SRCS
|
||||
bindings.cpp
|
||||
batch_manager/algorithms.cpp
|
||||
batch_manager/bindings.cpp
|
||||
batch_manager/gptManager.cpp
|
||||
batch_manager/llmRequest.cpp
|
||||
batch_manager/inferenceRequest.cpp
|
||||
batch_manager/kvCacheManager.cpp
|
||||
batch_manager/llmRequest.cpp
|
||||
batch_manager/namedTensor.cpp
|
||||
executor/bindings.cpp
|
||||
executor/executor.cpp)
|
||||
executor/executor.cpp
|
||||
bindings.cpp)
|
||||
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
|
||||
pybind11_add_module(${TRTLLM_PYBIND_MODULE} ${SRCS})
|
||||
|
||||
@ -41,15 +25,12 @@ set_property(TARGET ${TRTLLM_PYBIND_MODULE} PROPERTY POSITION_INDEPENDENT_CODE
|
||||
target_link_directories(${TRTLLM_PYBIND_MODULE} PUBLIC
|
||||
"${TORCH_INSTALL_PREFIX}/lib")
|
||||
target_link_libraries(
|
||||
${TRTLLM_PYBIND_MODULE} PUBLIC ${SHARED_TARGET} ${UNDEFINED_FLAG}
|
||||
${NO_AS_NEEDED_FLAG})
|
||||
target_link_libraries(
|
||||
${TRTLLM_PYBIND_MODULE} PUBLIC ${Python3_LIBRARIES} ${TORCH_LIBRARIES}
|
||||
torch_python ${UNDEFINED_FLAG})
|
||||
target_compile_definitions(${TRTLLM_PYBIND_MODULE}
|
||||
PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE})
|
||||
target_compile_definitions(${TRTLLM_PYBIND_MODULE}
|
||||
PUBLIC PYBIND11_DETAILED_ERROR_MESSAGES=1)
|
||||
${TRTLLM_PYBIND_MODULE}
|
||||
PUBLIC ${SHARED_TARGET} ${UNDEFINED_FLAG} ${NO_AS_NEEDED_FLAG}
|
||||
${Python3_LIBRARIES} ${TORCH_LIBRARIES} torch_python)
|
||||
target_compile_definitions(
|
||||
${TRTLLM_PYBIND_MODULE} PUBLIC TRTLLM_PYBIND_MODULE=${TRTLLM_PYBIND_MODULE}
|
||||
PYBIND11_DETAILED_ERROR_MESSAGES=1)
|
||||
|
||||
if(NOT WIN32)
|
||||
set_target_properties(
|
||||
|
||||
55
cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp
Normal file
55
cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp
Normal file
@ -0,0 +1,55 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "algorithms.h"
|
||||
#include "tensorrt_llm/batch_manager/capacityScheduler.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
|
||||
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
|
||||
#include "tensorrt_llm/pybind/common/algorithmBindings.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
using namespace tensorrt_llm::batch_manager;
|
||||
using namespace PybindUtils;
|
||||
|
||||
void tensorrt_llm::pybind::batch_manager::algorithms::initBindings(pybind11::module_& m)
|
||||
{
|
||||
// Algorithms with custom bindings
|
||||
py::class_<CapacityScheduler>(m, CapacityScheduler::name)
|
||||
.def_static("make", &CapacityScheduler::make, py::arg("max_num_requests"), py::arg("kv_cache_manager"),
|
||||
py::arg("cross_kv_cache_manager"), py::arg("peft_cache_manager"), py::arg("capacity_scheduler_policy"),
|
||||
py::arg("many_micro_batches") = false,
|
||||
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
|
||||
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE,
|
||||
"LlmRequestState.GENERATION_COMPLETE"))
|
||||
.def(py::init())
|
||||
.def("__call__", &CapacityScheduler::operator())
|
||||
.def("name", [](CapacityScheduler const&) { return CapacityScheduler::name; });
|
||||
|
||||
py::class_<MicroBatchScheduler>(m, MicroBatchScheduler::name)
|
||||
.def_static("make", &MicroBatchScheduler::make, py::arg("max_batch_size"),
|
||||
py::arg_v("max_num_tokens", std::nullopt, "None"), py::arg_v("ctx_chunk_config", std::nullopt, "None"),
|
||||
py::arg_v("max_context_length", std::nullopt, "None"),
|
||||
py::arg_v("no_schedule_until_state", LlmRequestState::kCONTEXT_INIT, "LlmRequestState.CONTEXT_INIT"),
|
||||
py::arg_v("no_schedule_after_state", LlmRequestState::kGENERATION_COMPLETE,
|
||||
"LlmRequestState.GENERATION_COMPLETE"))
|
||||
.def(py::init())
|
||||
.def("__call__", &MicroBatchScheduler::operator())
|
||||
.def("name", [](MicroBatchScheduler const&) { return MicroBatchScheduler::name; });
|
||||
}
|
||||
28
cpp/tensorrt_llm/pybind/batch_manager/algorithms.h
Normal file
28
cpp/tensorrt_llm/pybind/batch_manager/algorithms.h
Normal file
@ -0,0 +1,28 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager::algorithms
|
||||
{
|
||||
|
||||
void initBindings(pybind11::module_& m);
|
||||
|
||||
}
|
||||
41
cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Normal file
41
cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp
Normal file
@ -0,0 +1,41 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "bindings.h"
|
||||
#include "tensorrt_llm/batch_manager/common.h"
|
||||
#include "tensorrt_llm/batch_manager/microBatchScheduler.h"
|
||||
#include "tensorrt_llm/pybind/utils/bindTypes.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
void initBindings(pybind11::module_& m)
|
||||
{
|
||||
py::class_<tb::batch_scheduler::ContextChunkingConfig>(m, "ContextChunkingConfig")
|
||||
.def(py::init<tle::ContextChunkingPolicy, tensorrt_llm::runtime::SizeType32>(), py::arg("chunking_policy"),
|
||||
py::arg("chunk_unit_size"))
|
||||
.def_readwrite("chunking_policy", &tb::batch_scheduler::ContextChunkingConfig::chunkingPolicy)
|
||||
.def_readwrite("chunk_unit_size", &tb::batch_scheduler::ContextChunkingConfig::chunkUnitSize);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
28
cpp/tensorrt_llm/pybind/batch_manager/bindings.h
Normal file
28
cpp/tensorrt_llm/pybind/batch_manager/bindings.h
Normal file
@ -0,0 +1,28 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
void initBindings(pybind11::module_& m);
|
||||
|
||||
}
|
||||
@ -21,6 +21,7 @@
|
||||
#include "namedTensor.h"
|
||||
#include "tensorrt_llm/batch_manager/GptManager.h"
|
||||
#include "tensorrt_llm/batch_manager/callbacks.h"
|
||||
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
|
||||
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <functional>
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/namedTensor.h"
|
||||
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
29
cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
Normal file
29
cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
Normal file
@ -0,0 +1,29 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#include "kvCacheManager.h"
|
||||
#include "tensorrt_llm/pybind/utils/bindTypes.h"
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace py = pybind11;
|
||||
|
||||
void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m)
|
||||
{
|
||||
// TODO: Provide proper bindings
|
||||
py::classh<tb::kv_cache_manager::KVCacheManager>(m, "KVCacheManager");
|
||||
}
|
||||
|
||||
void tb::BasePeftCacheManagerBindings::initBindings(py::module_& m)
|
||||
{
|
||||
// TODO: Provide proper bindings
|
||||
py::classh<tb::BasePeftCacheManager>(m, "BasePeftCacheManager");
|
||||
}
|
||||
36
cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.h
Normal file
36
cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.h
Normal file
@ -0,0 +1,36 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
|
||||
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
class KVCacheManagerBindings
|
||||
{
|
||||
public:
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
class BasePeftCacheManagerBindings
|
||||
{
|
||||
public:
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
@ -17,22 +17,29 @@
|
||||
#include "llmRequest.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/pybind/utils/bindTypes.h"
|
||||
#include "tensorrt_llm/runtime/torch.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
using namespace tensorrt_llm::pybind::batch_manager;
|
||||
|
||||
using LlmRequestPtr = std::shared_ptr<tb::LlmRequest>;
|
||||
using RequestList = std::list<LlmRequestPtr>;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
@ -166,7 +173,6 @@ void LlmRequest::initBindings(py::module_& m)
|
||||
.def_property_readonly("orig_prompt_len", &LlmRequest::getOrigPromptLen)
|
||||
.def("has_draft_tokens", &LlmRequest::hasDraftTokens)
|
||||
.def("move_to_next_context_chunk", &LlmRequest::moveToNextContextChunk)
|
||||
.def("is_full_context_request", py::overload_cast<>(&LlmRequest::isFullContextRequest, py::const_))
|
||||
.def("is_last_context_chunk", py::overload_cast<>(&LlmRequest::isLastContextChunk, py::const_))
|
||||
.def("is_first_context_chunk", py::overload_cast<>(&LlmRequest::isFirstContextChunk, py::const_))
|
||||
.def("get_context_remaining_length", py::overload_cast<>(&LlmRequest::getContextRemainingLength, py::const_))
|
||||
@ -180,3 +186,140 @@ void LlmRequest::initBindings(py::module_& m)
|
||||
{ self.setDraftLogits(std::make_optional<LlmRequest::TensorPtr>(logits)); })
|
||||
.def_property("num_return_sequences", &LlmRequest::getNumReturnSequences, &LlmRequest::setNumReturnSequences);
|
||||
}
|
||||
|
||||
void tb::LlmRequestBindings::initBindings(py::module_& m)
|
||||
{
|
||||
py::classh<tb::LlmRequest>(m, "PyLlmRequest")
|
||||
.def("get_num_tokens", &tb::LlmRequest::getNumTokens, py::arg("beam"))
|
||||
.def_property_readonly("max_beam_num_tokens", &tb::LlmRequest::getMaxBeamNumTokens)
|
||||
.def("get_token", &tb::LlmRequest::getToken, py::arg("beam"), py::arg("pos"))
|
||||
.def("get_tokens", py::overload_cast<tb::LlmRequest::SizeType32>(&tb::LlmRequest::getTokens, py::const_),
|
||||
py::arg("beam"))
|
||||
.def("get_tokens", py::overload_cast<>(&tb::LlmRequest::getTokens, py::const_))
|
||||
.def_property_readonly("max_num_generated_tokens", &tb::LlmRequest::getMaxNumGeneratedTokens)
|
||||
.def("add_new_token", &tb::LlmRequest::addNewToken, py::arg("token"), py::arg("beam"))
|
||||
.def("add_new_tokens", &tb::LlmRequest::addNewTokens, py::arg("beam_tokens"))
|
||||
.def("set_generated_tokens", &tb::LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens"))
|
||||
.def("pause", &tb::LlmRequest::pause, py::arg("max_input_len"))
|
||||
.def_property("max_sent_token_len", &tb::LlmRequest::getMaxSentTokenLen, &tb::LlmRequest::setMaxSentTokenLen)
|
||||
.def("prompt_embedding_table",
|
||||
[](tb::LlmRequest& self)
|
||||
{
|
||||
std::optional<at::Tensor> value{std::nullopt};
|
||||
auto tensor = self.getPromptEmbeddingTable();
|
||||
if (tensor)
|
||||
{
|
||||
value = tr::Torch::tensor(*tensor);
|
||||
}
|
||||
return value;
|
||||
})
|
||||
.def("bad_words_list",
|
||||
[](tb::LlmRequest& self)
|
||||
{
|
||||
std::optional<at::Tensor> value{std::nullopt};
|
||||
auto tensor = self.getBadWordsList();
|
||||
if (tensor)
|
||||
{
|
||||
value = tr::Torch::tensor(*tensor);
|
||||
}
|
||||
return value;
|
||||
})
|
||||
.def_property(
|
||||
"draft_logits",
|
||||
[](tb::LlmRequest& self)
|
||||
{
|
||||
std::optional<at::Tensor> value{std::nullopt};
|
||||
auto tensor = self.getDraftLogits();
|
||||
if (tensor)
|
||||
{
|
||||
value = tr::Torch::tensor(*tensor);
|
||||
}
|
||||
return value;
|
||||
},
|
||||
[](tb::LlmRequest& self, at::Tensor& logits)
|
||||
{ self.setDraftLogits(std::make_optional<tb::LlmRequest::TensorPtr>(tr::TorchView::of(logits))); })
|
||||
.def("embedding_bias",
|
||||
[](tb::LlmRequest& self)
|
||||
{
|
||||
std::optional<at::Tensor> value{std::nullopt};
|
||||
auto tensor = self.getEmbeddingBias();
|
||||
if (tensor)
|
||||
{
|
||||
value = tr::Torch::tensor(*tensor);
|
||||
}
|
||||
return value;
|
||||
})
|
||||
.def("lora_config",
|
||||
[](tb::LlmRequest& self)
|
||||
{
|
||||
std::optional<at::Tensor> value{std::nullopt};
|
||||
auto tensor = self.getLoraConfig();
|
||||
if (tensor)
|
||||
{
|
||||
value = tr::Torch::tensor(*tensor);
|
||||
}
|
||||
return value;
|
||||
})
|
||||
.def("lora_weights",
|
||||
[](tb::LlmRequest& self)
|
||||
{
|
||||
std::optional<at::Tensor> value{std::nullopt};
|
||||
auto tensor = self.getLoraWeights();
|
||||
if (tensor)
|
||||
{
|
||||
value = tr::Torch::tensor(*tensor);
|
||||
}
|
||||
return value;
|
||||
})
|
||||
.def("stop_words_list",
|
||||
[](tb::LlmRequest& self)
|
||||
{
|
||||
std::optional<at::Tensor> value{std::nullopt};
|
||||
auto tensor = self.getStopWordsList();
|
||||
if (tensor)
|
||||
{
|
||||
value = tr::Torch::tensor(*tensor);
|
||||
}
|
||||
return value;
|
||||
})
|
||||
.def_property_readonly("prompt_vocab_size", &tb::LlmRequest::getPromptVocabSize)
|
||||
.def_property_readonly("lora_task_id", &tb::LlmRequest::getLoraTaskId)
|
||||
.def_property_readonly("lookahead_config", &tb::LlmRequest::getLookaheadConfig)
|
||||
.def_property_readonly(
|
||||
"context_current_position", py::overload_cast<>(&tb::LlmRequest::getContextCurrentPosition, py::const_))
|
||||
.def_property("context_chunk_size", &tb::LlmRequest::getContextChunkSize, &tb::LlmRequest::setContextChunkSize)
|
||||
.def_readwrite("request_id", &tb::LlmRequest::mRequestId)
|
||||
.def_readwrite("prompt_len", &tb::LlmRequest::mPromptLen)
|
||||
.def_readwrite("max_new_tokens", &tb::LlmRequest::mMaxNewTokens)
|
||||
.def_readwrite("sampling_config", &tb::LlmRequest::mSamplingConfig)
|
||||
.def_readwrite("state", &tb::LlmRequest::mState)
|
||||
.def_readwrite("is_streaming", &tb::LlmRequest::mIsStreaming)
|
||||
.def_readwrite("end_id", &tb::LlmRequest::mEndId)
|
||||
.def_readwrite("pad_id", &tb::LlmRequest::mPadId)
|
||||
.def_readwrite("seq_slot", &tb::LlmRequest::mSeqSlot)
|
||||
.def_property_readonly("return_log_probs", &tb::LlmRequest::returnLogProbs)
|
||||
.def_property_readonly("return_context_logits", &tb::LlmRequest::setReturnContextLogits)
|
||||
.def_property_readonly("return_generation_logits", &tb::LlmRequest::setReturnGenerationLogits)
|
||||
.def_property_readonly("log_probs", py::overload_cast<>(&tb::LlmRequest::getLogProbs, py::const_))
|
||||
.def("get_log_probs", py::overload_cast<tb::LlmRequest::SizeType32>(&tb::LlmRequest::getLogProbs, py::const_))
|
||||
.def("set_log_probs", &tb::LlmRequest::setLogProbs, py::arg("log_probs"), py::arg("beam"))
|
||||
.def("set_return_encoder_output", &tb::LlmRequest::setReturnEncoderOutput, py::arg("return_encoder_output"))
|
||||
.def("get_return_encoder_output", &tb::LlmRequest::getReturnEncoderOutput)
|
||||
.def("priority", py::overload_cast<>(&tb::LlmRequest::priority, py::const_))
|
||||
.def("set_priority", py::overload_cast<tle::PriorityType>(&tb::LlmRequest::setPriority))
|
||||
.def_property_readonly("cum_log_probs", &tb::LlmRequest::getCumLogProbs)
|
||||
.def("set_cum_log_prob", &tb::LlmRequest::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam"))
|
||||
.def_property_readonly("orig_prompt_len", &tb::LlmRequest::getOrigPromptLen)
|
||||
.def("has_draft_tokens", &tb::LlmRequest::hasDraftTokens)
|
||||
.def("move_to_next_context_chunk", &tb::LlmRequest::moveToNextContextChunk)
|
||||
.def("is_last_context_chunk", py::overload_cast<>(&tb::LlmRequest::isLastContextChunk, py::const_))
|
||||
.def("is_first_context_chunk", py::overload_cast<>(&tb::LlmRequest::isFirstContextChunk, py::const_))
|
||||
.def(
|
||||
"get_context_remaining_length", py::overload_cast<>(&tb::LlmRequest::getContextRemainingLength, py::const_))
|
||||
.def_property(
|
||||
"draft_tokens", [](tb::LlmRequest& self) { return *self.getDraftTokens(); },
|
||||
[](tb::LlmRequest& self, tb::LlmRequest::VecTokens& draftTokens)
|
||||
{ self.setDraftTokens(std::make_shared<tb::LlmRequest::VecTokens>(std::move(draftTokens))); });
|
||||
|
||||
py::bind_vector<tb::RequestVector>(m, "RequestVector");
|
||||
}
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/ops/tensor.h>
|
||||
@ -25,6 +26,15 @@
|
||||
#include <optional>
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
class LlmRequestBindings
|
||||
{
|
||||
public:
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
@ -91,6 +101,7 @@ public:
|
||||
std::optional<LlmRequest::LogitsPostProcessor> callback);
|
||||
|
||||
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> toTrtLlm() const;
|
||||
|
||||
static void initBindings(pybind11::module_& m);
|
||||
};
|
||||
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/namedTensor.h"
|
||||
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
|
||||
@ -23,18 +23,20 @@
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorrt_llm/pybind/batch_manager/gptManager.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/namedTensor.h"
|
||||
#include "tensorrt_llm/pybind/executor/bindings.h"
|
||||
#include "tensorrt_llm/pybind/utils/pathCaster.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/BatchManager.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
|
||||
#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h"
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/algorithms.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/bindings.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/gptManager.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/pybind/batch_manager/namedTensor.h"
|
||||
#include "tensorrt_llm/pybind/executor/bindings.h"
|
||||
#include "tensorrt_llm/pybind/utils/pathCaster.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/gptJsonConfig.h"
|
||||
#include "tensorrt_llm/runtime/memoryCounters.h"
|
||||
@ -333,6 +335,10 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
|
||||
tpb::NamedTensor::initBindings(m);
|
||||
tpb::LlmRequest::initBindings(m);
|
||||
tb::kv_cache_manager::KVCacheManagerBindings::initBindings(m);
|
||||
tb::BasePeftCacheManagerBindings::initBindings(m);
|
||||
|
||||
tb::LlmRequestBindings::initBindings(m);
|
||||
|
||||
auto tensorNames = m.def_submodule("tensor_names");
|
||||
// Input tensor names
|
||||
@ -412,8 +418,6 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def(py::pickle(gptModelParamsGetState, gptModelParamsSetState))
|
||||
.def("__eq__", &tb::TrtGptModelOptionalParams::operator==);
|
||||
|
||||
tpb::GptManager::initBindings(m);
|
||||
|
||||
py::class_<tr::MemoryCounters>(m, "MemoryCounters")
|
||||
.def_static("instance", &tr::MemoryCounters::getInstance, py::return_value_policy::reference)
|
||||
.def_property_readonly("gpu", &tr::MemoryCounters::getGpu)
|
||||
@ -447,4 +451,11 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
auto& world = tensorrt_llm::mpi::MpiComm::world();
|
||||
tensorrt_llm::mpi::MpiComm::setSession(world.split(color, rank));
|
||||
});
|
||||
|
||||
auto mInternal = m.def_submodule("internal", "Internal submodule of TRTLLM runtime");
|
||||
|
||||
tensorrt_llm::pybind::batch_manager::initBindings(mInternal);
|
||||
tensorrt_llm::pybind::batch_manager::algorithms::initBindings(mInternal);
|
||||
|
||||
tpb::GptManager::initBindings(m);
|
||||
}
|
||||
|
||||
39
cpp/tensorrt_llm/pybind/common/algorithmBindings.h
Normal file
39
cpp/tensorrt_llm/pybind/common/algorithmBindings.h
Normal file
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "opaqueBindings.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <string>
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
namespace PybindUtils
|
||||
{
|
||||
template <typename T>
|
||||
void makeAlgorithmBindings(py::module_& m)
|
||||
{
|
||||
py::class_<T>(m, T::name).def(py::init()).def("forward", &T::forward).def("name", [](T const&) { return T::name; });
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void instantiatePybindAlgorithm(py::module_& m);
|
||||
} // namespace PybindUtils
|
||||
|
||||
#define INSTANTIATE_ALGORITHM(TYPE) \
|
||||
template <> \
|
||||
void PybindUtils::instantiatePybindAlgorithm<TYPE>(py::module_ & m) \
|
||||
{ \
|
||||
makeAlgorithmBindings<TYPE>(m); \
|
||||
};
|
||||
18
cpp/tensorrt_llm/pybind/common/opaqueBindings.h
Normal file
18
cpp/tensorrt_llm/pybind/common/opaqueBindings.h
Normal file
@ -0,0 +1,18 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/common.h"
|
||||
#include <pybind11/stl_bind.h>
|
||||
|
||||
PYBIND11_MAKE_OPAQUE(tensorrt_llm::batch_manager::RequestVector)
|
||||
@ -408,9 +408,12 @@ void InitBindings(pybind11::module_& m)
|
||||
.def_readwrite("is_sequence_final", &tle::Result::isSequenceFinal);
|
||||
|
||||
py::class_<tle::Response>(m, "Response")
|
||||
.def(py::init<IdType, std::string>(), py::arg("request_id"), py::arg("error_msg"))
|
||||
.def(py::init<IdType, tle::Result>(), py::arg("request_id"), py::arg("result"))
|
||||
.def(py::init<IdType, std::string, std::optional<IdType>>(), py::arg("request_id"), py::arg("error_msg"),
|
||||
py::arg("client_id") = std::nullopt)
|
||||
.def(py::init<IdType, tle::Result, std::optional<IdType>>(), py::arg("request_id"), py::arg("result"),
|
||||
py::arg("client_id") = std::nullopt)
|
||||
.def_property_readonly("request_id", &tle::Response::getRequestId)
|
||||
.def_property_readonly("client_id", &tle::Response::getClientId)
|
||||
.def("has_error", &tle::Response::hasError)
|
||||
.def_property_readonly("error_msg", &tle::Response::getErrorMsg)
|
||||
.def_property_readonly("result", &tle::Response::getResult);
|
||||
|
||||
@ -16,6 +16,8 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tensorrt_llm::pybind::executor
|
||||
|
||||
@ -16,8 +16,10 @@
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
@ -17,10 +17,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace PYBIND11_NAMESPACE
|
||||
{
|
||||
|
||||
@ -17,11 +17,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
#include "tensorrt_llm/executor/tensor.h"
|
||||
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
|
||||
#include "tensorrt_llm/runtime/torch.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <torch/extension.h>
|
||||
|
||||
namespace PYBIND11_NAMESPACE
|
||||
|
||||
69
cpp/tensorrt_llm/pybind/utils/bindTypes.h
Normal file
69
cpp/tensorrt_llm/pybind/utils/bindTypes.h
Normal file
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
|
||||
#include <pybind11/pybind11.h>
|
||||
|
||||
namespace PybindUtils
|
||||
{
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
template <typename T>
|
||||
void bindList(py::module& m, std::string const& name)
|
||||
{
|
||||
py::class_<T>(m, name.c_str())
|
||||
.def(py::init())
|
||||
.def("push_back", [](T& lst, const typename T::value_type& value) { lst.push_back(value); })
|
||||
.def("pop_back", [](T& lst) { lst.pop_back(); })
|
||||
.def("push_front", [](T& lst, const typename T::value_type& value) { lst.push_front(value); })
|
||||
.def("pop_front", [](T& lst) { lst.pop_front(); })
|
||||
.def("__len__", [](T const& lst) { return lst.size(); })
|
||||
.def(
|
||||
"__iter__", [](T& lst) { return py::make_iterator(lst.begin(), lst.end()); }, py::keep_alive<0, 1>())
|
||||
.def("__getitem__",
|
||||
[](T const& lst, size_t index)
|
||||
{
|
||||
if (index >= lst.size())
|
||||
throw py::index_error();
|
||||
auto it = lst.begin();
|
||||
std::advance(it, index);
|
||||
return *it;
|
||||
})
|
||||
.def("__setitem__",
|
||||
[](T& lst, size_t index, const typename T::value_type& value)
|
||||
{
|
||||
if (index >= lst.size())
|
||||
throw py::index_error();
|
||||
auto it = lst.begin();
|
||||
std::advance(it, index);
|
||||
*it = value;
|
||||
});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void bindSet(py::module& m, std::string const& name)
|
||||
{
|
||||
py::class_<T>(m, name.c_str())
|
||||
.def(py::init())
|
||||
.def("clear", &T::clear)
|
||||
.def("size", &T::size)
|
||||
// .def("insert", py::overload_cast<const typename T::value_type&>(&T::insert))
|
||||
.def("erase", py::overload_cast<typename T::value_type const&>(&T::erase))
|
||||
.def("__contains__", [](T const& s, typename T::value_type x) { return s.find(x) != s.end(); })
|
||||
.def(
|
||||
"__iter__", [](T& s) { return py::make_iterator(s.begin(), s.end()); }, py::keep_alive<0, 1>());
|
||||
}
|
||||
|
||||
} // namespace PybindUtils
|
||||
@ -22,6 +22,7 @@
|
||||
#include "pybind11/detail/descr.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include "tensorrt_llm/pybind/common/opaqueBindings.h"
|
||||
#include <filesystem>
|
||||
|
||||
namespace PYBIND11_NAMESPACE
|
||||
|
||||
@ -161,6 +161,19 @@ void GptDecoder<T>::setup(SamplingConfig const& samplingConfig, size_t batchSize
|
||||
lookaheadParams->attentionPackedMasks = output->lookaheadOutputs->packedMasks;
|
||||
setupParams->decodingParams = std::move(lookaheadParams);
|
||||
}
|
||||
else if (mDecodingMode.isExternalDraftTokens())
|
||||
{
|
||||
auto externalDraftTokensParams = std::make_shared<tl::ExternalDraftTokensSetupParams>();
|
||||
// signed to unsigned
|
||||
if (mSamplingConfig.topK)
|
||||
{
|
||||
auto const& topK = mSamplingConfig.topK.value();
|
||||
externalDraftTokensParams->runtimeTopK = std::vector<SizeType32>(std::begin(topK), std::end(topK));
|
||||
}
|
||||
|
||||
externalDraftTokensParams->runtimeTopP = mSamplingConfig.topP;
|
||||
setupParams->decodingParams = std::move(externalDraftTokensParams);
|
||||
}
|
||||
setupParams->decodingParams->randomSeed = mSamplingConfig.randomSeed;
|
||||
|
||||
mDecodingLayerWorkspace->setDeviceBatchSlots(batchSlots);
|
||||
@ -244,6 +257,27 @@ void prepareMedusaInputs(
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void prepareExternalDraftTokensInputs(
|
||||
DecodingInput const& inputs, size_t maxBatchSize, std::shared_ptr<tl::DecodingInputs>& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto inputParams = std::dynamic_pointer_cast<tl::ExternalDraftTokensInputs>(baseInputs);
|
||||
|
||||
auto const& externalDraftTokensInputs = inputs.externalDraftTokensInputs.value();
|
||||
|
||||
inputParams->draftLogits = externalDraftTokensInputs.draftLogits;
|
||||
inputParams->draftProbs = externalDraftTokensInputs.draftProbs;
|
||||
inputParams->targetProbs = externalDraftTokensInputs.targetProbs;
|
||||
inputParams->numDraftTokens = externalDraftTokensInputs.numDraftTokens;
|
||||
inputParams->draftTokenIds = externalDraftTokensInputs.draftTokenIds;
|
||||
inputParams->constantThreshold = externalDraftTokensInputs.constantThreshold;
|
||||
inputParams->useRandomAcceptanceThreshold = externalDraftTokensInputs.useRandomAcceptanceThreshold;
|
||||
inputParams->step = externalDraftTokensInputs.step;
|
||||
inputParams->useDraftLogits = externalDraftTokensInputs.useDraftLogits;
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void prepareExplicitDraftTokensInput(DecodingInput const& inputs, std::shared_ptr<tl::DecodingInputs>& baseInputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
@ -316,6 +350,11 @@ std::shared_ptr<tl::BaseDecodingInputs> prepareInputs(
|
||||
forwardParams
|
||||
= std::make_shared<tl::ExplicitDraftTokensInputs>(input.endIds, input.batchSlots, input.batchSize);
|
||||
}
|
||||
else if (decodingMode.isExternalDraftTokens())
|
||||
{
|
||||
forwardParams = std::make_shared<tl::ExternalDraftTokensInputs>(
|
||||
input.endIds, input.batchSlots, input.step, ite, input.batchSize);
|
||||
}
|
||||
|
||||
// No logits for explicit draft tokens
|
||||
if (!decodingMode.isExplicitDraftTokens())
|
||||
@ -379,6 +418,11 @@ std::shared_ptr<tl::BaseDecodingInputs> prepareInputs(
|
||||
forwardParams->localBatchSize = input.batchSize;
|
||||
}
|
||||
|
||||
if (decodingMode.isExternalDraftTokens())
|
||||
{
|
||||
prepareExternalDraftTokensInputs(input, maxBatchSize, forwardParams);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
|
||||
return forwardParams;
|
||||
@ -593,105 +637,3 @@ namespace tensorrt_llm::runtime
|
||||
template class GptDecoder<float>;
|
||||
template class GptDecoder<half>;
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
void IGptDecoder::acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds,
|
||||
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths, ITensor const& finishedVec,
|
||||
ITensor& finishedFinal, ITensor& finishedSum, ITensor const& batchSlots, BufferManager::CudaStreamPtr const& stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const finishedVecShape = finishedVec.getShape();
|
||||
auto const maxBatchSize = finishedVecShape.d[1];
|
||||
auto const batchSlotsShape = batchSlots.getShape();
|
||||
auto const batchSize = batchSlotsShape.d[0];
|
||||
auto const targetTokenIdsShape = targetTokenIds.getShape();
|
||||
auto const beamWidth = targetTokenIdsShape.d[1];
|
||||
auto const maxSeqLength = targetTokenIdsShape.d[2];
|
||||
auto const maxDraftTokens = draftTokenIds.getDimension<1>();
|
||||
|
||||
TLLM_CHECK_WITH_INFO(beamWidth == 1,
|
||||
common::fmtstr("Beam width (" FMT_DIM ") > 1 is not supported for the speculative decoding", beamWidth));
|
||||
|
||||
TLLM_CHECK_WITH_INFO(batchSize <= maxBatchSize,
|
||||
common::fmtstr("Batch size (" FMT_DIM ") is not smaller or equal to max batch size (" FMT_DIM ")", batchSize,
|
||||
maxBatchSize));
|
||||
|
||||
TLLM_CHECK_WITH_INFO(draftTokenIds.getDimension<0>() == maxBatchSize,
|
||||
common::fmtstr("Draft tokens batch size (" FMT_DIM ") is not equal to target batch size (" FMT_DIM ")",
|
||||
draftTokenIds.getDimension<0>(), maxBatchSize));
|
||||
|
||||
TLLM_CHECK_WITH_INFO(contextLengths.getDimension<0>() == maxBatchSize,
|
||||
common::fmtstr("Context length batch size (" FMT_DIM ") is not equal to batch size (" FMT_DIM ")",
|
||||
contextLengths.getDimension<0>(), maxBatchSize));
|
||||
|
||||
TLLM_CHECK_WITH_INFO(numDraftTokens.getDimension<0>() == maxBatchSize,
|
||||
common::fmtstr("Num draft tokens batch size (" FMT_DIM ") is not equal to batch size (" FMT_DIM ")",
|
||||
numDraftTokens.getDimension<0>(), maxBatchSize));
|
||||
|
||||
TLLM_CHECK_WITH_INFO(sequenceLengths.getDimension<0>() == maxBatchSize,
|
||||
common::fmtstr("Sequence length batch size (" FMT_DIM ") is not equal to batch size (" FMT_DIM ")",
|
||||
sequenceLengths.getDimension<0>(), maxBatchSize));
|
||||
|
||||
tksd::invokeAcceptDraftTokensByIds(bufferCast<TokenIdType>(draftTokenIds), bufferCast<TokenIdType>(targetTokenIds),
|
||||
bufferCast<SizeType32>(contextLengths), bufferCast<SizeType32>(numDraftTokens),
|
||||
bufferCast<SizeType32>(sequenceLengths),
|
||||
reinterpret_cast<tensorrt_llm::kernels::FinishedState const*>(
|
||||
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(finishedVec)),
|
||||
reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
|
||||
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(finishedFinal)),
|
||||
bufferCast<int>(finishedSum), bufferCast<SizeType32>(batchSlots), batchSize, maxBatchSize, beamWidth,
|
||||
maxSeqLength, maxDraftTokens, stream->get());
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void IGptDecoder::acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const& targetLogits, ITensor& draftProbs,
|
||||
ITensor& targetProbs, ITensor const& numDraftTokens, ITensor& finished, ITensor const& batchSlots,
|
||||
SizeType32 vocabSize, SizeType32 vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
|
||||
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const draftLogitsShape = draftLogits.getShape();
|
||||
auto const maxBatchSize = draftLogitsShape.d[0];
|
||||
auto const maxTokensPerStep = draftLogitsShape.d[1];
|
||||
auto const batchSlotsShape = batchSlots.getShape();
|
||||
auto const batchSize = batchSlotsShape.d[0];
|
||||
auto constexpr beamWidth = 1;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
beamWidth == 1, common::fmtstr("Beam width (%d) > 1 is not supported for the speculative decoding", beamWidth));
|
||||
|
||||
TLLM_CHECK(draftLogitsShape.d[2] == vocabSize);
|
||||
|
||||
if (draftLogits.getDataType() == nvinfer1::DataType::kFLOAT)
|
||||
{
|
||||
tksd::acceptDraftTokensByLogits(bufferCast<float>(draftLogits),
|
||||
const_cast<float**>(reinterpret_cast<float const* const*>(bufferCast<int64_t>(targetLogits))),
|
||||
bufferCast<float>(draftProbs), bufferCast<float>(targetProbs), bufferCast<SizeType32>(numDraftTokens),
|
||||
reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
|
||||
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(finished)),
|
||||
curandState, bufferCast<SizeType32>(batchSlots), batchSize, maxBatchSize, beamWidth, vocabSize,
|
||||
vocabSizePadded, maxTokensPerStep, useRandomAcceptThreshold, randomAcceptThreshold, stream->get());
|
||||
}
|
||||
else if (draftLogits.getDataType() == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
tksd::acceptDraftTokensByLogits(bufferCast<half>(draftLogits),
|
||||
const_cast<half**>(reinterpret_cast<half const* const*>(bufferCast<int64_t>(targetLogits))),
|
||||
bufferCast<half>(draftProbs), bufferCast<half>(targetProbs), bufferCast<SizeType32>(numDraftTokens),
|
||||
reinterpret_cast<tensorrt_llm::kernels::FinishedState*>(
|
||||
bufferCast<tensorrt_llm::kernels::FinishedState::UnderlyingType>(finished)),
|
||||
curandState, bufferCast<SizeType32>(batchSlots), batchSize, maxBatchSize, beamWidth, vocabSize,
|
||||
vocabSizePadded, maxTokensPerStep, useRandomAcceptThreshold, randomAcceptThreshold, stream->get());
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Incorrect logits dtype. Only float32 and float16 are supported");
|
||||
}
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -93,30 +93,28 @@ GptDecoderBatched::GptDecoderBatched(std::size_t vocabSize, std::size_t vocabSiz
|
||||
auto constexpr nvFloatType = TRTDataType<float>::value;
|
||||
|
||||
auto& dInput = mJointDecodingInput;
|
||||
auto dummyLogits = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
auto endIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
auto batchSlots = mBufferManager.emptyTensor(MemoryType::kPINNED, nvSizeType);
|
||||
dInput
|
||||
= std::make_unique<DecodingInput>(0, 0, 0, 0, std::move(dummyLogits), std::move(endIds), std::move(batchSlots));
|
||||
|
||||
{ // prevent reusing these vars after std::move
|
||||
auto dummyLogits = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
auto endIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
auto batchSlots = mBufferManager.emptyTensor(MemoryType::kPINNED, nvSizeType);
|
||||
dInput = std::make_unique<DecodingInput>(
|
||||
0, 0, 0, 0, std::move(dummyLogits), std::move(endIds), std::move(batchSlots));
|
||||
}
|
||||
dInput->sequenceLimitLength = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
dInput->lengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
|
||||
auto& dOutput = mJointDecodingOutput;
|
||||
auto outputIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
auto gatheredOutputIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
dOutput = std::make_unique<DecodingOutput>(std::move(outputIds), std::move(gatheredOutputIds));
|
||||
|
||||
{ // prevent reusing these vars after std::move
|
||||
auto outputIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
auto gatheredOutputIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
dOutput = std::make_unique<DecodingOutput>(std::move(outputIds), std::move(gatheredOutputIds));
|
||||
}
|
||||
dOutput->newTokensSteps = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
dOutput->parentIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
dOutput->parentIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
mFinishedSteps
|
||||
= mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<tk::FinishedState::UnderlyingType>::value);
|
||||
mDraftProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
mTargetProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
mBatchSlotsSetup = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<SizeType32>::value);
|
||||
mBatchSlotsDecoder = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<SizeType32>::value);
|
||||
mBatchSlotsAcceptTokens = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<SizeType32>::value);
|
||||
mBatchSlotsAcceptLogits = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<SizeType32>::value);
|
||||
mBatchSlotsSetup = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType);
|
||||
mBatchSlotsDecoder = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType);
|
||||
// use batchSize many entries instead of the usual 1
|
||||
dOutput->finishedSum = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType);
|
||||
mFinishedSum = BufferManager::pinned(ITensor::makeShape({1}), nvSizeType);
|
||||
@ -129,16 +127,10 @@ GptDecoderBatched::GptDecoderBatched(std::size_t vocabSize, std::size_t vocabSiz
|
||||
|
||||
dOutput->logProbsTiled = mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<float>::value);
|
||||
|
||||
mNumDraftTokens = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
mCurandStates = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT8);
|
||||
mDraftTokenIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
|
||||
mDraftLogits = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
mTargetLogitsPtrs = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<float*>::value);
|
||||
|
||||
dInput->stopWordsPtrs = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<int32_t*>::value);
|
||||
dInput->stopWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<SizeType32>::value);
|
||||
dInput->stopWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType);
|
||||
dInput->badWordsPtrs = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<int32_t*>::value);
|
||||
dInput->badWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, TRTDataType<SizeType32>::value);
|
||||
dInput->badWordsLens = mBufferManager.emptyTensor(MemoryType::kPINNEDPOOL, nvSizeType);
|
||||
dInput->embeddingBias = mBufferManager.emptyTensor(MemoryType::kGPU, dtype);
|
||||
|
||||
int device;
|
||||
@ -149,13 +141,13 @@ GptDecoderBatched::GptDecoderBatched(std::size_t vocabSize, std::size_t vocabSiz
|
||||
|
||||
if (!mSpeculativeDecodingMode.isNone())
|
||||
{
|
||||
allocateSpeculativeDecodingBuffers();
|
||||
allocateSpeculativeDecodingBuffers(dtype);
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatched::allocateSpeculativeDecodingBuffers()
|
||||
void GptDecoderBatched::allocateSpeculativeDecodingBuffers(nvinfer1::DataType dtype)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto constexpr nvSizeType = TRTDataType<SizeType32>::value;
|
||||
@ -201,6 +193,22 @@ void GptDecoderBatched::allocateSpeculativeDecodingBuffers()
|
||||
}
|
||||
dOutput->speculativeDecodingOutputs = speculativeDecodingOutputs;
|
||||
|
||||
if (mSpeculativeDecodingMode.isDraftTokensExternal())
|
||||
{
|
||||
DecodingInput::ExternalDraftTokensInputs externalDraftTokensInputs;
|
||||
|
||||
externalDraftTokensInputs.draftLogits = mBufferManager.emptyTensor(MemoryType::kGPU, dtype);
|
||||
externalDraftTokensInputs.draftProbs = mBufferManager.emptyTensor(MemoryType::kGPU, dtype);
|
||||
externalDraftTokensInputs.targetProbs = mBufferManager.emptyTensor(MemoryType::kGPU, dtype);
|
||||
externalDraftTokensInputs.numDraftTokens = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
externalDraftTokensInputs.useDraftLogits
|
||||
= mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<bool>::value);
|
||||
externalDraftTokensInputs.draftTokenIds
|
||||
= mBufferManager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
|
||||
|
||||
dInput->externalDraftTokensInputs = externalDraftTokensInputs;
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -251,6 +259,7 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max
|
||||
auto const maxTokensPerStepXmaxBatchSizeXmaxBeamWidth
|
||||
= ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize, maxBeamWidth});
|
||||
auto const maxBatchSizeXmaxTokensPerStep = ITensor::makeShape({maxBatchSize, maxTokensPerEngineStep});
|
||||
auto const jointOutputIdsShape = ITensor::makeShape({maxBatchSize, maxBeamWidth, maxSequenceLength});
|
||||
|
||||
auto& dInput = *mJointDecodingInput;
|
||||
dInput.maxLength = mMaxSequenceLength;
|
||||
@ -268,8 +277,6 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max
|
||||
inputLengths.reshape(maxBatchSizeXmaxBeamWidth);
|
||||
mBufferManager.setZero(inputLengths);
|
||||
|
||||
auto const jointOutputIdsShape = ITensor::makeShape({maxBatchSize, maxBeamWidth, maxSequenceLength});
|
||||
|
||||
auto& dOutput = *mJointDecodingOutput;
|
||||
dOutput.ids->reshape(jointOutputIdsShape);
|
||||
|
||||
@ -296,15 +303,18 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max
|
||||
|
||||
mBatchSlotsSetup->reshape(ITensor::makeShape({maxBatchSize}));
|
||||
mBatchSlotsDecoder->reshape(ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize}));
|
||||
mBatchSlotsAcceptTokens->reshape(ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize}));
|
||||
mBatchSlotsAcceptLogits->reshape(ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize}));
|
||||
|
||||
if (mSpeculativeDecodingMode.isDraftTokensExternal())
|
||||
{
|
||||
mDraftProbs->reshape(ITensor::makeShape(
|
||||
dInput.externalDraftTokensInputs->draftProbs->reshape(ITensor::makeShape(
|
||||
{maxBatchSize, maxTokensPerEngineStep, maxBeamWidth, static_cast<SizeType32>(mVocabSizePadded)}));
|
||||
mTargetProbs->reshape(ITensor::makeShape(
|
||||
dInput.externalDraftTokensInputs->targetProbs->reshape(ITensor::makeShape(
|
||||
{maxBatchSize, maxTokensPerEngineStep, maxBeamWidth, static_cast<SizeType32>(mVocabSizePadded)}));
|
||||
dInput.externalDraftTokensInputs->draftLogits->reshape(
|
||||
ITensor::makeShape({maxBatchSize, maxTokensPerEngineStep, static_cast<SizeType32>(mVocabSizePadded)}));
|
||||
dInput.externalDraftTokensInputs->draftTokenIds->reshape(maxBatchSizeXmaxTokensPerStep);
|
||||
dInput.externalDraftTokensInputs->numDraftTokens->reshape(ITensor::makeShape({maxBatchSize, 1}));
|
||||
dInput.externalDraftTokensInputs->useDraftLogits->reshape(ITensor::makeShape({maxBatchSize, 1}));
|
||||
}
|
||||
|
||||
dOutput.parentIds->reshape(jointOutputIdsShape);
|
||||
@ -317,7 +327,7 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max
|
||||
dOutput.cumLogProbs->reshape(maxBatchSizeXmaxBeamWidth);
|
||||
mBufferManager.setZero(*dOutput.cumLogProbs);
|
||||
|
||||
dOutput.logProbs->reshape(ITensor::makeShape({maxBatchSize, maxBeamWidth, mMaxSequenceLength}));
|
||||
dOutput.logProbs->reshape(jointOutputIdsShape);
|
||||
mBufferManager.setZero(*dOutput.logProbs);
|
||||
|
||||
if (maxBeamWidth > 1)
|
||||
@ -328,15 +338,6 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max
|
||||
dOutput.logProbsTiled->reshape(ITensor::makeShape({maxSequenceLength, maxBatchSize, maxBeamWidth}));
|
||||
mBufferManager.setZero(*dOutput.logProbsTiled);
|
||||
|
||||
// speculative decoding only works for beam width == 1
|
||||
mDraftTokenIds->reshape(maxBatchSizeXmaxTokensPerStep);
|
||||
mDraftLogits->reshape(
|
||||
ITensor::makeShape({maxBatchSize, maxTokensPerEngineStep, static_cast<SizeType32>(mVocabSizePadded)}));
|
||||
mAcceptByLogits.resize(maxBatchSize);
|
||||
mNumDraftTokens->reshape(ITensor::makeShape({maxBatchSize, 1}));
|
||||
mCurandStates->reshape(ITensor::makeShape({maxBatchSize, sizeof(curandState_t)}));
|
||||
mTargetLogitsPtrs->reshape(ITensor::makeShape({maxTokensPerEngineStep, maxBatchSize}));
|
||||
|
||||
const_cast<ITensor&>(*dInput.embeddingBias)
|
||||
.reshape(ITensor::makeShape({maxBatchSize, static_cast<SizeType32>(mVocabSizePadded)}));
|
||||
const_cast<ITensor&>(*dInput.badWordsPtrs).reshape(ITensor::makeShape({maxBatchSize}));
|
||||
@ -591,7 +592,6 @@ void GptDecoderBatched::newRequestSpeculativeDecoding(
|
||||
SizeType32 batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
mAcceptByLogits[batchIdx] = false;
|
||||
|
||||
if (mSpeculativeDecodingMode.predictsDraftTokens())
|
||||
{
|
||||
@ -639,40 +639,41 @@ void GptDecoderBatched::newRequestDraftTokensExternal(
|
||||
auto const& stream = mDecoderStream;
|
||||
BufferManager manager{stream};
|
||||
|
||||
auto constexpr localBatchSize = 1;
|
||||
auto& dJointInput = *mJointDecodingInput;
|
||||
auto useDraftLogits = false;
|
||||
|
||||
auto const numDraftTokens = request.generatedTokensPerEngineStep - 1;
|
||||
if (request.draftLogits.has_value())
|
||||
{
|
||||
TensorPtr draftLogitsView = ITensor::view(request.draftLogits.value());
|
||||
mAcceptByLogits[batchIdx] = true;
|
||||
useDraftLogits = true;
|
||||
|
||||
TensorPtr draftLogitsReqBatchSlice = ITensor::slice(mDraftLogits, batchIdx, 1);
|
||||
TensorPtr draftLogitsReqBatchSlice
|
||||
= ITensor::slice(dJointInput.externalDraftTokensInputs->draftLogits, batchIdx, 1);
|
||||
draftLogitsReqBatchSlice->squeeze(0);
|
||||
TensorPtr draftLogitsReqTokensSlice = ITensor::slice(draftLogitsReqBatchSlice, 0, numDraftTokens);
|
||||
manager.copy(*draftLogitsView, *draftLogitsReqTokensSlice);
|
||||
}
|
||||
TensorPtr draftTokensReqBatchSlice = ITensor::slice(mDraftTokenIds, batchIdx, 1);
|
||||
auto useDraftLogitsView = ITensor::slice(dJointInput.externalDraftTokensInputs->useDraftLogits, batchIdx, 1);
|
||||
kernels::invokeFill(*useDraftLogitsView, useDraftLogits, *stream);
|
||||
|
||||
TensorPtr draftTokensReqBatchSlice
|
||||
= ITensor::slice(dJointInput.externalDraftTokensInputs->draftTokenIds, batchIdx, 1);
|
||||
draftTokensReqBatchSlice->squeeze(0);
|
||||
TensorPtr draftTokensReqTokensSlice = ITensor::slice(draftTokensReqBatchSlice, 0, numDraftTokens);
|
||||
TensorPtr draftTokensView = ITensor::view(request.draftTokens, ITensor::makeShape({numDraftTokens}));
|
||||
manager.copy(*draftTokensView, *draftTokensReqTokensSlice);
|
||||
|
||||
auto const curandStatesView = ITensor::slice(mCurandStates, batchIdx, 1);
|
||||
auto curandState = reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*curandStatesView));
|
||||
auto batchSlotsPtr = bufferCast<SizeType32>(*ITensor::slice(mBatchSlotsSetup, 0, localBatchSize));
|
||||
if (samplingConfig.randomSeed.has_value())
|
||||
{
|
||||
tk::invokeCurandInitialize(
|
||||
curandState, batchSlotsPtr, localBatchSize, samplingConfig.randomSeed.value()[0], stream->get());
|
||||
}
|
||||
else
|
||||
{
|
||||
tk::invokeCurandInitialize(curandState, batchSlotsPtr, localBatchSize, 0, stream->get());
|
||||
}
|
||||
auto numDraftTokensView = ITensor::slice(mNumDraftTokens, batchIdx, 1);
|
||||
auto numDraftTokensView = ITensor::slice(dJointInput.externalDraftTokensInputs->numDraftTokens, batchIdx, 1);
|
||||
kernels::invokeFill(*numDraftTokensView, numDraftTokens, *stream);
|
||||
|
||||
bool const useRandomAcceptanceThreshold = !samplingConfig.draftAcceptanceThreshold.has_value();
|
||||
float const constantThreshold
|
||||
= useRandomAcceptanceThreshold ? 0 : samplingConfig.draftAcceptanceThreshold.value()[0];
|
||||
|
||||
dJointInput.externalDraftTokensInputs->useRandomAcceptanceThreshold = useRandomAcceptanceThreshold;
|
||||
dJointInput.externalDraftTokensInputs->constantThreshold = constantThreshold;
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -838,8 +839,6 @@ void GptDecoderBatched::forwardDecoder(
|
||||
|
||||
auto batchSlotsDecoderPtr = maxBeamWidth > 1 && input.seqSlots ? bufferCast<SizeType32>(*input.seqSlots)
|
||||
: bufferCast<SizeType32>(*mBatchSlotsDecoder);
|
||||
auto batchSlotsAcceptTokensPtr = bufferCast<SizeType32>(*mBatchSlotsAcceptTokens);
|
||||
auto batchSlotsAcceptLogitsPtr = bufferCast<SizeType32>(*mBatchSlotsAcceptLogits);
|
||||
auto& dInput = *mJointDecodingInput;
|
||||
auto& dOutput = *mJointDecodingOutput;
|
||||
auto& decoder = *mDecoder;
|
||||
@ -864,26 +863,12 @@ void GptDecoderBatched::forwardDecoder(
|
||||
}
|
||||
|
||||
SizeType32 localBatchDecoderIdx = 0;
|
||||
SizeType32 localBatchAcceptTokensIdx = 0;
|
||||
SizeType32 localBatchAcceptLogitsIdx = 0;
|
||||
for (SizeType32 bi = 0; bi < mActualBatchSize; ++bi)
|
||||
{
|
||||
if (mFinished[bi] || !input.active.at(bi) || step >= mNumDecodingEngineTokens[bi])
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if (!mAcceptByLogits[bi] && mMaxDecodingDecoderTokens == 1 && mNumDecodingEngineTokens[bi] > 1
|
||||
&& step == mNumDecodingEngineTokens[bi] - 1)
|
||||
{
|
||||
batchSlotsAcceptTokensPtr[step * mActualBatchSize + localBatchAcceptTokensIdx] = bi;
|
||||
localBatchAcceptTokensIdx++;
|
||||
}
|
||||
else if (mAcceptByLogits[bi] && mMaxDecodingDecoderTokens == 1 && mNumDecodingEngineTokens[bi] > 1 && step == 0)
|
||||
{
|
||||
batchSlotsAcceptLogitsPtr[step * mActualBatchSize + localBatchAcceptLogitsIdx] = bi;
|
||||
localBatchAcceptLogitsIdx++;
|
||||
}
|
||||
batchSlotsDecoderPtr[step * mActualBatchSize + localBatchDecoderIdx] = bi;
|
||||
localBatchDecoderIdx++;
|
||||
}
|
||||
@ -892,9 +877,6 @@ void GptDecoderBatched::forwardDecoder(
|
||||
= *std::max_element(std::begin(mNumDecodingEngineTokens), std::end(mNumDecodingEngineTokens));
|
||||
|
||||
std::vector<SharedConstPtr> logitsVec;
|
||||
auto targetLogitsPtrsSlice = ITensor::slice(mTargetLogitsPtrs, step, 1);
|
||||
auto targetLogitsPtrsSlicePtr = reinterpret_cast<void const**>(bufferCast<int64_t>(*targetLogitsPtrsSlice));
|
||||
SizeType32 targetLogitsIdx = 0;
|
||||
for (SizeType32 bi = 0; bi < mActualBatchSize; ++bi)
|
||||
{
|
||||
if (mFinished[bi] || !input.active.at(bi) || step >= mNumDecodingEngineTokens[bi])
|
||||
@ -904,32 +886,6 @@ void GptDecoderBatched::forwardDecoder(
|
||||
auto const& targetLogits = allTargetLogits[bi];
|
||||
TensorPtr logitsSlice = ITensor::slice(targetLogits, step, singleRequest);
|
||||
logitsVec.push_back(logitsSlice);
|
||||
targetLogitsPtrsSlicePtr[targetLogitsIdx++] = logitsSlice->data();
|
||||
}
|
||||
|
||||
if (async && localBatchAcceptLogitsIdx > 0)
|
||||
{
|
||||
// These params are only used for testing. Thus, can be per batch instead of per request
|
||||
auto const& samplingConfig = decoder.getSamplingConfig();
|
||||
bool const useRandomAcceptanceThreshold = !samplingConfig.draftAcceptanceThreshold.has_value();
|
||||
float const randomAcceptanceThreshold
|
||||
= useRandomAcceptanceThreshold ? 0 : samplingConfig.draftAcceptanceThreshold.value()[0];
|
||||
|
||||
TensorPtr batchSlotsAcceptLogitsStepSlice = ITensor::slice(mBatchSlotsAcceptLogits, step, 1);
|
||||
batchSlotsAcceptLogitsStepSlice->squeeze(0);
|
||||
TensorPtr batchSlotsAcceptLogitsSlice
|
||||
= ITensor::slice(batchSlotsAcceptLogitsStepSlice, 0, localBatchAcceptLogitsIdx);
|
||||
|
||||
IGptDecoder::acceptDraftTokensByLogits(
|
||||
/* [maxBatchSize, maxDecodingTokens, vocabPadded] */ *mDraftLogits,
|
||||
/* [maxBatchSize][maxDecodingTokens, vocabPadded] */ *targetLogitsPtrsSlice,
|
||||
/* [maxBatchSize, maxDecodingTokens, vocabPadded] */ *mDraftProbs,
|
||||
/* [maxBatchSize, maxDecodingTokens, vocabPadded] */ *mTargetProbs,
|
||||
/* [maxBatchSize] */ *mNumDraftTokens,
|
||||
/* [maxDecodingTokens, maxBatchSize] */ *mFinishedSteps,
|
||||
/* [bs] */ *batchSlotsAcceptLogitsSlice, static_cast<SizeType32>(mVocabSize),
|
||||
static_cast<SizeType32>(mVocabSizePadded), useRandomAcceptanceThreshold, randomAcceptanceThreshold,
|
||||
reinterpret_cast<curandState_t*>(bufferCast<int8_t>(*mCurandStates)), stream);
|
||||
}
|
||||
|
||||
TensorPtr finishedStepsInput = ITensor::slice(mFinishedSteps, step, 1);
|
||||
@ -958,6 +914,11 @@ void GptDecoderBatched::forwardDecoder(
|
||||
dInput.medusaInputs->medusaLogits = input.predictedDraftLogits;
|
||||
}
|
||||
|
||||
if (mSpeculativeDecodingMode.isDraftTokensExternal())
|
||||
{
|
||||
dInput.externalDraftTokensInputs->step = step;
|
||||
}
|
||||
|
||||
dOutput.newTokens = newTokensStepView;
|
||||
dOutput.finishReasons = finishedStepsOutput;
|
||||
dOutput.lengths = sequenceLengths;
|
||||
@ -987,26 +948,6 @@ void GptDecoderBatched::forwardDecoder(
|
||||
mNbSteps[bi] += 1;
|
||||
mFinished[bi] = mNbSteps[bi] >= mMaxNewTokens[bi];
|
||||
}
|
||||
if (async && localBatchAcceptTokensIdx > 0)
|
||||
{
|
||||
TensorPtr batchSlotsAcceptTokensStepSlice = ITensor::slice(mBatchSlotsAcceptTokens, step, 1);
|
||||
batchSlotsAcceptTokensStepSlice->squeeze(0);
|
||||
auto batchSlotsAcceptTokensSlice
|
||||
= ITensor::slice(batchSlotsAcceptTokensStepSlice, 0, localBatchAcceptTokensIdx);
|
||||
|
||||
// Update finished state for 0th step
|
||||
auto finishedFinal = ITensor::slice(mFinishedSteps, step, 1);
|
||||
IGptDecoder::acceptDraftTokensByIds(
|
||||
/* [maxBatchSize, maxBeamWidth, maxSeqLen] */ *dOutput.ids,
|
||||
/* [maxBatchSize, maxDecodingDraftTokens] */ *mDraftTokenIds,
|
||||
/* [maxBatchSize] */ *dInput.lengths,
|
||||
/* [maxBatchSize] */ *mNumDraftTokens,
|
||||
/* [maxBatchSize] */ *dOutput.lengths,
|
||||
/* [maxDecodingTokens, maxBatchSize] */ *mFinishedSteps,
|
||||
/* [maxBatchSize] */ *finishedFinal,
|
||||
/* [maxBatchSize] */ *dOutput.finishedSum,
|
||||
/* [bs] */ *batchSlotsAcceptTokensSlice, stream);
|
||||
}
|
||||
|
||||
// If last iteration
|
||||
if (async && step == maxDecodingEngineTokens - mMaxDecodingDecoderTokens)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user