mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#1098)
* Update TensorRT-LLM * update submodule * Remove unused binaries
This commit is contained in:
parent
0ab9d17a59
commit
0f041b7b57
2
3rdparty/cutlass
vendored
2
3rdparty/cutlass
vendored
@ -1 +1 @@
|
||||
Subproject commit 39c6a83f231d6db2bc6b9c251e7add77d68cbfb4
|
||||
Subproject commit 8236f30675bbe98f81d11c05764b77bfcb25b8cc
|
||||
@ -73,7 +73,7 @@ Run a preprocessing script to prepare/generate dataset into a json that gptManag
|
||||
|
||||
This tool can be used in 2 different modes of traffic generation.
|
||||
|
||||
1 – Dataset
|
||||
##### 1 – Dataset
|
||||
|
||||
“Prompt”, “Instruction” (optional) and “Answer” specified as sentences in a Json file
|
||||
|
||||
@ -90,7 +90,7 @@ python3 prepare_dataset.py \
|
||||
--max-input-len 300
|
||||
```
|
||||
|
||||
2 – Normal token length distribution
|
||||
##### 2 – Normal token length distribution
|
||||
|
||||
This mode allows the user to generate normal token length distributions with a mean and std deviation specified.
|
||||
For example, setting mean=100 and std dev=10 would generate requests where 95.4% of values are in <80,120> range following the normal probability distribution. Setting std dev=0 will generate all requests with the same mean number of tokens.
|
||||
@ -140,3 +140,17 @@ mpirun -n 2 ./benchmarks/gptManagerBenchmark \
|
||||
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
|
||||
--max_num_samples 500
|
||||
```
|
||||
|
||||
To emulate `gptSessionBenchmark` static batching, you can use the `--static_emulated_batch_size` and `--static_emulated-timeout` arguments.
|
||||
Given a `static_emulated_batch_size` of `n` the server will wait for `n` requests to arrive before submitting them to the batch manager at once. If the `static_emulated-timeout` (in ms) is reached before `n` requests are collected, the batch will be submitted prematurely with the current request count.
|
||||
|
||||
Take GPT-350M as an example for single GPU with static batching
|
||||
```
|
||||
./benchmarks/gptManagerBenchmark \
|
||||
--model gpt \
|
||||
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
|
||||
--type IFB \
|
||||
--static_emulated_batch_size 32 \
|
||||
--static_emulated_timeout 100 \
|
||||
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
|
||||
```
|
||||
|
||||
Binary file not shown.
@ -27,12 +27,13 @@
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
#include <cxxopts.hpp>
|
||||
#include <iostream>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
#include <utility>
|
||||
|
||||
using namespace tensorrt_llm::batch_manager;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
@ -104,7 +105,7 @@ public:
|
||||
|
||||
void push(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
TLLM_CHECK_WITH_INFO(!hasInProgressReqId(requestId) && !hasPendingReqId(requestId),
|
||||
"requestId %lu is already in progress, request is ignored.", requestId);
|
||||
|
||||
@ -119,14 +120,14 @@ public:
|
||||
/// has been marked in progress
|
||||
std::tuple<std::shared_ptr<WorkItem>, bool> pop()
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
|
||||
auto workItem = mPendingWorkItems.front();
|
||||
mPendingWorkItems.pop_front();
|
||||
mPendingWorkItemsReqIds.erase(workItem->requestId());
|
||||
|
||||
bool markedInProgress;
|
||||
mInProgressWorkItems.emplace(std::make_pair(workItem->requestId(), workItem));
|
||||
bool markedInProgress = false;
|
||||
mInProgressWorkItems.emplace(workItem->requestId(), workItem);
|
||||
markedInProgress = true;
|
||||
|
||||
return {workItem, markedInProgress};
|
||||
@ -134,19 +135,19 @@ public:
|
||||
|
||||
size_t numPendingWorkItems() const
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
return mPendingWorkItems.size();
|
||||
}
|
||||
|
||||
size_t numInProgressWorkItems() const
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
return mInProgressWorkItems.size();
|
||||
}
|
||||
|
||||
size_t size() const
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
return mPendingWorkItems.size() + mInProgressWorkItems.size();
|
||||
}
|
||||
|
||||
@ -154,7 +155,7 @@ public:
|
||||
/// @param requestId
|
||||
void markFinished(const uint64_t requestId)
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
if (hasInProgressReqId(requestId))
|
||||
{
|
||||
mInProgressWorkItems.erase(requestId);
|
||||
@ -175,12 +176,13 @@ private:
|
||||
|
||||
struct BenchInfo
|
||||
{
|
||||
BenchInfo() {}
|
||||
BenchInfo() = default;
|
||||
|
||||
BenchInfo(int _inputLength, int _outputLength, std::chrono::time_point<std::chrono::steady_clock> _start)
|
||||
: inputLength(_inputLength)
|
||||
, outputLength(_outputLength)
|
||||
, start(_start)
|
||||
, latency()
|
||||
{
|
||||
}
|
||||
|
||||
@ -194,9 +196,9 @@ struct BenchInfo
|
||||
class Recorder
|
||||
{
|
||||
public:
|
||||
Recorder(std::string opCsvFile)
|
||||
explicit Recorder(std::string opCsvFile)
|
||||
: mOpCsvFile(std::move(opCsvFile))
|
||||
{
|
||||
mOpCsvFile = opCsvFile;
|
||||
}
|
||||
|
||||
void initialize()
|
||||
@ -286,11 +288,11 @@ private:
|
||||
|
||||
std::chrono::time_point<std::chrono::steady_clock> mStart;
|
||||
std::chrono::time_point<std::chrono::steady_clock> mEnd;
|
||||
int mNumSamples;
|
||||
float mTotalLatency;
|
||||
float mSeqThroughput;
|
||||
float mAvgSeqLatency;
|
||||
float mTokenThroughput;
|
||||
int mNumSamples{};
|
||||
float mTotalLatency{};
|
||||
float mSeqThroughput{};
|
||||
float mAvgSeqLatency{};
|
||||
float mTokenThroughput{};
|
||||
std::string mOpCsvFile;
|
||||
}; // class Recorder
|
||||
|
||||
@ -299,25 +301,39 @@ class GptServer
|
||||
public:
|
||||
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams,
|
||||
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId, int waitSleep)
|
||||
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId, std::chrono::milliseconds waitSleep,
|
||||
std::optional<uint64_t> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs)
|
||||
: mRecorder(std::move(recorder))
|
||||
, mTerminateReqId(terminateReqId)
|
||||
, mWaitSleep(waitSleep)
|
||||
, mStaticEmulatedBatchSize(staticEmulatedBatchSize)
|
||||
, mEmulatedBatchEndTimestamp(
|
||||
std::chrono::steady_clock::now() + std::chrono::milliseconds(staticEmulatedTimeoutMs))
|
||||
, mStaticEmulatedTimeoutMs(staticEmulatedTimeoutMs)
|
||||
, mActiveCount(0)
|
||||
{
|
||||
ReturnBatchManagerStatsCallback iterationDataCallback{nullptr};
|
||||
if (optionalParams.logIterationData)
|
||||
ReturnBatchManagerStatsCallback iterationDataCallback = [this, &optionalParams](std::string const& log)
|
||||
{
|
||||
iterationDataCallback = [this](const std::string& s) { return TLLM_LOG_INFO(s); };
|
||||
}
|
||||
if (optionalParams.logIterationData)
|
||||
{
|
||||
TLLM_LOG_INFO(log);
|
||||
}
|
||||
|
||||
if (mStaticEmulatedBatchSize)
|
||||
{
|
||||
auto const json = nlohmann::json::parse(log);
|
||||
auto const activeRequests = json["Active Request Count"];
|
||||
TLLM_CHECK(activeRequests <= mStaticEmulatedBatchSize.value());
|
||||
}
|
||||
};
|
||||
|
||||
mBatchManager = std::make_shared<GptManager>(
|
||||
trtEnginePath, modelType, maxBeamWidth, schedulerPolicy,
|
||||
[this](int max_num_requests) { return getInferenceRequests(max_num_requests); },
|
||||
[this](uint64_t requestId, std::list<NamedTensor> response_tensors, bool final_response,
|
||||
const std::string& errMsg)
|
||||
[this](uint64_t requestId, std::list<NamedTensor> const& response_tensors, bool final_response,
|
||||
std::string const& errMsg)
|
||||
{ return sendResponse(requestId, response_tensors, final_response, errMsg); },
|
||||
nullptr, iterationDataCallback, optionalParams, terminateReqId);
|
||||
|
||||
mRecorder = recorder;
|
||||
mTerminateReqId = terminateReqId;
|
||||
mWaitSleep = waitSleep;
|
||||
}
|
||||
|
||||
~GptServer()
|
||||
@ -353,10 +369,9 @@ public:
|
||||
|
||||
void waitForEmpty() const
|
||||
{
|
||||
while (mWorkItemsQueue.size() > 0)
|
||||
while (!mWorkItemsQueue.empty())
|
||||
{
|
||||
std::chrono::milliseconds timespan(mWaitSleep);
|
||||
std::this_thread::sleep_for(timespan);
|
||||
std::this_thread::sleep_for(mWaitSleep);
|
||||
}
|
||||
}
|
||||
|
||||
@ -365,6 +380,11 @@ public:
|
||||
mBatchManager->waitUntilTerminate();
|
||||
}
|
||||
|
||||
void shutdown() const
|
||||
{
|
||||
mBatchManager->shutdown();
|
||||
}
|
||||
|
||||
// Return up to max_num_requests inference requests.
|
||||
std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(const int max_num_requests)
|
||||
{
|
||||
@ -376,17 +396,40 @@ public:
|
||||
auto rank = comm.getRank();
|
||||
if (rank == 0)
|
||||
{
|
||||
auto num_new_work_items = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
|
||||
auto numNewWorkItems = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
|
||||
static_cast<int64_t>(max_num_requests));
|
||||
if (world_size > 1)
|
||||
{
|
||||
comm.bcast(&num_new_work_items, 1, mpi::MpiType::kINT64, 0);
|
||||
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
|
||||
}
|
||||
|
||||
if (num_new_work_items > 0)
|
||||
bool readyForNextBatch = numNewWorkItems > 0;
|
||||
if (mStaticEmulatedBatchSize)
|
||||
{
|
||||
if (numNewWorkItems > 0)
|
||||
{
|
||||
bool const timeout = std::chrono::steady_clock::now() > mEmulatedBatchEndTimestamp;
|
||||
bool const previousBatchFinished = mActiveCount == 0;
|
||||
bool const haveEnoughForNextBatch = numNewWorkItems >= mStaticEmulatedBatchSize.value();
|
||||
readyForNextBatch = previousBatchFinished && (timeout || haveEnoughForNextBatch);
|
||||
}
|
||||
if (numNewWorkItems == 0 || readyForNextBatch)
|
||||
{
|
||||
// Timeout should only begin once we have at least 1 pending request.
|
||||
// Reset timeout when no requests are pending or we submit a new batch.
|
||||
mEmulatedBatchEndTimestamp
|
||||
= std::chrono::steady_clock::now() + std::chrono::milliseconds(mStaticEmulatedTimeoutMs);
|
||||
}
|
||||
}
|
||||
|
||||
if (readyForNextBatch)
|
||||
{
|
||||
int count = 0;
|
||||
while (count < num_new_work_items)
|
||||
// Only add a single batch at a time when emulating static batching
|
||||
auto const numItemsToAdd = std::min(
|
||||
numNewWorkItems, static_cast<int64_t>(mStaticEmulatedBatchSize.value_or(numNewWorkItems)));
|
||||
mActiveCount += numItemsToAdd;
|
||||
while (count < numItemsToAdd)
|
||||
{
|
||||
auto [workItem, markedInProgress] = mWorkItemsQueue.pop();
|
||||
|
||||
@ -420,14 +463,14 @@ public:
|
||||
else
|
||||
{
|
||||
// subordinate ranks hang until master rank sends work
|
||||
int64_t num_new_work_items;
|
||||
comm.bcast(&num_new_work_items, 1, mpi::MpiType::kINT64, 0);
|
||||
if (num_new_work_items > 0)
|
||||
int64_t numNewWorkItems = 0;
|
||||
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
|
||||
if (numNewWorkItems > 0)
|
||||
{
|
||||
std::vector<int64_t> packed;
|
||||
comm.bcast(packed, 0);
|
||||
int64_t* packed_ptr = packed.data();
|
||||
for (int64_t count = 0; count < num_new_work_items; ++count)
|
||||
for (int64_t count = 0; count < numNewWorkItems; ++count)
|
||||
{
|
||||
int64_t n = *(packed_ptr++);
|
||||
auto ir = InferenceRequest::deserialize(packed_ptr);
|
||||
@ -453,6 +496,7 @@ public:
|
||||
{
|
||||
mWorkItemsQueue.markFinished(requestId);
|
||||
mRecorder->recordEnd(requestId);
|
||||
mActiveCount--;
|
||||
}
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
@ -466,15 +510,27 @@ private:
|
||||
std::shared_ptr<Recorder> mRecorder;
|
||||
WorkItemsQueue mWorkItemsQueue;
|
||||
std::optional<uint64_t> mTerminateReqId;
|
||||
int mWaitSleep;
|
||||
std::chrono::milliseconds mWaitSleep;
|
||||
std::optional<int> mStaticEmulatedBatchSize;
|
||||
std::chrono::time_point<std::chrono::steady_clock> mEmulatedBatchEndTimestamp;
|
||||
int32_t mStaticEmulatedTimeoutMs;
|
||||
std::atomic<uint64_t> mActiveCount;
|
||||
|
||||
}; // class GptServer
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
std::tuple<std::vector<std::vector<int32_t>>, std::vector<int32_t>, std::vector<float>> parseWorkloadJson(
|
||||
std::filesystem::path const& datasetPath, int maxNumSamples)
|
||||
struct Sample
|
||||
{
|
||||
std::vector<int32_t> inputIds;
|
||||
int32_t outputLen;
|
||||
float delay;
|
||||
};
|
||||
|
||||
using Samples = std::vector<Sample>;
|
||||
|
||||
Samples parseWorkloadJson(std::filesystem::path const& datasetPath, int maxNumSamples)
|
||||
{
|
||||
auto constexpr allowExceptions = true;
|
||||
auto constexpr ignoreComments = true;
|
||||
@ -482,37 +538,29 @@ std::tuple<std::vector<std::vector<int32_t>>, std::vector<int32_t>, std::vector<
|
||||
std::ifstream jsonStream(datasetPath);
|
||||
auto json = nlohmann::json::parse(jsonStream, nullptr, allowExceptions, ignoreComments);
|
||||
|
||||
std::vector<std::vector<int32_t>> inputIds;
|
||||
std::vector<int32_t> outputLens;
|
||||
std::vector<float> delays;
|
||||
Samples samples;
|
||||
|
||||
long int numSamples = 0;
|
||||
|
||||
for (auto& sample : json["samples"])
|
||||
for (auto const& sample : json["samples"])
|
||||
{
|
||||
if (numSamples >= maxNumSamples)
|
||||
if (samples.size() >= maxNumSamples)
|
||||
break;
|
||||
numSamples++;
|
||||
inputIds.push_back(sample["input_ids"]);
|
||||
outputLens.push_back(sample["output_len"]);
|
||||
delays.push_back(sample["delay"]);
|
||||
samples.emplace_back(Sample{sample["input_ids"], sample["output_len"], sample["delay"]});
|
||||
}
|
||||
return std::make_tuple(inputIds, outputLens, delays);
|
||||
return samples;
|
||||
}
|
||||
|
||||
std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId,
|
||||
std::tuple<std::vector<std::vector<int32_t>>, std::vector<int32_t>, std::vector<float>> const& dataset,
|
||||
std::size_t sample_idx, ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId,
|
||||
ITensor::SharedPtr const& padId, BufferManager const& bufferManager,
|
||||
ITensor::SharedPtr const& returnContextLogits = nullptr, ITensor::SharedPtr const& returnGenerationLogits = nullptr)
|
||||
std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const& sample,
|
||||
ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId, ITensor::SharedPtr const& padId,
|
||||
BufferManager const& bufferManager, ITensor::SharedPtr const& returnContextLogits = nullptr,
|
||||
ITensor::SharedPtr const& returnGenerationLogits = nullptr)
|
||||
{
|
||||
auto request = std::make_shared<InferenceRequest>(reqId);
|
||||
auto const& inputIds = (std::get<0>(dataset))[sample_idx];
|
||||
auto const& inputIds = sample.inputIds;
|
||||
request->setInputIds(bufferManager.copyFrom(
|
||||
inputIds, ITensor::makeShape({static_cast<SizeType>(inputIds.size())}), MemoryType::kPINNED));
|
||||
auto const request_output_len = (std::get<1>(dataset))[sample_idx];
|
||||
auto const requestOutputLen = sample.outputLen;
|
||||
request->setMaxNewTokens(
|
||||
bufferManager.copyFrom(&request_output_len, ITensor::makeShape({1, 1}), MemoryType::kPINNED));
|
||||
bufferManager.copyFrom(&requestOutputLen, ITensor::makeShape({1, 1}), MemoryType::kPINNED));
|
||||
request->setBeamWidth(beamWidthTensor);
|
||||
if (eosId != nullptr)
|
||||
{
|
||||
@ -533,29 +581,15 @@ std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId,
|
||||
return request;
|
||||
}
|
||||
|
||||
void benchmarkGptManager([[maybe_unused]] std::string const& modelName, std::filesystem::path const& engineDir,
|
||||
std::string const& type, std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples,
|
||||
int beamWidth, int warmUp, const std::optional<int32_t>& eosId, const std::optional<int32_t>& padId,
|
||||
std::shared_ptr<nvinfer1::ILogger> const& logger, TrtGptModelOptionalParams const& optionalParams,
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy, int waitSleep, bool returnContextLogits,
|
||||
bool returnGenerationLogits)
|
||||
void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType,
|
||||
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
|
||||
const std::optional<int32_t>& eosId, const std::optional<int32_t>& padId,
|
||||
TrtGptModelOptionalParams const& optionalParams, batch_scheduler::SchedulerPolicy schedulerPolicy,
|
||||
std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits,
|
||||
std::optional<int> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs)
|
||||
{
|
||||
auto const worldConfig = WorldConfig::mpi();
|
||||
|
||||
TrtGptModelType modelType;
|
||||
if (type == "V1")
|
||||
{
|
||||
modelType = TrtGptModelType::V1;
|
||||
}
|
||||
else if (type == "IFB")
|
||||
{
|
||||
modelType = TrtGptModelType::InflightFusedBatching;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_ERROR("Unexpected batching type: %s", type.c_str());
|
||||
}
|
||||
|
||||
BufferManager bufferManager{std::make_shared<CudaStream>()}; // the stream is not used
|
||||
|
||||
ITensor::SharedPtr beamWidthTensor{
|
||||
@ -563,14 +597,13 @@ void benchmarkGptManager([[maybe_unused]] std::string const& modelName, std::fil
|
||||
|
||||
// Load dataset
|
||||
const auto samples = parseWorkloadJson(datasetPath, maxNumSamples);
|
||||
const auto numSamples = (std::get<0>(samples)).size();
|
||||
const auto numSamples = samples.size();
|
||||
|
||||
const int maxBeamWidth = beamWidth;
|
||||
auto recorder = std::make_shared<Recorder>(opCsvFile);
|
||||
uint64_t terminateReqId = numSamples + 1;
|
||||
|
||||
auto gptServer = std::make_shared<GptServer>(
|
||||
engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams, recorder, terminateReqId, waitSleep);
|
||||
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams,
|
||||
recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, staticEmulatedTimeoutMs);
|
||||
|
||||
ITensor::SharedPtr eosIdTensor{
|
||||
eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr};
|
||||
@ -594,7 +627,7 @@ void benchmarkGptManager([[maybe_unused]] std::string const& modelName, std::fil
|
||||
++reqId;
|
||||
if (i == terminateReqId)
|
||||
++reqId;
|
||||
auto request = makeRequest(reqId, samples, 0, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager);
|
||||
auto request = makeRequest(reqId, samples[0], beamWidthTensor, eosIdTensor, padIdTensor, bufferManager);
|
||||
gptServer->enqueue(request);
|
||||
}
|
||||
gptServer->waitForEmpty();
|
||||
@ -603,11 +636,10 @@ void benchmarkGptManager([[maybe_unused]] std::string const& modelName, std::fil
|
||||
recorder->initialize();
|
||||
for (std::size_t i = 0; i < numSamples; ++i)
|
||||
{
|
||||
auto request = makeRequest(i + 1, samples, i, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager,
|
||||
auto request = makeRequest(i + 1, samples[i], beamWidthTensor, eosIdTensor, padIdTensor, bufferManager,
|
||||
returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor);
|
||||
|
||||
gptServer->enqueue(request);
|
||||
auto delayInMs = int(std::get<2>(samples)[i] * 1000);
|
||||
auto delayInMs = static_cast<int>(samples[i].delay * 1000);
|
||||
|
||||
if (delayInMs != 0)
|
||||
{
|
||||
@ -635,6 +667,7 @@ int main(int argc, char* argv[])
|
||||
cxxopts::Options options(
|
||||
"TensorRT-LLM BatchManager Benchmark", "TensorRT-LLM BatchManager Benchmark for GPT and GPT-like models.");
|
||||
options.add_options()("h,help", "Print usage");
|
||||
// TODO(rkobus): remove because unused
|
||||
options.add_options()(
|
||||
"m,model", "Model name specified for engines.", cxxopts::value<std::string>()->default_value("gpt_350m"));
|
||||
options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>());
|
||||
@ -668,6 +701,11 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.",
|
||||
cxxopts::value<std::string>()->default_value("guaranteed_no_evict"));
|
||||
|
||||
options.add_options()("static_emulated_batch_size",
|
||||
"Emulate static batching performance with the provided batch size.", cxxopts::value<int>());
|
||||
options.add_options()("static_emulated_timeout",
|
||||
"Timeout (ms) before launching a partial batch in emulated static batching mode",
|
||||
cxxopts::value<int>()->default_value("500"));
|
||||
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
|
||||
cxxopts::value<std::string>()->default_value("error"));
|
||||
options.add_options()(
|
||||
@ -693,6 +731,20 @@ int main(int argc, char* argv[])
|
||||
|
||||
// Argument: Batching Type
|
||||
auto const type = result["type"].as<std::string>();
|
||||
TrtGptModelType modelType{TrtGptModelType::V1};
|
||||
if (type == "V1")
|
||||
{
|
||||
modelType = TrtGptModelType::V1;
|
||||
}
|
||||
else if (type == "IFB")
|
||||
{
|
||||
modelType = TrtGptModelType::InflightFusedBatching;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_ERROR("Unexpected batching type: %s", type.c_str());
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Argument: Dataset
|
||||
auto const datasetPath = result["dataset"].as<std::string>();
|
||||
@ -705,7 +757,7 @@ int main(int argc, char* argv[])
|
||||
auto const beamWidth = result["beam_width"].as<int>();
|
||||
|
||||
// Argument: wait_sleep
|
||||
auto const waitSleep = result["wait_sleep"].as<int>();
|
||||
auto const waitSleep = std::chrono::milliseconds(result["wait_sleep"].as<int>());
|
||||
|
||||
TrtGptModelOptionalParams optionalParams;
|
||||
// Argument: Max tokens in paged K-V Cache
|
||||
@ -765,6 +817,14 @@ int main(int argc, char* argv[])
|
||||
eosId = result["eos_id"].as<int>();
|
||||
}
|
||||
|
||||
std::optional<int> staticEmulatedBatchSize;
|
||||
// Argument: Static emulated batch size
|
||||
if (result.count("static_emulated_batch_size"))
|
||||
{
|
||||
staticEmulatedBatchSize = result["static_emulated_batch_size"].as<int>();
|
||||
}
|
||||
auto const staticEmulatedTimeout = result["static_emulated_timeout"].as<int>();
|
||||
|
||||
// Argument: Scheduler policy
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy;
|
||||
auto const schedulerPolicyArg = result["scheduler_policy"].as<std::string>();
|
||||
@ -815,9 +875,9 @@ int main(int argc, char* argv[])
|
||||
|
||||
try
|
||||
{
|
||||
benchmarkGptManager(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), type,
|
||||
datasetPath, opCsvFile, maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, logger,
|
||||
optionalParams, schedulerPolicy, waitSleep, returnContextLogits, returnGenerationLogits);
|
||||
benchmarkGptManager(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile, maxNumSamples,
|
||||
beamWidth, result["warm_up"].as<int>(), eosId, padId, optionalParams, schedulerPolicy, waitSleep,
|
||||
returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, staticEmulatedTimeout);
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
{
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@ -704,6 +704,7 @@ _allowed_configs = {
|
||||
hidden_act="relu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
@ -724,6 +725,7 @@ _allowed_configs = {
|
||||
hidden_act="relu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
@ -744,6 +746,7 @@ _allowed_configs = {
|
||||
hidden_act="relu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
@ -764,6 +767,7 @@ _allowed_configs = {
|
||||
hidden_act="relu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
@ -784,6 +788,7 @@ _allowed_configs = {
|
||||
hidden_act="relu",
|
||||
n_positions=512,
|
||||
num_buckets=32,
|
||||
max_distance=128,
|
||||
max_batch_size=8,
|
||||
max_encoder_input_len=1024,
|
||||
max_decoder_input_len=1,
|
||||
@ -947,7 +952,7 @@ _allowed_configs = {
|
||||
)),
|
||||
"baichuan_7b":
|
||||
ModelConfig(name="baichuan_7b",
|
||||
family="baichuan_7b",
|
||||
family="baichuan",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=32,
|
||||
@ -964,7 +969,7 @@ _allowed_configs = {
|
||||
)),
|
||||
"baichuan2_7b_chat":
|
||||
ModelConfig(name="baichuan2_7b_chat",
|
||||
family="baichuan_7b",
|
||||
family="baichuan",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=32,
|
||||
@ -981,7 +986,7 @@ _allowed_configs = {
|
||||
)),
|
||||
"baichuan_13b_chat":
|
||||
ModelConfig(name="baichuan_13b_chat",
|
||||
family="baichuan_13b",
|
||||
family="baichuan",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=40,
|
||||
@ -998,7 +1003,7 @@ _allowed_configs = {
|
||||
)),
|
||||
"baichuan2_13b_chat":
|
||||
ModelConfig(name="baichuan2_13b_chat",
|
||||
family="baichuan_13b",
|
||||
family="baichuan",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=40,
|
||||
|
||||
@ -613,36 +613,42 @@ def build_gpt(args):
|
||||
})
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.FalconForCausalLM(config)
|
||||
elif family == "baichuan_7b":
|
||||
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
num_kv_heads=None,
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
position_embedding_type=PositionEmbeddingType.rope_gpt_neox,
|
||||
dtype=kv_dtype,
|
||||
mlp_hidden_size=build_config['inter_size'],
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size),
|
||||
quant_mode=quant_mode)
|
||||
elif family == "baichuan_13b":
|
||||
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(
|
||||
num_layers=build_config['num_layers'],
|
||||
num_heads=build_config['num_heads'],
|
||||
num_kv_heads=None,
|
||||
hidden_size=build_config['hidden_size'],
|
||||
vocab_size=build_config['vocab_size'],
|
||||
hidden_act=build_config['hidden_act'],
|
||||
max_position_embeddings=build_config['n_positions'],
|
||||
position_embedding_type=PositionEmbeddingType.alibi,
|
||||
dtype=kv_dtype,
|
||||
mlp_hidden_size=build_config['inter_size'],
|
||||
mapping=tensorrt_llm.Mapping(world_size=world_size,
|
||||
tp_size=world_size),
|
||||
quant_mode=quant_mode)
|
||||
elif family == "baichuan":
|
||||
config = {
|
||||
'architecture':
|
||||
'BaichuanForCausalLM',
|
||||
'dtype':
|
||||
args.dtype,
|
||||
'logits_dtype':
|
||||
'float32',
|
||||
'vocab_size':
|
||||
build_config['vocab_size'],
|
||||
'max_position_embeddings':
|
||||
build_config['n_positions'],
|
||||
'hidden_size':
|
||||
build_config['hidden_size'],
|
||||
'num_hidden_layers':
|
||||
build_config['num_layers'],
|
||||
'num_attention_heads':
|
||||
build_config['num_heads'],
|
||||
'num_key_value_heads':
|
||||
build_config['num_heads'],
|
||||
'hidden_act':
|
||||
build_config['hidden_act'],
|
||||
'intermediate_size':
|
||||
build_config['inter_size'],
|
||||
'position_embedding_type':
|
||||
'alibi_with_scale' if '7b' in args.model else 'rope_gpt_neox',
|
||||
'quantization': {
|
||||
'group_size': 128
|
||||
},
|
||||
'mapping': {
|
||||
'world_size': world_size,
|
||||
'tp_size': world_size,
|
||||
},
|
||||
}
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(config)
|
||||
elif family == "internlm":
|
||||
config = {
|
||||
'architecture':
|
||||
@ -748,6 +754,8 @@ def build_gpt(args):
|
||||
network.plugin_config.set_context_fmha(ContextFMHAType.enabled)
|
||||
network.plugin_config.enable_remove_input_padding()
|
||||
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_moe_plugin(dtype=args.dtype)
|
||||
|
||||
if args.quantization is None or "fp8" not in args.quantization:
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
|
||||
|
||||
@ -787,7 +795,7 @@ def build_gpt(args):
|
||||
max_beam_width=max_beam_width)
|
||||
if family in [
|
||||
'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox',
|
||||
'gptj', "mamba"
|
||||
'gptj', 'mamba', 'baichuan'
|
||||
]:
|
||||
tensorrt_llm_model(**inputs)
|
||||
else:
|
||||
@ -1013,6 +1021,8 @@ def enc_dec_build_helper(component, config, args):
|
||||
tp_size=world_size,
|
||||
pp_size=1) # TP only
|
||||
|
||||
fp16_clamping = (args.dtype == 'float16') and ('t5' in family)
|
||||
|
||||
if component == 'encoder':
|
||||
tllm_model = tensorrt_llm.models.EncoderModel(
|
||||
num_layers=config['num_layers'],
|
||||
@ -1040,7 +1050,8 @@ def enc_dec_build_helper(component, config, args):
|
||||
dtype=dtype,
|
||||
use_parallel_embedding=False, # by default
|
||||
embedding_sharding_dim=0, # by default
|
||||
mapping=mapping)
|
||||
mapping=mapping,
|
||||
fp16_clamping=fp16_clamping)
|
||||
elif component == 'decoder':
|
||||
tllm_model = tensorrt_llm.models.DecoderModel(
|
||||
num_layers=config['num_layers'],
|
||||
@ -1073,7 +1084,8 @@ def enc_dec_build_helper(component, config, args):
|
||||
embedding_sharding_dim=0, # by default
|
||||
mapping=mapping,
|
||||
rescale_before_lm_head=rescale_before_lm_head,
|
||||
logits_dtype='float32') # by default
|
||||
logits_dtype='float32', # by default
|
||||
fp16_clamping=fp16_clamping)
|
||||
|
||||
# Module -> Network
|
||||
engine_name = get_engine_name(args.model, args.dtype, world_size,
|
||||
@ -1129,6 +1141,7 @@ def enc_dec_build_helper(component, config, args):
|
||||
hidden_size=hidden_size,
|
||||
head_size=builder_config.head_size,
|
||||
max_batch_size=builder_config.max_batch_size,
|
||||
max_beam_width=builder_config.max_beam_width,
|
||||
vocab_size=builder_config.vocab_size,
|
||||
num_layers=builder_config.num_layers,
|
||||
gpt_attention_plugin=network.plugin_config.gpt_attention_plugin,
|
||||
|
||||
@ -89,6 +89,7 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
hidden_size=hidden_size,
|
||||
head_size=config["builder_config"]["head_size"],
|
||||
max_batch_size=config["builder_config"]["max_batch_size"],
|
||||
max_beam_width=config["builder_config"]["max_beam_width"],
|
||||
vocab_size=config["builder_config"]["vocab_size"],
|
||||
num_layers=config["builder_config"]["num_layers"],
|
||||
gpt_attention_plugin=config["plugin_config"]
|
||||
|
||||
@ -95,6 +95,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
if args.mode == 'plugin':
|
||||
self.use_gpt_attention_plugin = True
|
||||
self.remove_input_padding = True
|
||||
self.use_moe_plugin = True
|
||||
elif args.mode == 'ootb-except-mha':
|
||||
self.use_gpt_attention_plugin = True
|
||||
|
||||
@ -110,6 +111,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
self.num_kv_heads = self.num_heads
|
||||
model_config = tensorrt_llm.runtime.ModelConfig(
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_beam_width=self.num_beams,
|
||||
vocab_size=self.vocab_size,
|
||||
num_layers=self.num_layers,
|
||||
num_heads=self.num_heads // self.world_size,
|
||||
|
||||
@ -172,8 +172,12 @@ get_filename_component(TRT_LLM_ROOT_DIR ${CMAKE_SOURCE_DIR} PATH)
|
||||
|
||||
set(3RDPARTY_DIR ${TRT_LLM_ROOT_DIR}/3rdparty)
|
||||
include_directories(
|
||||
${CUDAToolkit_INCLUDE_DIRS} ${CUDNN_ROOT_DIR}/include ${NCCL_INCLUDE_DIR}
|
||||
${3RDPARTY_DIR}/cutlass/include ${3RDPARTY_DIR}/NVTX/include
|
||||
${CUDAToolkit_INCLUDE_DIRS}
|
||||
${CUDNN_ROOT_DIR}/include
|
||||
${NCCL_INCLUDE_DIR}
|
||||
${3RDPARTY_DIR}/cutlass/include
|
||||
${3RDPARTY_DIR}/cutlass/tools/util/include
|
||||
${3RDPARTY_DIR}/NVTX/include
|
||||
${3RDPARTY_DIR}/json/include)
|
||||
|
||||
# TRT dependencies
|
||||
|
||||
@ -72,6 +72,8 @@ public:
|
||||
|
||||
BatchManagerErrorCode_t shutdown();
|
||||
|
||||
SizeType getNumActiveRequests();
|
||||
|
||||
virtual ~GptManager();
|
||||
|
||||
protected:
|
||||
|
||||
@ -354,6 +354,11 @@ public:
|
||||
mDraftLogits = draftLogits;
|
||||
}
|
||||
|
||||
SizeType getNumDraftTokens() const
|
||||
{
|
||||
return mDraftTokens->size();
|
||||
}
|
||||
|
||||
void setReturnContextLogits(const bool returnContextLogits)
|
||||
{
|
||||
mReturnContextLogits = returnContextLogits;
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
@ -35,13 +36,15 @@ public:
|
||||
|
||||
explicit TrtGptModelOptionalParams(KvCacheConfig const& kvCacheConfig = KvCacheConfig{},
|
||||
bool enableTrtOverlap = false, std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt,
|
||||
bool normalizeLogProbs = true, bool logIterationData = false, bool enableChunkedContext = false)
|
||||
bool normalizeLogProbs = true, bool logIterationData = false, bool enableChunkedContext = false,
|
||||
std::optional<runtime::DecodingMode> const& decodingMode = std::nullopt)
|
||||
: kvCacheConfig{kvCacheConfig}
|
||||
, enableTrtOverlap{enableTrtOverlap}
|
||||
, deviceIds(deviceIds)
|
||||
, normalizeLogProbs{normalizeLogProbs}
|
||||
, logIterationData{logIterationData}
|
||||
, enableChunkedContext{enableChunkedContext}
|
||||
, decodingMode{decodingMode}
|
||||
{
|
||||
}
|
||||
|
||||
@ -51,6 +54,7 @@ public:
|
||||
bool normalizeLogProbs;
|
||||
bool logIterationData;
|
||||
bool enableChunkedContext;
|
||||
std::optional<runtime::DecodingMode> decodingMode;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -21,6 +21,7 @@
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
@ -36,6 +37,8 @@ public:
|
||||
, maxAttentionWindow{maxAttentionWindow}
|
||||
, sinkTokenLength{sinkTokenLength}
|
||||
, maxBatchSize{maxBatchSize}
|
||||
, maxStopWordsLen{0}
|
||||
, maxBadWordsLen{0}
|
||||
, logits{std::move(logits)}
|
||||
, endIds{std::move(endIds)}
|
||||
{
|
||||
@ -49,20 +52,28 @@ public:
|
||||
SizeType maxAttentionWindow;
|
||||
SizeType sinkTokenLength;
|
||||
SizeType maxBatchSize;
|
||||
TensorPtr logits; // [batchSize, beamWidth, vocabSizePadded], on gpu
|
||||
TensorPtr endIds; // [batchSize * beamWidth], on gpu
|
||||
SizeType maxStopWordsLen; // The maximum value in the `stopWordsLens` tensor
|
||||
SizeType maxBadWordsLen; // The maximum value in the `badWordsLens` tensor
|
||||
TensorPtr logits; // [batchSize, beamWidth, vocabSizePadded], on gpu
|
||||
std::optional<std::vector<TensorPtr>>
|
||||
logitsVec; // vector of size [batchSize] contains logits of size [beamWidth, vocabSizePadded], on gpu
|
||||
TensorPtr endIds; // [maxBatchSize * beamWidth], on gpu
|
||||
|
||||
// optional parameters
|
||||
TensorPtr finished; // [maxBatchSize, beamWidth], finished states at current iteration.
|
||||
// If true for some request, the decoding step of it is skipped, on gpu
|
||||
TensorPtr sequenceLimitLength; // [maxBatchSize], on gpu
|
||||
TensorPtr embeddingBias; // [vocabSizePadded], on gpu
|
||||
TensorPtr embeddingBias; // [maxBatchSize, vocabSizePadded], on gpu
|
||||
TensorPtr lengths; // [maxBatchSize, beamWidth], on gpu
|
||||
TensorPtr badWordsList; // [2, badWordsLength] or [batchSize, 2, badWordsLength], on gpu
|
||||
TensorPtr badWordsList; // [2, badWordsLength] or [maxBatchSize, 2, badWordsLength], on gpu
|
||||
TensorPtr badWordsPtrs; // [maxBatchSize][2, badWordsLength], on gpu
|
||||
TensorPtr badWordsLens; // [maxBatchSize], on gpu
|
||||
TensorPtr stopWordsList; // [maxBatchSize, 2, stopWordsLength], on gpu
|
||||
TensorPtr stopWordsPtrs; // [maxBatchSize][2, stopWordsLength], on gpu
|
||||
TensorPtr stopWordsLens; // [maxBatchSize], on gpu
|
||||
TensorPtr noRepeatNgramSize; // [maxBatchSize], on gpu
|
||||
TensorPtr
|
||||
batchSlots; // [batchSize], optional, address map of the linear batch id to to the seq slots, int32_t, on gpu
|
||||
batchSlots; // [batchSize], optional, address map of the linear batch id to to the seq slots, int32_t, pinned
|
||||
|
||||
// parameters for beam search
|
||||
TensorPtr cacheIndirection; // [maxBatchSize, beamWidth, maxSeqLen] - the k/v cache index for beam search, on gpu
|
||||
|
||||
138
cpp/include/tensorrt_llm/runtime/decodingMode.h
Normal file
138
cpp/include/tensorrt_llm/runtime/decodingMode.h
Normal file
@ -0,0 +1,138 @@
|
||||
/*
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
|
||||
class DecodingMode
|
||||
{
|
||||
public:
|
||||
static auto constexpr None()
|
||||
{
|
||||
return DecodingMode{kNone};
|
||||
}
|
||||
|
||||
static auto constexpr TopK()
|
||||
{
|
||||
return DecodingMode{kTopK};
|
||||
}
|
||||
|
||||
static auto constexpr TopP()
|
||||
{
|
||||
return DecodingMode{kTopP};
|
||||
}
|
||||
|
||||
static auto constexpr TopKTopP()
|
||||
{
|
||||
return DecodingMode{kTopKTopP};
|
||||
}
|
||||
|
||||
static auto constexpr BeamSearch()
|
||||
{
|
||||
return DecodingMode{kBeamSearch};
|
||||
}
|
||||
|
||||
bool constexpr isNone()
|
||||
{
|
||||
return mState == 0;
|
||||
}
|
||||
|
||||
bool constexpr isTopK()
|
||||
{
|
||||
return anyBitSet(kTopK);
|
||||
}
|
||||
|
||||
bool constexpr isTopP()
|
||||
{
|
||||
return anyBitSet(kTopP);
|
||||
}
|
||||
|
||||
bool constexpr isTopKorTopP()
|
||||
{
|
||||
return anyBitSet(kTopKTopP);
|
||||
}
|
||||
|
||||
bool constexpr isTopKandTopP()
|
||||
{
|
||||
return allBitSet(kTopKTopP);
|
||||
}
|
||||
|
||||
bool constexpr isBeamSearch()
|
||||
{
|
||||
return anyBitSet(kBeamSearch);
|
||||
}
|
||||
|
||||
using UnderlyingType = uint8_t;
|
||||
|
||||
private:
|
||||
constexpr DecodingMode(UnderlyingType state)
|
||||
: mState(state)
|
||||
{
|
||||
}
|
||||
|
||||
// No mode specified. Config will be determined from the beam width of the first request at runtime
|
||||
// TopKTopP if beamWidth == 1, BeamSearch otherwise
|
||||
static UnderlyingType constexpr kNone{0};
|
||||
static UnderlyingType constexpr kTopK{1u << 0};
|
||||
static UnderlyingType constexpr kTopP{1u << 1};
|
||||
static UnderlyingType constexpr kBeamSearch{1u << 2};
|
||||
static UnderlyingType constexpr kTopKTopP{kTopK | kTopP};
|
||||
|
||||
bool constexpr anyBitSet(UnderlyingType bits) const
|
||||
{
|
||||
return (mState & bits) != 0;
|
||||
}
|
||||
|
||||
bool constexpr allBitSet(UnderlyingType bits) const
|
||||
{
|
||||
return (mState & bits) == bits;
|
||||
}
|
||||
|
||||
UnderlyingType mState{};
|
||||
};
|
||||
|
||||
static_assert(DecodingMode::None().isNone());
|
||||
static_assert(!DecodingMode::None().isTopK());
|
||||
static_assert(!DecodingMode::None().isTopP());
|
||||
static_assert(!DecodingMode::None().isBeamSearch());
|
||||
|
||||
static_assert(DecodingMode::TopK().isTopK());
|
||||
static_assert(DecodingMode::TopK().isTopKorTopP());
|
||||
static_assert(!DecodingMode::TopK().isTopKandTopP());
|
||||
static_assert(!DecodingMode::TopK().isTopP());
|
||||
static_assert(!DecodingMode::TopK().isBeamSearch());
|
||||
|
||||
static_assert(DecodingMode::TopP().isTopP());
|
||||
static_assert(DecodingMode::TopP().isTopKorTopP());
|
||||
static_assert(!DecodingMode::TopP().isTopKandTopP());
|
||||
static_assert(!DecodingMode::TopP().isTopK());
|
||||
static_assert(!DecodingMode::TopP().isBeamSearch());
|
||||
|
||||
static_assert(DecodingMode::TopKTopP().isTopK());
|
||||
static_assert(DecodingMode::TopKTopP().isTopP());
|
||||
static_assert(DecodingMode::TopKTopP().isTopKorTopP());
|
||||
static_assert(DecodingMode::TopKTopP().isTopKandTopP());
|
||||
static_assert(!DecodingMode::TopKTopP().isBeamSearch());
|
||||
|
||||
static_assert(DecodingMode::BeamSearch().isBeamSearch());
|
||||
static_assert(!DecodingMode::BeamSearch().isTopKorTopP());
|
||||
|
||||
} // namespace runtime
|
||||
} // namespace tensorrt_llm
|
||||
@ -75,7 +75,7 @@ public:
|
||||
// Set to true by decoding if any of the stop conditions are met or if DecodingInput.finished is
|
||||
// true. In beam search and to determine whether to stop according to
|
||||
// DecodingInput.sequenceLimitLength, on gpu
|
||||
TensorPtr finishedSum; // [1], the sum of finished sequences, in pinned memory
|
||||
TensorPtr finishedSum; // [batchSize], the sum of finished sequences per request, in pinned memory
|
||||
|
||||
// mandatory parameters for beam search
|
||||
TensorPtr logProbs; // [batchSize, beamWidth, maxSeqLen], must be float*, on gpu
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "tensorrt_llm/common/cudaAllocator.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/decodingInput.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
#include "tensorrt_llm/runtime/decodingOutput.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
#include <curand_kernel.h>
|
||||
@ -44,9 +45,13 @@ namespace runtime
|
||||
class IGptDecoder
|
||||
{
|
||||
public:
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
|
||||
virtual ~IGptDecoder() = default;
|
||||
|
||||
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength) = 0;
|
||||
virtual void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength,
|
||||
std::optional<TensorPtr> const& batchSlots = std::nullopt)
|
||||
= 0;
|
||||
|
||||
virtual bool forward(DecodingOutput& output, DecodingInput const& input) = 0;
|
||||
|
||||
@ -58,18 +63,19 @@ public:
|
||||
|
||||
virtual const SamplingConfig& getSamplingConfig() = 0;
|
||||
|
||||
static void acceptDraftTokensByIds(const ITensor& targetTokenIds, const ITensor& draftTokenIds,
|
||||
const ITensor& contextLengths, const ITensor& numDraftTokens, ITensor& sequenceLengths,
|
||||
const ITensor& finishedVec, ITensor& finishedFinal, ITensor& finishedSum,
|
||||
static void acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds,
|
||||
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths,
|
||||
ITensor const& finishedVec, ITensor& finishedFinal, ITensor& finishedSum, ITensor const& batchSlots,
|
||||
BufferManager::CudaStreamPtr const& stream);
|
||||
|
||||
static void acceptDraftTokensByLogits(ITensor& draftLogits, const ITensor& targetLogits, ITensor& draftProbs,
|
||||
ITensor& targetProbs, const ITensor& numDraftTokens, ITensor& finished, SizeType vocabSize,
|
||||
SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
|
||||
static void acceptDraftTokensByLogits(ITensor& draftLogits, ITensor const& targetLogits, ITensor& draftProbs,
|
||||
ITensor& targetProbs, ITensor const& numDraftTokens, ITensor& finished, ITensor const& batchSlots,
|
||||
SizeType vocabSize, SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
|
||||
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream);
|
||||
|
||||
static std::unique_ptr<IGptDecoder> create(nvinfer1::DataType dtype, size_t maxBatchSize, size_t vocabSize,
|
||||
size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream);
|
||||
static std::unique_ptr<IGptDecoder> create(DecodingMode const& mode, nvinfer1::DataType dtype, size_t maxBatchSize,
|
||||
size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
|
||||
BufferManager::CudaStreamPtr const& stream);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -80,9 +86,11 @@ public:
|
||||
using CudaStreamPtr = BufferManager::CudaStreamPtr;
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
|
||||
GptDecoder(size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded, CudaStreamPtr const& stream);
|
||||
GptDecoder(DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize,
|
||||
size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream);
|
||||
|
||||
void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength) override;
|
||||
void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength,
|
||||
std::optional<TensorPtr> const& batchSlots = std::nullopt) override;
|
||||
|
||||
bool forward(DecodingOutput& output, DecodingInput const& input) override;
|
||||
|
||||
@ -105,15 +113,18 @@ private:
|
||||
SamplingConfig mSamplingConfig;
|
||||
};
|
||||
|
||||
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(nvinfer1::DataType dtype, size_t maxBatchSize, size_t vocabSize,
|
||||
size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream)
|
||||
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(DecodingMode const& mode, nvinfer1::DataType dtype,
|
||||
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
|
||||
BufferManager::CudaStreamPtr const& stream)
|
||||
{
|
||||
switch (dtype)
|
||||
{
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
return std::make_unique<GptDecoder<float>>(maxBatchSize, vocabSize, vocabSizePadded, stream);
|
||||
return std::make_unique<GptDecoder<float>>(
|
||||
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, maxSequenceLength, stream);
|
||||
case nvinfer1::DataType::kHALF:
|
||||
return std::make_unique<GptDecoder<half>>(maxBatchSize, vocabSize, vocabSizePadded, stream);
|
||||
return std::make_unique<GptDecoder<half>>(
|
||||
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, maxSequenceLength, stream);
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
@ -41,20 +41,21 @@ class GptDecoderBatch : public IGptDecoderBatch
|
||||
public:
|
||||
using CudaStreamPtr = std::shared_ptr<CudaStream>;
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
using SharedConstPtr = ITensor::SharedConstPtr;
|
||||
|
||||
GptDecoderBatch(std::size_t vocabSize, std::size_t vocabSizePadded, CudaStreamPtr stream);
|
||||
|
||||
//! Setup the decoder before calling `forward()`
|
||||
void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow, SizeType sinkTokenLength,
|
||||
SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype) override;
|
||||
|
||||
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
|
||||
void newRequest(
|
||||
SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig) override;
|
||||
void setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
|
||||
SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep, bool fusedDecoder,
|
||||
nvinfer1::DataType dtype) override;
|
||||
|
||||
void newBatch(
|
||||
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig) override;
|
||||
|
||||
void newRequests(std::vector<SizeType> const& seqSlots, std::vector<decoder_batch::Request> const& requests,
|
||||
std::vector<SamplingConfig> const& samplingConfigs) override;
|
||||
|
||||
TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) override;
|
||||
|
||||
void forwardSync(decoder_batch::Token const& e) override;
|
||||
@ -161,6 +162,9 @@ private:
|
||||
//! @brief Gather final beam search results for request `batchIdx`.
|
||||
CudaEvent postProcessRequest(SizeType batchIdx) const;
|
||||
|
||||
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
|
||||
void newRequest(SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
|
||||
|
||||
private:
|
||||
std::size_t const mVocabSize;
|
||||
std::size_t const mVocabSizePadded;
|
||||
@ -180,8 +184,6 @@ private:
|
||||
DecodingInputPtr mJointDecodingInput;
|
||||
DecodingOutputPtr mJointDecodingOutput;
|
||||
|
||||
std::vector<TensorPtr> mDraftTokenIds;
|
||||
std::vector<TensorPtr> mDraftLogits;
|
||||
std::vector<bool> mAcceptByLogits;
|
||||
TensorPtr mNumDraftTokens;
|
||||
TensorPtr mCurandStates;
|
||||
@ -193,16 +195,28 @@ private:
|
||||
std::vector<SizeType> mBeamWidths;
|
||||
std::vector<SizeType> mGeneratedTokensPerStep;
|
||||
|
||||
TensorPtr mFinishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states of type FinishedState
|
||||
// for each generated token of maxTokensPerStep, on gpu
|
||||
TensorPtr mDraftProbs; // [batchSize, maxDraftTokens, beamWidth, vocabPadded], temporary data for speculative
|
||||
// decoding accept by logits kernel, on gpu
|
||||
TensorPtr mTargetProbs; // [batchSize, maxDraftTokens+1, beamWidth, vocabPadded], temporary data for speculative
|
||||
// decoding accept by logits kernel, on gpu
|
||||
TensorPtr mFinishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states of type FinishedState
|
||||
// for each generated token of maxTokensPerStep, on gpu
|
||||
TensorPtr mDraftProbs; // [batchSize, maxDraftTokens+1, beamWidth, vocabPadded], temporary data for speculative
|
||||
// decoding accept by logits kernel, on gpu
|
||||
TensorPtr mTargetProbs; // [batchSize, maxDraftTokens+1, beamWidth, vocabPadded], temporary data for speculative
|
||||
// decoding accept by logits kernel, on gpu
|
||||
TensorPtr mDraftTokenIds; // [batchSize, maxDraftTokens+1], draft token indices, on gpu
|
||||
TensorPtr mDraftLogits; // [batchSize, maxDraftTokens+1, vocabSizePadded], draft token logits, on gpu
|
||||
|
||||
TensorPtr mBatchSlotsSetup; // [maxBatchSize], int32_t, address map, pinned
|
||||
TensorPtr mBatchSlotsDecoder; // [maxBatchSize], int32_t, address map, pinned
|
||||
TensorPtr mBatchSlotsAcceptTokens; // [maxBatchSize], int32_t, address map, pinned
|
||||
TensorPtr mBatchSlotsAcceptLogits; // [maxBatchSize], int32_t, address map, pinned
|
||||
TensorPtr mTargetLogitsPtrs; // [maxBatchSize], float*, pointers to target logits, pinned
|
||||
SizeType mMaxSequenceLength{};
|
||||
SizeType mMaxAttentionWindow{};
|
||||
SizeType mSinkTokenLength{};
|
||||
SizeType mActualBatchSize{};
|
||||
SizeType mMaxTokensPerStep{};
|
||||
SizeType mMaxStopWordsLen{};
|
||||
SizeType mMaxBadWordsLen{};
|
||||
|
||||
bool mFusedDecoder{false};
|
||||
};
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -19,6 +19,7 @@
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/loraModule.h"
|
||||
#include "tensorrt_llm/runtime/medusaModule.h"
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
@ -62,6 +63,7 @@ public:
|
||||
, mPagedContextFMHA(false)
|
||||
, mUseLoraPlugin(false)
|
||||
, mMlpHiddenSize(0)
|
||||
, mMedusaModule(std::nullopt)
|
||||
{
|
||||
}
|
||||
|
||||
@ -341,6 +343,31 @@ public:
|
||||
mMlpHiddenSize = mlpHiddenSize;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType constexpr getMaxLoraRank() const noexcept
|
||||
{
|
||||
return mMaxLoraRank;
|
||||
}
|
||||
|
||||
void constexpr setMaxLoraRank(SizeType maxLoraRank) noexcept
|
||||
{
|
||||
mMaxLoraRank = maxLoraRank;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr useMedusa() const noexcept
|
||||
{
|
||||
return mMedusaModule.has_value();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<MedusaModule> getMedusaModule() const noexcept
|
||||
{
|
||||
return mMedusaModule;
|
||||
}
|
||||
|
||||
void setMedusaModule(MedusaModule const& medusaModule) noexcept
|
||||
{
|
||||
mMedusaModule = medusaModule;
|
||||
}
|
||||
|
||||
private:
|
||||
SizeType mVocabSize;
|
||||
SizeType mNbLayers;
|
||||
@ -374,5 +401,8 @@ private:
|
||||
bool mUseLoraPlugin;
|
||||
std::vector<LoraModule> mLoraModules;
|
||||
SizeType mMlpHiddenSize;
|
||||
SizeType mMaxLoraRank;
|
||||
|
||||
std::optional<MedusaModule> mMedusaModule;
|
||||
};
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -20,6 +20,7 @@
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/cudaEvent.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
#include "tensorrt_llm/runtime/generationInput.h"
|
||||
#include "tensorrt_llm/runtime/generationOutput.h"
|
||||
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||
@ -90,6 +91,7 @@ public:
|
||||
KvCacheConfig kvCacheConfig{};
|
||||
std::optional<SizeType> ctxMicroBatchSize = std::nullopt;
|
||||
std::optional<SizeType> genMicroBatchSize = std::nullopt;
|
||||
std::optional<DecodingMode> decodingMode = std::nullopt;
|
||||
};
|
||||
|
||||
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
@ -146,7 +148,8 @@ private:
|
||||
void createContexts();
|
||||
void createBuffers(SizeType numMicroBatches);
|
||||
void createDecoders(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow, SizeType sinkTokenLength,
|
||||
SizeType maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches);
|
||||
SizeType maxSequenceLength, nvinfer1::DataType logitsType, bool decoderPerRequest, SizeType numMicroBatches,
|
||||
DecodingMode const& decodingMode);
|
||||
void createKvCacheManager(SizeType batchSize, SizeType beamWidth, SizeType maxAttentionWindow,
|
||||
SizeType sinkTokenLength, SizeType maxSequenceLength, KvCacheConfig const& config);
|
||||
void createCustomAllReduceWorkspace(SizeType batchSize, SizeType beamWidth, SizeType maxSequenceLength);
|
||||
|
||||
@ -141,11 +141,6 @@ public:
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
using TokenPtr = std::unique_ptr<decoder_batch::Token const>;
|
||||
|
||||
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
|
||||
virtual void newRequest(
|
||||
SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig)
|
||||
= 0;
|
||||
|
||||
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
|
||||
virtual TokenPtr forwardAsync(decoder_batch::Output& output, decoder_batch::Input const& input) = 0;
|
||||
|
||||
@ -186,6 +181,11 @@ public:
|
||||
|
||||
virtual std::vector<SizeType> getNbSteps() const = 0;
|
||||
|
||||
//! @brief Initialize batched decoder at seqSlots with a new `requests`.
|
||||
virtual void newRequests(std::vector<SizeType> const& seqSlots, std::vector<decoder_batch::Request> const& requests,
|
||||
std::vector<SamplingConfig> const& samplingConfigs)
|
||||
= 0;
|
||||
|
||||
protected:
|
||||
IGptDecoderBatch() = default;
|
||||
};
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
#include "tensorrt_llm/runtime/generationInput.h"
|
||||
#include "tensorrt_llm/runtime/generationOutput.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
@ -74,8 +75,9 @@ public:
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
|
||||
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
|
||||
virtual void setup(SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
|
||||
SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep, nvinfer1::DataType dtype)
|
||||
virtual void setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth,
|
||||
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep,
|
||||
bool fusedDecoder, nvinfer1::DataType dtype)
|
||||
= 0;
|
||||
|
||||
//! @brief Initialize the decoder with new batch of inputs.
|
||||
|
||||
@ -26,17 +26,67 @@ namespace tensorrt_llm::runtime
|
||||
|
||||
class SamplingConfig
|
||||
{
|
||||
private:
|
||||
using FloatType = float;
|
||||
|
||||
template <typename T>
|
||||
using OptVec = std::optional<std::vector<T>>;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
static OptVec<T> fuseValues(
|
||||
std::vector<SamplingConfig> const& configs, std::function<OptVec<T>(SizeType ci)> accessor)
|
||||
{
|
||||
std::vector<T> values;
|
||||
auto const hasValues = accessor(0).has_value();
|
||||
for (size_t ci = 0; ci < configs.size(); ++ci)
|
||||
{
|
||||
const auto& configValue = accessor(ci);
|
||||
TLLM_CHECK(hasValues == configValue.has_value());
|
||||
if (hasValues)
|
||||
{
|
||||
TLLM_CHECK(configValue.value().size() == 1);
|
||||
values.push_back(configValue.value().front());
|
||||
}
|
||||
}
|
||||
|
||||
if (!hasValues)
|
||||
{
|
||||
return std::nullopt;
|
||||
}
|
||||
return std::make_optional<std::vector<T>>(values);
|
||||
}
|
||||
|
||||
public:
|
||||
explicit SamplingConfig(SizeType beamWidth = 1)
|
||||
: beamWidth{beamWidth}
|
||||
{
|
||||
}
|
||||
|
||||
explicit SamplingConfig(std::vector<SamplingConfig> const& configs)
|
||||
{
|
||||
TLLM_CHECK(configs.size() > 0);
|
||||
beamWidth = configs.front().beamWidth;
|
||||
temperature = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].temperature; });
|
||||
minLength = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].minLength; });
|
||||
repetitionPenalty
|
||||
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].repetitionPenalty; });
|
||||
presencePenalty
|
||||
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].presencePenalty; });
|
||||
topK = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].topK; });
|
||||
topP = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].topP; });
|
||||
randomSeed = fuseValues<uint64_t>(configs, [&configs](SizeType ci) { return configs[ci].randomSeed; });
|
||||
topPDecay = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].topPDecay; });
|
||||
topPMin = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].topPMin; });
|
||||
topPResetIds = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].topPResetIds; });
|
||||
beamSearchDiversityRate
|
||||
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].beamSearchDiversityRate; });
|
||||
lengthPenalty = fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].lengthPenalty; });
|
||||
draftAcceptanceThreshold
|
||||
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].draftAcceptanceThreshold; });
|
||||
}
|
||||
|
||||
public:
|
||||
SizeType beamWidth;
|
||||
|
||||
OptVec<FloatType> temperature; // [1] or [batch_size] on cpu
|
||||
|
||||
@ -138,6 +138,7 @@ set(TRTLLM_LINK_LIBS
|
||||
${TRT_LIB}
|
||||
common_src
|
||||
kernels_src
|
||||
cutlass_src
|
||||
layers_src
|
||||
runtime_src)
|
||||
|
||||
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:6a8b82a255fc93e99bfca1bb9975f8ac524a980e25c6678fbed0e64b7d8e1841
|
||||
size 1949506
|
||||
@ -1,3 +0,0 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0bfef429b1985539c3956ada86c0578ad9b783c6b79d0c5123e7e23a18f3356b
|
||||
size 1966228
|
||||
@ -1,3 +0,0 @@
|
||||
86bf72386b323b73b0fd95f564270c8b libtensorrt_llm_batch_manager_static.a
|
||||
93e03895d79092f5bf81a4233078d0b3 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
b3fa820622b86294b498b661362a06ec386a6e1b commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:a94a642407d43a81d7c8221b4158bd6958c6fb57c3c1c39446e15ac8471f7b41
|
||||
size 1897882
|
||||
oid sha256:0268f64b0c2540e07bf05ad458f7aa33c9d6e65fc4f5c85cd8d0946d658ffeb8
|
||||
size 2092012
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3c1e32ebb36f74e6b971fe8898b9a1d1763332f5b2cdbf089443471de6087b12
|
||||
size 1871190
|
||||
oid sha256:89ae0be676e7aa9b562f6745636f7d77198f87b83ec6295aff74273767e4fca7
|
||||
size 2071180
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
ba7b3bdcc6754724cc9405c2699b1900 libtensorrt_llm_batch_manager_static.a
|
||||
b7bf0d41e6bde342352c6a5eee2d5ad0 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
63c3f64faa14f9d5d66b7e186a6cc80b libtensorrt_llm_batch_manager_static.a
|
||||
dbcc1bbe80d977c1655d32ef69b36578 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
|
||||
@ -25,7 +25,7 @@ namespace tensorrt_llm::common::nvtx
|
||||
{
|
||||
inline nvtx3::color nextColor()
|
||||
{
|
||||
#if !defined(NVTX_DISABLE)
|
||||
#ifndef NVTX_DISABLE
|
||||
constexpr std::array kColors{nvtx3::color{0xff00ff00}, nvtx3::color{0xff0000ff}, nvtx3::color{0xffffff00},
|
||||
nvtx3::color{0xffff00ff}, nvtx3::color{0xff00ffff}, nvtx3::color{0xffff0000}, nvtx3::color{0xffffffff}};
|
||||
constexpr auto numColors = kColors.size();
|
||||
|
||||
@ -25,7 +25,7 @@ namespace tensorrt_llm
|
||||
namespace cutlass_extensions
|
||||
{
|
||||
|
||||
template <typename GemmKernel>
|
||||
template <typename GemmKernel, bool enable_cutlass_3x = false>
|
||||
inline int compute_occupancy_for_kernel()
|
||||
{
|
||||
|
||||
@ -39,7 +39,14 @@ inline int compute_occupancy_for_kernel()
|
||||
tensorrt_llm::common::check_cuda_error(cudaGetDevice(&device));
|
||||
tensorrt_llm::common::check_cuda_error(
|
||||
cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
|
||||
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel<GemmKernel>));
|
||||
if constexpr (enable_cutlass_3x)
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::device_kernel<GemmKernel>));
|
||||
}
|
||||
else
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(cudaFuncGetAttributes(&attr, cutlass::Kernel<GemmKernel>));
|
||||
}
|
||||
if (smem_size + attr.sharedSizeBytes >= static_cast<size_t>(max_smem_per_block))
|
||||
{
|
||||
// This should mean that
|
||||
@ -51,8 +58,17 @@ inline int compute_occupancy_for_kernel()
|
||||
}
|
||||
|
||||
int max_active_blocks = -1;
|
||||
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
|
||||
if constexpr (enable_cutlass_3x)
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_active_blocks, cutlass::device_kernel<GemmKernel>,
|
||||
128 * (GemmKernel::NumLoadWarpGroups + GemmKernel::NumMmaWarpGroups), smem_size));
|
||||
}
|
||||
else
|
||||
{
|
||||
tensorrt_llm::common::check_cuda_error(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, cutlass::Kernel<GemmKernel>, GemmKernel::kThreadCount, smem_size));
|
||||
}
|
||||
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
@ -251,7 +251,6 @@ struct GemmFpAIntB
|
||||
CUTLASS_HOST_DEVICE
|
||||
static Status can_implement(Arguments const& args)
|
||||
{
|
||||
|
||||
static int const kAlignmentA
|
||||
= (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<32>>::value) ? 32
|
||||
: (platform::is_same<typename Mma::IteratorA::Layout, layout::ColumnMajorInterleaved<64>>::value)
|
||||
@ -340,19 +339,6 @@ struct GemmFpAIntB
|
||||
return 0;
|
||||
}
|
||||
|
||||
// The dummy template parameter is not used and exists so that we can compile this code using
|
||||
// a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in
|
||||
// a namespace
|
||||
template <bool B, typename dummy = void>
|
||||
struct KernelRunner
|
||||
{
|
||||
CUTLASS_DEVICE
|
||||
static void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
};
|
||||
|
||||
// Initializes the fine grained scale+bias iterator. Needed since the fine grained iterator
|
||||
// has a different constructor signature than a regular cutlass iterator
|
||||
template <typename IteratorScale, WeightOnlyQuantOp op, std::enable_if_t<isFinegrained(op), bool> = true>
|
||||
@ -375,169 +361,177 @@ struct GemmFpAIntB
|
||||
return IteratorScale(params, pointer_scale, extent, thread_id, threadblock_offset);
|
||||
}
|
||||
|
||||
template <typename dummy>
|
||||
struct KernelRunner<true, dummy>
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
CUTLASS_DEVICE
|
||||
static void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|
||||
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|
||||
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
|
||||
{
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|
||||
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
// Compute threadblock location
|
||||
ThreadblockSwizzle threadblock_swizzle;
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_tile_offset
|
||||
= threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// Early exit if CTA is out of range
|
||||
if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m()
|
||||
|| params.grid_tiled_shape.n() <= threadblock_tile_offset.n())
|
||||
{
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
|
||||
|
||||
typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64;
|
||||
typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0;
|
||||
cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(),
|
||||
{params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices);
|
||||
|
||||
typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(),
|
||||
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B,
|
||||
params.gather_B_indices);
|
||||
|
||||
typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1;
|
||||
typename Mma::IteratorScale iterator_scale = initialize_scale<typename Mma::IteratorScale, Mma::QuantOp>(
|
||||
params.params_scale, params.ref_scale.data(), params.ref_zero.data(),
|
||||
{scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations > 0)
|
||||
{
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(),
|
||||
params.problem_size.mn(), thread_idx, threadblock_offset, params.scatter_D_indices);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(),
|
||||
params.problem_size.mn(), thread_idx, threadblock_offset, params.scatter_D_indices);
|
||||
|
||||
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k())
|
||||
{
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1)
|
||||
{
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM,
|
||||
threadblock_tile_offset.k() * params.gemm_k_size,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{threadblock_tile_offset.k() * params.gemm_k_size * kInterleave,
|
||||
threadblock_tile_offset.n() * Mma::Shape::kN / kInterleave};
|
||||
|
||||
typename MatrixCoord::Index fg_row_offset = threadblock_tile_offset.k() * params.gemm_k_size / 64;
|
||||
typename MatrixCoord::Index scale_row_offset = isFinegrained(Mma::QuantOp) ? fg_row_offset : 0;
|
||||
cutlass::MatrixCoord tb_offset_scale{scale_row_offset, threadblock_tile_offset.n() * Mma::Shape::kN};
|
||||
|
||||
// Problem size is a function of threadblock index in the K dimension
|
||||
int problem_size_k = min(params.problem_size.k(), (threadblock_tile_offset.k() + 1) * params.gemm_k_size);
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size_k - tb_offset_A.column() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(params.params_A, params.ref_A.data(),
|
||||
{params.problem_size.m(), problem_size_k}, thread_idx, tb_offset_A, params.gather_A_indices);
|
||||
|
||||
typename Mma::IteratorB iterator_B(params.params_B, params.ref_B.data(),
|
||||
{problem_size_k * kInterleave, params.problem_size.n() / kInterleave}, thread_idx, tb_offset_B,
|
||||
params.gather_B_indices);
|
||||
|
||||
typename MatrixCoord::Index scale_row_extent = isFinegrained(Mma::QuantOp) ? problem_size_k / 64 : 1;
|
||||
typename Mma::IteratorScale iterator_scale = initialize_scale<typename Mma::IteratorScale, Mma::QuantOp>(
|
||||
params.params_scale, params.ref_scale.data(), params.ref_zero.data(),
|
||||
{scale_row_extent, params.problem_size.n()}, thread_idx, tb_offset_scale, params.group_size);
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Main loop
|
||||
//
|
||||
// Construct thread-scoped matrix multiply
|
||||
Mma mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
if (!kSplitKSerial || gemm_k_iterations > 0)
|
||||
{
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
//
|
||||
// Masked tile iterators constructed from members
|
||||
//
|
||||
|
||||
threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile);
|
||||
|
||||
// assume identity swizzle
|
||||
MatrixCoord threadblock_offset(
|
||||
threadblock_tile_offset.m() * Mma::Shape::kM, threadblock_tile_offset.n() * Mma::Shape::kN);
|
||||
|
||||
int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m();
|
||||
|
||||
// Construct the semaphore.
|
||||
Semaphore semaphore(params.semaphore + block_idx, thread_idx);
|
||||
|
||||
// If performing a reduction via split-K, fetch the initial synchronization
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// Fetch the synchronization lock initially but do not block.
|
||||
semaphore.fetch();
|
||||
|
||||
// Indicate which position in a serial reduction the output operator is currently updating
|
||||
output_op.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k());
|
||||
}
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(params.params_C, params.ref_C.data(), params.problem_size.mn(),
|
||||
thread_idx, threadblock_offset, params.scatter_D_indices);
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(params.params_D, params.ref_D.data(), params.problem_size.mn(),
|
||||
thread_idx, threadblock_offset, params.scatter_D_indices);
|
||||
|
||||
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Wait on the semaphore - this latency may have been covered by iterator construction
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
// For subsequent threadblocks, the source matrix is held in the 'D' tensor.
|
||||
if (threadblock_tile_offset.k())
|
||||
{
|
||||
iterator_C = iterator_D;
|
||||
}
|
||||
|
||||
semaphore.wait(threadblock_tile_offset.k());
|
||||
}
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
//
|
||||
// Release the semaphore
|
||||
//
|
||||
|
||||
if (kSplitKSerial && params.grid_tiled_shape.k() > 1)
|
||||
{
|
||||
|
||||
int lock = 0;
|
||||
if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1)
|
||||
{
|
||||
|
||||
// The final threadblock resets the semaphore for subsequent grids.
|
||||
lock = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Otherwise, the semaphore is incremented
|
||||
lock = threadblock_tile_offset.k() + 1;
|
||||
}
|
||||
|
||||
semaphore.release(lock);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
|
||||
{
|
||||
run_kernel_(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
@ -549,14 +543,13 @@ struct GemmFpAIntB
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
|
||||
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm70>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
run_kernel<arch::Sm70>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm75>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ <= 900)
|
||||
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
CUTLASS_NOT_IMPLEMENTED(); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
|
||||
@ -407,7 +407,7 @@ public:
|
||||
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
|
||||
// Compute threadblock location
|
||||
@ -534,6 +534,48 @@ public:
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(epilogue_visitor, accumulators);
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
if constexpr (platform::is_same<ArchTag, CompilationArch>::value)
|
||||
{
|
||||
run_kernel_(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
to the ArchTag of the cutlass kernel operator.
|
||||
*/
|
||||
/// Executes one GEMM
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 720)
|
||||
run_kernel<arch::Sm70>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 720) && (__CUDA_ARCH__ < 750)
|
||||
run_kernel<arch::Sm72>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
// TODO - replace with CUTLASS_NOT_IMPLEMENTED() and upgrade to 3.x kernels.
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
#endif
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -80,7 +80,8 @@ struct LayoutDetailsB<bfloat16_t, Arch, typename platform::enable_if<Arch::kMinC
|
||||
// Specializations for Turing+ when B is quantized. These can use the operator OpMultiplyAddDequantizeInterleavedBToA,
|
||||
// which signals that we want to dequantize after loading from smem.
|
||||
template <typename Arch>
|
||||
struct LayoutDetailsB<uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
struct LayoutDetailsB < uint8_t,
|
||||
Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
|
||||
@ -95,7 +96,8 @@ public:
|
||||
};
|
||||
|
||||
template <typename Arch>
|
||||
struct LayoutDetailsB<uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75>::type>
|
||||
struct LayoutDetailsB < uint4b_t,
|
||||
Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 75 && Arch::kMinComputeCapability<90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
|
||||
@ -109,6 +111,24 @@ public:
|
||||
using Operator = cutlass::arch::OpMultiplyAddDequantizeInterleavedBToA;
|
||||
};
|
||||
|
||||
template <typename Arch>
|
||||
struct LayoutDetailsB<uint8_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
template <typename Arch>
|
||||
struct LayoutDetailsB<uint4b_t, Arch, typename platform::enable_if<Arch::kMinComputeCapability >= 90>::type>
|
||||
{
|
||||
static constexpr int ThreadblockK = 64;
|
||||
using Layout = layout::ColumnMajor;
|
||||
static constexpr int ElementsPerAccess = 128 / cutlass::sizeof_bits<half_t>::value;
|
||||
using Operator = cutlass::arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
} // namespace kernel
|
||||
} // namespace gemm
|
||||
} // namespace cutlass
|
||||
|
||||
@ -321,174 +321,170 @@ public:
|
||||
return 0;
|
||||
}
|
||||
|
||||
// The dummy template parameter is not used and exists so that we can compile this code using
|
||||
// a standard earlier than C++17. Prior to C++17, fully specialized templates HAD to exists in
|
||||
// a namespace
|
||||
template <bool B, typename dummy = void>
|
||||
struct KernelRunner
|
||||
CUTLASS_DEVICE
|
||||
void run_kernel_(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
CUTLASS_DEVICE
|
||||
static void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability to implement
|
||||
// a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
||||
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|
||||
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
const int64_t gemm_k = params.problem_visitor.gemm_k;
|
||||
const int64_t gemm_n = params.problem_visitor.gemm_n;
|
||||
int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
int loop = 0;
|
||||
while (problem_visitor.next_tile())
|
||||
{
|
||||
loop++;
|
||||
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_offset(
|
||||
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0);
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on the transpose
|
||||
const int64_t rows_to_jump
|
||||
= problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
|
||||
ElementA* ptr_A = reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
|
||||
typename LayoutA::LongIndex ldm_A = gemm_k;
|
||||
|
||||
char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix;
|
||||
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
|
||||
typename LayoutB::LongIndex ldm_B
|
||||
= platform::is_same<layout::RowMajor, LayoutB>::value ? gemm_n : gemm_k * kInterleave;
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B,
|
||||
{problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
auto CreateMMA = [&]()
|
||||
{
|
||||
if constexpr (use_dq_gemm<Mma>::value)
|
||||
return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
|
||||
else
|
||||
return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
};
|
||||
Mma mma = CreateMMA();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n();
|
||||
|
||||
if constexpr (use_dq_gemm<Mma>::value)
|
||||
{
|
||||
const MatrixCoord scale_extent = {1, problem_size.n()};
|
||||
typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()),
|
||||
weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale);
|
||||
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
ElementC* ptr_C = reinterpret_cast<ElementC*>(params.ptr_C) + problem_idx * gemm_n;
|
||||
ElementC* ptr_D = reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
|
||||
|
||||
LayoutC layout_C(0);
|
||||
LayoutC layout_D(gemm_n);
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
|
||||
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn());
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn());
|
||||
|
||||
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename CompilationArch>
|
||||
CUTLASS_DEVICE void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
if constexpr (platform::is_same<KernelArch, CompilationArch>::value)
|
||||
{
|
||||
run_kernel_(params, shared_storage);
|
||||
}
|
||||
else
|
||||
{
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename dummy>
|
||||
struct KernelRunner<true, dummy>
|
||||
{
|
||||
CUTLASS_DEVICE
|
||||
static void run_kernel(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
//
|
||||
// These types shadow the type-level definitions and support the ability to implement
|
||||
// a 'transposed' GEMM that computes the transposed problems.
|
||||
//
|
||||
using ElementA = typename Mma::IteratorA::Element;
|
||||
using LayoutA = typename Mma::IteratorA::Layout;
|
||||
using ElementB = typename Mma::IteratorB::Element;
|
||||
using LayoutB = typename Mma::IteratorB::Layout;
|
||||
using ElementC = typename Epilogue::OutputTileIterator::Element;
|
||||
using LayoutC = typename Epilogue::OutputTileIterator::Layout;
|
||||
static constexpr int kInterleave = Mma::IteratorB::Shape::kRow / Mma::Shape::kK;
|
||||
static_assert(platform::is_same<LayoutB, layout::RowMajor>::value && kInterleave == 1
|
||||
|| platform::is_same<LayoutB, layout::ColumnMajor>::value && kInterleave >= 1,
|
||||
"B must be row major/col major OR col major interleaved.");
|
||||
|
||||
//
|
||||
// Problem visitor.
|
||||
//
|
||||
ProblemVisitor problem_visitor(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x);
|
||||
|
||||
const int64_t gemm_k = params.problem_visitor.gemm_k;
|
||||
const int64_t gemm_n = params.problem_visitor.gemm_n;
|
||||
int64_t bytes_per_expert_matrix = (gemm_k * gemm_n / 8) * cutlass::sizeof_bits<ElementB>::value;
|
||||
|
||||
// Outer 'persistent' loop to iterate over tiles
|
||||
int loop = 0;
|
||||
while (problem_visitor.next_tile())
|
||||
{
|
||||
loop++;
|
||||
|
||||
GemmCoord problem_size = problem_visitor.problem_size();
|
||||
int32_t problem_idx = problem_visitor.problem_index();
|
||||
int32_t cta_idx = int32_t(problem_visitor.threadblock_idx());
|
||||
|
||||
GemmCoord grid_shape = problem_visitor.grid_shape(problem_size);
|
||||
|
||||
cutlass::gemm::GemmCoord threadblock_offset(
|
||||
int(cta_idx / grid_shape.n()) * Mma::Shape::kM, int(cta_idx % grid_shape.n()) * Mma::Shape::kN, 0);
|
||||
|
||||
// Load element pointers. Exchange pointers and strides if working on the transpose
|
||||
const int64_t rows_to_jump
|
||||
= problem_idx == 0 ? 0 : params.problem_visitor.last_row_for_problem[problem_idx - 1];
|
||||
ElementA* ptr_A = reinterpret_cast<ElementA*>(params.ptr_A) + rows_to_jump * gemm_k;
|
||||
typename LayoutA::LongIndex ldm_A = gemm_k;
|
||||
|
||||
char* byte_ptr_B = ((char*) params.ptr_B) + problem_idx * bytes_per_expert_matrix;
|
||||
ElementB* ptr_B = reinterpret_cast<ElementB*>(byte_ptr_B);
|
||||
typename LayoutB::LongIndex ldm_B
|
||||
= platform::is_same<layout::RowMajor, LayoutB>::value ? gemm_n : gemm_k * kInterleave;
|
||||
|
||||
// Compute initial location in logical coordinates
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
threadblock_offset.m(),
|
||||
0,
|
||||
};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.n() / kInterleave};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_scale{0, threadblock_offset.n()};
|
||||
|
||||
// Compute position within threadblock
|
||||
int thread_idx = threadIdx.x;
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename Mma::IteratorA iterator_A(
|
||||
LayoutA(ldm_A), ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A);
|
||||
|
||||
typename Mma::IteratorB iterator_B(LayoutB(ldm_B), ptr_B,
|
||||
{problem_size.k() * kInterleave, problem_size.n() / kInterleave}, thread_idx, tb_offset_B);
|
||||
|
||||
typename Mma::FragmentC accumulators;
|
||||
|
||||
accumulators.clear();
|
||||
|
||||
// Broadcast the warp_id computed by lane 0 to ensure dependent code
|
||||
// is compiled as warp-uniform.
|
||||
int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);
|
||||
|
||||
int lane_idx = threadIdx.x % 32;
|
||||
|
||||
//
|
||||
// Matrix multiply phase
|
||||
//
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
auto CreateMMA = [&]()
|
||||
{
|
||||
if constexpr (use_dq_gemm<Mma>::value)
|
||||
return Mma(shared_storage.main_loop, params.group_size, thread_idx, warp_idx, lane_idx);
|
||||
else
|
||||
return Mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);
|
||||
};
|
||||
Mma mma = CreateMMA();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK;
|
||||
|
||||
// Wait for all threads to finish their epilogue phases from the previous tile.
|
||||
__syncthreads();
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
ElementScale* weight_scale_ptr = params.weight_scales + problem_idx * problem_size.n();
|
||||
|
||||
if constexpr (use_dq_gemm<Mma>::value)
|
||||
{
|
||||
const MatrixCoord scale_extent = {1, problem_size.n()};
|
||||
typename Mma::IteratorScale iterator_scale(Mma::IteratorScale::Layout(scale_extent.column()),
|
||||
weight_scale_ptr, scale_extent, thread_idx, tb_offset_scale);
|
||||
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, iterator_scale, accumulators);
|
||||
}
|
||||
else
|
||||
{
|
||||
mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators);
|
||||
}
|
||||
|
||||
//
|
||||
// Epilogue
|
||||
//
|
||||
|
||||
EpilogueOutputOp output_op(params.output_op);
|
||||
|
||||
ElementC* ptr_C = reinterpret_cast<ElementC*>(params.ptr_C) + problem_idx * gemm_n;
|
||||
ElementC* ptr_D = reinterpret_cast<ElementC*>(params.ptr_D) + rows_to_jump * gemm_n;
|
||||
|
||||
LayoutC layout_C(0);
|
||||
LayoutC layout_D(gemm_n);
|
||||
|
||||
typename Epilogue::OutputTileIterator::Params params_C(layout_C);
|
||||
typename Epilogue::OutputTileIterator::Params params_D(layout_D);
|
||||
|
||||
// Tile iterator loading from source tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_C(
|
||||
params_C, ptr_C, problem_size.mn(), thread_idx, threadblock_offset.mn());
|
||||
|
||||
// Tile iterator writing to destination tensor.
|
||||
typename Epilogue::OutputTileIterator iterator_D(
|
||||
params_D, ptr_D, problem_size.mn(), thread_idx, threadblock_offset.mn());
|
||||
|
||||
Epilogue epilogue(shared_storage.epilogue, thread_idx, warp_idx, lane_idx);
|
||||
|
||||
// Execute the epilogue operator to update the destination tensor.
|
||||
epilogue(output_op, iterator_D, accumulators, iterator_C);
|
||||
|
||||
// Next tile
|
||||
problem_visitor.advance(gridDim.x);
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/*
|
||||
To improve compilation speed, we do not compile the device operator if the CUDA_ARCH does not correspond
|
||||
@ -498,19 +494,20 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
void operator()(Params const& params, SharedStorage& shared_storage)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
|
||||
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm70>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm75>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
||||
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
#elif defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
// TODO Update the arch to Sm90 once CUTLASS hopper specialisations are available
|
||||
static constexpr bool compile_needed = platform::is_same<KernelArch, arch::Sm80>::value;
|
||||
KernelRunner<compile_needed>::run_kernel(params, shared_storage);
|
||||
#if defined(__CUDA_ARCH__)
|
||||
#if (__CUDA_ARCH__ >= 700) && (__CUDA_ARCH__ < 750)
|
||||
run_kernel<arch::Sm70>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 750) && (__CUDA_ARCH__ < 800)
|
||||
run_kernel<arch::Sm75>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 800) && (__CUDA_ARCH__ < 900)
|
||||
run_kernel<arch::Sm80>(params, shared_storage);
|
||||
#elif (__CUDA_ARCH__ >= 900)
|
||||
run_kernel<arch::Sm80>(
|
||||
params, shared_storage); // Don't compile these for Hopper or later. Use CUTLASS 3.x kernels.
|
||||
#else
|
||||
static_assert(
|
||||
false, "Invalid architecture being compiled. Only Volta+ supported in weight-only quantization kernels.");
|
||||
#endif
|
||||
#else
|
||||
CUTLASS_NOT_IMPLEMENTED();
|
||||
#endif
|
||||
|
||||
@ -60,12 +60,81 @@ enum class SplitKStyle
|
||||
// SPLIT_K_PARALLEL // Not supported yet
|
||||
};
|
||||
|
||||
enum class CutlassTileConfigSM90
|
||||
{
|
||||
// Signals that we should run heuristics do choose a config
|
||||
Undefined,
|
||||
|
||||
// Signals that we should run heuristics do choose a config
|
||||
ChooseWithHeuristic,
|
||||
|
||||
// CTA configs for M=64
|
||||
CtaShape64x16x128B,
|
||||
CtaShape64x32x128B,
|
||||
CtaShape64x64x128B,
|
||||
CtaShape64x128x128B,
|
||||
CtaShape64x256x128B,
|
||||
|
||||
// CTA configs for M=128
|
||||
CtaShape128x16x128B,
|
||||
CtaShape128x32x128B,
|
||||
CtaShape128x64x128B,
|
||||
CtaShape128x128x128B,
|
||||
CtaShape128x256x128B,
|
||||
|
||||
};
|
||||
|
||||
enum class MainloopScheduleType
|
||||
{
|
||||
AUTO // Automatically selects between pingpong and cooperative schedules on Hopper. On older architectures, this
|
||||
// defaults to the "legacy" main loop schedule.
|
||||
};
|
||||
|
||||
enum class EpilogueScheduleType
|
||||
{
|
||||
AUTO // Automatically chooses an epilogue schedule compatible with the selected main loop schedule for Hopper. For
|
||||
// architectures older than hopper, the epilogue is always performed by the same thread block as the main loop.
|
||||
};
|
||||
|
||||
enum class ClusterShape
|
||||
{
|
||||
ClusterShape_1x1x1,
|
||||
ClusterShape_2x1x1,
|
||||
ClusterShape_1x2x1,
|
||||
ClusterShape_2x2x1
|
||||
};
|
||||
|
||||
struct CutlassGemmConfig
|
||||
{
|
||||
CutlassTileConfig tile_config = CutlassTileConfig::ChooseWithHeuristic;
|
||||
SplitKStyle split_k_style = SplitKStyle::NO_SPLIT_K;
|
||||
int split_k_factor = -1;
|
||||
int stages = -1;
|
||||
|
||||
// config options for sm90
|
||||
CutlassTileConfigSM90 tile_config_sm90 = CutlassTileConfigSM90::ChooseWithHeuristic;
|
||||
MainloopScheduleType mainloop_schedule = MainloopScheduleType::AUTO;
|
||||
EpilogueScheduleType epilogue_schedule = EpilogueScheduleType::AUTO;
|
||||
ClusterShape cluster_shape = ClusterShape::ClusterShape_1x1x1;
|
||||
|
||||
CutlassGemmConfig() {}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfig tile_config, SplitKStyle split_k_style, int split_k_factor, int stages)
|
||||
: tile_config(tile_config)
|
||||
, split_k_style(split_k_style)
|
||||
, split_k_factor(split_k_factor)
|
||||
, stages(stages)
|
||||
{
|
||||
}
|
||||
|
||||
CutlassGemmConfig(CutlassTileConfigSM90 tile_config_sm90, MainloopScheduleType mainloop_schedule,
|
||||
EpilogueScheduleType epilogue_schedule, ClusterShape cluster_shape)
|
||||
: tile_config_sm90(tile_config_sm90)
|
||||
, mainloop_schedule(mainloop_schedule)
|
||||
, epilogue_schedule(epilogue_schedule)
|
||||
, cluster_shape(cluster_shape)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass_extensions
|
||||
|
||||
@ -18,6 +18,10 @@
|
||||
file(GLOB_RECURSE SRC_CPP *.cpp)
|
||||
file(GLOB_RECURSE SRC_CU *.cu)
|
||||
|
||||
# Exclude files in the cutlass_kernels folder
|
||||
list(FILTER SRC_CPP EXCLUDE REGEX "cutlass_kernels/.*")
|
||||
list(FILTER SRC_CU EXCLUDE REGEX "cutlass_kernels/.*")
|
||||
|
||||
# skip mmha 48, 80, 96, 112, 144, 160, 192 and 224 for fast build
|
||||
if(FAST_BUILD)
|
||||
list(FILTER SRC_CU EXCLUDE REGEX
|
||||
@ -27,3 +31,5 @@ endif()
|
||||
add_library(kernels_src OBJECT ${SRC_CPP} ${SRC_CU})
|
||||
set_property(TARGET kernels_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET kernels_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
|
||||
add_subdirectory(cutlass_kernels)
|
||||
|
||||
@ -25,42 +25,43 @@ namespace kernels
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
__global__ void ban_bad_words(T* logits, const int** output_ids_ptr, const int** parent_ids_ptr, const int* batch_slots,
|
||||
int batch_size, int beam_width, const int* bad_words, size_t bad_words_len, bool share_words, int vocab_size_padded,
|
||||
const int* sequence_lengths, const int max_seq_len)
|
||||
__global__ void ban_bad_words(T* logits, int32_t const** output_ids_ptr, int32_t const** parent_ids_ptr,
|
||||
int32_t const* batch_slots, int32_t beam_width, int32_t const** bad_words_ptrs, int32_t const* bad_words_lens,
|
||||
int32_t vocab_size_padded, int32_t const* sequence_lengths, const int32_t max_seq_len)
|
||||
{
|
||||
const int id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int batch_idx = blockIdx.y / beam_width;
|
||||
const int beam_idx = blockIdx.y % beam_width;
|
||||
int32_t const id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int32_t const batch_idx = blockIdx.y / beam_width;
|
||||
int32_t const beam_idx = blockIdx.y % beam_width;
|
||||
auto const batch_slot = batch_slots != nullptr ? batch_slots[batch_idx] : batch_idx;
|
||||
auto const batch_beam_idx = batch_slot * beam_width + beam_idx;
|
||||
|
||||
const int* base_bad_words = share_words ? bad_words : bad_words + batch_slot * 2 * bad_words_len;
|
||||
const int* base_bad_words_offsets = base_bad_words + bad_words_len;
|
||||
int32_t const* base_bad_words = bad_words_ptrs[batch_slot];
|
||||
auto const bad_words_len = bad_words_lens[batch_slot];
|
||||
int32_t const* base_bad_words_offsets = base_bad_words + bad_words_len;
|
||||
|
||||
if (id >= bad_words_len || base_bad_words_offsets[id] < 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const int item_end = base_bad_words_offsets[id];
|
||||
const int item_start = (id > 0) ? base_bad_words_offsets[id - 1] : 0;
|
||||
const int item_size = item_end - item_start;
|
||||
auto const item_end = base_bad_words_offsets[id];
|
||||
auto const item_start = (id > 0) ? base_bad_words_offsets[id - 1] : 0;
|
||||
auto const item_size = item_end - item_start;
|
||||
|
||||
/* The single-token case unconditionally bans the token */
|
||||
bool should_ban = item_size == 1;
|
||||
const int current_step{sequence_lengths[batch_beam_idx]};
|
||||
int32_t const current_step{sequence_lengths[batch_beam_idx]};
|
||||
/* Multi-token case and enough previously generated tokens to look for a match
|
||||
*/
|
||||
if (item_size > 1 && current_step >= item_size - 1)
|
||||
{
|
||||
should_ban = true;
|
||||
int parent_id = beam_idx;
|
||||
const bool gather_beam = beam_width > 1;
|
||||
int32_t parent_id = beam_idx;
|
||||
bool const gather_beam = beam_width > 1;
|
||||
|
||||
for (int token_idx = item_size - 2; token_idx >= 0; token_idx--)
|
||||
for (int32_t token_idx = item_size - 2; token_idx >= 0; token_idx--)
|
||||
{
|
||||
const int previous_token
|
||||
auto const previous_token
|
||||
= output_ids_ptr[batch_slot][parent_id * max_seq_len + current_step - (item_size - 1) + token_idx];
|
||||
|
||||
if (previous_token != base_bad_words[item_start + token_idx])
|
||||
@ -85,42 +86,46 @@ __global__ void ban_bad_words(T* logits, const int** output_ids_ptr, const int**
|
||||
|
||||
if (should_ban)
|
||||
{
|
||||
int banned_token = base_bad_words[item_end - 1];
|
||||
if (0 < banned_token && banned_token < vocab_size_padded)
|
||||
auto banned_token = base_bad_words[item_end - 1];
|
||||
if (0 <= banned_token && banned_token < vocab_size_padded)
|
||||
{
|
||||
logits[batch_slot * beam_width * vocab_size_padded + beam_idx * vocab_size_padded + banned_token]
|
||||
logits[batch_idx * beam_width * vocab_size_padded + beam_idx * vocab_size_padded + banned_token]
|
||||
= static_cast<T>(-INFINITY);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeBanBadWords(T* logits, const int** output_ids_ptr, const int** parent_ids_ptr, const int* batch_slot,
|
||||
int batch_size, int local_batch_size, int beam_width, const int* bad_words, bool share_words, size_t bad_words_len,
|
||||
int vocab_size_padded, const int* sequence_lengths, int max_seq_len, cudaStream_t stream)
|
||||
void invokeBanBadWords(T* logits, int32_t const** output_ids_ptr, int32_t const** parent_ids_ptr,
|
||||
int32_t const* batch_slot, int32_t batch_size, int32_t beam_width, int32_t const** bad_words,
|
||||
int32_t const* bad_words_lens, int32_t max_bad_words_len, int32_t vocab_size_padded,
|
||||
int32_t const* sequence_lengths, int32_t max_seq_len, cudaStream_t stream)
|
||||
{
|
||||
dim3 block, grid;
|
||||
constexpr size_t max_blocks{256};
|
||||
block.x = min(((bad_words_len + 32 - 1) / 32) * 32, max_blocks);
|
||||
grid.x = (bad_words_len + block.x - 1) / block.x;
|
||||
grid.y = local_batch_size * beam_width;
|
||||
constexpr int32_t max_blocks{256};
|
||||
block.x = min(((max_bad_words_len + 32 - 1) / 32) * 32, max_blocks);
|
||||
grid.x = (max_bad_words_len + block.x - 1) / block.x;
|
||||
grid.y = batch_size * beam_width;
|
||||
|
||||
ban_bad_words<<<grid, block, 0, stream>>>(logits, output_ids_ptr, parent_ids_ptr, batch_slot, batch_size,
|
||||
beam_width, bad_words, bad_words_len, share_words, vocab_size_padded, sequence_lengths, max_seq_len);
|
||||
ban_bad_words<<<grid, block, 0, stream>>>(logits, output_ids_ptr, parent_ids_ptr, batch_slot, beam_width, bad_words,
|
||||
bad_words_lens, vocab_size_padded, sequence_lengths, max_seq_len);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
template void invokeBanBadWords(half* logits, const int** output_ids_ptr, const int** parent_ids_ptr,
|
||||
const int* batch_slot, int batch_size, int local_batch_size, int beam_width, const int* bad_words, bool share_words,
|
||||
size_t bad_words_len, int vocab_size_padded, const int* sequence_lengths, int max_seq_len, cudaStream_t stream);
|
||||
template void invokeBanBadWords(half* logits, int32_t const** output_ids_ptr, int32_t const** parent_ids_ptr,
|
||||
int32_t const* batch_slot, int32_t batch_size, int32_t beam_width, int32_t const** bad_words,
|
||||
int32_t const* bad_words_lens, int32_t max_bad_words_len, int32_t vocab_size_padded,
|
||||
int32_t const* sequence_lengths, int32_t max_seq_len, cudaStream_t stream);
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeBanBadWords(__nv_bfloat16* logits, const int** output_ids_ptr, const int** parent_ids_ptr,
|
||||
const int* batch_slot, int batch_size, int local_batch_size, int beam_width, const int* bad_words, bool share_words,
|
||||
size_t bad_words_len, int vocab_size_padded, const int* sequence_lengths, int max_seq_len, cudaStream_t stream);
|
||||
template void invokeBanBadWords(__nv_bfloat16* logits, int32_t const** output_ids_ptr, int32_t const** parent_ids_ptr,
|
||||
int32_t const* batch_slot, int32_t batch_size, int32_t beam_width, int32_t const** bad_words,
|
||||
int32_t const* bad_words_lens, int32_t max_bad_words_len, int32_t vocab_size_padded,
|
||||
int32_t const* sequence_lengths, int32_t max_seq_len, cudaStream_t stream);
|
||||
#endif
|
||||
template void invokeBanBadWords(float* logits, const int** output_ids_ptr, const int** parent_ids_ptr,
|
||||
const int* batch_slot, int batch_size, int local_batch_size, int beam_width, const int* bad_words, bool share_words,
|
||||
size_t bad_words_len, int vocab_size_padded, const int* sequence_lengths, int max_seq_len, cudaStream_t stream);
|
||||
template void invokeBanBadWords(float* logits, int32_t const** output_ids_ptr, int32_t const** parent_ids_ptr,
|
||||
int32_t const* batch_slot, int32_t batch_size, int32_t beam_width, int32_t const** bad_words,
|
||||
int32_t const* bad_words_lens, int32_t max_bad_words_len, int32_t vocab_size_padded,
|
||||
int32_t const* sequence_lengths, int32_t max_seq_len, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -25,9 +25,10 @@ namespace kernels
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
void invokeBanBadWords(T* logits, const int** output_ids_ptr, const int** parent_ids_ptr, const int* batch_slot,
|
||||
int batch_size, int local_batch_size, int beam_width, const int* bad_words, bool share_words, size_t bad_words_len,
|
||||
int vocab_size_padded, const int* sequence_lengths, int max_seq_len, cudaStream_t stream);
|
||||
void invokeBanBadWords(T* logits, int32_t const** output_ids_ptr, int32_t const** parent_ids_ptr,
|
||||
int32_t const* batch_slot, int32_t batch_size, int32_t beam_width, int32_t const** bad_words,
|
||||
int32_t const* bad_words_len, int32_t max_bad_words_len, int32_t vocab_size_padded, int32_t const* sequence_lengths,
|
||||
int32_t max_seq_len, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -134,9 +134,8 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const Fi
|
||||
|
||||
template <typename T>
|
||||
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
|
||||
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size,
|
||||
int local_batch_size, int beam_width, int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded,
|
||||
size_t max_step, cudaStream_t stream)
|
||||
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, int beam_width,
|
||||
int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream)
|
||||
{
|
||||
// each input in the local batch can have different no_repeat_ngram_size. Use max for shmem allocation
|
||||
// getting the max of current batch and allocate shmem as needed is ideal. But here the ngram_buf is on GPU, while
|
||||
@ -149,7 +148,7 @@ void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedS
|
||||
constexpr size_t max_blocks{256};
|
||||
block.x = min(((max_step + 32 - 1) / 32) * 32, max_blocks);
|
||||
grid.x = (max_step + block.x - 1) / block.x;
|
||||
grid.y = local_batch_size * beam_width;
|
||||
grid.y = batch_size * beam_width;
|
||||
|
||||
// dynamically allocate shared memory of int[blockDim + 2*(ngram_size - 1)], where ngram_size - 1 is for boundary
|
||||
// token's ngram and for most recent tokens
|
||||
@ -162,8 +161,8 @@ void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedS
|
||||
#define INVOKE_BAN_REPEAT_NGRAM(T) \
|
||||
template void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf, \
|
||||
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, \
|
||||
int local_batch_size, int beam_width, int max_seq_len, const int* no_repeat_ngram_size_buf, \
|
||||
int vocab_size_padded, size_t max_step, cudaStream_t stream);
|
||||
int beam_width, int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, \
|
||||
cudaStream_t stream);
|
||||
|
||||
INVOKE_BAN_REPEAT_NGRAM(float)
|
||||
INVOKE_BAN_REPEAT_NGRAM(half)
|
||||
|
||||
@ -27,9 +27,8 @@ namespace kernels
|
||||
|
||||
template <typename T>
|
||||
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
|
||||
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size,
|
||||
int local_batch_size, int beam_width, int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded,
|
||||
size_t max_step, cudaStream_t stream);
|
||||
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, int beam_width,
|
||||
int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
91
cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt
Normal file
91
cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt
Normal file
@ -0,0 +1,91 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
#
|
||||
|
||||
file(GLOB_RECURSE SRC_CPP *.cpp)
|
||||
file(GLOB_RECURSE SRC_CU *.cu)
|
||||
|
||||
# This can happen when not building for Torch
|
||||
if(NOT Python3_EXECUTABLE)
|
||||
find_package(
|
||||
Python3
|
||||
COMPONENTS Interpreter
|
||||
REQUIRED)
|
||||
endif()
|
||||
|
||||
execute_process(
|
||||
WORKING_DIRECTORY ${3RDPARTY_DIR}/cutlass/python/
|
||||
COMMAND ${Python3_EXECUTABLE} setup_library.py develop --user
|
||||
RESULT_VARIABLE _CUTLASS_LIBRARY_SUCCESS)
|
||||
|
||||
if(NOT _CUTLASS_LIBRARY_SUCCESS MATCHES 0)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"Failed to set up the CUTLASS library due to ${_CUTLASS_LIBRARY_SUCCESS}."
|
||||
)
|
||||
endif()
|
||||
|
||||
set_directory_properties(
|
||||
PROPERTIES CMAKE_CONFIGURE_DEPENDS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/python/generate_kernels.py)
|
||||
|
||||
execute_process(
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/python/
|
||||
COMMAND ${Python3_EXECUTABLE} generate_kernels.py -o
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
RESULT_VARIABLE _KERNEL_GEN_SUCCESS)
|
||||
|
||||
if(NOT _KERNEL_GEN_SUCCESS MATCHES 0)
|
||||
message(
|
||||
FATAL_ERROR
|
||||
"Failed to generate CUTLASS kernel instantiations due to ${_KERNEL_GEN_SUCCESS}."
|
||||
)
|
||||
endif()
|
||||
|
||||
file(GLOB_RECURSE CU_INSTANTIATIONS ${CMAKE_CURRENT_BINARY_DIR}/*.cu)
|
||||
|
||||
add_library(cutlass_src OBJECT ${SRC_CPP} ${SRC_CU} ${CU_INSTANTIATIONS})
|
||||
set_property(TARGET cutlass_src PROPERTY POSITION_INDEPENDENT_CODE ON)
|
||||
set_property(TARGET cutlass_src PROPERTY CUDA_RESOLVE_DEVICE_SYMBOLS ON)
|
||||
|
||||
# Note - we deliberately do not include 90a PTX (even when 9.0+PTX is
|
||||
# specified). This is because sm_90a has arch conditional instructions that are
|
||||
# not forward compatible. As a result, it does not make sense to embed PTX into
|
||||
# the binary anyway.
|
||||
if("9.0" IN_LIST TORCH_CUDA_ARCH_LIST
|
||||
OR "9.0+PTX" IN_LIST TORCH_CUDA_ARCH_LIST
|
||||
OR TORCH_CUDA_ARCH_LIST STREQUAL "Auto")
|
||||
message(STATUS "MANUALLY APPENDING FLAG TO COMPILE FOR SM_90a.")
|
||||
target_compile_options(
|
||||
cutlass_src
|
||||
PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-gencode=arch=compute_90a,code=sm_90a>)
|
||||
|
||||
# Hopper kernels require cuda lib for TMA APIs
|
||||
target_link_libraries(cutlass_src PRIVATE CUDA::cuda_driver)
|
||||
|
||||
# No kernels should be parsed, unless hopper is specified. This is a build
|
||||
# time improvement
|
||||
target_compile_definitions(cutlass_src
|
||||
PRIVATE COMPILE_HOPPER_MIXED_INPUT_GEMMS)
|
||||
endif()
|
||||
|
||||
# Suppress GCC note: the ABI for passing parameters with 64-byte alignment has
|
||||
# changed in GCC 4.6 This note appears for kernels using TMA and clutters the
|
||||
# compilation output.
|
||||
if(NOT WIN32)
|
||||
target_compile_options(
|
||||
cutlass_src PRIVATE $<$<COMPILE_LANGUAGE:CUDA>:-Xcompiler=-Wno-psabi>)
|
||||
endif()
|
||||
@ -30,6 +30,7 @@
|
||||
#endif // #ifndef _WIN32
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <set>
|
||||
#include <vector>
|
||||
|
||||
using namespace tensorrt_llm::cutlass_extensions;
|
||||
@ -164,9 +165,103 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only,
|
||||
const bool int8_configs_only, const int max_split_k)
|
||||
std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
|
||||
const int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only)
|
||||
{
|
||||
enum class CutlassGemmType : char
|
||||
{
|
||||
Default,
|
||||
WeightOnly,
|
||||
Simt,
|
||||
Int8
|
||||
};
|
||||
|
||||
CutlassGemmType gemm_type = CutlassGemmType::Default;
|
||||
if (simt_configs_only)
|
||||
{
|
||||
gemm_type = CutlassGemmType::Simt;
|
||||
}
|
||||
else if (is_weight_only)
|
||||
{
|
||||
gemm_type = CutlassGemmType::WeightOnly;
|
||||
}
|
||||
else if (int8_configs_only)
|
||||
{
|
||||
gemm_type = CutlassGemmType::Int8;
|
||||
}
|
||||
|
||||
switch (gemm_type)
|
||||
{
|
||||
case CutlassGemmType::WeightOnly:
|
||||
return {CutlassTileConfigSM90::CtaShape64x16x128B, CutlassTileConfigSM90::CtaShape64x32x128B,
|
||||
CutlassTileConfigSM90::CtaShape64x64x128B, CutlassTileConfigSM90::CtaShape64x128x128B,
|
||||
CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x16x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B};
|
||||
default: throw std::runtime_error("get_candidate_tiles_sm90 only supports WeightOnly now.");
|
||||
}
|
||||
}
|
||||
|
||||
// We only compile CUTLASS kernels with multi-cast along M if the M tile is >= 128. This is purely to improve
|
||||
// compilation speed.
|
||||
bool supports_mcast_along_m(const CutlassTileConfigSM90 tile)
|
||||
{
|
||||
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape128x16x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x32x128B, CutlassTileConfigSM90::CtaShape128x64x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x128x128B, CutlassTileConfigSM90::CtaShape128x256x128B};
|
||||
return valid_tiles.count(tile) == 1;
|
||||
}
|
||||
|
||||
// We only compile CUTLASS kernels with multi-cast along N if the N tile is >= 128. This is purely to improve
|
||||
// compilation speed.
|
||||
bool supports_mcast_along_n(const CutlassTileConfigSM90 tile)
|
||||
{
|
||||
std::set<CutlassTileConfigSM90> valid_tiles{CutlassTileConfigSM90::CtaShape64x128x128B,
|
||||
CutlassTileConfigSM90::CtaShape64x256x128B, CutlassTileConfigSM90::CtaShape128x128x128B,
|
||||
CutlassTileConfigSM90::CtaShape128x256x128B};
|
||||
return valid_tiles.count(tile) == 1;
|
||||
}
|
||||
|
||||
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only,
|
||||
const bool int8_configs_only, const int max_split_k, const bool enable_hopper_gmma)
|
||||
{
|
||||
if (sm == 90 && enable_hopper_gmma)
|
||||
{
|
||||
std::vector<CutlassTileConfigSM90> tiles
|
||||
= get_candidate_tiles_sm90(sm, is_weight_only, simt_configs_only, int8_configs_only);
|
||||
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
for (const auto& tile_config : tiles)
|
||||
{
|
||||
CutlassGemmConfig config(
|
||||
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
|
||||
candidate_configs.push_back(config);
|
||||
|
||||
const bool has_m_mcast = supports_mcast_along_m(tile_config);
|
||||
const bool has_n_mcast = supports_mcast_along_n(tile_config);
|
||||
if (has_m_mcast)
|
||||
{
|
||||
CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_2x1x1);
|
||||
candidate_configs.push_back(config);
|
||||
}
|
||||
|
||||
if (has_n_mcast)
|
||||
{
|
||||
CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_1x2x1);
|
||||
candidate_configs.push_back(config);
|
||||
}
|
||||
|
||||
if (has_m_mcast && has_n_mcast)
|
||||
{
|
||||
CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
ClusterShape::ClusterShape_2x2x1);
|
||||
candidate_configs.push_back(config);
|
||||
}
|
||||
}
|
||||
return candidate_configs;
|
||||
}
|
||||
std::vector<CutlassTileConfig> tiles
|
||||
= get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only);
|
||||
|
||||
@ -177,7 +272,7 @@ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weigh
|
||||
{
|
||||
for (int stages = min_stages; stages <= max_stages; ++stages)
|
||||
{
|
||||
CutlassGemmConfig config{tile_config, SplitKStyle::NO_SPLIT_K, 1, stages};
|
||||
CutlassGemmConfig config(tile_config, SplitKStyle::NO_SPLIT_K, 1, stages);
|
||||
candidate_configs.push_back(config);
|
||||
if (sm >= 75)
|
||||
{
|
||||
@ -253,8 +348,8 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
|
||||
config_waves = num_waves_total;
|
||||
SplitKStyle split_style
|
||||
= split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
|
||||
best_config = CutlassGemmConfig{
|
||||
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
|
||||
best_config = CutlassGemmConfig(
|
||||
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages);
|
||||
current_m_tile = tile_shape.m;
|
||||
}
|
||||
else if (current_score == config_score
|
||||
@ -264,8 +359,8 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
|
||||
// Prefer deeper pipeline or smaller split-k
|
||||
SplitKStyle split_style
|
||||
= split_k_factor > 1 ? SplitKStyle::SPLIT_K_SERIAL : SplitKStyle::NO_SPLIT_K;
|
||||
best_config = CutlassGemmConfig{
|
||||
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages};
|
||||
best_config = CutlassGemmConfig(
|
||||
candidate_config.tile_config, split_style, split_k_factor, candidate_config.stages);
|
||||
current_m_tile = tile_shape.m;
|
||||
config_waves = num_waves_total;
|
||||
}
|
||||
|
||||
@ -28,7 +28,7 @@ namespace cutlass_kernels
|
||||
|
||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(int sm,
|
||||
const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only = false,
|
||||
const int max_split_k = 1);
|
||||
const int max_split_k = 1, const bool enable_hopper_gmma = false);
|
||||
|
||||
tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies(
|
||||
const std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig>& candidate_configs,
|
||||
|
||||
@ -138,10 +138,14 @@ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type)
|
||||
{
|
||||
return getLayoutDetailsForArch<cutlass::arch::Sm75>(quant_type);
|
||||
}
|
||||
else if (arch >= 80 && arch <= 90)
|
||||
else if (arch >= 80 && arch <= 89)
|
||||
{
|
||||
return getLayoutDetailsForArch<cutlass::arch::Sm80>(quant_type);
|
||||
}
|
||||
else if (arch == 90)
|
||||
{
|
||||
return getLayoutDetailsForArch<cutlass::arch::Sm90>(quant_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "Unsupported Arch");
|
||||
@ -532,6 +536,7 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const
|
||||
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight,
|
||||
const std::vector<size_t>& shape, QuantType quant_type)
|
||||
{
|
||||
const int arch = getSMVersion();
|
||||
LayoutDetails details = getLayoutDetailsForTransform(quant_type);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
|
||||
@ -551,7 +556,6 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, co
|
||||
// Works on row major data, so issue this permutation first.
|
||||
if (details.uses_imma_ldsm)
|
||||
{
|
||||
const int arch = getSMVersion();
|
||||
permute_B_rows_for_mixed_gemm(dst_buf.data(), src_buf.data(), shape, quant_type, arch);
|
||||
src_buf.swap(dst_buf);
|
||||
}
|
||||
@ -568,7 +572,10 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, co
|
||||
src_buf.swap(dst_buf);
|
||||
}
|
||||
|
||||
add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type);
|
||||
if (arch >= 70 && arch < 90)
|
||||
{
|
||||
add_bias_and_interleave_quantized_tensor_inplace(src_buf.data(), num_elts, quant_type);
|
||||
}
|
||||
std::copy(src_buf.begin(), src_buf.end(), preprocessed_quantized_weight);
|
||||
}
|
||||
|
||||
|
||||
@ -0,0 +1,112 @@
|
||||
/*
|
||||
* Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp8.h>
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/float8.h"
|
||||
#include "cutlass/half.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Tllm to Cutlass
|
||||
|
||||
template <typename T>
|
||||
struct TllmToCutlassTypeAdapter
|
||||
{
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TllmToCutlassTypeAdapter<half>
|
||||
{
|
||||
using type = cutlass::half_t;
|
||||
};
|
||||
|
||||
#if defined(ENABLE_BF16)
|
||||
template <>
|
||||
struct TllmToCutlassTypeAdapter<__nv_bfloat16>
|
||||
{
|
||||
using type = cutlass::bfloat16_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_FP8)
|
||||
template <>
|
||||
struct TllmToCutlassTypeAdapter<__nv_fp8_e4m3>
|
||||
{
|
||||
using type = cutlass::float_e4m3_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TllmToCutlassTypeAdapter<__nv_fp8_e5m2>
|
||||
{
|
||||
using type = cutlass::float_e5m2_t;
|
||||
};
|
||||
#endif
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// Cutlass to Tllm
|
||||
|
||||
template <typename T>
|
||||
struct CutlassToTllmTypeAdapter
|
||||
{
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CutlassToTllmTypeAdapter<cutlass::half_t>
|
||||
{
|
||||
using type = half;
|
||||
};
|
||||
|
||||
#if defined(ENABLE_BF16)
|
||||
template <>
|
||||
struct CutlassToTllmTypeAdapter<cutlass::bfloat16_t>
|
||||
{
|
||||
using type = __nv_bfloat16;
|
||||
};
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_FP8)
|
||||
template <>
|
||||
struct CutlassToTllmTypeAdapter<cutlass::float_e4m3_t>
|
||||
{
|
||||
using type = __nv_fp8_e4m3;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CutlassToTllmTypeAdapter<cutlass::float_e5m2_t>
|
||||
{
|
||||
using type = __nv_fp8_e5m2;
|
||||
};
|
||||
#endif
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
template class CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, /*Activation Type*/
|
||||
cutlass::int4b_t, /*Weight Type*/
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS, half, /*Scale and Zero Type*/
|
||||
half, /*Bias type Type*/
|
||||
half /*Output type Type*/
|
||||
>;
|
||||
#endif
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
template class CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, /*Activation Type*/
|
||||
cutlass::int4b_t, /*Weight Type*/
|
||||
cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY, half, /*Scale and Zero Type*/
|
||||
half, /*Bias type Type*/
|
||||
half /*Output type Type*/
|
||||
>;
|
||||
#endif
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,35 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
#ifdef ENABLE_FP8
|
||||
template class CutlassFpAIntBGemmRunner<__nv_fp8_e4m3, /*Activation Type*/
|
||||
cutlass::int4b_t, /*Weight Type*/
|
||||
cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, half, /*Scale and Zero Type*/
|
||||
half, /*Bias type Type*/
|
||||
half /*Output type Type*/
|
||||
>;
|
||||
#endif
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -62,11 +62,21 @@ public:
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
virtual void gemm(const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n,
|
||||
int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
||||
cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
|
||||
const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig,
|
||||
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
|
||||
const void* biases, const float alpha, void* C, int m, int n, int k, const int group_size,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
// Returns desired workspace size in bytes.
|
||||
virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0;
|
||||
|
||||
@ -78,7 +88,8 @@ protected:
|
||||
static constexpr int MIN_N_TILE = 64;
|
||||
};
|
||||
|
||||
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp,
|
||||
typename ScaleZeroType = ActivationType, typename BiasType = ActivationType, typename OutputType = ActivationType>
|
||||
class CutlassFpAIntBGemmRunner : public virtual CutlassFpAIntBGemmRunnerInterface
|
||||
{
|
||||
public:
|
||||
@ -89,10 +100,19 @@ public:
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
void gemm(const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, int k,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
|
||||
const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig,
|
||||
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override;
|
||||
|
||||
void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
|
||||
const void* biases, const float alpha, void* C, int m, int n, int k, const int group_size,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
// Disabled since the fused GEMM, activation kernels will not be used in v1.
|
||||
|
||||
// void gemm_bias_act(const T* A, const WeightType* B, const T* weight_scales, const T* biases, T* C, int m, int n,
|
||||
@ -106,9 +126,10 @@ public:
|
||||
|
||||
private:
|
||||
template <typename EpilogueTag>
|
||||
void dispatch_to_arch(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points,
|
||||
const T* biases, T* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
|
||||
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr);
|
||||
void dispatch_to_arch(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
||||
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr,
|
||||
const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr);
|
||||
|
||||
private:
|
||||
int sm_;
|
||||
|
||||
@ -38,6 +38,7 @@
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/fpA_intB_gemm_template_sm90.h"
|
||||
|
||||
namespace tk = tensorrt_llm::common;
|
||||
namespace tkc = tensorrt_llm::cutlass_extensions;
|
||||
@ -52,7 +53,7 @@ namespace cutlass_kernels
|
||||
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape, int Stages>
|
||||
void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T* weight_scales,
|
||||
const T* weight_zero_points, const T* biases, T* C, int m, int n, int k, const int group_size,
|
||||
const T* weight_zero_points, const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size,
|
||||
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
@ -177,7 +178,7 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
|
||||
{reinterpret_cast<ElementType*>(const_cast<T*>(weight_scales)), ld_scale_zero},
|
||||
{reinterpret_cast<ElementType*>(const_cast<T*>(weight_zero_points)), ld_scale_zero},
|
||||
{reinterpret_cast<ElementType*>(const_cast<T*>(biases)), 0}, {reinterpret_cast<ElementType*>(C), n},
|
||||
gemm_config.split_k_factor, {ElementAccumulator(1.f), output_op_beta});
|
||||
gemm_config.split_k_factor, {ElementAccumulator(alpha), output_op_beta});
|
||||
|
||||
// This assertion is enabled because because for the column interleaved layout, K MUST be a multiple of
|
||||
// threadblockK. The reason for this is that the default pitchlinear iterators are used to handle walking over the
|
||||
@ -230,8 +231,9 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
|
||||
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape, int Stages>
|
||||
void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points,
|
||||
const T* biases, T* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
|
||||
char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr)
|
||||
const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size,
|
||||
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
@ -252,16 +254,17 @@ void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_
|
||||
else
|
||||
{
|
||||
generic_mixed_gemm_kernelLauncher<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape,
|
||||
Stages>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace,
|
||||
workspace_bytes, stream, occupancy);
|
||||
Stages>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config,
|
||||
workspace, workspace_bytes, stream, occupancy);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape>
|
||||
void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points,
|
||||
const T* biases, T* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
|
||||
char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr)
|
||||
const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size,
|
||||
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
@ -269,18 +272,18 @@ void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scale
|
||||
{
|
||||
case 2:
|
||||
filter_and_run_mixed_gemm<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape, 2>(A, B,
|
||||
weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
|
||||
stream, occupancy);
|
||||
weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace,
|
||||
workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case 3:
|
||||
filter_and_run_mixed_gemm<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape, 3>(A, B,
|
||||
weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
|
||||
stream, occupancy);
|
||||
weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace,
|
||||
workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case 4:
|
||||
filter_and_run_mixed_gemm<T, WeightType, arch, QuantOp, EpilogueTag, ThreadblockShape, WarpShape, 4>(A, B,
|
||||
weight_scales, weight_zero_points, biases, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
|
||||
stream, occupancy);
|
||||
weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace,
|
||||
workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
default:
|
||||
std::string err_msg = "dispatch_gemm_config does not support stages " + std::to_string(gemm_config.stages);
|
||||
@ -289,58 +292,89 @@ void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scale
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
|
||||
void dispatch_gemm_to_cutlass(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points,
|
||||
const T* biases, T* C, int m, int n, int k, const int group_size, char* workspace, size_t workspace_bytes,
|
||||
tkc::CutlassGemmConfig gemm_config, cudaStream_t stream, int* occupancy = nullptr)
|
||||
template <typename T>
|
||||
constexpr bool is_fp8()
|
||||
{
|
||||
return std::is_same_v<T, __nv_fp8_e4m3> || std::is_same_v<T, __nv_fp8_e5m2>;
|
||||
}
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
|
||||
void dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
||||
int k, const int group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
|
||||
// Note that SIMT configs are omitted here since they are not supported for fpA_intB.
|
||||
// We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best
|
||||
// for mixed type gemms.
|
||||
switch (gemm_config.tile_config)
|
||||
// Don't instantiate configs that are not supported pre-hopper. Produce a sensible error instead.
|
||||
constexpr bool any_is_fp8 = is_fp8<ActivationType>() || is_fp8<WeightType>() || is_fp8<ScaleZeroType>()
|
||||
|| is_fp8<BiasType>() || is_fp8<OutputType>();
|
||||
|
||||
constexpr bool all_types_are_the_same = std::is_same_v<ActivationType, ScaleZeroType>
|
||||
&& std::is_same_v<ActivationType, BiasType> && std::is_same_v<ActivationType, OutputType>;
|
||||
|
||||
constexpr bool is_valid_pre_hopper = all_types_are_the_same && !any_is_fp8;
|
||||
|
||||
if constexpr (is_valid_pre_hopper)
|
||||
{
|
||||
case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatch_gemm_config<T, WeightType, arch, QuantOp, EpilogueTag, cutlass::gemm::GemmShape<32, 128, 64>,
|
||||
cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k,
|
||||
group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
|
||||
dispatch_gemm_config<T, WeightType, arch, QuantOp, EpilogueTag, cutlass::gemm::GemmShape<64, 128, 64>,
|
||||
cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k,
|
||||
group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
|
||||
if (arch::kMinComputeCapability < 75)
|
||||
// Note that SIMT configs are omitted here since they are not supported for fpA_intB.
|
||||
// We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the
|
||||
// best for mixed type gemms.
|
||||
switch (gemm_config.tile_config)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "Invalid config on Volta");
|
||||
case tkc::CutlassTileConfig::CtaShape32x128x64_WarpShape32x32x64:
|
||||
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
|
||||
cutlass::gemm::GemmShape<32, 128, 64>, cutlass::gemm::GemmShape<32, 32, 64>>(A, B, weight_scales,
|
||||
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
|
||||
stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfig::CtaShape64x128x64_WarpShape64x32x64:
|
||||
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
|
||||
cutlass::gemm::GemmShape<64, 128, 64>, cutlass::gemm::GemmShape<64, 32, 64>>(A, B, weight_scales,
|
||||
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
|
||||
stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfig::CtaShape128x128x64_WarpShape128x32x64:
|
||||
if (arch::kMinComputeCapability < 75)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "Invalid config on Volta");
|
||||
}
|
||||
else
|
||||
{
|
||||
dispatch_gemm_config<ActivationType, WeightType, arch, QuantOp, EpilogueTag,
|
||||
cutlass::gemm::GemmShape<128, 128, 64>, cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales,
|
||||
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes,
|
||||
stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case tkc::CutlassTileConfig::Undefined:
|
||||
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined.");
|
||||
break;
|
||||
case tkc::CutlassTileConfig::ChooseWithHeuristic:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have already been set by "
|
||||
"heuristic.");
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM.");
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
dispatch_gemm_config<T, WeightType, arch, QuantOp, EpilogueTag, cutlass::gemm::GemmShape<128, 128, 64>,
|
||||
cutlass::gemm::GemmShape<128, 32, 64>>(A, B, weight_scales, weight_zero_points, biases, C, m, n, k,
|
||||
group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
}
|
||||
break;
|
||||
case tkc::CutlassTileConfig::Undefined:
|
||||
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config undefined.");
|
||||
break;
|
||||
case tkc::CutlassTileConfig::ChooseWithHeuristic:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] gemm config should have already been set by "
|
||||
"heuristic.");
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLm Error][fpA_intB][dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM.");
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
// This is not a limitation in CUTLASS. We just do not need to support this case.
|
||||
std::string err_msg = "The activation type must equal the scale, bias and output types on Ampere and earlier.";
|
||||
throw std::runtime_error("[TensorRT-LLm Error][dispatch_gemm_to_cutlass] " + err_msg);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
|
||||
CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::CutlassFpAIntBGemmRunner()
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType,
|
||||
OutputType>::CutlassFpAIntBGemmRunner()
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
int device{-1};
|
||||
@ -349,42 +383,52 @@ CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::CutlassFpAIntBGemmRunner()
|
||||
tk::check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count_, cudaDevAttrMultiProcessorCount, device));
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
|
||||
CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::~CutlassFpAIntBGemmRunner()
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType,
|
||||
OutputType>::~CutlassFpAIntBGemmRunner()
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
template <typename EpilogueTag>
|
||||
void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::dispatch_to_arch<EpilogueTag>(const T* A, const WeightType* B,
|
||||
const T* weight_scales, const T* weight_zero_points, const T* biases, T* C, int m, int n, int k,
|
||||
const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr, const size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy)
|
||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType,
|
||||
OutputType>::dispatch_to_arch<EpilogueTag>(const ActivationType* A, const WeightType* B,
|
||||
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases,
|
||||
const float alpha, OutputType* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
|
||||
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
if (sm_ >= 70 && sm_ < 75)
|
||||
{
|
||||
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm70, QuantOp, EpilogueTag>(A, B, weight_scales,
|
||||
weight_zero_points, biases, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream,
|
||||
occupancy);
|
||||
dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, cutlass::arch::Sm70,
|
||||
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
||||
}
|
||||
else if (sm_ >= 75 && sm_ < 80)
|
||||
{
|
||||
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm75, QuantOp, EpilogueTag>(A, B, weight_scales,
|
||||
weight_zero_points, biases, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream,
|
||||
occupancy);
|
||||
dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, cutlass::arch::Sm75,
|
||||
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
||||
}
|
||||
else if (sm_ >= 80 && sm_ <= 90)
|
||||
else if (sm_ >= 80 && sm_ < 90)
|
||||
{
|
||||
dispatch_gemm_to_cutlass<T, WeightType, cutlass::arch::Sm80, QuantOp, EpilogueTag>(A, B, weight_scales,
|
||||
weight_zero_points, biases, C, m, n, k, group_size, workspace_ptr, workspace_bytes, gemm_config, stream,
|
||||
occupancy);
|
||||
dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, cutlass::arch::Sm80,
|
||||
QuantOp, EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
workspace_ptr, workspace_bytes, gemm_config, stream, occupancy);
|
||||
}
|
||||
else if (sm_ == 90)
|
||||
{
|
||||
sm90_dispatch_gemm_to_cutlass<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
|
||||
EpilogueTag>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size, workspace_ptr,
|
||||
workspace_bytes, gemm_config, stream, occupancy);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLm Error][CutlassFpAIntBGemmRunner][GEMM Dispatch] Arch unsupported for CUTLASS mixed type "
|
||||
"[TensorRT-LLM Error][CutlassFpAIntBGemmRunner][dispatch_to_arch] Arch unsupported for CUTLASS mixed type "
|
||||
"GEMM");
|
||||
}
|
||||
}
|
||||
@ -424,18 +468,20 @@ void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::dispatch_to_arch<Epilogue
|
||||
// }
|
||||
// }
|
||||
|
||||
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
|
||||
void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::gemm(const void* A, const void* B, const void* weight_scales,
|
||||
const void* weight_zero_points, const void* biases, void* C, int m, int n, int k, const int group_size,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
||||
const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, const void* biases,
|
||||
const float alpha, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig,
|
||||
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
if constexpr ((QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
|
||||
|| (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY))
|
||||
{
|
||||
dispatch_to_arch<tkc::EpilogueOpBias>((const T*) A, (const WeightType*) B, (const T*) weight_scales,
|
||||
(const T*) weight_zero_points, (const T*) biases, (T*) C, m, n, k, group_size, gemmConfig, workspace_ptr,
|
||||
workspace_bytes, stream, nullptr);
|
||||
dispatch_to_arch<tkc::EpilogueOpBias>((const ActivationType*) A, (const WeightType*) B,
|
||||
(const ScaleZeroType*) weight_scales, (const ScaleZeroType*) weight_zero_points, (const BiasType*) biases,
|
||||
alpha, (OutputType*) C, m, n, k, group_size, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -444,17 +490,31 @@ void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::gemm(const void* A, const
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
|
||||
void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::gemm(const void* A, const void* B, const void* weight_scales,
|
||||
void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
||||
cudaStream_t stream)
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
||||
const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, const void* biases,
|
||||
void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr,
|
||||
const size_t workspace_bytes, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
gemm(A, B, weight_scales, weight_zero_points, biases, 1.f, C, m, n, k, group_size, gemmConfig, workspace_ptr,
|
||||
workspace_bytes, stream);
|
||||
}
|
||||
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
||||
const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, int k,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
|
||||
if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY)
|
||||
{
|
||||
dispatch_to_arch<tkc::EpilogueOpDefault>((const T*) A, (const WeightType*) B, (const T*) weight_scales, nullptr,
|
||||
nullptr, (T*) C, m, n, k, k, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr);
|
||||
dispatch_to_arch<tkc::EpilogueOpBias>((const ActivationType*) A, (const WeightType*) B,
|
||||
(const ScaleZeroType*) weight_scales, nullptr, nullptr, alpha, (OutputType*) C, m, n, k, k, gemmConfig,
|
||||
workspace_ptr, workspace_bytes, stream, nullptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -462,17 +522,32 @@ void CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::gemm(const void* A, const
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
|
||||
std::vector<tkc::CutlassGemmConfig> CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::getConfigs() const
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
||||
const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||
{
|
||||
static constexpr bool is_weight_only = !std::is_same<T, WeightType>::value;
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
gemm(A, B, weight_scales, 1.f, C, m, n, k, gemmConfig, workspace_ptr, workspace_bytes, stream);
|
||||
}
|
||||
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
std::vector<tkc::CutlassGemmConfig>
|
||||
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::getConfigs() const
|
||||
{
|
||||
static constexpr bool is_weight_only = !std::is_same<ActivationType, WeightType>::value;
|
||||
std::vector<tkc::CutlassGemmConfig> candidateConfigs
|
||||
= get_candidate_configs(sm_, is_weight_only, false, false, SPLIT_K_LIMIT);
|
||||
= get_candidate_configs(sm_, is_weight_only, false, false, SPLIT_K_LIMIT, true);
|
||||
return candidateConfigs;
|
||||
}
|
||||
|
||||
template <typename T, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp>
|
||||
size_t CutlassFpAIntBGemmRunner<T, WeightType, QuantOp>::getWorkspaceSize(const int m, const int n, const int k)
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
size_t
|
||||
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::getWorkspaceSize(
|
||||
const int m, const int n, const int k)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
// These are the min tile sizes for each config, which would launch the maximum number of blocks
|
||||
|
||||
@ -0,0 +1,275 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "cute/numeric/integral_constant.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_heuristic.h"
|
||||
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h"
|
||||
|
||||
namespace tk = tensorrt_llm::common;
|
||||
namespace tkc = tensorrt_llm::cutlass_extensions;
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
// This filters out invalid template combinations that we DON'T want instantiated in CUTLASS. For example,
|
||||
// instantiating SM=75, Stages=3 is invalid so we would need to filter that out. Fine grained
|
||||
// quanitzation is only supported on Ampere+ GPUs.
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
||||
typename MainloopScheduleType>
|
||||
void sm90_dispatch_epilogue_schedules(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
||||
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
switch (gemm_config.epilogue_schedule)
|
||||
{
|
||||
case tkc::EpilogueScheduleType::AUTO:
|
||||
using EpilogueScheduleType = cute::conditional_t<size<0>(CTAShape{}) == Int<64>{},
|
||||
cutlass::epilogue::TmaWarpSpecialized, cutlass::epilogue::TmaWarpSpecializedCooperative>;
|
||||
sm90_generic_mixed_gemm_kernelLauncher<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
|
||||
EpilogueTag, CTAShape, ClusterShape, MainloopScheduleType, EpilogueScheduleType>(A, B, weight_scales,
|
||||
weight_zero_points, biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream,
|
||||
occupancy);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLM Error][fpA_intB][sm90_dispatch_epilogue_schedules] epilogue schedule config is invalid for "
|
||||
"mixed "
|
||||
"type GEMM.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
/*
|
||||
1x1x1 cluster shape is are supported for any tile shape.
|
||||
|
||||
2x1x1 cluster shape is only supported for when the M tile is at least 128.
|
||||
|
||||
1x2x1 cluster shape is only supported when the N tile is at least 128.
|
||||
|
||||
2x2x1 cluster shape is only supported when both the M and N tiles are at least 128.
|
||||
|
||||
We make the above restrictions are to improve compilation speed in TRT-LLM by pruning kernels
|
||||
that may not be very useful in practice.
|
||||
*/
|
||||
template <typename CTAShape, typename ClusterShape>
|
||||
constexpr bool are_tile_shapes_supported()
|
||||
{
|
||||
constexpr int cta_m = get<0>(CTAShape{});
|
||||
constexpr int cta_n = get<1>(CTAShape{});
|
||||
constexpr int cga_m = get<0>(ClusterShape{});
|
||||
constexpr int cga_n = get<1>(ClusterShape{});
|
||||
|
||||
if constexpr (cga_m == _1{} && cga_n == _1{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if constexpr (cga_m == _2{} && cga_n == _1{} && cta_m >= _128{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if constexpr (cga_m == _1{} && cga_n == _2{} && cta_n >= _128{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else if constexpr (cga_m == _2{} && cga_n == _2{} && cta_m >= _128{} && cta_n >= _128{})
|
||||
{
|
||||
return true;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape>
|
||||
void sm90_dispatch_mainloop_schedules(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
||||
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
|
||||
constexpr bool tile_shapes_supported = are_tile_shapes_supported<CTAShape, ClusterShape>();
|
||||
|
||||
if constexpr (tile_shapes_supported)
|
||||
{
|
||||
switch (gemm_config.mainloop_schedule)
|
||||
{
|
||||
case tkc::MainloopScheduleType::AUTO:
|
||||
using KernelScheduleType = cute::conditional_t<size<0>(CTAShape{}) == Int<64>{},
|
||||
cutlass::gemm::KernelTmaWarpSpecializedPingpongMixedInput,
|
||||
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>;
|
||||
sm90_dispatch_epilogue_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
|
||||
EpilogueTag, CTAShape, ClusterShape, KernelScheduleType>(A, B, weight_scales, weight_zero_points,
|
||||
biases, alpha, C, m, n, k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLM Error][fpA_intB][sm90_dispatch_mainloop_schedules] mainloop schedule config is invalid "
|
||||
"for "
|
||||
"mixed type GEMM.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLM Error][fpA_intB][sm90_dispatch_mainloop_schedules] Unsupported CTA and Cluster shapes for "
|
||||
"mixed type GEMM.");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape>
|
||||
void sm90_dispatch_gemm_config(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
||||
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
switch (gemm_config.cluster_shape)
|
||||
{
|
||||
case tkc::ClusterShape::ClusterShape_1x1x1:
|
||||
sm90_dispatch_mainloop_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
|
||||
EpilogueTag, CTAShape, Shape<_1, _1, _1>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n,
|
||||
k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::ClusterShape::ClusterShape_2x1x1:
|
||||
sm90_dispatch_mainloop_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
|
||||
EpilogueTag, CTAShape, Shape<_2, _1, _1>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n,
|
||||
k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::ClusterShape::ClusterShape_1x2x1:
|
||||
sm90_dispatch_mainloop_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
|
||||
EpilogueTag, CTAShape, Shape<_1, _2, _1>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n,
|
||||
k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::ClusterShape::ClusterShape_2x2x1:
|
||||
sm90_dispatch_mainloop_schedules<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp,
|
||||
EpilogueTag, CTAShape, Shape<_2, _2, _1>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n,
|
||||
k, group_size, gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLM Error][fpA_intB][dispatch_CGA_config] Config is invalid for mixed type GEMM.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
|
||||
void sm90_dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
||||
int k, const int group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
// Note that SIMT configs are omitted here since they are not supported for fpA_intB.
|
||||
// We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best
|
||||
// for mixed type gemms.
|
||||
|
||||
constexpr int Ktile = 128 / sizeof(ActivationType);
|
||||
using _Ktile = Int<Ktile>;
|
||||
switch (gemm_config.tile_config_sm90)
|
||||
{
|
||||
case tkc::CutlassTileConfigSM90::CtaShape64x16x128B:
|
||||
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
|
||||
Shape<_64, _16, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape64x32x128B:
|
||||
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
|
||||
Shape<_64, _32, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape64x64x128B:
|
||||
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
|
||||
Shape<_64, _64, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape64x128x128B:
|
||||
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
|
||||
Shape<_64, _128, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape64x256x128B:
|
||||
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
|
||||
Shape<_64, _256, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape128x16x128B:
|
||||
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
|
||||
Shape<_128, _16, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape128x32x128B:
|
||||
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
|
||||
Shape<_128, _32, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape128x64x128B:
|
||||
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
|
||||
Shape<_128, _64, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape128x128x128B:
|
||||
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
|
||||
Shape<_128, _128, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::CtaShape128x256x128B:
|
||||
sm90_dispatch_gemm_config<ActivationType, WeightType, ScaleZeroType, BiasType, OutputType, QuantOp, EpilogueTag,
|
||||
Shape<_128, _256, _Ktile>>(A, B, weight_scales, weight_zero_points, biases, alpha, C, m, n, k, group_size,
|
||||
gemm_config, workspace, workspace_bytes, stream, occupancy);
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::Undefined:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLm Error][fpA_intB][sm90_dispatch_gemm_to_cutlass] gemm config undefined.");
|
||||
break;
|
||||
case tkc::CutlassTileConfigSM90::ChooseWithHeuristic:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLm Error][fpA_intB][sm90_dispatch_gemm_to_cutlass] gemm config should have already been set by "
|
||||
"heuristic.");
|
||||
break;
|
||||
default:
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLm Error][fpA_intB][sm90_dispatch_gemm_to_cutlass] Config is invalid for mixed type GEMM.");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,39 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
#include "cutlass_extensions/weight_only_quant_op.h"
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
||||
typename MainloopScheduleType, typename EpilogueScheduleType>
|
||||
void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const WeightType* B,
|
||||
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases,
|
||||
const float alpha, OutputType* C, int m, int n, int k, const int group_size,
|
||||
tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr);
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,298 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef _WIN32
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif // #ifndef _WIN32
|
||||
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "cutlass_extensions/compute_occupancy.h"
|
||||
#include "cutlass_extensions/epilogue_helpers.h"
|
||||
#include "cutlass_extensions/gemm_configs.h"
|
||||
|
||||
#ifndef _WIN32
|
||||
#pragma GCC diagnostic pop
|
||||
#endif // #ifndef _WIN32
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/cutlass_type_conversion.h"
|
||||
#include "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.h"
|
||||
|
||||
namespace tk = tensorrt_llm::common;
|
||||
namespace tkc = tensorrt_llm::cutlass_extensions;
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace cutlass_kernels
|
||||
{
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
||||
typename MainloopScheduleType, typename EpilogueScheduleType>
|
||||
void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const WeightType* B,
|
||||
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases,
|
||||
const float alpha, OutputType* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
|
||||
char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
|
||||
#ifdef COMPILE_HOPPER_MIXED_INPUT_GEMMS
|
||||
using CutlassActivationType = typename TllmToCutlassTypeAdapter<ActivationType>::type;
|
||||
|
||||
// For FAST_BUILD, only instantiate kernels with 128x128x128B with 1x1x1 cluster shape.
|
||||
#ifdef FAST_BUILD
|
||||
constexpr int TILE_K = 128 * 8 / cutlass::sizeof_bits<CutlassActivationType>::value;
|
||||
using SupportedCtaShape = Shape<_128, _128, cute::Int<TILE_K>>;
|
||||
using SupportedCgaShape = Shape<_1, _1, _1>;
|
||||
|
||||
if constexpr (cute::is_same_v<SupportedCtaShape, CTAShape> && cute::is_same_v<SupportedCgaShape, ClusterShape>)
|
||||
{
|
||||
#endif // FAST_BUILD
|
||||
using CutlassWeightType__ = typename TllmToCutlassTypeAdapter<WeightType>::type;
|
||||
// We need to remap this since SM90 uses a different layout for the weight matrix.
|
||||
using CutlassWeightType_ = std::conditional_t<std::is_same_v<CutlassWeightType__, cutlass::uint4b_t>,
|
||||
cutlass::int4b_t, CutlassWeightType__>;
|
||||
|
||||
using CutlassWeightType
|
||||
= std::conditional_t<std::is_same_v<CutlassWeightType_, uint8_t>, int8_t, CutlassWeightType_>;
|
||||
|
||||
using CutlassScaleZeroType = typename TllmToCutlassTypeAdapter<ScaleZeroType>::type;
|
||||
using CutlassBiasType = typename TllmToCutlassTypeAdapter<BiasType>::type;
|
||||
using CutlassOutputType = typename TllmToCutlassTypeAdapter<OutputType>::type;
|
||||
|
||||
static_assert(std::is_same_v<CutlassActivationType, cutlass::half_t>
|
||||
|| std::is_same_v<CutlassActivationType, cutlass::bfloat16_t>
|
||||
|| std::is_same_v<CutlassActivationType, cutlass::float_e4m3_t>
|
||||
|| std::is_same_v<CutlassActivationType, cutlass::float_e5m2_t>,
|
||||
"Activation type must be bfloat16, half, FP8");
|
||||
|
||||
static_assert(std::is_same_v<CutlassWeightType, int8_t> || std::is_same_v<CutlassWeightType, cutlass::int4b_t>
|
||||
|| std::is_same_v<CutlassWeightType, cutlass::float_e4m3_t>
|
||||
|| std::is_same_v<CutlassWeightType, cutlass::float_e5m2_t>,
|
||||
"Weight type must be fp8, int8_t or int4_t");
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
|
||||
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<CutlassActivationType>::value;
|
||||
|
||||
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
|
||||
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<CutlassWeightType>::value;
|
||||
|
||||
// This example manually swaps and transposes, so keep transpose of input layouts
|
||||
using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
|
||||
using ElementZero = CutlassScaleZeroType;
|
||||
using ElementScale = CutlassScaleZeroType;
|
||||
|
||||
// C/D matrix configuration. We reuse the C operand for the bias and set the stride for broadcast.
|
||||
using LayoutBias = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentBias = 128 / cutlass::sizeof_bits<CutlassBiasType>::value;
|
||||
|
||||
// D matrix configuration
|
||||
using LayoutOutput = cutlass::layout::RowMajor;
|
||||
constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits<CutlassOutputType>::value;
|
||||
|
||||
// Core kernel configurations
|
||||
using ElementAccumulator = float; // Element type for internal accumulation
|
||||
using ElementCompute = float; // Element type for epilogue computation
|
||||
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
|
||||
using TileShape = CTAShape; // Threadblock-level tile size
|
||||
using KernelSchedule = MainloopScheduleType;
|
||||
using EpilogueSchedule = EpilogueScheduleType;
|
||||
|
||||
// Shrink the N dimension to match CTA_N if needed
|
||||
constexpr int epi_tile_M = cute::min(shape<0>(TileShape{}), 128); // 64 or 128
|
||||
constexpr int epi_tile_N = cute::min(shape<1>(TileShape{}), 32); // Allow this to be 16 for some small N tiles.
|
||||
using EpilogueTileType = cute::Shape<cute::Int<epi_tile_M>, cute::Int<epi_tile_N>>;
|
||||
|
||||
static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest;
|
||||
static_assert(std::is_same_v<EpilogueTag, tensorrt_llm::cutlass_extensions::EpilogueOpBias>, "");
|
||||
using EVT_bias_addition = cutlass::epilogue::fusion::Sm90EVT<
|
||||
cutlass::epilogue::fusion::Sm90Compute<cutlass::homogeneous_multiply_add, CutlassOutputType, ElementCompute,
|
||||
RoundStyle>, // alpha * acc + bias
|
||||
cutlass::epilogue::fusion::Sm90ScalarBroadcast<ElementAccumulator>, // alpha
|
||||
cutlass::epilogue::fusion::Sm90AccFetch, // acc
|
||||
cutlass::epilogue::fusion::Sm90ColBroadcast<0, TileShape, CutlassBiasType, Stride<_1, _0, _0>,
|
||||
AlignmentBias> // bias
|
||||
>;
|
||||
|
||||
using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<ArchTag, OperatorClass,
|
||||
TileShape, ClusterShape, EpilogueTileType, ElementAccumulator, ElementAccumulator,
|
||||
// Transpose layout of D here since we use the explicit swap + transpose trick
|
||||
// Void C since we don't use it. Prevents smem allocation.
|
||||
void, typename cutlass::layout::LayoutTranspose<LayoutBias>::type, AlignmentBias, CutlassOutputType,
|
||||
typename cutlass::layout::LayoutTranspose<LayoutOutput>::type, AlignmentOutput, EpilogueSchedule,
|
||||
EVT_bias_addition>::CollectiveOp;
|
||||
|
||||
using PackedScaleZero = cute::tuple<CutlassWeightType, ElementScale, ElementZero>;
|
||||
using PackedScale = cute::tuple<CutlassWeightType, ElementScale>;
|
||||
using ElementBCollectiveInfo = std::conditional_t<cutlass::hasZero(QuantOp), PackedScaleZero, PackedScale>;
|
||||
|
||||
// We swap A and B operands to the builder here
|
||||
using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<ArchTag, OperatorClass,
|
||||
ElementBCollectiveInfo, LayoutB_Transpose, AlignmentB, CutlassActivationType, LayoutA_Transpose, AlignmentA,
|
||||
ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<Shape<int, int, int, int>, // Indicates ProblemShape
|
||||
CollectiveMainloop, CollectiveEpilogue>;
|
||||
|
||||
if (occupancy != nullptr)
|
||||
{
|
||||
*occupancy = tensorrt_llm::cutlass_extensions::compute_occupancy_for_kernel<GemmKernel, true>();
|
||||
return;
|
||||
}
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
|
||||
using StrideA = typename GemmKernel::StrideA;
|
||||
using StrideB = typename GemmKernel::StrideB;
|
||||
using StrideC = typename GemmKernel::StrideC;
|
||||
using StrideD = typename GemmKernel::StrideD;
|
||||
using StrideS = typename CollectiveMainloop::StrideScale;
|
||||
|
||||
if (weight_scales == nullptr)
|
||||
{
|
||||
throw std::runtime_error("Weight scales must always be set to a non-null value.");
|
||||
}
|
||||
|
||||
if constexpr (cutlass::isFinegrained(QuantOp))
|
||||
{
|
||||
int cta_shape_k = cute::size<2>(TileShape{});
|
||||
if (group_size % cta_shape_k != 0)
|
||||
{
|
||||
std::string err_msg = "The group size must a multiple of " + std::to_string(cta_shape_k);
|
||||
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner]" + err_msg);
|
||||
}
|
||||
|
||||
if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)
|
||||
{
|
||||
if (weight_zero_points != nullptr)
|
||||
{
|
||||
throw std::runtime_error("Weight zero pointer must be a nullptr for scale only fine grained");
|
||||
}
|
||||
}
|
||||
else if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
|
||||
{
|
||||
if (weight_zero_points == nullptr)
|
||||
{
|
||||
throw std::runtime_error("Weight zero pointer must be valid for scale and bias fine grained");
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (group_size != k)
|
||||
{
|
||||
throw std::runtime_error("Invalid group size for per column scaling kernels.");
|
||||
}
|
||||
|
||||
if (weight_zero_points != nullptr)
|
||||
{
|
||||
throw std::runtime_error("Weight zero-points must be null when running per column scaling");
|
||||
}
|
||||
}
|
||||
|
||||
auto cutlass_scale_k = (k + group_size - 1) / group_size;
|
||||
StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, 1));
|
||||
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, 1));
|
||||
StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(n, m, 1));
|
||||
StrideS stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(n, cutlass_scale_k, 1));
|
||||
|
||||
// Use the output as the bias to avoid making a tma descriptor with a nullptr.
|
||||
auto output_as_bias_type = reinterpret_cast<const CutlassBiasType*>(C);
|
||||
|
||||
typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, {n, m, k, 1},
|
||||
{reinterpret_cast<CutlassWeightType const*>(B), stride_B, reinterpret_cast<CutlassActivationType const*>(A),
|
||||
stride_A, reinterpret_cast<ElementScale const*>(weight_scales), stride_S, group_size,
|
||||
reinterpret_cast<ElementZero const*>(weight_zero_points)},
|
||||
{{}, output_as_bias_type, stride_D, reinterpret_cast<CutlassOutputType*>(C), stride_D}};
|
||||
|
||||
args.epilogue.thread = {
|
||||
{alpha}, // alpha args
|
||||
{}, // accumulator
|
||||
{reinterpret_cast<CutlassBiasType const*>(biases), CutlassBiasType(0.f)}, // bias args
|
||||
{} // end multiply_add
|
||||
};
|
||||
|
||||
Gemm gemm;
|
||||
if (gemm.get_workspace_size(args) > workspace_bytes)
|
||||
{
|
||||
TLLM_LOG_ERROR("[TensorRT-LLm Error][fpA_intB Runner] given workspace size insufficient.");
|
||||
}
|
||||
|
||||
auto can_implement = gemm.can_implement(args);
|
||||
if (can_implement != cutlass::Status::kSuccess)
|
||||
{
|
||||
std::string err_msg = "fpA_intB cutlass kernel will fail for params. Error: "
|
||||
+ std::string(cutlassGetStatusString(can_implement));
|
||||
std::cout << err_msg << std::endl;
|
||||
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg);
|
||||
}
|
||||
|
||||
auto init_status = gemm.initialize(args, workspace, stream);
|
||||
if (init_status != cutlass::Status::kSuccess)
|
||||
{
|
||||
std::string err_msg = "Failed to initialize cutlass fpA_intB gemm. Error: "
|
||||
+ std::string(cutlassGetStatusString(init_status));
|
||||
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg);
|
||||
}
|
||||
|
||||
auto run_status = gemm.run(stream);
|
||||
if (run_status != cutlass::Status::kSuccess)
|
||||
{
|
||||
std::string err_msg
|
||||
= "Failed to run cutlass fpA_intB gemm. Error: " + std::string(cutlassGetStatusString(run_status));
|
||||
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] " + err_msg);
|
||||
}
|
||||
#ifdef FAST_BUILD
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("[TensorRT-LLm Error][fpA_intB Runner] Config not compiled with FAST_BUILD.");
|
||||
}
|
||||
#endif // FAST_BUILD
|
||||
|
||||
#else // COMPILE_HOPPER_MIXED_INPUT_GEMMS
|
||||
throw std::runtime_error(
|
||||
"[TensorRT-LLm Error][fpA_intB Runner] Please recompile with support for hopper by passing 90-real as an arch "
|
||||
"to build_wheel.py.");
|
||||
#endif // COMPILE_HOPPER_MIXED_INPUT_GEMMS
|
||||
}
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,304 @@
|
||||
import argparse
|
||||
import enum
|
||||
import os
|
||||
from itertools import product
|
||||
|
||||
from cutlass_library import *
|
||||
|
||||
|
||||
################################################################################
|
||||
# Epilogue Tag enum and string utils
|
||||
class TrtLlm_EpilogueTag(enum.Enum):
|
||||
epilogue_op_default = enum_auto()
|
||||
epilogue_op_bias = enum_auto()
|
||||
|
||||
|
||||
EpiTagNames = {
|
||||
TrtLlm_EpilogueTag.epilogue_op_default: "lc", # linear combination
|
||||
TrtLlm_EpilogueTag.epilogue_op_bias:
|
||||
"lc_bias" # linear combination with bias addition
|
||||
}
|
||||
|
||||
EpiTag = {
|
||||
TrtLlm_EpilogueTag.epilogue_op_default:
|
||||
"tensorrt_llm::cutlass_extensions::EpilogueOpDefault",
|
||||
TrtLlm_EpilogueTag.epilogue_op_bias:
|
||||
"tensorrt_llm::cutlass_extensions::EpilogueOpBias"
|
||||
}
|
||||
|
||||
|
||||
################################################################################
|
||||
# Quantization Operation and string utils
|
||||
class TrtLlm_QuantOp(enum.Enum):
|
||||
per_column_scale_only = enum_auto()
|
||||
finegrained_scale_only = enum_auto()
|
||||
finegrained_scale_and_zeros = enum_auto()
|
||||
|
||||
|
||||
QuantOpNames = {
|
||||
TrtLlm_QuantOp.per_column_scale_only: "cs",
|
||||
TrtLlm_QuantOp.finegrained_scale_only: "fgs",
|
||||
TrtLlm_QuantOp.finegrained_scale_and_zeros: "fgsz"
|
||||
}
|
||||
|
||||
QuantOpTag = {
|
||||
TrtLlm_QuantOp.per_column_scale_only:
|
||||
"cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY",
|
||||
TrtLlm_QuantOp.finegrained_scale_only:
|
||||
"cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY",
|
||||
TrtLlm_QuantOp.finegrained_scale_and_zeros:
|
||||
"cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS"
|
||||
}
|
||||
|
||||
################################################################################
|
||||
# The activations, biases, scales and zeros are instantiated using CUDA types,
|
||||
# not CUTLASS types. This map materializes the name of the CUDA type.
|
||||
CudaTypeName = {
|
||||
DataType.e4m3: "__nv_fp8_e4m3",
|
||||
DataType.bf16: "__nv_bfloat16",
|
||||
DataType.f16: "half"
|
||||
}
|
||||
|
||||
|
||||
################################################################################
|
||||
# A data structure holding all info to instantiate gemm launchers in TRT LLM.
|
||||
class TrtLlm_GemmLauncher:
|
||||
|
||||
def __init__(self,
|
||||
gemm_kind,
|
||||
arch,
|
||||
act_type,
|
||||
weight_type,
|
||||
scalezero_type,
|
||||
bias_type,
|
||||
output_type,
|
||||
quant_op,
|
||||
epi_tag,
|
||||
cta_shape,
|
||||
warp_shape,
|
||||
stages,
|
||||
cga_shape=None,
|
||||
mainloop_schedule=None,
|
||||
epi_schedule=None):
|
||||
self.gemm_kind = gemm_kind
|
||||
self.arch = arch
|
||||
self.act_type = act_type
|
||||
self.weight_type = weight_type
|
||||
self.scalezero_type = scalezero_type
|
||||
self.bias_type = bias_type
|
||||
self.output_type = output_type
|
||||
self.quant_op = quant_op
|
||||
self.epi_tag = epi_tag
|
||||
self.cta_shape = cta_shape
|
||||
self.warp_shape = warp_shape
|
||||
self.stages = stages
|
||||
self.cga_shape = cga_shape
|
||||
self.mainloop_schedule = mainloop_schedule
|
||||
self.epi_schedule = epi_schedule
|
||||
|
||||
def __repr__(self):
|
||||
kernel_prefix = "{}_sm{}_{}_{}_{}_{}_{}_{}_{}_{}x{}x{}_{}x{}x{}_{}".format(
|
||||
GemmKindNames[self.gemm_kind], self.arch,
|
||||
DataTypeNames[self.act_type], DataTypeNames[self.weight_type],
|
||||
DataTypeNames[self.scalezero_type], DataTypeNames[self.bias_type],
|
||||
DataTypeNames[self.output_type], QuantOpNames[self.quant_op],
|
||||
EpiTagNames[self.epi_tag], self.cta_shape[0], self.cta_shape[1],
|
||||
self.cta_shape[2], self.warp_shape[0], self.warp_shape[1],
|
||||
self.warp_shape[2], self.stages)
|
||||
|
||||
hopper_suffix = "_{}x{}x{}{}{}".format(
|
||||
self.cga_shape[0], self.cga_shape[1], self.cga_shape[2],
|
||||
KernelScheduleSuffixes[self.mainloop_schedule],
|
||||
EpilogueScheduleSuffixes[self.epi_schedule])
|
||||
|
||||
if self.arch == 90:
|
||||
return kernel_prefix + hopper_suffix
|
||||
elif self.arch > 90:
|
||||
raise ValueError(f"SM{self.arch} not supported yet.")
|
||||
return kernel_prefix
|
||||
|
||||
|
||||
################################################################################
|
||||
def tuple_to_cute_shape(shape):
|
||||
return f"cute::Shape<cute::Int<{shape[0]}>, cute::Int<{shape[1]}>, cute::Int<{shape[2]}>>"
|
||||
|
||||
|
||||
def instantiate_operation(operation):
|
||||
|
||||
act_tag = CudaTypeName[operation.act_type]
|
||||
weight_tag = DataTypeTag[operation.weight_type]
|
||||
scale_zero_tag = CudaTypeName[operation.scalezero_type]
|
||||
bias_tag = CudaTypeName[operation.bias_type]
|
||||
out_tag = CudaTypeName[operation.output_type]
|
||||
|
||||
quant_op = QuantOpTag[operation.quant_op]
|
||||
epi_tag = EpiTag[operation.epi_tag]
|
||||
|
||||
cute_cta_shape = tuple_to_cute_shape(operation.cta_shape)
|
||||
cute_cga_shape = tuple_to_cute_shape(operation.cga_shape)
|
||||
|
||||
kernel_sched = KernelScheduleTag[operation.mainloop_schedule]
|
||||
|
||||
# Here, we must append MixedInput depending on the schedule, since we know the types are different.
|
||||
# It is a work around since the CUTLASS library did not have the MixedInput schedules at the time of writing.
|
||||
if operation.mainloop_schedule in [
|
||||
KernelScheduleType.TmaWarpSpecializedCooperative,
|
||||
KernelScheduleType.TmaWarpSpecializedPingpong,
|
||||
KernelScheduleType.TmaWarpSpecialized
|
||||
]:
|
||||
kernel_sched += "MixedInput"
|
||||
epi_sched = EpilogueScheduleTag[operation.epi_schedule]
|
||||
|
||||
instantiation = f"""
|
||||
template void sm90_generic_mixed_gemm_kernelLauncher<{act_tag}, {weight_tag}, {scale_zero_tag}, {bias_tag}, {out_tag},
|
||||
{quant_op}, {epi_tag},
|
||||
{cute_cta_shape}, {cute_cga_shape},
|
||||
{kernel_sched}, {epi_sched}> (
|
||||
const {act_tag}*, const {weight_tag}*, const {scale_zero_tag}*, const {scale_zero_tag}*, const {bias_tag}*, const float,
|
||||
{out_tag}*, int, int, int, const int, tensorrt_llm::cutlass_extensions::CutlassGemmConfig, char*, size_t, cudaStream_t, int*
|
||||
);
|
||||
"""
|
||||
return instantiation
|
||||
|
||||
|
||||
def get_file_content(launcher_inl_files, operations):
|
||||
|
||||
include_list = list()
|
||||
for file in launcher_inl_files:
|
||||
include_list.append(f"#include \"{file}\"")
|
||||
includes = "\n".join(include_list)
|
||||
|
||||
insts_list = list()
|
||||
for op in operations:
|
||||
insts_list.append(instantiate_operation(op))
|
||||
instantiations = "\n".join(insts_list)
|
||||
|
||||
file_content = f"""{includes}
|
||||
namespace tensorrt_llm
|
||||
{{
|
||||
namespace kernels
|
||||
{{
|
||||
namespace cutlass_kernels
|
||||
{{
|
||||
|
||||
{instantiations}
|
||||
|
||||
}} // namespace cutlass_kernels
|
||||
}} // namespace kernels
|
||||
}} // namespace tensorrt_llm
|
||||
"""
|
||||
return file_content
|
||||
|
||||
|
||||
def write_file(launcher_inl_files, operations, output_file):
|
||||
with open(output_file, mode="w") as f:
|
||||
f.write(get_file_content(launcher_inl_files, operations))
|
||||
|
||||
|
||||
def is_op_valid(op):
|
||||
tile_m, tile_n, _ = op.cta_shape
|
||||
cga_m, cga_n, _ = op.cga_shape
|
||||
|
||||
if cga_m == 1 and cga_n == 1:
|
||||
return True
|
||||
|
||||
if cga_m == 2 and cga_n == 1 and tile_m >= 128:
|
||||
return True
|
||||
|
||||
if cga_m == 1 and cga_n == 2 and tile_n >= 128:
|
||||
return True
|
||||
|
||||
if cga_m == 2 and cga_n == 2 and tile_m >= 128 and tile_n >= 128:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
################################################################################
|
||||
def generate_sm90_operations():
|
||||
arch = 90
|
||||
|
||||
# For legacy reasons, we use unsigned types for fp16 / bf16 activations.
|
||||
# Takes the form (activation_type, weight_type, scalezero_type, bias_type, output_type)
|
||||
supported_dtypes = [
|
||||
(DataType.e4m3, DataType.s4, DataType.f16, DataType.f16, DataType.f16),
|
||||
(DataType.f16, DataType.u4, DataType.f16, DataType.f16, DataType.f16),
|
||||
(DataType.bf16, DataType.u4, DataType.bf16, DataType.bf16,
|
||||
DataType.bf16),
|
||||
(DataType.f16, DataType.u8, DataType.f16, DataType.f16, DataType.f16),
|
||||
(DataType.bf16, DataType.u8, DataType.bf16, DataType.bf16,
|
||||
DataType.bf16)
|
||||
]
|
||||
|
||||
quant_ops = [
|
||||
TrtLlm_QuantOp.per_column_scale_only,
|
||||
TrtLlm_QuantOp.finegrained_scale_only,
|
||||
TrtLlm_QuantOp.finegrained_scale_and_zeros
|
||||
]
|
||||
|
||||
epi_tags = [TrtLlm_EpilogueTag.epilogue_op_bias]
|
||||
|
||||
M_TILES = [64, 128]
|
||||
N_TILES = [16, 32, 64, 128, 256]
|
||||
cta_shapes_mn = product(M_TILES, N_TILES)
|
||||
|
||||
warp_shape = [4, 1, 1]
|
||||
stages = 0 # auto
|
||||
|
||||
cga_shapes = product([1, 2], [1, 2], [1])
|
||||
|
||||
partial_args = product(supported_dtypes, quant_ops, epi_tags, cta_shapes_mn,
|
||||
cga_shapes)
|
||||
|
||||
operations = list()
|
||||
for dtype_combo, quant_op, epi_tag, cta_shape_mn, cga_shape in partial_args:
|
||||
max_k_bits = 128 * 8
|
||||
cta_shape_k = max_k_bits // DataTypeSize[dtype_combo[0]]
|
||||
cta_shape_mnk = cta_shape_mn + (cta_shape_k, )
|
||||
|
||||
use_coop = cta_shape_mn[0] == 128
|
||||
mainloop_schedule = KernelScheduleType.TmaWarpSpecializedCooperative if use_coop else KernelScheduleType.TmaWarpSpecializedPingpong
|
||||
epi_schedule = EpilogueScheduleType.TmaWarpSpecializedCooperative if use_coop else EpilogueScheduleType.TmaWarpSpecialized
|
||||
|
||||
operation = TrtLlm_GemmLauncher(GemmKind.Gemm, arch, *dtype_combo, quant_op, epi_tag, cta_shape_mnk, \
|
||||
warp_shape, stages, cga_shape, mainloop_schedule, epi_schedule)
|
||||
|
||||
if is_op_valid(operation):
|
||||
operations.append(operation)
|
||||
return operations
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description='Print the output directory')
|
||||
|
||||
# Add the output_dir argument with short and long options
|
||||
parser.add_argument('-o',
|
||||
'--output_dir',
|
||||
type=str,
|
||||
required=True,
|
||||
help='Path to the output directory')
|
||||
|
||||
# Parse the command line arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get the absolute path of the provided directory
|
||||
output_dir = os.path.abspath(args.output_dir)
|
||||
|
||||
hopper_inl = "tensorrt_llm/kernels/cutlass_kernels/fpA_intB_gemm/launchers/fpA_intB_launcher_sm90.inl"
|
||||
|
||||
# The goal here is to group kernels with common instantiations together in order to reduce template instantiation overheads.
|
||||
# Template instantiation dominates the time in a compilation unit, so it is the most important factor to improve.
|
||||
operations = generate_sm90_operations()
|
||||
op_groups = dict()
|
||||
for op in operations:
|
||||
dict_key = (op.gemm_kind, op.arch, op.cta_shape[0])
|
||||
op_group = op_groups.get(dict_key, list())
|
||||
op_group.append(op)
|
||||
op_groups[dict_key] = op_group
|
||||
|
||||
file_counter = 1
|
||||
for key, value in op_groups.items():
|
||||
out_file = os.path.join(
|
||||
output_dir, f"cutlass_kernel_file_{file_counter}.generated.cu")
|
||||
write_file([hopper_inl], value, out_file)
|
||||
file_counter += 1
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -64,93 +64,113 @@ void invokeCurandBatchInitialize(curandState_t* states, const int* batchSlots, c
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void addBiasSoftMax(T* logits, T* probs, const T* bias, const int* endIds, const FinishedState* finished,
|
||||
const int* batchSlots, const int vocabSize, const int vocabSizePadded)
|
||||
__global__ void addBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bias, int32_t const* endIds,
|
||||
FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize, int32_t maxBatchSize,
|
||||
int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax, bool batchSlotsLogits)
|
||||
{
|
||||
auto const batchIdx = blockIdx.x;
|
||||
auto const beamIdx = blockIdx.y;
|
||||
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
|
||||
const FinishedState finishState = finished != nullptr ? finished[batchSlot] : FinishedState::empty();
|
||||
auto const batchIdxLogits = batchSlotsLogits ? batchSlot : batchIdx;
|
||||
FinishedState const finishState
|
||||
= finished != nullptr ? finished[beamIdx * maxBatchSize + batchSlot] : FinishedState::empty();
|
||||
if (finishState.isSkipDecoding())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
auto logitsPtr = logitsPtrs ? logitsPtrs[batchIdx] + beamIdx * vocabSizePadded
|
||||
: logits + (batchIdxLogits * beamWidth + beamIdx) * vocabSizePadded;
|
||||
|
||||
bool finish = finishState.isFinished();
|
||||
int offset = batchIdx * vocabSizePadded;
|
||||
int offset = (batchIdxLogits * beamWidth + beamIdx) * vocabSizePadded;
|
||||
|
||||
float maxVal = -1 * FLT_MAX;
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||
T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
__shared__ float sMaxVal;
|
||||
__shared__ float sSumVal;
|
||||
|
||||
for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x)
|
||||
{
|
||||
auto logit = logitsPtr[tid];
|
||||
if (tid < vocabSize)
|
||||
{
|
||||
if (finish && endIds != nullptr)
|
||||
{
|
||||
logits[offset + tid] = (tid == endIds[batchSlot]) ? MAX_T_VAL : -MAX_T_VAL;
|
||||
logit = (tid == endIds[batchSlot]) ? MAX_T_VAL : -MAX_T_VAL;
|
||||
}
|
||||
else
|
||||
{
|
||||
T bias_val = (bias != nullptr) ? bias[tid] : (T) 0.0f;
|
||||
logits[offset + tid] += bias_val;
|
||||
logit += bias_val;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
logits[offset + tid] = -MAX_T_VAL;
|
||||
logit = -MAX_T_VAL;
|
||||
}
|
||||
maxVal = max(maxVal, (float) logits[offset + tid]);
|
||||
maxVal = max(maxVal, (float) logit);
|
||||
logitsPtr[tid] = logit;
|
||||
}
|
||||
|
||||
maxVal = blockReduceMax<float>((float) maxVal);
|
||||
if (threadIdx.x == 0)
|
||||
if (!skipSoftMax)
|
||||
{
|
||||
sMaxVal = maxVal;
|
||||
}
|
||||
__syncthreads();
|
||||
maxVal = blockReduceMax<float>((float) maxVal);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
sMaxVal = maxVal;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float sumVal = 0.0f;
|
||||
for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x)
|
||||
{
|
||||
probs[offset + tid] = __expf((float) logits[offset + tid] - sMaxVal);
|
||||
sumVal += (float) probs[offset + tid];
|
||||
}
|
||||
float sumVal = 0.0f;
|
||||
for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x)
|
||||
{
|
||||
probs[offset + tid] = __expf((float) logitsPtr[tid] - sMaxVal);
|
||||
sumVal += (float) probs[offset + tid];
|
||||
}
|
||||
|
||||
sumVal = blockReduceSum<float>(sumVal);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
sSumVal = sumVal;
|
||||
}
|
||||
__syncthreads();
|
||||
sumVal = blockReduceSum<float>(sumVal);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
sSumVal = sumVal;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x)
|
||||
{
|
||||
probs[offset + tid] = ((float) probs[offset + tid] / (sSumVal + 1e-6f));
|
||||
for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x)
|
||||
{
|
||||
probs[offset + tid] = ((float) probs[offset + tid] / (sSumVal + 1e-6f));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeAddBiasSoftMax(T* logits, T* probs, const T* bias, const int* endIds, const FinishedState* finished,
|
||||
const int* batchSlots, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream)
|
||||
void invokeAddBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bias, int32_t const* endIds,
|
||||
FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize, int32_t maxBatchSize,
|
||||
int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax, bool batchSlotsLogits,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
dim3 grid(batchSize);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
dim3 grid(batchSize, beamWidth);
|
||||
auto const vocabRoundedToWarp = roundUp(vocabSize, 32);
|
||||
dim3 block(min(vocabRoundedToWarp, 1024));
|
||||
// vocabSize, e.g., 30000, 7000.... vocabSize is usually very big.
|
||||
addBiasSoftMax<<<grid, block, 0, stream>>>(
|
||||
logits, probs, bias, endIds, finished, batchSlots, vocabSize, vocabSizePadded);
|
||||
addBiasSoftMax<<<grid, block, 0, stream>>>(logits, logitsPtrs, probs, bias, endIds, finished, batchSlots, batchSize,
|
||||
maxBatchSize, beamWidth, vocabSize, vocabSizePadded, skipSoftMax, batchSlotsLogits);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template void invokeAddBiasSoftMax(float* logits, float* probs, const float* bias, const int* endIds,
|
||||
const FinishedState* finished, const int* batchSlots, const int m, const int nPadded, const int n,
|
||||
cudaStream_t stream);
|
||||
template void invokeAddBiasSoftMax(float* logits, float** logitsPtrs, float* probs, float const* bias,
|
||||
int32_t const* endIds, FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize,
|
||||
int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax,
|
||||
bool batchSlotsLogits, cudaStream_t stream);
|
||||
|
||||
template void invokeAddBiasSoftMax(half* logits, half* probs, const half* bias, const int* endIds,
|
||||
const FinishedState* finished, const int* batchSlots, const int m, const int nPadded, const int n,
|
||||
cudaStream_t stream);
|
||||
template void invokeAddBiasSoftMax(half* logits, half** logitsPtrs, half* probs, half const* bias,
|
||||
int32_t const* endIds, FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize,
|
||||
int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax,
|
||||
bool batchSlotsLogits, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
__global__ void scatterDecodingParamsKernel(T const* src, T* dst, int const* batchSlots, int batchSize)
|
||||
|
||||
@ -169,19 +169,28 @@ void invokeCurandBatchInitialize(curandState_t* states, const int* batchSlots, c
|
||||
//! endId token. Otherwise, adds bias per token if bias pointer is not nullptr.
|
||||
//!
|
||||
//! \param logits input/output buffer [maxBatchSize, vocabSize]. Logits to be modified by mask and bias.
|
||||
//! If nullptr, logitsPtrs has to be provided.
|
||||
//! \param logitsPtrs input/output buffer [maxBatchSize][vocabSize]. Vector of pointers to the logits.
|
||||
//! If nullptr, logits has to be provided.
|
||||
//! \param probs output buffer [maxBatchSize, vocabSize]. Probabilities of logits compute by softmax.
|
||||
//! Can be the same pointer as logits
|
||||
//! \param bias input buffer [vocabSize]. Bias to logit per token. Ignored if nullptr
|
||||
//! \param endIds input buffer [maxBatchSize]. EOS token ids per request
|
||||
//! \param finished input buffer [maxBatchSize] with flags set to true if request has finished the generation
|
||||
//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool
|
||||
//! \param batchSize batch size
|
||||
//! \param batchSize current batch size
|
||||
//! \param maxBatchSize max batch size
|
||||
//! \param beamWidth beam width
|
||||
//! \param vocabSize unpadded vocab size
|
||||
//! \param vocabSizePadded padded vocab size
|
||||
//! \param skipSoftMax flag to skip softmax computation
|
||||
//! \param batchSlotsLogits flag to use batchSlot as index for logits and probs
|
||||
//! \param stream stream
|
||||
template <typename T>
|
||||
void invokeAddBiasSoftMax(T* logits, T* probs, const T* bias, const int* endIds, const FinishedState* finished,
|
||||
const int* batchSlots, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream);
|
||||
void invokeAddBiasSoftMax(T* logits, T** logitsPtrs, T* probs, T const* bias, int32_t const* endIds,
|
||||
FinishedState const* finished, int32_t const* batchSlots, int32_t batchSize, int32_t maxBatchSize,
|
||||
int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded, bool skipSoftMax, bool batchSlotsLogits,
|
||||
cudaStream_t stream);
|
||||
|
||||
//! \brief Distributes values located in src to dst according to the indieces from batchSlots
|
||||
//!
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -499,26 +499,26 @@ void invokeTransposeLogProbs(float* outputLogProbs, float* outputLogProbsTiled,
|
||||
outputLogProbs, outputLogProbsTiled, sequenceLengths, batchSlots, batchSize, beamWidth, maxSeqLen);
|
||||
}
|
||||
|
||||
__global__ void acceptDraftTokensByIds(const int* draftIds, const int* targetIds, const int* contextLengths,
|
||||
const int* numsDraftTokens, int* sequenceLengths, const FinishedState* finished, FinishedState* finishedFinal,
|
||||
int* finishedSum, int batchSize, int beamWidth, int maxSeqLen, int maxDraftTokens)
|
||||
__global__ void acceptDraftTokensByIds(int32_t const* draftIds, int32_t const* targetIds, int32_t const* contextLengths,
|
||||
int32_t const* numsDraftTokens, int32_t* sequenceLengths, FinishedState const* finished,
|
||||
FinishedState* finishedFinal, int32_t* finishedSum, int32_t const* batchSlots, int32_t batchSize,
|
||||
int32_t maxBatchSize, int32_t maxSeqLen, int32_t maxDraftTokens)
|
||||
{
|
||||
int threadFinishedCount = 0;
|
||||
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < batchSize * beamWidth;
|
||||
index += blockDim.x * gridDim.x)
|
||||
for (int batchIdx = threadIdx.x; batchIdx < batchSize; batchIdx += blockDim.x)
|
||||
{
|
||||
const auto numDraftTokens = numsDraftTokens[index];
|
||||
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
|
||||
auto const numDraftTokens = numsDraftTokens[batchSlot];
|
||||
|
||||
const auto contextLength = contextLengths[index];
|
||||
auto& sequenceLength = sequenceLengths[index];
|
||||
auto const contextLength = contextLengths[batchSlot];
|
||||
auto& sequenceLength = sequenceLengths[batchSlot];
|
||||
int finishedDraftIdx = 0;
|
||||
for (int ti = contextLength; ti < min(sequenceLength, contextLength + numDraftTokens); ++ti, ++finishedDraftIdx)
|
||||
{
|
||||
const auto draftIdx = ti - contextLength;
|
||||
const auto targetTokenIdx = index * maxSeqLen + ti;
|
||||
const auto draftTokenIdx = index * maxDraftTokens + draftIdx;
|
||||
auto const draftIdx = ti - contextLength;
|
||||
auto const targetTokenIdx = batchSlot * maxSeqLen + ti;
|
||||
auto const draftTokenIdx = batchSlot * maxDraftTokens + draftIdx;
|
||||
// Check if draft tokens are the same as target tokens
|
||||
const bool accepted = draftIds[draftTokenIdx] == targetIds[targetTokenIdx];
|
||||
bool const accepted = draftIds[draftTokenIdx] == targetIds[targetTokenIdx];
|
||||
if (!accepted)
|
||||
{
|
||||
// Set sequence length to the numAcceptedTokens + 1
|
||||
@ -527,65 +527,57 @@ __global__ void acceptDraftTokensByIds(const int* draftIds, const int* targetIds
|
||||
break;
|
||||
}
|
||||
}
|
||||
FinishedState finishState = finished[finishedDraftIdx * batchSize * beamWidth + index];
|
||||
finishedFinal[index] = finishState;
|
||||
threadFinishedCount += static_cast<int>(finishState.isFinished());
|
||||
}
|
||||
FinishedState finishState = finished[finishedDraftIdx * maxBatchSize + batchSlot];
|
||||
finishedFinal[batchSlot] = finishState;
|
||||
|
||||
if (finishedSum)
|
||||
{
|
||||
int blockFinishedCount = 0;
|
||||
if (blockDim.x <= 32)
|
||||
if (finishedSum)
|
||||
{
|
||||
blockFinishedCount = warpReduceSum(threadFinishedCount);
|
||||
}
|
||||
else
|
||||
{
|
||||
blockFinishedCount = blockReduceSum(threadFinishedCount);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
finishedSum[0] = blockFinishedCount;
|
||||
finishedSum[batchSlot] = static_cast<int>(finishState.isFinished());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void invokeAcceptDraftTokensByIds(const int* draftIds, const int* targetIds, const int* contextLengths,
|
||||
const int* numsDraftTokens, int* sequenceLengths, const FinishedState* finished, FinishedState* finishedFinal,
|
||||
int* finishedSum, int batchSize, int beamWidth, int maxSeqLen, int maxDraftTokens, cudaStream_t stream)
|
||||
void invokeAcceptDraftTokensByIds(int32_t const* draftIds, int32_t const* targetIds, int32_t const* contextLengths,
|
||||
int32_t const* numsDraftTokens, int32_t* sequenceLengths, FinishedState const* finished,
|
||||
FinishedState* finishedFinal, int32_t* finishedSum, int32_t const* batchSlots, int32_t batchSize,
|
||||
int32_t maxBatchSize, int32_t beamWidth, int32_t maxSeqLen, int32_t maxDraftTokens, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK(beamWidth == 1);
|
||||
dim3 block(min(256, batchSize * beamWidth));
|
||||
dim3 block(min(1024, batchSize));
|
||||
dim3 grid(1);
|
||||
acceptDraftTokensByIds<<<grid, block, 0, stream>>>(draftIds, targetIds, contextLengths, numsDraftTokens,
|
||||
sequenceLengths, finished, finishedFinal, finishedSum, batchSize, beamWidth, maxSeqLen, maxDraftTokens);
|
||||
sequenceLengths, finished, finishedFinal, finishedSum, batchSlots, batchSize, maxBatchSize, maxSeqLen,
|
||||
maxDraftTokens);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void acceptDraftTokensByLogitsKernel(const T* draftProbs, T* targetProbs, const int* numsDraftTokens,
|
||||
FinishedState* finished, curandState_t* curandState, int batchSize, int beamWidth, int vocabSize,
|
||||
bool randomThreshold, float constantThreshold)
|
||||
__global__ void acceptDraftTokensByLogitsKernel(T const* draftProbs, T* targetProbs, int32_t const* numsDraftTokens,
|
||||
FinishedState* finished, curandState_t* curandState, int32_t const* batchSlots, int32_t batchSize,
|
||||
int32_t maxBatchSize, int32_t maxDraftTokens, int32_t beamWidth, int32_t vocabSize, bool randomThreshold,
|
||||
float constantThreshold)
|
||||
{
|
||||
const auto bid = blockIdx.x;
|
||||
const auto draftTokenIdx = blockIdx.y;
|
||||
const auto batchIdx = bid / beamWidth;
|
||||
auto const bid = blockIdx.x;
|
||||
auto const draftTokenIdx = blockIdx.y;
|
||||
auto const batchIdx = bid / beamWidth;
|
||||
auto const beamIdx = bid % beamWidth;
|
||||
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
|
||||
auto const batchSlotBeamWidth = batchSlot * beamWidth + beamIdx;
|
||||
|
||||
const auto numDraftTokens = numsDraftTokens[bid];
|
||||
auto const numDraftTokens = numsDraftTokens[batchSlotBeamWidth];
|
||||
|
||||
if (draftTokenIdx >= numDraftTokens)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
const auto logitsOffset = draftTokenIdx * batchSize * beamWidth * vocabSize + bid * vocabSize;
|
||||
const auto draftProbsBatch = draftProbs + logitsOffset;
|
||||
const auto targetProbsBatch = targetProbs + logitsOffset;
|
||||
auto const logitsOffset = (batchSlot * maxDraftTokens + draftTokenIdx) * beamWidth * vocabSize;
|
||||
auto const draftProbsBatch = draftProbs + logitsOffset;
|
||||
auto const targetProbsBatch = targetProbs + logitsOffset;
|
||||
|
||||
int rejected = 0;
|
||||
int32_t rejected = 0;
|
||||
auto vocabSizePadded = static_cast<int32_t>((vocabSize + blockDim.x - 1) / blockDim.x) * blockDim.x;
|
||||
|
||||
for (int vIdx = threadIdx.x; vIdx < vocabSize; vIdx += blockDim.x)
|
||||
for (int32_t vIdx = threadIdx.x; vIdx < vocabSizePadded; vIdx += blockDim.x)
|
||||
{
|
||||
if (rejected > 0)
|
||||
{
|
||||
@ -594,35 +586,41 @@ __global__ void acceptDraftTokensByLogitsKernel(const T* draftProbs, T* targetPr
|
||||
|
||||
// FIXME(nkorobov): We compare probability distributions, but it might make sense to compare probabilities of
|
||||
// the selected tokens based on the https://arxiv.org/pdf/2302.01318.pdf
|
||||
const auto threshold = randomThreshold ? curand_uniform(curandState + batchIdx) : constantThreshold;
|
||||
|
||||
const auto targetProb = static_cast<float>(targetProbsBatch[vIdx]);
|
||||
const auto draftProb = static_cast<float>(draftProbsBatch[vIdx]);
|
||||
bool const pred = vIdx < vocabSize;
|
||||
auto const threshold
|
||||
= pred ? (randomThreshold ? curand_uniform(curandState + batchSlot) : constantThreshold) : 0.f;
|
||||
auto const targetProb = pred ? static_cast<float>(targetProbsBatch[vIdx]) : 1.f;
|
||||
auto const draftProb = pred ? static_cast<float>(draftProbsBatch[vIdx]) : 0.f;
|
||||
|
||||
rejected = __syncthreads_count(targetProb < threshold * draftProb);
|
||||
}
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
finished[draftTokenIdx * batchSize * beamWidth + bid]
|
||||
finished[draftTokenIdx * maxBatchSize * beamWidth + batchSlotBeamWidth]
|
||||
= rejected > 0 ? FinishedState::skipDecoding() : FinishedState::empty();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void correctAcceptedStatesAndLogits(const T* draftProbs, T* targetProbs, T* targetLogits,
|
||||
const int* numsDraftTokens, FinishedState* finished, int batchSize, int beamWidth, int vocabSize)
|
||||
__global__ void correctAcceptedStatesAndLogits(T const* draftProbs, T* targetProbs, T** targetLogits,
|
||||
int32_t const* numsDraftTokens, FinishedState* finished, int32_t const* batchSlots, int32_t batchSize,
|
||||
int32_t maxBatchSize, int32_t maxDraftTokens, int32_t beamWidth, int32_t vocabSize)
|
||||
{
|
||||
const auto bid = blockIdx.x;
|
||||
const auto numDraftTokens = numsDraftTokens[bid];
|
||||
auto const bid = blockIdx.x;
|
||||
auto const batchIdx = bid / beamWidth;
|
||||
auto const beamIdx = bid % beamWidth;
|
||||
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
|
||||
auto const batchSlotBeamWidth = batchSlot * beamWidth + beamIdx;
|
||||
auto const numDraftTokens = numsDraftTokens[batchSlotBeamWidth];
|
||||
|
||||
__shared__ int numAcceptedTokens;
|
||||
__shared__ int32_t numAcceptedTokens;
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
numAcceptedTokens = numDraftTokens;
|
||||
bool cummulativeSkipDecoding = false;
|
||||
for (int ti = 0; ti < numDraftTokens + 1; ++ti)
|
||||
for (int32_t ti = 0; ti < numDraftTokens + 1; ++ti)
|
||||
{
|
||||
auto& finishedState = finished[ti * batchSize * beamWidth + bid];
|
||||
auto& finishedState = finished[ti * maxBatchSize * beamWidth + batchSlotBeamWidth];
|
||||
bool localSkipDecoding = finishedState.isSkipDecoding();
|
||||
if (cummulativeSkipDecoding == false && localSkipDecoding == true)
|
||||
{
|
||||
@ -637,15 +635,15 @@ __global__ void correctAcceptedStatesAndLogits(const T* draftProbs, T* targetPro
|
||||
|
||||
if (numAcceptedTokens < numDraftTokens)
|
||||
{
|
||||
const auto logitsIdx = numAcceptedTokens * batchSize * beamWidth * vocabSize + bid * vocabSize;
|
||||
const auto draftProbBatch = draftProbs + logitsIdx;
|
||||
auto const logitsIdx = (batchSlot * maxDraftTokens + numAcceptedTokens) * beamWidth * vocabSize;
|
||||
auto const draftProbBatch = draftProbs + logitsIdx;
|
||||
auto targetProbBatch = targetProbs + logitsIdx;
|
||||
auto targetLogitsBatch = targetLogits + logitsIdx;
|
||||
auto targetLogitsBatch = targetLogits[bid] + numAcceptedTokens * beamWidth * vocabSize;
|
||||
|
||||
float sumProbs = 0.f;
|
||||
for (int vIdx = threadIdx.x; vIdx < vocabSize; vIdx += blockDim.x)
|
||||
for (int32_t vIdx = threadIdx.x; vIdx < vocabSize; vIdx += blockDim.x)
|
||||
{
|
||||
const auto correctedProb = max(static_cast<float>(targetProbBatch[vIdx] - draftProbBatch[vIdx]), 0.f);
|
||||
auto const correctedProb = max(static_cast<float>(targetProbBatch[vIdx] - draftProbBatch[vIdx]), 0.f);
|
||||
sumProbs += correctedProb;
|
||||
targetProbBatch[vIdx] = correctedProb;
|
||||
}
|
||||
@ -658,49 +656,52 @@ __global__ void correctAcceptedStatesAndLogits(const T* draftProbs, T* targetPro
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int vIdx = threadIdx.x; vIdx < vocabSize; vIdx += blockDim.x)
|
||||
for (int32_t vIdx = threadIdx.x; vIdx < vocabSize; vIdx += blockDim.x)
|
||||
{
|
||||
const auto correctedNormProb = static_cast<float>(targetProbBatch[vIdx]) / sumProbsShared;
|
||||
auto const correctedNormProb = static_cast<float>(targetProbBatch[vIdx]) / sumProbsShared;
|
||||
targetLogitsBatch[vIdx] = __logf(correctedNormProb / (1.f - correctedNormProb));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void acceptDraftTokensByLogits(T* draftLogits, T* targetLogits, T* draftProbs, T* targetProbs,
|
||||
const int* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int batchSize, int beamWidth,
|
||||
int vocabSize, int vocabSizePadded, int maxDraftTokens, bool randomThreshold, float constantThreshold,
|
||||
cudaStream_t stream)
|
||||
void acceptDraftTokensByLogits(T* draftLogits, T** targetLogits, T* draftProbs, T* targetProbs,
|
||||
int32_t const* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int32_t const* batchSlots,
|
||||
int32_t batchSize, int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded,
|
||||
int32_t maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK(beamWidth == 1);
|
||||
{
|
||||
invokeAddBiasSoftMax(draftLogits, draftProbs, (T*) (nullptr), nullptr, finished, nullptr,
|
||||
batchSize * beamWidth * maxDraftTokens, vocabSize, vocabSizePadded, stream);
|
||||
invokeAddBiasSoftMax(targetLogits, targetProbs, (T*) (nullptr), nullptr, finished, nullptr,
|
||||
batchSize * beamWidth * maxDraftTokens, vocabSize, vocabSizePadded, stream);
|
||||
invokeAddBiasSoftMax(draftLogits, (T**) (nullptr), draftProbs, (T*) (nullptr), nullptr, finished, batchSlots,
|
||||
batchSize, maxBatchSize, beamWidth * maxDraftTokens, vocabSize, vocabSizePadded, /* skip softmax */ false,
|
||||
/* batchSlotLogits */ true, stream);
|
||||
invokeAddBiasSoftMax((T*) (nullptr), targetLogits, targetProbs, (T*) (nullptr), nullptr, finished, batchSlots,
|
||||
batchSize, maxBatchSize, beamWidth * maxDraftTokens, vocabSize, vocabSizePadded, /* skip softmax */ false,
|
||||
/* batchSlotLogits */ true, stream);
|
||||
}
|
||||
{
|
||||
dim3 block(1024);
|
||||
dim3 grid(batchSize * beamWidth, maxDraftTokens);
|
||||
acceptDraftTokensByLogitsKernel<<<grid, block, 0, stream>>>(draftProbs, targetProbs, numsDraftTokens, finished,
|
||||
curandState, batchSize, beamWidth, vocabSizePadded, randomThreshold, constantThreshold);
|
||||
curandState, batchSlots, batchSize, maxBatchSize, maxDraftTokens, beamWidth, vocabSizePadded,
|
||||
randomThreshold, constantThreshold);
|
||||
}
|
||||
{
|
||||
dim3 block(1024);
|
||||
dim3 grid(batchSize * beamWidth);
|
||||
correctAcceptedStatesAndLogits<<<grid, block, 0, stream>>>(
|
||||
draftProbs, targetProbs, targetLogits, numsDraftTokens, finished, batchSize, beamWidth, vocabSizePadded);
|
||||
correctAcceptedStatesAndLogits<<<grid, block, 0, stream>>>(draftProbs, targetProbs, targetLogits,
|
||||
numsDraftTokens, finished, batchSlots, batchSize, maxBatchSize, maxDraftTokens, beamWidth, vocabSizePadded);
|
||||
}
|
||||
}
|
||||
|
||||
template void acceptDraftTokensByLogits(float* draftLogits, float* targetLogits, float* draftProbs, float* targetProbs,
|
||||
const int* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int batchSize, int beamWidth,
|
||||
int vocabSize, int vocabSizePadded, int maxDraftTokens, bool randomThreshold, float constantThreshold,
|
||||
cudaStream_t stream);
|
||||
template void acceptDraftTokensByLogits(half* draftLogits, half* targetLogits, half* draftProbs, half* targetProbs,
|
||||
const int* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int batchSize, int beamWidth,
|
||||
int vocabSize, int vocabSizePadded, int maxDraftTokens, bool randomThreshold, float constantThreshold,
|
||||
cudaStream_t stream);
|
||||
template void acceptDraftTokensByLogits(float* draftLogits, float** targetLogits, float* draftProbs, float* targetProbs,
|
||||
int32_t const* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int32_t const* batchSlots,
|
||||
int32_t batchSize, int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded,
|
||||
int32_t maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream);
|
||||
template void acceptDraftTokensByLogits(half* draftLogits, half** targetLogits, half* draftProbs, half* targetProbs,
|
||||
int32_t const* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int32_t const* batchSlots,
|
||||
int32_t batchSize, int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded,
|
||||
int32_t maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -31,22 +31,22 @@ namespace kernels
|
||||
struct gatherTreeParam
|
||||
{
|
||||
// TODO rename the parameters
|
||||
int* beams = nullptr; // [batchSize, beamWidth, maxSeqLen], workspace to put intermediate outputIds
|
||||
int* sequenceLengths = nullptr; // [batchSize, beamWidth], total lengths of each query
|
||||
int maxSequenceLengthFinalStep = 0;
|
||||
const int* inputLengths = nullptr; // [batchSize, beamWidth]
|
||||
int32_t* beams = nullptr; // [batchSize, beamWidth, maxSeqLen], workspace to put intermediate outputIds
|
||||
int32_t* sequenceLengths = nullptr; // [batchSize, beamWidth], total lengths of each query
|
||||
int32_t maxSequenceLengthFinalStep = 0;
|
||||
int32_t const* inputLengths = nullptr; // [batchSize, beamWidth]
|
||||
// response input lengths (used to slice the ids during postprocessing)
|
||||
int* responseInputLengths = nullptr;
|
||||
int maxSeqLen = 0;
|
||||
int batchSize = 0;
|
||||
int beamWidth = 0;
|
||||
const int* stepIds = nullptr; // [maxSeqLen, batchSize, beamWidth]
|
||||
const int* parentIds = nullptr; // [maxSeqLen, batchSize, beamWidth]
|
||||
const int* endTokens = nullptr; // [batchSize], end token ids of each query
|
||||
int* outputIds = nullptr; // the buffer to put finalized ids
|
||||
int32_t* responseInputLengths = nullptr;
|
||||
int32_t maxSeqLen = 0;
|
||||
int32_t batchSize = 0;
|
||||
int32_t beamWidth = 0;
|
||||
int32_t const* stepIds = nullptr; // [maxSeqLen, batchSize, beamWidth]
|
||||
int32_t const* parentIds = nullptr; // [maxSeqLen, batchSize, beamWidth]
|
||||
int32_t const* endTokens = nullptr; // [batchSize], end token ids of each query
|
||||
int32_t* outputIds = nullptr; // the buffer to put finalized ids
|
||||
cudaStream_t stream;
|
||||
float* cumLogProbs = nullptr; // [batchSize, beamWidth]
|
||||
float lengthPenalty = 1.0f; // on cpu
|
||||
float* cumLogProbs = nullptr; // [batchSize, beamWidth]
|
||||
float lengthPenalty = 1.0f; // on cpu
|
||||
};
|
||||
|
||||
/*
|
||||
@ -54,15 +54,16 @@ Do gatherTree on beam search to get final result.
|
||||
*/
|
||||
void invokeGatherTree(gatherTreeParam param);
|
||||
|
||||
void invokeFinalize(int* outputIds, int* sequenceLengths, float* cumLogProbs, float* outputLogProbs,
|
||||
const int* topKOutputIds, const int* topKSequenceLengths, const float* scores, const float* topKCumLogProbs,
|
||||
const float* topKLogProbs, const int* numBeams, const int* inputLengths, const int beamWidth, const int maxSeqLen,
|
||||
const int batchSize, cudaStream_t stream);
|
||||
void invokeFinalize(int32_t* outputIds, int32_t* sequenceLengths, float* cumLogProbs, float* outputLogProbs,
|
||||
int32_t const* topKOutputIds, int32_t const* topKSequenceLengths, float const* scores, float const* topKCumLogProbs,
|
||||
float const* topKLogProbs, int32_t const* numBeams, int32_t const* inputLengths, int32_t beamWidth,
|
||||
int32_t maxSeqLen, int32_t batchSize, cudaStream_t stream);
|
||||
|
||||
void invokeInitializeOutput(int* outputIds, const int* endIds, int batchBeam, int maxSeqLen, cudaStream_t stream);
|
||||
void invokeInitializeOutput(
|
||||
int32_t* outputIds, const int32_t* endIds, int batchBeam, int maxSeqLen, cudaStream_t stream);
|
||||
|
||||
void invokeCopyNextStepIds(int* nextStepIds, int** outputIdsPtr, const int* sequenceLengths, const int* batchSlots,
|
||||
int batchSize, int beamWidth, int maxSeqLen, cudaStream_t stream);
|
||||
void invokeCopyNextStepIds(int32_t* nextStepIds, int32_t** outputIdsPtr, int32_t const* sequenceLengths,
|
||||
int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth, int32_t maxSeqLen, cudaStream_t stream);
|
||||
|
||||
//! \brief Accepts or rejects draft tokens based on the equality of draft and target tokens
|
||||
//! for speculative decoding. Target token is accepted if targetToken == draftToken.
|
||||
@ -79,14 +80,17 @@ void invokeCopyNextStepIds(int* nextStepIds, int** outputIdsPtr, const int* sequ
|
||||
//! \param finished input buffer [maxDraftTokens + 1, batchSize] finished states at each decoding iteration
|
||||
//! \param finishedFinal output buffer [batchSize] finished states after accepting/rejecting tokens
|
||||
//! \param finishedSum output buffer [1] total number of requests in batch that finished the execution
|
||||
//! \param batchSize batch size
|
||||
//! \param batchSlots
|
||||
//! \param batchSize current batch size
|
||||
//! \param maxBatchSize maximum batch size
|
||||
//! \param beamWidth beam width
|
||||
//! \param maxSeqLen maximum sequence length
|
||||
//! \param maxDraftTokens maximum number of draft tokens
|
||||
//! \param stream stream
|
||||
void invokeAcceptDraftTokensByIds(const int* draftIds, const int* targetIds, const int* contextLengths,
|
||||
const int* numsDraftTokens, int* sequenceLengths, const FinishedState* finished, FinishedState* finishedFinal,
|
||||
int* finishedSum, int batchSize, int beamWidth, int maxSeqLen, int maxDraftTokens, cudaStream_t stream);
|
||||
void invokeAcceptDraftTokensByIds(int32_t const* draftIds, int32_t const* targetIds, int32_t const* contextLengths,
|
||||
int32_t const* numsDraftTokens, int32_t* sequenceLengths, FinishedState const* finished,
|
||||
FinishedState* finishedFinal, int32_t* finishedSum, int32_t const* batchSlots, int32_t batchSize,
|
||||
int32_t maxBatchSize, int32_t beamWidth, int32_t maxSeqLen, int32_t maxDraftTokens, cudaStream_t stream);
|
||||
|
||||
//! \brief Performs probabilistic acceptance of draft tokens based on their probability distributions.
|
||||
//! Corrects targetLogits for the next to the last accepted token
|
||||
@ -94,7 +98,8 @@ void invokeAcceptDraftTokensByIds(const int* draftIds, const int* targetIds, con
|
||||
//!
|
||||
//! \param draftLogits input/output buffer [draftTokens, batchSize, beamWidth, vocabSize].
|
||||
//! Initially contains token logits of the draft model.
|
||||
//! \param targetLogits input/output buffer [draftTokens+1, batchSize, beamWidth, vocabSize].
|
||||
//! \param targetLogits input/output buffer [batchSize][draftTokens+1, beamWidth, vocabSize].
|
||||
//! Vector of pointers to the logits.
|
||||
//! Initially contains token logits of the target model.
|
||||
//! It is modified in-place for next to the last accepted token such as
|
||||
//! P'(x) = norm(max(0, P_{n+1}(x) - Q_{n+1}(x))), where N < maxDraftTokens is number of accepted tokens.
|
||||
@ -107,7 +112,9 @@ void invokeAcceptDraftTokensByIds(const int* draftIds, const int* targetIds, con
|
||||
//! At each step sets to NOT_FINISHED if token is accepted or SKIP_DECODING if token is not accepted
|
||||
//! \param curandState input buffer [batchSize]. Curand states properly
|
||||
//! initialized using invokeCurandInitialize per request.
|
||||
//! \param batchSize batch size
|
||||
//! \param batchSlots
|
||||
//! \param batchSize current batch size
|
||||
//! \param maxBatchSize maximum batch size
|
||||
//! \param beamWidth beam width
|
||||
//! \param vocabSize unpadded vocab size
|
||||
//! \param vocabSizePadded padded vocab size
|
||||
@ -116,17 +123,18 @@ void invokeAcceptDraftTokensByIds(const int* draftIds, const int* targetIds, con
|
||||
//! \param constantThreshold threshold used to accept tokens if randomThreshold is false
|
||||
//! \param stream stream
|
||||
template <typename T>
|
||||
void acceptDraftTokensByLogits(T* draftLogits, T* targetLogits, T* draftProbs, T* targetProbs,
|
||||
const int* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int batchSize, int beamWidth,
|
||||
int vocabSize, int vocabSizePadded, int maxDraftTokens, bool randomThreshold, float constantThreshold,
|
||||
void acceptDraftTokensByLogits(T* draftLogits, T** targetLogits, T* draftProbs, T* targetProbs,
|
||||
int32_t const* numsDraftTokens, FinishedState* finished, curandState_t* curandState, int32_t const* batchSlots,
|
||||
int32_t batchSize, int32_t maxBatchSize, int32_t beamWidth, int32_t vocabSize, int32_t vocabSizePadded,
|
||||
int32_t maxDraftTokens, bool randomThreshold, float constantThreshold, cudaStream_t stream);
|
||||
|
||||
void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_tiled, int32_t const* sequence_lengths,
|
||||
int32_t const* batchSlots, int32_t batch_size, int32_t beam_width, int32_t max_seq_len, cudaStream_t stream);
|
||||
|
||||
void invokeAcceptTokens(int32_t const* draft_tokens, int32_t const* target_tokens, int32_t const* context_lengths,
|
||||
int32_t const* nums_draft_tokens, int32_t* sequence_lengths, bool const* finished, bool* finished_final,
|
||||
int32_t* finished_sum, int32_t batch_size, int32_t beam_width, int32_t max_seq_len, int32_t max_draft_tokens,
|
||||
cudaStream_t stream);
|
||||
|
||||
void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_tiled, const int* sequence_lengths,
|
||||
const int* batchSlots, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream);
|
||||
|
||||
void invokeAcceptTokens(const int* draft_tokens, const int* target_tokens, const int* context_lengths,
|
||||
const int* nums_draft_tokens, int* sequence_lengths, const bool* finished, bool* finished_final, int* finished_sum,
|
||||
int batch_size, int beam_width, int max_seq_len, int max_draft_tokens, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -30,7 +30,8 @@ static constexpr int kUpdateKVCacheKernelShmSize = 16384;
|
||||
template <typename KVCacheBuffer, int MaxLayerCount, typename MoveEltType>
|
||||
__global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheBuffer, MaxLayerCount> kvCacheBuffers,
|
||||
const int* seqAcceptedDraftTokenOffsets, const IndexType* packedAcceptedDraftTokensIndices,
|
||||
const int32_t* pastKeyValueLengths, int rewindDraftTokenCount, int eltCountPerHead)
|
||||
const int32_t* pastKeyValueLengths, int rewindDraftTokenCommonCount, const int* rewindDraftTokenSeparateAdjustments,
|
||||
int eltCountPerHead)
|
||||
{
|
||||
int seqIdx = blockIdx.x;
|
||||
int headIdx = blockIdx.y;
|
||||
@ -46,7 +47,11 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
|
||||
return;
|
||||
}
|
||||
KVCacheBuffer& kvCacheBuffer = kvCacheBuffers[layerIdx];
|
||||
int tokenStartIdx = pastKeyValueLengths[seqIdx] - rewindDraftTokenCount;
|
||||
int tokenStartIdx = pastKeyValueLengths[seqIdx] - rewindDraftTokenCommonCount;
|
||||
if (rewindDraftTokenSeparateAdjustments != nullptr)
|
||||
{
|
||||
tokenStartIdx -= rewindDraftTokenSeparateAdjustments[seqIdx];
|
||||
}
|
||||
int maxEltCountPerMove = kUpdateKVCacheKernelShmSize / sizeof(MoveEltType) / seqDraftCount;
|
||||
int eltCountPerMove = min(maxEltCountPerMove, eltCountPerHead);
|
||||
__shared__ char loadSmemBuffer[kUpdateKVCacheKernelShmSize];
|
||||
@ -121,7 +126,7 @@ template <typename KVCacheBuffer, int MaxLayerCount>
|
||||
void updateKVCacheDraftTokenLocationBatched(const KVCacheBuffer* kvCacheBuffers,
|
||||
const int* seqAcceptedDraftTokenOffsets, const IndexType* packedAcceptedDraftTokensIndices,
|
||||
const int32_t* pastKeyValueLengths, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCount, cudaStream_t stream)
|
||||
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, cudaStream_t stream)
|
||||
{
|
||||
// make sure launch buffer is enough
|
||||
static_assert(MaxLayerCount * sizeof(KVCacheBuffer) <= 3072);
|
||||
@ -144,7 +149,7 @@ void updateKVCacheDraftTokenLocationBatched(const KVCacheBuffer* kvCacheBuffers,
|
||||
kvCacheBufferArray[i] = kvCacheBuffers[i];
|
||||
}
|
||||
void (*pKernelFunc)(
|
||||
std::array<KVCacheBuffer, MaxLayerCount>, const int*, const IndexType*, const int32_t*, int, int)
|
||||
std::array<KVCacheBuffer, MaxLayerCount>, const int*, const IndexType*, const int32_t*, int, const int*, int)
|
||||
= nullptr;
|
||||
switch (alignedBytes)
|
||||
{
|
||||
@ -176,7 +181,8 @@ void updateKVCacheDraftTokenLocationBatched(const KVCacheBuffer* kvCacheBuffers,
|
||||
}
|
||||
}
|
||||
pKernelFunc<<<grid, block, 0, stream>>>(kvCacheBufferArray, seqAcceptedDraftTokenOffsets,
|
||||
packedAcceptedDraftTokensIndices, pastKeyValueLengths, rewindDraftTokenCount, eltCountPerHead);
|
||||
packedAcceptedDraftTokensIndices, pastKeyValueLengths, rewindDraftTokenCommonCount,
|
||||
rewindDraftTokenSeparateAdjustments, eltCountPerHead);
|
||||
TLLM_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
@ -191,14 +197,17 @@ void updateKVCacheDraftTokenLocationBatched(const KVCacheBuffer* kvCacheBuffers,
|
||||
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
|
||||
* @param seqCount : Count of sequence
|
||||
* @param numKVHeads : Number of KV heads
|
||||
* @param sizeInBytesPerKVHead :
|
||||
* @param rewindDraftTokenCount
|
||||
* @param stream
|
||||
* @param sizeInBytesPerKVHead : Size of each KV head
|
||||
* @param rewindDraftTokenCommonCount : Common count to rewind
|
||||
* @param rewindDraftTokenSeparateAdjustments : Separate adjustment to rewind for each sequence, if nullptr, just use
|
||||
* rewindDraftTokenCommonCount, else use rewindDraftTokenSeparateAdjustments[i] + rewindDraftTokenCommonCount
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
template <typename KVCacheBuffer>
|
||||
void updateKVCacheDraftTokenLocation(const KVCacheBuffer* kvCacheBuffers, const int* seqAcceptedDraftTokenOffsets,
|
||||
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int layerCount, int seqCount,
|
||||
int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount, cudaStream_t stream)
|
||||
int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
int startLayer = 0;
|
||||
static constexpr int kMaxLayersPerIter = 32;
|
||||
@ -207,7 +216,8 @@ void updateKVCacheDraftTokenLocation(const KVCacheBuffer* kvCacheBuffers, const
|
||||
int microBatchLayerCount = std::min(layerCount - startLayer, kMaxLayersPerIter);
|
||||
updateKVCacheDraftTokenLocationBatched<KVCacheBuffer, kMaxLayersPerIter>(kvCacheBuffers + startLayer,
|
||||
seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices, pastKeyValueLengths, microBatchLayerCount,
|
||||
seqCount, numKVHeads, sizeInBytesPerKVHead, rewindDraftTokenCount, stream);
|
||||
seqCount, numKVHeads, sizeInBytesPerKVHead, rewindDraftTokenCommonCount,
|
||||
rewindDraftTokenSeparateAdjustments, stream);
|
||||
startLayer += microBatchLayerCount;
|
||||
}
|
||||
}
|
||||
@ -215,7 +225,7 @@ void updateKVCacheDraftTokenLocation(const KVCacheBuffer* kvCacheBuffers, const
|
||||
void updateLinearKVCacheDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets,
|
||||
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCount, int maxKVCacheLen, cudaStream_t stream)
|
||||
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, cudaStream_t stream)
|
||||
{
|
||||
std::vector<KVLinearBuffer> kvLinearBuffers;
|
||||
kvLinearBuffers.reserve(layerCount);
|
||||
@ -227,13 +237,14 @@ void updateLinearKVCacheDraftTokenLocation(const int* seqAcceptedDraftTokenOffse
|
||||
}
|
||||
updateKVCacheDraftTokenLocation(kvLinearBuffers.data(), seqAcceptedDraftTokenOffsets,
|
||||
packedAcceptedDraftTokensIndices, pastKeyValueLengths, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
|
||||
rewindDraftTokenCount, stream);
|
||||
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, stream);
|
||||
}
|
||||
|
||||
void updateKVBlockArrayDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets,
|
||||
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray,
|
||||
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount,
|
||||
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
|
||||
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount,
|
||||
int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
std::vector<KVBlockArray> kvBlockArrays;
|
||||
kvBlockArrays.reserve(layerCount);
|
||||
@ -245,7 +256,47 @@ void updateKVBlockArrayDraftTokenLocation(const int* seqAcceptedDraftTokenOffset
|
||||
}
|
||||
updateKVCacheDraftTokenLocation(kvBlockArrays.data(), seqAcceptedDraftTokenOffsets,
|
||||
packedAcceptedDraftTokensIndices, pastKeyValueLengths, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
|
||||
rewindDraftTokenCount, stream);
|
||||
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, stream);
|
||||
}
|
||||
|
||||
void updateLinearKVCacheDraftTokenLocationCommonRewind(const int* seqAcceptedDraftTokenOffsets,
|
||||
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCount, int maxKVCacheLen, cudaStream_t stream)
|
||||
{
|
||||
updateLinearKVCacheDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
|
||||
pastKeyValueLengths, pastKeyValueList, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
|
||||
rewindDraftTokenCount, nullptr, maxKVCacheLen, stream);
|
||||
}
|
||||
|
||||
void updateKVBlockArrayDraftTokenLocationCommonRewind(const int* seqAcceptedDraftTokenOffsets,
|
||||
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray,
|
||||
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount,
|
||||
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
|
||||
{
|
||||
updateKVBlockArrayDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
|
||||
pastKeyValueLengths, pointerArray, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
|
||||
rewindDraftTokenCount, nullptr, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
|
||||
}
|
||||
|
||||
void updateLinearKVCacheDraftTokenLocationSeparateRewind(const int* seqAcceptedDraftTokenOffsets,
|
||||
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int* rewindDraftTokenCounts, int maxKVCacheLen, cudaStream_t stream)
|
||||
{
|
||||
updateLinearKVCacheDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
|
||||
pastKeyValueLengths, pastKeyValueList, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead, 0,
|
||||
rewindDraftTokenCounts, maxKVCacheLen, stream);
|
||||
}
|
||||
|
||||
void updateKVBlockArrayDraftTokenLocationSeparateRewind(const int* seqAcceptedDraftTokenOffsets,
|
||||
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray,
|
||||
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int* rewindDraftTokenCounts,
|
||||
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
|
||||
{
|
||||
updateKVBlockArrayDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
|
||||
pastKeyValueLengths, pointerArray, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead, 0,
|
||||
rewindDraftTokenCounts, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::kernels::parallel_decoding
|
||||
|
||||
@ -27,14 +27,141 @@ namespace tensorrt_llm::kernels::parallel_decoding
|
||||
|
||||
using IndexType = int;
|
||||
|
||||
void updateLinearKVCacheDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets,
|
||||
/*!
|
||||
* Update Linear KV cache using common rewind count.
|
||||
* @param seqAcceptedDraftTokenOffsets : Array of length seqCount + 1, like [0, 3, 5]
|
||||
* @param packedAcceptedDraftTokensIndices : Array of length seqAcceptedDraftTokenOffsets[seqCount], each value is in
|
||||
* range [0, maxDraftTokenCount - 1]
|
||||
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
|
||||
* @param pastKeyValueList : Past key value list, which is the pointer array of each KVLinear cache.
|
||||
* @param layerCount : Count of layers
|
||||
* @param seqCount : Count of sequence
|
||||
* @param numKVHeads : Number of KV heads
|
||||
* @param sizeInBytesPerKVHead : Size of each KV head
|
||||
* @param rewindDraftTokenCount : Count to rewind
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateLinearKVCacheDraftTokenLocationCommonRewind(const int* seqAcceptedDraftTokenOffsets,
|
||||
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCount, int maxKVCacheLen, cudaStream_t stream);
|
||||
|
||||
void updateKVBlockArrayDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets,
|
||||
/*!
|
||||
* Update Block KV cache using common rewind count.
|
||||
* @param seqAcceptedDraftTokenOffsets : Array of length seqCount + 1, like [0, 3, 5]
|
||||
* @param packedAcceptedDraftTokensIndices : Array of length seqAcceptedDraftTokenOffsets[seqCount], each value is in
|
||||
* range [0, maxDraftTokenCount - 1]
|
||||
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
|
||||
* @param pointerArray : Pointer array of each Block KV cache.
|
||||
* @param layerCount : Count of layers
|
||||
* @param seqCount : Count of sequence
|
||||
* @param numKVHeads : Number of KV heads
|
||||
* @param sizeInBytesPerKVHead : Size of each KV head
|
||||
* @param rewindDraftTokenCount : Count to rewind
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
|
||||
* @param tokensPerBlock : Tokens per block of Block KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateKVBlockArrayDraftTokenLocationCommonRewind(const int* seqAcceptedDraftTokenOffsets,
|
||||
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray,
|
||||
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount,
|
||||
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
* Update Linear KV cache using separate rewind count for each sequence.
|
||||
* @param seqAcceptedDraftTokenOffsets : Array of length seqCount + 1, like [0, 3, 5]
|
||||
* @param packedAcceptedDraftTokensIndices : Array of length seqAcceptedDraftTokenOffsets[seqCount], each value is in
|
||||
* range [0, maxDraftTokenCount - 1]
|
||||
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
|
||||
* @param pastKeyValueList : Past key value list, which is the pointer array of each KVLinear cache.
|
||||
* @param layerCount : Count of layers
|
||||
* @param seqCount : Count of sequence
|
||||
* @param numKVHeads : Number of KV heads
|
||||
* @param sizeInBytesPerKVHead : Size of each KV head
|
||||
* @param rewindDraftTokenCounts : Pointer to an array of length seqCount, each element indicated the rewind count of
|
||||
* one sequence.
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateLinearKVCacheDraftTokenLocationSeparateRewind(const int* seqAcceptedDraftTokenOffsets,
|
||||
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int* rewindDraftTokenCounts, int maxKVCacheLen, cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
* Update Block KV cache using separate rewind count for each sequence.
|
||||
* @param seqAcceptedDraftTokenOffsets : Array of length seqCount + 1, like [0, 3, 5]
|
||||
* @param packedAcceptedDraftTokensIndices : Array of length seqAcceptedDraftTokenOffsets[seqCount], each value is in
|
||||
* range [0, maxDraftTokenCount - 1]
|
||||
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
|
||||
* @param pointerArray : Pointer array of each Block KV cache.
|
||||
* @param layerCount : Count of layers
|
||||
* @param seqCount : Count of sequence
|
||||
* @param numKVHeads : Number of KV heads
|
||||
* @param sizeInBytesPerKVHead : Size of each KV head
|
||||
* @param rewindDraftTokenCounts : Pointer to an array of length seqCount, each element indicated the rewind count of
|
||||
* one sequence.
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
|
||||
* @param tokensPerBlock : Tokens per block of Block KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateKVBlockArrayDraftTokenLocationSeparateRewind(const int* seqAcceptedDraftTokenOffsets,
|
||||
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray,
|
||||
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int* rewindDraftTokenCounts,
|
||||
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
* Update Linear KV cache using both common rewind and separate rewind count for each sequence. The common
|
||||
* rewindDraftTokenCommonCount and rewind count of each sequence in rewindDraftTokenSeparateAdjustments will be added
|
||||
* together for the final rewind count. It can save one add if both of them need to be used.
|
||||
* @param seqAcceptedDraftTokenOffsets : Array of length seqCount + 1, like [0, 3, 5]
|
||||
* @param packedAcceptedDraftTokensIndices : Array of length seqAcceptedDraftTokenOffsets[seqCount], each value is in
|
||||
* range [0, maxDraftTokenCount - 1]
|
||||
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
|
||||
* @param pastKeyValueList : Past key value list, which is the pointer array of each KVLinear cache.
|
||||
* @param layerCount : Count of layers
|
||||
* @param seqCount : Count of sequence
|
||||
* @param numKVHeads : Number of KV heads
|
||||
* @param sizeInBytesPerKVHead : Size of each KV head
|
||||
* @param rewindDraftTokenCommonCount : Common token count to rewind
|
||||
* @param rewindDraftTokenSeparateAdjustments : Pointer to an array of length seqCount, each element indicated the
|
||||
* rewind adjustment for one sequence.
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateLinearKVCacheDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets,
|
||||
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
* Update Block KV cache using both common rewind and separate rewind count for each sequence. The common
|
||||
* rewindDraftTokenCommonCount and rewind count of each sequence in rewindDraftTokenSeparateAdjustments will be added
|
||||
* together for the final rewind count. It can save one add if both of them need to be used.
|
||||
* @param seqAcceptedDraftTokenOffsets : Array of length seqCount + 1, like [0, 3, 5]
|
||||
* @param packedAcceptedDraftTokensIndices : Array of length seqAcceptedDraftTokenOffsets[seqCount], each value is in
|
||||
* range [0, maxDraftTokenCount - 1]
|
||||
* @param pastKeyValueLengths : Array of length seqCount, meaning how many tokens are already in KV cache
|
||||
* @param pointerArray : Pointer array of each Block KV cache.
|
||||
* @param layerCount : Count of layers
|
||||
* @param seqCount : Count of sequence
|
||||
* @param numKVHeads : Number of KV heads
|
||||
* @param sizeInBytesPerKVHead : Size of each KV head
|
||||
* @param rewindDraftTokenCommonCount : Common token count to rewind
|
||||
* @param rewindDraftTokenSeparateAdjustments : Pointer to an array of length seqCount, each element indicated the
|
||||
* rewind adjustment for one sequence.
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
|
||||
* @param tokensPerBlock : Tokens per block of Block KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateKVBlockArrayDraftTokenLocation(const int* seqAcceptedDraftTokenOffsets,
|
||||
const IndexType* packedAcceptedDraftTokensIndices, const int32_t* pastKeyValueLengths, int64_t* const* pointerArray,
|
||||
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount,
|
||||
int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace tensorrt_llm::kernels::parallel_decoding
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -30,35 +30,37 @@ namespace kernels
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
__global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorkspace, const int* penaltyWorkspacePrev,
|
||||
const float* temperatures, const float* repetitionPenalties, const float* presencePenalties,
|
||||
const float* frequencyPenalties, const bool accumulateVocab, const int maxSeqLen, const int vocabSize,
|
||||
const int vocabSizePadded, const int** outputIdsPtr, const int** parentIdsPtr, const int* inputLengths,
|
||||
const int* sequenceLengths, const int* minLengths, const int* endIds, const int* batchSlots)
|
||||
__global__ void batchApplyPenalty(T const* const* inputLogits, T* outputLogits, T const* biases,
|
||||
int32_t* penaltyWorkspace, int32_t const* penaltyWorkspacePrev, float const* temperatures,
|
||||
float const* repetitionPenalties, float const* presencePenalties, float const* frequencyPenalties,
|
||||
const bool accumulateVocab, int32_t const maxSeqLen, int32_t const vocabSize, int32_t const vocabSizePadded,
|
||||
int32_t const** outputIdsPtr, int32_t const** parentIdsPtr, int32_t const* inputLengths,
|
||||
int32_t const* sequenceLengths, int32_t const* minLengths, int32_t const* endIds, int32_t const* batchSlots)
|
||||
{
|
||||
const int beamWidth = gridDim.y;
|
||||
const int batchIdx = blockIdx.x;
|
||||
const int beamIdx = blockIdx.y;
|
||||
const int batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
|
||||
const int batchBeamIdx = batchIdx * beamWidth + beamIdx;
|
||||
const int batchSlotBeamIdx = batchSlot * beamWidth + beamIdx;
|
||||
const int inputLen = inputLengths == nullptr ? 0 : inputLengths[batchSlotBeamIdx];
|
||||
const int currentStep = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlotBeamIdx];
|
||||
int32_t const beamWidth = gridDim.y;
|
||||
int32_t const batchIdx = blockIdx.x;
|
||||
int32_t const beamIdx = blockIdx.y;
|
||||
int32_t const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
|
||||
int32_t const batchBeamIdx = batchIdx * beamWidth + beamIdx;
|
||||
int32_t const batchSlotBeamIdx = batchSlot * beamWidth + beamIdx;
|
||||
int32_t const inputLen = inputLengths == nullptr ? 0 : inputLengths[batchSlotBeamIdx];
|
||||
int32_t const currentStep = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlotBeamIdx];
|
||||
T const* biasBase = biases + batchSlot * vocabSizePadded;
|
||||
// Initialize or update the number of occurrences of tokens
|
||||
if (accumulateVocab)
|
||||
{
|
||||
penaltyWorkspace += batchBeamIdx * vocabSize;
|
||||
if (currentStep <= inputLen)
|
||||
{ // Context phase
|
||||
for (int index = threadIdx.x; index < vocabSize; index += blockDim.x)
|
||||
for (int32_t index = threadIdx.x; index < vocabSize; index += blockDim.x)
|
||||
{
|
||||
penaltyWorkspace[index] = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int step = threadIdx.x; step < inputLen; step += blockDim.x)
|
||||
for (int32_t step = threadIdx.x; step < inputLen; step += blockDim.x)
|
||||
{
|
||||
// All beams in the context phase are identical
|
||||
int penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step];
|
||||
int32_t penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + step];
|
||||
if (penaltyIndex < vocabSize)
|
||||
{
|
||||
atomicAdd(&penaltyWorkspace[penaltyIndex], 1);
|
||||
@ -69,9 +71,9 @@ __global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorksp
|
||||
{ // Generation phase
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
int parentBeam = parentIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 2];
|
||||
int32_t parentBeam = parentIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 2];
|
||||
penaltyWorkspacePrev += (batchIdx * beamWidth + parentBeam) * vocabSize;
|
||||
for (int index = threadIdx.x; index < vocabSize; index += blockDim.x)
|
||||
for (int32_t index = threadIdx.x; index < vocabSize; index += blockDim.x)
|
||||
{
|
||||
penaltyWorkspace[index] = penaltyWorkspacePrev[index];
|
||||
}
|
||||
@ -79,7 +81,7 @@ __global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorksp
|
||||
}
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
int penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1];
|
||||
int32_t penaltyIndex = outputIdsPtr[batchSlot][beamIdx * maxSeqLen + currentStep - 1];
|
||||
if (penaltyIndex < vocabSize)
|
||||
{
|
||||
penaltyWorkspace[penaltyIndex] += 1;
|
||||
@ -89,7 +91,8 @@ __global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorksp
|
||||
__syncthreads();
|
||||
}
|
||||
// Apply bias and penalties
|
||||
logits += batchBeamIdx * vocabSizePadded;
|
||||
auto const inLogitsPtr = inputLogits[batchIdx] + beamIdx * vocabSizePadded;
|
||||
auto outLogitsPtr = outputLogits + batchBeamIdx * vocabSizePadded;
|
||||
const T MASK_VAL = (std::is_same<T, half>::value) ? -HALF_FLT_MAX : -FLT_MAX;
|
||||
float invTemperature, repetitionPenalty, presencePenalty, frequencyPenalty;
|
||||
if (temperatures != nullptr)
|
||||
@ -108,22 +111,22 @@ __global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorksp
|
||||
{
|
||||
frequencyPenalty = frequencyPenalties[batchSlot];
|
||||
}
|
||||
for (int index = threadIdx.x; index < vocabSizePadded; index += blockDim.x)
|
||||
for (int32_t index = threadIdx.x; index < vocabSizePadded; index += blockDim.x)
|
||||
{
|
||||
if (index < vocabSize)
|
||||
{
|
||||
float logit = (float) logits[index];
|
||||
float logit = (float) inLogitsPtr[index];
|
||||
// Bias
|
||||
if (biases != nullptr)
|
||||
{
|
||||
logit += (float) biases[index];
|
||||
logit += (float) biasBase[index];
|
||||
}
|
||||
// Temperature
|
||||
if (temperatures != nullptr)
|
||||
{
|
||||
logit *= invTemperature;
|
||||
}
|
||||
int numOccurences = penaltyWorkspace[index];
|
||||
int32_t numOccurences = penaltyWorkspace[index];
|
||||
if (numOccurences > 0)
|
||||
{
|
||||
// Repetition
|
||||
@ -142,11 +145,11 @@ __global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorksp
|
||||
logit -= frequencyPenalty * numOccurences;
|
||||
}
|
||||
}
|
||||
logits[index] = logit;
|
||||
outLogitsPtr[index] = logit;
|
||||
}
|
||||
else
|
||||
{
|
||||
logits[index] = MASK_VAL;
|
||||
outLogitsPtr[index] = MASK_VAL;
|
||||
}
|
||||
}
|
||||
if (minLengths != nullptr)
|
||||
@ -155,7 +158,7 @@ __global__ void batchApplyPenalty(T* logits, const T* biases, int* penaltyWorksp
|
||||
// Min length
|
||||
if ((threadIdx.x == 0) && (currentStep - inputLen < minLengths[batchSlot]))
|
||||
{
|
||||
logits[endIds[batchSlot]] = MASK_VAL;
|
||||
outLogitsPtr[endIds[batchSlot]] = MASK_VAL;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -166,11 +169,11 @@ void invokeBatchApplyPenalty(const InvokeBatchApplyPenaltyParams<T>& params)
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
dim3 block(256);
|
||||
dim3 grid(params.batchSize, params.beamWidth);
|
||||
batchApplyPenalty<T><<<grid, block, 0, params.stream>>>(params.logits, params.biases, params.penaltyWorkspace,
|
||||
params.penaltyWorkspacePrev, params.temperatures, params.repetitionPenalties, params.presencePenalties,
|
||||
params.frequencyPenalties, params.accumulateVocab, params.maxSeqLen, params.vocabSize, params.vocabSizePadded,
|
||||
params.outputIdsPtr, params.parentIdsPtr, params.inputLengths, params.sequenceLengths, params.minLengths,
|
||||
params.endIds, params.batchSlots);
|
||||
batchApplyPenalty<T><<<grid, block, 0, params.stream>>>(params.inputLogits, params.outputLogits, params.biases,
|
||||
params.penaltyWorkspace, params.penaltyWorkspacePrev, params.temperatures, params.repetitionPenalties,
|
||||
params.presencePenalties, params.frequencyPenalties, params.accumulateVocab, params.maxSeqLen, params.vocabSize,
|
||||
params.vocabSizePadded, params.outputIdsPtr, params.parentIdsPtr, params.inputLengths, params.sequenceLengths,
|
||||
params.minLengths, params.endIds, params.batchSlots);
|
||||
}
|
||||
|
||||
template void invokeBatchApplyPenalty(const InvokeBatchApplyPenaltyParams<float>& params);
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -28,7 +28,8 @@ namespace kernels
|
||||
template <typename T>
|
||||
struct InvokeBatchApplyPenaltyParams
|
||||
{
|
||||
T* logits;
|
||||
T const* const* inputLogits;
|
||||
T* outputLogits;
|
||||
const T* biases;
|
||||
int* penaltyWorkspace;
|
||||
const int* penaltyWorkspacePrev;
|
||||
|
||||
@ -24,7 +24,7 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
enum class RepetitionPenaltyType
|
||||
enum class DecodingPenaltyType
|
||||
{
|
||||
Temperature, // the temperature penalty
|
||||
Repetition, // the repetition penalty
|
||||
@ -33,15 +33,15 @@ enum class RepetitionPenaltyType
|
||||
MinLength, // the min length penalty
|
||||
};
|
||||
|
||||
inline float getDefaultPenaltyValue(RepetitionPenaltyType penalty_type)
|
||||
inline float getDefaultPenaltyValue(DecodingPenaltyType penalty_type)
|
||||
{
|
||||
switch (penalty_type)
|
||||
{
|
||||
case RepetitionPenaltyType::Temperature: return 1.0f;
|
||||
case RepetitionPenaltyType::Repetition: return 1.0f;
|
||||
case RepetitionPenaltyType::Presence: return 0.0f;
|
||||
case RepetitionPenaltyType::Frequency: return 0.0f;
|
||||
case RepetitionPenaltyType::MinLength: return 1.0f;
|
||||
case DecodingPenaltyType::Temperature: return 1.0f;
|
||||
case DecodingPenaltyType::Repetition: return 1.0f;
|
||||
case DecodingPenaltyType::Presence: return 0.0f;
|
||||
case DecodingPenaltyType::Frequency: return 0.0f;
|
||||
case DecodingPenaltyType::MinLength: return 1.0f;
|
||||
default: break;
|
||||
}
|
||||
return 0.0f;
|
||||
|
||||
@ -23,11 +23,12 @@ struct Vec2Type<__nv_bfloat16>
|
||||
#endif
|
||||
}; // namespace
|
||||
|
||||
template <typename T, int kProcessRows, typename AccessType>
|
||||
__global__ void apply_per_channel_scale(T* smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols)
|
||||
template <typename T_in, typename T_out, int kProcessRows, typename AccessType>
|
||||
__global__ void apply_per_channel_scale(
|
||||
T_out* smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols)
|
||||
{
|
||||
static constexpr int kElems = sizeof(AccessType) / sizeof(T);
|
||||
T scale[kElems], act_vec[kElems];
|
||||
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
|
||||
T_in scale[kElems], act_vec[kElems];
|
||||
int col_offset = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int row_offset = blockIdx.y;
|
||||
if (col_offset * kElems >= cols || row_offset * kProcessRows >= rows)
|
||||
@ -39,13 +40,13 @@ __global__ void apply_per_channel_scale(T* smoothed_act, const T* act, const T*
|
||||
for (int i = 0; i < kProcessRows; ++i)
|
||||
{
|
||||
*reinterpret_cast<AccessType*>(act_vec) = reinterpret_cast<const AccessType*>(act + i * cols)[col_offset];
|
||||
if constexpr ((std::is_same_v<T, half>
|
||||
if constexpr ((std::is_same_v<T_in, half>
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
|| std::is_same_v<T, __nv_bfloat16>
|
||||
|| std::is_same_v<T_in, __nv_bfloat16>
|
||||
#endif
|
||||
) &&(kElems % 2 == 0))
|
||||
{
|
||||
using Vec2 = typename Vec2Type<T>::type;
|
||||
using Vec2 = typename Vec2Type<T_in>::type;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kElems; j += 2)
|
||||
{
|
||||
@ -58,58 +59,77 @@ __global__ void apply_per_channel_scale(T* smoothed_act, const T* act, const T*
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kElems; ++j)
|
||||
{
|
||||
act_vec[j] = static_cast<T>(static_cast<float>(act_vec[j]) * static_cast<float>(scale[j]));
|
||||
act_vec[j] = static_cast<T_in>(static_cast<float>(act_vec[j]) * static_cast<float>(scale[j]));
|
||||
}
|
||||
}
|
||||
if constexpr (std::is_same_v<T_in, T_out>)
|
||||
{
|
||||
reinterpret_cast<AccessType*>(smoothed_act + i * cols)[col_offset]
|
||||
= *reinterpret_cast<AccessType*>(act_vec);
|
||||
}
|
||||
else
|
||||
{
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kElems; ++j)
|
||||
{
|
||||
(smoothed_act + i * cols)[col_offset * kElems + j] = static_cast<T_out>(act_vec[j]);
|
||||
}
|
||||
}
|
||||
reinterpret_cast<AccessType*>(smoothed_act + i * cols)[col_offset] = *reinterpret_cast<AccessType*>(act_vec);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int kProcessRows, typename AccessType = float4>
|
||||
template <typename T_in, typename T_out, int kProcessRows, typename AccessType = float4>
|
||||
void apply_per_channel_scale_kernel_launcher_(
|
||||
T* smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols, cudaStream_t stream = 0)
|
||||
T_out* smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols, cudaStream_t stream = 0)
|
||||
{
|
||||
static constexpr int kElems = sizeof(AccessType) / sizeof(T);
|
||||
static constexpr int kElems = sizeof(AccessType) / sizeof(T_in);
|
||||
dim3 block(128);
|
||||
dim3 grid((cols / kElems + block.x - 1) / block.x, (rows + kProcessRows - 1) / kProcessRows);
|
||||
apply_per_channel_scale<T, kProcessRows, AccessType>
|
||||
apply_per_channel_scale<T_in, T_out, kProcessRows, AccessType>
|
||||
<<<grid, block, 0, stream>>>(smoothed_act, act, per_channel_scale, rows, cols);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename T_in, typename T_out>
|
||||
void apply_per_channel_scale_kernel_launcher(
|
||||
T* smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols, cudaStream_t stream)
|
||||
T_out* smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols, cudaStream_t stream)
|
||||
{
|
||||
int elems = rows * cols;
|
||||
if (elems < 2048 * 2048)
|
||||
{
|
||||
apply_per_channel_scale_kernel_launcher_<T, 1, float4>(
|
||||
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 1, float4>(
|
||||
smoothed_act, act, per_channel_scale, rows, cols, stream);
|
||||
}
|
||||
else if (elems < 4096 * 4096)
|
||||
{
|
||||
apply_per_channel_scale_kernel_launcher_<T, 4, float4>(
|
||||
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 4, float4>(
|
||||
smoothed_act, act, per_channel_scale, rows, cols, stream);
|
||||
}
|
||||
else if (elems < 8192 * 8192)
|
||||
{
|
||||
apply_per_channel_scale_kernel_launcher_<T, 8, float4>(
|
||||
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 8, float4>(
|
||||
smoothed_act, act, per_channel_scale, rows, cols, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_per_channel_scale_kernel_launcher_<T, 16, float4>(
|
||||
apply_per_channel_scale_kernel_launcher_<T_in, T_out, 16, float4>(
|
||||
smoothed_act, act, per_channel_scale, rows, cols, stream);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_PREQUANT_SCALE(T) \
|
||||
template void apply_per_channel_scale_kernel_launcher<T>( \
|
||||
T * smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols, cudaStream_t stream)
|
||||
#define INSTANTIATE_PREQUANT_SCALE(T_in, T_out) \
|
||||
template void apply_per_channel_scale_kernel_launcher<T_in, T_out>( \
|
||||
T_out * smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols, cudaStream_t stream)
|
||||
|
||||
INSTANTIATE_PREQUANT_SCALE(half, half);
|
||||
#if defined(ENABLE_FP8)
|
||||
INSTANTIATE_PREQUANT_SCALE(half, __nv_fp8_e4m3);
|
||||
#endif
|
||||
|
||||
INSTANTIATE_PREQUANT_SCALE(half);
|
||||
#if defined(ENABLE_BF16)
|
||||
INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16);
|
||||
INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_bfloat16);
|
||||
#if defined(ENABLE_FP8)
|
||||
INSTANTIATE_PREQUANT_SCALE(__nv_bfloat16, __nv_fp8_e4m3);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
} // namespace kernels
|
||||
|
||||
@ -17,6 +17,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
@ -32,9 +33,9 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
template <typename T_in, typename T_out = T_in>
|
||||
void apply_per_channel_scale_kernel_launcher(
|
||||
T* smoothed_act, const T* act, const T* per_channel_scale, int rows, int cols, cudaStream_t stream = 0);
|
||||
T_out* smoothed_act, const T_in* act, const T_in* per_channel_scale, int rows, int cols, cudaStream_t stream = 0);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -301,7 +301,7 @@ __device__ void vectorizedProcess(size_t threadRank, size_t numThreads, T const*
|
||||
}
|
||||
|
||||
/**
|
||||
* Fused filtering of the current pass and building histogram for the next pass (see steps 4 & 1 in `airTopPSsampling`
|
||||
* Fused filtering of the current pass and building histogram for the next pass (see steps 4 & 1 in `airTopPSampling`
|
||||
* description).
|
||||
*/
|
||||
template <typename T, typename IdxT, typename AccT, int BitsPerPass>
|
||||
@ -418,7 +418,7 @@ __device__ __forceinline__ void filterAndHistogram(T const* inBuf, IdxT const* i
|
||||
}
|
||||
|
||||
/**
|
||||
* Replace histogram with its own prefix sum (step 2 in `airTopPSsampling` description)
|
||||
* Replace histogram with its own prefix sum (step 2 in `airTopPSampling` description)
|
||||
*/
|
||||
template <typename IdxT, int BitsPerPass, int BlockSize>
|
||||
__device__ void scan(volatile IdxT* histogram)
|
||||
@ -472,7 +472,7 @@ __device__ void scan(volatile IdxT* histogram)
|
||||
|
||||
/**
|
||||
* Calculate in which bucket the k-th value will fall
|
||||
* (steps 3 in `airTopPSsampling` description)
|
||||
* (steps 3 in `airTopPSampling` description)
|
||||
*/
|
||||
template <typename T, typename IdxT, typename AccT, int BitsPerPass>
|
||||
__device__ void chooseBucket(
|
||||
@ -486,13 +486,17 @@ __device__ void chooseBucket(
|
||||
|
||||
// one and only one thread will satisfy this condition, so counter is
|
||||
// written by only one thread
|
||||
if ((prev < sum && cur >= sum) || (sum <= 0 && i == 0))
|
||||
// Add strict check for negetive cases.
|
||||
if ((sum > 0 && prev < sum && cur >= sum) || (sum <= 0 && prev == 0 && cur != 0))
|
||||
{
|
||||
counter->sum = sum - prev; // how many values still are there to find
|
||||
counter->len = countHistogram[i]; // cur - prev; // number of values in next pass
|
||||
typename cub::Traits<T>::UnsignedBits bucket = i;
|
||||
int startBit = calcsStartBit<T, BitsPerPass>(pass);
|
||||
counter->kthValueBits |= bucket << startBit;
|
||||
if (countHistogram[i]) // Only check meaningful ones
|
||||
{
|
||||
counter->sum = sum - prev; // how many values still are there to find
|
||||
counter->len = countHistogram[i]; // cur - prev; // number of values in next pass
|
||||
typename cub::Traits<T>::UnsignedBits bucket = i;
|
||||
int startBit = calcsStartBit<T, BitsPerPass>(pass);
|
||||
counter->kthValueBits |= bucket << startBit;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -533,7 +537,7 @@ __device__ void epilogue(T const value, IdxT const index, float* outputLogProbs,
|
||||
|
||||
/**
|
||||
* Find the target element.
|
||||
* (steps 4 in `airTopPSsampling` description)
|
||||
* (steps 4 in `airTopPSampling` description)
|
||||
*/
|
||||
template <typename T, typename IdxT, typename AccT, int BitsPerPass>
|
||||
__device__ void lastFilter(T const* inBuf, IdxT const* inIdxBuf, IdxT currentLen, Counter<T, IdxT, AccT>* counter,
|
||||
@ -603,7 +607,7 @@ __device__ void lastFilter(T const* inBuf, IdxT const* inIdxBuf, IdxT currentLen
|
||||
* their indices.
|
||||
*/
|
||||
template <typename T, typename IdxT, typename AccT, int BitsPerPass, int BlockSize, bool is_fused_filter = false>
|
||||
__global__ void airTopPSsampling(Counter<T, IdxT, AccT>* counters, AccT* histograms, IdxT* countHistograms, IdxT** ids,
|
||||
__global__ void airTopPSampling(Counter<T, IdxT, AccT>* counters, AccT* histograms, IdxT* countHistograms, IdxT** ids,
|
||||
int* sequenceLengths, FinishedState const* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
|
||||
float* outputLogProbs, IdxT const* endIds, int const batchSize, bool const* skipDecode, int const pass, T* buf1,
|
||||
IdxT* idxBuf1, T* buf2, IdxT* idxBuf2, int32_t const* batchSlots)
|
||||
@ -697,6 +701,7 @@ __global__ void airTopPSsampling(Counter<T, IdxT, AccT>* counters, AccT* histogr
|
||||
earlyStop);
|
||||
|
||||
__syncthreads();
|
||||
__threadfence();
|
||||
|
||||
bool isLastBlock = false;
|
||||
if (threadIdx.x == 0)
|
||||
@ -711,13 +716,42 @@ __global__ void airTopPSsampling(Counter<T, IdxT, AccT>* counters, AccT* histogr
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
__shared__ IdxT maxBucket;
|
||||
if (pass > 0)
|
||||
{
|
||||
// Avoid the scenario where currentSum is larger than the meaningful maximum prefix sum.
|
||||
// This situation happens because these two values are calculted in different ways.
|
||||
// So the precision loss during the calculation is also different.
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
maxBucket = 0;
|
||||
}
|
||||
__syncthreads();
|
||||
for (int i = threadIdx.x; i < numBuckets; i += blockDim.x)
|
||||
{
|
||||
if (countHistogram[i])
|
||||
{
|
||||
atomicMax(&maxBucket, i);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
scan<AccT, BitsPerPass, BlockSize>(histogram);
|
||||
__syncthreads();
|
||||
if (pass == 0)
|
||||
{
|
||||
currentSum = histogram[numBuckets - 1] * counter->p;
|
||||
}
|
||||
__syncthreads();
|
||||
else
|
||||
{
|
||||
if (currentSum > histogram[maxBucket])
|
||||
{
|
||||
currentSum = histogram[maxBucket];
|
||||
}
|
||||
}
|
||||
|
||||
chooseBucket<T, IdxT, AccT, BitsPerPass>(counter, histogram, countHistogram, currentSum, pass);
|
||||
__syncthreads();
|
||||
@ -818,7 +852,7 @@ unsigned calcAirTopPBlockNum(int batchSize, IdxT len, int smCnt)
|
||||
|
||||
int activeBlocks;
|
||||
cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&activeBlocks, airTopPSsampling<T, IdxT, AccT, BitsPerPass, BlockSize, false>, BlockSize, 0);
|
||||
&activeBlocks, airTopPSampling<T, IdxT, AccT, BitsPerPass, BlockSize, false>, BlockSize, 0);
|
||||
activeBlocks *= smCnt;
|
||||
|
||||
IdxT bestNumBlocks = 0;
|
||||
@ -909,13 +943,13 @@ void invokeBatchAirTopPSampling(void* workspace, size_t& workspaceSize, int** ou
|
||||
dim3 grid(blockNum, batchSize);
|
||||
// Sample with Top P given sorted tokens
|
||||
int constexpr numPasses = calcNumPasses<T, BitsPerPass>();
|
||||
auto kernel = airTopPSsampling<T, IdxT, AccT, BitsPerPass, SAMPLING_BLOCK_SIZE, false>;
|
||||
auto kernel = airTopPSampling<T, IdxT, AccT, BitsPerPass, SAMPLING_BLOCK_SIZE, false>;
|
||||
|
||||
for (int pass = 0; pass < numPasses; ++pass)
|
||||
{
|
||||
if (pass == numPasses - 1)
|
||||
{
|
||||
kernel = airTopPSsampling<T, IdxT, AccT, BitsPerPass, SAMPLING_BLOCK_SIZE, true>;
|
||||
kernel = airTopPSampling<T, IdxT, AccT, BitsPerPass, SAMPLING_BLOCK_SIZE, true>;
|
||||
}
|
||||
|
||||
kernel<<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(counters, histograms, countHistograms, outputIds,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -36,59 +36,6 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
__global__ void addBiasEndMask(T* logits, const T* bias, const int* endIds, const FinishedState* finished,
|
||||
const int* batchSlots, const int vocabSize, const int vocabSizePadded)
|
||||
{
|
||||
auto const batchIdx = blockIdx.x;
|
||||
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
|
||||
FinishedState const finishState = finished != nullptr ? finished[batchSlot] : FinishedState::empty();
|
||||
if (finishState.isSkipDecoding())
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
bool finish = finishState.isFinished();
|
||||
int offset = batchIdx * vocabSizePadded;
|
||||
|
||||
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||
T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
for (int tid = threadIdx.x; tid < vocabSizePadded; tid += blockDim.x)
|
||||
{
|
||||
if (tid >= vocabSize)
|
||||
{
|
||||
logits[offset + tid] = -MAX_T_VAL;
|
||||
}
|
||||
else if (finish)
|
||||
{
|
||||
logits[offset + tid] = (tid == endIds[batchSlot]) ? MAX_T_VAL : -MAX_T_VAL;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (bias != nullptr)
|
||||
{
|
||||
logits[offset + tid] += bias[tid];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeAddBiasEndMask(T* logits, const T* bias, const int* endIds, const FinishedState* finished,
|
||||
const int* batchSlots, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream)
|
||||
{
|
||||
dim3 grid(batchSize);
|
||||
dim3 block(min(vocabSizePadded, 1024));
|
||||
// n is the vocabSize, e.g., 30000, 7000.... vocabSize is usually very big.
|
||||
addBiasEndMask<<<grid, block, 0, stream>>>(logits, bias, endIds, finished, batchSlots, vocabSize, vocabSizePadded);
|
||||
}
|
||||
|
||||
template void invokeAddBiasEndMask(float* logits, const float* bias, const int* endIds, const FinishedState* finished,
|
||||
const int* batchSlots, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream);
|
||||
|
||||
template void invokeAddBiasEndMask(half* logits, const half* bias, const int* endIds, const FinishedState* finished,
|
||||
const int* batchSlots, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream);
|
||||
|
||||
template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
|
||||
__global__ void topKStage1(const T* __restrict logProbs, T* tmpLogProbs, int* topKTmpIdBuf, T* topKTmpValBuf,
|
||||
const FinishedState* finished, const int maxTopK, const int* topKs, const int vocabSize, const int* endIds,
|
||||
@ -176,7 +123,7 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
|
||||
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
|
||||
float* outputLogProbs, const int maxTopK, const int* topKs, const float topP, const float* topPs,
|
||||
curandState_t* curandstate, const int* endIds, const int vocabSize, const bool* skipDecode, const int* batchSlots,
|
||||
const bool normalizeLogProbs)
|
||||
const bool normalizeLogProbs, const bool logitHasProbs)
|
||||
{
|
||||
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||
T const MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
@ -240,7 +187,7 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
|
||||
|
||||
// when cumLogProbs are computed, topKTmpValBuf (logits_buf_) are
|
||||
// already pre-processed by softmax_kernel
|
||||
if (cumLogProbs == nullptr && outputLogProbs == nullptr)
|
||||
if (!logitHasProbs)
|
||||
{
|
||||
total.u = __expf(total.u - maxLogit);
|
||||
}
|
||||
@ -309,7 +256,8 @@ __global__ void topKStage2Sampling(const int* __restrict topKTmpIdBuf, T* topKTm
|
||||
topKStage2Sampling<T, BLOCK_SIZE_2_, BLOCKS_PER_BEAM_> \
|
||||
<<<batchSize, BLOCK_SIZE_2_, K_MAX * sizeof(int) + K_MAX * sizeof(float), stream>>>(topKTmpIdBuf, \
|
||||
topKTmpValBuf, ids, sequenceLengths, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, maxTopK, \
|
||||
topKs, topP, topPs, curandstate, endIds, vocabSize, skipDecode, batchSlots, normalizeLogProbs); \
|
||||
topKs, topP, topPs, curandstate, endIds, vocabSize, skipDecode, batchSlots, normalizeLogProbs, \
|
||||
logitsHasProbs); \
|
||||
break;
|
||||
|
||||
template <typename T>
|
||||
@ -317,9 +265,9 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo
|
||||
const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
|
||||
const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize,
|
||||
const bool* skipDecode, const bool normalizeLogProbs)
|
||||
const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs)
|
||||
{
|
||||
TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
// Not allow an ambiguous inputs topP and topPs.
|
||||
assert(topP == 1.0f || topPs == nullptr);
|
||||
@ -341,6 +289,11 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo
|
||||
return;
|
||||
}
|
||||
|
||||
if (maxTopK == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
T* tempLogProbs = (T*) workspace;
|
||||
int* topKTmpIdBuf = (int*) (tempLogProbs + tempLogProbsBufSize);
|
||||
T* topKTmpValBuf = (T*) (topKTmpIdBuf + topKTmpIdsBufSize);
|
||||
@ -367,6 +320,8 @@ void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* lo
|
||||
CASE_K(1024, 256, 256, 8, normalizeLogProbs);
|
||||
default: throw std::domain_error(fmtstr("top-k kernel supports 1<=k<=1024 but got k=%d", maxTopK));
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
#undef CASE_K
|
||||
@ -375,37 +330,37 @@ template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, co
|
||||
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
|
||||
float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP,
|
||||
const float* topPs, const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream,
|
||||
const int batchSize, const bool* skipDecode, const bool normalizeLogProbs);
|
||||
const int batchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs);
|
||||
|
||||
template void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids,
|
||||
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
|
||||
float* outputLogProbs, curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP,
|
||||
const float* topPs, const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream,
|
||||
const int batchSize, const bool* skipDecode, const bool normalizeLogProbs);
|
||||
const int batchSize, const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs);
|
||||
|
||||
template <typename T>
|
||||
void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths,
|
||||
const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
|
||||
const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode,
|
||||
const bool normalizeLogProbs)
|
||||
const bool normalizeLogProbs, const bool logitsHasProbs)
|
||||
{
|
||||
invokeBatchTopKSampling(workspace, workspaceSize, logProbs, ids, sequenceLengths, finishedInput, finishedOutput,
|
||||
cumLogProbs, outputLogProbs, curandstate, topK, nullptr, topP, nullptr, vocabSizePadded, endIds, batchSlots,
|
||||
stream, batchSize, skipDecode, normalizeLogProbs);
|
||||
stream, batchSize, skipDecode, normalizeLogProbs, logitsHasProbs);
|
||||
}
|
||||
|
||||
template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const float* logProbs, int** ids,
|
||||
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
|
||||
float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded,
|
||||
const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode,
|
||||
const bool normalizeLogProbs);
|
||||
const bool normalizeLogProbs, const bool logitsHasProbs);
|
||||
|
||||
template void invokeTopKSampling(void* workspace, size_t& workspaceSize, const half* logProbs, int** ids,
|
||||
int* sequenceLengths, const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs,
|
||||
float* outputLogProbs, curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded,
|
||||
const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode,
|
||||
const bool normalizeLogProbs);
|
||||
const bool normalizeLogProbs, const bool logitsHasProbs);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -34,7 +34,7 @@ namespace kernels
|
||||
//! buffer.
|
||||
//! \param workspaceSize size of the workspace in bytes
|
||||
//! \param logProbs input buffer [batchSize x vocabSizePadded].
|
||||
//! Log probabilities of each token in the vocab. If cumLogProbs or outputLogProbs are specified,
|
||||
//! Log probabilities of each token in the vocab. If logitsHasProbs is true,
|
||||
//! logProbs must contain **just** probabilities instead of log probabilities.
|
||||
//! \param outputIds output buffer [maxBatchSize][maxSeqLen]. Contains pointers to rows with output tokens per request
|
||||
//! \param sequenceLength input/output buffer [maxBatchSize]. Current sequence length of the request up to, but excluding endId token
|
||||
@ -61,13 +61,14 @@ namespace kernels
|
||||
//! \param batchSize batch size
|
||||
//! \param skipDecode input buffer [maxBatchSize]. Flags whether to skip decoding per request
|
||||
//! \param normalizeLogProbs when set to True outputLogProbs are normalized to TopK
|
||||
//! \param logitsHasProbs flag to highlight that logProbs contains probabilities
|
||||
// clang-format on
|
||||
template <typename T>
|
||||
void invokeBatchTopKSampling(void* workspace, size_t& workspaceSize, const T* logProbs, int** ids, int* sequenceLengths,
|
||||
const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
curandState_t* curandstate, const int maxTopK, const int* topKs, const float topP, const float* topPs,
|
||||
const int vocabSizePadded, const int* endIds, const int* batchSlots, cudaStream_t stream, const int batchSize,
|
||||
const bool* skipDecode, const bool normalizeLogProbs);
|
||||
const bool* skipDecode, const bool normalizeLogProbs, const bool logitsHasProbs);
|
||||
|
||||
//! \brief Specialization of invokeBatchTopKSampling with topPs=nullptr and topKs=nullptr
|
||||
template <typename T>
|
||||
@ -75,24 +76,7 @@ void invokeTopKSampling(void* workspace, size_t& workspaceSize, const T* logProb
|
||||
const FinishedState* finishedInput, FinishedState* finishedOutput, float* cumLogProbs, float* outputLogProbs,
|
||||
curandState_t* curandstate, const int topK, const float topP, const int vocabSizePadded, const int* endIds,
|
||||
const int* batchSlots, cudaStream_t stream, const int batchSize, const bool* skipDecode,
|
||||
const bool normalizeLogProbs);
|
||||
|
||||
//! \brief Applies mask and bias to logits. Sets -MAX_FLT value for tokens in range [vocabSize; vocabSizePadded) to
|
||||
//! prevent them being chosen If request finished the generation, sets MAX_FLT to endId token and -MAX_FLT to all other
|
||||
//! tokens forcing to choose endId token. Otherwise, adds bias per token if bias pointer is not nullptr.
|
||||
//!
|
||||
//! \param logits input/output buffer [batchSize, vocabSize]. Logits to be modified.
|
||||
//! \param bias input buffer [vocabSize]. Bias to logit per token. Ignored if nullptr
|
||||
//! \param endIds input buffer [maxBatchSize]. EOS token ids per request
|
||||
//! \param finished input buffer [maxBatchSize] with flags set to true if request has finished the generation
|
||||
//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool
|
||||
//! \param batchSize batch size
|
||||
//! \param vocabSize unpadded vocab size
|
||||
//! \param vocabSizePadded padded vocab size
|
||||
//! \param stream stream
|
||||
template <typename T>
|
||||
void invokeAddBiasEndMask(T* logits, const T* bias, const int* endIds, const FinishedState* finished,
|
||||
const int* batchSlots, const int batchSize, const int vocabSize, const int vocabSizePadded, cudaStream_t stream);
|
||||
const bool normalizeLogProbs, const bool logitsHasProbs);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -310,6 +310,8 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub
|
||||
curandState_t* curandstate, int const batchSize, size_t const vocabSizePadded, int const* endIds,
|
||||
float const maxTopP, float const* topPs, cudaStream_t stream, bool const* skipDecode, int const* batchSlots)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
int const vocabSize = vocabSizePadded;
|
||||
|
||||
size_t sortedLogProbBufSize = batchSize * vocabSize * sizeof(T); // type T
|
||||
@ -338,6 +340,7 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub
|
||||
// If the most probable token exceeds P, we skip sorting by setting beginOffsetBuf[bi] = offsetBuf[bi]
|
||||
topPBeamTopKKernel<T, BLOCK_SIZE><<<batchSize, BLOCK_SIZE, 0, stream>>>(logProbs, sortedIdVals, sortedLogProbs,
|
||||
finishedInput, vocabSize, offsetBuf, beginOffsetBuf, maxTopP, topPs, skipDecode, batchSlots);
|
||||
sync_check_cuda_error();
|
||||
|
||||
// Sort tokens by probability in descending order
|
||||
check_cuda_error(cub::DeviceSegmentedRadixSort::SortPairsDescending(cubTempStorage, cubTempStorageSize, logProbs,
|
||||
@ -352,6 +355,9 @@ void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cub
|
||||
topPSsampling<T, SAMPLING_BLOCK_SIZE><<<grid, SAMPLING_BLOCK_SIZE, 0, stream>>>(sortedLogProbs, sortedIdVals,
|
||||
outputIds, sequenceLength, finishedInput, finishedOutput, cumLogProbs, outputLogProbs, beginOffsetBuf,
|
||||
offsetBuf + 1, vocabSize, curandstate, maxTopP, topPs, endIds, batchSize, skipDecode, batchSlots);
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template void invokeBatchTopPSampling(void* workspace, size_t& workspaceSize, size_t& cubTempStorageSize,
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2020-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -24,43 +24,44 @@ namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
__global__ void stopWordsCriterion(const int** outputIds, const int** parentIds, const int* stopWords,
|
||||
FinishedState* finished, const int* sequenceLengths, const int* batchSlots, size_t stopWordsLen, int batchSize,
|
||||
int beamWidth, int maxSeqLen)
|
||||
__global__ void stopWordsCriterion(int32_t const** outputIds, int32_t const** parentIds, int32_t const** stopWords,
|
||||
FinishedState* finished, int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t const* stopWordsLens,
|
||||
int32_t batchSize, int32_t beamWidth, int32_t maxSeqLen)
|
||||
{
|
||||
int const id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int const batchIdx = blockIdx.y / beamWidth;
|
||||
int const beamIdx = blockIdx.y % beamWidth;
|
||||
int32_t const id = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int32_t const batchIdx = blockIdx.y / beamWidth;
|
||||
int32_t const beamIdx = blockIdx.y % beamWidth;
|
||||
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
|
||||
auto const batchBeamIdx = batchSlot * beamWidth + beamIdx;
|
||||
|
||||
int const* baseStopWords = stopWords + batchSlot * 2 * stopWordsLen;
|
||||
int const* baseOffsets = baseStopWords + stopWordsLen;
|
||||
auto const* baseStopWords = stopWords[batchSlot];
|
||||
auto const stopWordsLen = stopWordsLens[batchSlot];
|
||||
auto const* baseOffsets = baseStopWords + stopWordsLen;
|
||||
|
||||
if (id >= stopWordsLen || baseOffsets[id] < 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
int const itemEnd = baseOffsets[id];
|
||||
int const itemStart = (id > 0) ? baseOffsets[id - 1] : 0;
|
||||
int const itemSize = itemEnd - itemStart;
|
||||
auto const itemEnd = baseOffsets[id];
|
||||
auto const itemStart = (id > 0) ? baseOffsets[id - 1] : 0;
|
||||
auto const itemSize = itemEnd - itemStart;
|
||||
|
||||
// The single-token case unconditionally bans the token
|
||||
bool shouldStop = false;
|
||||
|
||||
// Need to minus 1 because the sequenceLengths is updated in this step
|
||||
int const currentStep = sequenceLengths[batchBeamIdx] - 1;
|
||||
auto const currentStep = sequenceLengths[batchBeamIdx] - 1;
|
||||
// Enough previously generated tokens to look for a match
|
||||
if (currentStep + 1 >= itemSize)
|
||||
{
|
||||
shouldStop = true;
|
||||
int parentId = beamIdx;
|
||||
auto parentId = beamIdx;
|
||||
bool const gatherBeam = beamWidth > 1;
|
||||
|
||||
for (int tokenIdx = itemSize - 1; tokenIdx >= 0; tokenIdx--)
|
||||
for (int32_t tokenIdx = itemSize - 1; tokenIdx >= 0; tokenIdx--)
|
||||
{
|
||||
int const previousToken
|
||||
auto const previousToken
|
||||
= outputIds[batchSlot][parentId * maxSeqLen + currentStep - (itemSize - 1) + tokenIdx];
|
||||
if (previousToken != baseStopWords[itemStart + tokenIdx])
|
||||
{
|
||||
@ -88,15 +89,16 @@ __global__ void stopWordsCriterion(const int** outputIds, const int** parentIds,
|
||||
}
|
||||
}
|
||||
|
||||
void invokeStopWordsCriterion(const int** outputIds, const int** parentIds, const int* stopWords,
|
||||
FinishedState* finished, const int* sequenceLengths, const int* batchSlots, size_t stopWordsLen, int batchSize,
|
||||
int beamWidth, int maxSeqLen, cudaStream_t stream)
|
||||
void invokeStopWordsCriterion(int32_t const** outputIds, int32_t const** parentIds, int32_t const** stopWords,
|
||||
FinishedState* finished, int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t const* stopWordsLen,
|
||||
int32_t maxStopWordsLen, int32_t batchSize, int32_t beamWidth, int32_t maxSeqLen, cudaStream_t stream)
|
||||
{
|
||||
// Check if we have sampled a word from the stopWords list. If so, stop the sequence.
|
||||
dim3 block, grid;
|
||||
constexpr size_t maxBlockSize{256};
|
||||
block.x = min(((stopWordsLen + 32 - 1) / 32) * 32, maxBlockSize);
|
||||
grid.x = (stopWordsLen + block.x - 1) / block.x;
|
||||
constexpr int32_t maxBlockSize{256};
|
||||
|
||||
block.x = min(((maxStopWordsLen + 32 - 1) / 32) * 32, maxBlockSize);
|
||||
grid.x = (maxStopWordsLen + block.x - 1) / block.x;
|
||||
grid.y = batchSize * beamWidth;
|
||||
|
||||
stopWordsCriterion<<<grid, block, 0, stream>>>(outputIds, parentIds, stopWords, finished, sequenceLengths,
|
||||
@ -104,15 +106,15 @@ void invokeStopWordsCriterion(const int** outputIds, const int** parentIds, cons
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
__global__ void lengthCriterion(FinishedState* finished, int* finishedSum, const uint32_t* sequenceLimitLength,
|
||||
const int* sequenceLengths, const int* batchSlots, int batchSize, int beamWidth)
|
||||
__global__ void lengthCriterion(FinishedState* finished, int32_t* finishedSum, uint32_t const* sequenceLimitLength,
|
||||
int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth)
|
||||
{
|
||||
int threadFinishedCount = 0;
|
||||
for (int index = threadIdx.x; index < batchSize * beamWidth; index += blockDim.x)
|
||||
int32_t threadFinishedCount = 0;
|
||||
auto const batchIdx = blockIdx.x;
|
||||
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
|
||||
|
||||
for (int32_t beamIdx = threadIdx.x; beamIdx < beamWidth; beamIdx += blockDim.x)
|
||||
{
|
||||
int const batchIdx = index / beamWidth;
|
||||
int const beamIdx = index % beamWidth;
|
||||
auto const batchSlot = batchSlots != nullptr ? batchSlots[batchIdx] : batchIdx;
|
||||
auto const batchSlotBeamWidthIdx = batchSlot * beamWidth + beamIdx;
|
||||
|
||||
auto finishState = finished[batchSlotBeamWidthIdx];
|
||||
@ -140,19 +142,20 @@ __global__ void lengthCriterion(FinishedState* finished, int* finishedSum, const
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
finishedSum[0] = blockFinishedCount;
|
||||
finishedSum[batchSlot] = blockFinishedCount;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void invokeLengthCriterion(FinishedState* finished, int* finishedSum, const uint32_t* sequenceLimitLength,
|
||||
const int* sequenceLengths, const int* batchSlots, int batchSize, int beamWidth, cudaStream_t stream)
|
||||
void invokeLengthCriterion(FinishedState* finished, int32_t* finishedSum, uint32_t const* sequenceLimitLength,
|
||||
int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
// Check if we have attained the sequence length limit. If so, stop the
|
||||
// sequence. In addition, check if all sequences are stopped and return the
|
||||
// result in shouldStop
|
||||
dim3 block{min(512, uint32_t(batchSize * beamWidth))};
|
||||
dim3 grid{1};
|
||||
dim3 block{min(512, uint32_t(beamWidth))};
|
||||
dim3 grid{uint32_t(batchSize)};
|
||||
|
||||
lengthCriterion<<<grid, block, 0, stream>>>(
|
||||
finished, finishedSum, sequenceLimitLength, sequenceLengths, batchSlots, batchSize, beamWidth);
|
||||
|
||||
@ -27,7 +27,7 @@ namespace kernels
|
||||
//! \param outputIds input buffer [maxBatchSize][maxSeqLen]. Contains pointers to rows with output tokens per request
|
||||
//! \param parentIds input buffer [maxBatchSize][maxSeqLen]. Contains pointers to rows with parent ids. Applicable when
|
||||
//! beamWidth > 1
|
||||
//! \param stopWords input buffer [maxBatchSize, 2, stopWordsLen]. For each instance in batch the first row
|
||||
//! \param stopWords input buffer [maxBatchSize][2, stopWordsLen]. For each instance in batch the first row
|
||||
//! is the token ids of the stop words. The second row is the exclusive prefix sum of the word lengths.
|
||||
//! In case all the words are made of a single token,
|
||||
//! the inner-most dimension of the tensor must be increased by 1.
|
||||
@ -37,14 +37,15 @@ namespace kernels
|
||||
//! \param sequenceLengths input buffer [maxBatchSize, beamWidth]. Current sequence
|
||||
//! lengths of the request tokens.
|
||||
//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool
|
||||
//! \param stopWordsLen cumulative length of all stop words
|
||||
//! \param stopWordsLen input buffer [maxBatchSize], cumulative length of all stop words per request
|
||||
//! \param maxStopWordsLen maximum stopWordsLen over all requests in the batch
|
||||
//! \param batchSize batch size
|
||||
//! \param beamWidth beam width
|
||||
//! \param maxSeqLen maximum length of the sequence
|
||||
//! \param stream stream
|
||||
void invokeStopWordsCriterion(const int** outputIds, const int** parentIds, const int* stopWords,
|
||||
FinishedState* finished, const int* sequenceLengths, const int* batchSlots, size_t stopWordsLen, int batchSize,
|
||||
int beamWidth, int maxSeqLen, cudaStream_t stream);
|
||||
void invokeStopWordsCriterion(int32_t const** outputIds, int32_t const** parentIds, int32_t const** stopWords,
|
||||
FinishedState* finished, int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t const* stopWordsLen,
|
||||
int32_t maxStopWordsLen, int32_t batchSize, int32_t beamWidth, int32_t maxSeqLen, cudaStream_t stream);
|
||||
|
||||
//! \brief Sets finished states based on the sequenceLimitLength and computes number of finished sequences in the batch.
|
||||
//!
|
||||
@ -59,7 +60,8 @@ void invokeStopWordsCriterion(const int** outputIds, const int** parentIds, cons
|
||||
//! \param batchSize batch size
|
||||
//! \param beamWidth beam width
|
||||
//! \param stream stream
|
||||
void invokeLengthCriterion(FinishedState* finished, int* finishedSum, const uint32_t* sequenceLimitLength,
|
||||
const int* sequenceLengths, const int* batchSlots, int batchSize, int beamWidth, cudaStream_t stream);
|
||||
void invokeLengthCriterion(FinishedState* finished, int32_t* finishedSum, uint32_t const* sequenceLimitLength,
|
||||
int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth,
|
||||
cudaStream_t stream);
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -274,9 +274,9 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu
|
||||
// src QKV: [batch, time, 3, head_num, size_per_head]
|
||||
// head_num != kv_head_num:
|
||||
// src QKV: [batch, time, head_num * size_per_head + 2 * kv_head_num * size_per_head]
|
||||
const int src_q_idx = token_idx * n + hidden_idx;
|
||||
const int src_k_idx = token_idx * n + src_k_offset + hidden_idx_kv;
|
||||
const int src_v_idx = token_idx * n + src_v_offset + hidden_idx_kv;
|
||||
auto const src_q_idx = static_cast<size_t>(token_idx) * n + hidden_idx;
|
||||
auto const src_k_idx = static_cast<size_t>(token_idx) * n + src_k_offset + hidden_idx_kv;
|
||||
auto const src_v_idx = static_cast<size_t>(token_idx) * n + src_v_offset + hidden_idx_kv;
|
||||
|
||||
Vec_type q, k, v;
|
||||
Vec_type q_bias, k_bias, v_bias;
|
||||
|
||||
@ -56,7 +56,8 @@ enum class WeightOnlyActivationFunctionType
|
||||
enum class WeightOnlyActivationType
|
||||
{
|
||||
FP16,
|
||||
BF16
|
||||
BF16,
|
||||
FP8
|
||||
};
|
||||
|
||||
struct WeightOnlyParams
|
||||
|
||||
@ -41,6 +41,12 @@ struct SupportedLayout<cutlass::uint4b_t, cutlass::layout::ColumnMajorTileInterl
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct SupportedLayout<cutlass::uint4b_t, cutlass::layout::ColumnMajor>
|
||||
{
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename TypeB, typename Arch>
|
||||
bool isEnabled()
|
||||
{
|
||||
@ -53,16 +59,20 @@ bool isEnabledForArch(int arch)
|
||||
{
|
||||
if (arch >= 70 && arch < 75)
|
||||
{
|
||||
return isEnabled<TypeB, cutlass::arch::Sm70>();
|
||||
return false;
|
||||
}
|
||||
else if (arch >= 75 && arch < 80)
|
||||
{
|
||||
return isEnabled<TypeB, cutlass::arch::Sm75>();
|
||||
}
|
||||
else if (arch >= 80 && arch <= 90)
|
||||
else if (arch >= 80 && arch < 90)
|
||||
{
|
||||
return isEnabled<TypeB, cutlass::arch::Sm80>();
|
||||
}
|
||||
else if (arch >= 90)
|
||||
{
|
||||
return isEnabled<TypeB, cutlass::arch::Sm90>();
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "Unsupported Arch");
|
||||
|
||||
75
cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/common.h
Normal file
75
cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/common.h
Normal file
@ -0,0 +1,75 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_fp8.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
enum class KernelType
|
||||
{
|
||||
W4A16,
|
||||
W4A8
|
||||
};
|
||||
|
||||
struct Params
|
||||
{
|
||||
void* act;
|
||||
void* act_scale;
|
||||
void* weight;
|
||||
void* scales;
|
||||
void* zeros;
|
||||
void* bias;
|
||||
void* out;
|
||||
float alpha;
|
||||
int m;
|
||||
int n;
|
||||
int k;
|
||||
int groupsize;
|
||||
KernelType type;
|
||||
|
||||
Params(void* _act, void* _act_scale, void* _weight, void* _scales, void* _zeros, void* _bias, void* _out,
|
||||
float _alpha, int _m, int _n, int _k, int _groupsize, KernelType _type)
|
||||
: act(_act)
|
||||
, act_scale(_act_scale)
|
||||
, weight(_weight)
|
||||
, scales(_scales)
|
||||
, zeros(_zeros)
|
||||
, bias(_bias)
|
||||
, out(_out)
|
||||
, alpha(_alpha)
|
||||
, m(_m)
|
||||
, n(_n)
|
||||
, k(_k)
|
||||
, groupsize(_groupsize)
|
||||
, type(_type)
|
||||
{
|
||||
}
|
||||
};
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
389
cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernel.h
Normal file
389
cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernel.h
Normal file
@ -0,0 +1,389 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/common.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
|
||||
struct ConverterI4ToF16
|
||||
{
|
||||
__device__ __forceinline__ static void convert(uint32_t& src, uint4& dst)
|
||||
{
|
||||
uint32_t* r = reinterpret_cast<uint32_t*>(&dst);
|
||||
uint32_t prmt_indices[4] = {0x4040, 0x4141, 0x4242, 0x4343};
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 4; ++ii)
|
||||
{
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" prmt.b32 %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "=r"(r[ii])
|
||||
: "r"(src), "n"(0), "r"(prmt_indices[ii]));
|
||||
}
|
||||
static constexpr uint32_t xor_mask = 0x64806408;
|
||||
static constexpr uint32_t and_mask = 0xFFF0FF0F;
|
||||
static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa;
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 4; ++ii)
|
||||
{
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" lop3.b32 %0, %0, %1, %2, %3;\n"
|
||||
"}\n"
|
||||
: "+r"(r[ii])
|
||||
: "n"(and_mask), "n"(xor_mask), "n"(immLut));
|
||||
}
|
||||
static constexpr uint32_t hfma_bias_rep = 0xD480E408;
|
||||
static constexpr uint32_t hfma_scale_rep = 0x2C003C00;
|
||||
|
||||
const half2& hfma_bias = reinterpret_cast<const half2&>(hfma_bias_rep);
|
||||
const half2& hfma_scale = reinterpret_cast<const half2&>(hfma_scale_rep);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < 4; ++ii)
|
||||
{
|
||||
__half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]);
|
||||
fp16x2_val = __hfma2(hfma_scale, fp16x2_val, hfma_bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ __forceinline__ static void convert(void* src, void* dst)
|
||||
{
|
||||
static_assert(N == 8 || N == 16);
|
||||
convert(reinterpret_cast<uint32_t*>(src)[0], reinterpret_cast<uint4*>(dst)[0]);
|
||||
if constexpr (N == 16)
|
||||
{
|
||||
convert(reinterpret_cast<uint32_t*>(src)[1], reinterpret_cast<uint4*>(dst)[1]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TVec, int N, bool Enable, typename TSrc>
|
||||
__device__ __forceinline__ void load(void* dst, TSrc* src, int stride)
|
||||
{
|
||||
if constexpr (Enable)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < N; ++ii)
|
||||
{
|
||||
reinterpret_cast<TVec*>(dst)[ii] = reinterpret_cast<TVec*>(src + ii * stride)[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int M, int K, bool Enable>
|
||||
__device__ __forceinline__ void apply_scale(void* act, void* act_scale)
|
||||
{
|
||||
static_assert(K % 2 == 0);
|
||||
static constexpr int VecK = K / 2;
|
||||
if constexpr (Enable)
|
||||
{
|
||||
half2* pa = reinterpret_cast<half2*>(act);
|
||||
half2* pb = reinterpret_cast<half2*>(act_scale);
|
||||
#pragma unroll
|
||||
for (int m = 0; m < M; ++m)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int k = 0; k < VecK; ++k)
|
||||
{
|
||||
pa[m * VecK + k] = __hmul2(pa[m * VecK + k], pb[k]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, int K, bool EnableZero>
|
||||
__device__ __forceinline__ void dequantize(void* w, void* quantized_w, void* scales, void* zeros)
|
||||
{
|
||||
using Converter = ConverterI4ToF16;
|
||||
static_assert(K % 2 == 0);
|
||||
static constexpr int VecK = K / 2;
|
||||
#pragma unroll
|
||||
for (int n = 0; n < N; ++n)
|
||||
{
|
||||
ConverterI4ToF16::convert<K>(
|
||||
reinterpret_cast<uint8_t*>(quantized_w) + n * K / 2, reinterpret_cast<half*>(w) + n * K);
|
||||
half2 vec_scale = __half2half2(reinterpret_cast<half*>(scales)[n]);
|
||||
half2 vec_zero = __half2half2(__float2half_rn(0.f));
|
||||
if constexpr (EnableZero)
|
||||
{
|
||||
vec_zero = __half2half2(reinterpret_cast<half*>(zeros)[n]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int k = 0; k < VecK; ++k)
|
||||
{
|
||||
reinterpret_cast<half2*>(w)[n * VecK + k]
|
||||
= __hfma2(reinterpret_cast<half2*>(w)[n * VecK + k], vec_scale, vec_zero);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int N, int K>
|
||||
__device__ __forceinline__ void pack_to_vec2(void* dst, void* src)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int n = 0; n < N; n += 2)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int k = 0; k < K; ++k)
|
||||
{
|
||||
reinterpret_cast<half*>(dst)[n * K + k * 2] = reinterpret_cast<half*>(src)[n * K + k];
|
||||
reinterpret_cast<half*>(dst)[n * K + k * 2 + 1] = reinterpret_cast<half*>(src)[(n + 1) * K + k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int M, int N, int K>
|
||||
__device__ __forceinline__ void mma(void* acc, void* w_pack2, void* act)
|
||||
{
|
||||
static_assert(N % 2 == 0);
|
||||
static constexpr int VecN = N / 2;
|
||||
#pragma unroll
|
||||
for (int m = 0; m < M; ++m)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int n = 0; n < VecN; ++n)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int k = 0; k < K; ++k)
|
||||
{
|
||||
reinterpret_cast<half2*>(acc)[m * VecN + n] = __hfma2(reinterpret_cast<half2*>(w_pack2)[n * K + k],
|
||||
__half2half2(reinterpret_cast<half*>(act)[m * K + k]), reinterpret_cast<half2*>(acc)[m * VecN + n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T warp_reduce_sum(T& val)
|
||||
{
|
||||
val += __shfl_xor_sync(~0, val, 16);
|
||||
val += __shfl_xor_sync(~0, val, 8);
|
||||
val += __shfl_xor_sync(~0, val, 4);
|
||||
val += __shfl_xor_sync(~0, val, 2);
|
||||
val += __shfl_xor_sync(~0, val, 1);
|
||||
return val;
|
||||
}
|
||||
|
||||
template <int CtaM, int CtaN, int Threads, bool EnableBias>
|
||||
__device__ __forceinline__ void epilogue(void* out, int stride, void* tile_acc, void* bias, float alpha)
|
||||
{
|
||||
static constexpr int WarpSize = 32;
|
||||
static constexpr int WarpNum = Threads / WarpSize;
|
||||
static constexpr int AlignShmemSize = (CtaM * CtaN + 31) / 32 * 32;
|
||||
static_assert(Threads % WarpSize == 0);
|
||||
__shared__ float shmem[AlignShmemSize * WarpNum];
|
||||
int tid = threadIdx.x;
|
||||
int warp_id = tid / WarpSize, lane_id = tid % WarpSize;
|
||||
#pragma unroll
|
||||
for (int m = 0; m < CtaM; ++m)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int n = 0; n < CtaN; ++n)
|
||||
{
|
||||
float v = __half2float(reinterpret_cast<half*>(tile_acc)[m * CtaN + n]);
|
||||
v = warp_reduce_sum(v);
|
||||
if (lane_id == 0)
|
||||
{
|
||||
shmem[warp_id * AlignShmemSize + m * CtaN + n] = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int ii = tid; ii < CtaM * CtaN; ii += Threads)
|
||||
{
|
||||
int m = ii / CtaN, n = ii % CtaN;
|
||||
float val = 0.f, v_bias = 0.f;
|
||||
if constexpr (EnableBias)
|
||||
{
|
||||
v_bias = static_cast<float>(reinterpret_cast<half*>(bias)[n]);
|
||||
}
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < WarpNum; ++jj)
|
||||
{
|
||||
val += shmem[jj * AlignShmemSize + ii];
|
||||
}
|
||||
reinterpret_cast<half*>(out)[m * stride + n] = __float2half_rn(alpha * val + v_bias);
|
||||
}
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__device__ __forceinline__ void fill(void* tile, half v)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < N; ++ii)
|
||||
{
|
||||
reinterpret_cast<half*>(tile)[ii] = v;
|
||||
}
|
||||
}
|
||||
|
||||
struct Fp16Details
|
||||
{
|
||||
using ActDataType = half;
|
||||
static constexpr int StepK = 8;
|
||||
using AccessTypeAct = float4;
|
||||
using AccessTypeActScale = float4;
|
||||
using AccessTypeW = float;
|
||||
|
||||
template <int CtaM>
|
||||
__device__ __forceinline__ static void load_act(void* dst, void* src, int stride)
|
||||
{
|
||||
load<AccessTypeAct, CtaM, true>(dst, reinterpret_cast<ActDataType*>(src), stride);
|
||||
}
|
||||
};
|
||||
|
||||
struct Fp8Details
|
||||
{
|
||||
using ActDataType = __nv_fp8_e4m3;
|
||||
static constexpr int StepK = 8;
|
||||
using AccessTypeAct = float2;
|
||||
using AccessTypeActScale = float4;
|
||||
using AccessTypeW = float;
|
||||
|
||||
__device__ __forceinline__ static void conversion(void* dst, void* src)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < StepK / 4; ++ii)
|
||||
{
|
||||
asm volatile(
|
||||
"{\n"
|
||||
".reg .b16 lo, hi;\n"
|
||||
"mov.b32 {lo, hi}, %2;\n"
|
||||
"cvt.rn.f16x2.e4m3x2 %0, lo;\n"
|
||||
"cvt.rn.f16x2.e4m3x2 %1, hi;\n"
|
||||
"}\n"
|
||||
: "=r"(reinterpret_cast<uint32_t*>(dst)[ii * 2]), "=r"(reinterpret_cast<uint32_t*>(dst)[ii * 2 + 1])
|
||||
: "r"(reinterpret_cast<uint32_t*>(src)[ii]));
|
||||
}
|
||||
#else
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < StepK; ++ii)
|
||||
{
|
||||
reinterpret_cast<half*>(dst)[ii] = static_cast<half>(reinterpret_cast<ActDataType*>(src)[ii]);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <int CtaM>
|
||||
__device__ __forceinline__ static void load_act(void* dst, void* src, int stride)
|
||||
{
|
||||
ActDataType vec[CtaM * StepK];
|
||||
load<AccessTypeAct, CtaM, true>(vec, reinterpret_cast<ActDataType*>(src), stride);
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < CtaM; ++ii)
|
||||
{
|
||||
conversion(reinterpret_cast<half*>(dst) + ii * StepK, vec + ii * StepK);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Details, int CtaM, int CtaN, int Threads, int GroupSize, bool EnableActScale, bool EnableZero,
|
||||
bool EnableBias>
|
||||
__global__ void kernel(typename Details::ActDataType* act, half* act_scale, uint8_t* weight, half* scales, half* zeros,
|
||||
half* bias, half* out, float alpha, int m, int n, int k)
|
||||
{
|
||||
// ArgType ArgName DataType Shape Layout
|
||||
//
|
||||
// input act fp16 [m, k] RowMajor
|
||||
// input act_scale fp16 [1, k] RowMajor
|
||||
// input weight int4b [k, n] ColumnMajor
|
||||
// input scales fp16 [k / GroupSize, n] RowMajor
|
||||
// input zeros fp16 [k / GroupSize, n] RowMajor
|
||||
// input bias fp16 [1, n] RowMajor
|
||||
// output out fp16 [m, n] RowMajor
|
||||
|
||||
using AccessTypeActScale = typename Details::AccessTypeActScale;
|
||||
using AccessTypeW = typename Details::AccessTypeW;
|
||||
static constexpr int StepK = Details::StepK;
|
||||
|
||||
static constexpr bool Mandatory = true;
|
||||
static constexpr int CtaK = StepK * Threads;
|
||||
static_assert(CtaN % 2 == 0);
|
||||
|
||||
const int m_tile_id = blockIdx.x, n_tile_id = blockIdx.y, tid = threadIdx.x;
|
||||
const int m_offset = m_tile_id * CtaM, n_offset = n_tile_id * CtaN;
|
||||
|
||||
act += m_offset * k;
|
||||
weight += n_offset * k / 2;
|
||||
scales += n_offset;
|
||||
zeros += n_offset;
|
||||
bias += n_offset;
|
||||
out += m_offset * n + n_offset;
|
||||
|
||||
half tile_a[StepK * CtaM], tile_w[StepK * CtaN], tile_w_pack2[StepK * CtaN];
|
||||
half tile_acc[CtaM * CtaN];
|
||||
fill<CtaM * CtaN>(tile_acc, __float2half_rn(0.f));
|
||||
|
||||
for (int idx_k = tid * StepK; idx_k < k; idx_k += CtaK)
|
||||
{
|
||||
half vec_act_scale[StepK];
|
||||
half vec_scale[CtaN], vec_zero[CtaN];
|
||||
uint8_t tile_w_quantized[StepK * CtaN / 2];
|
||||
// Load Data
|
||||
Details::load_act<CtaM>(tile_a, act + idx_k, k);
|
||||
load<AccessTypeActScale, 1, EnableActScale>(vec_act_scale, act_scale + idx_k, 0);
|
||||
load<AccessTypeW, CtaN, Mandatory>(tile_w_quantized, weight + idx_k / 2, k / 2);
|
||||
load<half, CtaN, Mandatory>(vec_scale, scales + idx_k / GroupSize * n, 1);
|
||||
load<half, CtaN, EnableZero>(vec_zero, zeros + idx_k / GroupSize * n, 1);
|
||||
// Dequantize Data
|
||||
apply_scale<CtaM, StepK, EnableActScale>(tile_a, vec_act_scale);
|
||||
dequantize<CtaN, StepK, EnableZero>(tile_w, tile_w_quantized, vec_scale, vec_zero);
|
||||
// Rearrange
|
||||
pack_to_vec2<CtaN, StepK>(tile_w_pack2, tile_w);
|
||||
// MMA
|
||||
mma<CtaM, CtaN, StepK>(tile_acc, tile_w_pack2, tile_a);
|
||||
}
|
||||
// Epilogue
|
||||
epilogue<CtaM, CtaN, Threads, EnableBias>(out, n, tile_acc, bias, alpha);
|
||||
}
|
||||
|
||||
template <typename Details, int CtaM, int CtaN, int Threads, int GroupSize, bool EnableActScale, bool EnableZero,
|
||||
bool EnableBias>
|
||||
void exec_kernel(Params& params, cudaStream_t s)
|
||||
{
|
||||
if (params.m % CtaM || params.n % CtaN)
|
||||
{
|
||||
throw std::runtime_error("launch failed");
|
||||
}
|
||||
dim3 grid(params.m / CtaM, params.n / CtaN);
|
||||
dim3 block(Threads);
|
||||
// clang-format off
|
||||
kernel<Details, CtaM, CtaN, Threads, GroupSize, EnableActScale, EnableZero, EnableBias><<<grid, block, 0, s>>>(
|
||||
reinterpret_cast<typename Details::ActDataType*>(params.act),
|
||||
reinterpret_cast<half*>(params.act_scale),
|
||||
reinterpret_cast<uint8_t*>(params.weight),
|
||||
reinterpret_cast<half*>(params.scales),
|
||||
reinterpret_cast<half*>(params.zeros),
|
||||
reinterpret_cast<half*>(params.bias),
|
||||
reinterpret_cast<half*>(params.out),
|
||||
params.alpha,
|
||||
params.m, params.n, params.k
|
||||
);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,123 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernel.h"
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/kernelLauncher.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
#define DISPATCHER_FOR_M(target_m, CtaM, CtaN, Threads) \
|
||||
do \
|
||||
{ \
|
||||
if (params.m == target_m) \
|
||||
{ \
|
||||
exec_kernel<Details, CtaM, CtaN, Threads, GroupSize, EnableActScale, EnableZero, EnableBias>(params, s); \
|
||||
return; \
|
||||
} \
|
||||
} while (0);
|
||||
|
||||
template <typename Details, int GroupSize, bool EnableActScale, bool EnableZero, bool EnableBias>
|
||||
void dispatcher(Params& params, cudaStream_t s)
|
||||
{
|
||||
// clang-format off
|
||||
DISPATCHER_FOR_M(1, 1, 8, 128);
|
||||
DISPATCHER_FOR_M(2, 2, 4, 128);
|
||||
DISPATCHER_FOR_M(3, 3, 16, 128);
|
||||
DISPATCHER_FOR_M(4, 4, 16, 128);
|
||||
DISPATCHER_FOR_M(5, 5, 16, 128);
|
||||
DISPATCHER_FOR_M(6, 6, 16, 128);
|
||||
DISPATCHER_FOR_M(7, 7, 16, 128);
|
||||
DISPATCHER_FOR_M(8, 8, 16, 128);
|
||||
DISPATCHER_FOR_M(9, 9, 8, 128);
|
||||
DISPATCHER_FOR_M(10, 10, 8, 128);
|
||||
DISPATCHER_FOR_M(11, 11, 8, 128);
|
||||
DISPATCHER_FOR_M(12, 12, 8, 128);
|
||||
DISPATCHER_FOR_M(13, 13, 8, 128);
|
||||
DISPATCHER_FOR_M(14, 14, 8, 128);
|
||||
DISPATCHER_FOR_M(15, 15, 8, 128);
|
||||
DISPATCHER_FOR_M(16, 16, 8, 128);
|
||||
// clang-format on
|
||||
throw std::runtime_error("unsupported m");
|
||||
}
|
||||
|
||||
template <typename Details, int GroupSize>
|
||||
void check_pointer(Params& params, cudaStream_t s)
|
||||
{
|
||||
if (params.act_scale && params.zeros && params.bias)
|
||||
{
|
||||
dispatcher<Details, GroupSize, true, true, true>(params, s);
|
||||
}
|
||||
else if (params.act_scale && params.zeros && !params.bias)
|
||||
{
|
||||
dispatcher<Details, GroupSize, true, true, false>(params, s);
|
||||
}
|
||||
else if (params.act_scale && !params.zeros && params.bias)
|
||||
{
|
||||
dispatcher<Details, GroupSize, true, false, true>(params, s);
|
||||
}
|
||||
else if (!params.act_scale && params.zeros && params.bias)
|
||||
{
|
||||
dispatcher<Details, GroupSize, false, true, true>(params, s);
|
||||
}
|
||||
else if (!params.act_scale && !params.zeros && params.bias)
|
||||
{
|
||||
dispatcher<Details, GroupSize, false, false, true>(params, s);
|
||||
}
|
||||
else if (params.act_scale && !params.zeros && !params.bias)
|
||||
{
|
||||
dispatcher<Details, GroupSize, true, false, false>(params, s);
|
||||
}
|
||||
else if (!params.act_scale && params.zeros && !params.bias)
|
||||
{
|
||||
dispatcher<Details, GroupSize, false, true, false>(params, s);
|
||||
}
|
||||
else
|
||||
{
|
||||
dispatcher<Details, GroupSize, false, false, false>(params, s);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Details>
|
||||
void select_gs(Params& params, cudaStream_t s)
|
||||
{
|
||||
if (params.groupsize == 64)
|
||||
{
|
||||
check_pointer<Details, 64>(params, s);
|
||||
}
|
||||
else if (params.groupsize == 128)
|
||||
{
|
||||
check_pointer<Details, 128>(params, s);
|
||||
}
|
||||
}
|
||||
|
||||
void kernel_launcher(Params& params, cudaStream_t s)
|
||||
{
|
||||
if (params.type == KernelType::W4A16)
|
||||
{
|
||||
select_gs<Fp16Details>(params, s);
|
||||
}
|
||||
else if (params.type == KernelType::W4A8)
|
||||
{
|
||||
select_gs<Fp8Details>(params, s);
|
||||
}
|
||||
}
|
||||
} // namespace weight_only
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -0,0 +1,30 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/sm90/common.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace weight_only
|
||||
{
|
||||
|
||||
void kernel_launcher(Params& params, cudaStream_t s);
|
||||
}
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -16,9 +16,7 @@
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/penaltyKernels.h"
|
||||
#include "tensorrt_llm/layers/baseBeamSearchLayer.h"
|
||||
#include "tensorrt_llm/layers/fillBuffers.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@ -111,11 +109,6 @@ void BaseBeamSearchLayer<T>::freeBuffer()
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
if (mIsAllocateBuffer)
|
||||
{
|
||||
mAllocator->free((void**) (&temperature_buf_));
|
||||
mAllocator->free((void**) (&min_lengths_buf_));
|
||||
mAllocator->free((void**) (&repetition_penalty_buf_));
|
||||
mAllocator->free((void**) (&presence_penalty_buf_));
|
||||
mAllocator->free((void**) (&frequency_penalty_buf_));
|
||||
mIsAllocateBuffer = false;
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
@ -125,12 +118,6 @@ template <typename T>
|
||||
void BaseBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
temperature_buf_ = mAllocator->reMalloc(temperature_buf_, sizeof(float) * batch_size, false);
|
||||
min_lengths_buf_ = mAllocator->reMalloc(min_lengths_buf_, sizeof(int) * batch_size, false);
|
||||
repetition_penalty_buf_ = mAllocator->reMalloc(repetition_penalty_buf_, sizeof(float) * batch_size, false);
|
||||
presence_penalty_buf_ = mAllocator->reMalloc(presence_penalty_buf_, sizeof(float) * batch_size, false);
|
||||
frequency_penalty_buf_ = mAllocator->reMalloc(frequency_penalty_buf_, sizeof(float) * batch_size, false);
|
||||
|
||||
mIsAllocateBuffer = true;
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -138,47 +125,13 @@ void BaseBeamSearchLayer<T>::allocateBuffer(size_t batch_size)
|
||||
template <typename T>
|
||||
void BaseBeamSearchLayer<T>::setupBase(size_t batch_size, SetupParams const& setupParams)
|
||||
{
|
||||
allocateBuffer(batch_size);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
// Setup penalties.
|
||||
FillBuffers const fillBuffers{batch_size, mStream};
|
||||
|
||||
use_temperature_ = static_cast<bool>(setupParams.temperature);
|
||||
use_repetition_penalty_ = static_cast<bool>(setupParams.repetition_penalty);
|
||||
use_presence_penalty_ = static_cast<bool>(setupParams.presence_penalty);
|
||||
use_frequency_penalty_ = static_cast<bool>(setupParams.frequency_penalty);
|
||||
use_min_lengths_ = static_cast<bool>(setupParams.min_length);
|
||||
if (use_temperature_)
|
||||
{
|
||||
fillBuffers(setupParams.temperature, getDefaultPenaltyValue(RepetitionPenaltyType::Temperature), mTemperature,
|
||||
temperature_buf_, (float*) nullptr, (int*) nullptr);
|
||||
}
|
||||
if (use_repetition_penalty_)
|
||||
{
|
||||
fillBuffers(setupParams.repetition_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Repetition),
|
||||
mRepetitionPenalty, repetition_penalty_buf_, (float*) nullptr, (int*) nullptr);
|
||||
}
|
||||
if (use_presence_penalty_)
|
||||
{
|
||||
fillBuffers(setupParams.presence_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Presence),
|
||||
mPresencePenalty, presence_penalty_buf_, (float*) nullptr, (int*) nullptr);
|
||||
}
|
||||
if (use_frequency_penalty_)
|
||||
{
|
||||
fillBuffers(setupParams.frequency_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Frequency),
|
||||
mFrequencyPenalty, frequency_penalty_buf_, (float*) nullptr, (int*) nullptr);
|
||||
}
|
||||
if (use_min_lengths_)
|
||||
{
|
||||
fillBuffers(setupParams.min_length, (int) getDefaultPenaltyValue(RepetitionPenaltyType::MinLength), mMinLengths,
|
||||
min_lengths_buf_, (int*) nullptr, (int*) nullptr);
|
||||
}
|
||||
allocateBuffer(batch_size);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BaseBeamSearchLayer<T>::forward(BeamSearchOutputParams& outputs, ForwardParams const& params,
|
||||
int* penalty_workspace, const int* penalty_workspace_prev)
|
||||
void BaseBeamSearchLayer<T>::forward(BeamSearchOutputParams& outputs, ForwardParams const& params)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s", __PRETTY_FUNCTION__);
|
||||
Tensor& output_ids_ptr = outputs.output_ids_ptr;
|
||||
@ -195,47 +148,6 @@ void BaseBeamSearchLayer<T>::forward(BeamSearchOutputParams& outputs, ForwardPar
|
||||
Tensor const& logits = params.logits;
|
||||
const auto local_batch_size = logits.shape[0];
|
||||
|
||||
#define ALL_OF(p_, sz_, dt_, v_) (std::all_of(p_, p_ + sz_, [&](dt_ b) { return b == v_; }))
|
||||
|
||||
const T* embedding_bias = params.embedding_bias ? params.embedding_bias->template getPtr<const T>() : nullptr;
|
||||
auto* temperatures = (use_temperature_
|
||||
&& !ALL_OF(std::begin(mTemperature) + ite * local_batch_size, local_batch_size, float,
|
||||
getDefaultPenaltyValue(RepetitionPenaltyType::Temperature)))
|
||||
? temperature_buf_ + ite * local_batch_size
|
||||
: nullptr;
|
||||
auto* repetition_penalties
|
||||
= (use_repetition_penalty_
|
||||
&& !ALL_OF(std::begin(mRepetitionPenalty) + ite * local_batch_size, local_batch_size, float,
|
||||
getDefaultPenaltyValue(RepetitionPenaltyType::Repetition)))
|
||||
? repetition_penalty_buf_ + ite * local_batch_size
|
||||
: nullptr;
|
||||
auto* presence_penalties = (use_presence_penalty_
|
||||
&& !ALL_OF(std::begin(mPresencePenalty) + ite * local_batch_size, local_batch_size,
|
||||
float, getDefaultPenaltyValue(RepetitionPenaltyType::Presence)))
|
||||
? presence_penalty_buf_ + ite * local_batch_size
|
||||
: nullptr;
|
||||
auto* frequency_penalties = (use_frequency_penalty_
|
||||
&& !ALL_OF(std::begin(mFrequencyPenalty) + ite * local_batch_size, local_batch_size,
|
||||
float, getDefaultPenaltyValue(RepetitionPenaltyType::Frequency)))
|
||||
? frequency_penalty_buf_ + ite * local_batch_size
|
||||
: nullptr;
|
||||
auto* min_lengths = (use_min_lengths_
|
||||
&& !ALL_OF(std::begin(mMinLengths) + ite * local_batch_size, local_batch_size, int,
|
||||
(int) getDefaultPenaltyValue(RepetitionPenaltyType::MinLength)))
|
||||
? min_lengths_buf_ + ite * local_batch_size
|
||||
: nullptr;
|
||||
|
||||
InvokeBatchApplyPenaltyParams<T> penalty_params{logits.getPtr<T>(), embedding_bias,
|
||||
penalty_workspace + ite * local_batch_size * beam_width * vocab_size_,
|
||||
penalty_workspace_prev + ite * local_batch_size * beam_width * vocab_size_, temperatures, repetition_penalties,
|
||||
presence_penalties, frequency_penalties,
|
||||
(use_repetition_penalty_ || use_presence_penalty_ || use_frequency_penalty_), local_batch_size, beam_width,
|
||||
max_seq_len, vocab_size_, vocab_size_padded_, output_ids_ptr.template getPtr<const int*>(),
|
||||
outputs.parent_ids_ptr.template getPtr<const int*>(), input_lengths, sequence_length, min_lengths,
|
||||
params.end_ids.template getPtr<const int>(), nullptr, mStream};
|
||||
invokeBatchApplyPenalty(penalty_params);
|
||||
sync_check_cuda_error();
|
||||
|
||||
invokeSoftMax(outputs, params);
|
||||
sync_check_cuda_error();
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -19,7 +19,6 @@
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/kernels/beamSearchTopkKernels.h"
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include "tensorrt_llm/kernels/penaltyTypes.h"
|
||||
#include "tensorrt_llm/layers/baseLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
|
||||
@ -72,8 +71,7 @@ public:
|
||||
tc::Tensor src_cache_indirection; // [local_batch_size, beam_width, max_seq_len]
|
||||
|
||||
// optional parameters
|
||||
std::optional<tc::Tensor> embedding_bias; // [vocab_size_padded]
|
||||
std::optional<tc::Tensor> input_lengths; // [local_batch_size * beam_width]
|
||||
std::optional<tc::Tensor> input_lengths; // [local_batch_size * beam_width]
|
||||
};
|
||||
|
||||
class BeamSearchOutputParams : public DecodingOutputParams
|
||||
@ -96,8 +94,7 @@ public:
|
||||
parent_ids_ptr; // [batch_size] int*, each array is [beam_width, max_seq_len], necessary in beam search
|
||||
};
|
||||
|
||||
void forward(BeamSearchOutputParams& outputs, ForwardParams const& params, int* penalty_workspace,
|
||||
const int* penalty_workspace_prev);
|
||||
void forward(BeamSearchOutputParams& outputs, ForwardParams const& params);
|
||||
|
||||
protected:
|
||||
// meta data
|
||||
@ -107,24 +104,6 @@ protected:
|
||||
size_t topk_softmax_workspace_size_;
|
||||
void* topk_softmax_workspace_ = nullptr;
|
||||
|
||||
float* temperature_buf_;
|
||||
float* repetition_penalty_buf_;
|
||||
float* presence_penalty_buf_;
|
||||
float* frequency_penalty_buf_;
|
||||
int* min_lengths_buf_;
|
||||
|
||||
std::vector<float> mTemperature;
|
||||
std::vector<float> mRepetitionPenalty;
|
||||
std::vector<float> mPresencePenalty;
|
||||
std::vector<float> mFrequencyPenalty;
|
||||
std::vector<int> mMinLengths;
|
||||
|
||||
bool use_temperature_ = false;
|
||||
bool use_repetition_penalty_ = false;
|
||||
bool use_presence_penalty_ = false;
|
||||
bool use_frequency_penalty_ = false;
|
||||
bool use_min_lengths_ = false;
|
||||
|
||||
virtual void invokeSoftMax(BeamSearchOutputParams& outputs, SoftmaxParams const& params) = 0;
|
||||
|
||||
void setupBase(size_t batch_size, SetupParams const& setupParams);
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -32,64 +32,6 @@ namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
{
|
||||
template <typename T>
|
||||
void BaseSamplingLayer<T>::allocateBuffer(size_t batchSize)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
std::array<size_t, 10> deviceBufferSizes;
|
||||
deviceBufferSizes[0] = sizeof(curandState_t) * batchSize;
|
||||
deviceBufferSizes[1] = sizeof(uint64_t) * batchSize;
|
||||
deviceBufferSizes[2] = sizeof(float) * batchSize;
|
||||
deviceBufferSizes[3] = sizeof(float) * batchSize;
|
||||
deviceBufferSizes[4] = sizeof(float) * batchSize;
|
||||
deviceBufferSizes[5] = sizeof(float) * batchSize;
|
||||
deviceBufferSizes[6] = sizeof(int) * batchSize;
|
||||
deviceBufferSizes[7] = sizeof(T) * batchSize * mVocabSizePadded;
|
||||
deviceBufferSizes[8] = sizeof(bool) * batchSize;
|
||||
deviceBufferSizes[9] = sizeof(float) * batchSize;
|
||||
|
||||
mCurandStatesDevice = mAllocator->reMalloc(mCurandStatesDevice, deviceBufferSizes[0], false);
|
||||
mRandomSeedsDevice = mAllocator->reMalloc(mRandomSeedsDevice, deviceBufferSizes[1], false);
|
||||
mTemperaturesDevice = mAllocator->reMalloc(mTemperaturesDevice, deviceBufferSizes[2], false);
|
||||
mRepetitionPenaltiesDevice = mAllocator->reMalloc(mRepetitionPenaltiesDevice, deviceBufferSizes[3], false);
|
||||
mPresencePenaltiesDevice = mAllocator->reMalloc(mPresencePenaltiesDevice, deviceBufferSizes[4], false);
|
||||
mFrequencyPenaltiesDevice = mAllocator->reMalloc(mFrequencyPenaltiesDevice, deviceBufferSizes[5], false);
|
||||
mMinLengthsDevice = mAllocator->reMalloc(mMinLengthsDevice, deviceBufferSizes[6], false);
|
||||
mRuntimeLogitsDevice = mAllocator->reMalloc(mRuntimeLogitsDevice, deviceBufferSizes[7], false);
|
||||
mSkipDecodeDevice = mAllocator->reMalloc(mSkipDecodeDevice, deviceBufferSizes[8], false);
|
||||
mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[9], false);
|
||||
|
||||
auto const bytesAllocated = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), (size_t) 0);
|
||||
TLLM_LOG_DEBUG("baseSamplingLayer allocated %lu bytes on GPU", (size_t) bytesAllocated);
|
||||
|
||||
// host buffers.
|
||||
mSkipDecodeHost = (bool*) std::realloc(mSkipDecodeHost, sizeof(bool) * batchSize);
|
||||
TLLM_CHECK(mSkipDecodeHost != nullptr);
|
||||
|
||||
mIsAllocateBuffer = true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BaseSamplingLayer<T>::freeBuffer()
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
if (mIsAllocateBuffer)
|
||||
{
|
||||
mAllocator->free((void**) (&mCurandStatesDevice));
|
||||
mAllocator->free((void**) (&mRandomSeedsDevice));
|
||||
mAllocator->free((void**) (&mTemperaturesDevice));
|
||||
mAllocator->free((void**) (&mRepetitionPenaltiesDevice));
|
||||
mAllocator->free((void**) (&mPresencePenaltiesDevice));
|
||||
mAllocator->free((void**) (&mFrequencyPenaltiesDevice));
|
||||
mAllocator->free((void**) (&mMinLengthsDevice));
|
||||
mAllocator->free((void**) (&mRuntimeLogitsDevice));
|
||||
mAllocator->free((void**) (&mSkipDecodeDevice));
|
||||
mAllocator->free((void**) (&mSetupWorkspaceDevice));
|
||||
std::free(mSkipDecodeHost);
|
||||
mIsAllocateBuffer = false;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
BaseSamplingLayer<T>::BaseSamplingLayer(size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded,
|
||||
cudaStream_t stream, std::shared_ptr<IAllocator> allocator, cudaDeviceProp* prop)
|
||||
@ -98,166 +40,6 @@ BaseSamplingLayer<T>::BaseSamplingLayer(size_t maxBatchSize, size_t vocabSize, s
|
||||
, mVocabSize(vocabSize)
|
||||
, mVocabSizePadded(vocabSizePadded)
|
||||
{
|
||||
allocateBuffer(maxBatchSize);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
BaseSamplingLayer<T>::BaseSamplingLayer(BaseSamplingLayer const& samplingLayer)
|
||||
: BaseLayer(samplingLayer)
|
||||
, mMaxBatchSize(samplingLayer.mMaxBatchSize)
|
||||
, mVocabSize(samplingLayer.mVocabSize)
|
||||
, mVocabSizePadded(samplingLayer.mVocabSizePadded)
|
||||
, mSamplingWorkspaceSize(samplingLayer.mSamplingWorkspaceSize)
|
||||
{
|
||||
allocateBuffer(mMaxBatchSize);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BaseSamplingLayer<T>::setupBase(const size_t batchSize, int const* batchSlots, SetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
|
||||
// If runtime argument has single random seed, using this random seed to
|
||||
// initialize the random table of all sentences. If the argument has
|
||||
// [batchSize] random seeds, initializing the random table by different
|
||||
// random seeds respectively. If no random seed, initialize the random table
|
||||
// of all sentences by 0 directly.
|
||||
if (setupParams.randomSeed)
|
||||
{
|
||||
if (setupParams.randomSeed->size() == 1)
|
||||
{
|
||||
invokeCurandInitialize(
|
||||
mCurandStatesDevice, batchSlots, batchSize, setupParams.randomSeed->front(), mStream);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(setupParams.randomSeed->size() == batchSize, "Random seed vector size mismatch.");
|
||||
cudaAutoCpy(mRandomSeedsDevice, setupParams.randomSeed->data(), batchSize, mStream);
|
||||
invokeCurandBatchInitialize(mCurandStatesDevice, batchSlots, batchSize, mRandomSeedsDevice, mStream);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Initialize curand states using the default seed 0.
|
||||
invokeCurandInitialize(mCurandStatesDevice, batchSlots, batchSize, 0, mStream);
|
||||
}
|
||||
|
||||
// Setup penalties.
|
||||
FillBuffers const fillBuffers{batchSize, mStream};
|
||||
|
||||
mUseTemperature = static_cast<bool>(setupParams.temperature);
|
||||
mUseRepetitionPenalty = static_cast<bool>(setupParams.repetition_penalty);
|
||||
mUsePresencePenalty = static_cast<bool>(setupParams.presence_penalty);
|
||||
mUseFrequencyPenalty = static_cast<bool>(setupParams.frequency_penalty);
|
||||
mUseMinLengths = static_cast<bool>(setupParams.min_length);
|
||||
if (mUseTemperature)
|
||||
{
|
||||
fillBuffers(setupParams.temperature, getDefaultPenaltyValue(RepetitionPenaltyType::Temperature), mTemperature,
|
||||
mTemperaturesDevice, mSetupWorkspaceDevice, batchSlots);
|
||||
}
|
||||
if (mUseRepetitionPenalty)
|
||||
{
|
||||
fillBuffers(setupParams.repetition_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Repetition),
|
||||
mRepetitionPenalty, mRepetitionPenaltiesDevice, mSetupWorkspaceDevice, batchSlots);
|
||||
}
|
||||
if (mUsePresencePenalty)
|
||||
{
|
||||
fillBuffers(setupParams.presence_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Presence),
|
||||
mPresencePenalty, mPresencePenaltiesDevice, mSetupWorkspaceDevice, batchSlots);
|
||||
}
|
||||
if (mUseFrequencyPenalty)
|
||||
{
|
||||
fillBuffers(setupParams.frequency_penalty, getDefaultPenaltyValue(RepetitionPenaltyType::Frequency),
|
||||
mFrequencyPenalty, mFrequencyPenaltiesDevice, mSetupWorkspaceDevice, batchSlots);
|
||||
}
|
||||
if (mUseMinLengths)
|
||||
{
|
||||
fillBuffers(setupParams.min_length, (int) getDefaultPenaltyValue(RepetitionPenaltyType::MinLength), mMinLengths,
|
||||
mMinLengthsDevice, mSetupWorkspaceDevice, batchSlots);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BaseSamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams const& inputs, int* penaltyWorkspace)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.logits.shape[0];
|
||||
auto const step = inputs.step;
|
||||
auto* const inputLengths = inputs.input_lengths ? inputs.input_lengths->template getPtr<const int>() : nullptr;
|
||||
|
||||
auto* logits = inputs.logits.template getPtr<T>();
|
||||
TLLM_CHECK_WITH_INFO((inputs.batch_slots_host.has_value() ^ inputs.batch_slots.has_value()) == 0,
|
||||
"either both batch_slots_host and batch_slots have to be provided or neither of them");
|
||||
auto* batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : nullptr;
|
||||
std::vector<int32_t> batchSlotsVec(batchSize);
|
||||
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
|
||||
auto* batchSlotsHost
|
||||
= inputs.batch_slots_host ? inputs.batch_slots_host->template getPtr<const int>() : batchSlotsVec.data();
|
||||
|
||||
#define ALL_OF(addrs_, p_, sz_, v_) (std::all_of(addrs_, addrs_ + sz_, [&](int32_t b) { return p_[b] == v_; }))
|
||||
|
||||
if (ALL_OF(batchSlotsHost, mSkipDecodeHost, batchSize, true))
|
||||
{
|
||||
// No sample in the current batch to do TopX sampling.
|
||||
return;
|
||||
}
|
||||
mSkipAny = std::any_of(batchSlotsHost, batchSlotsHost + batchSize,
|
||||
[this](int32_t batchSlot) { return this->mSkipDecodeHost[batchSlot]; });
|
||||
if (mSkipAny)
|
||||
{
|
||||
// A TopX Sampling layer directly changes the logit values. In case of
|
||||
// skip_any==true, meaning topk and topp layers will run simultaneously for
|
||||
// a batch in the same step. We copy the logits to an internal buffer, not
|
||||
// affecting the other sampling layers.
|
||||
TLLM_CHECK(inputs.logits.size() == batchSize * mVocabSizePadded);
|
||||
cudaD2Dcpy(mRuntimeLogitsDevice, logits, inputs.logits.size(), mStream);
|
||||
logits = mRuntimeLogitsDevice;
|
||||
}
|
||||
|
||||
auto* embeddingBias = inputs.embedding_bias ? inputs.embedding_bias->template getPtr<T const>() : nullptr;
|
||||
auto* temperatures = (mUseTemperature
|
||||
&& !ALL_OF(batchSlotsHost, mTemperature, batchSize,
|
||||
getDefaultPenaltyValue(RepetitionPenaltyType::Temperature)))
|
||||
? mTemperaturesDevice
|
||||
: nullptr;
|
||||
auto* repetitionPenalties = (mUseRepetitionPenalty
|
||||
&& !ALL_OF(batchSlotsHost, mRepetitionPenalty, batchSize,
|
||||
getDefaultPenaltyValue(RepetitionPenaltyType::Repetition)))
|
||||
? mRepetitionPenaltiesDevice
|
||||
: nullptr;
|
||||
auto* presencePenalties = (mUsePresencePenalty
|
||||
&& !ALL_OF(batchSlotsHost, mPresencePenalty, batchSize,
|
||||
getDefaultPenaltyValue(RepetitionPenaltyType::Presence)))
|
||||
? mPresencePenaltiesDevice
|
||||
: nullptr;
|
||||
auto* frequencyPenalties = (mUseFrequencyPenalty
|
||||
&& !ALL_OF(batchSlotsHost, mFrequencyPenalty, batchSize,
|
||||
getDefaultPenaltyValue(RepetitionPenaltyType::Frequency)))
|
||||
? mFrequencyPenaltiesDevice
|
||||
: nullptr;
|
||||
auto* minLengths = (mUseMinLengths
|
||||
&& !ALL_OF(batchSlotsHost, mMinLengths, batchSize,
|
||||
(int) getDefaultPenaltyValue(RepetitionPenaltyType::MinLength)))
|
||||
? mMinLengthsDevice
|
||||
: nullptr;
|
||||
|
||||
InvokeBatchApplyPenaltyParams<T> penaltyParams{logits, embeddingBias, penaltyWorkspace, nullptr, temperatures,
|
||||
repetitionPenalties, presencePenalties, frequencyPenalties,
|
||||
(mUseRepetitionPenalty || mUsePresencePenalty || mUseFrequencyPenalty), batchSize, 1, inputs.max_seq_len,
|
||||
mVocabSize, mVocabSizePadded, outputs.output_ids_ptr.template getPtr<const int*>(), nullptr, inputLengths,
|
||||
outputs.sequence_length->getPtr<const int>(), minLengths, inputs.end_ids.template getPtr<const int>(),
|
||||
batchSlots, mStream};
|
||||
invokeBatchApplyPenalty(penaltyParams);
|
||||
sync_check_cuda_error();
|
||||
#undef ALL_OF
|
||||
|
||||
runSampling(outputs, inputs);
|
||||
|
||||
sync_check_cuda_error();
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template class BaseSamplingLayer<float>;
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -32,8 +32,7 @@ namespace layers
|
||||
{
|
||||
|
||||
//! \brief Base class for sampling layers.
|
||||
//! Layer modifies logits in-place. However, when any of the requests skips the sampling layer,
|
||||
//! logits are copied and modified in temporary buffer.
|
||||
//! Layer modifies logits in-place.
|
||||
template <typename T>
|
||||
class BaseSamplingLayer : public BaseLayer
|
||||
{
|
||||
@ -63,8 +62,12 @@ public:
|
||||
int max_seq_len;
|
||||
|
||||
// optional parameters
|
||||
std::optional<tc::Tensor> embedding_bias; // [vocabSizePadded]
|
||||
std::optional<tc::Tensor> input_lengths; // [localBatchSize]
|
||||
std::optional<tc::Tensor> input_lengths; // [localBatchSize]
|
||||
curandState_t* curand_states; // [localBatchSize]
|
||||
// Pointer to the workspace for sampling computation
|
||||
void* sampling_workspace;
|
||||
// Flag to mark that logits tensor contains probabilities
|
||||
bool probs_computed;
|
||||
};
|
||||
|
||||
// clang-format off
|
||||
@ -80,8 +83,6 @@ public:
|
||||
BaseSamplingLayer(size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded, cudaStream_t stream,
|
||||
std::shared_ptr<tensorrt_llm::common::IAllocator> allocator, cudaDeviceProp* prop);
|
||||
|
||||
BaseSamplingLayer(BaseSamplingLayer const& samplingLayer);
|
||||
|
||||
~BaseSamplingLayer() override = default;
|
||||
|
||||
// clang-format off
|
||||
@ -92,9 +93,9 @@ public:
|
||||
//!
|
||||
//! \param outputs DecodingOutputParams struct with output tensors
|
||||
//! \param inputs ForwardParams struct with input tensors and params
|
||||
//! \param penaltyWorkspace
|
||||
//! \param curandStatesDevice Properly initialized curand states buffer on device
|
||||
// clang-format on
|
||||
void forward(DecodingOutputParams& outputs, ForwardParams const& inputs, int* penaltyWorkspace);
|
||||
virtual void forward(DecodingOutputParams& outputs, ForwardParams& inputs) = 0;
|
||||
|
||||
// clang-format off
|
||||
//! \brief Virtual function that setups internal tensors of the layer with sampling params
|
||||
@ -103,60 +104,28 @@ public:
|
||||
//! Thus, it must be called only once for new requests.
|
||||
//!
|
||||
//! \param batchSize Maximum batch size configured in the system
|
||||
//! \param batchSlots input tensor [batchSize], address map of the new requests
|
||||
//! \param batchSlots input tensor [batchSize], address map of the new requests, in pinned memory
|
||||
//! \param setupParams setup sampling parameters per request
|
||||
// clang-format on
|
||||
virtual void setup(size_t batchSize, int const* batchSlots, SetupParams const& setupParams) = 0;
|
||||
virtual void setup(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) = 0;
|
||||
|
||||
protected:
|
||||
//! \brief setup of the base class, has to be called in the beginning of the derived's class setup
|
||||
void setupBase(size_t batchSize, int const* batchSlots, SetupParams const& setupParams);
|
||||
size_t getWorkspaceSize() const
|
||||
{
|
||||
return mSamplingWorkspaceSize;
|
||||
}
|
||||
|
||||
// clang-format off
|
||||
//! \brief Executes sampling logic of the derived class
|
||||
//!
|
||||
//! \param outputs DecodingOutputParams struct with output tensors
|
||||
//! \param inputs ForwardParams struct with input tensors and params
|
||||
// clang-format on
|
||||
virtual void runSampling(DecodingOutputParams& outputs, DecodingParams const& inputs) = 0;
|
||||
|
||||
virtual void freeBuffer();
|
||||
size_t getAllocatedSize() const
|
||||
{
|
||||
return mAllocatedSize;
|
||||
}
|
||||
|
||||
protected:
|
||||
size_t mMaxBatchSize;
|
||||
size_t mVocabSize;
|
||||
size_t mVocabSizePadded;
|
||||
|
||||
size_t mSamplingWorkspaceSize;
|
||||
void* mSamplingWorkspaceDevice = nullptr;
|
||||
curandState_t* mCurandStatesDevice = nullptr;
|
||||
uint64_t* mRandomSeedsDevice = nullptr;
|
||||
|
||||
float* mTemperaturesDevice = nullptr;
|
||||
float* mRepetitionPenaltiesDevice = nullptr;
|
||||
float* mPresencePenaltiesDevice = nullptr;
|
||||
float* mFrequencyPenaltiesDevice = nullptr;
|
||||
int32_t* mMinLengthsDevice = nullptr;
|
||||
bool* mSkipDecodeDevice = nullptr;
|
||||
T* mRuntimeLogitsDevice = nullptr;
|
||||
void* mSetupWorkspaceDevice = nullptr;
|
||||
|
||||
std::vector<float> mTemperature;
|
||||
std::vector<float> mRepetitionPenalty;
|
||||
std::vector<float> mPresencePenalty;
|
||||
std::vector<float> mFrequencyPenalty;
|
||||
std::vector<int32_t> mMinLengths;
|
||||
bool* mSkipDecodeHost = nullptr;
|
||||
bool mSkipAny = false;
|
||||
|
||||
bool mUseTemperature = false;
|
||||
bool mUseRepetitionPenalty = false;
|
||||
bool mUsePresencePenalty = false;
|
||||
bool mUseFrequencyPenalty = false;
|
||||
bool mUseMinLengths = false;
|
||||
|
||||
private:
|
||||
void allocateBuffer(size_t batchSize);
|
||||
size_t mSamplingWorkspaceSize = 0;
|
||||
size_t mAllocatedSize = 0;
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
|
||||
@ -50,11 +50,10 @@ public:
|
||||
// mandatory parameters
|
||||
int step;
|
||||
int ite;
|
||||
tc::Tensor logits; // [local_batch_size, beam_width, vocab_size_padded]
|
||||
tc::Tensor end_ids; // [local_batch_size]
|
||||
std::optional<tc::Tensor> batch_slots; // [local_batch_size]
|
||||
std::optional<tc::Tensor> batch_slots_host; // [local_batch_size]
|
||||
std::optional<tc::Tensor> finished; // [batch_size * beam_width]
|
||||
tc::Tensor logits; // [local_batch_size, beam_width, vocab_size_padded]
|
||||
tc::Tensor end_ids; // [local_batch_size]
|
||||
std::optional<tc::Tensor> batch_slots; // [local_batch_size], on pinned memory
|
||||
std::optional<tc::Tensor> finished; // [batch_size * beam_width]
|
||||
};
|
||||
|
||||
class DecodingOutputParams
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -15,14 +15,15 @@
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/layers/dynamicDecodeLayer.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/banBadWords.h"
|
||||
#include "tensorrt_llm/kernels/banRepeatNgram.h"
|
||||
#include "tensorrt_llm/kernels/decodingKernels.h"
|
||||
#include "tensorrt_llm/kernels/penaltyKernels.h"
|
||||
#include "tensorrt_llm/kernels/stopCriteriaKernels.h"
|
||||
#include "tensorrt_llm/layers/baseBeamSearchLayer.h"
|
||||
#include "tensorrt_llm/layers/fillBuffers.h"
|
||||
#include "tensorrt_llm/layers/onlineBeamSearchLayer.h"
|
||||
#include "tensorrt_llm/layers/topKSamplingLayer.h"
|
||||
#include "tensorrt_llm/layers/topPSamplingLayer.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
|
||||
@ -35,54 +36,6 @@ namespace tensorrt_llm
|
||||
namespace layers
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::initialize()
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
mOnlineBeamsearchDecode
|
||||
= std::make_unique<OnlineBeamSearchLayer<T>>(vocab_size_, vocab_size_padded_, mStream, mAllocator);
|
||||
|
||||
mTopKDecode
|
||||
= std::make_unique<TopKSamplingLayer<T>>(max_batch_size_, vocab_size_, vocab_size_padded_, mStream, mAllocator);
|
||||
|
||||
mTopPDecode = std::make_unique<TopPSamplingLayer<T>>(max_batch_size_, vocab_size_, vocab_size_padded_, mStream,
|
||||
mAllocator, cuda_device_prop_, /* deterministic */ true);
|
||||
|
||||
mIdsPtrHost = runtime::BufferManager::pinned(ITensor::makeShape({}), runtime::TRTDataType<int*>::value);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DynamicDecodeLayer<T>::DynamicDecodeLayer(size_t max_batch_size, size_t vocab_size, size_t vocab_size_padded,
|
||||
cudaStream_t stream, std::shared_ptr<IAllocator> allocator, cudaDeviceProp* cuda_device_prop)
|
||||
: BaseLayer(stream, std::move(allocator))
|
||||
, max_batch_size_(max_batch_size)
|
||||
, vocab_size_(vocab_size)
|
||||
, vocab_size_padded_(vocab_size_padded)
|
||||
, cuda_device_prop_(cuda_device_prop)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
initialize();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DynamicDecodeLayer<T>::~DynamicDecodeLayer()
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
freeBuffer();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DynamicDecodeLayer<T>::DynamicDecodeLayer(DynamicDecodeLayer const& dynamic_decode_layer)
|
||||
: BaseLayer(dynamic_decode_layer)
|
||||
, max_batch_size_(dynamic_decode_layer.max_batch_size_)
|
||||
, vocab_size_(dynamic_decode_layer.vocab_size_)
|
||||
, vocab_size_padded_(dynamic_decode_layer.vocab_size_padded_)
|
||||
, cuda_device_prop_(dynamic_decode_layer.cuda_device_prop_)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
initialize();
|
||||
}
|
||||
|
||||
namespace
|
||||
{
|
||||
template <typename T>
|
||||
@ -118,20 +71,166 @@ bool hasDiffRuntimeArgs(DecodingSetupParams const& params)
|
||||
} // namespace
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::setup(
|
||||
size_t batch_size, size_t beam_width, int const* batch_slots, SetupParams const& setupParams)
|
||||
DynamicDecodeLayer<T>::DynamicDecodeLayer(DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth,
|
||||
size_t vocabSize, size_t vocabSizePadded, cudaStream_t stream, std::shared_ptr<IAllocator> allocator,
|
||||
cudaDeviceProp* cudaDeviceProp)
|
||||
: BaseLayer(stream, std::move(allocator))
|
||||
, mDecodingMode(mode)
|
||||
, mMaxBatchSize(maxBatchSize)
|
||||
, mMaxBeamWidth(maxBeamWidth)
|
||||
, mVocabSize(vocabSize)
|
||||
, mVocabSizePadded(vocabSizePadded)
|
||||
, mCudaDeviceProp(cudaDeviceProp)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
initialize();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DynamicDecodeLayer<T>::DynamicDecodeLayer(DynamicDecodeLayer const& dynamicDecodeLayer)
|
||||
: BaseLayer(dynamicDecodeLayer)
|
||||
, mDecodingMode(dynamicDecodeLayer.mDecodingMode)
|
||||
, mMaxBatchSize(dynamicDecodeLayer.mMaxBatchSize)
|
||||
, mMaxBeamWidth(dynamicDecodeLayer.mMaxBeamWidth)
|
||||
, mVocabSize(dynamicDecodeLayer.mVocabSize)
|
||||
, mVocabSizePadded(dynamicDecodeLayer.mVocabSizePadded)
|
||||
, mCudaDeviceProp(dynamicDecodeLayer.mCudaDeviceProp)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
initialize();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
DynamicDecodeLayer<T>::~DynamicDecodeLayer()
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
freeBuffer();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::initialize()
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
|
||||
if (beam_width == 1)
|
||||
{ // sampling layers
|
||||
typename TopPSamplingLayer<T>::SetupParams samplingParams;
|
||||
mIdsPtrHost = runtime::BufferManager::pinned(ITensor::makeShape({}), runtime::TRTDataType<int*>::value);
|
||||
mLogitsPtrsHost = runtime::BufferManager::pinned(ITensor::makeShape({}), runtime::TRTDataType<T*>::value);
|
||||
|
||||
samplingParams.temperature = setupParams.temperature;
|
||||
samplingParams.min_length = setupParams.min_length;
|
||||
samplingParams.repetition_penalty = setupParams.repetition_penalty;
|
||||
samplingParams.presence_penalty = setupParams.presence_penalty;
|
||||
samplingParams.frequency_penalty = setupParams.frequency_penalty;
|
||||
allocateBuffer();
|
||||
|
||||
mCyclicStep = 0;
|
||||
mRuntimeMaxSeqLen = 0;
|
||||
mConfiguredBeamWidth = -1;
|
||||
|
||||
mTemperature.resize(mMaxBatchSize);
|
||||
mRepetitionPenalty.resize(mMaxBatchSize);
|
||||
mPresencePenalty.resize(mMaxBatchSize);
|
||||
mFrequencyPenalty.resize(mMaxBatchSize);
|
||||
mMinLength.resize(mMaxBatchSize);
|
||||
|
||||
if (!mDecodingMode.isNone())
|
||||
{
|
||||
mConfiguredBeamWidth = mMaxBeamWidth;
|
||||
initializeLayers();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::allocateBuffer()
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
mZeroParentIdsDevice = mAllocator->reMalloc(mZeroParentIdsDevice, sizeof(int*) * 2 * mMaxBatchSize, false);
|
||||
mTemperatureDevice = mAllocator->reMalloc(mTemperatureDevice, sizeof(float) * mMaxBatchSize, false);
|
||||
mRepetitionPenaltyDevice = mAllocator->reMalloc(mRepetitionPenaltyDevice, sizeof(float) * mMaxBatchSize, false);
|
||||
mPresencePenaltyDevice = mAllocator->reMalloc(mPresencePenaltyDevice, sizeof(float) * mMaxBatchSize, false);
|
||||
mFrequencyPenaltyDevice = mAllocator->reMalloc(mFrequencyPenaltyDevice, sizeof(float) * mMaxBatchSize, false);
|
||||
mMinLengthDevice = mAllocator->reMalloc(mMinLengthDevice, sizeof(int32_t) * mMaxBatchSize, false);
|
||||
mRuntimeLogitsDevice = mAllocator->reMalloc(
|
||||
mRuntimeLogitsDevice, sizeof(T) * mMaxBatchSize * mMaxBeamWidth * mVocabSizePadded, false);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::freeBuffer()
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
mAllocator->free((void**) &mZeroParentIdsDevice);
|
||||
if (mPenaltyWorkspaceDevice != nullptr)
|
||||
{
|
||||
mAllocator->free((void**) &mPenaltyWorkspaceDevice);
|
||||
}
|
||||
if (mPenaltyWorkspacePrevDevice != nullptr)
|
||||
{
|
||||
mAllocator->free((void**) &mPenaltyWorkspacePrevDevice);
|
||||
}
|
||||
mAllocator->free((void**) (&mTemperatureDevice));
|
||||
mAllocator->free((void**) (&mRepetitionPenaltyDevice));
|
||||
mAllocator->free((void**) (&mPresencePenaltyDevice));
|
||||
mAllocator->free((void**) (&mFrequencyPenaltyDevice));
|
||||
mAllocator->free((void**) (&mMinLengthDevice));
|
||||
mAllocator->free((void**) (&mRuntimeLogitsDevice));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::initializeLayers()
|
||||
{
|
||||
const size_t workspaceSize = sizeof(int) * mMaxBatchSize * mConfiguredBeamWidth * mVocabSize;
|
||||
mPenaltyWorkspaceDevice = mAllocator->reMalloc(mPenaltyWorkspaceDevice, workspaceSize, false);
|
||||
|
||||
if (mDecodingMode.isTopKorTopP())
|
||||
{
|
||||
mSamplingLayer = std::make_unique<SamplingLayer<T>>(
|
||||
mDecodingMode, mMaxBatchSize, mVocabSize, mVocabSizePadded, mStream, mAllocator, mCudaDeviceProp);
|
||||
}
|
||||
else if (mDecodingMode.isBeamSearch())
|
||||
{
|
||||
mOnlineBeamSearchDecode
|
||||
= std::make_unique<OnlineBeamSearchLayer<T>>(mVocabSize, mVocabSizePadded, mStream, mAllocator);
|
||||
mPenaltyWorkspacePrevDevice = mAllocator->reMalloc(mPenaltyWorkspacePrevDevice, workspaceSize, false);
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(false, "Decoding mode is none of the supported {TopK, TopP, TopKTopP, BeamSearch}");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::setup(
|
||||
size_t batchSize, size_t beamWidth, int32_t const* batchSlots, SetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
if (mConfiguredBeamWidth == -1)
|
||||
{
|
||||
// This code is left only for Python runtime
|
||||
// In C++ runtime given maxBeamWidth should always be equal to the runtime beamWidth
|
||||
TLLM_CHECK(mDecodingMode.isNone());
|
||||
mConfiguredBeamWidth = beamWidth;
|
||||
mDecodingMode = mConfiguredBeamWidth == 1 ? DecodingMode::TopKTopP() : DecodingMode::BeamSearch();
|
||||
initializeLayers();
|
||||
}
|
||||
|
||||
TLLM_CHECK_WITH_INFO((mConfiguredBeamWidth == 1 && beamWidth == 1)
|
||||
|| (mConfiguredBeamWidth > 1 && beamWidth > 1 && beamWidth <= mConfiguredBeamWidth),
|
||||
"Decoder is configured with beam width %d, but %lu was given", mConfiguredBeamWidth, beamWidth);
|
||||
TLLM_CHECK_WITH_INFO(mConfiguredBeamWidth <= mMaxBeamWidth,
|
||||
"Decoder is created with max beam width %lu, but %d was given", mMaxBeamWidth, mConfiguredBeamWidth);
|
||||
|
||||
setupLayers(batchSize, beamWidth, batchSlots, setupParams);
|
||||
|
||||
setupPenalties(batchSize, batchSlots, setupParams);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::setupLayers(
|
||||
size_t batchSize, size_t beamWidth, int32_t const* batchSlots, SetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
if (beamWidth == 1)
|
||||
{ // sampling layers
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mDecodingMode.isTopKorTopP(), "beamWidth == 1 is given, but decoder is not configured as TopK or TopP");
|
||||
typename TopPSamplingLayer<T>::SetupParams samplingParams;
|
||||
|
||||
samplingParams.runtime_top_k = setupParams.runtime_top_k;
|
||||
samplingParams.runtime_top_p = setupParams.runtime_top_p;
|
||||
@ -142,166 +241,162 @@ void DynamicDecodeLayer<T>::setup(
|
||||
samplingParams.top_p_reset_ids = setupParams.top_p_reset_ids;
|
||||
samplingParams.normalize_log_probs = setupParams.normalize_log_probs;
|
||||
|
||||
mTopKDecode->setup(batch_size, batch_slots, samplingParams);
|
||||
mTopPDecode->setup(batch_size, batch_slots, samplingParams);
|
||||
mSamplingLayer->setup(batchSize, batchSlots, samplingParams);
|
||||
}
|
||||
else
|
||||
{ // beam search layer
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mDecodingMode.isBeamSearch(), "beamWidth > 1 is given, but decoder is not configured as BeamSearch");
|
||||
typename OnlineBeamSearchLayer<T>::SetupParams beamSearchParams;
|
||||
|
||||
beamSearchParams.temperature = setupParams.temperature;
|
||||
beamSearchParams.min_length = setupParams.min_length;
|
||||
beamSearchParams.repetition_penalty = setupParams.repetition_penalty;
|
||||
beamSearchParams.presence_penalty = setupParams.presence_penalty;
|
||||
beamSearchParams.frequency_penalty = setupParams.frequency_penalty;
|
||||
|
||||
beamSearchParams.beam_search_diversity_rate = setupParams.beam_search_diversity_rate;
|
||||
beamSearchParams.length_penalty = setupParams.length_penalty;
|
||||
|
||||
has_diff_runtime_args_ = hasDiffRuntimeArgs(beamSearchParams);
|
||||
mOnlineBeamsearchDecode->setup(batch_size, beamSearchParams);
|
||||
mHasDiffRuntimeArgs = hasDiffRuntimeArgs(beamSearchParams);
|
||||
mOnlineBeamSearchDecode->setup(batchSize, beamSearchParams);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::allocateBuffer(size_t batch_size, size_t beam_width, size_t max_seq_len)
|
||||
void DynamicDecodeLayer<T>::setupPenalties(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
mIdsPtrHost->resize(2 * batch_size);
|
||||
zero_parent_ids = mAllocator->reMalloc(zero_parent_ids, sizeof(int*) * 2 * batch_size, false);
|
||||
const size_t workspace_size = sizeof(int) * batch_size * beam_width * vocab_size_;
|
||||
if (beam_width == 1)
|
||||
{ // sampling layers
|
||||
top_k_workspace = mAllocator->reMalloc(top_k_workspace, workspace_size, false);
|
||||
top_p_workspace = mAllocator->reMalloc(top_p_workspace, workspace_size, false);
|
||||
}
|
||||
else
|
||||
{ // beam search layer
|
||||
beam_search_workspace_0 = mAllocator->reMalloc(beam_search_workspace_0, workspace_size, false);
|
||||
beam_search_workspace_1 = mAllocator->reMalloc(beam_search_workspace_1, workspace_size, false);
|
||||
}
|
||||
mCyclicStep = 0;
|
||||
}
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
std::vector<int32_t> batchSlotsVec(batchSize);
|
||||
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
|
||||
auto batchSlotsHost = batchSlots ? batchSlots : batchSlotsVec.data();
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::freeBuffer()
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
mAllocator->free((void**) &zero_parent_ids);
|
||||
if (top_k_workspace != nullptr)
|
||||
// Setup penalties.
|
||||
FillBuffers const fillBuffers{batchSize, mMaxBatchSize, mStream};
|
||||
|
||||
mUseTemperature = static_cast<bool>(setupParams.temperature);
|
||||
mUseRepetitionPenalty = static_cast<bool>(setupParams.repetition_penalty);
|
||||
mUsePresencePenalty = static_cast<bool>(setupParams.presence_penalty);
|
||||
mUseFrequencyPenalty = static_cast<bool>(setupParams.frequency_penalty);
|
||||
mUseMinLength = static_cast<bool>(setupParams.min_length);
|
||||
if (mUseTemperature)
|
||||
{
|
||||
mAllocator->free((void**) &top_k_workspace);
|
||||
fillBuffers(setupParams.temperature, getDefaultPenaltyValue(DecodingPenaltyType::Temperature), mTemperature,
|
||||
mTemperatureDevice, batchSlotsHost);
|
||||
}
|
||||
if (top_p_workspace != nullptr)
|
||||
if (mUseRepetitionPenalty)
|
||||
{
|
||||
mAllocator->free((void**) &top_p_workspace);
|
||||
fillBuffers(setupParams.repetition_penalty, getDefaultPenaltyValue(DecodingPenaltyType::Repetition),
|
||||
mRepetitionPenalty, mRepetitionPenaltyDevice, batchSlotsHost);
|
||||
}
|
||||
if (beam_search_workspace_0 != nullptr)
|
||||
if (mUsePresencePenalty)
|
||||
{
|
||||
mAllocator->free((void**) &beam_search_workspace_0);
|
||||
fillBuffers(setupParams.presence_penalty, getDefaultPenaltyValue(DecodingPenaltyType::Presence),
|
||||
mPresencePenalty, mPresencePenaltyDevice, batchSlotsHost);
|
||||
}
|
||||
if (beam_search_workspace_1 != nullptr)
|
||||
if (mUseFrequencyPenalty)
|
||||
{
|
||||
mAllocator->free((void**) &beam_search_workspace_1);
|
||||
fillBuffers(setupParams.frequency_penalty, getDefaultPenaltyValue(DecodingPenaltyType::Frequency),
|
||||
mFrequencyPenalty, mFrequencyPenaltyDevice, batchSlotsHost);
|
||||
}
|
||||
if (mUseMinLength)
|
||||
{
|
||||
fillBuffers(setupParams.min_length, (int) getDefaultPenaltyValue(DecodingPenaltyType::MinLength), mMinLength,
|
||||
mMinLengthDevice, batchSlotsHost);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const& params)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
const auto ite = params.ite;
|
||||
const auto step = params.step;
|
||||
const auto max_step = step;
|
||||
auto const& logits = params.logits;
|
||||
TLLM_CHECK(logits.shape.size() == 3);
|
||||
|
||||
auto const batch_size = logits.shape[0];
|
||||
auto const beam_width = logits.shape[1];
|
||||
auto const local_batch_size = static_cast<std::size_t>(params.local_batch_size);
|
||||
|
||||
auto const max_seq_len = outputs.output_ids.shape[outputs.output_ids.shape.size() - 1];
|
||||
TLLM_CHECK_WITH_INFO(params.logits || params.logits_vec, "Either logits or logits_vec have to be specified.");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
outputs.sequence_length.has_value(), "sequence_length tensor is mandatory in DynamicDecoderLayer.");
|
||||
allocateBuffer(batch_size, beam_width, max_seq_len);
|
||||
|
||||
// std::vector<int*> ids_ptr_host;
|
||||
auto idsPtrHost = runtime::bufferCast<int*>(*mIdsPtrHost);
|
||||
for (int i = 0; i < batch_size; i++)
|
||||
size_t batchSize = 0;
|
||||
size_t beamWidth = 0;
|
||||
size_t vocabSize = 0;
|
||||
auto const maxSeqLen = outputs.output_ids.shape[outputs.output_ids.shape.size() - 1];
|
||||
if (params.logits)
|
||||
{
|
||||
idsPtrHost[i] = outputs.output_ids.template getPtrWithOffset<int32_t>(i * beam_width * max_seq_len);
|
||||
auto const& logitsShape = params.logits->shape;
|
||||
TLLM_CHECK(logitsShape.size() == 3);
|
||||
batchSize = logitsShape[0];
|
||||
beamWidth = logitsShape[1];
|
||||
vocabSize = logitsShape[2];
|
||||
}
|
||||
for (int i = 0; i < batch_size; i++)
|
||||
else
|
||||
{
|
||||
if (beam_width > 1)
|
||||
{
|
||||
idsPtrHost[batch_size + i]
|
||||
= outputs.parent_ids.value().template getPtrWithOffset<int32_t>(i * beam_width * max_seq_len);
|
||||
}
|
||||
else
|
||||
{
|
||||
idsPtrHost[batch_size + i] = zero_parent_ids + i * beam_width * max_seq_len;
|
||||
}
|
||||
TLLM_CHECK(params.logits_vec->size());
|
||||
auto const& logitsShape = params.logits_vec.value()[0].shape;
|
||||
TLLM_CHECK(logitsShape.size() == 3);
|
||||
batchSize = params.logits_vec->size();
|
||||
beamWidth = logitsShape[1];
|
||||
vocabSize = logitsShape[2];
|
||||
}
|
||||
|
||||
outputs.output_ids_ptr
|
||||
= Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR, {batch_size, beam_width, max_seq_len}, idsPtrHost);
|
||||
outputs.parent_ids_ptr
|
||||
= Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR, {batch_size, beam_width, max_seq_len}, idsPtrHost + batch_size);
|
||||
TLLM_CHECK_WITH_INFO((mConfiguredBeamWidth == 1 && beamWidth == 1)
|
||||
|| (mConfiguredBeamWidth > 1 && beamWidth > 1 && beamWidth <= mConfiguredBeamWidth),
|
||||
"Decoder is configured with beam width %d, but %lu was given", mConfiguredBeamWidth, beamWidth);
|
||||
|
||||
if (params.no_repeat_ngram_size)
|
||||
if (!mLogitsPtrsHost->data())
|
||||
{
|
||||
const int* no_repeat_ngram_size_buf = params.no_repeat_ngram_size.value().template getPtr<const int>();
|
||||
|
||||
invokeBanRepeatNgram(logits.template getPtr<T>(), outputs.output_ids_ptr.template getPtr<const int*>(),
|
||||
reinterpret_cast<FinishedState*>(
|
||||
params.finished.value_or(Tensor{}).template getPtr<FinishedState::UnderlyingType>()),
|
||||
outputs.parent_ids_ptr.template getPtr<const int*>(), nullptr,
|
||||
outputs.sequence_length->template getPtr<int>(), batch_size, local_batch_size, beam_width, max_seq_len,
|
||||
no_repeat_ngram_size_buf, vocab_size_padded_, max_step, mStream);
|
||||
mLogitsPtrsHost = runtime::BufferManager::pinnedPool(
|
||||
ITensor::makeShape({static_cast<int32_t>(maxSeqLen), static_cast<int32_t>(mMaxBatchSize)}),
|
||||
runtime::TRTDataType<T*>::value);
|
||||
mIdsPtrHost = runtime::BufferManager::pinnedPool(
|
||||
ITensor::makeShape({static_cast<int32_t>(maxSeqLen), static_cast<int32_t>(2 * mMaxBatchSize)}),
|
||||
runtime::TRTDataType<int32_t*>::value);
|
||||
mRuntimeMaxSeqLen = maxSeqLen;
|
||||
}
|
||||
|
||||
if (params.bad_words_list)
|
||||
{
|
||||
const auto& bad_words = params.bad_words_list.value();
|
||||
const int* bad_words_ptr = bad_words.template getPtr<const int>();
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
bad_words.shape.size() == 2 || bad_words.shape.size() == 3, "Bad words dimension must be 2 or 3.");
|
||||
std::vector<int32_t> batchSlotsVec(batchSize);
|
||||
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
|
||||
auto batchSlotsHost = params.batch_slots ? params.batch_slots->template getPtr<const int>() : batchSlotsVec.data();
|
||||
auto batchSlots = params.batch_slots ? params.batch_slots->template getPtr<const int>() : nullptr;
|
||||
|
||||
const bool is_matrix = bad_words.shape.size() == 2;
|
||||
if (bad_words.shape.size() == 3)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(bad_words.shape[0] == batch_size,
|
||||
fmtstr("Shape of dim 0 of bad words is invalid. It "
|
||||
"must be equal to batch size."
|
||||
" However, it is %ld and the batch size is %ld.",
|
||||
bad_words.shape[0], batch_size));
|
||||
}
|
||||
mCyclicStep = mCyclicStep % mRuntimeMaxSeqLen;
|
||||
prepareIdsPtrs(outputs, batchSlotsHost, batchSize, beamWidth, maxSeqLen);
|
||||
|
||||
const bool shared_bad_words = is_matrix || bad_words.shape[0] == 1;
|
||||
const size_t bad_words_len = bad_words.shape[is_matrix ? 1 : 2];
|
||||
// Add check on batch size of bad words
|
||||
const int id_offset = ite * local_batch_size;
|
||||
const int decode_vocab_size_units_offset = id_offset * vocab_size_padded_;
|
||||
auto logits = Tensor(MEMORY_GPU, std::is_same_v<T, float> ? DataType::TYPE_FP32 : DataType::TYPE_FP16,
|
||||
{batchSize, beamWidth, mVocabSizePadded}, mRuntimeLogitsDevice);
|
||||
|
||||
invokeBanBadWords((T*) logits.getPtrWithOffset(decode_vocab_size_units_offset),
|
||||
outputs.output_ids_ptr.template getPtr<const int*>(),
|
||||
beam_width > 1 ? outputs.parent_ids_ptr.template getPtr<const int*>() : nullptr, nullptr, batch_size,
|
||||
local_batch_size, beam_width,
|
||||
shared_bad_words
|
||||
? bad_words_ptr
|
||||
: bad_words.template getPtrWithOffset<const int>(ite * local_batch_size * 2 * bad_words_len),
|
||||
shared_bad_words, bad_words_len, vocab_size_padded_, outputs.sequence_length->template getPtr<int>(),
|
||||
max_seq_len, mStream);
|
||||
}
|
||||
// Apply penalties
|
||||
applyPenalties(outputs, params, batchSlotsHost, batchSlots, batchSize, beamWidth, maxSeqLen);
|
||||
|
||||
// Ban bad words and NGrams
|
||||
banWords(logits, outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, mVocabSizePadded, mStream);
|
||||
|
||||
// Main function that calls forward of the respective layers
|
||||
layersForward(logits, outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen);
|
||||
|
||||
// Check if stop conditions are met
|
||||
checkStopCriteria(outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, mStream);
|
||||
|
||||
// Copy nextIds and transpose logits when needed
|
||||
prepareOutputData(outputs, params, mIdsPtrHost, batchSlots, batchSize, beamWidth, maxSeqLen, mCyclicStep, mStream);
|
||||
|
||||
mCyclicStep += 1;
|
||||
|
||||
sync_check_cuda_error();
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::layersForward(Tensor& logits, OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto const ite = params.ite;
|
||||
auto const step = params.step;
|
||||
|
||||
// common inputs
|
||||
auto const& end_ids = params.end_ids;
|
||||
auto const& endIds = params.end_ids;
|
||||
auto const localBatchSize = static_cast<std::size_t>(params.local_batch_size);
|
||||
|
||||
// dynamic decode GPT
|
||||
if (beam_width > 1)
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mDecodingMode.isBeamSearch(), "beamWidth > 1 is given, but decoder is not configured as BeamSearch");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
params.src_cache_indirection.has_value(), "src_cache_indirection is mandatory in beam search.");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
@ -312,29 +407,26 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
|
||||
|
||||
// Because we still not support batch beam search now, so we need to compute
|
||||
// one by one if there are different runtime arguments.
|
||||
const size_t dynamic_decode_batch_size = has_diff_runtime_args_ ? 1 : local_batch_size;
|
||||
const int dynamic_decode_total_iteration = local_batch_size / dynamic_decode_batch_size;
|
||||
const size_t dynamic_decode_batch_size = mHasDiffRuntimeArgs ? 1 : localBatchSize;
|
||||
const int dynamic_decode_total_iteration = localBatchSize / dynamic_decode_batch_size;
|
||||
|
||||
for (uint32_t dynamic_ite = ite * dynamic_decode_total_iteration;
|
||||
dynamic_ite < (ite + 1) * dynamic_decode_total_iteration; ++dynamic_ite)
|
||||
for (uint32_t dynamic_ite = 0; dynamic_ite < dynamic_decode_total_iteration; ++dynamic_ite)
|
||||
{
|
||||
const int dynamic_id_offset = dynamic_ite * dynamic_decode_batch_size * beam_width;
|
||||
const int dynamic_decode_vocab_size_units_offset = dynamic_id_offset * vocab_size_padded_;
|
||||
const int dynamic_id_offset = dynamic_ite * dynamic_decode_batch_size * beamWidth;
|
||||
const int dynamic_decode_vocab_size_units_offset = dynamic_id_offset * mVocabSizePadded;
|
||||
|
||||
auto const logits_offset = logits.slice(
|
||||
{dynamic_decode_batch_size, logits.shape[1], logits.shape[2]}, dynamic_decode_vocab_size_units_offset);
|
||||
auto const end_id_offset
|
||||
= end_ids.slice({dynamic_decode_batch_size}, dynamic_ite * dynamic_decode_batch_size);
|
||||
= endIds.slice({dynamic_decode_batch_size}, dynamic_ite * dynamic_decode_batch_size);
|
||||
typename BaseBeamSearchLayer<T>::ForwardParams dynamic_decode_input_tensors{step, ite, logits_offset,
|
||||
end_id_offset, *params.src_cache_indirection, static_cast<std::int32_t>(params.max_attention_window),
|
||||
static_cast<std::int32_t>(params.sink_token_length), static_cast<std::int32_t>(max_seq_len)};
|
||||
|
||||
dynamic_decode_input_tensors.embedding_bias = params.embedding_bias;
|
||||
static_cast<std::int32_t>(params.sink_token_length), static_cast<std::int32_t>(maxSeqLen)};
|
||||
|
||||
if (params.input_lengths)
|
||||
{
|
||||
dynamic_decode_input_tensors.input_lengths
|
||||
= params.input_lengths->slice({dynamic_decode_batch_size * beam_width}, dynamic_id_offset);
|
||||
= params.input_lengths->slice({dynamic_decode_batch_size * beamWidth}, dynamic_id_offset);
|
||||
}
|
||||
|
||||
// common outputs
|
||||
@ -345,122 +437,286 @@ void DynamicDecodeLayer<T>::forward(OutputParams& outputs, ForwardParams const&
|
||||
dynamic_decode_outputs.parent_ids_ptr = std::move(outputs.parent_ids_ptr);
|
||||
|
||||
dynamic_decode_outputs.sequence_length
|
||||
= outputs.sequence_length->slice({dynamic_decode_batch_size * beam_width}, dynamic_id_offset);
|
||||
= outputs.sequence_length->slice({dynamic_decode_batch_size * beamWidth}, dynamic_id_offset);
|
||||
dynamic_decode_outputs.finished
|
||||
= outputs.finished->slice({dynamic_decode_batch_size * beam_width}, dynamic_id_offset);
|
||||
= outputs.finished->slice({dynamic_decode_batch_size * beamWidth}, dynamic_id_offset);
|
||||
dynamic_decode_outputs.cum_log_probs
|
||||
= outputs.cum_log_probs->slice({dynamic_decode_batch_size * beam_width}, dynamic_id_offset);
|
||||
= outputs.cum_log_probs->slice({dynamic_decode_batch_size * beamWidth}, dynamic_id_offset);
|
||||
|
||||
dynamic_decode_outputs.beamHypotheses = outputs.beamHypotheses;
|
||||
dynamic_decode_outputs.output_log_probs = outputs.output_log_probs_tiled;
|
||||
|
||||
// only OnlineBeamSearchLayer support beam_search_diversity_rate
|
||||
// when beamHypotheses is used
|
||||
mOnlineBeamsearchDecode->forward(
|
||||
dynamic_decode_outputs, dynamic_decode_input_tensors, beam_search_workspace_0, beam_search_workspace_1);
|
||||
std::swap(beam_search_workspace_0, beam_search_workspace_1);
|
||||
mOnlineBeamSearchDecode->forward(dynamic_decode_outputs, dynamic_decode_input_tensors);
|
||||
} // end of dynamic_ite
|
||||
std::swap(mPenaltyWorkspaceDevice, mPenaltyWorkspacePrevDevice);
|
||||
}
|
||||
else
|
||||
{ // beam_width == 1
|
||||
{ // beamWidth == 1
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mDecodingMode.isTopKorTopP(), "beamWidth == 1 is given, but decoder is not configured as TopK or TopP");
|
||||
|
||||
// In sampling, we have supported batch sampling. So, we always compute all
|
||||
// sentences once.
|
||||
const size_t local_batch_offset = ite * local_batch_size * beam_width;
|
||||
|
||||
Tensor const logits_slice{
|
||||
logits.slice({local_batch_size, beam_width, logits.shape[2]}, local_batch_offset * logits.shape[2])};
|
||||
Tensor const end_id_slice{end_ids.slice({local_batch_size}, ite * local_batch_size)};
|
||||
Tensor const logits_slice{logits.slice({localBatchSize, beamWidth, logits.shape[2]}, 0)};
|
||||
Tensor const end_id_slice{endIds.slice({localBatchSize}, 0)};
|
||||
typename BaseSamplingLayer<T>::ForwardParams decode_input_tensors{
|
||||
step, ite, logits_slice, end_id_slice, static_cast<std::int32_t>(max_seq_len)};
|
||||
step, ite, logits_slice, end_id_slice, static_cast<std::int32_t>(maxSeqLen)};
|
||||
|
||||
decode_input_tensors.embedding_bias = params.embedding_bias;
|
||||
decode_input_tensors.finished = params.finished;
|
||||
|
||||
if (params.input_lengths)
|
||||
{
|
||||
auto& input_lengths = params.input_lengths.value();
|
||||
decode_input_tensors.input_lengths
|
||||
= input_lengths.slice({local_batch_size, beam_width}, local_batch_offset);
|
||||
decode_input_tensors.input_lengths = input_lengths.slice({localBatchSize, beamWidth}, 0);
|
||||
}
|
||||
decode_input_tensors.batch_slots = params.batch_slots;
|
||||
|
||||
DecodingOutputParams decode_outputs(outputs.output_ids);
|
||||
decode_outputs.output_ids_ptr = std::move(outputs.output_ids_ptr);
|
||||
if (outputs.sequence_length)
|
||||
{
|
||||
decode_outputs.sequence_length
|
||||
= outputs.sequence_length->slice({local_batch_size * beam_width}, local_batch_offset);
|
||||
decode_outputs.sequence_length = outputs.sequence_length->slice({localBatchSize * beamWidth}, 0);
|
||||
}
|
||||
if (outputs.finished)
|
||||
{
|
||||
decode_outputs.finished = outputs.finished->slice({local_batch_size * beam_width}, local_batch_offset);
|
||||
decode_outputs.finished = outputs.finished->slice({localBatchSize * beamWidth}, 0);
|
||||
}
|
||||
if (outputs.cum_log_probs)
|
||||
{
|
||||
decode_outputs.cum_log_probs
|
||||
= outputs.cum_log_probs->slice({local_batch_size * beam_width}, local_batch_offset);
|
||||
decode_outputs.cum_log_probs = outputs.cum_log_probs->slice({localBatchSize * beamWidth}, 0);
|
||||
}
|
||||
if (outputs.output_log_probs_tiled)
|
||||
{
|
||||
TLLM_CHECK(0 <= mCyclicStep && mCyclicStep < max_seq_len);
|
||||
TLLM_CHECK(0 <= mCyclicStep && mCyclicStep < maxSeqLen);
|
||||
Tensor& output_log_probs = outputs.output_log_probs_tiled.value();
|
||||
size_t step_offset = mCyclicStep * batch_size * beam_width;
|
||||
decode_outputs.output_log_probs
|
||||
= output_log_probs.slice({1, local_batch_size * beam_width}, step_offset + local_batch_offset);
|
||||
size_t step_offset = mCyclicStep * batchSize * beamWidth;
|
||||
decode_outputs.output_log_probs = output_log_probs.slice({1, localBatchSize * beamWidth}, step_offset);
|
||||
}
|
||||
|
||||
// Run topk / topp decode layers.
|
||||
// Currently, we support batch sampling. If the runtime arguments are like
|
||||
// topk = [4, 0, 4]. topp = [0.0, 0.5, 0.5]
|
||||
// then topk_decode handles [4, x, 4 + 0.5]
|
||||
// topp_decode handles [x, 0.5, x]
|
||||
// where "x" are skipped.
|
||||
mTopKDecode->forward(decode_outputs, decode_input_tensors, top_k_workspace);
|
||||
mTopPDecode->forward(decode_outputs, decode_input_tensors, top_p_workspace);
|
||||
// Run TopK + TopP decode layers.
|
||||
mSamplingLayer->forward(decode_outputs, decode_input_tensors);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
if (params.stop_words_list)
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::applyPenalties(OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlotsHost, int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto logitsPtrsHost = ITensor::slice(mLogitsPtrsHost, mCyclicStep, 1);
|
||||
auto logitsPtrsHostData = reinterpret_cast<T const**>(runtime::bufferCast<int64_t>(*logitsPtrsHost));
|
||||
for (int32_t bi = 0; bi < batchSize; bi++)
|
||||
{
|
||||
const size_t id_offset = ite * local_batch_size * beam_width;
|
||||
const size_t stop_words_length = params.stop_words_list->shape[2];
|
||||
|
||||
invokeStopWordsCriterion(outputs.output_ids_ptr.template getPtr<const int*>(),
|
||||
outputs.parent_ids_ptr.template getPtr<const int*>(),
|
||||
params.stop_words_list->template getPtrWithOffset<const int>(
|
||||
ite * local_batch_size * 2 * stop_words_length),
|
||||
reinterpret_cast<FinishedState*>(
|
||||
outputs.finished->template getPtrWithOffset<FinishedState::UnderlyingType>(id_offset)),
|
||||
outputs.sequence_length->template getPtr<int>(), nullptr, stop_words_length, batch_size, beam_width,
|
||||
max_seq_len, mStream);
|
||||
if (params.logits_vec)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(params.logits_vec->size() == batchSize,
|
||||
"Logits vector size (%lu) is not equal to the batchSize (%lu)", params.logits_vec->size(), batchSize);
|
||||
logitsPtrsHostData[bi] = params.logits_vec.value()[bi].template getPtr<T>();
|
||||
}
|
||||
else
|
||||
{
|
||||
logitsPtrsHostData[bi] = params.logits->template getPtrWithOffset<T>(bi * beamWidth * mVocabSizePadded);
|
||||
}
|
||||
}
|
||||
|
||||
int32_t const* inputLengths = nullptr;
|
||||
if (params.input_lengths)
|
||||
{
|
||||
auto& input_lengths = params.input_lengths.value();
|
||||
inputLengths = input_lengths.template getPtr<const int>();
|
||||
}
|
||||
auto* embeddingBias = params.embedding_bias ? params.embedding_bias->template getPtr<T const>() : nullptr;
|
||||
#define GET_PENALTIES(capital_name, penalty_name, type) \
|
||||
(mUse##capital_name \
|
||||
&& !allOfBatchSlots(batchSlotsHost, m##capital_name.data(), batchSize, \
|
||||
static_cast<type>(getDefaultPenaltyValue(DecodingPenaltyType::penalty_name)))) \
|
||||
? m##capital_name##Device \
|
||||
: nullptr;
|
||||
|
||||
auto* temperatures = GET_PENALTIES(Temperature, Temperature, float);
|
||||
auto* repetitionPenalties = GET_PENALTIES(RepetitionPenalty, Repetition, float);
|
||||
auto* presencePenalties = GET_PENALTIES(PresencePenalty, Presence, float);
|
||||
auto* frequencyPenalties = GET_PENALTIES(FrequencyPenalty, Frequency, float);
|
||||
auto* minLengths = GET_PENALTIES(MinLength, MinLength, int32_t);
|
||||
|
||||
#undef GET_PENALTIES
|
||||
|
||||
InvokeBatchApplyPenaltyParams<T> penaltyParams{reinterpret_cast<T const* const*>(logitsPtrsHostData),
|
||||
mRuntimeLogitsDevice, embeddingBias, mPenaltyWorkspaceDevice, mPenaltyWorkspacePrevDevice, temperatures,
|
||||
repetitionPenalties, presencePenalties, frequencyPenalties,
|
||||
(mUseRepetitionPenalty || mUsePresencePenalty || mUseFrequencyPenalty), batchSize,
|
||||
static_cast<int32_t>(beamWidth), static_cast<int32_t>(maxSeqLen), mVocabSize, mVocabSizePadded,
|
||||
outputs.output_ids_ptr.template getPtr<const int*>(), outputs.parent_ids_ptr.template getPtr<const int*>(),
|
||||
inputLengths, outputs.sequence_length->template getPtr<const int>(), minLengths,
|
||||
params.end_ids.template getPtr<const int>(), batchSlots, mStream};
|
||||
invokeBatchApplyPenalty(penaltyParams);
|
||||
sync_check_cuda_error();
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::banWords(Tensor& logits, OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
banRepeatNGrams(logits, outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, vocabSizePadded, stream);
|
||||
banBadWords(logits, outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, vocabSizePadded, stream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::banRepeatNGrams(Tensor& logits, OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto const max_step = params.step;
|
||||
if (params.no_repeat_ngram_size)
|
||||
{
|
||||
const int* noRepeatNgramSizeBuf = params.no_repeat_ngram_size.value().template getPtr<const int>();
|
||||
|
||||
invokeBanRepeatNgram(logits.template getPtr<T>(), outputs.output_ids_ptr.template getPtr<const int*>(),
|
||||
reinterpret_cast<FinishedState*>(
|
||||
params.finished.value_or(Tensor{}).template getPtr<FinishedState::UnderlyingType>()),
|
||||
outputs.parent_ids_ptr.template getPtr<const int*>(), batchSlots,
|
||||
outputs.sequence_length->template getPtr<int>(), batchSize, beamWidth, maxSeqLen,
|
||||
params.no_repeat_ngram_size.value().template getPtr<const int>(), vocabSizePadded, max_step, stream);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::banBadWords(Tensor& logits, OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto const maxBadWordsLength = params.max_bad_words_len;
|
||||
if (maxBadWordsLength)
|
||||
{
|
||||
int32_t const** badWordsPtr = params.bad_words_ptr->template getPtr<int32_t const*>();
|
||||
int32_t const* badWordsLens = params.bad_words_lengths->template getPtr<int32_t>();
|
||||
|
||||
invokeBanBadWords((T*) logits.template getPtr<T>(), outputs.output_ids_ptr.template getPtr<int32_t const*>(),
|
||||
beamWidth > 1 ? outputs.parent_ids_ptr.template getPtr<int32_t const*>() : nullptr, batchSlots, batchSize,
|
||||
beamWidth, badWordsPtr, badWordsLens, maxBadWordsLength, vocabSizePadded,
|
||||
outputs.sequence_length->template getPtr<int>(), maxSeqLen, stream);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::checkStopCriteria(OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
checkStopWordsStopCriteria(outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, stream);
|
||||
checkMaxLengthStopCriteria(outputs, params, batchSlots, batchSize, beamWidth, maxSeqLen, stream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::checkStopWordsStopCriteria(OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto const maxStopWordsLength = params.max_stop_words_len;
|
||||
if (maxStopWordsLength)
|
||||
{
|
||||
invokeStopWordsCriterion(outputs.output_ids_ptr.template getPtr<int32_t const*>(),
|
||||
outputs.parent_ids_ptr.template getPtr<int32_t const*>(),
|
||||
params.stop_words_ptr->template getPtr<int32_t const*>(),
|
||||
reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>()),
|
||||
outputs.sequence_length->template getPtr<int32_t>(), batchSlots,
|
||||
params.stop_words_lengths->template getPtr<int32_t const>(), maxStopWordsLength, batchSize, beamWidth,
|
||||
maxSeqLen, stream);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::checkMaxLengthStopCriteria(OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
if (params.sequence_limit_length)
|
||||
{
|
||||
invokeLengthCriterion(
|
||||
reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>()),
|
||||
outputs.finished_sum ? outputs.finished_sum->template getPtr<int>() : nullptr,
|
||||
params.sequence_limit_length->template getPtr<const uint32_t>(),
|
||||
outputs.sequence_length->template getPtr<int>(), nullptr, batch_size, beam_width, mStream);
|
||||
outputs.sequence_length->template getPtr<int>(), batchSlots, batchSize, beamWidth, stream);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::prepareIdsPtrs(
|
||||
OutputParams& outputs, int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto idsPtrHostSlice = ITensor::slice(mIdsPtrHost, mCyclicStep, 1);
|
||||
auto idsPtrHost = reinterpret_cast<int32_t**>(runtime::bufferCast<int64_t>(*idsPtrHostSlice));
|
||||
for (int bi = 0; bi < batchSize; bi++)
|
||||
{
|
||||
auto const batchSlot = batchSlots[bi];
|
||||
idsPtrHost[batchSlot]
|
||||
= outputs.output_ids.template getPtrWithOffset<int32_t>(batchSlot * beamWidth * maxSeqLen);
|
||||
}
|
||||
for (int bi = 0; bi < batchSize; bi++)
|
||||
{
|
||||
auto const batchSlot = batchSlots[bi];
|
||||
if (beamWidth > 1)
|
||||
{
|
||||
idsPtrHost[mMaxBatchSize + batchSlot]
|
||||
= outputs.parent_ids.value().template getPtrWithOffset<int32_t>(bi * beamWidth * maxSeqLen);
|
||||
}
|
||||
else
|
||||
{
|
||||
idsPtrHost[mMaxBatchSize + batchSlot] = mZeroParentIdsDevice + bi * beamWidth * maxSeqLen;
|
||||
}
|
||||
}
|
||||
|
||||
outputs.output_ids_ptr
|
||||
= Tensor(MEMORY_GPU, DataType::TYPE_INT32_PTR, {mMaxBatchSize, beamWidth, maxSeqLen}, idsPtrHost);
|
||||
outputs.parent_ids_ptr = Tensor(
|
||||
MEMORY_GPU, DataType::TYPE_INT32_PTR, {mMaxBatchSize, beamWidth, maxSeqLen}, idsPtrHost + mMaxBatchSize);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DynamicDecodeLayer<T>::prepareOutputData(OutputParams& outputs, ForwardParams const& params,
|
||||
runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize, size_t beamWidth,
|
||||
size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto idsPtrHostSlice = ITensor::slice(idsPtrsHost, cyclicStep, 1);
|
||||
auto idsPtrHost = reinterpret_cast<int32_t**>(runtime::bufferCast<int64_t>(*idsPtrHostSlice));
|
||||
invokeCopyNextStepIds(outputs.newTokens.template getPtr<int>(), idsPtrHost,
|
||||
outputs.sequence_length->template getPtr<int>(), nullptr, batch_size, beam_width, max_seq_len, mStream);
|
||||
outputs.sequence_length->template getPtr<int>(), batchSlots, batchSize, beamWidth, maxSeqLen, stream);
|
||||
|
||||
// Transpose the output log probs from [max_seq_len, bs, beam_width] to [batch_size, beam_width, max_seq_len]
|
||||
// Transpose the output log probs from [maxSeqLen, bs, beamWidth] to [batchSize, beamWidth, maxSeqLen]
|
||||
if (outputs.output_log_probs_tiled)
|
||||
{
|
||||
auto logProbsMaxSeqLen = outputs.output_log_probs_tiled.value().shape[0];
|
||||
|
||||
invokeTransposeLogProbs(outputs.output_log_probs.value().template getPtr<float>(),
|
||||
outputs.output_log_probs_tiled.value().template getPtr<float>(),
|
||||
outputs.sequence_length->template getPtr<int>(), nullptr, batch_size, beam_width, logProbsMaxSeqLen,
|
||||
mStream);
|
||||
outputs.sequence_length->template getPtr<int>(), batchSlots, batchSize, beamWidth, logProbsMaxSeqLen,
|
||||
stream);
|
||||
}
|
||||
|
||||
mCyclicStep += 1;
|
||||
mCyclicStep = mCyclicStep % max_seq_len;
|
||||
|
||||
sync_check_cuda_error();
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template class DynamicDecodeLayer<float>;
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2022, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -20,9 +20,9 @@
|
||||
#include "tensorrt_llm/kernels/beamSearchTopkKernels.h"
|
||||
#include "tensorrt_llm/layers/baseLayer.h"
|
||||
#include "tensorrt_llm/layers/onlineBeamSearchLayer.h"
|
||||
#include "tensorrt_llm/layers/topKSamplingLayer.h"
|
||||
#include "tensorrt_llm/layers/topPSamplingLayer.h"
|
||||
#include "tensorrt_llm/layers/samplingLayer.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include <optional>
|
||||
@ -41,12 +41,14 @@ struct BeamHypotheses;
|
||||
|
||||
namespace layers
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
class DynamicDecodeLayer : public BaseLayer
|
||||
{
|
||||
public:
|
||||
DynamicDecodeLayer(size_t max_batch_size, size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream,
|
||||
std::shared_ptr<tc::IAllocator> allocator, cudaDeviceProp* cuda_device_prop);
|
||||
DynamicDecodeLayer(runtime::DecodingMode const& mode, size_t max_batch_size, size_t max_beam_width,
|
||||
size_t vocab_size, size_t vocab_size_padded, cudaStream_t stream, std::shared_ptr<tc::IAllocator> allocator,
|
||||
cudaDeviceProp* cuda_device_prop);
|
||||
|
||||
~DynamicDecodeLayer() override;
|
||||
DynamicDecodeLayer(DynamicDecodeLayer const& dynamic_decode_layer);
|
||||
@ -83,14 +85,15 @@ public:
|
||||
{
|
||||
public:
|
||||
ForwardParams(int step, int ite, int maxInputLength, int maxAttentionWindow, int sinkTokenLength,
|
||||
int localBatchSize, tc::Tensor logits, tc::Tensor endIds)
|
||||
int localBatchSize, tc::Tensor endIds)
|
||||
: step{step}
|
||||
, ite{ite}
|
||||
, max_input_length{maxInputLength}
|
||||
, max_attention_window{maxAttentionWindow}
|
||||
, sink_token_length{sinkTokenLength}
|
||||
, local_batch_size{localBatchSize}
|
||||
, logits{std::move(logits)}
|
||||
, max_stop_words_len{0}
|
||||
, max_bad_words_len{0}
|
||||
, end_ids{std::move(endIds)}
|
||||
{
|
||||
}
|
||||
@ -102,9 +105,16 @@ public:
|
||||
int max_attention_window;
|
||||
int sink_token_length;
|
||||
int local_batch_size;
|
||||
tc::Tensor logits; // [batch_size, beam_width, vocab_size_padded], on gpu
|
||||
int max_stop_words_len;
|
||||
int max_bad_words_len;
|
||||
tc::Tensor end_ids; // [batch_size], on gpu
|
||||
|
||||
// One of these two fields has to be set
|
||||
// DynamicDecodeLayer::forward checks for it
|
||||
// Need both of these fields to support legacy code during transition period to the batched decoder
|
||||
std::optional<tc::Tensor> logits; // [batch_size, beam_width, vocab_size_padded], on gpu
|
||||
std::optional<std::vector<tc::Tensor>> logits_vec; // [batch_size], on gpu
|
||||
|
||||
// optional parameters
|
||||
std::optional<tc::Tensor> finished; // [batch_size * beam_width], optional
|
||||
std::optional<tc::Tensor> src_cache_indirection; // [local_batch_size, beam_width, max_seq_len] - the k/v cache
|
||||
@ -112,9 +122,12 @@ public:
|
||||
std::optional<tc::Tensor> sequence_limit_length; // [batch_size], on gpu
|
||||
std::optional<tc::Tensor> embedding_bias; // [vocab_size_padded], on gpu
|
||||
std::optional<tc::Tensor> input_lengths; // [batch_size, beam_width], on gpu
|
||||
std::optional<tc::Tensor> bad_words_list; // [2, bad_words_length] or [batch_size, 2, bad_words_length], on gpu
|
||||
std::optional<tc::Tensor> stop_words_list; // [batch_size, 2, stop_words_length], on gpu
|
||||
std::optional<tc::Tensor> bad_words_ptr; // [2, bad_words_length] or [batch_size, 2, bad_words_length], on gpu
|
||||
std::optional<tc::Tensor> bad_words_lengths; // [batch_size], on gpu
|
||||
std::optional<tc::Tensor> stop_words_ptr; // [batch_size][2, stop_words_length], on gpu
|
||||
std::optional<tc::Tensor> stop_words_lengths; // [batch_size], on gpu
|
||||
std::optional<tc::Tensor> no_repeat_ngram_size; // [batch_size], optional
|
||||
std::optional<tc::Tensor> batch_slots; // [batch_size], optional, in pinned memory
|
||||
};
|
||||
|
||||
class OutputParams
|
||||
@ -148,31 +161,88 @@ public:
|
||||
};
|
||||
|
||||
void forward(OutputParams& outputs, ForwardParams const& params);
|
||||
void allocateBuffer(size_t batch_size, size_t beam_width, size_t max_seq_len);
|
||||
void allocateBuffer();
|
||||
void freeBuffer();
|
||||
|
||||
private:
|
||||
void initialize();
|
||||
void initializeLayers();
|
||||
|
||||
std::unique_ptr<OnlineBeamSearchLayer<T>> mOnlineBeamsearchDecode;
|
||||
std::unique_ptr<TopKSamplingLayer<T>> mTopKDecode;
|
||||
std::unique_ptr<TopPSamplingLayer<T>> mTopPDecode;
|
||||
void setupLayers(size_t batchSize, size_t beamWidth, int32_t const* batchSlots, SetupParams const& setupParams);
|
||||
void setupPenalties(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams);
|
||||
|
||||
size_t max_batch_size_;
|
||||
size_t vocab_size_;
|
||||
size_t vocab_size_padded_;
|
||||
cudaDeviceProp* cuda_device_prop_;
|
||||
int* zero_parent_ids = nullptr;
|
||||
int* top_k_workspace = nullptr;
|
||||
int* top_p_workspace = nullptr;
|
||||
int* beam_search_workspace_0 = nullptr;
|
||||
int* beam_search_workspace_1 = nullptr;
|
||||
runtime::IBuffer::SharedPtr mIdsPtrHost;
|
||||
void layersForward(tc::Tensor& logits, OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen);
|
||||
|
||||
bool has_diff_runtime_args_ = false;
|
||||
void applyPenalties(OutputParams& outputs, ForwardParams const& params, int32_t const* batchSlotsHost,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen);
|
||||
|
||||
static void banWords(tc::Tensor& logits, OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded,
|
||||
cudaStream_t stream);
|
||||
static void banRepeatNGrams(tc::Tensor& logits, OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded,
|
||||
cudaStream_t stream);
|
||||
static void banBadWords(tc::Tensor& logits, OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, size_t vocabSizePadded,
|
||||
cudaStream_t stream);
|
||||
|
||||
static void checkStopCriteria(OutputParams& outputs, ForwardParams const& params, int32_t const* batchSlots,
|
||||
size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream);
|
||||
static void checkMaxLengthStopCriteria(OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream);
|
||||
static void checkStopWordsStopCriteria(OutputParams& outputs, ForwardParams const& params,
|
||||
int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen, cudaStream_t stream);
|
||||
|
||||
void prepareIdsPtrs(
|
||||
OutputParams& outputs, int32_t const* batchSlots, size_t batchSize, size_t beamWidth, size_t maxSeqLen);
|
||||
static void prepareOutputData(OutputParams& outputs, ForwardParams const& params,
|
||||
runtime::ITensor::SharedPtr const& idsPtrsHost, int32_t const* batchSlots, size_t batchSize, size_t beamWidth,
|
||||
size_t maxSeqLen, int32_t cyclicStep, cudaStream_t stream);
|
||||
|
||||
private:
|
||||
std::unique_ptr<OnlineBeamSearchLayer<T>> mOnlineBeamSearchDecode;
|
||||
std::unique_ptr<SamplingLayer<T>> mSamplingLayer;
|
||||
|
||||
runtime::DecodingMode mDecodingMode;
|
||||
size_t mMaxBatchSize;
|
||||
size_t mMaxBeamWidth;
|
||||
size_t mVocabSize;
|
||||
size_t mVocabSizePadded;
|
||||
|
||||
cudaDeviceProp* mCudaDeviceProp;
|
||||
|
||||
int32_t* mZeroParentIdsDevice = nullptr;
|
||||
int32_t* mPenaltyWorkspaceDevice = nullptr;
|
||||
int32_t* mPenaltyWorkspacePrevDevice = nullptr;
|
||||
runtime::ITensor::SharedPtr mIdsPtrHost;
|
||||
runtime::ITensor::SharedPtr mLogitsPtrsHost;
|
||||
|
||||
float* mTemperatureDevice = nullptr;
|
||||
float* mRepetitionPenaltyDevice = nullptr;
|
||||
float* mPresencePenaltyDevice = nullptr;
|
||||
float* mFrequencyPenaltyDevice = nullptr;
|
||||
int32_t* mMinLengthDevice = nullptr;
|
||||
T* mRuntimeLogitsDevice = nullptr;
|
||||
|
||||
std::vector<float> mTemperature;
|
||||
std::vector<float> mRepetitionPenalty;
|
||||
std::vector<float> mPresencePenalty;
|
||||
std::vector<float> mFrequencyPenalty;
|
||||
std::vector<int32_t> mMinLength;
|
||||
|
||||
bool mUseTemperature = false;
|
||||
bool mUseRepetitionPenalty = false;
|
||||
bool mUsePresencePenalty = false;
|
||||
bool mUseFrequencyPenalty = false;
|
||||
bool mUseMinLength = false;
|
||||
|
||||
bool mHasDiffRuntimeArgs = false;
|
||||
int* h_pinned_finished_sum_ = nullptr;
|
||||
|
||||
int mCyclicStep = 0;
|
||||
int32_t mCyclicStep = 0;
|
||||
int32_t mRuntimeMaxSeqLen = 0;
|
||||
int32_t mConfiguredBeamWidth = -1;
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -39,29 +39,31 @@ struct FillBuffers
|
||||
|
||||
template <typename T>
|
||||
void operator()(std::optional<std::vector<T>> const& optParam, T const defaultValue, std::vector<T>& hostBuffer,
|
||||
T* deviceBuffer, void* deviceTmpBuffer, const int* batchSlots) const
|
||||
T* deviceBuffer, int32_t const* batchSlots) const
|
||||
{
|
||||
using tensorrt_llm::common::cudaAutoCpy;
|
||||
|
||||
hostBuffer.resize(batchSize);
|
||||
if (!optParam)
|
||||
for (size_t bi = 0; bi < batchSize; ++bi)
|
||||
{
|
||||
std::fill(std::begin(hostBuffer), std::end(hostBuffer), defaultValue);
|
||||
auto const batchSlot = batchSlots ? batchSlots[bi] : bi;
|
||||
if (!optParam)
|
||||
{
|
||||
hostBuffer[batchSlot] = defaultValue;
|
||||
}
|
||||
else if (optParam->size() == 1)
|
||||
{
|
||||
hostBuffer[batchSlot] = optParam->front();
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(optParam->size() == batchSize, "Argument vector size mismatch.");
|
||||
hostBuffer[batchSlot] = optParam.value()[bi];
|
||||
}
|
||||
}
|
||||
else if (optParam->size() == 1)
|
||||
|
||||
if (batchSlots)
|
||||
{
|
||||
std::fill(std::begin(hostBuffer), std::end(hostBuffer), optParam->front());
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(optParam->size() == batchSize, "Argument vector size mismatch.");
|
||||
std::copy(optParam->begin(), optParam->end(), std::begin(hostBuffer));
|
||||
}
|
||||
if (deviceTmpBuffer && batchSlots)
|
||||
{
|
||||
cudaAutoCpy(reinterpret_cast<T*>(deviceTmpBuffer), hostBuffer.data(), batchSize, stream);
|
||||
tensorrt_llm::kernels::invokeScatterDecodingParams(
|
||||
reinterpret_cast<T*>(deviceTmpBuffer), deviceBuffer, batchSlots, batchSize, stream);
|
||||
cudaAutoCpy(deviceBuffer, hostBuffer.data(), maxBatchSize, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -70,6 +72,7 @@ struct FillBuffers
|
||||
}
|
||||
|
||||
size_t batchSize;
|
||||
size_t maxBatchSize;
|
||||
cudaStream_t stream;
|
||||
};
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@ -101,15 +101,13 @@ void OnlineBeamSearchLayer<T>::setup(size_t batch_size, SetupParams const& setup
|
||||
BaseBeamSearchLayer<T>::setupBase(batch_size, setupParams);
|
||||
allocateBuffer(batch_size);
|
||||
|
||||
mDiversityRate = setupParams.beam_search_diversity_rate.value_or(std::vector<float>(0.0f));
|
||||
mLengthPenalty = setupParams.length_penalty.value_or(std::vector<float>(0.0f));
|
||||
mDiversityRate.resize(batch_size);
|
||||
mLengthPenalty.resize(batch_size);
|
||||
|
||||
FillBuffers const fillBuffers{batch_size, mStream};
|
||||
FillBuffers const fillBuffers{batch_size, batch_size, mStream};
|
||||
|
||||
fillBuffers(setupParams.beam_search_diversity_rate, 0.0f, mDiversityRate, diversity_rates_buf_, (float*) nullptr,
|
||||
(int*) nullptr);
|
||||
fillBuffers(
|
||||
setupParams.length_penalty, 0.0f, mLengthPenalty, length_penalties_buf_, (float*) nullptr, (int*) nullptr);
|
||||
fillBuffers(setupParams.beam_search_diversity_rate, 0.0f, mDiversityRate, diversity_rates_buf_, (int*) nullptr);
|
||||
fillBuffers(setupParams.length_penalty, 0.0f, mLengthPenalty, length_penalties_buf_, (int*) nullptr);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
||||
215
cpp/tensorrt_llm/layers/samplingLayer.cpp
Normal file
215
cpp/tensorrt_llm/layers/samplingLayer.cpp
Normal file
@ -0,0 +1,215 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/layers/samplingLayer.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
using namespace tensorrt_llm::kernels;
|
||||
using namespace tensorrt_llm::runtime;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
{
|
||||
template <typename T>
|
||||
void SamplingLayer<T>::allocateBuffer(size_t batchSize)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
|
||||
mSamplingWorkspaceSize = 0;
|
||||
if (mDecodingMode.isTopK())
|
||||
{
|
||||
mSamplingWorkspaceSize = std::max(mSamplingWorkspaceSize, mTopKDecode->getWorkspaceSize());
|
||||
}
|
||||
if (mDecodingMode.isTopP())
|
||||
{
|
||||
mSamplingWorkspaceSize = std::max(mSamplingWorkspaceSize, mTopPDecode->getWorkspaceSize());
|
||||
}
|
||||
|
||||
std::array<size_t, 4> deviceBufferSizes;
|
||||
deviceBufferSizes[0] = sizeof(curandState_t) * batchSize;
|
||||
deviceBufferSizes[1] = sizeof(uint64_t) * batchSize;
|
||||
deviceBufferSizes[2] = sizeof(bool) * batchSize;
|
||||
deviceBufferSizes[3] = mSamplingWorkspaceSize;
|
||||
|
||||
mCurandStatesDevice = mAllocator->reMalloc(mCurandStatesDevice, deviceBufferSizes[0], false);
|
||||
mRandomSeedsDevice = mAllocator->reMalloc(mRandomSeedsDevice, deviceBufferSizes[1], false);
|
||||
mSkipDecodeDevice = mAllocator->reMalloc(mSkipDecodeDevice, deviceBufferSizes[2], false);
|
||||
mSamplingWorkspaceDevice = mAllocator->reMalloc(mSamplingWorkspaceDevice, deviceBufferSizes[3], false);
|
||||
|
||||
auto const bytesAllocated = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), 0);
|
||||
TLLM_LOG_DEBUG("SamplingLayer allocated %d bytes on GPU", bytesAllocated);
|
||||
|
||||
mAllocatedSize = bytesAllocated;
|
||||
if (mDecodingMode.isTopK())
|
||||
{
|
||||
mAllocatedSize += mTopKDecode->getAllocatedSize();
|
||||
}
|
||||
if (mDecodingMode.isTopP())
|
||||
{
|
||||
mAllocatedSize += mTopPDecode->getAllocatedSize();
|
||||
}
|
||||
|
||||
// host buffers.
|
||||
mSkipDecodeHost = (bool*) std::realloc(mSkipDecodeHost, sizeof(bool) * batchSize);
|
||||
TLLM_CHECK(mSkipDecodeHost != nullptr);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SamplingLayer<T>::freeBuffer()
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
mAllocator->free((void**) (&mCurandStatesDevice));
|
||||
mAllocator->free((void**) (&mRandomSeedsDevice));
|
||||
mAllocator->free((void**) (&mSkipDecodeDevice));
|
||||
mAllocator->free((void**) (&mSamplingWorkspaceDevice));
|
||||
std::free(mSkipDecodeHost);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
SamplingLayer<T>::SamplingLayer(DecodingMode const& mode, size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded,
|
||||
cudaStream_t stream, std::shared_ptr<IAllocator> allocator, cudaDeviceProp* prop)
|
||||
: BaseSamplingLayer<T>(maxBatchSize, vocabSize, vocabSizePadded, stream, std::move(allocator), nullptr)
|
||||
, mDecodingMode(mode)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!mDecodingMode.isBeamSearch(), "Beam search mode has been requested from Sampling Layer");
|
||||
TLLM_CHECK_WITH_INFO(mDecodingMode.isTopKorTopP(), "Requested mode is neither TopK nor TopP");
|
||||
if (mDecodingMode.isTopK())
|
||||
{
|
||||
mTopKDecode
|
||||
= std::make_unique<TopKSamplingLayer<T>>(maxBatchSize, vocabSize, vocabSizePadded, mStream, mAllocator);
|
||||
}
|
||||
|
||||
if (mDecodingMode.isTopP())
|
||||
{
|
||||
mTopPDecode = std::make_unique<TopPSamplingLayer<T>>(
|
||||
maxBatchSize, vocabSize, vocabSizePadded, mStream, mAllocator, prop, /* deterministic */ true);
|
||||
}
|
||||
|
||||
allocateBuffer(maxBatchSize);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SamplingLayer<T>::setup(const size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
|
||||
// If runtime argument has single random seed, using this random seed to
|
||||
// initialize the random table of all sentences. If the argument has
|
||||
// [batchSize] random seeds, initializing the random table by different
|
||||
// random seeds respectively. If no random seed, initialize the random table
|
||||
// of all sentences by 0 directly.
|
||||
if (setupParams.randomSeed)
|
||||
{
|
||||
if (setupParams.randomSeed->size() == 1)
|
||||
{
|
||||
invokeCurandInitialize(
|
||||
mCurandStatesDevice, batchSlots, batchSize, setupParams.randomSeed->front(), mStream);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(setupParams.randomSeed->size() == batchSize, "Random seed vector size mismatch.");
|
||||
cudaAutoCpy(mRandomSeedsDevice, setupParams.randomSeed->data(), batchSize, mStream);
|
||||
invokeCurandBatchInitialize(mCurandStatesDevice, batchSlots, batchSize, mRandomSeedsDevice, mStream);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Initialize curand states using the default seed 0.
|
||||
invokeCurandInitialize(mCurandStatesDevice, batchSlots, batchSize, 0, mStream);
|
||||
}
|
||||
|
||||
if (mDecodingMode.isTopK())
|
||||
{
|
||||
mTopKDecode->setup(batchSize, batchSlots, setupParams);
|
||||
}
|
||||
if (mDecodingMode.isTopP())
|
||||
{
|
||||
mTopPDecode->setup(batchSize, batchSlots, setupParams);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void SamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.logits.shape[0];
|
||||
|
||||
auto logits = inputs.logits.template getPtr<T>();
|
||||
auto endIds = inputs.end_ids.template getPtr<const int>();
|
||||
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : nullptr;
|
||||
float* cumLogProbs = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr<float>() : nullptr;
|
||||
float* outputLogProbs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr<float>() : nullptr;
|
||||
|
||||
FinishedState* finishedInput = (inputs.finished)
|
||||
? reinterpret_cast<FinishedState*>(inputs.finished->template getPtr<FinishedState::UnderlyingType>())
|
||||
: nullptr;
|
||||
|
||||
std::vector<int32_t> batchSlotsVec(batchSize);
|
||||
std::iota(batchSlotsVec.begin(), batchSlotsVec.end(), 0);
|
||||
auto batchSlotsHost = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : batchSlotsVec.data();
|
||||
|
||||
bool skipTopK = !mDecodingMode.isTopK();
|
||||
if (!skipTopK)
|
||||
{
|
||||
skipTopK = allOfBatchSlots(batchSlotsHost, mTopKDecode->getSkipDecodeHost(), batchSize, true);
|
||||
}
|
||||
|
||||
bool skipTopP = !mDecodingMode.isTopP();
|
||||
if (!skipTopP)
|
||||
{
|
||||
skipTopP = allOfBatchSlots(batchSlotsHost, mTopPDecode->getSkipDecodeHost(), batchSize, true);
|
||||
}
|
||||
|
||||
// Compute probabilities either for TopP or if cumLogProbs or outputLogProbs are specified
|
||||
bool const skipSoftMax = skipTopP && cumLogProbs == nullptr && outputLogProbs == nullptr;
|
||||
|
||||
inputs.curand_states = mCurandStatesDevice;
|
||||
inputs.sampling_workspace = mSamplingWorkspaceDevice;
|
||||
inputs.probs_computed = !skipSoftMax;
|
||||
|
||||
invokeAddBiasSoftMax(logits, (T**) nullptr, logits, (T*) (nullptr), endIds, finishedInput, batchSlots, batchSize,
|
||||
mMaxBatchSize, /* bw */ 1, mVocabSize, mVocabSizePadded, skipSoftMax, /* batchSlotLogits */ false, mStream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
if (!skipTopK)
|
||||
{
|
||||
mTopKDecode->forward(outputs, inputs);
|
||||
}
|
||||
|
||||
if (!skipTopP)
|
||||
{
|
||||
mTopPDecode->forward(outputs, inputs);
|
||||
}
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template class SamplingLayer<float>;
|
||||
template class SamplingLayer<half>;
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
91
cpp/tensorrt_llm/layers/samplingLayer.h
Normal file
91
cpp/tensorrt_llm/layers/samplingLayer.h
Normal file
@ -0,0 +1,91 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
#include "tensorrt_llm/layers/baseSamplingLayer.h"
|
||||
#include "tensorrt_llm/layers/decodingParams.h"
|
||||
#include "tensorrt_llm/layers/topKSamplingLayer.h"
|
||||
#include "tensorrt_llm/layers/topPSamplingLayer.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
|
||||
namespace tc = tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace layers
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
inline bool allOfBatchSlots(int32_t const* batchSlotsHost, T const* data, size_t batchSize, T value)
|
||||
{
|
||||
return std::all_of(batchSlotsHost, batchSlotsHost + batchSize, [&](int32_t b) { return data[b] == value; });
|
||||
};
|
||||
|
||||
//! \brief Top class for sampling layers.
|
||||
//! It sets up and executes TopKSamplingLayer and TopPSamplingLayer samplings
|
||||
template <typename T>
|
||||
class SamplingLayer : public BaseSamplingLayer<T>
|
||||
{
|
||||
public:
|
||||
using Base = BaseSamplingLayer<T>;
|
||||
using SetupParams = typename Base::SetupParams;
|
||||
using ForwardParams = typename Base::ForwardParams;
|
||||
|
||||
SamplingLayer(runtime::DecodingMode const& mode, size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded,
|
||||
cudaStream_t stream, std::shared_ptr<tensorrt_llm::common::IAllocator> allocator, cudaDeviceProp* prop);
|
||||
|
||||
~SamplingLayer() override = default;
|
||||
|
||||
void forward(DecodingOutputParams& outputs, ForwardParams& inputs) override;
|
||||
|
||||
void setup(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) override;
|
||||
|
||||
private:
|
||||
using Base::mMaxBatchSize;
|
||||
using Base::mVocabSize;
|
||||
using Base::mVocabSizePadded;
|
||||
using Base::mSamplingWorkspaceSize;
|
||||
using Base::mAllocatedSize;
|
||||
|
||||
using Base::mStream;
|
||||
using Base::mAllocator;
|
||||
|
||||
runtime::DecodingMode mDecodingMode;
|
||||
|
||||
void* mSamplingWorkspaceDevice = nullptr;
|
||||
curandState_t* mCurandStatesDevice = nullptr;
|
||||
uint64_t* mRandomSeedsDevice = nullptr;
|
||||
|
||||
bool* mSkipDecodeDevice = nullptr;
|
||||
|
||||
bool* mSkipDecodeHost = nullptr;
|
||||
bool mSkipAny = false;
|
||||
|
||||
std::unique_ptr<TopKSamplingLayer<T>> mTopKDecode;
|
||||
std::unique_ptr<TopPSamplingLayer<T>> mTopPDecode;
|
||||
|
||||
private:
|
||||
void allocateBuffer(size_t batchSize);
|
||||
void freeBuffer();
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
} // namespace tensorrt_llm
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -74,46 +74,41 @@ void TopKSamplingLayer<T>::allocateBuffer(size_t const batchSize)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
invokeTopKSampling<T>(nullptr, mSamplingWorkspaceSize, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr,
|
||||
nullptr, nullptr, TOP_K_MAX, 1.0f, mVocabSizePadded, nullptr, nullptr, mStream, batchSize, mSkipDecodeDevice,
|
||||
mNormalizeLogProbs);
|
||||
nullptr, nullptr, TOP_K_MAX, 1.0f, mVocabSizePadded, nullptr, nullptr, mStream, batchSize, nullptr,
|
||||
mNormalizeLogProbs, false);
|
||||
|
||||
std::array<size_t, 4> deviceBufferSizes;
|
||||
deviceBufferSizes[0] = mSamplingWorkspaceSize;
|
||||
deviceBufferSizes[1] = sizeof(uint32_t) * batchSize;
|
||||
deviceBufferSizes[2] = sizeof(float) * batchSize;
|
||||
deviceBufferSizes[3] = std::max(deviceBufferSizes[1], deviceBufferSizes[2]);
|
||||
deviceBufferSizes[0] = sizeof(uint32_t) * batchSize;
|
||||
deviceBufferSizes[1] = sizeof(float) * batchSize;
|
||||
deviceBufferSizes[2] = sizeof(bool) * batchSize;
|
||||
deviceBufferSizes[3] = std::max(deviceBufferSizes[0], deviceBufferSizes[1]);
|
||||
|
||||
mSamplingWorkspaceDevice = mAllocator->reMalloc(mSamplingWorkspaceDevice, deviceBufferSizes[0], false);
|
||||
mRuntimeTopKDevice = mAllocator->reMalloc(mRuntimeTopKDevice, deviceBufferSizes[1], false);
|
||||
mRuntimeTopPDevice = mAllocator->reMalloc(mRuntimeTopPDevice, deviceBufferSizes[2], false);
|
||||
mRuntimeTopKDevice = mAllocator->reMalloc(mRuntimeTopKDevice, deviceBufferSizes[0], false);
|
||||
mRuntimeTopPDevice = mAllocator->reMalloc(mRuntimeTopPDevice, deviceBufferSizes[1], false);
|
||||
mSkipDecodeDevice = mAllocator->reMalloc(mSkipDecodeDevice, deviceBufferSizes[2], false);
|
||||
mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[3], false);
|
||||
|
||||
auto const bytesAllocated = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), (size_t) 0);
|
||||
TLLM_LOG_DEBUG("topKSamplingLayer allocated %lu bytes on GPU", (size_t) bytesAllocated);
|
||||
mSkipDecodeHost = (bool*) std::realloc(mSkipDecodeHost, sizeof(bool) * batchSize);
|
||||
|
||||
mIsAllocateBuffer = true;
|
||||
mAllocatedSize = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), 0);
|
||||
TLLM_LOG_DEBUG("topKSamplingLayer allocated %lu bytes on GPU", mAllocatedSize);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TopKSamplingLayer<T>::freeBuffer()
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
if (mIsAllocateBuffer)
|
||||
{
|
||||
mAllocator->free((void**) (&mSamplingWorkspaceDevice));
|
||||
mAllocator->free((void**) (&mRuntimeTopKDevice));
|
||||
mAllocator->free((void**) (&mRuntimeTopPDevice));
|
||||
mAllocator->free((void**) (&mSetupWorkspaceDevice));
|
||||
}
|
||||
BaseSamplingLayer<T>::freeBuffer();
|
||||
mIsAllocateBuffer = false;
|
||||
mAllocator->free((void**) (&mRuntimeTopKDevice));
|
||||
mAllocator->free((void**) (&mRuntimeTopPDevice));
|
||||
mAllocator->free((void**) (&mSkipDecodeDevice));
|
||||
mAllocator->free((void**) (&mSetupWorkspaceDevice));
|
||||
std::free(mSkipDecodeHost);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TopKSamplingLayer<T>::setup(size_t const batchSize, int const* batchSlots, SetupParams const& setupParams)
|
||||
void TopKSamplingLayer<T>::setup(size_t const batchSize, int32_t const* batchSlots, SetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
BaseSamplingLayer<T>::setupBase(batchSize, batchSlots, setupParams);
|
||||
|
||||
uint32_t constexpr defaultTopK = 0;
|
||||
auto runtimeTopK = setupParams.runtime_top_k.value_or(std::vector<uint32_t>{defaultTopK});
|
||||
@ -161,29 +156,48 @@ void TopKSamplingLayer<T>::setup(size_t const batchSize, int const* batchSlots,
|
||||
reinterpret_cast<float*>(mSetupWorkspaceDevice), mRuntimeTopPDevice, batchSlots, batchSize, mStream);
|
||||
}
|
||||
|
||||
dim3 block(std::min((int) batchSize, 256));
|
||||
dim3 grid(divUp((int) batchSize, (int) block.x));
|
||||
// support topK up to TOP_K_MAX.
|
||||
setupTopKRuntimeArgs<TOP_K_MAX><<<grid, block, 0, mStream>>>(batchSize, topK, mRuntimeTopKDevice, runtimeTopKSize,
|
||||
topP, mRuntimeTopPDevice, runtimeTopPSize, mSkipDecodeDevice, batchSlots);
|
||||
{
|
||||
dim3 block(std::min((int) batchSize, 256));
|
||||
dim3 grid(divUp((int) batchSize, (int) block.x));
|
||||
// support topK up to TOP_K_MAX.
|
||||
setupTopKRuntimeArgs<TOP_K_MAX><<<grid, block, 0, mStream>>>(batchSize, topK, mRuntimeTopKDevice,
|
||||
runtimeTopKSize, topP, mRuntimeTopPDevice, runtimeTopPSize, mSkipDecodeDevice, batchSlots);
|
||||
}
|
||||
|
||||
cudaAutoCpy(mSkipDecodeHost, mSkipDecodeDevice, mMaxBatchSize, mStream);
|
||||
std::vector<uint32_t> runtimeTopKs(mMaxBatchSize);
|
||||
cudaAutoCpy(runtimeTopKs.data(), mRuntimeTopKDevice, mMaxBatchSize, mStream);
|
||||
// TODO(nkorobov): find maxTopK using batch slot
|
||||
mRuntimeMaxTopK = *std::max_element(std::begin(runtimeTopKs), std::end(runtimeTopKs));
|
||||
{
|
||||
uint32_t maxTopK = 0;
|
||||
for (size_t bi = 0; bi < batchSize; ++bi)
|
||||
{
|
||||
uint32_t bid = bi;
|
||||
if (batchSlots)
|
||||
{
|
||||
bid = batchSlots[bi];
|
||||
}
|
||||
maxTopK = std::max(maxTopK, runtimeTopKs[bid]);
|
||||
}
|
||||
mRuntimeMaxTopK = std::max(mRuntimeMaxTopK, maxTopK);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TopKSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingParams const& inputs)
|
||||
void TopKSamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.logits.shape[0];
|
||||
|
||||
// in case of skip any, the logit value is already copied and processed.
|
||||
auto* logits = mSkipAny ? mRuntimeLogitsDevice : inputs.logits.template getPtr<T>();
|
||||
auto* endIds = inputs.end_ids.template getPtr<const int>();
|
||||
auto* batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : nullptr;
|
||||
auto logits = inputs.logits.template getPtr<T>();
|
||||
auto endIds = inputs.end_ids.template getPtr<const int>();
|
||||
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : nullptr;
|
||||
auto curandStatesDevice = inputs.curand_states;
|
||||
auto samplingWorkspaceDevice = inputs.sampling_workspace;
|
||||
auto const probsComputed = inputs.probs_computed;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(curandStatesDevice, "No curand states provided");
|
||||
TLLM_CHECK_WITH_INFO(samplingWorkspaceDevice, "No sampling workspace provided");
|
||||
|
||||
FinishedState* finishedInput = (inputs.finished)
|
||||
? reinterpret_cast<FinishedState*>(inputs.finished->template getPtr<FinishedState::UnderlyingType>())
|
||||
@ -191,32 +205,16 @@ void TopKSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
|
||||
FinishedState* finishedOutput = (outputs.finished)
|
||||
? reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>())
|
||||
: nullptr;
|
||||
invokeAddBiasEndMask(
|
||||
logits, (T*) (nullptr), endIds, finishedInput, batchSlots, batchSize, mVocabSize, mVocabSizePadded, mStream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
float* cumLogProbs = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr<float>() : nullptr;
|
||||
float* outputLogProbs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr<float>() : nullptr;
|
||||
|
||||
if (cumLogProbs != nullptr || outputLogProbs != nullptr)
|
||||
{
|
||||
invokeAddBiasSoftMax(logits, logits, (T*) (nullptr), endIds, finishedInput, batchSlots, batchSize, mVocabSize,
|
||||
mVocabSizePadded, mStream);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
int* sequenceLength = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<int>() : nullptr;
|
||||
|
||||
invokeBatchTopKSampling(mSamplingWorkspaceDevice, mSamplingWorkspaceSize, logits,
|
||||
invokeBatchTopKSampling(samplingWorkspaceDevice, mSamplingWorkspaceSize, logits,
|
||||
outputs.output_ids_ptr.template getPtr<int*>(), sequenceLength, finishedInput, finishedOutput, cumLogProbs,
|
||||
outputLogProbs, mCurandStatesDevice,
|
||||
(int) mRuntimeMaxTopK, // useless because mRuntimeTopKDevice is never
|
||||
// nullptr. Keep for legacy.
|
||||
(int*) (mRuntimeTopKDevice),
|
||||
1.0f, // useless because mRuntimeTopPDevice is never nullptr. Keep for
|
||||
// legacy.
|
||||
outputLogProbs, curandStatesDevice, (int) mRuntimeMaxTopK, (int*) (mRuntimeTopKDevice), 1.0f,
|
||||
mRuntimeTopPDevice, mVocabSizePadded, endIds, batchSlots, mStream, batchSize, mSkipDecodeDevice,
|
||||
mNormalizeLogProbs);
|
||||
mNormalizeLogProbs, probsComputed);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
@ -228,13 +226,6 @@ TopKSamplingLayer<T>::TopKSamplingLayer(size_t maxBatchSize, size_t vocabSize, s
|
||||
allocateBuffer(mMaxBatchSize);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TopKSamplingLayer<T>::TopKSamplingLayer(TopKSamplingLayer<T> const& topKSamplingLayer)
|
||||
: BaseSamplingLayer<T>(topKSamplingLayer)
|
||||
{
|
||||
allocateBuffer(mMaxBatchSize);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TopKSamplingLayer<T>::~TopKSamplingLayer()
|
||||
{
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -36,46 +36,44 @@ class TopKSamplingLayer : public BaseSamplingLayer<T>
|
||||
public:
|
||||
using Base = BaseSamplingLayer<T>;
|
||||
using SetupParams = typename Base::SetupParams;
|
||||
using ForwardParams = typename Base::ForwardParams;
|
||||
|
||||
TopKSamplingLayer(size_t maxBatchSize, size_t vocabSize, size_t vocabSizePadded, cudaStream_t stream,
|
||||
std::shared_ptr<tensorrt_llm::common::IAllocator> allocator);
|
||||
TopKSamplingLayer(TopKSamplingLayer<T> const& top_k_sampling_layer);
|
||||
~TopKSamplingLayer();
|
||||
|
||||
void setup(size_t batchSize, int const* batch_slots, SetupParams const& setupParams) override;
|
||||
void setup(size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) override;
|
||||
void forward(DecodingOutputParams& outputs, ForwardParams& inputs) override;
|
||||
|
||||
protected:
|
||||
void runSampling(DecodingOutputParams& outputs, DecodingParams const& inputs) override;
|
||||
|
||||
void freeBuffer() override;
|
||||
const bool* getSkipDecodeHost() const
|
||||
{
|
||||
return mSkipDecodeHost;
|
||||
}
|
||||
|
||||
protected:
|
||||
bool mNormalizeLogProbs = true;
|
||||
uint32_t mRuntimeMaxTopK = 1;
|
||||
uint32_t mRuntimeMaxTopK = 0;
|
||||
uint32_t* mRuntimeTopKDevice = nullptr;
|
||||
float* mRuntimeTopPDevice = nullptr;
|
||||
void* mSetupWorkspaceDevice = nullptr;
|
||||
bool* mSkipDecodeDevice = nullptr;
|
||||
bool* mSkipDecodeHost = nullptr;
|
||||
|
||||
using Base::mMaxBatchSize;
|
||||
using Base::mVocabSize;
|
||||
using Base::mVocabSizePadded;
|
||||
|
||||
using Base::mSamplingWorkspaceSize;
|
||||
using Base::mSamplingWorkspaceDevice;
|
||||
using Base::mCurandStatesDevice;
|
||||
using Base::mSkipDecodeDevice;
|
||||
using Base::mSkipDecodeHost;
|
||||
using Base::mSkipAny;
|
||||
using Base::mRuntimeLogitsDevice;
|
||||
using Base::mAllocatedSize;
|
||||
|
||||
using Base::mStream;
|
||||
using Base::mAllocator;
|
||||
using Base::mIsAllocateBuffer;
|
||||
|
||||
static constexpr uint32_t TOP_K_MAX = 1024;
|
||||
|
||||
private:
|
||||
void allocateBuffer(size_t batchSize);
|
||||
void freeBuffer();
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -34,34 +34,20 @@ namespace tensorrt_llm
|
||||
namespace layers
|
||||
{
|
||||
|
||||
static __global__ void set_topp_runtime_args(int batchSize, uint32_t top_k, uint32_t* top_ks, int top_ks_size,
|
||||
float top_p, float* top_ps, int top_ps_size, bool* skip_decode, float* initial_top_p_buf, float* top_p_decay_buf,
|
||||
float* top_p_min_buf, const int* batch_slots)
|
||||
static __global__ void setTopPRuntimeArgs(int batchSize, uint32_t topK, uint32_t* topKs, int topKsSize, float topP,
|
||||
float* topPs, int topPsSize, bool* skipDecode, const int* batchSlots, float* initialTopPBuf)
|
||||
{
|
||||
/**
|
||||
* @brief Setup the runtime arguments for topp, broadcasting top_p to top_ps
|
||||
and top_k to top_ks, verifying value ranges of top_p_decay/top_p_min.
|
||||
*
|
||||
* \param batchSize
|
||||
* \param top_k
|
||||
* \param top_ks [batchSize]
|
||||
* \param top_ks_size
|
||||
* \param top_p
|
||||
* \param top_ps [batchSize]
|
||||
* \param top_ps_size
|
||||
* \param skip_decode [batchSize]
|
||||
* \param initial_top_p_buf [batchSize]
|
||||
* \param top_p_decay_buf [batchSize]
|
||||
* \param top_p_min_buf [batchSize]
|
||||
*
|
||||
and top_k to top_ks.
|
||||
*/
|
||||
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
for (int bi = index; bi < batchSize; bi += gridDim.x * blockDim.x)
|
||||
{
|
||||
auto const batch_slot = batch_slots != nullptr ? batch_slots[bi] : bi;
|
||||
std::uint32_t k = top_ks_size > 1 ? top_ks[batch_slot] : top_k;
|
||||
float p = top_ps_size > 1 ? top_ps[batch_slot] : top_p;
|
||||
auto const batchSlot = batchSlots != nullptr ? batchSlots[bi] : bi;
|
||||
std::uint32_t k = topKsSize > 1 ? topKs[batchSlot] : topK;
|
||||
float p = topPsSize > 1 ? topPs[batchSlot] : topP;
|
||||
if (k == 0 && p == 0.0f)
|
||||
{
|
||||
// TensorRT-LLM's topp implementation does not support topp = 0.0f, but it
|
||||
@ -69,11 +55,11 @@ static __global__ void set_topp_runtime_args(int batchSize, uint32_t top_k, uint
|
||||
// solution.
|
||||
k = 1;
|
||||
}
|
||||
top_ks[batch_slot] = k;
|
||||
top_ps[batch_slot] = p;
|
||||
skip_decode[batch_slot] = k > 0;
|
||||
topKs[batchSlot] = k;
|
||||
topPs[batchSlot] = p;
|
||||
skipDecode[batchSlot] = k > 0;
|
||||
|
||||
initial_top_p_buf[batch_slot] = top_ps[batch_slot];
|
||||
initialTopPBuf[batchSlot] = topPs[batchSlot];
|
||||
}
|
||||
}
|
||||
|
||||
@ -81,10 +67,10 @@ template <typename T>
|
||||
void TopPSamplingLayer<T>::allocateBuffer(size_t batchSize)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
if (is_deterministic_)
|
||||
if (mIsDeterministic)
|
||||
{
|
||||
invokeTopPSampling<T>(nullptr, // workspace
|
||||
mSamplingWorkspaceSize, cub_temp_storage_size_,
|
||||
mSamplingWorkspaceSize, mCubTempStorageSize,
|
||||
nullptr, // output_ids
|
||||
nullptr, // sequence_length
|
||||
nullptr, // finished_input_buffer
|
||||
@ -92,8 +78,8 @@ void TopPSamplingLayer<T>::allocateBuffer(size_t batchSize)
|
||||
nullptr, // cum_log_probs
|
||||
nullptr, // output_log_probs
|
||||
nullptr, // log_probs
|
||||
topp_id_vals_buf_, topp_offset_buf_, begin_topp_offset_buf_, mCurandStatesDevice, batchSize,
|
||||
mVocabSizePadded, nullptr, 0.f, mStream, mSkipDecodeDevice, nullptr);
|
||||
mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, nullptr, batchSize, mVocabSizePadded, nullptr,
|
||||
0.f, mStream, nullptr, nullptr);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -105,68 +91,63 @@ void TopPSamplingLayer<T>::allocateBuffer(size_t batchSize)
|
||||
nullptr, // cum_log_probs
|
||||
nullptr, // output_log_probs
|
||||
nullptr, // log_probs)
|
||||
mCurandStatesDevice, batchSize, mVocabSizePadded, nullptr, 0.f, mStream, air_topp_block_num_,
|
||||
mSkipDecodeDevice, nullptr);
|
||||
nullptr, batchSize, mVocabSizePadded, nullptr, 0.f, mStream, mAirTopPBlockNum, nullptr, nullptr);
|
||||
}
|
||||
|
||||
std::array<size_t, 11> deviceBufferSizes;
|
||||
deviceBufferSizes[0] = mSamplingWorkspaceSize;
|
||||
deviceBufferSizes[1] = sizeof(int32_t) * batchSize * mVocabSizePadded;
|
||||
deviceBufferSizes[0] = sizeof(int32_t) * batchSize * mVocabSizePadded;
|
||||
deviceBufferSizes[1] = sizeof(int32_t) * (batchSize + 1);
|
||||
deviceBufferSizes[2] = sizeof(int32_t) * (batchSize + 1);
|
||||
deviceBufferSizes[3] = sizeof(int32_t) * (batchSize + 1);
|
||||
deviceBufferSizes[4] = sizeof(uint32_t) * batchSize;
|
||||
deviceBufferSizes[3] = sizeof(uint32_t) * batchSize;
|
||||
deviceBufferSizes[4] = sizeof(float) * batchSize;
|
||||
deviceBufferSizes[5] = sizeof(float) * batchSize;
|
||||
deviceBufferSizes[6] = sizeof(float) * batchSize;
|
||||
deviceBufferSizes[7] = sizeof(float) * batchSize;
|
||||
deviceBufferSizes[8] = sizeof(float) * batchSize;
|
||||
deviceBufferSizes[9] = sizeof(int32_t) * batchSize;
|
||||
deviceBufferSizes[10] = *std::max_element(&deviceBufferSizes[4], &deviceBufferSizes[10]);
|
||||
deviceBufferSizes[8] = sizeof(int32_t) * batchSize;
|
||||
deviceBufferSizes[9] = sizeof(bool) * batchSize;
|
||||
deviceBufferSizes[10] = *std::max_element(&deviceBufferSizes[3], &deviceBufferSizes[9]);
|
||||
|
||||
mSamplingWorkspaceDevice = mAllocator->reMalloc(mSamplingWorkspaceDevice, deviceBufferSizes[0], true);
|
||||
topp_id_vals_buf_ = mAllocator->reMalloc(topp_id_vals_buf_, deviceBufferSizes[1], false);
|
||||
topp_offset_buf_ = mAllocator->reMalloc(topp_offset_buf_, deviceBufferSizes[2], false);
|
||||
begin_topp_offset_buf_ = mAllocator->reMalloc(begin_topp_offset_buf_, deviceBufferSizes[3], false);
|
||||
runtime_top_k_buf_ = mAllocator->reMalloc(runtime_top_k_buf_, deviceBufferSizes[4], false);
|
||||
runtime_top_p_buf_ = mAllocator->reMalloc(runtime_top_p_buf_, deviceBufferSizes[5], false);
|
||||
initial_top_p_buf_ = mAllocator->reMalloc(initial_top_p_buf_, deviceBufferSizes[6], false);
|
||||
top_p_decay_buf_ = mAllocator->reMalloc(top_p_decay_buf_, deviceBufferSizes[7], false);
|
||||
top_p_min_buf_ = mAllocator->reMalloc(top_p_min_buf_, deviceBufferSizes[8], false);
|
||||
top_p_reset_ids_buf_ = mAllocator->reMalloc(top_p_reset_ids_buf_, deviceBufferSizes[9], false);
|
||||
setup_workspace_buf_ = mAllocator->reMalloc(setup_workspace_buf_, deviceBufferSizes[10], false);
|
||||
mTopPIdValsDevice = mAllocator->reMalloc(mTopPIdValsDevice, deviceBufferSizes[0], false);
|
||||
mTopPOffsetDevice = mAllocator->reMalloc(mTopPOffsetDevice, deviceBufferSizes[1], false);
|
||||
mBeginTopPOffsetDevice = mAllocator->reMalloc(mBeginTopPOffsetDevice, deviceBufferSizes[2], false);
|
||||
mRuntimeTopKDevice = mAllocator->reMalloc(mRuntimeTopKDevice, deviceBufferSizes[3], false);
|
||||
mRuntimeTopPDevice = mAllocator->reMalloc(mRuntimeTopPDevice, deviceBufferSizes[4], false);
|
||||
mInitialTopPDevice = mAllocator->reMalloc(mInitialTopPDevice, deviceBufferSizes[5], false);
|
||||
mTopPDecayDevice = mAllocator->reMalloc(mTopPDecayDevice, deviceBufferSizes[6], false);
|
||||
mTopPMinDevice = mAllocator->reMalloc(mTopPMinDevice, deviceBufferSizes[7], false);
|
||||
mTopPResetIdsDevice = mAllocator->reMalloc(mTopPResetIdsDevice, deviceBufferSizes[8], false);
|
||||
mSkipDecodeDevice = mAllocator->reMalloc(mSkipDecodeDevice, deviceBufferSizes[9], false);
|
||||
mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[10], false);
|
||||
|
||||
auto const bytesAllocated = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), (size_t) 0);
|
||||
TLLM_LOG_DEBUG("topPSamplingLayer allocated %lu bytes on GPU", (size_t) bytesAllocated);
|
||||
mSkipDecodeHost = (bool*) std::realloc(mSkipDecodeHost, sizeof(bool) * batchSize);
|
||||
std::fill(mSkipDecodeHost, mSkipDecodeHost + batchSize, true);
|
||||
|
||||
mIsAllocateBuffer = true;
|
||||
mAllocatedSize = std::accumulate(deviceBufferSizes.begin(), deviceBufferSizes.end(), 0);
|
||||
TLLM_LOG_DEBUG("topPSamplingLayer allocated %lu bytes on GPU", mAllocatedSize);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TopPSamplingLayer<T>::freeBuffer()
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
if (mIsAllocateBuffer)
|
||||
{
|
||||
mAllocator->free((void**) (&mSamplingWorkspaceDevice));
|
||||
mAllocator->free((void**) (&topp_id_vals_buf_));
|
||||
mAllocator->free((void**) (&topp_offset_buf_));
|
||||
mAllocator->free((void**) (&begin_topp_offset_buf_));
|
||||
mAllocator->free((void**) (&runtime_top_k_buf_));
|
||||
mAllocator->free((void**) (&runtime_top_p_buf_));
|
||||
mAllocator->free((void**) (&initial_top_p_buf_));
|
||||
mAllocator->free((void**) (&top_p_decay_buf_));
|
||||
mAllocator->free((void**) (&top_p_min_buf_));
|
||||
mAllocator->free((void**) (&top_p_reset_ids_buf_));
|
||||
mAllocator->free((void**) (&setup_workspace_buf_));
|
||||
}
|
||||
BaseSamplingLayer<T>::freeBuffer();
|
||||
mIsAllocateBuffer = false;
|
||||
mAllocator->free((void**) (&mTopPIdValsDevice));
|
||||
mAllocator->free((void**) (&mTopPOffsetDevice));
|
||||
mAllocator->free((void**) (&mBeginTopPOffsetDevice));
|
||||
mAllocator->free((void**) (&mRuntimeTopKDevice));
|
||||
mAllocator->free((void**) (&mRuntimeTopPDevice));
|
||||
mAllocator->free((void**) (&mInitialTopPDevice));
|
||||
mAllocator->free((void**) (&mTopPDecayDevice));
|
||||
mAllocator->free((void**) (&mTopPMinDevice));
|
||||
mAllocator->free((void**) (&mTopPResetIdsDevice));
|
||||
mAllocator->free((void**) (&mSkipDecodeDevice));
|
||||
mAllocator->free((void**) (&mSetupWorkspaceDevice));
|
||||
std::free(mSkipDecodeHost);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TopPSamplingLayer<T>::setup(size_t const batchSize, int const* batchSlots, SetupParams const& setupParams)
|
||||
void TopPSamplingLayer<T>::setup(size_t const batchSize, int32_t const* batchSlots, SetupParams const& setupParams)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
BaseSamplingLayer<T>::setupBase(batchSize, batchSlots, setupParams);
|
||||
|
||||
uint32_t const defaultTopK = 0;
|
||||
auto runtimeTopK = setupParams.runtime_top_k.value_or(std::vector<uint32_t>{defaultTopK});
|
||||
@ -186,7 +167,16 @@ void TopPSamplingLayer<T>::setup(size_t const batchSize, int const* batchSlots,
|
||||
|
||||
if (runtimeTopPSize == 0)
|
||||
{
|
||||
std::fill_n(mSkipDecodeHost, batchSize, true);
|
||||
for (size_t bi = 0; bi < batchSize; ++bi)
|
||||
{
|
||||
int32_t bid = bi;
|
||||
if (batchSlots)
|
||||
{
|
||||
bid = batchSlots[bi];
|
||||
}
|
||||
mSkipDecodeHost[bid] = true;
|
||||
}
|
||||
cudaAutoCpy(mSkipDecodeDevice, mSkipDecodeHost, mMaxBatchSize, mStream);
|
||||
return;
|
||||
}
|
||||
|
||||
@ -224,17 +214,17 @@ void TopPSamplingLayer<T>::setup(size_t const batchSize, int const* batchSlots,
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(runtimeTopK.size() == batchSize,
|
||||
fmtstr("runtimeTopK.size() (%lu) == batchSize (%lu) is not satisfied!", runtimeTopK.size(), batchSize));
|
||||
cudaAutoCpy(reinterpret_cast<uint32_t*>(setup_workspace_buf_), runtimeTopK.data(), batchSize, mStream);
|
||||
cudaAutoCpy(reinterpret_cast<uint32_t*>(mSetupWorkspaceDevice), runtimeTopK.data(), batchSize, mStream);
|
||||
invokeScatterDecodingParams(
|
||||
reinterpret_cast<uint32_t*>(setup_workspace_buf_), runtime_top_k_buf_, batchSlots, batchSize, mStream);
|
||||
reinterpret_cast<uint32_t*>(mSetupWorkspaceDevice), mRuntimeTopKDevice, batchSlots, batchSize, mStream);
|
||||
}
|
||||
if (runtimeTopPSize > 1)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(runtimeTopP.size() == batchSize,
|
||||
fmtstr("runtime_top_p.size() (%lu) == batchSize (%lu) is not satisfied!", runtimeTopP.size(), batchSize));
|
||||
cudaAutoCpy(reinterpret_cast<float*>(setup_workspace_buf_), runtimeTopP.data(), batchSize, mStream);
|
||||
cudaAutoCpy(reinterpret_cast<float*>(mSetupWorkspaceDevice), runtimeTopP.data(), batchSize, mStream);
|
||||
invokeScatterDecodingParams(
|
||||
reinterpret_cast<float*>(setup_workspace_buf_), runtime_top_p_buf_, batchSlots, batchSize, mStream);
|
||||
reinterpret_cast<float*>(mSetupWorkspaceDevice), mRuntimeTopPDevice, batchSlots, batchSize, mStream);
|
||||
}
|
||||
|
||||
auto fillBuffers
|
||||
@ -246,50 +236,66 @@ void TopPSamplingLayer<T>::setup(size_t const batchSize, int const* batchSlots,
|
||||
invokeScatterDecodingParams(deviceTmpBuffer, deviceBuffer, batchSlots, batchSize, mStream);
|
||||
};
|
||||
|
||||
fillBuffers("top_p_decay", decayVec, reinterpret_cast<float*>(setup_workspace_buf_), top_p_decay_buf_);
|
||||
fillBuffers("top_p_decay", decayVec, reinterpret_cast<float*>(mSetupWorkspaceDevice), mTopPDecayDevice);
|
||||
|
||||
fillBuffers("top_p_min", topPMinVec, reinterpret_cast<float*>(setup_workspace_buf_), top_p_min_buf_);
|
||||
fillBuffers("top_p_min", topPMinVec, reinterpret_cast<float*>(mSetupWorkspaceDevice), mTopPMinDevice);
|
||||
|
||||
fillBuffers(
|
||||
"top_p_reset_ids", topPResetIdsVec, reinterpret_cast<int32_t*>(setup_workspace_buf_), top_p_reset_ids_buf_);
|
||||
"top_p_reset_ids", topPResetIdsVec, reinterpret_cast<int32_t*>(mSetupWorkspaceDevice), mTopPResetIdsDevice);
|
||||
|
||||
dim3 block(std::min((int) batchSize, 256));
|
||||
dim3 grid(divUp((int) batchSize, (int) block.x));
|
||||
set_topp_runtime_args<<<grid, block, 0, mStream>>>(batchSize, topK, runtime_top_k_buf_, runtimeTopKSize, topP,
|
||||
runtime_top_p_buf_, runtimeTopPSize, mSkipDecodeDevice, initial_top_p_buf_, top_p_decay_buf_, top_p_min_buf_,
|
||||
batchSlots);
|
||||
sync_check_cuda_error();
|
||||
{
|
||||
dim3 block(std::min((int) batchSize, 256));
|
||||
dim3 grid(divUp((int) batchSize, (int) block.x));
|
||||
setTopPRuntimeArgs<<<grid, block, 0, mStream>>>(batchSize, topK, mRuntimeTopKDevice, runtimeTopKSize, topP,
|
||||
mRuntimeTopPDevice, runtimeTopPSize, mSkipDecodeDevice, batchSlots, mInitialTopPDevice);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
cudaAutoCpy(mSkipDecodeHost, mSkipDecodeDevice, mMaxBatchSize, mStream);
|
||||
std::vector<float> runtimeTopPs(mMaxBatchSize);
|
||||
cudaAutoCpy(runtimeTopPs.data(), mRuntimeTopPDevice, mMaxBatchSize, mStream);
|
||||
{
|
||||
float maxTopP = 0.f;
|
||||
for (size_t bi = 0; bi < batchSize; ++bi)
|
||||
{
|
||||
int32_t bid = bi;
|
||||
if (batchSlots)
|
||||
{
|
||||
bid = batchSlots[bi];
|
||||
}
|
||||
maxTopP = std::max(maxTopP, runtimeTopPs[bid]);
|
||||
}
|
||||
mRuntimeMaxTopP = std::max(mRuntimeMaxTopP, maxTopP);
|
||||
}
|
||||
|
||||
std::vector<float> runtime_top_ps(mMaxBatchSize);
|
||||
cudaAutoCpy(runtime_top_ps.data(), runtime_top_p_buf_, mMaxBatchSize, mStream);
|
||||
// TODO(nkorobov): find maxTopP using batch slots
|
||||
mRuntimeMaxTopP = *std::max_element(std::begin(runtime_top_ps), std::end(runtime_top_ps));
|
||||
|
||||
if (!is_deterministic_)
|
||||
if (!mIsDeterministic)
|
||||
{
|
||||
int smCnt = mCudaDeviceProp->multiProcessorCount;
|
||||
air_topp_block_num_ = calcAirTopPBlockNum<T, int, float>(batchSize, (int) mVocabSizePadded, smCnt);
|
||||
mAirTopPBlockNum = calcAirTopPBlockNum<T, int, float>(batchSize, (int) mVocabSizePadded, smCnt);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TopPSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingParams const& inputs)
|
||||
void TopPSamplingLayer<T>::forward(DecodingOutputParams& outputs, ForwardParams& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.logits.shape[0];
|
||||
|
||||
// in case of skip any, the logit value is already copied and processed.
|
||||
auto* logits = !mSkipAny ? inputs.logits.template getPtr<T>() : mRuntimeLogitsDevice;
|
||||
auto* endIds = inputs.end_ids.template getPtr<const int>();
|
||||
auto* batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : nullptr;
|
||||
// Probabilities must be already computed instead of logits
|
||||
auto probs = inputs.logits.template getPtr<T>();
|
||||
auto endIds = inputs.end_ids.template getPtr<const int>();
|
||||
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<const int>() : nullptr;
|
||||
auto curandStatesDevice = inputs.curand_states;
|
||||
auto samplingWorkspaceDevice = inputs.sampling_workspace;
|
||||
|
||||
if (is_deterministic_)
|
||||
TLLM_CHECK_WITH_INFO(curandStatesDevice, "No curand states provided");
|
||||
TLLM_CHECK_WITH_INFO(samplingWorkspaceDevice, "No sampling workspace provided");
|
||||
|
||||
if (mIsDeterministic)
|
||||
{
|
||||
invokeTopPInitialize(
|
||||
topp_id_vals_buf_, topp_offset_buf_, begin_topp_offset_buf_, batchSize, mVocabSizePadded, mStream);
|
||||
mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, batchSize, mVocabSizePadded, mStream);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
@ -299,33 +305,30 @@ void TopPSamplingLayer<T>::runSampling(DecodingOutputParams& outputs, DecodingPa
|
||||
FinishedState* finishedOutput = (outputs.finished)
|
||||
? reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>())
|
||||
: nullptr;
|
||||
invokeAddBiasSoftMax(logits, logits, (T*) (nullptr), endIds, finishedInput, batchSlots, batchSize, mVocabSize,
|
||||
mVocabSizePadded, mStream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
float* cumLogProbs = (outputs.cum_log_probs) ? outputs.cum_log_probs->template getPtr<float>() : nullptr;
|
||||
float* outputLogProbs = (outputs.output_log_probs) ? outputs.output_log_probs->template getPtr<float>() : nullptr;
|
||||
int* sequenceLength = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<int>() : nullptr;
|
||||
|
||||
if (is_deterministic_)
|
||||
if (mIsDeterministic)
|
||||
{
|
||||
invokeBatchTopPSampling<T>(mSamplingWorkspaceDevice, mSamplingWorkspaceSize, cub_temp_storage_size_,
|
||||
invokeBatchTopPSampling<T>(samplingWorkspaceDevice, mSamplingWorkspaceSize, mCubTempStorageSize,
|
||||
outputs.output_ids_ptr.template getPtr<int*>(), sequenceLength, finishedInput, finishedOutput, cumLogProbs,
|
||||
outputLogProbs, logits, topp_id_vals_buf_, topp_offset_buf_, begin_topp_offset_buf_, mCurandStatesDevice,
|
||||
batchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP, runtime_top_p_buf_, mStream, mSkipDecodeDevice,
|
||||
outputLogProbs, probs, mTopPIdValsDevice, mTopPOffsetDevice, mBeginTopPOffsetDevice, curandStatesDevice,
|
||||
batchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP, mRuntimeTopPDevice, mStream, mSkipDecodeDevice,
|
||||
batchSlots);
|
||||
sync_check_cuda_error();
|
||||
invokeComputeToppDecay(runtime_top_p_buf_, initial_top_p_buf_,
|
||||
outputs.output_ids_ptr.template getPtr<const int*>(), top_p_decay_buf_, top_p_min_buf_,
|
||||
top_p_reset_ids_buf_, sequenceLength, batchSlots, batchSize, mStream);
|
||||
invokeComputeToppDecay(mRuntimeTopPDevice, mInitialTopPDevice,
|
||||
outputs.output_ids_ptr.template getPtr<const int*>(), mTopPDecayDevice, mTopPMinDevice, mTopPResetIdsDevice,
|
||||
sequenceLength, batchSlots, batchSize, mStream);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
else
|
||||
{
|
||||
invokeBatchAirTopPSampling<T>(mSamplingWorkspaceDevice, mSamplingWorkspaceSize,
|
||||
invokeBatchAirTopPSampling<T>(samplingWorkspaceDevice, mSamplingWorkspaceSize,
|
||||
outputs.output_ids_ptr.template getPtr<int*>(), sequenceLength, finishedInput, finishedOutput, cumLogProbs,
|
||||
outputLogProbs, logits, mCurandStatesDevice, batchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP,
|
||||
runtime_top_p_buf_, mStream, air_topp_block_num_, mSkipDecodeDevice, batchSlots);
|
||||
outputLogProbs, probs, curandStatesDevice, batchSize, mVocabSizePadded, endIds, mRuntimeMaxTopP,
|
||||
mRuntimeTopPDevice, mStream, mAirTopPBlockNum, mSkipDecodeDevice, batchSlots);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
}
|
||||
@ -334,14 +337,7 @@ template <typename T>
|
||||
TopPSamplingLayer<T>::TopPSamplingLayer(std::size_t maxBatchSize, std::size_t vocabSize, std::size_t vocabSizePadded,
|
||||
cudaStream_t stream, std::shared_ptr<IAllocator> allocator, cudaDeviceProp* prop, bool isDeterministic)
|
||||
: BaseSamplingLayer<T>(maxBatchSize, vocabSize, vocabSizePadded, stream, std::move(allocator), prop)
|
||||
, is_deterministic_(isDeterministic)
|
||||
{
|
||||
allocateBuffer(mMaxBatchSize);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TopPSamplingLayer<T>::TopPSamplingLayer(TopPSamplingLayer<T> const& top_p_sampling_layer)
|
||||
: BaseSamplingLayer<T>(top_p_sampling_layer)
|
||||
, mIsDeterministic(isDeterministic)
|
||||
{
|
||||
allocateBuffer(mMaxBatchSize);
|
||||
}
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright (c) 2019-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2019-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
* Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -29,60 +29,60 @@ namespace layers
|
||||
{
|
||||
|
||||
//! \brief Layer to randomly sample tokens from TopP logits.
|
||||
//! Layer expects probs precomputed in "logits" tensor
|
||||
template <typename T>
|
||||
class TopPSamplingLayer : public BaseSamplingLayer<T>
|
||||
{
|
||||
public:
|
||||
using Base = BaseSamplingLayer<T>;
|
||||
using SetupParams = typename Base::SetupParams;
|
||||
using ForwardParams = typename Base::ForwardParams;
|
||||
|
||||
TopPSamplingLayer(std::size_t maxBatchSize, std::size_t vocabSize, std::size_t vocabSizePadded, cudaStream_t stream,
|
||||
std::shared_ptr<tensorrt_llm::common::IAllocator> allocator, cudaDeviceProp* prop, bool isDeterministic = true);
|
||||
TopPSamplingLayer(TopPSamplingLayer<T> const& top_p_sampling_layer);
|
||||
~TopPSamplingLayer();
|
||||
|
||||
void setup(std::size_t batchSize, int const* batch_slots, SetupParams const& setupParams) override;
|
||||
void setup(std::size_t batchSize, int32_t const* batchSlots, SetupParams const& setupParams) override;
|
||||
void forward(DecodingOutputParams& outputs, ForwardParams& inputs) override;
|
||||
|
||||
const bool* getSkipDecodeHost() const
|
||||
{
|
||||
return mSkipDecodeHost;
|
||||
}
|
||||
|
||||
protected:
|
||||
void runSampling(DecodingOutputParams& outputs, DecodingParams const& inputs) override;
|
||||
void freeBuffer() override;
|
||||
uint32_t* mRuntimeTopKDevice = nullptr;
|
||||
float* mRuntimeTopPDevice = nullptr;
|
||||
float mRuntimeMaxTopP{0.f};
|
||||
float* mInitialTopPDevice = nullptr;
|
||||
float* mTopPDecayDevice = nullptr;
|
||||
float* mTopPMinDevice = nullptr;
|
||||
int32_t* mTopPResetIdsDevice = nullptr;
|
||||
void* mSetupWorkspaceDevice = nullptr;
|
||||
|
||||
protected:
|
||||
uint32_t* runtime_top_k_buf_ = nullptr;
|
||||
float* runtime_top_p_buf_ = nullptr;
|
||||
float mRuntimeMaxTopP;
|
||||
float* initial_top_p_buf_ = nullptr;
|
||||
float* top_p_decay_buf_ = nullptr;
|
||||
float* top_p_min_buf_ = nullptr;
|
||||
int32_t* top_p_reset_ids_buf_ = nullptr;
|
||||
void* setup_workspace_buf_ = nullptr;
|
||||
|
||||
int32_t* topp_id_vals_buf_ = nullptr;
|
||||
int32_t* topp_offset_buf_ = nullptr;
|
||||
int32_t* begin_topp_offset_buf_ = nullptr;
|
||||
std::size_t cub_temp_storage_size_;
|
||||
bool is_deterministic_ = true;
|
||||
int air_topp_block_num_;
|
||||
int32_t* mTopPIdValsDevice = nullptr;
|
||||
int32_t* mTopPOffsetDevice = nullptr;
|
||||
int32_t* mBeginTopPOffsetDevice = nullptr;
|
||||
bool* mSkipDecodeDevice = nullptr;
|
||||
bool* mSkipDecodeHost = nullptr;
|
||||
size_t mCubTempStorageSize;
|
||||
bool mIsDeterministic = true;
|
||||
int mAirTopPBlockNum;
|
||||
|
||||
using Base::mMaxBatchSize;
|
||||
using Base::mVocabSize;
|
||||
using Base::mVocabSizePadded;
|
||||
|
||||
using Base::mSamplingWorkspaceSize;
|
||||
using Base::mSamplingWorkspaceDevice;
|
||||
using Base::mCurandStatesDevice;
|
||||
using Base::mSkipDecodeDevice;
|
||||
using Base::mSkipDecodeHost;
|
||||
using Base::mSkipAny;
|
||||
using Base::mRuntimeLogitsDevice;
|
||||
using Base::mAllocatedSize;
|
||||
|
||||
using Base::mStream;
|
||||
using Base::mAllocator;
|
||||
using Base::mIsAllocateBuffer;
|
||||
using Base::mCudaDeviceProp;
|
||||
|
||||
private:
|
||||
void allocateBuffer(std::size_t batchSize);
|
||||
void freeBuffer();
|
||||
};
|
||||
|
||||
} // namespace layers
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user