Update TensorRT-LLM (#1098)

* Update TensorRT-LLM

* update submodule

* Remove unused binaries
This commit is contained in:
Kaiyu Xie 2024-02-18 15:48:08 +08:00 committed by GitHub
parent 0ab9d17a59
commit 0f041b7b57
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
231 changed files with 11513 additions and 4728 deletions

2
3rdparty/cutlass vendored

@ -1 +1 @@
Subproject commit 39c6a83f231d6db2bc6b9c251e7add77d68cbfb4
Subproject commit 8236f30675bbe98f81d11c05764b77bfcb25b8cc

View File

@ -73,7 +73,7 @@ Run a preprocessing script to prepare/generate dataset into a json that gptManag
This tool can be used in 2 different modes of traffic generation.
1 Dataset
##### 1 Dataset
“Prompt”, “Instruction” (optional) and “Answer” specified as sentences in a Json file
@ -90,7 +90,7 @@ python3 prepare_dataset.py \
--max-input-len 300
```
2 Normal token length distribution
##### 2 Normal token length distribution
This mode allows the user to generate normal token length distributions with a mean and std deviation specified.
For example, setting mean=100 and std dev=10 would generate requests where 95.4% of values are in <80,120> range following the normal probability distribution. Setting std dev=0 will generate all requests with the same mean number of tokens.
@ -140,3 +140,17 @@ mpirun -n 2 ./benchmarks/gptManagerBenchmark \
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
--max_num_samples 500
```
To emulate `gptSessionBenchmark` static batching, you can use the `--static_emulated_batch_size` and `--static_emulated-timeout` arguments.
Given a `static_emulated_batch_size` of `n` the server will wait for `n` requests to arrive before submitting them to the batch manager at once. If the `static_emulated-timeout` (in ms) is reached before `n` requests are collected, the batch will be submitted prematurely with the current request count.
Take GPT-350M as an example for single GPU with static batching
```
./benchmarks/gptManagerBenchmark \
--model gpt \
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
--type IFB \
--static_emulated_batch_size 32 \
--static_emulated_timeout 100 \
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
```

Binary file not shown.

View File

@ -27,12 +27,13 @@
#include "tensorrt_llm/runtime/worldConfig.h"
#include <chrono>
#include <cstdint>
#include <cxxopts.hpp>
#include <iostream>
#include <nlohmann/json.hpp>
#include <random>
#include <string>
#include <thread>
#include <utility>
using namespace tensorrt_llm::batch_manager;
using namespace tensorrt_llm::runtime;
@ -104,7 +105,7 @@ public:
void push(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
{
std::lock_guard<std::mutex> lk(mMutex);
std::lock_guard<std::mutex> lock(mMutex);
TLLM_CHECK_WITH_INFO(!hasInProgressReqId(requestId) && !hasPendingReqId(requestId),
"requestId %lu is already in progress, request is ignored.", requestId);
@ -119,14 +120,14 @@ public:
/// has been marked in progress
std::tuple<std::shared_ptr<WorkItem>, bool> pop()
{
std::lock_guard<std::mutex> lk(mMutex);
std::lock_guard<std::mutex> lock(mMutex);
auto workItem = mPendingWorkItems.front();
mPendingWorkItems.pop_front();
mPendingWorkItemsReqIds.erase(workItem->requestId());
bool markedInProgress;
mInProgressWorkItems.emplace(std::make_pair(workItem->requestId(), workItem));
bool markedInProgress = false;
mInProgressWorkItems.emplace(workItem->requestId(), workItem);
markedInProgress = true;
return {workItem, markedInProgress};
@ -134,19 +135,19 @@ public:
size_t numPendingWorkItems() const
{
std::lock_guard<std::mutex> lk(mMutex);
std::lock_guard<std::mutex> lock(mMutex);
return mPendingWorkItems.size();
}
size_t numInProgressWorkItems() const
{
std::lock_guard<std::mutex> lk(mMutex);
std::lock_guard<std::mutex> lock(mMutex);
return mInProgressWorkItems.size();
}
size_t size() const
{
std::lock_guard<std::mutex> lk(mMutex);
std::lock_guard<std::mutex> lock(mMutex);
return mPendingWorkItems.size() + mInProgressWorkItems.size();
}
@ -154,7 +155,7 @@ public:
/// @param requestId
void markFinished(const uint64_t requestId)
{
std::lock_guard<std::mutex> lk(mMutex);
std::lock_guard<std::mutex> lock(mMutex);
if (hasInProgressReqId(requestId))
{
mInProgressWorkItems.erase(requestId);
@ -175,12 +176,13 @@ private:
struct BenchInfo
{
BenchInfo() {}
BenchInfo() = default;
BenchInfo(int _inputLength, int _outputLength, std::chrono::time_point<std::chrono::steady_clock> _start)
: inputLength(_inputLength)
, outputLength(_outputLength)
, start(_start)
, latency()
{
}
@ -194,9 +196,9 @@ struct BenchInfo
class Recorder
{
public:
Recorder(std::string opCsvFile)
explicit Recorder(std::string opCsvFile)
: mOpCsvFile(std::move(opCsvFile))
{
mOpCsvFile = opCsvFile;
}
void initialize()
@ -286,11 +288,11 @@ private:
std::chrono::time_point<std::chrono::steady_clock> mStart;
std::chrono::time_point<std::chrono::steady_clock> mEnd;
int mNumSamples;
float mTotalLatency;
float mSeqThroughput;
float mAvgSeqLatency;
float mTokenThroughput;
int mNumSamples{};
float mTotalLatency{};
float mSeqThroughput{};
float mAvgSeqLatency{};
float mTokenThroughput{};
std::string mOpCsvFile;
}; // class Recorder
@ -299,25 +301,39 @@ class GptServer
public:
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams,
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId, int waitSleep)
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId, std::chrono::milliseconds waitSleep,
std::optional<uint64_t> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs)
: mRecorder(std::move(recorder))
, mTerminateReqId(terminateReqId)
, mWaitSleep(waitSleep)
, mStaticEmulatedBatchSize(staticEmulatedBatchSize)
, mEmulatedBatchEndTimestamp(
std::chrono::steady_clock::now() + std::chrono::milliseconds(staticEmulatedTimeoutMs))
, mStaticEmulatedTimeoutMs(staticEmulatedTimeoutMs)
, mActiveCount(0)
{
ReturnBatchManagerStatsCallback iterationDataCallback{nullptr};
if (optionalParams.logIterationData)
ReturnBatchManagerStatsCallback iterationDataCallback = [this, &optionalParams](std::string const& log)
{
iterationDataCallback = [this](const std::string& s) { return TLLM_LOG_INFO(s); };
}
if (optionalParams.logIterationData)
{
TLLM_LOG_INFO(log);
}
if (mStaticEmulatedBatchSize)
{
auto const json = nlohmann::json::parse(log);
auto const activeRequests = json["Active Request Count"];
TLLM_CHECK(activeRequests <= mStaticEmulatedBatchSize.value());
}
};
mBatchManager = std::make_shared<GptManager>(
trtEnginePath, modelType, maxBeamWidth, schedulerPolicy,
[this](int max_num_requests) { return getInferenceRequests(max_num_requests); },
[this](uint64_t requestId, std::list<NamedTensor> response_tensors, bool final_response,
const std::string& errMsg)
[this](uint64_t requestId, std::list<NamedTensor> const& response_tensors, bool final_response,
std::string const& errMsg)
{ return sendResponse(requestId, response_tensors, final_response, errMsg); },
nullptr, iterationDataCallback, optionalParams, terminateReqId);
mRecorder = recorder;
mTerminateReqId = terminateReqId;
mWaitSleep = waitSleep;
}
~GptServer()
@ -353,10 +369,9 @@ public:
void waitForEmpty() const
{
while (mWorkItemsQueue.size() > 0)
while (!mWorkItemsQueue.empty())
{
std::chrono::milliseconds timespan(mWaitSleep);
std::this_thread::sleep_for(timespan);
std::this_thread::sleep_for(mWaitSleep);
}
}
@ -365,6 +380,11 @@ public:
mBatchManager->waitUntilTerminate();
}
void shutdown() const
{
mBatchManager->shutdown();
}
// Return up to max_num_requests inference requests.
std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(const int max_num_requests)
{
@ -376,17 +396,40 @@ public:
auto rank = comm.getRank();
if (rank == 0)
{
auto num_new_work_items = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
auto numNewWorkItems = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
static_cast<int64_t>(max_num_requests));
if (world_size > 1)
{
comm.bcast(&num_new_work_items, 1, mpi::MpiType::kINT64, 0);
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
}
if (num_new_work_items > 0)
bool readyForNextBatch = numNewWorkItems > 0;
if (mStaticEmulatedBatchSize)
{
if (numNewWorkItems > 0)
{
bool const timeout = std::chrono::steady_clock::now() > mEmulatedBatchEndTimestamp;
bool const previousBatchFinished = mActiveCount == 0;
bool const haveEnoughForNextBatch = numNewWorkItems >= mStaticEmulatedBatchSize.value();
readyForNextBatch = previousBatchFinished && (timeout || haveEnoughForNextBatch);
}
if (numNewWorkItems == 0 || readyForNextBatch)
{
// Timeout should only begin once we have at least 1 pending request.
// Reset timeout when no requests are pending or we submit a new batch.
mEmulatedBatchEndTimestamp
= std::chrono::steady_clock::now() + std::chrono::milliseconds(mStaticEmulatedTimeoutMs);
}
}
if (readyForNextBatch)
{
int count = 0;
while (count < num_new_work_items)
// Only add a single batch at a time when emulating static batching
auto const numItemsToAdd = std::min(
numNewWorkItems, static_cast<int64_t>(mStaticEmulatedBatchSize.value_or(numNewWorkItems)));
mActiveCount += numItemsToAdd;
while (count < numItemsToAdd)
{
auto [workItem, markedInProgress] = mWorkItemsQueue.pop();
@ -420,14 +463,14 @@ public:
else
{
// subordinate ranks hang until master rank sends work
int64_t num_new_work_items;
comm.bcast(&num_new_work_items, 1, mpi::MpiType::kINT64, 0);
if (num_new_work_items > 0)
int64_t numNewWorkItems = 0;
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
if (numNewWorkItems > 0)
{
std::vector<int64_t> packed;
comm.bcast(packed, 0);
int64_t* packed_ptr = packed.data();
for (int64_t count = 0; count < num_new_work_items; ++count)
for (int64_t count = 0; count < numNewWorkItems; ++count)
{
int64_t n = *(packed_ptr++);
auto ir = InferenceRequest::deserialize(packed_ptr);
@ -453,6 +496,7 @@ public:
{
mWorkItemsQueue.markFinished(requestId);
mRecorder->recordEnd(requestId);
mActiveCount--;
}
}
catch (const std::exception& e)
@ -466,15 +510,27 @@ private:
std::shared_ptr<Recorder> mRecorder;
WorkItemsQueue mWorkItemsQueue;
std::optional<uint64_t> mTerminateReqId;
int mWaitSleep;
std::chrono::milliseconds mWaitSleep;
std::optional<int> mStaticEmulatedBatchSize;
std::chrono::time_point<std::chrono::steady_clock> mEmulatedBatchEndTimestamp;
int32_t mStaticEmulatedTimeoutMs;
std::atomic<uint64_t> mActiveCount;
}; // class GptServer
namespace
{
std::tuple<std::vector<std::vector<int32_t>>, std::vector<int32_t>, std::vector<float>> parseWorkloadJson(
std::filesystem::path const& datasetPath, int maxNumSamples)
struct Sample
{
std::vector<int32_t> inputIds;
int32_t outputLen;
float delay;
};
using Samples = std::vector<Sample>;
Samples parseWorkloadJson(std::filesystem::path const& datasetPath, int maxNumSamples)
{
auto constexpr allowExceptions = true;
auto constexpr ignoreComments = true;
@ -482,37 +538,29 @@ std::tuple<std::vector<std::vector<int32_t>>, std::vector<int32_t>, std::vector<
std::ifstream jsonStream(datasetPath);
auto json = nlohmann::json::parse(jsonStream, nullptr, allowExceptions, ignoreComments);
std::vector<std::vector<int32_t>> inputIds;
std::vector<int32_t> outputLens;
std::vector<float> delays;
Samples samples;
long int numSamples = 0;
for (auto& sample : json["samples"])
for (auto const& sample : json["samples"])
{
if (numSamples >= maxNumSamples)
if (samples.size() >= maxNumSamples)
break;
numSamples++;
inputIds.push_back(sample["input_ids"]);
outputLens.push_back(sample["output_len"]);
delays.push_back(sample["delay"]);
samples.emplace_back(Sample{sample["input_ids"], sample["output_len"], sample["delay"]});
}
return std::make_tuple(inputIds, outputLens, delays);
return samples;
}
std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId,
std::tuple<std::vector<std::vector<int32_t>>, std::vector<int32_t>, std::vector<float>> const& dataset,
std::size_t sample_idx, ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId,
ITensor::SharedPtr const& padId, BufferManager const& bufferManager,
ITensor::SharedPtr const& returnContextLogits = nullptr, ITensor::SharedPtr const& returnGenerationLogits = nullptr)
std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const& sample,
ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId, ITensor::SharedPtr const& padId,
BufferManager const& bufferManager, ITensor::SharedPtr const& returnContextLogits = nullptr,
ITensor::SharedPtr const& returnGenerationLogits = nullptr)
{
auto request = std::make_shared<InferenceRequest>(reqId);
auto const& inputIds = (std::get<0>(dataset))[sample_idx];
auto const& inputIds = sample.inputIds;
request->setInputIds(bufferManager.copyFrom(
inputIds, ITensor::makeShape({static_cast<SizeType>(inputIds.size())}), MemoryType::kPINNED));
auto const request_output_len = (std::get<1>(dataset))[sample_idx];
auto const requestOutputLen = sample.outputLen;
request->setMaxNewTokens(
bufferManager.copyFrom(&request_output_len, ITensor::makeShape({1, 1}), MemoryType::kPINNED));
bufferManager.copyFrom(&requestOutputLen, ITensor::makeShape({1, 1}), MemoryType::kPINNED));
request->setBeamWidth(beamWidthTensor);
if (eosId != nullptr)
{
@ -533,29 +581,15 @@ std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId,
return request;
}
void benchmarkGptManager([[maybe_unused]] std::string const& modelName, std::filesystem::path const& engineDir,
std::string const& type, std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples,
int beamWidth, int warmUp, const std::optional<int32_t>& eosId, const std::optional<int32_t>& padId,
std::shared_ptr<nvinfer1::ILogger> const& logger, TrtGptModelOptionalParams const& optionalParams,
batch_scheduler::SchedulerPolicy schedulerPolicy, int waitSleep, bool returnContextLogits,
bool returnGenerationLogits)
void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType,
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
const std::optional<int32_t>& eosId, const std::optional<int32_t>& padId,
TrtGptModelOptionalParams const& optionalParams, batch_scheduler::SchedulerPolicy schedulerPolicy,
std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits,
std::optional<int> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs)
{
auto const worldConfig = WorldConfig::mpi();
TrtGptModelType modelType;
if (type == "V1")
{
modelType = TrtGptModelType::V1;
}
else if (type == "IFB")
{
modelType = TrtGptModelType::InflightFusedBatching;
}
else
{
TLLM_LOG_ERROR("Unexpected batching type: %s", type.c_str());
}
BufferManager bufferManager{std::make_shared<CudaStream>()}; // the stream is not used
ITensor::SharedPtr beamWidthTensor{
@ -563,14 +597,13 @@ void benchmarkGptManager([[maybe_unused]] std::string const& modelName, std::fil
// Load dataset
const auto samples = parseWorkloadJson(datasetPath, maxNumSamples);
const auto numSamples = (std::get<0>(samples)).size();
const auto numSamples = samples.size();
const int maxBeamWidth = beamWidth;
auto recorder = std::make_shared<Recorder>(opCsvFile);
uint64_t terminateReqId = numSamples + 1;
auto gptServer = std::make_shared<GptServer>(
engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams, recorder, terminateReqId, waitSleep);
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams,
recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, staticEmulatedTimeoutMs);
ITensor::SharedPtr eosIdTensor{
eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr};
@ -594,7 +627,7 @@ void benchmarkGptManager([[maybe_unused]] std::string const& modelName, std::fil
++reqId;
if (i == terminateReqId)
++reqId;
auto request = makeRequest(reqId, samples, 0, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager);
auto request = makeRequest(reqId, samples[0], beamWidthTensor, eosIdTensor, padIdTensor, bufferManager);
gptServer->enqueue(request);
}
gptServer->waitForEmpty();
@ -603,11 +636,10 @@ void benchmarkGptManager([[maybe_unused]] std::string const& modelName, std::fil
recorder->initialize();
for (std::size_t i = 0; i < numSamples; ++i)
{
auto request = makeRequest(i + 1, samples, i, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager,
auto request = makeRequest(i + 1, samples[i], beamWidthTensor, eosIdTensor, padIdTensor, bufferManager,
returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor);
gptServer->enqueue(request);
auto delayInMs = int(std::get<2>(samples)[i] * 1000);
auto delayInMs = static_cast<int>(samples[i].delay * 1000);
if (delayInMs != 0)
{
@ -635,6 +667,7 @@ int main(int argc, char* argv[])
cxxopts::Options options(
"TensorRT-LLM BatchManager Benchmark", "TensorRT-LLM BatchManager Benchmark for GPT and GPT-like models.");
options.add_options()("h,help", "Print usage");
// TODO(rkobus): remove because unused
options.add_options()(
"m,model", "Model name specified for engines.", cxxopts::value<std::string>()->default_value("gpt_350m"));
options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>());
@ -668,6 +701,11 @@ int main(int argc, char* argv[])
options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.",
cxxopts::value<std::string>()->default_value("guaranteed_no_evict"));
options.add_options()("static_emulated_batch_size",
"Emulate static batching performance with the provided batch size.", cxxopts::value<int>());
options.add_options()("static_emulated_timeout",
"Timeout (ms) before launching a partial batch in emulated static batching mode",
cxxopts::value<int>()->default_value("500"));
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
cxxopts::value<std::string>()->default_value("error"));
options.add_options()(
@ -693,6 +731,20 @@ int main(int argc, char* argv[])
// Argument: Batching Type
auto const type = result["type"].as<std::string>();
TrtGptModelType modelType{TrtGptModelType::V1};
if (type == "V1")
{
modelType = TrtGptModelType::V1;
}
else if (type == "IFB")
{
modelType = TrtGptModelType::InflightFusedBatching;
}
else
{
TLLM_LOG_ERROR("Unexpected batching type: %s", type.c_str());
return 1;
}
// Argument: Dataset
auto const datasetPath = result["dataset"].as<std::string>();
@ -705,7 +757,7 @@ int main(int argc, char* argv[])
auto const beamWidth = result["beam_width"].as<int>();
// Argument: wait_sleep
auto const waitSleep = result["wait_sleep"].as<int>();
auto const waitSleep = std::chrono::milliseconds(result["wait_sleep"].as<int>());
TrtGptModelOptionalParams optionalParams;
// Argument: Max tokens in paged K-V Cache
@ -765,6 +817,14 @@ int main(int argc, char* argv[])
eosId = result["eos_id"].as<int>();
}
std::optional<int> staticEmulatedBatchSize;
// Argument: Static emulated batch size
if (result.count("static_emulated_batch_size"))
{
staticEmulatedBatchSize = result["static_emulated_batch_size"].as<int>();
}
auto const staticEmulatedTimeout = result["static_emulated_timeout"].as<int>();
// Argument: Scheduler policy
batch_scheduler::SchedulerPolicy schedulerPolicy;
auto const schedulerPolicyArg = result["scheduler_policy"].as<std::string>();
@ -815,9 +875,9 @@ int main(int argc, char* argv[])
try
{
benchmarkGptManager(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), type,
datasetPath, opCsvFile, maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, logger,
optionalParams, schedulerPolicy, waitSleep, returnContextLogits, returnGenerationLogits);
benchmarkGptManager(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile, maxNumSamples,
beamWidth, result["warm_up"].as<int>(), eosId, padId, optionalParams, schedulerPolicy, waitSleep,
returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, staticEmulatedTimeout);
}
catch (const std::exception& e)
{

Binary file not shown.

View File

@ -704,6 +704,7 @@ _allowed_configs = {
hidden_act="relu",
n_positions=512,
num_buckets=32,
max_distance=128,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
@ -724,6 +725,7 @@ _allowed_configs = {
hidden_act="relu",
n_positions=512,
num_buckets=32,
max_distance=128,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
@ -744,6 +746,7 @@ _allowed_configs = {
hidden_act="relu",
n_positions=512,
num_buckets=32,
max_distance=128,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
@ -764,6 +767,7 @@ _allowed_configs = {
hidden_act="relu",
n_positions=512,
num_buckets=32,
max_distance=128,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
@ -784,6 +788,7 @@ _allowed_configs = {
hidden_act="relu",
n_positions=512,
num_buckets=32,
max_distance=128,
max_batch_size=8,
max_encoder_input_len=1024,
max_decoder_input_len=1,
@ -947,7 +952,7 @@ _allowed_configs = {
)),
"baichuan_7b":
ModelConfig(name="baichuan_7b",
family="baichuan_7b",
family="baichuan",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=32,
@ -964,7 +969,7 @@ _allowed_configs = {
)),
"baichuan2_7b_chat":
ModelConfig(name="baichuan2_7b_chat",
family="baichuan_7b",
family="baichuan",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=32,
@ -981,7 +986,7 @@ _allowed_configs = {
)),
"baichuan_13b_chat":
ModelConfig(name="baichuan_13b_chat",
family="baichuan_13b",
family="baichuan",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=40,
@ -998,7 +1003,7 @@ _allowed_configs = {
)),
"baichuan2_13b_chat":
ModelConfig(name="baichuan2_13b_chat",
family="baichuan_13b",
family="baichuan",
benchmark_type="gpt",
build_config=BuildConfig(
num_layers=40,

View File

@ -613,36 +613,42 @@ def build_gpt(args):
})
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(config)
elif family == "baichuan_7b":
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
num_kv_heads=None,
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
dtype=kv_dtype,
mlp_hidden_size=build_config['inter_size'],
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size),
quant_mode=quant_mode)
elif family == "baichuan_13b":
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(
num_layers=build_config['num_layers'],
num_heads=build_config['num_heads'],
num_kv_heads=None,
hidden_size=build_config['hidden_size'],
vocab_size=build_config['vocab_size'],
hidden_act=build_config['hidden_act'],
max_position_embeddings=build_config['n_positions'],
position_embedding_type=PositionEmbeddingType.alibi,
dtype=kv_dtype,
mlp_hidden_size=build_config['inter_size'],
mapping=tensorrt_llm.Mapping(world_size=world_size,
tp_size=world_size),
quant_mode=quant_mode)
elif family == "baichuan":
config = {
'architecture':
'BaichuanForCausalLM',
'dtype':
args.dtype,
'logits_dtype':
'float32',
'vocab_size':
build_config['vocab_size'],
'max_position_embeddings':
build_config['n_positions'],
'hidden_size':
build_config['hidden_size'],
'num_hidden_layers':
build_config['num_layers'],
'num_attention_heads':
build_config['num_heads'],
'num_key_value_heads':
build_config['num_heads'],
'hidden_act':
build_config['hidden_act'],
'intermediate_size':
build_config['inter_size'],
'position_embedding_type':
'alibi_with_scale' if '7b' in args.model else 'rope_gpt_neox',
'quantization': {
'group_size': 128
},
'mapping': {
'world_size': world_size,
'tp_size': world_size,
},
}
config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(config)
elif family == "internlm":
config = {
'architecture':
@ -748,6 +754,8 @@ def build_gpt(args):
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
network.plugin_config.enable_remove_input_padding()
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
network.plugin_config.set_moe_plugin(dtype=args.dtype)
if args.quantization is None or "fp8" not in args.quantization:
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
@ -787,7 +795,7 @@ def build_gpt(args):
max_beam_width=max_beam_width)
if family in [
'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox',
'gptj', "mamba"
'gptj', 'mamba', 'baichuan'
]:
tensorrt_llm_model(**inputs)
else:
@ -1013,6 +1021,8 @@ def enc_dec_build_helper(component, config, args):
tp_size=world_size,
pp_size=1) # TP only
fp16_clamping = (args.dtype == 'float16') and ('t5' in family)
if component == 'encoder':
tllm_model = tensorrt_llm.models.EncoderModel(
num_layers=config['num_layers'],
@ -1040,7 +1050,8 @@ def enc_dec_build_helper(component, config, args):
dtype=dtype,
use_parallel_embedding=False, # by default
embedding_sharding_dim=0, # by default
mapping=mapping)
mapping=mapping,
fp16_clamping=fp16_clamping)
elif component == 'decoder':
tllm_model = tensorrt_llm.models.DecoderModel(
num_layers=config['num_layers'],
@ -1073,7 +1084,8 @@ def enc_dec_build_helper(component, config, args):
embedding_sharding_dim=0, # by default
mapping=mapping,
rescale_before_lm_head=rescale_before_lm_head,
logits_dtype='float32') # by default
logits_dtype='float32', # by default
fp16_clamping=fp16_clamping)
# Module -> Network
engine_name = get_engine_name(args.model, args.dtype, world_size,
@ -1129,6 +1141,7 @@ def enc_dec_build_helper(component, config, args):
hidden_size=hidden_size,
head_size=builder_config.head_size,
max_batch_size=builder_config.max_batch_size,
max_beam_width=builder_config.max_beam_width,
vocab_size=builder_config.vocab_size,
num_layers=builder_config.num_layers,
gpt_attention_plugin=network.plugin_config.gpt_attention_plugin,

View File

@ -89,6 +89,7 @@ class EncDecBenchmark(BaseBenchmark):
hidden_size=hidden_size,
head_size=config["builder_config"]["head_size"],
max_batch_size=config["builder_config"]["max_batch_size"],
max_beam_width=config["builder_config"]["max_beam_width"],
vocab_size=config["builder_config"]["vocab_size"],
num_layers=config["builder_config"]["num_layers"],
gpt_attention_plugin=config["plugin_config"]

View File

@ -95,6 +95,7 @@ class GPTBenchmark(BaseBenchmark):
if args.mode == 'plugin':
self.use_gpt_attention_plugin = True
self.remove_input_padding = True
self.use_moe_plugin = True
elif args.mode == 'ootb-except-mha':
self.use_gpt_attention_plugin = True
@ -110,6 +111,7 @@ class GPTBenchmark(BaseBenchmark):
self.num_kv_heads = self.num_heads
model_config = tensorrt_llm.runtime.ModelConfig(
max_batch_size=self.max_batch_size,
max_beam_width=self.num_beams,
vocab_size=self.vocab_size,
num_layers=self.num_layers,
num_heads=self.num_heads // self.world_size,

View File

@ -172,8 +172,12 @@ get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_SOURCE_DIR} PATH)
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
include_directories(
${CUDAToolkit_INCLUDE_DIRS} ${CUDNN_ROOT_DIR}/include ${NCCL_INCLUDE_DIR}
${3RDPARTY_DIR}/cutlass/include ${3RDPARTY_DIR}/NVTX/include
${CUDAToolkit_INCLUDE_DIRS}
${CUDNN_ROOT_DIR}/include
${NCCL_INCLUDE_DIR}
${3RDPARTY_DIR}/cutlass/include
${3RDPARTY_DIR}/cutlass/tools/util/include
${3RDPARTY_DIR}/NVTX/include
${3RDPARTY_DIR}/json/include)
# TRT dependencies

View File

@ -72,6 +72,8 @@ public:
BatchManagerErrorCode_t shutdown();
SizeType getNumActiveRequests();
virtual ~GptManager();
protected:

View File

@ -354,6 +354,11 @@ public:
mDraftLogits = draftLogits;
}
SizeType getNumDraftTokens() const
{
return mDraftTokens->size();
}
void setReturnContextLogits(const bool returnContextLogits)
{
mReturnContextLogits = returnContextLogits;

View File

@ -19,6 +19,7 @@
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/decodingMode.h"
#include <optional>
#include <vector>
@ -35,13 +36,15 @@ public:
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
bool enableTrtOverlap = false, std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt,
bool normalizeLogProbs = true, bool logIterationData = false, bool enableChunkedContext = false)
bool normalizeLogProbs = true, bool logIterationData = false, bool enableChunkedContext = false,
std::optional<runtime::DecodingMode> const& decodingMode = std::nullopt)
: kvCacheConfig{kvCacheConfig}
, enableTrtOverlap{enableTrtOverlap}
, deviceIds(deviceIds)
, normalizeLogProbs{normalizeLogProbs}
, logIterationData{logIterationData}
, enableChunkedContext{enableChunkedContext}
, decodingMode{decodingMode}
{
}
@ -51,6 +54,7 @@ public:
bool normalizeLogProbs;
bool logIterationData;
bool enableChunkedContext;
std::optional<runtime::DecodingMode> decodingMode;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -21,6 +21,7 @@
#include "tensorrt_llm/runtime/iTensor.h"
#include <memory>
#include <optional>
namespace tensorrt_llm::runtime
{
@ -36,6 +37,8 @@ public:
, maxAttentionWindow{maxAttentionWindow}
, sinkTokenLength{sinkTokenLength}
, maxBatchSize{maxBatchSize}
, maxStopWordsLen{0}
, maxBadWordsLen{0}
, logits{std::move(logits)}
, endIds{std::move(endIds)}
{
@ -49,20 +52,28 @@ public:
SizeType maxAttentionWindow;
SizeType sinkTokenLength;
SizeType maxBatchSize;
TensorPtr logits; // [batchSize, beamWidth, vocabSizePadded], on gpu
TensorPtr endIds; // [batchSize * beamWidth], on gpu
SizeType maxStopWordsLen; // The maximum value in the `stopWordsLens` tensor
SizeType maxBadWordsLen; // The maximum value in the `badWordsLens` tensor
TensorPtr logits; // [batchSize, beamWidth, vocabSizePadded], on gpu
std::optional<std::vector<TensorPtr>>
logitsVec; // vector of size [batchSize] contains logits of size [beamWidth, vocabSizePadded], on gpu
TensorPtr endIds; // [maxBatchSize * beamWidth], on gpu
// optional parameters
TensorPtr finished; // [maxBatchSize, beamWidth], finished states at current iteration.
// If true for some request, the decoding step of it is skipped, on gpu
TensorPtr sequenceLimitLength; // [maxBatchSize], on gpu
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
TensorPtr embeddingBias; // [maxBatchSize, vocabSizePadded], on gpu
TensorPtr lengths; // [maxBatchSize, beamWidth], on gpu
TensorPtr badWordsList; // [2, badWordsLength] or [batchSize, 2, badWordsLength], on gpu
TensorPtr badWordsList; // [2, badWordsLength] or [maxBatchSize, 2, badWordsLength], on gpu
TensorPtr badWordsPtrs; // [maxBatchSize][2, badWordsLength], on gpu
TensorPtr badWordsLens; // [maxBatchSize], on gpu
TensorPtr stopWordsList; // [maxBatchSize, 2, stopWordsLength], on gpu
TensorPtr stopWordsPtrs; // [maxBatchSize][2, stopWordsLength], on gpu
TensorPtr stopWordsLens; // [maxBatchSize], on gpu
TensorPtr noRepeatNgramSize; // [maxBatchSize], on gpu
TensorPtr
batchSlots; // [batchSize], optional, address map of the linear batch id to to the seq slots, int32_t, on gpu
batchSlots; // [batchSize], optional, address map of the linear batch id to to the seq slots, int32_t, pinned
// parameters for beam search
TensorPtr cacheIndirection; // [maxBatchSize, beamWidth, maxSeqLen] - the k/v cache index for beam search, on gpu

View File

@ -0,0 +1,138 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
namespace tensorrt_llm
{
namespace runtime
{
class DecodingMode
{
public:
static auto constexpr None()
{
return DecodingMode{kNone};
}
static auto constexpr TopK()
{
return DecodingMode{kTopK};
}
static auto constexpr TopP()
{
return DecodingMode{kTopP};
}
static auto constexpr TopKTopP()
{
return DecodingMode{kTopKTopP};
}
static auto constexpr BeamSearch()
{
return DecodingMode{kBeamSearch};
}
bool constexpr isNone()
{
return mState == 0;
}
bool constexpr isTopK()
{
return anyBitSet(kTopK);
}
bool constexpr isTopP()
{
return anyBitSet(kTopP);
}
bool constexpr isTopKorTopP()
{
return anyBitSet(kTopKTopP);
}
bool constexpr isTopKandTopP()
{
return allBitSet(kTopKTopP);
}
bool constexpr isBeamSearch()
{
return anyBitSet(kBeamSearch);
}
using UnderlyingType = uint8_t;
private:
constexpr DecodingMode(UnderlyingType state)
: mState(state)
{
}
// No mode specified. Config will be determined from the beam width of the first request at runtime
// TopKTopP if beamWidth == 1, BeamSearch otherwise
static UnderlyingType constexpr kNone{0};
static UnderlyingType constexpr kTopK{1u << 0};
static UnderlyingType constexpr kTopP{1u << 1};
static UnderlyingType constexpr kBeamSearch{1u << 2};
static UnderlyingType constexpr kTopKTopP{kTopK | kTopP};
bool constexpr anyBitSet(UnderlyingType bits) const
{
return (mState & bits) != 0;
}
bool constexpr allBitSet(UnderlyingType bits) const
{
return (mState & bits) == bits;
}
UnderlyingType mState{};
};
static_assert(DecodingMode::None().isNone());
static_assert(!DecodingMode::None().isTopK());
static_assert(!DecodingMode::None().isTopP());
static_assert(!DecodingMode::None().isBeamSearch());
static_assert(DecodingMode::TopK().isTopK());
static_assert(DecodingMode::TopK().isTopKorTopP());
static_assert(!DecodingMode::TopK().isTopKandTopP());
static_assert(!DecodingMode::TopK().isTopP());
static_assert(!DecodingMode::TopK().isBeamSearch());
static_assert(DecodingMode::TopP().isTopP());
static_assert(DecodingMode::TopP().isTopKorTopP());
static_assert(!DecodingMode::TopP().isTopKandTopP());
static_assert(!DecodingMode::TopP().isTopK());
static_assert(!DecodingMode::TopP().isBeamSearch());
static_assert(DecodingMode::TopKTopP().isTopK());
static_assert(DecodingMode::TopKTopP().isTopP());
static_assert(DecodingMode::TopKTopP().isTopKorTopP());
static_assert(DecodingMode::TopKTopP().isTopKandTopP());
static_assert(!DecodingMode::TopKTopP().isBeamSearch());
static_assert(DecodingMode::BeamSearch().isBeamSearch());
static_assert(!DecodingMode::BeamSearch().isTopKorTopP());
} // namespace runtime
} // namespace tensorrt_llm

View File

@ -75,7 +75,7 @@ public:
// Set to true by decoding if any of the stop conditions are met or if DecodingInput.finished is
// true. In beam search and to determine whether to stop according to
// DecodingInput.sequenceLimitLength, on gpu
TensorPtr finishedSum; // [1], the sum of finished sequences, in pinned memory
TensorPtr finishedSum; // [batchSize], the sum of finished sequences per request, in pinned memory
// mandatory parameters for beam search
TensorPtr logProbs; // [batchSize, beamWidth, maxSeqLen], must be float*, on gpu

View File

@ -19,6 +19,7 @@
#include "tensorrt_llm/common/cudaAllocator.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/decodingOutput.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include <curand_kernel.h>
@ -44,9 +45,13 @@ namespace runtime
class IGptDecoder
{
public:
using TensorPtr = std::shared_ptr<ITensor>;
virtual ~IGptDecoder() = default;
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength) = 0;
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength,
std::optional<TensorPtr> const& batchSlots = std::nullopt)
= 0;
virtual bool forward(DecodingOutput& output, DecodingInput const& input) = 0;
@ -58,18 +63,19 @@ public:
virtual const SamplingConfig& getSamplingConfig() = 0;
static void acceptDraftTokensByIds(const ITensor& targetTokenIds, const ITensor& draftTokenIds,
const ITensor& contextLengths, const ITensor& numDraftTokens, ITensor& sequenceLengths,
const ITensor& finishedVec, ITensor& finishedFinal, ITensor& finishedSum,
static void acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds,
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths,
ITensor const& finishedVec, ITensor& finishedFinal, ITensor& finishedSum, ITensor const& batchSlots,
BufferManager::CudaStreamPtr const& stream);
static void acceptDraftTokensByLogits(ITensor& draftLogits, const ITensor& targetLogits, ITensor& draftProbs,
ITensor& targetProbs, const ITensor& numDraftTokens, ITensor& finished, SizeType vocabSize,
SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
static void acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const& targetLogits, ITensor& draftProbs,
ITensor& targetProbs, ITensor const& numDraftTokens, ITensor& finished, ITensor const& batchSlots,
SizeType vocabSize, SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream);
static std::unique_ptr<IGptDecoder> create(nvinfer1::DataType dtype, size_t maxBatchSize, size_t vocabSize,
size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream);
static std::unique_ptr<IGptDecoder> create(DecodingMode const& mode, nvinfer1::DataType dtype, size_t maxBatchSize,
size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
BufferManager::CudaStreamPtr const& stream);
};
template <typename T>
@ -80,9 +86,11 @@ public:
using CudaStreamPtr = BufferManager::CudaStreamPtr;
using TensorPtr = std::shared_ptr<ITensor>;
GptDecoder(size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded, CudaStreamPtr const& stream);
GptDecoder(DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize,
size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream);
void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength) override;
void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength,
std::optional<TensorPtr> const& batchSlots = std::nullopt) override;
bool forward(DecodingOutput& output, DecodingInput const& input) override;
@ -105,15 +113,18 @@ private:
SamplingConfig mSamplingConfig;
};
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(nvinfer1::DataType dtype, size_t maxBatchSize, size_t vocabSize,
size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream)
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(DecodingMode const& mode, nvinfer1::DataType dtype,
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
BufferManager::CudaStreamPtr const& stream)
{
switch (dtype)
{
case nvinfer1::DataType::kFLOAT:
return std::make_unique<GptDecoder<float>>(maxBatchSize, vocabSize, vocabSizePadded, stream);
return std::make_unique<GptDecoder<float>>(
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, maxSequenceLength, stream);
case nvinfer1::DataType::kHALF:
return std::make_unique<GptDecoder<half>>(maxBatchSize, vocabSize, vocabSizePadded, stream);
return std::make_unique<GptDecoder<half>>(
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, maxSequenceLength, stream);
default: return nullptr;
}
}

View File

@ -41,20 +41,21 @@ class GptDecoderBatch : public IGptDecoderBatch
public:
using CudaStreamPtr = std::shared_ptr<CudaStream>;
using TensorPtr = ITensor::SharedPtr;
using SharedConstPtr = ITensor::SharedConstPtr;
GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
//! Setup the decoder before calling `forward()`
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow, SizeType sinkTokenLength,
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
void newRequest(
SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig) override;
void setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep, bool fusedDecoder,
nvinfer1::DataType dtype) override;
void newBatch(
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig) override;
void newRequests(std::vector<SizeType> const& seqSlots, std::vector<decoder_batch::Request> const& requests,
std::vector<SamplingConfig> const& samplingConfigs) override;
TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override;
void forwardSync(decoder_batch::Token const& e) override;
@ -161,6 +162,9 @@ private:
//! @brief Gather final beam search results for request `batchIdx`.
CudaEvent postProcessRequest(SizeType batchIdx) const;
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
void newRequest(SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
private:
std::size_t const mVocabSize;
std::size_t const mVocabSizePadded;
@ -180,8 +184,6 @@ private:
DecodingInputPtr mJointDecodingInput;
DecodingOutputPtr mJointDecodingOutput;
std::vector<TensorPtr> mDraftTokenIds;
std::vector<TensorPtr> mDraftLogits;
std::vector<bool> mAcceptByLogits;
TensorPtr mNumDraftTokens;
TensorPtr mCurandStates;
@ -193,16 +195,28 @@ private:
std::vector<SizeType> mBeamWidths;
std::vector<SizeType> mGeneratedTokensPerStep;
TensorPtr mFinishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states of type FinishedState
// for each generated token of maxTokensPerStep, on gpu
TensorPtr mDraftProbs; // [batchSize, maxDraftTokens, beamWidth, vocabPadded], temporary data for speculative
// decoding accept by logits kernel, on gpu
TensorPtr mTargetProbs; // [batchSize, maxDraftTokens+1, beamWidth, vocabPadded], temporary data for speculative
// decoding accept by logits kernel, on gpu
TensorPtr mFinishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states of type FinishedState
// for each generated token of maxTokensPerStep, on gpu
TensorPtr mDraftProbs; // [batchSize, maxDraftTokens+1, beamWidth, vocabPadded], temporary data for speculative
// decoding accept by logits kernel, on gpu
TensorPtr mTargetProbs; // [batchSize, maxDraftTokens+1, beamWidth, vocabPadded], temporary data for speculative
// decoding accept by logits kernel, on gpu
TensorPtr mDraftTokenIds; // [batchSize, maxDraftTokens+1], draft token indices, on gpu
TensorPtr mDraftLogits; // [batchSize, maxDraftTokens+1, vocabSizePadded], draft token logits, on gpu
TensorPtr mBatchSlotsSetup; // [maxBatchSize], int32_t, address map, pinned
TensorPtr mBatchSlotsDecoder; // [maxBatchSize], int32_t, address map, pinned
TensorPtr mBatchSlotsAcceptTokens; // [maxBatchSize], int32_t, address map, pinned
TensorPtr mBatchSlotsAcceptLogits; // [maxBatchSize], int32_t, address map, pinned
TensorPtr mTargetLogitsPtrs; // [maxBatchSize], float*, pointers to target logits, pinned
SizeType mMaxSequenceLength{};
SizeType mMaxAttentionWindow{};
SizeType mSinkTokenLength{};
SizeType mActualBatchSize{};
SizeType mMaxTokensPerStep{};
SizeType mMaxStopWordsLen{};
SizeType mMaxBadWordsLen{};
bool mFusedDecoder{false};
};
} // namespace tensorrt_llm::runtime

View File

@ -19,6 +19,7 @@
#include "tensorrt_llm/common/quantization.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/loraModule.h"
#include "tensorrt_llm/runtime/medusaModule.h"
#include <NvInferRuntime.h>
namespace tensorrt_llm::runtime
@ -62,6 +63,7 @@ public:
, mPagedContextFMHA(false)
, mUseLoraPlugin(false)
, mMlpHiddenSize(0)
, mMedusaModule(std::nullopt)
{
}
@ -341,6 +343,31 @@ public:
mMlpHiddenSize = mlpHiddenSize;
}
[[nodiscard]] SizeType constexpr getMaxLoraRank() const noexcept
{
return mMaxLoraRank;
}
void constexpr setMaxLoraRank(SizeType maxLoraRank) noexcept
{
mMaxLoraRank = maxLoraRank;
}
[[nodiscard]] bool constexpr useMedusa() const noexcept
{
return mMedusaModule.has_value();
}
[[nodiscard]] std::optional<MedusaModule> getMedusaModule() const noexcept
{
return mMedusaModule;
}
void setMedusaModule(MedusaModule const& medusaModule) noexcept
{
mMedusaModule = medusaModule;
}
private:
SizeType mVocabSize;
SizeType mNbLayers;
@ -374,5 +401,8 @@ private:
bool mUseLoraPlugin;
std::vector<LoraModule> mLoraModules;
SizeType mMlpHiddenSize;
SizeType mMaxLoraRank;
std::optional<MedusaModule> mMedusaModule;
};
} // namespace tensorrt_llm::runtime

View File

@ -20,6 +20,7 @@
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/generationInput.h"
#include "tensorrt_llm/runtime/generationOutput.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
@ -90,6 +91,7 @@ public:
KvCacheConfig kvCacheConfig{};
std::optional<SizeType> ctxMicroBatchSize = std::nullopt;
std::optional<SizeType> genMicroBatchSize = std::nullopt;
std::optional<DecodingMode> decodingMode = std::nullopt;
};
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
@ -146,7 +148,8 @@ private:
void createContexts();
void createBuffers(SizeType numMicroBatches);
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow, SizeType sinkTokenLength,
SizeType maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches);
SizeType maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches,
DecodingMode const& decodingMode);
void createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
SizeType sinkTokenLength, SizeType maxSequenceLength, KvCacheConfig const& config);
void createCustomAllReduceWorkspace(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength);

View File

@ -141,11 +141,6 @@ public:
using TensorPtr = std::shared_ptr<ITensor>;
using TokenPtr = std::unique_ptr<decoder_batch::Token const>;
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
virtual void newRequest(
SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig)
= 0;
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
virtual TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
@ -186,6 +181,11 @@ public:
virtual std::vector<SizeType> getNbSteps() const = 0;
//! @brief Initialize batched decoder at seqSlots with a new `requests`.
virtual void newRequests(std::vector<SizeType> const& seqSlots, std::vector<decoder_batch::Request> const& requests,
std::vector<SamplingConfig> const& samplingConfigs)
= 0;
protected:
IGptDecoderBatch() = default;
};

View File

@ -17,6 +17,7 @@
#pragma once
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/generationInput.h"
#include "tensorrt_llm/runtime/generationOutput.h"
#include "tensorrt_llm/runtime/iTensor.h"
@ -74,8 +75,9 @@ public:
using TensorPtr = std::shared_ptr<ITensor>;
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
virtual void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
virtual void setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth,
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep,
bool fusedDecoder, nvinfer1::DataType dtype)
= 0;
//! @brief Initialize the decoder with new batch of inputs.

View File

@ -26,17 +26,67 @@ namespace tensorrt_llm::runtime
class SamplingConfig
{
private:
using FloatType = float;
template <typename T>
using OptVec = std::optional<std::vector<T>>;
private:
template <typename T>
static OptVec<T> fuseValues(
std::vector<SamplingConfig> const& configs, std::function<OptVec<T>(SizeType ci)> accessor)
{
std::vector<T> values;
auto const hasValues = accessor(0).has_value();
for (size_t ci = 0; ci < configs.size(); ++ci)
{
const auto& configValue = accessor(ci);
TLLM_CHECK(hasValues == configValue.has_value());
if (hasValues)
{
TLLM_CHECK(configValue.value().size() == 1);
values.push_back(configValue.value().front());
}
}
if (!hasValues)
{
return std::nullopt;
}
return std::make_optional<std::vector<T>>(values);
}
public:
explicit SamplingConfig(SizeType beamWidth = 1)
: beamWidth{beamWidth}
{
}
explicit SamplingConfig(std::vector<SamplingConfig> const& configs)
{
TLLM_CHECK(configs.size() > 0);
beamWidth = configs.front().beamWidth;
temperature = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].temperature; });
minLength = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].minLength; });
repetitionPenalty
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].repetitionPenalty; });
presencePenalty
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].presencePenalty; });
topK = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].topK; });
topP = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].topP; });
randomSeed = fuseValues<uint64_t>(configs, [&configs](SizeType ci) { return configs[ci].randomSeed; });
topPDecay = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].topPDecay; });
topPMin = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].topPMin; });
topPResetIds = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].topPResetIds; });
beamSearchDiversityRate
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].beamSearchDiversityRate; });
lengthPenalty = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].lengthPenalty; });
draftAcceptanceThreshold
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].draftAcceptanceThreshold; });
}
public:
SizeType beamWidth;
OptVec<FloatType> temperature; // [1] or [batch_size] on cpu

View File

@ -138,6 +138,7 @@ set(TRTLLM_LINK_LIBS
${TRT_LIB}
common_src
kernels_src
cutlass_src
layers_src
runtime_src)

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:6a8b82a255fc93e99bfca1bb9975f8ac524a980e25c6678fbed0e64b7d8e1841
size 1949506

View File

@ -1,3 +0,0 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0bfef429b1985539c3956ada86c0578ad9b783c6b79d0c5123e7e23a18f3356b
size 1966228

View File

@ -1,3 +0,0 @@
86bf72386b323b73b0fd95f564270c8b libtensorrt_llm_batch_manager_static.a
93e03895d79092f5bf81a4233078d0b3 libtensorrt_llm_batch_manager_static.pre_cxx11.a
b3fa820622b86294b498b661362a06ec386a6e1b commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:a94a642407d43a81d7c8221b4158bd6958c6fb57c3c1c39446e15ac8471f7b41
size 1897882
oid sha256:0268f64b0c2540e07bf05ad458f7aa33c9d6e65fc4f5c85cd8d0946d658ffeb8
size 2092012

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:3c1e32ebb36f74e6b971fe8898b9a1d1763332f5b2cdbf089443471de6087b12
size 1871190
oid sha256:89ae0be676e7aa9b562f6745636f7d77198f87b83ec6295aff74273767e4fca7
size 2071180

View File

@ -1,2 +1,2 @@
ba7b3bdcc6754724cc9405c2699b1900 libtensorrt_llm_batch_manager_static.a
b7bf0d41e6bde342352c6a5eee2d5ad0 libtensorrt_llm_batch_manager_static.pre_cxx11.a
63c3f64faa14f9d5d66b7e186a6cc80b libtensorrt_llm_batch_manager_static.a
dbcc1bbe80d977c1655d32ef69b36578 libtensorrt_llm_batch_manager_static.pre_cxx11.a

View File

@ -25,7 +25,7 @@ namespace tensorrt_llm::common::nvtx
{
inline nvtx3::color nextColor()
{
#if !defined(NVTX_DISABLE)
#ifndef NVTX_DISABLE
constexpr std::array kColors{nvtx3::color{0xff00ff00}, nvtx3::color{0xff0000ff}, nvtx3::color{0xffffff00},
nvtx3::color{0xffff00ff}, nvtx3::color{0xff00ffff}, nvtx3::color{0xffff0000}, nvtx3::color{0xffffffff}};
constexpr auto numColors = kColors.size();

View File

@ -25,7 +25,7 @@ namespace tensorrt_llm
namespace cutlass_extensions
{
template <typename GemmKernel>
template <typename GemmKernel, bool enable_cutlass_3x = false>
inline int compute_occupancy_for_kernel()
{
@ -39,7 +39,14 @@ inline int compute_occupancy_for_kernel()
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
tensorrt_llm::common::check_cuda_error(
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel<GemmKernel>));
if constexpr (enable_cutlass_3x)
{
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::device_kernel<GemmKernel>));
}
else
{
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel<GemmKernel>));
}
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
{
// This should mean that
@ -51,8 +58,17 @@ inline int compute_occupancy_for_kernel()
}
int max_active_blocks = -1;
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
if constexpr (enable_cutlass_3x)
{
tensorrt_llm::common::check_cuda_error(
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel<GemmKernel>,
128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size));
}
else
{
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
}
return max_active_blocks;
}

View File

@ -251,7 +251,6 @@ struct GemmFpAIntB
CUTLASS_HOST_DEVICE
static Status can_implement(Arguments const& args)
{
static int const kAlignmentA
= (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ? 32
: (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value)
@ -340,19 +339,6 @@ struct GemmFpAIntB
return 0;
}
// The dummy template parameter is not used and exists so that we can compile this code using
// a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in
// a namespace
template <bool B, typename dummy = void>
struct KernelRunner
{
CUTLASS_DEVICE
static void run_kernel(Params const& params, SharedStorage& shared_storage)
{
CUTLASS_NOT_IMPLEMENTED();
}
};
// Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator
// has a different constructor signature than a regular cutlass iterator
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<isFinegrained(op), bool> = true>
@ -375,169 +361,177 @@ struct GemmFpAIntB
return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset);
}
template <typename dummy>
struct KernelRunner<true, dummy>
CUTLASS_DEVICE
void run_kernel_(Params const& params, SharedStorage& shared_storage)
{
CUTLASS_DEVICE
static void run_kernel(Params const& params, SharedStorage& shared_storage)
using LayoutB = typename Mma::IteratorB::Layout;
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
{
using LayoutB = typename Mma::IteratorB::Layout;
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
// Compute threadblock location
ThreadblockSwizzle threadblock_swizzle;
cutlass::gemm::GemmCoord threadblock_tile_offset
= threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// Early exit if CTA is out of range
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
{
return;
}
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size,
};
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64;
typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0;
cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(),
{params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices);
typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(),
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B,
params.gather_B_indices);
typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1;
typename Mma::IteratorScale iterator_scale = initialize_scale<typename Mma::IteratorScale, Mma::QuantOp>(
params.params_scale, params.ref_scale.data(), params.ref_zero.data(),
{scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations > 0)
{
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
}
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(),
params.problem_size.mn(), thread_idx, threadblock_offset, params.scatter_D_indices);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(),
params.problem_size.mn(), thread_idx, threadblock_offset, params.scatter_D_indices);
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k())
{
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1)
{
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else
{
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
return;
}
};
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_tile_offset.m() * Mma::Shape::kM,
threadblock_tile_offset.k() * params.gemm_k_size,
};
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64;
typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0;
cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN};
// Problem size is a function of threadblock index in the K dimension
int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(),
{params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices);
typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(),
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B,
params.gather_B_indices);
typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1;
typename Mma::IteratorScale iterator_scale = initialize_scale<typename Mma::IteratorScale, Mma::QuantOp>(
params.params_scale, params.ref_scale.data(), params.ref_zero.data(),
{scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
typename Mma::FragmentC accumulators;
accumulators.clear();
if (!kSplitKSerial || gemm_k_iterations > 0)
{
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
}
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
// assume identity swizzle
MatrixCoord threadblock_offset(
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
// Construct the semaphore.
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
// If performing a reduction via split-K, fetch the initial synchronization
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
// Fetch the synchronization lock initially but do not block.
semaphore.fetch();
// Indicate which position in a serial reduction the output operator is currently updating
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
}
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(),
thread_idx, threadblock_offset, params.scatter_D_indices);
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(),
thread_idx, threadblock_offset, params.scatter_D_indices);
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Wait on the semaphore - this latency may have been covered by iterator construction
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
if (threadblock_tile_offset.k())
{
iterator_C = iterator_D;
}
semaphore.wait(threadblock_tile_offset.k());
}
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
//
// Release the semaphore
//
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
{
int lock = 0;
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1)
{
// The final threadblock resets the semaphore for subsequent grids.
lock = 0;
}
else
{
// Otherwise, the semaphore is incremented
lock = threadblock_tile_offset.k() + 1;
}
semaphore.release(lock);
}
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
{
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
{
run_kernel_(params, shared_storage);
}
else
{
CUTLASS_NOT_IMPLEMENTED();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
@ -549,14 +543,13 @@ struct GemmFpAIntB
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm70>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
run_kernel<arch::Sm70>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm75>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ <= 900)
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 900)
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
#else
static_assert(
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");

View File

@ -407,7 +407,7 @@ public:
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
void run_kernel_(Params const& params, SharedStorage& shared_storage)
{
// Compute threadblock location
@ -534,6 +534,48 @@ public:
// Execute the epilogue operator to update the destination tensor.
epilogue(epilogue_visitor, accumulators);
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
{
if constexpr (platform::is_same<ArchTag, CompilationArch>::value)
{
run_kernel_(params, shared_storage);
}
else
{
CUTLASS_NOT_IMPLEMENTED();
}
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
to the ArchTag of the cutlass kernel operator.
*/
/// Executes one GEMM
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 720)
run_kernel<arch::Sm70>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750)
run_kernel<arch::Sm72>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 900)
// TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels.
run_kernel<arch::Sm80>(params, shared_storage);
#else
static_assert(
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
#endif
#else
CUTLASS_NOT_IMPLEMENTED();
#endif
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@ -80,7 +80,8 @@ struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinC
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
// which signals that we want to dequantize after loading from smem.
template <typename Arch>
struct LayoutDetailsB<uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
struct LayoutDetailsB < uint8_t,
Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
{
static constexpr int ThreadblockK = 64;
@ -95,7 +96,8 @@ public:
};
template <typename Arch>
struct LayoutDetailsB<uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
struct LayoutDetailsB < uint4b_t,
Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
{
static constexpr int ThreadblockK = 64;
@ -109,6 +111,24 @@ public:
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
};
template <typename Arch>
struct LayoutDetailsB<uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{
static constexpr int ThreadblockK = 64;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
template <typename Arch>
struct LayoutDetailsB<uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
{
static constexpr int ThreadblockK = 64;
using Layout = layout::ColumnMajor;
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
using Operator = cutlass::arch::OpMultiplyAdd;
};
} // namespace kernel
} // namespace gemm
} // namespace cutlass

View File

@ -321,174 +321,170 @@ public:
return 0;
}
// The dummy template parameter is not used and exists so that we can compile this code using
// a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in
// a namespace
template <bool B, typename dummy = void>
struct KernelRunner
CUTLASS_DEVICE
void run_kernel_(Params const& params, SharedStorage& shared_storage)
{
CUTLASS_DEVICE
static void run_kernel(Params const& params, SharedStorage& shared_storage)
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
//
// Problem visitor.
//
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
const int64_t gemm_k = params.problem_visitor.gemm_k;
const int64_t gemm_n = params.problem_visitor.gemm_n;
int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits<ElementB>::value;
// Outer 'persistent' loop to iterate over tiles
int loop = 0;
while (problem_visitor.next_tile())
{
loop++;
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
cutlass::gemm::GemmCoord threadblock_offset(
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0);
// Load element pointers. Exchange pointers and strides if working on the transpose
const int64_t rows_to_jump
= problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
ElementA* ptr_A = reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
typename LayoutA::LongIndex ldm_A = gemm_k;
char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix;
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
typename LayoutB::LongIndex ldm_B
= platform::is_same<layout::RowMajor, LayoutB>::value ? gemm_n : gemm_k * kInterleave;
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B,
{problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
auto CreateMMA = [&]()
{
if constexpr (use_dq_gemm<Mma>::value)
return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
else
return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
};
Mma mma = CreateMMA();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n();
if constexpr (use_dq_gemm<Mma>::value)
{
const MatrixCoord scale_extent = {1, problem_size.n()};
typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()),
weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale);
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
}
else
{
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
}
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
ElementC* ptr_C = reinterpret_cast<ElementC*>(params.ptr_C) + problem_idx * gemm_n;
ElementC* ptr_D = reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
LayoutC layout_C(0);
LayoutC layout_D(gemm_n);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn());
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn());
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
// Next tile
problem_visitor.advance(gridDim.x);
}
}
template <typename CompilationArch>
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
{
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
{
run_kernel_(params, shared_storage);
}
else
{
CUTLASS_NOT_IMPLEMENTED();
}
};
template <typename dummy>
struct KernelRunner<true, dummy>
{
CUTLASS_DEVICE
static void run_kernel(Params const& params, SharedStorage& shared_storage)
{
//
// These types shadow the type-level definitions and support the ability to implement
// a 'transposed' GEMM that computes the transposed problems.
//
using ElementA = typename Mma::IteratorA::Element;
using LayoutA = typename Mma::IteratorA::Layout;
using ElementB = typename Mma::IteratorB::Element;
using LayoutB = typename Mma::IteratorB::Layout;
using ElementC = typename Epilogue::OutputTileIterator::Element;
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
"B must be row major/col major OR col major interleaved.");
//
// Problem visitor.
//
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
const int64_t gemm_k = params.problem_visitor.gemm_k;
const int64_t gemm_n = params.problem_visitor.gemm_n;
int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits<ElementB>::value;
// Outer 'persistent' loop to iterate over tiles
int loop = 0;
while (problem_visitor.next_tile())
{
loop++;
GemmCoord problem_size = problem_visitor.problem_size();
int32_t problem_idx = problem_visitor.problem_index();
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
cutlass::gemm::GemmCoord threadblock_offset(
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0);
// Load element pointers. Exchange pointers and strides if working on the transpose
const int64_t rows_to_jump
= problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
ElementA* ptr_A = reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
typename LayoutA::LongIndex ldm_A = gemm_k;
char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix;
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
typename LayoutB::LongIndex ldm_B
= platform::is_same<layout::RowMajor, LayoutB>::value ? gemm_n : gemm_k * kInterleave;
// Compute initial location in logical coordinates
cutlass::MatrixCoord tb_offset_A{
threadblock_offset.m(),
0,
};
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
// Compute position within threadblock
int thread_idx = threadIdx.x;
// Construct iterators to A and B operands
typename Mma::IteratorA iterator_A(
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A);
typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B,
{problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B);
typename Mma::FragmentC accumulators;
accumulators.clear();
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
int lane_idx = threadIdx.x % 32;
//
// Matrix multiply phase
//
// Construct thread-scoped matrix multiply
auto CreateMMA = [&]()
{
if constexpr (use_dq_gemm<Mma>::value)
return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
else
return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
};
Mma mma = CreateMMA();
// Compute threadblock-scoped matrix multiply-add
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
// Wait for all threads to finish their epilogue phases from the previous tile.
__syncthreads();
// Compute threadblock-scoped matrix multiply-add
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n();
if constexpr (use_dq_gemm<Mma>::value)
{
const MatrixCoord scale_extent = {1, problem_size.n()};
typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()),
weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale);
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
}
else
{
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
}
//
// Epilogue
//
EpilogueOutputOp output_op(params.output_op);
ElementC* ptr_C = reinterpret_cast<ElementC*>(params.ptr_C) + problem_idx * gemm_n;
ElementC* ptr_D = reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
LayoutC layout_C(0);
LayoutC layout_D(gemm_n);
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
// Tile iterator loading from source tensor.
typename Epilogue::OutputTileIterator iterator_C(
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn());
// Tile iterator writing to destination tensor.
typename Epilogue::OutputTileIterator iterator_D(
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn());
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
// Execute the epilogue operator to update the destination tensor.
epilogue(output_op, iterator_D, accumulators, iterator_C);
// Next tile
problem_visitor.advance(gridDim.x);
}
}
};
}
/*
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
@ -498,19 +494,20 @@ public:
CUTLASS_DEVICE
void operator()(Params const& params, SharedStorage& shared_storage)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm70>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm75>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
#if defined(__CUDA_ARCH__)
#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
run_kernel<arch::Sm70>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
run_kernel<arch::Sm75>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
run_kernel<arch::Sm80>(params, shared_storage);
#elif (__CUDA_ARCH__ >= 900)
run_kernel<arch::Sm80>(
params, shared_storage); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
#else
static_assert(
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
#endif
#else
CUTLASS_NOT_IMPLEMENTED();
#endif

View File

@ -60,12 +60,81 @@ enum class SplitKStyle
// SPLIT_K_PARALLEL // Not supported yet
};
enum class CutlassTileConfigSM90
{
// Signals that we should run heuristics do choose a config
Undefined,
// Signals that we should run heuristics do choose a config
ChooseWithHeuristic,
// CTA configs for M=64
CtaShape64x16x128B,
CtaShape64x32x128B,
CtaShape64x64x128B,
CtaShape64x128x128B,
CtaShape64x256x128B,
// CTA configs for M=128
CtaShape128x16x128B,
CtaShape128x32x128B,
CtaShape128x64x128B,
CtaShape128x128x128B,
CtaShape128x256x128B,
};
enum class MainloopScheduleType
{
AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this
// defaults to the "legacy" main loop schedule.
};
enum class EpilogueScheduleType
{
AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For
// architectures older than hopper, the epilogue is always performed by the same thread block as the main loop.
};
enum class ClusterShape
{
ClusterShape_1x1x1,
ClusterShape_2x1x1,
ClusterShape_1x2x1,
ClusterShape_2x2x1
};
struct CutlassGemmConfig
{
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
int split_k_factor = -1;
int stages = -1;
// config options for sm90
CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic;
MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO;
EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO;
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
CutlassGemmConfig() {}
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
: tile_config(tile_config)
, split_k_style(split_k_style)
, split_k_factor(split_k_factor)
, stages(stages)
{
}
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
: tile_config_sm90(tile_config_sm90)
, mainloop_schedule(mainloop_schedule)
, epilogue_schedule(epilogue_schedule)
, cluster_shape(cluster_shape)
{
}
};
} // namespace cutlass_extensions

View File

@ -18,6 +18,10 @@
file(GLOB_RECURSE SRC_CPP *.cpp)
file(GLOB_RECURSE SRC_CU *.cu)
# Exclude files in the cutlass_kernels folder
list(FILTER SRC_CPP EXCLUDE REGEX "cutlass_kernels/.*")
list(FILTER SRC_CU EXCLUDE REGEX "cutlass_kernels/.*")
# skip mmha 48, 80, 96, 112, 144, 160, 192 and 224 for fast build
if(FAST_BUILD)
list(FILTER SRC_CU EXCLUDE REGEX
@ -27,3 +31,5 @@ endif()
add_library(kernels_src OBJECT ${SRC_CPP} ${SRC_CU})
set_property(TARGET kernels_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET kernels_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
add_subdirectory(cutlass_kernels)

View File

@ -25,42 +25,43 @@ namespace kernels
{
template <typename T>
__global__ void ban_bad_words(T* logits, const int** output_ids_ptr, const int** parent_ids_ptr, const int* batch_slots,
int batch_size, int beam_width, const int* bad_words, size_t bad_words_len, bool share_words, int vocab_size_padded,
const int* sequence_lengths, const int max_seq_len)
__global__ void ban_bad_words(T* logits, int32_t const** output_ids_ptr, int32_t const** parent_ids_ptr,
int32_t const* batch_slots, int32_t beam_width, int32_t const** bad_words_ptrs, int32_t const* bad_words_lens,
int32_t vocab_size_padded, int32_t const* sequence_lengths, const int32_t max_seq_len)
{
const int id = blockIdx.x * blockDim.x + threadIdx.x;
const int batch_idx = blockIdx.y / beam_width;
const int beam_idx = blockIdx.y % beam_width;
int32_t const id = blockIdx.x * blockDim.x + threadIdx.x;
int32_t const batch_idx = blockIdx.y / beam_width;
int32_t const beam_idx = blockIdx.y % beam_width;
auto const batch_slot = batch_slots != nullptr ? batch_slots[batch_idx] : batch_idx;
auto const batch_beam_idx = batch_slot * beam_width + beam_idx;
const int* base_bad_words = share_words ? bad_words : bad_words + batch_slot * 2 * bad_words_len;
const int* base_bad_words_offsets = base_bad_words + bad_words_len;
int32_t const* base_bad_words = bad_words_ptrs[batch_slot];
auto const bad_words_len = bad_words_lens[batch_slot];
int32_t const* base_bad_words_offsets = base_bad_words + bad_words_len;
if (id >= bad_words_len || base_bad_words_offsets[id] < 0)
{
return;
}
const int item_end = base_bad_words_offsets[id];
const int item_start = (id > 0) ? base_bad_words_offsets[id - 1] : 0;
const int item_size = item_end - item_start;
auto const item_end = base_bad_words_offsets[id];
auto const item_start = (id > 0) ? base_bad_words_offsets[id - 1] : 0;
auto const item_size = item_end - item_start;
/* The single-token case unconditionally bans the token */
bool should_ban = item_size == 1;
const int current_step{sequence_lengths[batch_beam_idx]};
int32_t const current_step{sequence_lengths[batch_beam_idx]};
/* Multi-token case and enough previously generated tokens to look for a match
*/
if (item_size > 1 && current_step >= item_size - 1)
{
should_ban = true;
int parent_id = beam_idx;
const bool gather_beam = beam_width > 1;
int32_t parent_id = beam_idx;
bool const gather_beam = beam_width > 1;
for (int token_idx = item_size - 2; token_idx >= 0; token_idx--)
for (int32_t token_idx = item_size - 2; token_idx >= 0; token_idx--)
{
const int previous_token
auto const previous_token
= output_ids_ptr[batch_slot][parent_id * max_seq_len + current_step - (item_size - 1) + token_idx];
if (previous_token != base_bad_words[item_start + token_idx])
@ -85,42 +86,46 @@ __global__ void ban_bad_words(T* logits, const int** output_ids_ptr, const int**
if (should_ban)
{
int banned_token = base_bad_words[item_end - 1];
if (0 < banned_token && banned_token < vocab_size_padded)
auto banned_token = base_bad_words[item_end - 1];
if (0 <= banned_token && banned_token < vocab_size_padded)
{
logits[batch_slot * beam_width * vocab_size_padded + beam_idx * vocab_size_padded + banned_token]
logits[batch_idx * beam_width * vocab_size_padded + beam_idx * vocab_size_padded + banned_token]
= static_cast<T>(-INFINITY);
}
}
}
template <typename T>
void invokeBanBadWords(T* logits, const int** output_ids_ptr, const int** parent_ids_ptr, const int* batch_slot,
int batch_size, int local_batch_size, int beam_width, const int* bad_words, bool share_words, size_t bad_words_len,
int vocab_size_padded, const int* sequence_lengths, int max_seq_len, cudaStream_t stream)
void invokeBanBadWords(T* logits, int32_t const** output_ids_ptr, int32_t const** parent_ids_ptr,
int32_t const* batch_slot, int32_t batch_size, int32_t beam_width, int32_t const** bad_words,
int32_t const* bad_words_lens, int32_t max_bad_words_len, int32_t vocab_size_padded,
int32_t const* sequence_lengths, int32_t max_seq_len, cudaStream_t stream)
{
dim3 block, grid;
constexpr size_t max_blocks{256};
block.x = min(((bad_words_len + 32 - 1) / 32) * 32, max_blocks);
grid.x = (bad_words_len + block.x - 1) / block.x;
grid.y = local_batch_size * beam_width;
constexpr int32_t max_blocks{256};
block.x = min(((max_bad_words_len + 32 - 1) / 32) * 32, max_blocks);
grid.x = (max_bad_words_len + block.x - 1) / block.x;
grid.y = batch_size * beam_width;
ban_bad_words<<<grid, block, 0, stream>>>(logits, output_ids_ptr, parent_ids_ptr, batch_slot, batch_size,
beam_width, bad_words, bad_words_len, share_words, vocab_size_padded, sequence_lengths, max_seq_len);
ban_bad_words<<<grid, block, 0, stream>>>(logits, output_ids_ptr, parent_ids_ptr, batch_slot, beam_width, bad_words,
bad_words_lens, vocab_size_padded, sequence_lengths, max_seq_len);
sync_check_cuda_error();
}
template void invokeBanBadWords(half* logits, const int** output_ids_ptr, const int** parent_ids_ptr,
const int* batch_slot, int batch_size, int local_batch_size, int beam_width, const int* bad_words, bool share_words,
size_t bad_words_len, int vocab_size_padded, const int* sequence_lengths, int max_seq_len, cudaStream_t stream);
template void invokeBanBadWords(half* logits, int32_t const** output_ids_ptr, int32_t const** parent_ids_ptr,
int32_t const* batch_slot, int32_t batch_size, int32_t beam_width, int32_t const** bad_words,
int32_t const* bad_words_lens, int32_t max_bad_words_len, int32_t vocab_size_padded,
int32_t const* sequence_lengths, int32_t max_seq_len, cudaStream_t stream);
#ifdef ENABLE_BF16
template void invokeBanBadWords(__nv_bfloat16* logits, const int** output_ids_ptr, const int** parent_ids_ptr,
const int* batch_slot, int batch_size, int local_batch_size, int beam_width, const int* bad_words, bool share_words,
size_t bad_words_len, int vocab_size_padded, const int* sequence_lengths, int max_seq_len, cudaStream_t stream);
template void invokeBanBadWords(__nv_bfloat16* logits, int32_t const** output_ids_ptr, int32_t const** parent_ids_ptr,
int32_t const* batch_slot, int32_t batch_size, int32_t beam_width, int32_t const** bad_words,
int32_t const* bad_words_lens, int32_t max_bad_words_len, int32_t vocab_size_padded,
int32_t const* sequence_lengths, int32_t max_seq_len, cudaStream_t stream);
#endif
template void invokeBanBadWords(float* logits, const int** output_ids_ptr, const int** parent_ids_ptr,
const int* batch_slot, int batch_size, int local_batch_size, int beam_width, const int* bad_words, bool share_words,
size_t bad_words_len, int vocab_size_padded, const int* sequence_lengths, int max_seq_len, cudaStream_t stream);
template void invokeBanBadWords(float* logits, int32_t const** output_ids_ptr, int32_t const** parent_ids_ptr,
int32_t const* batch_slot, int32_t batch_size, int32_t beam_width, int32_t const** bad_words,
int32_t const* bad_words_lens, int32_t max_bad_words_len, int32_t vocab_size_padded,
int32_t const* sequence_lengths, int32_t max_seq_len, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -25,9 +25,10 @@ namespace kernels
{
template <typename T>
void invokeBanBadWords(T* logits, const int** output_ids_ptr, const int** parent_ids_ptr, const int* batch_slot,
int batch_size, int local_batch_size, int beam_width, const int* bad_words, bool share_words, size_t bad_words_len,
int vocab_size_padded, const int* sequence_lengths, int max_seq_len, cudaStream_t stream);
void invokeBanBadWords(T* logits, int32_t const** output_ids_ptr, int32_t const** parent_ids_ptr,
int32_t const* batch_slot, int32_t batch_size, int32_t beam_width, int32_t const** bad_words,
int32_t const* bad_words_len, int32_t max_bad_words_len, int32_t vocab_size_padded, int32_t const* sequence_lengths,
int32_t max_seq_len, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -134,9 +134,8 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const Fi
template <typename T>
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size,
int local_batch_size, int beam_width, int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded,
size_t max_step, cudaStream_t stream)
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, int beam_width,
int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream)
{
// each input in the local batch can have different no_repeat_ngram_size. Use max for shmem allocation
// getting the max of current batch and allocate shmem as needed is ideal. But here the ngram_buf is on GPU, while
@ -149,7 +148,7 @@ void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedS
constexpr size_t max_blocks{256};
block.x = min(((max_step + 32 - 1) / 32) * 32, max_blocks);
grid.x = (max_step + block.x - 1) / block.x;
grid.y = local_batch_size * beam_width;
grid.y = batch_size * beam_width;
// dynamically allocate shared memory of int[blockDim + 2*(ngram_size - 1)], where ngram_size - 1 is for boundary
// token's ngram and for most recent tokens
@ -162,8 +161,8 @@ void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedS
#define INVOKE_BAN_REPEAT_NGRAM(T) \
template void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf, \
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, \
int local_batch_size, int beam_width, int max_seq_len, const int* no_repeat_ngram_size_buf, \
int vocab_size_padded, size_t max_step, cudaStream_t stream);
int beam_width, int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, \
cudaStream_t stream);
INVOKE_BAN_REPEAT_NGRAM(float)
INVOKE_BAN_REPEAT_NGRAM(half)

View File

@ -27,9 +27,8 @@ namespace kernels
template <typename T>
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size,
int local_batch_size, int beam_width, int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded,
size_t max_step, cudaStream_t stream);
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, int beam_width,
int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,91 @@
#
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy of
# the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
#
file(GLOB_RECURSE SRC_CPP *.cpp)
file(GLOB_RECURSE SRC_CU *.cu)
# This can happen when not building for Torch
if(NOT Python3_EXECUTABLE)
find_package(
Python3
COMPONENTS Interpreter
REQUIRED)
endif()
execute_process(
WORKING_DIRECTORY ${3RDPARTY_DIR}/cutlass/python/
COMMAND ${Python3_EXECUTABLE} setup_library.py develop --user
RESULT_VARIABLE _CUTLASS_LIBRARY_SUCCESS)
if(NOT _CUTLASS_LIBRARY_SUCCESS MATCHES 0)
message(
FATAL_ERROR
"Failed to set up the CUTLASS library due to ${_CUTLASS_LIBRARY_SUCCESS}."
)
endif()
set_directory_properties(
PROPERTIES CMAKE_CONFIGURE_DEPENDS
${CMAKE_CURRENT_SOURCE_DIR}/python/generate_kernels.py)
execute_process(
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/python/
COMMAND ${Python3_EXECUTABLE} generate_kernels.py -o
${CMAKE_CURRENT_BINARY_DIR}
RESULT_VARIABLE _KERNEL_GEN_SUCCESS)
if(NOT _KERNEL_GEN_SUCCESS MATCHES 0)
message(
FATAL_ERROR
"Failed to generate CUTLASS kernel instantiations due to ${_KERNEL_GEN_SUCCESS}."
)
endif()
file(GLOB_RECURSE CU_INSTANTIATIONS ${CMAKE_CURRENT_BINARY_DIR}/*.cu)
add_library(cutlass_src OBJECT ${SRC_CPP} ${SRC_CU} ${CU_INSTANTIATIONS})
set_property(TARGET cutlass_src PROPERTY POSITION_INDEPENDENT_CODE ON)
set_property(TARGET cutlass_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
# Note - we deliberately do not include 90a PTX (even when 9.0+PTX is
# specified). This is because sm_90a has arch conditional instructions that are
# not forward compatible. As a result, it does not make sense to embed PTX into
# the binary anyway.
if("9.0" IN_LIST TORCH_CUDA_ARCH_LIST
OR "9.0+PTX" IN_LIST TORCH_CUDA_ARCH_LIST
OR TORCH_CUDA_ARCH_LIST STREQUAL "Auto")
message(STATUS "MANUALLY APPENDING FLAG TO COMPILE FOR SM_90a.")
target_compile_options(
cutlass_src
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_90a,code=sm_90a>)
# Hopper kernels require cuda lib for TMA APIs
target_link_libraries(cutlass_src PRIVATE CUDA::cuda_driver)
# No kernels should be parsed, unless hopper is specified. This is a build
# time improvement
target_compile_definitions(cutlass_src
PRIVATE COMPILE_HOPPER_MIXED_INPUT_GEMMS)
endif()
# Suppress GCC note: the ABI for passing parameters with 64-byte alignment has
# changed in GCC 4.6 This note appears for kernels using TMA and clutters the
# compilation output.
if(NOT WIN32)
target_compile_options(
cutlass_src PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
endif()

View File

@ -30,6 +30,7 @@
#endif // #ifndef _WIN32
#include <cuda_runtime_api.h>
#include <set>
#include <vector>
using namespace tensorrt_llm::cutlass_extensions;
@ -164,9 +165,103 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
}
}
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only,
const bool int8_configs_only, const int max_split_k)
std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
const int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only)
{
enum class CutlassGemmType : char
{
Default,
WeightOnly,
Simt,
Int8
};
CutlassGemmType gemm_type = CutlassGemmType::Default;
if (simt_configs_only)
{
gemm_type = CutlassGemmType::Simt;
}
else if (is_weight_only)
{
gemm_type = CutlassGemmType::WeightOnly;
}
else if (int8_configs_only)
{
gemm_type = CutlassGemmType::Int8;
}
switch (gemm_type)
{
case CutlassGemmType::WeightOnly:
return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B,
CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B,
CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x16x128B,
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B};
default: throw std::runtime_error("get_candidate_tiles_sm90 only supports WeightOnly now.");
}
}
// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve
// compilation speed.
bool supports_mcast_along_m(const CutlassTileConfigSM90 tile)
{
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B,
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B};
return valid_tiles.count(tile) == 1;
}
// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve
// compilation speed.
bool supports_mcast_along_n(const CutlassTileConfigSM90 tile)
{
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B,
CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
CutlassTileConfigSM90::CtaShape128x256x128B};
return valid_tiles.count(tile) == 1;
}
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only,
const bool int8_configs_only, const int max_split_k, const bool enable_hopper_gmma)
{
if (sm == 90 && enable_hopper_gmma)
{
std::vector<CutlassTileConfigSM90> tiles
= get_candidate_tiles_sm90(sm, is_weight_only, simt_configs_only, int8_configs_only);
std::vector<CutlassGemmConfig> candidate_configs;
for (const auto& tile_config : tiles)
{
CutlassGemmConfig config(
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(config);
const bool has_m_mcast = supports_mcast_along_m(tile_config);
const bool has_n_mcast = supports_mcast_along_n(tile_config);
if (has_m_mcast)
{
CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x1x1);
candidate_configs.push_back(config);
}
if (has_n_mcast)
{
CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_1x2x1);
candidate_configs.push_back(config);
}
if (has_m_mcast && has_n_mcast)
{
CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
ClusterShape::ClusterShape_2x2x1);
candidate_configs.push_back(config);
}
}
return candidate_configs;
}
std::vector<CutlassTileConfig> tiles
= get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only);
@ -177,7 +272,7 @@ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weigh
{
for (int stages = min_stages; stages <= max_stages; ++stages)
{
CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages);
candidate_configs.push_back(config);
if (sm >= 75)
{
@ -253,8 +348,8 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
config_waves = num_waves_total;
SplitKStyle split_style
= split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
best_config = CutlassGemmConfig{
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
best_config = CutlassGemmConfig(
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages);
current_m_tile = tile_shape.m;
}
else if (current_score == config_score
@ -264,8 +359,8 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
// Prefer deeper pipeline or smaller split-k
SplitKStyle split_style
= split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
best_config = CutlassGemmConfig{
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
best_config = CutlassGemmConfig(
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages);
current_m_tile = tile_shape.m;
config_waves = num_waves_total;
}

View File

@ -28,7 +28,7 @@ namespace cutlass_kernels
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(int sm,
const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only = false,
const int max_split_k = 1);
const int max_split_k = 1, const bool enable_hopper_gmma = false);
tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies(
const std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig>& candidate_configs,

View File

@ -138,10 +138,14 @@ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type)
{
return getLayoutDetailsForArch<cutlass::arch::Sm75>(quant_type);
}
else if (arch >= 80 && arch <= 90)
else if (arch >= 80 && arch <= 89)
{
return getLayoutDetailsForArch<cutlass::arch::Sm80>(quant_type);
}
else if (arch == 90)
{
return getLayoutDetailsForArch<cutlass::arch::Sm90>(quant_type);
}
else
{
TLLM_CHECK_WITH_INFO(false, "Unsupported Arch");
@ -532,6 +536,7 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight,
const std::vector<size_t>& shape, QuantType quant_type)
{
const int arch = getSMVersion();
LayoutDetails details = getLayoutDetailsForTransform(quant_type);
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
@ -551,7 +556,6 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, co
// Works on row major data, so issue this permutation first.
if (details.uses_imma_ldsm)
{
const int arch = getSMVersion();
permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch);
src_buf.swap(dst_buf);
}
@ -568,7 +572,10 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, co
src_buf.swap(dst_buf);
}
add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type);
if (arch >= 70 && arch < 90)
{
add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type);
}
std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight);
}

View File

@ -0,0 +1,112 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include "cutlass/bfloat16.h"
#include "cutlass/float8.h"
#include "cutlass/half.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
///////////////////////////////////////////////////////////////////////////////////////////////////
// Tllm to Cutlass
template <typename T>
struct TllmToCutlassTypeAdapter
{
using type = T;
};
template <>
struct TllmToCutlassTypeAdapter<half>
{
using type = cutlass::half_t;
};
#if defined(ENABLE_BF16)
template <>
struct TllmToCutlassTypeAdapter<__nv_bfloat16>
{
using type = cutlass::bfloat16_t;
};
#endif
#if defined(ENABLE_FP8)
template <>
struct TllmToCutlassTypeAdapter<__nv_fp8_e4m3>
{
using type = cutlass::float_e4m3_t;
};
template <>
struct TllmToCutlassTypeAdapter<__nv_fp8_e5m2>
{
using type = cutlass::float_e5m2_t;
};
#endif
///////////////////////////////////////////////////////////////////////////////////////////////////
// Cutlass to Tllm
template <typename T>
struct CutlassToTllmTypeAdapter
{
using type = T;
};
template <>
struct CutlassToTllmTypeAdapter<cutlass::half_t>
{
using type = half;
};
#if defined(ENABLE_BF16)
template <>
struct CutlassToTllmTypeAdapter<cutlass::bfloat16_t>
{
using type = __nv_bfloat16;
};
#endif
#if defined(ENABLE_FP8)
template <>
struct CutlassToTllmTypeAdapter<cutlass::float_e4m3_t>
{
using type = __nv_fp8_e4m3;
};
template <>
struct CutlassToTllmTypeAdapter<cutlass::float_e5m2_t>
{
using type = __nv_fp8_e5m2;
};
#endif
///////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,35 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
#ifdef ENABLE_FP8
template class CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, /*Activation Type*/
cutlass::int4b_t, /*Weight Type*/
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, /*Scale and Zero Type*/
half, /*Bias type Type*/
half /*Output type Type*/
>;
#endif
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,35 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
#ifdef ENABLE_FP8
template class CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, /*Activation Type*/
cutlass::int4b_t, /*Weight Type*/
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, half, /*Scale and Zero Type*/
half, /*Bias type Type*/
half /*Output type Type*/
>;
#endif
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,35 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
#ifdef ENABLE_FP8
template class CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, /*Activation Type*/
cutlass::int4b_t, /*Weight Type*/
cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, half, /*Scale and Zero Type*/
half, /*Bias type Type*/
half /*Output type Type*/
>;
#endif
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -62,11 +62,21 @@ public:
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
= 0;
virtual void gemm(const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n,
int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
cudaStream_t stream)
= 0;
virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig,
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
= 0;
virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
const void* biases, const float alpha, void* C, int m, int n, int k, const int group_size,
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
= 0;
// Returns desired workspace size in bytes.
virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0;
@ -78,7 +88,8 @@ protected:
static constexpr int MIN_N_TILE = 64;
};
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp,
typename ScaleZeroType = ActivationType, typename BiasType = ActivationType, typename OutputType = ActivationType>
class CutlassFpAIntBGemmRunner : public virtual CutlassFpAIntBGemmRunnerInterface
{
public:
@ -89,10 +100,19 @@ public:
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
cudaStream_t stream) override;
void gemm(const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, int k,
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
cudaStream_t stream) override;
void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig,
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override;
void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
const void* biases, const float alpha, void* C, int m, int n, int k, const int group_size,
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
cudaStream_t stream) override;
// Disabled since the fused GEMM, activation kernels will not be used in v1.
// void gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, int m, int n,
@ -106,9 +126,10 @@ public:
private:
template <typename EpilogueTag>
void dispatch_to_arch(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points,
const T* biases, T* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr);
void dispatch_to_arch(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr,
const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr);
private:
int sm_;

View File

@ -38,6 +38,7 @@
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h"
namespace tk = tensorrt_llm::common;
namespace tkc = tensorrt_llm::cutlass_extensions;
@ -52,7 +53,7 @@ namespace cutlass_kernels
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
typename ThreadblockShape, typename WarpShape, int Stages>
void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T* weight_scales,
const T* weight_zero_points, const T* biases, T* C, int m, int n, int k, const int group_size,
const T* weight_zero_points, const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size,
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
int* occupancy = nullptr)
{
@ -177,7 +178,7 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
{reinterpret_cast<ElementType*>(const_cast<T*>(weight_scales)), ld_scale_zero},
{reinterpret_cast<ElementType*>(const_cast<T*>(weight_zero_points)), ld_scale_zero},
{reinterpret_cast<ElementType*>(const_cast<T*>(biases)), 0}, {reinterpret_cast<ElementType*>(C), n},
gemm_config.split_k_factor, {ElementAccumulator(1.f), output_op_beta});
gemm_config.split_k_factor, {ElementAccumulator(alpha), output_op_beta});
// This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of
// threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the
@ -230,8 +231,9 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
typename ThreadblockShape, typename WarpShape, int Stages>
void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points,
const T* biases, T* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr)
const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size,
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
@ -252,16 +254,17 @@ void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_
else
{
generic_mixed_gemm_kernelLauncher<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape,
Stages>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace,
workspace_bytes, stream, occupancy);
Stages>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config,
workspace, workspace_bytes, stream, occupancy);
}
}
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
typename ThreadblockShape, typename WarpShape>
void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points,
const T* biases, T* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr)
const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size,
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
@ -269,18 +272,18 @@ void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scale
{
case 2:
filter_and_run_mixed_gemm<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B,
weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace,
workspace_bytes, stream, occupancy);
break;
case 3:
filter_and_run_mixed_gemm<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape, 3>(A, B,
weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace,
workspace_bytes, stream, occupancy);
break;
case 4:
filter_and_run_mixed_gemm<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape, 4>(A, B,
weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace,
workspace_bytes, stream, occupancy);
break;
default:
std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages);
@ -289,58 +292,89 @@ void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scale
}
}
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
void dispatch_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points,
const T* biases, T* C, int m, int n, int k, const int group_size, char* workspace, size_t workspace_bytes,
tkc::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr)
template <typename T>
constexpr bool is_fp8()
{
return std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
}
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
void dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
int k, const int group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
cudaStream_t stream, int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
// Note that SIMT configs are omitted here since they are not supported for fpA_intB.
// We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best
// for mixed type gemms.
switch (gemm_config.tile_config)
// Don't instantiate configs that are not supported pre-hopper. Produce a sensible error instead.
constexpr bool any_is_fp8 = is_fp8<ActivationType>() || is_fp8<WeightType>() || is_fp8<ScaleZeroType>()
|| is_fp8<BiasType>() || is_fp8<OutputType>();
constexpr bool all_types_are_the_same = std::is_same_v<ActivationType, ScaleZeroType>
&& std::is_same_v<ActivationType, BiasType> && std::is_same_v<ActivationType, OutputType>;
constexpr bool is_valid_pre_hopper = all_types_are_the_same && !any_is_fp8;
if constexpr (is_valid_pre_hopper)
{
case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatch_gemm_config<T, WeightType, arch, QuantOp, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k,
group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
dispatch_gemm_config<T, WeightType, arch, QuantOp, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k,
group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
if (arch::kMinComputeCapability < 75)
// Note that SIMT configs are omitted here since they are not supported for fpA_intB.
// We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the
// best for mixed type gemms.
switch (gemm_config.tile_config)
{
TLLM_CHECK_WITH_INFO(false, "Invalid config on Volta");
case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
break;
case tkc::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
break;
case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
if (arch::kMinComputeCapability < 75)
{
TLLM_CHECK_WITH_INFO(false, "Invalid config on Volta");
}
else
{
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
stream, occupancy);
}
break;
case tkc::CutlassTileConfig::Undefined:
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined.");
break;
case tkc::CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
"[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have already been set by "
"heuristic.");
break;
default:
throw std::runtime_error(
"[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM.");
break;
}
else
{
dispatch_gemm_config<T, WeightType, arch, QuantOp, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 64>,
cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k,
group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
}
break;
case tkc::CutlassTileConfig::Undefined:
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined.");
break;
case tkc::CutlassTileConfig::ChooseWithHeuristic:
throw std::runtime_error(
"[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have already been set by "
"heuristic.");
break;
default:
throw std::runtime_error(
"[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM.");
break;
}
else
{
// This is not a limitation in CUTLASS. We just do not need to support this case.
std::string err_msg = "The activation type must equal the scale, bias and output types on Ampere and earlier.";
throw std::runtime_error("[TensorRT-LLm Error][dispatch_gemm_to_cutlass] " + err_msg);
}
}
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::CutlassFpAIntBGemmRunner()
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType>
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType,
OutputType>::CutlassFpAIntBGemmRunner()
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
int device{-1};
@ -349,42 +383,52 @@ CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::CutlassFpAIntBGemmRunner()
tk::check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
}
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::~CutlassFpAIntBGemmRunner()
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType>
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType,
OutputType>::~CutlassFpAIntBGemmRunner()
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
}
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType>
template <typename EpilogueTag>
void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::dispatch_to_arch<EpilogueTag>(const T* A, const WeightType* B,
const T* weight_scales, const T* weight_zero_points, const T* biases, T* C, int m, int n, int k,
const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr, const size_t workspace_bytes,
cudaStream_t stream, int* occupancy)
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType,
OutputType>::dispatch_to_arch<EpilogueTag>(const ActivationType* A, const WeightType* B,
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases,
const float alpha, OutputType* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (sm_ >= 70 && sm_ < 75)
{
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm70, QuantOp, EpilogueTag>(A, B, weight_scales,
weight_zero_points, biases, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream,
occupancy);
dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, cutlass::arch::Sm70,
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
}
else if (sm_ >= 75 && sm_ < 80)
{
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm75, QuantOp, EpilogueTag>(A, B, weight_scales,
weight_zero_points, biases, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream,
occupancy);
dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, cutlass::arch::Sm75,
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
}
else if (sm_ >= 80 && sm_ <= 90)
else if (sm_ >= 80 && sm_ < 90)
{
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm80, QuantOp, EpilogueTag>(A, B, weight_scales,
weight_zero_points, biases, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream,
occupancy);
dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, cutlass::arch::Sm80,
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
}
else if (sm_ == 90)
{
sm90_dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr,
workspace_bytes, gemm_config, stream, occupancy);
}
else
{
throw std::runtime_error(
"[TensorRT-LLm Error][CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS mixed type "
"[TensorRT-LLM Error][CutlassFpAIntBGemmRunner][dispatch_to_arch] Arch unsupported for CUTLASS mixed type "
"GEMM");
}
}
@ -424,18 +468,20 @@ void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::dispatch_to_arch<Epilogue
// }
// }
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::gemm(const void* A, const void* B, const void* weight_scales,
const void* weight_zero_points, const void* biases, void* C, int m, int n, int k, const int group_size,
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType>
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, const void* biases,
const float alpha, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig,
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
if constexpr ((QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
|| (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY))
{
dispatch_to_arch<tkc::EpilogueOpBias>((const T*) A, (const WeightType*) B, (const T*) weight_scales,
(const T*) weight_zero_points, (const T*) biases, (T*) C, m, n, k, group_size, gemmConfig, workspace_ptr,
workspace_bytes, stream, nullptr);
dispatch_to_arch<tkc::EpilogueOpBias>((const ActivationType*) A, (const WeightType*) B,
(const ScaleZeroType*) weight_scales, (const ScaleZeroType*) weight_zero_points, (const BiasType*) biases,
alpha, (OutputType*) C, m, n, k, group_size, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr);
}
else
{
@ -444,17 +490,31 @@ void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::gemm(const void* A, const
}
}
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::gemm(const void* A, const void* B, const void* weight_scales,
void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
cudaStream_t stream)
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType>
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, const void* biases,
void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr,
const size_t workspace_bytes, cudaStream_t stream)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
gemm(A, B, weight_scales, weight_zero_points, biases, 1.f, C, m, n, k, group_size, gemmConfig, workspace_ptr,
workspace_bytes, stream);
}
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType>
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, int k,
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY)
{
dispatch_to_arch<tkc::EpilogueOpDefault>((const T*) A, (const WeightType*) B, (const T*) weight_scales, nullptr,
nullptr, (T*) C, m, n, k, k, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr);
dispatch_to_arch<tkc::EpilogueOpBias>((const ActivationType*) A, (const WeightType*) B,
(const ScaleZeroType*) weight_scales, nullptr, nullptr, alpha, (OutputType*) C, m, n, k, k, gemmConfig,
workspace_ptr, workspace_bytes, stream, nullptr);
}
else
{
@ -462,17 +522,32 @@ void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::gemm(const void* A, const
}
}
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
std::vector<tkc::CutlassGemmConfig> CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::getConfigs() const
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType>
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k,
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
{
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
gemm(A, B, weight_scales, 1.f, C, m, n, k, gemmConfig, workspace_ptr, workspace_bytes, stream);
}
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType>
std::vector<tkc::CutlassGemmConfig>
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::getConfigs() const
{
static constexpr bool is_weight_only = !std::is_same<ActivationType, WeightType>::value;
std::vector<tkc::CutlassGemmConfig> candidateConfigs
= get_candidate_configs(sm_, is_weight_only, false, false, SPLIT_K_LIMIT);
= get_candidate_configs(sm_, is_weight_only, false, false, SPLIT_K_LIMIT, true);
return candidateConfigs;
}
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
size_t CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::getWorkspaceSize(const int m, const int n, const int k)
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType>
size_t
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::getWorkspaceSize(
const int m, const int n, const int k)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
// These are the min tile sizes for each config, which would launch the maximum number of blocks

View File

@ -0,0 +1,275 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cute/numeric/integral_constant.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h"
namespace tk = tensorrt_llm::common;
namespace tkc = tensorrt_llm::cutlass_extensions;
using namespace cute;
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
// This filters out invalid template combinations that we DON'T want instantiated in CUTLASS. For example,
// instantiating SM=75, Stages=3 is invalid so we would need to filter that out. Fine grained
// quanitzation is only supported on Ampere+ GPUs.
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
typename MainloopScheduleType>
void sm90_dispatch_epilogue_schedules(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
switch (gemm_config.epilogue_schedule)
{
case tkc::EpilogueScheduleType::AUTO:
using EpilogueScheduleType = cute::conditional_t<size<0>(CTAShape{}) == Int<64>{},
cutlass::epilogue::TmaWarpSpecialized, cutlass::epilogue::TmaWarpSpecializedCooperative>;
sm90_generic_mixed_gemm_kernelLauncher<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
EpilogueTag, CTAShape, ClusterShape, MainloopScheduleType, EpilogueScheduleType>(A, B, weight_scales,
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream,
occupancy);
break;
default:
throw std::runtime_error(
"[TensorRT-LLM Error][fpA_intB][sm90_dispatch_epilogue_schedules] epilogue schedule config is invalid for "
"mixed "
"type GEMM.");
break;
}
}
/*
1x1x1 cluster shape is are supported for any tile shape.
2x1x1 cluster shape is only supported for when the M tile is at least 128.
1x2x1 cluster shape is only supported when the N tile is at least 128.
2x2x1 cluster shape is only supported when both the M and N tiles are at least 128.
We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels
that may not be very useful in practice.
*/
template <typename CTAShape, typename ClusterShape>
constexpr bool are_tile_shapes_supported()
{
constexpr int cta_m = get<0>(CTAShape{});
constexpr int cta_n = get<1>(CTAShape{});
constexpr int cga_m = get<0>(ClusterShape{});
constexpr int cga_n = get<1>(ClusterShape{});
if constexpr (cga_m == _1{} && cga_n == _1{})
{
return true;
}
else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{})
{
return true;
}
else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{})
{
return true;
}
else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{})
{
return true;
}
else
{
return false;
}
}
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape>
void sm90_dispatch_mainloop_schedules(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
constexpr bool tile_shapes_supported = are_tile_shapes_supported<CTAShape, ClusterShape>();
if constexpr (tile_shapes_supported)
{
switch (gemm_config.mainloop_schedule)
{
case tkc::MainloopScheduleType::AUTO:
using KernelScheduleType = cute::conditional_t<size<0>(CTAShape{}) == Int<64>{},
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput,
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>;
sm90_dispatch_epilogue_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
EpilogueTag, CTAShape, ClusterShape, KernelScheduleType>(A, B, weight_scales, weight_zero_points,
biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
default:
throw std::runtime_error(
"[TensorRT-LLM Error][fpA_intB][sm90_dispatch_mainloop_schedules] mainloop schedule config is invalid "
"for "
"mixed type GEMM.");
break;
}
}
else
{
throw std::runtime_error(
"[TensorRT-LLM Error][fpA_intB][sm90_dispatch_mainloop_schedules] Unsupported CTA and Cluster shapes for "
"mixed type GEMM.");
}
}
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape>
void sm90_dispatch_gemm_config(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
switch (gemm_config.cluster_shape)
{
case tkc::ClusterShape::ClusterShape_1x1x1:
sm90_dispatch_mainloop_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
EpilogueTag, CTAShape, Shape<_1, _1, _1>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n,
k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::ClusterShape::ClusterShape_2x1x1:
sm90_dispatch_mainloop_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
EpilogueTag, CTAShape, Shape<_2, _1, _1>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n,
k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::ClusterShape::ClusterShape_1x2x1:
sm90_dispatch_mainloop_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
EpilogueTag, CTAShape, Shape<_1, _2, _1>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n,
k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::ClusterShape::ClusterShape_2x2x1:
sm90_dispatch_mainloop_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
EpilogueTag, CTAShape, Shape<_2, _2, _1>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n,
k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
default:
throw std::runtime_error(
"[TensorRT-LLM Error][fpA_intB][dispatch_CGA_config] Config is invalid for mixed type GEMM.");
break;
}
}
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
void sm90_dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
int k, const int group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
cudaStream_t stream, int* occupancy = nullptr)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
// Note that SIMT configs are omitted here since they are not supported for fpA_intB.
// We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best
// for mixed type gemms.
constexpr int Ktile = 128 / sizeof(ActivationType);
using _Ktile = Int<Ktile>;
switch (gemm_config.tile_config_sm90)
{
case tkc::CutlassTileConfigSM90::CtaShape64x16x128B:
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
Shape<_64, _16, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfigSM90::CtaShape64x32x128B:
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
Shape<_64, _32, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfigSM90::CtaShape64x64x128B:
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
Shape<_64, _64, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfigSM90::CtaShape64x128x128B:
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
Shape<_64, _128, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfigSM90::CtaShape64x256x128B:
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
Shape<_64, _256, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfigSM90::CtaShape128x16x128B:
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
Shape<_128, _16, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfigSM90::CtaShape128x32x128B:
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
Shape<_128, _32, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfigSM90::CtaShape128x64x128B:
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
Shape<_128, _64, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfigSM90::CtaShape128x128x128B:
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
Shape<_128, _128, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfigSM90::CtaShape128x256x128B:
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
Shape<_128, _256, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
gemm_config, workspace, workspace_bytes, stream, occupancy);
break;
case tkc::CutlassTileConfigSM90::Undefined:
throw std::runtime_error(
"[TensorRT-LLm Error][fpA_intB][sm90_dispatch_gemm_to_cutlass] gemm config undefined.");
break;
case tkc::CutlassTileConfigSM90::ChooseWithHeuristic:
throw std::runtime_error(
"[TensorRT-LLm Error][fpA_intB][sm90_dispatch_gemm_to_cutlass] gemm config should have already been set by "
"heuristic.");
break;
default:
throw std::runtime_error(
"[TensorRT-LLm Error][fpA_intB][sm90_dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM.");
break;
}
}
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,39 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "cutlass_extensions/gemm_configs.h"
#include "cutlass_extensions/weight_only_quant_op.h"
#include <cuda_runtime_api.h>
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
typename MainloopScheduleType, typename EpilogueScheduleType>
void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const WeightType* B,
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases,
const float alpha, OutputType* C, int m, int n, int k, const int group_size,
tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr);
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,298 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef _WIN32
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif // #ifndef _WIN32
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass_extensions/compute_occupancy.h"
#include "cutlass_extensions/epilogue_helpers.h"
#include "cutlass_extensions/gemm_configs.h"
#ifndef _WIN32
#pragma GCC diagnostic pop
#endif // #ifndef _WIN32
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h"
namespace tk = tensorrt_llm::common;
namespace tkc = tensorrt_llm::cutlass_extensions;
using namespace cute;
namespace tensorrt_llm
{
namespace kernels
{
namespace cutlass_kernels
{
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
typename MainloopScheduleType, typename EpilogueScheduleType>
void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const WeightType* B,
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases,
const float alpha, OutputType* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy)
{
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
#ifdef COMPILE_HOPPER_MIXED_INPUT_GEMMS
using CutlassActivationType = typename TllmToCutlassTypeAdapter<ActivationType>::type;
// For FAST_BUILD, only instantiate kernels with 128x128x128B with 1x1x1 cluster shape.
#ifdef FAST_BUILD
constexpr int TILE_K = 128 * 8 / cutlass::sizeof_bits<CutlassActivationType>::value;
using SupportedCtaShape = Shape<_128, _128, cute::Int<TILE_K>>;
using SupportedCgaShape = Shape<_1, _1, _1>;
if constexpr (cute::is_same_v<SupportedCtaShape, CTAShape> && cute::is_same_v<SupportedCgaShape, ClusterShape>)
{
#endif // FAST_BUILD
using CutlassWeightType__ = typename TllmToCutlassTypeAdapter<WeightType>::type;
// We need to remap this since SM90 uses a different layout for the weight matrix.
using CutlassWeightType_ = std::conditional_t<std::is_same_v<CutlassWeightType__, cutlass::uint4b_t>,
cutlass::int4b_t, CutlassWeightType__>;
using CutlassWeightType
= std::conditional_t<std::is_same_v<CutlassWeightType_, uint8_t>, int8_t, CutlassWeightType_>;
using CutlassScaleZeroType = typename TllmToCutlassTypeAdapter<ScaleZeroType>::type;
using CutlassBiasType = typename TllmToCutlassTypeAdapter<BiasType>::type;
using CutlassOutputType = typename TllmToCutlassTypeAdapter<OutputType>::type;
static_assert(std::is_same_v<CutlassActivationType, cutlass::half_t>
|| std::is_same_v<CutlassActivationType, cutlass::bfloat16_t>
|| std::is_same_v<CutlassActivationType, cutlass::float_e4m3_t>
|| std::is_same_v<CutlassActivationType, cutlass::float_e5m2_t>,
"Activation type must be bfloat16, half, FP8");
static_assert(std::is_same_v<CutlassWeightType, int8_t> || std::is_same_v<CutlassWeightType, cutlass::int4b_t>
|| std::is_same_v<CutlassWeightType, cutlass::float_e4m3_t>
|| std::is_same_v<CutlassWeightType, cutlass::float_e5m2_t>,
"Weight type must be fp8, int8_t or int4_t");
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<CutlassActivationType>::value;
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<CutlassWeightType>::value;
// This example manually swaps and transposes, so keep transpose of input layouts
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
using ElementZero = CutlassScaleZeroType;
using ElementScale = CutlassScaleZeroType;
// C/D matrix configuration. We reuse the C operand for the bias and set the stride for broadcast.
using LayoutBias = cutlass::layout::RowMajor;
constexpr int AlignmentBias = 128 / cutlass::sizeof_bits<CutlassBiasType>::value;
// D matrix configuration
using LayoutOutput = cutlass::layout::RowMajor;
constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<CutlassOutputType>::value;
// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = CTAShape; // Threadblock-level tile size
using KernelSchedule = MainloopScheduleType;
using EpilogueSchedule = EpilogueScheduleType;
// Shrink the N dimension to match CTA_N if needed
constexpr int epi_tile_M = cute::min(shape<0>(TileShape{}), 128); // 64 or 128
constexpr int epi_tile_N = cute::min(shape<1>(TileShape{}), 32); // Allow this to be 16 for some small N tiles.
using EpilogueTileType = cute::Shape<cute::Int<epi_tile_M>, cute::Int<epi_tile_N>>;
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
static_assert(std::is_same_v<EpilogueTag, tensorrt_llm::cutlass_extensions::EpilogueOpBias>, "");
using EVT_bias_addition = cutlass::epilogue::fusion::Sm90EVT<
cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, CutlassOutputType, ElementCompute,
RoundStyle>, // alpha * acc + bias
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAccumulator>, // alpha
cutlass::epilogue::fusion::Sm90AccFetch, // acc
cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, CutlassBiasType, Stride<_1, _0, _0>,
AlignmentBias> // bias
>;
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<ArchTag, OperatorClass,
TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementAccumulator,
// Transpose layout of D here since we use the explicit swap + transpose trick
// Void C since we don't use it. Prevents smem allocation.
void, typename cutlass::layout::LayoutTranspose<LayoutBias>::type, AlignmentBias, CutlassOutputType,
typename cutlass::layout::LayoutTranspose<LayoutOutput>::type, AlignmentOutput, EpilogueSchedule,
EVT_bias_addition>::CollectiveOp;
using PackedScaleZero = cute::tuple<CutlassWeightType, ElementScale, ElementZero>;
using PackedScale = cute::tuple<CutlassWeightType, ElementScale>;
using ElementBCollectiveInfo = std::conditional_t<cutlass::hasZero(QuantOp), PackedScaleZero, PackedScale>;
// We swap A and B operands to the builder here
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<ArchTag, OperatorClass,
ElementBCollectiveInfo, LayoutB_Transpose, AlignmentB, CutlassActivationType, LayoutA_Transpose, AlignmentA,
ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue>;
if (occupancy != nullptr)
{
*occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel, true>();
return;
}
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
using StrideA = typename GemmKernel::StrideA;
using StrideB = typename GemmKernel::StrideB;
using StrideC = typename GemmKernel::StrideC;
using StrideD = typename GemmKernel::StrideD;
using StrideS = typename CollectiveMainloop::StrideScale;
if (weight_scales == nullptr)
{
throw std::runtime_error("Weight scales must always be set to a non-null value.");
}
if constexpr (cutlass::isFinegrained(QuantOp))
{
int cta_shape_k = cute::size<2>(TileShape{});
if (group_size % cta_shape_k != 0)
{
std::string err_msg = "The group size must a multiple of " + std::to_string(cta_shape_k);
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner]" + err_msg);
}
if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)
{
if (weight_zero_points != nullptr)
{
throw std::runtime_error("Weight zero pointer must be a nullptr for scale only fine grained");
}
}
else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
{
if (weight_zero_points == nullptr)
{
throw std::runtime_error("Weight zero pointer must be valid for scale and bias fine grained");
}
}
}
else
{
if (group_size != k)
{
throw std::runtime_error("Invalid group size for per column scaling kernels.");
}
if (weight_zero_points != nullptr)
{
throw std::runtime_error("Weight zero-points must be null when running per column scaling");
}
}
auto cutlass_scale_k = (k + group_size - 1) / group_size;
StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1));
StrideS stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(n, cutlass_scale_k, 1));
// Use the output as the bias to avoid making a tma descriptor with a nullptr.
auto output_as_bias_type = reinterpret_cast<const CutlassBiasType*>(C);
typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, {n, m, k, 1},
{reinterpret_cast<CutlassWeightType const*>(B), stride_B, reinterpret_cast<CutlassActivationType const*>(A),
stride_A, reinterpret_cast<ElementScale const*>(weight_scales), stride_S, group_size,
reinterpret_cast<ElementZero const*>(weight_zero_points)},
{{}, output_as_bias_type, stride_D, reinterpret_cast<CutlassOutputType*>(C), stride_D}};
args.epilogue.thread = {
{alpha}, // alpha args
{}, // accumulator
{reinterpret_cast<CutlassBiasType const*>(biases), CutlassBiasType(0.f)}, // bias args
{} // end multiply_add
};
Gemm gemm;
if (gemm.get_workspace_size(args) > workspace_bytes)
{
TLLM_LOG_ERROR("[TensorRT-LLm Error][fpA_intB Runner] given workspace size insufficient.");
}
auto can_implement = gemm.can_implement(args);
if (can_implement != cutlass::Status::kSuccess)
{
std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: "
+ std::string(cutlassGetStatusString(can_implement));
std::cout << err_msg << std::endl;
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg);
}
auto init_status = gemm.initialize(args, workspace, stream);
if (init_status != cutlass::Status::kSuccess)
{
std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: "
+ std::string(cutlassGetStatusString(init_status));
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg);
}
auto run_status = gemm.run(stream);
if (run_status != cutlass::Status::kSuccess)
{
std::string err_msg
= "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status));
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg);
}
#ifdef FAST_BUILD
}
else
{
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] Config not compiled with FAST_BUILD.");
}
#endif // FAST_BUILD
#else // COMPILE_HOPPER_MIXED_INPUT_GEMMS
throw std::runtime_error(
"[TensorRT-LLm Error][fpA_intB Runner] Please recompile with support for hopper by passing 90-real as an arch "
"to build_wheel.py.");
#endif // COMPILE_HOPPER_MIXED_INPUT_GEMMS
}
} // namespace cutlass_kernels
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,304 @@
import argparse
import enum
import os
from itertools import product
from cutlass_library import *
################################################################################
# Epilogue Tag enum and string utils
class TrtLlm_EpilogueTag(enum.Enum):
epilogue_op_default = enum_auto()
epilogue_op_bias = enum_auto()
EpiTagNames = {
TrtLlm_EpilogueTag.epilogue_op_default: "lc", # linear combination
TrtLlm_EpilogueTag.epilogue_op_bias:
"lc_bias" # linear combination with bias addition
}
EpiTag = {
TrtLlm_EpilogueTag.epilogue_op_default:
"tensorrt_llm::cutlass_extensions::EpilogueOpDefault",
TrtLlm_EpilogueTag.epilogue_op_bias:
"tensorrt_llm::cutlass_extensions::EpilogueOpBias"
}
################################################################################
# Quantization Operation and string utils
class TrtLlm_QuantOp(enum.Enum):
per_column_scale_only = enum_auto()
finegrained_scale_only = enum_auto()
finegrained_scale_and_zeros = enum_auto()
QuantOpNames = {
TrtLlm_QuantOp.per_column_scale_only: "cs",
TrtLlm_QuantOp.finegrained_scale_only: "fgs",
TrtLlm_QuantOp.finegrained_scale_and_zeros: "fgsz"
}
QuantOpTag = {
TrtLlm_QuantOp.per_column_scale_only:
"cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY",
TrtLlm_QuantOp.finegrained_scale_only:
"cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY",
TrtLlm_QuantOp.finegrained_scale_and_zeros:
"cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS"
}
################################################################################
# The activations, biases, scales and zeros are instantiated using CUDA types,
# not CUTLASS types. This map materializes the name of the CUDA type.
CudaTypeName = {
DataType.e4m3: "__nv_fp8_e4m3",
DataType.bf16: "__nv_bfloat16",
DataType.f16: "half"
}
################################################################################
# A data structure holding all info to instantiate gemm launchers in TRT LLM.
class TrtLlm_GemmLauncher:
def __init__(self,
gemm_kind,
arch,
act_type,
weight_type,
scalezero_type,
bias_type,
output_type,
quant_op,
epi_tag,
cta_shape,
warp_shape,
stages,
cga_shape=None,
mainloop_schedule=None,
epi_schedule=None):
self.gemm_kind = gemm_kind
self.arch = arch
self.act_type = act_type
self.weight_type = weight_type
self.scalezero_type = scalezero_type
self.bias_type = bias_type
self.output_type = output_type
self.quant_op = quant_op
self.epi_tag = epi_tag
self.cta_shape = cta_shape
self.warp_shape = warp_shape
self.stages = stages
self.cga_shape = cga_shape
self.mainloop_schedule = mainloop_schedule
self.epi_schedule = epi_schedule
def __repr__(self):
kernel_prefix = "{}_sm{}_{}_{}_{}_{}_{}_{}_{}_{}x{}x{}_{}x{}x{}_{}".format(
GemmKindNames[self.gemm_kind], self.arch,
DataTypeNames[self.act_type], DataTypeNames[self.weight_type],
DataTypeNames[self.scalezero_type], DataTypeNames[self.bias_type],
DataTypeNames[self.output_type], QuantOpNames[self.quant_op],
EpiTagNames[self.epi_tag], self.cta_shape[0], self.cta_shape[1],
self.cta_shape[2], self.warp_shape[0], self.warp_shape[1],
self.warp_shape[2], self.stages)
hopper_suffix = "_{}x{}x{}{}{}".format(
self.cga_shape[0], self.cga_shape[1], self.cga_shape[2],
KernelScheduleSuffixes[self.mainloop_schedule],
EpilogueScheduleSuffixes[self.epi_schedule])
if self.arch == 90:
return kernel_prefix + hopper_suffix
elif self.arch > 90:
raise ValueError(f"SM{self.arch} not supported yet.")
return kernel_prefix
################################################################################
def tuple_to_cute_shape(shape):
return f"cute::Shape<cute::Int<{shape[0]}>, cute::Int<{shape[1]}>, cute::Int<{shape[2]}>>"
def instantiate_operation(operation):
act_tag = CudaTypeName[operation.act_type]
weight_tag = DataTypeTag[operation.weight_type]
scale_zero_tag = CudaTypeName[operation.scalezero_type]
bias_tag = CudaTypeName[operation.bias_type]
out_tag = CudaTypeName[operation.output_type]
quant_op = QuantOpTag[operation.quant_op]
epi_tag = EpiTag[operation.epi_tag]
cute_cta_shape = tuple_to_cute_shape(operation.cta_shape)
cute_cga_shape = tuple_to_cute_shape(operation.cga_shape)
kernel_sched = KernelScheduleTag[operation.mainloop_schedule]
# Here, we must append MixedInput depending on the schedule, since we know the types are different.
# It is a work around since the CUTLASS library did not have the MixedInput schedules at the time of writing.
if operation.mainloop_schedule in [
KernelScheduleType.TmaWarpSpecializedCooperative,
KernelScheduleType.TmaWarpSpecializedPingpong,
KernelScheduleType.TmaWarpSpecialized
]:
kernel_sched += "MixedInput"
epi_sched = EpilogueScheduleTag[operation.epi_schedule]
instantiation = f"""
template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag},
{quant_op}, {epi_tag},
{cute_cta_shape}, {cute_cga_shape},
{kernel_sched}, {epi_sched}> (
const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float,
{out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
);
"""
return instantiation
def get_file_content(launcher_inl_files, operations):
include_list = list()
for file in launcher_inl_files:
include_list.append(f"#include \"{file}\"")
includes = "\n".join(include_list)
insts_list = list()
for op in operations:
insts_list.append(instantiate_operation(op))
instantiations = "\n".join(insts_list)
file_content = f"""{includes}
namespace tensorrt_llm
{{
namespace kernels
{{
namespace cutlass_kernels
{{
{instantiations}
}} // namespace cutlass_kernels
}} // namespace kernels
}} // namespace tensorrt_llm
"""
return file_content
def write_file(launcher_inl_files, operations, output_file):
with open(output_file, mode="w") as f:
f.write(get_file_content(launcher_inl_files, operations))
def is_op_valid(op):
tile_m, tile_n, _ = op.cta_shape
cga_m, cga_n, _ = op.cga_shape
if cga_m == 1 and cga_n == 1:
return True
if cga_m == 2 and cga_n == 1 and tile_m >= 128:
return True
if cga_m == 1 and cga_n == 2 and tile_n >= 128:
return True
if cga_m == 2 and cga_n == 2 and tile_m >= 128 and tile_n >= 128:
return True
return False
################################################################################
def generate_sm90_operations():
arch = 90
# For legacy reasons, we use unsigned types for fp16 / bf16 activations.
# Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type)
supported_dtypes = [
(DataType.e4m3, DataType.s4, DataType.f16, DataType.f16, DataType.f16),
(DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
(DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16,
DataType.bf16),
(DataType.f16, DataType.u8, DataType.f16, DataType.f16, DataType.f16),
(DataType.bf16, DataType.u8, DataType.bf16, DataType.bf16,
DataType.bf16)
]
quant_ops = [
TrtLlm_QuantOp.per_column_scale_only,
TrtLlm_QuantOp.finegrained_scale_only,
TrtLlm_QuantOp.finegrained_scale_and_zeros
]
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_bias]
M_TILES = [64, 128]
N_TILES = [16, 32, 64, 128, 256]
cta_shapes_mn = product(M_TILES, N_TILES)
warp_shape = [4, 1, 1]
stages = 0 # auto
cga_shapes = product([1, 2], [1, 2], [1])
partial_args = product(supported_dtypes, quant_ops, epi_tags, cta_shapes_mn,
cga_shapes)
operations = list()
for dtype_combo, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args:
max_k_bits = 128 * 8
cta_shape_k = max_k_bits // DataTypeSize[dtype_combo[0]]
cta_shape_mnk = cta_shape_mn + (cta_shape_k, )
use_coop = cta_shape_mn[0] == 128
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative if use_coop else KernelScheduleType.TmaWarpSpecializedPingpong
epi_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative if use_coop else EpilogueScheduleType.TmaWarpSpecialized
operation = TrtLlm_GemmLauncher(GemmKind.Gemm, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule)
if is_op_valid(operation):
operations.append(operation)
return operations
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Print the output directory')
# Add the output_dir argument with short and long options
parser.add_argument('-o',
'--output_dir',
type=str,
required=True,
help='Path to the output directory')
# Parse the command line arguments
args = parser.parse_args()
# Get the absolute path of the provided directory
output_dir = os.path.abspath(args.output_dir)
hopper_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
# The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads.
# Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve.
operations = generate_sm90_operations()
op_groups = dict()
for op in operations:
dict_key = (op.gemm_kind, op.arch, op.cta_shape[0])
op_group = op_groups.get(dict_key, list())
op_group.append(op)
op_groups[dict_key] = op_group
file_counter = 1
for key, value in op_groups.items():
out_file = os.path.join(
output_dir, f"cutlass_kernel_file_{file_counter}.generated.cu")
write_file([hopper_inl], value, out_file)
file_counter += 1

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -64,93 +64,113 @@ void invokeCurandBatchInitialize(curandState_t* states, const int* batchSlots, c
}
template <typename T>
__global__ void addBiasSoftMax(T* logits, T* probs, const T* bias, const int* endIds, const FinishedState* finished,
const int* batchSlots, const int vocabSize, const int vocabSizePadded)
__global__ void addBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bias, int32_t const* endIds,
FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize, int32_t maxBatchSize,
int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax, bool batchSlotsLogits)
{
auto const batchIdx = blockIdx.x;
auto const beamIdx = blockIdx.y;
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
const FinishedState finishState = finished != nullptr ? finished[batchSlot] : FinishedState::empty();
auto const batchIdxLogits = batchSlotsLogits ? batchSlot : batchIdx;
FinishedState const finishState
= finished != nullptr ? finished[beamIdx * maxBatchSize + batchSlot] : FinishedState::empty();
if (finishState.isSkipDecoding())
{
return;
}
auto logitsPtr = logitsPtrs ? logitsPtrs[batchIdx] + beamIdx * vocabSizePadded
: logits + (batchIdxLogits * beamWidth + beamIdx) * vocabSizePadded;
bool finish = finishState.isFinished();
int offset = batchIdx * vocabSizePadded;
int offset = (batchIdxLogits * beamWidth + beamIdx) * vocabSizePadded;
float maxVal = -1 * FLT_MAX;
const bool IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
bool const IS_FP16 = std::is_same<T, half>::value;
T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
__shared__ float sMaxVal;
__shared__ float sSumVal;
for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x)
{
auto logit = logitsPtr[tid];
if (tid < vocabSize)
{
if (finish && endIds != nullptr)
{
logits[offset + tid] = (tid == endIds[batchSlot]) ? MAX_T_VAL : -MAX_T_VAL;
logit = (tid == endIds[batchSlot]) ? MAX_T_VAL : -MAX_T_VAL;
}
else
{
T bias_val = (bias != nullptr) ? bias[tid] : (T) 0.0f;
logits[offset + tid] += bias_val;
logit += bias_val;
}
}
else
{
logits[offset + tid] = -MAX_T_VAL;
logit = -MAX_T_VAL;
}
maxVal = max(maxVal, (float) logits[offset + tid]);
maxVal = max(maxVal, (float) logit);
logitsPtr[tid] = logit;
}
maxVal = blockReduceMax<float>((float) maxVal);
if (threadIdx.x == 0)
if (!skipSoftMax)
{
sMaxVal = maxVal;
}
__syncthreads();
maxVal = blockReduceMax<float>((float) maxVal);
if (threadIdx.x == 0)
{
sMaxVal = maxVal;
}
__syncthreads();
float sumVal = 0.0f;
for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x)
{
probs[offset + tid] = __expf((float) logits[offset + tid] - sMaxVal);
sumVal += (float) probs[offset + tid];
}
float sumVal = 0.0f;
for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x)
{
probs[offset + tid] = __expf((float) logitsPtr[tid] - sMaxVal);
sumVal += (float) probs[offset + tid];
}
sumVal = blockReduceSum<float>(sumVal);
if (threadIdx.x == 0)
{
sSumVal = sumVal;
}
__syncthreads();
sumVal = blockReduceSum<float>(sumVal);
if (threadIdx.x == 0)
{
sSumVal = sumVal;
}
__syncthreads();
for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x)
{
probs[offset + tid] = ((float) probs[offset + tid] / (sSumVal + 1e-6f));
for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x)
{
probs[offset + tid] = ((float) probs[offset + tid] / (sSumVal + 1e-6f));
}
}
}
template <typename T>
void invokeAddBiasSoftMax(T* logits, T* probs, const T* bias, const int* endIds, const FinishedState* finished,
const int* batchSlots, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream)
void invokeAddBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bias, int32_t const* endIds,
FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize, int32_t maxBatchSize,
int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax, bool batchSlotsLogits,
cudaStream_t stream)
{
dim3 grid(batchSize);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
dim3 grid(batchSize, beamWidth);
auto const vocabRoundedToWarp = roundUp(vocabSize, 32);
dim3 block(min(vocabRoundedToWarp, 1024));
// vocabSize, e.g., 30000, 7000.... vocabSize is usually very big.
addBiasSoftMax<<<grid, block, 0, stream>>>(
logits, probs, bias, endIds, finished, batchSlots, vocabSize, vocabSizePadded);
addBiasSoftMax<<<grid, block, 0, stream>>>(logits, logitsPtrs, probs, bias, endIds, finished, batchSlots, batchSize,
maxBatchSize, beamWidth, vocabSize, vocabSizePadded, skipSoftMax, batchSlotsLogits);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template void invokeAddBiasSoftMax(float* logits, float* probs, const float* bias, const int* endIds,
const FinishedState* finished, const int* batchSlots, const int m, const int nPadded, const int n,
cudaStream_t stream);
template void invokeAddBiasSoftMax(float* logits, float** logitsPtrs, float* probs, float const* bias,
int32_t const* endIds, FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize,
int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax,
bool batchSlotsLogits, cudaStream_t stream);
template void invokeAddBiasSoftMax(half* logits, half* probs, const half* bias, const int* endIds,
const FinishedState* finished, const int* batchSlots, const int m, const int nPadded, const int n,
cudaStream_t stream);
template void invokeAddBiasSoftMax(half* logits, half** logitsPtrs, half* probs, half const* bias,
int32_t const* endIds, FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize,
int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax,
bool batchSlotsLogits, cudaStream_t stream);
template <typename T>
__global__ void scatterDecodingParamsKernel(T const* src, T* dst, int const* batchSlots, int batchSize)

View File

@ -169,19 +169,28 @@ void invokeCurandBatchInitialize(curandState_t* states, const int* batchSlots, c
//! endId token. Otherwise, adds bias per token if bias pointer is not nullptr.
//!
//! \param logits input/output buffer [maxBatchSize, vocabSize]. Logits to be modified by mask and bias.
//! If nullptr, logitsPtrs has to be provided.
//! \param logitsPtrs input/output buffer [maxBatchSize][vocabSize]. Vector of pointers to the logits.
//! If nullptr, logits has to be provided.
//! \param probs output buffer [maxBatchSize, vocabSize]. Probabilities of logits compute by softmax.
//! Can be the same pointer as logits
//! \param bias input buffer [vocabSize]. Bias to logit per token. Ignored if nullptr
//! \param endIds input buffer [maxBatchSize]. EOS token ids per request
//! \param finished input buffer [maxBatchSize] with flags set to true if request has finished the generation
//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool
//! \param batchSize batch size
//! \param batchSize current batch size
//! \param maxBatchSize max batch size
//! \param beamWidth beam width
//! \param vocabSize unpadded vocab size
//! \param vocabSizePadded padded vocab size
//! \param skipSoftMax flag to skip softmax computation
//! \param batchSlotsLogits flag to use batchSlot as index for logits and probs
//! \param stream stream
template <typename T>
void invokeAddBiasSoftMax(T* logits, T* probs, const T* bias, const int* endIds, const FinishedState* finished,
const int* batchSlots, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream);
void invokeAddBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bias, int32_t const* endIds,
FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize, int32_t maxBatchSize,
int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax, bool batchSlotsLogits,
cudaStream_t stream);
//! \brief Distributes values located in src to dst according to the indieces from batchSlots
//!

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -499,26 +499,26 @@ void invokeTransposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled,
outputLogProbs, outputLogProbsTiled, sequenceLengths, batchSlots, batchSize, beamWidth, maxSeqLen);
}
__global__ void acceptDraftTokensByIds(const int* draftIds, const int* targetIds, const int* contextLengths,
const int* numsDraftTokens, int* sequenceLengths, const FinishedState* finished, FinishedState* finishedFinal,
int* finishedSum, int batchSize, int beamWidth, int maxSeqLen, int maxDraftTokens)
__global__ void acceptDraftTokensByIds(int32_t const* draftIds, int32_t const* targetIds, int32_t const* contextLengths,
int32_t const* numsDraftTokens, int32_t* sequenceLengths, FinishedState const* finished,
FinishedState* finishedFinal, int32_t* finishedSum, int32_t const* batchSlots, int32_t batchSize,
int32_t maxBatchSize, int32_t maxSeqLen, int32_t maxDraftTokens)
{
int threadFinishedCount = 0;
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batchSize * beamWidth;
index += blockDim.x * gridDim.x)
for (int batchIdx = threadIdx.x; batchIdx < batchSize; batchIdx += blockDim.x)
{
const auto numDraftTokens = numsDraftTokens[index];
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
auto const numDraftTokens = numsDraftTokens[batchSlot];
const auto contextLength = contextLengths[index];
auto& sequenceLength = sequenceLengths[index];
auto const contextLength = contextLengths[batchSlot];
auto& sequenceLength = sequenceLengths[batchSlot];
int finishedDraftIdx = 0;
for (int ti = contextLength; ti < min(sequenceLength, contextLength + numDraftTokens); ++ti, ++finishedDraftIdx)
{
const auto draftIdx = ti - contextLength;
const auto targetTokenIdx = index * maxSeqLen + ti;
const auto draftTokenIdx = index * maxDraftTokens + draftIdx;
auto const draftIdx = ti - contextLength;
auto const targetTokenIdx = batchSlot * maxSeqLen + ti;
auto const draftTokenIdx = batchSlot * maxDraftTokens + draftIdx;
// Check if draft tokens are the same as target tokens
const bool accepted = draftIds[draftTokenIdx] == targetIds[targetTokenIdx];
bool const accepted = draftIds[draftTokenIdx] == targetIds[targetTokenIdx];
if (!accepted)
{
// Set sequence length to the numAcceptedTokens + 1
@ -527,65 +527,57 @@ __global__ void acceptDraftTokensByIds(const int* draftIds, const int* targetIds
break;
}
}
FinishedState finishState = finished[finishedDraftIdx * batchSize * beamWidth + index];
finishedFinal[index] = finishState;
threadFinishedCount += static_cast<int>(finishState.isFinished());
}
FinishedState finishState = finished[finishedDraftIdx * maxBatchSize + batchSlot];
finishedFinal[batchSlot] = finishState;
if (finishedSum)
{
int blockFinishedCount = 0;
if (blockDim.x <= 32)
if (finishedSum)
{
blockFinishedCount = warpReduceSum(threadFinishedCount);
}
else
{
blockFinishedCount = blockReduceSum(threadFinishedCount);
}
__syncthreads();
if (threadIdx.x == 0)
{
finishedSum[0] = blockFinishedCount;
finishedSum[batchSlot] = static_cast<int>(finishState.isFinished());
}
}
}
void invokeAcceptDraftTokensByIds(const int* draftIds, const int* targetIds, const int* contextLengths,
const int* numsDraftTokens, int* sequenceLengths, const FinishedState* finished, FinishedState* finishedFinal,
int* finishedSum, int batchSize, int beamWidth, int maxSeqLen, int maxDraftTokens, cudaStream_t stream)
void invokeAcceptDraftTokensByIds(int32_t const* draftIds, int32_t const* targetIds, int32_t const* contextLengths,
int32_t const* numsDraftTokens, int32_t* sequenceLengths, FinishedState const* finished,
FinishedState* finishedFinal, int32_t* finishedSum, int32_t const* batchSlots, int32_t batchSize,
int32_t maxBatchSize, int32_t beamWidth, int32_t maxSeqLen, int32_t maxDraftTokens, cudaStream_t stream)
{
TLLM_CHECK(beamWidth == 1);
dim3 block(min(256, batchSize * beamWidth));
dim3 block(min(1024, batchSize));
dim3 grid(1);
acceptDraftTokensByIds<<<grid, block, 0, stream>>>(draftIds, targetIds, contextLengths, numsDraftTokens,
sequenceLengths, finished, finishedFinal, finishedSum, batchSize, beamWidth, maxSeqLen, maxDraftTokens);
sequenceLengths, finished, finishedFinal, finishedSum, batchSlots, batchSize, maxBatchSize, maxSeqLen,
maxDraftTokens);
}
template <typename T>
__global__ void acceptDraftTokensByLogitsKernel(const T* draftProbs, T* targetProbs, const int* numsDraftTokens,
FinishedState* finished, curandState_t* curandState, int batchSize, int beamWidth, int vocabSize,
bool randomThreshold, float constantThreshold)
__global__ void acceptDraftTokensByLogitsKernel(T const* draftProbs, T* targetProbs, int32_t const* numsDraftTokens,
FinishedState* finished, curandState_t* curandState, int32_t const* batchSlots, int32_t batchSize,
int32_t maxBatchSize, int32_t maxDraftTokens, int32_t beamWidth, int32_t vocabSize, bool randomThreshold,
float constantThreshold)
{
const auto bid = blockIdx.x;
const auto draftTokenIdx = blockIdx.y;
const auto batchIdx = bid / beamWidth;
auto const bid = blockIdx.x;
auto const draftTokenIdx = blockIdx.y;
auto const batchIdx = bid / beamWidth;
auto const beamIdx = bid % beamWidth;
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
auto const batchSlotBeamWidth = batchSlot * beamWidth + beamIdx;
const auto numDraftTokens = numsDraftTokens[bid];
auto const numDraftTokens = numsDraftTokens[batchSlotBeamWidth];
if (draftTokenIdx >= numDraftTokens)
{
return;
}
const auto logitsOffset = draftTokenIdx * batchSize * beamWidth * vocabSize + bid * vocabSize;
const auto draftProbsBatch = draftProbs + logitsOffset;
const auto targetProbsBatch = targetProbs + logitsOffset;
auto const logitsOffset = (batchSlot * maxDraftTokens + draftTokenIdx) * beamWidth * vocabSize;
auto const draftProbsBatch = draftProbs + logitsOffset;
auto const targetProbsBatch = targetProbs + logitsOffset;
int rejected = 0;
int32_t rejected = 0;
auto vocabSizePadded = static_cast<int32_t>((vocabSize + blockDim.x - 1) / blockDim.x) * blockDim.x;
for (int vIdx = threadIdx.x; vIdx < vocabSize; vIdx += blockDim.x)
for (int32_t vIdx = threadIdx.x; vIdx < vocabSizePadded; vIdx += blockDim.x)
{
if (rejected > 0)
{
@ -594,35 +586,41 @@ __global__ void acceptDraftTokensByLogitsKernel(const T* draftProbs, T* targetPr
// FIXME(nkorobov): We compare probability distributions, but it might make sense to compare probabilities of
// the selected tokens based on the https://arxiv.org/pdf/2302.01318.pdf
const auto threshold = randomThreshold ? curand_uniform(curandState + batchIdx) : constantThreshold;
const auto targetProb = static_cast<float>(targetProbsBatch[vIdx]);
const auto draftProb = static_cast<float>(draftProbsBatch[vIdx]);
bool const pred = vIdx < vocabSize;
auto const threshold
= pred ? (randomThreshold ? curand_uniform(curandState + batchSlot) : constantThreshold) : 0.f;
auto const targetProb = pred ? static_cast<float>(targetProbsBatch[vIdx]) : 1.f;
auto const draftProb = pred ? static_cast<float>(draftProbsBatch[vIdx]) : 0.f;
rejected = __syncthreads_count(targetProb < threshold * draftProb);
}
if (threadIdx.x == 0)
{
finished[draftTokenIdx * batchSize * beamWidth + bid]
finished[draftTokenIdx * maxBatchSize * beamWidth + batchSlotBeamWidth]
= rejected > 0 ? FinishedState::skipDecoding() : FinishedState::empty();
}
}
template <typename T>
__global__ void correctAcceptedStatesAndLogits(const T* draftProbs, T* targetProbs, T* targetLogits,
const int* numsDraftTokens, FinishedState* finished, int batchSize, int beamWidth, int vocabSize)
__global__ void correctAcceptedStatesAndLogits(T const* draftProbs, T* targetProbs, T** targetLogits,
int32_t const* numsDraftTokens, FinishedState* finished, int32_t const* batchSlots, int32_t batchSize,
int32_t maxBatchSize, int32_t maxDraftTokens, int32_t beamWidth, int32_t vocabSize)
{
const auto bid = blockIdx.x;
const auto numDraftTokens = numsDraftTokens[bid];
auto const bid = blockIdx.x;
auto const batchIdx = bid / beamWidth;
auto const beamIdx = bid % beamWidth;
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
auto const batchSlotBeamWidth = batchSlot * beamWidth + beamIdx;
auto const numDraftTokens = numsDraftTokens[batchSlotBeamWidth];
__shared__ int numAcceptedTokens;
__shared__ int32_t numAcceptedTokens;
if (threadIdx.x == 0)
{
numAcceptedTokens = numDraftTokens;
bool cummulativeSkipDecoding = false;
for (int ti = 0; ti < numDraftTokens + 1; ++ti)
for (int32_t ti = 0; ti < numDraftTokens + 1; ++ti)
{
auto& finishedState = finished[ti * batchSize * beamWidth + bid];
auto& finishedState = finished[ti * maxBatchSize * beamWidth + batchSlotBeamWidth];
bool localSkipDecoding = finishedState.isSkipDecoding();
if (cummulativeSkipDecoding == false && localSkipDecoding == true)
{
@ -637,15 +635,15 @@ __global__ void correctAcceptedStatesAndLogits(const T* draftProbs, T* targetPro
if (numAcceptedTokens < numDraftTokens)
{
const auto logitsIdx = numAcceptedTokens * batchSize * beamWidth * vocabSize + bid * vocabSize;
const auto draftProbBatch = draftProbs + logitsIdx;
auto const logitsIdx = (batchSlot * maxDraftTokens + numAcceptedTokens) * beamWidth * vocabSize;
auto const draftProbBatch = draftProbs + logitsIdx;
auto targetProbBatch = targetProbs + logitsIdx;
auto targetLogitsBatch = targetLogits + logitsIdx;
auto targetLogitsBatch = targetLogits[bid] + numAcceptedTokens * beamWidth * vocabSize;
float sumProbs = 0.f;
for (int vIdx = threadIdx.x; vIdx < vocabSize; vIdx += blockDim.x)
for (int32_t vIdx = threadIdx.x; vIdx < vocabSize; vIdx += blockDim.x)
{
const auto correctedProb = max(static_cast<float>(targetProbBatch[vIdx] - draftProbBatch[vIdx]), 0.f);
auto const correctedProb = max(static_cast<float>(targetProbBatch[vIdx] - draftProbBatch[vIdx]), 0.f);
sumProbs += correctedProb;
targetProbBatch[vIdx] = correctedProb;
}
@ -658,49 +656,52 @@ __global__ void correctAcceptedStatesAndLogits(const T* draftProbs, T* targetPro
}
__syncthreads();
for (int vIdx = threadIdx.x; vIdx < vocabSize; vIdx += blockDim.x)
for (int32_t vIdx = threadIdx.x; vIdx < vocabSize; vIdx += blockDim.x)
{
const auto correctedNormProb = static_cast<float>(targetProbBatch[vIdx]) / sumProbsShared;
auto const correctedNormProb = static_cast<float>(targetProbBatch[vIdx]) / sumProbsShared;
targetLogitsBatch[vIdx] = __logf(correctedNormProb / (1.f - correctedNormProb));
}
}
}
template <typename T>
void acceptDraftTokensByLogits(T* draftLogits, T* targetLogits, T* draftProbs, T* targetProbs,
const int* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int batchSize, int beamWidth,
int vocabSize, int vocabSizePadded, int maxDraftTokens, bool randomThreshold, float constantThreshold,
cudaStream_t stream)
void acceptDraftTokensByLogits(T* draftLogits, T** targetLogits, T* draftProbs, T* targetProbs,
int32_t const* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int32_t const* batchSlots,
int32_t batchSize, int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded,
int32_t maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream)
{
TLLM_CHECK(beamWidth == 1);
{
invokeAddBiasSoftMax(draftLogits, draftProbs, (T*) (nullptr), nullptr, finished, nullptr,
batchSize * beamWidth * maxDraftTokens, vocabSize, vocabSizePadded, stream);
invokeAddBiasSoftMax(targetLogits, targetProbs, (T*) (nullptr), nullptr, finished, nullptr,
batchSize * beamWidth * maxDraftTokens, vocabSize, vocabSizePadded, stream);
invokeAddBiasSoftMax(draftLogits, (T**) (nullptr), draftProbs, (T*) (nullptr), nullptr, finished, batchSlots,
batchSize, maxBatchSize, beamWidth * maxDraftTokens, vocabSize, vocabSizePadded, /* skip softmax */ false,
/* batchSlotLogits */ true, stream);
invokeAddBiasSoftMax((T*) (nullptr), targetLogits, targetProbs, (T*) (nullptr), nullptr, finished, batchSlots,
batchSize, maxBatchSize, beamWidth * maxDraftTokens, vocabSize, vocabSizePadded, /* skip softmax */ false,
/* batchSlotLogits */ true, stream);
}
{
dim3 block(1024);
dim3 grid(batchSize * beamWidth, maxDraftTokens);
acceptDraftTokensByLogitsKernel<<<grid, block, 0, stream>>>(draftProbs, targetProbs, numsDraftTokens, finished,
curandState, batchSize, beamWidth, vocabSizePadded, randomThreshold, constantThreshold);
curandState, batchSlots, batchSize, maxBatchSize, maxDraftTokens, beamWidth, vocabSizePadded,
randomThreshold, constantThreshold);
}
{
dim3 block(1024);
dim3 grid(batchSize * beamWidth);
correctAcceptedStatesAndLogits<<<grid, block, 0, stream>>>(
draftProbs, targetProbs, targetLogits, numsDraftTokens, finished, batchSize, beamWidth, vocabSizePadded);
correctAcceptedStatesAndLogits<<<grid, block, 0, stream>>>(draftProbs, targetProbs, targetLogits,
numsDraftTokens, finished, batchSlots, batchSize, maxBatchSize, maxDraftTokens, beamWidth, vocabSizePadded);
}
}
template void acceptDraftTokensByLogits(float* draftLogits, float* targetLogits, float* draftProbs, float* targetProbs,
const int* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int batchSize, int beamWidth,
int vocabSize, int vocabSizePadded, int maxDraftTokens, bool randomThreshold, float constantThreshold,
cudaStream_t stream);
template void acceptDraftTokensByLogits(half* draftLogits, half* targetLogits, half* draftProbs, half* targetProbs,
const int* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int batchSize, int beamWidth,
int vocabSize, int vocabSizePadded, int maxDraftTokens, bool randomThreshold, float constantThreshold,
cudaStream_t stream);
template void acceptDraftTokensByLogits(float* draftLogits, float** targetLogits, float* draftProbs, float* targetProbs,
int32_t const* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int32_t const* batchSlots,
int32_t batchSize, int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded,
int32_t maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream);
template void acceptDraftTokensByLogits(half* draftLogits, half** targetLogits, half* draftProbs, half* targetProbs,
int32_t const* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int32_t const* batchSlots,
int32_t batchSize, int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded,
int32_t maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -31,22 +31,22 @@ namespace kernels
struct gatherTreeParam
{
// TODO rename the parameters
int* beams = nullptr; // [batchSize, beamWidth, maxSeqLen], workspace to put intermediate outputIds
int* sequenceLengths = nullptr; // [batchSize, beamWidth], total lengths of each query
int maxSequenceLengthFinalStep = 0;
const int* inputLengths = nullptr; // [batchSize, beamWidth]
int32_t* beams = nullptr; // [batchSize, beamWidth, maxSeqLen], workspace to put intermediate outputIds
int32_t* sequenceLengths = nullptr; // [batchSize, beamWidth], total lengths of each query
int32_t maxSequenceLengthFinalStep = 0;
int32_t const* inputLengths = nullptr; // [batchSize, beamWidth]
// response input lengths (used to slice the ids during postprocessing)
int* responseInputLengths = nullptr;
int maxSeqLen = 0;
int batchSize = 0;
int beamWidth = 0;
const int* stepIds = nullptr; // [maxSeqLen, batchSize, beamWidth]
const int* parentIds = nullptr; // [maxSeqLen, batchSize, beamWidth]
const int* endTokens = nullptr; // [batchSize], end token ids of each query
int* outputIds = nullptr; // the buffer to put finalized ids
int32_t* responseInputLengths = nullptr;
int32_t maxSeqLen = 0;
int32_t batchSize = 0;
int32_t beamWidth = 0;
int32_t const* stepIds = nullptr; // [maxSeqLen, batchSize, beamWidth]
int32_t const* parentIds = nullptr; // [maxSeqLen, batchSize, beamWidth]
int32_t const* endTokens = nullptr; // [batchSize], end token ids of each query
int32_t* outputIds = nullptr; // the buffer to put finalized ids
cudaStream_t stream;
float* cumLogProbs = nullptr; // [batchSize, beamWidth]
float lengthPenalty = 1.0f; // on cpu
float* cumLogProbs = nullptr; // [batchSize, beamWidth]
float lengthPenalty = 1.0f; // on cpu
};
/*
@ -54,15 +54,16 @@ Do gatherTree on beam search to get final result.
*/
void invokeGatherTree(gatherTreeParam param);
void invokeFinalize(int* outputIds, int* sequenceLengths, float* cumLogProbs, float* outputLogProbs,
const int* topKOutputIds, const int* topKSequenceLengths, const float* scores, const float* topKCumLogProbs,
const float* topKLogProbs, const int* numBeams, const int* inputLengths, const int beamWidth, const int maxSeqLen,
const int batchSize, cudaStream_t stream);
void invokeFinalize(int32_t* outputIds, int32_t* sequenceLengths, float* cumLogProbs, float* outputLogProbs,
int32_t const* topKOutputIds, int32_t const* topKSequenceLengths, float const* scores, float const* topKCumLogProbs,
float const* topKLogProbs, int32_t const* numBeams, int32_t const* inputLengths, int32_t beamWidth,
int32_t maxSeqLen, int32_t batchSize, cudaStream_t stream);
void invokeInitializeOutput(int* outputIds, const int* endIds, int batchBeam, int maxSeqLen, cudaStream_t stream);
void invokeInitializeOutput(
int32_t* outputIds, const int32_t* endIds, int batchBeam, int maxSeqLen, cudaStream_t stream);
void invokeCopyNextStepIds(int* nextStepIds, int** outputIdsPtr, const int* sequenceLengths, const int* batchSlots,
int batchSize, int beamWidth, int maxSeqLen, cudaStream_t stream);
void invokeCopyNextStepIds(int32_t* nextStepIds, int32_t** outputIdsPtr, int32_t const* sequenceLengths,
int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth, int32_t maxSeqLen, cudaStream_t stream);
//! \brief Accepts or rejects draft tokens based on the equality of draft and target tokens
//! for speculative decoding. Target token is accepted if targetToken == draftToken.
@ -79,14 +80,17 @@ void invokeCopyNextStepIds(int* nextStepIds, int** outputIdsPtr, const int* sequ
//! \param finished input buffer [maxDraftTokens + 1, batchSize] finished states at each decoding iteration
//! \param finishedFinal output buffer [batchSize] finished states after accepting/rejecting tokens
//! \param finishedSum output buffer [1] total number of requests in batch that finished the execution
//! \param batchSize batch size
//! \param batchSlots
//! \param batchSize current batch size
//! \param maxBatchSize maximum batch size
//! \param beamWidth beam width
//! \param maxSeqLen maximum sequence length
//! \param maxDraftTokens maximum number of draft tokens
//! \param stream stream
void invokeAcceptDraftTokensByIds(const int* draftIds, const int* targetIds, const int* contextLengths,
const int* numsDraftTokens, int* sequenceLengths, const FinishedState* finished, FinishedState* finishedFinal,
int* finishedSum, int batchSize, int beamWidth, int maxSeqLen, int maxDraftTokens, cudaStream_t stream);
void invokeAcceptDraftTokensByIds(int32_t const* draftIds, int32_t const* targetIds, int32_t const* contextLengths,
int32_t const* numsDraftTokens, int32_t* sequenceLengths, FinishedState const* finished,
FinishedState* finishedFinal, int32_t* finishedSum, int32_t const* batchSlots, int32_t batchSize,
int32_t maxBatchSize, int32_t beamWidth, int32_t maxSeqLen, int32_t maxDraftTokens, cudaStream_t stream);
//! \brief Performs probabilistic acceptance of draft tokens based on their probability distributions.
//! Corrects targetLogits for the next to the last accepted token
@ -94,7 +98,8 @@ void invokeAcceptDraftTokensByIds(const int* draftIds, const int* targetIds, con
//!
//! \param draftLogits input/output buffer [draftTokens, batchSize, beamWidth, vocabSize].
//! Initially contains token logits of the draft model.
//! \param targetLogits input/output buffer [draftTokens+1, batchSize, beamWidth, vocabSize].
//! \param targetLogits input/output buffer [batchSize][draftTokens+1, beamWidth, vocabSize].
//! Vector of pointers to the logits.
//! Initially contains token logits of the target model.
//! It is modified in-place for next to the last accepted token such as
//! P'(x) = norm(max(0, P_{n+1}(x) - Q_{n+1}(x))), where N < maxDraftTokens is number of accepted tokens.
@ -107,7 +112,9 @@ void invokeAcceptDraftTokensByIds(const int* draftIds, const int* targetIds, con
//! At each step sets to NOT_FINISHED if token is accepted or SKIP_DECODING if token is not accepted
//! \param curandState input buffer [batchSize]. Curand states properly
//! initialized using invokeCurandInitialize per request.
//! \param batchSize batch size
//! \param batchSlots
//! \param batchSize current batch size
//! \param maxBatchSize maximum batch size
//! \param beamWidth beam width
//! \param vocabSize unpadded vocab size
//! \param vocabSizePadded padded vocab size
@ -116,17 +123,18 @@ void invokeAcceptDraftTokensByIds(const int* draftIds, const int* targetIds, con
//! \param constantThreshold threshold used to accept tokens if randomThreshold is false
//! \param stream stream
template <typename T>
void acceptDraftTokensByLogits(T* draftLogits, T* targetLogits, T* draftProbs, T* targetProbs,
const int* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int batchSize, int beamWidth,
int vocabSize, int vocabSizePadded, int maxDraftTokens, bool randomThreshold, float constantThreshold,
void acceptDraftTokensByLogits(T* draftLogits, T** targetLogits, T* draftProbs, T* targetProbs,
int32_t const* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int32_t const* batchSlots,
int32_t batchSize, int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded,
int32_t maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream);
void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_tiled, int32_t const* sequence_lengths,
int32_t const* batchSlots, int32_t batch_size, int32_t beam_width, int32_t max_seq_len, cudaStream_t stream);
void invokeAcceptTokens(int32_t const* draft_tokens, int32_t const* target_tokens, int32_t const* context_lengths,
int32_t const* nums_draft_tokens, int32_t* sequence_lengths, bool const* finished, bool* finished_final,
int32_t* finished_sum, int32_t batch_size, int32_t beam_width, int32_t max_seq_len, int32_t max_draft_tokens,
cudaStream_t stream);
void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_tiled, const int* sequence_lengths,
const int* batchSlots, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream);
void invokeAcceptTokens(const int* draft_tokens, const int* target_tokens, const int* context_lengths,
const int* nums_draft_tokens, int* sequence_lengths, const bool* finished, bool* finished_final, int* finished_sum,
int batch_size, int beam_width, int max_seq_len, int max_draft_tokens, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -30,7 +30,8 @@ static constexpr int kUpdateKVCacheKernelShmSize = 16384;
template <typename KVCacheBuffer, int MaxLayerCount, typename MoveEltType>
__global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheBuffer, MaxLayerCount> kvCacheBuffers,
const int* seqAcceptedDraftTokenOffsets, const IndexType* packedAcceptedDraftTokensIndices,
const int32_t* pastKeyValueLengths, int rewindDraftTokenCount, int eltCountPerHead)
const int32_t* pastKeyValueLengths, int rewindDraftTokenCommonCount, const int* rewindDraftTokenSeparateAdjustments,
int eltCountPerHead)
{
int seqIdx = blockIdx.x;
int headIdx = blockIdx.y;
@ -46,7 +47,11 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
return;
}
KVCacheBuffer& kvCacheBuffer = kvCacheBuffers[layerIdx];
int tokenStartIdx = pastKeyValueLengths[seqIdx] - rewindDraftTokenCount;
int tokenStartIdx = pastKeyValueLengths[seqIdx] - rewindDraftTokenCommonCount;
if (rewindDraftTokenSeparateAdjustments != nullptr)
{
tokenStartIdx -= rewindDraftTokenSeparateAdjustments[seqIdx];
}
int maxEltCountPerMove = kUpdateKVCacheKernelShmSize / sizeof(MoveEltType) / seqDraftCount;
int eltCountPerMove = min(maxEltCountPerMove, eltCountPerHead);
__shared__ char loadSmemBuffer[kUpdateKVCacheKernelShmSize];
@ -121,7 +126,7 @@ template <typename KVCacheBuffer, int MaxLayerCount>
void updateKVCacheDraftTokenLocationBatched(const KVCacheBuffer* kvCacheBuffers,
const int* seqAcceptedDraftTokenOffsets, const IndexType* packedAcceptedDraftTokensIndices,
const int32_t* pastKeyValueLengths, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCount, cudaStream_t stream)
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, cudaStream_t stream)
{
// make sure launch buffer is enough
static_assert(MaxLayerCount * sizeof(KVCacheBuffer) <= 3072);
@ -144,7 +149,7 @@ void updateKVCacheDraftTokenLocationBatched(const KVCacheBuffer* kvCacheBuffers,
kvCacheBufferArray[i] = kvCacheBuffers[i];
}
void (*pKernelFunc)(
std::array<KVCacheBuffer, MaxLayerCount>, const int*, const IndexType*, const int32_t*, int, int)
std::array<KVCacheBuffer, MaxLayerCount>, const int*, const IndexType*, const int32_t*, int, const int*, int)
= nullptr;
switch (alignedBytes)
{
@ -176,7 +181,8 @@ void updateKVCacheDraftTokenLocationBatched(const KVCacheBuffer* kvCacheBuffers,
}
}
pKernelFunc<<<grid, block, 0, stream>>>(kvCacheBufferArray, seqAcceptedDraftTokenOffsets,
packedAcceptedDraftTokensIndices, pastKeyValueLengths, rewindDraftTokenCount, eltCountPerHead);
packedAcceptedDraftTokensIndices, pastKeyValueLengths, rewindDraftTokenCommonCount,
rewindDraftTokenSeparateAdjustments, eltCountPerHead);
TLLM_CUDA_CHECK(cudaGetLastError());
}
@ -191,14 +197,17 @@ void updateKVCacheDraftTokenLocationBatched(const KVCacheBuffer* kvCacheBuffers,
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
* @param seqCount : Count of sequence
* @param numKVHeads : Number of KV heads
* @param sizeInBytesPerKVHead :
* @param rewindDraftTokenCount
* @param stream
* @param sizeInBytesPerKVHead : Size of each KV head
* @param rewindDraftTokenCommonCount : Common count to rewind
* @param rewindDraftTokenSeparateAdjustments : Separate adjustment to rewind for each sequence, if nullptr, just use
* rewindDraftTokenCommonCount, else use rewindDraftTokenSeparateAdjustments[i] + rewindDraftTokenCommonCount
* @param stream : CUDA stream to use.
*/
template <typename KVCacheBuffer>
void updateKVCacheDraftTokenLocation(const KVCacheBuffer* kvCacheBuffers, const int* seqAcceptedDraftTokenOffsets,
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int layerCount, int seqCount,
int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount, cudaStream_t stream)
int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments,
cudaStream_t stream)
{
int startLayer = 0;
static constexpr int kMaxLayersPerIter = 32;
@ -207,7 +216,8 @@ void updateKVCacheDraftTokenLocation(const KVCacheBuffer* kvCacheBuffers, const
int microBatchLayerCount = std::min(layerCount - startLayer, kMaxLayersPerIter);
updateKVCacheDraftTokenLocationBatched<KVCacheBuffer, kMaxLayersPerIter>(kvCacheBuffers + startLayer,
seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices, pastKeyValueLengths, microBatchLayerCount,
seqCount, numKVHeads, sizeInBytesPerKVHead, rewindDraftTokenCount, stream);
seqCount, numKVHeads, sizeInBytesPerKVHead, rewindDraftTokenCommonCount,
rewindDraftTokenSeparateAdjustments, stream);
startLayer += microBatchLayerCount;
}
}
@ -215,7 +225,7 @@ void updateKVCacheDraftTokenLocation(const KVCacheBuffer* kvCacheBuffers, const
void updateLinearKVCacheDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets,
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCount, int maxKVCacheLen, cudaStream_t stream)
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, cudaStream_t stream)
{
std::vector<KVLinearBuffer> kvLinearBuffers;
kvLinearBuffers.reserve(layerCount);
@ -227,13 +237,14 @@ void updateLinearKVCacheDraftTokenLocation(const int* seqAcceptedDraftTokenOffse
}
updateKVCacheDraftTokenLocation(kvLinearBuffers.data(), seqAcceptedDraftTokenOffsets,
packedAcceptedDraftTokensIndices, pastKeyValueLengths, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
rewindDraftTokenCount, stream);
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, stream);
}
void updateKVBlockArrayDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets,
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray,
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount,
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount,
int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock,
cudaStream_t stream)
{
std::vector<KVBlockArray> kvBlockArrays;
kvBlockArrays.reserve(layerCount);
@ -245,7 +256,47 @@ void updateKVBlockArrayDraftTokenLocation(const int* seqAcceptedDraftTokenOffset
}
updateKVCacheDraftTokenLocation(kvBlockArrays.data(), seqAcceptedDraftTokenOffsets,
packedAcceptedDraftTokensIndices, pastKeyValueLengths, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
rewindDraftTokenCount, stream);
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, stream);
}
void updateLinearKVCacheDraftTokenLocationCommonRewind(const int* seqAcceptedDraftTokenOffsets,
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCount, int maxKVCacheLen, cudaStream_t stream)
{
updateLinearKVCacheDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
pastKeyValueLengths, pastKeyValueList, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
rewindDraftTokenCount, nullptr, maxKVCacheLen, stream);
}
void updateKVBlockArrayDraftTokenLocationCommonRewind(const int* seqAcceptedDraftTokenOffsets,
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray,
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount,
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
{
updateKVBlockArrayDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
pastKeyValueLengths, pointerArray, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
rewindDraftTokenCount, nullptr, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
}
void updateLinearKVCacheDraftTokenLocationSeparateRewind(const int* seqAcceptedDraftTokenOffsets,
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int* rewindDraftTokenCounts, int maxKVCacheLen, cudaStream_t stream)
{
updateLinearKVCacheDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
pastKeyValueLengths, pastKeyValueList, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead, 0,
rewindDraftTokenCounts, maxKVCacheLen, stream);
}
void updateKVBlockArrayDraftTokenLocationSeparateRewind(const int* seqAcceptedDraftTokenOffsets,
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray,
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int* rewindDraftTokenCounts,
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
{
updateKVBlockArrayDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
pastKeyValueLengths, pointerArray, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead, 0,
rewindDraftTokenCounts, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
}
} // namespace tensorrt_llm::kernels::parallel_decoding

View File

@ -27,14 +27,141 @@ namespace tensorrt_llm::kernels::parallel_decoding
using IndexType = int;
void updateLinearKVCacheDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets,
/*!
* Update Linear KV cache using common rewind count.
* @param seqAcceptedDraftTokenOffsets : Array of length seqCount + 1, like [0, 3, 5]
* @param packedAcceptedDraftTokensIndices : Array of length seqAcceptedDraftTokenOffsets[seqCount], each value is in
* range [0, maxDraftTokenCount - 1]
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
* @param pastKeyValueList : Past key value list, which is the pointer array of each KVLinear cache.
* @param layerCount : Count of layers
* @param seqCount : Count of sequence
* @param numKVHeads : Number of KV heads
* @param sizeInBytesPerKVHead : Size of each KV head
* @param rewindDraftTokenCount : Count to rewind
* @param maxKVCacheLen : Maximum length of each KV cache
* @param stream : CUDA stream to use.
*/
void updateLinearKVCacheDraftTokenLocationCommonRewind(const int* seqAcceptedDraftTokenOffsets,
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCount, int maxKVCacheLen, cudaStream_t stream);
void updateKVBlockArrayDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets,
/*!
* Update Block KV cache using common rewind count.
* @param seqAcceptedDraftTokenOffsets : Array of length seqCount + 1, like [0, 3, 5]
* @param packedAcceptedDraftTokensIndices : Array of length seqAcceptedDraftTokenOffsets[seqCount], each value is in
* range [0, maxDraftTokenCount - 1]
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
* @param pointerArray : Pointer array of each Block KV cache.
* @param layerCount : Count of layers
* @param seqCount : Count of sequence
* @param numKVHeads : Number of KV heads
* @param sizeInBytesPerKVHead : Size of each KV head
* @param rewindDraftTokenCount : Count to rewind
* @param maxKVCacheLen : Maximum length of each KV cache
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
* @param tokensPerBlock : Tokens per block of Block KV cache
* @param stream : CUDA stream to use.
*/
void updateKVBlockArrayDraftTokenLocationCommonRewind(const int* seqAcceptedDraftTokenOffsets,
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray,
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount,
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
/*!
* Update Linear KV cache using separate rewind count for each sequence.
* @param seqAcceptedDraftTokenOffsets : Array of length seqCount + 1, like [0, 3, 5]
* @param packedAcceptedDraftTokensIndices : Array of length seqAcceptedDraftTokenOffsets[seqCount], each value is in
* range [0, maxDraftTokenCount - 1]
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
* @param pastKeyValueList : Past key value list, which is the pointer array of each KVLinear cache.
* @param layerCount : Count of layers
* @param seqCount : Count of sequence
* @param numKVHeads : Number of KV heads
* @param sizeInBytesPerKVHead : Size of each KV head
* @param rewindDraftTokenCounts : Pointer to an array of length seqCount, each element indicated the rewind count of
* one sequence.
* @param maxKVCacheLen : Maximum length of each KV cache
* @param stream : CUDA stream to use.
*/
void updateLinearKVCacheDraftTokenLocationSeparateRewind(const int* seqAcceptedDraftTokenOffsets,
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int* rewindDraftTokenCounts, int maxKVCacheLen, cudaStream_t stream);
/*!
* Update Block KV cache using separate rewind count for each sequence.
* @param seqAcceptedDraftTokenOffsets : Array of length seqCount + 1, like [0, 3, 5]
* @param packedAcceptedDraftTokensIndices : Array of length seqAcceptedDraftTokenOffsets[seqCount], each value is in
* range [0, maxDraftTokenCount - 1]
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
* @param pointerArray : Pointer array of each Block KV cache.
* @param layerCount : Count of layers
* @param seqCount : Count of sequence
* @param numKVHeads : Number of KV heads
* @param sizeInBytesPerKVHead : Size of each KV head
* @param rewindDraftTokenCounts : Pointer to an array of length seqCount, each element indicated the rewind count of
* one sequence.
* @param maxKVCacheLen : Maximum length of each KV cache
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
* @param tokensPerBlock : Tokens per block of Block KV cache
* @param stream : CUDA stream to use.
*/
void updateKVBlockArrayDraftTokenLocationSeparateRewind(const int* seqAcceptedDraftTokenOffsets,
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray,
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int* rewindDraftTokenCounts,
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
/*!
* Update Linear KV cache using both common rewind and separate rewind count for each sequence. The common
* rewindDraftTokenCommonCount and rewind count of each sequence in rewindDraftTokenSeparateAdjustments will be added
* together for the final rewind count. It can save one add if both of them need to be used.
* @param seqAcceptedDraftTokenOffsets : Array of length seqCount + 1, like [0, 3, 5]
* @param packedAcceptedDraftTokensIndices : Array of length seqAcceptedDraftTokenOffsets[seqCount], each value is in
* range [0, maxDraftTokenCount - 1]
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
* @param pastKeyValueList : Past key value list, which is the pointer array of each KVLinear cache.
* @param layerCount : Count of layers
* @param seqCount : Count of sequence
* @param numKVHeads : Number of KV heads
* @param sizeInBytesPerKVHead : Size of each KV head
* @param rewindDraftTokenCommonCount : Common token count to rewind
* @param rewindDraftTokenSeparateAdjustments : Pointer to an array of length seqCount, each element indicated the
* rewind adjustment for one sequence.
* @param maxKVCacheLen : Maximum length of each KV cache
* @param stream : CUDA stream to use.
*/
void updateLinearKVCacheDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets,
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths,
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, cudaStream_t stream);
/*!
* Update Block KV cache using both common rewind and separate rewind count for each sequence. The common
* rewindDraftTokenCommonCount and rewind count of each sequence in rewindDraftTokenSeparateAdjustments will be added
* together for the final rewind count. It can save one add if both of them need to be used.
* @param seqAcceptedDraftTokenOffsets : Array of length seqCount + 1, like [0, 3, 5]
* @param packedAcceptedDraftTokensIndices : Array of length seqAcceptedDraftTokenOffsets[seqCount], each value is in
* range [0, maxDraftTokenCount - 1]
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
* @param pointerArray : Pointer array of each Block KV cache.
* @param layerCount : Count of layers
* @param seqCount : Count of sequence
* @param numKVHeads : Number of KV heads
* @param sizeInBytesPerKVHead : Size of each KV head
* @param rewindDraftTokenCommonCount : Common token count to rewind
* @param rewindDraftTokenSeparateAdjustments : Pointer to an array of length seqCount, each element indicated the
* rewind adjustment for one sequence.
* @param maxKVCacheLen : Maximum length of each KV cache
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
* @param tokensPerBlock : Tokens per block of Block KV cache
* @param stream : CUDA stream to use.
*/
void updateKVBlockArrayDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets,
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray,
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount,
int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock,
cudaStream_t stream);
} // namespace tensorrt_llm::kernels::parallel_decoding

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -30,35 +30,37 @@ namespace kernels
{
template <typename T>
__global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorkspace, const int* penaltyWorkspacePrev,
const float* temperatures, const float* repetitionPenalties, const float* presencePenalties,
const float* frequencyPenalties, const bool accumulateVocab, const int maxSeqLen, const int vocabSize,
const int vocabSizePadded, const int** outputIdsPtr, const int** parentIdsPtr, const int* inputLengths,
const int* sequenceLengths, const int* minLengths, const int* endIds, const int* batchSlots)
__global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, T const* biases,
int32_t* penaltyWorkspace, int32_t const* penaltyWorkspacePrev, float const* temperatures,
float const* repetitionPenalties, float const* presencePenalties, float const* frequencyPenalties,
const bool accumulateVocab, int32_t const maxSeqLen, int32_t const vocabSize, int32_t const vocabSizePadded,
int32_t const** outputIdsPtr, int32_t const** parentIdsPtr, int32_t const* inputLengths,
int32_t const* sequenceLengths, int32_t const* minLengths, int32_t const* endIds, int32_t const* batchSlots)
{
const int beamWidth = gridDim.y;
const int batchIdx = blockIdx.x;
const int beamIdx = blockIdx.y;
const int batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
const int batchBeamIdx = batchIdx * beamWidth + beamIdx;
const int batchSlotBeamIdx = batchSlot * beamWidth + beamIdx;
const int inputLen = inputLengths == nullptr ? 0 : inputLengths[batchSlotBeamIdx];
const int currentStep = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlotBeamIdx];
int32_t const beamWidth = gridDim.y;
int32_t const batchIdx = blockIdx.x;
int32_t const beamIdx = blockIdx.y;
int32_t const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
int32_t const batchBeamIdx = batchIdx * beamWidth + beamIdx;
int32_t const batchSlotBeamIdx = batchSlot * beamWidth + beamIdx;
int32_t const inputLen = inputLengths == nullptr ? 0 : inputLengths[batchSlotBeamIdx];
int32_t const currentStep = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlotBeamIdx];
T const* biasBase = biases + batchSlot * vocabSizePadded;
// Initialize or update the number of occurrences of tokens
if (accumulateVocab)
{
penaltyWorkspace += batchBeamIdx * vocabSize;
if (currentStep <= inputLen)
{ // Context phase
for (int index = threadIdx.x; index < vocabSize; index += blockDim.x)
for (int32_t index = threadIdx.x; index < vocabSize; index += blockDim.x)
{
penaltyWorkspace[index] = 0;
}
__syncthreads();
for (int step = threadIdx.x; step < inputLen; step += blockDim.x)
for (int32_t step = threadIdx.x; step < inputLen; step += blockDim.x)
{
// All beams in the context phase are identical
int penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step];
int32_t penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step];
if (penaltyIndex < vocabSize)
{
atomicAdd(&penaltyWorkspace[penaltyIndex], 1);
@ -69,9 +71,9 @@ __global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorksp
{ // Generation phase
if (beamWidth > 1)
{
int parentBeam = parentIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 2];
int32_t parentBeam = parentIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 2];
penaltyWorkspacePrev += (batchIdx * beamWidth + parentBeam) * vocabSize;
for (int index = threadIdx.x; index < vocabSize; index += blockDim.x)
for (int32_t index = threadIdx.x; index < vocabSize; index += blockDim.x)
{
penaltyWorkspace[index] = penaltyWorkspacePrev[index];
}
@ -79,7 +81,7 @@ __global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorksp
}
if (threadIdx.x == 0)
{
int penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1];
int32_t penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1];
if (penaltyIndex < vocabSize)
{
penaltyWorkspace[penaltyIndex] += 1;
@ -89,7 +91,8 @@ __global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorksp
__syncthreads();
}
// Apply bias and penalties
logits += batchBeamIdx * vocabSizePadded;
auto const inLogitsPtr = inputLogits[batchIdx] + beamIdx * vocabSizePadded;
auto outLogitsPtr = outputLogits + batchBeamIdx * vocabSizePadded;
const T MASK_VAL = (std::is_same<T, half>::value) ? -HALF_FLT_MAX : -FLT_MAX;
float invTemperature, repetitionPenalty, presencePenalty, frequencyPenalty;
if (temperatures != nullptr)
@ -108,22 +111,22 @@ __global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorksp
{
frequencyPenalty = frequencyPenalties[batchSlot];
}
for (int index = threadIdx.x; index < vocabSizePadded; index += blockDim.x)
for (int32_t index = threadIdx.x; index < vocabSizePadded; index += blockDim.x)
{
if (index < vocabSize)
{
float logit = (float) logits[index];
float logit = (float) inLogitsPtr[index];
// Bias
if (biases != nullptr)
{
logit += (float) biases[index];
logit += (float) biasBase[index];
}
// Temperature
if (temperatures != nullptr)
{
logit *= invTemperature;
}
int numOccurences = penaltyWorkspace[index];
int32_t numOccurences = penaltyWorkspace[index];
if (numOccurences > 0)
{
// Repetition
@ -142,11 +145,11 @@ __global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorksp
logit -= frequencyPenalty * numOccurences;
}
}
logits[index] = logit;
outLogitsPtr[index] = logit;
}
else
{
logits[index] = MASK_VAL;
outLogitsPtr[index] = MASK_VAL;
}
}
if (minLengths != nullptr)
@ -155,7 +158,7 @@ __global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorksp
// Min length
if ((threadIdx.x == 0) && (currentStep - inputLen < minLengths[batchSlot]))
{
logits[endIds[batchSlot]] = MASK_VAL;
outLogitsPtr[endIds[batchSlot]] = MASK_VAL;
}
}
}
@ -166,11 +169,11 @@ void invokeBatchApplyPenalty(const InvokeBatchApplyPenaltyParams<T>& params)
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
dim3 block(256);
dim3 grid(params.batchSize, params.beamWidth);
batchApplyPenalty<T><<<grid, block, 0, params.stream>>>(params.logits, params.biases, params.penaltyWorkspace,
params.penaltyWorkspacePrev, params.temperatures, params.repetitionPenalties, params.presencePenalties,
params.frequencyPenalties, params.accumulateVocab, params.maxSeqLen, params.vocabSize, params.vocabSizePadded,
params.outputIdsPtr, params.parentIdsPtr, params.inputLengths, params.sequenceLengths, params.minLengths,
params.endIds, params.batchSlots);
batchApplyPenalty<T><<<grid, block, 0, params.stream>>>(params.inputLogits, params.outputLogits, params.biases,
params.penaltyWorkspace, params.penaltyWorkspacePrev, params.temperatures, params.repetitionPenalties,
params.presencePenalties, params.frequencyPenalties, params.accumulateVocab, params.maxSeqLen, params.vocabSize,
params.vocabSizePadded, params.outputIdsPtr, params.parentIdsPtr, params.inputLengths, params.sequenceLengths,
params.minLengths, params.endIds, params.batchSlots);
}
template void invokeBatchApplyPenalty(const InvokeBatchApplyPenaltyParams<float>& params);

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -28,7 +28,8 @@ namespace kernels
template <typename T>
struct InvokeBatchApplyPenaltyParams
{
T* logits;
T const* const* inputLogits;
T* outputLogits;
const T* biases;
int* penaltyWorkspace;
const int* penaltyWorkspacePrev;

View File

@ -24,7 +24,7 @@ namespace tensorrt_llm
namespace kernels
{
enum class RepetitionPenaltyType
enum class DecodingPenaltyType
{
Temperature, // the temperature penalty
Repetition, // the repetition penalty
@ -33,15 +33,15 @@ enum class RepetitionPenaltyType
MinLength, // the min length penalty
};
inline float getDefaultPenaltyValue(RepetitionPenaltyType penalty_type)
inline float getDefaultPenaltyValue(DecodingPenaltyType penalty_type)
{
switch (penalty_type)
{
case RepetitionPenaltyType::Temperature: return 1.0f;
case RepetitionPenaltyType::Repetition: return 1.0f;
case RepetitionPenaltyType::Presence: return 0.0f;
case RepetitionPenaltyType::Frequency: return 0.0f;
case RepetitionPenaltyType::MinLength: return 1.0f;
case DecodingPenaltyType::Temperature: return 1.0f;
case DecodingPenaltyType::Repetition: return 1.0f;
case DecodingPenaltyType::Presence: return 0.0f;
case DecodingPenaltyType::Frequency: return 0.0f;
case DecodingPenaltyType::MinLength: return 1.0f;
default: break;
}
return 0.0f;

View File

@ -23,11 +23,12 @@ struct Vec2Type<__nv_bfloat16>
#endif
}; // namespace
template <typename T, int kProcessRows, typename AccessType>
__global__ void apply_per_channel_scale(T* smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols)
template <typename T_in, typename T_out, int kProcessRows, typename AccessType>
__global__ void apply_per_channel_scale(
T_out* smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols)
{
static constexpr int kElems = sizeof(AccessType) / sizeof(T);
T scale[kElems], act_vec[kElems];
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
T_in scale[kElems], act_vec[kElems];
int col_offset = blockIdx.x * blockDim.x + threadIdx.x;
int row_offset = blockIdx.y;
if (col_offset * kElems >= cols || row_offset * kProcessRows >= rows)
@ -39,13 +40,13 @@ __global__ void apply_per_channel_scale(T* smoothed_act, const T* act, const T*
for (int i = 0; i < kProcessRows; ++i)
{
*reinterpret_cast<AccessType*>(act_vec) = reinterpret_cast<const AccessType*>(act + i * cols)[col_offset];
if constexpr ((std::is_same_v<T, half>
if constexpr ((std::is_same_v<T_in, half>
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|| std::is_same_v<T, __nv_bfloat16>
|| std::is_same_v<T_in, __nv_bfloat16>
#endif
) &&(kElems % 2 == 0))
{
using Vec2 = typename Vec2Type<T>::type;
using Vec2 = typename Vec2Type<T_in>::type;
#pragma unroll
for (int j = 0; j < kElems; j += 2)
{
@ -58,58 +59,77 @@ __global__ void apply_per_channel_scale(T* smoothed_act, const T* act, const T*
#pragma unroll
for (int j = 0; j < kElems; ++j)
{
act_vec[j] = static_cast<T>(static_cast<float>(act_vec[j]) * static_cast<float>(scale[j]));
act_vec[j] = static_cast<T_in>(static_cast<float>(act_vec[j]) * static_cast<float>(scale[j]));
}
}
if constexpr (std::is_same_v<T_in, T_out>)
{
reinterpret_cast<AccessType*>(smoothed_act + i * cols)[col_offset]
= *reinterpret_cast<AccessType*>(act_vec);
}
else
{
#pragma unroll
for (int j = 0; j < kElems; ++j)
{
(smoothed_act + i * cols)[col_offset * kElems + j] = static_cast<T_out>(act_vec[j]);
}
}
reinterpret_cast<AccessType*>(smoothed_act + i * cols)[col_offset] = *reinterpret_cast<AccessType*>(act_vec);
}
}
template <typename T, int kProcessRows, typename AccessType = float4>
template <typename T_in, typename T_out, int kProcessRows, typename AccessType = float4>
void apply_per_channel_scale_kernel_launcher_(
T* smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols, cudaStream_t stream = 0)
T_out* smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols, cudaStream_t stream = 0)
{
static constexpr int kElems = sizeof(AccessType) / sizeof(T);
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
dim3 block(128);
dim3 grid((cols / kElems + block.x - 1) / block.x, (rows + kProcessRows - 1) / kProcessRows);
apply_per_channel_scale<T, kProcessRows, AccessType>
apply_per_channel_scale<T_in, T_out, kProcessRows, AccessType>
<<<grid, block, 0, stream>>>(smoothed_act, act, per_channel_scale, rows, cols);
}
template <typename T>
template <typename T_in, typename T_out>
void apply_per_channel_scale_kernel_launcher(
T* smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols, cudaStream_t stream)
T_out* smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols, cudaStream_t stream)
{
int elems = rows * cols;
if (elems < 2048 * 2048)
{
apply_per_channel_scale_kernel_launcher_<T, 1, float4>(
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 1, float4>(
smoothed_act, act, per_channel_scale, rows, cols, stream);
}
else if (elems < 4096 * 4096)
{
apply_per_channel_scale_kernel_launcher_<T, 4, float4>(
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 4, float4>(
smoothed_act, act, per_channel_scale, rows, cols, stream);
}
else if (elems < 8192 * 8192)
{
apply_per_channel_scale_kernel_launcher_<T, 8, float4>(
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 8, float4>(
smoothed_act, act, per_channel_scale, rows, cols, stream);
}
else
{
apply_per_channel_scale_kernel_launcher_<T, 16, float4>(
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 16, float4>(
smoothed_act, act, per_channel_scale, rows, cols, stream);
}
}
#define INSTANTIATE_PREQUANT_SCALE(T) \
template void apply_per_channel_scale_kernel_launcher<T>( \
T * smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols, cudaStream_t stream)
#define INSTANTIATE_PREQUANT_SCALE(T_in, T_out) \
template void apply_per_channel_scale_kernel_launcher<T_in, T_out>( \
T_out * smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols, cudaStream_t stream)
INSTANTIATE_PREQUANT_SCALE(half, half);
#if defined(ENABLE_FP8)
INSTANTIATE_PREQUANT_SCALE(half, __nv_fp8_e4m3);
#endif
INSTANTIATE_PREQUANT_SCALE(half);
#if defined(ENABLE_BF16)
INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16);
INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_bfloat16);
#if defined(ENABLE_FP8)
INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
#endif
#endif
} // namespace kernels

View File

@ -17,6 +17,7 @@
#pragma once
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
@ -32,9 +33,9 @@ namespace tensorrt_llm
namespace kernels
{
template <typename T>
template <typename T_in, typename T_out = T_in>
void apply_per_channel_scale_kernel_launcher(
T* smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols, cudaStream_t stream = 0);
T_out* smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols, cudaStream_t stream = 0);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -301,7 +301,7 @@ __device__ void vectorizedProcess(size_t threadRank, size_t numThreads, T const*
}
/**
* Fused filtering of the current pass and building histogram for the next pass (see steps 4 & 1 in `airTopPSsampling`
* Fused filtering of the current pass and building histogram for the next pass (see steps 4 & 1 in `airTopPSampling`
* description).
*/
template <typename T, typename IdxT, typename AccT, int BitsPerPass>
@ -418,7 +418,7 @@ __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* i
}
/**
* Replace histogram with its own prefix sum (step 2 in `airTopPSsampling` description)
* Replace histogram with its own prefix sum (step 2 in `airTopPSampling` description)
*/
template <typename IdxT, int BitsPerPass, int BlockSize>
__device__ void scan(volatile IdxT* histogram)
@ -472,7 +472,7 @@ __device__ void scan(volatile IdxT* histogram)
/**
* Calculate in which bucket the k-th value will fall
* (steps 3 in `airTopPSsampling` description)
* (steps 3 in `airTopPSampling` description)
*/
template <typename T, typename IdxT, typename AccT, int BitsPerPass>
__device__ void chooseBucket(
@ -486,13 +486,17 @@ __device__ void chooseBucket(
// one and only one thread will satisfy this condition, so counter is
// written by only one thread
if ((prev < sum && cur >= sum) || (sum <= 0 && i == 0))
// Add strict check for negetive cases.
if ((sum > 0 && prev < sum && cur >= sum) || (sum <= 0 && prev == 0 && cur != 0))
{
counter->sum = sum - prev; // how many values still are there to find
counter->len = countHistogram[i]; // cur - prev; // number of values in next pass
typename cub::Traits<T>::UnsignedBits bucket = i;
int startBit = calcsStartBit<T, BitsPerPass>(pass);
counter->kthValueBits |= bucket << startBit;
if (countHistogram[i]) // Only check meaningful ones
{
counter->sum = sum - prev; // how many values still are there to find
counter->len = countHistogram[i]; // cur - prev; // number of values in next pass
typename cub::Traits<T>::UnsignedBits bucket = i;
int startBit = calcsStartBit<T, BitsPerPass>(pass);
counter->kthValueBits |= bucket << startBit;
}
}
}
}
@ -533,7 +537,7 @@ __device__ void epilogue(T const value, IdxT const index, float* outputLogProbs,
/**
* Find the target element.
* (steps 4 in `airTopPSsampling` description)
* (steps 4 in `airTopPSampling` description)
*/
template <typename T, typename IdxT, typename AccT, int BitsPerPass>
__device__ void lastFilter(T const* inBuf, IdxT const* inIdxBuf, IdxT currentLen, Counter<T, IdxT, AccT>* counter,
@ -603,7 +607,7 @@ __device__ void lastFilter(T const* inBuf, IdxT const* inIdxBuf, IdxT currentLen
* their indices.
*/
template <typename T, typename IdxT, typename AccT, int BitsPerPass, int BlockSize, bool is_fused_filter = false>
__global__ void airTopPSsampling(Counter<T, IdxT, AccT>* counters, AccT* histograms, IdxT* countHistograms, IdxT** ids,
__global__ void airTopPSampling(Counter<T, IdxT, AccT>* counters, AccT* histograms, IdxT* countHistograms, IdxT** ids,
int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, IdxT const* endIds, int const batchSize, bool const* skipDecode, int const pass, T* buf1,
IdxT* idxBuf1, T* buf2, IdxT* idxBuf2, int32_t const* batchSlots)
@ -697,6 +701,7 @@ __global__ void airTopPSsampling(Counter<T, IdxT, AccT>* counters, AccT* histogr
earlyStop);
__syncthreads();
__threadfence();
bool isLastBlock = false;
if (threadIdx.x == 0)
@ -711,13 +716,42 @@ __global__ void airTopPSsampling(Counter<T, IdxT, AccT>* counters, AccT* histogr
{
return;
}
__shared__ IdxT maxBucket;
if (pass > 0)
{
// Avoid the scenario where currentSum is larger than the meaningful maximum prefix sum.
// This situation happens because these two values are calculted in different ways.
// So the precision loss during the calculation is also different.
if (threadIdx.x == 0)
{
maxBucket = 0;
}
__syncthreads();
for (int i = threadIdx.x; i < numBuckets; i += blockDim.x)
{
if (countHistogram[i])
{
atomicMax(&maxBucket, i);
}
}
__syncthreads();
}
scan<AccT, BitsPerPass, BlockSize>(histogram);
__syncthreads();
if (pass == 0)
{
currentSum = histogram[numBuckets - 1] * counter->p;
}
__syncthreads();
else
{
if (currentSum > histogram[maxBucket])
{
currentSum = histogram[maxBucket];
}
}
chooseBucket<T, IdxT, AccT, BitsPerPass>(counter, histogram, countHistogram, currentSum, pass);
__syncthreads();
@ -818,7 +852,7 @@ unsigned calcAirTopPBlockNum(int batchSize, IdxT len, int smCnt)
int activeBlocks;
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&activeBlocks, airTopPSsampling<T, IdxT, AccT, BitsPerPass, BlockSize, false>, BlockSize, 0);
&activeBlocks, airTopPSampling<T, IdxT, AccT, BitsPerPass, BlockSize, false>, BlockSize, 0);
activeBlocks *= smCnt;
IdxT bestNumBlocks = 0;
@ -909,13 +943,13 @@ void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** ou
dim3 grid(blockNum, batchSize);
// Sample with Top P given sorted tokens
int constexpr numPasses = calcNumPasses<T, BitsPerPass>();
auto kernel = airTopPSsampling<T, IdxT, AccT, BitsPerPass, SAMPLING_BLOCK_SIZE, false>;
auto kernel = airTopPSampling<T, IdxT, AccT, BitsPerPass, SAMPLING_BLOCK_SIZE, false>;
for (int pass = 0; pass < numPasses; ++pass)
{
if (pass == numPasses - 1)
{
kernel = airTopPSsampling<T, IdxT, AccT, BitsPerPass, SAMPLING_BLOCK_SIZE, true>;
kernel = airTopPSampling<T, IdxT, AccT, BitsPerPass, SAMPLING_BLOCK_SIZE, true>;
}
kernel<<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(counters, histograms, countHistograms, outputIds,

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -36,59 +36,6 @@ namespace tensorrt_llm
namespace kernels
{
template <typename T>
__global__ void addBiasEndMask(T* logits, const T* bias, const int* endIds, const FinishedState* finished,
const int* batchSlots, const int vocabSize, const int vocabSizePadded)
{
auto const batchIdx = blockIdx.x;
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
FinishedState const finishState = finished != nullptr ? finished[batchSlot] : FinishedState::empty();
if (finishState.isSkipDecoding())
{
return;
}
bool finish = finishState.isFinished();
int offset = batchIdx * vocabSizePadded;
bool const IS_FP16 = std::is_same<T, half>::value;
T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x)
{
if (tid >= vocabSize)
{
logits[offset + tid] = -MAX_T_VAL;
}
else if (finish)
{
logits[offset + tid] = (tid == endIds[batchSlot]) ? MAX_T_VAL : -MAX_T_VAL;
}
else
{
if (bias != nullptr)
{
logits[offset + tid] += bias[tid];
}
}
}
}
template <typename T>
void invokeAddBiasEndMask(T* logits, const T* bias, const int* endIds, const FinishedState* finished,
const int* batchSlots, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream)
{
dim3 grid(batchSize);
dim3 block(min(vocabSizePadded, 1024));
// n is the vocabSize, e.g., 30000, 7000.... vocabSize is usually very big.
addBiasEndMask<<<grid, block, 0, stream>>>(logits, bias, endIds, finished, batchSlots, vocabSize, vocabSizePadded);
}
template void invokeAddBiasEndMask(float* logits, const float* bias, const int* endIds, const FinishedState* finished,
const int* batchSlots, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream);
template void invokeAddBiasEndMask(half* logits, const half* bias, const int* endIds, const FinishedState* finished,
const int* batchSlots, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream);
template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
__global__ void topKStage1(const T* __restrict logProbs, T* tmpLogProbs, int* topKTmpIdBuf, T* topKTmpValBuf,
const FinishedState* finished, const int maxTopK, const int* topKs, const int vocabSize, const int* endIds,
@ -176,7 +123,7 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, const int maxTopK, const int* topKs, const float topP, const float* topPs,
curandState_t* curandstate, const int* endIds, const int vocabSize, const bool* skipDecode, const int* batchSlots,
const bool normalizeLogProbs)
const bool normalizeLogProbs, const bool logitHasProbs)
{
bool const IS_FP16 = std::is_same<T, half>::value;
T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
@ -240,7 +187,7 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
// when cumLogProbs are computed, topKTmpValBuf (logits_buf_) are
// already pre-processed by softmax_kernel
if (cumLogProbs == nullptr && outputLogProbs == nullptr)
if (!logitHasProbs)
{
total.u = __expf(total.u - maxLogit);
}
@ -309,7 +256,8 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
topKStage2Sampling<T, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_> \
<<<batchSize, BLOCK_SIZE_2_, K_MAX * sizeof(int) + K_MAX * sizeof(float), stream>>>(topKTmpIdBuf, \
topKTmpValBuf, ids, sequenceLengths, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, maxTopK, \
topKs, topP, topPs, curandstate, endIds, vocabSize, skipDecode, batchSlots, normalizeLogProbs); \
topKs, topP, topPs, curandstate, endIds, vocabSize, skipDecode, batchSlots, normalizeLogProbs, \
logitsHasProbs); \
break;
template <typename T>
@ -317,9 +265,9 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo
const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize,
const bool* skipDecode, const bool normalizeLogProbs)
const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs)
{
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// Not allow an ambiguous inputs topP and topPs.
assert(topP == 1.0f || topPs == nullptr);
@ -341,6 +289,11 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo
return;
}
if (maxTopK == 0)
{
return;
}
T* tempLogProbs = (T*) workspace;
int* topKTmpIdBuf = (int*) (tempLogProbs + tempLogProbsBufSize);
T* topKTmpValBuf = (T*) (topKTmpIdBuf + topKTmpIdsBufSize);
@ -367,6 +320,8 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo
CASE_K(1024, 256, 256, 8, normalizeLogProbs);
default: throw std::domain_error(fmtstr("top-k kernel supports 1<=k<=1024 but got k=%d", maxTopK));
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
#undef CASE_K
@ -375,37 +330,37 @@ template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, co
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP,
const float* topPs, const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream,
const int batchSize, const bool* skipDecode, const bool normalizeLogProbs);
const int batchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs);
template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids,
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP,
const float* topPs, const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream,
const int batchSize, const bool* skipDecode, const bool normalizeLogProbs);
const int batchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs);
template <typename T>
void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths,
const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode,
const bool normalizeLogProbs)
const bool normalizeLogProbs, const bool logitsHasProbs)
{
invokeBatchTopKSampling(workspace, workspaceSize, logProbs, ids, sequenceLengths, finishedInput, finishedOutput,
cumLogProbs, outputLogProbs, curandstate, topK, nullptr, topP, nullptr, vocabSizePadded, endIds, batchSlots,
stream, batchSize, skipDecode, normalizeLogProbs);
stream, batchSize, skipDecode, normalizeLogProbs, logitsHasProbs);
}
template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const float* logProbs, int** ids,
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded,
const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode,
const bool normalizeLogProbs);
const bool normalizeLogProbs, const bool logitsHasProbs);
template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids,
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded,
const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode,
const bool normalizeLogProbs);
const bool normalizeLogProbs, const bool logitsHasProbs);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -34,7 +34,7 @@ namespace kernels
//! buffer.
//! \param workspaceSize size of the workspace in bytes
//! \param logProbs input buffer [batchSize x vocabSizePadded].
//! Log probabilities of each token in the vocab. If cumLogProbs or outputLogProbs are specified,
//! Log probabilities of each token in the vocab. If logitsHasProbs is true,
//! logProbs must contain **just** probabilities instead of log probabilities.
//! \param outputIds output buffer [maxBatchSize][maxSeqLen]. Contains pointers to rows with output tokens per request
//! \param sequenceLength input/output buffer [maxBatchSize]. Current sequence length of the request up to, but excluding endId token
@ -61,13 +61,14 @@ namespace kernels
//! \param batchSize batch size
//! \param skipDecode input buffer [maxBatchSize]. Flags whether to skip decoding per request
//! \param normalizeLogProbs when set to True outputLogProbs are normalized to TopK
//! \param logitsHasProbs flag to highlight that logProbs contains probabilities
// clang-format on
template <typename T>
void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths,
const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize,
const bool* skipDecode, const bool normalizeLogProbs);
const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs);
//! \brief Specialization of invokeBatchTopKSampling with topPs=nullptr and topKs=nullptr
template <typename T>
@ -75,24 +76,7 @@ void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProb
const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode,
const bool normalizeLogProbs);
//! \brief Applies mask and bias to logits. Sets -MAX_FLT value for tokens in range [vocabSize; vocabSizePadded) to
//! prevent them being chosen If request finished the generation, sets MAX_FLT to endId token and -MAX_FLT to all other
//! tokens forcing to choose endId token. Otherwise, adds bias per token if bias pointer is not nullptr.
//!
//! \param logits input/output buffer [batchSize, vocabSize]. Logits to be modified.
//! \param bias input buffer [vocabSize]. Bias to logit per token. Ignored if nullptr
//! \param endIds input buffer [maxBatchSize]. EOS token ids per request
//! \param finished input buffer [maxBatchSize] with flags set to true if request has finished the generation
//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool
//! \param batchSize batch size
//! \param vocabSize unpadded vocab size
//! \param vocabSizePadded padded vocab size
//! \param stream stream
template <typename T>
void invokeAddBiasEndMask(T* logits, const T* bias, const int* endIds, const FinishedState* finished,
const int* batchSlots, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream);
const bool normalizeLogProbs, const bool logitsHasProbs);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -310,6 +310,8 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub
curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds,
float const maxTopP, float const* topPs, cudaStream_t stream, bool const* skipDecode, int const* batchSlots)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
int const vocabSize = vocabSizePadded;
size_t sortedLogProbBufSize = batchSize * vocabSize * sizeof(T); // type T
@ -338,6 +340,7 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub
// If the most probable token exceeds P, we skip sorting by setting beginOffsetBuf[bi] = offsetBuf[bi]
topPBeamTopKKernel<T, BLOCK_SIZE><<<batchSize, BLOCK_SIZE, 0, stream>>>(logProbs, sortedIdVals, sortedLogProbs,
finishedInput, vocabSize, offsetBuf, beginOffsetBuf, maxTopP, topPs, skipDecode, batchSlots);
sync_check_cuda_error();
// Sort tokens by probability in descending order
check_cuda_error(cub::DeviceSegmentedRadixSort::SortPairsDescending(cubTempStorage, cubTempStorageSize, logProbs,
@ -352,6 +355,9 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub
topPSsampling<T, SAMPLING_BLOCK_SIZE><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(sortedLogProbs, sortedIdVals,
outputIds, sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, beginOffsetBuf,
offsetBuf + 1, vocabSize, curandstate, maxTopP, topPs, endIds, batchSize, skipDecode, batchSlots);
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize,

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -24,43 +24,44 @@ namespace tensorrt_llm
{
namespace kernels
{
__global__ void stopWordsCriterion(const int** outputIds, const int** parentIds, const int* stopWords,
FinishedState* finished, const int* sequenceLengths, const int* batchSlots, size_t stopWordsLen, int batchSize,
int beamWidth, int maxSeqLen)
__global__ void stopWordsCriterion(int32_t const** outputIds, int32_t const** parentIds, int32_t const** stopWords,
FinishedState* finished, int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t const* stopWordsLens,
int32_t batchSize, int32_t beamWidth, int32_t maxSeqLen)
{
int const id = blockIdx.x * blockDim.x + threadIdx.x;
int const batchIdx = blockIdx.y / beamWidth;
int const beamIdx = blockIdx.y % beamWidth;
int32_t const id = blockIdx.x * blockDim.x + threadIdx.x;
int32_t const batchIdx = blockIdx.y / beamWidth;
int32_t const beamIdx = blockIdx.y % beamWidth;
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
auto const batchBeamIdx = batchSlot * beamWidth + beamIdx;
int const* baseStopWords = stopWords + batchSlot * 2 * stopWordsLen;
int const* baseOffsets = baseStopWords + stopWordsLen;
auto const* baseStopWords = stopWords[batchSlot];
auto const stopWordsLen = stopWordsLens[batchSlot];
auto const* baseOffsets = baseStopWords + stopWordsLen;
if (id >= stopWordsLen || baseOffsets[id] < 0)
{
return;
}
int const itemEnd = baseOffsets[id];
int const itemStart = (id > 0) ? baseOffsets[id - 1] : 0;
int const itemSize = itemEnd - itemStart;
auto const itemEnd = baseOffsets[id];
auto const itemStart = (id > 0) ? baseOffsets[id - 1] : 0;
auto const itemSize = itemEnd - itemStart;
// The single-token case unconditionally bans the token
bool shouldStop = false;
// Need to minus 1 because the sequenceLengths is updated in this step
int const currentStep = sequenceLengths[batchBeamIdx] - 1;
auto const currentStep = sequenceLengths[batchBeamIdx] - 1;
// Enough previously generated tokens to look for a match
if (currentStep + 1 >= itemSize)
{
shouldStop = true;
int parentId = beamIdx;
auto parentId = beamIdx;
bool const gatherBeam = beamWidth > 1;
for (int tokenIdx = itemSize - 1; tokenIdx >= 0; tokenIdx--)
for (int32_t tokenIdx = itemSize - 1; tokenIdx >= 0; tokenIdx--)
{
int const previousToken
auto const previousToken
= outputIds[batchSlot][parentId * maxSeqLen + currentStep - (itemSize - 1) + tokenIdx];
if (previousToken != baseStopWords[itemStart + tokenIdx])
{
@ -88,15 +89,16 @@ __global__ void stopWordsCriterion(const int** outputIds, const int** parentIds,
}
}
void invokeStopWordsCriterion(const int** outputIds, const int** parentIds, const int* stopWords,
FinishedState* finished, const int* sequenceLengths, const int* batchSlots, size_t stopWordsLen, int batchSize,
int beamWidth, int maxSeqLen, cudaStream_t stream)
void invokeStopWordsCriterion(int32_t const** outputIds, int32_t const** parentIds, int32_t const** stopWords,
FinishedState* finished, int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t const* stopWordsLen,
int32_t maxStopWordsLen, int32_t batchSize, int32_t beamWidth, int32_t maxSeqLen, cudaStream_t stream)
{
// Check if we have sampled a word from the stopWords list. If so, stop the sequence.
dim3 block, grid;
constexpr size_t maxBlockSize{256};
block.x = min(((stopWordsLen + 32 - 1) / 32) * 32, maxBlockSize);
grid.x = (stopWordsLen + block.x - 1) / block.x;
constexpr int32_t maxBlockSize{256};
block.x = min(((maxStopWordsLen + 32 - 1) / 32) * 32, maxBlockSize);
grid.x = (maxStopWordsLen + block.x - 1) / block.x;
grid.y = batchSize * beamWidth;
stopWordsCriterion<<<grid, block, 0, stream>>>(outputIds, parentIds, stopWords, finished, sequenceLengths,
@ -104,15 +106,15 @@ void invokeStopWordsCriterion(const int** outputIds, const int** parentIds, cons
sync_check_cuda_error();
}
__global__ void lengthCriterion(FinishedState* finished, int* finishedSum, const uint32_t* sequenceLimitLength,
const int* sequenceLengths, const int* batchSlots, int batchSize, int beamWidth)
__global__ void lengthCriterion(FinishedState* finished, int32_t* finishedSum, uint32_t const* sequenceLimitLength,
int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth)
{
int threadFinishedCount = 0;
for (int index = threadIdx.x; index < batchSize * beamWidth; index += blockDim.x)
int32_t threadFinishedCount = 0;
auto const batchIdx = blockIdx.x;
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
for (int32_t beamIdx = threadIdx.x; beamIdx < beamWidth; beamIdx += blockDim.x)
{
int const batchIdx = index / beamWidth;
int const beamIdx = index % beamWidth;
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
auto const batchSlotBeamWidthIdx = batchSlot * beamWidth + beamIdx;
auto finishState = finished[batchSlotBeamWidthIdx];
@ -140,19 +142,20 @@ __global__ void lengthCriterion(FinishedState* finished, int* finishedSum, const
if (threadIdx.x == 0)
{
finishedSum[0] = blockFinishedCount;
finishedSum[batchSlot] = blockFinishedCount;
}
}
}
void invokeLengthCriterion(FinishedState* finished, int* finishedSum, const uint32_t* sequenceLimitLength,
const int* sequenceLengths, const int* batchSlots, int batchSize, int beamWidth, cudaStream_t stream)
void invokeLengthCriterion(FinishedState* finished, int32_t* finishedSum, uint32_t const* sequenceLimitLength,
int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth,
cudaStream_t stream)
{
// Check if we have attained the sequence length limit. If so, stop the
// sequence. In addition, check if all sequences are stopped and return the
// result in shouldStop
dim3 block{min(512, uint32_t(batchSize * beamWidth))};
dim3 grid{1};
dim3 block{min(512, uint32_t(beamWidth))};
dim3 grid{uint32_t(batchSize)};
lengthCriterion<<<grid, block, 0, stream>>>(
finished, finishedSum, sequenceLimitLength, sequenceLengths, batchSlots, batchSize, beamWidth);

View File

@ -27,7 +27,7 @@ namespace kernels
//! \param outputIds input buffer [maxBatchSize][maxSeqLen]. Contains pointers to rows with output tokens per request
//! \param parentIds input buffer [maxBatchSize][maxSeqLen]. Contains pointers to rows with parent ids. Applicable when
//! beamWidth > 1
//! \param stopWords input buffer [maxBatchSize, 2, stopWordsLen]. For each instance in batch the first row
//! \param stopWords input buffer [maxBatchSize][2, stopWordsLen]. For each instance in batch the first row
//! is the token ids of the stop words. The second row is the exclusive prefix sum of the word lengths.
//! In case all the words are made of a single token,
//! the inner-most dimension of the tensor must be increased by 1.
@ -37,14 +37,15 @@ namespace kernels
//! \param sequenceLengths input buffer [maxBatchSize, beamWidth]. Current sequence
//! lengths of the request tokens.
//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool
//! \param stopWordsLen cumulative length of all stop words
//! \param stopWordsLen input buffer [maxBatchSize], cumulative length of all stop words per request
//! \param maxStopWordsLen maximum stopWordsLen over all requests in the batch
//! \param batchSize batch size
//! \param beamWidth beam width
//! \param maxSeqLen maximum length of the sequence
//! \param stream stream
void invokeStopWordsCriterion(const int** outputIds, const int** parentIds, const int* stopWords,
FinishedState* finished, const int* sequenceLengths, const int* batchSlots, size_t stopWordsLen, int batchSize,
int beamWidth, int maxSeqLen, cudaStream_t stream);
void invokeStopWordsCriterion(int32_t const** outputIds, int32_t const** parentIds, int32_t const** stopWords,
FinishedState* finished, int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t const* stopWordsLen,
int32_t maxStopWordsLen, int32_t batchSize, int32_t beamWidth, int32_t maxSeqLen, cudaStream_t stream);
//! \brief Sets finished states based on the sequenceLimitLength and computes number of finished sequences in the batch.
//!
@ -59,7 +60,8 @@ void invokeStopWordsCriterion(const int** outputIds, const int** parentIds, cons
//! \param batchSize batch size
//! \param beamWidth beam width
//! \param stream stream
void invokeLengthCriterion(FinishedState* finished, int* finishedSum, const uint32_t* sequenceLimitLength,
const int* sequenceLengths, const int* batchSlots, int batchSize, int beamWidth, cudaStream_t stream);
void invokeLengthCriterion(FinishedState* finished, int32_t* finishedSum, uint32_t const* sequenceLimitLength,
int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth,
cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -274,9 +274,9 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu
// src QKV: [batch, time, 3, head_num, size_per_head]
// head_num != kv_head_num:
// src QKV: [batch, time, head_num * size_per_head + 2 * kv_head_num * size_per_head]
const int src_q_idx = token_idx * n + hidden_idx;
const int src_k_idx = token_idx * n + src_k_offset + hidden_idx_kv;
const int src_v_idx = token_idx * n + src_v_offset + hidden_idx_kv;
auto const src_q_idx = static_cast<size_t>(token_idx) * n + hidden_idx;
auto const src_k_idx = static_cast<size_t>(token_idx) * n + src_k_offset + hidden_idx_kv;
auto const src_v_idx = static_cast<size_t>(token_idx) * n + src_v_offset + hidden_idx_kv;
Vec_type q, k, v;
Vec_type q_bias, k_bias, v_bias;

View File

@ -56,7 +56,8 @@ enum class WeightOnlyActivationFunctionType
enum class WeightOnlyActivationType
{
FP16,
BF16
BF16,
FP8
};
struct WeightOnlyParams

View File

@ -41,6 +41,12 @@ struct SupportedLayout<cutlass::uint4b_t, cutlass::layout::ColumnMajorTileInterl
static constexpr bool value = true;
};
template <>
struct SupportedLayout<cutlass::uint4b_t, cutlass::layout::ColumnMajor>
{
static constexpr bool value = true;
};
template <typename TypeB, typename Arch>
bool isEnabled()
{
@ -53,16 +59,20 @@ bool isEnabledForArch(int arch)
{
if (arch >= 70 && arch < 75)
{
return isEnabled<TypeB, cutlass::arch::Sm70>();
return false;
}
else if (arch >= 75 && arch < 80)
{
return isEnabled<TypeB, cutlass::arch::Sm75>();
}
else if (arch >= 80 && arch <= 90)
else if (arch >= 80 && arch < 90)
{
return isEnabled<TypeB, cutlass::arch::Sm80>();
}
else if (arch >= 90)
{
return isEnabled<TypeB, cutlass::arch::Sm90>();
}
else
{
TLLM_CHECK_WITH_INFO(false, "Unsupported Arch");

View File

@ -0,0 +1,75 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cassert>
#include <cmath>
#include <cstdint>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
#include <cuda_runtime_api.h>
#include <iostream>
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
enum class KernelType
{
W4A16,
W4A8
};
struct Params
{
void* act;
void* act_scale;
void* weight;
void* scales;
void* zeros;
void* bias;
void* out;
float alpha;
int m;
int n;
int k;
int groupsize;
KernelType type;
Params(void* _act, void* _act_scale, void* _weight, void* _scales, void* _zeros, void* _bias, void* _out,
float _alpha, int _m, int _n, int _k, int _groupsize, KernelType _type)
: act(_act)
, act_scale(_act_scale)
, weight(_weight)
, scales(_scales)
, zeros(_zeros)
, bias(_bias)
, out(_out)
, alpha(_alpha)
, m(_m)
, n(_n)
, k(_k)
, groupsize(_groupsize)
, type(_type)
{
}
};
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,389 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/common.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
struct ConverterI4ToF16
{
__device__ __forceinline__ static void convert(uint32_t& src, uint4& dst)
{
uint32_t* r = reinterpret_cast<uint32_t*>(&dst);
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
#pragma unroll
for (int ii = 0; ii < 4; ++ii)
{
asm volatile(
"{\n"
" prmt.b32 %0, %1, %2, %3;\n"
"}\n"
: "=r"(r[ii])
: "r"(src), "n"(0), "r"(prmt_indices[ii]));
}
static constexpr uint32_t xor_mask = 0x64806408;
static constexpr uint32_t and_mask = 0xFFF0FF0F;
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
#pragma unroll
for (int ii = 0; ii < 4; ++ii)
{
asm volatile(
"{\n"
" lop3.b32 %0, %0, %1, %2, %3;\n"
"}\n"
: "+r"(r[ii])
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
}
static constexpr uint32_t hfma_bias_rep = 0xD480E408;
static constexpr uint32_t hfma_scale_rep = 0x2C003C00;
const half2& hfma_bias = reinterpret_cast<const half2&>(hfma_bias_rep);
const half2& hfma_scale = reinterpret_cast<const half2&>(hfma_scale_rep);
#pragma unroll
for (int ii = 0; ii < 4; ++ii)
{
__half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias);
}
}
template <int N>
__device__ __forceinline__ static void convert(void* src, void* dst)
{
static_assert(N == 8 || N == 16);
convert(reinterpret_cast<uint32_t*>(src)[0], reinterpret_cast<uint4*>(dst)[0]);
if constexpr (N == 16)
{
convert(reinterpret_cast<uint32_t*>(src)[1], reinterpret_cast<uint4*>(dst)[1]);
}
}
};
template <typename TVec, int N, bool Enable, typename TSrc>
__device__ __forceinline__ void load(void* dst, TSrc* src, int stride)
{
if constexpr (Enable)
{
#pragma unroll
for (int ii = 0; ii < N; ++ii)
{
reinterpret_cast<TVec*>(dst)[ii] = reinterpret_cast<TVec*>(src + ii * stride)[0];
}
}
}
template <int M, int K, bool Enable>
__device__ __forceinline__ void apply_scale(void* act, void* act_scale)
{
static_assert(K % 2 == 0);
static constexpr int VecK = K / 2;
if constexpr (Enable)
{
half2* pa = reinterpret_cast<half2*>(act);
half2* pb = reinterpret_cast<half2*>(act_scale);
#pragma unroll
for (int m = 0; m < M; ++m)
{
#pragma unroll
for (int k = 0; k < VecK; ++k)
{
pa[m * VecK + k] = __hmul2(pa[m * VecK + k], pb[k]);
}
}
}
}
template <int N, int K, bool EnableZero>
__device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros)
{
using Converter = ConverterI4ToF16;
static_assert(K % 2 == 0);
static constexpr int VecK = K / 2;
#pragma unroll
for (int n = 0; n < N; ++n)
{
ConverterI4ToF16::convert<K>(
reinterpret_cast<uint8_t*>(quantized_w) + n * K / 2, reinterpret_cast<half*>(w) + n * K);
half2 vec_scale = __half2half2(reinterpret_cast<half*>(scales)[n]);
half2 vec_zero = __half2half2(__float2half_rn(0.f));
if constexpr (EnableZero)
{
vec_zero = __half2half2(reinterpret_cast<half*>(zeros)[n]);
}
#pragma unroll
for (int k = 0; k < VecK; ++k)
{
reinterpret_cast<half2*>(w)[n * VecK + k]
= __hfma2(reinterpret_cast<half2*>(w)[n * VecK + k], vec_scale, vec_zero);
}
}
}
template <int N, int K>
__device__ __forceinline__ void pack_to_vec2(void* dst, void* src)
{
#pragma unroll
for (int n = 0; n < N; n += 2)
{
#pragma unroll
for (int k = 0; k < K; ++k)
{
reinterpret_cast<half*>(dst)[n * K + k * 2] = reinterpret_cast<half*>(src)[n * K + k];
reinterpret_cast<half*>(dst)[n * K + k * 2 + 1] = reinterpret_cast<half*>(src)[(n + 1) * K + k];
}
}
}
template <int M, int N, int K>
__device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act)
{
static_assert(N % 2 == 0);
static constexpr int VecN = N / 2;
#pragma unroll
for (int m = 0; m < M; ++m)
{
#pragma unroll
for (int n = 0; n < VecN; ++n)
{
#pragma unroll
for (int k = 0; k < K; ++k)
{
reinterpret_cast<half2*>(acc)[m * VecN + n] = __hfma2(reinterpret_cast<half2*>(w_pack2)[n * K + k],
__half2half2(reinterpret_cast<half*>(act)[m * K + k]), reinterpret_cast<half2*>(acc)[m * VecN + n]);
}
}
}
}
template <typename T>
__device__ __forceinline__ T warp_reduce_sum(T& val)
{
val += __shfl_xor_sync(~0, val, 16);
val += __shfl_xor_sync(~0, val, 8);
val += __shfl_xor_sync(~0, val, 4);
val += __shfl_xor_sync(~0, val, 2);
val += __shfl_xor_sync(~0, val, 1);
return val;
}
template <int CtaM, int CtaN, int Threads, bool EnableBias>
__device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias, float alpha)
{
static constexpr int WarpSize = 32;
static constexpr int WarpNum = Threads / WarpSize;
static constexpr int AlignShmemSize = (CtaM * CtaN + 31) / 32 * 32;
static_assert(Threads % WarpSize == 0);
__shared__ float shmem[AlignShmemSize * WarpNum];
int tid = threadIdx.x;
int warp_id = tid / WarpSize, lane_id = tid % WarpSize;
#pragma unroll
for (int m = 0; m < CtaM; ++m)
{
#pragma unroll
for (int n = 0; n < CtaN; ++n)
{
float v = __half2float(reinterpret_cast<half*>(tile_acc)[m * CtaN + n]);
v = warp_reduce_sum(v);
if (lane_id == 0)
{
shmem[warp_id * AlignShmemSize + m * CtaN + n] = v;
}
}
}
__syncthreads();
#pragma unroll
for (int ii = tid; ii < CtaM * CtaN; ii += Threads)
{
int m = ii / CtaN, n = ii % CtaN;
float val = 0.f, v_bias = 0.f;
if constexpr (EnableBias)
{
v_bias = static_cast<float>(reinterpret_cast<half*>(bias)[n]);
}
#pragma unroll
for (int jj = 0; jj < WarpNum; ++jj)
{
val += shmem[jj * AlignShmemSize + ii];
}
reinterpret_cast<half*>(out)[m * stride + n] = __float2half_rn(alpha * val + v_bias);
}
}
template <int N>
__device__ __forceinline__ void fill(void* tile, half v)
{
#pragma unroll
for (int ii = 0; ii < N; ++ii)
{
reinterpret_cast<half*>(tile)[ii] = v;
}
}
struct Fp16Details
{
using ActDataType = half;
static constexpr int StepK = 8;
using AccessTypeAct = float4;
using AccessTypeActScale = float4;
using AccessTypeW = float;
template <int CtaM>
__device__ __forceinline__ static void load_act(void* dst, void* src, int stride)
{
load<AccessTypeAct, CtaM, true>(dst, reinterpret_cast<ActDataType*>(src), stride);
}
};
struct Fp8Details
{
using ActDataType = __nv_fp8_e4m3;
static constexpr int StepK = 8;
using AccessTypeAct = float2;
using AccessTypeActScale = float4;
using AccessTypeW = float;
__device__ __forceinline__ static void conversion(void* dst, void* src)
{
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
#pragma unroll
for (int ii = 0; ii < StepK / 4; ++ii)
{
asm volatile(
"{\n"
".reg .b16 lo, hi;\n"
"mov.b32 {lo, hi}, %2;\n"
"cvt.rn.f16x2.e4m3x2 %0, lo;\n"
"cvt.rn.f16x2.e4m3x2 %1, hi;\n"
"}\n"
: "=r"(reinterpret_cast<uint32_t*>(dst)[ii * 2]), "=r"(reinterpret_cast<uint32_t*>(dst)[ii * 2 + 1])
: "r"(reinterpret_cast<uint32_t*>(src)[ii]));
}
#else
#pragma unroll
for (int ii = 0; ii < StepK; ++ii)
{
reinterpret_cast<half*>(dst)[ii] = static_cast<half>(reinterpret_cast<ActDataType*>(src)[ii]);
}
#endif
}
template <int CtaM>
__device__ __forceinline__ static void load_act(void* dst, void* src, int stride)
{
ActDataType vec[CtaM * StepK];
load<AccessTypeAct, CtaM, true>(vec, reinterpret_cast<ActDataType*>(src), stride);
#pragma unroll
for (int ii = 0; ii < CtaM; ++ii)
{
conversion(reinterpret_cast<half*>(dst) + ii * StepK, vec + ii * StepK);
}
}
};
template <typename Details, int CtaM, int CtaN, int Threads, int GroupSize, bool EnableActScale, bool EnableZero,
bool EnableBias>
__global__ void kernel(typename Details::ActDataType* act, half* act_scale, uint8_t* weight, half* scales, half* zeros,
half* bias, half* out, float alpha, int m, int n, int k)
{
// ArgType ArgName DataType Shape Layout
//
// input act fp16 [m, k] RowMajor
// input act_scale fp16 [1, k] RowMajor
// input weight int4b [k, n] ColumnMajor
// input scales fp16 [k / GroupSize, n] RowMajor
// input zeros fp16 [k / GroupSize, n] RowMajor
// input bias fp16 [1, n] RowMajor
// output out fp16 [m, n] RowMajor
using AccessTypeActScale = typename Details::AccessTypeActScale;
using AccessTypeW = typename Details::AccessTypeW;
static constexpr int StepK = Details::StepK;
static constexpr bool Mandatory = true;
static constexpr int CtaK = StepK * Threads;
static_assert(CtaN % 2 == 0);
const int m_tile_id = blockIdx.x, n_tile_id = blockIdx.y, tid = threadIdx.x;
const int m_offset = m_tile_id * CtaM, n_offset = n_tile_id * CtaN;
act += m_offset * k;
weight += n_offset * k / 2;
scales += n_offset;
zeros += n_offset;
bias += n_offset;
out += m_offset * n + n_offset;
half tile_a[StepK * CtaM], tile_w[StepK * CtaN], tile_w_pack2[StepK * CtaN];
half tile_acc[CtaM * CtaN];
fill<CtaM * CtaN>(tile_acc, __float2half_rn(0.f));
for (int idx_k = tid * StepK; idx_k < k; idx_k += CtaK)
{
half vec_act_scale[StepK];
half vec_scale[CtaN], vec_zero[CtaN];
uint8_t tile_w_quantized[StepK * CtaN / 2];
// Load Data
Details::load_act<CtaM>(tile_a, act + idx_k, k);
load<AccessTypeActScale, 1, EnableActScale>(vec_act_scale, act_scale + idx_k, 0);
load<AccessTypeW, CtaN, Mandatory>(tile_w_quantized, weight + idx_k / 2, k / 2);
load<half, CtaN, Mandatory>(vec_scale, scales + idx_k / GroupSize * n, 1);
load<half, CtaN, EnableZero>(vec_zero, zeros + idx_k / GroupSize * n, 1);
// Dequantize Data
apply_scale<CtaM, StepK, EnableActScale>(tile_a, vec_act_scale);
dequantize<CtaN, StepK, EnableZero>(tile_w, tile_w_quantized, vec_scale, vec_zero);
// Rearrange
pack_to_vec2<CtaN, StepK>(tile_w_pack2, tile_w);
// MMA
mma<CtaM, CtaN, StepK>(tile_acc, tile_w_pack2, tile_a);
}
// Epilogue
epilogue<CtaM, CtaN, Threads, EnableBias>(out, n, tile_acc, bias, alpha);
}
template <typename Details, int CtaM, int CtaN, int Threads, int GroupSize, bool EnableActScale, bool EnableZero,
bool EnableBias>
void exec_kernel(Params& params, cudaStream_t s)
{
if (params.m % CtaM || params.n % CtaN)
{
throw std::runtime_error("launch failed");
}
dim3 grid(params.m / CtaM, params.n / CtaN);
dim3 block(Threads);
// clang-format off
kernel<Details, CtaM, CtaN, Threads, GroupSize, EnableActScale, EnableZero, EnableBias><<<grid, block, 0, s>>>(
reinterpret_cast<typename Details::ActDataType*>(params.act),
reinterpret_cast<half*>(params.act_scale),
reinterpret_cast<uint8_t*>(params.weight),
reinterpret_cast<half*>(params.scales),
reinterpret_cast<half*>(params.zeros),
reinterpret_cast<half*>(params.bias),
reinterpret_cast<half*>(params.out),
params.alpha,
params.m, params.n, params.k
);
// clang-format on
}
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,123 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernel.h"
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernelLauncher.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
#define DISPATCHER_FOR_M(target_m, CtaM, CtaN, Threads) \
do \
{ \
if (params.m == target_m) \
{ \
exec_kernel<Details, CtaM, CtaN, Threads, GroupSize, EnableActScale, EnableZero, EnableBias>(params, s); \
return; \
} \
} while (0);
template <typename Details, int GroupSize, bool EnableActScale, bool EnableZero, bool EnableBias>
void dispatcher(Params& params, cudaStream_t s)
{
// clang-format off
DISPATCHER_FOR_M(1, 1, 8, 128);
DISPATCHER_FOR_M(2, 2, 4, 128);
DISPATCHER_FOR_M(3, 3, 16, 128);
DISPATCHER_FOR_M(4, 4, 16, 128);
DISPATCHER_FOR_M(5, 5, 16, 128);
DISPATCHER_FOR_M(6, 6, 16, 128);
DISPATCHER_FOR_M(7, 7, 16, 128);
DISPATCHER_FOR_M(8, 8, 16, 128);
DISPATCHER_FOR_M(9, 9, 8, 128);
DISPATCHER_FOR_M(10, 10, 8, 128);
DISPATCHER_FOR_M(11, 11, 8, 128);
DISPATCHER_FOR_M(12, 12, 8, 128);
DISPATCHER_FOR_M(13, 13, 8, 128);
DISPATCHER_FOR_M(14, 14, 8, 128);
DISPATCHER_FOR_M(15, 15, 8, 128);
DISPATCHER_FOR_M(16, 16, 8, 128);
// clang-format on
throw std::runtime_error("unsupported m");
}
template <typename Details, int GroupSize>
void check_pointer(Params& params, cudaStream_t s)
{
if (params.act_scale && params.zeros && params.bias)
{
dispatcher<Details, GroupSize, true, true, true>(params, s);
}
else if (params.act_scale && params.zeros && !params.bias)
{
dispatcher<Details, GroupSize, true, true, false>(params, s);
}
else if (params.act_scale && !params.zeros && params.bias)
{
dispatcher<Details, GroupSize, true, false, true>(params, s);
}
else if (!params.act_scale && params.zeros && params.bias)
{
dispatcher<Details, GroupSize, false, true, true>(params, s);
}
else if (!params.act_scale && !params.zeros && params.bias)
{
dispatcher<Details, GroupSize, false, false, true>(params, s);
}
else if (params.act_scale && !params.zeros && !params.bias)
{
dispatcher<Details, GroupSize, true, false, false>(params, s);
}
else if (!params.act_scale && params.zeros && !params.bias)
{
dispatcher<Details, GroupSize, false, true, false>(params, s);
}
else
{
dispatcher<Details, GroupSize, false, false, false>(params, s);
}
}
template <typename Details>
void select_gs(Params& params, cudaStream_t s)
{
if (params.groupsize == 64)
{
check_pointer<Details, 64>(params, s);
}
else if (params.groupsize == 128)
{
check_pointer<Details, 128>(params, s);
}
}
void kernel_launcher(Params& params, cudaStream_t s)
{
if (params.type == KernelType::W4A16)
{
select_gs<Fp16Details>(params, s);
}
else if (params.type == KernelType::W4A8)
{
select_gs<Fp8Details>(params, s);
}
}
} // namespace weight_only
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,30 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/common.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace weight_only
{
void kernel_launcher(Params& params, cudaStream_t s);
}
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -16,9 +16,7 @@
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/penaltyKernels.h"
#include "tensorrt_llm/layers/baseBeamSearchLayer.h"
#include "tensorrt_llm/layers/fillBuffers.h"
#include <algorithm>
@ -111,11 +109,6 @@ void BaseBeamSearchLayer<T>::freeBuffer()
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (mIsAllocateBuffer)
{
mAllocator->free((void**) (&temperature_buf_));
mAllocator->free((void**) (&min_lengths_buf_));
mAllocator->free((void**) (&repetition_penalty_buf_));
mAllocator->free((void**) (&presence_penalty_buf_));
mAllocator->free((void**) (&frequency_penalty_buf_));
mIsAllocateBuffer = false;
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
@ -125,12 +118,6 @@ template <typename T>
void BaseBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
temperature_buf_ = mAllocator->reMalloc(temperature_buf_, sizeof(float) * batch_size, false);
min_lengths_buf_ = mAllocator->reMalloc(min_lengths_buf_, sizeof(int) * batch_size, false);
repetition_penalty_buf_ = mAllocator->reMalloc(repetition_penalty_buf_, sizeof(float) * batch_size, false);
presence_penalty_buf_ = mAllocator->reMalloc(presence_penalty_buf_, sizeof(float) * batch_size, false);
frequency_penalty_buf_ = mAllocator->reMalloc(frequency_penalty_buf_, sizeof(float) * batch_size, false);
mIsAllocateBuffer = true;
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
@ -138,47 +125,13 @@ void BaseBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
template <typename T>
void BaseBeamSearchLayer<T>::setupBase(size_t batch_size, SetupParams const& setupParams)
{
allocateBuffer(batch_size);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
// Setup penalties.
FillBuffers const fillBuffers{batch_size, mStream};
use_temperature_ = static_cast<bool>(setupParams.temperature);
use_repetition_penalty_ = static_cast<bool>(setupParams.repetition_penalty);
use_presence_penalty_ = static_cast<bool>(setupParams.presence_penalty);
use_frequency_penalty_ = static_cast<bool>(setupParams.frequency_penalty);
use_min_lengths_ = static_cast<bool>(setupParams.min_length);
if (use_temperature_)
{
fillBuffers(setupParams.temperature, getDefaultPenaltyValue(RepetitionPenaltyType::Temperature), mTemperature,
temperature_buf_, (float*) nullptr, (int*) nullptr);
}
if (use_repetition_penalty_)
{
fillBuffers(setupParams.repetition_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Repetition),
mRepetitionPenalty, repetition_penalty_buf_, (float*) nullptr, (int*) nullptr);
}
if (use_presence_penalty_)
{
fillBuffers(setupParams.presence_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Presence),
mPresencePenalty, presence_penalty_buf_, (float*) nullptr, (int*) nullptr);
}
if (use_frequency_penalty_)
{
fillBuffers(setupParams.frequency_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Frequency),
mFrequencyPenalty, frequency_penalty_buf_, (float*) nullptr, (int*) nullptr);
}
if (use_min_lengths_)
{
fillBuffers(setupParams.min_length, (int) getDefaultPenaltyValue(RepetitionPenaltyType::MinLength), mMinLengths,
min_lengths_buf_, (int*) nullptr, (int*) nullptr);
}
allocateBuffer(batch_size);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void BaseBeamSearchLayer<T>::forward(BeamSearchOutputParams& outputs, ForwardParams const& params,
int* penalty_workspace, const int* penalty_workspace_prev)
void BaseBeamSearchLayer<T>::forward(BeamSearchOutputParams& outputs, ForwardParams const& params)
{
TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__);
Tensor& output_ids_ptr = outputs.output_ids_ptr;
@ -195,47 +148,6 @@ void BaseBeamSearchLayer<T>::forward(BeamSearchOutputParams& outputs, ForwardPar
Tensor const& logits = params.logits;
const auto local_batch_size = logits.shape[0];
#define ALL_OF(p_, sz_, dt_, v_) (std::all_of(p_, p_ + sz_, [&](dt_ b) { return b == v_; }))
const T* embedding_bias = params.embedding_bias ? params.embedding_bias->template getPtr<const T>() : nullptr;
auto* temperatures = (use_temperature_
&& !ALL_OF(std::begin(mTemperature) + ite * local_batch_size, local_batch_size, float,
getDefaultPenaltyValue(RepetitionPenaltyType::Temperature)))
? temperature_buf_ + ite * local_batch_size
: nullptr;
auto* repetition_penalties
= (use_repetition_penalty_
&& !ALL_OF(std::begin(mRepetitionPenalty) + ite * local_batch_size, local_batch_size, float,
getDefaultPenaltyValue(RepetitionPenaltyType::Repetition)))
? repetition_penalty_buf_ + ite * local_batch_size
: nullptr;
auto* presence_penalties = (use_presence_penalty_
&& !ALL_OF(std::begin(mPresencePenalty) + ite * local_batch_size, local_batch_size,
float, getDefaultPenaltyValue(RepetitionPenaltyType::Presence)))
? presence_penalty_buf_ + ite * local_batch_size
: nullptr;
auto* frequency_penalties = (use_frequency_penalty_
&& !ALL_OF(std::begin(mFrequencyPenalty) + ite * local_batch_size, local_batch_size,
float, getDefaultPenaltyValue(RepetitionPenaltyType::Frequency)))
? frequency_penalty_buf_ + ite * local_batch_size
: nullptr;
auto* min_lengths = (use_min_lengths_
&& !ALL_OF(std::begin(mMinLengths) + ite * local_batch_size, local_batch_size, int,
(int) getDefaultPenaltyValue(RepetitionPenaltyType::MinLength)))
? min_lengths_buf_ + ite * local_batch_size
: nullptr;
InvokeBatchApplyPenaltyParams<T> penalty_params{logits.getPtr<T>(), embedding_bias,
penalty_workspace + ite * local_batch_size * beam_width * vocab_size_,
penalty_workspace_prev + ite * local_batch_size * beam_width * vocab_size_, temperatures, repetition_penalties,
presence_penalties, frequency_penalties,
(use_repetition_penalty_ || use_presence_penalty_ || use_frequency_penalty_), local_batch_size, beam_width,
max_seq_len, vocab_size_, vocab_size_padded_, output_ids_ptr.template getPtr<const int*>(),
outputs.parent_ids_ptr.template getPtr<const int*>(), input_lengths, sequence_length, min_lengths,
params.end_ids.template getPtr<const int>(), nullptr, mStream};
invokeBatchApplyPenalty(penalty_params);
sync_check_cuda_error();
invokeSoftMax(outputs, params);
sync_check_cuda_error();

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -19,7 +19,6 @@
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/kernels/beamSearchTopkKernels.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/penaltyTypes.h"
#include "tensorrt_llm/layers/baseLayer.h"
#include "tensorrt_llm/layers/decodingParams.h"
@ -72,8 +71,7 @@ public:
tc::Tensor src_cache_indirection; // [local_batch_size, beam_width, max_seq_len]
// optional parameters
std::optional<tc::Tensor> embedding_bias; // [vocab_size_padded]
std::optional<tc::Tensor> input_lengths; // [local_batch_size * beam_width]
std::optional<tc::Tensor> input_lengths; // [local_batch_size * beam_width]
};
class BeamSearchOutputParams : public DecodingOutputParams
@ -96,8 +94,7 @@ public:
parent_ids_ptr; // [batch_size] int*, each array is [beam_width, max_seq_len], necessary in beam search
};
void forward(BeamSearchOutputParams& outputs, ForwardParams const& params, int* penalty_workspace,
const int* penalty_workspace_prev);
void forward(BeamSearchOutputParams& outputs, ForwardParams const& params);
protected:
// meta data
@ -107,24 +104,6 @@ protected:
size_t topk_softmax_workspace_size_;
void* topk_softmax_workspace_ = nullptr;
float* temperature_buf_;
float* repetition_penalty_buf_;
float* presence_penalty_buf_;
float* frequency_penalty_buf_;
int* min_lengths_buf_;
std::vector<float> mTemperature;
std::vector<float> mRepetitionPenalty;
std::vector<float> mPresencePenalty;
std::vector<float> mFrequencyPenalty;
std::vector<int> mMinLengths;
bool use_temperature_ = false;
bool use_repetition_penalty_ = false;
bool use_presence_penalty_ = false;
bool use_frequency_penalty_ = false;
bool use_min_lengths_ = false;
virtual void invokeSoftMax(BeamSearchOutputParams& outputs, SoftmaxParams const& params) = 0;
void setupBase(size_t batch_size, SetupParams const& setupParams);

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -32,64 +32,6 @@ namespace tensorrt_llm
{
namespace layers
{
template <typename T>
void BaseSamplingLayer<T>::allocateBuffer(size_t batchSize)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
std::array<size_t, 10> deviceBufferSizes;
deviceBufferSizes[0] = sizeof(curandState_t) * batchSize;
deviceBufferSizes[1] = sizeof(uint64_t) * batchSize;
deviceBufferSizes[2] = sizeof(float) * batchSize;
deviceBufferSizes[3] = sizeof(float) * batchSize;
deviceBufferSizes[4] = sizeof(float) * batchSize;
deviceBufferSizes[5] = sizeof(float) * batchSize;
deviceBufferSizes[6] = sizeof(int) * batchSize;
deviceBufferSizes[7] = sizeof(T) * batchSize * mVocabSizePadded;
deviceBufferSizes[8] = sizeof(bool) * batchSize;
deviceBufferSizes[9] = sizeof(float) * batchSize;
mCurandStatesDevice = mAllocator->reMalloc(mCurandStatesDevice, deviceBufferSizes[0], false);
mRandomSeedsDevice = mAllocator->reMalloc(mRandomSeedsDevice, deviceBufferSizes[1], false);
mTemperaturesDevice = mAllocator->reMalloc(mTemperaturesDevice, deviceBufferSizes[2], false);
mRepetitionPenaltiesDevice = mAllocator->reMalloc(mRepetitionPenaltiesDevice, deviceBufferSizes[3], false);
mPresencePenaltiesDevice = mAllocator->reMalloc(mPresencePenaltiesDevice, deviceBufferSizes[4], false);
mFrequencyPenaltiesDevice = mAllocator->reMalloc(mFrequencyPenaltiesDevice, deviceBufferSizes[5], false);
mMinLengthsDevice = mAllocator->reMalloc(mMinLengthsDevice, deviceBufferSizes[6], false);
mRuntimeLogitsDevice = mAllocator->reMalloc(mRuntimeLogitsDevice, deviceBufferSizes[7], false);
mSkipDecodeDevice = mAllocator->reMalloc(mSkipDecodeDevice, deviceBufferSizes[8], false);
mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[9], false);
auto const bytesAllocated = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), (size_t) 0);
TLLM_LOG_DEBUG("baseSamplingLayer allocated %lu bytes on GPU", (size_t) bytesAllocated);
// host buffers.
mSkipDecodeHost = (bool*) std::realloc(mSkipDecodeHost, sizeof(bool) * batchSize);
TLLM_CHECK(mSkipDecodeHost != nullptr);
mIsAllocateBuffer = true;
}
template <typename T>
void BaseSamplingLayer<T>::freeBuffer()
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
if (mIsAllocateBuffer)
{
mAllocator->free((void**) (&mCurandStatesDevice));
mAllocator->free((void**) (&mRandomSeedsDevice));
mAllocator->free((void**) (&mTemperaturesDevice));
mAllocator->free((void**) (&mRepetitionPenaltiesDevice));
mAllocator->free((void**) (&mPresencePenaltiesDevice));
mAllocator->free((void**) (&mFrequencyPenaltiesDevice));
mAllocator->free((void**) (&mMinLengthsDevice));
mAllocator->free((void**) (&mRuntimeLogitsDevice));
mAllocator->free((void**) (&mSkipDecodeDevice));
mAllocator->free((void**) (&mSetupWorkspaceDevice));
std::free(mSkipDecodeHost);
mIsAllocateBuffer = false;
}
}
template <typename T>
BaseSamplingLayer<T>::BaseSamplingLayer(size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded,
cudaStream_t stream, std::shared_ptr<IAllocator> allocator, cudaDeviceProp* prop)
@ -98,166 +40,6 @@ BaseSamplingLayer<T>::BaseSamplingLayer(size_t maxBatchSize, size_t vocabSize, s
, mVocabSize(vocabSize)
, mVocabSizePadded(vocabSizePadded)
{
allocateBuffer(maxBatchSize);
}
template <typename T>
BaseSamplingLayer<T>::BaseSamplingLayer(BaseSamplingLayer const& samplingLayer)
: BaseLayer(samplingLayer)
, mMaxBatchSize(samplingLayer.mMaxBatchSize)
, mVocabSize(samplingLayer.mVocabSize)
, mVocabSizePadded(samplingLayer.mVocabSizePadded)
, mSamplingWorkspaceSize(samplingLayer.mSamplingWorkspaceSize)
{
allocateBuffer(mMaxBatchSize);
}
template <typename T>
void BaseSamplingLayer<T>::setupBase(const size_t batchSize, int const* batchSlots, SetupParams const& setupParams)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
// If runtime argument has single random seed, using this random seed to
// initialize the random table of all sentences. If the argument has
// [batchSize] random seeds, initializing the random table by different
// random seeds respectively. If no random seed, initialize the random table
// of all sentences by 0 directly.
if (setupParams.randomSeed)
{
if (setupParams.randomSeed->size() == 1)
{
invokeCurandInitialize(
mCurandStatesDevice, batchSlots, batchSize, setupParams.randomSeed->front(), mStream);
sync_check_cuda_error();
}
else
{
TLLM_CHECK_WITH_INFO(setupParams.randomSeed->size() == batchSize, "Random seed vector size mismatch.");
cudaAutoCpy(mRandomSeedsDevice, setupParams.randomSeed->data(), batchSize, mStream);
invokeCurandBatchInitialize(mCurandStatesDevice, batchSlots, batchSize, mRandomSeedsDevice, mStream);
sync_check_cuda_error();
}
}
else
{
// Initialize curand states using the default seed 0.
invokeCurandInitialize(mCurandStatesDevice, batchSlots, batchSize, 0, mStream);
}
// Setup penalties.
FillBuffers const fillBuffers{batchSize, mStream};
mUseTemperature = static_cast<bool>(setupParams.temperature);
mUseRepetitionPenalty = static_cast<bool>(setupParams.repetition_penalty);
mUsePresencePenalty = static_cast<bool>(setupParams.presence_penalty);
mUseFrequencyPenalty = static_cast<bool>(setupParams.frequency_penalty);
mUseMinLengths = static_cast<bool>(setupParams.min_length);
if (mUseTemperature)
{
fillBuffers(setupParams.temperature, getDefaultPenaltyValue(RepetitionPenaltyType::Temperature), mTemperature,
mTemperaturesDevice, mSetupWorkspaceDevice, batchSlots);
}
if (mUseRepetitionPenalty)
{
fillBuffers(setupParams.repetition_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Repetition),
mRepetitionPenalty, mRepetitionPenaltiesDevice, mSetupWorkspaceDevice, batchSlots);
}
if (mUsePresencePenalty)
{
fillBuffers(setupParams.presence_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Presence),
mPresencePenalty, mPresencePenaltiesDevice, mSetupWorkspaceDevice, batchSlots);
}
if (mUseFrequencyPenalty)
{
fillBuffers(setupParams.frequency_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Frequency),
mFrequencyPenalty, mFrequencyPenaltiesDevice, mSetupWorkspaceDevice, batchSlots);
}
if (mUseMinLengths)
{
fillBuffers(setupParams.min_length, (int) getDefaultPenaltyValue(RepetitionPenaltyType::MinLength), mMinLengths,
mMinLengthsDevice, mSetupWorkspaceDevice, batchSlots);
}
}
template <typename T>
void BaseSamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams const& inputs, int* penaltyWorkspace)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputs.logits.shape[0];
auto const step = inputs.step;
auto* const inputLengths = inputs.input_lengths ? inputs.input_lengths->template getPtr<const int>() : nullptr;
auto* logits = inputs.logits.template getPtr<T>();
TLLM_CHECK_WITH_INFO((inputs.batch_slots_host.has_value() ^ inputs.batch_slots.has_value()) == 0,
"either both batch_slots_host and batch_slots have to be provided or neither of them");
auto* batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : nullptr;
std::vector<int32_t> batchSlotsVec(batchSize);
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
auto* batchSlotsHost
= inputs.batch_slots_host ? inputs.batch_slots_host->template getPtr<const int>() : batchSlotsVec.data();
#define ALL_OF(addrs_, p_, sz_, v_) (std::all_of(addrs_, addrs_ + sz_, [&](int32_t b) { return p_[b] == v_; }))
if (ALL_OF(batchSlotsHost, mSkipDecodeHost, batchSize, true))
{
// No sample in the current batch to do TopX sampling.
return;
}
mSkipAny = std::any_of(batchSlotsHost, batchSlotsHost + batchSize,
[this](int32_t batchSlot) { return this->mSkipDecodeHost[batchSlot]; });
if (mSkipAny)
{
// A TopX Sampling layer directly changes the logit values. In case of
// skip_any==true, meaning topk and topp layers will run simultaneously for
// a batch in the same step. We copy the logits to an internal buffer, not
// affecting the other sampling layers.
TLLM_CHECK(inputs.logits.size() == batchSize * mVocabSizePadded);
cudaD2Dcpy(mRuntimeLogitsDevice, logits, inputs.logits.size(), mStream);
logits = mRuntimeLogitsDevice;
}
auto* embeddingBias = inputs.embedding_bias ? inputs.embedding_bias->template getPtr<T const>() : nullptr;
auto* temperatures = (mUseTemperature
&& !ALL_OF(batchSlotsHost, mTemperature, batchSize,
getDefaultPenaltyValue(RepetitionPenaltyType::Temperature)))
? mTemperaturesDevice
: nullptr;
auto* repetitionPenalties = (mUseRepetitionPenalty
&& !ALL_OF(batchSlotsHost, mRepetitionPenalty, batchSize,
getDefaultPenaltyValue(RepetitionPenaltyType::Repetition)))
? mRepetitionPenaltiesDevice
: nullptr;
auto* presencePenalties = (mUsePresencePenalty
&& !ALL_OF(batchSlotsHost, mPresencePenalty, batchSize,
getDefaultPenaltyValue(RepetitionPenaltyType::Presence)))
? mPresencePenaltiesDevice
: nullptr;
auto* frequencyPenalties = (mUseFrequencyPenalty
&& !ALL_OF(batchSlotsHost, mFrequencyPenalty, batchSize,
getDefaultPenaltyValue(RepetitionPenaltyType::Frequency)))
? mFrequencyPenaltiesDevice
: nullptr;
auto* minLengths = (mUseMinLengths
&& !ALL_OF(batchSlotsHost, mMinLengths, batchSize,
(int) getDefaultPenaltyValue(RepetitionPenaltyType::MinLength)))
? mMinLengthsDevice
: nullptr;
InvokeBatchApplyPenaltyParams<T> penaltyParams{logits, embeddingBias, penaltyWorkspace, nullptr, temperatures,
repetitionPenalties, presencePenalties, frequencyPenalties,
(mUseRepetitionPenalty || mUsePresencePenalty || mUseFrequencyPenalty), batchSize, 1, inputs.max_seq_len,
mVocabSize, mVocabSizePadded, outputs.output_ids_ptr.template getPtr<const int*>(), nullptr, inputLengths,
outputs.sequence_length->getPtr<const int>(), minLengths, inputs.end_ids.template getPtr<const int>(),
batchSlots, mStream};
invokeBatchApplyPenalty(penaltyParams);
sync_check_cuda_error();
#undef ALL_OF
runSampling(outputs, inputs);
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template class BaseSamplingLayer<float>;

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -32,8 +32,7 @@ namespace layers
{
//! \brief Base class for sampling layers.
//! Layer modifies logits in-place. However, when any of the requests skips the sampling layer,
//! logits are copied and modified in temporary buffer.
//! Layer modifies logits in-place.
template <typename T>
class BaseSamplingLayer : public BaseLayer
{
@ -63,8 +62,12 @@ public:
int max_seq_len;
// optional parameters
std::optional<tc::Tensor> embedding_bias; // [vocabSizePadded]
std::optional<tc::Tensor> input_lengths; // [localBatchSize]
std::optional<tc::Tensor> input_lengths; // [localBatchSize]
curandState_t* curand_states; // [localBatchSize]
// Pointer to the workspace for sampling computation
void* sampling_workspace;
// Flag to mark that logits tensor contains probabilities
bool probs_computed;
};
// clang-format off
@ -80,8 +83,6 @@ public:
BaseSamplingLayer(size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded, cudaStream_t stream,
std::shared_ptr<tensorrt_llm::common::IAllocator> allocator, cudaDeviceProp* prop);
BaseSamplingLayer(BaseSamplingLayer const& samplingLayer);
~BaseSamplingLayer() override = default;
// clang-format off
@ -92,9 +93,9 @@ public:
//!
//! \param outputs DecodingOutputParams struct with output tensors
//! \param inputs ForwardParams struct with input tensors and params
//! \param penaltyWorkspace
//! \param curandStatesDevice Properly initialized curand states buffer on device
// clang-format on
void forward(DecodingOutputParams& outputs, ForwardParams const& inputs, int* penaltyWorkspace);
virtual void forward(DecodingOutputParams& outputs, ForwardParams& inputs) = 0;
// clang-format off
//! \brief Virtual function that setups internal tensors of the layer with sampling params
@ -103,60 +104,28 @@ public:
//! Thus, it must be called only once for new requests.
//!
//! \param batchSize Maximum batch size configured in the system
//! \param batchSlots input tensor [batchSize], address map of the new requests
//! \param batchSlots input tensor [batchSize], address map of the new requests, in pinned memory
//! \param setupParams setup sampling parameters per request
// clang-format on
virtual void setup(size_t batchSize, int const* batchSlots, SetupParams const& setupParams) = 0;
virtual void setup(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) = 0;
protected:
//! \brief setup of the base class, has to be called in the beginning of the derived's class setup
void setupBase(size_t batchSize, int const* batchSlots, SetupParams const& setupParams);
size_t getWorkspaceSize() const
{
return mSamplingWorkspaceSize;
}
// clang-format off
//! \brief Executes sampling logic of the derived class
//!
//! \param outputs DecodingOutputParams struct with output tensors
//! \param inputs ForwardParams struct with input tensors and params
// clang-format on
virtual void runSampling(DecodingOutputParams& outputs, DecodingParams const& inputs) = 0;
virtual void freeBuffer();
size_t getAllocatedSize() const
{
return mAllocatedSize;
}
protected:
size_t mMaxBatchSize;
size_t mVocabSize;
size_t mVocabSizePadded;
size_t mSamplingWorkspaceSize;
void* mSamplingWorkspaceDevice = nullptr;
curandState_t* mCurandStatesDevice = nullptr;
uint64_t* mRandomSeedsDevice = nullptr;
float* mTemperaturesDevice = nullptr;
float* mRepetitionPenaltiesDevice = nullptr;
float* mPresencePenaltiesDevice = nullptr;
float* mFrequencyPenaltiesDevice = nullptr;
int32_t* mMinLengthsDevice = nullptr;
bool* mSkipDecodeDevice = nullptr;
T* mRuntimeLogitsDevice = nullptr;
void* mSetupWorkspaceDevice = nullptr;
std::vector<float> mTemperature;
std::vector<float> mRepetitionPenalty;
std::vector<float> mPresencePenalty;
std::vector<float> mFrequencyPenalty;
std::vector<int32_t> mMinLengths;
bool* mSkipDecodeHost = nullptr;
bool mSkipAny = false;
bool mUseTemperature = false;
bool mUseRepetitionPenalty = false;
bool mUsePresencePenalty = false;
bool mUseFrequencyPenalty = false;
bool mUseMinLengths = false;
private:
void allocateBuffer(size_t batchSize);
size_t mSamplingWorkspaceSize = 0;
size_t mAllocatedSize = 0;
};
} // namespace layers

View File

@ -50,11 +50,10 @@ public:
// mandatory parameters
int step;
int ite;
tc::Tensor logits; // [local_batch_size, beam_width, vocab_size_padded]
tc::Tensor end_ids; // [local_batch_size]
std::optional<tc::Tensor> batch_slots; // [local_batch_size]
std::optional<tc::Tensor> batch_slots_host; // [local_batch_size]
std::optional<tc::Tensor> finished; // [batch_size * beam_width]
tc::Tensor logits; // [local_batch_size, beam_width, vocab_size_padded]
tc::Tensor end_ids; // [local_batch_size]
std::optional<tc::Tensor> batch_slots; // [local_batch_size], on pinned memory
std::optional<tc::Tensor> finished; // [batch_size * beam_width]
};
class DecodingOutputParams

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2022, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -15,14 +15,15 @@
*/
#include "tensorrt_llm/layers/dynamicDecodeLayer.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/banBadWords.h"
#include "tensorrt_llm/kernels/banRepeatNgram.h"
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/kernels/penaltyKernels.h"
#include "tensorrt_llm/kernels/stopCriteriaKernels.h"
#include "tensorrt_llm/layers/baseBeamSearchLayer.h"
#include "tensorrt_llm/layers/fillBuffers.h"
#include "tensorrt_llm/layers/onlineBeamSearchLayer.h"
#include "tensorrt_llm/layers/topKSamplingLayer.h"
#include "tensorrt_llm/layers/topPSamplingLayer.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaStream.h"
@ -35,54 +36,6 @@ namespace tensorrt_llm
namespace layers
{
template <typename T>
void DynamicDecodeLayer<T>::initialize()
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
mOnlineBeamsearchDecode
= std::make_unique<OnlineBeamSearchLayer<T>>(vocab_size_, vocab_size_padded_, mStream, mAllocator);
mTopKDecode
= std::make_unique<TopKSamplingLayer<T>>(max_batch_size_, vocab_size_, vocab_size_padded_, mStream, mAllocator);
mTopPDecode = std::make_unique<TopPSamplingLayer<T>>(max_batch_size_, vocab_size_, vocab_size_padded_, mStream,
mAllocator, cuda_device_prop_, /* deterministic */ true);
mIdsPtrHost = runtime::BufferManager::pinned(ITensor::makeShape({}), runtime::TRTDataType<int*>::value);
}
template <typename T>
DynamicDecodeLayer<T>::DynamicDecodeLayer(size_t max_batch_size, size_t vocab_size, size_t vocab_size_padded,
cudaStream_t stream, std::shared_ptr<IAllocator> allocator, cudaDeviceProp* cuda_device_prop)
: BaseLayer(stream, std::move(allocator))
, max_batch_size_(max_batch_size)
, vocab_size_(vocab_size)
, vocab_size_padded_(vocab_size_padded)
, cuda_device_prop_(cuda_device_prop)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
initialize();
}
template <typename T>
DynamicDecodeLayer<T>::~DynamicDecodeLayer()
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
freeBuffer();
}
template <typename T>
DynamicDecodeLayer<T>::DynamicDecodeLayer(DynamicDecodeLayer const& dynamic_decode_layer)
: BaseLayer(dynamic_decode_layer)
, max_batch_size_(dynamic_decode_layer.max_batch_size_)
, vocab_size_(dynamic_decode_layer.vocab_size_)
, vocab_size_padded_(dynamic_decode_layer.vocab_size_padded_)
, cuda_device_prop_(dynamic_decode_layer.cuda_device_prop_)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
initialize();
}
namespace
{
template <typename T>
@ -118,20 +71,166 @@ bool hasDiffRuntimeArgs(DecodingSetupParams const& params)
} // namespace
template <typename T>
void DynamicDecodeLayer<T>::setup(
size_t batch_size, size_t beam_width, int const* batch_slots, SetupParams const& setupParams)
DynamicDecodeLayer<T>::DynamicDecodeLayer(DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth,
size_t vocabSize, size_t vocabSizePadded, cudaStream_t stream, std::shared_ptr<IAllocator> allocator,
cudaDeviceProp* cudaDeviceProp)
: BaseLayer(stream, std::move(allocator))
, mDecodingMode(mode)
, mMaxBatchSize(maxBatchSize)
, mMaxBeamWidth(maxBeamWidth)
, mVocabSize(vocabSize)
, mVocabSizePadded(vocabSizePadded)
, mCudaDeviceProp(cudaDeviceProp)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
initialize();
}
template <typename T>
DynamicDecodeLayer<T>::DynamicDecodeLayer(DynamicDecodeLayer const& dynamicDecodeLayer)
: BaseLayer(dynamicDecodeLayer)
, mDecodingMode(dynamicDecodeLayer.mDecodingMode)
, mMaxBatchSize(dynamicDecodeLayer.mMaxBatchSize)
, mMaxBeamWidth(dynamicDecodeLayer.mMaxBeamWidth)
, mVocabSize(dynamicDecodeLayer.mVocabSize)
, mVocabSizePadded(dynamicDecodeLayer.mVocabSizePadded)
, mCudaDeviceProp(dynamicDecodeLayer.mCudaDeviceProp)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
initialize();
}
template <typename T>
DynamicDecodeLayer<T>::~DynamicDecodeLayer()
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
freeBuffer();
}
template <typename T>
void DynamicDecodeLayer<T>::initialize()
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
if (beam_width == 1)
{ // sampling layers
typename TopPSamplingLayer<T>::SetupParams samplingParams;
mIdsPtrHost = runtime::BufferManager::pinned(ITensor::makeShape({}), runtime::TRTDataType<int*>::value);
mLogitsPtrsHost = runtime::BufferManager::pinned(ITensor::makeShape({}), runtime::TRTDataType<T*>::value);
samplingParams.temperature = setupParams.temperature;
samplingParams.min_length = setupParams.min_length;
samplingParams.repetition_penalty = setupParams.repetition_penalty;
samplingParams.presence_penalty = setupParams.presence_penalty;
samplingParams.frequency_penalty = setupParams.frequency_penalty;
allocateBuffer();
mCyclicStep = 0;
mRuntimeMaxSeqLen = 0;
mConfiguredBeamWidth = -1;
mTemperature.resize(mMaxBatchSize);
mRepetitionPenalty.resize(mMaxBatchSize);
mPresencePenalty.resize(mMaxBatchSize);
mFrequencyPenalty.resize(mMaxBatchSize);
mMinLength.resize(mMaxBatchSize);
if (!mDecodingMode.isNone())
{
mConfiguredBeamWidth = mMaxBeamWidth;
initializeLayers();
}
}
template <typename T>
void DynamicDecodeLayer<T>::allocateBuffer()
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
mZeroParentIdsDevice = mAllocator->reMalloc(mZeroParentIdsDevice, sizeof(int*) * 2 * mMaxBatchSize, false);
mTemperatureDevice = mAllocator->reMalloc(mTemperatureDevice, sizeof(float) * mMaxBatchSize, false);
mRepetitionPenaltyDevice = mAllocator->reMalloc(mRepetitionPenaltyDevice, sizeof(float) * mMaxBatchSize, false);
mPresencePenaltyDevice = mAllocator->reMalloc(mPresencePenaltyDevice, sizeof(float) * mMaxBatchSize, false);
mFrequencyPenaltyDevice = mAllocator->reMalloc(mFrequencyPenaltyDevice, sizeof(float) * mMaxBatchSize, false);
mMinLengthDevice = mAllocator->reMalloc(mMinLengthDevice, sizeof(int32_t) * mMaxBatchSize, false);
mRuntimeLogitsDevice = mAllocator->reMalloc(
mRuntimeLogitsDevice, sizeof(T) * mMaxBatchSize * mMaxBeamWidth * mVocabSizePadded, false);
}
template <typename T>
void DynamicDecodeLayer<T>::freeBuffer()
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
mAllocator->free((void**) &mZeroParentIdsDevice);
if (mPenaltyWorkspaceDevice != nullptr)
{
mAllocator->free((void**) &mPenaltyWorkspaceDevice);
}
if (mPenaltyWorkspacePrevDevice != nullptr)
{
mAllocator->free((void**) &mPenaltyWorkspacePrevDevice);
}
mAllocator->free((void**) (&mTemperatureDevice));
mAllocator->free((void**) (&mRepetitionPenaltyDevice));
mAllocator->free((void**) (&mPresencePenaltyDevice));
mAllocator->free((void**) (&mFrequencyPenaltyDevice));
mAllocator->free((void**) (&mMinLengthDevice));
mAllocator->free((void**) (&mRuntimeLogitsDevice));
}
template <typename T>
void DynamicDecodeLayer<T>::initializeLayers()
{
const size_t workspaceSize = sizeof(int) * mMaxBatchSize * mConfiguredBeamWidth * mVocabSize;
mPenaltyWorkspaceDevice = mAllocator->reMalloc(mPenaltyWorkspaceDevice, workspaceSize, false);
if (mDecodingMode.isTopKorTopP())
{
mSamplingLayer = std::make_unique<SamplingLayer<T>>(
mDecodingMode, mMaxBatchSize, mVocabSize, mVocabSizePadded, mStream, mAllocator, mCudaDeviceProp);
}
else if (mDecodingMode.isBeamSearch())
{
mOnlineBeamSearchDecode
= std::make_unique<OnlineBeamSearchLayer<T>>(mVocabSize, mVocabSizePadded, mStream, mAllocator);
mPenaltyWorkspacePrevDevice = mAllocator->reMalloc(mPenaltyWorkspacePrevDevice, workspaceSize, false);
}
else
{
TLLM_CHECK_WITH_INFO(false, "Decoding mode is none of the supported {TopK, TopP, TopKTopP, BeamSearch}");
}
}
template <typename T>
void DynamicDecodeLayer<T>::setup(
size_t batchSize, size_t beamWidth, int32_t const* batchSlots, SetupParams const& setupParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (mConfiguredBeamWidth == -1)
{
// This code is left only for Python runtime
// In C++ runtime given maxBeamWidth should always be equal to the runtime beamWidth
TLLM_CHECK(mDecodingMode.isNone());
mConfiguredBeamWidth = beamWidth;
mDecodingMode = mConfiguredBeamWidth == 1 ? DecodingMode::TopKTopP() : DecodingMode::BeamSearch();
initializeLayers();
}
TLLM_CHECK_WITH_INFO((mConfiguredBeamWidth == 1 && beamWidth == 1)
|| (mConfiguredBeamWidth > 1 && beamWidth > 1 && beamWidth <= mConfiguredBeamWidth),
"Decoder is configured with beam width %d, but %lu was given", mConfiguredBeamWidth, beamWidth);
TLLM_CHECK_WITH_INFO(mConfiguredBeamWidth <= mMaxBeamWidth,
"Decoder is created with max beam width %lu, but %d was given", mMaxBeamWidth, mConfiguredBeamWidth);
setupLayers(batchSize, beamWidth, batchSlots, setupParams);
setupPenalties(batchSize, batchSlots, setupParams);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void DynamicDecodeLayer<T>::setupLayers(
size_t batchSize, size_t beamWidth, int32_t const* batchSlots, SetupParams const& setupParams)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (beamWidth == 1)
{ // sampling layers
TLLM_CHECK_WITH_INFO(
mDecodingMode.isTopKorTopP(), "beamWidth == 1 is given, but decoder is not configured as TopK or TopP");
typename TopPSamplingLayer<T>::SetupParams samplingParams;
samplingParams.runtime_top_k = setupParams.runtime_top_k;
samplingParams.runtime_top_p = setupParams.runtime_top_p;
@ -142,166 +241,162 @@ void DynamicDecodeLayer<T>::setup(
samplingParams.top_p_reset_ids = setupParams.top_p_reset_ids;
samplingParams.normalize_log_probs = setupParams.normalize_log_probs;
mTopKDecode->setup(batch_size, batch_slots, samplingParams);
mTopPDecode->setup(batch_size, batch_slots, samplingParams);
mSamplingLayer->setup(batchSize, batchSlots, samplingParams);
}
else
{ // beam search layer
TLLM_CHECK_WITH_INFO(
mDecodingMode.isBeamSearch(), "beamWidth > 1 is given, but decoder is not configured as BeamSearch");
typename OnlineBeamSearchLayer<T>::SetupParams beamSearchParams;
beamSearchParams.temperature = setupParams.temperature;
beamSearchParams.min_length = setupParams.min_length;
beamSearchParams.repetition_penalty = setupParams.repetition_penalty;
beamSearchParams.presence_penalty = setupParams.presence_penalty;
beamSearchParams.frequency_penalty = setupParams.frequency_penalty;
beamSearchParams.beam_search_diversity_rate = setupParams.beam_search_diversity_rate;
beamSearchParams.length_penalty = setupParams.length_penalty;
has_diff_runtime_args_ = hasDiffRuntimeArgs(beamSearchParams);
mOnlineBeamsearchDecode->setup(batch_size, beamSearchParams);
mHasDiffRuntimeArgs = hasDiffRuntimeArgs(beamSearchParams);
mOnlineBeamSearchDecode->setup(batchSize, beamSearchParams);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void DynamicDecodeLayer<T>::allocateBuffer(size_t batch_size, size_t beam_width, size_t max_seq_len)
void DynamicDecodeLayer<T>::setupPenalties(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
mIdsPtrHost->resize(2 * batch_size);
zero_parent_ids = mAllocator->reMalloc(zero_parent_ids, sizeof(int*) * 2 * batch_size, false);
const size_t workspace_size = sizeof(int) * batch_size * beam_width * vocab_size_;
if (beam_width == 1)
{ // sampling layers
top_k_workspace = mAllocator->reMalloc(top_k_workspace, workspace_size, false);
top_p_workspace = mAllocator->reMalloc(top_p_workspace, workspace_size, false);
}
else
{ // beam search layer
beam_search_workspace_0 = mAllocator->reMalloc(beam_search_workspace_0, workspace_size, false);
beam_search_workspace_1 = mAllocator->reMalloc(beam_search_workspace_1, workspace_size, false);
}
mCyclicStep = 0;
}
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
std::vector<int32_t> batchSlotsVec(batchSize);
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
auto batchSlotsHost = batchSlots ? batchSlots : batchSlotsVec.data();
template <typename T>
void DynamicDecodeLayer<T>::freeBuffer()
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
mAllocator->free((void**) &zero_parent_ids);
if (top_k_workspace != nullptr)
// Setup penalties.
FillBuffers const fillBuffers{batchSize, mMaxBatchSize, mStream};
mUseTemperature = static_cast<bool>(setupParams.temperature);
mUseRepetitionPenalty = static_cast<bool>(setupParams.repetition_penalty);
mUsePresencePenalty = static_cast<bool>(setupParams.presence_penalty);
mUseFrequencyPenalty = static_cast<bool>(setupParams.frequency_penalty);
mUseMinLength = static_cast<bool>(setupParams.min_length);
if (mUseTemperature)
{
mAllocator->free((void**) &top_k_workspace);
fillBuffers(setupParams.temperature, getDefaultPenaltyValue(DecodingPenaltyType::Temperature), mTemperature,
mTemperatureDevice, batchSlotsHost);
}
if (top_p_workspace != nullptr)
if (mUseRepetitionPenalty)
{
mAllocator->free((void**) &top_p_workspace);
fillBuffers(setupParams.repetition_penalty, getDefaultPenaltyValue(DecodingPenaltyType::Repetition),
mRepetitionPenalty, mRepetitionPenaltyDevice, batchSlotsHost);
}
if (beam_search_workspace_0 != nullptr)
if (mUsePresencePenalty)
{
mAllocator->free((void**) &beam_search_workspace_0);
fillBuffers(setupParams.presence_penalty, getDefaultPenaltyValue(DecodingPenaltyType::Presence),
mPresencePenalty, mPresencePenaltyDevice, batchSlotsHost);
}
if (beam_search_workspace_1 != nullptr)
if (mUseFrequencyPenalty)
{
mAllocator->free((void**) &beam_search_workspace_1);
fillBuffers(setupParams.frequency_penalty, getDefaultPenaltyValue(DecodingPenaltyType::Frequency),
mFrequencyPenalty, mFrequencyPenaltyDevice, batchSlotsHost);
}
if (mUseMinLength)
{
fillBuffers(setupParams.min_length, (int) getDefaultPenaltyValue(DecodingPenaltyType::MinLength), mMinLength,
mMinLengthDevice, batchSlotsHost);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const& params)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
const auto ite = params.ite;
const auto step = params.step;
const auto max_step = step;
auto const& logits = params.logits;
TLLM_CHECK(logits.shape.size() == 3);
auto const batch_size = logits.shape[0];
auto const beam_width = logits.shape[1];
auto const local_batch_size = static_cast<std::size_t>(params.local_batch_size);
auto const max_seq_len = outputs.output_ids.shape[outputs.output_ids.shape.size() - 1];
TLLM_CHECK_WITH_INFO(params.logits || params.logits_vec, "Either logits or logits_vec have to be specified.");
TLLM_CHECK_WITH_INFO(
outputs.sequence_length.has_value(), "sequence_length tensor is mandatory in DynamicDecoderLayer.");
allocateBuffer(batch_size, beam_width, max_seq_len);
// std::vector<int*> ids_ptr_host;
auto idsPtrHost = runtime::bufferCast<int*>(*mIdsPtrHost);
for (int i = 0; i < batch_size; i++)
size_t batchSize = 0;
size_t beamWidth = 0;
size_t vocabSize = 0;
auto const maxSeqLen = outputs.output_ids.shape[outputs.output_ids.shape.size() - 1];
if (params.logits)
{
idsPtrHost[i] = outputs.output_ids.template getPtrWithOffset<int32_t>(i * beam_width * max_seq_len);
auto const& logitsShape = params.logits->shape;
TLLM_CHECK(logitsShape.size() == 3);
batchSize = logitsShape[0];
beamWidth = logitsShape[1];
vocabSize = logitsShape[2];
}
for (int i = 0; i < batch_size; i++)
else
{
if (beam_width > 1)
{
idsPtrHost[batch_size + i]
= outputs.parent_ids.value().template getPtrWithOffset<int32_t>(i * beam_width * max_seq_len);
}
else
{
idsPtrHost[batch_size + i] = zero_parent_ids + i * beam_width * max_seq_len;
}
TLLM_CHECK(params.logits_vec->size());
auto const& logitsShape = params.logits_vec.value()[0].shape;
TLLM_CHECK(logitsShape.size() == 3);
batchSize = params.logits_vec->size();
beamWidth = logitsShape[1];
vocabSize = logitsShape[2];
}
outputs.output_ids_ptr
= Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR, {batch_size, beam_width, max_seq_len}, idsPtrHost);
outputs.parent_ids_ptr
= Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR, {batch_size, beam_width, max_seq_len}, idsPtrHost + batch_size);
TLLM_CHECK_WITH_INFO((mConfiguredBeamWidth == 1 && beamWidth == 1)
|| (mConfiguredBeamWidth > 1 && beamWidth > 1 && beamWidth <= mConfiguredBeamWidth),
"Decoder is configured with beam width %d, but %lu was given", mConfiguredBeamWidth, beamWidth);
if (params.no_repeat_ngram_size)
if (!mLogitsPtrsHost->data())
{
const int* no_repeat_ngram_size_buf = params.no_repeat_ngram_size.value().template getPtr<const int>();
invokeBanRepeatNgram(logits.template getPtr<T>(), outputs.output_ids_ptr.template getPtr<const int*>(),
reinterpret_cast<FinishedState*>(
params.finished.value_or(Tensor{}).template getPtr<FinishedState::UnderlyingType>()),
outputs.parent_ids_ptr.template getPtr<const int*>(), nullptr,
outputs.sequence_length->template getPtr<int>(), batch_size, local_batch_size, beam_width, max_seq_len,
no_repeat_ngram_size_buf, vocab_size_padded_, max_step, mStream);
mLogitsPtrsHost = runtime::BufferManager::pinnedPool(
ITensor::makeShape({static_cast<int32_t>(maxSeqLen), static_cast<int32_t>(mMaxBatchSize)}),
runtime::TRTDataType<T*>::value);
mIdsPtrHost = runtime::BufferManager::pinnedPool(
ITensor::makeShape({static_cast<int32_t>(maxSeqLen), static_cast<int32_t>(2 * mMaxBatchSize)}),
runtime::TRTDataType<int32_t*>::value);
mRuntimeMaxSeqLen = maxSeqLen;
}
if (params.bad_words_list)
{
const auto& bad_words = params.bad_words_list.value();
const int* bad_words_ptr = bad_words.template getPtr<const int>();
TLLM_CHECK_WITH_INFO(
bad_words.shape.size() == 2 || bad_words.shape.size() == 3, "Bad words dimension must be 2 or 3.");
std::vector<int32_t> batchSlotsVec(batchSize);
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
auto batchSlotsHost = params.batch_slots ? params.batch_slots->template getPtr<const int>() : batchSlotsVec.data();
auto batchSlots = params.batch_slots ? params.batch_slots->template getPtr<const int>() : nullptr;
const bool is_matrix = bad_words.shape.size() == 2;
if (bad_words.shape.size() == 3)
{
TLLM_CHECK_WITH_INFO(bad_words.shape[0] == batch_size,
fmtstr("Shape of dim 0 of bad words is invalid. It "
"must be equal to batch size."
" However, it is %ld and the batch size is %ld.",
bad_words.shape[0], batch_size));
}
mCyclicStep = mCyclicStep % mRuntimeMaxSeqLen;
prepareIdsPtrs(outputs, batchSlotsHost, batchSize, beamWidth, maxSeqLen);
const bool shared_bad_words = is_matrix || bad_words.shape[0] == 1;
const size_t bad_words_len = bad_words.shape[is_matrix ? 1 : 2];
// Add check on batch size of bad words
const int id_offset = ite * local_batch_size;
const int decode_vocab_size_units_offset = id_offset * vocab_size_padded_;
auto logits = Tensor(MEMORY_GPU, std::is_same_v<T, float> ? DataType::TYPE_FP32 : DataType::TYPE_FP16,
{batchSize, beamWidth, mVocabSizePadded}, mRuntimeLogitsDevice);
invokeBanBadWords((T*) logits.getPtrWithOffset(decode_vocab_size_units_offset),
outputs.output_ids_ptr.template getPtr<const int*>(),
beam_width > 1 ? outputs.parent_ids_ptr.template getPtr<const int*>() : nullptr, nullptr, batch_size,
local_batch_size, beam_width,
shared_bad_words
? bad_words_ptr
: bad_words.template getPtrWithOffset<const int>(ite * local_batch_size * 2 * bad_words_len),
shared_bad_words, bad_words_len, vocab_size_padded_, outputs.sequence_length->template getPtr<int>(),
max_seq_len, mStream);
}
// Apply penalties
applyPenalties(outputs, params, batchSlotsHost, batchSlots, batchSize, beamWidth, maxSeqLen);
// Ban bad words and NGrams
banWords(logits, outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, mVocabSizePadded, mStream);
// Main function that calls forward of the respective layers
layersForward(logits, outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen);
// Check if stop conditions are met
checkStopCriteria(outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, mStream);
// Copy nextIds and transpose logits when needed
prepareOutputData(outputs, params, mIdsPtrHost, batchSlots, batchSize, beamWidth, maxSeqLen, mCyclicStep, mStream);
mCyclicStep += 1;
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void DynamicDecodeLayer<T>::layersForward(Tensor& logits, OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const ite = params.ite;
auto const step = params.step;
// common inputs
auto const& end_ids = params.end_ids;
auto const& endIds = params.end_ids;
auto const localBatchSize = static_cast<std::size_t>(params.local_batch_size);
// dynamic decode GPT
if (beam_width > 1)
if (beamWidth > 1)
{
TLLM_CHECK_WITH_INFO(
mDecodingMode.isBeamSearch(), "beamWidth > 1 is given, but decoder is not configured as BeamSearch");
TLLM_CHECK_WITH_INFO(
params.src_cache_indirection.has_value(), "src_cache_indirection is mandatory in beam search.");
TLLM_CHECK_WITH_INFO(
@ -312,29 +407,26 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
// Because we still not support batch beam search now, so we need to compute
// one by one if there are different runtime arguments.
const size_t dynamic_decode_batch_size = has_diff_runtime_args_ ? 1 : local_batch_size;
const int dynamic_decode_total_iteration = local_batch_size / dynamic_decode_batch_size;
const size_t dynamic_decode_batch_size = mHasDiffRuntimeArgs ? 1 : localBatchSize;
const int dynamic_decode_total_iteration = localBatchSize / dynamic_decode_batch_size;
for (uint32_t dynamic_ite = ite * dynamic_decode_total_iteration;
dynamic_ite < (ite + 1) * dynamic_decode_total_iteration; ++dynamic_ite)
for (uint32_t dynamic_ite = 0; dynamic_ite < dynamic_decode_total_iteration; ++dynamic_ite)
{
const int dynamic_id_offset = dynamic_ite * dynamic_decode_batch_size * beam_width;
const int dynamic_decode_vocab_size_units_offset = dynamic_id_offset * vocab_size_padded_;
const int dynamic_id_offset = dynamic_ite * dynamic_decode_batch_size * beamWidth;
const int dynamic_decode_vocab_size_units_offset = dynamic_id_offset * mVocabSizePadded;
auto const logits_offset = logits.slice(
{dynamic_decode_batch_size, logits.shape[1], logits.shape[2]}, dynamic_decode_vocab_size_units_offset);
auto const end_id_offset
= end_ids.slice({dynamic_decode_batch_size}, dynamic_ite * dynamic_decode_batch_size);
= endIds.slice({dynamic_decode_batch_size}, dynamic_ite * dynamic_decode_batch_size);
typename BaseBeamSearchLayer<T>::ForwardParams dynamic_decode_input_tensors{step, ite, logits_offset,
end_id_offset, *params.src_cache_indirection, static_cast<std::int32_t>(params.max_attention_window),
static_cast<std::int32_t>(params.sink_token_length), static_cast<std::int32_t>(max_seq_len)};
dynamic_decode_input_tensors.embedding_bias = params.embedding_bias;
static_cast<std::int32_t>(params.sink_token_length), static_cast<std::int32_t>(maxSeqLen)};
if (params.input_lengths)
{
dynamic_decode_input_tensors.input_lengths
= params.input_lengths->slice({dynamic_decode_batch_size * beam_width}, dynamic_id_offset);
= params.input_lengths->slice({dynamic_decode_batch_size * beamWidth}, dynamic_id_offset);
}
// common outputs
@ -345,122 +437,286 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
dynamic_decode_outputs.parent_ids_ptr = std::move(outputs.parent_ids_ptr);
dynamic_decode_outputs.sequence_length
= outputs.sequence_length->slice({dynamic_decode_batch_size * beam_width}, dynamic_id_offset);
= outputs.sequence_length->slice({dynamic_decode_batch_size * beamWidth}, dynamic_id_offset);
dynamic_decode_outputs.finished
= outputs.finished->slice({dynamic_decode_batch_size * beam_width}, dynamic_id_offset);
= outputs.finished->slice({dynamic_decode_batch_size * beamWidth}, dynamic_id_offset);
dynamic_decode_outputs.cum_log_probs
= outputs.cum_log_probs->slice({dynamic_decode_batch_size * beam_width}, dynamic_id_offset);
= outputs.cum_log_probs->slice({dynamic_decode_batch_size * beamWidth}, dynamic_id_offset);
dynamic_decode_outputs.beamHypotheses = outputs.beamHypotheses;
dynamic_decode_outputs.output_log_probs = outputs.output_log_probs_tiled;
// only OnlineBeamSearchLayer support beam_search_diversity_rate
// when beamHypotheses is used
mOnlineBeamsearchDecode->forward(
dynamic_decode_outputs, dynamic_decode_input_tensors, beam_search_workspace_0, beam_search_workspace_1);
std::swap(beam_search_workspace_0, beam_search_workspace_1);
mOnlineBeamSearchDecode->forward(dynamic_decode_outputs, dynamic_decode_input_tensors);
} // end of dynamic_ite
std::swap(mPenaltyWorkspaceDevice, mPenaltyWorkspacePrevDevice);
}
else
{ // beam_width == 1
{ // beamWidth == 1
TLLM_CHECK_WITH_INFO(
mDecodingMode.isTopKorTopP(), "beamWidth == 1 is given, but decoder is not configured as TopK or TopP");
// In sampling, we have supported batch sampling. So, we always compute all
// sentences once.
const size_t local_batch_offset = ite * local_batch_size * beam_width;
Tensor const logits_slice{
logits.slice({local_batch_size, beam_width, logits.shape[2]}, local_batch_offset * logits.shape[2])};
Tensor const end_id_slice{end_ids.slice({local_batch_size}, ite * local_batch_size)};
Tensor const logits_slice{logits.slice({localBatchSize, beamWidth, logits.shape[2]}, 0)};
Tensor const end_id_slice{endIds.slice({localBatchSize}, 0)};
typename BaseSamplingLayer<T>::ForwardParams decode_input_tensors{
step, ite, logits_slice, end_id_slice, static_cast<std::int32_t>(max_seq_len)};
step, ite, logits_slice, end_id_slice, static_cast<std::int32_t>(maxSeqLen)};
decode_input_tensors.embedding_bias = params.embedding_bias;
decode_input_tensors.finished = params.finished;
if (params.input_lengths)
{
auto& input_lengths = params.input_lengths.value();
decode_input_tensors.input_lengths
= input_lengths.slice({local_batch_size, beam_width}, local_batch_offset);
decode_input_tensors.input_lengths = input_lengths.slice({localBatchSize, beamWidth}, 0);
}
decode_input_tensors.batch_slots = params.batch_slots;
DecodingOutputParams decode_outputs(outputs.output_ids);
decode_outputs.output_ids_ptr = std::move(outputs.output_ids_ptr);
if (outputs.sequence_length)
{
decode_outputs.sequence_length
= outputs.sequence_length->slice({local_batch_size * beam_width}, local_batch_offset);
decode_outputs.sequence_length = outputs.sequence_length->slice({localBatchSize * beamWidth}, 0);
}
if (outputs.finished)
{
decode_outputs.finished = outputs.finished->slice({local_batch_size * beam_width}, local_batch_offset);
decode_outputs.finished = outputs.finished->slice({localBatchSize * beamWidth}, 0);
}
if (outputs.cum_log_probs)
{
decode_outputs.cum_log_probs
= outputs.cum_log_probs->slice({local_batch_size * beam_width}, local_batch_offset);
decode_outputs.cum_log_probs = outputs.cum_log_probs->slice({localBatchSize * beamWidth}, 0);
}
if (outputs.output_log_probs_tiled)
{
TLLM_CHECK(0 <= mCyclicStep && mCyclicStep < max_seq_len);
TLLM_CHECK(0 <= mCyclicStep && mCyclicStep < maxSeqLen);
Tensor& output_log_probs = outputs.output_log_probs_tiled.value();
size_t step_offset = mCyclicStep * batch_size * beam_width;
decode_outputs.output_log_probs
= output_log_probs.slice({1, local_batch_size * beam_width}, step_offset + local_batch_offset);
size_t step_offset = mCyclicStep * batchSize * beamWidth;
decode_outputs.output_log_probs = output_log_probs.slice({1, localBatchSize * beamWidth}, step_offset);
}
// Run topk / topp decode layers.
// Currently, we support batch sampling. If the runtime arguments are like
// topk = [4, 0, 4]. topp = [0.0, 0.5, 0.5]
// then topk_decode handles [4, x, 4 + 0.5]
// topp_decode handles [x, 0.5, x]
// where "x" are skipped.
mTopKDecode->forward(decode_outputs, decode_input_tensors, top_k_workspace);
mTopPDecode->forward(decode_outputs, decode_input_tensors, top_p_workspace);
// Run TopK + TopP decode layers.
mSamplingLayer->forward(decode_outputs, decode_input_tensors);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
if (params.stop_words_list)
template <typename T>
void DynamicDecodeLayer<T>::applyPenalties(OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlotsHost, int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto logitsPtrsHost = ITensor::slice(mLogitsPtrsHost, mCyclicStep, 1);
auto logitsPtrsHostData = reinterpret_cast<T const**>(runtime::bufferCast<int64_t>(*logitsPtrsHost));
for (int32_t bi = 0; bi < batchSize; bi++)
{
const size_t id_offset = ite * local_batch_size * beam_width;
const size_t stop_words_length = params.stop_words_list->shape[2];
invokeStopWordsCriterion(outputs.output_ids_ptr.template getPtr<const int*>(),
outputs.parent_ids_ptr.template getPtr<const int*>(),
params.stop_words_list->template getPtrWithOffset<const int>(
ite * local_batch_size * 2 * stop_words_length),
reinterpret_cast<FinishedState*>(
outputs.finished->template getPtrWithOffset<FinishedState::UnderlyingType>(id_offset)),
outputs.sequence_length->template getPtr<int>(), nullptr, stop_words_length, batch_size, beam_width,
max_seq_len, mStream);
if (params.logits_vec)
{
TLLM_CHECK_WITH_INFO(params.logits_vec->size() == batchSize,
"Logits vector size (%lu) is not equal to the batchSize (%lu)", params.logits_vec->size(), batchSize);
logitsPtrsHostData[bi] = params.logits_vec.value()[bi].template getPtr<T>();
}
else
{
logitsPtrsHostData[bi] = params.logits->template getPtrWithOffset<T>(bi * beamWidth * mVocabSizePadded);
}
}
int32_t const* inputLengths = nullptr;
if (params.input_lengths)
{
auto& input_lengths = params.input_lengths.value();
inputLengths = input_lengths.template getPtr<const int>();
}
auto* embeddingBias = params.embedding_bias ? params.embedding_bias->template getPtr<T const>() : nullptr;
#define GET_PENALTIES(capital_name, penalty_name, type) \
(mUse##capital_name \
&& !allOfBatchSlots(batchSlotsHost, m##capital_name.data(), batchSize, \
static_cast<type>(getDefaultPenaltyValue(DecodingPenaltyType::penalty_name)))) \
? m##capital_name##Device \
: nullptr;
auto* temperatures = GET_PENALTIES(Temperature, Temperature, float);
auto* repetitionPenalties = GET_PENALTIES(RepetitionPenalty, Repetition, float);
auto* presencePenalties = GET_PENALTIES(PresencePenalty, Presence, float);
auto* frequencyPenalties = GET_PENALTIES(FrequencyPenalty, Frequency, float);
auto* minLengths = GET_PENALTIES(MinLength, MinLength, int32_t);
#undef GET_PENALTIES
InvokeBatchApplyPenaltyParams<T> penaltyParams{reinterpret_cast<T const* const*>(logitsPtrsHostData),
mRuntimeLogitsDevice, embeddingBias, mPenaltyWorkspaceDevice, mPenaltyWorkspacePrevDevice, temperatures,
repetitionPenalties, presencePenalties, frequencyPenalties,
(mUseRepetitionPenalty || mUsePresencePenalty || mUseFrequencyPenalty), batchSize,
static_cast<int32_t>(beamWidth), static_cast<int32_t>(maxSeqLen), mVocabSize, mVocabSizePadded,
outputs.output_ids_ptr.template getPtr<const int*>(), outputs.parent_ids_ptr.template getPtr<const int*>(),
inputLengths, outputs.sequence_length->template getPtr<const int>(), minLengths,
params.end_ids.template getPtr<const int>(), batchSlots, mStream};
invokeBatchApplyPenalty(penaltyParams);
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void DynamicDecodeLayer<T>::banWords(Tensor& logits, OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded,
cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
banRepeatNGrams(logits, outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, vocabSizePadded, stream);
banBadWords(logits, outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, vocabSizePadded, stream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void DynamicDecodeLayer<T>::banRepeatNGrams(Tensor& logits, OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded,
cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const max_step = params.step;
if (params.no_repeat_ngram_size)
{
const int* noRepeatNgramSizeBuf = params.no_repeat_ngram_size.value().template getPtr<const int>();
invokeBanRepeatNgram(logits.template getPtr<T>(), outputs.output_ids_ptr.template getPtr<const int*>(),
reinterpret_cast<FinishedState*>(
params.finished.value_or(Tensor{}).template getPtr<FinishedState::UnderlyingType>()),
outputs.parent_ids_ptr.template getPtr<const int*>(), batchSlots,
outputs.sequence_length->template getPtr<int>(), batchSize, beamWidth, maxSeqLen,
params.no_repeat_ngram_size.value().template getPtr<const int>(), vocabSizePadded, max_step, stream);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void DynamicDecodeLayer<T>::banBadWords(Tensor& logits, OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded,
cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const maxBadWordsLength = params.max_bad_words_len;
if (maxBadWordsLength)
{
int32_t const** badWordsPtr = params.bad_words_ptr->template getPtr<int32_t const*>();
int32_t const* badWordsLens = params.bad_words_lengths->template getPtr<int32_t>();
invokeBanBadWords((T*) logits.template getPtr<T>(), outputs.output_ids_ptr.template getPtr<int32_t const*>(),
beamWidth > 1 ? outputs.parent_ids_ptr.template getPtr<int32_t const*>() : nullptr, batchSlots, batchSize,
beamWidth, badWordsPtr, badWordsLens, maxBadWordsLength, vocabSizePadded,
outputs.sequence_length->template getPtr<int>(), maxSeqLen, stream);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void DynamicDecodeLayer<T>::checkStopCriteria(OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
checkStopWordsStopCriteria(outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, stream);
checkMaxLengthStopCriteria(outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, stream);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void DynamicDecodeLayer<T>::checkStopWordsStopCriteria(OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const maxStopWordsLength = params.max_stop_words_len;
if (maxStopWordsLength)
{
invokeStopWordsCriterion(outputs.output_ids_ptr.template getPtr<int32_t const*>(),
outputs.parent_ids_ptr.template getPtr<int32_t const*>(),
params.stop_words_ptr->template getPtr<int32_t const*>(),
reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>()),
outputs.sequence_length->template getPtr<int32_t>(), batchSlots,
params.stop_words_lengths->template getPtr<int32_t const>(), maxStopWordsLength, batchSize, beamWidth,
maxSeqLen, stream);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void DynamicDecodeLayer<T>::checkMaxLengthStopCriteria(OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
if (params.sequence_limit_length)
{
invokeLengthCriterion(
reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>()),
outputs.finished_sum ? outputs.finished_sum->template getPtr<int>() : nullptr,
params.sequence_limit_length->template getPtr<const uint32_t>(),
outputs.sequence_length->template getPtr<int>(), nullptr, batch_size, beam_width, mStream);
outputs.sequence_length->template getPtr<int>(), batchSlots, batchSize, beamWidth, stream);
sync_check_cuda_error();
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void DynamicDecodeLayer<T>::prepareIdsPtrs(
OutputParams& outputs, int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto idsPtrHostSlice = ITensor::slice(mIdsPtrHost, mCyclicStep, 1);
auto idsPtrHost = reinterpret_cast<int32_t**>(runtime::bufferCast<int64_t>(*idsPtrHostSlice));
for (int bi = 0; bi < batchSize; bi++)
{
auto const batchSlot = batchSlots[bi];
idsPtrHost[batchSlot]
= outputs.output_ids.template getPtrWithOffset<int32_t>(batchSlot * beamWidth * maxSeqLen);
}
for (int bi = 0; bi < batchSize; bi++)
{
auto const batchSlot = batchSlots[bi];
if (beamWidth > 1)
{
idsPtrHost[mMaxBatchSize + batchSlot]
= outputs.parent_ids.value().template getPtrWithOffset<int32_t>(bi * beamWidth * maxSeqLen);
}
else
{
idsPtrHost[mMaxBatchSize + batchSlot] = mZeroParentIdsDevice + bi * beamWidth * maxSeqLen;
}
}
outputs.output_ids_ptr
= Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR, {mMaxBatchSize, beamWidth, maxSeqLen}, idsPtrHost);
outputs.parent_ids_ptr = Tensor(
MEMORY_GPU, DataType::TYPE_INT32_PTR, {mMaxBatchSize, beamWidth, maxSeqLen}, idsPtrHost + mMaxBatchSize);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template <typename T>
void DynamicDecodeLayer<T>::prepareOutputData(OutputParams& outputs, ForwardParams const& params,
runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize, size_t beamWidth,
size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto idsPtrHostSlice = ITensor::slice(idsPtrsHost, cyclicStep, 1);
auto idsPtrHost = reinterpret_cast<int32_t**>(runtime::bufferCast<int64_t>(*idsPtrHostSlice));
invokeCopyNextStepIds(outputs.newTokens.template getPtr<int>(), idsPtrHost,
outputs.sequence_length->template getPtr<int>(), nullptr, batch_size, beam_width, max_seq_len, mStream);
outputs.sequence_length->template getPtr<int>(), batchSlots, batchSize, beamWidth, maxSeqLen, stream);
// Transpose the output log probs from [max_seq_len, bs, beam_width] to [batch_size, beam_width, max_seq_len]
// Transpose the output log probs from [maxSeqLen, bs, beamWidth] to [batchSize, beamWidth, maxSeqLen]
if (outputs.output_log_probs_tiled)
{
auto logProbsMaxSeqLen = outputs.output_log_probs_tiled.value().shape[0];
invokeTransposeLogProbs(outputs.output_log_probs.value().template getPtr<float>(),
outputs.output_log_probs_tiled.value().template getPtr<float>(),
outputs.sequence_length->template getPtr<int>(), nullptr, batch_size, beam_width, logProbsMaxSeqLen,
mStream);
outputs.sequence_length->template getPtr<int>(), batchSlots, batchSize, beamWidth, logProbsMaxSeqLen,
stream);
}
mCyclicStep += 1;
mCyclicStep = mCyclicStep % max_seq_len;
sync_check_cuda_error();
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template class DynamicDecodeLayer<float>;

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2022-2022, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -20,9 +20,9 @@
#include "tensorrt_llm/kernels/beamSearchTopkKernels.h"
#include "tensorrt_llm/layers/baseLayer.h"
#include "tensorrt_llm/layers/onlineBeamSearchLayer.h"
#include "tensorrt_llm/layers/topKSamplingLayer.h"
#include "tensorrt_llm/layers/topPSamplingLayer.h"
#include "tensorrt_llm/layers/samplingLayer.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include <optional>
@ -41,12 +41,14 @@ struct BeamHypotheses;
namespace layers
{
template <typename T>
class DynamicDecodeLayer : public BaseLayer
{
public:
DynamicDecodeLayer(size_t max_batch_size, size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
std::shared_ptr<tc::IAllocator> allocator, cudaDeviceProp* cuda_device_prop);
DynamicDecodeLayer(runtime::DecodingMode const& mode, size_t max_batch_size, size_t max_beam_width,
size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, std::shared_ptr<tc::IAllocator> allocator,
cudaDeviceProp* cuda_device_prop);
~DynamicDecodeLayer() override;
DynamicDecodeLayer(DynamicDecodeLayer const& dynamic_decode_layer);
@ -83,14 +85,15 @@ public:
{
public:
ForwardParams(int step, int ite, int maxInputLength, int maxAttentionWindow, int sinkTokenLength,
int localBatchSize, tc::Tensor logits, tc::Tensor endIds)
int localBatchSize, tc::Tensor endIds)
: step{step}
, ite{ite}
, max_input_length{maxInputLength}
, max_attention_window{maxAttentionWindow}
, sink_token_length{sinkTokenLength}
, local_batch_size{localBatchSize}
, logits{std::move(logits)}
, max_stop_words_len{0}
, max_bad_words_len{0}
, end_ids{std::move(endIds)}
{
}
@ -102,9 +105,16 @@ public:
int max_attention_window;
int sink_token_length;
int local_batch_size;
tc::Tensor logits; // [batch_size, beam_width, vocab_size_padded], on gpu
int max_stop_words_len;
int max_bad_words_len;
tc::Tensor end_ids; // [batch_size], on gpu
// One of these two fields has to be set
// DynamicDecodeLayer::forward checks for it
// Need both of these fields to support legacy code during transition period to the batched decoder
std::optional<tc::Tensor> logits; // [batch_size, beam_width, vocab_size_padded], on gpu
std::optional<std::vector<tc::Tensor>> logits_vec; // [batch_size], on gpu
// optional parameters
std::optional<tc::Tensor> finished; // [batch_size * beam_width], optional
std::optional<tc::Tensor> src_cache_indirection; // [local_batch_size, beam_width, max_seq_len] - the k/v cache
@ -112,9 +122,12 @@ public:
std::optional<tc::Tensor> sequence_limit_length; // [batch_size], on gpu
std::optional<tc::Tensor> embedding_bias; // [vocab_size_padded], on gpu
std::optional<tc::Tensor> input_lengths; // [batch_size, beam_width], on gpu
std::optional<tc::Tensor> bad_words_list; // [2, bad_words_length] or [batch_size, 2, bad_words_length], on gpu
std::optional<tc::Tensor> stop_words_list; // [batch_size, 2, stop_words_length], on gpu
std::optional<tc::Tensor> bad_words_ptr; // [2, bad_words_length] or [batch_size, 2, bad_words_length], on gpu
std::optional<tc::Tensor> bad_words_lengths; // [batch_size], on gpu
std::optional<tc::Tensor> stop_words_ptr; // [batch_size][2, stop_words_length], on gpu
std::optional<tc::Tensor> stop_words_lengths; // [batch_size], on gpu
std::optional<tc::Tensor> no_repeat_ngram_size; // [batch_size], optional
std::optional<tc::Tensor> batch_slots; // [batch_size], optional, in pinned memory
};
class OutputParams
@ -148,31 +161,88 @@ public:
};
void forward(OutputParams& outputs, ForwardParams const& params);
void allocateBuffer(size_t batch_size, size_t beam_width, size_t max_seq_len);
void allocateBuffer();
void freeBuffer();
private:
void initialize();
void initializeLayers();
std::unique_ptr<OnlineBeamSearchLayer<T>> mOnlineBeamsearchDecode;
std::unique_ptr<TopKSamplingLayer<T>> mTopKDecode;
std::unique_ptr<TopPSamplingLayer<T>> mTopPDecode;
void setupLayers(size_t batchSize, size_t beamWidth, int32_t const* batchSlots, SetupParams const& setupParams);
void setupPenalties(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams);
size_t max_batch_size_;
size_t vocab_size_;
size_t vocab_size_padded_;
cudaDeviceProp* cuda_device_prop_;
int* zero_parent_ids = nullptr;
int* top_k_workspace = nullptr;
int* top_p_workspace = nullptr;
int* beam_search_workspace_0 = nullptr;
int* beam_search_workspace_1 = nullptr;
runtime::IBuffer::SharedPtr mIdsPtrHost;
void layersForward(tc::Tensor& logits, OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen);
bool has_diff_runtime_args_ = false;
void applyPenalties(OutputParams& outputs, ForwardParams const& params, int32_t const* batchSlotsHost,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen);
static void banWords(tc::Tensor& logits, OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded,
cudaStream_t stream);
static void banRepeatNGrams(tc::Tensor& logits, OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded,
cudaStream_t stream);
static void banBadWords(tc::Tensor& logits, OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded,
cudaStream_t stream);
static void checkStopCriteria(OutputParams& outputs, ForwardParams const& params, int32_t const* batchSlots,
size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream);
static void checkMaxLengthStopCriteria(OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream);
static void checkStopWordsStopCriteria(OutputParams& outputs, ForwardParams const& params,
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream);
void prepareIdsPtrs(
OutputParams& outputs, int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen);
static void prepareOutputData(OutputParams& outputs, ForwardParams const& params,
runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize, size_t beamWidth,
size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream);
private:
std::unique_ptr<OnlineBeamSearchLayer<T>> mOnlineBeamSearchDecode;
std::unique_ptr<SamplingLayer<T>> mSamplingLayer;
runtime::DecodingMode mDecodingMode;
size_t mMaxBatchSize;
size_t mMaxBeamWidth;
size_t mVocabSize;
size_t mVocabSizePadded;
cudaDeviceProp* mCudaDeviceProp;
int32_t* mZeroParentIdsDevice = nullptr;
int32_t* mPenaltyWorkspaceDevice = nullptr;
int32_t* mPenaltyWorkspacePrevDevice = nullptr;
runtime::ITensor::SharedPtr mIdsPtrHost;
runtime::ITensor::SharedPtr mLogitsPtrsHost;
float* mTemperatureDevice = nullptr;
float* mRepetitionPenaltyDevice = nullptr;
float* mPresencePenaltyDevice = nullptr;
float* mFrequencyPenaltyDevice = nullptr;
int32_t* mMinLengthDevice = nullptr;
T* mRuntimeLogitsDevice = nullptr;
std::vector<float> mTemperature;
std::vector<float> mRepetitionPenalty;
std::vector<float> mPresencePenalty;
std::vector<float> mFrequencyPenalty;
std::vector<int32_t> mMinLength;
bool mUseTemperature = false;
bool mUseRepetitionPenalty = false;
bool mUsePresencePenalty = false;
bool mUseFrequencyPenalty = false;
bool mUseMinLength = false;
bool mHasDiffRuntimeArgs = false;
int* h_pinned_finished_sum_ = nullptr;
int mCyclicStep = 0;
int32_t mCyclicStep = 0;
int32_t mRuntimeMaxSeqLen = 0;
int32_t mConfiguredBeamWidth = -1;
};
} // namespace layers

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -39,29 +39,31 @@ struct FillBuffers
template <typename T>
void operator()(std::optional<std::vector<T>> const& optParam, T const defaultValue, std::vector<T>& hostBuffer,
T* deviceBuffer, void* deviceTmpBuffer, const int* batchSlots) const
T* deviceBuffer, int32_t const* batchSlots) const
{
using tensorrt_llm::common::cudaAutoCpy;
hostBuffer.resize(batchSize);
if (!optParam)
for (size_t bi = 0; bi < batchSize; ++bi)
{
std::fill(std::begin(hostBuffer), std::end(hostBuffer), defaultValue);
auto const batchSlot = batchSlots ? batchSlots[bi] : bi;
if (!optParam)
{
hostBuffer[batchSlot] = defaultValue;
}
else if (optParam->size() == 1)
{
hostBuffer[batchSlot] = optParam->front();
}
else
{
TLLM_CHECK_WITH_INFO(optParam->size() == batchSize, "Argument vector size mismatch.");
hostBuffer[batchSlot] = optParam.value()[bi];
}
}
else if (optParam->size() == 1)
if (batchSlots)
{
std::fill(std::begin(hostBuffer), std::end(hostBuffer), optParam->front());
}
else
{
TLLM_CHECK_WITH_INFO(optParam->size() == batchSize, "Argument vector size mismatch.");
std::copy(optParam->begin(), optParam->end(), std::begin(hostBuffer));
}
if (deviceTmpBuffer && batchSlots)
{
cudaAutoCpy(reinterpret_cast<T*>(deviceTmpBuffer), hostBuffer.data(), batchSize, stream);
tensorrt_llm::kernels::invokeScatterDecodingParams(
reinterpret_cast<T*>(deviceTmpBuffer), deviceBuffer, batchSlots, batchSize, stream);
cudaAutoCpy(deviceBuffer, hostBuffer.data(), maxBatchSize, stream);
}
else
{
@ -70,6 +72,7 @@ struct FillBuffers
}
size_t batchSize;
size_t maxBatchSize;
cudaStream_t stream;
};

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -101,15 +101,13 @@ void OnlineBeamSearchLayer<T>::setup(size_t batch_size, SetupParams const& setup
BaseBeamSearchLayer<T>::setupBase(batch_size, setupParams);
allocateBuffer(batch_size);
mDiversityRate = setupParams.beam_search_diversity_rate.value_or(std::vector<float>(0.0f));
mLengthPenalty = setupParams.length_penalty.value_or(std::vector<float>(0.0f));
mDiversityRate.resize(batch_size);
mLengthPenalty.resize(batch_size);
FillBuffers const fillBuffers{batch_size, mStream};
FillBuffers const fillBuffers{batch_size, batch_size, mStream};
fillBuffers(setupParams.beam_search_diversity_rate, 0.0f, mDiversityRate, diversity_rates_buf_, (float*) nullptr,
(int*) nullptr);
fillBuffers(
setupParams.length_penalty, 0.0f, mLengthPenalty, length_penalties_buf_, (float*) nullptr, (int*) nullptr);
fillBuffers(setupParams.beam_search_diversity_rate, 0.0f, mDiversityRate, diversity_rates_buf_, (int*) nullptr);
fillBuffers(setupParams.length_penalty, 0.0f, mLengthPenalty, length_penalties_buf_, (int*) nullptr);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.

View File

@ -0,0 +1,215 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/layers/samplingLayer.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
#include <algorithm>
using namespace tensorrt_llm::common;
using namespace tensorrt_llm::kernels;
using namespace tensorrt_llm::runtime;
namespace tensorrt_llm
{
namespace layers
{
template <typename T>
void SamplingLayer<T>::allocateBuffer(size_t batchSize)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
mSamplingWorkspaceSize = 0;
if (mDecodingMode.isTopK())
{
mSamplingWorkspaceSize = std::max(mSamplingWorkspaceSize, mTopKDecode->getWorkspaceSize());
}
if (mDecodingMode.isTopP())
{
mSamplingWorkspaceSize = std::max(mSamplingWorkspaceSize, mTopPDecode->getWorkspaceSize());
}
std::array<size_t, 4> deviceBufferSizes;
deviceBufferSizes[0] = sizeof(curandState_t) * batchSize;
deviceBufferSizes[1] = sizeof(uint64_t) * batchSize;
deviceBufferSizes[2] = sizeof(bool) * batchSize;
deviceBufferSizes[3] = mSamplingWorkspaceSize;
mCurandStatesDevice = mAllocator->reMalloc(mCurandStatesDevice, deviceBufferSizes[0], false);
mRandomSeedsDevice = mAllocator->reMalloc(mRandomSeedsDevice, deviceBufferSizes[1], false);
mSkipDecodeDevice = mAllocator->reMalloc(mSkipDecodeDevice, deviceBufferSizes[2], false);
mSamplingWorkspaceDevice = mAllocator->reMalloc(mSamplingWorkspaceDevice, deviceBufferSizes[3], false);
auto const bytesAllocated = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), 0);
TLLM_LOG_DEBUG("SamplingLayer allocated %d bytes on GPU", bytesAllocated);
mAllocatedSize = bytesAllocated;
if (mDecodingMode.isTopK())
{
mAllocatedSize += mTopKDecode->getAllocatedSize();
}
if (mDecodingMode.isTopP())
{
mAllocatedSize += mTopPDecode->getAllocatedSize();
}
// host buffers.
mSkipDecodeHost = (bool*) std::realloc(mSkipDecodeHost, sizeof(bool) * batchSize);
TLLM_CHECK(mSkipDecodeHost != nullptr);
}
template <typename T>
void SamplingLayer<T>::freeBuffer()
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
mAllocator->free((void**) (&mCurandStatesDevice));
mAllocator->free((void**) (&mRandomSeedsDevice));
mAllocator->free((void**) (&mSkipDecodeDevice));
mAllocator->free((void**) (&mSamplingWorkspaceDevice));
std::free(mSkipDecodeHost);
}
template <typename T>
SamplingLayer<T>::SamplingLayer(DecodingMode const& mode, size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded,
cudaStream_t stream, std::shared_ptr<IAllocator> allocator, cudaDeviceProp* prop)
: BaseSamplingLayer<T>(maxBatchSize, vocabSize, vocabSizePadded, stream, std::move(allocator), nullptr)
, mDecodingMode(mode)
{
TLLM_CHECK_WITH_INFO(!mDecodingMode.isBeamSearch(), "Beam search mode has been requested from Sampling Layer");
TLLM_CHECK_WITH_INFO(mDecodingMode.isTopKorTopP(), "Requested mode is neither TopK nor TopP");
if (mDecodingMode.isTopK())
{
mTopKDecode
= std::make_unique<TopKSamplingLayer<T>>(maxBatchSize, vocabSize, vocabSizePadded, mStream, mAllocator);
}
if (mDecodingMode.isTopP())
{
mTopPDecode = std::make_unique<TopPSamplingLayer<T>>(
maxBatchSize, vocabSize, vocabSizePadded, mStream, mAllocator, prop, /* deterministic */ true);
}
allocateBuffer(maxBatchSize);
}
template <typename T>
void SamplingLayer<T>::setup(const size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
// If runtime argument has single random seed, using this random seed to
// initialize the random table of all sentences. If the argument has
// [batchSize] random seeds, initializing the random table by different
// random seeds respectively. If no random seed, initialize the random table
// of all sentences by 0 directly.
if (setupParams.randomSeed)
{
if (setupParams.randomSeed->size() == 1)
{
invokeCurandInitialize(
mCurandStatesDevice, batchSlots, batchSize, setupParams.randomSeed->front(), mStream);
sync_check_cuda_error();
}
else
{
TLLM_CHECK_WITH_INFO(setupParams.randomSeed->size() == batchSize, "Random seed vector size mismatch.");
cudaAutoCpy(mRandomSeedsDevice, setupParams.randomSeed->data(), batchSize, mStream);
invokeCurandBatchInitialize(mCurandStatesDevice, batchSlots, batchSize, mRandomSeedsDevice, mStream);
sync_check_cuda_error();
}
}
else
{
// Initialize curand states using the default seed 0.
invokeCurandInitialize(mCurandStatesDevice, batchSlots, batchSize, 0, mStream);
}
if (mDecodingMode.isTopK())
{
mTopKDecode->setup(batchSize, batchSlots, setupParams);
}
if (mDecodingMode.isTopP())
{
mTopPDecode->setup(batchSize, batchSlots, setupParams);
}
}
template <typename T>
void SamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams& inputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputs.logits.shape[0];
auto logits = inputs.logits.template getPtr<T>();
auto endIds = inputs.end_ids.template getPtr<const int>();
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : nullptr;
float* cumLogProbs = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr<float>() : nullptr;
float* outputLogProbs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr<float>() : nullptr;
FinishedState* finishedInput = (inputs.finished)
? reinterpret_cast<FinishedState*>(inputs.finished->template getPtr<FinishedState::UnderlyingType>())
: nullptr;
std::vector<int32_t> batchSlotsVec(batchSize);
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
auto batchSlotsHost = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : batchSlotsVec.data();
bool skipTopK = !mDecodingMode.isTopK();
if (!skipTopK)
{
skipTopK = allOfBatchSlots(batchSlotsHost, mTopKDecode->getSkipDecodeHost(), batchSize, true);
}
bool skipTopP = !mDecodingMode.isTopP();
if (!skipTopP)
{
skipTopP = allOfBatchSlots(batchSlotsHost, mTopPDecode->getSkipDecodeHost(), batchSize, true);
}
// Compute probabilities either for TopP or if cumLogProbs or outputLogProbs are specified
bool const skipSoftMax = skipTopP && cumLogProbs == nullptr && outputLogProbs == nullptr;
inputs.curand_states = mCurandStatesDevice;
inputs.sampling_workspace = mSamplingWorkspaceDevice;
inputs.probs_computed = !skipSoftMax;
invokeAddBiasSoftMax(logits, (T**) nullptr, logits, (T*) (nullptr), endIds, finishedInput, batchSlots, batchSize,
mMaxBatchSize, /* bw */ 1, mVocabSize, mVocabSizePadded, skipSoftMax, /* batchSlotLogits */ false, mStream);
sync_check_cuda_error();
if (!skipTopK)
{
mTopKDecode->forward(outputs, inputs);
}
if (!skipTopP)
{
mTopPDecode->forward(outputs, inputs);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
template class SamplingLayer<float>;
template class SamplingLayer<half>;
} // namespace layers
} // namespace tensorrt_llm

View File

@ -0,0 +1,91 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <curand_kernel.h>
#include "tensorrt_llm/common/tensor.h"
#include "tensorrt_llm/layers/baseSamplingLayer.h"
#include "tensorrt_llm/layers/decodingParams.h"
#include "tensorrt_llm/layers/topKSamplingLayer.h"
#include "tensorrt_llm/layers/topPSamplingLayer.h"
#include "tensorrt_llm/runtime/decodingMode.h"
namespace tc = tensorrt_llm::common;
namespace tensorrt_llm
{
namespace layers
{
template <typename T>
inline bool allOfBatchSlots(int32_t const* batchSlotsHost, T const* data, size_t batchSize, T value)
{
return std::all_of(batchSlotsHost, batchSlotsHost + batchSize, [&](int32_t b) { return data[b] == value; });
};
//! \brief Top class for sampling layers.
//! It sets up and executes TopKSamplingLayer and TopPSamplingLayer samplings
template <typename T>
class SamplingLayer : public BaseSamplingLayer<T>
{
public:
using Base = BaseSamplingLayer<T>;
using SetupParams = typename Base::SetupParams;
using ForwardParams = typename Base::ForwardParams;
SamplingLayer(runtime::DecodingMode const& mode, size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded,
cudaStream_t stream, std::shared_ptr<tensorrt_llm::common::IAllocator> allocator, cudaDeviceProp* prop);
~SamplingLayer() override = default;
void forward(DecodingOutputParams& outputs, ForwardParams& inputs) override;
void setup(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) override;
private:
using Base::mMaxBatchSize;
using Base::mVocabSize;
using Base::mVocabSizePadded;
using Base::mSamplingWorkspaceSize;
using Base::mAllocatedSize;
using Base::mStream;
using Base::mAllocator;
runtime::DecodingMode mDecodingMode;
void* mSamplingWorkspaceDevice = nullptr;
curandState_t* mCurandStatesDevice = nullptr;
uint64_t* mRandomSeedsDevice = nullptr;
bool* mSkipDecodeDevice = nullptr;
bool* mSkipDecodeHost = nullptr;
bool mSkipAny = false;
std::unique_ptr<TopKSamplingLayer<T>> mTopKDecode;
std::unique_ptr<TopPSamplingLayer<T>> mTopPDecode;
private:
void allocateBuffer(size_t batchSize);
void freeBuffer();
};
} // namespace layers
} // namespace tensorrt_llm

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -74,46 +74,41 @@ void TopKSamplingLayer<T>::allocateBuffer(size_t const batchSize)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
invokeTopKSampling<T>(nullptr, mSamplingWorkspaceSize, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, TOP_K_MAX, 1.0f, mVocabSizePadded, nullptr, nullptr, mStream, batchSize, mSkipDecodeDevice,
mNormalizeLogProbs);
nullptr, nullptr, TOP_K_MAX, 1.0f, mVocabSizePadded, nullptr, nullptr, mStream, batchSize, nullptr,
mNormalizeLogProbs, false);
std::array<size_t, 4> deviceBufferSizes;
deviceBufferSizes[0] = mSamplingWorkspaceSize;
deviceBufferSizes[1] = sizeof(uint32_t) * batchSize;
deviceBufferSizes[2] = sizeof(float) * batchSize;
deviceBufferSizes[3] = std::max(deviceBufferSizes[1], deviceBufferSizes[2]);
deviceBufferSizes[0] = sizeof(uint32_t) * batchSize;
deviceBufferSizes[1] = sizeof(float) * batchSize;
deviceBufferSizes[2] = sizeof(bool) * batchSize;
deviceBufferSizes[3] = std::max(deviceBufferSizes[0], deviceBufferSizes[1]);
mSamplingWorkspaceDevice = mAllocator->reMalloc(mSamplingWorkspaceDevice, deviceBufferSizes[0], false);
mRuntimeTopKDevice = mAllocator->reMalloc(mRuntimeTopKDevice, deviceBufferSizes[1], false);
mRuntimeTopPDevice = mAllocator->reMalloc(mRuntimeTopPDevice, deviceBufferSizes[2], false);
mRuntimeTopKDevice = mAllocator->reMalloc(mRuntimeTopKDevice, deviceBufferSizes[0], false);
mRuntimeTopPDevice = mAllocator->reMalloc(mRuntimeTopPDevice, deviceBufferSizes[1], false);
mSkipDecodeDevice = mAllocator->reMalloc(mSkipDecodeDevice, deviceBufferSizes[2], false);
mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[3], false);
auto const bytesAllocated = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), (size_t) 0);
TLLM_LOG_DEBUG("topKSamplingLayer allocated %lu bytes on GPU", (size_t) bytesAllocated);
mSkipDecodeHost = (bool*) std::realloc(mSkipDecodeHost, sizeof(bool) * batchSize);
mIsAllocateBuffer = true;
mAllocatedSize = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), 0);
TLLM_LOG_DEBUG("topKSamplingLayer allocated %lu bytes on GPU", mAllocatedSize);
}
template <typename T>
void TopKSamplingLayer<T>::freeBuffer()
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
if (mIsAllocateBuffer)
{
mAllocator->free((void**) (&mSamplingWorkspaceDevice));
mAllocator->free((void**) (&mRuntimeTopKDevice));
mAllocator->free((void**) (&mRuntimeTopPDevice));
mAllocator->free((void**) (&mSetupWorkspaceDevice));
}
BaseSamplingLayer<T>::freeBuffer();
mIsAllocateBuffer = false;
mAllocator->free((void**) (&mRuntimeTopKDevice));
mAllocator->free((void**) (&mRuntimeTopPDevice));
mAllocator->free((void**) (&mSkipDecodeDevice));
mAllocator->free((void**) (&mSetupWorkspaceDevice));
std::free(mSkipDecodeHost);
}
template <typename T>
void TopKSamplingLayer<T>::setup(size_t const batchSize, int const* batchSlots, SetupParams const& setupParams)
void TopKSamplingLayer<T>::setup(size_t const batchSize, int32_t const* batchSlots, SetupParams const& setupParams)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
BaseSamplingLayer<T>::setupBase(batchSize, batchSlots, setupParams);
uint32_t constexpr defaultTopK = 0;
auto runtimeTopK = setupParams.runtime_top_k.value_or(std::vector<uint32_t>{defaultTopK});
@ -161,29 +156,48 @@ void TopKSamplingLayer<T>::setup(size_t const batchSize, int const* batchSlots,
reinterpret_cast<float*>(mSetupWorkspaceDevice), mRuntimeTopPDevice, batchSlots, batchSize, mStream);
}
dim3 block(std::min((int) batchSize, 256));
dim3 grid(divUp((int) batchSize, (int) block.x));
// support topK up to TOP_K_MAX.
setupTopKRuntimeArgs<TOP_K_MAX><<<grid, block, 0, mStream>>>(batchSize, topK, mRuntimeTopKDevice, runtimeTopKSize,
topP, mRuntimeTopPDevice, runtimeTopPSize, mSkipDecodeDevice, batchSlots);
{
dim3 block(std::min((int) batchSize, 256));
dim3 grid(divUp((int) batchSize, (int) block.x));
// support topK up to TOP_K_MAX.
setupTopKRuntimeArgs<TOP_K_MAX><<<grid, block, 0, mStream>>>(batchSize, topK, mRuntimeTopKDevice,
runtimeTopKSize, topP, mRuntimeTopPDevice, runtimeTopPSize, mSkipDecodeDevice, batchSlots);
}
cudaAutoCpy(mSkipDecodeHost, mSkipDecodeDevice, mMaxBatchSize, mStream);
std::vector<uint32_t> runtimeTopKs(mMaxBatchSize);
cudaAutoCpy(runtimeTopKs.data(), mRuntimeTopKDevice, mMaxBatchSize, mStream);
// TODO(nkorobov): find maxTopK using batch slot
mRuntimeMaxTopK = *std::max_element(std::begin(runtimeTopKs), std::end(runtimeTopKs));
{
uint32_t maxTopK = 0;
for (size_t bi = 0; bi < batchSize; ++bi)
{
uint32_t bid = bi;
if (batchSlots)
{
bid = batchSlots[bi];
}
maxTopK = std::max(maxTopK, runtimeTopKs[bid]);
}
mRuntimeMaxTopK = std::max(mRuntimeMaxTopK, maxTopK);
}
}
template <typename T>
void TopKSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingParams const& inputs)
void TopKSamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams& inputs)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
auto const batchSize = inputs.logits.shape[0];
// in case of skip any, the logit value is already copied and processed.
auto* logits = mSkipAny ? mRuntimeLogitsDevice : inputs.logits.template getPtr<T>();
auto* endIds = inputs.end_ids.template getPtr<const int>();
auto* batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : nullptr;
auto logits = inputs.logits.template getPtr<T>();
auto endIds = inputs.end_ids.template getPtr<const int>();
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : nullptr;
auto curandStatesDevice = inputs.curand_states;
auto samplingWorkspaceDevice = inputs.sampling_workspace;
auto const probsComputed = inputs.probs_computed;
TLLM_CHECK_WITH_INFO(curandStatesDevice, "No curand states provided");
TLLM_CHECK_WITH_INFO(samplingWorkspaceDevice, "No sampling workspace provided");
FinishedState* finishedInput = (inputs.finished)
? reinterpret_cast<FinishedState*>(inputs.finished->template getPtr<FinishedState::UnderlyingType>())
@ -191,32 +205,16 @@ void TopKSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
FinishedState* finishedOutput = (outputs.finished)
? reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>())
: nullptr;
invokeAddBiasEndMask(
logits, (T*) (nullptr), endIds, finishedInput, batchSlots, batchSize, mVocabSize, mVocabSizePadded, mStream);
sync_check_cuda_error();
float* cumLogProbs = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr<float>() : nullptr;
float* outputLogProbs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr<float>() : nullptr;
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
{
invokeAddBiasSoftMax(logits, logits, (T*) (nullptr), endIds, finishedInput, batchSlots, batchSize, mVocabSize,
mVocabSizePadded, mStream);
sync_check_cuda_error();
}
int* sequenceLength = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<int>() : nullptr;
invokeBatchTopKSampling(mSamplingWorkspaceDevice, mSamplingWorkspaceSize, logits,
invokeBatchTopKSampling(samplingWorkspaceDevice, mSamplingWorkspaceSize, logits,
outputs.output_ids_ptr.template getPtr<int*>(), sequenceLength, finishedInput, finishedOutput, cumLogProbs,
outputLogProbs, mCurandStatesDevice,
(int) mRuntimeMaxTopK, // useless because mRuntimeTopKDevice is never
// nullptr. Keep for legacy.
(int*) (mRuntimeTopKDevice),
1.0f, // useless because mRuntimeTopPDevice is never nullptr. Keep for
// legacy.
outputLogProbs, curandStatesDevice, (int) mRuntimeMaxTopK, (int*) (mRuntimeTopKDevice), 1.0f,
mRuntimeTopPDevice, mVocabSizePadded, endIds, batchSlots, mStream, batchSize, mSkipDecodeDevice,
mNormalizeLogProbs);
mNormalizeLogProbs, probsComputed);
sync_check_cuda_error();
}
@ -228,13 +226,6 @@ TopKSamplingLayer<T>::TopKSamplingLayer(size_t maxBatchSize, size_t vocabSize, s
allocateBuffer(mMaxBatchSize);
}
template <typename T>
TopKSamplingLayer<T>::TopKSamplingLayer(TopKSamplingLayer<T> const& topKSamplingLayer)
: BaseSamplingLayer<T>(topKSamplingLayer)
{
allocateBuffer(mMaxBatchSize);
}
template <typename T>
TopKSamplingLayer<T>::~TopKSamplingLayer()
{

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -36,46 +36,44 @@ class TopKSamplingLayer : public BaseSamplingLayer<T>
public:
using Base = BaseSamplingLayer<T>;
using SetupParams = typename Base::SetupParams;
using ForwardParams = typename Base::ForwardParams;
TopKSamplingLayer(size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded, cudaStream_t stream,
std::shared_ptr<tensorrt_llm::common::IAllocator> allocator);
TopKSamplingLayer(TopKSamplingLayer<T> const& top_k_sampling_layer);
~TopKSamplingLayer();
void setup(size_t batchSize, int const* batch_slots, SetupParams const& setupParams) override;
void setup(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) override;
void forward(DecodingOutputParams& outputs, ForwardParams& inputs) override;
protected:
void runSampling(DecodingOutputParams& outputs, DecodingParams const& inputs) override;
void freeBuffer() override;
const bool* getSkipDecodeHost() const
{
return mSkipDecodeHost;
}
protected:
bool mNormalizeLogProbs = true;
uint32_t mRuntimeMaxTopK = 1;
uint32_t mRuntimeMaxTopK = 0;
uint32_t* mRuntimeTopKDevice = nullptr;
float* mRuntimeTopPDevice = nullptr;
void* mSetupWorkspaceDevice = nullptr;
bool* mSkipDecodeDevice = nullptr;
bool* mSkipDecodeHost = nullptr;
using Base::mMaxBatchSize;
using Base::mVocabSize;
using Base::mVocabSizePadded;
using Base::mSamplingWorkspaceSize;
using Base::mSamplingWorkspaceDevice;
using Base::mCurandStatesDevice;
using Base::mSkipDecodeDevice;
using Base::mSkipDecodeHost;
using Base::mSkipAny;
using Base::mRuntimeLogitsDevice;
using Base::mAllocatedSize;
using Base::mStream;
using Base::mAllocator;
using Base::mIsAllocateBuffer;
static constexpr uint32_t TOP_K_MAX = 1024;
private:
void allocateBuffer(size_t batchSize);
void freeBuffer();
};
} // namespace layers

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -34,34 +34,20 @@ namespace tensorrt_llm
namespace layers
{
static __global__ void set_topp_runtime_args(int batchSize, uint32_t top_k, uint32_t* top_ks, int top_ks_size,
float top_p, float* top_ps, int top_ps_size, bool* skip_decode, float* initial_top_p_buf, float* top_p_decay_buf,
float* top_p_min_buf, const int* batch_slots)
static __global__ void setTopPRuntimeArgs(int batchSize, uint32_t topK, uint32_t* topKs, int topKsSize, float topP,
float* topPs, int topPsSize, bool* skipDecode, const int* batchSlots, float* initialTopPBuf)
{
/**
* @brief Setup the runtime arguments for topp, broadcasting top_p to top_ps
and top_k to top_ks, verifying value ranges of top_p_decay/top_p_min.
*
* \param batchSize
* \param top_k
* \param top_ks [batchSize]
* \param top_ks_size
* \param top_p
* \param top_ps [batchSize]
* \param top_ps_size
* \param skip_decode [batchSize]
* \param initial_top_p_buf [batchSize]
* \param top_p_decay_buf [batchSize]
* \param top_p_min_buf [batchSize]
*
and top_k to top_ks.
*/
int index = blockIdx.x * blockDim.x + threadIdx.x;
for (int bi = index; bi < batchSize; bi += gridDim.x * blockDim.x)
{
auto const batch_slot = batch_slots != nullptr ? batch_slots[bi] : bi;
std::uint32_t k = top_ks_size > 1 ? top_ks[batch_slot] : top_k;
float p = top_ps_size > 1 ? top_ps[batch_slot] : top_p;
auto const batchSlot = batchSlots != nullptr ? batchSlots[bi] : bi;
std::uint32_t k = topKsSize > 1 ? topKs[batchSlot] : topK;
float p = topPsSize > 1 ? topPs[batchSlot] : topP;
if (k == 0 && p == 0.0f)
{
// TensorRT-LLM's topp implementation does not support topp = 0.0f, but it
@ -69,11 +55,11 @@ static __global__ void set_topp_runtime_args(int batchSize, uint32_t top_k, uint
// solution.
k = 1;
}
top_ks[batch_slot] = k;
top_ps[batch_slot] = p;
skip_decode[batch_slot] = k > 0;
topKs[batchSlot] = k;
topPs[batchSlot] = p;
skipDecode[batchSlot] = k > 0;
initial_top_p_buf[batch_slot] = top_ps[batch_slot];
initialTopPBuf[batchSlot] = topPs[batchSlot];
}
}
@ -81,10 +67,10 @@ template <typename T>
void TopPSamplingLayer<T>::allocateBuffer(size_t batchSize)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
if (is_deterministic_)
if (mIsDeterministic)
{
invokeTopPSampling<T>(nullptr, // workspace
mSamplingWorkspaceSize, cub_temp_storage_size_,
mSamplingWorkspaceSize, mCubTempStorageSize,
nullptr, // output_ids
nullptr, // sequence_length
nullptr, // finished_input_buffer
@ -92,8 +78,8 @@ void TopPSamplingLayer<T>::allocateBuffer(size_t batchSize)
nullptr, // cum_log_probs
nullptr, // output_log_probs
nullptr, // log_probs
topp_id_vals_buf_, topp_offset_buf_, begin_topp_offset_buf_, mCurandStatesDevice, batchSize,
mVocabSizePadded, nullptr, 0.f, mStream, mSkipDecodeDevice, nullptr);
mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, nullptr, batchSize, mVocabSizePadded, nullptr,
0.f, mStream, nullptr, nullptr);
}
else
{
@ -105,68 +91,63 @@ void TopPSamplingLayer<T>::allocateBuffer(size_t batchSize)
nullptr, // cum_log_probs
nullptr, // output_log_probs
nullptr, // log_probs)
mCurandStatesDevice, batchSize, mVocabSizePadded, nullptr, 0.f, mStream, air_topp_block_num_,
mSkipDecodeDevice, nullptr);
nullptr, batchSize, mVocabSizePadded, nullptr, 0.f, mStream, mAirTopPBlockNum, nullptr, nullptr);
}
std::array<size_t, 11> deviceBufferSizes;
deviceBufferSizes[0] = mSamplingWorkspaceSize;
deviceBufferSizes[1] = sizeof(int32_t) * batchSize * mVocabSizePadded;
deviceBufferSizes[0] = sizeof(int32_t) * batchSize * mVocabSizePadded;
deviceBufferSizes[1] = sizeof(int32_t) * (batchSize + 1);
deviceBufferSizes[2] = sizeof(int32_t) * (batchSize + 1);
deviceBufferSizes[3] = sizeof(int32_t) * (batchSize + 1);
deviceBufferSizes[4] = sizeof(uint32_t) * batchSize;
deviceBufferSizes[3] = sizeof(uint32_t) * batchSize;
deviceBufferSizes[4] = sizeof(float) * batchSize;
deviceBufferSizes[5] = sizeof(float) * batchSize;
deviceBufferSizes[6] = sizeof(float) * batchSize;
deviceBufferSizes[7] = sizeof(float) * batchSize;
deviceBufferSizes[8] = sizeof(float) * batchSize;
deviceBufferSizes[9] = sizeof(int32_t) * batchSize;
deviceBufferSizes[10] = *std::max_element(&deviceBufferSizes[4], &deviceBufferSizes[10]);
deviceBufferSizes[8] = sizeof(int32_t) * batchSize;
deviceBufferSizes[9] = sizeof(bool) * batchSize;
deviceBufferSizes[10] = *std::max_element(&deviceBufferSizes[3], &deviceBufferSizes[9]);
mSamplingWorkspaceDevice = mAllocator->reMalloc(mSamplingWorkspaceDevice, deviceBufferSizes[0], true);
topp_id_vals_buf_ = mAllocator->reMalloc(topp_id_vals_buf_, deviceBufferSizes[1], false);
topp_offset_buf_ = mAllocator->reMalloc(topp_offset_buf_, deviceBufferSizes[2], false);
begin_topp_offset_buf_ = mAllocator->reMalloc(begin_topp_offset_buf_, deviceBufferSizes[3], false);
runtime_top_k_buf_ = mAllocator->reMalloc(runtime_top_k_buf_, deviceBufferSizes[4], false);
runtime_top_p_buf_ = mAllocator->reMalloc(runtime_top_p_buf_, deviceBufferSizes[5], false);
initial_top_p_buf_ = mAllocator->reMalloc(initial_top_p_buf_, deviceBufferSizes[6], false);
top_p_decay_buf_ = mAllocator->reMalloc(top_p_decay_buf_, deviceBufferSizes[7], false);
top_p_min_buf_ = mAllocator->reMalloc(top_p_min_buf_, deviceBufferSizes[8], false);
top_p_reset_ids_buf_ = mAllocator->reMalloc(top_p_reset_ids_buf_, deviceBufferSizes[9], false);
setup_workspace_buf_ = mAllocator->reMalloc(setup_workspace_buf_, deviceBufferSizes[10], false);
mTopPIdValsDevice = mAllocator->reMalloc(mTopPIdValsDevice, deviceBufferSizes[0], false);
mTopPOffsetDevice = mAllocator->reMalloc(mTopPOffsetDevice, deviceBufferSizes[1], false);
mBeginTopPOffsetDevice = mAllocator->reMalloc(mBeginTopPOffsetDevice, deviceBufferSizes[2], false);
mRuntimeTopKDevice = mAllocator->reMalloc(mRuntimeTopKDevice, deviceBufferSizes[3], false);
mRuntimeTopPDevice = mAllocator->reMalloc(mRuntimeTopPDevice, deviceBufferSizes[4], false);
mInitialTopPDevice = mAllocator->reMalloc(mInitialTopPDevice, deviceBufferSizes[5], false);
mTopPDecayDevice = mAllocator->reMalloc(mTopPDecayDevice, deviceBufferSizes[6], false);
mTopPMinDevice = mAllocator->reMalloc(mTopPMinDevice, deviceBufferSizes[7], false);
mTopPResetIdsDevice = mAllocator->reMalloc(mTopPResetIdsDevice, deviceBufferSizes[8], false);
mSkipDecodeDevice = mAllocator->reMalloc(mSkipDecodeDevice, deviceBufferSizes[9], false);
mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[10], false);
auto const bytesAllocated = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), (size_t) 0);
TLLM_LOG_DEBUG("topPSamplingLayer allocated %lu bytes on GPU", (size_t) bytesAllocated);
mSkipDecodeHost = (bool*) std::realloc(mSkipDecodeHost, sizeof(bool) * batchSize);
std::fill(mSkipDecodeHost, mSkipDecodeHost + batchSize, true);
mIsAllocateBuffer = true;
mAllocatedSize = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), 0);
TLLM_LOG_DEBUG("topPSamplingLayer allocated %lu bytes on GPU", mAllocatedSize);
}
template <typename T>
void TopPSamplingLayer<T>::freeBuffer()
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
if (mIsAllocateBuffer)
{
mAllocator->free((void**) (&mSamplingWorkspaceDevice));
mAllocator->free((void**) (&topp_id_vals_buf_));
mAllocator->free((void**) (&topp_offset_buf_));
mAllocator->free((void**) (&begin_topp_offset_buf_));
mAllocator->free((void**) (&runtime_top_k_buf_));
mAllocator->free((void**) (&runtime_top_p_buf_));
mAllocator->free((void**) (&initial_top_p_buf_));
mAllocator->free((void**) (&top_p_decay_buf_));
mAllocator->free((void**) (&top_p_min_buf_));
mAllocator->free((void**) (&top_p_reset_ids_buf_));
mAllocator->free((void**) (&setup_workspace_buf_));
}
BaseSamplingLayer<T>::freeBuffer();
mIsAllocateBuffer = false;
mAllocator->free((void**) (&mTopPIdValsDevice));
mAllocator->free((void**) (&mTopPOffsetDevice));
mAllocator->free((void**) (&mBeginTopPOffsetDevice));
mAllocator->free((void**) (&mRuntimeTopKDevice));
mAllocator->free((void**) (&mRuntimeTopPDevice));
mAllocator->free((void**) (&mInitialTopPDevice));
mAllocator->free((void**) (&mTopPDecayDevice));
mAllocator->free((void**) (&mTopPMinDevice));
mAllocator->free((void**) (&mTopPResetIdsDevice));
mAllocator->free((void**) (&mSkipDecodeDevice));
mAllocator->free((void**) (&mSetupWorkspaceDevice));
std::free(mSkipDecodeHost);
}
template <typename T>
void TopPSamplingLayer<T>::setup(size_t const batchSize, int const* batchSlots, SetupParams const& setupParams)
void TopPSamplingLayer<T>::setup(size_t const batchSize, int32_t const* batchSlots, SetupParams const& setupParams)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
BaseSamplingLayer<T>::setupBase(batchSize, batchSlots, setupParams);
uint32_t const defaultTopK = 0;
auto runtimeTopK = setupParams.runtime_top_k.value_or(std::vector<uint32_t>{defaultTopK});
@ -186,7 +167,16 @@ void TopPSamplingLayer<T>::setup(size_t const batchSize, int const* batchSlots,
if (runtimeTopPSize == 0)
{
std::fill_n(mSkipDecodeHost, batchSize, true);
for (size_t bi = 0; bi < batchSize; ++bi)
{
int32_t bid = bi;
if (batchSlots)
{
bid = batchSlots[bi];
}
mSkipDecodeHost[bid] = true;
}
cudaAutoCpy(mSkipDecodeDevice, mSkipDecodeHost, mMaxBatchSize, mStream);
return;
}
@ -224,17 +214,17 @@ void TopPSamplingLayer<T>::setup(size_t const batchSize, int const* batchSlots,
{
TLLM_CHECK_WITH_INFO(runtimeTopK.size() == batchSize,
fmtstr("runtimeTopK.size() (%lu) == batchSize (%lu) is not satisfied!", runtimeTopK.size(), batchSize));
cudaAutoCpy(reinterpret_cast<uint32_t*>(setup_workspace_buf_), runtimeTopK.data(), batchSize, mStream);
cudaAutoCpy(reinterpret_cast<uint32_t*>(mSetupWorkspaceDevice), runtimeTopK.data(), batchSize, mStream);
invokeScatterDecodingParams(
reinterpret_cast<uint32_t*>(setup_workspace_buf_), runtime_top_k_buf_, batchSlots, batchSize, mStream);
reinterpret_cast<uint32_t*>(mSetupWorkspaceDevice), mRuntimeTopKDevice, batchSlots, batchSize, mStream);
}
if (runtimeTopPSize > 1)
{
TLLM_CHECK_WITH_INFO(runtimeTopP.size() == batchSize,
fmtstr("runtime_top_p.size() (%lu) == batchSize (%lu) is not satisfied!", runtimeTopP.size(), batchSize));
cudaAutoCpy(reinterpret_cast<float*>(setup_workspace_buf_), runtimeTopP.data(), batchSize, mStream);
cudaAutoCpy(reinterpret_cast<float*>(mSetupWorkspaceDevice), runtimeTopP.data(), batchSize, mStream);
invokeScatterDecodingParams(
reinterpret_cast<float*>(setup_workspace_buf_), runtime_top_p_buf_, batchSlots, batchSize, mStream);
reinterpret_cast<float*>(mSetupWorkspaceDevice), mRuntimeTopPDevice, batchSlots, batchSize, mStream);
}
auto fillBuffers
@ -246,50 +236,66 @@ void TopPSamplingLayer<T>::setup(size_t const batchSize, int const* batchSlots,
invokeScatterDecodingParams(deviceTmpBuffer, deviceBuffer, batchSlots, batchSize, mStream);
};
fillBuffers("top_p_decay", decayVec, reinterpret_cast<float*>(setup_workspace_buf_), top_p_decay_buf_);
fillBuffers("top_p_decay", decayVec, reinterpret_cast<float*>(mSetupWorkspaceDevice), mTopPDecayDevice);
fillBuffers("top_p_min", topPMinVec, reinterpret_cast<float*>(setup_workspace_buf_), top_p_min_buf_);
fillBuffers("top_p_min", topPMinVec, reinterpret_cast<float*>(mSetupWorkspaceDevice), mTopPMinDevice);
fillBuffers(
"top_p_reset_ids", topPResetIdsVec, reinterpret_cast<int32_t*>(setup_workspace_buf_), top_p_reset_ids_buf_);
"top_p_reset_ids", topPResetIdsVec, reinterpret_cast<int32_t*>(mSetupWorkspaceDevice), mTopPResetIdsDevice);
dim3 block(std::min((int) batchSize, 256));
dim3 grid(divUp((int) batchSize, (int) block.x));
set_topp_runtime_args<<<grid, block, 0, mStream>>>(batchSize, topK, runtime_top_k_buf_, runtimeTopKSize, topP,
runtime_top_p_buf_, runtimeTopPSize, mSkipDecodeDevice, initial_top_p_buf_, top_p_decay_buf_, top_p_min_buf_,
batchSlots);
sync_check_cuda_error();
{
dim3 block(std::min((int) batchSize, 256));
dim3 grid(divUp((int) batchSize, (int) block.x));
setTopPRuntimeArgs<<<grid, block, 0, mStream>>>(batchSize, topK, mRuntimeTopKDevice, runtimeTopKSize, topP,
mRuntimeTopPDevice, runtimeTopPSize, mSkipDecodeDevice, batchSlots, mInitialTopPDevice);
sync_check_cuda_error();
}
cudaAutoCpy(mSkipDecodeHost, mSkipDecodeDevice, mMaxBatchSize, mStream);
std::vector<float> runtimeTopPs(mMaxBatchSize);
cudaAutoCpy(runtimeTopPs.data(), mRuntimeTopPDevice, mMaxBatchSize, mStream);
{
float maxTopP = 0.f;
for (size_t bi = 0; bi < batchSize; ++bi)
{
int32_t bid = bi;
if (batchSlots)
{
bid = batchSlots[bi];
}
maxTopP = std::max(maxTopP, runtimeTopPs[bid]);
}
mRuntimeMaxTopP = std::max(mRuntimeMaxTopP, maxTopP);
}
std::vector<float> runtime_top_ps(mMaxBatchSize);
cudaAutoCpy(runtime_top_ps.data(), runtime_top_p_buf_, mMaxBatchSize, mStream);
// TODO(nkorobov): find maxTopP using batch slots
mRuntimeMaxTopP = *std::max_element(std::begin(runtime_top_ps), std::end(runtime_top_ps));
if (!is_deterministic_)
if (!mIsDeterministic)
{
int smCnt = mCudaDeviceProp->multiProcessorCount;
air_topp_block_num_ = calcAirTopPBlockNum<T, int, float>(batchSize, (int) mVocabSizePadded, smCnt);
mAirTopPBlockNum = calcAirTopPBlockNum<T, int, float>(batchSize, (int) mVocabSizePadded, smCnt);
}
}
template <typename T>
void TopPSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingParams const& inputs)
void TopPSamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams& inputs)
{
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
auto const batchSize = inputs.logits.shape[0];
// in case of skip any, the logit value is already copied and processed.
auto* logits = !mSkipAny ? inputs.logits.template getPtr<T>() : mRuntimeLogitsDevice;
auto* endIds = inputs.end_ids.template getPtr<const int>();
auto* batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : nullptr;
// Probabilities must be already computed instead of logits
auto probs = inputs.logits.template getPtr<T>();
auto endIds = inputs.end_ids.template getPtr<const int>();
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : nullptr;
auto curandStatesDevice = inputs.curand_states;
auto samplingWorkspaceDevice = inputs.sampling_workspace;
if (is_deterministic_)
TLLM_CHECK_WITH_INFO(curandStatesDevice, "No curand states provided");
TLLM_CHECK_WITH_INFO(samplingWorkspaceDevice, "No sampling workspace provided");
if (mIsDeterministic)
{
invokeTopPInitialize(
topp_id_vals_buf_, topp_offset_buf_, begin_topp_offset_buf_, batchSize, mVocabSizePadded, mStream);
mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, batchSize, mVocabSizePadded, mStream);
sync_check_cuda_error();
}
@ -299,33 +305,30 @@ void TopPSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
FinishedState* finishedOutput = (outputs.finished)
? reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>())
: nullptr;
invokeAddBiasSoftMax(logits, logits, (T*) (nullptr), endIds, finishedInput, batchSlots, batchSize, mVocabSize,
mVocabSizePadded, mStream);
sync_check_cuda_error();
float* cumLogProbs = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr<float>() : nullptr;
float* outputLogProbs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr<float>() : nullptr;
int* sequenceLength = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<int>() : nullptr;
if (is_deterministic_)
if (mIsDeterministic)
{
invokeBatchTopPSampling<T>(mSamplingWorkspaceDevice, mSamplingWorkspaceSize, cub_temp_storage_size_,
invokeBatchTopPSampling<T>(samplingWorkspaceDevice, mSamplingWorkspaceSize, mCubTempStorageSize,
outputs.output_ids_ptr.template getPtr<int*>(), sequenceLength, finishedInput, finishedOutput, cumLogProbs,
outputLogProbs, logits, topp_id_vals_buf_, topp_offset_buf_, begin_topp_offset_buf_, mCurandStatesDevice,
batchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP, runtime_top_p_buf_, mStream, mSkipDecodeDevice,
outputLogProbs, probs, mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, curandStatesDevice,
batchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP, mRuntimeTopPDevice, mStream, mSkipDecodeDevice,
batchSlots);
sync_check_cuda_error();
invokeComputeToppDecay(runtime_top_p_buf_, initial_top_p_buf_,
outputs.output_ids_ptr.template getPtr<const int*>(), top_p_decay_buf_, top_p_min_buf_,
top_p_reset_ids_buf_, sequenceLength, batchSlots, batchSize, mStream);
invokeComputeToppDecay(mRuntimeTopPDevice, mInitialTopPDevice,
outputs.output_ids_ptr.template getPtr<const int*>(), mTopPDecayDevice, mTopPMinDevice, mTopPResetIdsDevice,
sequenceLength, batchSlots, batchSize, mStream);
sync_check_cuda_error();
}
else
{
invokeBatchAirTopPSampling<T>(mSamplingWorkspaceDevice, mSamplingWorkspaceSize,
invokeBatchAirTopPSampling<T>(samplingWorkspaceDevice, mSamplingWorkspaceSize,
outputs.output_ids_ptr.template getPtr<int*>(), sequenceLength, finishedInput, finishedOutput, cumLogProbs,
outputLogProbs, logits, mCurandStatesDevice, batchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP,
runtime_top_p_buf_, mStream, air_topp_block_num_, mSkipDecodeDevice, batchSlots);
outputLogProbs, probs, curandStatesDevice, batchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP,
mRuntimeTopPDevice, mStream, mAirTopPBlockNum, mSkipDecodeDevice, batchSlots);
sync_check_cuda_error();
}
}
@ -334,14 +337,7 @@ template <typename T>
TopPSamplingLayer<T>::TopPSamplingLayer(std::size_t maxBatchSize, std::size_t vocabSize, std::size_t vocabSizePadded,
cudaStream_t stream, std::shared_ptr<IAllocator> allocator, cudaDeviceProp* prop, bool isDeterministic)
: BaseSamplingLayer<T>(maxBatchSize, vocabSize, vocabSizePadded, stream, std::move(allocator), prop)
, is_deterministic_(isDeterministic)
{
allocateBuffer(mMaxBatchSize);
}
template <typename T>
TopPSamplingLayer<T>::TopPSamplingLayer(TopPSamplingLayer<T> const& top_p_sampling_layer)
: BaseSamplingLayer<T>(top_p_sampling_layer)
, mIsDeterministic(isDeterministic)
{
allocateBuffer(mMaxBatchSize);
}

View File

@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
*
* Licensed under the Apache License, Version 2.0 (the "License");
@ -29,60 +29,60 @@ namespace layers
{
//! \brief Layer to randomly sample tokens from TopP logits.
//! Layer expects probs precomputed in "logits" tensor
template <typename T>
class TopPSamplingLayer : public BaseSamplingLayer<T>
{
public:
using Base = BaseSamplingLayer<T>;
using SetupParams = typename Base::SetupParams;
using ForwardParams = typename Base::ForwardParams;
TopPSamplingLayer(std::size_t maxBatchSize, std::size_t vocabSize, std::size_t vocabSizePadded, cudaStream_t stream,
std::shared_ptr<tensorrt_llm::common::IAllocator> allocator, cudaDeviceProp* prop, bool isDeterministic = true);
TopPSamplingLayer(TopPSamplingLayer<T> const& top_p_sampling_layer);
~TopPSamplingLayer();
void setup(std::size_t batchSize, int const* batch_slots, SetupParams const& setupParams) override;
void setup(std::size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) override;
void forward(DecodingOutputParams& outputs, ForwardParams& inputs) override;
const bool* getSkipDecodeHost() const
{
return mSkipDecodeHost;
}
protected:
void runSampling(DecodingOutputParams& outputs, DecodingParams const& inputs) override;
void freeBuffer() override;
uint32_t* mRuntimeTopKDevice = nullptr;
float* mRuntimeTopPDevice = nullptr;
float mRuntimeMaxTopP{0.f};
float* mInitialTopPDevice = nullptr;
float* mTopPDecayDevice = nullptr;
float* mTopPMinDevice = nullptr;
int32_t* mTopPResetIdsDevice = nullptr;
void* mSetupWorkspaceDevice = nullptr;
protected:
uint32_t* runtime_top_k_buf_ = nullptr;
float* runtime_top_p_buf_ = nullptr;
float mRuntimeMaxTopP;
float* initial_top_p_buf_ = nullptr;
float* top_p_decay_buf_ = nullptr;
float* top_p_min_buf_ = nullptr;
int32_t* top_p_reset_ids_buf_ = nullptr;
void* setup_workspace_buf_ = nullptr;
int32_t* topp_id_vals_buf_ = nullptr;
int32_t* topp_offset_buf_ = nullptr;
int32_t* begin_topp_offset_buf_ = nullptr;
std::size_t cub_temp_storage_size_;
bool is_deterministic_ = true;
int air_topp_block_num_;
int32_t* mTopPIdValsDevice = nullptr;
int32_t* mTopPOffsetDevice = nullptr;
int32_t* mBeginTopPOffsetDevice = nullptr;
bool* mSkipDecodeDevice = nullptr;
bool* mSkipDecodeHost = nullptr;
size_t mCubTempStorageSize;
bool mIsDeterministic = true;
int mAirTopPBlockNum;
using Base::mMaxBatchSize;
using Base::mVocabSize;
using Base::mVocabSizePadded;
using Base::mSamplingWorkspaceSize;
using Base::mSamplingWorkspaceDevice;
using Base::mCurandStatesDevice;
using Base::mSkipDecodeDevice;
using Base::mSkipDecodeHost;
using Base::mSkipAny;
using Base::mRuntimeLogitsDevice;
using Base::mAllocatedSize;
using Base::mStream;
using Base::mAllocator;
using Base::mIsAllocateBuffer;
using Base::mCudaDeviceProp;
private:
void allocateBuffer(std::size_t batchSize);
void freeBuffer();
};
} // namespace layers

Some files were not shown because too many files have changed in this diff Show More