Update TensorRT-LLM (#465)

* Update TensorRT-LLM

---------

Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2023-11-24 22:12:26 +08:00 committed by GitHub
parent 6755a3f077
commit 711a28d9bf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
186 changed files with 4365 additions and 2485 deletions

View File

@ -261,7 +261,7 @@ The list of supported models is:
* [StarCoder](examples/gpt)
* [T5](examples/enc_dec)
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder support that contains many encoder-decoder models such as T5, Flan-T5, etc. We unroll the exact model names in the list above to let users find specific models easiler.
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder support that contains many encoder-decoder models such as T5, Flan-T5, etc. We unroll the exact model names in the list above to let users find specific models easier.
## Performance

View File

@ -16,25 +16,19 @@
*/
#include "tensorrt_llm/batch_manager/GptManager.h"
#include "tensorrt_llm/batch_manager/NamedTensor.h"
#include "tensorrt_llm/batch_manager/callbacks.h"
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/common/mpiUtils.h"
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
#include "tensorrt_llm/runtime/gptJsonConfig.h"
#include "tensorrt_llm/runtime/gptSession.h"
#include "tensorrt_llm/runtime/tllmLogger.h"
#include <NvInfer.h>
#include <NvInferPlugin.h>
#include <chrono>
#include <cxxopts.hpp>
#include <iostream>
#include <nlohmann/json.hpp>
#include <string>
using namespace tensorrt_llm::runtime;
using namespace tensorrt_llm::batch_manager;
using namespace tensorrt_llm::runtime;
using namespace tensorrt_llm::mpi;
@ -456,7 +450,7 @@ std::pair<std::vector<std::vector<int32_t>>, std::vector<int32_t>> parseDataset(
}
void benchmarkGptManager(std::string const& modelName, std::filesystem::path const& engineDir, std::string const& type,
std::string const& datasetPath, int beamWidth, std::shared_ptr<nvinfer1::ILogger> const& logger,
std::string const& datasetPath, int beamWidth, int warmUp, std::shared_ptr<nvinfer1::ILogger> const& logger,
TrtGptModelOptionalParams const& optionalParams, batch_scheduler::SchedulerPolicy schedulerPolicy)
{
auto const worldConfig = WorldConfig::mpi(*logger);
@ -506,6 +500,19 @@ void benchmarkGptManager(std::string const& modelName, std::filesystem::path con
if (worldConfig.getRank() == 0)
{
// Warm up
for (auto i = 1; i < warmUp + 1; ++i)
{
// skip terminateReqId
if (i == terminateReqId)
{
i += 1;
}
gptServer->enqueue(tensors_list[0], i, false);
}
gptServer->waitForEmpty();
// Benchmark
recorder->initialize();
for (int i = 0; i < tensors_list.size(); ++i)
{
@ -539,6 +546,8 @@ int main(int argc, char* argv[])
cxxopts::value<std::string>()->default_value(""));
options.add_options()(
"beam_width", "Specify beam width you want to benchmark.", cxxopts::value<int>()->default_value("1"));
options.add_options()(
"warm_up", "Specify warm up iterations before benchmark starts.", cxxopts::value<int>()->default_value("2"));
options.add_options()("max_num_sequences", "Max number of Sequences.", cxxopts::value<int>());
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
@ -651,7 +660,7 @@ int main(int argc, char* argv[])
try
{
benchmarkGptManager(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), type,
datasetPath, beamWidth, logger, optionalParams, schedulerPolicy);
datasetPath, beamWidth, result["warm_up"].as<int>(), logger, optionalParams, schedulerPolicy);
}
catch (const std::exception& e)
{

View File

@ -43,9 +43,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
std::string modelNameHyphen = modelName;
std::filesystem::path jsonFileName = dataPath / "config.json";
if (tc::strStartsWith(modelName, "chatglm"))
if (tc::strStartsWith(modelName, "chatglm") || tc::strStartsWith(modelName, "glm"))
{
std::replace(modelNameHyphen.begin(), modelNameHyphen.end(), '_', '-');
jsonFileName = dataPath / (modelNameHyphen + std::string("-config.json"));
}
auto const json = GptJsonConfig::parse(jsonFileName);

View File

@ -15,8 +15,10 @@
from argparse import ArgumentParser
import tensorrt as trt
# isort: off
import torch
import tensorrt as trt
# isort: on
from cuda import cuda, cudart
from mpi4py import MPI
from polygraphy.backend.trt import CreateConfig, EngineFromNetwork

View File

@ -376,6 +376,7 @@ _allowed_configs = {
build_config=BuildConfig(
num_layers=28,
num_heads=32,
num_kv_heads=2,
hidden_size=4096,
vocab_size=65024,
hidden_act='swiglu',
@ -393,6 +394,7 @@ _allowed_configs = {
build_config=BuildConfig(
num_layers=28,
num_heads=32,
num_kv_heads=2,
hidden_size=4096,
vocab_size=65024,
hidden_act='swiglu',

View File

@ -45,7 +45,9 @@ def get_engine_name(model, dtype, tp_size, rank):
def serialize_engine(engine, path):
with open(path, 'wb') as f:
f.write(bytearray(engine))
# engine object is already complies with python buffer protocol, no need to
# convert it to bytearray before write, converting to bytearray consumes lots of memory
f.write(engine)
class BaseBenchmark(object):

View File

@ -14,11 +14,9 @@
# limitations under the License.
import argparse
import multiprocessing as mp
from multiprocessing import Process, Queue
from time import time
import torch
from mem_monitor import mem_monitor
def parse_arguments():
@ -213,6 +211,7 @@ def main(args):
from allowed_configs import get_allowed_models
from bert_benchmark import BERTBenchmark
from gpt_benchmark import GPTBenchmark
from mem_monitor import MemoryMonitor
from tensorrt_llm.logger import logger
@ -282,11 +281,8 @@ def main(args):
torch.cuda.empty_cache()
latencies = []
# Launch a subprocess to monitor memory usage
q1 = Queue() # q1 is used for sending signal to subprocess
q2 = Queue() # q2 is used for receiving results from subprocess
mem_monitor_process = Process(target=mem_monitor, args=(q1, q2))
mem_monitor_process.start()
memory_monitor = MemoryMonitor()
memory_monitor.start()
iter_idx = 0
try:
@ -313,15 +309,12 @@ def main(args):
except Exception as e:
print("Found exception during benchmarking", e.with_traceback())
mem_monitor_process.kill()
memory_monitor.kill()
raise e
logger.debug("Sending signal to mem monitor process, start")
q1.put(1)
logger.debug("Sending signal to mem monitor process, done")
peak_gpu_used = q2.get()
logger.debug("Get peak gpu memory usage from mem monitor process, done")
mem_monitor_process.join()
logger.debug("Memory monitor process joined")
memory_monitor.stop()
_, peak_gpu_used = memory_monitor.get_peak_memory_usage("GiB")
peak_gpu_used = round(peak_gpu_used, 3)
latency = round(sum(latencies) / iter_idx, 3)
latencies.sort()

View File

@ -16,8 +16,10 @@ import os
import time
from collections import OrderedDict
import tensorrt as trt
# isort: off
import torch
import tensorrt as trt
#isort: on
from allowed_configs import get_build_config
from base_benchmark import BaseBenchmark, serialize_engine

View File

@ -91,6 +91,7 @@ class GPTBenchmark(BaseBenchmark):
self.use_layernorm_plugin = False
self.use_rmsnorm_plugin = False
self.use_lookup_plugin = non_mha_plg_dtype
self.use_weight_only_quant_gemm_plugin = non_mha_plg_dtype
self.enable_context_fmha = use_mha_plugin
self.remove_input_padding = use_non_mha_plugin
@ -288,7 +289,8 @@ class GPTBenchmark(BaseBenchmark):
max_batch_size=self.max_batch_size,
max_input_len=self.max_input_len,
max_output_len=self.max_output_len,
int8=self.quant_mode.has_act_and_weight_quant(),
int8=self.quant_mode.has_act_and_weight_quant()
or self.quant_mode.is_int8_weight_only(),
quant_mode=self.quant_mode,
use_refit=self.refit,
opt_level=self.builder_opt,
@ -505,7 +507,7 @@ class GPTBenchmark(BaseBenchmark):
dtype=self.dtype)
network.plugin_config.set_quantize_tensor_plugin()
network.plugin_config.set_quantize_per_token_plugin()
elif self.use_weight_only:
elif self.use_weight_only and self.use_weight_only_quant_gemm_plugin:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=self.dtype)

View File

@ -12,29 +12,57 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from multiprocessing import Event, Process, Queue
import pynvml
from tensorrt_llm.logger import logger
from tensorrt_llm.profiler import (MemUnitType, bytes_to_target_unit,
device_memory_info)
def get_memory_info(handle):
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle,
version=pynvml.nvmlMemory_v2)
total = round(mem_info.total / 1024 / 1024 / 1024, 2)
used = round(mem_info.used / 1024 / 1024 / 1024, 2)
free = round(mem_info.free / 1024 / 1024 / 1024, 2)
return total, used, free
class MemoryMonitor:
def __init__(self, query_interval=0.1):
self.query_interval = query_interval # second(s)
self.mem_monitor_process = None
# bytes
self._peak_host_memory = 0
self._peak_device_memory = 0
def mem_monitor(q1, q2):
pynvml.nvmlInit()
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
self.device_handles = {}
peak_used = 0
while q1.empty():
_, used, _ = get_memory_info(handle)
peak_used = max(used, peak_used)
time.sleep(0.1)
self.signal_event = Event() # Sending signal to subprocess
self.peak_mem_queue = Queue() # Receiving results from subprocess
pynvml.nvmlShutdown()
q2.put(peak_used)
def start(self):
self.mem_monitor_process = Process(target=self._upd_peak_memory_usage,
args=(self.signal_event,
self.peak_mem_queue))
self.mem_monitor_process.start()
logger.debug("Launched memory monitor subprocess.")
def kill(self):
if self.mem_monitor_process is not None:
self.mem_monitor_process.kill()
logger.debug("Memory monitor subprocess is killed.")
def stop(self):
self.signal_event.set()
logger.debug("Sent signal to stop memory monitor subprocess.")
self._peak_device_memory = max(self._peak_device_memory,
self.peak_mem_queue.get())
self.mem_monitor_process.join()
self.mem_monitor_process = None
logger.debug("Memory monitor subprocess joined.")
def _upd_peak_memory_usage(self, signal_event, peak_mem_queue):
peak_used, _, _ = device_memory_info()
while not signal_event.is_set():
used, _, _ = device_memory_info()
peak_used = max(used, peak_used)
peak_mem_queue.put(peak_used)
def get_peak_memory_usage(self, unit: MemUnitType = 'GiB'):
return bytes_to_target_unit(self._peak_host_memory, unit), \
bytes_to_target_unit(self._peak_device_memory, unit)

View File

@ -45,15 +45,17 @@ class GptManager
{
public:
using SizeType = tensorrt_llm::runtime::SizeType;
using TokenIdType = tensorrt_llm::runtime::TokenIdType;
using RequestList = std::list<std::shared_ptr<LlmRequest>>;
using TensorPtr = runtime::ITensor::SharedPtr;
GptManager(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
GptManager(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, SizeType maxBeamWidth,
batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
SendResponseCallback sendResponseCb, PollStopSignalCallback pollStopSignalCb = nullptr,
ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
const TrtGptModelOptionalParams& optionalParams = TrtGptModelOptionalParams(),
std::optional<uint64_t> terminateReqId = std::nullopt);
std::optional<uint64_t> terminateReqId = std::nullopt, std::optional<SizeType> maxDraftTokens = std::nullopt,
bool excludeInputInOutput = false);
/* Wraps the user-provided callback for requests.
Adds requests to request table.
@ -70,6 +72,8 @@ public:
BatchManagerErrorCode_t waitUntilTerminate();
BatchManagerErrorCode_t shutdown();
virtual ~GptManager();
protected:
@ -78,17 +82,23 @@ protected:
virtual BatchManagerErrorCode_t step(RequestList& activeRequests, std::set<uint64_t>& activeRequestsIds);
private:
SizeType getMaxInputLen() const;
SizeType getMaxOutputLen() const;
SizeType getMaxNumSequences() const;
void validateLlmRequest(LlmRequest& newReq) const;
static std::shared_ptr<LlmRequest> fillLlmRequest(std::shared_ptr<InferenceRequest> newReq);
static std::shared_ptr<std::vector<int32_t>> getReqInputTokens(std::shared_ptr<InferenceRequest> new_req);
static int32_t getMaxNewTokens(std::shared_ptr<InferenceRequest> newReq);
static std::shared_ptr<std::vector<TokenIdType>> getReqInputTokens(std::shared_ptr<InferenceRequest> newReq);
static SizeType getMaxNewTokens(std::shared_ptr<InferenceRequest> newReq);
GetInferenceRequestsCallback mGetInferenceRequestsCb;
SendResponseCallback mSendResponseCb;
PollStopSignalCallback mPollStopSignalCb;
ReturnBatchManagerStatsCallback mReturnBatchManagerStatsCb;
std::shared_ptr<TrtGptModel> mTrtGptModel;
SizeType mMaxInputLen;
SizeType mMaxOutputLen;
SizeType mMaxKvCacheLen;
SizeType mMaxNumSequences;
std::optional<uint64_t> mTerminateReqId;
std::optional<SizeType> mMaxDraftTokens;
// Iteration counter - incremented every iteration of the generation loop
int64_t mIterationCounter;
@ -96,16 +106,14 @@ private:
RequestList mActiveRequests;
// IDs of live requests
std::set<uint64_t> mActiveRequestsIds;
// Boolean that controls if prompt should be included in output tokens for non-streaming
bool mExcludeInputInOutput;
GetInferenceRequestsCallback mGetInferenceRequestsCb;
SendResponseCallback mSendResponseCb;
PollStopSignalCallback mPollStopSignalCb;
ReturnBatchManagerStatsCallback mReturnBatchManagerStatsCb;
std::atomic<bool> destructor_called_;
std::atomic<bool> shutdown_requested_;
void decoupled_execution_loop();
std::shared_ptr<std::thread> worker_thread_;
inline static const std::string kInputIdsTensorName_ = "input_ids";
inline static const std::string kDraftInputIdsTensorName_ = "draft_input_ids";
inline static const std::string kMaxNewTokensTensorName_ = "request_output_len";
inline static const std::string kBeamWidthTensorName_ = "beam_width";
inline static const std::string kEndIdTensorName_ = "end_id";

View File

@ -20,26 +20,44 @@
namespace tensorrt_llm::batch_manager
{
struct NamedTensor
template <typename TTensor>
struct GenericNamedTensor
{
using TensorPtr = tensorrt_llm::runtime::ITensor::SharedPtr;
using TensorPtr = TTensor;
TensorPtr tensor;
std::string name;
NamedTensor() = default;
~NamedTensor() = default;
GenericNamedTensor() = default;
~GenericNamedTensor() = default;
// Host Tensor constructor
NamedTensor(
GenericNamedTensor(
nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, const void* _data = nullptr);
NamedTensor(TensorPtr _tensor, std::string _name)
GenericNamedTensor(TensorPtr _tensor, std::string _name)
: tensor(std::move(_tensor))
, name(std::move(_name))
{
}
GenericNamedTensor(std::string _name)
: name(std::move(_name))
{
}
};
struct NamedTensor : public GenericNamedTensor<tensorrt_llm::runtime::ITensor::SharedPtr>
{
using Base = GenericNamedTensor<tensorrt_llm::runtime::ITensor::SharedPtr>;
using TensorPtr = Base::TensorPtr;
NamedTensor(
nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, const void* _data = nullptr);
NamedTensor(TensorPtr _tensor, std::string _name)
: Base(_tensor, _name){};
std::vector<int64_t> serialize();
static NamedTensor deserialize(const int64_t* packed);
};

View File

@ -38,33 +38,34 @@
namespace tensorrt_llm::batch_manager
{
class InferenceRequest
template <typename TTensor, typename TTensorMap>
class GenericInferenceRequest
{
public:
using TensorPtr = tensorrt_llm::runtime::ITensor::SharedPtr;
using TensorMap = tensorrt_llm::runtime::StringPtrMap<tensorrt_llm::runtime::ITensor>;
using TensorPtr = TTensor;
using TensorMap = TTensorMap;
InferenceRequest(uint64_t requestId)
GenericInferenceRequest(uint64_t requestId)
: mRequestId(requestId)
, mIsStreaming(false)
{
}
InferenceRequest(TensorMap const& inputTensors, uint64_t requestId)
GenericInferenceRequest(TensorMap const& inputTensors, uint64_t requestId)
: mInputTensors(inputTensors)
, mRequestId(requestId)
, mIsStreaming(false)
{
}
InferenceRequest(TensorMap&& inputTensors, uint64_t requestId)
GenericInferenceRequest(TensorMap&& inputTensors, uint64_t requestId)
: mInputTensors(std::move(inputTensors))
, mRequestId(requestId)
, mIsStreaming(false)
{
}
~InferenceRequest() {}
~GenericInferenceRequest() {}
template <typename T>
std::tuple<bool, T> getScalarValueFromTensor(
@ -139,6 +140,36 @@ public:
return mRequestId;
}
protected:
TensorMap mInputTensors;
uint64_t mRequestId;
bool mIsStreaming;
};
class InferenceRequest : public GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr,
tensorrt_llm::runtime::StringPtrMap<tensorrt_llm::runtime::ITensor>>
{
public:
using Base = GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr,
tensorrt_llm::runtime::StringPtrMap<tensorrt_llm::runtime::ITensor>>;
using TensorPtr = Base::TensorPtr;
using TensorMap = Base::TensorMap;
InferenceRequest(uint64_t requestId)
: Base(requestId)
{
}
InferenceRequest(TensorMap const& inputTensors, uint64_t requestId)
: Base(inputTensors, requestId)
{
}
InferenceRequest(TensorMap&& inputTensors, uint64_t requestId)
: Base(inputTensors, requestId)
{
}
const std::vector<int64_t> serialize() const
{
std::list<int64_t> packed;
@ -184,11 +215,6 @@ public:
ir->setIsStreaming(IsStreaming);
return ir;
}
private:
TensorMap mInputTensors;
uint64_t mRequestId;
bool mIsStreaming;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -168,7 +168,7 @@ class BlockManager
public:
using SizeType = tensorrt_llm::runtime::SizeType;
explicit BlockManager(std::size_t blocksInPool);
explicit BlockManager(SizeType blocksInPool);
void startScheduling();
@ -179,22 +179,22 @@ public:
// Simulate freeing all blocks for that sequence to check impact on number of free blocks
void schedulingFreeAllBlocks(GenerationRequest& sequence);
[[nodiscard]] std::size_t getNumFreeBlocks() const
[[nodiscard]] SizeType getNumFreeBlocks() const
{
return mFreeBlocks.size();
}
[[nodiscard]] std::size_t getNumAllocatedBlocks() const
[[nodiscard]] SizeType getNumAllocatedBlocks() const
{
return mAllocatedBlocks.size();
}
[[nodiscard]] bool hasFreeBlocks(std::size_t numRequired = 1) const
[[nodiscard]] bool hasFreeBlocks(SizeType numRequired = 1) const
{
return getNumFreeBlocks() >= numRequired;
}
[[nodiscard]] bool schedulingHasFreeBlocks(std::size_t numRequired = 1) const
[[nodiscard]] bool schedulingHasFreeBlocks(SizeType numRequired = 1) const
{
return mSchedulingNumFreeBlocks >= numRequired;
}
@ -205,7 +205,7 @@ private:
// List of allocated blocks for each sequences
std::vector<std::vector<KVCacheBlock>> mAllocatedBlocks;
// Used to keep track of number of free blocks during scheduling
std::size_t mSchedulingNumFreeBlocks;
SizeType mSchedulingNumFreeBlocks;
};
class KVCacheManager

View File

@ -36,24 +36,28 @@ enum LlmRequestState_t
REQUEST_STATE_GENERATION_COMPLETE = 3
};
class LlmRequest
template <typename TTensor>
class GenericLlmRequest
{
public:
using SizeType = runtime::SizeType;
using TokenIdType = runtime::TokenIdType;
using RequestIdType = std::uint64_t;
using BeamTokens = std::vector<std::vector<TokenIdType>>;
using VecTokens = std::vector<TokenIdType>;
using VecLogProbs = std::vector<float>;
using TensorPtr = runtime::ITensor::SharedPtr;
using BeamTokens = std::vector<VecTokens>;
using TensorPtr = TTensor;
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<std::vector<TokenIdType>> input_tokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
GenericLlmRequest(RequestIdType requestId, SizeType maxNewTokens,
std::shared_ptr<std::vector<TokenIdType>> inputTokens, runtime::SamplingConfig samplingConfig, bool isStreaming,
std::optional<SizeType> endId = std::nullopt, std::optional<SizeType> padId = std::nullopt,
std::optional<TensorPtr> embeddingBias = std::nullopt, std::optional<TensorPtr> badWordsList = std::nullopt,
std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false)
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false,
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt)
: mRequestId(requestId)
, mPromptLen(input_tokens->size())
, mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens)
, mSamplingConfig(samplingConfig)
, mState(REQUEST_STATE_CONTEXT_INIT)
@ -61,7 +65,7 @@ public:
, mEndId(endId)
, mPadId(padId)
, mBatchSlot(-1)
, mOrigPromptLen(input_tokens->size())
, mOrigPromptLen(inputTokens->size())
, mEmbeddingBias(embeddingBias)
, mBadWordsList(badWordsList)
, mStopWordsList(stopWordsList)
@ -70,10 +74,11 @@ public:
, mReturnLogProbs(returnLogProbs)
, mLogProbs(samplingConfig.beamWidth)
, mCumLogProbs(samplingConfig.beamWidth)
, mDraftTokens(draftTokens.value_or(std::make_shared<VecTokens>()))
{
mMaxSentTokenPos = mPromptLen - 1;
// Scatter the input tokens to other beam
mTokens = std::make_shared<BeamTokens>(mSamplingConfig.beamWidth, *input_tokens);
mTokens = BeamTokens(mSamplingConfig.beamWidth, *inputTokens);
if ((mPromptEmbeddingTable.has_value() && !mPromptVocabSize.has_value())
|| (!mPromptEmbeddingTable.has_value() && mPromptVocabSize.has_value()))
@ -91,7 +96,7 @@ public:
/// @return The number of tokens
SizeType getNumTokens(SizeType beam) const
{
return mTokens->at(beam).size();
return mTokens.at(beam).size();
}
/// @brief Get max number of tokens across all beams
@ -101,7 +106,7 @@ public:
SizeType maxTokens = 0;
for (SizeType beam = 0; beam < mSamplingConfig.beamWidth; ++beam)
{
maxTokens = std::max(maxTokens, static_cast<SizeType>(mTokens->at(beam).size()));
maxTokens = std::max(maxTokens, static_cast<SizeType>(mTokens.at(beam).size()));
}
return maxTokens;
}
@ -112,19 +117,33 @@ public:
/// @return The token index
TokenIdType getToken(SizeType beam, SizeType pos) const
{
return mTokens->at(beam).at(pos);
return mTokens.at(beam).at(pos);
}
/// @brief Get the tokens at a given beam index
/// @param beam The beam index
/// @return A vector of tokens for this beam index, includes the prompt
std::vector<TokenIdType> getTokens(SizeType beam) const
/// @param beam The beam index
/// @return A vector of tokens for this beam index, includes the prompt
std::vector<TokenIdType> const& getTokens(SizeType beam) const
{
return mTokens->at(beam);
return mTokens.at(beam);
}
/// @brief Get the draft tokens
/// @return shared_ptr to vector of draft tokens
std::shared_ptr<std::vector<TokenIdType>> const& getDraftTokens() const
{
return mDraftTokens;
}
/// @brief Returns true if request has draft tokens
/// @return flag
bool hasDraftTokens() const
{
return mDraftTokens && mDraftTokens->size() > 0;
}
/// @brief Get the maximum number of generated tokens among all rays in beam
/// @return The number of generated tokens (doesn't include the prompt tokens)
/// @return The number of generated tokens (doesn't include the prompt tokens)
SizeType getMaxNumGeneratedTokens() const
{
return getMaxBeamNumTokens() - mPromptLen;
@ -135,7 +154,7 @@ public:
/// @param beam The beam to which to add the new token
void addNewToken(TokenIdType token, SizeType beam)
{
mTokens->at(beam).push_back(token);
mTokens.at(beam).push_back(token);
}
/// @brief Add new generated tokens to the vector of tokens
@ -147,7 +166,7 @@ public:
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
{
const auto outputId = beamTokens[beam];
mTokens->at(beam).push_back(outputId);
mTokens.at(beam).push_back(outputId);
}
}
@ -158,7 +177,7 @@ public:
assert(generatedBeamTokens.size() == mSamplingConfig.beamWidth);
for (std::size_t beam = 0; beam < generatedBeamTokens.size(); ++beam)
{
auto& beamTokens = (*mTokens)[beam];
auto& beamTokens = mTokens[beam];
beamTokens.resize(mPromptLen);
beamTokens.insert(beamTokens.end(), generatedBeamTokens[beam].begin(), generatedBeamTokens[beam].end());
}
@ -173,9 +192,9 @@ public:
// As a temporary solution, we currently reset the tokens to the prompt
if (mSamplingConfig.beamWidth > 1)
{
for (std::size_t beam = 0; beam < mTokens->size(); ++beam)
for (std::size_t beam = 0; beam < mTokens.size(); ++beam)
{
auto& beamTokens = mTokens->at(beam);
auto& beamTokens = mTokens.at(beam);
beamTokens.resize(mPromptLen);
if (mReturnLogProbs)
{
@ -186,9 +205,9 @@ public:
else
{
SizeType newPromptLen = std::min(maxInputLen, mPromptLen + getMaxNumGeneratedTokens());
for (std::size_t beam = 0; beam < mTokens->size(); ++beam)
for (std::size_t beam = 0; beam < mTokens.size(); ++beam)
{
auto& beamTokens = mTokens->at(beam);
auto& beamTokens = mTokens.at(beam);
beamTokens.resize(newPromptLen);
if (mReturnLogProbs)
@ -225,21 +244,6 @@ public:
return mPromptEmbeddingTable;
}
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager)
{
if (!mPromptEmbeddingTable.has_value()
|| mPromptEmbeddingTable.value()->getMemoryType() == runtime::MemoryType::kGPU)
{
return;
}
else
{
TensorPtr gpuPromptEmbeddingTable
= manager.copyFrom(*mPromptEmbeddingTable.value(), runtime::MemoryType::kGPU);
mPromptEmbeddingTable = gpuPromptEmbeddingTable;
}
}
std::optional<SizeType> getPromptVocabSize() const
{
return mPromptVocabSize;
@ -296,6 +300,11 @@ public:
return mOrigPromptLen;
}
void setDraftTokens(const std::shared_ptr<VecTokens>& draftTokens)
{
mDraftTokens = draftTokens;
}
RequestIdType mRequestId;
SizeType mPromptLen;
SizeType mMaxNewTokens;
@ -307,9 +316,9 @@ public:
std::optional<SizeType> mPadId;
SizeType mBatchSlot;
private:
protected:
SizeType mOrigPromptLen;
std::shared_ptr<BeamTokens> mTokens;
BeamTokens mTokens;
SizeType mMaxSentTokenPos;
std::optional<TensorPtr> mEmbeddingBias;
@ -323,6 +332,47 @@ private:
std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen]
VecLogProbs mCumLogProbs; // [beamSize]
std::shared_ptr<VecTokens> mDraftTokens;
};
class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
{
public:
using Base = GenericLlmRequest<runtime::ITensor::SharedPtr>;
using TensorPtr = Base::TensorPtr;
using SizeType = Base::SizeType;
using TokenIdType = Base::TokenIdType;
using RequestIdType = Base::RequestIdType;
using VecLogProbs = Base::VecLogProbs;
using BeamTokens = Base::BeamTokens;
using VecTokens = Base::VecTokens;
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<std::vector<TokenIdType>> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false,
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt)
: Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId, embeddingBias,
badWordsList, stopWordsList, promptEmbeddingTable, promptVocabSize, returnLogProbs, draftTokens)
{
}
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager)
{
if (!mPromptEmbeddingTable.has_value()
|| mPromptEmbeddingTable.value()->getMemoryType() == runtime::MemoryType::kGPU)
{
return;
}
else
{
TensorPtr gpuPromptEmbeddingTable
= manager.copyFrom(*mPromptEmbeddingTable.value(), runtime::MemoryType::kGPU);
mPromptEmbeddingTable = gpuPromptEmbeddingTable;
}
}
};
} // namespace tensorrt_llm::batch_manager

View File

@ -50,6 +50,8 @@ public:
TensorPtr endIds; // [batchSize * beamWidth], on gpu
// optional parameters
TensorPtr finished; // [batchSize, beamWidth], finished states at current iteration.
// If true for some request, the decoding step of it is skipped, on gpu
TensorPtr sequenceLimitLength; // [batchSize], on gpu
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
TensorPtr lengths; // [batchSize, beamWidth], on gpu

View File

@ -64,11 +64,19 @@ public:
// mandatory parameters
TensorPtr ids; // [batchSize, beamWidth, maxSeqLen], on gpu, must contain previously generated token ids for all
// steps before DecodingInput.step
TensorPtr newTokens; // [batchSize, beamWidth] on gpu.
TensorPtr newTokensSteps; // [maxTokensPerStep, batchSize, beamWidth] new tokens at each generated token of
// maxTokensPerStep, on gpu.
TensorPtr newTokens; // [batchSize, beamWidth] usually a view of newTokensSteps for the current token, on gpu.
std::vector<TensorPtr> newTokensVec; // vector of size maxTokensPerStep with tensor [batchSize, beamWidth].
// Vector of views on newTokensSteps for each token. Elements are on gpu.
// optional parameters
TensorPtr finished; // [batchSize, beamWidth], mandatory in beam search and to determine whether to stop
// according to DecodingInput.sequenceLimitLength, on gpu
TensorPtr finishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states at each generated token of
// maxTokensPerStep, on gpu
TensorPtr finished; // [batchSize, beamWidth], usually a view of finishedSteps for current token.
// Set to true by decoding if any of the stop conditions are met or if DecodingInput.finished is
// true. In beam search and to determine whether to stop according to
// DecodingInput.sequenceLimitLength, on gpu
TensorPtr finishedSum; // [1], the sum of finished sequences, in pinned memory
// mandatory parameters for beam search

View File

@ -55,6 +55,10 @@ public:
DecodingInput const& decodingInput, BufferManager const& manager)
= 0;
static void acceptTokens(const ITensor& targetTokenIds, const ITensor& draftTokenIds, const ITensor& contextLengths,
const ITensor& numDraftTokens, ITensor& sequenceLengths, const ITensor& finishedVec, ITensor& finishedFinal,
ITensor& finishedSum, BufferManager::CudaStreamPtr const& stream);
static std::unique_ptr<IGptDecoder> create(
nvinfer1::DataType dtype, size_t vocabSize, size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream);
};

View File

@ -40,13 +40,13 @@ class GptDecoderBatch : public IGptDecoderBatch
{
public:
using CudaStreamPtr = std::shared_ptr<CudaStream>;
using TensorPtr = std::shared_ptr<ITensor>;
using TensorPtr = ITensor::SharedPtr;
GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
//! Setup the decoder before calling `forward()`
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength,
nvinfer1::DataType dtype) override;
SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
void newRequest(
@ -69,6 +69,7 @@ public:
return {mFinished.begin(), mFinished.begin() + mActualBatchSize};
}
//! @param batchIdx index of the batch
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without
//! padding for request `batchIdx`, on gpu
[[nodiscard]] TensorPtr getOutputIds(SizeType batchIdx) const override
@ -99,18 +100,6 @@ public:
return ITensor::slice(mJointDecodingOutput->parentIds, 0, mActualBatchSize);
}
//! @returns [batchSize, maxBeamWidth], marks finished requests (per beam), on gpu
[[nodiscard]] TensorPtr getFinishedBeams() const override
{
return ITensor::slice(mJointDecodingOutput->finished, 0, mActualBatchSize);
}
//! @returns [batchSize, maxBeamWidth], total sequence lengths (per beam), on gpu
[[nodiscard]] TensorPtr getOutputLengths() const override
{
return ITensor::slice(mJointDecodingOutput->lengths, 0, mActualBatchSize);
}
//! @returns [batchSize, maxBeamWidth], cumulative log probabilities (per beam), on gpu
[[nodiscard]] TensorPtr getCumLogProbs() const override
{
@ -139,10 +128,21 @@ public:
return tensor;
}
//! @returns [batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
[[nodiscard]] TensorPtr getNewTokens() const override
//! @brief Get maxTokensPerStep tokens generated in the last forward pass
//! @returns [maxTokensPerStep, batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
[[nodiscard]] TensorPtr getAllNewTokens() const override
{
return ITensor::slice(mJointDecodingOutput->newTokens, 0, mActualBatchSize);
return mJointDecodingOutput->newTokensSteps;
}
//! @brief Get tokens generated in one step of last forward pass
//! @param iter The iteration within [0; maxTokensPerStep) for which to get the tokens
//! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu
[[nodiscard]] TensorPtr getNewTokens(SizeType iter = 0) const override
{
TensorPtr newTokensView = std::move(ITensor::slice(mJointDecodingOutput->newTokensSteps, iter, 1));
newTokensView->squeeze(0);
return ITensor::slice(newTokensView, 0, mActualBatchSize);
}
//! @returns [batchSize], the number of generation steps executed on each request
@ -180,13 +180,18 @@ private:
DecodingInputPtr mJointDecodingInput;
DecodingOutputPtr mJointDecodingOutput;
std::vector<TensorPtr> mDraftTokenIds;
TensorPtr mNumDraftTokens;
std::vector<SizeType> mNbSteps;
std::vector<bool> mFinished;
TensorPtr mFinishedSum;
std::vector<SizeType> mMaxNewTokens;
std::vector<SizeType> mBeamWidths;
std::vector<SizeType> mGeneratedTokensPerStep;
SizeType mMaxSequenceLength{};
SizeType mMaxKvCacheLength{};
SizeType mActualBatchSize{};
SizeType mMaxTokensPerStep{};
};
} // namespace tensorrt_llm::runtime

View File

@ -54,6 +54,7 @@ public:
, mModelVariant(ModelVariant::kGpt)
, mUseCustomAllReduce(false)
, mMaxPromptEmbeddingTableSize(0)
, mMaxDraftLen(0)
{
}
@ -253,6 +254,16 @@ public:
mUseCustomAllReduce = customAllReduce;
}
void constexpr setMaxDraftLen(SizeType maxDraftLen) noexcept
{
mMaxDraftLen = maxDraftLen;
}
[[nodiscard]] SizeType constexpr getMaxTokensPerStep() const noexcept
{
return mMaxDraftLen + 1;
}
private:
SizeType mVocabSize;
SizeType mNbLayers;
@ -276,6 +287,7 @@ private:
bool mUseCustomAllReduce;
SizeType mMaxPromptEmbeddingTableSize;
SizeType mMaxDraftLen;
};
} // namespace tensorrt_llm::runtime

View File

@ -21,6 +21,7 @@
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
#include <cstdint>
#include <memory>
@ -35,12 +36,14 @@ namespace decoder_batch
class Request
{
public:
using ConstTensorPtr = std::shared_ptr<ITensor const>;
using TensorPtr = std::shared_ptr<ITensor>;
using ConstTensorPtr = ITensor::SharedConstPtr;
using TensorPtr = ITensor::SharedPtr;
using BufferPtr = IBuffer::SharedPtr;
explicit Request(ConstTensorPtr ids, std::optional<SizeType> maxNewTokens = std::nullopt,
explicit Request(ConstTensorPtr ids, SizeType inputLen, std::optional<SizeType> maxNewTokens = std::nullopt,
std::optional<SizeType> endId = std::nullopt)
: ids{std::move(ids)}
, inputLen(inputLen)
, maxNewTokens{maxNewTokens}
, endId{endId}
, computeCumLogProbs(false)
@ -48,42 +51,68 @@ public:
{
}
// the number of tokens generated per step
SizeType generatedTokensPerStep() const
{
return draftTokens ? draftTokens->getSize() + 1 : 1;
}
// mandatory parameters
ConstTensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu
SizeType inputLen; // the input length without draft tokens
// optional parameters
std::optional<SizeType> maxNewTokens; // maximum number of tokens to generate for this request
std::optional<SizeType> endId; // end token id
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
TensorPtr badWordsList; // [2, badWordsLength], on gpu
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
BufferPtr draftTokens; // [generatedTokensPerStep - 1], on gpu, draft tokens from speculative decoding
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
TensorPtr badWordsList; // [2, badWordsLength], on gpu
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request
bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request
bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request
bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request
};
class Input : public decoder::Input
class Input
{
public:
using Base = decoder::Input;
using TensorConstPtr = ITensor::SharedConstPtr;
using TensorPtr = ITensor::SharedPtr;
explicit Input(TensorPtr logits)
: Base{std::move(logits)}
{
auto const batchSize = this->logits->getShape().d[0];
active.resize(batchSize, true);
}
explicit Input(TensorPtr logits, std::vector<bool> const& active)
: Base{std::move(logits)}
explicit Input(std::vector<TensorConstPtr> const& logits, std::vector<bool> const& active)
: logits{logits}
, active{active}
{
auto const batchSize = static_cast<std::size_t>(this->logits->getShape().d[0]);
TLLM_CHECK_WITH_INFO(this->active.size() == batchSize, "'active' vector size does not match logits batchSize");
TLLM_CHECK_WITH_INFO(
this->active.size() == logits.size(), "'active' vector size does not match logits vector size");
}
explicit Input(std::vector<TensorConstPtr> const& logits)
: Input{logits, std::vector<bool>(logits.size(), true)}
{
}
explicit Input(std::vector<TensorPtr> const& logits, std::vector<bool> const& active)
: Input{
utils::transformVector(logits, [](auto& x) { return std::const_pointer_cast<ITensor const>(x); }), active}
{
}
explicit Input(std::vector<TensorPtr> const& logits)
: Input{logits, std::vector<bool>(logits.size(), true)}
{
}
// mandatory parameters
std::vector<TensorConstPtr>
logits; // batchSize * [1, beamWidth, vocabSizePadded] or [generatedTokensPerStep, 1, vocabSizePadded], on gpu
// control activity of decoder slots in batch
std::vector<bool> active; // [batchSize]
// parameters for beam search
TensorConstPtr cacheIndirection; // [batchSize, maxBeamWidth, maxSeqLen] - indices into KV cache of different rays
// within one beam for beam search, on gpu
};
using Output = decoder::Output;
@ -127,6 +156,7 @@ public:
forwardSync(*forwardAsync(output, input));
}
//! @param batchIdx index of the batch
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
//! ids without padding for request `batchIdx`, on gpu
virtual TensorPtr getOutputIds(SizeType batchIdx) const = 0;
@ -135,12 +165,6 @@ public:
//! Result will only be available after event returned
virtual CudaEvent finalize(SizeType batchIdx) const = 0;
//! @returns [batchSize, beamWidth], marks finished requests (per beam), on gpu
virtual TensorPtr getFinishedBeams() const = 0;
//! @returns [batchSize, beamWidth], total sequence lengths (per beam), on gpu
virtual TensorPtr getOutputLengths() const = 0;
//! @returns [batchSize (actual)], marks finished requests (per batch)
virtual std::vector<bool> getFinished() const = 0;

View File

@ -75,7 +75,7 @@ public:
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
virtual void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
SizeType maxSequenceLength, nvinfer1::DataType dtype)
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
= 0;
//! @brief Initialize the decoder with new batch of inputs.
@ -108,8 +108,14 @@ public:
//! @returns [batchSize, maxBeamWidth, maxSequenceLength], log probabilities (per beam), on gpu
virtual TensorPtr getLogProbs() const = 0;
//! @returns [batchSize, beamWidth], latests generated tokens (per beam), on gpu
virtual TensorPtr getNewTokens() const = 0;
//! @brief Get tokens generated in one step of last forward pass
//! @param iter The iteration within [0; maxTokensPerStep) for which to get the tokens
//! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu
virtual TensorPtr getNewTokens(SizeType iter = 0) const = 0;
//! @brief Get maxTokensPerStep tokens generated in the last forward pass
//! @returns [maxTokensPerStep, batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
virtual TensorPtr getAllNewTokens() const = 0;
//! @returns [1], number of finished sequences, in pinned host memory
virtual TensorPtr getNbFinished() const = 0;

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b867c2e048671eecc421244d436436782093baf02f0fd5d49232b3d3042e55ea
size 1688216
oid sha256:9e6a5d7dba399049a4da9ca729153e5a6080986782a314b867e7635454eb36de
size 1705954

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:db433a13ec6a017638bbb97b53a98624ad675b395787c99054d48ab370f5e3a0
size 1697778
oid sha256:64fae7bca97be7c3067b4544da0c3d79621ec3632c10e39b7a005d886702e8eb
size 1706098

View File

@ -1,3 +1,3 @@
81f472ac2b68edd03a0265299744347f libtensorrt_llm_batch_manager_static.a
4e5e3bbdfffa6deb6a50c541a946ac7a libtensorrt_llm_batch_manager_static.pre_cxx11.a
7edd8a21 commit
02375d908e57e2194e3f28a4e83dd963 libtensorrt_llm_batch_manager_static.a
e5c4994ecc347d808f6d38fb686b5cf1 libtensorrt_llm_batch_manager_static.pre_cxx11.a
cd83045d7c127af5f907efd7c710bd6fe1f90ec4 commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:681917aea11f45d83ba1429ded44ced97cb8ce5f54eb1c3fb3055bc342f0ffbf
size 1600734
oid sha256:f3cca913fc62df4119e4df10921be97086714740148f54c528da7bb2826f67ba
size 1617426

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:d59b04e3229358ec2d9476b07f0361aa4a8539e543312c8952b690173040663d
size 1598666
oid sha256:3d633874e8b32a56758bf8bbdc0955ed8c5d43d531ba330a2274bcec13e1c89f
size 1620144

View File

@ -1,2 +1,2 @@
c9d5678a2ec347188457ad4a3a59d483 libtensorrt_llm_batch_manager_static.a
53261d576d540ab330f2f2e1f8d99677 libtensorrt_llm_batch_manager_static.pre_cxx11.a
f379e62b3f69afa4bd1d8e5551a6ede4 libtensorrt_llm_batch_manager_static.a
92f44b2834d39c2c62a9b0bd0549b159 libtensorrt_llm_batch_manager_static.pre_cxx11.a

View File

@ -503,5 +503,71 @@ void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_ti
output_log_probs, output_log_probs_tiled, sequence_lengths, batch_size, beam_width, max_seq_len);
}
__global__ void acceptTokensKernel(const int* draft_tokens, const int* target_tokens, const int* context_lengths,
const int* nums_draft_tokens, int* sequence_lengths, const bool* finished, bool* finished_final, int* finished_sum,
int batch_size, int beam_width, int max_seq_len, int max_draft_tokens)
{
int thread_finished_count = 0;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * beam_width;
index += blockDim.x * gridDim.x)
{
const auto num_draft_tokens = nums_draft_tokens[index];
const auto context_length = context_lengths[index];
auto& sequence_length = sequence_lengths[index];
int finished_draft_idx = 0;
for (int ti = context_length; ti < min(sequence_length, context_length + num_draft_tokens);
++ti, ++finished_draft_idx)
{
const auto draft_idx = ti - context_length;
const auto target_token_idx = index * max_seq_len + ti;
const auto draft_token_idx = index * max_draft_tokens + draft_idx;
// Check if draft tokens are the same as target tokens
// FIXME(nkorobov); compare logits here
const bool accepted = draft_tokens[draft_token_idx] == target_tokens[target_token_idx];
if (!accepted)
{
// Set sequence length to the numAcceptedTokens + 1
sequence_length = min(ti + 1, max_seq_len);
// FIXME(nkorobov): do we need to set endIds here?
break;
}
}
bool finish = finished[finished_draft_idx * batch_size * beam_width + index];
finished_final[index] = finish;
thread_finished_count += static_cast<int>(finish);
}
if (finished_sum)
{
int block_finished_count = 0;
if (blockDim.x <= 32)
{
block_finished_count = warpReduceSum(thread_finished_count);
}
else
{
block_finished_count = blockReduceSum(thread_finished_count);
}
__syncthreads();
if (threadIdx.x == 0)
{
finished_sum[0] = block_finished_count;
}
}
}
void invokeAcceptTokens(const int* draft_tokens, const int* target_tokens, const int* context_lengths,
const int* nums_draft_tokens, int* sequence_lengths, const bool* finished, bool* finished_final, int* finished_sum,
int batch_size, int beam_width, int max_seq_len, int max_draft_tokens, cudaStream_t stream)
{
dim3 block(min(256, batch_size * beam_width));
dim3 grid(1);
acceptTokensKernel<<<grid, block, 0, stream>>>(draft_tokens, target_tokens, context_lengths, nums_draft_tokens,
sequence_lengths, finished, finished_final, finished_sum, batch_size, beam_width, max_seq_len,
max_draft_tokens);
}
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -65,5 +65,9 @@ void invokeCopyNextStepIds(int* next_step_ids, int** output_ids_ptr, const int*
void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_tiled, const int* sequence_lengths,
int batch_size, int beam_width, int max_seq_len, cudaStream_t stream);
void invokeAcceptTokens(const int* draft_tokens, const int* target_tokens, const int* context_lengths,
const int* nums_draft_tokens, int* sequence_lengths, const bool* finished, bool* finished_final, int* finished_sum,
int batch_size, int beam_width, int max_seq_len, int max_draft_tokens, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -126,7 +126,7 @@ __global__ void computePaddingOffsets(int* paddingOffsets, const int* seqOffsets
// Iterate over the tokens to update the number of padded elements.
for (int tokenIdx = threadIdx.x; tokenIdx < seqLength; tokenIdx += blockDim.x)
{
paddingOffsets[seqBegin + tokenIdx] = paddingOffset + max(0, tokenIdx - seqLength);
paddingOffsets[seqBegin + tokenIdx] = paddingOffset;
}
}
@ -152,7 +152,7 @@ __global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const
int seqLength = seqEnd - seqBegin;
// Iterate over the tokens to update the number of padded elements.
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < maskSize; idx += blockDim.x)
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < maskSize; idx += gridDim.x * blockDim.x)
{
// The position in the matrix.
int rowIdx = idx / maxSeqLength;
@ -235,8 +235,7 @@ void invokeBuildDecoderInfo(const BuildDecoderInfoParams<T>& params, cudaStream_
// Compute the attention mask, if needed.
if (params.attentionMask != nullptr)
{
// large value like 512 hurts kernel perf at long sequence length. Keep small for now.
const int MIN_BLOCKS = 16;
const int MIN_BLOCKS = 512;
int blocksPerSeq = 16;
while (blocksPerSeq * params.batchSize < MIN_BLOCKS)
{

View File

@ -195,9 +195,9 @@ __global__ void topKStage1(const T* __restrict logProbs, T* tmpLogProbs, int* to
template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
__global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTmpValBuf, int** ids,
int* sequenceLengths, bool* finished, float* cumLogProbs, float* outputLogProbs, const int maxTopK,
const int* topKs, const float topP, const float* topPs, curandState_t* curandstate, const int* endIds,
const int vocabSize, const bool* skipDecode)
int* sequenceLengths, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
const int maxTopK, const int* topKs, const float topP, const float* topPs, curandState_t* curandstate,
const int* endIds, const int vocabSize, const bool* skipDecode)
{
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
@ -226,8 +226,12 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
}
TopK_2<float> partial;
if (finished != nullptr && finished[batchId] == true)
if (finishedInput != nullptr && finishedInput[batchId] == true)
{
if (finishedOutput != nullptr)
{
finishedOutput[batchId] = true;
}
return;
}
@ -300,18 +304,18 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
break;
}
}
if (sequenceLengths != nullptr && finished != nullptr)
if (sequenceLengths != nullptr && finishedOutput != nullptr)
{
const int seqLen = sequenceLengths[batchId];
if (ids[batchId][seqLen] == endIds[batchId])
{
finished[batchId] = true;
finishedOutput[batchId] = true;
// Do not increase seq len when EOS is generated. Seq len should always contain only tokens to be
// outputted
}
else
{
finished[batchId] = false;
finishedOutput[batchId] = false;
sequenceLengths[batchId] += 1;
}
}
@ -319,19 +323,20 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
}
#define CASE_K(K_MAX, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \
topKStage1<T, BLOCK_SIZE_1_, BLOCKS_PER_BEAM_><<<batchSize * BLOCKS_PER_BEAM_, BLOCK_SIZE_1_, 0, stream>>>( \
logProbs, tempLogProbs, topKTmpIdBuf, topKTmpValBuf, finished, maxTopK, topKs, vocabSize, endIds, skipDecode); \
topKStage1<T, BLOCK_SIZE_1_, BLOCKS_PER_BEAM_> \
<<<batchSize * BLOCKS_PER_BEAM_, BLOCK_SIZE_1_, 0, stream>>>(logProbs, tempLogProbs, topKTmpIdBuf, \
topKTmpValBuf, finishedInput, maxTopK, topKs, vocabSize, endIds, skipDecode); \
topKStage2Sampling<T, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_> \
<<<batchSize, BLOCK_SIZE_2_, K_MAX * sizeof(int) + K_MAX * sizeof(float), stream>>>(topKTmpIdBuf, \
topKTmpValBuf, ids, sequenceLengths, finished, cumLogProbs, outputLogProbs, maxTopK, topKs, topP, topPs, \
curandstate, endIds, vocabSize, skipDecode); \
topKTmpValBuf, ids, sequenceLengths, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, maxTopK, \
topKs, topP, topPs, curandstate, endIds, vocabSize, skipDecode); \
break;
template <typename T>
void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths,
bool* finished, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int maxTopK,
const int* topKs, const float topP, const float* topPs, const int vocabSizePadded, const int* endIds,
cudaStream_t stream, const int batchSize, const bool* skipDecode)
const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
@ -386,35 +391,35 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo
#undef CASE_K
template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const float* logProbs, int** ids,
int* sequenceLengths, bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate,
const int maxTopK, const int* topKs, const float topP, const float* topPs, const int vocabSizePadded,
const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode);
int* sequenceLengths, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode);
template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids,
int* sequenceLengths, bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate,
const int maxTopK, const int* topKs, const float topP, const float* topPs, const int vocabSizePadded,
const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode);
int* sequenceLengths, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode);
template <typename T>
void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths,
bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int topK,
const float topP, const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize,
const bool* skipDecode)
const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
cudaStream_t stream, const int batchSize, const bool* skipDecode)
{
invokeBatchTopKSampling(workspace, workspaceSize, logProbs, ids, sequenceLengths, finished_buf, cumLogProbs,
outputLogProbs, curandstate, topK, nullptr, topP, nullptr, vocabSizePadded, endIds, stream, batchSize,
skipDecode);
invokeBatchTopKSampling(workspace, workspaceSize, logProbs, ids, sequenceLengths, finishedInput, finishedOutput,
cumLogProbs, outputLogProbs, curandstate, topK, nullptr, topP, nullptr, vocabSizePadded, endIds, stream,
batchSize, skipDecode);
}
template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const float* logProbs, int** ids,
int* sequenceLengths, bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate,
const int topK, const float topP, const int vocabSizePadded, const int* endIds, cudaStream_t stream,
const int batchSize, const bool* skipDecode);
int* sequenceLengths, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
cudaStream_t stream, const int batchSize, const bool* skipDecode);
template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids,
int* sequenceLengths, bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate,
const int topK, const float topP, const int vocabSizePadded, const int* endIds, cudaStream_t stream,
const int batchSize, const bool* skipDecode);
int* sequenceLengths, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
cudaStream_t stream, const int batchSize, const bool* skipDecode);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -37,8 +37,8 @@ namespace kernels
//! logProbs must contain **just** probabilities instead of log probabilities.
//! \param outputIds output buffer [batchSize][maxSeqLen]. Contains pointers to rows with output tokens per request
//! \param sequenceLength input/output buffer [batchSize]. Current sequence length of the request up to, but excluding endId token
//! \param finishedBuf input/output buffer [batchSize]. Flag if sequence has finished (if finished || outputId == endId).
//! If true, request exits early.
//! \param finishedInput input buffer [batchSize]. If true, request exits early.
//! \param finishedOutput output buffer [batchSize]. Set flag if sequence has finished (if finished || outputId == endId).
//! \param cumLogProbs input/output buffer [batchSize]. Cumulative log probability of selected tokens. Ignored if nullptr
//! \param outputLogProbs output buffer [batchSize]. Log probs is the probability induced by the top-k sampling.
//! We normalize the probability 'expLogit' of the selected token by the probability 's_sum' of a set of top-k
@ -60,16 +60,16 @@ namespace kernels
// clang-format on
template <typename T>
void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths,
bool* finished, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int maxTopK,
const int* topKs, const float topP, const float* topPs, const int vocabSizePadded, const int* endIds,
cudaStream_t stream, const int batchSize, const bool* skipDecode);
const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode);
//! \brief Specialization of invokeBatchTopKSampling with topPs=nullptr and topKs=nullptr
template <typename T>
void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** outputIds, int* sequenceLength,
bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int topK,
const float topP, const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize,
const bool* skipDecode);
const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
cudaStream_t stream, const int batchSize, const bool* skipDecode);
//! \brief Initialize batchSize curand states with given seed.
//!

View File

@ -159,7 +159,7 @@ struct BlockPrefixCallbackOp
template <typename T>
__device__ void epilogue(int batchId, int currentStep, int offset, int** ids, int* sortedIdVals, T* sortedLogProbs,
float* cumLogProbs, float* outputLogProbs, const int* endIds, int* sequenceLengths, bool* finishedBuf)
float* cumLogProbs, float* outputLogProbs, const int* endIds, int* sequenceLengths, bool* finishedOutput)
{
ids[batchId][currentStep] = sortedIdVals[offset];
@ -175,26 +175,26 @@ __device__ void epilogue(int batchId, int currentStep, int offset, int** ids, in
outputLogProbs[batchId] = lprob;
}
}
if (sequenceLengths != nullptr && finishedBuf != nullptr)
if (sequenceLengths != nullptr && finishedOutput != nullptr)
{
if (ids[batchId][currentStep] == endIds[batchId])
{
finishedBuf[batchId] = true;
finishedOutput[batchId] = true;
// Do not increase seq len when EOS is generated. Seq len should always contain only tokens to be outputted
}
else
{
finishedBuf[batchId] = false;
finishedOutput[batchId] = false;
sequenceLengths[batchId] += 1;
}
}
}
template <typename T, int blockSize>
__global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, int* sequenceLength, bool* finishedBuf,
float* cumLogProbs, float* outputLogProbs, const int* beginOffsetBuf, const int* offsetBuf, const int vocabSize,
curandState_t* curandstate, const float topP, const float* topPs, const int* endIds, const int batchSize,
const bool* skipDecode)
__global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, int* sequenceLength,
const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
const int* beginOffsetBuf, const int* offsetBuf, const int vocabSize, curandState_t* curandstate, const float topP,
const float* topPs, const int* endIds, const int batchSize, const bool* skipDecode)
{
/**
* Each block processes one request row sorted in descending order by probabilities.
@ -214,8 +214,12 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i
}
// Exit early if sequence has finished
if (finishedBuf != nullptr && finishedBuf[batchId] == true)
if (finishedInput != nullptr && finishedInput[batchId] == true)
{
if (finishedOutput != nullptr)
{
finishedOutput[batchId] = true;
}
ids[batchId][sequenceLength[batchId]] = endIds[batchId];
return;
}
@ -244,7 +248,7 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i
{
int offset = batchId * vocabSize;
epilogue(batchId, currentStep, offset, ids, sortedIdVals, sortedLogProbs, cumLogProbs, outputLogProbs,
endIds, sequenceLength, finishedBuf);
endIds, sequenceLength, finishedOutput);
}
return;
}
@ -285,16 +289,16 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i
if (threadIdx.x == min(blockDim.x - count, blockDim.x - 1))
{
epilogue(batchId, currentStep, offset + selectedTokenId, ids, sortedIdVals, sortedLogProbs, cumLogProbs,
outputLogProbs, endIds, sequenceLength, finishedBuf);
outputLogProbs, endIds, sequenceLength, finishedOutput);
}
}
template <typename T>
void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, const T* logProbs,
const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, const int batchSize,
const size_t vocabSizePadded, const int* endIds, const float maxTopP, const float* topPs, cudaStream_t stream,
const bool* skipDecode)
int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
const T* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float maxTopP, const float* topPs,
cudaStream_t stream, const bool* skipDecode)
{
// Here, we put batch size as an argument because the batch size of
// initialization and inference may be different due to pipeline parallelism.
@ -338,42 +342,45 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub
dim3 grid(batchSize);
// Sample with Top P given sorted tokens
topPSsampling<T, SAMPLING_BLOCK_SIZE><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(sortedLogProbs, sortedIdVals,
outputIds, sequenceLength, finishedBuf, cumLogProbs, outputLogProbs, beginOffsetBuf, offsetBuf + 1, vocabSize,
curandstate, maxTopP, topPs, endIds, batchSize, skipDecode);
outputIds, sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, beginOffsetBuf,
offsetBuf + 1, vocabSize, curandstate, maxTopP, topPs, endIds, batchSize, skipDecode);
}
template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize,
int** outputIds, int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs,
const float* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float maxTopP, const float* topPs,
cudaStream_t stream, const bool* skipDecode);
int** outputIds, int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs,
float* outputLogProbs, const float* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf,
curandState_t* curandstate, const int batchSize, const size_t vocabSizePadded, const int* endIds,
const float maxTopP, const float* topPs, cudaStream_t stream, const bool* skipDecode);
template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize,
int** outputIds, int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs,
const half* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float maxTopP, const float* topPs,
cudaStream_t stream, const bool* skipDecode);
int** outputIds, int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs,
float* outputLogProbs, const half* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf,
curandState_t* curandstate, const int batchSize, const size_t vocabSizePadded, const int* endIds,
const float maxTopP, const float* topPs, cudaStream_t stream, const bool* skipDecode);
template <typename T>
void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, const T* logProbs,
const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, const int batchSize,
const size_t vocabSizePadded, const int* endIds, const float topP, cudaStream_t stream, const bool* skipDecode)
int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
const T* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float topP, cudaStream_t stream,
const bool* skipDecode)
{
invokeBatchTopPSampling(workspace, workspaceSize, cubTempStorageSize, outputIds, sequenceLength, finishedBuf,
cumLogProbs, outputLogProbs, logProbs, idVals, offsetBuf, beginOffsetBuf, curandstate, batchSize,
vocabSizePadded, endIds, topP, nullptr, stream, skipDecode);
invokeBatchTopPSampling(workspace, workspaceSize, cubTempStorageSize, outputIds, sequenceLength, finishedInput,
finishedOutput, cumLogProbs, outputLogProbs, logProbs, idVals, offsetBuf, beginOffsetBuf, curandstate,
batchSize, vocabSizePadded, endIds, topP, nullptr, stream, skipDecode);
}
template void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, const float* logProbs,
const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, const int batchSize,
const size_t vocabSizePadded, const int* endIds, const float topP, cudaStream_t stream, const bool* skipDecode);
int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
const float* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float topP, cudaStream_t stream,
const bool* skipDecode);
template void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, const half* logProbs,
const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, const int batchSize,
const size_t vocabSizePadded, const int* endIds, const float topP, cudaStream_t stream, const bool* skipDecode);
int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
const half* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float topP, cudaStream_t stream,
const bool* skipDecode);
template <typename T>
__global__ void addBiasSoftMax(

View File

@ -64,9 +64,9 @@ the
//! \param outputIds output buffer [batchSize][maxSeqLen]. Contains pointers to rows with output tokens per request.
//! \param sequenceLength input/output buffer [batchSize]. Current sequence length of the request up to, but excluding
endId token.
//! \param finishedBuf input/output buffer [batchSize]. Flag if sequence has finished (if finished || outputId ==
//! \param finishedInput input buffer [batchSize]. Exit early if true.
//! \param finishedOutput output buffer [batchSize]. Set flag if sequence has finished (if finished || outputId ==
endId).
//! If true, request exits early.
//! \param cumLogProbs input/output buffer [batchSize]. Cumulative log probability of selected tokens. Ignored if
nullptr.
//! \param outputLogProbs output buffer [batchSize]. Log probs is the probability
@ -94,17 +94,18 @@ nullptr.
*/
template <typename T>
void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, const T* logProbs,
const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, const int batchSize,
const size_t vocabSizePadded, const int* endIds, const float maxTopP, const float* topPs, cudaStream_t stream,
const bool* skipDecode);
int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
const T* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float maxTopP, const float* topPs,
cudaStream_t stream, const bool* skipDecode);
//! \brief Specialization of invokeBatchTopPSampling with topPs=nullptr
template <typename T>
void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, const T* logProbs,
const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, const int batchSize,
const size_t vocabSizePadded, const int* endIds, const float topPp, cudaStream_t stream, const bool* skipDecode);
int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
const T* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float topPp, cudaStream_t stream,
const bool* skipDecode);
//! \brief Compute the topp decay by https://arxiv.org/pdf/2206.04624.pdf
//! In short, the formula is

View File

@ -69,6 +69,7 @@ struct WeightOnlyParams
const ActType* scales;
const ActType* zeros;
const ActType* in;
const ActType* act_scale;
const ActType* bias;
ActType* out;
const int m;
@ -81,13 +82,14 @@ struct WeightOnlyParams
WeightOnlyActivationType act_type;
WeightOnlyParams(const uint8_t* _qweight, const ActType* _scales, const ActType* _zeros, const ActType* _in,
const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k, const int _group_size,
const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type,
const ActType* _act_scale, const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k,
const int _group_size, const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type,
const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type)
: qweight(_qweight)
, scales(_scales)
, zeros(_zeros)
, in(_in)
, act_scale(_act_scale)
, bias(_bias)
, out(_out)
, m(_m)

View File

@ -292,9 +292,9 @@ public:
};
template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp,
bool Zero, bool Bias, int NPerBlock, int Batch, int BlockSize>
bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize>
__device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType* scales, const ActType* zeros,
const ActType* in, const ActType* bias, ActType* out, const int n, const int k)
const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k)
{
static_assert(NPerBlock == 1 || (NPerBlock % 2 == 0));
using ActType2 = typename ActTypeDetails<ActType>::Vec2;
@ -376,6 +376,16 @@ __device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType*
}
}
}
ActType act_scale_v[Details::kElemsPerThread];
if constexpr (ActScale)
{
#pragma unroll
for (int idx = 0; idx < Details::kActivationAccessNum; ++idx)
{
load<AccType>(act_scale_v + idx * Details::kActivationElemNumPerAccess,
act_scale + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess);
}
}
#pragma unroll
for (int b = 0; b < Batch; ++b)
{
@ -386,6 +396,16 @@ __device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType*
// load activation elements
load<AccType>(in_v + idx * Details::kActivationElemNumPerAccess,
in + b * k + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess);
if constexpr (ActScale)
{
#pragma unroll
for (int i = 0; i < Details::kActivationElemNumPerAccess; i += 2)
{
*reinterpret_cast<ActType2*>(in_v + idx * Details::kActivationElemNumPerAccess + i) = __hmul2(
*reinterpret_cast<ActType2*>(in_v + idx * Details::kActivationElemNumPerAccess + i),
*reinterpret_cast<ActType2*>(act_scale_v + idx * Details::kActivationElemNumPerAccess + i));
}
}
}
// Perform vector inner product and accumulate
if constexpr (NPerBlock == 1)
@ -448,20 +468,20 @@ __device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType*
}
template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp,
bool Zero, bool Bias, int NPerBlock, int Batch, int BlockSize>
bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize>
__global__ void weight_only_batched_gemv_wrapper(const uint8_t* qweight, const ActType* scales, const ActType* zeros,
const ActType* in, const ActType* bias, ActType* out, const int n, const int k)
const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k)
{
if constexpr (std::is_same_v<ActType, half>)
{
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch, BlockSize>(
qweight, scales, zeros, in, bias, out, n, k);
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, ActScale, NPerBlock, Batch,
BlockSize>(qweight, scales, zeros, in, act_scale, bias, out, n, k);
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
else if (std::is_same_v<ActType, nv_bfloat16>)
{
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch, BlockSize>(
qweight, scales, zeros, in, bias, out, n, k);
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, ActScale, NPerBlock, Batch,
BlockSize>(qweight, scales, zeros, in, act_scale, bias, out, n, k);
}
#endif
}
@ -478,10 +498,24 @@ struct WeightOnlyBatchedGemvKernelLauncher
dim3 grid(params.n / NPerBlock / kInterleave);
dim3 block(BlockSize);
int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave;
weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch,
BlockSize><<<grid, block, size, stream>>>(params.qweight, reinterpret_cast<const half*>(params.scales),
reinterpret_cast<const half*>(params.zeros), reinterpret_cast<const half*>(params.in),
reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n, params.k);
if (params.act_scale != nullptr)
{
weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, true, NPerBlock, Batch,
BlockSize><<<grid, block, size, stream>>>(params.qweight,
reinterpret_cast<const half*>(params.scales), reinterpret_cast<const half*>(params.zeros),
reinterpret_cast<const half*>(params.in), reinterpret_cast<const half*>(params.act_scale),
reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n,
params.k);
}
else
{
weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, false, NPerBlock,
Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight,
reinterpret_cast<const half*>(params.scales), reinterpret_cast<const half*>(params.zeros),
reinterpret_cast<const half*>(params.in), reinterpret_cast<const half*>(params.act_scale),
reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n,
params.k);
}
}
#if defined(ENABLE_BF16)
else if (params.act_type == WeightOnlyActivationType::BF16)
@ -490,12 +524,28 @@ struct WeightOnlyBatchedGemvKernelLauncher
dim3 grid(params.n / NPerBlock / kInterleave);
dim3 block(BlockSize);
int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave;
weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch,
BlockSize><<<grid, block, size, stream>>>(params.qweight,
reinterpret_cast<const __nv_bfloat16*>(params.scales),
reinterpret_cast<const __nv_bfloat16*>(params.zeros), reinterpret_cast<const __nv_bfloat16*>(params.in),
reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out),
params.n, params.k);
if (params.act_scale != nullptr)
{
weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, true,
NPerBlock, Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight,
reinterpret_cast<const __nv_bfloat16*>(params.scales),
reinterpret_cast<const __nv_bfloat16*>(params.zeros),
reinterpret_cast<const __nv_bfloat16*>(params.in),
reinterpret_cast<const __nv_bfloat16*>(params.act_scale),
reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out),
params.n, params.k);
}
else
{
weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, false,
NPerBlock, Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight,
reinterpret_cast<const __nv_bfloat16*>(params.scales),
reinterpret_cast<const __nv_bfloat16*>(params.zeros),
reinterpret_cast<const __nv_bfloat16*>(params.in),
reinterpret_cast<const __nv_bfloat16*>(params.act_scale),
reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out),
params.n, params.k);
}
}
#endif
}

View File

@ -50,8 +50,9 @@ public:
// mandatory parameters
int step;
int ite;
tc::Tensor logits; // [local_batch_size, beam_width, vocab_size_padded]
tc::Tensor end_ids; // [local_batch_size]
tc::Tensor logits; // [local_batch_size, beam_width, vocab_size_padded]
tc::Tensor end_ids; // [local_batch_size]
std::optional<tc::Tensor> finished; // [batch_size * beam_width]
};
class DecodingOutputParams

View File

@ -226,7 +226,7 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
invokeBanRepeatNgram(logits.template getPtrWithOffset<T>(decode_vocab_size_units_offset),
outputs.output_ids_ptr.template getPtr<const int*>(),
outputs.finished.value_or(Tensor{}).template getPtr<bool>(),
params.finished.value_or(Tensor{}).template getPtr<bool>(),
outputs.parent_ids.value_or(Tensor{}).template getPtr<const int>(), batch_size, local_batch_size,
beam_width, no_repeat_ngram_size_buf, id_offset, vocab_size_padded_, step, stream_);
}
@ -341,6 +341,7 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
step, ite, logits_slice, end_id_slice, static_cast<std::int32_t>(max_seq_len)};
decode_input_tensors.embedding_bias = params.embedding_bias;
decode_input_tensors.finished = params.finished;
if (params.input_lengths)
{

View File

@ -101,6 +101,7 @@ public:
tc::Tensor end_ids; // [batch_size], on gpu
// optional parameters
std::optional<tc::Tensor> finished; // [batch_size * beam_width], optional
std::optional<tc::Tensor> src_cache_indirection; // [local_batch_size, beam_width, max_seq_len] - the k/v cache
// index for beam search, mandatory for beam search, on gpu
std::optional<tc::Tensor> sequence_limit_length; // [batch_size], on gpu

View File

@ -93,7 +93,7 @@ void TopKSamplingLayer<T>::allocateBuffer(size_t const batch_size, std::vector<u
max_top_k = 1;
}
invokeTopKSampling<T>(nullptr, sampling_workspace_size_, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, max_top_k, 1.0f, vocab_size_padded_, nullptr, stream_, batch_size, skip_decode_buf_);
nullptr, nullptr, max_top_k, 1.0f, vocab_size_padded_, nullptr, stream_, batch_size, skip_decode_buf_);
sampling_workspace_ = allocator_->reMalloc(sampling_workspace_, sampling_workspace_size_, false);
runtime_top_k_buf_ = allocator_->reMalloc(runtime_top_k_buf_, sizeof(uint32_t) * batch_size, false);
runtime_top_p_buf_ = allocator_->reMalloc(runtime_top_p_buf_, sizeof(float) * batch_size, false);
@ -171,9 +171,10 @@ void TopKSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
auto* logits = !skip_any_ ? params.logits.template getPtr<T>() : runtime_logits_buf_;
auto* end_ids = params.end_ids.template getPtr<const int>();
bool* finished = (outputs.finished) ? outputs.finished->template getPtr<bool>() : nullptr;
bool* finished_input = (params.finished) ? params.finished->template getPtr<bool>() : nullptr;
bool* finished_output = (outputs.finished) ? outputs.finished->template getPtr<bool>() : nullptr;
invokeAddBiasEndMask(
logits, (T*) (nullptr), end_ids, finished, local_batch_size, vocab_size_, vocab_size_padded_, stream_);
logits, (T*) (nullptr), end_ids, finished_input, local_batch_size, vocab_size_, vocab_size_padded_, stream_);
sync_check_cuda_error();
float* cum_log_probs = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr<float>() : nullptr;
@ -181,16 +182,16 @@ void TopKSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
if (cum_log_probs != nullptr || output_log_probs != nullptr)
{
invokeAddBiasSoftMax(
logits, (T*) (nullptr), end_ids, finished, local_batch_size, vocab_size_, vocab_size_padded_, stream_);
invokeAddBiasSoftMax(logits, (T*) (nullptr), end_ids, finished_input, local_batch_size, vocab_size_,
vocab_size_padded_, stream_);
sync_check_cuda_error();
}
int* sequence_length = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<int>() : nullptr;
invokeBatchTopKSampling(sampling_workspace_, sampling_workspace_size_, logits,
outputs.output_ids_ptr.template getPtr<int*>(), sequence_length, finished, cum_log_probs, output_log_probs,
curandstate_buf_ + ite * local_batch_size,
outputs.output_ids_ptr.template getPtr<int*>(), sequence_length, finished_input, finished_output, cum_log_probs,
output_log_probs, curandstate_buf_ + ite * local_batch_size,
(int) runtime_max_top_k_, // useless because runtime_top_k_buf_ is never
// nullptr. Keep for legacy.
(int*) (runtime_top_k_buf_ + ite * local_batch_size),

View File

@ -110,7 +110,8 @@ void TopPSamplingLayer<T>::allocateBuffer(std::size_t batch_size, std::vector<fl
sampling_workspace_size_, cub_temp_storage_size_,
nullptr, // output_ids
nullptr, // sequence_length
nullptr, // finished_buffer
nullptr, // finished_input_buffer
nullptr, // finished_output_buffer
nullptr, // cum_log_probs
nullptr, // output_log_probs
nullptr, // log_probs
@ -241,9 +242,10 @@ void TopPSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
topp_id_vals_buf_, topp_offset_buf_, begin_topp_offset_buf_, local_batch_size, vocab_size_padded_, stream_);
sync_check_cuda_error();
bool* finished = (outputs.finished) ? outputs.finished->template getPtr<bool>() : nullptr;
bool* finished_input = (params.finished) ? params.finished->template getPtr<bool>() : nullptr;
bool* finished_output = (outputs.finished) ? outputs.finished->template getPtr<bool>() : nullptr;
invokeAddBiasSoftMax(
logits, (T*) (nullptr), end_ids, finished, local_batch_size, vocab_size_, vocab_size_padded_, stream_);
logits, (T*) (nullptr), end_ids, finished_input, local_batch_size, vocab_size_, vocab_size_padded_, stream_);
sync_check_cuda_error();
float* cum_log_probs = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr<float>() : nullptr;
@ -251,10 +253,10 @@ void TopPSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
int* sequence_length = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<int>() : nullptr;
invokeBatchTopPSampling<T>(sampling_workspace_, sampling_workspace_size_, cub_temp_storage_size_,
outputs.output_ids_ptr.template getPtr<int*>(), sequence_length, finished, cum_log_probs, output_log_probs,
logits, topp_id_vals_buf_, topp_offset_buf_, begin_topp_offset_buf_, curandstate_buf_ + ite * local_batch_size,
local_batch_size, vocab_size_padded_, end_ids, runtime_max_top_p_, runtime_top_p_buf_ + ite * local_batch_size,
stream_, skip_decode_buf_ + ite * local_batch_size);
outputs.output_ids_ptr.template getPtr<int*>(), sequence_length, finished_input, finished_output, cum_log_probs,
output_log_probs, logits, topp_id_vals_buf_, topp_offset_buf_, begin_topp_offset_buf_,
curandstate_buf_ + ite * local_batch_size, local_batch_size, vocab_size_padded_, end_ids, runtime_max_top_p_,
runtime_top_p_buf_ + ite * local_batch_size, stream_, skip_decode_buf_ + ite * local_batch_size);
sync_check_cuda_error();
invokeComputeToppDecay(runtime_top_p_buf_ + ite * local_batch_size, initial_top_p_buf_ + ite * local_batch_size,

View File

@ -47,7 +47,8 @@ BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_s
, mRemovePadding(remove_padding)
{
// pre-check whether FMHA is supported in order to save memory allocation
mEnableContextFMHA = mEnableContextFMHA && (mType == DataType::kHALF) && MHARunner::fmha_supported(mHeadSize, mSM);
mEnableContextFMHA = mEnableContextFMHA && (mType == DataType::kHALF) && MHARunner::fmha_supported(mHeadSize, mSM)
&& !mRelativeAttention;
}
// Parameterized constructor
@ -266,7 +267,9 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc
T* linear_bias_slopes = nullptr;
if (mEnableContextFMHA && !mRelativeAttention)
// FMHA doesn't apply to MHA with relative attention bias, i.e. softmax(QK + bias) * V
// We update mEnableContextFMHA in constructor to check this condition
if (mEnableContextFMHA)
{
// b, max_seqlen, actual_total_seqlen
mFMHARunner->setup(request_batch_size, request_seq_len, request_seq_len, request_batch_size * request_seq_len);

View File

@ -271,7 +271,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_hea
{
// pre-check whether FMHA is supported in order to save memory allocation
mEnableContextFMHA = mEnableContextFMHA && (mType == DataType::kHALF || mType == DataType::kBF16)
&& MHARunner::fmha_supported(getHeadSize(), mSM);
&& MHARunner::fmha_supported(getHeadSize(), mSM) && !mCrossAttention
&& mPositionEmbeddingType != tensorrt_llm::kernels::PositionEmbeddingType::kRELATIVE;
TLLM_CHECK(isRoPE() == (rotary_embedding_dim != 0));
TLLM_CHECK_WITH_INFO(
@ -583,7 +584,8 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
// in context phase, currently FMHA runner has two restrictions:
// 1. only apply to self attention. If want fused multi-head cross attention, FMHCA kernels and runner is needed
// 2. doesn't apply to MHA with relative attention bias, i.e. softmax(QK + bias) * V
if (mEnableContextFMHA && !isCrossAttention() && !isRelativePosition())
// We update mEnableContextFMHA in constructor to check these conditions
if (mEnableContextFMHA)
{
invokeApplyBiasRopeUpdateKVCache(const_cast<T*>(params.attention_input), kv_cache_buffer,
const_cast<T*>(params.qkv_bias), params.context_lengths, mRemovePadding ? padding_offset : nullptr,

View File

@ -308,11 +308,17 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe
}
const int n = inputDesc[mWeightInputIdx].dims.d[1];
const int k = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1];
bool use_cuda_kernel = m < SMALL_M_FAST_PATH && mCudaKernelEnabled;
bool use_pre_quant_scale = mQuantAlgo & PRE_QUANT_SCALE;
// mQuantAlgo = pre_quant_scale * 4 + zero * 2 + bias
if (mQuantAlgo & PRE_QUANT_SCALE)
const half* zeros_ptr = (mQuantAlgo & ZERO) ? reinterpret_cast<const half*>(inputs[mZerosInputIdx]) : nullptr;
const half* biases_ptr = (mQuantAlgo & BIAS) ? reinterpret_cast<const half*>(inputs[mBiasesInputIdx]) : nullptr;
const half* act_ptr = reinterpret_cast<const half*>(inputs[0]);
if (use_pre_quant_scale && !use_cuda_kernel)
{
// Apply pre-quant per channel scale on activations
act_ptr = reinterpret_cast<const half*>(workspace);
if (mType == nvinfer1::DataType::kHALF)
{
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<half>(reinterpret_cast<half*>(workspace),
@ -329,10 +335,6 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe
#endif
}
const half* zeros_ptr = (mQuantAlgo & ZERO) ? reinterpret_cast<const half*>(inputs[mZerosInputIdx]) : nullptr;
const half* biases_ptr = (mQuantAlgo & BIAS) ? reinterpret_cast<const half*>(inputs[mBiasesInputIdx]) : nullptr;
const half* act_ptr = reinterpret_cast<const half*>((mQuantAlgo & PRE_QUANT_SCALE) ? workspace : inputs[0]);
#if defined(ENABLE_BF16)
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16,
"No valid weightOnlyGropwiseQuantMatmul configuration");
@ -350,14 +352,18 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe
{
weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::BF16;
}
if (m < SMALL_M_FAST_PATH && mCudaKernelEnabled)
if (use_cuda_kernel)
{
// Use CUDA kernels for small batch size
// The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass kernel
// when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights.
const void* pre_quant_scale = nullptr;
if (use_pre_quant_scale)
pre_quant_scale = inputs[mPreQuantScaleInputIdx];
tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast<const uint8_t*>(inputs[mWeightInputIdx]),
inputs[mScalesInputIdx], zeros_ptr, act_ptr, biases_ptr, outputs[0], m, real_n, k, mGroupSize,
tensorrt_llm::kernels::WeightOnlyQuantType::Int4b, tensorrt_llm::kernels::WeightOnlyType::GroupWise,
inputs[mScalesInputIdx], zeros_ptr, act_ptr, pre_quant_scale, biases_ptr, outputs[0], m, real_n, k,
mGroupSize, tensorrt_llm::kernels::WeightOnlyQuantType::Int4b,
tensorrt_llm::kernels::WeightOnlyType::GroupWise,
tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type};
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream);
}

View File

@ -322,7 +322,7 @@ int WeightOnlyQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDesc* input
// The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass
// kernel when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights.
tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast<const uint8_t*>(inputs[1]), inputs[2], nullptr,
inputs[0], nullptr, outputs[0], m, real_n, k, 0, weight_only_quant_type,
inputs[0], nullptr, nullptr, outputs[0], m, real_n, k, 0, weight_only_quant_type,
tensorrt_llm::kernels::WeightOnlyType::PerChannel,
tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type};
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream);

View File

@ -24,7 +24,14 @@ endif()
find_package(pybind11 REQUIRED)
set(SRCS bindings.cpp runtime/generationInput.cpp runtime/generationOutput.cpp)
set(SRCS
bindings.cpp
batch_manager/gptManager.cpp
batch_manager/llmRequest.cpp
batch_manager/inferenceRequest.cpp
batch_manager/namedTensor.cpp
runtime/generationInput.cpp
runtime/generationOutput.cpp)
pybind11_add_module(${TRTLLM_PYBIND_MODULE} ${SRCS})

View File

@ -0,0 +1,88 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "gptManager.h"
#include "inferenceRequest.h"
#include "namedTensor.h"
#include "tensorrt_llm/batch_manager/GptManager.h"
#include "tensorrt_llm/batch_manager/callbacks.h"
#include "tensorrt_llm/common/assert.h"
#include <ATen/ATen.h>
#include <ATen/ops/tensor.h>
#include <functional>
#include <memory>
#include <optional>
namespace tb = tensorrt_llm::batch_manager;
namespace tensorrt_llm::pybind::batch_manager
{
GptManager::GptManager(std::filesystem::path const& trtEnginePath, tb::TrtGptModelType modelType, int32_t maxBeamWidth,
tb::batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
SendResponseCallback sendResponseCb, tb::PollStopSignalCallback pollStopSignalCb,
tb::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb, const tb::TrtGptModelOptionalParams& optionalParams,
std::optional<uint64_t> terminateReqId)
: tb::GptManager(trtEnginePath, modelType, maxBeamWidth, schedulerPolicy, callbackAdapter(getInferenceRequestsCb),
callbackAdapter(sendResponseCb), pollStopSignalCb, returnBatchManagerStatsCb, optionalParams, terminateReqId)
{
}
py::object GptManager::enter()
{
return py::cast(this);
}
void GptManager::exit(py::handle type, py::handle value, py::handle traceback)
{
// NOTE: we must release the GIL here. GptManager has spawned a thread for the execution loop. That thread must be
// able to do forward progress for the shutdown process to succeed. For that, we must manually release the GIL while
// waiting in `process.join()`.
py::gil_scoped_release release;
shutdown();
}
tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback)
{
return [callback](int32_t max_sequences)
{
std::list<InferenceRequest> pythonResults = callback(max_sequences);
std::list<std::shared_ptr<tb::InferenceRequest>> cppResults{};
for (const auto& ir : pythonResults)
{
cppResults.push_back(ir.toTrtLlm());
}
return cppResults;
};
}
tb::SendResponseCallback callbackAdapter(SendResponseCallback callback)
{
return [callback](uint64_t id, std::list<tb::NamedTensor> const& cppTensors, bool isOk, const std::string& errMsg)
{
std::list<NamedTensor> pythonList{};
for (const auto& cppNamedTensor : cppTensors)
{
pythonList.push_back(NamedTensor{cppNamedTensor});
}
callback(id, pythonList, isOk, errMsg);
};
}
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -0,0 +1,54 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "inferenceRequest.h"
#include "namedTensor.h"
#include "tensorrt_llm/batch_manager/GptManager.h"
#include "tensorrt_llm/batch_manager/callbacks.h"
#include <pybind11/functional.h>
#include <ATen/ops/tensor.h>
#include <functional>
namespace py = pybind11;
namespace tb = tensorrt_llm::batch_manager;
namespace tensorrt_llm::pybind::batch_manager
{
using GetInferenceRequestsCallback = std::function<std::list<InferenceRequest>(int32_t)>;
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, const std::string&)>;
tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback);
tb::SendResponseCallback callbackAdapter(SendResponseCallback callback);
class GptManager : tb::GptManager
{
public:
GptManager(std::filesystem::path const& trtEnginePath, tb::TrtGptModelType modelType, int32_t maxBeamWidth,
tb::batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
SendResponseCallback sendResponseCb, tb::PollStopSignalCallback pollStopSignalCb = nullptr,
tb::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
const tb::TrtGptModelOptionalParams& optionalParams = tb::TrtGptModelOptionalParams(),
std::optional<uint64_t> terminateReqId = std::nullopt);
py::object enter();
void exit(py::handle type, py::handle value, py::handle traceback);
};
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -0,0 +1,37 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "inferenceRequest.h"
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <memory>
namespace tb = tensorrt_llm::batch_manager;
namespace tr = tensorrt_llm::runtime;
using namespace tensorrt_llm::pybind::batch_manager;
std::shared_ptr<tb::InferenceRequest> InferenceRequest::toTrtLlm() const
{
tb::InferenceRequest::TensorMap trtTensors;
for (const auto& torchTensorItem : mInputTensors)
{
trtTensors.insert({torchTensorItem.first, tr::TorchView::of(torchTensorItem.second)});
}
return std::make_shared<tb::InferenceRequest>(trtTensors, mRequestId);
}

View File

@ -0,0 +1,59 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/common/assert.h"
#include <ATen/ATen.h>
#include <ATen/ops/tensor.h>
#include <memory>
#include <optional>
namespace tensorrt_llm::pybind::batch_manager
{
class InferenceRequest : public tensorrt_llm::batch_manager::GenericInferenceRequest<at::Tensor,
std::unordered_map<std::string, at::Tensor>>
{
public:
using Base
= tensorrt_llm::batch_manager::GenericInferenceRequest<at::Tensor, std::unordered_map<std::string, at::Tensor>>;
using TensorPtr = Base::TensorPtr;
using TensorMap = Base::TensorMap;
InferenceRequest(uint64_t requestId)
: Base(requestId)
{
}
InferenceRequest(TensorMap const& inputTensors, uint64_t requestId)
: Base(inputTensors, requestId)
{
}
InferenceRequest(TensorMap&& inputTensors, uint64_t requestId)
: Base(inputTensors, requestId)
{
}
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::InferenceRequest> toTrtLlm() const;
};
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -0,0 +1,54 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "llmRequest.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/runtime/generationInput.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <memory>
namespace tb = tensorrt_llm::batch_manager;
namespace tr = tensorrt_llm::runtime;
using namespace tensorrt_llm::pybind::batch_manager;
namespace
{
std::optional<tb::LlmRequest::TensorPtr> from_torch(std::optional<LlmRequest::TensorPtr> torchPtr)
{
if (torchPtr)
{
return tr::TorchView::of(torchPtr.value());
}
return std::nullopt;
}
} // namespace
std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
{
auto embeddingBias = from_torch(mEmbeddingBias);
auto badWordsList = from_torch(mBadWordsList);
auto stopWordsList = from_torch(mStopWordsList);
auto promptEmbeddingTable = from_torch(mPromptEmbeddingTable);
return std::make_shared<tb::LlmRequest>(mRequestId, mMaxNewTokens,
std::make_shared<std::vector<TokenIdType>>(mTokens.at(0)), mSamplingConfig, mIsStreaming, mEndId, mPadId,
embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, mPromptVocabSize, mReturnLogProbs,
mDraftTokens);
}

View File

@ -0,0 +1,62 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/common/assert.h"
#include <ATen/ATen.h>
#include <ATen/ops/tensor.h>
#include <memory>
#include <optional>
namespace tensorrt_llm::pybind::batch_manager
{
class LlmRequest : public tensorrt_llm::batch_manager::GenericLlmRequest<at::Tensor>
{
public:
using Base = GenericLlmRequest<at::Tensor>;
using TensorPtr = Base::TensorPtr;
using SizeType = Base::SizeType;
using TokenIdType = Base::TokenIdType;
using RequestIdType = Base::RequestIdType;
using VecLogProbs = Base::VecLogProbs;
using BeamTokens = Base::BeamTokens;
using VecTokens = Base::VecTokens;
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::vector<TokenIdType> inputTokens,
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false,
std::optional<VecTokens> draftTokens = std::nullopt)
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
samplingConfig, isStreaming, endId, padId, embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable,
promptVocabSize, returnLogProbs,
draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value()))
: std::make_shared<VecTokens>())
{
}
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> toTrtLlm() const;
};
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -0,0 +1,47 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "namedTensor.h"
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/runtime/torchUtils.h"
#include "tensorrt_llm/runtime/torchView.h"
#include <memory>
namespace tb = tensorrt_llm::batch_manager;
namespace tensorrt_llm::pybind::batch_manager
{
NamedTensor::NamedTensor(const tb::NamedTensor& cppNamedTensor)
: Base(cppNamedTensor.name)
{
auto cppTensor = cppNamedTensor.tensor;
std::vector<at::IntArrayRef::value_type> shapeValues;
for (int i = 0; i < cppTensor->getShape().nbDims; ++i)
{
shapeValues.push_back(cppTensor->getShape().d[i]);
}
tensor = at::from_blob(cppTensor->data(), shapeValues,
at::TensorOptions()
.device(runtime::TorchUtils::deviceType(cppTensor->getMemoryType()))
.pinned_memory(cppTensor->getMemoryType() == runtime::MemoryType::kPINNED)
.dtype(runtime::TorchUtils::dataType(cppTensor->getDataType())));
}
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -0,0 +1,50 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/batch_manager/NamedTensor.h"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include <ATen/ATen.h>
#include <ATen/core/ATen_fwd.h>
#include <ATen/ops/from_blob.h>
#include <ATen/ops/tensor.h>
#include <c10/core/DeviceType.h>
#include <c10/util/ArrayRef.h>
#include <memory>
#include <optional>
namespace tb = tensorrt_llm::batch_manager;
namespace tensorrt_llm::pybind::batch_manager
{
struct NamedTensor : public tb::GenericNamedTensor<std::optional<at::Tensor>>
{
using Base = tb::GenericNamedTensor<std::optional<at::Tensor>>;
using TensorPtr = Base::TensorPtr;
NamedTensor(TensorPtr _tensor, std::string _name)
: Base(_tensor, _name){};
NamedTensor(const tb::NamedTensor& cppNamedTensor);
};
} // namespace tensorrt_llm::pybind::batch_manager

View File

@ -15,14 +15,28 @@
* limitations under the License.
*/
#include <memory>
#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <pybind11/stl.h>
#include <torch/extension.h>
#include <vector>
#include "batch_manager/gptManager.h"
#include "batch_manager/inferenceRequest.h"
#include "batch_manager/llmRequest.h"
#include "batch_manager/namedTensor.h"
#include "runtime/generationInput.h"
#include "runtime/generationOutput.h"
#include "tensorrt_llm/batch_manager/BatchManager.h"
#include "tensorrt_llm/batch_manager/batchScheduler.h"
#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h"
#include "utils/pathCaster.h"
#include "tensorrt_llm/batch_manager/GptManager.h"
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/gptJsonConfig.h"
@ -31,9 +45,13 @@
namespace py = pybind11;
namespace tb = tensorrt_llm::batch_manager;
namespace tbb = tensorrt_llm::batch_manager::batch_scheduler;
namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager;
namespace tpb = tensorrt_llm::pybind::batch_manager;
namespace tc = tensorrt_llm::common;
namespace tr = tensorrt_llm::runtime;
namespace tpr = tensorrt_llm::pybind::runtime;
using SizeType = tr::SizeType;
#if not defined(TRTLLM_PYBIND_MODULE)
#error "TRTLLM_PYBIND_MODULE must be defined"
@ -53,8 +71,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_readwrite("prompt_tuning_enabled", &tpr::PromptTuningParams::promptTuningEnabled);
py::class_<tpr::GenerationInput>(m, "GenerationInput")
.def(py::init<tr::SizeType, tr::SizeType, tpr::GenerationInput::TensorPtr, tpr::GenerationInput::TensorPtr,
bool>(),
.def(py::init<SizeType, SizeType, tpr::GenerationInput::TensorPtr, tpr::GenerationInput::TensorPtr, bool>(),
py::arg("end_id"), py::arg("pad_id"), py::arg("ids"), py::arg("lengths"), py::arg("packed") = false)
.def_readwrite("end_id", &tpr::GenerationInput::endId)
.def_readwrite("pad_id", &tpr::GenerationInput::padId)
@ -76,16 +93,16 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_readwrite("context_logits", &tpr::GenerationOutput::contextLogits)
.def_readwrite("on_token_generated", &tpr::GenerationOutput::onTokenGenerated);
py::class_<tb::kv_cache_manager::KvCacheConfig>(m, "KvCacheConfig")
.def(py::init<std::optional<tr::SizeType>, std::optional<tr::SizeType>, std::optional<float>>(),
py::class_<tbk::KvCacheConfig>(m, "KvCacheConfig")
.def(py::init<std::optional<SizeType>, std::optional<SizeType>, std::optional<float>>(),
py::arg("max_tokens") = py::none(), py::arg("max_kv_cache_length") = py::none(),
py::arg("free_gpu_memory_fraction") = py::none())
.def_readwrite("max_tokens", &tb::kv_cache_manager::KvCacheConfig::maxTokens)
.def_readwrite("max_kv_cache_length", &tb::kv_cache_manager::KvCacheConfig::maxKvCacheLength)
.def_readwrite("free_gpu_memory_fraction", &tb::kv_cache_manager::KvCacheConfig::freeGpuMemoryFraction);
.def_readwrite("max_tokens", &tbk::KvCacheConfig::maxTokens)
.def_readwrite("max_kv_cache_length", &tbk::KvCacheConfig::maxKvCacheLength)
.def_readwrite("free_gpu_memory_fraction", &tbk::KvCacheConfig::freeGpuMemoryFraction);
py::class_<tr::GptSession::Config>(m, "GptSessionConfig")
.def(py::init<tr::SizeType, tr::SizeType, tr::SizeType>(), py::arg("max_batch_size"), py::arg("max_beam_width"),
.def(py::init<SizeType, SizeType, SizeType>(), py::arg("max_batch_size"), py::arg("max_beam_width"),
py::arg("max_sequence_length"))
.def_readwrite("max_batch_size", &tr::GptSession::Config::maxBatchSize)
.def_readwrite("max_beam_width", &tr::GptSession::Config::maxBeamWidth)
@ -148,9 +165,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def(py::self != py::self);
py::class_<tr::GptModelConfig>(m, "GptModelConfig")
.def(py::init<tr::SizeType, tr::SizeType, tr::SizeType, tr::SizeType, nvinfer1::DataType>(),
py::arg("vocab_size"), py::arg("num_layers"), py::arg("num_heads"), py::arg("hidden_size"),
py::arg("data_type"))
.def(py::init<SizeType, SizeType, SizeType, SizeType, nvinfer1::DataType>(), py::arg("vocab_size"),
py::arg("num_layers"), py::arg("num_heads"), py::arg("hidden_size"), py::arg("data_type"))
.def_property_readonly("vocab_size", &tr::GptModelConfig::getVocabSize)
.def("vocab_size_padded", &tr::GptModelConfig::getVocabSizePadded, py::arg("world_size"))
.def("num_layers", &tr::GptModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1)
@ -185,7 +201,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
py::overload_cast<bool>(&tr::GptModelConfig::useCustomAllReduce));
py::class_<tr::WorldConfig>(m, "WorldConfig")
.def(py::init<tr::SizeType, tr::SizeType, tr::SizeType, tr::SizeType>(), py::arg("tensor_parallelism") = 1,
.def(py::init<SizeType, SizeType, SizeType, SizeType>(), py::arg("tensor_parallelism") = 1,
py::arg("pipeline_parallelism") = 1, py::arg("rank") = 0,
py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode)
.def_property_readonly("size", &tr::WorldConfig::getSize)
@ -199,13 +215,12 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_property_readonly("pipeline_parallel_rank", &tr::WorldConfig::getPipelineParallelRank)
.def_property_readonly("tensor_parallel_rank", &tr::WorldConfig::getTensorParallelRank)
.def_static("mpi",
py::overload_cast<tr::SizeType, std::optional<tr::SizeType>, std::optional<tr::SizeType>>(
&tr::WorldConfig::mpi),
py::overload_cast<SizeType, std::optional<SizeType>, std::optional<SizeType>>(&tr::WorldConfig::mpi),
py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode, py::arg("tensor_parallelism") = py::none(),
py::arg("pipeline_parallelism") = py::none());
py::class_<tr::SamplingConfig>(m, "SamplingConfig")
.def(py::init<tr::SizeType>(), py::arg("beam_width") = 1)
.def(py::init<SizeType>(), py::arg("beam_width") = 1)
.def_readwrite("beam_width", &tr::SamplingConfig::beamWidth)
.def_readwrite("temperature", &tr::SamplingConfig::temperature)
.def_readwrite("min_length", &tr::SamplingConfig::minLength)
@ -221,13 +236,12 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty);
py::class_<tr::GptJsonConfig>(m, "GptJsonConfig")
.def(py::init<std::string, std::string, tr::SizeType, tr::SizeType, tr::GptModelConfig>(), py::arg("name"),
.def(py::init<std::string, std::string, SizeType, SizeType, tr::GptModelConfig>(), py::arg("name"),
py::arg("precision"), py::arg("tensor_parallelism"), py::arg("pipeline_parallelism"),
py::arg("model_config"))
.def_static("parse", py::overload_cast<std::string const&>(&tr::GptJsonConfig::parse), py::arg("json"))
.def_static(
"parse_file", [](std::string const& file) { return tr::GptJsonConfig::parse(std::filesystem::path(file)); },
py::arg("file"))
"parse_file", py::overload_cast<std::filesystem::path const&>(&tr::GptJsonConfig::parse), py::arg("path"))
.def_property_readonly("model_config", &tr::GptJsonConfig::getModelConfig)
.def_property_readonly("name", &tr::GptJsonConfig::getName)
.def_property_readonly("precision", &tr::GptJsonConfig::getPrecision)
@ -254,4 +268,104 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
tr::SamplingConfig const& samplingConfig)
{ self.generate(*outputs.toTrtLlm(), *inputs.toTrtLlm(), samplingConfig); },
py::arg("outputs"), py::arg("inputs"), py::arg("sampling_config"));
py::enum_<tb::LlmRequestState_t>(m, "LlmRequestState")
.value("REQUEST_STATE_UNKNOWN", tb::LlmRequestState_t::REQUEST_STATE_UNKNOWN)
.value("REQUEST_STATE_CONTEXT_INIT", tb::LlmRequestState_t::REQUEST_STATE_CONTEXT_INIT)
.value("REQUEST_STATE_GENERATION_IN_PROGRESS", tb::LlmRequestState_t::REQUEST_STATE_GENERATION_IN_PROGRESS)
.value("REQUEST_STATE_GENERATION_COMPLETE", tb::LlmRequestState_t::REQUEST_STATE_GENERATION_COMPLETE);
using LlmRequest = tpb::LlmRequest;
py::class_<LlmRequest>(m, "LlmRequest")
.def(py::init<LlmRequest::RequestIdType, LlmRequest::SizeType, std::vector<LlmRequest::TokenIdType>,
tr::SamplingConfig, bool, std::optional<LlmRequest::SizeType>, std::optional<LlmRequest::SizeType>,
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
std::optional<LlmRequest::SizeType>, bool, std::optional<LlmRequest::VecTokens>>(),
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
py::arg("stop_words_list") = std::nullopt, py::arg("prompt_embedding_table") = std::nullopt,
py::arg("prompt_vocab_size") = std::nullopt, py::arg("return_log_probs") = false,
py::arg("draft_tokens") = std::nullopt)
.def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam"))
.def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens)
.def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos"))
.def("get_tokens", &LlmRequest::getTokens, py::arg("beam"))
.def_property_readonly("max_num_generated_tokens", &LlmRequest::getMaxNumGeneratedTokens)
.def("add_new_token", &LlmRequest::addNewToken, py::arg("token"), py::arg("beam"))
.def("add_new_tokens", &LlmRequest::addNewTokens, py::arg("beam_tokens"))
.def("set_generated_tokens", &LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens"))
.def("pause", &LlmRequest::pause, py::arg("max_input_len"))
.def_property("max_sent_token_pos", &LlmRequest::getMaxSentTokenPos, &LlmRequest::setMaxSentTokenPos)
.def_property_readonly("prompt_embedding_table", &LlmRequest::getPromptEmbeddingTable)
.def_property_readonly("prompt_vocab_size", &LlmRequest::getPromptVocabSize)
.def_property_readonly("embedding_bias", &LlmRequest::getEmbeddingBias)
.def_property_readonly("bad_words_list", &LlmRequest::getBadWordsList)
.def_property_readonly("stop_words_list", &LlmRequest::getStopWordsList)
.def_readwrite("request_id", &LlmRequest::mRequestId)
.def_readwrite("prompt_len", &LlmRequest::mPromptLen)
.def_readwrite("max_new_tokens", &LlmRequest::mMaxNewTokens)
.def_readwrite("sampling_config", &LlmRequest::mSamplingConfig)
.def_readwrite("state", &LlmRequest::mState)
.def_readwrite("is_streaming", &LlmRequest::mIsStreaming)
.def_readwrite("end_id", &LlmRequest::mEndId)
.def_readwrite("pad_id", &LlmRequest::mPadId)
.def_readwrite("batch_slot", &LlmRequest::mBatchSlot)
.def_property_readonly("return_log_probs", &LlmRequest::returnLogProbs)
.def_property_readonly("log_probs", py::overload_cast<>(&LlmRequest::getLogProbs, py::const_))
.def("get_log_probs", py::overload_cast<SizeType>(&LlmRequest::getLogProbs, py::const_))
.def("set_log_probs", &LlmRequest::setLogProbs, py::arg("log_probs"), py::arg("beam"))
.def_property_readonly("cum_log_probs", &LlmRequest::getCumLogProbs)
.def("set_cum_log_prob", &LlmRequest::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam"))
.def_property_readonly("orig_prompt_len", &LlmRequest::getOrigPromptLen)
.def("has_draft_tokens", &LlmRequest::hasDraftTokens)
.def_property(
"draft_tokens", [](LlmRequest& self) { return *self.getDraftTokens(); },
[](LlmRequest& self, LlmRequest::VecTokens& draftTokens)
{ self.setDraftTokens(std::make_shared<LlmRequest::VecTokens>(std::move(draftTokens))); });
using InferenceRequest = tpb::InferenceRequest;
py::class_<InferenceRequest>(m, "InferenceRequest")
.def(py::init<uint64_t>())
.def(py::init<InferenceRequest::TensorMap const&, uint64_t>())
.def("get_input_tensor", &InferenceRequest::getInputTensor, py::arg("input_tensor_name"))
.def("emplace_input_tensor", &InferenceRequest::emplaceInputTensor, py::arg("input_tensor_name"),
py::arg("input_tensor"))
.def_property("is_streaming", &InferenceRequest::isStreaming, &InferenceRequest::setIsStreaming)
.def_property_readonly("request_id", &InferenceRequest::getRequestId);
py::enum_<tb::TrtGptModelType>(m, "TrtGptModelType")
.value("V1", tb::TrtGptModelType::V1)
.value("InflightBatching", tb::TrtGptModelType::InflightBatching)
.value("InflightFusedBatching", tb::TrtGptModelType::InflightFusedBatching);
py::enum_<tbb::SchedulerPolicy>(m, "SchedulerPolicy")
.value("MAX_UTILIZATION", tbb::SchedulerPolicy::MAX_UTILIZATION)
.value("GUARANTEED_NO_EVICT", tbb::SchedulerPolicy::GUARANTEED_NO_EVICT);
py::class_<tb::TrtGptModelOptionalParams>(m, "TrtGptModelOptionalParams")
.def(py::init<tbk::KvCacheConfig, std::optional<SizeType>, bool>(),
py::arg("kv_cache_config") = tbk::KvCacheConfig{}, py::arg("max_num_sequences") = py::none(),
py::arg("enable_trt_overlap") = true)
.def_readwrite("kv_cache_config", &tb::TrtGptModelOptionalParams::kvCacheConfig)
.def_readwrite("max_num_sequences", &tb::TrtGptModelOptionalParams::maxNumSequences)
.def_readwrite("enable_trt_overlap", &tb::TrtGptModelOptionalParams::enableTrtOverlap);
py::class_<tpb::NamedTensor>(m, "NamedTensor")
.def(py::init<tpb::NamedTensor::TensorPtr, std::string>(), py::arg("tensor"), py::arg("name"))
.def_readwrite("tensor", &tpb::NamedTensor::tensor)
.def_readwrite("name", &tpb::NamedTensor::name);
py::class_<tpb::GptManager>(m, "GptManager")
.def(py::init<std::filesystem::path const&, tb::TrtGptModelType, int32_t, tb::batch_scheduler::SchedulerPolicy,
tpb::GetInferenceRequestsCallback, tpb::SendResponseCallback, tb::PollStopSignalCallback,
tb::ReturnBatchManagerStatsCallback, const tb::TrtGptModelOptionalParams&, std::optional<uint64_t>>(),
py::arg("trt_engine_path"), py::arg("model_type"), py::arg("max_beam_width"), py::arg("scheduler_policy"),
py::arg("get_inference_requests_cb"), py::arg("send_response_cb"), py::arg("poll_stop_signal_cb") = nullptr,
py::arg("return_batch_manager_stats_cb") = nullptr,
py::arg("optional_params") = tb::TrtGptModelOptionalParams(), py::arg("terminate_req_id") = std::nullopt)
.def("shutdown", &tpb::GptManager::exit)
.def("__enter__", &tpb::GptManager::enter)
.def("__exit__", &tpb::GptManager::exit);
}

View File

@ -0,0 +1,103 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "pybind11/cast.h"
#include "pybind11/detail/common.h"
#include "pybind11/detail/descr.h"
#include "pybind11/pybind11.h"
#include "pybind11/pytypes.h"
#include <filesystem>
namespace PYBIND11_NAMESPACE
{
namespace detail
{
template <typename T>
struct PathCaster
{
private:
static PyObject* unicode_from_fs_native(const std::string& w)
{
return PyUnicode_DecodeFSDefaultAndSize(w.c_str(), ssize_t(w.size()));
}
static PyObject* unicode_from_fs_native(const std::wstring& w)
{
return PyUnicode_FromWideChar(w.c_str(), ssize_t(w.size()));
}
public:
static handle cast(const T& path, return_value_policy, handle)
{
if (auto py_str = unicode_from_fs_native(path.native()))
{
return module_::import("pathlib").attr("Path")(reinterpret_steal<object>(py_str)).release();
}
return nullptr;
}
bool load(handle handle, bool)
{
PyObject* native = nullptr;
if constexpr (std::is_same_v<typename T::value_type, char>)
{
if (PyUnicode_FSConverter(handle.ptr(), &native) != 0)
{
if (auto* c_str = PyBytes_AsString(native))
{
// AsString returns a pointer to the internal buffer, which
// must not be free'd.
value = c_str;
}
}
}
else if constexpr (std::is_same_v<typename T::value_type, wchar_t>)
{
if (PyUnicode_FSDecoder(handle.ptr(), &native) != 0)
{
if (auto* c_str = PyUnicode_AsWideCharString(native, nullptr))
{
// AsWideCharString returns a new string that must be free'd.
value = c_str; // Copies the string.
PyMem_Free(c_str);
}
}
}
Py_XDECREF(native);
if (PyErr_Occurred())
{
PyErr_Clear();
return false;
}
return true;
}
PYBIND11_TYPE_CASTER(T, const_name("os.PathLike"));
};
template <>
struct type_caster<std::filesystem::path> : public PathCaster<std::filesystem::path>
{
};
} // namespace detail
} // namespace PYBIND11_NAMESPACE

View File

@ -130,6 +130,11 @@ typename tl::DynamicDecodeLayer<T>::ForwardParams prepareInputs(DecodingInput co
forwardParams.stop_words_list = tcc::toTllmTensor(*input.stopWordsList);
}
if (input.finished)
{
forwardParams.finished = tcc::toTllmTensor(*input.finished);
}
return forwardParams;
}
@ -335,6 +340,47 @@ void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& de
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void IGptDecoder::acceptTokens(const ITensor& targetTokenIds, const ITensor& draftTokenIds,
const ITensor& contextLengths, const ITensor& numDraftTokens, ITensor& sequenceLengths, const ITensor& finishedVec,
ITensor& finishedFinal, ITensor& finishedSum, BufferManager::CudaStreamPtr const& stream)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto const targetTokenIdsShape = targetTokenIds.getShape();
auto const batchSize = targetTokenIdsShape.d[0];
auto const beamWidth = targetTokenIdsShape.d[1];
auto const maxSeqLength = targetTokenIdsShape.d[2];
auto const maxDraftTokens = draftTokenIds.getShape().d[2];
TLLM_CHECK_WITH_INFO(
beamWidth == 1, common::fmtstr("Beam width (%d) > 1 is not supported for the speculative decoding", beamWidth));
TLLM_CHECK_WITH_INFO(draftTokenIds.getShape().d[0] == batchSize,
common::fmtstr("Draft tokens batch size (%d) is not equal to target batch size (%d)",
draftTokenIds.getShape().d[0], batchSize));
TLLM_CHECK_WITH_INFO(contextLengths.getShape().d[0] == batchSize,
common::fmtstr("Context length batch size (%d) is not equal to batch size (%d)", contextLengths.getShape().d[0],
batchSize));
TLLM_CHECK_WITH_INFO(numDraftTokens.getShape().d[0] == batchSize,
common::fmtstr("Num draft tokens batch size (%d) is not equal to batch size (%d)",
numDraftTokens.getShape().d[0], batchSize));
TLLM_CHECK_WITH_INFO(sequenceLengths.getShape().d[0] == batchSize,
common::fmtstr("Sequence length batch size (%d) is not equal to batch size (%d)",
sequenceLengths.getShape().d[0], batchSize));
tensorrt_llm::kernels::invokeAcceptTokens(bufferCast<SizeType>(draftTokenIds), bufferCast<SizeType>(targetTokenIds),
bufferCast<SizeType>(contextLengths), bufferCast<SizeType>(numDraftTokens),
bufferCast<SizeType>(sequenceLengths), bufferCast<bool>(finishedVec), bufferCast<bool>(finishedFinal),
bufferCast<int>(finishedSum), batchSize, beamWidth, maxSeqLength, maxDraftTokens, stream->get());
sync_check_cuda_error();
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
namespace tensorrt_llm::runtime
{
template class GptDecoder<float>;

View File

@ -92,33 +92,42 @@ GptDecoderBatch::GptDecoderBatch(
auto outputIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
dOutput = std::make_unique<DecodingOutput>(std::move(outputIds));
dOutput->newTokens = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
dOutput->newTokensSteps = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
dOutput->parentIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
dOutput->finished = mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<bool>::value);
dOutput->finishedSteps = mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<bool>::value);
// use batchSize many entries instead of the usual 1
dOutput->finishedSum = mBufferManager.emptyTensor(MemoryType::kPINNED, nvSizeType);
mFinishedSum = mBufferManager.pinned(ITensor::makeShape({1}), nvSizeType);
dOutput->lengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
// we don't need dOutput->lengths because lengths are passed from outside
dOutput->cumLogProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
dOutput->logProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
dOutput->beamHypotheses.empty(mBufferManager);
mNumDraftTokens = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
SizeType maxSequenceLength, nvinfer1::DataType dtype)
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(maxBatchSize > 0);
TLLM_CHECK(maxBeamWidth > 0);
TLLM_CHECK(maxTokensPerStep > 0);
TLLM_CHECK(maxSequenceLength > 0);
mActualBatchSize = maxBatchSize;
mGeneratedTokensPerStep.resize(maxBatchSize);
mMaxSequenceLength = maxSequenceLength;
mMaxKvCacheLength = maxKvCacheLength;
mMaxTokensPerStep = maxTokensPerStep;
auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize});
auto const maxBatchSizeXmaxBeamWidth = ITensor::makeShape({maxBatchSize, maxBeamWidth});
auto const maxBatchSizeXmaxTokensPerStepXmaxBeamWidth
= ITensor::makeShape({maxBatchSize, maxTokensPerStep, maxBeamWidth});
auto const maxTokensPerStepXmaxBatchSizeXmaxBeamWidth
= ITensor::makeShape({maxTokensPerStep, maxBatchSize, maxBeamWidth});
auto& dInput = *mJointDecodingInput;
const_cast<ITensor&>(*dInput.endIds).reshape(maxBatchSizeXmaxBeamWidth);
@ -133,14 +142,13 @@ void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeTy
auto& dOutput = *mJointDecodingOutput;
dOutput.ids->reshape(jointOutputIdsShape);
dOutput.newTokens->reshape(maxBatchSizeXmaxBeamWidth);
mBufferManager.setZero(*dOutput.newTokens);
dOutput.newTokensSteps->reshape(maxTokensPerStepXmaxBatchSizeXmaxBeamWidth);
mBufferManager.setZero(*dOutput.newTokensSteps);
dOutput.finishedSteps->reshape(maxBatchSizeXmaxTokensPerStepXmaxBeamWidth);
mBufferManager.setZero(*dOutput.finishedSteps);
dOutput.parentIds->reshape(jointOutputIdsShape);
dOutput.lengths->reshape(maxBatchSizeXmaxBeamWidth);
mBufferManager.setZero(*dOutput.lengths);
dOutput.finished->reshape(maxBatchSizeXmaxBeamWidth);
mBufferManager.setZero(*dOutput.finished);
mBufferManager.setZero(*dOutput.finishedSum);
// use batchSize many entries instead of the usual 1
dOutput.finishedSum->reshape(maxBatchSizeShape);
mBufferManager.setZero(*dOutput.finishedSum);
@ -160,6 +168,10 @@ void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeTy
dOutput.beamHypotheses.release();
}
// speculative decoding only works for beam width == 1
mDraftTokenIds.resize(maxBatchSize);
mNumDraftTokens->reshape(ITensor::makeShape({maxBatchSize, 1}));
mStreams.resize(maxBatchSize);
mDecoders.resize(maxBatchSize);
mDecodingInputs.resize(maxBatchSize);
@ -180,6 +192,7 @@ void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeTy
mFinished[i] = true;
mMaxNewTokens[i] = 0;
mBeamWidths[i] = 0;
mGeneratedTokensPerStep[i] = 0;
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
}
@ -198,7 +211,7 @@ void GptDecoderBatch::newRequest(
tc::fmtstr("Beam width (%d) must be smaller than maxBeamWidth (%d) passed to decoder setup function.",
beamWidth, maxBeamWidth));
auto const& requestIds = request.ids;
auto const inputLength = requestIds->getShape().d[0];
auto const inputLength = request.inputLen;
auto const maxNewTokens = request.maxNewTokens.value_or(mMaxSequenceLength - inputLength);
TLLM_CHECK_WITH_INFO(inputLength + maxNewTokens <= mMaxSequenceLength,
tc::fmtstr("Input length (%d) + max new tokens (%d) must be less than max sequence length (%d).", inputLength,
@ -258,14 +271,21 @@ void GptDecoderBatch::newRequest(
outputIds->reshape(outputIdsShape);
dOutput = std::make_unique<DecodingOutput>(outputIds);
dOutput->finished = ITensor::slice(dJointOutput.finished, batchIdx, localBatchSize);
manager.setZero(*dOutput->finished);
dOutput->finishedSum = ITensor::slice(dJointOutput.finishedSum, batchIdx, localBatchSize);
manager.setZero(*dOutput->finishedSum);
dOutput->lengths = ITensor::slice(dJointOutput.lengths, batchIdx, localBatchSize);
kernels::invokeFill(*dOutput->lengths, inputLength, *stream);
dOutput->newTokens = ITensor::slice(dJointOutput.newTokens, batchIdx, localBatchSize);
manager.setZero(*dOutput->newTokens);
dOutput->newTokensVec.resize(mMaxTokensPerStep);
for (SizeType ti = 0; ti < mMaxTokensPerStep; ++ti)
{
TensorPtr newTokensStepView = std::move(ITensor::slice(dJointOutput.newTokensSteps, ti, localBatchSize));
newTokensStepView->squeeze(0);
dOutput->newTokensVec[ti] = ITensor::slice(newTokensStepView, batchIdx, localBatchSize);
manager.setZero(*dOutput->newTokensVec[ti]);
}
dOutput->finishedSteps = ITensor::slice(dJointOutput.finishedSteps, batchIdx, localBatchSize);
manager.setZero(*dOutput->finishedSteps);
dOutput->finishedSteps->squeeze(0);
// cumLogProb is mandatory for beamWidth > 1
dOutput->cumLogProbs = nullptr;
@ -293,12 +313,24 @@ void GptDecoderBatch::newRequest(
dOutput->beamHypotheses.init(manager, endId);
}
auto generatedTokensPerStep = request.generatedTokensPerStep();
if (generatedTokensPerStep > 1)
{
auto numDraftTokens = generatedTokensPerStep - 1;
TensorPtr draftTokensView = ITensor::view(request.draftTokens, ITensor::makeShape({1, 1, numDraftTokens}));
mDraftTokenIds[batchIdx] = draftTokensView;
auto numDraftTokensView = ITensor::slice(mNumDraftTokens, batchIdx, localBatchSize);
kernels::invokeFill(*numDraftTokensView, numDraftTokens, *stream);
}
// remaining
mDecoders[batchIdx]->setup(samplingConfig, localBatchSize, mMaxSequenceLength);
mBeamWidths[batchIdx] = beamWidth;
mNbSteps[batchIdx] = 0;
mFinished[batchIdx] = false;
mMaxNewTokens[batchIdx] = maxNewTokens;
mGeneratedTokensPerStep[batchIdx] = generatedTokensPerStep;
// copy the request ids into outputIds
auto inputIdsView = ITensor::view(requestIds, ITensor::makeShape({localBatchSize, inputLength}));
@ -312,14 +344,11 @@ GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync(
decoder_batch::Output& output, decoder_batch::Input const& input)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
auto& logits = input.logits;
auto const& logitsShape = logits->getShape();
auto& allLogits = input.logits;
TLLM_CHECK(logitsShape.d[0] == mActualBatchSize);
// TODO(nkorobov): check logits shape considering draft tokens
auto const& jointOutputIdsShape = mJointDecodingOutput->ids->getShape();
auto const maxBeamWidth = jointOutputIdsShape.d[1];
TLLM_CHECK(logitsShape.d[1] == maxBeamWidth);
TLLM_CHECK(static_cast<std::size_t>(logitsShape.d[2]) == mVocabSizePadded);
auto& srcCacheIndirection = input.cacheIndirection;
auto& tgtCacheIndirection = output.cacheIndirection;
@ -328,70 +357,83 @@ GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync(
TLLM_CHECK(!srcCacheIndirection || srcCacheIndirection->getDataType() == TRTDataType<SizeType>::value);
TLLM_CHECK(!tgtCacheIndirection || tgtCacheIndirection->getDataType() == TRTDataType<SizeType>::value);
TLLM_CHECK(static_cast<SizeType>(output.sequenceLengths->getSize()) == mActualBatchSize * maxBeamWidth);
// TODO should remove this reshape and set shape to [batch_size, beam_width] outside
TensorPtr sequenceLengths = ITensor::view(output.sequenceLengths);
sequenceLengths->reshape(ITensor::makeShape({mActualBatchSize, maxBeamWidth}));
TensorPtr sequenceLengths
= ITensor::view(output.sequenceLengths, ITensor::makeShape({mActualBatchSize, maxBeamWidth}));
TLLM_CHECK(sequenceLengths);
auto constexpr singleRequest = 1;
CudaEvent eventStart{};
mStream->record(eventStart);
for (std::int32_t i = 0; i < mActualBatchSize; ++i)
for (std::int32_t bi = 0; bi < mActualBatchSize; ++bi)
{
if (mFinished[i] || !input.active.at(i))
if (mFinished[bi] || !input.active.at(bi))
continue;
auto& stream = mStreams[i];
auto& logits = allLogits[bi];
auto const& logitsShape = logits->getShape();
TLLM_CHECK_WITH_INFO(logitsShape.d[0] == mGeneratedTokensPerStep[bi],
tc::fmtstr(
"First dim (%d) does not match generated tokens (%d)", logitsShape.d[0], mGeneratedTokensPerStep[bi]));
TLLM_CHECK_WITH_INFO(logitsShape.d[1] == mBeamWidths[bi],
tc::fmtstr("Second dim (%d) does not match beam width (%d)", logitsShape.d[1], mBeamWidths[bi]));
TLLM_CHECK(static_cast<std::size_t>(logitsShape.d[2]) == mVocabSizePadded);
auto& stream = mStreams[bi];
stream->wait(eventStart.get());
auto& dInput = *mDecodingInputs[i];
auto& dOutput = *mDecodingOutputs[i];
auto logitsView = std::shared_ptr(ITensor::slice(logits, i, singleRequest));
dInput.logits
= ITensor::view(logitsView, ITensor::makeShape({singleRequest, mBeamWidths[i], logitsShape.d[2]}));
auto& dInput = *mDecodingInputs[bi];
auto& dOutput = *mDecodingOutputs[bi];
auto& decoder = *mDecoders[bi];
if (srcCacheIndirection && tgtCacheIndirection)
{
auto srcView = std::shared_ptr(ITensor::slice(srcCacheIndirection, i, singleRequest));
auto tgtView = std::shared_ptr(ITensor::slice(tgtCacheIndirection, i, singleRequest));
dInput.cacheIndirection
= ITensor::view(srcView, ITensor::makeShape({singleRequest, mBeamWidths[i], srcView->getShape().d[2]}));
dOutput.cacheIndirection
= ITensor::view(tgtView, ITensor::makeShape({singleRequest, mBeamWidths[i], tgtView->getShape().d[2]}));
auto srcView = std::shared_ptr(ITensor::slice(srcCacheIndirection, bi, singleRequest));
auto tgtView = std::shared_ptr(ITensor::slice(tgtCacheIndirection, bi, singleRequest));
dInput.cacheIndirection = ITensor::view(
srcView, ITensor::makeShape({singleRequest, mBeamWidths[bi], srcView->getShape().d[2]}));
dOutput.cacheIndirection = ITensor::view(
tgtView, ITensor::makeShape({singleRequest, mBeamWidths[bi], tgtView->getShape().d[2]}));
}
auto sequenceLengthsView = std::shared_ptr(ITensor::slice(sequenceLengths, i, singleRequest));
dOutput.lengths = ITensor::view(sequenceLengthsView, ITensor::makeShape({singleRequest, mBeamWidths[i]}));
auto& decoder = *mDecoders[i];
decoder.forwardAsync(dOutput, dInput);
auto sequenceLengthsView = std::shared_ptr(ITensor::slice(sequenceLengths, bi, singleRequest));
dOutput.lengths = ITensor::view(sequenceLengthsView, ITensor::makeShape({singleRequest, mBeamWidths[bi]}));
auto manager = BufferManager{stream};
auto jointOutputIdsView = ITensor::slice(mJointDecodingOutput->ids, i, singleRequest);
auto const& jointOutputShape = jointOutputIdsView->getShape();
// squeeze dim 0 and set beamWidth
jointOutputIdsView->reshape(ITensor::makeShape({mBeamWidths[i], jointOutputShape.d[2]}));
manager.copy(*dOutput.ids, *jointOutputIdsView);
auto jointSequenceLengthsView = ITensor::slice(mJointDecodingOutput->lengths, i, singleRequest);
jointSequenceLengthsView->reshape(ITensor::makeShape({1, mBeamWidths[i]}));
manager.copy(*dOutput.lengths, *jointSequenceLengthsView);
if (mBeamWidths[i] > 1)
for (std::int32_t di = 0; di < mGeneratedTokensPerStep[bi]; ++di)
{
auto jointOutputParentIdsView = ITensor::slice(mJointDecodingOutput->parentIds, i, singleRequest);
auto const& jointOutputParentIdsShape = jointOutputParentIdsView->getShape();
// squeeze dim 0 and set beamWidth
jointOutputParentIdsView->reshape(ITensor::makeShape({mBeamWidths[i], jointOutputParentIdsShape.d[2]}));
dInput.logits = ITensor::slice(logits, di, singleRequest);
dOutput.newTokens = ITensor::view(dOutput.newTokensVec[di]);
dInput.finished = ITensor::slice(dOutput.finishedSteps, di, 1);
dOutput.finished
= ITensor::slice(dOutput.finishedSteps, std::min(di + 1, mGeneratedTokensPerStep[bi] - 1), 1);
manager.copy(*dOutput.parentIds, *jointOutputParentIdsView);
decoder.forwardAsync(dOutput, dInput);
mNbSteps[bi] += 1;
mFinished[bi] = mNbSteps[bi] >= mMaxNewTokens[bi];
dInput.step += 1;
}
if (mGeneratedTokensPerStep[bi] > 1)
{
auto draftTokenIds = mDraftTokenIds[bi];
auto numDraftTokens = ITensor::slice(mNumDraftTokens, bi, singleRequest);
// Update finished state for 0th step
auto finishedFinal = ITensor::slice(dOutput.finishedSteps, 0, 1);
IGptDecoder::acceptTokens(
/* [bs=1, bw=1, max_seq_len] */ *dOutput.ids,
/* [bs, bw, max_draft_tokens] */ *draftTokenIds,
/* [bs, bw] */ *dInput.lengths,
/* [bs, bw] */ *numDraftTokens,
/* [bs, bw] */ *dOutput.lengths,
/* [max_draft_tokens, bs, bw] */ *dOutput.finishedSteps,
/* [bs, bw] */ *finishedFinal,
/* [1] */ *dOutput.finishedSum, stream);
}
CudaEvent event{};
stream->record(event);
mStream->wait(event);
dInput.step += 1;
mNbSteps[i] += 1;
mFinished[i] = mNbSteps[i] >= mMaxNewTokens[i];
}
CudaEvent eventStop{};
@ -412,7 +454,7 @@ void GptDecoderBatch::forwardSync(decoder_batch::Token const& token)
auto& dOutput = *mDecodingOutputs[i];
mFinished[i] = mFinished[i]
// This condition requires the synchronization above
|| *bufferCast<SizeType>(*dOutput.finishedSum) == static_cast<SizeType>(dOutput.finished->getSize());
|| *bufferCast<SizeType>(*dOutput.finishedSum) == static_cast<SizeType>(dOutput.lengths->getSize());
}
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
@ -449,6 +491,7 @@ void GptDecoderBatch::newBatch(
// split batch into single requests
auto const& inputLengths = inputs.lengths;
mActualBatchSize = inputLengths->getShape().d[0];
mGeneratedTokensPerStep.resize(mActualBatchSize);
auto const& jointOutputIdsShape = mJointDecodingOutput->ids->getShape();
auto const maxBatchSize = jointOutputIdsShape.d[0];
@ -464,6 +507,7 @@ void GptDecoderBatch::newBatch(
auto inputOffset = 0;
for (auto batchIdx = 0; batchIdx < mActualBatchSize; ++batchIdx)
{
mGeneratedTokensPerStep[batchIdx] = 1;
auto const inputLength = inputLengthsPtr[batchIdx];
auto const inputShape = ITensor::makeShape({inputLength});
TensorPtr inputView;
@ -477,8 +521,7 @@ void GptDecoderBatch::newBatch(
inputView = ITensor::slice(inputs.ids, batchIdx, 1);
inputView->reshape(inputShape);
}
auto request = decoder_batch::Request{inputView, inputs.maxNewTokens, inputs.endId};
auto request = decoder_batch::Request{inputView, inputLength, inputs.maxNewTokens, inputs.endId};
request.computeCumLogProbs = (outputs.cumLogProbs != nullptr);
request.computeLogProbs = (outputs.logProbs != nullptr);
@ -515,7 +558,20 @@ void GptDecoderBatch::newBatch(
void GptDecoderBatch::forwardAsync(decoder::Output& output, decoder::Input const& input)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
decoder_batch::Input batchInput{input.logits};
auto const& logitsShape = input.logits->getShape();
auto const batchSize = logitsShape.d[0];
auto constexpr singleRequest = 1;
std::vector<ITensor::SharedConstPtr> logits;
logits.reserve(batchSize);
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
auto logitsSlice = std::shared_ptr(ITensor::slice(input.logits, batchIdx, singleRequest));
logits.emplace_back(
ITensor::view(logitsSlice, ITensor::makeShape({singleRequest, mBeamWidths[batchIdx], logitsShape.d[2]})));
}
decoder_batch::Input batchInput{logits};
batchInput.cacheIndirection = input.cacheIndirection;
decoder_batch::Output batchOutput;

View File

@ -102,6 +102,7 @@ GptJsonConfig parseJson(InputType&& i)
auto const maxBatchSize = parseJsonFieldOr(builderConfig, "max_batch_size", 0);
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
auto const maxOutputLen = parseJsonFieldOr(builderConfig, "max_output_len", 0);
auto const maxDraftLen = parseJsonFieldOr(builderConfig, "max_draft_len", 0);
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(builderConfig, "max_num_tokens");
auto const maxPromptEmbeddingTableSize
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
@ -132,6 +133,7 @@ GptJsonConfig parseJson(InputType&& i)
modelConfig.setMaxInputLen(maxInputLen);
modelConfig.setMaxOutputLen(maxOutputLen);
modelConfig.setMaxNumTokens(maxNumTokens);
modelConfig.setMaxDraftLen(maxDraftLen);
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
if (name == std::string("chatglm_6b") || name == std::string("glm_10b"))

View File

@ -133,7 +133,9 @@ void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType
mDecoders.emplace_back(std::make_shared<GptDecoderBatch>(vocabSize, vocabSizePadded, stream));
else
mDecoders.emplace_back(std::make_shared<StatefulGptDecoder>(vocabSize, vocabSizePadded, stream));
mDecoders.back()->setup(batchSize, beamWidth, maxKvCacheLength, maxSequenceLength, logitsType);
constexpr SizeType maxTokensPerStep = 1;
mDecoders.back()->setup(
batchSize, beamWidth, maxKvCacheLength, maxSequenceLength, maxTokensPerStep, logitsType);
}
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);

View File

@ -62,9 +62,10 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS
}
void StatefulGptDecoder::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
SizeType maxSequenceLength, nvinfer1::DataType dtype)
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK(maxTokensPerStep == 1);
mDecoder = IGptDecoder::create(dtype, mVocabSize, mVocabSizePadded, mStream);
reshapeBuffers(maxBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength);
@ -102,6 +103,7 @@ void StatefulGptDecoder::reshapeBuffers(
mBufferManager.setZero(*dOutput.newTokens);
dOutput.parentIds->reshape(outputIdsShape);
dOutput.finished->reshape(batchSizeXbeamWidth);
dInput.finished = ITensor::view(dOutput.finished);
mBufferManager.setZero(*dOutput.finished);
mBufferManager.setZero(*dOutput.finishedSum);

View File

@ -40,7 +40,7 @@ public:
//! Setup the decoder before calling `forward()`
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength,
nvinfer1::DataType dtype) override;
SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
//! @brief Initialize the decoder with new batch of inputs.
void newBatch(
@ -53,6 +53,7 @@ public:
//! @brief Gather final results for all requests.
void finalize() const override;
//! @param step index within tokens generated in one step
//! @returns [batchSize, maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
//! ids without padding, on gpu
[[nodiscard]] TensorPtr getOutputIds() const override
@ -72,12 +73,24 @@ public:
return mDecodingOutput->logProbs;
}
//! @returns [batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
[[nodiscard]] TensorPtr getNewTokens() const override
//! @brief Get tokens generated in one step of last forward pass
//! @param iter The iteration within [0; maxTokensPerStep) for which to get the tokens
//! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu
[[nodiscard]] TensorPtr getNewTokens(SizeType iter = 0) const override
{
TLLM_CHECK(iter == 0);
return mDecodingOutput->newTokens;
}
//! @brief Get tokens generated in the last forward pass
//! @returns [batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
[[nodiscard]] TensorPtr getAllNewTokens() const override
{
TensorPtr newTokens = std::move(ITensor::view(mDecodingOutput->newTokensSteps));
newTokens->unsqueeze(0);
return newTokens;
}
//! @returns [1], number of finished sequences, in pinned host memory
[[nodiscard]] TensorPtr getNbFinished() const override
{

View File

@ -23,6 +23,7 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/core/DeviceType.h>
#include <cuda_runtime.h>
#include <torch/types.h>
@ -111,6 +112,17 @@ public:
}
}
static at::DeviceType deviceType(runtime::MemoryType memoryType)
{
switch (memoryType)
{
case runtime::MemoryType::kGPU: return c10::kCUDA;
case runtime::MemoryType::kCPU: [[fallthrough]];
case runtime::MemoryType::kPINNED: [[fallthrough]];
default: return c10::kCPU;
}
}
static at::cuda::CUDAStream stream(runtime::CudaStream& cudaStream)
{
return at::cuda::getStreamFromExternal(cudaStream.get(), static_cast<at::DeviceIndex>(cudaStream.getDevice()));

View File

@ -144,7 +144,8 @@ void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width,
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
th::optional<th::Tensor> src_cache_indirection_opt,
// Outputs
th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop, th::optional<th::Tensor> finished_opt,
th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop,
th::optional<th::Tensor> finished_input, th::optional<th::Tensor> finished_output,
th::optional<th::Tensor> sequence_lengths_opt, th::optional<th::Tensor> cum_log_probs_opt,
th::optional<th::Tensor> output_log_probs_opt, th::optional<th::Tensor> parent_ids_opt,
th::optional<th::Tensor> tgt_cache_indirection_opt, th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
@ -166,12 +167,13 @@ void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width,
safeUpdate<int>(bad_words_list_opt, forwardParams.bad_words_list);
safeUpdate<int>(stop_words_list_opt, forwardParams.stop_words_list);
safeUpdate<int>(no_repeat_ngram_size_opt, forwardParams.no_repeat_ngram_size);
safeUpdate<int>(finished_input, forwardParams.finished);
auto const& output_ids_converted = convert_tensor<int>(output_token_ids);
typename tensorrt_llm::layers::DynamicDecodeLayer<T>::OutputParams outputParams{output_ids_converted};
outputParams.newTokens = std::move(convert_tensor<int>(newTokens));
safeUpdate<bool>(finished_opt, outputParams.finished);
safeUpdate<bool>(finished_output, outputParams.finished);
std::int32_t* finished_sum_host = nullptr;
if (forwardParams.sequence_limit_length && outputParams.finished.has_value())
{
@ -280,7 +282,8 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
th::optional<th::Tensor> src_cache_indirection_opt,
// output buffers.
th::Tensor output_token_ids, th::Tensor newTokens, th::optional<th::Tensor> finished_opt,
th::Tensor output_token_ids, th::Tensor newTokens, th::optional<th::Tensor> finished_input,
th::optional<th::Tensor> finished_output,
th::optional<th::Tensor> seuqence_lengths_opt, // length of the current sequences.
th::optional<th::Tensor> cum_log_probs_opt, th::optional<th::Tensor> output_log_probs_opt,
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,
@ -329,7 +332,8 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max
CHECK_OPTIONAL_INPUT(src_cache_indirection_opt, torch::kInt32);
CHECK_INPUT(output_token_ids, torch::kInt32);
CHECK_OPTIONAL_INPUT(finished_opt, torch::kBool);
CHECK_OPTIONAL_INPUT(finished_input, torch::kBool);
CHECK_OPTIONAL_INPUT(finished_output, torch::kBool);
CHECK_OPTIONAL_INPUT(seuqence_lengths_opt, torch::kInt32);
CHECK_OPTIONAL_INPUT(cum_log_probs_opt, torch::kFloat32);
CHECK_OPTIONAL_INPUT(output_log_probs_opt, torch::kFloat32);
@ -345,11 +349,11 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max
sequence_limit_length_opt, stop_words_list_opt, bad_words_list_opt, no_repeat_ngram_size_opt,
src_cache_indirection_opt,
// Outputs
output_token_ids, newTokens, should_stop, finished_opt, seuqence_lengths_opt, cum_log_probs_opt,
output_log_probs_opt, parent_ids_opt, tgt_cache_indirection_opt, beam_hyps_output_ids_tgt_opt,
beam_hyps_sequence_lengths_tgt_opt, beam_hyps_cum_log_probs_opt, beam_hyps_normed_scores_opt,
beam_hyps_log_probs_opt, beam_hyps_min_normed_scores_opt, beam_hyps_num_beams_opt, beam_hyps_is_done_opt,
use_beam_hyps);
output_token_ids, newTokens, should_stop, finished_input, finished_output, seuqence_lengths_opt,
cum_log_probs_opt, output_log_probs_opt, parent_ids_opt, tgt_cache_indirection_opt,
beam_hyps_output_ids_tgt_opt, beam_hyps_sequence_lengths_tgt_opt, beam_hyps_cum_log_probs_opt,
beam_hyps_normed_scores_opt, beam_hyps_log_probs_opt, beam_hyps_min_normed_scores_opt, beam_hyps_num_beams_opt,
beam_hyps_is_done_opt, use_beam_hyps);
return should_stop;
}

View File

@ -46,10 +46,10 @@ public:
th::optional<th::Tensor> src_cache_indirection_opt,
// Outputs
th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop,
th::optional<th::Tensor> finished_opt, th::optional<th::Tensor> sequence_lengths_opt,
th::optional<th::Tensor> cum_log_probs_opt, th::optional<th::Tensor> output_log_probs_opt,
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,
th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
th::optional<th::Tensor> finished_input, th::optional<th::Tensor> finished_output,
th::optional<th::Tensor> sequence_lengths_opt, th::optional<th::Tensor> cum_log_probs_opt,
th::optional<th::Tensor> output_log_probs_opt, th::optional<th::Tensor> parent_ids_opt,
th::optional<th::Tensor> tgt_cache_indirection_opt, th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
th::optional<th::Tensor> beam_hyps_sequence_lengths_tgt_opt,
th::optional<th::Tensor> beam_hyps_cum_log_probs_opt, th::optional<th::Tensor> beam_hyps_normed_scores_opt,
th::optional<th::Tensor> beam_hyps_log_probs_opt, th::optional<th::Tensor> beam_hyps_min_normed_scores_opt,
@ -84,10 +84,10 @@ public:
th::optional<th::Tensor> src_cache_indirection_opt,
// Outputs
th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop,
th::optional<th::Tensor> finished_opt, th::optional<th::Tensor> sequence_lengths_opt,
th::optional<th::Tensor> cum_log_probs_opt, th::optional<th::Tensor> output_log_probs_opt,
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,
th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
th::optional<th::Tensor> finished_input, th::optional<th::Tensor> finished_output,
th::optional<th::Tensor> sequence_lengths_opt, th::optional<th::Tensor> cum_log_probs_opt,
th::optional<th::Tensor> output_log_probs_opt, th::optional<th::Tensor> parent_ids_opt,
th::optional<th::Tensor> tgt_cache_indirection_opt, th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
th::optional<th::Tensor> beam_hyps_sequence_lengths_tgt_opt,
th::optional<th::Tensor> beam_hyps_cum_log_probs_opt, th::optional<th::Tensor> beam_hyps_normed_scores_opt,
th::optional<th::Tensor> beam_hyps_log_probs_opt, th::optional<th::Tensor> beam_hyps_min_normed_scores_opt,
@ -128,7 +128,8 @@ public:
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
th::optional<th::Tensor> src_cache_indirection_opt,
// output buffers.
th::Tensor output_token_ids, th::Tensor newTokens, th::optional<th::Tensor> finished_opt,
th::Tensor output_token_ids, th::Tensor newTokens, th::optional<th::Tensor> finished_input,
th::optional<th::Tensor> finished_output,
th::optional<th::Tensor> seuqence_lengths_opt, // length of the current sequences.
th::optional<th::Tensor> cum_log_probs_opt, th::optional<th::Tensor> output_log_probs_opt,
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,

View File

@ -87,6 +87,7 @@ set(SAMPLING_KERNEL_TEST_SRC
kernels/sampling/samplingUtilsTest.cu)
add_gtest(samplingKernelsTest "${SAMPLING_KERNEL_TEST_SRC}")
add_gtest(weightOnlyKernelTest kernels/weightOnly/weightOnlyKernelTest.cpp)
add_gtest(decodingKernelsTest kernels/decodingKernelTest.cpp)
if(BUILD_BATCH_MANAGER)
add_subdirectory(batch_manager)

View File

@ -0,0 +1,160 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef TOP_LEVEL_DIR
#error "Define TOP_LEVEL_DIR"
#endif
#include <gtest/gtest.h>
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include <random>
namespace tk = tensorrt_llm::kernels;
using namespace tensorrt_llm::runtime;
namespace
{
void runAcceptedTokensTest(SizeType seed)
{
constexpr SizeType batchSize{8};
constexpr SizeType beamWidth{1};
constexpr SizeType maxSeqLen{16};
constexpr SizeType vocabSize{32};
constexpr SizeType maxDraftTokens{8};
auto stream = std::make_shared<CudaStream>();
BufferManager manager(stream);
std::mt19937 generator(seed);
std::uniform_int_distribution<int> contextLenDistr(0, maxSeqLen - maxDraftTokens);
std::uniform_int_distribution<int> numDraftTokensDistr(1, maxDraftTokens);
std::uniform_int_distribution<int> vocabDistr(1, vocabSize);
std::uniform_real_distribution<float> acceptTokenDistr(0.f, 1.f);
auto draftTokens
= manager.pinned(ITensor::makeShape({batchSize, beamWidth, maxDraftTokens}), nvinfer1::DataType::kINT32);
auto targetTokens
= manager.pinned(ITensor::makeShape({batchSize, beamWidth, maxSeqLen}), nvinfer1::DataType::kINT32);
auto numsDraftTokens = manager.pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32);
auto sequenceLengths = manager.pinned(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kINT32);
auto contextLengths = manager.pinned(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kINT32);
auto finishedSteps
= manager.pinned(ITensor::makeShape({maxDraftTokens, batchSize, beamWidth}), nvinfer1::DataType::kBOOL);
auto finishedFinal = manager.pinned(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kBOOL);
auto finishedSum = manager.pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
std::vector<int> acceptedLen(batchSize * beamWidth);
std::vector<bool> acceptedFinished(batchSize * beamWidth);
auto sequenceLengthsPtr = bufferCast<SizeType>(*sequenceLengths);
auto contextLengthsPtr = bufferCast<SizeType>(*contextLengths);
auto numsDraftTokensPtr = bufferCast<SizeType>(*numsDraftTokens);
auto draftTokensPtr = bufferCast<SizeType>(*draftTokens);
auto targetTokensPtr = bufferCast<SizeType>(*targetTokens);
auto finishedStepsPtr = bufferCast<bool>(*finishedSteps);
auto finishedFinalPtr = bufferCast<bool>(*finishedFinal);
auto finishedSumPtr = bufferCast<SizeType>(*finishedSum);
for (SizeType bi = 0; bi < batchSize; ++bi)
{
numsDraftTokensPtr[bi] = numDraftTokensDistr(generator);
}
for (SizeType bi = 0; bi < batchSize * beamWidth; ++bi)
{
const SizeType batchIdx = bi / beamWidth;
// Randomly init context len
contextLengthsPtr[bi] = contextLenDistr(generator);
// Sequence len is at most numsDraftTokensPtr[bi] away from context len (it can be closer if e.g. endId is
// generated)
std::uniform_int_distribution<int> realDraftTokensDistr(0, numsDraftTokensPtr[batchIdx]);
const auto realLen = realDraftTokensDistr(generator);
sequenceLengthsPtr[bi] = contextLengthsPtr[bi] + realLen;
for (int i = 0; i < realLen; ++i)
{
finishedStepsPtr[i * batchSize * beamWidth + bi] = false;
}
for (int i = realLen; i <= numsDraftTokensPtr[batchIdx]; ++i)
{
finishedStepsPtr[i * batchSize * beamWidth + bi] = true;
}
// Init helper vector with max value
acceptedLen[bi] = sequenceLengthsPtr[bi];
acceptedFinished[bi] = finishedStepsPtr[realLen * batchSize * beamWidth + bi];
}
// Fill token arrays
for (SizeType bi = 0; bi < batchSize * beamWidth; ++bi)
{
// Draft: [d0, d1, d2, ... for numsDraftTokensPtr[bi] ... , dN]
// Target: [vocabSize + 1, vocabSize + 1, ... for contextLengthsPtr[bi] ... vocabSize + 1,
// t0, t1, t2, ... for numsDraftTokensPtr[bi] ... , tN,
// vocabSize + 1, vocabSize + 1, .. to maxSeqLen]
for (SizeType si = 0; si < contextLengthsPtr[bi]; ++si)
{
targetTokensPtr[bi * maxSeqLen + si] = vocabSize + 1;
}
for (SizeType si = contextLengthsPtr[bi]; si < sequenceLengthsPtr[bi]; ++si)
{
const auto draftToken = vocabDistr(generator);
const auto draftTokenIdx = si - contextLengthsPtr[bi];
const auto targetToken
= acceptTokenDistr(generator) < 1.f / (draftTokenIdx + 1e-6) ? draftToken : vocabDistr(generator);
draftTokensPtr[bi * maxDraftTokens + draftTokenIdx] = draftToken;
targetTokensPtr[bi * maxSeqLen + si] = targetToken;
if (draftToken != targetToken)
{
acceptedLen[bi] = std::min(acceptedLen[bi], std::min(si + 1, maxSeqLen));
acceptedFinished[bi] = finishedStepsPtr[draftTokenIdx * batchSize * beamWidth + bi];
}
}
for (SizeType si = sequenceLengthsPtr[bi]; si < maxSeqLen; ++si)
{
targetTokensPtr[bi * maxSeqLen + si] = vocabSize + 1;
}
}
// Call function
tk::invokeAcceptTokens(draftTokensPtr, targetTokensPtr, contextLengthsPtr, numsDraftTokensPtr, sequenceLengthsPtr,
finishedStepsPtr, finishedFinalPtr, finishedSumPtr, batchSize, beamWidth, maxSeqLen, maxDraftTokens,
stream->get());
stream->synchronize();
// Verify seqLen for accepted tokens
int finishedSumRef = 0;
for (SizeType bi = 0; bi < batchSize * beamWidth; ++bi)
{
EXPECT_EQ(acceptedLen[bi], sequenceLengthsPtr[bi]) << " bi " << bi << " seed " << seed;
EXPECT_EQ(acceptedFinished[bi], finishedFinalPtr[bi]) << " bi " << bi << " seed " << seed;
finishedSumRef += static_cast<SizeType>(acceptedFinished[bi]);
}
EXPECT_EQ(finishedSumRef, finishedSumPtr[0]);
}
TEST(DecodingKernelsTest, acceptTokensKernel)
{
constexpr SizeType seeds = 64;
for (SizeType seed = 0; seed < seeds; ++seed)
{
runAcceptedTokensTest(seed);
}
}
} // end of namespace

View File

@ -44,7 +44,7 @@ protected:
{
size_t workspaceSize;
tk::invokeTopKSampling<T>(nullptr, workspaceSize, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
this->mMaxTopK, 1.0f, params.vocabSize, nullptr, this->mStream->get(), params.batchSize, nullptr);
nullptr, this->mMaxTopK, 1.0f, params.vocabSize, nullptr, this->mStream->get(), params.batchSize, nullptr);
return workspaceSize;
}
@ -59,8 +59,8 @@ protected:
// preprocesses log_prob_buf when those are provided.
bufferCast<T>(*this->mProbsDevice), bufferCast<int*>(*this->mIdsPtrHost),
bufferCast<int32_t>(*this->mSeqLengthsDevice), bufferCast<bool>(*this->mFinishedDevice),
bufferCast<float>(*this->mCumLogProbsDevice), bufferCast<float>(*this->mOutputLogProbsDevice),
this->mCurandStatesDevice, this->mMaxTopK,
bufferCast<bool>(*this->mFinishedDevice), bufferCast<float>(*this->mCumLogProbsDevice),
bufferCast<float>(*this->mOutputLogProbsDevice), this->mCurandStatesDevice, this->mMaxTopK,
hasDiffRuntimeArgs ? bufferCast<int32_t>(*this->mTopKsDevice) : nullptr, params.topP,
hasDiffRuntimeArgs ? bufferCast<float>(*this->mTopPsDevice) : nullptr, params.vocabSize,
bufferCast<int32_t>(*this->mEndIdsDevice), this->mStream->get(), params.batchSize,

View File

@ -50,6 +50,7 @@ private:
nullptr, // output_ids
nullptr, // sequence_length
nullptr, // finished_buffer
nullptr, // finished_buffer
nullptr, // cum_log_probs
nullptr, // output_log_probs
nullptr, // log_probs
@ -69,6 +70,7 @@ private:
nullptr, // output_ids
nullptr, // sequence_length
nullptr, // finished_buffer
nullptr, // finished_buffer
nullptr, // cum_log_probs
nullptr, // output_log_probs
nullptr, // log_probs
@ -85,8 +87,8 @@ private:
// Perform batched TopP sampling
tk::invokeBatchTopPSampling<T>(workspaceDevice->data(), workspaceSize, cubTempStorageSize,
bufferCast<int*>(*this->mIdsPtrHost), bufferCast<int32_t>(*this->mSeqLengthsDevice),
bufferCast<bool>(*this->mFinishedDevice), bufferCast<float>(*this->mCumLogProbsDevice),
bufferCast<float>(*this->mOutputLogProbsDevice),
bufferCast<bool>(*this->mFinishedDevice), bufferCast<bool>(*this->mFinishedDevice),
bufferCast<float>(*this->mCumLogProbsDevice), bufferCast<float>(*this->mOutputLogProbsDevice),
// Note that the kernel needs vocab probs instead of
// log-prob if cum_log_probs or output_log_probs are
// provided. It's because the sampling layer already

View File

@ -83,8 +83,8 @@ float benchmark_perchannel(void* act, void* weight, void* scales, void* zeros, v
cudaEventCreate(&end);
if constexpr (std::is_same_v<KernelFlag, CudaKernel>)
{
WeightOnlyParams params{reinterpret_cast<uint8_t*>(weight), scales, zeros, act, bias, out, m, n, k, group_size,
BFlag, WeightOnlyType::PerChannel, WeightOnlyActivationFunctionType::Identity, AFlag};
WeightOnlyParams params{reinterpret_cast<uint8_t*>(weight), scales, zeros, act, nullptr, bias, out, m, n, k,
group_size, BFlag, WeightOnlyType::PerChannel, WeightOnlyActivationFunctionType::Identity, AFlag};
for (int i = 0; i < warmup; ++i)
{
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s);
@ -164,8 +164,8 @@ float benchmark_groupwise(void* act, void* weight, void* scales, void* zeros, vo
cudaEventCreate(&end);
if constexpr (std::is_same_v<KernelFlag, CudaKernel>)
{
WeightOnlyParams params{reinterpret_cast<uint8_t*>(weight), scales, zeros, act, bias, out, m, n, k, group_size,
BFlag, WeightOnlyType::GroupWise, WeightOnlyActivationFunctionType::Identity, AFlag};
WeightOnlyParams params{reinterpret_cast<uint8_t*>(weight), scales, zeros, act, nullptr, bias, out, m, n, k,
group_size, BFlag, WeightOnlyType::GroupWise, WeightOnlyActivationFunctionType::Identity, AFlag};
for (int i = 0; i < warmup; ++i)
{
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s);

View File

@ -28,8 +28,7 @@ resources_dir = _pl.Path(
__file__).parent.parent.parent.parent.parent / "examples/chatglm"
sys.path.insert(0, str(resources_dir))
engine_target_path = _pl.Path(
__file__).parent.parent / "models/rt_engine/chatglm"
engine_target_path = _pl.Path(__file__).parent.parent / "models/rt_engine"
import build as _ecb
@ -67,7 +66,10 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
model_name_list = ["chatglm_6b", "chatglm2_6b", "chatglm3_6b"]
hf_dir_list = [resources_dir / model_name for model_name in model_name_list]
trt_dir = resources_dir / "trtModel"
trt_dir_list = [
resources_dir / ("output_" + model_name)
for model_name in model_name_list
]
run_command(
["pip", "install", "-r",
@ -89,14 +91,17 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
)
print("\nBuilding engines")
for model_name, hf_dir in zip(model_name_list, hf_dir_list):
for model_name, hf_dir, trt_dir in zip(model_name_list, hf_dir_list,
trt_dir_list):
print("Building %s" % model_name)
build_engine(model_name, hf_dir, trt_dir, world_size)
if not _Path(engine_target_path).exists():
_Path(engine_target_path).mkdir(parents=True, exist_ok=True)
for file in _Path(trt_dir).glob("*"):
_shutil.move(file, engine_target_path)
for model_name in model_name_list:
_shutil.move(
_Path(resources_dir) / ("output_" + model_name),
engine_target_path / model_name)
print("Done.")

View File

@ -147,7 +147,8 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
engine_dir / 'fp16-plugin-packed-paged' / tp_pp_dir, tp_size,
'--dtype=float16', '--use_gpt_attention_plugin=float16',
'--remove_input_padding', '--paged_kv_cache',
'--enable_context_fmha', '--max_num_tokens=10000')
'--enable_context_fmha_fp32_acc', '--max_num_tokens=10000',
'--max_draft_len=5')
print("Done.")

View File

@ -48,11 +48,12 @@ def generate(model_name, batch_size, beam_width):
args.input_text += args.input_text[0] * (batch_size - 2)
args.beam_width = beam_width
args.tokenizer_dir = resources_dir / model_name
args.engine_dir = Path(__file__).parent.parent / "models/rt_engine/chatglm"
args.engine_dir = Path(__file__).parent.parent / ("models/rt_engine/" +
model_name)
tensorrt_llm.logger.set_level(args.log_level)
config_path = Path(args.engine_dir) / (model_name + '-config.json')
config_path = args.engine_dir / 'config.json'
with open(config_path, 'r') as f:
config = json.load(f)
assert (config['builder_config']['name'] == model_name)
@ -226,6 +227,4 @@ if __name__ == '__main__':
generate("chatglm3_6b", batch_size=1, beam_width=1)
generate("chatglm3_6b", batch_size=2, beam_width=1)
generate("chatglm3_6b", batch_size=1, beam_width=2)
#generate("glm_10b", batch_size=1, beam_width=1)
#generate("glm_10b", batch_size=2, beam_width=1)
print("Done.")

View File

@ -14,6 +14,9 @@
* limitations under the License.
*/
#include <algorithm>
#include <random>
#include <gmock/gmock-matchers.h>
#include <gtest/gtest.h>
@ -31,27 +34,109 @@ namespace tc = tensorrt_llm::common;
namespace
{
void verifyResults(BufferManager& manager, GptDecoderBatch const& decoder,
std::vector<SamplingConfig> const& samplingConfigs, std::vector<SizeType> const& inputLengths, SizeType batchSize,
SizeType maxBeamWidth, SizeType maxSeqLength, SizeType nbNewTokens, int tokenId, int padId)
decoder_batch::Input prepareDecoderInputs(SizeType batchSize, SizeType maxBeamWidth, SizeType maxSeqLength,
SizeType vocabSizePadded, nvinfer1::DataType dataType, std::vector<SamplingConfig> const& samplingConfigs,
std::vector<SizeType> const& generatedTokensPerSteps, BufferManager& manager)
{
auto sequenceLengths = decoder.getOutputLengths();
ASSERT_TRUE(sequenceLengths);
EXPECT_EQ(sequenceLengths->getSize(), batchSize * maxBeamWidth);
auto sequenceLengthsHost = manager.copyFrom(*sequenceLengths, MemoryType::kCPU);
auto sequenceLengthsPtr = bufferCast<SizeType>(*sequenceLengthsHost);
manager.getStream().synchronize();
for (auto b = 0; b < batchSize; ++b)
std::vector<decoder_batch::Input::TensorPtr> logits;
logits.reserve(batchSize);
auto constexpr tokenId = 1;
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
auto samplingConfig = samplingConfigs[b];
for (auto bw = 0; bw < samplingConfig.beamWidth; ++bw)
{
auto index = tc::flat_index(sequenceLengths->getShape().d, b, bw);
EXPECT_EQ(sequenceLengthsPtr[index], inputLengths[b] + nbNewTokens);
}
auto const beamWidth = samplingConfigs[batchIdx].beamWidth;
logits.emplace_back(
manager.gpu(ITensor::makeShape({generatedTokensPerSteps[batchIdx], beamWidth, vocabSizePadded}), dataType));
manager.setZero(*logits.back());
}
decoder_batch::Input inputs{logits};
if (maxBeamWidth > 1)
{
auto srcCacheIndirection
= manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType<SizeType>::value);
manager.setZero(*srcCacheIndirection);
inputs.cacheIndirection = std::move(srcCacheIndirection);
}
return inputs;
}
decoder_batch::Output prepareDecoderOutputs(SizeType batchSize, SizeType maxBeamWidth, SizeType maxSeqLength,
std::vector<SizeType> const& tiledInputLengths, BufferManager& manager)
{
decoder_batch::Output outputs{};
auto sequenceLengths
= manager.copyFrom(tiledInputLengths, ITensor::makeShape({batchSize, maxBeamWidth}), MemoryType::kGPU);
outputs.sequenceLengths = std::move(sequenceLengths);
if (maxBeamWidth > 1)
{
auto tgtCacheIndirection
= manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType<SizeType>::value);
manager.setZero(*tgtCacheIndirection);
outputs.cacheIndirection = std::move(tgtCacheIndirection);
}
return outputs;
}
std::vector<decoder_batch::Request> prepareRequests(SizeType batchSize, SizeType maxNewTokens,
std::vector<SizeType> const& inputLengths, std::vector<SizeType> const& generatedTokensPerSteps,
std::vector<SizeType> const& acceptedTokensPerStep, TokenIdType tokenId, TokenIdType endId, TokenIdType padId,
bool computeLogProbs, BufferManager& manager)
{
auto& stream = manager.getStream();
std::vector<decoder_batch::Request> requests;
requests.reserve(batchSize);
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
auto shape = ITensor::makeShape({inputLengths[batchIdx]});
auto input = manager.gpu(shape, TRTDataType<SizeType>::value);
kernels::invokeFill(*input, tokenId, stream);
requests.emplace_back(decoder_batch::Request{std::move(input), inputLengths[batchIdx], maxNewTokens, endId});
if (generatedTokensPerSteps[batchIdx] > 1)
{
std::vector<TokenIdType> draftTokens(generatedTokensPerSteps[batchIdx] - 1);
std::fill(draftTokens.begin(), draftTokens.begin() + acceptedTokensPerStep[batchIdx], 1023);
requests.back().draftTokens = manager.copyFrom(draftTokens, MemoryType::kGPU);
}
requests.back().computeCumLogProbs = computeLogProbs;
requests.back().computeLogProbs = computeLogProbs;
}
return requests;
}
void advanceSequenceLengths(std::vector<SizeType>& sequenceLengths, std::vector<SizeType> const& acceptedTokensPerStep,
std::vector<SamplingConfig> const& samplingConfigs, SizeType batchSize, SizeType maxBeamWidth)
{
for (int batchIdx = 0; batchIdx < batchSize; batchIdx++)
{
for (int beamId = 0; beamId < samplingConfigs.at(batchIdx).beamWidth; beamId++)
{
sequenceLengths.at(tc::flat_index2(batchIdx, beamId, maxBeamWidth))
+= acceptedTokensPerStep.at(batchIdx) + 1;
}
}
}
void checkSequenceLengths(
ITensor const& sequenceLengths, std::vector<SizeType> const& expectedLengths, BufferManager& manager)
{
auto sequenceLengthsHost = manager.copyFrom(sequenceLengths, MemoryType::kCPU);
auto sequenceLengthsHostRange = BufferRange<SizeType>(*sequenceLengthsHost);
EXPECT_THAT(sequenceLengthsHostRange, ::testing::ElementsAreArray(expectedLengths));
}
void verifyResults(BufferManager& manager, GptDecoderBatch const& decoder,
std::vector<SamplingConfig> const& samplingConfigs, std::vector<SizeType> const& inputLengths,
std::vector<SizeType> const& sequenceLengths, SizeType batchSize, SizeType maxBeamWidth, SizeType maxSeqLength,
SizeType tokenId, SizeType padId)
{
auto outputsIds = decoder.getOutputIds();
// TODO: test parentIds
// parentIds = decoder.getParentIds();
@ -68,30 +153,33 @@ void verifyResults(BufferManager& manager, GptDecoderBatch const& decoder,
for (auto b = 0; b < batchSize; ++b)
{
auto samplingConfig = samplingConfigs[b];
auto samplingConfig = samplingConfigs.at(b);
for (auto bw = 0; bw < samplingConfig.beamWidth; ++bw)
{
auto const result = (samplingConfig.beamWidth == 1) ? 1023 : bw;
auto const outputPtr = output + tc::flat_index(outputShape.d, b, bw, 0);
auto begin = outputPtr;
auto end = outputPtr + inputLengths[b];
auto end = outputPtr + inputLengths.at(b);
ASSERT_LE(begin, end) << "bad input length " << inputLengths.at(b);
ASSERT_THAT(std::vector(begin, end), ::testing::Each(tokenId)) << "input tokens: "
<< "b:" << b << " bw: " << bw;
begin = end;
end = begin + nbNewTokens;
end = outputPtr + sequenceLengths.at(tc::flat_index2(b, bw, maxBeamWidth));
ASSERT_LE(begin, end) << "bad seq length " << sequenceLengths.at(b);
ASSERT_THAT(std::vector(begin, end), ::testing::Each(result)) << "new tokens: "
<< "b:" << b << " bw: " << bw;
begin = end;
end = outputPtr + maxSeqLength;
ASSERT_LE(begin, end) << "bad max length " << maxSeqLength;
ASSERT_THAT(std::vector(begin, end), ::testing::Each(padId)) << "padding: "
<< "b:" << b << " bw: " << bw;
}
}
}
void testDecoder(nvinfer1::DataType const dtype, std::vector<SamplingConfig> const& samplingConfigs, int maxBeamWidth,
bool computeLogProbs)
void testDecoder(nvinfer1::DataType const dtype, std::vector<SamplingConfig> const& samplingConfigs,
SizeType maxBeamWidth, bool computeLogProbs)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
SizeType constexpr tensorParallelism{1};
@ -109,118 +197,99 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector<SamplingConfig> con
auto streamPtr = std::make_shared<CudaStream>();
BufferManager manager(streamPtr);
// create decoder
int constexpr endId{50257};
int constexpr padId{50257};
TokenIdType constexpr endId{50257};
TokenIdType constexpr padId{50257};
auto const dataType = modelConfig.getDataType();
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
auto decoder = GptDecoderBatch(vocabSize, vocabSizePadded, streamPtr);
// setup decoder
auto const batchSize = static_cast<SizeType>(samplingConfigs.size());
SizeType constexpr maxInputLength{8};
SizeType constexpr maxNewTokens{2};
auto constexpr maxSeqLength = maxInputLength + maxNewTokens;
// We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache).
auto const maxKvCacheLength = maxSeqLength;
SizeType const maxNewTokens{2};
auto const maxSeqLength = maxInputLength + maxNewTokens;
SizeType constexpr maxGeneratedTokensPerStep{1};
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, modelConfig.getDataType());
std::vector<SizeType> inputLengths(batchSize);
std::iota(inputLengths.begin(), inputLengths.end(), 4);
std::vector<SizeType> const inputLengths{4, 5, 6, 7};
std::vector<SizeType> tiledInputLengths;
for (int batch_id = 0; batch_id < inputLengths.size(); batch_id++)
for (int batchIdx = 0; batchIdx < inputLengths.size(); batchIdx++)
{
for (int beam_id = 0; beam_id < maxBeamWidth; beam_id++)
for (int beamId = 0; beamId < maxBeamWidth; beamId++)
{
tiledInputLengths.push_back(inputLengths.at(batch_id));
tiledInputLengths.push_back(inputLengths.at(batchIdx));
}
}
// set up inputs
auto logits = std::shared_ptr(
manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, vocabSizePadded}), modelConfig.getDataType()));
manager.setZero(*logits);
decoder_batch::Input inputs{logits};
if (maxBeamWidth > 1)
std::vector<SizeType> generatedTokensPerSteps(batchSize);
std::vector<SizeType> acceptedTokensPerStep(batchSize);
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
auto srcCacheIndirection = std::shared_ptr(
manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType<SizeType>::value));
manager.setZero(*srcCacheIndirection);
inputs.cacheIndirection = srcCacheIndirection;
generatedTokensPerSteps[batchIdx] = maxGeneratedTokensPerStep;
acceptedTokensPerStep[batchIdx] = generatedTokensPerSteps[batchIdx] - 1;
}
// set up outputs
decoder_batch::Output outputs{};
if (maxBeamWidth > 1)
{
auto tgtCacheIndirection = std::shared_ptr(
manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType<SizeType>::value));
manager.setZero(*tgtCacheIndirection);
outputs.cacheIndirection = tgtCacheIndirection;
}
auto sequenceLengths
= std::shared_ptr(manager.gpu(ITensor::makeShape({batchSize * maxBeamWidth}), TRTDataType<SizeType>::value));
manager.copy(tiledInputLengths.data(), *sequenceLengths);
outputs.sequenceLengths = sequenceLengths;
auto constexpr tokenId = 1;
std::vector<decoder_batch::Input::TensorPtr> inputIds;
for (auto b = 0; b < batchSize; ++b)
{
auto shape = ITensor::makeShape({inputLengths[b]});
auto input = std::shared_ptr(manager.gpu(shape, TRTDataType<SizeType>::value));
kernels::invokeFill(*input, tokenId, *streamPtr);
inputIds.emplace_back(input);
auto requests = prepareRequests(batchSize, maxNewTokens, inputLengths, generatedTokensPerSteps,
acceptedTokensPerStep, tokenId, endId, padId, computeLogProbs, manager);
auto decoderRequest = decoder_batch::Request{inputIds[b], maxNewTokens, endId};
decoderRequest.computeCumLogProbs = computeLogProbs;
decoderRequest.computeLogProbs = computeLogProbs;
decoder.newRequest(b, decoderRequest, samplingConfigs[b]);
// set up inputs and outputs
auto inputs = prepareDecoderInputs(batchSize, maxBeamWidth, maxSeqLength, vocabSizePadded, dataType,
samplingConfigs, generatedTokensPerSteps, manager);
auto outputs = prepareDecoderOutputs(batchSize, maxBeamWidth, maxSeqLength, tiledInputLengths, manager);
// We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache).
auto const maxKvCacheLength = maxSeqLength;
// set up decoder
auto decoder = GptDecoderBatch(vocabSize, vocabSizePadded, streamPtr);
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, maxGeneratedTokensPerStep, dataType);
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
decoder.newRequest(batchIdx, requests[batchIdx], samplingConfigs[batchIdx]);
}
cudaDeviceSynchronize();
auto const& nbSteps = decoder.getNbSteps();
EXPECT_EQ(nbSteps.size(), batchSize);
EXPECT_THAT(nbSteps, ::testing::Each(0));
auto expectedLengths = tiledInputLengths;
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
auto const& finished = decoder.getFinished();
EXPECT_EQ(finished.size(), batchSize);
EXPECT_THAT(finished, ::testing::Each(false));
verifyResults(
manager, decoder, samplingConfigs, inputLengths, batchSize, maxBeamWidth, maxSeqLength, 0, tokenId, padId);
verifyResults(manager, decoder, samplingConfigs, inputLengths, expectedLengths, batchSize, maxBeamWidth,
maxSeqLength, tokenId, padId);
// run decoder for 1 step
decoder.forward(outputs, inputs);
EXPECT_THAT(decoder.getNbSteps(), ::testing::Each(1));
advanceSequenceLengths(expectedLengths, acceptedTokensPerStep, samplingConfigs, batchSize, maxBeamWidth);
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
EXPECT_THAT(decoder.getFinished(), ::testing::Each(false));
verifyResults(
manager, decoder, samplingConfigs, inputLengths, batchSize, maxBeamWidth, maxSeqLength, 1, tokenId, padId);
verifyResults(manager, decoder, samplingConfigs, inputLengths, expectedLengths, batchSize, maxBeamWidth,
maxSeqLength, tokenId, padId);
// run decoder for 1 step
decoder.forward(outputs, inputs);
EXPECT_THAT(decoder.getFinished(), ::testing::Each(true));
EXPECT_THAT(decoder.getNbSteps(), ::testing::Each(maxNewTokens));
verifyResults(
manager, decoder, samplingConfigs, inputLengths, batchSize, maxBeamWidth, maxSeqLength, 2, tokenId, padId);
advanceSequenceLengths(expectedLengths, acceptedTokensPerStep, samplingConfigs, batchSize, maxBeamWidth);
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
EXPECT_THAT(decoder.getFinished(), ::testing::Each(true));
verifyResults(manager, decoder, samplingConfigs, inputLengths, expectedLengths, batchSize, maxBeamWidth,
maxSeqLength, tokenId, padId);
EXPECT_NO_THROW(decoder.forward(outputs, inputs));
EXPECT_THAT(decoder.getNbSteps(), ::testing::Each(maxNewTokens));
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
auto decoderRequest = decoder_batch::Request{inputIds[0], maxNewTokens};
decoderRequest.computeCumLogProbs = computeLogProbs;
decoderRequest.computeLogProbs = computeLogProbs;
decoder.newRequest(0, decoderRequest, samplingConfigs[0]);
decoder.newRequest(0, requests[0], samplingConfigs[0]);
EXPECT_FALSE(decoder.getFinished()[0]);
EXPECT_EQ(decoder.getNbSteps()[0], 0);
}
void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector<SamplingConfig> const& samplingConfigs,
int maxBeamWidth, bool computeLogProbs)
SizeType maxBeamWidth, bool computeLogProbs)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
SizeType constexpr tensorParallelism{1};
@ -238,105 +307,192 @@ void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector<SamplingCo
auto streamPtr = std::make_shared<CudaStream>();
BufferManager manager(streamPtr);
// create decoder
int constexpr endId{50257};
int constexpr padId{50257};
TokenIdType constexpr endId{50257};
TokenIdType constexpr padId{50257};
auto const dataType = modelConfig.getDataType();
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
auto decoder = GptDecoderBatch(vocabSize, vocabSizePadded, streamPtr);
// setup decoder
auto const batchSize = static_cast<SizeType>(samplingConfigs.size());
SizeType constexpr maxInputLength{8};
SizeType constexpr maxNewTokens{8};
auto constexpr maxSeqLength = maxInputLength + maxNewTokens;
// We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache).
auto const maxKvCacheLength = maxSeqLength;
SizeType constexpr maxGeneratedTokensPerStep{1};
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, modelConfig.getDataType());
std::vector<SizeType> inputLengths(batchSize);
std::iota(inputLengths.begin(), inputLengths.end(), 4);
std::vector<SizeType> const inputLengths{4, 5, 6, 7};
std::vector<SizeType> tiledInputLengths;
for (int batch_id = 0; batch_id < inputLengths.size(); batch_id++)
for (int batchIdx = 0; batchIdx < inputLengths.size(); batchIdx++)
{
for (int beam_id = 0; beam_id < maxBeamWidth; beam_id++)
for (int beamId = 0; beamId < maxBeamWidth; beamId++)
{
tiledInputLengths.push_back(inputLengths.at(batch_id));
tiledInputLengths.push_back(inputLengths.at(batchIdx));
}
}
// set up inputs
auto logits = std::shared_ptr(
manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, vocabSizePadded}), modelConfig.getDataType()));
manager.setZero(*logits);
decoder_batch::Input inputs{logits};
if (maxBeamWidth > 1)
std::vector<SizeType> generatedTokensPerSteps(batchSize);
std::vector<SizeType> acceptedTokensPerStep(batchSize);
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
auto srcCacheIndirection = std::shared_ptr(
manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType<SizeType>::value));
manager.setZero(*srcCacheIndirection);
inputs.cacheIndirection = srcCacheIndirection;
generatedTokensPerSteps[batchIdx] = maxGeneratedTokensPerStep;
acceptedTokensPerStep[batchIdx] = generatedTokensPerSteps[batchIdx] - 1;
}
// set up outputs
decoder_batch::Output outputs{};
auto constexpr tokenId = 1;
auto requests = prepareRequests(batchSize, maxNewTokens, inputLengths, generatedTokensPerSteps,
acceptedTokensPerStep, tokenId, endId, padId, computeLogProbs, manager);
if (maxBeamWidth > 1)
{
auto tgtCacheIndirection = std::shared_ptr(
manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType<SizeType>::value));
manager.setZero(*tgtCacheIndirection);
outputs.cacheIndirection = tgtCacheIndirection;
}
auto sequenceLengths
= std::shared_ptr(manager.gpu(ITensor::makeShape({batchSize * maxBeamWidth}), TRTDataType<SizeType>::value));
manager.copy(tiledInputLengths.data(), *sequenceLengths);
outputs.sequenceLengths = sequenceLengths;
// set up inputs and outputs
auto inputs = prepareDecoderInputs(batchSize, maxBeamWidth, maxSeqLength, vocabSizePadded, dataType,
samplingConfigs, generatedTokensPerSteps, manager);
auto outputs = prepareDecoderOutputs(batchSize, maxBeamWidth, maxSeqLength, tiledInputLengths, manager);
// We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache).
auto const maxKvCacheLength = maxSeqLength;
// set up decoder
auto decoder = GptDecoderBatch(vocabSize, vocabSizePadded, streamPtr);
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, maxGeneratedTokensPerStep, dataType);
auto const& nbSteps = decoder.getNbSteps();
EXPECT_EQ(nbSteps.size(), batchSize);
std::vector<SizeType> expectedSteps(batchSize, 0);
auto expectedLengths = tiledInputLengths;
auto const& finished = decoder.getFinished();
EXPECT_EQ(finished.size(), batchSize);
std::vector<bool> expectedFinished(batchSize, true);
auto constexpr tokenId = 1;
std::vector<decoder_batch::Input::TensorPtr> inputIds;
for (auto b = 0; b < batchSize; ++b)
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
auto shape = ITensor::makeShape({inputLengths[b]});
auto input = std::shared_ptr(manager.gpu(shape, TRTDataType<SizeType>::value));
kernels::invokeFill(*input, tokenId, *streamPtr);
inputIds.emplace_back(input);
auto decoderRequest = decoder_batch::Request{inputIds[b], maxNewTokens, endId};
decoderRequest.computeCumLogProbs = computeLogProbs;
decoderRequest.computeLogProbs = computeLogProbs;
decoder.newRequest(b, decoderRequest, samplingConfigs[b]);
decoder.newRequest(batchIdx, requests[batchIdx], samplingConfigs[batchIdx]);
decoder.forward(outputs, inputs);
for (auto i = 0; i < inputIds.size(); ++i)
{
expectedSteps[i] = std::min(expectedSteps[i] + 1, maxNewTokens);
expectedFinished[i] = expectedSteps[i] == maxNewTokens;
}
advanceSequenceLengths(expectedLengths, acceptedTokensPerStep, samplingConfigs, batchIdx + 1, maxBeamWidth);
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
EXPECT_THAT(decoder.getNbSteps(), ::testing::ElementsAreArray(expectedSteps));
for (auto bi = 0; bi <= batchIdx; ++bi)
{
auto firstBeamIndex = tc::flat_index2(bi, 0, maxBeamWidth);
expectedFinished.at(bi)
= expectedLengths.at(firstBeamIndex) - tiledInputLengths.at(firstBeamIndex) == maxNewTokens;
}
EXPECT_THAT(decoder.getFinished(), ::testing::ElementsAreArray(expectedFinished));
}
while (!decoder.getFinished().back())
auto finishedVec = decoder.getFinished();
while (!std::any_of(finishedVec.begin(), finishedVec.end(), [](bool finish) { return finish; }))
{
decoder.forward(outputs, inputs);
}
EXPECT_THAT(decoder.getFinished(), ::testing::Each(true));
EXPECT_THAT(decoder.getNbSteps(), ::testing::Each(maxNewTokens));
finishedVec = decoder.getFinished();
verifyResults(manager, decoder, samplingConfigs, inputLengths, batchSize, maxBeamWidth, maxSeqLength, maxNewTokens,
tokenId, padId);
advanceSequenceLengths(expectedLengths, acceptedTokensPerStep, samplingConfigs, batchSize, maxBeamWidth);
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
for (auto bi = 0; bi < batchSize; ++bi)
{
auto firstBeamIndex = tc::flat_index2(bi, 0, maxBeamWidth);
expectedFinished.at(bi)
= expectedLengths.at(firstBeamIndex) - tiledInputLengths.at(firstBeamIndex) == maxNewTokens;
}
EXPECT_THAT(finishedVec, ::testing::ElementsAreArray(expectedFinished));
}
verifyResults(manager, decoder, samplingConfigs, inputLengths, expectedLengths, batchSize, maxBeamWidth,
maxSeqLength, tokenId, padId);
}
void testDecoderDraft(nvinfer1::DataType const dtype, std::vector<SamplingConfig> const& samplingConfigs,
SizeType maxBeamWidth, std::vector<SizeType> const& generatedTokensPerSteps,
std::vector<SizeType> const& acceptedTokensPerStep, SizeType maxGeneratedTokensPerStep)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
SizeType constexpr tensorParallelism{1};
SizeType constexpr pipelineParallelism{1};
SizeType constexpr localRank{0};
WorldConfig constexpr worldConfig{tensorParallelism, pipelineParallelism, localRank};
SizeType constexpr vocabSize{51200};
SizeType constexpr nbLayers{2};
SizeType constexpr nbHeads{16};
SizeType constexpr hiddenSize{1024};
GptModelConfig modelConfig{vocabSize, nbLayers, nbHeads, hiddenSize, dtype};
modelConfig.useGptAttentionPlugin(false);
auto streamPtr = std::make_shared<CudaStream>();
BufferManager manager(streamPtr);
TokenIdType constexpr endId{50257};
TokenIdType constexpr padId{50257};
auto const dataType = modelConfig.getDataType();
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
auto const batchSize = static_cast<SizeType>(samplingConfigs.size());
SizeType constexpr maxInputLength{8};
SizeType const maxNewTokens{4};
auto const maxSeqLength = maxInputLength + maxNewTokens;
std::vector<SizeType> inputLengths(batchSize);
std::iota(inputLengths.begin(), inputLengths.end(), 4);
std::vector<SizeType> tiledInputLengths;
for (int batchIdx = 0; batchIdx < inputLengths.size(); batchIdx++)
{
for (int beamId = 0; beamId < maxBeamWidth; beamId++)
{
tiledInputLengths.push_back(inputLengths.at(batchIdx));
}
}
std::vector<SizeType> advancedTokensPerStep{generatedTokensPerSteps};
std::for_each(advancedTokensPerStep.begin(), advancedTokensPerStep.end(), [](auto& x) { x -= 1; });
auto constexpr tokenId = 1;
auto requests = prepareRequests(batchSize, maxNewTokens, inputLengths, generatedTokensPerSteps,
acceptedTokensPerStep, tokenId, endId, padId, false, manager);
// set up inputs and outputs
auto inputs = prepareDecoderInputs(batchSize, maxBeamWidth, maxSeqLength, vocabSizePadded, dataType,
samplingConfigs, generatedTokensPerSteps, manager);
auto outputs = prepareDecoderOutputs(batchSize, maxBeamWidth, maxSeqLength, tiledInputLengths, manager);
// We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache).
auto const maxKvCacheLength = maxSeqLength;
// set up decoder
auto decoder = GptDecoderBatch(vocabSize, vocabSizePadded, streamPtr);
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, maxGeneratedTokensPerStep, dataType);
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
{
decoder.newRequest(batchIdx, requests[batchIdx], samplingConfigs[batchIdx]);
}
cudaDeviceSynchronize();
auto expectedLengths = tiledInputLengths;
auto generatedLengths = tiledInputLengths;
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
auto const& finished = decoder.getFinished();
EXPECT_EQ(finished.size(), batchSize);
EXPECT_THAT(finished, ::testing::Each(false));
verifyResults(manager, decoder, samplingConfigs, inputLengths, expectedLengths, batchSize, maxBeamWidth,
maxSeqLength, tokenId, padId);
// run decoder for 1 step
decoder.forward(outputs, inputs);
advanceSequenceLengths(expectedLengths, acceptedTokensPerStep, samplingConfigs, batchSize, maxBeamWidth);
// WAR: we don't write endId back into outputIds when we rejected tokens,
// so we adjust the lengths for verifyResults here
advanceSequenceLengths(generatedLengths, advancedTokensPerStep, samplingConfigs, batchSize, maxBeamWidth);
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
EXPECT_THAT(decoder.getFinished(), ::testing::Each(false));
verifyResults(manager, decoder, samplingConfigs, inputLengths, generatedLengths, batchSize, maxBeamWidth,
maxSeqLength, tokenId, padId);
}
} // namespace
@ -347,7 +503,26 @@ struct BeamConfig
std::vector<SizeType> beamWidths;
};
class ParamTest : public ::testing::TestWithParam<std::tuple<nvinfer1::DataType, BeamConfig, bool>>
using ParamType = std::tuple<nvinfer1::DataType, BeamConfig, bool>;
std::string generateTestName(const testing::TestParamInfo<ParamType>& info)
{
std::string name{std::get<0>(info.param) == nvinfer1::DataType::kFLOAT ? "Float" : "Half"};
BeamConfig const beamConfig = std::get<1>(info.param);
name.append("MaxBeamWidth" + std::to_string(beamConfig.maxBeamWidth));
for (auto const beamWdith : beamConfig.beamWidths)
{
name.append("Bw" + std::to_string(beamWdith));
}
bool const computeLogProbs{std::get<2>(info.param)};
if (computeLogProbs)
{
name.append("LogProbs");
}
return name;
}
class ParamTest : public ::testing::TestWithParam<ParamType>
{
};
@ -365,29 +540,14 @@ TEST_P(ParamTest, Test)
testDecoder(dtype, samplingConfigs, beamConfig.maxBeamWidth, computeLogProbs);
}
INSTANTIATE_TEST_SUITE_P(GptDecoderBatchTest, ParamTest,
INSTANTIATE_TEST_SUITE_P(GptDecoderBwTest, ParamTest,
testing::Combine(testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kHALF),
testing::Values(BeamConfig{1, {1, 1, 1}}, BeamConfig{3, {3, 3, 3, 3}}, BeamConfig{4, {1, 1}},
BeamConfig{4, {3, 3, 3}}, BeamConfig{4, {1, 2, 3, 4}}),
testing::Values(false, true)),
[](const testing::TestParamInfo<ParamTest::ParamType>& info)
{
std::string name{std::get<0>(info.param) == nvinfer1::DataType::kFLOAT ? "Float" : "Half"};
BeamConfig const beamConfig = std::get<1>(info.param);
bool const computeLogProbs = std::get<2>(info.param);
name.append("MaxBeamWidth" + std::to_string(beamConfig.maxBeamWidth));
for (auto const beamWidth : beamConfig.beamWidths)
{
name.append("Bw" + std::to_string(beamWidth));
}
if (computeLogProbs)
{
name.append("LogProbs");
}
return name;
});
generateTestName);
class ParamWavefrontTest : public ::testing::TestWithParam<std::tuple<nvinfer1::DataType, BeamConfig, bool>>
class ParamWavefrontTest : public ::testing::TestWithParam<ParamType>
{
};
@ -405,24 +565,70 @@ TEST_P(ParamWavefrontTest, Test)
testDecoderWavefront(dtype, samplingConfigs, beamConfig.maxBeamWidth, computeLogProbs);
}
INSTANTIATE_TEST_SUITE_P(GptDecoderBatchTest, ParamWavefrontTest,
INSTANTIATE_TEST_SUITE_P(GptDecoderBwTest, ParamWavefrontTest,
testing::Combine(testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kHALF),
testing::Values(BeamConfig{1, {1, 1, 1}}, BeamConfig{3, {3, 3, 3, 3}}, BeamConfig{4, {1, 1}},
BeamConfig{4, {3, 3, 3}}, BeamConfig{4, {1, 2, 3, 4}}),
testing::Values(false, true)),
[](const testing::TestParamInfo<ParamTest::ParamType>& info)
generateTestName);
struct DraftConfig
{
SizeType maxGeneratedTokensPerStep;
std::vector<SizeType> generatedTokensPerSteps;
std::vector<SizeType> acceptedTokensPerStep;
};
using DraftTestParamType = std::tuple<nvinfer1::DataType, BeamConfig, DraftConfig>;
class ParamDraftTest : public ::testing::TestWithParam<DraftTestParamType>
{
};
TEST_P(ParamDraftTest, Test)
{
nvinfer1::DataType const dtype{std::get<0>(GetParam())};
BeamConfig const beamConfig{std::get<1>(GetParam())};
DraftConfig const draftConfig{std::get<2>(GetParam())};
ASSERT_EQ(beamConfig.beamWidths.size(), draftConfig.acceptedTokensPerStep.size());
ASSERT_EQ(beamConfig.beamWidths.size(), draftConfig.generatedTokensPerSteps.size());
std::vector<SamplingConfig> samplingConfigs;
for (auto const beamWidth : beamConfig.beamWidths)
{
samplingConfigs.emplace_back(beamWidth);
}
testDecoderDraft(dtype, samplingConfigs, beamConfig.maxBeamWidth, draftConfig.generatedTokensPerSteps,
draftConfig.acceptedTokensPerStep, draftConfig.maxGeneratedTokensPerStep);
}
INSTANTIATE_TEST_SUITE_P(GptDecoderTest, ParamDraftTest,
testing::Combine(testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kHALF),
testing::Values(BeamConfig{1, {1, 1, 1}}, BeamConfig{4, {1, 1, 1}}),
testing::Values( //
DraftConfig{2, {1, 1, 1}, {0, 0, 0}}, DraftConfig{2, {2, 2, 2}, {1, 1, 1}},
DraftConfig{4, {1, 2, 3}, {0, 0, 1}}
)),
[](const testing::TestParamInfo<DraftTestParamType>& info)
{
std::string name{std::get<0>(info.param) == nvinfer1::DataType::kFLOAT ? "Float" : "Half"};
BeamConfig const beamConfig = std::get<1>(info.param);
bool const computeLogProbs = std::get<2>(info.param);
DraftConfig const draftConfig = std::get<2>(info.param);
name.append("MaxBeamWidth" + std::to_string(beamConfig.maxBeamWidth));
auto const batchSize = beamConfig.beamWidths.size();
for (auto const beamWdith : beamConfig.beamWidths)
{
name.append("Bw" + std::to_string(beamWdith));
}
if (computeLogProbs)
name.append("PerStep" + std::to_string(draftConfig.maxGeneratedTokensPerStep));
for (std::size_t i = 0; i < batchSize; ++i)
{
name.append("LogProbs");
auto const acc = draftConfig.acceptedTokensPerStep.at(i);
auto const gen = draftConfig.generatedTokensPerSteps.at(i);
name.append("Acc" + std::to_string(acc) + "of" + std::to_string(gen));
}
return name;
});

View File

@ -97,6 +97,7 @@ void testDecoder(nvinfer1::DataType const dtype, SamplingConfig const& samplingC
outputs.lengths
= manager.copyFrom(sequenceLengthsVec, ITensor::makeShape({batchSize, beamWidth}), MemoryType::kGPU);
outputs.finished = manager.gpu(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kBOOL);
inputs.finished = ITensor::view(outputs.finished);
manager.setZero(*outputs.finished);
outputs.finishedSum = BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
auto* finishedSumHost = bufferCast<std::int32_t>(*outputs.finishedSum);

View File

@ -692,7 +692,7 @@ namespace
// TODO: consolidate this function with testGptSession
// Notice: all ChatGLM / GLM models use this
// function The differences are GptModelConfig::ModelVariant
// The differences are GptModelConfig::ModelVariant
void testChatGlmSession(fs::path const& modelPath, std::string const& modelName, ModelSpec const& modelSpec,
ModelIds const modelIds, SizeType beamWidth, std::initializer_list<int> const& batchSizes,
std::shared_ptr<nvinfer1::ILogger> const& logger, bool cudaGraphMode, MicroBatchSizes microBatchSizes)
@ -718,7 +718,7 @@ void testChatGlmSession(fs::path const& modelPath, std::string const& modelName,
auto const expectedOutputData = bufferCast<TokenIdType const>(*expectedOutput);
ASSERT_TRUE(fs::exists(modelPath));
auto const json = GptJsonConfig::parse(modelPath / (modelName + "-config.json"));
auto const json = GptJsonConfig::parse(modelPath / "config.json");
auto const modelConfig = json.getModelConfig();
verifyModelConfig(modelConfig, modelSpec);
auto const decoderPerRequest = modelSpec.mDecoderPerRequest;
@ -867,7 +867,7 @@ void testChatGlmSession(fs::path const& modelPath, std::string const& modelName,
TEST_F(ChatGlmSessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
{
auto const modelName{"chatglm_6b"};
auto const modelPath{ENGINE_PATH / "chatglm"};
auto const modelPath{ENGINE_PATH / modelName};
auto const batchSizes = {1};
auto constexpr dtype = nvinfer1::DataType::kHALF;
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
@ -879,7 +879,7 @@ TEST_F(ChatGlmSessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
{
auto const modelName{"chatglm2_6b"};
auto const modelPath{ENGINE_PATH / "chatglm"};
auto const modelPath{ENGINE_PATH / modelName};
auto const batchSizes = {1};
auto constexpr dtype = nvinfer1::DataType::kHALF;
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
@ -891,7 +891,7 @@ TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS2BM1)
{
auto const modelName{"chatglm2_6b"};
auto const modelPath{ENGINE_PATH / "chatglm"};
auto const modelPath{ENGINE_PATH / modelName};
auto const batchSizes = {2};
auto constexpr dtype = nvinfer1::DataType::kHALF;
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
@ -903,7 +903,7 @@ TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS2BM1)
TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM2)
{
auto const modelName{"chatglm2_6b"};
auto const modelPath{ENGINE_PATH / "chatglm"};
auto const modelPath{ENGINE_PATH / modelName};
auto const batchSizes = {1};
auto constexpr dtype = nvinfer1::DataType::kHALF;
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
@ -915,7 +915,7 @@ TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM2)
TEST_F(ChatGlm3SessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
{
auto const modelName{"chatglm3_6b"};
auto const modelPath{ENGINE_PATH / "chatglm"};
auto const modelPath{ENGINE_PATH / modelName};
auto const batchSizes = {1};
auto constexpr dtype = nvinfer1::DataType::kHALF;
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();

View File

@ -19,23 +19,24 @@ COPY docker/common/install_cmake.sh install_cmake.sh
RUN bash ./install_cmake.sh && rm install_cmake.sh
# Download & install internal TRT release
ARG TRT_VER="9.1.0.4"
ENV TRT_VER=$TRT_VER
ARG CUDA_VER="12.2"
ENV CUDA_VER=$CUDA_VER
ARG CUDNN_VER="8.9.4.25-1+cuda12.2"
ENV CUDNN_VER=$CUDNN_VER
ARG NCCL_VER="2.18.3-1+cuda12.2"
ENV NCCL_VER=$NCCL_VER
ARG CUBLAS_VER="12.2.5.6-1"
ENV CUBLAS_VER=$CUBLAS_VER
ARG TRT_VER CUDA_VER CUDNN_VER NCCL_VER CUBLAS_VER
COPY docker/common/install_tensorrt.sh install_tensorrt.sh
RUN bash ./install_tensorrt.sh && rm install_tensorrt.sh
RUN bash ./install_tensorrt.sh \
--TRT_VER=${TRT_VER} \
--CUDA_VER=${CUDA_VER} \
--CUDNN_VER=${CUDNN_VER} \
--NCCL_VER=${NCCL_VER} \
--CUBLAS_VER=${CUBLAS_VER} && \
rm install_tensorrt.sh
# Install latest Polygraphy
COPY docker/common/install_polygraphy.sh install_polygraphy.sh
RUN bash ./install_polygraphy.sh && rm install_polygraphy.sh
# Install mpi4py
COPY docker/common/install_mpi4py.sh install_mpi4py.sh
RUN bash ./install_mpi4py.sh && rm install_mpi4py.sh
# Install PyTorch
ARG TORCH_INSTALL_TYPE="skip"
COPY docker/common/install_pytorch.sh install_pytorch.sh
@ -49,7 +50,7 @@ COPY benchmarks benchmarks
COPY scripts scripts
COPY tensorrt_llm tensorrt_llm
COPY 3rdparty 3rdparty
COPY setup.py requirements.txt ./
COPY setup.py requirements.txt requirements-dev.txt ./
ARG BUILD_WHEEL_ARGS="--clean --trt_root /usr/local/tensorrt"
RUN python3 scripts/build_wheel.py ${BUILD_WHEEL_ARGS}

View File

@ -102,10 +102,14 @@ wheel_%: STAGE = wheel
release_%: STAGE = release
# For x86_64 and aarch64
# For x86_64
jenkins_%: IMAGE_WITH_TAG = $(shell grep 'LLM_DOCKER_IMAGE = ' ../jenkins/L0_MergeRequest.groovy | grep -o '".*"' | tr -d '"')
jenkins_%: STAGE = devel
# For aarch64
jenkins-aarch64_%: IMAGE_WITH_TAG = $(shell grep 'LLM_DOCKER_IMAGE = ' ../jenkins/L1_Nightly_GH200RemoteJob.groovy | grep -o '".*"' | tr -d '"')
jenkins-aarch64_%: STAGE = devel
# For x86_64
centos7_%: IMAGE_WITH_TAG = $(shell grep 'LLM_CENTOS7_DOCKER_IMAGE = ' ../jenkins/L0_MergeRequest.groovy | grep -o '".*"' | tr -d '"')
centos7_%: STAGE = devel
@ -117,7 +121,7 @@ ubuntu22_%: STAGE = devel
ubuntu22_%: BASE_IMAGE = nvidia/cuda
ubuntu22_%: BASE_TAG = 12.2.2-devel-ubuntu22.04
# For x86_64 and aarch64
# For x86_64
old-cuda_%: IMAGE_WITH_TAG = $(shell grep 'LLM_OLD_CUDA_DOCKER_IMAGE = ' ../jenkins/L0_MergeRequest.groovy | grep -o '".*"' | tr -d '"')
old-cuda_%: BASE_TAG = 23.07-py3
old-cuda_%: STAGE = devel

View File

@ -30,7 +30,7 @@ init_ubuntu() {
fi
apt-get clean
rm -rf /var/lib/apt/lists/*
echo "export LD_LIBRARY_PATH=/usr/local/cuda/lib64:\$LD_LIBRARY_PATH" >> "${ENV}"
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64' >> "${ENV}"
# Remove previous TRT installation
if [[ $(apt list --installed | grep libnvinfer) ]]; then
apt-get remove --purge -y libnvinfer*
@ -39,7 +39,6 @@ init_ubuntu() {
apt-get remove --purge -y tensorrt*
fi
pip uninstall -y tensorrt
pip install mpi4py
}
install_gcc_centos() {
@ -53,7 +52,7 @@ install_gcc_centos() {
./contrib/download_prerequisites
./configure --disable-multilib --enable-languages=c,c++ --with-pi
make -j$(nproc) && make install
echo "export LD_LIBRARY_PATH=/usr/local/lib64:\$LD_LIBRARY_PATH" >> "${ENV}"
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64' >> "${ENV}"
cd .. && rm -rf /tmp/gcc-*
yum clean all
}
@ -64,7 +63,7 @@ init_centos() {
yum -y update
yum -y install centos-release-scl-rh epel-release
# https://gitlab.com/nvidia/container-images/cuda
echo "export LD_LIBRARY_PATH=/usr/local/cuda/lib64:\$LD_LIBRARY_PATH" >> "${ENV}"
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64' >> "${ENV}"
CUDA_VERSION=$(nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p')
YUM_CUDA=${CUDA_VERSION/./-}
# Consistent with manylinux2014 centos-7 based version
@ -73,7 +72,7 @@ init_centos() {
echo "source scl_source enable rh-git227 rh-python38" >> "${ENV}"
echo "source scl_source enable devtoolset-10" >> "${DEVTOOLSET_ENV_FILE}"
echo "source ${DEVTOOLSET_ENV_FILE}" >> "${ENV}"
echo 'export PATH=/usr/lib64/openmpi3/bin:$PATH' >> "${ENV}"
echo 'export PATH=$PATH:/usr/lib64/openmpi3/bin' >> "${ENV}"
bash -c "pip install 'urllib3<2.0'"
yum clean all
}

View File

@ -12,4 +12,4 @@ wget --no-verbose ${RELEASE_URL_CMAKE} -P /tmp
tar -xf /tmp/${CMAKE_FILE_NAME}.tar.gz -C /usr/local/
ln -s /usr/local/${CMAKE_FILE_NAME} /usr/local/cmake
echo 'export PATH=/usr/local/cmake/bin:$PATH' >> "${ENV}"
echo 'export PATH=$PATH:/usr/local/cmake/bin' >> "${ENV}"

View File

@ -0,0 +1,11 @@
#!/bin/bash
set -ex
MPI4PY_VERSION="3.1.5"
RELEASE_URL="https://github.com/mpi4py/mpi4py/archive/refs/tags/${MPI4PY_VERSION}.tar.gz"
curl -L ${RELEASE_URL} | tar -zx -C /tmp
# Bypassing compatibility issues with higher versions (>= 69) of setuptools.
sed -i 's/>= 40\.9\.0/>= 40.9.0, < 69/g' /tmp/mpi4py-${MPI4PY_VERSION}/pyproject.toml
pip install /tmp/mpi4py-${MPI4PY_VERSION}
rm -rf /tmp/mpi4py*

View File

@ -2,6 +2,24 @@
set -ex
TRT_VER="9.1.0.4"
CUDA_VER="12.2"
CUDNN_VER="8.9.4.25-1+cuda12.2"
NCCL_VER="2.18.3-1+cuda12.2"
CUBLAS_VER="12.2.5.6-1"
for i in "$@"; do
case $i in
--TRT_VER=?*) TRT_VER="${i#*=}";;
--CUDA_VER=?*) CUDA_VER="${i#*=}";;
--CUDNN_VER=?*) CUDNN_VER="${i#*=}";;
--NCCL_VER=?*) NCCL_VER="${i#*=}";;
--CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";;
*) ;;
esac
shift
done
NVCC_VERSION_OUTPUT=$(nvcc --version)
if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then
echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}."
@ -64,7 +82,7 @@ install_tensorrt() {
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
pip install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
rm -rf /tmp/TensorRT.tar
echo 'export LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH' >> "${ENV}"
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/tensorrt/lib' >> "${ENV}"
}
# Install base packages depending on the base OS

View File

@ -257,7 +257,7 @@ as:
# Collectives.
def allreduce(tensor: Tensor, group: List[int]) -> Tensor
def allgather(tensor: Tensor, group: List[int]) -> Tensor
def allgather(tensor: Tensor, group: List[int], gather_dim: int = 0) -> Tensor
# Point-to-point communication primitives.
def send(tensor: Tensor, tgt: int) -> Tensor
@ -267,7 +267,7 @@ def recv(tensor: Tensor, src: int) -> Tensor
The multi-GPU support can be enabled through two different modes of model
parallelism: Tensor Parallelism and Pipeline Parallelism. The former mode
splits the different layers of a model across the GPUs. Each GPU runs the
entire network and synchronizes with its sibblings when needed. The Pipeline
entire network and synchronizes with its siblings when needed. The Pipeline
Parallelism distributes the different layers to the GPUs. Each GPU runs a
subset of the entire model and communications happen at the boundary of those
subsets of layers. Tensor Parallelism usually leads to more balanced executions

View File

@ -4,11 +4,10 @@ This document shows how to build and run a Baichuan models (including `v1_7b`/`v
## Overview
The TensorRT-LLM Baichuan implementation can be found in [tensorrt_llm/models/baichuan/model.py](../../tensorrt_llm/models/baichuan/model.py). The TensorRT-LLM Baichuan example code is located in [`examples/baichuan`](./). There are three main files:
The TensorRT-LLM Baichuan implementation can be found in [tensorrt_llm/models/baichuan/model.py](../../tensorrt_llm/models/baichuan/model.py). The TensorRT-LLM Baichuan example code is located in [`examples/baichuan`](./). There are two main files:
* [`build.py`](./build.py) to build the [TensorRT](https://developer.nvidia.com/tensorrt) engine(s) needed to run the Baichuan model,
* [`run.py`](./run.py) to run the inference on an input text,
* and a shared [`../summarize.py`](../summarize.py) to summarize the articles in the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset using the model.
* [`run.py`](./run.py) to run the inference on an input text.
These scripts accept an argument named model_version, whose value should be `v1_7b`/`v1_13b`/`v2_7b`/`v2_13b` and the default value is `v1_13b`.

View File

@ -18,9 +18,12 @@ import time
from pathlib import Path
import onnx
import tensorrt as trt
# isort: off
import torch
import torch.multiprocessing as mp
import tensorrt as trt
# isort: on
from onnx import TensorProto, helper
from transformers import AutoConfig, AutoModelForCausalLM
@ -499,8 +502,6 @@ def build_rank_engine(builder: Builder,
config_path = os.path.join(args.output_dir, 'config.json')
builder.save_config(builder_config, config_path)
tensorrt_llm.tools.cleanup(network, tensorrt_llm_baichuan)
return engine

View File

@ -132,6 +132,7 @@ def hf_baichuan_converter(args):
saved_dir.mkdir(parents=True, exist_ok=True)
model = AutoModelForCausalLM.from_pretrained(args.in_file,
torch_dtype="auto",
device_map="auto",
trust_remote_code=True)

View File

@ -16,8 +16,10 @@ import argparse
import os
from collections import OrderedDict
import tensorrt as trt
# isort: off
import torch
import tensorrt as trt
# isort: on
from transformers import BertConfig, BertForQuestionAnswering, BertModel
import tensorrt_llm

View File

@ -16,8 +16,10 @@ import argparse
import json
import os
import tensorrt as trt
# isort: off
import torch
import tensorrt as trt
# isort: on
import tensorrt_llm
from tensorrt_llm import logger

View File

@ -3,8 +3,10 @@ import json
import os
from pathlib import Path
import tensorrt as trt
# isort: off
import torch
import tensorrt as trt
# isort: on
from transformers import AutoTokenizer
import tensorrt_llm

View File

@ -4,11 +4,10 @@ This document shows how to build and run a BLOOM model in TensorRT-LLM on both s
## Overview
The TensorRT-LLM BLOOM implementation can be found in [tensorrt_llm/models/bloom/model.py](../../tensorrt_llm/models/bloom/model.py). The TensorRT-LLM BLOOM example code is located in [`examples/bloom`](./). There are three main files:
The TensorRT-LLM BLOOM implementation can be found in [tensorrt_llm/models/bloom/model.py](../../tensorrt_llm/models/bloom/model.py). The TensorRT-LLM BLOOM example code is located in [`examples/bloom`](./). There are two main files:
* [`build.py`](./build.py) to build the [TensorRT](https://developer.nvidia.com/tensorrt) engine(s) needed to run the BLOOM model,
* [`run.py`](./run.py) to run the inference on an input text,
* and a shared [`../summarize.py`](../summarize.py) to summarize the articles in the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset using the model.
* [`run.py`](./run.py) to run the inference on an input text.
## Support Matrix
* FP16

View File

@ -18,9 +18,12 @@ import time
from pathlib import Path
import onnx
import tensorrt as trt
# isort: off
import torch
import torch.multiprocessing as mp
import tensorrt as trt
# isort: on
from onnx import TensorProto, helper
from transformers import BloomConfig, BloomForCausalLM
@ -477,8 +480,6 @@ def build_rank_engine(builder: Builder,
config_path = os.path.join(args.output_dir, 'config.json')
builder.save_config(builder_config, config_path)
tensorrt_llm.tools.cleanup(network, tensorrt_llm_bloom)
return engine

View File

@ -4,4 +4,4 @@ awq/
chatglm*_6b*/
dataset/
glm_10b/
trtModel/
output_*/

View File

@ -5,11 +5,10 @@ This document explains how to build the [ChatGLM-6B](https://huggingface.co/THUD
## Overview
The TensorRT-LLM ChatGLM implementation can be found in [`tensorrt_llm/models/chatglm/model.py`](../../tensorrt_llm/models/chatglm/model.py).
The TensorRT-LLM ChatGLM example code is located in [`examples/chatglm`](./). There are three main files:
The TensorRT-LLM ChatGLM example code is located in [`examples/chatglm`](./). There are two main files:
* [`build.py`](./build.py) to build the [TensorRT](https://developer.nvidia.com/tensorrt) engine(s) needed to run the ChatGLM model.
* [`run.py`](./run.py) to run the inference on an input text.
* and a shared [`../summarize.py`](../summarize.py) to summarize the articles in the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset using the model.
## Support Matrix

View File

@ -18,12 +18,13 @@ import time
from pathlib import Path
from typing import List
import onnx
import tensorrt as trt
# isort: off
import torch
import torch.multiprocessing as mp
from onnx import TensorProto, helper
from weight import load_from_hf
import tensorrt as trt
# isort: on
from visualize import to_onnx
from weight import get_scaling_factors, load_from_hf
import tensorrt_llm
from tensorrt_llm._utils import str_dtype_to_trt
@ -36,12 +37,12 @@ from tensorrt_llm.plugin.plugin import ContextFMHAType
from tensorrt_llm.profiler import check_gpt_mem_usage
from tensorrt_llm.quantization import QuantMode
from weight import get_scaling_factors # isort:skip
from weight import load_from_hf_checkpoint # isort:skip
def get_engine_name(model, dtype, tp_size, rank):
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
def get_engine_name(model, dtype, tp_size, pp_size, rank):
if pp_size == 1:
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
return '{}_{}_tp{}_pp{}_rank{}.engine'.format(model, dtype, tp_size,
pp_size, rank)
def find_engines(dir: Path,
@ -53,61 +54,6 @@ def find_engines(dir: Path,
return list(dir.glob(template))
def trt_dtype_to_onnx(dtype):
if dtype == trt.float16:
return TensorProto.DataType.FLOAT16
elif dtype == trt.float32:
return TensorProto.DataType.FLOAT
elif dtype == trt.int32:
return TensorProto.DataType.INT32
else:
raise TypeError("%s is not supported" % dtype)
def to_onnx(network, path):
inputs = []
for i in range(network.num_inputs):
network_input = network.get_input(i)
inputs.append(
helper.make_tensor_value_info(
network_input.name, trt_dtype_to_onnx(network_input.dtype),
list(network_input.shape)))
outputs = []
for i in range(network.num_outputs):
network_output = network.get_output(i)
outputs.append(
helper.make_tensor_value_info(
network_output.name, trt_dtype_to_onnx(network_output.dtype),
list(network_output.shape)))
nodes = []
for i in range(network.num_layers):
layer = network.get_layer(i)
layer_inputs = []
for j in range(layer.num_inputs):
ipt = layer.get_input(j)
if ipt is not None:
layer_inputs.append(layer.get_input(j).name)
layer_outputs = [
layer.get_output(j).name for j in range(layer.num_outputs)
]
nodes.append(
helper.make_node(str(layer.type),
name=layer.name,
inputs=layer_inputs,
outputs=layer_outputs,
domain="com.nvidia"))
onnx_model = helper.make_model(helper.make_graph(nodes,
'attention',
inputs,
outputs,
initializer=None),
producer_name='NVIDIA')
onnx.save(onnx_model, path)
def serialize_engine(engine, path):
logger.info(f'Serializing engine to {path}...')
tik = time.time()
@ -118,7 +64,7 @@ def serialize_engine(engine, path):
logger.info(f'Engine serialized. Total time: {t}')
def truncate_input_output(
def truncate_input_output_len(
max_input_len,
max_output_len,
max_seq_length_from_config,
@ -152,21 +98,29 @@ def parse_arguments(args):
help=
'the name of the model, use "_" rather than "-" to connect the name parts'
)
parser.add_argument('--world_size',
type=int,
default=1,
help='world size, only support tensor parallelism now')
parser.add_argument('--tp_size', type=int, default=1)
parser.add_argument('--pp_size', type=int, default=1)
parser.add_argument(
'--world_size',
'-ws',
type=int,
default=1,
help='world size, only support tensor parallelism now',
)
parser.add_argument('--tp_size', '-tp', type=int, default=1)
parser.add_argument('--pp_size', '-pp', type=int, default=1)
parser.add_argument('--model_dir', type=str, default=None)
parser.add_argument('--dtype',
type=str,
default='float16',
choices=['float32', 'float16', 'bfloat16'])
parser.add_argument('--logits_dtype',
type=str,
default='float32',
choices=['float16', 'float32'])
parser.add_argument('--quant_ckpt_path', type=str, default="awq/")
parser.add_argument(
'--dtype',
type=str,
default='float16',
choices=['float32', 'float16', 'bfloat16'],
)
parser.add_argument(
'--logits_dtype',
type=str,
default='float32',
choices=['float16', 'float32'],
)
parser.add_argument(
'--timing_cache',
type=str,
@ -177,8 +131,9 @@ def parse_arguments(args):
parser.add_argument(
'--log_level',
type=str,
default='verbose',
choices=['verbose', 'info', 'warning', 'error', 'internal_error'])
default='info',
choices=['verbose', 'info', 'warning', 'error', 'internal_error'],
)
parser.add_argument('--max_batch_size', type=int, default=8)
parser.add_argument('--max_input_len', type=int, default=1024)
parser.add_argument('--max_output_len', type=int, default=1024)
@ -226,12 +181,16 @@ def parse_arguments(args):
action='store_true',
default=False)
parser.add_argument('--parallel_build', default=False, action='store_true')
parser.add_argument('--enable_context_fmha',
default=False,
action='store_true')
parser.add_argument('--enable_context_fmha_fp32_acc',
default=False,
action='store_true')
parser.add_argument(
'--enable_context_fmha',
default=False,
action='store_true',
)
parser.add_argument(
'--enable_context_fmha_fp32_acc',
default=False,
action='store_true',
)
parser.add_argument(
'--multi_block_mode',
default=False,
@ -240,19 +199,18 @@ def parse_arguments(args):
'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \
It is beneifical when batchxnum_heads cannot fully utilize GPU.'
)
parser.add_argument('--load_by_shard',
action='store_true',
help='Load a pretrained model shard-by-shard.')
parser.add_argument('--visualize', default=False, action='store_true')
parser.add_argument('--enable_debug_output',
default=False,
action='store_true')
parser.add_argument(
'--enable_debug_output',
default=False,
action='store_true',
)
parser.add_argument('--gpus_per_node', type=int, default=8)
parser.add_argument('--builder_opt', type=int, default=None)
parser.add_argument(
'--output_dir',
type=Path,
default='trtModel',
default=None,
help=
'The path to save the serialized engine files, timing cache file and model configs'
)
@ -263,9 +221,11 @@ def parse_arguments(args):
help=
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
)
parser.add_argument('--remove_input_padding',
default=False,
action='store_true')
parser.add_argument(
'--remove_input_padding',
default=False,
action='store_true',
)
parser.add_argument(
'--paged_kv_cache',
action="store_true",
@ -277,7 +237,8 @@ def parse_arguments(args):
'--use_inflight_batching',
action="store_true",
default=False,
help="Activates inflight batching mode of gptAttentionPlugin.")
help="Activates inflight batching mode of gptAttentionPlugin.",
)
# Arguments related to the quantization of the model.
parser.add_argument(
@ -293,17 +254,18 @@ def parse_arguments(args):
default=False,
action="store_true",
help='Quantize weights for the various GEMMs to INT4/INT8.'
'See --weight_only_precision to set the precision')
'See --weight_only_precision to set the precision',
)
parser.add_argument(
'--weight_only_precision',
const='int8',
type=str,
nargs='?',
default='int8',
choices=['int8', 'int4'],
choices=['int8', 'int4', 'int4_awq'],
help=
'Define the precision for the weights when using weight-only quantization.'
'You must also use --use_weight_only for that argument to have an impact.'
'You must also use --use_weight_only for that argument to have an impact.',
)
parser.add_argument(
'--per_channel',
@ -312,7 +274,8 @@ def parse_arguments(args):
help=
'By default, we use a single static scaling factor for the GEMM\'s result. '
'per_channel instead uses a different static scaling factor for each channel. '
'The latter is usually more accurate, but a little slower.')
'The latter is usually more accurate, but a little slower.',
)
parser.add_argument(
'--per_token',
default=False,
@ -320,7 +283,23 @@ def parse_arguments(args):
help=
'By default, we use a single static scaling factor to scale activations in the int8 range. '
'per_token chooses at run time, and for each token, a custom scaling factor. '
'The latter is usually more accurate, but a little slower.')
'The latter is usually more accurate, but a little slower.',
)
parser.add_argument(
'--per_group',
default=False,
action="store_true",
help=
'By default, we use a single static scaling factor to scale weights in the int4 range. '
'per_group chooses at run time, and for each group, a custom scaling factor. '
'The flag is built for GPTQ/AWQ quantization.',
)
parser.add_argument(
'--group_size',
type=int,
default=128,
help='Group size used in GPTQ/AWQ quantization.',
)
parser.add_argument(
'--int8_kv_cache',
default=False,
@ -333,16 +312,20 @@ def parse_arguments(args):
type=int,
default=None,
help=
'Seed to use when initializing the random number generator for torch.')
parser.add_argument('--tokens_per_block',
type=int,
default=64,
help='Number of tokens per block in paged KV cache')
'Seed to use when initializing the random number generator for torch.',
)
parser.add_argument(
'--tokens_per_block',
type=int,
default=64,
help='Number of tokens per block in paged KV cache',
)
parser.add_argument(
'--enable_fp8',
default=False,
action='store_true',
help='Use FP8 Linear layer for Attention QKV/Dense and MLP.',
)
parser.add_argument(
'--fp8_kv_cache',
@ -355,13 +338,16 @@ def parse_arguments(args):
'--max_num_tokens',
type=int,
default=None,
help='Define the max number of tokens supported by the engine')
help='Define the max number of tokens supported by the engine',
)
parser.add_argument(
'--use_custom_all_reduce',
action='store_true',
help=
'Activates latency-optimized algorithm for all-reduce instead of NCCL.')
'Activates latency-optimized algorithm for all-reduce instead of NCCL.',
)
args = parser.parse_args(args)
logger.set_level(args.log_level)
plugins_args = [
@ -377,6 +363,8 @@ def parse_arguments(args):
)
setattr(args, plugin_arg, args.dtype)
assert args.world_size == args.tp_size * args.pp_size # only TP is supported now
if args.model_dir is None:
args.model_dir = args.model_name
with open(Path(args.model_dir) / "config.json", "r") as f:
@ -385,75 +373,96 @@ def parse_arguments(args):
if args.model_name in ["chatglm_6b", "glm_10b"]:
assert args.max_input_len < js["max_sequence_length"]
if args.output_dir is None:
args.output_dir = Path("output_" + args.model_name)
if args.model_name in ["chatglm_6b"]:
args.ffn_hidden_size = js["inner_hidden_size"]
args.hidden_size = js["hidden_size"]
args.norm_epsilon = js["layernorm_epsilon"]
args.num_heads = js["num_attention_heads"]
args.num_layers = js["num_layers"]
args.vocab_size = js["vocab_size"]
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output(
args.max_input_len, args.max_output_len, js["max_sequence_length"])
args.apply_query_key_layer_scaling = False
args.hidden_act = 'gelu'
args.linear_bias = True
args.multi_block_mode = False
args.multi_query_mode = False
args.num_kv_heads = js["num_attention_heads"]
args.qkv_bias = True
args.use_cache = js["use_cache"]
elif args.model_name in ["glm_10b"]:
args.hidden_size = js["hidden_size"]
args.num_attention_heads = js["num_attention_heads"]
args.num_heads = js["num_attention_heads"]
args.num_layers = js["num_layers"]
args.vocab_size = js["vocab_size"]
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output(
args.max_input_len, args.max_output_len, js["max_sequence_length"],
True)
args.apply_query_key_layer_scaling = False
args.apply_residual_connection_post_layernorm = False
args.ffn_hidden_size = 4 * args.hidden_size
args.ffn_hidden_size = js["inner_hidden_size"]
args.hidden_act = 'gelu'
args.hidden_size = js["hidden_size"]
args.linear_bias = True
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output_len(
args.max_input_len,
args.max_output_len,
js["max_sequence_length"],
)
args.multi_block_mode = False
args.multi_query_mode = False
args.norm_epsilon = 1.0e-5
args.norm_epsilon = js["layernorm_epsilon"]
args.num_heads = js["num_attention_heads"]
args.num_kv_heads = js["num_attention_heads"]
args.num_layers = js["num_layers"]
args.qkv_bias = True
args.use_cache = True
args.rmsnorm = False
args.rotary_embedding_scaling = 1.0
args.use_cache = js["use_cache"]
args.vocab_size = js["vocab_size"]
elif args.model_name in [
"chatglm2_6b", "chatglm2_6b_32k", "chatglm3_6b", "chatglm3_6b_base",
"chatglm3_6b_32k"
"chatglm2_6b",
"chatglm2_6b_32k",
"chatglm3_6b",
"chatglm3_6b_base",
"chatglm3_6b_32k",
]:
args.apply_query_key_layer_scaling = False
args.apply_residual_connection_post_layernorm = js[
"apply_residual_connection_post_layernorm"]
args.ffn_hidden_size = js["ffn_hidden_size"]
args.hidden_act = 'swiglu'
args.hidden_size = js["hidden_size"]
args.linear_bias = js["add_bias_linear"]
args.multi_query_mode = js["multi_query_attention"]
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output_len(
args.max_input_len,
args.max_output_len,
js["seq_length"],
)
args.multi_block_mode = False
args.multi_query_mode = False # regardless of config.json
args.norm_epsilon = js["layernorm_epsilon"]
args.num_heads = js["num_attention_heads"]
args.num_kv_heads = js["multi_query_group_num"]
args.num_layers = js["num_layers"]
args.qkv_bias = js["add_qkv_bias"]
args.rmsnorm = js["rmsnorm"]
args.use_cache = js["use_cache"]
args.vocab_size = js["padded_vocab_size"]
args.max_seq_length = min(args.max_input_len + args.max_output_len,
js["seq_length"])
if args.model_name in ["chatglm2_6b_32k", "chatglm3_6b_32k"]:
args.rotary_embedding_scaling = js["rope_ratio"]
args.hidden_act = 'swiglu'
else:
args.rotary_embedding_scaling = 1.0
args.use_cache = js["use_cache"]
args.vocab_size = js["padded_vocab_size"]
elif args.model_name in ["glm_10b"]:
args.apply_query_key_layer_scaling = False
args.apply_residual_connection_post_layernorm = False
args.ffn_hidden_size = 4 * js["hidden_size"]
args.hidden_act = 'gelu'
args.hidden_size = js["hidden_size"]
args.linear_bias = True
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output_len(
args.max_input_len,
args.max_output_len,
js["max_sequence_length"],
True,
)
args.multi_block_mode = False
args.multi_query_mode = False
args.norm_epsilon = 1.0e-5
args.num_heads = js["num_attention_heads"]
args.num_kv_heads = js["num_attention_heads"]
args.num_layers = js["num_layers"]
args.qkv_bias = True
args.rmsnorm = False
args.rotary_embedding_scaling = 1.0
args.use_cache = True
args.vocab_size = js["vocab_size"]
if args.use_inflight_batching:
if not args.use_gpt_attention_plugin:
args.use_gpt_attention_plugin = 'float16'
logger.info(
f"Using GPT attention plugin for inflight batching mode. "
f"Setting to default '{args.use_gpt_attention_plugin}'")
f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'"
)
if not args.remove_input_padding:
args.remove_input_padding = True
logger.info(
@ -478,18 +487,19 @@ def parse_arguments(args):
if args.int8_kv_cache:
args.quant_mode = args.quant_mode.set_int8_kv_cache()
if args.fp8_kv_cache:
assert (
args.use_gpt_attention_plugin or args.use_inflight_batching
), "You have to use GPT attention plugin when fp8 KV cache is set"
elif args.fp8_kv_cache:
args.quant_mode = args.quant_mode.set_fp8_kv_cache()
if args.enable_fp8:
args.quant_mode = args.quant_mode.set_fp8_qdq()
if args.max_num_tokens is not None:
assert args.enable_context_fmha
logger.info(' Build Arguments '.center(100, '='))
for k, v in vars(args).items():
logger.info(f' - {k.ljust(30, ".")}: {v}')
logger.info('=' * 100)
return args
@ -510,40 +520,59 @@ def build_rank_engine(
args.mapping = Mapping(
world_size=args.world_size,
rank=rank,
tp_size=args.world_size,
tp_size=args.tp_size,
)
assert args.num_layers % args.pp_size == 0, \
f"num_layers {args.n_layer} must be a multiple of pipeline "\
f"parallelism size {args.pp_size}"
trtllm_model = ChatGLMHeadModel(args=args)
trtllm_model = ChatGLMHeadModel(
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
apply_residual_connection_post_layernorm=args.
apply_residual_connection_post_layernorm,
dtype=args.dtype,
enable_debug_output=args.enable_debug_output,
ffn_hidden_size=args.ffn_hidden_size,
hidden_act=args.hidden_act,
hidden_size=args.hidden_size,
linear_bias=args.linear_bias,
logits_dtype=args.logits_dtype,
mapping=args.mapping,
max_input_len=args.max_input_len,
max_output_len=args.max_output_len,
max_seq_length=args.max_seq_length,
model_name=args.model_name,
norm_epsilon=args.norm_epsilon,
num_heads=args.num_heads,
num_kv_heads=args.num_kv_heads,
num_layers=args.num_layers,
qkv_bias=args.qkv_bias,
quant_mode=args.quant_mode,
rmsnorm=args.rmsnorm,
rotary_embedding_scaling=args.rotary_embedding_scaling,
tokens_per_block=args.tokens_per_block,
use_cache=args.use_cache,
vocab_size=args.vocab_size,
)
if args.use_smooth_quant or args.use_weight_only:
trtllm_model = quantize_model(trtllm_model, args.quant_mode)
if args.enable_fp8 or args.fp8_kv_cache:
elif args.enable_fp8 or args.fp8_kv_cache:
logger.info(f'Loading scaling factors from '
f'{args.quantized_fp8_model_path}')
quant_scales = get_scaling_factors(args.quantized_fp8_model_path,
num_layers=args.n_layer,
quant_mode=args.quant_mode)
tensorrt_llm_falcon = quantize_model(tensorrt_llm_falcon,
quant_mode=args.quant_mode,
quant_scales=quant_scales)
if not args.load_by_shard:
trtllm_model = load_from_hf(
trtllm_model,
args.model_dir,
mapping=args.mapping,
dtype=args.dtype,
model_name=args.model_name,
)
else:
trtllm_model = load_from_hf_checkpoint(
trtllm_model,
args.model_dir,
mapping=args.mapping,
dtype=args.dtype,
model_name=args.model_name,
)
trtllm_model = quantize_model(trtllm_model,
quant_mode=args.quant_mode,
quant_scales=quant_scales)
trtllm_model = load_from_hf(
trtllm_model,
args.model_dir,
mapping=args.mapping,
dtype=args.dtype,
model_name=args.model_name,
)
# Module -> Network
network = builder.create_network()
@ -552,12 +581,20 @@ def build_rank_engine(
network.plugin_config.set_gpt_attention_plugin(
dtype=args.use_gpt_attention_plugin)
if args.use_gemm_plugin:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
if args.use_layernorm_plugin:
network.plugin_config.set_layernorm_plugin(
dtype=args.use_layernorm_plugin)
if not args.enable_fp8:
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
else:
logger.info(
"Gemm plugin does not support FP8. Disabled Gemm plugin.")
if args.use_rmsnorm_plugin:
network.plugin_config.set_rmsnorm_plugin(dtype=args.use_rmsnorm_plugin)
# Quantization plugins.
if args.use_smooth_quant:
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
network.plugin_config.set_rmsnorm_quantization_plugin(dtype=args.dtype)
network.plugin_config.set_quantize_tensor_plugin()
network.plugin_config.set_quantize_per_token_plugin()
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
if args.enable_context_fmha:
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
@ -566,7 +603,13 @@ def build_rank_engine(
ContextFMHAType.enabled_with_fp32_acc)
if args.multi_block_mode:
network.plugin_config.enable_mmha_multi_block_mode()
if args.use_weight_only:
if args.per_group:
network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin(
dtype='float16')
else:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype='float16')
if args.world_size > 1:
network.plugin_config.set_nccl_plugin(args.dtype,
args.use_custom_all_reduce)
@ -575,21 +618,6 @@ def build_rank_engine(
if args.paged_kv_cache:
network.plugin_config.enable_paged_kv_cache(args.tokens_per_block)
# Quantization plugins.
if args.use_smooth_quant:
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
network.plugin_config.set_layernorm_quantization_plugin(
dtype=args.dtype)
network.plugin_config.set_quantize_tensor_plugin()
network.plugin_config.set_quantize_per_token_plugin()
elif args.use_weight_only:
network.plugin_config.set_weight_only_quant_matmul_plugin(
dtype=args.dtype)
if args.world_size > 1:
network.plugin_config.set_nccl_plugin(args.dtype,
args.use_custom_all_reduce)
with net_guard(network):
# Prepare
network.set_named_parameters(trtllm_model.named_parameters())
@ -605,7 +633,7 @@ def build_rank_engine(
trtllm_model(*inputs)
if args.enable_debug_output:
# mark intermediate nodes' outputs
for k, v in tensorrt_llm_falcon.named_network_outputs():
for k, v in trtllm_model.named_network_outputs():
v = v.trt_tensor
v.name = k
network.trt_network.mark_output(v)
@ -616,24 +644,21 @@ def build_rank_engine(
tensorrt_llm.graph_rewriting.optimize(network)
engine = None
# Network -> Engine
engine = None
engine = builder.build_engine(network, builder_config)
if rank == 0:
config_path = args.output_dir / (args.model_name + '-config.json')
config_path = args.output_dir / 'config.json'
builder.save_config(builder_config, config_path)
tensorrt_llm.tools.cleanup(network, trtllm_model)
return engine
def build(rank, args):
torch.cuda.set_device(rank % args.gpus_per_node)
tensorrt_llm.logger.set_level(args.log_level)
logger.set_level(args.log_level)
args.output_dir.mkdir(parents=True, exist_ok=True)
timing_cache_file = args.output_dir / "model.cache"
timing_cache_file = args.timing_cache
timing_cache = timing_cache_file
builder = Builder()
@ -642,12 +667,15 @@ def build(rank, args):
# skip other ranks if parallel_build is enabled
if args.parallel_build and cur_rank != rank:
continue
# NOTE: when only int8 kv cache is used together with paged kv cache no int8 tensors are exposed to TRT
int8_trt_flag = args.quant_mode.has_act_or_weight_quant() or (
not args.paged_kv_cache and args.quant_mode.has_int8_kv_cache())
builder_config = builder.create_builder_config(
precision=args.dtype,
timing_cache=timing_cache,
tensor_parallel=args.world_size,
int8=(args.quant_mode.has_act_or_weight_quant()
or args.quant_mode.has_int8_kv_cache()),
tensor_parallel=args.tp_size,
pipeline_parallel=args.pp_size,
int8=int8_trt_flag,
fp8=args.enable_fp8,
strongly_typed=args.strongly_typed,
opt_level=args.builder_opt,
@ -678,6 +706,7 @@ def build(rank, args):
args.model_name,
args.dtype,
args.world_size,
args.pp_size,
cur_rank,
)
engine = build_rank_engine(
@ -706,7 +735,7 @@ def build(rank, args):
max_input_len=args.max_input_len,
max_output_len=args.max_output_len,
local_num_kv_heads=local_num_kv_heads,
head_size=args.hidden_size / args.num_heads,
head_size=args.hidden_size // args.num_heads,
num_layers=args.num_layers)
if cur_rank == 0:
@ -734,8 +763,8 @@ def run_build(args=None):
if args.parallel_build and args.world_size > 1 and \
torch.cuda.device_count() >= args.world_size:
logger.warning(
f'Parallelly build TensorRT engines. Please make sure that all '
f'of the {args.world_size} GPUs are totally free.')
f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
)
mp.spawn(build, nprocs=args.world_size, args=(args, ))
else:
args.parallel_build = False

View File

@ -93,40 +93,50 @@ def get_model(ckpt_path, dtype="float16", cache_dir=None):
return model
def get_args():
def parse_arguments(args):
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--model_dir",
type=str,
required=True,
help="Directory of a HF model checkpoint")
parser.add_argument("--dtype", help="Model data type.", default="float16")
parser.add_argument(
"--qformat",
'--model_name',
'-m',
type=str,
choices=['fp8', 'int4_awq'],
default='fp8',
help='Quantization format. Currently only fp8 is supported. '
'For int8 smoothquant, use smoothquant.py instead. ')
required=True,
choices=[
"chatglm_6b", "chatglm2_6b", "chatglm2_6b_32k", "chatglm3_6b",
"chatglm3_6b_base", "chatglm3_6b_32k", "glm_10b"
],
help=
'the name of the model, use "_" rather than "-" to connect the name parts'
)
parser.add_argument("--dtype", help="Model data type.", default="float16")
parser.add_argument("--qformat",
type=str,
choices=['fp8', 'int4_awq'],
default='int4_awq',
help='Quantization format.'
'For int8 smoothquant, use smoothquant.py instead.')
parser.add_argument("--calib_size",
type=int,
default=512,
default=32,
help="Number of samples for calibration.")
parser.add_argument("--export_path", default="exported_model")
parser.add_argument('--model_dir', type=str, default=None)
parser.add_argument("--export_path", default="awq")
parser.add_argument("--cache_dir",
type=str,
default=None,
default="dataset/",
help="Directory of dataset cache.")
parser.add_argument('--seed', type=int, default=None, help='Random seed')
args = parser.parse_args()
return args
def main():
def main(args=None):
if not torch.cuda.is_available():
raise EnvironmentError("GPU is required for inference.")
args = get_args()
args = parse_arguments(args)
if args.model_dir is None:
args.model_dir = args.model_name
if args.seed is not None:
random.seed(args.seed)
np.random.seed(args.seed)

View File

@ -43,8 +43,9 @@ def parse_arguments(args=None):
)
parser.add_argument('--max_output_len', type=int, default=1024)
parser.add_argument('--log_level', type=str, default='error')
parser.add_argument('--engine_dir', type=str, default='trtModel')
parser.add_argument('--engine_dir', type=str, default=None)
parser.add_argument('--beam_width', type=int, default=1)
parser.add_argument('--streaming', default=False, action='store_true')
parser.add_argument(
'--input_text',
type=str,
@ -71,14 +72,20 @@ def parse_arguments(args=None):
parser.add_argument('--top_k', type=int, default=1)
parser.add_argument('--top_p', type=float, default=0.0)
parser.add_argument('--random_seed', type=int, default=1)
return parser.parse_args(args)
args = parser.parse_args(args)
if args.engine_dir is None:
args.engine_dir = Path("output_" + args.model_name)
return args
if __name__ == '__main__':
args = parse_arguments()
tensorrt_llm.logger.set_level(args.log_level)
config_path = Path(args.engine_dir) / (args.model_name + '-config.json')
config_path = Path(args.engine_dir) / 'config.json'
with open(config_path, 'r') as f:
config = json.load(f)
@ -89,9 +96,11 @@ if __name__ == '__main__':
max_beam_width = config['builder_config']['max_beam_width']
remove_input_padding = config['builder_config']['remove_input_padding']
use_gpt_attention_plugin = config['plugin_config']['gpt_attention_plugin']
world_size = config['builder_config']['tensor_parallel']
assert world_size == tensorrt_llm.mpi_world_size(
), f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
tp_size = config['builder_config']['tensor_parallel']
pp_size = config['builder_config']['pipeline_parallel']
world_size = tp_size * pp_size
assert world_size == tensorrt_llm.mpi_world_size(), \
f'Engine world size ({tp_size} * {pp_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
if args.max_output_len > max_output_len:
print("Truncate max_output_len as %d" % max_output_len)
@ -199,9 +208,10 @@ if __name__ == '__main__':
model_config = ModelConfig(
vocab_size=config['builder_config']['vocab_size'],
num_layers=config['builder_config']['num_layers'],
num_heads=config['builder_config']['num_heads'] // world_size,
num_kv_heads=config['builder_config']['num_kv_heads'] // world_size,
hidden_size=config['builder_config']['hidden_size'] // world_size,
num_heads=config['builder_config']['num_heads'] // tp_size,
num_kv_heads=(config['builder_config']['num_kv_heads'] + tp_size - 1) //
tp_size,
hidden_size=config['builder_config']['hidden_size'] // tp_size,
gpt_attention_plugin=use_gpt_attention_plugin,
remove_input_padding=config['builder_config']['remove_input_padding'],
model_name=args.model_name,
@ -251,11 +261,8 @@ if __name__ == '__main__':
sampling_config,
output_sequence_lengths=True,
return_dict=True,
streaming=args.streaming,
)
torch.cuda.synchronize()
output_ids = output["output_ids"]
output_lengths = output["sequence_lengths"]
if runtime_rank == 0:
@ -271,18 +278,44 @@ if __name__ == '__main__':
]:
from process import process_response
for i in range(batch_size):
print("\nInput %2d ---> len=%d\n%s" %
(i, input_lengths[i], input_text[i]))
if args.streaming: # streaming output
print("#" * 80)
# only print the first batch and the first beam to show the effect,
# actually all beams of all batches are available
print("Input %2d ---> len=%d\n%s" %
(0, input_lengths[0], input_text[0]))
print("\nOutput %2d --->" % i)
output_ids_one_batch = output_ids[i, :, input_lengths[i]:]
output_lengths_one_batch = output_lengths[i]
output_token_list = tokenizer.batch_decode(output_ids_one_batch,
skip_special_tokens=True)
output_token_list = process_response(output_token_list)
for j, (length, simple_output) in enumerate(
zip(output_lengths_one_batch, output_token_list)):
print("\n Beam %2d ---> len=%d\n%s" %
(j, length, simple_output))
print("Finished!")
for output_item in output:
output_id = output_item["output_ids"]
output_sequence_lengths = output_item["sequence_lengths"]
output_id = output_id[0, 0, output_sequence_lengths[0, 0] - 1]
output_word = tokenizer.convert_ids_to_tokens(int(output_id))
output_word = output_word.replace("", " ") # For English
output_word = tokenizer.convert_tokens_to_string(output_word)
print(output_word, end="", flush=True)
print("\n" + "#" * 80)
else: # regular output
torch.cuda.synchronize()
output_ids = output["output_ids"]
output_lengths = output["sequence_lengths"]
print("#" * 80)
for i in range(batch_size):
print("Input %2d ---> len=%d\n%s" %
(i, input_lengths[i], input_text[i]))
print("\nOutput %2d --->" % i)
output_ids_one_batch = output_ids[i, :, input_lengths[i]:]
output_lengths_one_batch = output_lengths[i] - input_lengths[
i] + 1
output_token_list = tokenizer.batch_decode(
output_ids_one_batch, skip_special_tokens=True)
output_token_list = process_response(output_token_list)
for j, (length, simple_output) in enumerate(
zip(output_lengths_one_batch, output_token_list)):
print(" Beam %2d ---> len=%d\n%s" %
(j, length, simple_output))
print("#" * 80)
del decoder
print(f"Finished from worker {runtime_rank}")

View File

@ -0,0 +1,73 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import onnx
import tensorrt as trt
from onnx import TensorProto, helper
def trt_dtype_to_onnx(dtype):
if dtype == trt.float16:
return TensorProto.DataType.FLOAT16
elif dtype == trt.float32:
return TensorProto.DataType.FLOAT
elif dtype == trt.int32:
return TensorProto.DataType.INT32
else:
raise TypeError("%s is not supported" % dtype)
def to_onnx(network, path):
inputs = []
for i in range(network.num_inputs):
network_input = network.get_input(i)
inputs.append(
helper.make_tensor_value_info(
network_input.name, trt_dtype_to_onnx(network_input.dtype),
list(network_input.shape)))
outputs = []
for i in range(network.num_outputs):
network_output = network.get_output(i)
outputs.append(
helper.make_tensor_value_info(
network_output.name, trt_dtype_to_onnx(network_output.dtype),
list(network_output.shape)))
nodes = []
for i in range(network.num_layers):
layer = network.get_layer(i)
layer_inputs = []
for j in range(layer.num_inputs):
ipt = layer.get_input(j)
if ipt is not None:
layer_inputs.append(layer.get_input(j).name)
layer_outputs = [
layer.get_output(j).name for j in range(layer.num_outputs)
]
nodes.append(
helper.make_node(str(layer.type),
name=layer.name,
inputs=layer_inputs,
outputs=layer_outputs,
domain="com.nvidia"))
onnx_model = helper.make_model(helper.make_graph(nodes,
'attention',
inputs,
outputs,
initializer=None),
producer_name='NVIDIA')
onnx.save(onnx_model, path)

View File

@ -29,6 +29,15 @@ from tensorrt_llm.models.quantized.quant import get_dummy_quant_scales
from tensorrt_llm.quantization import QuantMode
def split(weight: np.ndarray, tp_size: int, rank: int = 0, dim: int = 0):
if tp_size == 1:
return weight
elif weight.ndim == 1:
return np.ascontiguousarray(np.split(weight, tp_size)[rank].copy())
return np.ascontiguousarray(
np.split(weight, tp_size, axis=dim)[rank].copy())
def split_matrix(weight: np.ndarray, tp_size: int, rank: int, dim: int):
return np.ascontiguousarray(split(weight, tp_size, rank, dim=dim))
@ -81,7 +90,7 @@ def load_quant_weight(src, value_dst, scale_dst, plugin_weight_only_quant_type):
def load_from_hf(
trt_model,
hf_model_dir,
mapping=None,
mapping=Mapping(),
dtype="float32",
model_name=None,
multi_query_mode=False,
@ -100,9 +109,9 @@ def load_from_hf(
hf_model = transformers.AutoModel.from_pretrained(hf_model_dir,
trust_remote_code=True)
num_layers = hf_model.config.num_layers
hidden_size = hf_model.config.hidden_size
num_heads = hf_model.config.num_attention_heads
num_layers = hf_model.config.num_layers
torch_type = str_dtype_to_torch(dtype)
quant_mode = getattr(trt_model, 'quant_mode', QuantMode(0))
@ -344,7 +353,8 @@ def load_from_hf(
scale_dst=dst.per_channel_scale,
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
else:
dst.weight.value = torch_to_numpy(split_weight)
dst.weight.value = np.ascontiguousarray(
torch_to_numpy(split_weight))
feed_weight_count += 1
# Dense multiplication bias, only GLM-10B
@ -462,7 +472,8 @@ def load_from_hf(
scale_dst=dst.per_channel_scale,
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
else:
dst.weight.value = torch_to_numpy(split_weight)
dst.weight.value = np.ascontiguousarray(
torch_to_numpy(split_weight))
feed_weight_count += 1
# Multilayer perceptron 4h -> h bias, only GLM-10B
@ -498,115 +509,6 @@ def load_from_hf(
return trt_model
def load_from_hf_checkpoint(
trtllm_falcon: tensorrt_llm.models.FalconForCausalLM,
model_dir: Union[str, Path],
mapping=Mapping(),
dtype: Union[str, torch.dtype] = torch.float32,
):
logger.info('Loading weights from HF Falcon...')
tik = time.time()
model_dir = Path(model_dir)
if isinstance(dtype, str):
dtype = tensorrt_llm._utils.str_dtype_to_torch(dtype)
def is_bias(_name):
return 'bias' in _name
layers_range = trtllm_falcon.get_transformer_layers(
trtllm_falcon.mapping, trtllm_falcon.num_layers)
for model_file in iterate_shard_files(model_dir, mapping.tp_rank):
logger.debug(f'Loading file {str(model_file)}...')
state_dict = load_state_dict(model_file, dtype)
for name, param in state_dict.items():
logger.debug(f'Converting weight {name}...')
i = retrieved_layer_index_from_name(name)
if i is None:
layer = None
else:
if i not in layers_range:
continue
layer = trtllm_falcon.layers[i - layers_range[0]]
if 'self_attention.query_key_value' in name:
if not is_bias(name):
layer.attention.qkv.weight.value = split_qkv_weight(
trtllm_falcon,
param,
mapping.tp_size,
mapping.tp_rank,
is_bias=False,
num_kv_heads=trtllm_falcon.num_kv_heads)
else:
layer.attention.qkv.bias.value = split_qkv_weight(
trtllm_falcon,
param,
mapping.tp_size,
mapping.tp_rank,
is_bias=True,
num_kv_heads=trtllm_falcon.num_kv_heads)
elif 'self_attention.dense' in name:
if not is_bias(name):
layer.attention.dense.weight.value = split_matrix(
param, mapping.tp_size, mapping.tp_rank, dim=1)
else:
layer.attention.dense.bias.value = param
elif 'mlp.dense_h_to_4h' in name:
if not is_bias(name):
layer.mlp.fc.weight.value = split_matrix(param,
mapping.tp_size,
mapping.tp_rank,
dim=0)
else:
layer.mlp.fc.bias.value = split_matrix(param,
mapping.tp_size,
mapping.tp_rank,
dim=0)
elif 'mlp.dense_4h_to_h' in name:
if not is_bias(name):
layer.mlp.proj.weight.value = split_matrix(param,
mapping.tp_size,
mapping.tp_rank,
dim=1)
else:
layer.mlp.proj.bias.value = param
elif 'ln_attn' in name or 'input_layernorm' in name:
if not is_bias(name):
layer.input_layernorm.weight.value = param
else:
layer.input_layernorm.bias.value = param
elif 'ln_mlp' in name:
assert layer.mlp_layernorm is not None
if not is_bias(name):
layer.mlp_layernorm.weight.value = param
else:
layer.mlp_layernorm.bias.value = param
elif 'post_attention_layernorm' in name:
assert layer.post_layernorm is not None
if not is_bias(name):
layer.post_layernorm.weight.value = param
else:
layer.post_layernorm.bias.value = param
elif 'word_embeddings' in name:
if mapping.is_first_pp_rank():
trtllm_falcon.embedding.weight.value = param.copy()
if mapping.is_last_pp_rank():
trtllm_falcon.lm_head.weight.value = split_matrix(
param, mapping.tp_size, mapping.tp_rank, dim=0)
elif 'ln_f' in name:
if mapping.is_last_pp_rank():
if not is_bias(name):
trtllm_falcon.ln_f.weight.value = param
else:
trtllm_falcon.ln_f.bias.value = param
del state_dict
tok = time.time()
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
logger.info(f'Weights loaded. Total time: {t}')
def get_scaling_factors(
model_path: Union[str, Path],
num_layers: int,

View File

@ -74,14 +74,14 @@ python build.py --model_type t5 \
--dtype float32 \
--max_beam_width 1
# Example 2: build flan-t5-small using 4-way tensor parallelism on a node with 8 GPUs (but only use 4 of them, for demonstration purpose), BF16, enabling beam search up to width=3
# Example 2: build t5-small using 4-way tensor parallelism on a node with 8 GPUs (but only use 4 of them, for demonstration purpose), BF16, enabling beam search up to width=3
python build.py --model_type t5 \
--world_size 4 \
--tp_size 4 \
--gpus_per_node 4 \
--weight_dir tmp/trt_models/flan-t5-small/tp4 \
-o tmp/trt_engines/flan-t5-small/4-gpu \
--engine_name flan-t5-small \
--weight_dir tmp/trt_models/t5-small/tp4 \
-o tmp/trt_engines/t5-small/4-gpu \
--engine_name t5-small \
--remove_input_padding \
--use_bert_attention_plugin \
--use_gpt_attention_plugin \
@ -90,7 +90,7 @@ python build.py --model_type t5 \
--dtype bfloat16 \
--max_beam_width 3
# Example 3: build flan-t5-small using 2-way tensor parallelism and 2-way pipeline parallelism on a node with 8 GPUs, FP16, enabling beam search up to width=3
# Example 3: build flan-t5-small using 2-way tensor parallelism and 2-way pipeline parallelism on a node with 8 GPUs, BF16, enabling beam search up to width=3
python build.py --model_type t5 \
--world_size 4 \
--tp_size 2 \
@ -104,7 +104,7 @@ python build.py --model_type t5 \
--use_gpt_attention_plugin \
--use_gemm_plugin \
--use_rmsnorm_plugin \
--dtype float16 \
--dtype bfloat16 \
--max_beam_width 3
```
@ -117,9 +117,15 @@ Note that during model deployment, only the TensorRT engine files are needed. Pr
# Example 1: inference w/ single GPU, FP32, greedy search, compare results with HuggingFace FP32
python3 run.py --engine_dir tmp/trt_engines/t5-small/1-gpu/float32/tp1 --engine_name t5-small --model_name t5-small --max_new_token=64 --num_beams=1 --compare_hf_fp32
# Example 2: inference w/ 4 GPUs (4-way TP, as configured during the engine building step), BF16, greedy search
mpirun --allow-run-as-root -np 4 python3 run.py --engine_dir tmp/trt_engines/flan-t5-small/4-gpu/bfloat16/tp4 --engine_name flan-t5-small --model_name google/flan-t5-small --max_new_token=64 --num_beams=1
# Example 2: inference w/ 4 GPUs (4-way TP, as configured during the engine building step), BF16, greedy search, compare results with HuggingFace FP32
mpirun --allow-run-as-root -np 4 python3 run.py --engine_dir tmp/trt_engines/t5-small/4-gpu/bfloat16/tp4 --engine_name t5-small --model_name t5-small --max_new_token=64 --num_beams=1 --compare_hf_fp32
# Example 3: inference w/ 4 GPUs (2-way TP and 2-way PP, as configured during the engine building step), FP16, beam search
mpirun --allow-run-as-root -np 4 python3 run.py --engine_dir tmp/trt_engines/flan-t5-small/4-gpu/float16/tp2 --engine_name flan-t5-small --model_name google/flan-t5-small --max_new_token=64 --num_beams=3
# Example 3: inference w/ 4 GPUs (2-way TP and 2-way PP, as configured during the engine building step), BF16, greedy search
mpirun --allow-run-as-root -np 4 python3 run.py --engine_dir tmp/trt_engines/flan-t5-small/4-gpu/bfloat16/tp2 --engine_name flan-t5-small --model_name google/flan-t5-small --max_new_token=64 --num_beams=1
```
### Reminders
- Flan-T5 models have known issues regarding FP16 precision and using BF16 precision is recommended, regardless of TRT-LLM. While we are working on improving FP16 results, please stay with FP32 or BF16 precision for Flan-T5 family.
- Batched/Ragged input with beam search is having subtle issues with some sequence results being truncated. For the time being, please follow (1) if batch size = 1, no problem (2) if batched input is padded (i.e., not using `--remove_input_padding` flag), no problem (3) if batched input is ragged (i.e., using `--remove_input_padding`), only use greedy search for now.
- For T5 and Flan-T5 family that have relative attention bias design, the relative attention table is split along `num_heads` dimension in Tensor Parallelism mode. Therefore, `num_heads` must be divisible by `tp_size`. Please be aware of this when setting the TP parameter.

Some files were not shown because too many files have changed in this diff Show More