mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#465)
* Update TensorRT-LLM --------- Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
parent
6755a3f077
commit
711a28d9bf
@ -261,7 +261,7 @@ The list of supported models is:
|
||||
* [StarCoder](examples/gpt)
|
||||
* [T5](examples/enc_dec)
|
||||
|
||||
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder support that contains many encoder-decoder models such as T5, Flan-T5, etc. We unroll the exact model names in the list above to let users find specific models easiler.
|
||||
Note: [Encoder-Decoder](examples/enc_dec/) provides general encoder-decoder support that contains many encoder-decoder models such as T5, Flan-T5, etc. We unroll the exact model names in the list above to let users find specific models easier.
|
||||
|
||||
## Performance
|
||||
|
||||
|
||||
@ -16,25 +16,19 @@
|
||||
*/
|
||||
#include "tensorrt_llm/batch_manager/GptManager.h"
|
||||
#include "tensorrt_llm/batch_manager/NamedTensor.h"
|
||||
#include "tensorrt_llm/batch_manager/callbacks.h"
|
||||
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/common/mpiUtils.h"
|
||||
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
||||
#include "tensorrt_llm/runtime/gptJsonConfig.h"
|
||||
#include "tensorrt_llm/runtime/gptSession.h"
|
||||
#include "tensorrt_llm/runtime/tllmLogger.h"
|
||||
|
||||
#include <NvInfer.h>
|
||||
#include <NvInferPlugin.h>
|
||||
#include <chrono>
|
||||
#include <cxxopts.hpp>
|
||||
#include <iostream>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <string>
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
using namespace tensorrt_llm::batch_manager;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
using namespace tensorrt_llm::mpi;
|
||||
@ -456,7 +450,7 @@ std::pair<std::vector<std::vector<int32_t>>, std::vector<int32_t>> parseDataset(
|
||||
}
|
||||
|
||||
void benchmarkGptManager(std::string const& modelName, std::filesystem::path const& engineDir, std::string const& type,
|
||||
std::string const& datasetPath, int beamWidth, std::shared_ptr<nvinfer1::ILogger> const& logger,
|
||||
std::string const& datasetPath, int beamWidth, int warmUp, std::shared_ptr<nvinfer1::ILogger> const& logger,
|
||||
TrtGptModelOptionalParams const& optionalParams, batch_scheduler::SchedulerPolicy schedulerPolicy)
|
||||
{
|
||||
auto const worldConfig = WorldConfig::mpi(*logger);
|
||||
@ -506,6 +500,19 @@ void benchmarkGptManager(std::string const& modelName, std::filesystem::path con
|
||||
|
||||
if (worldConfig.getRank() == 0)
|
||||
{
|
||||
// Warm up
|
||||
for (auto i = 1; i < warmUp + 1; ++i)
|
||||
{
|
||||
// skip terminateReqId
|
||||
if (i == terminateReqId)
|
||||
{
|
||||
i += 1;
|
||||
}
|
||||
gptServer->enqueue(tensors_list[0], i, false);
|
||||
}
|
||||
gptServer->waitForEmpty();
|
||||
|
||||
// Benchmark
|
||||
recorder->initialize();
|
||||
for (int i = 0; i < tensors_list.size(); ++i)
|
||||
{
|
||||
@ -539,6 +546,8 @@ int main(int argc, char* argv[])
|
||||
cxxopts::value<std::string>()->default_value(""));
|
||||
options.add_options()(
|
||||
"beam_width", "Specify beam width you want to benchmark.", cxxopts::value<int>()->default_value("1"));
|
||||
options.add_options()(
|
||||
"warm_up", "Specify warm up iterations before benchmark starts.", cxxopts::value<int>()->default_value("2"));
|
||||
|
||||
options.add_options()("max_num_sequences", "Max number of Sequences.", cxxopts::value<int>());
|
||||
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
|
||||
@ -651,7 +660,7 @@ int main(int argc, char* argv[])
|
||||
try
|
||||
{
|
||||
benchmarkGptManager(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), type,
|
||||
datasetPath, beamWidth, logger, optionalParams, schedulerPolicy);
|
||||
datasetPath, beamWidth, result["warm_up"].as<int>(), logger, optionalParams, schedulerPolicy);
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
|
||||
@ -43,9 +43,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
|
||||
std::string modelNameHyphen = modelName;
|
||||
std::filesystem::path jsonFileName = dataPath / "config.json";
|
||||
if (tc::strStartsWith(modelName, "chatglm"))
|
||||
if (tc::strStartsWith(modelName, "chatglm") || tc::strStartsWith(modelName, "glm"))
|
||||
{
|
||||
std::replace(modelNameHyphen.begin(), modelNameHyphen.end(), '_', '-');
|
||||
jsonFileName = dataPath / (modelNameHyphen + std::string("-config.json"));
|
||||
}
|
||||
auto const json = GptJsonConfig::parse(jsonFileName);
|
||||
|
||||
@ -15,8 +15,10 @@
|
||||
|
||||
from argparse import ArgumentParser
|
||||
|
||||
import tensorrt as trt
|
||||
# isort: off
|
||||
import torch
|
||||
import tensorrt as trt
|
||||
# isort: on
|
||||
from cuda import cuda, cudart
|
||||
from mpi4py import MPI
|
||||
from polygraphy.backend.trt import CreateConfig, EngineFromNetwork
|
||||
|
||||
@ -376,6 +376,7 @@ _allowed_configs = {
|
||||
build_config=BuildConfig(
|
||||
num_layers=28,
|
||||
num_heads=32,
|
||||
num_kv_heads=2,
|
||||
hidden_size=4096,
|
||||
vocab_size=65024,
|
||||
hidden_act='swiglu',
|
||||
@ -393,6 +394,7 @@ _allowed_configs = {
|
||||
build_config=BuildConfig(
|
||||
num_layers=28,
|
||||
num_heads=32,
|
||||
num_kv_heads=2,
|
||||
hidden_size=4096,
|
||||
vocab_size=65024,
|
||||
hidden_act='swiglu',
|
||||
|
||||
@ -45,7 +45,9 @@ def get_engine_name(model, dtype, tp_size, rank):
|
||||
|
||||
def serialize_engine(engine, path):
|
||||
with open(path, 'wb') as f:
|
||||
f.write(bytearray(engine))
|
||||
# engine object is already complies with python buffer protocol, no need to
|
||||
# convert it to bytearray before write, converting to bytearray consumes lots of memory
|
||||
f.write(engine)
|
||||
|
||||
|
||||
class BaseBenchmark(object):
|
||||
|
||||
@ -14,11 +14,9 @@
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import multiprocessing as mp
|
||||
from multiprocessing import Process, Queue
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
from mem_monitor import mem_monitor
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
@ -213,6 +211,7 @@ def main(args):
|
||||
from allowed_configs import get_allowed_models
|
||||
from bert_benchmark import BERTBenchmark
|
||||
from gpt_benchmark import GPTBenchmark
|
||||
from mem_monitor import MemoryMonitor
|
||||
|
||||
from tensorrt_llm.logger import logger
|
||||
|
||||
@ -282,11 +281,8 @@ def main(args):
|
||||
torch.cuda.empty_cache()
|
||||
latencies = []
|
||||
|
||||
# Launch a subprocess to monitor memory usage
|
||||
q1 = Queue() # q1 is used for sending signal to subprocess
|
||||
q2 = Queue() # q2 is used for receiving results from subprocess
|
||||
mem_monitor_process = Process(target=mem_monitor, args=(q1, q2))
|
||||
mem_monitor_process.start()
|
||||
memory_monitor = MemoryMonitor()
|
||||
memory_monitor.start()
|
||||
|
||||
iter_idx = 0
|
||||
try:
|
||||
@ -313,15 +309,12 @@ def main(args):
|
||||
|
||||
except Exception as e:
|
||||
print("Found exception during benchmarking", e.with_traceback())
|
||||
mem_monitor_process.kill()
|
||||
memory_monitor.kill()
|
||||
raise e
|
||||
logger.debug("Sending signal to mem monitor process, start")
|
||||
q1.put(1)
|
||||
logger.debug("Sending signal to mem monitor process, done")
|
||||
peak_gpu_used = q2.get()
|
||||
logger.debug("Get peak gpu memory usage from mem monitor process, done")
|
||||
mem_monitor_process.join()
|
||||
logger.debug("Memory monitor process joined")
|
||||
|
||||
memory_monitor.stop()
|
||||
_, peak_gpu_used = memory_monitor.get_peak_memory_usage("GiB")
|
||||
peak_gpu_used = round(peak_gpu_used, 3)
|
||||
|
||||
latency = round(sum(latencies) / iter_idx, 3)
|
||||
latencies.sort()
|
||||
|
||||
@ -16,8 +16,10 @@ import os
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
import tensorrt as trt
|
||||
# isort: off
|
||||
import torch
|
||||
import tensorrt as trt
|
||||
#isort: on
|
||||
from allowed_configs import get_build_config
|
||||
from base_benchmark import BaseBenchmark, serialize_engine
|
||||
|
||||
|
||||
@ -91,6 +91,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
self.use_layernorm_plugin = False
|
||||
self.use_rmsnorm_plugin = False
|
||||
self.use_lookup_plugin = non_mha_plg_dtype
|
||||
self.use_weight_only_quant_gemm_plugin = non_mha_plg_dtype
|
||||
self.enable_context_fmha = use_mha_plugin
|
||||
|
||||
self.remove_input_padding = use_non_mha_plugin
|
||||
@ -288,7 +289,8 @@ class GPTBenchmark(BaseBenchmark):
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_input_len=self.max_input_len,
|
||||
max_output_len=self.max_output_len,
|
||||
int8=self.quant_mode.has_act_and_weight_quant(),
|
||||
int8=self.quant_mode.has_act_and_weight_quant()
|
||||
or self.quant_mode.is_int8_weight_only(),
|
||||
quant_mode=self.quant_mode,
|
||||
use_refit=self.refit,
|
||||
opt_level=self.builder_opt,
|
||||
@ -505,7 +507,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
dtype=self.dtype)
|
||||
network.plugin_config.set_quantize_tensor_plugin()
|
||||
network.plugin_config.set_quantize_per_token_plugin()
|
||||
elif self.use_weight_only:
|
||||
elif self.use_weight_only and self.use_weight_only_quant_gemm_plugin:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype=self.dtype)
|
||||
|
||||
|
||||
@ -12,29 +12,57 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import time
|
||||
from multiprocessing import Event, Process, Queue
|
||||
|
||||
import pynvml
|
||||
from tensorrt_llm.logger import logger
|
||||
from tensorrt_llm.profiler import (MemUnitType, bytes_to_target_unit,
|
||||
device_memory_info)
|
||||
|
||||
|
||||
def get_memory_info(handle):
|
||||
mem_info = pynvml.nvmlDeviceGetMemoryInfo(handle,
|
||||
version=pynvml.nvmlMemory_v2)
|
||||
total = round(mem_info.total / 1024 / 1024 / 1024, 2)
|
||||
used = round(mem_info.used / 1024 / 1024 / 1024, 2)
|
||||
free = round(mem_info.free / 1024 / 1024 / 1024, 2)
|
||||
return total, used, free
|
||||
class MemoryMonitor:
|
||||
|
||||
def __init__(self, query_interval=0.1):
|
||||
self.query_interval = query_interval # second(s)
|
||||
self.mem_monitor_process = None
|
||||
# bytes
|
||||
self._peak_host_memory = 0
|
||||
self._peak_device_memory = 0
|
||||
|
||||
def mem_monitor(q1, q2):
|
||||
pynvml.nvmlInit()
|
||||
handle = pynvml.nvmlDeviceGetHandleByIndex(0)
|
||||
self.device_handles = {}
|
||||
|
||||
peak_used = 0
|
||||
while q1.empty():
|
||||
_, used, _ = get_memory_info(handle)
|
||||
peak_used = max(used, peak_used)
|
||||
time.sleep(0.1)
|
||||
self.signal_event = Event() # Sending signal to subprocess
|
||||
self.peak_mem_queue = Queue() # Receiving results from subprocess
|
||||
|
||||
pynvml.nvmlShutdown()
|
||||
q2.put(peak_used)
|
||||
def start(self):
|
||||
self.mem_monitor_process = Process(target=self._upd_peak_memory_usage,
|
||||
args=(self.signal_event,
|
||||
self.peak_mem_queue))
|
||||
self.mem_monitor_process.start()
|
||||
logger.debug("Launched memory monitor subprocess.")
|
||||
|
||||
def kill(self):
|
||||
if self.mem_monitor_process is not None:
|
||||
self.mem_monitor_process.kill()
|
||||
logger.debug("Memory monitor subprocess is killed.")
|
||||
|
||||
def stop(self):
|
||||
self.signal_event.set()
|
||||
logger.debug("Sent signal to stop memory monitor subprocess.")
|
||||
|
||||
self._peak_device_memory = max(self._peak_device_memory,
|
||||
self.peak_mem_queue.get())
|
||||
|
||||
self.mem_monitor_process.join()
|
||||
self.mem_monitor_process = None
|
||||
logger.debug("Memory monitor subprocess joined.")
|
||||
|
||||
def _upd_peak_memory_usage(self, signal_event, peak_mem_queue):
|
||||
peak_used, _, _ = device_memory_info()
|
||||
while not signal_event.is_set():
|
||||
used, _, _ = device_memory_info()
|
||||
peak_used = max(used, peak_used)
|
||||
peak_mem_queue.put(peak_used)
|
||||
|
||||
def get_peak_memory_usage(self, unit: MemUnitType = 'GiB'):
|
||||
return bytes_to_target_unit(self._peak_host_memory, unit), \
|
||||
bytes_to_target_unit(self._peak_device_memory, unit)
|
||||
|
||||
@ -45,15 +45,17 @@ class GptManager
|
||||
{
|
||||
public:
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
using TokenIdType = tensorrt_llm::runtime::TokenIdType;
|
||||
using RequestList = std::list<std::shared_ptr<LlmRequest>>;
|
||||
using TensorPtr = runtime::ITensor::SharedPtr;
|
||||
|
||||
GptManager(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
|
||||
GptManager(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, SizeType maxBeamWidth,
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
|
||||
SendResponseCallback sendResponseCb, PollStopSignalCallback pollStopSignalCb = nullptr,
|
||||
ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
|
||||
const TrtGptModelOptionalParams& optionalParams = TrtGptModelOptionalParams(),
|
||||
std::optional<uint64_t> terminateReqId = std::nullopt);
|
||||
std::optional<uint64_t> terminateReqId = std::nullopt, std::optional<SizeType> maxDraftTokens = std::nullopt,
|
||||
bool excludeInputInOutput = false);
|
||||
|
||||
/* Wraps the user-provided callback for requests.
|
||||
Adds requests to request table.
|
||||
@ -70,6 +72,8 @@ public:
|
||||
|
||||
BatchManagerErrorCode_t waitUntilTerminate();
|
||||
|
||||
BatchManagerErrorCode_t shutdown();
|
||||
|
||||
virtual ~GptManager();
|
||||
|
||||
protected:
|
||||
@ -78,17 +82,23 @@ protected:
|
||||
virtual BatchManagerErrorCode_t step(RequestList& activeRequests, std::set<uint64_t>& activeRequestsIds);
|
||||
|
||||
private:
|
||||
SizeType getMaxInputLen() const;
|
||||
SizeType getMaxOutputLen() const;
|
||||
SizeType getMaxNumSequences() const;
|
||||
|
||||
void validateLlmRequest(LlmRequest& newReq) const;
|
||||
static std::shared_ptr<LlmRequest> fillLlmRequest(std::shared_ptr<InferenceRequest> newReq);
|
||||
static std::shared_ptr<std::vector<int32_t>> getReqInputTokens(std::shared_ptr<InferenceRequest> new_req);
|
||||
static int32_t getMaxNewTokens(std::shared_ptr<InferenceRequest> newReq);
|
||||
static std::shared_ptr<std::vector<TokenIdType>> getReqInputTokens(std::shared_ptr<InferenceRequest> newReq);
|
||||
static SizeType getMaxNewTokens(std::shared_ptr<InferenceRequest> newReq);
|
||||
|
||||
GetInferenceRequestsCallback mGetInferenceRequestsCb;
|
||||
SendResponseCallback mSendResponseCb;
|
||||
PollStopSignalCallback mPollStopSignalCb;
|
||||
ReturnBatchManagerStatsCallback mReturnBatchManagerStatsCb;
|
||||
|
||||
std::shared_ptr<TrtGptModel> mTrtGptModel;
|
||||
SizeType mMaxInputLen;
|
||||
SizeType mMaxOutputLen;
|
||||
SizeType mMaxKvCacheLen;
|
||||
SizeType mMaxNumSequences;
|
||||
std::optional<uint64_t> mTerminateReqId;
|
||||
std::optional<SizeType> mMaxDraftTokens;
|
||||
|
||||
// Iteration counter - incremented every iteration of the generation loop
|
||||
int64_t mIterationCounter;
|
||||
@ -96,16 +106,14 @@ private:
|
||||
RequestList mActiveRequests;
|
||||
// IDs of live requests
|
||||
std::set<uint64_t> mActiveRequestsIds;
|
||||
// Boolean that controls if prompt should be included in output tokens for non-streaming
|
||||
bool mExcludeInputInOutput;
|
||||
|
||||
GetInferenceRequestsCallback mGetInferenceRequestsCb;
|
||||
SendResponseCallback mSendResponseCb;
|
||||
PollStopSignalCallback mPollStopSignalCb;
|
||||
ReturnBatchManagerStatsCallback mReturnBatchManagerStatsCb;
|
||||
|
||||
std::atomic<bool> destructor_called_;
|
||||
std::atomic<bool> shutdown_requested_;
|
||||
void decoupled_execution_loop();
|
||||
std::shared_ptr<std::thread> worker_thread_;
|
||||
inline static const std::string kInputIdsTensorName_ = "input_ids";
|
||||
inline static const std::string kDraftInputIdsTensorName_ = "draft_input_ids";
|
||||
inline static const std::string kMaxNewTokensTensorName_ = "request_output_len";
|
||||
inline static const std::string kBeamWidthTensorName_ = "beam_width";
|
||||
inline static const std::string kEndIdTensorName_ = "end_id";
|
||||
|
||||
@ -20,26 +20,44 @@
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
struct NamedTensor
|
||||
template <typename TTensor>
|
||||
struct GenericNamedTensor
|
||||
{
|
||||
using TensorPtr = tensorrt_llm::runtime::ITensor::SharedPtr;
|
||||
using TensorPtr = TTensor;
|
||||
|
||||
TensorPtr tensor;
|
||||
std::string name;
|
||||
|
||||
NamedTensor() = default;
|
||||
~NamedTensor() = default;
|
||||
GenericNamedTensor() = default;
|
||||
~GenericNamedTensor() = default;
|
||||
|
||||
// Host Tensor constructor
|
||||
NamedTensor(
|
||||
GenericNamedTensor(
|
||||
nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, const void* _data = nullptr);
|
||||
|
||||
NamedTensor(TensorPtr _tensor, std::string _name)
|
||||
GenericNamedTensor(TensorPtr _tensor, std::string _name)
|
||||
: tensor(std::move(_tensor))
|
||||
, name(std::move(_name))
|
||||
{
|
||||
}
|
||||
|
||||
GenericNamedTensor(std::string _name)
|
||||
: name(std::move(_name))
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
struct NamedTensor : public GenericNamedTensor<tensorrt_llm::runtime::ITensor::SharedPtr>
|
||||
{
|
||||
using Base = GenericNamedTensor<tensorrt_llm::runtime::ITensor::SharedPtr>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
|
||||
NamedTensor(
|
||||
nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, const void* _data = nullptr);
|
||||
|
||||
NamedTensor(TensorPtr _tensor, std::string _name)
|
||||
: Base(_tensor, _name){};
|
||||
|
||||
std::vector<int64_t> serialize();
|
||||
static NamedTensor deserialize(const int64_t* packed);
|
||||
};
|
||||
|
||||
@ -38,33 +38,34 @@
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
class InferenceRequest
|
||||
template <typename TTensor, typename TTensorMap>
|
||||
class GenericInferenceRequest
|
||||
{
|
||||
public:
|
||||
using TensorPtr = tensorrt_llm::runtime::ITensor::SharedPtr;
|
||||
using TensorMap = tensorrt_llm::runtime::StringPtrMap<tensorrt_llm::runtime::ITensor>;
|
||||
using TensorPtr = TTensor;
|
||||
using TensorMap = TTensorMap;
|
||||
|
||||
InferenceRequest(uint64_t requestId)
|
||||
GenericInferenceRequest(uint64_t requestId)
|
||||
: mRequestId(requestId)
|
||||
, mIsStreaming(false)
|
||||
{
|
||||
}
|
||||
|
||||
InferenceRequest(TensorMap const& inputTensors, uint64_t requestId)
|
||||
GenericInferenceRequest(TensorMap const& inputTensors, uint64_t requestId)
|
||||
: mInputTensors(inputTensors)
|
||||
, mRequestId(requestId)
|
||||
, mIsStreaming(false)
|
||||
{
|
||||
}
|
||||
|
||||
InferenceRequest(TensorMap&& inputTensors, uint64_t requestId)
|
||||
GenericInferenceRequest(TensorMap&& inputTensors, uint64_t requestId)
|
||||
: mInputTensors(std::move(inputTensors))
|
||||
, mRequestId(requestId)
|
||||
, mIsStreaming(false)
|
||||
{
|
||||
}
|
||||
|
||||
~InferenceRequest() {}
|
||||
~GenericInferenceRequest() {}
|
||||
|
||||
template <typename T>
|
||||
std::tuple<bool, T> getScalarValueFromTensor(
|
||||
@ -139,6 +140,36 @@ public:
|
||||
return mRequestId;
|
||||
}
|
||||
|
||||
protected:
|
||||
TensorMap mInputTensors;
|
||||
uint64_t mRequestId;
|
||||
bool mIsStreaming;
|
||||
};
|
||||
|
||||
class InferenceRequest : public GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr,
|
||||
tensorrt_llm::runtime::StringPtrMap<tensorrt_llm::runtime::ITensor>>
|
||||
{
|
||||
public:
|
||||
using Base = GenericInferenceRequest<tensorrt_llm::runtime::ITensor::SharedPtr,
|
||||
tensorrt_llm::runtime::StringPtrMap<tensorrt_llm::runtime::ITensor>>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
using TensorMap = Base::TensorMap;
|
||||
|
||||
InferenceRequest(uint64_t requestId)
|
||||
: Base(requestId)
|
||||
{
|
||||
}
|
||||
|
||||
InferenceRequest(TensorMap const& inputTensors, uint64_t requestId)
|
||||
: Base(inputTensors, requestId)
|
||||
{
|
||||
}
|
||||
|
||||
InferenceRequest(TensorMap&& inputTensors, uint64_t requestId)
|
||||
: Base(inputTensors, requestId)
|
||||
{
|
||||
}
|
||||
|
||||
const std::vector<int64_t> serialize() const
|
||||
{
|
||||
std::list<int64_t> packed;
|
||||
@ -184,11 +215,6 @@ public:
|
||||
ir->setIsStreaming(IsStreaming);
|
||||
return ir;
|
||||
}
|
||||
|
||||
private:
|
||||
TensorMap mInputTensors;
|
||||
uint64_t mRequestId;
|
||||
bool mIsStreaming;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -168,7 +168,7 @@ class BlockManager
|
||||
public:
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
|
||||
explicit BlockManager(std::size_t blocksInPool);
|
||||
explicit BlockManager(SizeType blocksInPool);
|
||||
|
||||
void startScheduling();
|
||||
|
||||
@ -179,22 +179,22 @@ public:
|
||||
// Simulate freeing all blocks for that sequence to check impact on number of free blocks
|
||||
void schedulingFreeAllBlocks(GenerationRequest& sequence);
|
||||
|
||||
[[nodiscard]] std::size_t getNumFreeBlocks() const
|
||||
[[nodiscard]] SizeType getNumFreeBlocks() const
|
||||
{
|
||||
return mFreeBlocks.size();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::size_t getNumAllocatedBlocks() const
|
||||
[[nodiscard]] SizeType getNumAllocatedBlocks() const
|
||||
{
|
||||
return mAllocatedBlocks.size();
|
||||
}
|
||||
|
||||
[[nodiscard]] bool hasFreeBlocks(std::size_t numRequired = 1) const
|
||||
[[nodiscard]] bool hasFreeBlocks(SizeType numRequired = 1) const
|
||||
{
|
||||
return getNumFreeBlocks() >= numRequired;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool schedulingHasFreeBlocks(std::size_t numRequired = 1) const
|
||||
[[nodiscard]] bool schedulingHasFreeBlocks(SizeType numRequired = 1) const
|
||||
{
|
||||
return mSchedulingNumFreeBlocks >= numRequired;
|
||||
}
|
||||
@ -205,7 +205,7 @@ private:
|
||||
// List of allocated blocks for each sequences
|
||||
std::vector<std::vector<KVCacheBlock>> mAllocatedBlocks;
|
||||
// Used to keep track of number of free blocks during scheduling
|
||||
std::size_t mSchedulingNumFreeBlocks;
|
||||
SizeType mSchedulingNumFreeBlocks;
|
||||
};
|
||||
|
||||
class KVCacheManager
|
||||
|
||||
@ -36,24 +36,28 @@ enum LlmRequestState_t
|
||||
REQUEST_STATE_GENERATION_COMPLETE = 3
|
||||
};
|
||||
|
||||
class LlmRequest
|
||||
template <typename TTensor>
|
||||
class GenericLlmRequest
|
||||
{
|
||||
public:
|
||||
using SizeType = runtime::SizeType;
|
||||
using TokenIdType = runtime::TokenIdType;
|
||||
using RequestIdType = std::uint64_t;
|
||||
using BeamTokens = std::vector<std::vector<TokenIdType>>;
|
||||
using VecTokens = std::vector<TokenIdType>;
|
||||
using VecLogProbs = std::vector<float>;
|
||||
using TensorPtr = runtime::ITensor::SharedPtr;
|
||||
using BeamTokens = std::vector<VecTokens>;
|
||||
using TensorPtr = TTensor;
|
||||
|
||||
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<std::vector<TokenIdType>> input_tokens,
|
||||
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
|
||||
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
GenericLlmRequest(RequestIdType requestId, SizeType maxNewTokens,
|
||||
std::shared_ptr<std::vector<TokenIdType>> inputTokens, runtime::SamplingConfig samplingConfig, bool isStreaming,
|
||||
std::optional<SizeType> endId = std::nullopt, std::optional<SizeType> padId = std::nullopt,
|
||||
std::optional<TensorPtr> embeddingBias = std::nullopt, std::optional<TensorPtr> badWordsList = std::nullopt,
|
||||
std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false)
|
||||
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false,
|
||||
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(input_tokens->size())
|
||||
, mPromptLen(inputTokens->size())
|
||||
, mMaxNewTokens(maxNewTokens)
|
||||
, mSamplingConfig(samplingConfig)
|
||||
, mState(REQUEST_STATE_CONTEXT_INIT)
|
||||
@ -61,7 +65,7 @@ public:
|
||||
, mEndId(endId)
|
||||
, mPadId(padId)
|
||||
, mBatchSlot(-1)
|
||||
, mOrigPromptLen(input_tokens->size())
|
||||
, mOrigPromptLen(inputTokens->size())
|
||||
, mEmbeddingBias(embeddingBias)
|
||||
, mBadWordsList(badWordsList)
|
||||
, mStopWordsList(stopWordsList)
|
||||
@ -70,10 +74,11 @@ public:
|
||||
, mReturnLogProbs(returnLogProbs)
|
||||
, mLogProbs(samplingConfig.beamWidth)
|
||||
, mCumLogProbs(samplingConfig.beamWidth)
|
||||
, mDraftTokens(draftTokens.value_or(std::make_shared<VecTokens>()))
|
||||
{
|
||||
mMaxSentTokenPos = mPromptLen - 1;
|
||||
// Scatter the input tokens to other beam
|
||||
mTokens = std::make_shared<BeamTokens>(mSamplingConfig.beamWidth, *input_tokens);
|
||||
mTokens = BeamTokens(mSamplingConfig.beamWidth, *inputTokens);
|
||||
|
||||
if ((mPromptEmbeddingTable.has_value() && !mPromptVocabSize.has_value())
|
||||
|| (!mPromptEmbeddingTable.has_value() && mPromptVocabSize.has_value()))
|
||||
@ -91,7 +96,7 @@ public:
|
||||
/// @return The number of tokens
|
||||
SizeType getNumTokens(SizeType beam) const
|
||||
{
|
||||
return mTokens->at(beam).size();
|
||||
return mTokens.at(beam).size();
|
||||
}
|
||||
|
||||
/// @brief Get max number of tokens across all beams
|
||||
@ -101,7 +106,7 @@ public:
|
||||
SizeType maxTokens = 0;
|
||||
for (SizeType beam = 0; beam < mSamplingConfig.beamWidth; ++beam)
|
||||
{
|
||||
maxTokens = std::max(maxTokens, static_cast<SizeType>(mTokens->at(beam).size()));
|
||||
maxTokens = std::max(maxTokens, static_cast<SizeType>(mTokens.at(beam).size()));
|
||||
}
|
||||
return maxTokens;
|
||||
}
|
||||
@ -112,19 +117,33 @@ public:
|
||||
/// @return The token index
|
||||
TokenIdType getToken(SizeType beam, SizeType pos) const
|
||||
{
|
||||
return mTokens->at(beam).at(pos);
|
||||
return mTokens.at(beam).at(pos);
|
||||
}
|
||||
|
||||
/// @brief Get the tokens at a given beam index
|
||||
/// @param beam The beam index
|
||||
/// @return A vector of tokens for this beam index, includes the prompt
|
||||
std::vector<TokenIdType> getTokens(SizeType beam) const
|
||||
/// @param beam The beam index
|
||||
/// @return A vector of tokens for this beam index, includes the prompt
|
||||
std::vector<TokenIdType> const& getTokens(SizeType beam) const
|
||||
{
|
||||
return mTokens->at(beam);
|
||||
return mTokens.at(beam);
|
||||
}
|
||||
|
||||
/// @brief Get the draft tokens
|
||||
/// @return shared_ptr to vector of draft tokens
|
||||
std::shared_ptr<std::vector<TokenIdType>> const& getDraftTokens() const
|
||||
{
|
||||
return mDraftTokens;
|
||||
}
|
||||
|
||||
/// @brief Returns true if request has draft tokens
|
||||
/// @return flag
|
||||
bool hasDraftTokens() const
|
||||
{
|
||||
return mDraftTokens && mDraftTokens->size() > 0;
|
||||
}
|
||||
|
||||
/// @brief Get the maximum number of generated tokens among all rays in beam
|
||||
/// @return The number of generated tokens (doesn't include the prompt tokens)
|
||||
/// @return The number of generated tokens (doesn't include the prompt tokens)
|
||||
SizeType getMaxNumGeneratedTokens() const
|
||||
{
|
||||
return getMaxBeamNumTokens() - mPromptLen;
|
||||
@ -135,7 +154,7 @@ public:
|
||||
/// @param beam The beam to which to add the new token
|
||||
void addNewToken(TokenIdType token, SizeType beam)
|
||||
{
|
||||
mTokens->at(beam).push_back(token);
|
||||
mTokens.at(beam).push_back(token);
|
||||
}
|
||||
|
||||
/// @brief Add new generated tokens to the vector of tokens
|
||||
@ -147,7 +166,7 @@ public:
|
||||
for (std::size_t beam = 0; beam < beamTokens.size(); ++beam)
|
||||
{
|
||||
const auto outputId = beamTokens[beam];
|
||||
mTokens->at(beam).push_back(outputId);
|
||||
mTokens.at(beam).push_back(outputId);
|
||||
}
|
||||
}
|
||||
|
||||
@ -158,7 +177,7 @@ public:
|
||||
assert(generatedBeamTokens.size() == mSamplingConfig.beamWidth);
|
||||
for (std::size_t beam = 0; beam < generatedBeamTokens.size(); ++beam)
|
||||
{
|
||||
auto& beamTokens = (*mTokens)[beam];
|
||||
auto& beamTokens = mTokens[beam];
|
||||
beamTokens.resize(mPromptLen);
|
||||
beamTokens.insert(beamTokens.end(), generatedBeamTokens[beam].begin(), generatedBeamTokens[beam].end());
|
||||
}
|
||||
@ -173,9 +192,9 @@ public:
|
||||
// As a temporary solution, we currently reset the tokens to the prompt
|
||||
if (mSamplingConfig.beamWidth > 1)
|
||||
{
|
||||
for (std::size_t beam = 0; beam < mTokens->size(); ++beam)
|
||||
for (std::size_t beam = 0; beam < mTokens.size(); ++beam)
|
||||
{
|
||||
auto& beamTokens = mTokens->at(beam);
|
||||
auto& beamTokens = mTokens.at(beam);
|
||||
beamTokens.resize(mPromptLen);
|
||||
if (mReturnLogProbs)
|
||||
{
|
||||
@ -186,9 +205,9 @@ public:
|
||||
else
|
||||
{
|
||||
SizeType newPromptLen = std::min(maxInputLen, mPromptLen + getMaxNumGeneratedTokens());
|
||||
for (std::size_t beam = 0; beam < mTokens->size(); ++beam)
|
||||
for (std::size_t beam = 0; beam < mTokens.size(); ++beam)
|
||||
{
|
||||
auto& beamTokens = mTokens->at(beam);
|
||||
auto& beamTokens = mTokens.at(beam);
|
||||
beamTokens.resize(newPromptLen);
|
||||
|
||||
if (mReturnLogProbs)
|
||||
@ -225,21 +244,6 @@ public:
|
||||
return mPromptEmbeddingTable;
|
||||
}
|
||||
|
||||
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager)
|
||||
{
|
||||
if (!mPromptEmbeddingTable.has_value()
|
||||
|| mPromptEmbeddingTable.value()->getMemoryType() == runtime::MemoryType::kGPU)
|
||||
{
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
TensorPtr gpuPromptEmbeddingTable
|
||||
= manager.copyFrom(*mPromptEmbeddingTable.value(), runtime::MemoryType::kGPU);
|
||||
mPromptEmbeddingTable = gpuPromptEmbeddingTable;
|
||||
}
|
||||
}
|
||||
|
||||
std::optional<SizeType> getPromptVocabSize() const
|
||||
{
|
||||
return mPromptVocabSize;
|
||||
@ -296,6 +300,11 @@ public:
|
||||
return mOrigPromptLen;
|
||||
}
|
||||
|
||||
void setDraftTokens(const std::shared_ptr<VecTokens>& draftTokens)
|
||||
{
|
||||
mDraftTokens = draftTokens;
|
||||
}
|
||||
|
||||
RequestIdType mRequestId;
|
||||
SizeType mPromptLen;
|
||||
SizeType mMaxNewTokens;
|
||||
@ -307,9 +316,9 @@ public:
|
||||
std::optional<SizeType> mPadId;
|
||||
SizeType mBatchSlot;
|
||||
|
||||
private:
|
||||
protected:
|
||||
SizeType mOrigPromptLen;
|
||||
std::shared_ptr<BeamTokens> mTokens;
|
||||
BeamTokens mTokens;
|
||||
SizeType mMaxSentTokenPos;
|
||||
|
||||
std::optional<TensorPtr> mEmbeddingBias;
|
||||
@ -323,6 +332,47 @@ private:
|
||||
|
||||
std::vector<VecLogProbs> mLogProbs; // [beamSize, seqLen]
|
||||
VecLogProbs mCumLogProbs; // [beamSize]
|
||||
std::shared_ptr<VecTokens> mDraftTokens;
|
||||
};
|
||||
|
||||
class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
|
||||
{
|
||||
public:
|
||||
using Base = GenericLlmRequest<runtime::ITensor::SharedPtr>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
using SizeType = Base::SizeType;
|
||||
using TokenIdType = Base::TokenIdType;
|
||||
using RequestIdType = Base::RequestIdType;
|
||||
using VecLogProbs = Base::VecLogProbs;
|
||||
using BeamTokens = Base::BeamTokens;
|
||||
using VecTokens = Base::VecTokens;
|
||||
|
||||
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<std::vector<TokenIdType>> inputTokens,
|
||||
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
|
||||
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false,
|
||||
std::optional<std::shared_ptr<VecTokens>> draftTokens = std::nullopt)
|
||||
: Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId, embeddingBias,
|
||||
badWordsList, stopWordsList, promptEmbeddingTable, promptVocabSize, returnLogProbs, draftTokens)
|
||||
{
|
||||
}
|
||||
|
||||
void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager)
|
||||
{
|
||||
if (!mPromptEmbeddingTable.has_value()
|
||||
|| mPromptEmbeddingTable.value()->getMemoryType() == runtime::MemoryType::kGPU)
|
||||
{
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
TensorPtr gpuPromptEmbeddingTable
|
||||
= manager.copyFrom(*mPromptEmbeddingTable.value(), runtime::MemoryType::kGPU);
|
||||
mPromptEmbeddingTable = gpuPromptEmbeddingTable;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -50,6 +50,8 @@ public:
|
||||
TensorPtr endIds; // [batchSize * beamWidth], on gpu
|
||||
|
||||
// optional parameters
|
||||
TensorPtr finished; // [batchSize, beamWidth], finished states at current iteration.
|
||||
// If true for some request, the decoding step of it is skipped, on gpu
|
||||
TensorPtr sequenceLimitLength; // [batchSize], on gpu
|
||||
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
|
||||
TensorPtr lengths; // [batchSize, beamWidth], on gpu
|
||||
|
||||
@ -64,11 +64,19 @@ public:
|
||||
// mandatory parameters
|
||||
TensorPtr ids; // [batchSize, beamWidth, maxSeqLen], on gpu, must contain previously generated token ids for all
|
||||
// steps before DecodingInput.step
|
||||
TensorPtr newTokens; // [batchSize, beamWidth] on gpu.
|
||||
TensorPtr newTokensSteps; // [maxTokensPerStep, batchSize, beamWidth] new tokens at each generated token of
|
||||
// maxTokensPerStep, on gpu.
|
||||
TensorPtr newTokens; // [batchSize, beamWidth] usually a view of newTokensSteps for the current token, on gpu.
|
||||
std::vector<TensorPtr> newTokensVec; // vector of size maxTokensPerStep with tensor [batchSize, beamWidth].
|
||||
// Vector of views on newTokensSteps for each token. Elements are on gpu.
|
||||
|
||||
// optional parameters
|
||||
TensorPtr finished; // [batchSize, beamWidth], mandatory in beam search and to determine whether to stop
|
||||
// according to DecodingInput.sequenceLimitLength, on gpu
|
||||
TensorPtr finishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states at each generated token of
|
||||
// maxTokensPerStep, on gpu
|
||||
TensorPtr finished; // [batchSize, beamWidth], usually a view of finishedSteps for current token.
|
||||
// Set to true by decoding if any of the stop conditions are met or if DecodingInput.finished is
|
||||
// true. In beam search and to determine whether to stop according to
|
||||
// DecodingInput.sequenceLimitLength, on gpu
|
||||
TensorPtr finishedSum; // [1], the sum of finished sequences, in pinned memory
|
||||
|
||||
// mandatory parameters for beam search
|
||||
|
||||
@ -55,6 +55,10 @@ public:
|
||||
DecodingInput const& decodingInput, BufferManager const& manager)
|
||||
= 0;
|
||||
|
||||
static void acceptTokens(const ITensor& targetTokenIds, const ITensor& draftTokenIds, const ITensor& contextLengths,
|
||||
const ITensor& numDraftTokens, ITensor& sequenceLengths, const ITensor& finishedVec, ITensor& finishedFinal,
|
||||
ITensor& finishedSum, BufferManager::CudaStreamPtr const& stream);
|
||||
|
||||
static std::unique_ptr<IGptDecoder> create(
|
||||
nvinfer1::DataType dtype, size_t vocabSize, size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream);
|
||||
};
|
||||
|
||||
@ -40,13 +40,13 @@ class GptDecoderBatch : public IGptDecoderBatch
|
||||
{
|
||||
public:
|
||||
using CudaStreamPtr = std::shared_ptr<CudaStream>;
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
|
||||
GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
|
||||
|
||||
//! Setup the decoder before calling `forward()`
|
||||
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength,
|
||||
nvinfer1::DataType dtype) override;
|
||||
SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
|
||||
|
||||
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
|
||||
void newRequest(
|
||||
@ -69,6 +69,7 @@ public:
|
||||
return {mFinished.begin(), mFinished.begin() + mActualBatchSize};
|
||||
}
|
||||
|
||||
//! @param batchIdx index of the batch
|
||||
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token ids without
|
||||
//! padding for request `batchIdx`, on gpu
|
||||
[[nodiscard]] TensorPtr getOutputIds(SizeType batchIdx) const override
|
||||
@ -99,18 +100,6 @@ public:
|
||||
return ITensor::slice(mJointDecodingOutput->parentIds, 0, mActualBatchSize);
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth], marks finished requests (per beam), on gpu
|
||||
[[nodiscard]] TensorPtr getFinishedBeams() const override
|
||||
{
|
||||
return ITensor::slice(mJointDecodingOutput->finished, 0, mActualBatchSize);
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth], total sequence lengths (per beam), on gpu
|
||||
[[nodiscard]] TensorPtr getOutputLengths() const override
|
||||
{
|
||||
return ITensor::slice(mJointDecodingOutput->lengths, 0, mActualBatchSize);
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth], cumulative log probabilities (per beam), on gpu
|
||||
[[nodiscard]] TensorPtr getCumLogProbs() const override
|
||||
{
|
||||
@ -139,10 +128,21 @@ public:
|
||||
return tensor;
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
|
||||
[[nodiscard]] TensorPtr getNewTokens() const override
|
||||
//! @brief Get maxTokensPerStep tokens generated in the last forward pass
|
||||
//! @returns [maxTokensPerStep, batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
|
||||
[[nodiscard]] TensorPtr getAllNewTokens() const override
|
||||
{
|
||||
return ITensor::slice(mJointDecodingOutput->newTokens, 0, mActualBatchSize);
|
||||
return mJointDecodingOutput->newTokensSteps;
|
||||
}
|
||||
|
||||
//! @brief Get tokens generated in one step of last forward pass
|
||||
//! @param iter The iteration within [0; maxTokensPerStep) for which to get the tokens
|
||||
//! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu
|
||||
[[nodiscard]] TensorPtr getNewTokens(SizeType iter = 0) const override
|
||||
{
|
||||
TensorPtr newTokensView = std::move(ITensor::slice(mJointDecodingOutput->newTokensSteps, iter, 1));
|
||||
newTokensView->squeeze(0);
|
||||
return ITensor::slice(newTokensView, 0, mActualBatchSize);
|
||||
}
|
||||
|
||||
//! @returns [batchSize], the number of generation steps executed on each request
|
||||
@ -180,13 +180,18 @@ private:
|
||||
DecodingInputPtr mJointDecodingInput;
|
||||
DecodingOutputPtr mJointDecodingOutput;
|
||||
|
||||
std::vector<TensorPtr> mDraftTokenIds;
|
||||
TensorPtr mNumDraftTokens;
|
||||
|
||||
std::vector<SizeType> mNbSteps;
|
||||
std::vector<bool> mFinished;
|
||||
TensorPtr mFinishedSum;
|
||||
std::vector<SizeType> mMaxNewTokens;
|
||||
std::vector<SizeType> mBeamWidths;
|
||||
std::vector<SizeType> mGeneratedTokensPerStep;
|
||||
SizeType mMaxSequenceLength{};
|
||||
SizeType mMaxKvCacheLength{};
|
||||
SizeType mActualBatchSize{};
|
||||
SizeType mMaxTokensPerStep{};
|
||||
};
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -54,6 +54,7 @@ public:
|
||||
, mModelVariant(ModelVariant::kGpt)
|
||||
, mUseCustomAllReduce(false)
|
||||
, mMaxPromptEmbeddingTableSize(0)
|
||||
, mMaxDraftLen(0)
|
||||
{
|
||||
}
|
||||
|
||||
@ -253,6 +254,16 @@ public:
|
||||
mUseCustomAllReduce = customAllReduce;
|
||||
}
|
||||
|
||||
void constexpr setMaxDraftLen(SizeType maxDraftLen) noexcept
|
||||
{
|
||||
mMaxDraftLen = maxDraftLen;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType constexpr getMaxTokensPerStep() const noexcept
|
||||
{
|
||||
return mMaxDraftLen + 1;
|
||||
}
|
||||
|
||||
private:
|
||||
SizeType mVocabSize;
|
||||
SizeType mNbLayers;
|
||||
@ -276,6 +287,7 @@ private:
|
||||
bool mUseCustomAllReduce;
|
||||
|
||||
SizeType mMaxPromptEmbeddingTableSize;
|
||||
SizeType mMaxDraftLen;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/iStatefulGptDecoder.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include "tensorrt_llm/runtime/utils/sessionUtils.h"
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
@ -35,12 +36,14 @@ namespace decoder_batch
|
||||
class Request
|
||||
{
|
||||
public:
|
||||
using ConstTensorPtr = std::shared_ptr<ITensor const>;
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
using ConstTensorPtr = ITensor::SharedConstPtr;
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
using BufferPtr = IBuffer::SharedPtr;
|
||||
|
||||
explicit Request(ConstTensorPtr ids, std::optional<SizeType> maxNewTokens = std::nullopt,
|
||||
explicit Request(ConstTensorPtr ids, SizeType inputLen, std::optional<SizeType> maxNewTokens = std::nullopt,
|
||||
std::optional<SizeType> endId = std::nullopt)
|
||||
: ids{std::move(ids)}
|
||||
, inputLen(inputLen)
|
||||
, maxNewTokens{maxNewTokens}
|
||||
, endId{endId}
|
||||
, computeCumLogProbs(false)
|
||||
@ -48,42 +51,68 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
// the number of tokens generated per step
|
||||
SizeType generatedTokensPerStep() const
|
||||
{
|
||||
return draftTokens ? draftTokens->getSize() + 1 : 1;
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
ConstTensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu
|
||||
SizeType inputLen; // the input length without draft tokens
|
||||
|
||||
// optional parameters
|
||||
std::optional<SizeType> maxNewTokens; // maximum number of tokens to generate for this request
|
||||
std::optional<SizeType> endId; // end token id
|
||||
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
|
||||
TensorPtr badWordsList; // [2, badWordsLength], on gpu
|
||||
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
|
||||
BufferPtr draftTokens; // [generatedTokensPerStep - 1], on gpu, draft tokens from speculative decoding
|
||||
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
|
||||
TensorPtr badWordsList; // [2, badWordsLength], on gpu
|
||||
TensorPtr stopWordsList; // [2, stopWordsLength], on gpu
|
||||
|
||||
bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||
bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||
bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||
bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||
};
|
||||
|
||||
class Input : public decoder::Input
|
||||
class Input
|
||||
{
|
||||
public:
|
||||
using Base = decoder::Input;
|
||||
using TensorConstPtr = ITensor::SharedConstPtr;
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
|
||||
explicit Input(TensorPtr logits)
|
||||
: Base{std::move(logits)}
|
||||
{
|
||||
auto const batchSize = this->logits->getShape().d[0];
|
||||
active.resize(batchSize, true);
|
||||
}
|
||||
|
||||
explicit Input(TensorPtr logits, std::vector<bool> const& active)
|
||||
: Base{std::move(logits)}
|
||||
explicit Input(std::vector<TensorConstPtr> const& logits, std::vector<bool> const& active)
|
||||
: logits{logits}
|
||||
, active{active}
|
||||
{
|
||||
auto const batchSize = static_cast<std::size_t>(this->logits->getShape().d[0]);
|
||||
TLLM_CHECK_WITH_INFO(this->active.size() == batchSize, "'active' vector size does not match logits batchSize");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
this->active.size() == logits.size(), "'active' vector size does not match logits vector size");
|
||||
}
|
||||
|
||||
explicit Input(std::vector<TensorConstPtr> const& logits)
|
||||
: Input{logits, std::vector<bool>(logits.size(), true)}
|
||||
{
|
||||
}
|
||||
|
||||
explicit Input(std::vector<TensorPtr> const& logits, std::vector<bool> const& active)
|
||||
: Input{
|
||||
utils::transformVector(logits, [](auto& x) { return std::const_pointer_cast<ITensor const>(x); }), active}
|
||||
{
|
||||
}
|
||||
|
||||
explicit Input(std::vector<TensorPtr> const& logits)
|
||||
: Input{logits, std::vector<bool>(logits.size(), true)}
|
||||
{
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
std::vector<TensorConstPtr>
|
||||
logits; // batchSize * [1, beamWidth, vocabSizePadded] or [generatedTokensPerStep, 1, vocabSizePadded], on gpu
|
||||
|
||||
// control activity of decoder slots in batch
|
||||
std::vector<bool> active; // [batchSize]
|
||||
|
||||
// parameters for beam search
|
||||
TensorConstPtr cacheIndirection; // [batchSize, maxBeamWidth, maxSeqLen] - indices into KV cache of different rays
|
||||
// within one beam for beam search, on gpu
|
||||
};
|
||||
|
||||
using Output = decoder::Output;
|
||||
@ -127,6 +156,7 @@ public:
|
||||
forwardSync(*forwardAsync(output, input));
|
||||
}
|
||||
|
||||
//! @param batchIdx index of the batch
|
||||
//! @returns [maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
|
||||
//! ids without padding for request `batchIdx`, on gpu
|
||||
virtual TensorPtr getOutputIds(SizeType batchIdx) const = 0;
|
||||
@ -135,12 +165,6 @@ public:
|
||||
//! Result will only be available after event returned
|
||||
virtual CudaEvent finalize(SizeType batchIdx) const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth], marks finished requests (per beam), on gpu
|
||||
virtual TensorPtr getFinishedBeams() const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth], total sequence lengths (per beam), on gpu
|
||||
virtual TensorPtr getOutputLengths() const = 0;
|
||||
|
||||
//! @returns [batchSize (actual)], marks finished requests (per batch)
|
||||
virtual std::vector<bool> getFinished() const = 0;
|
||||
|
||||
|
||||
@ -75,7 +75,7 @@ public:
|
||||
|
||||
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
|
||||
virtual void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
|
||||
SizeType maxSequenceLength, nvinfer1::DataType dtype)
|
||||
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
|
||||
= 0;
|
||||
|
||||
//! @brief Initialize the decoder with new batch of inputs.
|
||||
@ -108,8 +108,14 @@ public:
|
||||
//! @returns [batchSize, maxBeamWidth, maxSequenceLength], log probabilities (per beam), on gpu
|
||||
virtual TensorPtr getLogProbs() const = 0;
|
||||
|
||||
//! @returns [batchSize, beamWidth], latests generated tokens (per beam), on gpu
|
||||
virtual TensorPtr getNewTokens() const = 0;
|
||||
//! @brief Get tokens generated in one step of last forward pass
|
||||
//! @param iter The iteration within [0; maxTokensPerStep) for which to get the tokens
|
||||
//! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu
|
||||
virtual TensorPtr getNewTokens(SizeType iter = 0) const = 0;
|
||||
|
||||
//! @brief Get maxTokensPerStep tokens generated in the last forward pass
|
||||
//! @returns [maxTokensPerStep, batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
|
||||
virtual TensorPtr getAllNewTokens() const = 0;
|
||||
|
||||
//! @returns [1], number of finished sequences, in pinned host memory
|
||||
virtual TensorPtr getNbFinished() const = 0;
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b867c2e048671eecc421244d436436782093baf02f0fd5d49232b3d3042e55ea
|
||||
size 1688216
|
||||
oid sha256:9e6a5d7dba399049a4da9ca729153e5a6080986782a314b867e7635454eb36de
|
||||
size 1705954
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:db433a13ec6a017638bbb97b53a98624ad675b395787c99054d48ab370f5e3a0
|
||||
size 1697778
|
||||
oid sha256:64fae7bca97be7c3067b4544da0c3d79621ec3632c10e39b7a005d886702e8eb
|
||||
size 1706098
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
81f472ac2b68edd03a0265299744347f libtensorrt_llm_batch_manager_static.a
|
||||
4e5e3bbdfffa6deb6a50c541a946ac7a libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
7edd8a21 commit
|
||||
02375d908e57e2194e3f28a4e83dd963 libtensorrt_llm_batch_manager_static.a
|
||||
e5c4994ecc347d808f6d38fb686b5cf1 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
cd83045d7c127af5f907efd7c710bd6fe1f90ec4 commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:681917aea11f45d83ba1429ded44ced97cb8ce5f54eb1c3fb3055bc342f0ffbf
|
||||
size 1600734
|
||||
oid sha256:f3cca913fc62df4119e4df10921be97086714740148f54c528da7bb2826f67ba
|
||||
size 1617426
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:d59b04e3229358ec2d9476b07f0361aa4a8539e543312c8952b690173040663d
|
||||
size 1598666
|
||||
oid sha256:3d633874e8b32a56758bf8bbdc0955ed8c5d43d531ba330a2274bcec13e1c89f
|
||||
size 1620144
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
c9d5678a2ec347188457ad4a3a59d483 libtensorrt_llm_batch_manager_static.a
|
||||
53261d576d540ab330f2f2e1f8d99677 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
f379e62b3f69afa4bd1d8e5551a6ede4 libtensorrt_llm_batch_manager_static.a
|
||||
92f44b2834d39c2c62a9b0bd0549b159 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
|
||||
@ -503,5 +503,71 @@ void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_ti
|
||||
output_log_probs, output_log_probs_tiled, sequence_lengths, batch_size, beam_width, max_seq_len);
|
||||
}
|
||||
|
||||
__global__ void acceptTokensKernel(const int* draft_tokens, const int* target_tokens, const int* context_lengths,
|
||||
const int* nums_draft_tokens, int* sequence_lengths, const bool* finished, bool* finished_final, int* finished_sum,
|
||||
int batch_size, int beam_width, int max_seq_len, int max_draft_tokens)
|
||||
{
|
||||
int thread_finished_count = 0;
|
||||
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batch_size * beam_width;
|
||||
index += blockDim.x * gridDim.x)
|
||||
{
|
||||
const auto num_draft_tokens = nums_draft_tokens[index];
|
||||
|
||||
const auto context_length = context_lengths[index];
|
||||
auto& sequence_length = sequence_lengths[index];
|
||||
int finished_draft_idx = 0;
|
||||
for (int ti = context_length; ti < min(sequence_length, context_length + num_draft_tokens);
|
||||
++ti, ++finished_draft_idx)
|
||||
{
|
||||
const auto draft_idx = ti - context_length;
|
||||
const auto target_token_idx = index * max_seq_len + ti;
|
||||
const auto draft_token_idx = index * max_draft_tokens + draft_idx;
|
||||
// Check if draft tokens are the same as target tokens
|
||||
// FIXME(nkorobov); compare logits here
|
||||
const bool accepted = draft_tokens[draft_token_idx] == target_tokens[target_token_idx];
|
||||
if (!accepted)
|
||||
{
|
||||
// Set sequence length to the numAcceptedTokens + 1
|
||||
sequence_length = min(ti + 1, max_seq_len);
|
||||
// FIXME(nkorobov): do we need to set endIds here?
|
||||
break;
|
||||
}
|
||||
}
|
||||
bool finish = finished[finished_draft_idx * batch_size * beam_width + index];
|
||||
finished_final[index] = finish;
|
||||
thread_finished_count += static_cast<int>(finish);
|
||||
}
|
||||
|
||||
if (finished_sum)
|
||||
{
|
||||
int block_finished_count = 0;
|
||||
if (blockDim.x <= 32)
|
||||
{
|
||||
block_finished_count = warpReduceSum(thread_finished_count);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_finished_count = blockReduceSum(thread_finished_count);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
finished_sum[0] = block_finished_count;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void invokeAcceptTokens(const int* draft_tokens, const int* target_tokens, const int* context_lengths,
|
||||
const int* nums_draft_tokens, int* sequence_lengths, const bool* finished, bool* finished_final, int* finished_sum,
|
||||
int batch_size, int beam_width, int max_seq_len, int max_draft_tokens, cudaStream_t stream)
|
||||
{
|
||||
dim3 block(min(256, batch_size * beam_width));
|
||||
dim3 grid(1);
|
||||
acceptTokensKernel<<<grid, block, 0, stream>>>(draft_tokens, target_tokens, context_lengths, nums_draft_tokens,
|
||||
sequence_lengths, finished, finished_final, finished_sum, batch_size, beam_width, max_seq_len,
|
||||
max_draft_tokens);
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -65,5 +65,9 @@ void invokeCopyNextStepIds(int* next_step_ids, int** output_ids_ptr, const int*
|
||||
void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_tiled, const int* sequence_lengths,
|
||||
int batch_size, int beam_width, int max_seq_len, cudaStream_t stream);
|
||||
|
||||
void invokeAcceptTokens(const int* draft_tokens, const int* target_tokens, const int* context_lengths,
|
||||
const int* nums_draft_tokens, int* sequence_lengths, const bool* finished, bool* finished_final, int* finished_sum,
|
||||
int batch_size, int beam_width, int max_seq_len, int max_draft_tokens, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -126,7 +126,7 @@ __global__ void computePaddingOffsets(int* paddingOffsets, const int* seqOffsets
|
||||
// Iterate over the tokens to update the number of padded elements.
|
||||
for (int tokenIdx = threadIdx.x; tokenIdx < seqLength; tokenIdx += blockDim.x)
|
||||
{
|
||||
paddingOffsets[seqBegin + tokenIdx] = paddingOffset + max(0, tokenIdx - seqLength);
|
||||
paddingOffsets[seqBegin + tokenIdx] = paddingOffset;
|
||||
}
|
||||
}
|
||||
|
||||
@ -152,7 +152,7 @@ __global__ void computeAttentionMask(AttentionMaskDataType* attentionMask, const
|
||||
int seqLength = seqEnd - seqBegin;
|
||||
|
||||
// Iterate over the tokens to update the number of padded elements.
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < maskSize; idx += blockDim.x)
|
||||
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < maskSize; idx += gridDim.x * blockDim.x)
|
||||
{
|
||||
// The position in the matrix.
|
||||
int rowIdx = idx / maxSeqLength;
|
||||
@ -235,8 +235,7 @@ void invokeBuildDecoderInfo(const BuildDecoderInfoParams<T>& params, cudaStream_
|
||||
// Compute the attention mask, if needed.
|
||||
if (params.attentionMask != nullptr)
|
||||
{
|
||||
// large value like 512 hurts kernel perf at long sequence length. Keep small for now.
|
||||
const int MIN_BLOCKS = 16;
|
||||
const int MIN_BLOCKS = 512;
|
||||
int blocksPerSeq = 16;
|
||||
while (blocksPerSeq * params.batchSize < MIN_BLOCKS)
|
||||
{
|
||||
|
||||
@ -195,9 +195,9 @@ __global__ void topKStage1(const T* __restrict logProbs, T* tmpLogProbs, int* to
|
||||
|
||||
template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
|
||||
__global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTmpValBuf, int** ids,
|
||||
int* sequenceLengths, bool* finished, float* cumLogProbs, float* outputLogProbs, const int maxTopK,
|
||||
const int* topKs, const float topP, const float* topPs, curandState_t* curandstate, const int* endIds,
|
||||
const int vocabSize, const bool* skipDecode)
|
||||
int* sequenceLengths, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
const int maxTopK, const int* topKs, const float topP, const float* topPs, curandState_t* curandstate,
|
||||
const int* endIds, const int vocabSize, const bool* skipDecode)
|
||||
{
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
@ -226,8 +226,12 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
|
||||
}
|
||||
TopK_2<float> partial;
|
||||
|
||||
if (finished != nullptr && finished[batchId] == true)
|
||||
if (finishedInput != nullptr && finishedInput[batchId] == true)
|
||||
{
|
||||
if (finishedOutput != nullptr)
|
||||
{
|
||||
finishedOutput[batchId] = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
@ -300,18 +304,18 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (sequenceLengths != nullptr && finished != nullptr)
|
||||
if (sequenceLengths != nullptr && finishedOutput != nullptr)
|
||||
{
|
||||
const int seqLen = sequenceLengths[batchId];
|
||||
if (ids[batchId][seqLen] == endIds[batchId])
|
||||
{
|
||||
finished[batchId] = true;
|
||||
finishedOutput[batchId] = true;
|
||||
// Do not increase seq len when EOS is generated. Seq len should always contain only tokens to be
|
||||
// outputted
|
||||
}
|
||||
else
|
||||
{
|
||||
finished[batchId] = false;
|
||||
finishedOutput[batchId] = false;
|
||||
sequenceLengths[batchId] += 1;
|
||||
}
|
||||
}
|
||||
@ -319,19 +323,20 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
|
||||
}
|
||||
|
||||
#define CASE_K(K_MAX, BLOCK_SIZE_1_, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_) \
|
||||
topKStage1<T, BLOCK_SIZE_1_, BLOCKS_PER_BEAM_><<<batchSize * BLOCKS_PER_BEAM_, BLOCK_SIZE_1_, 0, stream>>>( \
|
||||
logProbs, tempLogProbs, topKTmpIdBuf, topKTmpValBuf, finished, maxTopK, topKs, vocabSize, endIds, skipDecode); \
|
||||
topKStage1<T, BLOCK_SIZE_1_, BLOCKS_PER_BEAM_> \
|
||||
<<<batchSize * BLOCKS_PER_BEAM_, BLOCK_SIZE_1_, 0, stream>>>(logProbs, tempLogProbs, topKTmpIdBuf, \
|
||||
topKTmpValBuf, finishedInput, maxTopK, topKs, vocabSize, endIds, skipDecode); \
|
||||
topKStage2Sampling<T, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_> \
|
||||
<<<batchSize, BLOCK_SIZE_2_, K_MAX * sizeof(int) + K_MAX * sizeof(float), stream>>>(topKTmpIdBuf, \
|
||||
topKTmpValBuf, ids, sequenceLengths, finished, cumLogProbs, outputLogProbs, maxTopK, topKs, topP, topPs, \
|
||||
curandstate, endIds, vocabSize, skipDecode); \
|
||||
topKTmpValBuf, ids, sequenceLengths, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, maxTopK, \
|
||||
topKs, topP, topPs, curandstate, endIds, vocabSize, skipDecode); \
|
||||
break;
|
||||
|
||||
template <typename T>
|
||||
void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths,
|
||||
bool* finished, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int maxTopK,
|
||||
const int* topKs, const float topP, const float* topPs, const int vocabSizePadded, const int* endIds,
|
||||
cudaStream_t stream, const int batchSize, const bool* skipDecode)
|
||||
const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
|
||||
const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -386,35 +391,35 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo
|
||||
#undef CASE_K
|
||||
|
||||
template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const float* logProbs, int** ids,
|
||||
int* sequenceLengths, bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate,
|
||||
const int maxTopK, const int* topKs, const float topP, const float* topPs, const int vocabSizePadded,
|
||||
const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode);
|
||||
int* sequenceLengths, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
|
||||
const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode);
|
||||
|
||||
template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids,
|
||||
int* sequenceLengths, bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate,
|
||||
const int maxTopK, const int* topKs, const float topP, const float* topPs, const int vocabSizePadded,
|
||||
const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode);
|
||||
int* sequenceLengths, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
|
||||
const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode);
|
||||
|
||||
template <typename T>
|
||||
void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths,
|
||||
bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int topK,
|
||||
const float topP, const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize,
|
||||
const bool* skipDecode)
|
||||
const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
|
||||
cudaStream_t stream, const int batchSize, const bool* skipDecode)
|
||||
{
|
||||
invokeBatchTopKSampling(workspace, workspaceSize, logProbs, ids, sequenceLengths, finished_buf, cumLogProbs,
|
||||
outputLogProbs, curandstate, topK, nullptr, topP, nullptr, vocabSizePadded, endIds, stream, batchSize,
|
||||
skipDecode);
|
||||
invokeBatchTopKSampling(workspace, workspaceSize, logProbs, ids, sequenceLengths, finishedInput, finishedOutput,
|
||||
cumLogProbs, outputLogProbs, curandstate, topK, nullptr, topP, nullptr, vocabSizePadded, endIds, stream,
|
||||
batchSize, skipDecode);
|
||||
}
|
||||
|
||||
template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const float* logProbs, int** ids,
|
||||
int* sequenceLengths, bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate,
|
||||
const int topK, const float topP, const int vocabSizePadded, const int* endIds, cudaStream_t stream,
|
||||
const int batchSize, const bool* skipDecode);
|
||||
int* sequenceLengths, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
|
||||
cudaStream_t stream, const int batchSize, const bool* skipDecode);
|
||||
|
||||
template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids,
|
||||
int* sequenceLengths, bool* finished_buf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate,
|
||||
const int topK, const float topP, const int vocabSizePadded, const int* endIds, cudaStream_t stream,
|
||||
const int batchSize, const bool* skipDecode);
|
||||
int* sequenceLengths, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
|
||||
cudaStream_t stream, const int batchSize, const bool* skipDecode);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -37,8 +37,8 @@ namespace kernels
|
||||
//! logProbs must contain **just** probabilities instead of log probabilities.
|
||||
//! \param outputIds output buffer [batchSize][maxSeqLen]. Contains pointers to rows with output tokens per request
|
||||
//! \param sequenceLength input/output buffer [batchSize]. Current sequence length of the request up to, but excluding endId token
|
||||
//! \param finishedBuf input/output buffer [batchSize]. Flag if sequence has finished (if finished || outputId == endId).
|
||||
//! If true, request exits early.
|
||||
//! \param finishedInput input buffer [batchSize]. If true, request exits early.
|
||||
//! \param finishedOutput output buffer [batchSize]. Set flag if sequence has finished (if finished || outputId == endId).
|
||||
//! \param cumLogProbs input/output buffer [batchSize]. Cumulative log probability of selected tokens. Ignored if nullptr
|
||||
//! \param outputLogProbs output buffer [batchSize]. Log probs is the probability induced by the top-k sampling.
|
||||
//! We normalize the probability 'expLogit' of the selected token by the probability 's_sum' of a set of top-k
|
||||
@ -60,16 +60,16 @@ namespace kernels
|
||||
// clang-format on
|
||||
template <typename T>
|
||||
void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths,
|
||||
bool* finished, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int maxTopK,
|
||||
const int* topKs, const float topP, const float* topPs, const int vocabSizePadded, const int* endIds,
|
||||
cudaStream_t stream, const int batchSize, const bool* skipDecode);
|
||||
const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
|
||||
const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize, const bool* skipDecode);
|
||||
|
||||
//! \brief Specialization of invokeBatchTopKSampling with topPs=nullptr and topKs=nullptr
|
||||
template <typename T>
|
||||
void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** outputIds, int* sequenceLength,
|
||||
bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, curandState_t* curandstate, const int topK,
|
||||
const float topP, const int vocabSizePadded, const int* endIds, cudaStream_t stream, const int batchSize,
|
||||
const bool* skipDecode);
|
||||
const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
|
||||
cudaStream_t stream, const int batchSize, const bool* skipDecode);
|
||||
|
||||
//! \brief Initialize batchSize curand states with given seed.
|
||||
//!
|
||||
|
||||
@ -159,7 +159,7 @@ struct BlockPrefixCallbackOp
|
||||
|
||||
template <typename T>
|
||||
__device__ void epilogue(int batchId, int currentStep, int offset, int** ids, int* sortedIdVals, T* sortedLogProbs,
|
||||
float* cumLogProbs, float* outputLogProbs, const int* endIds, int* sequenceLengths, bool* finishedBuf)
|
||||
float* cumLogProbs, float* outputLogProbs, const int* endIds, int* sequenceLengths, bool* finishedOutput)
|
||||
{
|
||||
ids[batchId][currentStep] = sortedIdVals[offset];
|
||||
|
||||
@ -175,26 +175,26 @@ __device__ void epilogue(int batchId, int currentStep, int offset, int** ids, in
|
||||
outputLogProbs[batchId] = lprob;
|
||||
}
|
||||
}
|
||||
if (sequenceLengths != nullptr && finishedBuf != nullptr)
|
||||
if (sequenceLengths != nullptr && finishedOutput != nullptr)
|
||||
{
|
||||
if (ids[batchId][currentStep] == endIds[batchId])
|
||||
{
|
||||
finishedBuf[batchId] = true;
|
||||
finishedOutput[batchId] = true;
|
||||
// Do not increase seq len when EOS is generated. Seq len should always contain only tokens to be outputted
|
||||
}
|
||||
else
|
||||
{
|
||||
finishedBuf[batchId] = false;
|
||||
finishedOutput[batchId] = false;
|
||||
sequenceLengths[batchId] += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int blockSize>
|
||||
__global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, int* sequenceLength, bool* finishedBuf,
|
||||
float* cumLogProbs, float* outputLogProbs, const int* beginOffsetBuf, const int* offsetBuf, const int vocabSize,
|
||||
curandState_t* curandstate, const float topP, const float* topPs, const int* endIds, const int batchSize,
|
||||
const bool* skipDecode)
|
||||
__global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, int* sequenceLength,
|
||||
const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
const int* beginOffsetBuf, const int* offsetBuf, const int vocabSize, curandState_t* curandstate, const float topP,
|
||||
const float* topPs, const int* endIds, const int batchSize, const bool* skipDecode)
|
||||
{
|
||||
/**
|
||||
* Each block processes one request row sorted in descending order by probabilities.
|
||||
@ -214,8 +214,12 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i
|
||||
}
|
||||
|
||||
// Exit early if sequence has finished
|
||||
if (finishedBuf != nullptr && finishedBuf[batchId] == true)
|
||||
if (finishedInput != nullptr && finishedInput[batchId] == true)
|
||||
{
|
||||
if (finishedOutput != nullptr)
|
||||
{
|
||||
finishedOutput[batchId] = true;
|
||||
}
|
||||
ids[batchId][sequenceLength[batchId]] = endIds[batchId];
|
||||
return;
|
||||
}
|
||||
@ -244,7 +248,7 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i
|
||||
{
|
||||
int offset = batchId * vocabSize;
|
||||
epilogue(batchId, currentStep, offset, ids, sortedIdVals, sortedLogProbs, cumLogProbs, outputLogProbs,
|
||||
endIds, sequenceLength, finishedBuf);
|
||||
endIds, sequenceLength, finishedOutput);
|
||||
}
|
||||
return;
|
||||
}
|
||||
@ -285,16 +289,16 @@ __global__ void topPSsampling(T* sortedLogProbs, int* sortedIdVals, int** ids, i
|
||||
if (threadIdx.x == min(blockDim.x - count, blockDim.x - 1))
|
||||
{
|
||||
epilogue(batchId, currentStep, offset + selectedTokenId, ids, sortedIdVals, sortedLogProbs, cumLogProbs,
|
||||
outputLogProbs, endIds, sequenceLength, finishedBuf);
|
||||
outputLogProbs, endIds, sequenceLength, finishedOutput);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
|
||||
int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, const T* logProbs,
|
||||
const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, const int batchSize,
|
||||
const size_t vocabSizePadded, const int* endIds, const float maxTopP, const float* topPs, cudaStream_t stream,
|
||||
const bool* skipDecode)
|
||||
int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
const T* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
|
||||
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float maxTopP, const float* topPs,
|
||||
cudaStream_t stream, const bool* skipDecode)
|
||||
{
|
||||
// Here, we put batch size as an argument because the batch size of
|
||||
// initialization and inference may be different due to pipeline parallelism.
|
||||
@ -338,42 +342,45 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub
|
||||
dim3 grid(batchSize);
|
||||
// Sample with Top P given sorted tokens
|
||||
topPSsampling<T, SAMPLING_BLOCK_SIZE><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(sortedLogProbs, sortedIdVals,
|
||||
outputIds, sequenceLength, finishedBuf, cumLogProbs, outputLogProbs, beginOffsetBuf, offsetBuf + 1, vocabSize,
|
||||
curandstate, maxTopP, topPs, endIds, batchSize, skipDecode);
|
||||
outputIds, sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, beginOffsetBuf,
|
||||
offsetBuf + 1, vocabSize, curandstate, maxTopP, topPs, endIds, batchSize, skipDecode);
|
||||
}
|
||||
|
||||
template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize,
|
||||
int** outputIds, int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs,
|
||||
const float* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
|
||||
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float maxTopP, const float* topPs,
|
||||
cudaStream_t stream, const bool* skipDecode);
|
||||
int** outputIds, int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs,
|
||||
float* outputLogProbs, const float* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf,
|
||||
curandState_t* curandstate, const int batchSize, const size_t vocabSizePadded, const int* endIds,
|
||||
const float maxTopP, const float* topPs, cudaStream_t stream, const bool* skipDecode);
|
||||
|
||||
template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize,
|
||||
int** outputIds, int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs,
|
||||
const half* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
|
||||
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float maxTopP, const float* topPs,
|
||||
cudaStream_t stream, const bool* skipDecode);
|
||||
int** outputIds, int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs,
|
||||
float* outputLogProbs, const half* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf,
|
||||
curandState_t* curandstate, const int batchSize, const size_t vocabSizePadded, const int* endIds,
|
||||
const float maxTopP, const float* topPs, cudaStream_t stream, const bool* skipDecode);
|
||||
|
||||
template <typename T>
|
||||
void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
|
||||
int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, const T* logProbs,
|
||||
const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, const int batchSize,
|
||||
const size_t vocabSizePadded, const int* endIds, const float topP, cudaStream_t stream, const bool* skipDecode)
|
||||
int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
const T* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
|
||||
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float topP, cudaStream_t stream,
|
||||
const bool* skipDecode)
|
||||
{
|
||||
invokeBatchTopPSampling(workspace, workspaceSize, cubTempStorageSize, outputIds, sequenceLength, finishedBuf,
|
||||
cumLogProbs, outputLogProbs, logProbs, idVals, offsetBuf, beginOffsetBuf, curandstate, batchSize,
|
||||
vocabSizePadded, endIds, topP, nullptr, stream, skipDecode);
|
||||
invokeBatchTopPSampling(workspace, workspaceSize, cubTempStorageSize, outputIds, sequenceLength, finishedInput,
|
||||
finishedOutput, cumLogProbs, outputLogProbs, logProbs, idVals, offsetBuf, beginOffsetBuf, curandstate,
|
||||
batchSize, vocabSizePadded, endIds, topP, nullptr, stream, skipDecode);
|
||||
}
|
||||
|
||||
template void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
|
||||
int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, const float* logProbs,
|
||||
const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, const int batchSize,
|
||||
const size_t vocabSizePadded, const int* endIds, const float topP, cudaStream_t stream, const bool* skipDecode);
|
||||
int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
const float* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
|
||||
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float topP, cudaStream_t stream,
|
||||
const bool* skipDecode);
|
||||
|
||||
template void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
|
||||
int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, const half* logProbs,
|
||||
const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, const int batchSize,
|
||||
const size_t vocabSizePadded, const int* endIds, const float topP, cudaStream_t stream, const bool* skipDecode);
|
||||
int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
const half* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
|
||||
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float topP, cudaStream_t stream,
|
||||
const bool* skipDecode);
|
||||
|
||||
template <typename T>
|
||||
__global__ void addBiasSoftMax(
|
||||
|
||||
@ -64,9 +64,9 @@ the
|
||||
//! \param outputIds output buffer [batchSize][maxSeqLen]. Contains pointers to rows with output tokens per request.
|
||||
//! \param sequenceLength input/output buffer [batchSize]. Current sequence length of the request up to, but excluding
|
||||
endId token.
|
||||
//! \param finishedBuf input/output buffer [batchSize]. Flag if sequence has finished (if finished || outputId ==
|
||||
//! \param finishedInput input buffer [batchSize]. Exit early if true.
|
||||
//! \param finishedOutput output buffer [batchSize]. Set flag if sequence has finished (if finished || outputId ==
|
||||
endId).
|
||||
//! If true, request exits early.
|
||||
//! \param cumLogProbs input/output buffer [batchSize]. Cumulative log probability of selected tokens. Ignored if
|
||||
nullptr.
|
||||
//! \param outputLogProbs output buffer [batchSize]. Log probs is the probability
|
||||
@ -94,17 +94,18 @@ nullptr.
|
||||
*/
|
||||
template <typename T>
|
||||
void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
|
||||
int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, const T* logProbs,
|
||||
const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, const int batchSize,
|
||||
const size_t vocabSizePadded, const int* endIds, const float maxTopP, const float* topPs, cudaStream_t stream,
|
||||
const bool* skipDecode);
|
||||
int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
const T* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
|
||||
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float maxTopP, const float* topPs,
|
||||
cudaStream_t stream, const bool* skipDecode);
|
||||
|
||||
//! \brief Specialization of invokeBatchTopPSampling with topPs=nullptr
|
||||
template <typename T>
|
||||
void invokeTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize, int** outputIds,
|
||||
int* sequenceLength, bool* finishedBuf, float* cumLogProbs, float* outputLogProbs, const T* logProbs,
|
||||
const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate, const int batchSize,
|
||||
const size_t vocabSizePadded, const int* endIds, const float topPp, cudaStream_t stream, const bool* skipDecode);
|
||||
int* sequenceLength, const bool* finishedInput, bool* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
const T* logProbs, const int* idVals, int* offsetBuf, int* beginOffsetBuf, curandState_t* curandstate,
|
||||
const int batchSize, const size_t vocabSizePadded, const int* endIds, const float topPp, cudaStream_t stream,
|
||||
const bool* skipDecode);
|
||||
|
||||
//! \brief Compute the topp decay by https://arxiv.org/pdf/2206.04624.pdf
|
||||
//! In short, the formula is
|
||||
|
||||
@ -69,6 +69,7 @@ struct WeightOnlyParams
|
||||
const ActType* scales;
|
||||
const ActType* zeros;
|
||||
const ActType* in;
|
||||
const ActType* act_scale;
|
||||
const ActType* bias;
|
||||
ActType* out;
|
||||
const int m;
|
||||
@ -81,13 +82,14 @@ struct WeightOnlyParams
|
||||
WeightOnlyActivationType act_type;
|
||||
|
||||
WeightOnlyParams(const uint8_t* _qweight, const ActType* _scales, const ActType* _zeros, const ActType* _in,
|
||||
const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k, const int _group_size,
|
||||
const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type,
|
||||
const ActType* _act_scale, const ActType* _bias, ActType* _out, const int _m, const int _n, const int _k,
|
||||
const int _group_size, const WeightOnlyQuantType _quant_type, const WeightOnlyType _weight_only_type,
|
||||
const WeightOnlyActivationFunctionType _act_func_type, const WeightOnlyActivationType _act_type)
|
||||
: qweight(_qweight)
|
||||
, scales(_scales)
|
||||
, zeros(_zeros)
|
||||
, in(_in)
|
||||
, act_scale(_act_scale)
|
||||
, bias(_bias)
|
||||
, out(_out)
|
||||
, m(_m)
|
||||
|
||||
@ -292,9 +292,9 @@ public:
|
||||
};
|
||||
|
||||
template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp,
|
||||
bool Zero, bool Bias, int NPerBlock, int Batch, int BlockSize>
|
||||
bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize>
|
||||
__device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType* scales, const ActType* zeros,
|
||||
const ActType* in, const ActType* bias, ActType* out, const int n, const int k)
|
||||
const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k)
|
||||
{
|
||||
static_assert(NPerBlock == 1 || (NPerBlock % 2 == 0));
|
||||
using ActType2 = typename ActTypeDetails<ActType>::Vec2;
|
||||
@ -376,6 +376,16 @@ __device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType*
|
||||
}
|
||||
}
|
||||
}
|
||||
ActType act_scale_v[Details::kElemsPerThread];
|
||||
if constexpr (ActScale)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int idx = 0; idx < Details::kActivationAccessNum; ++idx)
|
||||
{
|
||||
load<AccType>(act_scale_v + idx * Details::kActivationElemNumPerAccess,
|
||||
act_scale + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess);
|
||||
}
|
||||
}
|
||||
#pragma unroll
|
||||
for (int b = 0; b < Batch; ++b)
|
||||
{
|
||||
@ -386,6 +396,16 @@ __device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType*
|
||||
// load activation elements
|
||||
load<AccType>(in_v + idx * Details::kActivationElemNumPerAccess,
|
||||
in + b * k + scale_loader.offset() + idx * Details::kActivationElemNumPerAccess);
|
||||
if constexpr (ActScale)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < Details::kActivationElemNumPerAccess; i += 2)
|
||||
{
|
||||
*reinterpret_cast<ActType2*>(in_v + idx * Details::kActivationElemNumPerAccess + i) = __hmul2(
|
||||
*reinterpret_cast<ActType2*>(in_v + idx * Details::kActivationElemNumPerAccess + i),
|
||||
*reinterpret_cast<ActType2*>(act_scale_v + idx * Details::kActivationElemNumPerAccess + i));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Perform vector inner product and accumulate
|
||||
if constexpr (NPerBlock == 1)
|
||||
@ -448,20 +468,20 @@ __device__ void weight_only_batched_gemv(const uint8_t* qweight, const ActType*
|
||||
}
|
||||
|
||||
template <typename ActType, WeightOnlyQuantType QType, typename WeightOnlyFlag, template <typename T> class ActOp,
|
||||
bool Zero, bool Bias, int NPerBlock, int Batch, int BlockSize>
|
||||
bool Zero, bool Bias, bool ActScale, int NPerBlock, int Batch, int BlockSize>
|
||||
__global__ void weight_only_batched_gemv_wrapper(const uint8_t* qweight, const ActType* scales, const ActType* zeros,
|
||||
const ActType* in, const ActType* bias, ActType* out, const int n, const int k)
|
||||
const ActType* in, const ActType* act_scale, const ActType* bias, ActType* out, const int n, const int k)
|
||||
{
|
||||
if constexpr (std::is_same_v<ActType, half>)
|
||||
{
|
||||
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch, BlockSize>(
|
||||
qweight, scales, zeros, in, bias, out, n, k);
|
||||
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, ActScale, NPerBlock, Batch,
|
||||
BlockSize>(qweight, scales, zeros, in, act_scale, bias, out, n, k);
|
||||
}
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
else if (std::is_same_v<ActType, nv_bfloat16>)
|
||||
{
|
||||
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch, BlockSize>(
|
||||
qweight, scales, zeros, in, bias, out, n, k);
|
||||
weight_only_batched_gemv<ActType, QType, WeightOnlyFlag, ActOp, Zero, Bias, ActScale, NPerBlock, Batch,
|
||||
BlockSize>(qweight, scales, zeros, in, act_scale, bias, out, n, k);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@ -478,10 +498,24 @@ struct WeightOnlyBatchedGemvKernelLauncher
|
||||
dim3 grid(params.n / NPerBlock / kInterleave);
|
||||
dim3 block(BlockSize);
|
||||
int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave;
|
||||
weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch,
|
||||
BlockSize><<<grid, block, size, stream>>>(params.qweight, reinterpret_cast<const half*>(params.scales),
|
||||
reinterpret_cast<const half*>(params.zeros), reinterpret_cast<const half*>(params.in),
|
||||
reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n, params.k);
|
||||
if (params.act_scale != nullptr)
|
||||
{
|
||||
weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, true, NPerBlock, Batch,
|
||||
BlockSize><<<grid, block, size, stream>>>(params.qweight,
|
||||
reinterpret_cast<const half*>(params.scales), reinterpret_cast<const half*>(params.zeros),
|
||||
reinterpret_cast<const half*>(params.in), reinterpret_cast<const half*>(params.act_scale),
|
||||
reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n,
|
||||
params.k);
|
||||
}
|
||||
else
|
||||
{
|
||||
weight_only_batched_gemv_wrapper<half, QType, WeightOnlyFlag, ActOp, Zero, Bias, false, NPerBlock,
|
||||
Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight,
|
||||
reinterpret_cast<const half*>(params.scales), reinterpret_cast<const half*>(params.zeros),
|
||||
reinterpret_cast<const half*>(params.in), reinterpret_cast<const half*>(params.act_scale),
|
||||
reinterpret_cast<const half*>(params.bias), reinterpret_cast<half*>(params.out), params.n,
|
||||
params.k);
|
||||
}
|
||||
}
|
||||
#if defined(ENABLE_BF16)
|
||||
else if (params.act_type == WeightOnlyActivationType::BF16)
|
||||
@ -490,12 +524,28 @@ struct WeightOnlyBatchedGemvKernelLauncher
|
||||
dim3 grid(params.n / NPerBlock / kInterleave);
|
||||
dim3 block(BlockSize);
|
||||
int size = sizeof(float) * BlockSize / 32 * Batch * NPerBlock * kInterleave;
|
||||
weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, NPerBlock, Batch,
|
||||
BlockSize><<<grid, block, size, stream>>>(params.qweight,
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.scales),
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.zeros), reinterpret_cast<const __nv_bfloat16*>(params.in),
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out),
|
||||
params.n, params.k);
|
||||
if (params.act_scale != nullptr)
|
||||
{
|
||||
weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, true,
|
||||
NPerBlock, Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight,
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.scales),
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.zeros),
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.in),
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.act_scale),
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out),
|
||||
params.n, params.k);
|
||||
}
|
||||
else
|
||||
{
|
||||
weight_only_batched_gemv_wrapper<__nv_bfloat16, QType, WeightOnlyFlag, ActOp, Zero, Bias, false,
|
||||
NPerBlock, Batch, BlockSize><<<grid, block, size, stream>>>(params.qweight,
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.scales),
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.zeros),
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.in),
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.act_scale),
|
||||
reinterpret_cast<const __nv_bfloat16*>(params.bias), reinterpret_cast<__nv_bfloat16*>(params.out),
|
||||
params.n, params.k);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@ -50,8 +50,9 @@ public:
|
||||
// mandatory parameters
|
||||
int step;
|
||||
int ite;
|
||||
tc::Tensor logits; // [local_batch_size, beam_width, vocab_size_padded]
|
||||
tc::Tensor end_ids; // [local_batch_size]
|
||||
tc::Tensor logits; // [local_batch_size, beam_width, vocab_size_padded]
|
||||
tc::Tensor end_ids; // [local_batch_size]
|
||||
std::optional<tc::Tensor> finished; // [batch_size * beam_width]
|
||||
};
|
||||
|
||||
class DecodingOutputParams
|
||||
|
||||
@ -226,7 +226,7 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
|
||||
|
||||
invokeBanRepeatNgram(logits.template getPtrWithOffset<T>(decode_vocab_size_units_offset),
|
||||
outputs.output_ids_ptr.template getPtr<const int*>(),
|
||||
outputs.finished.value_or(Tensor{}).template getPtr<bool>(),
|
||||
params.finished.value_or(Tensor{}).template getPtr<bool>(),
|
||||
outputs.parent_ids.value_or(Tensor{}).template getPtr<const int>(), batch_size, local_batch_size,
|
||||
beam_width, no_repeat_ngram_size_buf, id_offset, vocab_size_padded_, step, stream_);
|
||||
}
|
||||
@ -341,6 +341,7 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
|
||||
step, ite, logits_slice, end_id_slice, static_cast<std::int32_t>(max_seq_len)};
|
||||
|
||||
decode_input_tensors.embedding_bias = params.embedding_bias;
|
||||
decode_input_tensors.finished = params.finished;
|
||||
|
||||
if (params.input_lengths)
|
||||
{
|
||||
|
||||
@ -101,6 +101,7 @@ public:
|
||||
tc::Tensor end_ids; // [batch_size], on gpu
|
||||
|
||||
// optional parameters
|
||||
std::optional<tc::Tensor> finished; // [batch_size * beam_width], optional
|
||||
std::optional<tc::Tensor> src_cache_indirection; // [local_batch_size, beam_width, max_seq_len] - the k/v cache
|
||||
// index for beam search, mandatory for beam search, on gpu
|
||||
std::optional<tc::Tensor> sequence_limit_length; // [batch_size], on gpu
|
||||
|
||||
@ -93,7 +93,7 @@ void TopKSamplingLayer<T>::allocateBuffer(size_t const batch_size, std::vector<u
|
||||
max_top_k = 1;
|
||||
}
|
||||
invokeTopKSampling<T>(nullptr, sampling_workspace_size_, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
nullptr, max_top_k, 1.0f, vocab_size_padded_, nullptr, stream_, batch_size, skip_decode_buf_);
|
||||
nullptr, nullptr, max_top_k, 1.0f, vocab_size_padded_, nullptr, stream_, batch_size, skip_decode_buf_);
|
||||
sampling_workspace_ = allocator_->reMalloc(sampling_workspace_, sampling_workspace_size_, false);
|
||||
runtime_top_k_buf_ = allocator_->reMalloc(runtime_top_k_buf_, sizeof(uint32_t) * batch_size, false);
|
||||
runtime_top_p_buf_ = allocator_->reMalloc(runtime_top_p_buf_, sizeof(float) * batch_size, false);
|
||||
@ -171,9 +171,10 @@ void TopKSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
|
||||
auto* logits = !skip_any_ ? params.logits.template getPtr<T>() : runtime_logits_buf_;
|
||||
auto* end_ids = params.end_ids.template getPtr<const int>();
|
||||
|
||||
bool* finished = (outputs.finished) ? outputs.finished->template getPtr<bool>() : nullptr;
|
||||
bool* finished_input = (params.finished) ? params.finished->template getPtr<bool>() : nullptr;
|
||||
bool* finished_output = (outputs.finished) ? outputs.finished->template getPtr<bool>() : nullptr;
|
||||
invokeAddBiasEndMask(
|
||||
logits, (T*) (nullptr), end_ids, finished, local_batch_size, vocab_size_, vocab_size_padded_, stream_);
|
||||
logits, (T*) (nullptr), end_ids, finished_input, local_batch_size, vocab_size_, vocab_size_padded_, stream_);
|
||||
sync_check_cuda_error();
|
||||
|
||||
float* cum_log_probs = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr<float>() : nullptr;
|
||||
@ -181,16 +182,16 @@ void TopKSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
|
||||
|
||||
if (cum_log_probs != nullptr || output_log_probs != nullptr)
|
||||
{
|
||||
invokeAddBiasSoftMax(
|
||||
logits, (T*) (nullptr), end_ids, finished, local_batch_size, vocab_size_, vocab_size_padded_, stream_);
|
||||
invokeAddBiasSoftMax(logits, (T*) (nullptr), end_ids, finished_input, local_batch_size, vocab_size_,
|
||||
vocab_size_padded_, stream_);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
int* sequence_length = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<int>() : nullptr;
|
||||
|
||||
invokeBatchTopKSampling(sampling_workspace_, sampling_workspace_size_, logits,
|
||||
outputs.output_ids_ptr.template getPtr<int*>(), sequence_length, finished, cum_log_probs, output_log_probs,
|
||||
curandstate_buf_ + ite * local_batch_size,
|
||||
outputs.output_ids_ptr.template getPtr<int*>(), sequence_length, finished_input, finished_output, cum_log_probs,
|
||||
output_log_probs, curandstate_buf_ + ite * local_batch_size,
|
||||
(int) runtime_max_top_k_, // useless because runtime_top_k_buf_ is never
|
||||
// nullptr. Keep for legacy.
|
||||
(int*) (runtime_top_k_buf_ + ite * local_batch_size),
|
||||
|
||||
@ -110,7 +110,8 @@ void TopPSamplingLayer<T>::allocateBuffer(std::size_t batch_size, std::vector<fl
|
||||
sampling_workspace_size_, cub_temp_storage_size_,
|
||||
nullptr, // output_ids
|
||||
nullptr, // sequence_length
|
||||
nullptr, // finished_buffer
|
||||
nullptr, // finished_input_buffer
|
||||
nullptr, // finished_output_buffer
|
||||
nullptr, // cum_log_probs
|
||||
nullptr, // output_log_probs
|
||||
nullptr, // log_probs
|
||||
@ -241,9 +242,10 @@ void TopPSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
|
||||
topp_id_vals_buf_, topp_offset_buf_, begin_topp_offset_buf_, local_batch_size, vocab_size_padded_, stream_);
|
||||
sync_check_cuda_error();
|
||||
|
||||
bool* finished = (outputs.finished) ? outputs.finished->template getPtr<bool>() : nullptr;
|
||||
bool* finished_input = (params.finished) ? params.finished->template getPtr<bool>() : nullptr;
|
||||
bool* finished_output = (outputs.finished) ? outputs.finished->template getPtr<bool>() : nullptr;
|
||||
invokeAddBiasSoftMax(
|
||||
logits, (T*) (nullptr), end_ids, finished, local_batch_size, vocab_size_, vocab_size_padded_, stream_);
|
||||
logits, (T*) (nullptr), end_ids, finished_input, local_batch_size, vocab_size_, vocab_size_padded_, stream_);
|
||||
sync_check_cuda_error();
|
||||
|
||||
float* cum_log_probs = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr<float>() : nullptr;
|
||||
@ -251,10 +253,10 @@ void TopPSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
|
||||
int* sequence_length = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<int>() : nullptr;
|
||||
|
||||
invokeBatchTopPSampling<T>(sampling_workspace_, sampling_workspace_size_, cub_temp_storage_size_,
|
||||
outputs.output_ids_ptr.template getPtr<int*>(), sequence_length, finished, cum_log_probs, output_log_probs,
|
||||
logits, topp_id_vals_buf_, topp_offset_buf_, begin_topp_offset_buf_, curandstate_buf_ + ite * local_batch_size,
|
||||
local_batch_size, vocab_size_padded_, end_ids, runtime_max_top_p_, runtime_top_p_buf_ + ite * local_batch_size,
|
||||
stream_, skip_decode_buf_ + ite * local_batch_size);
|
||||
outputs.output_ids_ptr.template getPtr<int*>(), sequence_length, finished_input, finished_output, cum_log_probs,
|
||||
output_log_probs, logits, topp_id_vals_buf_, topp_offset_buf_, begin_topp_offset_buf_,
|
||||
curandstate_buf_ + ite * local_batch_size, local_batch_size, vocab_size_padded_, end_ids, runtime_max_top_p_,
|
||||
runtime_top_p_buf_ + ite * local_batch_size, stream_, skip_decode_buf_ + ite * local_batch_size);
|
||||
sync_check_cuda_error();
|
||||
|
||||
invokeComputeToppDecay(runtime_top_p_buf_ + ite * local_batch_size, initial_top_p_buf_ + ite * local_batch_size,
|
||||
|
||||
@ -47,7 +47,8 @@ BertAttentionPlugin::BertAttentionPlugin(int num_heads, int head_size, float q_s
|
||||
, mRemovePadding(remove_padding)
|
||||
{
|
||||
// pre-check whether FMHA is supported in order to save memory allocation
|
||||
mEnableContextFMHA = mEnableContextFMHA && (mType == DataType::kHALF) && MHARunner::fmha_supported(mHeadSize, mSM);
|
||||
mEnableContextFMHA = mEnableContextFMHA && (mType == DataType::kHALF) && MHARunner::fmha_supported(mHeadSize, mSM)
|
||||
&& !mRelativeAttention;
|
||||
}
|
||||
|
||||
// Parameterized constructor
|
||||
@ -266,7 +267,9 @@ int BertAttentionPlugin::enqueueImpl(const nvinfer1::PluginTensorDesc* inputDesc
|
||||
|
||||
T* linear_bias_slopes = nullptr;
|
||||
|
||||
if (mEnableContextFMHA && !mRelativeAttention)
|
||||
// FMHA doesn't apply to MHA with relative attention bias, i.e. softmax(QK + bias) * V
|
||||
// We update mEnableContextFMHA in constructor to check this condition
|
||||
if (mEnableContextFMHA)
|
||||
{
|
||||
// b, max_seqlen, actual_total_seqlen
|
||||
mFMHARunner->setup(request_batch_size, request_seq_len, request_seq_len, request_batch_size * request_seq_len);
|
||||
|
||||
@ -271,7 +271,8 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int num_heads, int num_kv_hea
|
||||
{
|
||||
// pre-check whether FMHA is supported in order to save memory allocation
|
||||
mEnableContextFMHA = mEnableContextFMHA && (mType == DataType::kHALF || mType == DataType::kBF16)
|
||||
&& MHARunner::fmha_supported(getHeadSize(), mSM);
|
||||
&& MHARunner::fmha_supported(getHeadSize(), mSM) && !mCrossAttention
|
||||
&& mPositionEmbeddingType != tensorrt_llm::kernels::PositionEmbeddingType::kRELATIVE;
|
||||
|
||||
TLLM_CHECK(isRoPE() == (rotary_embedding_dim != 0));
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
@ -583,7 +584,8 @@ int GPTAttentionPluginCommon::enqueueContext(const EnqueueContextParams<T, KVCac
|
||||
// in context phase, currently FMHA runner has two restrictions:
|
||||
// 1. only apply to self attention. If want fused multi-head cross attention, FMHCA kernels and runner is needed
|
||||
// 2. doesn't apply to MHA with relative attention bias, i.e. softmax(QK + bias) * V
|
||||
if (mEnableContextFMHA && !isCrossAttention() && !isRelativePosition())
|
||||
// We update mEnableContextFMHA in constructor to check these conditions
|
||||
if (mEnableContextFMHA)
|
||||
{
|
||||
invokeApplyBiasRopeUpdateKVCache(const_cast<T*>(params.attention_input), kv_cache_buffer,
|
||||
const_cast<T*>(params.qkv_bias), params.context_lengths, mRemovePadding ? padding_offset : nullptr,
|
||||
|
||||
@ -308,11 +308,17 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe
|
||||
}
|
||||
const int n = inputDesc[mWeightInputIdx].dims.d[1];
|
||||
const int k = inputDesc[0].dims.d[inputDesc[0].dims.nbDims - 1];
|
||||
bool use_cuda_kernel = m < SMALL_M_FAST_PATH && mCudaKernelEnabled;
|
||||
bool use_pre_quant_scale = mQuantAlgo & PRE_QUANT_SCALE;
|
||||
|
||||
// mQuantAlgo = pre_quant_scale * 4 + zero * 2 + bias
|
||||
if (mQuantAlgo & PRE_QUANT_SCALE)
|
||||
const half* zeros_ptr = (mQuantAlgo & ZERO) ? reinterpret_cast<const half*>(inputs[mZerosInputIdx]) : nullptr;
|
||||
const half* biases_ptr = (mQuantAlgo & BIAS) ? reinterpret_cast<const half*>(inputs[mBiasesInputIdx]) : nullptr;
|
||||
const half* act_ptr = reinterpret_cast<const half*>(inputs[0]);
|
||||
|
||||
if (use_pre_quant_scale && !use_cuda_kernel)
|
||||
{
|
||||
// Apply pre-quant per channel scale on activations
|
||||
act_ptr = reinterpret_cast<const half*>(workspace);
|
||||
if (mType == nvinfer1::DataType::kHALF)
|
||||
{
|
||||
tensorrt_llm::kernels::apply_per_channel_scale_kernel_launcher<half>(reinterpret_cast<half*>(workspace),
|
||||
@ -329,10 +335,6 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe
|
||||
#endif
|
||||
}
|
||||
|
||||
const half* zeros_ptr = (mQuantAlgo & ZERO) ? reinterpret_cast<const half*>(inputs[mZerosInputIdx]) : nullptr;
|
||||
const half* biases_ptr = (mQuantAlgo & BIAS) ? reinterpret_cast<const half*>(inputs[mBiasesInputIdx]) : nullptr;
|
||||
const half* act_ptr = reinterpret_cast<const half*>((mQuantAlgo & PRE_QUANT_SCALE) ? workspace : inputs[0]);
|
||||
|
||||
#if defined(ENABLE_BF16)
|
||||
TLLM_CHECK_WITH_INFO(mType == nvinfer1::DataType::kHALF || mType == nvinfer1::DataType::kBF16,
|
||||
"No valid weightOnlyGropwiseQuantMatmul configuration");
|
||||
@ -350,14 +352,18 @@ int WeightOnlyGroupwiseQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDe
|
||||
{
|
||||
weight_only_act_type = tensorrt_llm::kernels::WeightOnlyActivationType::BF16;
|
||||
}
|
||||
if (m < SMALL_M_FAST_PATH && mCudaKernelEnabled)
|
||||
if (use_cuda_kernel)
|
||||
{
|
||||
// Use CUDA kernels for small batch size
|
||||
// The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass kernel
|
||||
// when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights.
|
||||
const void* pre_quant_scale = nullptr;
|
||||
if (use_pre_quant_scale)
|
||||
pre_quant_scale = inputs[mPreQuantScaleInputIdx];
|
||||
tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast<const uint8_t*>(inputs[mWeightInputIdx]),
|
||||
inputs[mScalesInputIdx], zeros_ptr, act_ptr, biases_ptr, outputs[0], m, real_n, k, mGroupSize,
|
||||
tensorrt_llm::kernels::WeightOnlyQuantType::Int4b, tensorrt_llm::kernels::WeightOnlyType::GroupWise,
|
||||
inputs[mScalesInputIdx], zeros_ptr, act_ptr, pre_quant_scale, biases_ptr, outputs[0], m, real_n, k,
|
||||
mGroupSize, tensorrt_llm::kernels::WeightOnlyQuantType::Int4b,
|
||||
tensorrt_llm::kernels::WeightOnlyType::GroupWise,
|
||||
tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type};
|
||||
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream);
|
||||
}
|
||||
|
||||
@ -322,7 +322,7 @@ int WeightOnlyQuantMatmulPlugin::enqueue(const nvinfer1::PluginTensorDesc* input
|
||||
// The CUDA kernel is designed for ColumnMajorTileInterleave weight layout used in fpAIntB cutlass
|
||||
// kernel when sm >= 75 and the preprocessing of cutlass on sm70 does not interleave the weights.
|
||||
tensorrt_llm::kernels::WeightOnlyParams params{reinterpret_cast<const uint8_t*>(inputs[1]), inputs[2], nullptr,
|
||||
inputs[0], nullptr, outputs[0], m, real_n, k, 0, weight_only_quant_type,
|
||||
inputs[0], nullptr, nullptr, outputs[0], m, real_n, k, 0, weight_only_quant_type,
|
||||
tensorrt_llm::kernels::WeightOnlyType::PerChannel,
|
||||
tensorrt_llm::kernels::WeightOnlyActivationFunctionType::Identity, weight_only_act_type};
|
||||
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, stream);
|
||||
|
||||
@ -24,7 +24,14 @@ endif()
|
||||
|
||||
find_package(pybind11 REQUIRED)
|
||||
|
||||
set(SRCS bindings.cpp runtime/generationInput.cpp runtime/generationOutput.cpp)
|
||||
set(SRCS
|
||||
bindings.cpp
|
||||
batch_manager/gptManager.cpp
|
||||
batch_manager/llmRequest.cpp
|
||||
batch_manager/inferenceRequest.cpp
|
||||
batch_manager/namedTensor.cpp
|
||||
runtime/generationInput.cpp
|
||||
runtime/generationOutput.cpp)
|
||||
|
||||
pybind11_add_module(${TRTLLM_PYBIND_MODULE} ${SRCS})
|
||||
|
||||
|
||||
88
cpp/tensorrt_llm/pybind/batch_manager/gptManager.cpp
Normal file
88
cpp/tensorrt_llm/pybind/batch_manager/gptManager.cpp
Normal file
@ -0,0 +1,88 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "gptManager.h"
|
||||
#include "inferenceRequest.h"
|
||||
#include "namedTensor.h"
|
||||
#include "tensorrt_llm/batch_manager/GptManager.h"
|
||||
#include "tensorrt_llm/batch_manager/callbacks.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
GptManager::GptManager(std::filesystem::path const& trtEnginePath, tb::TrtGptModelType modelType, int32_t maxBeamWidth,
|
||||
tb::batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
|
||||
SendResponseCallback sendResponseCb, tb::PollStopSignalCallback pollStopSignalCb,
|
||||
tb::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb, const tb::TrtGptModelOptionalParams& optionalParams,
|
||||
std::optional<uint64_t> terminateReqId)
|
||||
: tb::GptManager(trtEnginePath, modelType, maxBeamWidth, schedulerPolicy, callbackAdapter(getInferenceRequestsCb),
|
||||
callbackAdapter(sendResponseCb), pollStopSignalCb, returnBatchManagerStatsCb, optionalParams, terminateReqId)
|
||||
{
|
||||
}
|
||||
|
||||
py::object GptManager::enter()
|
||||
{
|
||||
return py::cast(this);
|
||||
}
|
||||
|
||||
void GptManager::exit(py::handle type, py::handle value, py::handle traceback)
|
||||
{
|
||||
// NOTE: we must release the GIL here. GptManager has spawned a thread for the execution loop. That thread must be
|
||||
// able to do forward progress for the shutdown process to succeed. For that, we must manually release the GIL while
|
||||
// waiting in `process.join()`.
|
||||
py::gil_scoped_release release;
|
||||
shutdown();
|
||||
}
|
||||
|
||||
tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback)
|
||||
{
|
||||
return [callback](int32_t max_sequences)
|
||||
{
|
||||
std::list<InferenceRequest> pythonResults = callback(max_sequences);
|
||||
std::list<std::shared_ptr<tb::InferenceRequest>> cppResults{};
|
||||
|
||||
for (const auto& ir : pythonResults)
|
||||
{
|
||||
cppResults.push_back(ir.toTrtLlm());
|
||||
}
|
||||
|
||||
return cppResults;
|
||||
};
|
||||
}
|
||||
|
||||
tb::SendResponseCallback callbackAdapter(SendResponseCallback callback)
|
||||
{
|
||||
return [callback](uint64_t id, std::list<tb::NamedTensor> const& cppTensors, bool isOk, const std::string& errMsg)
|
||||
{
|
||||
std::list<NamedTensor> pythonList{};
|
||||
for (const auto& cppNamedTensor : cppTensors)
|
||||
{
|
||||
pythonList.push_back(NamedTensor{cppNamedTensor});
|
||||
}
|
||||
callback(id, pythonList, isOk, errMsg);
|
||||
};
|
||||
}
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
54
cpp/tensorrt_llm/pybind/batch_manager/gptManager.h
Normal file
54
cpp/tensorrt_llm/pybind/batch_manager/gptManager.h
Normal file
@ -0,0 +1,54 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "inferenceRequest.h"
|
||||
#include "namedTensor.h"
|
||||
#include "tensorrt_llm/batch_manager/GptManager.h"
|
||||
#include "tensorrt_llm/batch_manager/callbacks.h"
|
||||
#include <pybind11/functional.h>
|
||||
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <functional>
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
using GetInferenceRequestsCallback = std::function<std::list<InferenceRequest>(int32_t)>;
|
||||
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, const std::string&)>;
|
||||
|
||||
tb::GetInferenceRequestsCallback callbackAdapter(GetInferenceRequestsCallback callback);
|
||||
tb::SendResponseCallback callbackAdapter(SendResponseCallback callback);
|
||||
|
||||
class GptManager : tb::GptManager
|
||||
{
|
||||
public:
|
||||
GptManager(std::filesystem::path const& trtEnginePath, tb::TrtGptModelType modelType, int32_t maxBeamWidth,
|
||||
tb::batch_scheduler::SchedulerPolicy schedulerPolicy, GetInferenceRequestsCallback getInferenceRequestsCb,
|
||||
SendResponseCallback sendResponseCb, tb::PollStopSignalCallback pollStopSignalCb = nullptr,
|
||||
tb::ReturnBatchManagerStatsCallback returnBatchManagerStatsCb = nullptr,
|
||||
const tb::TrtGptModelOptionalParams& optionalParams = tb::TrtGptModelOptionalParams(),
|
||||
std::optional<uint64_t> terminateReqId = std::nullopt);
|
||||
|
||||
py::object enter();
|
||||
void exit(py::handle type, py::handle value, py::handle traceback);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
37
cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp
Normal file
37
cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.cpp
Normal file
@ -0,0 +1,37 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "inferenceRequest.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
#include <memory>
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
|
||||
using namespace tensorrt_llm::pybind::batch_manager;
|
||||
|
||||
std::shared_ptr<tb::InferenceRequest> InferenceRequest::toTrtLlm() const
|
||||
{
|
||||
tb::InferenceRequest::TensorMap trtTensors;
|
||||
for (const auto& torchTensorItem : mInputTensors)
|
||||
{
|
||||
trtTensors.insert({torchTensorItem.first, tr::TorchView::of(torchTensorItem.second)});
|
||||
}
|
||||
|
||||
return std::make_shared<tb::InferenceRequest>(trtTensors, mRequestId);
|
||||
}
|
||||
59
cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.h
Normal file
59
cpp/tensorrt_llm/pybind/batch_manager/inferenceRequest.h
Normal file
@ -0,0 +1,59 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
class InferenceRequest : public tensorrt_llm::batch_manager::GenericInferenceRequest<at::Tensor,
|
||||
std::unordered_map<std::string, at::Tensor>>
|
||||
{
|
||||
public:
|
||||
using Base
|
||||
= tensorrt_llm::batch_manager::GenericInferenceRequest<at::Tensor, std::unordered_map<std::string, at::Tensor>>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
using TensorMap = Base::TensorMap;
|
||||
|
||||
InferenceRequest(uint64_t requestId)
|
||||
: Base(requestId)
|
||||
{
|
||||
}
|
||||
|
||||
InferenceRequest(TensorMap const& inputTensors, uint64_t requestId)
|
||||
: Base(inputTensors, requestId)
|
||||
{
|
||||
}
|
||||
|
||||
InferenceRequest(TensorMap&& inputTensors, uint64_t requestId)
|
||||
: Base(inputTensors, requestId)
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::InferenceRequest> toTrtLlm() const;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
54
cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp
Normal file
54
cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp
Normal file
@ -0,0 +1,54 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "llmRequest.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/runtime/generationInput.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
#include <memory>
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
|
||||
using namespace tensorrt_llm::pybind::batch_manager;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
std::optional<tb::LlmRequest::TensorPtr> from_torch(std::optional<LlmRequest::TensorPtr> torchPtr)
|
||||
{
|
||||
if (torchPtr)
|
||||
{
|
||||
return tr::TorchView::of(torchPtr.value());
|
||||
}
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
std::shared_ptr<tb::LlmRequest> LlmRequest::toTrtLlm() const
|
||||
{
|
||||
auto embeddingBias = from_torch(mEmbeddingBias);
|
||||
auto badWordsList = from_torch(mBadWordsList);
|
||||
auto stopWordsList = from_torch(mStopWordsList);
|
||||
auto promptEmbeddingTable = from_torch(mPromptEmbeddingTable);
|
||||
|
||||
return std::make_shared<tb::LlmRequest>(mRequestId, mMaxNewTokens,
|
||||
std::make_shared<std::vector<TokenIdType>>(mTokens.at(0)), mSamplingConfig, mIsStreaming, mEndId, mPadId,
|
||||
embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable, mPromptVocabSize, mReturnLogProbs,
|
||||
mDraftTokens);
|
||||
}
|
||||
62
cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h
Normal file
62
cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h
Normal file
@ -0,0 +1,62 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
class LlmRequest : public tensorrt_llm::batch_manager::GenericLlmRequest<at::Tensor>
|
||||
{
|
||||
public:
|
||||
using Base = GenericLlmRequest<at::Tensor>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
using SizeType = Base::SizeType;
|
||||
using TokenIdType = Base::TokenIdType;
|
||||
using RequestIdType = Base::RequestIdType;
|
||||
using VecLogProbs = Base::VecLogProbs;
|
||||
using BeamTokens = Base::BeamTokens;
|
||||
using VecTokens = Base::VecTokens;
|
||||
|
||||
LlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::vector<TokenIdType> inputTokens,
|
||||
runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
|
||||
std::optional<SizeType> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
|
||||
std::optional<TensorPtr> badWordsList = std::nullopt, std::optional<TensorPtr> stopWordsList = std::nullopt,
|
||||
std::optional<TensorPtr> promptEmbeddingTable = std::nullopt,
|
||||
std::optional<SizeType> promptVocabSize = std::nullopt, bool returnLogProbs = false,
|
||||
std::optional<VecTokens> draftTokens = std::nullopt)
|
||||
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
|
||||
samplingConfig, isStreaming, endId, padId, embeddingBias, badWordsList, stopWordsList, promptEmbeddingTable,
|
||||
promptVocabSize, returnLogProbs,
|
||||
draftTokens.has_value() ? std::make_shared<VecTokens>(std::move(draftTokens.value()))
|
||||
: std::make_shared<VecTokens>())
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest> toTrtLlm() const;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
47
cpp/tensorrt_llm/pybind/batch_manager/namedTensor.cpp
Normal file
47
cpp/tensorrt_llm/pybind/batch_manager/namedTensor.cpp
Normal file
@ -0,0 +1,47 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "namedTensor.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include "tensorrt_llm/runtime/torchUtils.h"
|
||||
#include "tensorrt_llm/runtime/torchView.h"
|
||||
#include <memory>
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
NamedTensor::NamedTensor(const tb::NamedTensor& cppNamedTensor)
|
||||
: Base(cppNamedTensor.name)
|
||||
{
|
||||
auto cppTensor = cppNamedTensor.tensor;
|
||||
std::vector<at::IntArrayRef::value_type> shapeValues;
|
||||
for (int i = 0; i < cppTensor->getShape().nbDims; ++i)
|
||||
{
|
||||
shapeValues.push_back(cppTensor->getShape().d[i]);
|
||||
}
|
||||
|
||||
tensor = at::from_blob(cppTensor->data(), shapeValues,
|
||||
at::TensorOptions()
|
||||
.device(runtime::TorchUtils::deviceType(cppTensor->getMemoryType()))
|
||||
.pinned_memory(cppTensor->getMemoryType() == runtime::MemoryType::kPINNED)
|
||||
.dtype(runtime::TorchUtils::dataType(cppTensor->getDataType())));
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
50
cpp/tensorrt_llm/pybind/batch_manager/namedTensor.h
Normal file
50
cpp/tensorrt_llm/pybind/batch_manager/namedTensor.h
Normal file
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/NamedTensor.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <ATen/ops/from_blob.h>
|
||||
#include <ATen/ops/tensor.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/util/ArrayRef.h>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
|
||||
namespace tensorrt_llm::pybind::batch_manager
|
||||
{
|
||||
|
||||
struct NamedTensor : public tb::GenericNamedTensor<std::optional<at::Tensor>>
|
||||
{
|
||||
using Base = tb::GenericNamedTensor<std::optional<at::Tensor>>;
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
|
||||
NamedTensor(TensorPtr _tensor, std::string _name)
|
||||
: Base(_tensor, _name){};
|
||||
|
||||
NamedTensor(const tb::NamedTensor& cppNamedTensor);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::pybind::batch_manager
|
||||
@ -15,14 +15,28 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <memory>
|
||||
#include <pybind11/functional.h>
|
||||
#include <pybind11/operators.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <torch/extension.h>
|
||||
#include <vector>
|
||||
|
||||
#include "batch_manager/gptManager.h"
|
||||
#include "batch_manager/inferenceRequest.h"
|
||||
#include "batch_manager/llmRequest.h"
|
||||
#include "batch_manager/namedTensor.h"
|
||||
#include "runtime/generationInput.h"
|
||||
#include "runtime/generationOutput.h"
|
||||
#include "tensorrt_llm/batch_manager/BatchManager.h"
|
||||
#include "tensorrt_llm/batch_manager/batchScheduler.h"
|
||||
#include "tensorrt_llm/batch_manager/trtGptModelOptionalParams.h"
|
||||
#include "utils/pathCaster.h"
|
||||
|
||||
#include "tensorrt_llm/batch_manager/GptManager.h"
|
||||
#include "tensorrt_llm/batch_manager/inferenceRequest.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/gptJsonConfig.h"
|
||||
@ -31,9 +45,13 @@
|
||||
|
||||
namespace py = pybind11;
|
||||
namespace tb = tensorrt_llm::batch_manager;
|
||||
namespace tbb = tensorrt_llm::batch_manager::batch_scheduler;
|
||||
namespace tbk = tensorrt_llm::batch_manager::kv_cache_manager;
|
||||
namespace tpb = tensorrt_llm::pybind::batch_manager;
|
||||
namespace tc = tensorrt_llm::common;
|
||||
namespace tr = tensorrt_llm::runtime;
|
||||
namespace tpr = tensorrt_llm::pybind::runtime;
|
||||
using SizeType = tr::SizeType;
|
||||
|
||||
#if not defined(TRTLLM_PYBIND_MODULE)
|
||||
#error "TRTLLM_PYBIND_MODULE must be defined"
|
||||
@ -53,8 +71,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_readwrite("prompt_tuning_enabled", &tpr::PromptTuningParams::promptTuningEnabled);
|
||||
|
||||
py::class_<tpr::GenerationInput>(m, "GenerationInput")
|
||||
.def(py::init<tr::SizeType, tr::SizeType, tpr::GenerationInput::TensorPtr, tpr::GenerationInput::TensorPtr,
|
||||
bool>(),
|
||||
.def(py::init<SizeType, SizeType, tpr::GenerationInput::TensorPtr, tpr::GenerationInput::TensorPtr, bool>(),
|
||||
py::arg("end_id"), py::arg("pad_id"), py::arg("ids"), py::arg("lengths"), py::arg("packed") = false)
|
||||
.def_readwrite("end_id", &tpr::GenerationInput::endId)
|
||||
.def_readwrite("pad_id", &tpr::GenerationInput::padId)
|
||||
@ -76,16 +93,16 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_readwrite("context_logits", &tpr::GenerationOutput::contextLogits)
|
||||
.def_readwrite("on_token_generated", &tpr::GenerationOutput::onTokenGenerated);
|
||||
|
||||
py::class_<tb::kv_cache_manager::KvCacheConfig>(m, "KvCacheConfig")
|
||||
.def(py::init<std::optional<tr::SizeType>, std::optional<tr::SizeType>, std::optional<float>>(),
|
||||
py::class_<tbk::KvCacheConfig>(m, "KvCacheConfig")
|
||||
.def(py::init<std::optional<SizeType>, std::optional<SizeType>, std::optional<float>>(),
|
||||
py::arg("max_tokens") = py::none(), py::arg("max_kv_cache_length") = py::none(),
|
||||
py::arg("free_gpu_memory_fraction") = py::none())
|
||||
.def_readwrite("max_tokens", &tb::kv_cache_manager::KvCacheConfig::maxTokens)
|
||||
.def_readwrite("max_kv_cache_length", &tb::kv_cache_manager::KvCacheConfig::maxKvCacheLength)
|
||||
.def_readwrite("free_gpu_memory_fraction", &tb::kv_cache_manager::KvCacheConfig::freeGpuMemoryFraction);
|
||||
.def_readwrite("max_tokens", &tbk::KvCacheConfig::maxTokens)
|
||||
.def_readwrite("max_kv_cache_length", &tbk::KvCacheConfig::maxKvCacheLength)
|
||||
.def_readwrite("free_gpu_memory_fraction", &tbk::KvCacheConfig::freeGpuMemoryFraction);
|
||||
|
||||
py::class_<tr::GptSession::Config>(m, "GptSessionConfig")
|
||||
.def(py::init<tr::SizeType, tr::SizeType, tr::SizeType>(), py::arg("max_batch_size"), py::arg("max_beam_width"),
|
||||
.def(py::init<SizeType, SizeType, SizeType>(), py::arg("max_batch_size"), py::arg("max_beam_width"),
|
||||
py::arg("max_sequence_length"))
|
||||
.def_readwrite("max_batch_size", &tr::GptSession::Config::maxBatchSize)
|
||||
.def_readwrite("max_beam_width", &tr::GptSession::Config::maxBeamWidth)
|
||||
@ -148,9 +165,8 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def(py::self != py::self);
|
||||
|
||||
py::class_<tr::GptModelConfig>(m, "GptModelConfig")
|
||||
.def(py::init<tr::SizeType, tr::SizeType, tr::SizeType, tr::SizeType, nvinfer1::DataType>(),
|
||||
py::arg("vocab_size"), py::arg("num_layers"), py::arg("num_heads"), py::arg("hidden_size"),
|
||||
py::arg("data_type"))
|
||||
.def(py::init<SizeType, SizeType, SizeType, SizeType, nvinfer1::DataType>(), py::arg("vocab_size"),
|
||||
py::arg("num_layers"), py::arg("num_heads"), py::arg("hidden_size"), py::arg("data_type"))
|
||||
.def_property_readonly("vocab_size", &tr::GptModelConfig::getVocabSize)
|
||||
.def("vocab_size_padded", &tr::GptModelConfig::getVocabSizePadded, py::arg("world_size"))
|
||||
.def("num_layers", &tr::GptModelConfig::getNbLayers, py::arg("pipeline_parallelism") = 1)
|
||||
@ -185,7 +201,7 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
py::overload_cast<bool>(&tr::GptModelConfig::useCustomAllReduce));
|
||||
|
||||
py::class_<tr::WorldConfig>(m, "WorldConfig")
|
||||
.def(py::init<tr::SizeType, tr::SizeType, tr::SizeType, tr::SizeType>(), py::arg("tensor_parallelism") = 1,
|
||||
.def(py::init<SizeType, SizeType, SizeType, SizeType>(), py::arg("tensor_parallelism") = 1,
|
||||
py::arg("pipeline_parallelism") = 1, py::arg("rank") = 0,
|
||||
py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode)
|
||||
.def_property_readonly("size", &tr::WorldConfig::getSize)
|
||||
@ -199,13 +215,12 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_property_readonly("pipeline_parallel_rank", &tr::WorldConfig::getPipelineParallelRank)
|
||||
.def_property_readonly("tensor_parallel_rank", &tr::WorldConfig::getTensorParallelRank)
|
||||
.def_static("mpi",
|
||||
py::overload_cast<tr::SizeType, std::optional<tr::SizeType>, std::optional<tr::SizeType>>(
|
||||
&tr::WorldConfig::mpi),
|
||||
py::overload_cast<SizeType, std::optional<SizeType>, std::optional<SizeType>>(&tr::WorldConfig::mpi),
|
||||
py::arg("gpus_per_node") = tr::WorldConfig::kDefaultGpusPerNode, py::arg("tensor_parallelism") = py::none(),
|
||||
py::arg("pipeline_parallelism") = py::none());
|
||||
|
||||
py::class_<tr::SamplingConfig>(m, "SamplingConfig")
|
||||
.def(py::init<tr::SizeType>(), py::arg("beam_width") = 1)
|
||||
.def(py::init<SizeType>(), py::arg("beam_width") = 1)
|
||||
.def_readwrite("beam_width", &tr::SamplingConfig::beamWidth)
|
||||
.def_readwrite("temperature", &tr::SamplingConfig::temperature)
|
||||
.def_readwrite("min_length", &tr::SamplingConfig::minLength)
|
||||
@ -221,13 +236,12 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
.def_readwrite("length_penalty", &tr::SamplingConfig::lengthPenalty);
|
||||
|
||||
py::class_<tr::GptJsonConfig>(m, "GptJsonConfig")
|
||||
.def(py::init<std::string, std::string, tr::SizeType, tr::SizeType, tr::GptModelConfig>(), py::arg("name"),
|
||||
.def(py::init<std::string, std::string, SizeType, SizeType, tr::GptModelConfig>(), py::arg("name"),
|
||||
py::arg("precision"), py::arg("tensor_parallelism"), py::arg("pipeline_parallelism"),
|
||||
py::arg("model_config"))
|
||||
.def_static("parse", py::overload_cast<std::string const&>(&tr::GptJsonConfig::parse), py::arg("json"))
|
||||
.def_static(
|
||||
"parse_file", [](std::string const& file) { return tr::GptJsonConfig::parse(std::filesystem::path(file)); },
|
||||
py::arg("file"))
|
||||
"parse_file", py::overload_cast<std::filesystem::path const&>(&tr::GptJsonConfig::parse), py::arg("path"))
|
||||
.def_property_readonly("model_config", &tr::GptJsonConfig::getModelConfig)
|
||||
.def_property_readonly("name", &tr::GptJsonConfig::getName)
|
||||
.def_property_readonly("precision", &tr::GptJsonConfig::getPrecision)
|
||||
@ -254,4 +268,104 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
|
||||
tr::SamplingConfig const& samplingConfig)
|
||||
{ self.generate(*outputs.toTrtLlm(), *inputs.toTrtLlm(), samplingConfig); },
|
||||
py::arg("outputs"), py::arg("inputs"), py::arg("sampling_config"));
|
||||
|
||||
py::enum_<tb::LlmRequestState_t>(m, "LlmRequestState")
|
||||
.value("REQUEST_STATE_UNKNOWN", tb::LlmRequestState_t::REQUEST_STATE_UNKNOWN)
|
||||
.value("REQUEST_STATE_CONTEXT_INIT", tb::LlmRequestState_t::REQUEST_STATE_CONTEXT_INIT)
|
||||
.value("REQUEST_STATE_GENERATION_IN_PROGRESS", tb::LlmRequestState_t::REQUEST_STATE_GENERATION_IN_PROGRESS)
|
||||
.value("REQUEST_STATE_GENERATION_COMPLETE", tb::LlmRequestState_t::REQUEST_STATE_GENERATION_COMPLETE);
|
||||
|
||||
using LlmRequest = tpb::LlmRequest;
|
||||
py::class_<LlmRequest>(m, "LlmRequest")
|
||||
.def(py::init<LlmRequest::RequestIdType, LlmRequest::SizeType, std::vector<LlmRequest::TokenIdType>,
|
||||
tr::SamplingConfig, bool, std::optional<LlmRequest::SizeType>, std::optional<LlmRequest::SizeType>,
|
||||
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
|
||||
std::optional<LlmRequest::TensorPtr>, std::optional<LlmRequest::TensorPtr>,
|
||||
std::optional<LlmRequest::SizeType>, bool, std::optional<LlmRequest::VecTokens>>(),
|
||||
py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"),
|
||||
py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt,
|
||||
py::arg("embedding_bias") = std::nullopt, py::arg("bad_words_list") = std::nullopt,
|
||||
py::arg("stop_words_list") = std::nullopt, py::arg("prompt_embedding_table") = std::nullopt,
|
||||
py::arg("prompt_vocab_size") = std::nullopt, py::arg("return_log_probs") = false,
|
||||
py::arg("draft_tokens") = std::nullopt)
|
||||
.def("get_num_tokens", &LlmRequest::getNumTokens, py::arg("beam"))
|
||||
.def_property_readonly("max_beam_num_tokens", &LlmRequest::getMaxBeamNumTokens)
|
||||
.def("get_token", &LlmRequest::getToken, py::arg("beam"), py::arg("pos"))
|
||||
.def("get_tokens", &LlmRequest::getTokens, py::arg("beam"))
|
||||
.def_property_readonly("max_num_generated_tokens", &LlmRequest::getMaxNumGeneratedTokens)
|
||||
.def("add_new_token", &LlmRequest::addNewToken, py::arg("token"), py::arg("beam"))
|
||||
.def("add_new_tokens", &LlmRequest::addNewTokens, py::arg("beam_tokens"))
|
||||
.def("set_generated_tokens", &LlmRequest::setGeneratedTokens, py::arg("generated_beam_tokens"))
|
||||
.def("pause", &LlmRequest::pause, py::arg("max_input_len"))
|
||||
.def_property("max_sent_token_pos", &LlmRequest::getMaxSentTokenPos, &LlmRequest::setMaxSentTokenPos)
|
||||
.def_property_readonly("prompt_embedding_table", &LlmRequest::getPromptEmbeddingTable)
|
||||
.def_property_readonly("prompt_vocab_size", &LlmRequest::getPromptVocabSize)
|
||||
.def_property_readonly("embedding_bias", &LlmRequest::getEmbeddingBias)
|
||||
.def_property_readonly("bad_words_list", &LlmRequest::getBadWordsList)
|
||||
.def_property_readonly("stop_words_list", &LlmRequest::getStopWordsList)
|
||||
.def_readwrite("request_id", &LlmRequest::mRequestId)
|
||||
.def_readwrite("prompt_len", &LlmRequest::mPromptLen)
|
||||
.def_readwrite("max_new_tokens", &LlmRequest::mMaxNewTokens)
|
||||
.def_readwrite("sampling_config", &LlmRequest::mSamplingConfig)
|
||||
.def_readwrite("state", &LlmRequest::mState)
|
||||
.def_readwrite("is_streaming", &LlmRequest::mIsStreaming)
|
||||
.def_readwrite("end_id", &LlmRequest::mEndId)
|
||||
.def_readwrite("pad_id", &LlmRequest::mPadId)
|
||||
.def_readwrite("batch_slot", &LlmRequest::mBatchSlot)
|
||||
.def_property_readonly("return_log_probs", &LlmRequest::returnLogProbs)
|
||||
.def_property_readonly("log_probs", py::overload_cast<>(&LlmRequest::getLogProbs, py::const_))
|
||||
.def("get_log_probs", py::overload_cast<SizeType>(&LlmRequest::getLogProbs, py::const_))
|
||||
.def("set_log_probs", &LlmRequest::setLogProbs, py::arg("log_probs"), py::arg("beam"))
|
||||
.def_property_readonly("cum_log_probs", &LlmRequest::getCumLogProbs)
|
||||
.def("set_cum_log_prob", &LlmRequest::setCumLogProb, py::arg("cum_log_prob"), py::arg("beam"))
|
||||
.def_property_readonly("orig_prompt_len", &LlmRequest::getOrigPromptLen)
|
||||
.def("has_draft_tokens", &LlmRequest::hasDraftTokens)
|
||||
.def_property(
|
||||
"draft_tokens", [](LlmRequest& self) { return *self.getDraftTokens(); },
|
||||
[](LlmRequest& self, LlmRequest::VecTokens& draftTokens)
|
||||
{ self.setDraftTokens(std::make_shared<LlmRequest::VecTokens>(std::move(draftTokens))); });
|
||||
|
||||
using InferenceRequest = tpb::InferenceRequest;
|
||||
py::class_<InferenceRequest>(m, "InferenceRequest")
|
||||
.def(py::init<uint64_t>())
|
||||
.def(py::init<InferenceRequest::TensorMap const&, uint64_t>())
|
||||
.def("get_input_tensor", &InferenceRequest::getInputTensor, py::arg("input_tensor_name"))
|
||||
.def("emplace_input_tensor", &InferenceRequest::emplaceInputTensor, py::arg("input_tensor_name"),
|
||||
py::arg("input_tensor"))
|
||||
.def_property("is_streaming", &InferenceRequest::isStreaming, &InferenceRequest::setIsStreaming)
|
||||
.def_property_readonly("request_id", &InferenceRequest::getRequestId);
|
||||
|
||||
py::enum_<tb::TrtGptModelType>(m, "TrtGptModelType")
|
||||
.value("V1", tb::TrtGptModelType::V1)
|
||||
.value("InflightBatching", tb::TrtGptModelType::InflightBatching)
|
||||
.value("InflightFusedBatching", tb::TrtGptModelType::InflightFusedBatching);
|
||||
|
||||
py::enum_<tbb::SchedulerPolicy>(m, "SchedulerPolicy")
|
||||
.value("MAX_UTILIZATION", tbb::SchedulerPolicy::MAX_UTILIZATION)
|
||||
.value("GUARANTEED_NO_EVICT", tbb::SchedulerPolicy::GUARANTEED_NO_EVICT);
|
||||
|
||||
py::class_<tb::TrtGptModelOptionalParams>(m, "TrtGptModelOptionalParams")
|
||||
.def(py::init<tbk::KvCacheConfig, std::optional<SizeType>, bool>(),
|
||||
py::arg("kv_cache_config") = tbk::KvCacheConfig{}, py::arg("max_num_sequences") = py::none(),
|
||||
py::arg("enable_trt_overlap") = true)
|
||||
.def_readwrite("kv_cache_config", &tb::TrtGptModelOptionalParams::kvCacheConfig)
|
||||
.def_readwrite("max_num_sequences", &tb::TrtGptModelOptionalParams::maxNumSequences)
|
||||
.def_readwrite("enable_trt_overlap", &tb::TrtGptModelOptionalParams::enableTrtOverlap);
|
||||
|
||||
py::class_<tpb::NamedTensor>(m, "NamedTensor")
|
||||
.def(py::init<tpb::NamedTensor::TensorPtr, std::string>(), py::arg("tensor"), py::arg("name"))
|
||||
.def_readwrite("tensor", &tpb::NamedTensor::tensor)
|
||||
.def_readwrite("name", &tpb::NamedTensor::name);
|
||||
|
||||
py::class_<tpb::GptManager>(m, "GptManager")
|
||||
.def(py::init<std::filesystem::path const&, tb::TrtGptModelType, int32_t, tb::batch_scheduler::SchedulerPolicy,
|
||||
tpb::GetInferenceRequestsCallback, tpb::SendResponseCallback, tb::PollStopSignalCallback,
|
||||
tb::ReturnBatchManagerStatsCallback, const tb::TrtGptModelOptionalParams&, std::optional<uint64_t>>(),
|
||||
py::arg("trt_engine_path"), py::arg("model_type"), py::arg("max_beam_width"), py::arg("scheduler_policy"),
|
||||
py::arg("get_inference_requests_cb"), py::arg("send_response_cb"), py::arg("poll_stop_signal_cb") = nullptr,
|
||||
py::arg("return_batch_manager_stats_cb") = nullptr,
|
||||
py::arg("optional_params") = tb::TrtGptModelOptionalParams(), py::arg("terminate_req_id") = std::nullopt)
|
||||
.def("shutdown", &tpb::GptManager::exit)
|
||||
.def("__enter__", &tpb::GptManager::enter)
|
||||
.def("__exit__", &tpb::GptManager::exit);
|
||||
}
|
||||
|
||||
103
cpp/tensorrt_llm/pybind/utils/pathCaster.h
Normal file
103
cpp/tensorrt_llm/pybind/utils/pathCaster.h
Normal file
@ -0,0 +1,103 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pybind11/cast.h"
|
||||
#include "pybind11/detail/common.h"
|
||||
#include "pybind11/detail/descr.h"
|
||||
#include "pybind11/pybind11.h"
|
||||
#include "pybind11/pytypes.h"
|
||||
#include <filesystem>
|
||||
|
||||
namespace PYBIND11_NAMESPACE
|
||||
{
|
||||
|
||||
namespace detail
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct PathCaster
|
||||
{
|
||||
|
||||
private:
|
||||
static PyObject* unicode_from_fs_native(const std::string& w)
|
||||
{
|
||||
return PyUnicode_DecodeFSDefaultAndSize(w.c_str(), ssize_t(w.size()));
|
||||
}
|
||||
|
||||
static PyObject* unicode_from_fs_native(const std::wstring& w)
|
||||
{
|
||||
return PyUnicode_FromWideChar(w.c_str(), ssize_t(w.size()));
|
||||
}
|
||||
|
||||
public:
|
||||
static handle cast(const T& path, return_value_policy, handle)
|
||||
{
|
||||
if (auto py_str = unicode_from_fs_native(path.native()))
|
||||
{
|
||||
return module_::import("pathlib").attr("Path")(reinterpret_steal<object>(py_str)).release();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
bool load(handle handle, bool)
|
||||
{
|
||||
PyObject* native = nullptr;
|
||||
if constexpr (std::is_same_v<typename T::value_type, char>)
|
||||
{
|
||||
if (PyUnicode_FSConverter(handle.ptr(), &native) != 0)
|
||||
{
|
||||
if (auto* c_str = PyBytes_AsString(native))
|
||||
{
|
||||
// AsString returns a pointer to the internal buffer, which
|
||||
// must not be free'd.
|
||||
value = c_str;
|
||||
}
|
||||
}
|
||||
}
|
||||
else if constexpr (std::is_same_v<typename T::value_type, wchar_t>)
|
||||
{
|
||||
if (PyUnicode_FSDecoder(handle.ptr(), &native) != 0)
|
||||
{
|
||||
if (auto* c_str = PyUnicode_AsWideCharString(native, nullptr))
|
||||
{
|
||||
// AsWideCharString returns a new string that must be free'd.
|
||||
value = c_str; // Copies the string.
|
||||
PyMem_Free(c_str);
|
||||
}
|
||||
}
|
||||
}
|
||||
Py_XDECREF(native);
|
||||
if (PyErr_Occurred())
|
||||
{
|
||||
PyErr_Clear();
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
PYBIND11_TYPE_CASTER(T, const_name("os.PathLike"));
|
||||
};
|
||||
|
||||
template <>
|
||||
struct type_caster<std::filesystem::path> : public PathCaster<std::filesystem::path>
|
||||
{
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace PYBIND11_NAMESPACE
|
||||
@ -130,6 +130,11 @@ typename tl::DynamicDecodeLayer<T>::ForwardParams prepareInputs(DecodingInput co
|
||||
forwardParams.stop_words_list = tcc::toTllmTensor(*input.stopWordsList);
|
||||
}
|
||||
|
||||
if (input.finished)
|
||||
{
|
||||
forwardParams.finished = tcc::toTllmTensor(*input.finished);
|
||||
}
|
||||
|
||||
return forwardParams;
|
||||
}
|
||||
|
||||
@ -335,6 +340,47 @@ void GptDecoder<T>::gatherTree(ITensor& finalOutputIds, DecodingOutput const& de
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void IGptDecoder::acceptTokens(const ITensor& targetTokenIds, const ITensor& draftTokenIds,
|
||||
const ITensor& contextLengths, const ITensor& numDraftTokens, ITensor& sequenceLengths, const ITensor& finishedVec,
|
||||
ITensor& finishedFinal, ITensor& finishedSum, BufferManager::CudaStreamPtr const& stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const targetTokenIdsShape = targetTokenIds.getShape();
|
||||
auto const batchSize = targetTokenIdsShape.d[0];
|
||||
auto const beamWidth = targetTokenIdsShape.d[1];
|
||||
auto const maxSeqLength = targetTokenIdsShape.d[2];
|
||||
auto const maxDraftTokens = draftTokenIds.getShape().d[2];
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
beamWidth == 1, common::fmtstr("Beam width (%d) > 1 is not supported for the speculative decoding", beamWidth));
|
||||
|
||||
TLLM_CHECK_WITH_INFO(draftTokenIds.getShape().d[0] == batchSize,
|
||||
common::fmtstr("Draft tokens batch size (%d) is not equal to target batch size (%d)",
|
||||
draftTokenIds.getShape().d[0], batchSize));
|
||||
|
||||
TLLM_CHECK_WITH_INFO(contextLengths.getShape().d[0] == batchSize,
|
||||
common::fmtstr("Context length batch size (%d) is not equal to batch size (%d)", contextLengths.getShape().d[0],
|
||||
batchSize));
|
||||
|
||||
TLLM_CHECK_WITH_INFO(numDraftTokens.getShape().d[0] == batchSize,
|
||||
common::fmtstr("Num draft tokens batch size (%d) is not equal to batch size (%d)",
|
||||
numDraftTokens.getShape().d[0], batchSize));
|
||||
|
||||
TLLM_CHECK_WITH_INFO(sequenceLengths.getShape().d[0] == batchSize,
|
||||
common::fmtstr("Sequence length batch size (%d) is not equal to batch size (%d)",
|
||||
sequenceLengths.getShape().d[0], batchSize));
|
||||
|
||||
tensorrt_llm::kernels::invokeAcceptTokens(bufferCast<SizeType>(draftTokenIds), bufferCast<SizeType>(targetTokenIds),
|
||||
bufferCast<SizeType>(contextLengths), bufferCast<SizeType>(numDraftTokens),
|
||||
bufferCast<SizeType>(sequenceLengths), bufferCast<bool>(finishedVec), bufferCast<bool>(finishedFinal),
|
||||
bufferCast<int>(finishedSum), batchSize, beamWidth, maxSeqLength, maxDraftTokens, stream->get());
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
template class GptDecoder<float>;
|
||||
|
||||
@ -92,33 +92,42 @@ GptDecoderBatch::GptDecoderBatch(
|
||||
auto outputIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
dOutput = std::make_unique<DecodingOutput>(std::move(outputIds));
|
||||
|
||||
dOutput->newTokens = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
dOutput->newTokensSteps = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
dOutput->parentIds = mBufferManager.emptyTensor(MemoryType::kGPU, nvTokenIdType);
|
||||
dOutput->finished = mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<bool>::value);
|
||||
dOutput->finishedSteps = mBufferManager.emptyTensor(MemoryType::kGPU, TRTDataType<bool>::value);
|
||||
// use batchSize many entries instead of the usual 1
|
||||
dOutput->finishedSum = mBufferManager.emptyTensor(MemoryType::kPINNED, nvSizeType);
|
||||
mFinishedSum = mBufferManager.pinned(ITensor::makeShape({1}), nvSizeType);
|
||||
dOutput->lengths = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
// we don't need dOutput->lengths because lengths are passed from outside
|
||||
dOutput->cumLogProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
dOutput->logProbs = mBufferManager.emptyTensor(MemoryType::kGPU, nvFloatType);
|
||||
dOutput->beamHypotheses.empty(mBufferManager);
|
||||
|
||||
mNumDraftTokens = mBufferManager.emptyTensor(MemoryType::kGPU, nvSizeType);
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
|
||||
SizeType maxSequenceLength, nvinfer1::DataType dtype)
|
||||
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(maxBatchSize > 0);
|
||||
TLLM_CHECK(maxBeamWidth > 0);
|
||||
TLLM_CHECK(maxTokensPerStep > 0);
|
||||
TLLM_CHECK(maxSequenceLength > 0);
|
||||
|
||||
mActualBatchSize = maxBatchSize;
|
||||
mGeneratedTokensPerStep.resize(maxBatchSize);
|
||||
mMaxSequenceLength = maxSequenceLength;
|
||||
mMaxKvCacheLength = maxKvCacheLength;
|
||||
mMaxTokensPerStep = maxTokensPerStep;
|
||||
|
||||
auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize});
|
||||
auto const maxBatchSizeXmaxBeamWidth = ITensor::makeShape({maxBatchSize, maxBeamWidth});
|
||||
auto const maxBatchSizeXmaxTokensPerStepXmaxBeamWidth
|
||||
= ITensor::makeShape({maxBatchSize, maxTokensPerStep, maxBeamWidth});
|
||||
auto const maxTokensPerStepXmaxBatchSizeXmaxBeamWidth
|
||||
= ITensor::makeShape({maxTokensPerStep, maxBatchSize, maxBeamWidth});
|
||||
|
||||
auto& dInput = *mJointDecodingInput;
|
||||
const_cast<ITensor&>(*dInput.endIds).reshape(maxBatchSizeXmaxBeamWidth);
|
||||
@ -133,14 +142,13 @@ void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeTy
|
||||
|
||||
auto& dOutput = *mJointDecodingOutput;
|
||||
dOutput.ids->reshape(jointOutputIdsShape);
|
||||
dOutput.newTokens->reshape(maxBatchSizeXmaxBeamWidth);
|
||||
mBufferManager.setZero(*dOutput.newTokens);
|
||||
|
||||
dOutput.newTokensSteps->reshape(maxTokensPerStepXmaxBatchSizeXmaxBeamWidth);
|
||||
mBufferManager.setZero(*dOutput.newTokensSteps);
|
||||
dOutput.finishedSteps->reshape(maxBatchSizeXmaxTokensPerStepXmaxBeamWidth);
|
||||
mBufferManager.setZero(*dOutput.finishedSteps);
|
||||
|
||||
dOutput.parentIds->reshape(jointOutputIdsShape);
|
||||
dOutput.lengths->reshape(maxBatchSizeXmaxBeamWidth);
|
||||
mBufferManager.setZero(*dOutput.lengths);
|
||||
dOutput.finished->reshape(maxBatchSizeXmaxBeamWidth);
|
||||
mBufferManager.setZero(*dOutput.finished);
|
||||
mBufferManager.setZero(*dOutput.finishedSum);
|
||||
// use batchSize many entries instead of the usual 1
|
||||
dOutput.finishedSum->reshape(maxBatchSizeShape);
|
||||
mBufferManager.setZero(*dOutput.finishedSum);
|
||||
@ -160,6 +168,10 @@ void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeTy
|
||||
dOutput.beamHypotheses.release();
|
||||
}
|
||||
|
||||
// speculative decoding only works for beam width == 1
|
||||
mDraftTokenIds.resize(maxBatchSize);
|
||||
mNumDraftTokens->reshape(ITensor::makeShape({maxBatchSize, 1}));
|
||||
|
||||
mStreams.resize(maxBatchSize);
|
||||
mDecoders.resize(maxBatchSize);
|
||||
mDecodingInputs.resize(maxBatchSize);
|
||||
@ -180,6 +192,7 @@ void GptDecoderBatch::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeTy
|
||||
mFinished[i] = true;
|
||||
mMaxNewTokens[i] = 0;
|
||||
mBeamWidths[i] = 0;
|
||||
mGeneratedTokensPerStep[i] = 0;
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -198,7 +211,7 @@ void GptDecoderBatch::newRequest(
|
||||
tc::fmtstr("Beam width (%d) must be smaller than maxBeamWidth (%d) passed to decoder setup function.",
|
||||
beamWidth, maxBeamWidth));
|
||||
auto const& requestIds = request.ids;
|
||||
auto const inputLength = requestIds->getShape().d[0];
|
||||
auto const inputLength = request.inputLen;
|
||||
auto const maxNewTokens = request.maxNewTokens.value_or(mMaxSequenceLength - inputLength);
|
||||
TLLM_CHECK_WITH_INFO(inputLength + maxNewTokens <= mMaxSequenceLength,
|
||||
tc::fmtstr("Input length (%d) + max new tokens (%d) must be less than max sequence length (%d).", inputLength,
|
||||
@ -258,14 +271,21 @@ void GptDecoderBatch::newRequest(
|
||||
outputIds->reshape(outputIdsShape);
|
||||
dOutput = std::make_unique<DecodingOutput>(outputIds);
|
||||
|
||||
dOutput->finished = ITensor::slice(dJointOutput.finished, batchIdx, localBatchSize);
|
||||
manager.setZero(*dOutput->finished);
|
||||
dOutput->finishedSum = ITensor::slice(dJointOutput.finishedSum, batchIdx, localBatchSize);
|
||||
manager.setZero(*dOutput->finishedSum);
|
||||
dOutput->lengths = ITensor::slice(dJointOutput.lengths, batchIdx, localBatchSize);
|
||||
kernels::invokeFill(*dOutput->lengths, inputLength, *stream);
|
||||
dOutput->newTokens = ITensor::slice(dJointOutput.newTokens, batchIdx, localBatchSize);
|
||||
manager.setZero(*dOutput->newTokens);
|
||||
|
||||
dOutput->newTokensVec.resize(mMaxTokensPerStep);
|
||||
for (SizeType ti = 0; ti < mMaxTokensPerStep; ++ti)
|
||||
{
|
||||
TensorPtr newTokensStepView = std::move(ITensor::slice(dJointOutput.newTokensSteps, ti, localBatchSize));
|
||||
newTokensStepView->squeeze(0);
|
||||
dOutput->newTokensVec[ti] = ITensor::slice(newTokensStepView, batchIdx, localBatchSize);
|
||||
manager.setZero(*dOutput->newTokensVec[ti]);
|
||||
}
|
||||
|
||||
dOutput->finishedSteps = ITensor::slice(dJointOutput.finishedSteps, batchIdx, localBatchSize);
|
||||
manager.setZero(*dOutput->finishedSteps);
|
||||
dOutput->finishedSteps->squeeze(0);
|
||||
|
||||
// cumLogProb is mandatory for beamWidth > 1
|
||||
dOutput->cumLogProbs = nullptr;
|
||||
@ -293,12 +313,24 @@ void GptDecoderBatch::newRequest(
|
||||
dOutput->beamHypotheses.init(manager, endId);
|
||||
}
|
||||
|
||||
auto generatedTokensPerStep = request.generatedTokensPerStep();
|
||||
if (generatedTokensPerStep > 1)
|
||||
{
|
||||
auto numDraftTokens = generatedTokensPerStep - 1;
|
||||
TensorPtr draftTokensView = ITensor::view(request.draftTokens, ITensor::makeShape({1, 1, numDraftTokens}));
|
||||
mDraftTokenIds[batchIdx] = draftTokensView;
|
||||
|
||||
auto numDraftTokensView = ITensor::slice(mNumDraftTokens, batchIdx, localBatchSize);
|
||||
kernels::invokeFill(*numDraftTokensView, numDraftTokens, *stream);
|
||||
}
|
||||
|
||||
// remaining
|
||||
mDecoders[batchIdx]->setup(samplingConfig, localBatchSize, mMaxSequenceLength);
|
||||
mBeamWidths[batchIdx] = beamWidth;
|
||||
mNbSteps[batchIdx] = 0;
|
||||
mFinished[batchIdx] = false;
|
||||
mMaxNewTokens[batchIdx] = maxNewTokens;
|
||||
mGeneratedTokensPerStep[batchIdx] = generatedTokensPerStep;
|
||||
|
||||
// copy the request ids into outputIds
|
||||
auto inputIdsView = ITensor::view(requestIds, ITensor::makeShape({localBatchSize, inputLength}));
|
||||
@ -312,14 +344,11 @@ GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync(
|
||||
decoder_batch::Output& output, decoder_batch::Input const& input)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
auto& logits = input.logits;
|
||||
auto const& logitsShape = logits->getShape();
|
||||
auto& allLogits = input.logits;
|
||||
|
||||
TLLM_CHECK(logitsShape.d[0] == mActualBatchSize);
|
||||
// TODO(nkorobov): check logits shape considering draft tokens
|
||||
auto const& jointOutputIdsShape = mJointDecodingOutput->ids->getShape();
|
||||
auto const maxBeamWidth = jointOutputIdsShape.d[1];
|
||||
TLLM_CHECK(logitsShape.d[1] == maxBeamWidth);
|
||||
TLLM_CHECK(static_cast<std::size_t>(logitsShape.d[2]) == mVocabSizePadded);
|
||||
|
||||
auto& srcCacheIndirection = input.cacheIndirection;
|
||||
auto& tgtCacheIndirection = output.cacheIndirection;
|
||||
@ -328,70 +357,83 @@ GptDecoderBatch::TokenPtr GptDecoderBatch::forwardAsync(
|
||||
TLLM_CHECK(!srcCacheIndirection || srcCacheIndirection->getDataType() == TRTDataType<SizeType>::value);
|
||||
TLLM_CHECK(!tgtCacheIndirection || tgtCacheIndirection->getDataType() == TRTDataType<SizeType>::value);
|
||||
|
||||
TLLM_CHECK(static_cast<SizeType>(output.sequenceLengths->getSize()) == mActualBatchSize * maxBeamWidth);
|
||||
// TODO should remove this reshape and set shape to [batch_size, beam_width] outside
|
||||
TensorPtr sequenceLengths = ITensor::view(output.sequenceLengths);
|
||||
sequenceLengths->reshape(ITensor::makeShape({mActualBatchSize, maxBeamWidth}));
|
||||
TensorPtr sequenceLengths
|
||||
= ITensor::view(output.sequenceLengths, ITensor::makeShape({mActualBatchSize, maxBeamWidth}));
|
||||
TLLM_CHECK(sequenceLengths);
|
||||
auto constexpr singleRequest = 1;
|
||||
|
||||
CudaEvent eventStart{};
|
||||
mStream->record(eventStart);
|
||||
for (std::int32_t i = 0; i < mActualBatchSize; ++i)
|
||||
for (std::int32_t bi = 0; bi < mActualBatchSize; ++bi)
|
||||
{
|
||||
if (mFinished[i] || !input.active.at(i))
|
||||
if (mFinished[bi] || !input.active.at(bi))
|
||||
continue;
|
||||
|
||||
auto& stream = mStreams[i];
|
||||
auto& logits = allLogits[bi];
|
||||
auto const& logitsShape = logits->getShape();
|
||||
TLLM_CHECK_WITH_INFO(logitsShape.d[0] == mGeneratedTokensPerStep[bi],
|
||||
tc::fmtstr(
|
||||
"First dim (%d) does not match generated tokens (%d)", logitsShape.d[0], mGeneratedTokensPerStep[bi]));
|
||||
TLLM_CHECK_WITH_INFO(logitsShape.d[1] == mBeamWidths[bi],
|
||||
tc::fmtstr("Second dim (%d) does not match beam width (%d)", logitsShape.d[1], mBeamWidths[bi]));
|
||||
TLLM_CHECK(static_cast<std::size_t>(logitsShape.d[2]) == mVocabSizePadded);
|
||||
|
||||
auto& stream = mStreams[bi];
|
||||
stream->wait(eventStart.get());
|
||||
auto& dInput = *mDecodingInputs[i];
|
||||
auto& dOutput = *mDecodingOutputs[i];
|
||||
auto logitsView = std::shared_ptr(ITensor::slice(logits, i, singleRequest));
|
||||
dInput.logits
|
||||
= ITensor::view(logitsView, ITensor::makeShape({singleRequest, mBeamWidths[i], logitsShape.d[2]}));
|
||||
auto& dInput = *mDecodingInputs[bi];
|
||||
auto& dOutput = *mDecodingOutputs[bi];
|
||||
auto& decoder = *mDecoders[bi];
|
||||
|
||||
if (srcCacheIndirection && tgtCacheIndirection)
|
||||
{
|
||||
auto srcView = std::shared_ptr(ITensor::slice(srcCacheIndirection, i, singleRequest));
|
||||
auto tgtView = std::shared_ptr(ITensor::slice(tgtCacheIndirection, i, singleRequest));
|
||||
dInput.cacheIndirection
|
||||
= ITensor::view(srcView, ITensor::makeShape({singleRequest, mBeamWidths[i], srcView->getShape().d[2]}));
|
||||
dOutput.cacheIndirection
|
||||
= ITensor::view(tgtView, ITensor::makeShape({singleRequest, mBeamWidths[i], tgtView->getShape().d[2]}));
|
||||
auto srcView = std::shared_ptr(ITensor::slice(srcCacheIndirection, bi, singleRequest));
|
||||
auto tgtView = std::shared_ptr(ITensor::slice(tgtCacheIndirection, bi, singleRequest));
|
||||
dInput.cacheIndirection = ITensor::view(
|
||||
srcView, ITensor::makeShape({singleRequest, mBeamWidths[bi], srcView->getShape().d[2]}));
|
||||
dOutput.cacheIndirection = ITensor::view(
|
||||
tgtView, ITensor::makeShape({singleRequest, mBeamWidths[bi], tgtView->getShape().d[2]}));
|
||||
}
|
||||
auto sequenceLengthsView = std::shared_ptr(ITensor::slice(sequenceLengths, i, singleRequest));
|
||||
dOutput.lengths = ITensor::view(sequenceLengthsView, ITensor::makeShape({singleRequest, mBeamWidths[i]}));
|
||||
|
||||
auto& decoder = *mDecoders[i];
|
||||
decoder.forwardAsync(dOutput, dInput);
|
||||
auto sequenceLengthsView = std::shared_ptr(ITensor::slice(sequenceLengths, bi, singleRequest));
|
||||
dOutput.lengths = ITensor::view(sequenceLengthsView, ITensor::makeShape({singleRequest, mBeamWidths[bi]}));
|
||||
|
||||
auto manager = BufferManager{stream};
|
||||
|
||||
auto jointOutputIdsView = ITensor::slice(mJointDecodingOutput->ids, i, singleRequest);
|
||||
auto const& jointOutputShape = jointOutputIdsView->getShape();
|
||||
// squeeze dim 0 and set beamWidth
|
||||
jointOutputIdsView->reshape(ITensor::makeShape({mBeamWidths[i], jointOutputShape.d[2]}));
|
||||
|
||||
manager.copy(*dOutput.ids, *jointOutputIdsView);
|
||||
|
||||
auto jointSequenceLengthsView = ITensor::slice(mJointDecodingOutput->lengths, i, singleRequest);
|
||||
jointSequenceLengthsView->reshape(ITensor::makeShape({1, mBeamWidths[i]}));
|
||||
manager.copy(*dOutput.lengths, *jointSequenceLengthsView);
|
||||
|
||||
if (mBeamWidths[i] > 1)
|
||||
for (std::int32_t di = 0; di < mGeneratedTokensPerStep[bi]; ++di)
|
||||
{
|
||||
auto jointOutputParentIdsView = ITensor::slice(mJointDecodingOutput->parentIds, i, singleRequest);
|
||||
auto const& jointOutputParentIdsShape = jointOutputParentIdsView->getShape();
|
||||
// squeeze dim 0 and set beamWidth
|
||||
jointOutputParentIdsView->reshape(ITensor::makeShape({mBeamWidths[i], jointOutputParentIdsShape.d[2]}));
|
||||
dInput.logits = ITensor::slice(logits, di, singleRequest);
|
||||
dOutput.newTokens = ITensor::view(dOutput.newTokensVec[di]);
|
||||
dInput.finished = ITensor::slice(dOutput.finishedSteps, di, 1);
|
||||
dOutput.finished
|
||||
= ITensor::slice(dOutput.finishedSteps, std::min(di + 1, mGeneratedTokensPerStep[bi] - 1), 1);
|
||||
|
||||
manager.copy(*dOutput.parentIds, *jointOutputParentIdsView);
|
||||
decoder.forwardAsync(dOutput, dInput);
|
||||
|
||||
mNbSteps[bi] += 1;
|
||||
mFinished[bi] = mNbSteps[bi] >= mMaxNewTokens[bi];
|
||||
dInput.step += 1;
|
||||
}
|
||||
|
||||
if (mGeneratedTokensPerStep[bi] > 1)
|
||||
{
|
||||
auto draftTokenIds = mDraftTokenIds[bi];
|
||||
auto numDraftTokens = ITensor::slice(mNumDraftTokens, bi, singleRequest);
|
||||
// Update finished state for 0th step
|
||||
auto finishedFinal = ITensor::slice(dOutput.finishedSteps, 0, 1);
|
||||
IGptDecoder::acceptTokens(
|
||||
/* [bs=1, bw=1, max_seq_len] */ *dOutput.ids,
|
||||
/* [bs, bw, max_draft_tokens] */ *draftTokenIds,
|
||||
/* [bs, bw] */ *dInput.lengths,
|
||||
/* [bs, bw] */ *numDraftTokens,
|
||||
/* [bs, bw] */ *dOutput.lengths,
|
||||
/* [max_draft_tokens, bs, bw] */ *dOutput.finishedSteps,
|
||||
/* [bs, bw] */ *finishedFinal,
|
||||
/* [1] */ *dOutput.finishedSum, stream);
|
||||
}
|
||||
|
||||
CudaEvent event{};
|
||||
stream->record(event);
|
||||
mStream->wait(event);
|
||||
dInput.step += 1;
|
||||
mNbSteps[i] += 1;
|
||||
mFinished[i] = mNbSteps[i] >= mMaxNewTokens[i];
|
||||
}
|
||||
|
||||
CudaEvent eventStop{};
|
||||
@ -412,7 +454,7 @@ void GptDecoderBatch::forwardSync(decoder_batch::Token const& token)
|
||||
auto& dOutput = *mDecodingOutputs[i];
|
||||
mFinished[i] = mFinished[i]
|
||||
// This condition requires the synchronization above
|
||||
|| *bufferCast<SizeType>(*dOutput.finishedSum) == static_cast<SizeType>(dOutput.finished->getSize());
|
||||
|| *bufferCast<SizeType>(*dOutput.finishedSum) == static_cast<SizeType>(dOutput.lengths->getSize());
|
||||
}
|
||||
}
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
@ -449,6 +491,7 @@ void GptDecoderBatch::newBatch(
|
||||
// split batch into single requests
|
||||
auto const& inputLengths = inputs.lengths;
|
||||
mActualBatchSize = inputLengths->getShape().d[0];
|
||||
mGeneratedTokensPerStep.resize(mActualBatchSize);
|
||||
|
||||
auto const& jointOutputIdsShape = mJointDecodingOutput->ids->getShape();
|
||||
auto const maxBatchSize = jointOutputIdsShape.d[0];
|
||||
@ -464,6 +507,7 @@ void GptDecoderBatch::newBatch(
|
||||
auto inputOffset = 0;
|
||||
for (auto batchIdx = 0; batchIdx < mActualBatchSize; ++batchIdx)
|
||||
{
|
||||
mGeneratedTokensPerStep[batchIdx] = 1;
|
||||
auto const inputLength = inputLengthsPtr[batchIdx];
|
||||
auto const inputShape = ITensor::makeShape({inputLength});
|
||||
TensorPtr inputView;
|
||||
@ -477,8 +521,7 @@ void GptDecoderBatch::newBatch(
|
||||
inputView = ITensor::slice(inputs.ids, batchIdx, 1);
|
||||
inputView->reshape(inputShape);
|
||||
}
|
||||
|
||||
auto request = decoder_batch::Request{inputView, inputs.maxNewTokens, inputs.endId};
|
||||
auto request = decoder_batch::Request{inputView, inputLength, inputs.maxNewTokens, inputs.endId};
|
||||
request.computeCumLogProbs = (outputs.cumLogProbs != nullptr);
|
||||
request.computeLogProbs = (outputs.logProbs != nullptr);
|
||||
|
||||
@ -515,7 +558,20 @@ void GptDecoderBatch::newBatch(
|
||||
void GptDecoderBatch::forwardAsync(decoder::Output& output, decoder::Input const& input)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
decoder_batch::Input batchInput{input.logits};
|
||||
|
||||
auto const& logitsShape = input.logits->getShape();
|
||||
auto const batchSize = logitsShape.d[0];
|
||||
auto constexpr singleRequest = 1;
|
||||
std::vector<ITensor::SharedConstPtr> logits;
|
||||
logits.reserve(batchSize);
|
||||
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
|
||||
{
|
||||
auto logitsSlice = std::shared_ptr(ITensor::slice(input.logits, batchIdx, singleRequest));
|
||||
logits.emplace_back(
|
||||
ITensor::view(logitsSlice, ITensor::makeShape({singleRequest, mBeamWidths[batchIdx], logitsShape.d[2]})));
|
||||
}
|
||||
|
||||
decoder_batch::Input batchInput{logits};
|
||||
batchInput.cacheIndirection = input.cacheIndirection;
|
||||
|
||||
decoder_batch::Output batchOutput;
|
||||
|
||||
@ -102,6 +102,7 @@ GptJsonConfig parseJson(InputType&& i)
|
||||
auto const maxBatchSize = parseJsonFieldOr(builderConfig, "max_batch_size", 0);
|
||||
auto const maxInputLen = parseJsonFieldOr(builderConfig, "max_input_len", 0);
|
||||
auto const maxOutputLen = parseJsonFieldOr(builderConfig, "max_output_len", 0);
|
||||
auto const maxDraftLen = parseJsonFieldOr(builderConfig, "max_draft_len", 0);
|
||||
auto const maxNumTokens = parseJsonFieldOptional<SizeType>(builderConfig, "max_num_tokens");
|
||||
auto const maxPromptEmbeddingTableSize
|
||||
= parseJsonFieldOr<SizeType>(builderConfig, "max_prompt_embedding_table_size", 0);
|
||||
@ -132,6 +133,7 @@ GptJsonConfig parseJson(InputType&& i)
|
||||
modelConfig.setMaxInputLen(maxInputLen);
|
||||
modelConfig.setMaxOutputLen(maxOutputLen);
|
||||
modelConfig.setMaxNumTokens(maxNumTokens);
|
||||
modelConfig.setMaxDraftLen(maxDraftLen);
|
||||
modelConfig.setMaxPromptEmbeddingTableSize(maxPromptEmbeddingTableSize);
|
||||
|
||||
if (name == std::string("chatglm_6b") || name == std::string("glm_10b"))
|
||||
|
||||
@ -133,7 +133,9 @@ void GptSession::createDecoders(SizeType batchSize, SizeType beamWidth, SizeType
|
||||
mDecoders.emplace_back(std::make_shared<GptDecoderBatch>(vocabSize, vocabSizePadded, stream));
|
||||
else
|
||||
mDecoders.emplace_back(std::make_shared<StatefulGptDecoder>(vocabSize, vocabSizePadded, stream));
|
||||
mDecoders.back()->setup(batchSize, beamWidth, maxKvCacheLength, maxSequenceLength, logitsType);
|
||||
constexpr SizeType maxTokensPerStep = 1;
|
||||
mDecoders.back()->setup(
|
||||
batchSize, beamWidth, maxKvCacheLength, maxSequenceLength, maxTokensPerStep, logitsType);
|
||||
}
|
||||
|
||||
TLLM_LOG_DEBUG("%s stop", __PRETTY_FUNCTION__);
|
||||
|
||||
@ -62,9 +62,10 @@ StatefulGptDecoder::StatefulGptDecoder(std::size_t vocabSize, std::size_t vocabS
|
||||
}
|
||||
|
||||
void StatefulGptDecoder::setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength,
|
||||
SizeType maxSequenceLength, nvinfer1::DataType dtype)
|
||||
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK(maxTokensPerStep == 1);
|
||||
mDecoder = IGptDecoder::create(dtype, mVocabSize, mVocabSizePadded, mStream);
|
||||
|
||||
reshapeBuffers(maxBatchSize, maxBeamWidth, maxKvCacheLength, maxSequenceLength);
|
||||
@ -102,6 +103,7 @@ void StatefulGptDecoder::reshapeBuffers(
|
||||
mBufferManager.setZero(*dOutput.newTokens);
|
||||
dOutput.parentIds->reshape(outputIdsShape);
|
||||
dOutput.finished->reshape(batchSizeXbeamWidth);
|
||||
dInput.finished = ITensor::view(dOutput.finished);
|
||||
mBufferManager.setZero(*dOutput.finished);
|
||||
mBufferManager.setZero(*dOutput.finishedSum);
|
||||
|
||||
|
||||
@ -40,7 +40,7 @@ public:
|
||||
|
||||
//! Setup the decoder before calling `forward()`
|
||||
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxKvCacheLength, SizeType maxSequenceLength,
|
||||
nvinfer1::DataType dtype) override;
|
||||
SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
|
||||
|
||||
//! @brief Initialize the decoder with new batch of inputs.
|
||||
void newBatch(
|
||||
@ -53,6 +53,7 @@ public:
|
||||
//! @brief Gather final results for all requests.
|
||||
void finalize() const override;
|
||||
|
||||
//! @param step index within tokens generated in one step
|
||||
//! @returns [batchSize, maxBeamWidth, maxInputLength + maxNewTokens], contains input token ids and generated token
|
||||
//! ids without padding, on gpu
|
||||
[[nodiscard]] TensorPtr getOutputIds() const override
|
||||
@ -72,12 +73,24 @@ public:
|
||||
return mDecodingOutput->logProbs;
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
|
||||
[[nodiscard]] TensorPtr getNewTokens() const override
|
||||
//! @brief Get tokens generated in one step of last forward pass
|
||||
//! @param iter The iteration within [0; maxTokensPerStep) for which to get the tokens
|
||||
//! @returns [batchSize, beamWidth], tokens generated in `iter` (per beam), on gpu
|
||||
[[nodiscard]] TensorPtr getNewTokens(SizeType iter = 0) const override
|
||||
{
|
||||
TLLM_CHECK(iter == 0);
|
||||
return mDecodingOutput->newTokens;
|
||||
}
|
||||
|
||||
//! @brief Get tokens generated in the last forward pass
|
||||
//! @returns [batchSize, maxBeamWidth], tokens generated in last forward pass, on gpu
|
||||
[[nodiscard]] TensorPtr getAllNewTokens() const override
|
||||
{
|
||||
TensorPtr newTokens = std::move(ITensor::view(mDecodingOutput->newTokensSteps));
|
||||
newTokens->unsqueeze(0);
|
||||
return newTokens;
|
||||
}
|
||||
|
||||
//! @returns [1], number of finished sequences, in pinned host memory
|
||||
[[nodiscard]] TensorPtr getNbFinished() const override
|
||||
{
|
||||
|
||||
@ -23,6 +23,7 @@
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <torch/types.h>
|
||||
|
||||
@ -111,6 +112,17 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
static at::DeviceType deviceType(runtime::MemoryType memoryType)
|
||||
{
|
||||
switch (memoryType)
|
||||
{
|
||||
case runtime::MemoryType::kGPU: return c10::kCUDA;
|
||||
case runtime::MemoryType::kCPU: [[fallthrough]];
|
||||
case runtime::MemoryType::kPINNED: [[fallthrough]];
|
||||
default: return c10::kCPU;
|
||||
}
|
||||
}
|
||||
|
||||
static at::cuda::CUDAStream stream(runtime::CudaStream& cudaStream)
|
||||
{
|
||||
return at::cuda::getStreamFromExternal(cudaStream.get(), static_cast<at::DeviceIndex>(cudaStream.getDevice()));
|
||||
|
||||
@ -144,7 +144,8 @@ void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width,
|
||||
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
th::optional<th::Tensor> src_cache_indirection_opt,
|
||||
// Outputs
|
||||
th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop, th::optional<th::Tensor> finished_opt,
|
||||
th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop,
|
||||
th::optional<th::Tensor> finished_input, th::optional<th::Tensor> finished_output,
|
||||
th::optional<th::Tensor> sequence_lengths_opt, th::optional<th::Tensor> cum_log_probs_opt,
|
||||
th::optional<th::Tensor> output_log_probs_opt, th::optional<th::Tensor> parent_ids_opt,
|
||||
th::optional<th::Tensor> tgt_cache_indirection_opt, th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
|
||||
@ -166,12 +167,13 @@ void FtDynamicDecode<T>::forward(th::Tensor& logits, // (batch_size, beam_width,
|
||||
safeUpdate<int>(bad_words_list_opt, forwardParams.bad_words_list);
|
||||
safeUpdate<int>(stop_words_list_opt, forwardParams.stop_words_list);
|
||||
safeUpdate<int>(no_repeat_ngram_size_opt, forwardParams.no_repeat_ngram_size);
|
||||
safeUpdate<int>(finished_input, forwardParams.finished);
|
||||
|
||||
auto const& output_ids_converted = convert_tensor<int>(output_token_ids);
|
||||
typename tensorrt_llm::layers::DynamicDecodeLayer<T>::OutputParams outputParams{output_ids_converted};
|
||||
outputParams.newTokens = std::move(convert_tensor<int>(newTokens));
|
||||
|
||||
safeUpdate<bool>(finished_opt, outputParams.finished);
|
||||
safeUpdate<bool>(finished_output, outputParams.finished);
|
||||
std::int32_t* finished_sum_host = nullptr;
|
||||
if (forwardParams.sequence_limit_length && outputParams.finished.has_value())
|
||||
{
|
||||
@ -280,7 +282,8 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max
|
||||
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
th::optional<th::Tensor> src_cache_indirection_opt,
|
||||
// output buffers.
|
||||
th::Tensor output_token_ids, th::Tensor newTokens, th::optional<th::Tensor> finished_opt,
|
||||
th::Tensor output_token_ids, th::Tensor newTokens, th::optional<th::Tensor> finished_input,
|
||||
th::optional<th::Tensor> finished_output,
|
||||
th::optional<th::Tensor> seuqence_lengths_opt, // length of the current sequences.
|
||||
th::optional<th::Tensor> cum_log_probs_opt, th::optional<th::Tensor> output_log_probs_opt,
|
||||
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,
|
||||
@ -329,7 +332,8 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max
|
||||
CHECK_OPTIONAL_INPUT(src_cache_indirection_opt, torch::kInt32);
|
||||
|
||||
CHECK_INPUT(output_token_ids, torch::kInt32);
|
||||
CHECK_OPTIONAL_INPUT(finished_opt, torch::kBool);
|
||||
CHECK_OPTIONAL_INPUT(finished_input, torch::kBool);
|
||||
CHECK_OPTIONAL_INPUT(finished_output, torch::kBool);
|
||||
CHECK_OPTIONAL_INPUT(seuqence_lengths_opt, torch::kInt32);
|
||||
CHECK_OPTIONAL_INPUT(cum_log_probs_opt, torch::kFloat32);
|
||||
CHECK_OPTIONAL_INPUT(output_log_probs_opt, torch::kFloat32);
|
||||
@ -345,11 +349,11 @@ th::Tensor DynamicDecodeOp::forward(th::Tensor logits, int64_t step, int64_t max
|
||||
sequence_limit_length_opt, stop_words_list_opt, bad_words_list_opt, no_repeat_ngram_size_opt,
|
||||
src_cache_indirection_opt,
|
||||
// Outputs
|
||||
output_token_ids, newTokens, should_stop, finished_opt, seuqence_lengths_opt, cum_log_probs_opt,
|
||||
output_log_probs_opt, parent_ids_opt, tgt_cache_indirection_opt, beam_hyps_output_ids_tgt_opt,
|
||||
beam_hyps_sequence_lengths_tgt_opt, beam_hyps_cum_log_probs_opt, beam_hyps_normed_scores_opt,
|
||||
beam_hyps_log_probs_opt, beam_hyps_min_normed_scores_opt, beam_hyps_num_beams_opt, beam_hyps_is_done_opt,
|
||||
use_beam_hyps);
|
||||
output_token_ids, newTokens, should_stop, finished_input, finished_output, seuqence_lengths_opt,
|
||||
cum_log_probs_opt, output_log_probs_opt, parent_ids_opt, tgt_cache_indirection_opt,
|
||||
beam_hyps_output_ids_tgt_opt, beam_hyps_sequence_lengths_tgt_opt, beam_hyps_cum_log_probs_opt,
|
||||
beam_hyps_normed_scores_opt, beam_hyps_log_probs_opt, beam_hyps_min_normed_scores_opt, beam_hyps_num_beams_opt,
|
||||
beam_hyps_is_done_opt, use_beam_hyps);
|
||||
|
||||
return should_stop;
|
||||
}
|
||||
|
||||
@ -46,10 +46,10 @@ public:
|
||||
th::optional<th::Tensor> src_cache_indirection_opt,
|
||||
// Outputs
|
||||
th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop,
|
||||
th::optional<th::Tensor> finished_opt, th::optional<th::Tensor> sequence_lengths_opt,
|
||||
th::optional<th::Tensor> cum_log_probs_opt, th::optional<th::Tensor> output_log_probs_opt,
|
||||
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,
|
||||
th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
|
||||
th::optional<th::Tensor> finished_input, th::optional<th::Tensor> finished_output,
|
||||
th::optional<th::Tensor> sequence_lengths_opt, th::optional<th::Tensor> cum_log_probs_opt,
|
||||
th::optional<th::Tensor> output_log_probs_opt, th::optional<th::Tensor> parent_ids_opt,
|
||||
th::optional<th::Tensor> tgt_cache_indirection_opt, th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
|
||||
th::optional<th::Tensor> beam_hyps_sequence_lengths_tgt_opt,
|
||||
th::optional<th::Tensor> beam_hyps_cum_log_probs_opt, th::optional<th::Tensor> beam_hyps_normed_scores_opt,
|
||||
th::optional<th::Tensor> beam_hyps_log_probs_opt, th::optional<th::Tensor> beam_hyps_min_normed_scores_opt,
|
||||
@ -84,10 +84,10 @@ public:
|
||||
th::optional<th::Tensor> src_cache_indirection_opt,
|
||||
// Outputs
|
||||
th::Tensor& output_token_ids, th::Tensor& newTokens, th::Tensor& should_stop,
|
||||
th::optional<th::Tensor> finished_opt, th::optional<th::Tensor> sequence_lengths_opt,
|
||||
th::optional<th::Tensor> cum_log_probs_opt, th::optional<th::Tensor> output_log_probs_opt,
|
||||
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,
|
||||
th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
|
||||
th::optional<th::Tensor> finished_input, th::optional<th::Tensor> finished_output,
|
||||
th::optional<th::Tensor> sequence_lengths_opt, th::optional<th::Tensor> cum_log_probs_opt,
|
||||
th::optional<th::Tensor> output_log_probs_opt, th::optional<th::Tensor> parent_ids_opt,
|
||||
th::optional<th::Tensor> tgt_cache_indirection_opt, th::optional<th::Tensor> beam_hyps_output_ids_tgt_opt,
|
||||
th::optional<th::Tensor> beam_hyps_sequence_lengths_tgt_opt,
|
||||
th::optional<th::Tensor> beam_hyps_cum_log_probs_opt, th::optional<th::Tensor> beam_hyps_normed_scores_opt,
|
||||
th::optional<th::Tensor> beam_hyps_log_probs_opt, th::optional<th::Tensor> beam_hyps_min_normed_scores_opt,
|
||||
@ -128,7 +128,8 @@ public:
|
||||
th::optional<th::Tensor> bad_words_list_opt, th::optional<th::Tensor> no_repeat_ngram_size_opt,
|
||||
th::optional<th::Tensor> src_cache_indirection_opt,
|
||||
// output buffers.
|
||||
th::Tensor output_token_ids, th::Tensor newTokens, th::optional<th::Tensor> finished_opt,
|
||||
th::Tensor output_token_ids, th::Tensor newTokens, th::optional<th::Tensor> finished_input,
|
||||
th::optional<th::Tensor> finished_output,
|
||||
th::optional<th::Tensor> seuqence_lengths_opt, // length of the current sequences.
|
||||
th::optional<th::Tensor> cum_log_probs_opt, th::optional<th::Tensor> output_log_probs_opt,
|
||||
th::optional<th::Tensor> parent_ids_opt, th::optional<th::Tensor> tgt_cache_indirection_opt,
|
||||
|
||||
@ -87,6 +87,7 @@ set(SAMPLING_KERNEL_TEST_SRC
|
||||
kernels/sampling/samplingUtilsTest.cu)
|
||||
add_gtest(samplingKernelsTest "${SAMPLING_KERNEL_TEST_SRC}")
|
||||
add_gtest(weightOnlyKernelTest kernels/weightOnly/weightOnlyKernelTest.cpp)
|
||||
add_gtest(decodingKernelsTest kernels/decodingKernelTest.cpp)
|
||||
|
||||
if(BUILD_BATCH_MANAGER)
|
||||
add_subdirectory(batch_manager)
|
||||
|
||||
160
cpp/tests/kernels/decodingKernelTest.cpp
Normal file
160
cpp/tests/kernels/decodingKernelTest.cpp
Normal file
@ -0,0 +1,160 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifndef TOP_LEVEL_DIR
|
||||
#error "Define TOP_LEVEL_DIR"
|
||||
#endif
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "tensorrt_llm/kernels/decodingKernels.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include <random>
|
||||
|
||||
namespace tk = tensorrt_llm::kernels;
|
||||
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
void runAcceptedTokensTest(SizeType seed)
|
||||
{
|
||||
constexpr SizeType batchSize{8};
|
||||
constexpr SizeType beamWidth{1};
|
||||
constexpr SizeType maxSeqLen{16};
|
||||
constexpr SizeType vocabSize{32};
|
||||
constexpr SizeType maxDraftTokens{8};
|
||||
|
||||
auto stream = std::make_shared<CudaStream>();
|
||||
BufferManager manager(stream);
|
||||
|
||||
std::mt19937 generator(seed);
|
||||
std::uniform_int_distribution<int> contextLenDistr(0, maxSeqLen - maxDraftTokens);
|
||||
std::uniform_int_distribution<int> numDraftTokensDistr(1, maxDraftTokens);
|
||||
std::uniform_int_distribution<int> vocabDistr(1, vocabSize);
|
||||
std::uniform_real_distribution<float> acceptTokenDistr(0.f, 1.f);
|
||||
|
||||
auto draftTokens
|
||||
= manager.pinned(ITensor::makeShape({batchSize, beamWidth, maxDraftTokens}), nvinfer1::DataType::kINT32);
|
||||
auto targetTokens
|
||||
= manager.pinned(ITensor::makeShape({batchSize, beamWidth, maxSeqLen}), nvinfer1::DataType::kINT32);
|
||||
auto numsDraftTokens = manager.pinned(ITensor::makeShape({batchSize}), nvinfer1::DataType::kINT32);
|
||||
auto sequenceLengths = manager.pinned(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kINT32);
|
||||
auto contextLengths = manager.pinned(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kINT32);
|
||||
auto finishedSteps
|
||||
= manager.pinned(ITensor::makeShape({maxDraftTokens, batchSize, beamWidth}), nvinfer1::DataType::kBOOL);
|
||||
auto finishedFinal = manager.pinned(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kBOOL);
|
||||
auto finishedSum = manager.pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
|
||||
|
||||
std::vector<int> acceptedLen(batchSize * beamWidth);
|
||||
std::vector<bool> acceptedFinished(batchSize * beamWidth);
|
||||
|
||||
auto sequenceLengthsPtr = bufferCast<SizeType>(*sequenceLengths);
|
||||
auto contextLengthsPtr = bufferCast<SizeType>(*contextLengths);
|
||||
auto numsDraftTokensPtr = bufferCast<SizeType>(*numsDraftTokens);
|
||||
auto draftTokensPtr = bufferCast<SizeType>(*draftTokens);
|
||||
auto targetTokensPtr = bufferCast<SizeType>(*targetTokens);
|
||||
auto finishedStepsPtr = bufferCast<bool>(*finishedSteps);
|
||||
auto finishedFinalPtr = bufferCast<bool>(*finishedFinal);
|
||||
auto finishedSumPtr = bufferCast<SizeType>(*finishedSum);
|
||||
for (SizeType bi = 0; bi < batchSize; ++bi)
|
||||
{
|
||||
numsDraftTokensPtr[bi] = numDraftTokensDistr(generator);
|
||||
}
|
||||
for (SizeType bi = 0; bi < batchSize * beamWidth; ++bi)
|
||||
{
|
||||
const SizeType batchIdx = bi / beamWidth;
|
||||
// Randomly init context len
|
||||
contextLengthsPtr[bi] = contextLenDistr(generator);
|
||||
|
||||
// Sequence len is at most numsDraftTokensPtr[bi] away from context len (it can be closer if e.g. endId is
|
||||
// generated)
|
||||
std::uniform_int_distribution<int> realDraftTokensDistr(0, numsDraftTokensPtr[batchIdx]);
|
||||
const auto realLen = realDraftTokensDistr(generator);
|
||||
sequenceLengthsPtr[bi] = contextLengthsPtr[bi] + realLen;
|
||||
|
||||
for (int i = 0; i < realLen; ++i)
|
||||
{
|
||||
finishedStepsPtr[i * batchSize * beamWidth + bi] = false;
|
||||
}
|
||||
for (int i = realLen; i <= numsDraftTokensPtr[batchIdx]; ++i)
|
||||
{
|
||||
finishedStepsPtr[i * batchSize * beamWidth + bi] = true;
|
||||
}
|
||||
|
||||
// Init helper vector with max value
|
||||
acceptedLen[bi] = sequenceLengthsPtr[bi];
|
||||
acceptedFinished[bi] = finishedStepsPtr[realLen * batchSize * beamWidth + bi];
|
||||
}
|
||||
// Fill token arrays
|
||||
for (SizeType bi = 0; bi < batchSize * beamWidth; ++bi)
|
||||
{
|
||||
// Draft: [d0, d1, d2, ... for numsDraftTokensPtr[bi] ... , dN]
|
||||
// Target: [vocabSize + 1, vocabSize + 1, ... for contextLengthsPtr[bi] ... vocabSize + 1,
|
||||
// t0, t1, t2, ... for numsDraftTokensPtr[bi] ... , tN,
|
||||
// vocabSize + 1, vocabSize + 1, .. to maxSeqLen]
|
||||
for (SizeType si = 0; si < contextLengthsPtr[bi]; ++si)
|
||||
{
|
||||
targetTokensPtr[bi * maxSeqLen + si] = vocabSize + 1;
|
||||
}
|
||||
for (SizeType si = contextLengthsPtr[bi]; si < sequenceLengthsPtr[bi]; ++si)
|
||||
{
|
||||
const auto draftToken = vocabDistr(generator);
|
||||
const auto draftTokenIdx = si - contextLengthsPtr[bi];
|
||||
const auto targetToken
|
||||
= acceptTokenDistr(generator) < 1.f / (draftTokenIdx + 1e-6) ? draftToken : vocabDistr(generator);
|
||||
draftTokensPtr[bi * maxDraftTokens + draftTokenIdx] = draftToken;
|
||||
targetTokensPtr[bi * maxSeqLen + si] = targetToken;
|
||||
if (draftToken != targetToken)
|
||||
{
|
||||
acceptedLen[bi] = std::min(acceptedLen[bi], std::min(si + 1, maxSeqLen));
|
||||
acceptedFinished[bi] = finishedStepsPtr[draftTokenIdx * batchSize * beamWidth + bi];
|
||||
}
|
||||
}
|
||||
for (SizeType si = sequenceLengthsPtr[bi]; si < maxSeqLen; ++si)
|
||||
{
|
||||
targetTokensPtr[bi * maxSeqLen + si] = vocabSize + 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Call function
|
||||
tk::invokeAcceptTokens(draftTokensPtr, targetTokensPtr, contextLengthsPtr, numsDraftTokensPtr, sequenceLengthsPtr,
|
||||
finishedStepsPtr, finishedFinalPtr, finishedSumPtr, batchSize, beamWidth, maxSeqLen, maxDraftTokens,
|
||||
stream->get());
|
||||
|
||||
stream->synchronize();
|
||||
|
||||
// Verify seqLen for accepted tokens
|
||||
int finishedSumRef = 0;
|
||||
for (SizeType bi = 0; bi < batchSize * beamWidth; ++bi)
|
||||
{
|
||||
EXPECT_EQ(acceptedLen[bi], sequenceLengthsPtr[bi]) << " bi " << bi << " seed " << seed;
|
||||
EXPECT_EQ(acceptedFinished[bi], finishedFinalPtr[bi]) << " bi " << bi << " seed " << seed;
|
||||
finishedSumRef += static_cast<SizeType>(acceptedFinished[bi]);
|
||||
}
|
||||
EXPECT_EQ(finishedSumRef, finishedSumPtr[0]);
|
||||
}
|
||||
|
||||
TEST(DecodingKernelsTest, acceptTokensKernel)
|
||||
{
|
||||
constexpr SizeType seeds = 64;
|
||||
for (SizeType seed = 0; seed < seeds; ++seed)
|
||||
{
|
||||
runAcceptedTokensTest(seed);
|
||||
}
|
||||
}
|
||||
|
||||
} // end of namespace
|
||||
@ -44,7 +44,7 @@ protected:
|
||||
{
|
||||
size_t workspaceSize;
|
||||
tk::invokeTopKSampling<T>(nullptr, workspaceSize, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
this->mMaxTopK, 1.0f, params.vocabSize, nullptr, this->mStream->get(), params.batchSize, nullptr);
|
||||
nullptr, this->mMaxTopK, 1.0f, params.vocabSize, nullptr, this->mStream->get(), params.batchSize, nullptr);
|
||||
return workspaceSize;
|
||||
}
|
||||
|
||||
@ -59,8 +59,8 @@ protected:
|
||||
// preprocesses log_prob_buf when those are provided.
|
||||
bufferCast<T>(*this->mProbsDevice), bufferCast<int*>(*this->mIdsPtrHost),
|
||||
bufferCast<int32_t>(*this->mSeqLengthsDevice), bufferCast<bool>(*this->mFinishedDevice),
|
||||
bufferCast<float>(*this->mCumLogProbsDevice), bufferCast<float>(*this->mOutputLogProbsDevice),
|
||||
this->mCurandStatesDevice, this->mMaxTopK,
|
||||
bufferCast<bool>(*this->mFinishedDevice), bufferCast<float>(*this->mCumLogProbsDevice),
|
||||
bufferCast<float>(*this->mOutputLogProbsDevice), this->mCurandStatesDevice, this->mMaxTopK,
|
||||
hasDiffRuntimeArgs ? bufferCast<int32_t>(*this->mTopKsDevice) : nullptr, params.topP,
|
||||
hasDiffRuntimeArgs ? bufferCast<float>(*this->mTopPsDevice) : nullptr, params.vocabSize,
|
||||
bufferCast<int32_t>(*this->mEndIdsDevice), this->mStream->get(), params.batchSize,
|
||||
|
||||
@ -50,6 +50,7 @@ private:
|
||||
nullptr, // output_ids
|
||||
nullptr, // sequence_length
|
||||
nullptr, // finished_buffer
|
||||
nullptr, // finished_buffer
|
||||
nullptr, // cum_log_probs
|
||||
nullptr, // output_log_probs
|
||||
nullptr, // log_probs
|
||||
@ -69,6 +70,7 @@ private:
|
||||
nullptr, // output_ids
|
||||
nullptr, // sequence_length
|
||||
nullptr, // finished_buffer
|
||||
nullptr, // finished_buffer
|
||||
nullptr, // cum_log_probs
|
||||
nullptr, // output_log_probs
|
||||
nullptr, // log_probs
|
||||
@ -85,8 +87,8 @@ private:
|
||||
// Perform batched TopP sampling
|
||||
tk::invokeBatchTopPSampling<T>(workspaceDevice->data(), workspaceSize, cubTempStorageSize,
|
||||
bufferCast<int*>(*this->mIdsPtrHost), bufferCast<int32_t>(*this->mSeqLengthsDevice),
|
||||
bufferCast<bool>(*this->mFinishedDevice), bufferCast<float>(*this->mCumLogProbsDevice),
|
||||
bufferCast<float>(*this->mOutputLogProbsDevice),
|
||||
bufferCast<bool>(*this->mFinishedDevice), bufferCast<bool>(*this->mFinishedDevice),
|
||||
bufferCast<float>(*this->mCumLogProbsDevice), bufferCast<float>(*this->mOutputLogProbsDevice),
|
||||
// Note that the kernel needs vocab probs instead of
|
||||
// log-prob if cum_log_probs or output_log_probs are
|
||||
// provided. It's because the sampling layer already
|
||||
|
||||
@ -83,8 +83,8 @@ float benchmark_perchannel(void* act, void* weight, void* scales, void* zeros, v
|
||||
cudaEventCreate(&end);
|
||||
if constexpr (std::is_same_v<KernelFlag, CudaKernel>)
|
||||
{
|
||||
WeightOnlyParams params{reinterpret_cast<uint8_t*>(weight), scales, zeros, act, bias, out, m, n, k, group_size,
|
||||
BFlag, WeightOnlyType::PerChannel, WeightOnlyActivationFunctionType::Identity, AFlag};
|
||||
WeightOnlyParams params{reinterpret_cast<uint8_t*>(weight), scales, zeros, act, nullptr, bias, out, m, n, k,
|
||||
group_size, BFlag, WeightOnlyType::PerChannel, WeightOnlyActivationFunctionType::Identity, AFlag};
|
||||
for (int i = 0; i < warmup; ++i)
|
||||
{
|
||||
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s);
|
||||
@ -164,8 +164,8 @@ float benchmark_groupwise(void* act, void* weight, void* scales, void* zeros, vo
|
||||
cudaEventCreate(&end);
|
||||
if constexpr (std::is_same_v<KernelFlag, CudaKernel>)
|
||||
{
|
||||
WeightOnlyParams params{reinterpret_cast<uint8_t*>(weight), scales, zeros, act, bias, out, m, n, k, group_size,
|
||||
BFlag, WeightOnlyType::GroupWise, WeightOnlyActivationFunctionType::Identity, AFlag};
|
||||
WeightOnlyParams params{reinterpret_cast<uint8_t*>(weight), scales, zeros, act, nullptr, bias, out, m, n, k,
|
||||
group_size, BFlag, WeightOnlyType::GroupWise, WeightOnlyActivationFunctionType::Identity, AFlag};
|
||||
for (int i = 0; i < warmup; ++i)
|
||||
{
|
||||
tensorrt_llm::kernels::weight_only_batched_gemv_launcher(params, s);
|
||||
|
||||
@ -28,8 +28,7 @@ resources_dir = _pl.Path(
|
||||
__file__).parent.parent.parent.parent.parent / "examples/chatglm"
|
||||
sys.path.insert(0, str(resources_dir))
|
||||
|
||||
engine_target_path = _pl.Path(
|
||||
__file__).parent.parent / "models/rt_engine/chatglm"
|
||||
engine_target_path = _pl.Path(__file__).parent.parent / "models/rt_engine"
|
||||
|
||||
import build as _ecb
|
||||
|
||||
@ -67,7 +66,10 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
|
||||
|
||||
model_name_list = ["chatglm_6b", "chatglm2_6b", "chatglm3_6b"]
|
||||
hf_dir_list = [resources_dir / model_name for model_name in model_name_list]
|
||||
trt_dir = resources_dir / "trtModel"
|
||||
trt_dir_list = [
|
||||
resources_dir / ("output_" + model_name)
|
||||
for model_name in model_name_list
|
||||
]
|
||||
|
||||
run_command(
|
||||
["pip", "install", "-r",
|
||||
@ -89,14 +91,17 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
|
||||
)
|
||||
|
||||
print("\nBuilding engines")
|
||||
for model_name, hf_dir in zip(model_name_list, hf_dir_list):
|
||||
for model_name, hf_dir, trt_dir in zip(model_name_list, hf_dir_list,
|
||||
trt_dir_list):
|
||||
print("Building %s" % model_name)
|
||||
build_engine(model_name, hf_dir, trt_dir, world_size)
|
||||
|
||||
if not _Path(engine_target_path).exists():
|
||||
_Path(engine_target_path).mkdir(parents=True, exist_ok=True)
|
||||
for file in _Path(trt_dir).glob("*"):
|
||||
_shutil.move(file, engine_target_path)
|
||||
for model_name in model_name_list:
|
||||
_shutil.move(
|
||||
_Path(resources_dir) / ("output_" + model_name),
|
||||
engine_target_path / model_name)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
@ -147,7 +147,8 @@ def build_engines(model_cache: _tp.Optional[str] = None, world_size: int = 1):
|
||||
engine_dir / 'fp16-plugin-packed-paged' / tp_pp_dir, tp_size,
|
||||
'--dtype=float16', '--use_gpt_attention_plugin=float16',
|
||||
'--remove_input_padding', '--paged_kv_cache',
|
||||
'--enable_context_fmha', '--max_num_tokens=10000')
|
||||
'--enable_context_fmha_fp32_acc', '--max_num_tokens=10000',
|
||||
'--max_draft_len=5')
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
@ -48,11 +48,12 @@ def generate(model_name, batch_size, beam_width):
|
||||
args.input_text += args.input_text[0] * (batch_size - 2)
|
||||
args.beam_width = beam_width
|
||||
args.tokenizer_dir = resources_dir / model_name
|
||||
args.engine_dir = Path(__file__).parent.parent / "models/rt_engine/chatglm"
|
||||
args.engine_dir = Path(__file__).parent.parent / ("models/rt_engine/" +
|
||||
model_name)
|
||||
|
||||
tensorrt_llm.logger.set_level(args.log_level)
|
||||
|
||||
config_path = Path(args.engine_dir) / (model_name + '-config.json')
|
||||
config_path = args.engine_dir / 'config.json'
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
assert (config['builder_config']['name'] == model_name)
|
||||
@ -226,6 +227,4 @@ if __name__ == '__main__':
|
||||
generate("chatglm3_6b", batch_size=1, beam_width=1)
|
||||
generate("chatglm3_6b", batch_size=2, beam_width=1)
|
||||
generate("chatglm3_6b", batch_size=1, beam_width=2)
|
||||
#generate("glm_10b", batch_size=1, beam_width=1)
|
||||
#generate("glm_10b", batch_size=2, beam_width=1)
|
||||
print("Done.")
|
||||
|
||||
@ -14,6 +14,9 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include <algorithm>
|
||||
#include <random>
|
||||
|
||||
#include <gmock/gmock-matchers.h>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
@ -31,27 +34,109 @@ namespace tc = tensorrt_llm::common;
|
||||
namespace
|
||||
{
|
||||
|
||||
void verifyResults(BufferManager& manager, GptDecoderBatch const& decoder,
|
||||
std::vector<SamplingConfig> const& samplingConfigs, std::vector<SizeType> const& inputLengths, SizeType batchSize,
|
||||
SizeType maxBeamWidth, SizeType maxSeqLength, SizeType nbNewTokens, int tokenId, int padId)
|
||||
decoder_batch::Input prepareDecoderInputs(SizeType batchSize, SizeType maxBeamWidth, SizeType maxSeqLength,
|
||||
SizeType vocabSizePadded, nvinfer1::DataType dataType, std::vector<SamplingConfig> const& samplingConfigs,
|
||||
std::vector<SizeType> const& generatedTokensPerSteps, BufferManager& manager)
|
||||
{
|
||||
auto sequenceLengths = decoder.getOutputLengths();
|
||||
ASSERT_TRUE(sequenceLengths);
|
||||
EXPECT_EQ(sequenceLengths->getSize(), batchSize * maxBeamWidth);
|
||||
auto sequenceLengthsHost = manager.copyFrom(*sequenceLengths, MemoryType::kCPU);
|
||||
auto sequenceLengthsPtr = bufferCast<SizeType>(*sequenceLengthsHost);
|
||||
manager.getStream().synchronize();
|
||||
|
||||
for (auto b = 0; b < batchSize; ++b)
|
||||
std::vector<decoder_batch::Input::TensorPtr> logits;
|
||||
logits.reserve(batchSize);
|
||||
auto constexpr tokenId = 1;
|
||||
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
|
||||
{
|
||||
auto samplingConfig = samplingConfigs[b];
|
||||
for (auto bw = 0; bw < samplingConfig.beamWidth; ++bw)
|
||||
{
|
||||
auto index = tc::flat_index(sequenceLengths->getShape().d, b, bw);
|
||||
EXPECT_EQ(sequenceLengthsPtr[index], inputLengths[b] + nbNewTokens);
|
||||
}
|
||||
auto const beamWidth = samplingConfigs[batchIdx].beamWidth;
|
||||
logits.emplace_back(
|
||||
manager.gpu(ITensor::makeShape({generatedTokensPerSteps[batchIdx], beamWidth, vocabSizePadded}), dataType));
|
||||
manager.setZero(*logits.back());
|
||||
}
|
||||
|
||||
decoder_batch::Input inputs{logits};
|
||||
|
||||
if (maxBeamWidth > 1)
|
||||
{
|
||||
auto srcCacheIndirection
|
||||
= manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType<SizeType>::value);
|
||||
manager.setZero(*srcCacheIndirection);
|
||||
inputs.cacheIndirection = std::move(srcCacheIndirection);
|
||||
}
|
||||
|
||||
return inputs;
|
||||
}
|
||||
|
||||
decoder_batch::Output prepareDecoderOutputs(SizeType batchSize, SizeType maxBeamWidth, SizeType maxSeqLength,
|
||||
std::vector<SizeType> const& tiledInputLengths, BufferManager& manager)
|
||||
{
|
||||
decoder_batch::Output outputs{};
|
||||
|
||||
auto sequenceLengths
|
||||
= manager.copyFrom(tiledInputLengths, ITensor::makeShape({batchSize, maxBeamWidth}), MemoryType::kGPU);
|
||||
outputs.sequenceLengths = std::move(sequenceLengths);
|
||||
|
||||
if (maxBeamWidth > 1)
|
||||
{
|
||||
auto tgtCacheIndirection
|
||||
= manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType<SizeType>::value);
|
||||
manager.setZero(*tgtCacheIndirection);
|
||||
outputs.cacheIndirection = std::move(tgtCacheIndirection);
|
||||
}
|
||||
|
||||
return outputs;
|
||||
}
|
||||
|
||||
std::vector<decoder_batch::Request> prepareRequests(SizeType batchSize, SizeType maxNewTokens,
|
||||
std::vector<SizeType> const& inputLengths, std::vector<SizeType> const& generatedTokensPerSteps,
|
||||
std::vector<SizeType> const& acceptedTokensPerStep, TokenIdType tokenId, TokenIdType endId, TokenIdType padId,
|
||||
bool computeLogProbs, BufferManager& manager)
|
||||
{
|
||||
auto& stream = manager.getStream();
|
||||
|
||||
std::vector<decoder_batch::Request> requests;
|
||||
requests.reserve(batchSize);
|
||||
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
|
||||
{
|
||||
auto shape = ITensor::makeShape({inputLengths[batchIdx]});
|
||||
auto input = manager.gpu(shape, TRTDataType<SizeType>::value);
|
||||
kernels::invokeFill(*input, tokenId, stream);
|
||||
|
||||
requests.emplace_back(decoder_batch::Request{std::move(input), inputLengths[batchIdx], maxNewTokens, endId});
|
||||
if (generatedTokensPerSteps[batchIdx] > 1)
|
||||
{
|
||||
std::vector<TokenIdType> draftTokens(generatedTokensPerSteps[batchIdx] - 1);
|
||||
std::fill(draftTokens.begin(), draftTokens.begin() + acceptedTokensPerStep[batchIdx], 1023);
|
||||
requests.back().draftTokens = manager.copyFrom(draftTokens, MemoryType::kGPU);
|
||||
}
|
||||
requests.back().computeCumLogProbs = computeLogProbs;
|
||||
requests.back().computeLogProbs = computeLogProbs;
|
||||
}
|
||||
|
||||
return requests;
|
||||
}
|
||||
|
||||
void advanceSequenceLengths(std::vector<SizeType>& sequenceLengths, std::vector<SizeType> const& acceptedTokensPerStep,
|
||||
std::vector<SamplingConfig> const& samplingConfigs, SizeType batchSize, SizeType maxBeamWidth)
|
||||
{
|
||||
for (int batchIdx = 0; batchIdx < batchSize; batchIdx++)
|
||||
{
|
||||
for (int beamId = 0; beamId < samplingConfigs.at(batchIdx).beamWidth; beamId++)
|
||||
{
|
||||
sequenceLengths.at(tc::flat_index2(batchIdx, beamId, maxBeamWidth))
|
||||
+= acceptedTokensPerStep.at(batchIdx) + 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void checkSequenceLengths(
|
||||
ITensor const& sequenceLengths, std::vector<SizeType> const& expectedLengths, BufferManager& manager)
|
||||
{
|
||||
auto sequenceLengthsHost = manager.copyFrom(sequenceLengths, MemoryType::kCPU);
|
||||
auto sequenceLengthsHostRange = BufferRange<SizeType>(*sequenceLengthsHost);
|
||||
EXPECT_THAT(sequenceLengthsHostRange, ::testing::ElementsAreArray(expectedLengths));
|
||||
}
|
||||
|
||||
void verifyResults(BufferManager& manager, GptDecoderBatch const& decoder,
|
||||
std::vector<SamplingConfig> const& samplingConfigs, std::vector<SizeType> const& inputLengths,
|
||||
std::vector<SizeType> const& sequenceLengths, SizeType batchSize, SizeType maxBeamWidth, SizeType maxSeqLength,
|
||||
SizeType tokenId, SizeType padId)
|
||||
{
|
||||
auto outputsIds = decoder.getOutputIds();
|
||||
// TODO: test parentIds
|
||||
// parentIds = decoder.getParentIds();
|
||||
@ -68,30 +153,33 @@ void verifyResults(BufferManager& manager, GptDecoderBatch const& decoder,
|
||||
|
||||
for (auto b = 0; b < batchSize; ++b)
|
||||
{
|
||||
auto samplingConfig = samplingConfigs[b];
|
||||
auto samplingConfig = samplingConfigs.at(b);
|
||||
for (auto bw = 0; bw < samplingConfig.beamWidth; ++bw)
|
||||
{
|
||||
auto const result = (samplingConfig.beamWidth == 1) ? 1023 : bw;
|
||||
|
||||
auto const outputPtr = output + tc::flat_index(outputShape.d, b, bw, 0);
|
||||
auto begin = outputPtr;
|
||||
auto end = outputPtr + inputLengths[b];
|
||||
auto end = outputPtr + inputLengths.at(b);
|
||||
ASSERT_LE(begin, end) << "bad input length " << inputLengths.at(b);
|
||||
ASSERT_THAT(std::vector(begin, end), ::testing::Each(tokenId)) << "input tokens: "
|
||||
<< "b:" << b << " bw: " << bw;
|
||||
begin = end;
|
||||
end = begin + nbNewTokens;
|
||||
end = outputPtr + sequenceLengths.at(tc::flat_index2(b, bw, maxBeamWidth));
|
||||
ASSERT_LE(begin, end) << "bad seq length " << sequenceLengths.at(b);
|
||||
ASSERT_THAT(std::vector(begin, end), ::testing::Each(result)) << "new tokens: "
|
||||
<< "b:" << b << " bw: " << bw;
|
||||
begin = end;
|
||||
end = outputPtr + maxSeqLength;
|
||||
ASSERT_LE(begin, end) << "bad max length " << maxSeqLength;
|
||||
ASSERT_THAT(std::vector(begin, end), ::testing::Each(padId)) << "padding: "
|
||||
<< "b:" << b << " bw: " << bw;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void testDecoder(nvinfer1::DataType const dtype, std::vector<SamplingConfig> const& samplingConfigs, int maxBeamWidth,
|
||||
bool computeLogProbs)
|
||||
void testDecoder(nvinfer1::DataType const dtype, std::vector<SamplingConfig> const& samplingConfigs,
|
||||
SizeType maxBeamWidth, bool computeLogProbs)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
SizeType constexpr tensorParallelism{1};
|
||||
@ -109,118 +197,99 @@ void testDecoder(nvinfer1::DataType const dtype, std::vector<SamplingConfig> con
|
||||
auto streamPtr = std::make_shared<CudaStream>();
|
||||
BufferManager manager(streamPtr);
|
||||
|
||||
// create decoder
|
||||
int constexpr endId{50257};
|
||||
int constexpr padId{50257};
|
||||
TokenIdType constexpr endId{50257};
|
||||
TokenIdType constexpr padId{50257};
|
||||
|
||||
auto const dataType = modelConfig.getDataType();
|
||||
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
|
||||
auto decoder = GptDecoderBatch(vocabSize, vocabSizePadded, streamPtr);
|
||||
|
||||
// setup decoder
|
||||
auto const batchSize = static_cast<SizeType>(samplingConfigs.size());
|
||||
SizeType constexpr maxInputLength{8};
|
||||
SizeType constexpr maxNewTokens{2};
|
||||
auto constexpr maxSeqLength = maxInputLength + maxNewTokens;
|
||||
// We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache).
|
||||
auto const maxKvCacheLength = maxSeqLength;
|
||||
SizeType const maxNewTokens{2};
|
||||
auto const maxSeqLength = maxInputLength + maxNewTokens;
|
||||
SizeType constexpr maxGeneratedTokensPerStep{1};
|
||||
|
||||
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, modelConfig.getDataType());
|
||||
std::vector<SizeType> inputLengths(batchSize);
|
||||
std::iota(inputLengths.begin(), inputLengths.end(), 4);
|
||||
|
||||
std::vector<SizeType> const inputLengths{4, 5, 6, 7};
|
||||
std::vector<SizeType> tiledInputLengths;
|
||||
for (int batch_id = 0; batch_id < inputLengths.size(); batch_id++)
|
||||
for (int batchIdx = 0; batchIdx < inputLengths.size(); batchIdx++)
|
||||
{
|
||||
for (int beam_id = 0; beam_id < maxBeamWidth; beam_id++)
|
||||
for (int beamId = 0; beamId < maxBeamWidth; beamId++)
|
||||
{
|
||||
tiledInputLengths.push_back(inputLengths.at(batch_id));
|
||||
tiledInputLengths.push_back(inputLengths.at(batchIdx));
|
||||
}
|
||||
}
|
||||
|
||||
// set up inputs
|
||||
auto logits = std::shared_ptr(
|
||||
manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, vocabSizePadded}), modelConfig.getDataType()));
|
||||
manager.setZero(*logits);
|
||||
|
||||
decoder_batch::Input inputs{logits};
|
||||
if (maxBeamWidth > 1)
|
||||
std::vector<SizeType> generatedTokensPerSteps(batchSize);
|
||||
std::vector<SizeType> acceptedTokensPerStep(batchSize);
|
||||
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
|
||||
{
|
||||
auto srcCacheIndirection = std::shared_ptr(
|
||||
manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType<SizeType>::value));
|
||||
manager.setZero(*srcCacheIndirection);
|
||||
inputs.cacheIndirection = srcCacheIndirection;
|
||||
generatedTokensPerSteps[batchIdx] = maxGeneratedTokensPerStep;
|
||||
acceptedTokensPerStep[batchIdx] = generatedTokensPerSteps[batchIdx] - 1;
|
||||
}
|
||||
|
||||
// set up outputs
|
||||
decoder_batch::Output outputs{};
|
||||
|
||||
if (maxBeamWidth > 1)
|
||||
{
|
||||
auto tgtCacheIndirection = std::shared_ptr(
|
||||
manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType<SizeType>::value));
|
||||
manager.setZero(*tgtCacheIndirection);
|
||||
outputs.cacheIndirection = tgtCacheIndirection;
|
||||
}
|
||||
auto sequenceLengths
|
||||
= std::shared_ptr(manager.gpu(ITensor::makeShape({batchSize * maxBeamWidth}), TRTDataType<SizeType>::value));
|
||||
manager.copy(tiledInputLengths.data(), *sequenceLengths);
|
||||
outputs.sequenceLengths = sequenceLengths;
|
||||
|
||||
auto constexpr tokenId = 1;
|
||||
std::vector<decoder_batch::Input::TensorPtr> inputIds;
|
||||
for (auto b = 0; b < batchSize; ++b)
|
||||
{
|
||||
auto shape = ITensor::makeShape({inputLengths[b]});
|
||||
auto input = std::shared_ptr(manager.gpu(shape, TRTDataType<SizeType>::value));
|
||||
kernels::invokeFill(*input, tokenId, *streamPtr);
|
||||
inputIds.emplace_back(input);
|
||||
auto requests = prepareRequests(batchSize, maxNewTokens, inputLengths, generatedTokensPerSteps,
|
||||
acceptedTokensPerStep, tokenId, endId, padId, computeLogProbs, manager);
|
||||
|
||||
auto decoderRequest = decoder_batch::Request{inputIds[b], maxNewTokens, endId};
|
||||
decoderRequest.computeCumLogProbs = computeLogProbs;
|
||||
decoderRequest.computeLogProbs = computeLogProbs;
|
||||
decoder.newRequest(b, decoderRequest, samplingConfigs[b]);
|
||||
// set up inputs and outputs
|
||||
auto inputs = prepareDecoderInputs(batchSize, maxBeamWidth, maxSeqLength, vocabSizePadded, dataType,
|
||||
samplingConfigs, generatedTokensPerSteps, manager);
|
||||
auto outputs = prepareDecoderOutputs(batchSize, maxBeamWidth, maxSeqLength, tiledInputLengths, manager);
|
||||
|
||||
// We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache).
|
||||
auto const maxKvCacheLength = maxSeqLength;
|
||||
|
||||
// set up decoder
|
||||
auto decoder = GptDecoderBatch(vocabSize, vocabSizePadded, streamPtr);
|
||||
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, maxGeneratedTokensPerStep, dataType);
|
||||
|
||||
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
|
||||
{
|
||||
decoder.newRequest(batchIdx, requests[batchIdx], samplingConfigs[batchIdx]);
|
||||
}
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
auto const& nbSteps = decoder.getNbSteps();
|
||||
EXPECT_EQ(nbSteps.size(), batchSize);
|
||||
EXPECT_THAT(nbSteps, ::testing::Each(0));
|
||||
auto expectedLengths = tiledInputLengths;
|
||||
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
|
||||
|
||||
auto const& finished = decoder.getFinished();
|
||||
EXPECT_EQ(finished.size(), batchSize);
|
||||
EXPECT_THAT(finished, ::testing::Each(false));
|
||||
|
||||
verifyResults(
|
||||
manager, decoder, samplingConfigs, inputLengths, batchSize, maxBeamWidth, maxSeqLength, 0, tokenId, padId);
|
||||
verifyResults(manager, decoder, samplingConfigs, inputLengths, expectedLengths, batchSize, maxBeamWidth,
|
||||
maxSeqLength, tokenId, padId);
|
||||
|
||||
// run decoder for 1 step
|
||||
decoder.forward(outputs, inputs);
|
||||
EXPECT_THAT(decoder.getNbSteps(), ::testing::Each(1));
|
||||
|
||||
advanceSequenceLengths(expectedLengths, acceptedTokensPerStep, samplingConfigs, batchSize, maxBeamWidth);
|
||||
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
|
||||
EXPECT_THAT(decoder.getFinished(), ::testing::Each(false));
|
||||
|
||||
verifyResults(
|
||||
manager, decoder, samplingConfigs, inputLengths, batchSize, maxBeamWidth, maxSeqLength, 1, tokenId, padId);
|
||||
verifyResults(manager, decoder, samplingConfigs, inputLengths, expectedLengths, batchSize, maxBeamWidth,
|
||||
maxSeqLength, tokenId, padId);
|
||||
|
||||
// run decoder for 1 step
|
||||
decoder.forward(outputs, inputs);
|
||||
EXPECT_THAT(decoder.getFinished(), ::testing::Each(true));
|
||||
EXPECT_THAT(decoder.getNbSteps(), ::testing::Each(maxNewTokens));
|
||||
|
||||
verifyResults(
|
||||
manager, decoder, samplingConfigs, inputLengths, batchSize, maxBeamWidth, maxSeqLength, 2, tokenId, padId);
|
||||
advanceSequenceLengths(expectedLengths, acceptedTokensPerStep, samplingConfigs, batchSize, maxBeamWidth);
|
||||
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
|
||||
EXPECT_THAT(decoder.getFinished(), ::testing::Each(true));
|
||||
|
||||
verifyResults(manager, decoder, samplingConfigs, inputLengths, expectedLengths, batchSize, maxBeamWidth,
|
||||
maxSeqLength, tokenId, padId);
|
||||
|
||||
EXPECT_NO_THROW(decoder.forward(outputs, inputs));
|
||||
EXPECT_THAT(decoder.getNbSteps(), ::testing::Each(maxNewTokens));
|
||||
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
|
||||
|
||||
auto decoderRequest = decoder_batch::Request{inputIds[0], maxNewTokens};
|
||||
decoderRequest.computeCumLogProbs = computeLogProbs;
|
||||
decoderRequest.computeLogProbs = computeLogProbs;
|
||||
decoder.newRequest(0, decoderRequest, samplingConfigs[0]);
|
||||
decoder.newRequest(0, requests[0], samplingConfigs[0]);
|
||||
EXPECT_FALSE(decoder.getFinished()[0]);
|
||||
EXPECT_EQ(decoder.getNbSteps()[0], 0);
|
||||
}
|
||||
|
||||
void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector<SamplingConfig> const& samplingConfigs,
|
||||
int maxBeamWidth, bool computeLogProbs)
|
||||
SizeType maxBeamWidth, bool computeLogProbs)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
SizeType constexpr tensorParallelism{1};
|
||||
@ -238,105 +307,192 @@ void testDecoderWavefront(nvinfer1::DataType const dtype, std::vector<SamplingCo
|
||||
auto streamPtr = std::make_shared<CudaStream>();
|
||||
BufferManager manager(streamPtr);
|
||||
|
||||
// create decoder
|
||||
int constexpr endId{50257};
|
||||
int constexpr padId{50257};
|
||||
TokenIdType constexpr endId{50257};
|
||||
TokenIdType constexpr padId{50257};
|
||||
|
||||
auto const dataType = modelConfig.getDataType();
|
||||
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
|
||||
auto decoder = GptDecoderBatch(vocabSize, vocabSizePadded, streamPtr);
|
||||
|
||||
// setup decoder
|
||||
auto const batchSize = static_cast<SizeType>(samplingConfigs.size());
|
||||
SizeType constexpr maxInputLength{8};
|
||||
SizeType constexpr maxNewTokens{8};
|
||||
auto constexpr maxSeqLength = maxInputLength + maxNewTokens;
|
||||
// We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache).
|
||||
auto const maxKvCacheLength = maxSeqLength;
|
||||
SizeType constexpr maxGeneratedTokensPerStep{1};
|
||||
|
||||
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, modelConfig.getDataType());
|
||||
std::vector<SizeType> inputLengths(batchSize);
|
||||
std::iota(inputLengths.begin(), inputLengths.end(), 4);
|
||||
|
||||
std::vector<SizeType> const inputLengths{4, 5, 6, 7};
|
||||
std::vector<SizeType> tiledInputLengths;
|
||||
for (int batch_id = 0; batch_id < inputLengths.size(); batch_id++)
|
||||
for (int batchIdx = 0; batchIdx < inputLengths.size(); batchIdx++)
|
||||
{
|
||||
for (int beam_id = 0; beam_id < maxBeamWidth; beam_id++)
|
||||
for (int beamId = 0; beamId < maxBeamWidth; beamId++)
|
||||
{
|
||||
tiledInputLengths.push_back(inputLengths.at(batch_id));
|
||||
tiledInputLengths.push_back(inputLengths.at(batchIdx));
|
||||
}
|
||||
}
|
||||
|
||||
// set up inputs
|
||||
auto logits = std::shared_ptr(
|
||||
manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, vocabSizePadded}), modelConfig.getDataType()));
|
||||
manager.setZero(*logits);
|
||||
|
||||
decoder_batch::Input inputs{logits};
|
||||
if (maxBeamWidth > 1)
|
||||
std::vector<SizeType> generatedTokensPerSteps(batchSize);
|
||||
std::vector<SizeType> acceptedTokensPerStep(batchSize);
|
||||
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
|
||||
{
|
||||
auto srcCacheIndirection = std::shared_ptr(
|
||||
manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType<SizeType>::value));
|
||||
manager.setZero(*srcCacheIndirection);
|
||||
inputs.cacheIndirection = srcCacheIndirection;
|
||||
generatedTokensPerSteps[batchIdx] = maxGeneratedTokensPerStep;
|
||||
acceptedTokensPerStep[batchIdx] = generatedTokensPerSteps[batchIdx] - 1;
|
||||
}
|
||||
|
||||
// set up outputs
|
||||
decoder_batch::Output outputs{};
|
||||
auto constexpr tokenId = 1;
|
||||
auto requests = prepareRequests(batchSize, maxNewTokens, inputLengths, generatedTokensPerSteps,
|
||||
acceptedTokensPerStep, tokenId, endId, padId, computeLogProbs, manager);
|
||||
|
||||
if (maxBeamWidth > 1)
|
||||
{
|
||||
auto tgtCacheIndirection = std::shared_ptr(
|
||||
manager.gpu(ITensor::makeShape({batchSize, maxBeamWidth, maxSeqLength}), TRTDataType<SizeType>::value));
|
||||
manager.setZero(*tgtCacheIndirection);
|
||||
outputs.cacheIndirection = tgtCacheIndirection;
|
||||
}
|
||||
auto sequenceLengths
|
||||
= std::shared_ptr(manager.gpu(ITensor::makeShape({batchSize * maxBeamWidth}), TRTDataType<SizeType>::value));
|
||||
manager.copy(tiledInputLengths.data(), *sequenceLengths);
|
||||
outputs.sequenceLengths = sequenceLengths;
|
||||
// set up inputs and outputs
|
||||
auto inputs = prepareDecoderInputs(batchSize, maxBeamWidth, maxSeqLength, vocabSizePadded, dataType,
|
||||
samplingConfigs, generatedTokensPerSteps, manager);
|
||||
auto outputs = prepareDecoderOutputs(batchSize, maxBeamWidth, maxSeqLength, tiledInputLengths, manager);
|
||||
|
||||
// We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache).
|
||||
auto const maxKvCacheLength = maxSeqLength;
|
||||
|
||||
// set up decoder
|
||||
auto decoder = GptDecoderBatch(vocabSize, vocabSizePadded, streamPtr);
|
||||
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, maxGeneratedTokensPerStep, dataType);
|
||||
|
||||
auto const& nbSteps = decoder.getNbSteps();
|
||||
EXPECT_EQ(nbSteps.size(), batchSize);
|
||||
std::vector<SizeType> expectedSteps(batchSize, 0);
|
||||
auto expectedLengths = tiledInputLengths;
|
||||
|
||||
auto const& finished = decoder.getFinished();
|
||||
EXPECT_EQ(finished.size(), batchSize);
|
||||
std::vector<bool> expectedFinished(batchSize, true);
|
||||
|
||||
auto constexpr tokenId = 1;
|
||||
std::vector<decoder_batch::Input::TensorPtr> inputIds;
|
||||
for (auto b = 0; b < batchSize; ++b)
|
||||
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
|
||||
{
|
||||
auto shape = ITensor::makeShape({inputLengths[b]});
|
||||
auto input = std::shared_ptr(manager.gpu(shape, TRTDataType<SizeType>::value));
|
||||
kernels::invokeFill(*input, tokenId, *streamPtr);
|
||||
inputIds.emplace_back(input);
|
||||
|
||||
auto decoderRequest = decoder_batch::Request{inputIds[b], maxNewTokens, endId};
|
||||
decoderRequest.computeCumLogProbs = computeLogProbs;
|
||||
decoderRequest.computeLogProbs = computeLogProbs;
|
||||
decoder.newRequest(b, decoderRequest, samplingConfigs[b]);
|
||||
decoder.newRequest(batchIdx, requests[batchIdx], samplingConfigs[batchIdx]);
|
||||
|
||||
decoder.forward(outputs, inputs);
|
||||
|
||||
for (auto i = 0; i < inputIds.size(); ++i)
|
||||
{
|
||||
expectedSteps[i] = std::min(expectedSteps[i] + 1, maxNewTokens);
|
||||
expectedFinished[i] = expectedSteps[i] == maxNewTokens;
|
||||
}
|
||||
advanceSequenceLengths(expectedLengths, acceptedTokensPerStep, samplingConfigs, batchIdx + 1, maxBeamWidth);
|
||||
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
|
||||
|
||||
EXPECT_THAT(decoder.getNbSteps(), ::testing::ElementsAreArray(expectedSteps));
|
||||
for (auto bi = 0; bi <= batchIdx; ++bi)
|
||||
{
|
||||
auto firstBeamIndex = tc::flat_index2(bi, 0, maxBeamWidth);
|
||||
expectedFinished.at(bi)
|
||||
= expectedLengths.at(firstBeamIndex) - tiledInputLengths.at(firstBeamIndex) == maxNewTokens;
|
||||
}
|
||||
EXPECT_THAT(decoder.getFinished(), ::testing::ElementsAreArray(expectedFinished));
|
||||
}
|
||||
|
||||
while (!decoder.getFinished().back())
|
||||
auto finishedVec = decoder.getFinished();
|
||||
while (!std::any_of(finishedVec.begin(), finishedVec.end(), [](bool finish) { return finish; }))
|
||||
{
|
||||
decoder.forward(outputs, inputs);
|
||||
}
|
||||
EXPECT_THAT(decoder.getFinished(), ::testing::Each(true));
|
||||
EXPECT_THAT(decoder.getNbSteps(), ::testing::Each(maxNewTokens));
|
||||
finishedVec = decoder.getFinished();
|
||||
|
||||
verifyResults(manager, decoder, samplingConfigs, inputLengths, batchSize, maxBeamWidth, maxSeqLength, maxNewTokens,
|
||||
tokenId, padId);
|
||||
advanceSequenceLengths(expectedLengths, acceptedTokensPerStep, samplingConfigs, batchSize, maxBeamWidth);
|
||||
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
|
||||
|
||||
for (auto bi = 0; bi < batchSize; ++bi)
|
||||
{
|
||||
auto firstBeamIndex = tc::flat_index2(bi, 0, maxBeamWidth);
|
||||
expectedFinished.at(bi)
|
||||
= expectedLengths.at(firstBeamIndex) - tiledInputLengths.at(firstBeamIndex) == maxNewTokens;
|
||||
}
|
||||
EXPECT_THAT(finishedVec, ::testing::ElementsAreArray(expectedFinished));
|
||||
}
|
||||
|
||||
verifyResults(manager, decoder, samplingConfigs, inputLengths, expectedLengths, batchSize, maxBeamWidth,
|
||||
maxSeqLength, tokenId, padId);
|
||||
}
|
||||
|
||||
void testDecoderDraft(nvinfer1::DataType const dtype, std::vector<SamplingConfig> const& samplingConfigs,
|
||||
SizeType maxBeamWidth, std::vector<SizeType> const& generatedTokensPerSteps,
|
||||
std::vector<SizeType> const& acceptedTokensPerStep, SizeType maxGeneratedTokensPerStep)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
SizeType constexpr tensorParallelism{1};
|
||||
SizeType constexpr pipelineParallelism{1};
|
||||
SizeType constexpr localRank{0};
|
||||
WorldConfig constexpr worldConfig{tensorParallelism, pipelineParallelism, localRank};
|
||||
|
||||
SizeType constexpr vocabSize{51200};
|
||||
SizeType constexpr nbLayers{2};
|
||||
SizeType constexpr nbHeads{16};
|
||||
SizeType constexpr hiddenSize{1024};
|
||||
GptModelConfig modelConfig{vocabSize, nbLayers, nbHeads, hiddenSize, dtype};
|
||||
modelConfig.useGptAttentionPlugin(false);
|
||||
|
||||
auto streamPtr = std::make_shared<CudaStream>();
|
||||
BufferManager manager(streamPtr);
|
||||
|
||||
TokenIdType constexpr endId{50257};
|
||||
TokenIdType constexpr padId{50257};
|
||||
|
||||
auto const dataType = modelConfig.getDataType();
|
||||
auto const vocabSizePadded = modelConfig.getVocabSizePadded(worldConfig.getSize());
|
||||
|
||||
auto const batchSize = static_cast<SizeType>(samplingConfigs.size());
|
||||
SizeType constexpr maxInputLength{8};
|
||||
SizeType const maxNewTokens{4};
|
||||
auto const maxSeqLength = maxInputLength + maxNewTokens;
|
||||
|
||||
std::vector<SizeType> inputLengths(batchSize);
|
||||
std::iota(inputLengths.begin(), inputLengths.end(), 4);
|
||||
|
||||
std::vector<SizeType> tiledInputLengths;
|
||||
for (int batchIdx = 0; batchIdx < inputLengths.size(); batchIdx++)
|
||||
{
|
||||
for (int beamId = 0; beamId < maxBeamWidth; beamId++)
|
||||
{
|
||||
tiledInputLengths.push_back(inputLengths.at(batchIdx));
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<SizeType> advancedTokensPerStep{generatedTokensPerSteps};
|
||||
std::for_each(advancedTokensPerStep.begin(), advancedTokensPerStep.end(), [](auto& x) { x -= 1; });
|
||||
|
||||
auto constexpr tokenId = 1;
|
||||
auto requests = prepareRequests(batchSize, maxNewTokens, inputLengths, generatedTokensPerSteps,
|
||||
acceptedTokensPerStep, tokenId, endId, padId, false, manager);
|
||||
|
||||
// set up inputs and outputs
|
||||
auto inputs = prepareDecoderInputs(batchSize, maxBeamWidth, maxSeqLength, vocabSizePadded, dataType,
|
||||
samplingConfigs, generatedTokensPerSteps, manager);
|
||||
auto outputs = prepareDecoderOutputs(batchSize, maxBeamWidth, maxSeqLength, tiledInputLengths, manager);
|
||||
|
||||
// We set maxKvCacheLength = maxSeqLength, but it can be smaller than maxSeqLength (cyclic kv cache).
|
||||
auto const maxKvCacheLength = maxSeqLength;
|
||||
|
||||
// set up decoder
|
||||
auto decoder = GptDecoderBatch(vocabSize, vocabSizePadded, streamPtr);
|
||||
decoder.setup(batchSize, maxBeamWidth, maxSeqLength, maxKvCacheLength, maxGeneratedTokensPerStep, dataType);
|
||||
|
||||
for (auto batchIdx = 0; batchIdx < batchSize; ++batchIdx)
|
||||
{
|
||||
decoder.newRequest(batchIdx, requests[batchIdx], samplingConfigs[batchIdx]);
|
||||
}
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
auto expectedLengths = tiledInputLengths;
|
||||
auto generatedLengths = tiledInputLengths;
|
||||
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
|
||||
|
||||
auto const& finished = decoder.getFinished();
|
||||
EXPECT_EQ(finished.size(), batchSize);
|
||||
EXPECT_THAT(finished, ::testing::Each(false));
|
||||
|
||||
verifyResults(manager, decoder, samplingConfigs, inputLengths, expectedLengths, batchSize, maxBeamWidth,
|
||||
maxSeqLength, tokenId, padId);
|
||||
|
||||
// run decoder for 1 step
|
||||
decoder.forward(outputs, inputs);
|
||||
|
||||
advanceSequenceLengths(expectedLengths, acceptedTokensPerStep, samplingConfigs, batchSize, maxBeamWidth);
|
||||
// WAR: we don't write endId back into outputIds when we rejected tokens,
|
||||
// so we adjust the lengths for verifyResults here
|
||||
advanceSequenceLengths(generatedLengths, advancedTokensPerStep, samplingConfigs, batchSize, maxBeamWidth);
|
||||
checkSequenceLengths(*outputs.sequenceLengths, expectedLengths, manager);
|
||||
EXPECT_THAT(decoder.getFinished(), ::testing::Each(false));
|
||||
|
||||
verifyResults(manager, decoder, samplingConfigs, inputLengths, generatedLengths, batchSize, maxBeamWidth,
|
||||
maxSeqLength, tokenId, padId);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -347,7 +503,26 @@ struct BeamConfig
|
||||
std::vector<SizeType> beamWidths;
|
||||
};
|
||||
|
||||
class ParamTest : public ::testing::TestWithParam<std::tuple<nvinfer1::DataType, BeamConfig, bool>>
|
||||
using ParamType = std::tuple<nvinfer1::DataType, BeamConfig, bool>;
|
||||
|
||||
std::string generateTestName(const testing::TestParamInfo<ParamType>& info)
|
||||
{
|
||||
std::string name{std::get<0>(info.param) == nvinfer1::DataType::kFLOAT ? "Float" : "Half"};
|
||||
BeamConfig const beamConfig = std::get<1>(info.param);
|
||||
name.append("MaxBeamWidth" + std::to_string(beamConfig.maxBeamWidth));
|
||||
for (auto const beamWdith : beamConfig.beamWidths)
|
||||
{
|
||||
name.append("Bw" + std::to_string(beamWdith));
|
||||
}
|
||||
bool const computeLogProbs{std::get<2>(info.param)};
|
||||
if (computeLogProbs)
|
||||
{
|
||||
name.append("LogProbs");
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
class ParamTest : public ::testing::TestWithParam<ParamType>
|
||||
{
|
||||
};
|
||||
|
||||
@ -365,29 +540,14 @@ TEST_P(ParamTest, Test)
|
||||
testDecoder(dtype, samplingConfigs, beamConfig.maxBeamWidth, computeLogProbs);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(GptDecoderBatchTest, ParamTest,
|
||||
INSTANTIATE_TEST_SUITE_P(GptDecoderBwTest, ParamTest,
|
||||
testing::Combine(testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kHALF),
|
||||
testing::Values(BeamConfig{1, {1, 1, 1}}, BeamConfig{3, {3, 3, 3, 3}}, BeamConfig{4, {1, 1}},
|
||||
BeamConfig{4, {3, 3, 3}}, BeamConfig{4, {1, 2, 3, 4}}),
|
||||
testing::Values(false, true)),
|
||||
[](const testing::TestParamInfo<ParamTest::ParamType>& info)
|
||||
{
|
||||
std::string name{std::get<0>(info.param) == nvinfer1::DataType::kFLOAT ? "Float" : "Half"};
|
||||
BeamConfig const beamConfig = std::get<1>(info.param);
|
||||
bool const computeLogProbs = std::get<2>(info.param);
|
||||
name.append("MaxBeamWidth" + std::to_string(beamConfig.maxBeamWidth));
|
||||
for (auto const beamWidth : beamConfig.beamWidths)
|
||||
{
|
||||
name.append("Bw" + std::to_string(beamWidth));
|
||||
}
|
||||
if (computeLogProbs)
|
||||
{
|
||||
name.append("LogProbs");
|
||||
}
|
||||
return name;
|
||||
});
|
||||
generateTestName);
|
||||
|
||||
class ParamWavefrontTest : public ::testing::TestWithParam<std::tuple<nvinfer1::DataType, BeamConfig, bool>>
|
||||
class ParamWavefrontTest : public ::testing::TestWithParam<ParamType>
|
||||
{
|
||||
};
|
||||
|
||||
@ -405,24 +565,70 @@ TEST_P(ParamWavefrontTest, Test)
|
||||
testDecoderWavefront(dtype, samplingConfigs, beamConfig.maxBeamWidth, computeLogProbs);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(GptDecoderBatchTest, ParamWavefrontTest,
|
||||
INSTANTIATE_TEST_SUITE_P(GptDecoderBwTest, ParamWavefrontTest,
|
||||
testing::Combine(testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kHALF),
|
||||
testing::Values(BeamConfig{1, {1, 1, 1}}, BeamConfig{3, {3, 3, 3, 3}}, BeamConfig{4, {1, 1}},
|
||||
BeamConfig{4, {3, 3, 3}}, BeamConfig{4, {1, 2, 3, 4}}),
|
||||
testing::Values(false, true)),
|
||||
[](const testing::TestParamInfo<ParamTest::ParamType>& info)
|
||||
generateTestName);
|
||||
|
||||
struct DraftConfig
|
||||
{
|
||||
SizeType maxGeneratedTokensPerStep;
|
||||
std::vector<SizeType> generatedTokensPerSteps;
|
||||
std::vector<SizeType> acceptedTokensPerStep;
|
||||
};
|
||||
|
||||
using DraftTestParamType = std::tuple<nvinfer1::DataType, BeamConfig, DraftConfig>;
|
||||
|
||||
class ParamDraftTest : public ::testing::TestWithParam<DraftTestParamType>
|
||||
{
|
||||
};
|
||||
|
||||
TEST_P(ParamDraftTest, Test)
|
||||
{
|
||||
nvinfer1::DataType const dtype{std::get<0>(GetParam())};
|
||||
BeamConfig const beamConfig{std::get<1>(GetParam())};
|
||||
DraftConfig const draftConfig{std::get<2>(GetParam())};
|
||||
|
||||
ASSERT_EQ(beamConfig.beamWidths.size(), draftConfig.acceptedTokensPerStep.size());
|
||||
ASSERT_EQ(beamConfig.beamWidths.size(), draftConfig.generatedTokensPerSteps.size());
|
||||
|
||||
std::vector<SamplingConfig> samplingConfigs;
|
||||
for (auto const beamWidth : beamConfig.beamWidths)
|
||||
{
|
||||
samplingConfigs.emplace_back(beamWidth);
|
||||
}
|
||||
|
||||
testDecoderDraft(dtype, samplingConfigs, beamConfig.maxBeamWidth, draftConfig.generatedTokensPerSteps,
|
||||
draftConfig.acceptedTokensPerStep, draftConfig.maxGeneratedTokensPerStep);
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(GptDecoderTest, ParamDraftTest,
|
||||
testing::Combine(testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kHALF),
|
||||
testing::Values(BeamConfig{1, {1, 1, 1}}, BeamConfig{4, {1, 1, 1}}),
|
||||
testing::Values( //
|
||||
DraftConfig{2, {1, 1, 1}, {0, 0, 0}}, DraftConfig{2, {2, 2, 2}, {1, 1, 1}},
|
||||
DraftConfig{4, {1, 2, 3}, {0, 0, 1}}
|
||||
|
||||
)),
|
||||
[](const testing::TestParamInfo<DraftTestParamType>& info)
|
||||
{
|
||||
std::string name{std::get<0>(info.param) == nvinfer1::DataType::kFLOAT ? "Float" : "Half"};
|
||||
BeamConfig const beamConfig = std::get<1>(info.param);
|
||||
bool const computeLogProbs = std::get<2>(info.param);
|
||||
DraftConfig const draftConfig = std::get<2>(info.param);
|
||||
name.append("MaxBeamWidth" + std::to_string(beamConfig.maxBeamWidth));
|
||||
auto const batchSize = beamConfig.beamWidths.size();
|
||||
for (auto const beamWdith : beamConfig.beamWidths)
|
||||
{
|
||||
name.append("Bw" + std::to_string(beamWdith));
|
||||
}
|
||||
if (computeLogProbs)
|
||||
name.append("PerStep" + std::to_string(draftConfig.maxGeneratedTokensPerStep));
|
||||
for (std::size_t i = 0; i < batchSize; ++i)
|
||||
{
|
||||
name.append("LogProbs");
|
||||
auto const acc = draftConfig.acceptedTokensPerStep.at(i);
|
||||
auto const gen = draftConfig.generatedTokensPerSteps.at(i);
|
||||
name.append("Acc" + std::to_string(acc) + "of" + std::to_string(gen));
|
||||
}
|
||||
return name;
|
||||
});
|
||||
|
||||
@ -97,6 +97,7 @@ void testDecoder(nvinfer1::DataType const dtype, SamplingConfig const& samplingC
|
||||
outputs.lengths
|
||||
= manager.copyFrom(sequenceLengthsVec, ITensor::makeShape({batchSize, beamWidth}), MemoryType::kGPU);
|
||||
outputs.finished = manager.gpu(ITensor::makeShape({batchSize, beamWidth}), nvinfer1::DataType::kBOOL);
|
||||
inputs.finished = ITensor::view(outputs.finished);
|
||||
manager.setZero(*outputs.finished);
|
||||
outputs.finishedSum = BufferManager::pinned(ITensor::makeShape({1}), nvinfer1::DataType::kINT32);
|
||||
auto* finishedSumHost = bufferCast<std::int32_t>(*outputs.finishedSum);
|
||||
|
||||
@ -692,7 +692,7 @@ namespace
|
||||
|
||||
// TODO: consolidate this function with testGptSession
|
||||
// Notice: all ChatGLM / GLM models use this
|
||||
// function The differences are GptModelConfig::ModelVariant
|
||||
// The differences are GptModelConfig::ModelVariant
|
||||
void testChatGlmSession(fs::path const& modelPath, std::string const& modelName, ModelSpec const& modelSpec,
|
||||
ModelIds const modelIds, SizeType beamWidth, std::initializer_list<int> const& batchSizes,
|
||||
std::shared_ptr<nvinfer1::ILogger> const& logger, bool cudaGraphMode, MicroBatchSizes microBatchSizes)
|
||||
@ -718,7 +718,7 @@ void testChatGlmSession(fs::path const& modelPath, std::string const& modelName,
|
||||
auto const expectedOutputData = bufferCast<TokenIdType const>(*expectedOutput);
|
||||
|
||||
ASSERT_TRUE(fs::exists(modelPath));
|
||||
auto const json = GptJsonConfig::parse(modelPath / (modelName + "-config.json"));
|
||||
auto const json = GptJsonConfig::parse(modelPath / "config.json");
|
||||
auto const modelConfig = json.getModelConfig();
|
||||
verifyModelConfig(modelConfig, modelSpec);
|
||||
auto const decoderPerRequest = modelSpec.mDecoderPerRequest;
|
||||
@ -867,7 +867,7 @@ void testChatGlmSession(fs::path const& modelPath, std::string const& modelName,
|
||||
TEST_F(ChatGlmSessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
{
|
||||
auto const modelName{"chatglm_6b"};
|
||||
auto const modelPath{ENGINE_PATH / "chatglm"};
|
||||
auto const modelPath{ENGINE_PATH / modelName};
|
||||
auto const batchSizes = {1};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
@ -879,7 +879,7 @@ TEST_F(ChatGlmSessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
{
|
||||
auto const modelName{"chatglm2_6b"};
|
||||
auto const modelPath{ENGINE_PATH / "chatglm"};
|
||||
auto const modelPath{ENGINE_PATH / modelName};
|
||||
auto const batchSizes = {1};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
@ -891,7 +891,7 @@ TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS2BM1)
|
||||
{
|
||||
auto const modelName{"chatglm2_6b"};
|
||||
auto const modelPath{ENGINE_PATH / "chatglm"};
|
||||
auto const modelPath{ENGINE_PATH / modelName};
|
||||
auto const batchSizes = {2};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
@ -903,7 +903,7 @@ TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS2BM1)
|
||||
TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM2)
|
||||
{
|
||||
auto const modelName{"chatglm2_6b"};
|
||||
auto const modelPath{ENGINE_PATH / "chatglm"};
|
||||
auto const modelPath{ENGINE_PATH / modelName};
|
||||
auto const batchSizes = {1};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
@ -915,7 +915,7 @@ TEST_F(ChatGlm2SessionTest, SamplingFP16WithGptAttentionPluginBS1BM2)
|
||||
TEST_F(ChatGlm3SessionTest, SamplingFP16WithGptAttentionPluginBS1BM1)
|
||||
{
|
||||
auto const modelName{"chatglm3_6b"};
|
||||
auto const modelPath{ENGINE_PATH / "chatglm"};
|
||||
auto const modelPath{ENGINE_PATH / modelName};
|
||||
auto const batchSizes = {1};
|
||||
auto constexpr dtype = nvinfer1::DataType::kHALF;
|
||||
auto const modelSpec = ModelSpec{"", "", dtype}.useGptAttentionPlugin();
|
||||
|
||||
@ -19,23 +19,24 @@ COPY docker/common/install_cmake.sh install_cmake.sh
|
||||
RUN bash ./install_cmake.sh && rm install_cmake.sh
|
||||
|
||||
# Download & install internal TRT release
|
||||
ARG TRT_VER="9.1.0.4"
|
||||
ENV TRT_VER=$TRT_VER
|
||||
ARG CUDA_VER="12.2"
|
||||
ENV CUDA_VER=$CUDA_VER
|
||||
ARG CUDNN_VER="8.9.4.25-1+cuda12.2"
|
||||
ENV CUDNN_VER=$CUDNN_VER
|
||||
ARG NCCL_VER="2.18.3-1+cuda12.2"
|
||||
ENV NCCL_VER=$NCCL_VER
|
||||
ARG CUBLAS_VER="12.2.5.6-1"
|
||||
ENV CUBLAS_VER=$CUBLAS_VER
|
||||
ARG TRT_VER CUDA_VER CUDNN_VER NCCL_VER CUBLAS_VER
|
||||
COPY docker/common/install_tensorrt.sh install_tensorrt.sh
|
||||
RUN bash ./install_tensorrt.sh && rm install_tensorrt.sh
|
||||
RUN bash ./install_tensorrt.sh \
|
||||
--TRT_VER=${TRT_VER} \
|
||||
--CUDA_VER=${CUDA_VER} \
|
||||
--CUDNN_VER=${CUDNN_VER} \
|
||||
--NCCL_VER=${NCCL_VER} \
|
||||
--CUBLAS_VER=${CUBLAS_VER} && \
|
||||
rm install_tensorrt.sh
|
||||
|
||||
# Install latest Polygraphy
|
||||
COPY docker/common/install_polygraphy.sh install_polygraphy.sh
|
||||
RUN bash ./install_polygraphy.sh && rm install_polygraphy.sh
|
||||
|
||||
# Install mpi4py
|
||||
COPY docker/common/install_mpi4py.sh install_mpi4py.sh
|
||||
RUN bash ./install_mpi4py.sh && rm install_mpi4py.sh
|
||||
|
||||
# Install PyTorch
|
||||
ARG TORCH_INSTALL_TYPE="skip"
|
||||
COPY docker/common/install_pytorch.sh install_pytorch.sh
|
||||
@ -49,7 +50,7 @@ COPY benchmarks benchmarks
|
||||
COPY scripts scripts
|
||||
COPY tensorrt_llm tensorrt_llm
|
||||
COPY 3rdparty 3rdparty
|
||||
COPY setup.py requirements.txt ./
|
||||
COPY setup.py requirements.txt requirements-dev.txt ./
|
||||
|
||||
ARG BUILD_WHEEL_ARGS="--clean --trt_root /usr/local/tensorrt"
|
||||
RUN python3 scripts/build_wheel.py ${BUILD_WHEEL_ARGS}
|
||||
|
||||
@ -102,10 +102,14 @@ wheel_%: STAGE = wheel
|
||||
|
||||
release_%: STAGE = release
|
||||
|
||||
# For x86_64 and aarch64
|
||||
# For x86_64
|
||||
jenkins_%: IMAGE_WITH_TAG = $(shell grep 'LLM_DOCKER_IMAGE = ' ../jenkins/L0_MergeRequest.groovy | grep -o '".*"' | tr -d '"')
|
||||
jenkins_%: STAGE = devel
|
||||
|
||||
# For aarch64
|
||||
jenkins-aarch64_%: IMAGE_WITH_TAG = $(shell grep 'LLM_DOCKER_IMAGE = ' ../jenkins/L1_Nightly_GH200RemoteJob.groovy | grep -o '".*"' | tr -d '"')
|
||||
jenkins-aarch64_%: STAGE = devel
|
||||
|
||||
# For x86_64
|
||||
centos7_%: IMAGE_WITH_TAG = $(shell grep 'LLM_CENTOS7_DOCKER_IMAGE = ' ../jenkins/L0_MergeRequest.groovy | grep -o '".*"' | tr -d '"')
|
||||
centos7_%: STAGE = devel
|
||||
@ -117,7 +121,7 @@ ubuntu22_%: STAGE = devel
|
||||
ubuntu22_%: BASE_IMAGE = nvidia/cuda
|
||||
ubuntu22_%: BASE_TAG = 12.2.2-devel-ubuntu22.04
|
||||
|
||||
# For x86_64 and aarch64
|
||||
# For x86_64
|
||||
old-cuda_%: IMAGE_WITH_TAG = $(shell grep 'LLM_OLD_CUDA_DOCKER_IMAGE = ' ../jenkins/L0_MergeRequest.groovy | grep -o '".*"' | tr -d '"')
|
||||
old-cuda_%: BASE_TAG = 23.07-py3
|
||||
old-cuda_%: STAGE = devel
|
||||
|
||||
@ -30,7 +30,7 @@ init_ubuntu() {
|
||||
fi
|
||||
apt-get clean
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
echo "export LD_LIBRARY_PATH=/usr/local/cuda/lib64:\$LD_LIBRARY_PATH" >> "${ENV}"
|
||||
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64' >> "${ENV}"
|
||||
# Remove previous TRT installation
|
||||
if [[ $(apt list --installed | grep libnvinfer) ]]; then
|
||||
apt-get remove --purge -y libnvinfer*
|
||||
@ -39,7 +39,6 @@ init_ubuntu() {
|
||||
apt-get remove --purge -y tensorrt*
|
||||
fi
|
||||
pip uninstall -y tensorrt
|
||||
pip install mpi4py
|
||||
}
|
||||
|
||||
install_gcc_centos() {
|
||||
@ -53,7 +52,7 @@ install_gcc_centos() {
|
||||
./contrib/download_prerequisites
|
||||
./configure --disable-multilib --enable-languages=c,c++ --with-pi
|
||||
make -j$(nproc) && make install
|
||||
echo "export LD_LIBRARY_PATH=/usr/local/lib64:\$LD_LIBRARY_PATH" >> "${ENV}"
|
||||
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64' >> "${ENV}"
|
||||
cd .. && rm -rf /tmp/gcc-*
|
||||
yum clean all
|
||||
}
|
||||
@ -64,7 +63,7 @@ init_centos() {
|
||||
yum -y update
|
||||
yum -y install centos-release-scl-rh epel-release
|
||||
# https://gitlab.com/nvidia/container-images/cuda
|
||||
echo "export LD_LIBRARY_PATH=/usr/local/cuda/lib64:\$LD_LIBRARY_PATH" >> "${ENV}"
|
||||
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64' >> "${ENV}"
|
||||
CUDA_VERSION=$(nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p')
|
||||
YUM_CUDA=${CUDA_VERSION/./-}
|
||||
# Consistent with manylinux2014 centos-7 based version
|
||||
@ -73,7 +72,7 @@ init_centos() {
|
||||
echo "source scl_source enable rh-git227 rh-python38" >> "${ENV}"
|
||||
echo "source scl_source enable devtoolset-10" >> "${DEVTOOLSET_ENV_FILE}"
|
||||
echo "source ${DEVTOOLSET_ENV_FILE}" >> "${ENV}"
|
||||
echo 'export PATH=/usr/lib64/openmpi3/bin:$PATH' >> "${ENV}"
|
||||
echo 'export PATH=$PATH:/usr/lib64/openmpi3/bin' >> "${ENV}"
|
||||
bash -c "pip install 'urllib3<2.0'"
|
||||
yum clean all
|
||||
}
|
||||
|
||||
@ -12,4 +12,4 @@ wget --no-verbose ${RELEASE_URL_CMAKE} -P /tmp
|
||||
tar -xf /tmp/${CMAKE_FILE_NAME}.tar.gz -C /usr/local/
|
||||
ln -s /usr/local/${CMAKE_FILE_NAME} /usr/local/cmake
|
||||
|
||||
echo 'export PATH=/usr/local/cmake/bin:$PATH' >> "${ENV}"
|
||||
echo 'export PATH=$PATH:/usr/local/cmake/bin' >> "${ENV}"
|
||||
|
||||
11
docker/common/install_mpi4py.sh
Normal file
11
docker/common/install_mpi4py.sh
Normal file
@ -0,0 +1,11 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
|
||||
MPI4PY_VERSION="3.1.5"
|
||||
RELEASE_URL="https://github.com/mpi4py/mpi4py/archive/refs/tags/${MPI4PY_VERSION}.tar.gz"
|
||||
curl -L ${RELEASE_URL} | tar -zx -C /tmp
|
||||
# Bypassing compatibility issues with higher versions (>= 69) of setuptools.
|
||||
sed -i 's/>= 40\.9\.0/>= 40.9.0, < 69/g' /tmp/mpi4py-${MPI4PY_VERSION}/pyproject.toml
|
||||
pip install /tmp/mpi4py-${MPI4PY_VERSION}
|
||||
rm -rf /tmp/mpi4py*
|
||||
@ -2,6 +2,24 @@
|
||||
|
||||
set -ex
|
||||
|
||||
TRT_VER="9.1.0.4"
|
||||
CUDA_VER="12.2"
|
||||
CUDNN_VER="8.9.4.25-1+cuda12.2"
|
||||
NCCL_VER="2.18.3-1+cuda12.2"
|
||||
CUBLAS_VER="12.2.5.6-1"
|
||||
|
||||
for i in "$@"; do
|
||||
case $i in
|
||||
--TRT_VER=?*) TRT_VER="${i#*=}";;
|
||||
--CUDA_VER=?*) CUDA_VER="${i#*=}";;
|
||||
--CUDNN_VER=?*) CUDNN_VER="${i#*=}";;
|
||||
--NCCL_VER=?*) NCCL_VER="${i#*=}";;
|
||||
--CUBLAS_VER=?*) CUBLAS_VER="${i#*=}";;
|
||||
*) ;;
|
||||
esac
|
||||
shift
|
||||
done
|
||||
|
||||
NVCC_VERSION_OUTPUT=$(nvcc --version)
|
||||
if [[ $(echo $NVCC_VERSION_OUTPUT | grep -oP "\d+\.\d+" | head -n 1) != ${CUDA_VER} ]]; then
|
||||
echo "The version of pre-installed CUDA is not equal to ${CUDA_VER}."
|
||||
@ -64,7 +82,7 @@ install_tensorrt() {
|
||||
mv /usr/local/TensorRT-${TRT_VER} /usr/local/tensorrt
|
||||
pip install /usr/local/tensorrt/python/tensorrt-*-cp${PARSED_PY_VERSION}-*.whl
|
||||
rm -rf /tmp/TensorRT.tar
|
||||
echo 'export LD_LIBRARY_PATH=/usr/local/tensorrt/lib:$LD_LIBRARY_PATH' >> "${ENV}"
|
||||
echo 'export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/tensorrt/lib' >> "${ENV}"
|
||||
}
|
||||
|
||||
# Install base packages depending on the base OS
|
||||
|
||||
@ -257,7 +257,7 @@ as:
|
||||
|
||||
# Collectives.
|
||||
def allreduce(tensor: Tensor, group: List[int]) -> Tensor
|
||||
def allgather(tensor: Tensor, group: List[int]) -> Tensor
|
||||
def allgather(tensor: Tensor, group: List[int], gather_dim: int = 0) -> Tensor
|
||||
|
||||
# Point-to-point communication primitives.
|
||||
def send(tensor: Tensor, tgt: int) -> Tensor
|
||||
@ -267,7 +267,7 @@ def recv(tensor: Tensor, src: int) -> Tensor
|
||||
The multi-GPU support can be enabled through two different modes of model
|
||||
parallelism: Tensor Parallelism and Pipeline Parallelism. The former mode
|
||||
splits the different layers of a model across the GPUs. Each GPU runs the
|
||||
entire network and synchronizes with its sibblings when needed. The Pipeline
|
||||
entire network and synchronizes with its siblings when needed. The Pipeline
|
||||
Parallelism distributes the different layers to the GPUs. Each GPU runs a
|
||||
subset of the entire model and communications happen at the boundary of those
|
||||
subsets of layers. Tensor Parallelism usually leads to more balanced executions
|
||||
|
||||
@ -4,11 +4,10 @@ This document shows how to build and run a Baichuan models (including `v1_7b`/`v
|
||||
|
||||
## Overview
|
||||
|
||||
The TensorRT-LLM Baichuan implementation can be found in [tensorrt_llm/models/baichuan/model.py](../../tensorrt_llm/models/baichuan/model.py). The TensorRT-LLM Baichuan example code is located in [`examples/baichuan`](./). There are three main files:
|
||||
The TensorRT-LLM Baichuan implementation can be found in [tensorrt_llm/models/baichuan/model.py](../../tensorrt_llm/models/baichuan/model.py). The TensorRT-LLM Baichuan example code is located in [`examples/baichuan`](./). There are two main files:
|
||||
|
||||
* [`build.py`](./build.py) to build the [TensorRT](https://developer.nvidia.com/tensorrt) engine(s) needed to run the Baichuan model,
|
||||
* [`run.py`](./run.py) to run the inference on an input text,
|
||||
* and a shared [`../summarize.py`](../summarize.py) to summarize the articles in the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset using the model.
|
||||
* [`run.py`](./run.py) to run the inference on an input text.
|
||||
|
||||
These scripts accept an argument named model_version, whose value should be `v1_7b`/`v1_13b`/`v2_7b`/`v2_13b` and the default value is `v1_13b`.
|
||||
|
||||
|
||||
@ -18,9 +18,12 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
import tensorrt as trt
|
||||
|
||||
# isort: off
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import tensorrt as trt
|
||||
# isort: on
|
||||
from onnx import TensorProto, helper
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
@ -499,8 +502,6 @@ def build_rank_engine(builder: Builder,
|
||||
config_path = os.path.join(args.output_dir, 'config.json')
|
||||
builder.save_config(builder_config, config_path)
|
||||
|
||||
tensorrt_llm.tools.cleanup(network, tensorrt_llm_baichuan)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
|
||||
@ -132,6 +132,7 @@ def hf_baichuan_converter(args):
|
||||
saved_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = AutoModelForCausalLM.from_pretrained(args.in_file,
|
||||
torch_dtype="auto",
|
||||
device_map="auto",
|
||||
trust_remote_code=True)
|
||||
|
||||
|
||||
@ -16,8 +16,10 @@ import argparse
|
||||
import os
|
||||
from collections import OrderedDict
|
||||
|
||||
import tensorrt as trt
|
||||
# isort: off
|
||||
import torch
|
||||
import tensorrt as trt
|
||||
# isort: on
|
||||
from transformers import BertConfig, BertForQuestionAnswering, BertModel
|
||||
|
||||
import tensorrt_llm
|
||||
|
||||
@ -16,8 +16,10 @@ import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
import tensorrt as trt
|
||||
# isort: off
|
||||
import torch
|
||||
import tensorrt as trt
|
||||
# isort: on
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm import logger
|
||||
|
||||
@ -3,8 +3,10 @@ import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import tensorrt as trt
|
||||
# isort: off
|
||||
import torch
|
||||
import tensorrt as trt
|
||||
# isort: on
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
import tensorrt_llm
|
||||
|
||||
@ -4,11 +4,10 @@ This document shows how to build and run a BLOOM model in TensorRT-LLM on both s
|
||||
|
||||
## Overview
|
||||
|
||||
The TensorRT-LLM BLOOM implementation can be found in [tensorrt_llm/models/bloom/model.py](../../tensorrt_llm/models/bloom/model.py). The TensorRT-LLM BLOOM example code is located in [`examples/bloom`](./). There are three main files:
|
||||
The TensorRT-LLM BLOOM implementation can be found in [tensorrt_llm/models/bloom/model.py](../../tensorrt_llm/models/bloom/model.py). The TensorRT-LLM BLOOM example code is located in [`examples/bloom`](./). There are two main files:
|
||||
|
||||
* [`build.py`](./build.py) to build the [TensorRT](https://developer.nvidia.com/tensorrt) engine(s) needed to run the BLOOM model,
|
||||
* [`run.py`](./run.py) to run the inference on an input text,
|
||||
* and a shared [`../summarize.py`](../summarize.py) to summarize the articles in the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset using the model.
|
||||
* [`run.py`](./run.py) to run the inference on an input text.
|
||||
|
||||
## Support Matrix
|
||||
* FP16
|
||||
|
||||
@ -18,9 +18,12 @@ import time
|
||||
from pathlib import Path
|
||||
|
||||
import onnx
|
||||
import tensorrt as trt
|
||||
|
||||
# isort: off
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import tensorrt as trt
|
||||
# isort: on
|
||||
from onnx import TensorProto, helper
|
||||
from transformers import BloomConfig, BloomForCausalLM
|
||||
|
||||
@ -477,8 +480,6 @@ def build_rank_engine(builder: Builder,
|
||||
config_path = os.path.join(args.output_dir, 'config.json')
|
||||
builder.save_config(builder_config, config_path)
|
||||
|
||||
tensorrt_llm.tools.cleanup(network, tensorrt_llm_bloom)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
|
||||
2
examples/chatglm/.gitignore
vendored
2
examples/chatglm/.gitignore
vendored
@ -4,4 +4,4 @@ awq/
|
||||
chatglm*_6b*/
|
||||
dataset/
|
||||
glm_10b/
|
||||
trtModel/
|
||||
output_*/
|
||||
|
||||
@ -5,11 +5,10 @@ This document explains how to build the [ChatGLM-6B](https://huggingface.co/THUD
|
||||
## Overview
|
||||
|
||||
The TensorRT-LLM ChatGLM implementation can be found in [`tensorrt_llm/models/chatglm/model.py`](../../tensorrt_llm/models/chatglm/model.py).
|
||||
The TensorRT-LLM ChatGLM example code is located in [`examples/chatglm`](./). There are three main files:
|
||||
The TensorRT-LLM ChatGLM example code is located in [`examples/chatglm`](./). There are two main files:
|
||||
|
||||
* [`build.py`](./build.py) to build the [TensorRT](https://developer.nvidia.com/tensorrt) engine(s) needed to run the ChatGLM model.
|
||||
* [`run.py`](./run.py) to run the inference on an input text.
|
||||
* and a shared [`../summarize.py`](../summarize.py) to summarize the articles in the [cnn_dailymail](https://huggingface.co/datasets/cnn_dailymail) dataset using the model.
|
||||
|
||||
## Support Matrix
|
||||
|
||||
|
||||
@ -18,12 +18,13 @@ import time
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
import onnx
|
||||
import tensorrt as trt
|
||||
# isort: off
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from onnx import TensorProto, helper
|
||||
from weight import load_from_hf
|
||||
import tensorrt as trt
|
||||
# isort: on
|
||||
from visualize import to_onnx
|
||||
from weight import get_scaling_factors, load_from_hf
|
||||
|
||||
import tensorrt_llm
|
||||
from tensorrt_llm._utils import str_dtype_to_trt
|
||||
@ -36,12 +37,12 @@ from tensorrt_llm.plugin.plugin import ContextFMHAType
|
||||
from tensorrt_llm.profiler import check_gpt_mem_usage
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
|
||||
from weight import get_scaling_factors # isort:skip
|
||||
from weight import load_from_hf_checkpoint # isort:skip
|
||||
|
||||
|
||||
def get_engine_name(model, dtype, tp_size, rank):
|
||||
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
|
||||
def get_engine_name(model, dtype, tp_size, pp_size, rank):
|
||||
if pp_size == 1:
|
||||
return '{}_{}_tp{}_rank{}.engine'.format(model, dtype, tp_size, rank)
|
||||
return '{}_{}_tp{}_pp{}_rank{}.engine'.format(model, dtype, tp_size,
|
||||
pp_size, rank)
|
||||
|
||||
|
||||
def find_engines(dir: Path,
|
||||
@ -53,61 +54,6 @@ def find_engines(dir: Path,
|
||||
return list(dir.glob(template))
|
||||
|
||||
|
||||
def trt_dtype_to_onnx(dtype):
|
||||
if dtype == trt.float16:
|
||||
return TensorProto.DataType.FLOAT16
|
||||
elif dtype == trt.float32:
|
||||
return TensorProto.DataType.FLOAT
|
||||
elif dtype == trt.int32:
|
||||
return TensorProto.DataType.INT32
|
||||
else:
|
||||
raise TypeError("%s is not supported" % dtype)
|
||||
|
||||
|
||||
def to_onnx(network, path):
|
||||
inputs = []
|
||||
for i in range(network.num_inputs):
|
||||
network_input = network.get_input(i)
|
||||
inputs.append(
|
||||
helper.make_tensor_value_info(
|
||||
network_input.name, trt_dtype_to_onnx(network_input.dtype),
|
||||
list(network_input.shape)))
|
||||
|
||||
outputs = []
|
||||
for i in range(network.num_outputs):
|
||||
network_output = network.get_output(i)
|
||||
outputs.append(
|
||||
helper.make_tensor_value_info(
|
||||
network_output.name, trt_dtype_to_onnx(network_output.dtype),
|
||||
list(network_output.shape)))
|
||||
|
||||
nodes = []
|
||||
for i in range(network.num_layers):
|
||||
layer = network.get_layer(i)
|
||||
layer_inputs = []
|
||||
for j in range(layer.num_inputs):
|
||||
ipt = layer.get_input(j)
|
||||
if ipt is not None:
|
||||
layer_inputs.append(layer.get_input(j).name)
|
||||
layer_outputs = [
|
||||
layer.get_output(j).name for j in range(layer.num_outputs)
|
||||
]
|
||||
nodes.append(
|
||||
helper.make_node(str(layer.type),
|
||||
name=layer.name,
|
||||
inputs=layer_inputs,
|
||||
outputs=layer_outputs,
|
||||
domain="com.nvidia"))
|
||||
|
||||
onnx_model = helper.make_model(helper.make_graph(nodes,
|
||||
'attention',
|
||||
inputs,
|
||||
outputs,
|
||||
initializer=None),
|
||||
producer_name='NVIDIA')
|
||||
onnx.save(onnx_model, path)
|
||||
|
||||
|
||||
def serialize_engine(engine, path):
|
||||
logger.info(f'Serializing engine to {path}...')
|
||||
tik = time.time()
|
||||
@ -118,7 +64,7 @@ def serialize_engine(engine, path):
|
||||
logger.info(f'Engine serialized. Total time: {t}')
|
||||
|
||||
|
||||
def truncate_input_output(
|
||||
def truncate_input_output_len(
|
||||
max_input_len,
|
||||
max_output_len,
|
||||
max_seq_length_from_config,
|
||||
@ -152,21 +98,29 @@ def parse_arguments(args):
|
||||
help=
|
||||
'the name of the model, use "_" rather than "-" to connect the name parts'
|
||||
)
|
||||
parser.add_argument('--world_size',
|
||||
type=int,
|
||||
default=1,
|
||||
help='world size, only support tensor parallelism now')
|
||||
parser.add_argument('--tp_size', type=int, default=1)
|
||||
parser.add_argument('--pp_size', type=int, default=1)
|
||||
parser.add_argument(
|
||||
'--world_size',
|
||||
'-ws',
|
||||
type=int,
|
||||
default=1,
|
||||
help='world size, only support tensor parallelism now',
|
||||
)
|
||||
parser.add_argument('--tp_size', '-tp', type=int, default=1)
|
||||
parser.add_argument('--pp_size', '-pp', type=int, default=1)
|
||||
parser.add_argument('--model_dir', type=str, default=None)
|
||||
parser.add_argument('--dtype',
|
||||
type=str,
|
||||
default='float16',
|
||||
choices=['float32', 'float16', 'bfloat16'])
|
||||
parser.add_argument('--logits_dtype',
|
||||
type=str,
|
||||
default='float32',
|
||||
choices=['float16', 'float32'])
|
||||
parser.add_argument('--quant_ckpt_path', type=str, default="awq/")
|
||||
parser.add_argument(
|
||||
'--dtype',
|
||||
type=str,
|
||||
default='float16',
|
||||
choices=['float32', 'float16', 'bfloat16'],
|
||||
)
|
||||
parser.add_argument(
|
||||
'--logits_dtype',
|
||||
type=str,
|
||||
default='float32',
|
||||
choices=['float16', 'float32'],
|
||||
)
|
||||
parser.add_argument(
|
||||
'--timing_cache',
|
||||
type=str,
|
||||
@ -177,8 +131,9 @@ def parse_arguments(args):
|
||||
parser.add_argument(
|
||||
'--log_level',
|
||||
type=str,
|
||||
default='verbose',
|
||||
choices=['verbose', 'info', 'warning', 'error', 'internal_error'])
|
||||
default='info',
|
||||
choices=['verbose', 'info', 'warning', 'error', 'internal_error'],
|
||||
)
|
||||
parser.add_argument('--max_batch_size', type=int, default=8)
|
||||
parser.add_argument('--max_input_len', type=int, default=1024)
|
||||
parser.add_argument('--max_output_len', type=int, default=1024)
|
||||
@ -226,12 +181,16 @@ def parse_arguments(args):
|
||||
action='store_true',
|
||||
default=False)
|
||||
parser.add_argument('--parallel_build', default=False, action='store_true')
|
||||
parser.add_argument('--enable_context_fmha',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument('--enable_context_fmha_fp32_acc',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument(
|
||||
'--enable_context_fmha',
|
||||
default=False,
|
||||
action='store_true',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--enable_context_fmha_fp32_acc',
|
||||
default=False,
|
||||
action='store_true',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--multi_block_mode',
|
||||
default=False,
|
||||
@ -240,19 +199,18 @@ def parse_arguments(args):
|
||||
'Split long kv sequence into multiple blocks (applied to generation MHA kernels). \
|
||||
It is beneifical when batchxnum_heads cannot fully utilize GPU.'
|
||||
)
|
||||
parser.add_argument('--load_by_shard',
|
||||
action='store_true',
|
||||
help='Load a pretrained model shard-by-shard.')
|
||||
parser.add_argument('--visualize', default=False, action='store_true')
|
||||
parser.add_argument('--enable_debug_output',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument(
|
||||
'--enable_debug_output',
|
||||
default=False,
|
||||
action='store_true',
|
||||
)
|
||||
parser.add_argument('--gpus_per_node', type=int, default=8)
|
||||
parser.add_argument('--builder_opt', type=int, default=None)
|
||||
parser.add_argument(
|
||||
'--output_dir',
|
||||
type=Path,
|
||||
default='trtModel',
|
||||
default=None,
|
||||
help=
|
||||
'The path to save the serialized engine files, timing cache file and model configs'
|
||||
)
|
||||
@ -263,9 +221,11 @@ def parse_arguments(args):
|
||||
help=
|
||||
'This option is introduced with trt 9.1.0.1+ and will reduce the building time significantly for fp8.'
|
||||
)
|
||||
parser.add_argument('--remove_input_padding',
|
||||
default=False,
|
||||
action='store_true')
|
||||
parser.add_argument(
|
||||
'--remove_input_padding',
|
||||
default=False,
|
||||
action='store_true',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--paged_kv_cache',
|
||||
action="store_true",
|
||||
@ -277,7 +237,8 @@ def parse_arguments(args):
|
||||
'--use_inflight_batching',
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Activates inflight batching mode of gptAttentionPlugin.")
|
||||
help="Activates inflight batching mode of gptAttentionPlugin.",
|
||||
)
|
||||
|
||||
# Arguments related to the quantization of the model.
|
||||
parser.add_argument(
|
||||
@ -293,17 +254,18 @@ def parse_arguments(args):
|
||||
default=False,
|
||||
action="store_true",
|
||||
help='Quantize weights for the various GEMMs to INT4/INT8.'
|
||||
'See --weight_only_precision to set the precision')
|
||||
'See --weight_only_precision to set the precision',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--weight_only_precision',
|
||||
const='int8',
|
||||
type=str,
|
||||
nargs='?',
|
||||
default='int8',
|
||||
choices=['int8', 'int4'],
|
||||
choices=['int8', 'int4', 'int4_awq'],
|
||||
help=
|
||||
'Define the precision for the weights when using weight-only quantization.'
|
||||
'You must also use --use_weight_only for that argument to have an impact.'
|
||||
'You must also use --use_weight_only for that argument to have an impact.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--per_channel',
|
||||
@ -312,7 +274,8 @@ def parse_arguments(args):
|
||||
help=
|
||||
'By default, we use a single static scaling factor for the GEMM\'s result. '
|
||||
'per_channel instead uses a different static scaling factor for each channel. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
'The latter is usually more accurate, but a little slower.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--per_token',
|
||||
default=False,
|
||||
@ -320,7 +283,23 @@ def parse_arguments(args):
|
||||
help=
|
||||
'By default, we use a single static scaling factor to scale activations in the int8 range. '
|
||||
'per_token chooses at run time, and for each token, a custom scaling factor. '
|
||||
'The latter is usually more accurate, but a little slower.')
|
||||
'The latter is usually more accurate, but a little slower.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--per_group',
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=
|
||||
'By default, we use a single static scaling factor to scale weights in the int4 range. '
|
||||
'per_group chooses at run time, and for each group, a custom scaling factor. '
|
||||
'The flag is built for GPTQ/AWQ quantization.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--group_size',
|
||||
type=int,
|
||||
default=128,
|
||||
help='Group size used in GPTQ/AWQ quantization.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--int8_kv_cache',
|
||||
default=False,
|
||||
@ -333,16 +312,20 @@ def parse_arguments(args):
|
||||
type=int,
|
||||
default=None,
|
||||
help=
|
||||
'Seed to use when initializing the random number generator for torch.')
|
||||
parser.add_argument('--tokens_per_block',
|
||||
type=int,
|
||||
default=64,
|
||||
help='Number of tokens per block in paged KV cache')
|
||||
'Seed to use when initializing the random number generator for torch.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--tokens_per_block',
|
||||
type=int,
|
||||
default=64,
|
||||
help='Number of tokens per block in paged KV cache',
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
'--enable_fp8',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help='Use FP8 Linear layer for Attention QKV/Dense and MLP.',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--fp8_kv_cache',
|
||||
@ -355,13 +338,16 @@ def parse_arguments(args):
|
||||
'--max_num_tokens',
|
||||
type=int,
|
||||
default=None,
|
||||
help='Define the max number of tokens supported by the engine')
|
||||
help='Define the max number of tokens supported by the engine',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--use_custom_all_reduce',
|
||||
action='store_true',
|
||||
help=
|
||||
'Activates latency-optimized algorithm for all-reduce instead of NCCL.')
|
||||
'Activates latency-optimized algorithm for all-reduce instead of NCCL.',
|
||||
)
|
||||
args = parser.parse_args(args)
|
||||
|
||||
logger.set_level(args.log_level)
|
||||
|
||||
plugins_args = [
|
||||
@ -377,6 +363,8 @@ def parse_arguments(args):
|
||||
)
|
||||
setattr(args, plugin_arg, args.dtype)
|
||||
|
||||
assert args.world_size == args.tp_size * args.pp_size # only TP is supported now
|
||||
|
||||
if args.model_dir is None:
|
||||
args.model_dir = args.model_name
|
||||
with open(Path(args.model_dir) / "config.json", "r") as f:
|
||||
@ -385,75 +373,96 @@ def parse_arguments(args):
|
||||
if args.model_name in ["chatglm_6b", "glm_10b"]:
|
||||
assert args.max_input_len < js["max_sequence_length"]
|
||||
|
||||
if args.output_dir is None:
|
||||
args.output_dir = Path("output_" + args.model_name)
|
||||
|
||||
if args.model_name in ["chatglm_6b"]:
|
||||
args.ffn_hidden_size = js["inner_hidden_size"]
|
||||
args.hidden_size = js["hidden_size"]
|
||||
args.norm_epsilon = js["layernorm_epsilon"]
|
||||
args.num_heads = js["num_attention_heads"]
|
||||
args.num_layers = js["num_layers"]
|
||||
args.vocab_size = js["vocab_size"]
|
||||
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output(
|
||||
args.max_input_len, args.max_output_len, js["max_sequence_length"])
|
||||
args.apply_query_key_layer_scaling = False
|
||||
args.hidden_act = 'gelu'
|
||||
args.linear_bias = True
|
||||
args.multi_block_mode = False
|
||||
args.multi_query_mode = False
|
||||
args.num_kv_heads = js["num_attention_heads"]
|
||||
args.qkv_bias = True
|
||||
args.use_cache = js["use_cache"]
|
||||
elif args.model_name in ["glm_10b"]:
|
||||
args.hidden_size = js["hidden_size"]
|
||||
args.num_attention_heads = js["num_attention_heads"]
|
||||
args.num_heads = js["num_attention_heads"]
|
||||
args.num_layers = js["num_layers"]
|
||||
args.vocab_size = js["vocab_size"]
|
||||
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output(
|
||||
args.max_input_len, args.max_output_len, js["max_sequence_length"],
|
||||
True)
|
||||
args.apply_query_key_layer_scaling = False
|
||||
args.apply_residual_connection_post_layernorm = False
|
||||
args.ffn_hidden_size = 4 * args.hidden_size
|
||||
args.ffn_hidden_size = js["inner_hidden_size"]
|
||||
args.hidden_act = 'gelu'
|
||||
args.hidden_size = js["hidden_size"]
|
||||
args.linear_bias = True
|
||||
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output_len(
|
||||
args.max_input_len,
|
||||
args.max_output_len,
|
||||
js["max_sequence_length"],
|
||||
)
|
||||
args.multi_block_mode = False
|
||||
args.multi_query_mode = False
|
||||
args.norm_epsilon = 1.0e-5
|
||||
args.norm_epsilon = js["layernorm_epsilon"]
|
||||
args.num_heads = js["num_attention_heads"]
|
||||
args.num_kv_heads = js["num_attention_heads"]
|
||||
args.num_layers = js["num_layers"]
|
||||
args.qkv_bias = True
|
||||
args.use_cache = True
|
||||
args.rmsnorm = False
|
||||
args.rotary_embedding_scaling = 1.0
|
||||
args.use_cache = js["use_cache"]
|
||||
args.vocab_size = js["vocab_size"]
|
||||
elif args.model_name in [
|
||||
"chatglm2_6b", "chatglm2_6b_32k", "chatglm3_6b", "chatglm3_6b_base",
|
||||
"chatglm3_6b_32k"
|
||||
"chatglm2_6b",
|
||||
"chatglm2_6b_32k",
|
||||
"chatglm3_6b",
|
||||
"chatglm3_6b_base",
|
||||
"chatglm3_6b_32k",
|
||||
]:
|
||||
args.apply_query_key_layer_scaling = False
|
||||
args.apply_residual_connection_post_layernorm = js[
|
||||
"apply_residual_connection_post_layernorm"]
|
||||
args.ffn_hidden_size = js["ffn_hidden_size"]
|
||||
args.hidden_act = 'swiglu'
|
||||
args.hidden_size = js["hidden_size"]
|
||||
args.linear_bias = js["add_bias_linear"]
|
||||
args.multi_query_mode = js["multi_query_attention"]
|
||||
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output_len(
|
||||
args.max_input_len,
|
||||
args.max_output_len,
|
||||
js["seq_length"],
|
||||
)
|
||||
args.multi_block_mode = False
|
||||
args.multi_query_mode = False # regardless of config.json
|
||||
args.norm_epsilon = js["layernorm_epsilon"]
|
||||
args.num_heads = js["num_attention_heads"]
|
||||
args.num_kv_heads = js["multi_query_group_num"]
|
||||
args.num_layers = js["num_layers"]
|
||||
args.qkv_bias = js["add_qkv_bias"]
|
||||
args.rmsnorm = js["rmsnorm"]
|
||||
args.use_cache = js["use_cache"]
|
||||
args.vocab_size = js["padded_vocab_size"]
|
||||
args.max_seq_length = min(args.max_input_len + args.max_output_len,
|
||||
js["seq_length"])
|
||||
if args.model_name in ["chatglm2_6b_32k", "chatglm3_6b_32k"]:
|
||||
args.rotary_embedding_scaling = js["rope_ratio"]
|
||||
args.hidden_act = 'swiglu'
|
||||
else:
|
||||
args.rotary_embedding_scaling = 1.0
|
||||
args.use_cache = js["use_cache"]
|
||||
args.vocab_size = js["padded_vocab_size"]
|
||||
elif args.model_name in ["glm_10b"]:
|
||||
args.apply_query_key_layer_scaling = False
|
||||
args.apply_residual_connection_post_layernorm = False
|
||||
args.ffn_hidden_size = 4 * js["hidden_size"]
|
||||
args.hidden_act = 'gelu'
|
||||
args.hidden_size = js["hidden_size"]
|
||||
args.linear_bias = True
|
||||
args.max_input_len, args.max_output_len, args.max_seq_length = truncate_input_output_len(
|
||||
args.max_input_len,
|
||||
args.max_output_len,
|
||||
js["max_sequence_length"],
|
||||
True,
|
||||
)
|
||||
args.multi_block_mode = False
|
||||
args.multi_query_mode = False
|
||||
args.norm_epsilon = 1.0e-5
|
||||
args.num_heads = js["num_attention_heads"]
|
||||
args.num_kv_heads = js["num_attention_heads"]
|
||||
args.num_layers = js["num_layers"]
|
||||
args.qkv_bias = True
|
||||
args.rmsnorm = False
|
||||
args.rotary_embedding_scaling = 1.0
|
||||
args.use_cache = True
|
||||
args.vocab_size = js["vocab_size"]
|
||||
|
||||
if args.use_inflight_batching:
|
||||
if not args.use_gpt_attention_plugin:
|
||||
args.use_gpt_attention_plugin = 'float16'
|
||||
logger.info(
|
||||
f"Using GPT attention plugin for inflight batching mode. "
|
||||
f"Setting to default '{args.use_gpt_attention_plugin}'")
|
||||
f"Using GPT attention plugin for inflight batching mode. Setting to default '{args.use_gpt_attention_plugin}'"
|
||||
)
|
||||
if not args.remove_input_padding:
|
||||
args.remove_input_padding = True
|
||||
logger.info(
|
||||
@ -478,18 +487,19 @@ def parse_arguments(args):
|
||||
if args.int8_kv_cache:
|
||||
args.quant_mode = args.quant_mode.set_int8_kv_cache()
|
||||
|
||||
if args.fp8_kv_cache:
|
||||
assert (
|
||||
args.use_gpt_attention_plugin or args.use_inflight_batching
|
||||
), "You have to use GPT attention plugin when fp8 KV cache is set"
|
||||
elif args.fp8_kv_cache:
|
||||
args.quant_mode = args.quant_mode.set_fp8_kv_cache()
|
||||
|
||||
if args.enable_fp8:
|
||||
args.quant_mode = args.quant_mode.set_fp8_qdq()
|
||||
|
||||
if args.max_num_tokens is not None:
|
||||
assert args.enable_context_fmha
|
||||
|
||||
logger.info(' Build Arguments '.center(100, '='))
|
||||
for k, v in vars(args).items():
|
||||
logger.info(f' - {k.ljust(30, ".")}: {v}')
|
||||
logger.info('=' * 100)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
@ -510,40 +520,59 @@ def build_rank_engine(
|
||||
args.mapping = Mapping(
|
||||
world_size=args.world_size,
|
||||
rank=rank,
|
||||
tp_size=args.world_size,
|
||||
tp_size=args.tp_size,
|
||||
)
|
||||
assert args.num_layers % args.pp_size == 0, \
|
||||
f"num_layers {args.n_layer} must be a multiple of pipeline "\
|
||||
f"parallelism size {args.pp_size}"
|
||||
trtllm_model = ChatGLMHeadModel(args=args)
|
||||
trtllm_model = ChatGLMHeadModel(
|
||||
apply_query_key_layer_scaling=args.apply_query_key_layer_scaling,
|
||||
apply_residual_connection_post_layernorm=args.
|
||||
apply_residual_connection_post_layernorm,
|
||||
dtype=args.dtype,
|
||||
enable_debug_output=args.enable_debug_output,
|
||||
ffn_hidden_size=args.ffn_hidden_size,
|
||||
hidden_act=args.hidden_act,
|
||||
hidden_size=args.hidden_size,
|
||||
linear_bias=args.linear_bias,
|
||||
logits_dtype=args.logits_dtype,
|
||||
mapping=args.mapping,
|
||||
max_input_len=args.max_input_len,
|
||||
max_output_len=args.max_output_len,
|
||||
max_seq_length=args.max_seq_length,
|
||||
model_name=args.model_name,
|
||||
norm_epsilon=args.norm_epsilon,
|
||||
num_heads=args.num_heads,
|
||||
num_kv_heads=args.num_kv_heads,
|
||||
num_layers=args.num_layers,
|
||||
qkv_bias=args.qkv_bias,
|
||||
quant_mode=args.quant_mode,
|
||||
rmsnorm=args.rmsnorm,
|
||||
rotary_embedding_scaling=args.rotary_embedding_scaling,
|
||||
tokens_per_block=args.tokens_per_block,
|
||||
use_cache=args.use_cache,
|
||||
vocab_size=args.vocab_size,
|
||||
)
|
||||
|
||||
if args.use_smooth_quant or args.use_weight_only:
|
||||
trtllm_model = quantize_model(trtllm_model, args.quant_mode)
|
||||
if args.enable_fp8 or args.fp8_kv_cache:
|
||||
elif args.enable_fp8 or args.fp8_kv_cache:
|
||||
logger.info(f'Loading scaling factors from '
|
||||
f'{args.quantized_fp8_model_path}')
|
||||
quant_scales = get_scaling_factors(args.quantized_fp8_model_path,
|
||||
num_layers=args.n_layer,
|
||||
quant_mode=args.quant_mode)
|
||||
tensorrt_llm_falcon = quantize_model(tensorrt_llm_falcon,
|
||||
quant_mode=args.quant_mode,
|
||||
quant_scales=quant_scales)
|
||||
if not args.load_by_shard:
|
||||
trtllm_model = load_from_hf(
|
||||
trtllm_model,
|
||||
args.model_dir,
|
||||
mapping=args.mapping,
|
||||
dtype=args.dtype,
|
||||
model_name=args.model_name,
|
||||
)
|
||||
else:
|
||||
trtllm_model = load_from_hf_checkpoint(
|
||||
trtllm_model,
|
||||
args.model_dir,
|
||||
mapping=args.mapping,
|
||||
dtype=args.dtype,
|
||||
model_name=args.model_name,
|
||||
)
|
||||
trtllm_model = quantize_model(trtllm_model,
|
||||
quant_mode=args.quant_mode,
|
||||
quant_scales=quant_scales)
|
||||
|
||||
trtllm_model = load_from_hf(
|
||||
trtllm_model,
|
||||
args.model_dir,
|
||||
mapping=args.mapping,
|
||||
dtype=args.dtype,
|
||||
model_name=args.model_name,
|
||||
)
|
||||
|
||||
# Module -> Network
|
||||
network = builder.create_network()
|
||||
@ -552,12 +581,20 @@ def build_rank_engine(
|
||||
network.plugin_config.set_gpt_attention_plugin(
|
||||
dtype=args.use_gpt_attention_plugin)
|
||||
if args.use_gemm_plugin:
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
|
||||
if args.use_layernorm_plugin:
|
||||
network.plugin_config.set_layernorm_plugin(
|
||||
dtype=args.use_layernorm_plugin)
|
||||
if not args.enable_fp8:
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.use_gemm_plugin)
|
||||
else:
|
||||
logger.info(
|
||||
"Gemm plugin does not support FP8. Disabled Gemm plugin.")
|
||||
if args.use_rmsnorm_plugin:
|
||||
network.plugin_config.set_rmsnorm_plugin(dtype=args.use_rmsnorm_plugin)
|
||||
|
||||
# Quantization plugins.
|
||||
if args.use_smooth_quant:
|
||||
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_rmsnorm_quantization_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_quantize_tensor_plugin()
|
||||
network.plugin_config.set_quantize_per_token_plugin()
|
||||
assert not (args.enable_context_fmha and args.enable_context_fmha_fp32_acc)
|
||||
if args.enable_context_fmha:
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
@ -566,7 +603,13 @@ def build_rank_engine(
|
||||
ContextFMHAType.enabled_with_fp32_acc)
|
||||
if args.multi_block_mode:
|
||||
network.plugin_config.enable_mmha_multi_block_mode()
|
||||
|
||||
if args.use_weight_only:
|
||||
if args.per_group:
|
||||
network.plugin_config.set_weight_only_groupwise_quant_matmul_plugin(
|
||||
dtype='float16')
|
||||
else:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype='float16')
|
||||
if args.world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(args.dtype,
|
||||
args.use_custom_all_reduce)
|
||||
@ -575,21 +618,6 @@ def build_rank_engine(
|
||||
if args.paged_kv_cache:
|
||||
network.plugin_config.enable_paged_kv_cache(args.tokens_per_block)
|
||||
|
||||
# Quantization plugins.
|
||||
if args.use_smooth_quant:
|
||||
network.plugin_config.set_smooth_quant_gemm_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_layernorm_quantization_plugin(
|
||||
dtype=args.dtype)
|
||||
network.plugin_config.set_quantize_tensor_plugin()
|
||||
network.plugin_config.set_quantize_per_token_plugin()
|
||||
elif args.use_weight_only:
|
||||
network.plugin_config.set_weight_only_quant_matmul_plugin(
|
||||
dtype=args.dtype)
|
||||
|
||||
if args.world_size > 1:
|
||||
network.plugin_config.set_nccl_plugin(args.dtype,
|
||||
args.use_custom_all_reduce)
|
||||
|
||||
with net_guard(network):
|
||||
# Prepare
|
||||
network.set_named_parameters(trtllm_model.named_parameters())
|
||||
@ -605,7 +633,7 @@ def build_rank_engine(
|
||||
trtllm_model(*inputs)
|
||||
if args.enable_debug_output:
|
||||
# mark intermediate nodes' outputs
|
||||
for k, v in tensorrt_llm_falcon.named_network_outputs():
|
||||
for k, v in trtllm_model.named_network_outputs():
|
||||
v = v.trt_tensor
|
||||
v.name = k
|
||||
network.trt_network.mark_output(v)
|
||||
@ -616,24 +644,21 @@ def build_rank_engine(
|
||||
|
||||
tensorrt_llm.graph_rewriting.optimize(network)
|
||||
|
||||
engine = None
|
||||
|
||||
# Network -> Engine
|
||||
engine = None
|
||||
engine = builder.build_engine(network, builder_config)
|
||||
if rank == 0:
|
||||
config_path = args.output_dir / (args.model_name + '-config.json')
|
||||
config_path = args.output_dir / 'config.json'
|
||||
builder.save_config(builder_config, config_path)
|
||||
|
||||
tensorrt_llm.tools.cleanup(network, trtllm_model)
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
def build(rank, args):
|
||||
torch.cuda.set_device(rank % args.gpus_per_node)
|
||||
tensorrt_llm.logger.set_level(args.log_level)
|
||||
logger.set_level(args.log_level)
|
||||
args.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
timing_cache_file = args.output_dir / "model.cache"
|
||||
timing_cache_file = args.timing_cache
|
||||
timing_cache = timing_cache_file
|
||||
|
||||
builder = Builder()
|
||||
@ -642,12 +667,15 @@ def build(rank, args):
|
||||
# skip other ranks if parallel_build is enabled
|
||||
if args.parallel_build and cur_rank != rank:
|
||||
continue
|
||||
# NOTE: when only int8 kv cache is used together with paged kv cache no int8 tensors are exposed to TRT
|
||||
int8_trt_flag = args.quant_mode.has_act_or_weight_quant() or (
|
||||
not args.paged_kv_cache and args.quant_mode.has_int8_kv_cache())
|
||||
builder_config = builder.create_builder_config(
|
||||
precision=args.dtype,
|
||||
timing_cache=timing_cache,
|
||||
tensor_parallel=args.world_size,
|
||||
int8=(args.quant_mode.has_act_or_weight_quant()
|
||||
or args.quant_mode.has_int8_kv_cache()),
|
||||
tensor_parallel=args.tp_size,
|
||||
pipeline_parallel=args.pp_size,
|
||||
int8=int8_trt_flag,
|
||||
fp8=args.enable_fp8,
|
||||
strongly_typed=args.strongly_typed,
|
||||
opt_level=args.builder_opt,
|
||||
@ -678,6 +706,7 @@ def build(rank, args):
|
||||
args.model_name,
|
||||
args.dtype,
|
||||
args.world_size,
|
||||
args.pp_size,
|
||||
cur_rank,
|
||||
)
|
||||
engine = build_rank_engine(
|
||||
@ -706,7 +735,7 @@ def build(rank, args):
|
||||
max_input_len=args.max_input_len,
|
||||
max_output_len=args.max_output_len,
|
||||
local_num_kv_heads=local_num_kv_heads,
|
||||
head_size=args.hidden_size / args.num_heads,
|
||||
head_size=args.hidden_size // args.num_heads,
|
||||
num_layers=args.num_layers)
|
||||
|
||||
if cur_rank == 0:
|
||||
@ -734,8 +763,8 @@ def run_build(args=None):
|
||||
if args.parallel_build and args.world_size > 1 and \
|
||||
torch.cuda.device_count() >= args.world_size:
|
||||
logger.warning(
|
||||
f'Parallelly build TensorRT engines. Please make sure that all '
|
||||
f'of the {args.world_size} GPUs are totally free.')
|
||||
f'Parallelly build TensorRT engines. Please make sure that all of the {args.world_size} GPUs are totally free.'
|
||||
)
|
||||
mp.spawn(build, nprocs=args.world_size, args=(args, ))
|
||||
else:
|
||||
args.parallel_build = False
|
||||
|
||||
@ -93,40 +93,50 @@ def get_model(ckpt_path, dtype="float16", cache_dir=None):
|
||||
return model
|
||||
|
||||
|
||||
def get_args():
|
||||
def parse_arguments(args):
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument("--model_dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Directory of a HF model checkpoint")
|
||||
parser.add_argument("--dtype", help="Model data type.", default="float16")
|
||||
parser.add_argument(
|
||||
"--qformat",
|
||||
'--model_name',
|
||||
'-m',
|
||||
type=str,
|
||||
choices=['fp8', 'int4_awq'],
|
||||
default='fp8',
|
||||
help='Quantization format. Currently only fp8 is supported. '
|
||||
'For int8 smoothquant, use smoothquant.py instead. ')
|
||||
required=True,
|
||||
choices=[
|
||||
"chatglm_6b", "chatglm2_6b", "chatglm2_6b_32k", "chatglm3_6b",
|
||||
"chatglm3_6b_base", "chatglm3_6b_32k", "glm_10b"
|
||||
],
|
||||
help=
|
||||
'the name of the model, use "_" rather than "-" to connect the name parts'
|
||||
)
|
||||
parser.add_argument("--dtype", help="Model data type.", default="float16")
|
||||
parser.add_argument("--qformat",
|
||||
type=str,
|
||||
choices=['fp8', 'int4_awq'],
|
||||
default='int4_awq',
|
||||
help='Quantization format.'
|
||||
'For int8 smoothquant, use smoothquant.py instead.')
|
||||
parser.add_argument("--calib_size",
|
||||
type=int,
|
||||
default=512,
|
||||
default=32,
|
||||
help="Number of samples for calibration.")
|
||||
parser.add_argument("--export_path", default="exported_model")
|
||||
parser.add_argument('--model_dir', type=str, default=None)
|
||||
parser.add_argument("--export_path", default="awq")
|
||||
parser.add_argument("--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
default="dataset/",
|
||||
help="Directory of dataset cache.")
|
||||
parser.add_argument('--seed', type=int, default=None, help='Random seed')
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
def main(args=None):
|
||||
if not torch.cuda.is_available():
|
||||
raise EnvironmentError("GPU is required for inference.")
|
||||
|
||||
args = get_args()
|
||||
args = parse_arguments(args)
|
||||
|
||||
if args.model_dir is None:
|
||||
args.model_dir = args.model_name
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
@ -43,8 +43,9 @@ def parse_arguments(args=None):
|
||||
)
|
||||
parser.add_argument('--max_output_len', type=int, default=1024)
|
||||
parser.add_argument('--log_level', type=str, default='error')
|
||||
parser.add_argument('--engine_dir', type=str, default='trtModel')
|
||||
parser.add_argument('--engine_dir', type=str, default=None)
|
||||
parser.add_argument('--beam_width', type=int, default=1)
|
||||
parser.add_argument('--streaming', default=False, action='store_true')
|
||||
parser.add_argument(
|
||||
'--input_text',
|
||||
type=str,
|
||||
@ -71,14 +72,20 @@ def parse_arguments(args=None):
|
||||
parser.add_argument('--top_k', type=int, default=1)
|
||||
parser.add_argument('--top_p', type=float, default=0.0)
|
||||
parser.add_argument('--random_seed', type=int, default=1)
|
||||
return parser.parse_args(args)
|
||||
|
||||
args = parser.parse_args(args)
|
||||
|
||||
if args.engine_dir is None:
|
||||
args.engine_dir = Path("output_" + args.model_name)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_arguments()
|
||||
tensorrt_llm.logger.set_level(args.log_level)
|
||||
|
||||
config_path = Path(args.engine_dir) / (args.model_name + '-config.json')
|
||||
config_path = Path(args.engine_dir) / 'config.json'
|
||||
with open(config_path, 'r') as f:
|
||||
config = json.load(f)
|
||||
|
||||
@ -89,9 +96,11 @@ if __name__ == '__main__':
|
||||
max_beam_width = config['builder_config']['max_beam_width']
|
||||
remove_input_padding = config['builder_config']['remove_input_padding']
|
||||
use_gpt_attention_plugin = config['plugin_config']['gpt_attention_plugin']
|
||||
world_size = config['builder_config']['tensor_parallel']
|
||||
assert world_size == tensorrt_llm.mpi_world_size(
|
||||
), f'Engine world size ({world_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
|
||||
tp_size = config['builder_config']['tensor_parallel']
|
||||
pp_size = config['builder_config']['pipeline_parallel']
|
||||
world_size = tp_size * pp_size
|
||||
assert world_size == tensorrt_llm.mpi_world_size(), \
|
||||
f'Engine world size ({tp_size} * {pp_size}) != Runtime world size ({tensorrt_llm.mpi_world_size()})'
|
||||
|
||||
if args.max_output_len > max_output_len:
|
||||
print("Truncate max_output_len as %d" % max_output_len)
|
||||
@ -199,9 +208,10 @@ if __name__ == '__main__':
|
||||
model_config = ModelConfig(
|
||||
vocab_size=config['builder_config']['vocab_size'],
|
||||
num_layers=config['builder_config']['num_layers'],
|
||||
num_heads=config['builder_config']['num_heads'] // world_size,
|
||||
num_kv_heads=config['builder_config']['num_kv_heads'] // world_size,
|
||||
hidden_size=config['builder_config']['hidden_size'] // world_size,
|
||||
num_heads=config['builder_config']['num_heads'] // tp_size,
|
||||
num_kv_heads=(config['builder_config']['num_kv_heads'] + tp_size - 1) //
|
||||
tp_size,
|
||||
hidden_size=config['builder_config']['hidden_size'] // tp_size,
|
||||
gpt_attention_plugin=use_gpt_attention_plugin,
|
||||
remove_input_padding=config['builder_config']['remove_input_padding'],
|
||||
model_name=args.model_name,
|
||||
@ -251,11 +261,8 @@ if __name__ == '__main__':
|
||||
sampling_config,
|
||||
output_sequence_lengths=True,
|
||||
return_dict=True,
|
||||
streaming=args.streaming,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
output_ids = output["output_ids"]
|
||||
output_lengths = output["sequence_lengths"]
|
||||
|
||||
if runtime_rank == 0:
|
||||
|
||||
@ -271,18 +278,44 @@ if __name__ == '__main__':
|
||||
]:
|
||||
from process import process_response
|
||||
|
||||
for i in range(batch_size):
|
||||
print("\nInput %2d ---> len=%d\n%s" %
|
||||
(i, input_lengths[i], input_text[i]))
|
||||
if args.streaming: # streaming output
|
||||
print("#" * 80)
|
||||
# only print the first batch and the first beam to show the effect,
|
||||
# actually all beams of all batches are available
|
||||
print("Input %2d ---> len=%d\n%s" %
|
||||
(0, input_lengths[0], input_text[0]))
|
||||
print("\nOutput %2d --->" % i)
|
||||
output_ids_one_batch = output_ids[i, :, input_lengths[i]:]
|
||||
output_lengths_one_batch = output_lengths[i]
|
||||
output_token_list = tokenizer.batch_decode(output_ids_one_batch,
|
||||
skip_special_tokens=True)
|
||||
output_token_list = process_response(output_token_list)
|
||||
for j, (length, simple_output) in enumerate(
|
||||
zip(output_lengths_one_batch, output_token_list)):
|
||||
print("\n Beam %2d ---> len=%d\n%s" %
|
||||
(j, length, simple_output))
|
||||
|
||||
print("Finished!")
|
||||
for output_item in output:
|
||||
output_id = output_item["output_ids"]
|
||||
output_sequence_lengths = output_item["sequence_lengths"]
|
||||
output_id = output_id[0, 0, output_sequence_lengths[0, 0] - 1]
|
||||
output_word = tokenizer.convert_ids_to_tokens(int(output_id))
|
||||
output_word = output_word.replace("▁", " ") # For English
|
||||
output_word = tokenizer.convert_tokens_to_string(output_word)
|
||||
print(output_word, end="", flush=True)
|
||||
print("\n" + "#" * 80)
|
||||
else: # regular output
|
||||
torch.cuda.synchronize()
|
||||
output_ids = output["output_ids"]
|
||||
output_lengths = output["sequence_lengths"]
|
||||
print("#" * 80)
|
||||
for i in range(batch_size):
|
||||
print("Input %2d ---> len=%d\n%s" %
|
||||
(i, input_lengths[i], input_text[i]))
|
||||
print("\nOutput %2d --->" % i)
|
||||
output_ids_one_batch = output_ids[i, :, input_lengths[i]:]
|
||||
output_lengths_one_batch = output_lengths[i] - input_lengths[
|
||||
i] + 1
|
||||
output_token_list = tokenizer.batch_decode(
|
||||
output_ids_one_batch, skip_special_tokens=True)
|
||||
output_token_list = process_response(output_token_list)
|
||||
for j, (length, simple_output) in enumerate(
|
||||
zip(output_lengths_one_batch, output_token_list)):
|
||||
print(" Beam %2d ---> len=%d\n%s" %
|
||||
(j, length, simple_output))
|
||||
print("#" * 80)
|
||||
|
||||
del decoder
|
||||
|
||||
print(f"Finished from worker {runtime_rank}")
|
||||
|
||||
73
examples/chatglm/visualize.py
Normal file
73
examples/chatglm/visualize.py
Normal file
@ -0,0 +1,73 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import onnx
|
||||
import tensorrt as trt
|
||||
from onnx import TensorProto, helper
|
||||
|
||||
|
||||
def trt_dtype_to_onnx(dtype):
|
||||
if dtype == trt.float16:
|
||||
return TensorProto.DataType.FLOAT16
|
||||
elif dtype == trt.float32:
|
||||
return TensorProto.DataType.FLOAT
|
||||
elif dtype == trt.int32:
|
||||
return TensorProto.DataType.INT32
|
||||
else:
|
||||
raise TypeError("%s is not supported" % dtype)
|
||||
|
||||
|
||||
def to_onnx(network, path):
|
||||
inputs = []
|
||||
for i in range(network.num_inputs):
|
||||
network_input = network.get_input(i)
|
||||
inputs.append(
|
||||
helper.make_tensor_value_info(
|
||||
network_input.name, trt_dtype_to_onnx(network_input.dtype),
|
||||
list(network_input.shape)))
|
||||
|
||||
outputs = []
|
||||
for i in range(network.num_outputs):
|
||||
network_output = network.get_output(i)
|
||||
outputs.append(
|
||||
helper.make_tensor_value_info(
|
||||
network_output.name, trt_dtype_to_onnx(network_output.dtype),
|
||||
list(network_output.shape)))
|
||||
|
||||
nodes = []
|
||||
for i in range(network.num_layers):
|
||||
layer = network.get_layer(i)
|
||||
layer_inputs = []
|
||||
for j in range(layer.num_inputs):
|
||||
ipt = layer.get_input(j)
|
||||
if ipt is not None:
|
||||
layer_inputs.append(layer.get_input(j).name)
|
||||
layer_outputs = [
|
||||
layer.get_output(j).name for j in range(layer.num_outputs)
|
||||
]
|
||||
nodes.append(
|
||||
helper.make_node(str(layer.type),
|
||||
name=layer.name,
|
||||
inputs=layer_inputs,
|
||||
outputs=layer_outputs,
|
||||
domain="com.nvidia"))
|
||||
|
||||
onnx_model = helper.make_model(helper.make_graph(nodes,
|
||||
'attention',
|
||||
inputs,
|
||||
outputs,
|
||||
initializer=None),
|
||||
producer_name='NVIDIA')
|
||||
onnx.save(onnx_model, path)
|
||||
@ -29,6 +29,15 @@ from tensorrt_llm.models.quantized.quant import get_dummy_quant_scales
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
|
||||
|
||||
def split(weight: np.ndarray, tp_size: int, rank: int = 0, dim: int = 0):
|
||||
if tp_size == 1:
|
||||
return weight
|
||||
elif weight.ndim == 1:
|
||||
return np.ascontiguousarray(np.split(weight, tp_size)[rank].copy())
|
||||
return np.ascontiguousarray(
|
||||
np.split(weight, tp_size, axis=dim)[rank].copy())
|
||||
|
||||
|
||||
def split_matrix(weight: np.ndarray, tp_size: int, rank: int, dim: int):
|
||||
return np.ascontiguousarray(split(weight, tp_size, rank, dim=dim))
|
||||
|
||||
@ -81,7 +90,7 @@ def load_quant_weight(src, value_dst, scale_dst, plugin_weight_only_quant_type):
|
||||
def load_from_hf(
|
||||
trt_model,
|
||||
hf_model_dir,
|
||||
mapping=None,
|
||||
mapping=Mapping(),
|
||||
dtype="float32",
|
||||
model_name=None,
|
||||
multi_query_mode=False,
|
||||
@ -100,9 +109,9 @@ def load_from_hf(
|
||||
|
||||
hf_model = transformers.AutoModel.from_pretrained(hf_model_dir,
|
||||
trust_remote_code=True)
|
||||
num_layers = hf_model.config.num_layers
|
||||
hidden_size = hf_model.config.hidden_size
|
||||
num_heads = hf_model.config.num_attention_heads
|
||||
num_layers = hf_model.config.num_layers
|
||||
|
||||
torch_type = str_dtype_to_torch(dtype)
|
||||
quant_mode = getattr(trt_model, 'quant_mode', QuantMode(0))
|
||||
@ -344,7 +353,8 @@ def load_from_hf(
|
||||
scale_dst=dst.per_channel_scale,
|
||||
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
|
||||
else:
|
||||
dst.weight.value = torch_to_numpy(split_weight)
|
||||
dst.weight.value = np.ascontiguousarray(
|
||||
torch_to_numpy(split_weight))
|
||||
feed_weight_count += 1
|
||||
|
||||
# Dense multiplication bias, only GLM-10B
|
||||
@ -462,7 +472,8 @@ def load_from_hf(
|
||||
scale_dst=dst.per_channel_scale,
|
||||
plugin_weight_only_quant_type=plugin_weight_only_quant_type)
|
||||
else:
|
||||
dst.weight.value = torch_to_numpy(split_weight)
|
||||
dst.weight.value = np.ascontiguousarray(
|
||||
torch_to_numpy(split_weight))
|
||||
feed_weight_count += 1
|
||||
|
||||
# Multilayer perceptron 4h -> h bias, only GLM-10B
|
||||
@ -498,115 +509,6 @@ def load_from_hf(
|
||||
return trt_model
|
||||
|
||||
|
||||
def load_from_hf_checkpoint(
|
||||
trtllm_falcon: tensorrt_llm.models.FalconForCausalLM,
|
||||
model_dir: Union[str, Path],
|
||||
mapping=Mapping(),
|
||||
dtype: Union[str, torch.dtype] = torch.float32,
|
||||
):
|
||||
logger.info('Loading weights from HF Falcon...')
|
||||
tik = time.time()
|
||||
|
||||
model_dir = Path(model_dir)
|
||||
if isinstance(dtype, str):
|
||||
dtype = tensorrt_llm._utils.str_dtype_to_torch(dtype)
|
||||
|
||||
def is_bias(_name):
|
||||
return 'bias' in _name
|
||||
|
||||
layers_range = trtllm_falcon.get_transformer_layers(
|
||||
trtllm_falcon.mapping, trtllm_falcon.num_layers)
|
||||
for model_file in iterate_shard_files(model_dir, mapping.tp_rank):
|
||||
logger.debug(f'Loading file {str(model_file)}...')
|
||||
state_dict = load_state_dict(model_file, dtype)
|
||||
for name, param in state_dict.items():
|
||||
logger.debug(f'Converting weight {name}...')
|
||||
i = retrieved_layer_index_from_name(name)
|
||||
if i is None:
|
||||
layer = None
|
||||
else:
|
||||
if i not in layers_range:
|
||||
continue
|
||||
layer = trtllm_falcon.layers[i - layers_range[0]]
|
||||
|
||||
if 'self_attention.query_key_value' in name:
|
||||
if not is_bias(name):
|
||||
layer.attention.qkv.weight.value = split_qkv_weight(
|
||||
trtllm_falcon,
|
||||
param,
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
is_bias=False,
|
||||
num_kv_heads=trtllm_falcon.num_kv_heads)
|
||||
else:
|
||||
layer.attention.qkv.bias.value = split_qkv_weight(
|
||||
trtllm_falcon,
|
||||
param,
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
is_bias=True,
|
||||
num_kv_heads=trtllm_falcon.num_kv_heads)
|
||||
elif 'self_attention.dense' in name:
|
||||
if not is_bias(name):
|
||||
layer.attention.dense.weight.value = split_matrix(
|
||||
param, mapping.tp_size, mapping.tp_rank, dim=1)
|
||||
else:
|
||||
layer.attention.dense.bias.value = param
|
||||
elif 'mlp.dense_h_to_4h' in name:
|
||||
if not is_bias(name):
|
||||
layer.mlp.fc.weight.value = split_matrix(param,
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
dim=0)
|
||||
else:
|
||||
layer.mlp.fc.bias.value = split_matrix(param,
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
dim=0)
|
||||
elif 'mlp.dense_4h_to_h' in name:
|
||||
if not is_bias(name):
|
||||
layer.mlp.proj.weight.value = split_matrix(param,
|
||||
mapping.tp_size,
|
||||
mapping.tp_rank,
|
||||
dim=1)
|
||||
else:
|
||||
layer.mlp.proj.bias.value = param
|
||||
elif 'ln_attn' in name or 'input_layernorm' in name:
|
||||
if not is_bias(name):
|
||||
layer.input_layernorm.weight.value = param
|
||||
else:
|
||||
layer.input_layernorm.bias.value = param
|
||||
elif 'ln_mlp' in name:
|
||||
assert layer.mlp_layernorm is not None
|
||||
if not is_bias(name):
|
||||
layer.mlp_layernorm.weight.value = param
|
||||
else:
|
||||
layer.mlp_layernorm.bias.value = param
|
||||
elif 'post_attention_layernorm' in name:
|
||||
assert layer.post_layernorm is not None
|
||||
if not is_bias(name):
|
||||
layer.post_layernorm.weight.value = param
|
||||
else:
|
||||
layer.post_layernorm.bias.value = param
|
||||
elif 'word_embeddings' in name:
|
||||
if mapping.is_first_pp_rank():
|
||||
trtllm_falcon.embedding.weight.value = param.copy()
|
||||
if mapping.is_last_pp_rank():
|
||||
trtllm_falcon.lm_head.weight.value = split_matrix(
|
||||
param, mapping.tp_size, mapping.tp_rank, dim=0)
|
||||
elif 'ln_f' in name:
|
||||
if mapping.is_last_pp_rank():
|
||||
if not is_bias(name):
|
||||
trtllm_falcon.ln_f.weight.value = param
|
||||
else:
|
||||
trtllm_falcon.ln_f.bias.value = param
|
||||
del state_dict
|
||||
|
||||
tok = time.time()
|
||||
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik))
|
||||
logger.info(f'Weights loaded. Total time: {t}')
|
||||
|
||||
|
||||
def get_scaling_factors(
|
||||
model_path: Union[str, Path],
|
||||
num_layers: int,
|
||||
|
||||
@ -74,14 +74,14 @@ python build.py --model_type t5 \
|
||||
--dtype float32 \
|
||||
--max_beam_width 1
|
||||
|
||||
# Example 2: build flan-t5-small using 4-way tensor parallelism on a node with 8 GPUs (but only use 4 of them, for demonstration purpose), BF16, enabling beam search up to width=3
|
||||
# Example 2: build t5-small using 4-way tensor parallelism on a node with 8 GPUs (but only use 4 of them, for demonstration purpose), BF16, enabling beam search up to width=3
|
||||
python build.py --model_type t5 \
|
||||
--world_size 4 \
|
||||
--tp_size 4 \
|
||||
--gpus_per_node 4 \
|
||||
--weight_dir tmp/trt_models/flan-t5-small/tp4 \
|
||||
-o tmp/trt_engines/flan-t5-small/4-gpu \
|
||||
--engine_name flan-t5-small \
|
||||
--weight_dir tmp/trt_models/t5-small/tp4 \
|
||||
-o tmp/trt_engines/t5-small/4-gpu \
|
||||
--engine_name t5-small \
|
||||
--remove_input_padding \
|
||||
--use_bert_attention_plugin \
|
||||
--use_gpt_attention_plugin \
|
||||
@ -90,7 +90,7 @@ python build.py --model_type t5 \
|
||||
--dtype bfloat16 \
|
||||
--max_beam_width 3
|
||||
|
||||
# Example 3: build flan-t5-small using 2-way tensor parallelism and 2-way pipeline parallelism on a node with 8 GPUs, FP16, enabling beam search up to width=3
|
||||
# Example 3: build flan-t5-small using 2-way tensor parallelism and 2-way pipeline parallelism on a node with 8 GPUs, BF16, enabling beam search up to width=3
|
||||
python build.py --model_type t5 \
|
||||
--world_size 4 \
|
||||
--tp_size 2 \
|
||||
@ -104,7 +104,7 @@ python build.py --model_type t5 \
|
||||
--use_gpt_attention_plugin \
|
||||
--use_gemm_plugin \
|
||||
--use_rmsnorm_plugin \
|
||||
--dtype float16 \
|
||||
--dtype bfloat16 \
|
||||
--max_beam_width 3
|
||||
```
|
||||
|
||||
@ -117,9 +117,15 @@ Note that during model deployment, only the TensorRT engine files are needed. Pr
|
||||
# Example 1: inference w/ single GPU, FP32, greedy search, compare results with HuggingFace FP32
|
||||
python3 run.py --engine_dir tmp/trt_engines/t5-small/1-gpu/float32/tp1 --engine_name t5-small --model_name t5-small --max_new_token=64 --num_beams=1 --compare_hf_fp32
|
||||
|
||||
# Example 2: inference w/ 4 GPUs (4-way TP, as configured during the engine building step), BF16, greedy search
|
||||
mpirun --allow-run-as-root -np 4 python3 run.py --engine_dir tmp/trt_engines/flan-t5-small/4-gpu/bfloat16/tp4 --engine_name flan-t5-small --model_name google/flan-t5-small --max_new_token=64 --num_beams=1
|
||||
# Example 2: inference w/ 4 GPUs (4-way TP, as configured during the engine building step), BF16, greedy search, compare results with HuggingFace FP32
|
||||
mpirun --allow-run-as-root -np 4 python3 run.py --engine_dir tmp/trt_engines/t5-small/4-gpu/bfloat16/tp4 --engine_name t5-small --model_name t5-small --max_new_token=64 --num_beams=1 --compare_hf_fp32
|
||||
|
||||
# Example 3: inference w/ 4 GPUs (2-way TP and 2-way PP, as configured during the engine building step), FP16, beam search
|
||||
mpirun --allow-run-as-root -np 4 python3 run.py --engine_dir tmp/trt_engines/flan-t5-small/4-gpu/float16/tp2 --engine_name flan-t5-small --model_name google/flan-t5-small --max_new_token=64 --num_beams=3
|
||||
# Example 3: inference w/ 4 GPUs (2-way TP and 2-way PP, as configured during the engine building step), BF16, greedy search
|
||||
mpirun --allow-run-as-root -np 4 python3 run.py --engine_dir tmp/trt_engines/flan-t5-small/4-gpu/bfloat16/tp2 --engine_name flan-t5-small --model_name google/flan-t5-small --max_new_token=64 --num_beams=1
|
||||
```
|
||||
|
||||
### Reminders
|
||||
|
||||
- Flan-T5 models have known issues regarding FP16 precision and using BF16 precision is recommended, regardless of TRT-LLM. While we are working on improving FP16 results, please stay with FP32 or BF16 precision for Flan-T5 family.
|
||||
- Batched/Ragged input with beam search is having subtle issues with some sequence results being truncated. For the time being, please follow (1) if batch size = 1, no problem (2) if batched input is padded (i.e., not using `--remove_input_padding` flag), no problem (3) if batched input is ragged (i.e., using `--remove_input_padding`), only use greedy search for now.
|
||||
- For T5 and Flan-T5 family that have relative attention bias design, the relative attention table is split along `num_heads` dimension in Tensor Parallelism mode. Therefore, `num_heads` must be divisible by `tp_size`. Please be aware of this when setting the TP parameter.
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user