Update TensorRT-LLM (#1274)

* Update TensorRT-LLM

---------

Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
Kaiyu Xie 2024-03-12 18:15:52 +08:00 committed by GitHub
parent 728cc0044b
commit 4bb65f216f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
488 changed files with 23183 additions and 10468 deletions

View File

@ -59,6 +59,7 @@ PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000 PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 60 PenaltyReturnTypeOnItsOwnLine: 60
PointerAlignment: Left PointerAlignment: Left
QualifierAlignment: Right
ReflowComments: true ReflowComments: true
SeparateDefinitionBlocks: Always SeparateDefinitionBlocks: Always
SortIncludes: CaseSensitive SortIncludes: CaseSensitive

10
.gitignore vendored
View File

@ -17,6 +17,16 @@ venv/
.local/ .local/
.hypothesis/ .hypothesis/
.idea/ .idea/
dump*/
.trt-internal
*.dot
*.prof
*.log
*.pkl
*.hdf5
*.lock
config.json
/*.svg
cpp/cmake-build-* cpp/cmake-build-*
cpp/.ccache/ cpp/.ccache/
tensorrt_llm/libs tensorrt_llm/libs

View File

@ -355,6 +355,9 @@ however, that it is recommended to use the C++ version.
## Troubleshooting ## Troubleshooting
* If you encounter accuracy issues in the generated text, you may want to increase
the internal precision in the attention layer. For that, pass the `--context_fmha_fp32_acc enable` to
`trtllm-build`.
* It's recommended to add options `shm-size=1g ulimit memlock=-1` to the * It's recommended to add options `shm-size=1g ulimit memlock=-1` to the
docker or nvidia-docker run command. Otherwise you may see NCCL errors when docker or nvidia-docker run command. Otherwise you may see NCCL errors when

View File

@ -39,7 +39,6 @@ Take GPT-350M as an example for single GPU
``` ```
./benchmarks/gptSessionBenchmark \ ./benchmarks/gptSessionBenchmark \
--model gpt_350m \
--engine_dir "../../benchmarks/gpt_350m/" \ --engine_dir "../../benchmarks/gpt_350m/" \
--batch_size "1" \ --batch_size "1" \
--input_output_len "60,20" --input_output_len "60,20"
@ -50,7 +49,6 @@ Take GPT-350M as an example for single GPU
Take GPT-175B as an example for multiple GPUs Take GPT-175B as an example for multiple GPUs
``` ```
mpirun -n 8 ./benchmarks/gptSessionBenchmark \ mpirun -n 8 ./benchmarks/gptSessionBenchmark \
--model gpt_175b \
--engine_dir "../../benchmarks/gpt_175b/" \ --engine_dir "../../benchmarks/gpt_175b/" \
--batch_size "1" \ --batch_size "1" \
--input_output_len "60,20" --input_output_len "60,20"
@ -125,7 +123,6 @@ cd cpp/build
Take GPT-350M as an example for single GPU V1 batching Take GPT-350M as an example for single GPU V1 batching
``` ```
./benchmarks/gptManagerBenchmark \ ./benchmarks/gptManagerBenchmark \
--model gpt \
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \ --engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
--type V1 \ --type V1 \
--dataset ../../benchmarks/cpp/preprocessed_dataset.json --dataset ../../benchmarks/cpp/preprocessed_dataset.json
@ -135,7 +132,6 @@ Take GPT-350M as an example for single GPU V1 batching
Take GPT-350M as an example for 2-GPU inflight batching Take GPT-350M as an example for 2-GPU inflight batching
``` ```
mpirun -n 2 ./benchmarks/gptManagerBenchmark \ mpirun -n 2 ./benchmarks/gptManagerBenchmark \
--model gpt \
--engine_dir ../../examples/gpt/trt_engine/gpt2-ib/fp16/2-gpu/ \ --engine_dir ../../examples/gpt/trt_engine/gpt2-ib/fp16/2-gpu/ \
--type IFB \ --type IFB \
--dataset ../../benchmarks/cpp/preprocessed_dataset.json --dataset ../../benchmarks/cpp/preprocessed_dataset.json
@ -165,7 +161,6 @@ Given a `static_emulated_batch_size` of `n` the server will wait for `n` request
Take GPT-350M as an example for single GPU with static batching Take GPT-350M as an example for single GPU with static batching
``` ```
./benchmarks/gptManagerBenchmark \ ./benchmarks/gptManagerBenchmark \
--model gpt \
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \ --engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
--type IFB \ --type IFB \
--static_emulated_batch_size 32 \ --static_emulated_batch_size 32 \

View File

@ -237,7 +237,7 @@ int main(int argc, char* argv[])
benchmarkBert(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes, inLens, benchmarkBert(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes, inLens,
logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>()); logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>());
} }
catch (const std::exception& e) catch (std::exception const& e)
{ {
TLLM_LOG_ERROR(e.what()); TLLM_LOG_ERROR(e.what());
return 1; return 1;

View File

@ -24,6 +24,7 @@
#include "tensorrt_llm/common/stringUtils.h" #include "tensorrt_llm/common/stringUtils.h"
#include "tensorrt_llm/executor/executor.h" #include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/plugins/api/tllmPlugin.h" #include "tensorrt_llm/plugins/api/tllmPlugin.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/tllmLogger.h" #include "tensorrt_llm/runtime/tllmLogger.h"
#include "tensorrt_llm/runtime/worldConfig.h" #include "tensorrt_llm/runtime/worldConfig.h"
@ -64,20 +65,18 @@ struct BenchmarkParams
class WorkItem class WorkItem
{ {
public: public:
WorkItem(std::shared_ptr<InferenceRequest> ir, uint64_t requestId) WorkItem(std::shared_ptr<InferenceRequest> inferenceRequest, uint64_t requestId)
: mInferenceRequest(ir) : mInferenceRequest(std::move(inferenceRequest))
, mRequestId(requestId) , mRequestId(requestId)
{ {
} }
~WorkItem() {} [[nodiscard]] uint64_t requestId() const
uint64_t requestId() const
{ {
return mRequestId; return mRequestId;
} }
std::shared_ptr<InferenceRequest> getInferenceRequest() const [[nodiscard]] std::shared_ptr<InferenceRequest> getInferenceRequest() const
{ {
return mInferenceRequest; return mInferenceRequest;
} }
@ -93,7 +92,7 @@ class WorkItemsQueue
public: public:
void clear() void clear()
{ {
std::lock_guard<std::mutex> lk(mMutex); std::lock_guard<std::mutex> lock(mMutex);
mPendingWorkItems.clear(); mPendingWorkItems.clear();
mPendingWorkItemsReqIds.clear(); mPendingWorkItemsReqIds.clear();
mInProgressWorkItems.clear(); mInProgressWorkItems.clear();
@ -289,7 +288,7 @@ public:
if (outputFile.is_open()) if (outputFile.is_open())
{ {
for (const auto& header : headers) for (auto const& header : headers)
{ {
outputFile << header << ","; outputFile << header << ",";
} }
@ -340,13 +339,12 @@ public:
mExecutor = std::make_shared<texec::Executor>(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig); mExecutor = std::make_shared<texec::Executor>(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig);
} }
~ExecutorServer() {}
void enqueue(std::vector<texec::Request> requests, bool warmup = false) void enqueue(std::vector<texec::Request> requests, bool warmup = false)
{ {
try try
{ {
std::vector<SizeType> inputLengths, maxNewTokens; std::vector<SizeType> inputLengths;
std::vector<SizeType> maxNewTokens;
for (auto const& request : requests) for (auto const& request : requests)
{ {
inputLengths.push_back(request.getInputTokenIds().size()); inputLengths.push_back(request.getInputTokenIds().size());
@ -363,11 +361,10 @@ public:
mActiveCount++; mActiveCount++;
} }
} }
catch (const std::exception& e) catch (std::exception const& e)
{ {
TLLM_THROW("%s", e.what()); TLLM_THROW("%s", e.what());
} }
return;
} }
void waitForResponses(std::optional<SizeType> numRequests, bool warmup = false) void waitForResponses(std::optional<SizeType> numRequests, bool warmup = false)
@ -415,17 +412,16 @@ private:
class GptServer class GptServer
{ {
public: public:
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth, GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, SizeType maxBeamWidth,
batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams, batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams,
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId, std::chrono::milliseconds waitSleep, std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId, std::chrono::milliseconds waitSleep,
std::optional<uint64_t> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs, bool logIterationData) std::optional<SizeType> const staticEmulatedBatchSize,
std::optional<std::chrono::milliseconds> const batchTimeout, bool logIterationData)
: mRecorder(std::move(recorder)) : mRecorder(std::move(recorder))
, mTerminateReqId(terminateReqId) , mTerminateReqId(terminateReqId)
, mWaitSleep(waitSleep) , mWaitSleep(waitSleep)
, mStaticEmulatedBatchSize(staticEmulatedBatchSize) , mStaticEmulatedBatchSize(staticEmulatedBatchSize)
, mEmulatedBatchEndTimestamp( , mBatchTimeout(batchTimeout.value_or(std::chrono::milliseconds{0}))
std::chrono::steady_clock::now() + std::chrono::milliseconds(staticEmulatedTimeoutMs))
, mStaticEmulatedTimeoutMs(staticEmulatedTimeoutMs)
, mActiveCount(0) , mActiveCount(0)
{ {
ReturnBatchManagerStatsCallback iterationDataCallback = [this, logIterationData](std::string const& log) ReturnBatchManagerStatsCallback iterationDataCallback = [this, logIterationData](std::string const& log)
@ -473,16 +469,21 @@ public:
mRecorder->recordStart(request, requestId); mRecorder->recordStart(request, requestId);
mWorkItemsQueue.push(request, requestId); mWorkItemsQueue.push(request, requestId);
} }
catch (const tc::TllmException& e) catch (tc::TllmException const& e)
{ {
throw; throw;
} }
catch (const std::exception& e) catch (std::exception const& e)
{ {
TLLM_THROW("%s", e.what()); TLLM_THROW("%s", e.what());
} }
} }
void resetBatchDeadline()
{
mBatchDeadline = (std::chrono::steady_clock::now() + mBatchTimeout).time_since_epoch();
}
void waitForEmpty() const void waitForEmpty() const
{ {
while (!mWorkItemsQueue.empty()) while (!mWorkItemsQueue.empty())
@ -502,9 +503,9 @@ public:
} }
// Return up to max_num_requests inference requests. // Return up to max_num_requests inference requests.
std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(const int max_num_requests) std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(int const max_num_requests)
{ {
std::list<std::shared_ptr<InferenceRequest>> rval; std::list<std::shared_ptr<InferenceRequest>> inferenceRequests;
auto& comm = COMM_SESSION; auto& comm = COMM_SESSION;
if (max_num_requests > 0) if (max_num_requests > 0)
{ {
@ -515,12 +516,12 @@ public:
auto const numNewWorkItems = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()), auto const numNewWorkItems = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
static_cast<int64_t>(max_num_requests)); static_cast<int64_t>(max_num_requests));
bool readyForNextBatch = numNewWorkItems > 0; bool const timeout = std::chrono::steady_clock::now().time_since_epoch() > mBatchDeadline.load();
bool readyForNextBatch = numNewWorkItems > 0 && timeout;
if (mStaticEmulatedBatchSize) if (mStaticEmulatedBatchSize)
{ {
if (numNewWorkItems > 0) if (numNewWorkItems > 0)
{ {
bool const timeout = std::chrono::steady_clock::now() > mEmulatedBatchEndTimestamp;
bool const previousBatchFinished = mActiveCount == 0; bool const previousBatchFinished = mActiveCount == 0;
bool const haveEnoughForNextBatch = numNewWorkItems >= mStaticEmulatedBatchSize.value(); bool const haveEnoughForNextBatch = numNewWorkItems >= mStaticEmulatedBatchSize.value();
readyForNextBatch = previousBatchFinished && (timeout || haveEnoughForNextBatch); readyForNextBatch = previousBatchFinished && (timeout || haveEnoughForNextBatch);
@ -529,26 +530,23 @@ public:
{ {
// Timeout should only begin once we have at least 1 pending request. // Timeout should only begin once we have at least 1 pending request.
// Reset timeout when no requests are pending or we submit a new batch. // Reset timeout when no requests are pending or we submit a new batch.
mEmulatedBatchEndTimestamp resetBatchDeadline();
= std::chrono::steady_clock::now() + std::chrono::milliseconds(mStaticEmulatedTimeoutMs);
} }
} }
if (readyForNextBatch) if (readyForNextBatch)
{ {
int count = 0;
// Only add a single batch at a time when emulating static batching // Only add a single batch at a time when emulating static batching
auto const numItemsToAdd = std::min( auto const numItemsToAdd = std::min(
numNewWorkItems, static_cast<int64_t>(mStaticEmulatedBatchSize.value_or(numNewWorkItems))); numNewWorkItems, static_cast<int64_t>(mStaticEmulatedBatchSize.value_or(numNewWorkItems)));
mActiveCount += numItemsToAdd; mActiveCount += numItemsToAdd;
while (count < numItemsToAdd) while (inferenceRequests.size() < numItemsToAdd)
{ {
auto [workItem, markedInProgress] = mWorkItemsQueue.pop(); auto [workItem, markedInProgress] = mWorkItemsQueue.pop();
if (markedInProgress) if (markedInProgress)
{ {
rval.emplace_back(workItem->getInferenceRequest()); inferenceRequests.emplace_back(workItem->getInferenceRequest());
count++;
} }
else else
{ {
@ -561,14 +559,14 @@ public:
} }
if (world_size > 1) if (world_size > 1)
{ {
auto numNewWorkItems = static_cast<int64_t>(rval.size()); auto numNewWorkItems = static_cast<int64_t>(inferenceRequests.size());
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0); comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
if (numNewWorkItems > 0) if (numNewWorkItems > 0)
{ {
std::vector<int64_t> packed; std::vector<int64_t> packed;
for (auto const& ir : rval) for (auto const& infReq : inferenceRequests)
{ {
auto vpacked = ir->serialize(); auto vpacked = infReq->serialize();
packed.push_back(static_cast<int64_t>(vpacked.size())); packed.push_back(static_cast<int64_t>(vpacked.size()));
packed.insert( packed.insert(
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end())); packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
@ -590,18 +588,18 @@ public:
for (int64_t count = 0; count < numNewWorkItems; ++count) for (int64_t count = 0; count < numNewWorkItems; ++count)
{ {
int64_t n = *(packed_ptr++); int64_t n = *(packed_ptr++);
auto ir = InferenceRequest::deserialize(packed_ptr); auto infReq = InferenceRequest::deserialize(packed_ptr);
packed_ptr += n; packed_ptr += n;
rval.emplace_back(ir); inferenceRequests.emplace_back(infReq);
} }
} }
} }
} }
return rval; return inferenceRequests;
} }
void sendResponse(uint64_t requestId, [[maybe_unused]] std::list<NamedTensor> const& response_tensors, void sendResponse(uint64_t requestId, [[maybe_unused]] std::list<NamedTensor> const& response_tensors,
bool final_response, [[maybe_unused]] const std::string& errMsg) bool final_response, [[maybe_unused]] std::string const& errMsg)
{ {
// `response_tensors` contains `outputIds, sequenceLength, [contextLogits, generationLogits], logProbs, // `response_tensors` contains `outputIds, sequenceLength, [contextLogits, generationLogits], logProbs,
// cumLogProbs`. `contextLogits, generationLogits` are optional, only contained when `gather_context_logits` and // cumLogProbs`. `contextLogits, generationLogits` are optional, only contained when `gather_context_logits` and
@ -616,7 +614,7 @@ public:
mActiveCount--; mActiveCount--;
} }
} }
catch (const std::exception& e) catch (std::exception const& e)
{ {
TLLM_LOG_ERROR("Failed to send response for requestId %lu\n%s", requestId, e.what()); TLLM_LOG_ERROR("Failed to send response for requestId %lu\n%s", requestId, e.what());
} }
@ -628,9 +626,9 @@ private:
WorkItemsQueue mWorkItemsQueue; WorkItemsQueue mWorkItemsQueue;
std::optional<uint64_t> mTerminateReqId; std::optional<uint64_t> mTerminateReqId;
std::chrono::milliseconds mWaitSleep; std::chrono::milliseconds mWaitSleep;
std::optional<int> mStaticEmulatedBatchSize; std::optional<SizeType> mStaticEmulatedBatchSize;
std::chrono::time_point<std::chrono::steady_clock> mEmulatedBatchEndTimestamp; std::chrono::milliseconds mBatchTimeout;
int32_t mStaticEmulatedTimeoutMs; std::atomic<std::chrono::steady_clock::time_point::duration> mBatchDeadline;
std::atomic<uint64_t> mActiveCount; std::atomic<uint64_t> mActiveCount;
}; // class GptServer }; // class GptServer
@ -674,10 +672,9 @@ std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const&
auto request = std::make_shared<InferenceRequest>(reqId); auto request = std::make_shared<InferenceRequest>(reqId);
auto const& inputIds = sample.inputIds; auto const& inputIds = sample.inputIds;
request->setInputIds(bufferManager.copyFrom( request->setInputIds(bufferManager.copyFrom(
inputIds, ITensor::makeShape({static_cast<SizeType>(inputIds.size())}), MemoryType::kPINNED)); inputIds, ITensor::makeShape({static_cast<SizeType>(inputIds.size())}), MemoryType::kCPU));
auto const requestOutputLen = sample.outputLen; auto const requestOutputLen = sample.outputLen;
request->setMaxNewTokens( request->setMaxNewTokens(bufferManager.copyFrom(&requestOutputLen, ITensor::makeShape({1, 1}), MemoryType::kCPU));
bufferManager.copyFrom(&requestOutputLen, ITensor::makeShape({1, 1}), MemoryType::kPINNED));
request->setBeamWidth(beamWidthTensor); request->setBeamWidth(beamWidthTensor);
if (eosId != nullptr) if (eosId != nullptr)
{ {
@ -704,14 +701,15 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType const& beamWid
{ {
auto samplingConfig = texec::SamplingConfig{beamWidth}; auto samplingConfig = texec::SamplingConfig{beamWidth};
auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false}; auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false};
return texec::Request(sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId); return {sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId};
} }
void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType, void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType,
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp, std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
std::optional<int32_t> const& eosId, std::optional<int32_t> const& padId, BenchmarkParams const& benchmarkParams, std::optional<TokenIdType> const& eosId, std::optional<TokenIdType> const& padId,
batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits, BenchmarkParams const& benchmarkParams, batch_scheduler::SchedulerPolicy schedulerPolicy,
bool returnGenerationLogits, std::optional<int> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs, std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits,
std::optional<SizeType> const staticEmulatedBatchSize, std::optional<std::chrono::milliseconds> const batchTimeout,
bool logIterationData) bool logIterationData)
{ {
auto const worldConfig = WorldConfig::mpi(); auto const worldConfig = WorldConfig::mpi();
@ -736,14 +734,14 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
bufferManager.copyFrom(&beamWidth, ITensor::makeShape({1}), MemoryType::kPINNED)}; bufferManager.copyFrom(&beamWidth, ITensor::makeShape({1}), MemoryType::kPINNED)};
// Load dataset // Load dataset
const auto samples = parseWorkloadJson(datasetPath, maxNumSamples); auto const samples = parseWorkloadJson(datasetPath, maxNumSamples);
const auto numSamples = samples.size(); auto const numSamples = samples.size();
const int maxBeamWidth = beamWidth; int const maxBeamWidth = beamWidth;
auto recorder = std::make_shared<Recorder>(opCsvFile); auto recorder = std::make_shared<Recorder>(opCsvFile);
uint64_t terminateReqId = numSamples + 1; uint64_t terminateReqId = numSamples + 1;
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams, auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams,
recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, staticEmulatedTimeoutMs, logIterationData); recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, batchTimeout, logIterationData);
ITensor::SharedPtr eosIdTensor{ ITensor::SharedPtr eosIdTensor{
eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr}; eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr};
@ -761,6 +759,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
if (worldConfig.getRank() == 0) if (worldConfig.getRank() == 0)
{ {
// Warm up // Warm up
gptServer->resetBatchDeadline();
SizeType reqId = 0; SizeType reqId = 0;
for (auto i = 0; i < warmUp; ++i) for (auto i = 0; i < warmUp; ++i)
{ {
@ -774,6 +773,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
// Benchmark // Benchmark
recorder->initialize(); recorder->initialize();
gptServer->resetBatchDeadline();
for (std::size_t i = 0; i < numSamples; ++i) for (std::size_t i = 0; i < numSamples; ++i)
{ {
auto request = makeRequest(i + 1, samples[i], beamWidthTensor, eosIdTensor, padIdTensor, bufferManager, auto request = makeRequest(i + 1, samples[i], beamWidthTensor, eosIdTensor, padIdTensor, bufferManager,
@ -806,23 +806,19 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m
batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits, batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits,
bool returnGenerationLogits, std::optional<int> const staticEmulatedBatchSize, bool logIterationData) bool returnGenerationLogits, std::optional<int> const staticEmulatedBatchSize, bool logIterationData)
{ {
// Check that mpi size is 1 for now auto const& world = tensorrt_llm::mpi::MpiComm::world();
auto const worldConfig = WorldConfig::mpi(); auto worldRank = world.getRank();
if (worldConfig.getSize() > 1)
{
TLLM_THROW("benchmarkExecutor does not yet support mpiSize > 1");
}
// Load dataset // Load dataset
const auto samples = parseWorkloadJson(datasetPath, maxNumSamples); auto const samples = parseWorkloadJson(datasetPath, maxNumSamples);
const auto numSamples = samples.size(); auto const numSamples = samples.size();
auto recorder = std::make_shared<Recorder>(opCsvFile); auto recorder = std::make_shared<Recorder>(opCsvFile);
auto executorServer = std::make_shared<ExecutorServer>(engineDir, modelType, beamWidth, schedulerPolicy, auto executorServer = std::make_shared<ExecutorServer>(engineDir, modelType, beamWidth, schedulerPolicy,
benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData); benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData);
if (worldConfig.getRank() == 0) if (worldRank == 0)
{ {
// Warm up // Warm up
{ {
@ -849,7 +845,7 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m
delays.push_back(static_cast<int>(samples[i].delay * 1000)); delays.push_back(static_cast<int>(samples[i].delay * 1000));
} }
bool hasDelay = std::any_of(delays.begin(), delays.end(), [](const auto& delay) { return delay > 0; }); bool hasDelay = std::any_of(delays.begin(), delays.end(), [](auto const& delay) { return delay > 0; });
if (hasDelay && staticEmulatedBatchSize) if (hasDelay && staticEmulatedBatchSize)
{ {
TLLM_THROW("Executor benchmark doesn't support delays with emulated static batch sizes"); TLLM_THROW("Executor benchmark doesn't support delays with emulated static batch sizes");
@ -910,9 +906,6 @@ int main(int argc, char* argv[])
cxxopts::Options options( cxxopts::Options options(
"TensorRT-LLM BatchManager Benchmark", "TensorRT-LLM BatchManager Benchmark for GPT and GPT-like models."); "TensorRT-LLM BatchManager Benchmark", "TensorRT-LLM BatchManager Benchmark for GPT and GPT-like models.");
options.add_options()("h,help", "Print usage"); options.add_options()("h,help", "Print usage");
// TODO(rkobus): remove because unused
options.add_options()(
"m,model", "Model name specified for engines.", cxxopts::value<std::string>()->default_value("gpt_350m"));
options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>()); options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>());
options.add_options()( options.add_options()(
"api", "API type: gptManager or executor.", cxxopts::value<std::string>()->default_value("gptManager")); "api", "API type: gptManager or executor.", cxxopts::value<std::string>()->default_value("gptManager"));
@ -929,8 +922,8 @@ int main(int argc, char* argv[])
options.add_options()( options.add_options()(
"warm_up", "Specify warm up iterations before benchmark starts.", cxxopts::value<int>()->default_value("2")); "warm_up", "Specify warm up iterations before benchmark starts.", cxxopts::value<int>()->default_value("2"));
options.add_options()( options.add_options()(
"eos_id", "Specify the end-of-sequence token id.", cxxopts::value<int>()->default_value("-1")); "eos_id", "Specify the end-of-sequence token id.", cxxopts::value<TokenIdType>()->default_value("-1"));
options.add_options()("pad_id", "Specify the padding token id.", cxxopts::value<int>()); options.add_options()("pad_id", "Specify the padding token id.", cxxopts::value<TokenIdType>());
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>()); options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
options.add_options()( options.add_options()(
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>()); "kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
@ -949,11 +942,15 @@ int main(int argc, char* argv[])
options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.", options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.",
cxxopts::value<std::string>()->default_value("guaranteed_no_evict")); cxxopts::value<std::string>()->default_value("guaranteed_no_evict"));
options.add_options()("first_batch_delay",
"Delay before submitting the first batch of requests. This can be used to increase the size of the first "
"batch.",
cxxopts::value<int32_t>());
options.add_options()("static_emulated_batch_size", options.add_options()("static_emulated_batch_size",
"Emulate static batching performance with the provided batch size.", cxxopts::value<int>()); "Emulate static batching performance with the provided batch size.", cxxopts::value<SizeType>());
options.add_options()("static_emulated_timeout", options.add_options()("static_emulated_timeout",
"Timeout (ms) before launching a partial batch in emulated static batching mode", "Timeout (ms) before launching a partial batch in emulated static batching mode",
cxxopts::value<int>()->default_value("500")); cxxopts::value<int32_t>()->default_value("500"));
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.", options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
cxxopts::value<std::string>()->default_value("error")); cxxopts::value<std::string>()->default_value("error"));
options.add_options()("log_iteration_data", "On each decoder iteration, print batch state metadata.", options.add_options()("log_iteration_data", "On each decoder iteration, print batch state metadata.",
@ -1042,23 +1039,31 @@ int main(int argc, char* argv[])
// Argument: Enable return context logits // Argument: Enable return context logits
bool returnGenerationLogits = result["return_generation_logits"].as<bool>(); bool returnGenerationLogits = result["return_generation_logits"].as<bool>();
std::optional<int32_t> padId; std::optional<TokenIdType> padId;
// Argument: Padding token id // Argument: Padding token id
if (result.count("pad_id")) if (result.count("pad_id"))
{ {
padId = result["pad_id"].as<int>(); padId = result["pad_id"].as<TokenIdType>();
} }
// Argument: End-of-sentence token id // Argument: End-of-sentence token id
std::optional<int32_t> eosId = result["eos_id"].as<int>(); std::optional<TokenIdType> eosId = result["eos_id"].as<TokenIdType>();
std::optional<int> staticEmulatedBatchSize; std::optional<std::chrono::milliseconds> batchTimeout;
// Argument: first_batch_delay
if (result.count("first_batch_delay"))
{
batchTimeout = std::chrono::milliseconds{result["first_batch_delay"].as<int32_t>()};
}
std::optional<SizeType> staticEmulatedBatchSize;
// Argument: Static emulated batch size // Argument: Static emulated batch size
if (result.count("static_emulated_batch_size")) if (result.count("static_emulated_batch_size"))
{ {
staticEmulatedBatchSize = result["static_emulated_batch_size"].as<int>(); staticEmulatedBatchSize = result["static_emulated_batch_size"].as<SizeType>();
batchTimeout = std::chrono::milliseconds{result["static_emulated_timeout"].as<int32_t>()};
} }
auto const staticEmulatedTimeout = result["static_emulated_timeout"].as<int>();
// Argument: Scheduler policy // Argument: Scheduler policy
batch_scheduler::SchedulerPolicy schedulerPolicy; batch_scheduler::SchedulerPolicy schedulerPolicy;
@ -1114,10 +1119,10 @@ int main(int argc, char* argv[])
{ {
benchmarkGptManager(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile, benchmarkGptManager(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile,
maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy, maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy,
waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, staticEmulatedTimeout, waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, batchTimeout,
logIterationData); logIterationData);
} }
catch (const std::exception& e) catch (std::exception const& e)
{ {
TLLM_LOG_ERROR(e.what()); TLLM_LOG_ERROR(e.what());
return 1; return 1;
@ -1131,7 +1136,7 @@ int main(int argc, char* argv[])
beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy, waitSleep, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy, waitSleep,
returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, logIterationData); returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, logIterationData);
} }
catch (const std::exception& e) catch (std::exception const& e)
{ {
TLLM_LOG_ERROR(e.what()); TLLM_LOG_ERROR(e.what());
return 1; return 1;

View File

@ -15,7 +15,6 @@
* limitations under the License. * limitations under the License.
*/ */
#include "tensorrt_llm/common/cudaUtils.h" #include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/memoryUtils.h"
#include "tensorrt_llm/plugins/api/tllmPlugin.h" #include "tensorrt_llm/plugins/api/tllmPlugin.h"
#include "tensorrt_llm/runtime/gptJsonConfig.h" #include "tensorrt_llm/runtime/gptJsonConfig.h"
#include "tensorrt_llm/runtime/gptSession.h" #include "tensorrt_llm/runtime/gptSession.h"
@ -56,12 +55,11 @@ size_t monitorMemory(std::atomic_bool& done)
return peakMem; return peakMem;
} }
void benchmarkGptSession(std::string const& modelName, std::filesystem::path const& dataPath, void benchmarkGptSession(std::filesystem::path const& dataPath, std::vector<int> const& batchSizes, int beamWidth,
std::vector<int> const& batchSizes, int beamWidth, std::vector<std::vector<int>> const& inOutLen, std::vector<std::vector<int>> const& inOutLen, std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp,
std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp, int numRuns, int duration, int numRuns, int duration, GptSession::Config& sessionConfig, bool cudaGraphMode, bool printAllLogits,
GptSession::Config& sessionConfig, bool cudaGraphMode, bool printAllLogits, bool disableForceMaxTokens) bool disableForceMaxTokens)
{ {
std::string modelNameHyphen = modelName;
std::filesystem::path jsonFileName = dataPath / "config.json"; std::filesystem::path jsonFileName = dataPath / "config.json";
auto const json = GptJsonConfig::parse(jsonFileName); auto const json = GptJsonConfig::parse(jsonFileName);
auto const modelConfig = json.getModelConfig(); auto const modelConfig = json.getModelConfig();
@ -69,7 +67,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
SizeType deviceCount{0}; SizeType deviceCount{0};
TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount)); TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
auto const worldConfig = WorldConfig::mpi(deviceCount, json.getTensorParallelism(), json.getPipelineParallelism()); auto const worldConfig = WorldConfig::mpi(deviceCount, json.getTensorParallelism(), json.getPipelineParallelism());
auto const enginePath = dataPath / json.engineFilename(worldConfig, modelNameHyphen); auto const enginePath = dataPath / json.engineFilename(worldConfig);
auto const dtype = modelConfig.getDataType(); auto const dtype = modelConfig.getDataType();
auto const maxNumTokens = modelConfig.getMaxNumTokens(); auto const maxNumTokens = modelConfig.getMaxNumTokens();
auto const useHalf = (dtype == nvinfer1::DataType::kHALF); auto const useHalf = (dtype == nvinfer1::DataType::kHALF);
@ -104,7 +102,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
auto& memoryCounter = MemoryCounters::getInstance(); auto& memoryCounter = MemoryCounters::getInstance();
TLLM_LOG_INFO(memoryCounter.toString()); TLLM_LOG_INFO(memoryCounter.toString());
std::atomic_bool done;
for (auto const batchSize : batchSizes) for (auto const batchSize : batchSizes)
{ {
if (inputPacked && maxNumTokens != std::nullopt) if (inputPacked && maxNumTokens != std::nullopt)
@ -114,10 +112,11 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
"benchmark on %d tokens", "benchmark on %d tokens",
maxNumTokens.value(), maxBatchSize * maxInputLength); maxNumTokens.value(), maxBatchSize * maxInputLength);
} }
std::atomic_bool done = false; done = false;
auto peakMemFuture = std::async(&monitorMemory, std::ref(done));
size_t peakMem;
try try
{ {
auto peakMemFuture = std::async(&monitorMemory, std::ref(done));
TLLM_LOG_INFO(memoryCounter.toString()); TLLM_LOG_INFO(memoryCounter.toString());
std::vector<SizeType> inputLengthsHost(batchSize, maxInputLength); std::vector<SizeType> inputLengthsHost(batchSize, maxInputLength);
@ -205,7 +204,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
TLLM_LOG_INFO(memoryCounter.toString()); TLLM_LOG_INFO(memoryCounter.toString());
done = true; done = true;
size_t peakMem = peakMemFuture.get(); peakMemFuture.wait();
peakMem = peakMemFuture.get();
printf("Benchmarking done. Iteration: %d, duration: %.2f sec.\n", iterIdx, curDuration / 1000); printf("Benchmarking done. Iteration: %d, duration: %.2f sec.\n", iterIdx, curDuration / 1000);
@ -275,6 +275,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
std::size_t found = std::string(e.what()).find("out of memory"); std::size_t found = std::string(e.what()).find("out of memory");
// We need to kill the memory monitor when OOM. // We need to kill the memory monitor when OOM.
done = true; done = true;
peakMemFuture.wait();
peakMem = peakMemFuture.get();
// Unexpected error; rethrow // Unexpected error; rethrow
if (found == std::string::npos) if (found == std::string::npos)
@ -297,6 +299,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
{ {
// We need to kill memory monitor when any other issue occurs // We need to kill memory monitor when any other issue occurs
done = true; done = true;
peakMemFuture.wait();
peakMem = peakMemFuture.get();
throw; throw;
} }
} }
@ -311,8 +315,6 @@ int main(int argc, char* argv[])
cxxopts::Options options( cxxopts::Options options(
"TensorRT-LLM C++ Runtime Benchmark", "TensorRT-LLM C++ Runtime Benchmark for GPT and GPT-like models."); "TensorRT-LLM C++ Runtime Benchmark", "TensorRT-LLM C++ Runtime Benchmark for GPT and GPT-like models.");
options.add_options()("h,help", "Print usage"); options.add_options()("h,help", "Print usage");
options.add_options()(
"m,model", "Model name specified for engines.", cxxopts::value<std::string>()->default_value("gpt_350m"));
options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>()); options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>());
options.add_options()("batch_size", options.add_options()("batch_size",
"Specify batch size(s) you want to benchmark. Multiple batch sizes can be separated by \";\", example: " "Specify batch size(s) you want to benchmark. Multiple batch sizes can be separated by \";\", example: "
@ -459,11 +461,11 @@ int main(int argc, char* argv[])
try try
{ {
benchmarkGptSession(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes, benchmarkGptSession(result["engine_dir"].as<std::string>(), batchSizes, beamWidth, inOutLen, logger,
beamWidth, inOutLen, logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>(), sessionConfig,
result["duration"].as<int>(), sessionConfig, enableCudaGraph, printAllLogits, disableForceMaxTokens); enableCudaGraph, printAllLogits, disableForceMaxTokens);
} }
catch (const std::exception& e) catch (std::exception const& e)
{ {
TLLM_LOG_ERROR(e.what()); TLLM_LOG_ERROR(e.what());
return 1; return 1;

View File

@ -86,6 +86,7 @@ class EncDecBuildConfig:
max_output_len: Optional[int] = None max_output_len: Optional[int] = None
builder_opt: Optional[int] = None builder_opt: Optional[int] = None
n_mels: Optional[int] = None n_mels: Optional[int] = None
skip_cross_qkv: bool = False
def __post_init__(self) -> None: def __post_init__(self) -> None:
assert self.head_size is not None assert self.head_size is not None

View File

@ -89,7 +89,11 @@ class BaseBenchmark(object):
(f'Engine world size ({world_size}) != Runtime world size ({self.world_size})') (f'Engine world size ({world_size}) != Runtime world size ({self.world_size})')
# Load config into self # Load config into self
for key, value in self.config['pretrained_config'].items(): for key, value in self.config['pretrained_config'].items():
setattr(self, key, value) if key == "ssm_cfg":
for ssm_key, ssm_value in value.items():
setattr(self, "mamba_" + ssm_key, ssm_value)
else:
setattr(self, key, value)
self.quant_mode = QuantMode.from_quant_algo( self.quant_mode = QuantMode.from_quant_algo(
quant_algo=self.quantization['quant_algo'], quant_algo=self.quantization['quant_algo'],

View File

@ -327,9 +327,16 @@ def main(args):
torch.cuda.empty_cache() torch.cuda.empty_cache()
latencies = [] latencies = []
# Disable Host memory monitor when cuda graph is enabled for cuda graph performance.
disable_host_mem_monitor = False
if args.enable_cuda_graph:
logger.warning(
'Disable host memory monitor when cuda graph is enabled.')
disable_host_mem_monitor = True
if not disable_mem_monitor: if not disable_mem_monitor:
memory_monitor = MemoryMonitor() memory_monitor = MemoryMonitor(
disable_host_mem_monitor=disable_host_mem_monitor)
memory_monitor.start() memory_monitor.start()
iter_idx = 0 iter_idx = 0

View File

@ -648,9 +648,12 @@ def build_gpt(args):
'tp_size': world_size, 'tp_size': world_size,
}, },
} }
config = PretrainedConfig.from_dict(config) config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(config) tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(config)
elif family == "internlm": elif family == "internlm":
quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization)
config = { config = {
'architecture': 'architecture':
'LLaMAForCausalLM', 'LLaMAForCausalLM',
@ -673,8 +676,10 @@ def build_gpt(args):
build_config['n_positions'], build_config['n_positions'],
'hidden_act': 'hidden_act':
build_config['hidden_act'], build_config['hidden_act'],
'quantization': 'quantization': {
quant_mode.to_dict(), 'quant_algo': quant_algo,
'kv_cache_quant_algo': kv_cache_quant_algo
},
'mapping': { 'mapping': {
'world_size': world_size, 'world_size': world_size,
'tp_size': world_size 'tp_size': world_size
@ -696,6 +701,7 @@ def build_gpt(args):
"has_zero_point": True, "has_zero_point": True,
"pre_quant_scale": False, "pre_quant_scale": False,
}) })
config = PretrainedConfig.from_dict(config) config = PretrainedConfig.from_dict(config)
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(config) tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(config)
elif family == "qwen": elif family == "qwen":
@ -1038,6 +1044,7 @@ def enc_dec_build_helper(component, config, args):
or quant_mode.is_int8_weight_only()), or quant_mode.is_int8_weight_only()),
quant_mode=quant_mode, quant_mode=quant_mode,
n_mels=n_mels, n_mels=n_mels,
skip_cross_qkv=config['skip_cross_qkv'],
) )
# build engine # build engine

View File

@ -22,7 +22,7 @@ from tensorrt_llm.profiler import (MemUnitType, bytes_to_target_unit,
class MemoryMonitor: class MemoryMonitor:
def __init__(self, query_interval=0.1): def __init__(self, query_interval=0.1, disable_host_mem_monitor=False):
self.query_interval = query_interval # second(s) self.query_interval = query_interval # second(s)
self.mem_monitor_process = None self.mem_monitor_process = None
# bytes # bytes
@ -35,6 +35,8 @@ class MemoryMonitor:
self.signal_event = Event() # Sending signal to subprocess self.signal_event = Event() # Sending signal to subprocess
self.peak_mem_queue = Queue() # Receiving results from subprocess self.peak_mem_queue = Queue() # Receiving results from subprocess
self.disable_host_mem_monitor = disable_host_mem_monitor
def start(self): def start(self):
self.mem_monitor_process = Process(target=self._upd_peak_memory_usage, self.mem_monitor_process = Process(target=self._upd_peak_memory_usage,
args=(self.signal_event, args=(self.signal_event,
@ -70,7 +72,10 @@ class MemoryMonitor:
peak_mem_queue.put((peak_host_used, peak_device_used)) peak_mem_queue.put((peak_host_used, peak_device_used))
def get_memory_usage(self): def get_memory_usage(self):
host_used, _, _ = host_memory_info(self.pid) if self.disable_host_mem_monitor:
host_used = 0
else:
host_used, _, _ = host_memory_info(self.pid)
device_used, _, _ = device_memory_info() device_used, _, _ = device_memory_info()
return host_used, device_used return host_used, device_used

View File

@ -36,6 +36,7 @@ option(NVTX_DISABLE "Disable all NVTX features" ON)
option(WARNING_IS_ERROR "Treat all warnings as errors" OFF) option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)
option(FAST_BUILD "Skip compiling some kernels to accelerate compiling" OFF) option(FAST_BUILD "Skip compiling some kernels to accelerate compiling" OFF)
option(FAST_MATH "Compiling in fast math mode" OFF) option(FAST_MATH "Compiling in fast math mode" OFF)
option(INDEX_RANGE_CHECK "Compiling with index range checks" OFF)
if(NVTX_DISABLE) if(NVTX_DISABLE)
add_compile_definitions("NVTX_DISABLE") add_compile_definitions("NVTX_DISABLE")
@ -97,6 +98,11 @@ if(FAST_BUILD)
message(WARNING "Skip some kernels to accelerate compilation") message(WARNING "Skip some kernels to accelerate compilation")
endif() endif()
if(INDEX_RANGE_CHECK)
add_compile_definitions("INDEX_RANGE_CHECK")
message(WARNING "Check index range to detect OOB accesses")
endif()
# Determine CUDA version before enabling the language extension # Determine CUDA version before enabling the language extension
check_language(CUDA) check_language(CUDA)
if(CMAKE_CUDA_COMPILER) if(CMAKE_CUDA_COMPILER)
@ -162,10 +168,6 @@ message(STATUS " version: ${CUDAToolkit_VERSION}")
message(STATUS " libraries: ${CUDAToolkit_LIBRARY_DIR}") message(STATUS " libraries: ${CUDAToolkit_LIBRARY_DIR}")
message(STATUS " include path: ${CUDAToolkit_INCLUDE_DIRS}") message(STATUS " include path: ${CUDAToolkit_INCLUDE_DIRS}")
find_library(
CUDNN_LIB cudnn
HINTS ${CUDNN_ROOT_DIR} ${CUDAToolkit_LIBRARY_DIR}
PATH_SUFFIXES lib64 lib lib/x64)
set(CUBLAS_LIB CUDA::cublas) set(CUBLAS_LIB CUDA::cublas)
set(CUBLASLT_LIB CUDA::cublasLt) set(CUBLASLT_LIB CUDA::cublasLt)
set(CUDA_DRV_LIB CUDA::cuda_driver) set(CUDA_DRV_LIB CUDA::cuda_driver)

View File

@ -29,9 +29,9 @@ class InferenceRequest;
class NamedTensor; class NamedTensor;
using GetInferenceRequestsCallback = std::function<std::list<std::shared_ptr<InferenceRequest>>(int32_t)>; using GetInferenceRequestsCallback = std::function<std::list<std::shared_ptr<InferenceRequest>>(int32_t)>;
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, const std::string&)>; using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, std::string const&)>;
using PollStopSignalCallback = std::function<std::unordered_set<uint64_t>()>; using PollStopSignalCallback = std::function<std::unordered_set<uint64_t>()>;
// json of stats as a string // json of stats as a string
using ReturnBatchManagerStatsCallback = std::function<void(const std::string&)>; using ReturnBatchManagerStatsCallback = std::function<void(std::string const&)>;
} // namespace tensorrt_llm::batch_manager } // namespace tensorrt_llm::batch_manager

View File

@ -312,9 +312,9 @@ public:
[[nodiscard]] std::vector<int64_t> serialize() const; [[nodiscard]] std::vector<int64_t> serialize() const;
static std::shared_ptr<InferenceRequest> deserialize(const std::vector<int64_t>& packed); static std::shared_ptr<InferenceRequest> deserialize(std::vector<int64_t> const& packed);
static std::shared_ptr<InferenceRequest> deserialize(const int64_t* packed_ptr); static std::shared_ptr<InferenceRequest> deserialize(int64_t const* packed_ptr);
}; };
} // namespace tensorrt_llm::batch_manager } // namespace tensorrt_llm::batch_manager

View File

@ -50,6 +50,13 @@ public:
{ {
} }
bool operator==(KvCacheConfig const& other) const
{
return maxTokens == other.maxTokens && maxAttentionWindow == other.maxAttentionWindow
&& sinkTokenLength == other.sinkTokenLength && freeGpuMemoryFraction == other.freeGpuMemoryFraction
&& enableBlockReuse == other.enableBlockReuse && useUvm == other.useUvm;
}
std::optional<SizeType> maxTokens; std::optional<SizeType> maxTokens;
std::optional<SizeType> maxAttentionWindow; std::optional<SizeType> maxAttentionWindow;
std::optional<SizeType> sinkTokenLength; std::optional<SizeType> sinkTokenLength;

View File

@ -176,6 +176,13 @@ public:
mNumTokens += n; mNumTokens += n;
} }
void removeTokens(SizeType n)
{
TLLM_CHECK(n <= mNumTokens);
TLLM_CHECK(mNumTokens - n >= 0);
mNumTokens -= n;
}
[[nodiscard]] SizeType getSequenceSlotIdx() const [[nodiscard]] SizeType getSequenceSlotIdx() const
{ {
return mSeqSlotIdx; return mSeqSlotIdx;
@ -214,6 +221,14 @@ public:
} }
} }
void removeLastBlock()
{
for (auto& beamBlockIds : mCacheBlockIds)
{
beamBlockIds.pop_back();
}
}
void setNumPrepopulatedTokens(std::vector<int> numPrepopulatedTokens) void setNumPrepopulatedTokens(std::vector<int> numPrepopulatedTokens)
{ {
mNumPrepopulatedTokens = std::move(numPrepopulatedTokens); mNumPrepopulatedTokens = std::move(numPrepopulatedTokens);
@ -280,32 +295,40 @@ public:
//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks //! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
void schedulingReleaseBlocks(GenerationRequest& sequence); void schedulingReleaseBlocks(GenerationRequest& sequence);
[[nodiscard]] SizeType getNumFreeBlocks() const //! \brief Release last block in the sequence
void releaseLastBlock(GenerationRequest& sequence);
[[nodiscard]] SizeType getNumFreeBlocks() const noexcept
{ {
return mFreeBlocks.size(); return mFreeBlocks.size();
} }
[[nodiscard]] SizeType getNumAllocatedBlocks() const [[nodiscard]] SizeType getNumReusedBlocks() const noexcept
{
return mReusedBlocks;
}
[[nodiscard]] SizeType getNumAllocatedBlocks() const noexcept
{ {
return getMaxNumBlocks() - getNumFreeBlocks(); return getMaxNumBlocks() - getNumFreeBlocks();
} }
[[nodiscard]] bool hasFreeBlocks(SizeType numRequired = 1) const [[nodiscard]] bool hasFreeBlocks(SizeType numRequired = 1) const noexcept
{ {
return getNumFreeBlocks() >= numRequired; return getNumFreeBlocks() >= numRequired;
} }
[[nodiscard]] bool schedulingHasFreeBlocks(SizeType numRequired = 1) const [[nodiscard]] bool schedulingHasFreeBlocks(SizeType numRequired = 1) const noexcept
{ {
return mSchedulingNumFreeBlocks >= numRequired; return mSchedulingNumFreeBlocks >= numRequired;
} }
[[nodiscard]] SizeType getMaxNumBlocks() const [[nodiscard]] SizeType getMaxNumBlocks() const noexcept
{ {
return static_cast<SizeType>(mAllBlocksByIdx.size()); return static_cast<SizeType>(mAllBlocksByIdx.size());
} }
[[nodiscard]] SizeType getTokensPerBlock() const [[nodiscard]] SizeType getTokensPerBlock() const noexcept
{ {
return mTokensPerBlock; return mTokensPerBlock;
} }
@ -478,11 +501,15 @@ public:
return mEnableBlockReuse; return mEnableBlockReuse;
} }
void removeToken(SizeType seqSlotIdx);
void rewindKVCache(SizeType seqSlotIdx, SizeType rewindLengths);
private: private:
void resetBlockPointers(SizeType seqSlotIdx, SizeType beamWidth); void resetBlockPointers(SizeType seqSlotIdx, SizeType beamWidth);
void cacheBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx); void cacheBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
void cacheNewBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx); void cacheNewBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
void updateNewBlockPointer(const GenerationRequest& seq, SizeType seqSlotIdx, SizeType blockIdx); void updateNewBlockPointer(GenerationRequest const& seq, SizeType seqSlotIdx, SizeType blockIdx);
void updateToken(SizeType seqSlotIdx, bool addToken);
private: private:
// Number of elements per one blocks // Number of elements per one blocks

View File

@ -474,7 +474,7 @@ public:
return mDraftTokens->size(); return mDraftTokens->size();
} }
void setReturnContextLogits(const bool returnContextLogits) void setReturnContextLogits(bool const returnContextLogits)
{ {
mReturnContextLogits = returnContextLogits; mReturnContextLogits = returnContextLogits;
} }
@ -484,7 +484,7 @@ public:
return mReturnContextLogits; return mReturnContextLogits;
} }
void setReturnGenerationLogits(const bool returnGenerationLogits) void setReturnGenerationLogits(bool const returnGenerationLogits)
{ {
mReturnGenerationLogits = returnGenerationLogits; mReturnGenerationLogits = returnGenerationLogits;
} }
@ -556,6 +556,11 @@ public:
return mState == REQUEST_STATE_GENERATION_IN_PROGRESS; return mState == REQUEST_STATE_GENERATION_IN_PROGRESS;
} }
[[nodiscard]] bool isGenerationCompleteState() const noexcept
{
return mState == REQUEST_STATE_GENERATION_COMPLETE;
}
/// To determine whether the context is unchunked. When a context is chunked into only a part, it /// To determine whether the context is unchunked. When a context is chunked into only a part, it
/// is still different from the unchunked state, which indicates the initial status. /// is still different from the unchunked state, which indicates the initial status.
[[nodiscard]] bool isFullContextRequest() const noexcept [[nodiscard]] bool isFullContextRequest() const noexcept

View File

@ -64,7 +64,7 @@ public:
using TensorPtr = Base::TensorPtr; using TensorPtr = Base::TensorPtr;
NamedTensor( NamedTensor(
nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, const void* _data = nullptr); nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, void const* _data = nullptr);
NamedTensor(TensorPtr _tensor, std::string _name) NamedTensor(TensorPtr _tensor, std::string _name)
: Base(std::move(_tensor), std::move(_name)){}; : Base(std::move(_tensor), std::move(_name)){};
@ -74,6 +74,10 @@ public:
[[nodiscard]] std::vector<int64_t> serialize() const; [[nodiscard]] std::vector<int64_t> serialize() const;
static NamedTensor deserialize(const int64_t* packed); void serialize(int64_t* out, const size_t totalSize) const;
[[nodiscard]] size_t serializedSize() const;
static NamedTensor deserialize(int64_t const* packed);
}; };
} // namespace tensorrt_llm::batch_manager } // namespace tensorrt_llm::batch_manager

View File

@ -50,11 +50,19 @@ public:
explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig) explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig)
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), : TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()),
executorConfig.getEnableTrtOverlap(), executorConfig.getDeviceIds(), executorConfig.getNormalizeLogProbs(), executorConfig.getEnableTrtOverlap(),
executorConfig.getEnableChunkedContext()) executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(),
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext())
{ {
} }
bool operator==(TrtGptModelOptionalParams const& other) const
{
return kvCacheConfig == other.kvCacheConfig && enableTrtOverlap == other.enableTrtOverlap
&& deviceIds == other.deviceIds && normalizeLogProbs == other.normalizeLogProbs
&& enableChunkedContext == other.enableChunkedContext && decodingMode == other.decodingMode;
}
KvCacheConfig kvCacheConfig; KvCacheConfig kvCacheConfig;
bool enableTrtOverlap; bool enableTrtOverlap;

View File

@ -16,6 +16,7 @@
#pragma once #pragma once
#include "tensorrt_llm/common/assert.h"
#include <cstdint> #include <cstdint>
namespace tensorrt_llm::common namespace tensorrt_llm::common
@ -80,11 +81,17 @@ public:
[[nodiscard]] reference operator[](size_type index) [[nodiscard]] reference operator[](size_type index)
{ {
#ifdef INDEX_RANGE_CHECK
TLLM_CHECK_WITH_INFO(index < mSize, "Index %lu is out of bounds [0, %lu)", index, mSize);
#endif
return mData[index]; return mData[index];
} }
[[nodiscard]] const_reference operator[](size_type index) const [[nodiscard]] const_reference operator[](size_type index) const
{ {
#ifdef INDEX_RANGE_CHECK
TLLM_CHECK_WITH_INFO(index < mSize, "Index %lu is out of bounds [0, %lu)", index, mSize);
#endif
return mData[index]; return mData[index];
} }

View File

@ -56,6 +56,7 @@ enum class MpiType
kUINT64, kUINT64,
kFP8, kFP8,
kBF16, kBF16,
kCHAR,
}; };
//! \brief For converting a C++ data type to a TensorRT data type. //! \brief For converting a C++ data type to a TensorRT data type.
@ -133,6 +134,12 @@ struct MpiTypeConverter<std::uint64_t>
static constexpr auto value = MpiType::kUINT64; static constexpr auto value = MpiType::kUINT64;
}; };
template <>
struct MpiTypeConverter<char>
{
static constexpr auto value = MpiType::kCHAR;
};
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
template <> template <>
struct MpiTypeConverter<__nv_fp8_e4m3> struct MpiTypeConverter<__nv_fp8_e4m3>
@ -202,8 +209,8 @@ public:
~MpiComm() noexcept; ~MpiComm() noexcept;
// no copy // no copy
MpiComm(const MpiComm&) = delete; MpiComm(MpiComm const&) = delete;
MpiComm& operator=(const MpiComm&) = delete; MpiComm& operator=(MpiComm const&) = delete;
// move // move
MpiComm(MpiComm&&) noexcept; MpiComm(MpiComm&&) noexcept;
@ -253,7 +260,24 @@ public:
} }
} }
void bcast(std::vector<int64_t>& packed, int root) const; template <typename T>
void bcast(std::vector<T>& vec, int root) const
{
auto const rank = getRank();
auto vecSize = (rank == root) ? static_cast<int64_t>(vec.size()) : int64_t(0);
bcast(&vecSize, 1, MpiType::kINT64, root);
vec.resize(vecSize);
if constexpr (std::is_fundamental_v<std::remove_cv_t<T>>)
{
auto const mpiType = MpiTypeConverter<std::remove_cv_t<T>>::value;
bcast(vec.data(), vec.size(), mpiType, root);
}
else
{
bcast(vec.data(), vec.size() * sizeof(T), MpiType::kBYTE, root);
}
}
void send(void const* buffer, std::size_t size, MpiType dtype, int dest, int tag) const; void send(void const* buffer, std::size_t size, MpiType dtype, int dest, int tag) const;
@ -297,8 +321,8 @@ public:
} }
} }
void allreduce(const void* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const; void allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const;
void allgather(const void* sendbuf, void* recvbuf, int count, MpiType dtype) const; void allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const;
void barrier() const; void barrier() const;
void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const; void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;

View File

@ -34,6 +34,9 @@
namespace tensorrt_llm::executor namespace tensorrt_llm::executor
{ {
class Model;
class Serialization;
/// @brief Sampling configuration /// @brief Sampling configuration
class SamplingConfig class SamplingConfig
{ {
@ -51,6 +54,8 @@ public:
~SamplingConfig(); ~SamplingConfig();
bool operator==(SamplingConfig const& other) const;
[[nodiscard]] SizeType getBeamWidth() const; [[nodiscard]] SizeType getBeamWidth() const;
[[nodiscard]] std::optional<SizeType> getTopK() const; [[nodiscard]] std::optional<SizeType> getTopK() const;
[[nodiscard]] std::optional<FloatType> getTopP() const; [[nodiscard]] std::optional<FloatType> getTopP() const;
@ -68,6 +73,7 @@ public:
[[nodiscard]] std::optional<SizeType> getEarlyStopping() const; [[nodiscard]] std::optional<SizeType> getEarlyStopping() const;
private: private:
friend class Serialization;
SizeType mBeamWidth; SizeType mBeamWidth;
std::optional<SizeType> mTopK; std::optional<SizeType> mTopK;
std::optional<FloatType> mTopP; std::optional<FloatType> mTopP;
@ -86,12 +92,16 @@ private:
}; };
/// @brief Configuration that controls the outputs of a Result /// @brief Configuration that controls the outputs of a Result
struct OutputConfig class OutputConfig
{ {
bool returnLogProbs{false}; public:
bool returnContextLogits{false}; OutputConfig(bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false,
bool returnGenerationLogits{false}; bool excludeInputFromOutput = false);
bool excludeInputFromOutput{false};
bool returnLogProbs;
bool returnContextLogits;
bool returnGenerationLogits;
bool excludeInputFromOutput;
}; };
/// @brief Configuration for speculative decoding. Allows to include draft tokens, draft logits and specify acceptance /// @brief Configuration for speculative decoding. Allows to include draft tokens, draft logits and specify acceptance
@ -109,6 +119,7 @@ public:
[[nodiscard]] std::optional<FloatType> getAcceptanceThreshold() const; [[nodiscard]] std::optional<FloatType> getAcceptanceThreshold() const;
private: private:
friend class Serialization;
VecTokens mTokens; VecTokens mTokens;
std::optional<Tensor> mLogits; std::optional<Tensor> mLogits;
std::optional<FloatType> mAcceptanceThreshold; std::optional<FloatType> mAcceptanceThreshold;
@ -128,6 +139,7 @@ public:
[[nodiscard]] Tensor getEmbeddingTable() const; [[nodiscard]] Tensor getEmbeddingTable() const;
private: private:
friend class Serialization;
Tensor mEmbeddingTable; Tensor mEmbeddingTable;
}; };
@ -142,6 +154,8 @@ public:
[[nodiscard]] Tensor getConfig() const; [[nodiscard]] Tensor getConfig() const;
private: private:
friend class Serialization;
Tensor mWeights; Tensor mWeights;
Tensor mConfig; Tensor mConfig;
}; };
@ -207,6 +221,7 @@ public:
void setLoraConfig(LoraConfig loraConfig); void setLoraConfig(LoraConfig loraConfig);
private: private:
friend class Serialization;
class Impl; class Impl;
std::unique_ptr<Impl> mImpl; std::unique_ptr<Impl> mImpl;
}; };
@ -298,15 +313,49 @@ private:
SizeType const kDefaultIterStatsMaxIterations = 1000; SizeType const kDefaultIterStatsMaxIterations = 1000;
/// @brief A configuration class for the parallel execution parameters
/// Currently only supports commType = CommunicationType::kMPI
class ParallelConfig
{
public:
/// @brief Constructor
/// @param commType The communication type. See CommunicationType.
/// @param commMode The communication mode. See CommunicationMode.
/// @param deviceIds The IDs of the GPUs involved in the execution of the model
/// @param participantIds The participant IDs (MPI ranks if commType == kMPI) involved in the execution of the
/// model. The first participant is considered to be the leader.
ParallelConfig(CommunicationType commType = CommunicationType::kMPI,
CommunicationMode commMode = CommunicationMode::kLEADER,
std::optional<std::vector<SizeType>> deviceIds = std::nullopt,
std::optional<std::vector<SizeType>> participantIds = std::nullopt);
~ParallelConfig();
[[nodiscard]] CommunicationType getCommunicationType() const;
[[nodiscard]] CommunicationMode getCommunicationMode() const;
[[nodiscard]] std::optional<std::vector<SizeType>> getDeviceIds() const;
[[nodiscard]] std::optional<std::vector<SizeType>> getParticipantIds() const;
void setCommunicationType(CommunicationType type);
void setCommunicationMode(CommunicationMode mode);
void setDeviceIds(std::vector<SizeType> deviceIds);
void setParticipantIds(std::vector<SizeType> participantIds);
private:
CommunicationType mCommType;
CommunicationMode mCommMode;
std::optional<std::vector<SizeType>> mDeviceIds;
std::optional<std::vector<SizeType>> mParticipantIds;
};
/// @brief Configuration class for the model executor /// @brief Configuration class for the model executor
class ExecutorConfig class ExecutorConfig
{ {
public: public:
ExecutorConfig(SizeType maxBeamWidth = 1, SchedulerConfig schedulerConfig = SchedulerConfig(), ExecutorConfig(SizeType maxBeamWidth = 1, SchedulerConfig schedulerConfig = SchedulerConfig(),
KvCacheConfig kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false, bool normalizeLogProbs = true, KvCacheConfig kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false, bool normalizeLogProbs = true,
bool enableTrtOverlap = false, std::optional<std::vector<SizeType>> deviceIds = std::nullopt, bool enableTrtOverlap = false, SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations,
SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations, BatchingType batchingType = BatchingType::kINFLIGHT,
BatchingType batchingType = BatchingType::kINFLIGHT); std::optional<ParallelConfig> parallelConfig = std::nullopt);
[[nodiscard]] SizeType getMaxBeamWidth() const; [[nodiscard]] SizeType getMaxBeamWidth() const;
[[nodiscard]] SchedulerConfig getSchedulerConfig() const; [[nodiscard]] SchedulerConfig getSchedulerConfig() const;
@ -314,9 +363,9 @@ public:
[[nodiscard]] bool getEnableChunkedContext() const; [[nodiscard]] bool getEnableChunkedContext() const;
[[nodiscard]] bool getNormalizeLogProbs() const; [[nodiscard]] bool getNormalizeLogProbs() const;
[[nodiscard]] bool getEnableTrtOverlap() const; [[nodiscard]] bool getEnableTrtOverlap() const;
[[nodiscard]] std::optional<std::vector<SizeType>> getDeviceIds() const;
[[nodiscard]] SizeType getIterStatsMaxIterations() const; [[nodiscard]] SizeType getIterStatsMaxIterations() const;
[[nodiscard]] BatchingType getBatchingType() const; [[nodiscard]] BatchingType getBatchingType() const;
[[nodiscard]] std::optional<ParallelConfig> getParallelConfig() const;
void setMaxBeamWidth(SizeType maxBeamWidth); void setMaxBeamWidth(SizeType maxBeamWidth);
void setSchedulerConfig(SchedulerConfig schedulerConfig); void setSchedulerConfig(SchedulerConfig schedulerConfig);
@ -324,9 +373,9 @@ public:
void setEnableChunkedContext(bool enableChunkedContext); void setEnableChunkedContext(bool enableChunkedContext);
void setNormalizeLogProbs(bool normalizeLogProbs); void setNormalizeLogProbs(bool normalizeLogProbs);
void setEnableTrtOverlap(bool enableTrtOverlap); void setEnableTrtOverlap(bool enableTrtOverlap);
void setDeviceIds(std::optional<std::vector<SizeType>> deviceIds);
void setIterStatsMaxIterations(SizeType iterStatsMaxIterations); void setIterStatsMaxIterations(SizeType iterStatsMaxIterations);
void setBatchingType(BatchingType batchingType); void setBatchingType(BatchingType batchingType);
void setParallelConfig(ParallelConfig parallelConfig);
private: private:
SizeType mMaxBeamWidth; SizeType mMaxBeamWidth;
@ -335,24 +384,11 @@ private:
bool mEnableChunkedContext; bool mEnableChunkedContext;
bool mNormalizeLogProbs; bool mNormalizeLogProbs;
bool mEnableTrtOverlap; bool mEnableTrtOverlap;
std::optional<std::vector<SizeType>> mDeviceIds;
SizeType mIterStatsMaxIterations; SizeType mIterStatsMaxIterations;
BatchingType mBatchingType; BatchingType mBatchingType;
std::optional<ParallelConfig> mParallelConfig;
}; };
/// TODO:
/// @brief A class to identify processes involved in the execution of a model
/// Currently only supports MPI communication
class Communicator
{
public:
Communicator(CommunicatorType commType, CommMode mode, SizeType currentId, std::vector<SizeType> const& commIds,
std::optional<SizeType> orchestratorId){};
~Communicator() = default;
};
class Model;
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference /// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
class Executor class Executor
{ {
@ -364,14 +400,12 @@ public:
/// @param modelType The type of model /// @param modelType The type of model
/// @param executorConfig The configuration for the executor /// @param executorConfig The configuration for the executor
/// @param comm An optional inter-process communicator configuration /// @param comm An optional inter-process communicator configuration
Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig executorConfig, Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig executorConfig);
std::optional<Communicator> comm = std::nullopt);
Executor(std::vector<uint8_t> const& engineBuffer, std::string const& jsonConfigStr, ModelType modelType, Executor(std::vector<uint8_t> const& engineBuffer, std::string const& jsonConfigStr, ModelType modelType,
ExecutorConfig executorConfig, std::optional<Communicator> comm = std::nullopt); ExecutorConfig executorConfig);
Executor( Executor(std::shared_ptr<Model> model, ExecutorConfig executorConfig);
std::shared_ptr<Model> model, ExecutorConfig executorConfig, std::optional<Communicator> comm = std::nullopt);
~Executor(); ~Executor();

View File

@ -180,11 +180,11 @@ public:
~Tensor() = default; ~Tensor() = default;
Tensor(const Tensor& other) noexcept = default; Tensor(Tensor const& other) noexcept = default;
Tensor(Tensor&& other) noexcept = default; Tensor(Tensor&& other) noexcept = default;
Tensor& operator=(const Tensor& other) noexcept = default; Tensor& operator=(Tensor const& other) noexcept = default;
Tensor& operator=(Tensor&& other) noexcept = default; Tensor& operator=(Tensor&& other) noexcept = default;
@ -267,6 +267,7 @@ private:
friend std::shared_ptr<runtime::ITensor> const& detail::toITensor(Tensor const& tensor); friend std::shared_ptr<runtime::ITensor> const& detail::toITensor(Tensor const& tensor);
friend Tensor detail::ofITensor(std::shared_ptr<runtime::ITensor> tensor); friend Tensor detail::ofITensor(std::shared_ptr<runtime::ITensor> tensor);
friend class Serialization;
}; };
} // namespace tensorrt_llm::executor } // namespace tensorrt_llm::executor

View File

@ -155,21 +155,16 @@ enum class SchedulerPolicy
kGUARANTEED_NO_EVICT = 1, kGUARANTEED_NO_EVICT = 1,
}; };
enum class CommunicatorType enum class CommunicationType
{ {
kMPI = 0 kMPI = 0
}; };
enum class CommMode enum class CommunicationMode
{ {
kLEADER, // With the leader mode, only the leader will be returning from the executor constructor and kLEADER, // With the leader mode, only the leader can enqueue requests. The requests will be
// therefore only the leader can enqueue requests and get responses // broadcasted to the workers. All participants can get response via awaitResponses. The leader is the
kORCHESTRATOR, // With the orchestrator mode, only the orchestrator will be returning from the executor constructor // first participant in the provided participant IDS, or 0 if participant ID is not provided
// and therefore only the leader can enqueue requests and get responses The orchestrator doesn't
// participate in the computations
kALL, // With the ALL mode, all participants are expected to make the same calls to the executor API
// So they all need to send the same requests
// Responses will be the same for all participants
}; };
} // namespace tensorrt_llm::executor } // namespace tensorrt_llm::executor

View File

@ -81,6 +81,11 @@ public:
using UnderlyingType = uint8_t; using UnderlyingType = uint8_t;
bool operator==(DecodingMode const& other) const
{
return mState == other.mState;
}
private: private:
constexpr DecodingMode(UnderlyingType state) constexpr DecodingMode(UnderlyingType state)
: mState(state) : mState(state)

View File

@ -17,10 +17,13 @@
#pragma once #pragma once
#include "tensorrt_llm/runtime/bufferManager.h" #include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/decodingInput.h" #include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingMode.h" #include "tensorrt_llm/runtime/decodingMode.h"
#include "tensorrt_llm/runtime/decodingOutput.h" #include "tensorrt_llm/runtime/decodingOutput.h"
#include "tensorrt_llm/runtime/gptModelConfig.h"
#include "tensorrt_llm/runtime/samplingConfig.h" #include "tensorrt_llm/runtime/samplingConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <curand_kernel.h> #include <curand_kernel.h>
#include <memory> #include <memory>
@ -59,7 +62,7 @@ public:
DecodingInput const& decodingInput, BufferManager const& manager) DecodingInput const& decodingInput, BufferManager const& manager)
= 0; = 0;
virtual const SamplingConfig& getSamplingConfig() = 0; virtual SamplingConfig const& getSamplingConfig() = 0;
static void acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds, static void acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds,
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths, ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths,
@ -71,6 +74,11 @@ public:
SizeType vocabSize, SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold, SizeType vocabSize, SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream); curandState_t* curandState, BufferManager::CudaStreamPtr const& stream);
static void updateKVCacheBasedOnAcceptedTokens(ITensor const& acceptedOffsets, ITensor const& packedAcceptedIds,
ITensor const& pointerArray, ITensor const& pastKeyValueLengths, GptModelConfig const& modelConfig,
WorldConfig const& worldConfig, BufferManager::CudaStreamPtr stream, SizeType rewindDraftTokenCount,
SizeType maxAttentionWindow, SizeType maxBlocksPerSeq, nvinfer1::DataType dtype);
static std::unique_ptr<IGptDecoder> create(DecodingMode const& mode, nvinfer1::DataType dtype, size_t maxBatchSize, static std::unique_ptr<IGptDecoder> create(DecodingMode const& mode, nvinfer1::DataType dtype, size_t maxBatchSize,
size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
BufferManager::CudaStreamPtr const& stream); BufferManager::CudaStreamPtr const& stream);
@ -97,7 +105,7 @@ public:
void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput, DecodingInput const& decodingInput, void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput, DecodingInput const& decodingInput,
BufferManager const& manager) override; BufferManager const& manager) override;
const SamplingConfig& getSamplingConfig() override SamplingConfig const& getSamplingConfig() override
{ {
return mSamplingConfig; return mSamplingConfig;
} }

View File

@ -153,6 +153,18 @@ public:
return mFinishedSum; return mFinishedSum;
} }
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
[[nodiscard]] TensorPtr getNextDraftTokens() const override
{
return mNextDraftTokens;
}
//! @returns [batchSize], lengths of the predicted draft tokens for next step, on gpu
[[nodiscard]] TensorPtr getNextDraftTokenLengths() const override
{
return mNextDraftTokenLengths;
}
private: private:
//! @brief Gather final beam search results for request `batchIdx`. //! @brief Gather final beam search results for request `batchIdx`.
[[nodiscard]] CudaEvent postProcessRequest(SizeType batchIdx) const; [[nodiscard]] CudaEvent postProcessRequest(SizeType batchIdx) const;
@ -204,6 +216,8 @@ private:
TensorPtr mBatchSlotsAcceptTokens; // [maxBatchSize], int32_t, address map, pinned TensorPtr mBatchSlotsAcceptTokens; // [maxBatchSize], int32_t, address map, pinned
TensorPtr mBatchSlotsAcceptLogits; // [maxBatchSize], int32_t, address map, pinned TensorPtr mBatchSlotsAcceptLogits; // [maxBatchSize], int32_t, address map, pinned
TensorPtr mTargetLogitsPtrs; // [maxBatchSize], float*, pointers to target logits, pinned TensorPtr mTargetLogitsPtrs; // [maxBatchSize], float*, pointers to target logits, pinned
TensorPtr mNextDraftTokens;
TensorPtr mNextDraftTokenLengths;
SizeType mMaxSequenceLength{}; SizeType mMaxSequenceLength{};
SizeType mMaxAttentionWindow{}; SizeType mMaxAttentionWindow{};
SizeType mSinkTokenLength{}; SizeType mSinkTokenLength{};

View File

@ -46,15 +46,10 @@ public:
, endId{endId} , endId{endId}
, computeCumLogProbs(false) , computeCumLogProbs(false)
, computeLogProbs(false) , computeLogProbs(false)
, generatedTokensPerStep(1)
{ {
} }
// the number of tokens generated per step
SizeType generatedTokensPerStep() const
{
return draftTokens ? draftTokens->getSize() + 1 : 1;
}
// mandatory parameters // mandatory parameters
ConstTensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu ConstTensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu
SizeType inputLen; // the input length without draft tokens SizeType inputLen; // the input length without draft tokens
@ -71,6 +66,7 @@ public:
bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request
bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request
SizeType generatedTokensPerStep;
}; };
class Input class Input
@ -184,6 +180,12 @@ public:
std::vector<SamplingConfig> const& samplingConfigs) std::vector<SamplingConfig> const& samplingConfigs)
= 0; = 0;
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
virtual TensorPtr getNextDraftTokens() const = 0;
//! @returns [batchSize], lengths of the predicted draft tokens for next step, on gpu
virtual TensorPtr getNextDraftTokenLengths() const = 0;
protected: protected:
IGptDecoderBatch() = default; IGptDecoderBatch() = default;
}; };

View File

@ -36,7 +36,7 @@ public:
IpcMemory(WorldConfig const& worldConfig, std::size_t bufferSize); IpcMemory(WorldConfig const& worldConfig, std::size_t bufferSize);
~IpcMemory(); ~IpcMemory();
[[nodiscard]] const std::vector<void*>& getCommPtrsTensor() const [[nodiscard]] std::vector<void*> const& getCommPtrsTensor() const
{ {
return mCommPtrs; return mCommPtrs;
} }

View File

@ -67,7 +67,7 @@ public:
// Fill the tasks tensor for the batch using the provided tasksHost // Fill the tasks tensor for the batch using the provided tasksHost
// Function assumes that the first numContextRequests requests in the batch are context requests // Function assumes that the first numContextRequests requests in the batch are context requests
void fillTasksTensor(TensorPtr tasksHost, const SizeType batchSize, const SizeType numContextRequests, void fillTasksTensor(TensorPtr tasksHost, const SizeType batchSize, const SizeType numContextRequests,
const std::vector<SizeType>& reqBeamWidths, const std::vector<SizeType>& reqPromptLengths, std::vector<SizeType> const& reqBeamWidths, std::vector<SizeType> const& reqPromptLengths,
BufferManager const& manager, bool packedInput); BufferManager const& manager, bool packedInput);
}; };

View File

@ -43,7 +43,7 @@ private:
auto const hasValues = accessor(0).has_value(); auto const hasValues = accessor(0).has_value();
for (size_t ci = 0; ci < configs.size(); ++ci) for (size_t ci = 0; ci < configs.size(); ++ci)
{ {
const auto& configValue = accessor(ci); auto const& configValue = accessor(ci);
TLLM_CHECK(hasValues == configValue.has_value()); TLLM_CHECK(hasValues == configValue.has_value());
if (hasValues) if (hasValues)
{ {

View File

@ -188,7 +188,6 @@ endif()
set(TRTLLM_LINK_LIBS set(TRTLLM_LINK_LIBS
${CUBLAS_LIB} ${CUBLAS_LIB}
${CUBLASLT_LIB} ${CUBLASLT_LIB}
${CUDNN_LIB}
${CMAKE_DL_LIBS} ${CMAKE_DL_LIBS}
${MPI_C_LIBRARIES} ${MPI_C_LIBRARIES}
${NCCL_LIB} ${NCCL_LIB}

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:0ecc134ad10a54b2953c772e72db2f71e84130d5736087b033e9e5b78594db6d oid sha256:c56ee13bb109917ab10df168ca15e6057436df1cd8b64a4268c6e7aae78a5ad8
size 2113376 size 2126310

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:9aa3f3d7f8313c099df8e9bd4c9707922a4f1c4025c4c99986acf6df781738c7 oid sha256:339532215fa4c16e68ca28ee23d0a0e09c9caefa7bd19b563d2f7b83cad6822e
size 2128450 size 2142070

View File

@ -1,3 +1,3 @@
add62ff328028bbcded1af694fe758c5 libtensorrt_llm_batch_manager_static.a c9c505e2cb6e95b7cfc124c04ab1fcb3 libtensorrt_llm_batch_manager_static.a
9e8846e200e2aaaeace862741a90c3ab libtensorrt_llm_batch_manager_static.pre_cxx11.a 2f5cec5a5b42e0031bc2edc688c1e74b libtensorrt_llm_batch_manager_static.pre_cxx11.a
230623fa285048a2de5c54c2cc0f364fb9f2c559 commit 741fb083cc42933439ae54557b177b6d7064da4f commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:7b25de974b6ca5f0dcb279f16f38199167d1efc35c01770d3234bec2dfb5dc86 oid sha256:a4060f2d60472850344e5b5799f9ad88390f4ad9c056e3843f3bdbcc046ca68b
size 2097848 size 2106440

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:5f06cee5ae2bcf393196265cd9a3ef832690cd4c5c53934bbfb169d50ab33c41 oid sha256:829f1ed5af0b0d2577e57fd13979706fe0b3636bd6338aac3c34a615f64afedc
size 2055004 size 2064310

View File

@ -1,2 +1,2 @@
bb62a31b8e17dae284d784ba43d5bc02 libtensorrt_llm_batch_manager_static.a 2db5c985786dad3dd16c22ec54af0803 libtensorrt_llm_batch_manager_static.a
19327f59c7f5b6235e15b322d5f5a0f4 libtensorrt_llm_batch_manager_static.pre_cxx11.a 96940249ff7b3ff09754b89ad25fcf9f libtensorrt_llm_batch_manager_static.pre_cxx11.a

View File

@ -42,11 +42,11 @@ public:
virtual ~IAllocator() = default; virtual ~IAllocator() = default;
// no copying // no copying
IAllocator(const IAllocator&) = delete; IAllocator(IAllocator const&) = delete;
IAllocator& operator=(const IAllocator&) = delete; IAllocator& operator=(IAllocator const&) = delete;
template <typename T> template <typename T>
[[nodiscard]] T* reMalloc(T* ptr, size_t sizeBytes, const bool setZero = true) [[nodiscard]] T* reMalloc(T* ptr, size_t sizeBytes, bool const setZero = true)
{ {
TLLM_LOG_TRACE(__PRETTY_FUNCTION__); TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
// TODO martinma: why do we need this size extension? // TODO martinma: why do we need this size extension?

View File

@ -23,7 +23,7 @@
namespace tensorrt_llm::common namespace tensorrt_llm::common
{ {
[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "") [[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "")
{ {
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str())); throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str()));
} }
@ -38,8 +38,10 @@ public:
#if defined(_WIN32) #if defined(_WIN32)
#define TLLM_LIKELY(x) (__assume((x) == 1), (x)) #define TLLM_LIKELY(x) (__assume((x) == 1), (x))
#define TLLM_UNLIKELY(x) (__assume((x) == 0), (x))
#else #else
#define TLLM_LIKELY(x) __builtin_expect((x), 1) #define TLLM_LIKELY(x) __builtin_expect((x), 1)
#define TLLM_UNLIKELY(x) __builtin_expect((x), 0)
#endif #endif
#define TLLM_CHECK(val) \ #define TLLM_CHECK(val) \
@ -61,20 +63,22 @@ public:
#define TLLM_CHECK_DEBUG(val) \ #define TLLM_CHECK_DEBUG(val) \
do \ do \
{ \ { \
if (DebugConfig::isCheckDebugEnabled()) \ if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
{ \ { \
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \ TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \ : tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
} \ } \
} while (0) } while (0)
#define TLLM_CHECK_DEBUG_WITH_INFO(val, info) \ #define TLLM_CHECK_DEBUG_WITH_INFO(val, info, ...) \
do \ do \
{ \ { \
if (DebugConfig::isCheckDebugEnabled()) \ if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
{ \ { \
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \ TLLM_LIKELY(static_cast<bool>(val)) \
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, info); \ ? ((void) 0) \
: tensorrt_llm::common::throwRuntimeError( \
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \
} \ } \
} while (0) } while (0)

View File

@ -42,7 +42,7 @@ CublasMMWrapper::~CublasMMWrapper()
mMutex = nullptr; mMutex = nullptr;
} }
CublasMMWrapper::CublasMMWrapper(const CublasMMWrapper& wrapper) CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper)
: mCublasHandle(wrapper.mCublasHandle) : mCublasHandle(wrapper.mCublasHandle)
, mCublasLtHandle(wrapper.mCublasLtHandle) , mCublasLtHandle(wrapper.mCublasLtHandle)
, mStream(wrapper.mStream) , mStream(wrapper.mStream)
@ -50,8 +50,8 @@ CublasMMWrapper::CublasMMWrapper(const CublasMMWrapper& wrapper)
{ {
} }
void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
const int k, const int lda, const int ldb, const int ldc) int const k, int const lda, int const ldb, int const ldc)
{ {
// -------------------------------------- // --------------------------------------
// Create descriptors for the original matrices // Create descriptors for the original matrices
@ -79,15 +79,15 @@ void CublasMMWrapper::destroyDescriptors()
mCDesc = NULL; mCDesc = NULL;
} }
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc) void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc)
{ {
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f); Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f);
} }
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc,
const std::optional<cublasLtMatmulHeuristicResult_t>& heuristic) std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
{ {
if (heuristic) if (heuristic)
{ {
@ -102,8 +102,8 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c
} }
} }
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta) void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta)
{ {
bool usingCublasLt = mAType == CUDA_R_16F; bool usingCublasLt = mAType == CUDA_R_16F;
@ -111,9 +111,9 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c
/* usingCublasLt */ usingCublasLt); /* usingCublasLt */ usingCublasLt);
} }
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta, void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
const cublasLtMatmulAlgo_t& algo, bool hasAlgo, bool usingCublasLt) cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt)
{ {
half h_alpha = (half) (f_alpha); half h_alpha = (half) (f_alpha);
half h_beta = (half) (f_beta); half h_beta = (half) (f_beta);
@ -126,8 +126,8 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c
int batch_count = 1; int batch_count = 1;
// fp32 use cublas as default // fp32 use cublas as default
// fp16 use cublasLt as default // fp16 use cublasLt as default
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha); void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta); void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE; int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
if (usingCublasLt) if (usingCublasLt)
@ -154,10 +154,10 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c
} }
} }
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
const int k, const void* A, const int lda, const int64_t strideA, const void* B, const int ldb, int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb,
const int64_t strideB, void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha, const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha,
const float f_beta) float const f_beta)
{ {
half h_alpha = (half) f_alpha; half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta; half h_beta = (half) f_beta;
@ -165,26 +165,26 @@ void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperati
std::lock_guard<std::mutex> lock(*mMutex); std::lock_guard<std::mutex> lock(*mMutex);
int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha); void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta); void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda, check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda,
strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType, strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType,
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP)); mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
} }
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
const int k, const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA, int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA,
const void* B, cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C, void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C,
cudaDataType_t CType, const int ldc, const int64_t strideC, const int batchCount, cudaDataType_t computeType) cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType)
{ {
half h_alpha = (half) f_alpha; half h_alpha = (half) f_alpha;
half h_beta = (half) f_beta; half h_beta = (half) f_beta;
std::lock_guard<std::mutex> lock(*mMutex); std::lock_guard<std::mutex> lock(*mMutex);
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0; bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha); void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta); void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda, check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda,
strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType, strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType,
@ -267,8 +267,8 @@ void CublasMMWrapper::setStream(cudaStream_t stream)
mStream = stream; mStream = stream;
} }
bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
const int k, const int lda, const int ldb, const int ldc, const cublasLtMatmulAlgo_t& algo) int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo)
{ {
TLLM_CHECK_WITH_INFO( TLLM_CHECK_WITH_INFO(
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
@ -291,12 +291,12 @@ bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t tr
} }
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasOperation_t transa, std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasOperation_t transa,
cublasOperation_t transb, const int m, const int n, const int k, const int lda, const int ldb, const int ldc) cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc)
{ {
TLLM_CHECK_WITH_INFO( TLLM_CHECK_WITH_INFO(
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function"); descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
const auto heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc); auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc);
sync_check_cuda_error(); sync_check_cuda_error();

View File

@ -65,39 +65,39 @@ public:
~CublasMMWrapper(); ~CublasMMWrapper();
CublasMMWrapper(const CublasMMWrapper& wrapper); CublasMMWrapper(CublasMMWrapper const& wrapper);
/********************** GEMMs **********************/ /********************** GEMMs **********************/
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A, void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
const int lda, const void* B, const int ldb, void* C, const int ldc); int const lda, void const* B, int const ldb, void* C, int const ldc);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A, void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
const int lda, const void* B, const int ldb, void* C, const int ldc, int const lda, void const* B, int const ldb, void* C, int const ldc,
const std::optional<cublasLtMatmulHeuristicResult_t>& algo); std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A, void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta); int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta);
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A, void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
const cublasLtMatmulAlgo_t& algo, bool hasAlgo, bool usingCublasLt); cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt);
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
const void* A, const int lda, const int64_t strideA, const void* B, const int ldb, const int64_t strideB, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB,
void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha = 1.0f, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f,
const float f_beta = 0.0f); float const f_beta = 0.0f);
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA, const void* B, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B,
cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C, cudaDataType_t CType, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType,
const int ldc, const int64_t strideC, const int batchCount, cudaDataType_t computeType); int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType);
/********************** Tactic selection helpers **********************/ /********************** Tactic selection helpers **********************/
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
const int lda, const int ldb, const int ldc, const cublasLtMatmulAlgo_t& algo); int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo);
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb, std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
const int m, const int n, const int k, const int lda, const int ldb, const int ldc); int const m, int const n, int const k, int const lda, int const ldb, int const ldc);
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle, std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc, cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
@ -126,8 +126,8 @@ public:
CublasDataType getCublasDataType(cudaDataType_t data_type); CublasDataType getCublasDataType(cudaDataType_t data_type);
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
const int lda, const int ldb, const int ldc); int const lda, int const ldb, int const ldc);
void destroyDescriptors(); void destroyDescriptors();
cublasHandle_t getCublasHandle() cublasHandle_t getCublasHandle()

View File

@ -43,7 +43,7 @@ CUDADriverWrapper::CUDADriverWrapper()
handle = dllOpen(CUDA_LIB_NAME); handle = dllOpen(CUDA_LIB_NAME);
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly."); TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
auto load_sym = [](void* handle, const char* name) auto load_sym = [](void* handle, char const* name)
{ {
void* ret = dllGetSym(handle, name); void* ret = dllGetSym(handle, name);
return ret; return ret;
@ -69,7 +69,7 @@ CUDADriverWrapper::~CUDADriverWrapper()
dllClose(handle); dllClose(handle);
} }
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, const char** pStr) const CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const
{ {
return (*_cuGetErrorName)(error, pStr); return (*_cuGetErrorName)(error, pStr);
} }
@ -94,7 +94,7 @@ CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const
return (*_cuLinkDestroy)(state); return (*_cuLinkDestroy)(state);
} }
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, const void* image) const CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const
{ {
return (*_cuModuleLoadData)(module, image); return (*_cuModuleLoadData)(module, image);
} }
@ -105,24 +105,24 @@ CUresult CUDADriverWrapper::cuLinkCreate(
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut); return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
} }
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, const char* name) const CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const
{ {
return (*_cuModuleGetFunction)(hfunc, hmod, name); return (*_cuModuleGetFunction)(hfunc, hmod, name);
} }
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name) const CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const
{ {
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name); return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
} }
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path, CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path,
unsigned int numOptions, CUjit_option* options, void** optionValues) const unsigned int numOptions, CUjit_option* options, void** optionValues) const
{ {
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues); return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
} }
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
{ {
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues); return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
} }

View File

@ -37,7 +37,7 @@ public:
~CUDADriverWrapper(); ~CUDADriverWrapper();
CUresult cuGetErrorName(CUresult error, const char** pStr) const; CUresult cuGetErrorName(CUresult error, char const** pStr) const;
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const; CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
@ -47,19 +47,19 @@ public:
CUresult cuLinkDestroy(CUlinkState state) const; CUresult cuLinkDestroy(CUlinkState state) const;
CUresult cuModuleLoadData(CUmodule* module, const void* image) const; CUresult cuModuleLoadData(CUmodule* module, void const* image) const;
CUresult cuLinkCreate( CUresult cuLinkCreate(
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const; unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, const char* name) const; CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const;
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name) const; CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const;
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path, unsigned int numOptions, CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions,
CUjit_option* options, void** optionValues) const; CUjit_option* options, void** optionValues) const;
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name, CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name,
unsigned int numOptions, CUjit_option* options, void** optionValues) const; unsigned int numOptions, CUjit_option* options, void** optionValues) const;
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
@ -72,18 +72,18 @@ public:
private: private:
void* handle; void* handle;
CUresult (*_cuGetErrorName)(CUresult, const char**); CUresult (*_cuGetErrorName)(CUresult, char const**);
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int); CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*); CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
CUresult (*_cuModuleUnload)(CUmodule); CUresult (*_cuModuleUnload)(CUmodule);
CUresult (*_cuLinkDestroy)(CUlinkState); CUresult (*_cuLinkDestroy)(CUlinkState);
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*); CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
CUresult (*_cuModuleLoadData)(CUmodule*, const void*); CUresult (*_cuModuleLoadData)(CUmodule*, void const*);
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, const char*); CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*);
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, const char*); CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*);
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, const char*, unsigned int, CUjit_option*, void**); CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLinkAddData)( CUresult (*_cuLinkAddData)(
CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**); CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**);
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int, CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
unsigned int, unsigned int, unsigned int, CUstream, void**); unsigned int, unsigned int, unsigned int, CUstream, void**);
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ, CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
@ -91,11 +91,11 @@ private:
CUstream hStream, void** kernelParams, void** extra); CUstream hStream, void** kernelParams, void** extra);
}; };
inline void cuErrCheck_(CUresult stat, const CUDADriverWrapper& wrap, const char* file, int line) inline void cuErrCheck_(CUresult stat, CUDADriverWrapper const& wrap, char const* file, int line)
{ {
if (stat != CUDA_SUCCESS) if (stat != CUDA_SUCCESS)
{ {
const char* msg = nullptr; char const* msg = nullptr;
wrap.cuGetErrorName(stat, &msg); wrap.cuGetErrorName(stat, &msg);
fprintf(stderr, "CUDA Error: %s %s %d\n", msg, file, line); fprintf(stderr, "CUDA Error: %s %s %d\n", msg, file, line);
} }

View File

@ -121,16 +121,16 @@ void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaSt
} }
template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>( template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>(
float* dst, const float* src, const int64_t numel, cudaStream_t stream); float* dst, float const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<float, float, __nv_fp8_e4m3>( template void invokeFakeQuantize<float, float, __nv_fp8_e4m3>(
float* dst, const __nv_fp8_e4m3* src, const int64_t numel, cudaStream_t stream); float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>( template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>(
half* dst, const half* src, const int64_t numel, cudaStream_t stream); half* dst, half const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>( template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>(
__nv_bfloat16* dst, const __nv_bfloat16* src, const int64_t numel, cudaStream_t stream); __nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream);
template void invokeFakeQuantize<float, half, float>( template void invokeFakeQuantize<float, half, float>(
half* dst, const float* src, const int64_t numel, cudaStream_t stream); half* dst, float const* src, const int64_t numel, cudaStream_t stream);
__device__ float atomicMaxExtd(float* address, float val) __device__ float atomicMaxExtd(float* address, float val)
{ {
@ -146,7 +146,7 @@ inline __device__ T atomicMaxExtdV2(T* address, T val)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
static_assert(std::is_same_v<T, half> | std::is_same_v<T, __nv_bfloat16>, "T needs to be either half or bfloat16"); static_assert(std::is_same_v<T, half> | std::is_same_v<T, __nv_bfloat16>, "T needs to be either half or bfloat16");
// The address in 64 bits. // The address in 64 bits.
uint64_t address_u64 = reinterpret_cast<const uint64_t&>(address); uint64_t address_u64 = reinterpret_cast<uint64_t const&>(address);
// Pack the input value into 32 bits. // Pack the input value into 32 bits.
union union
@ -155,7 +155,7 @@ inline __device__ T atomicMaxExtdV2(T* address, T val)
uint16_t u[2]; uint16_t u[2];
} old, tmp = {}; } old, tmp = {};
const int loc = (address_u64 & 0x2) >> 1; int const loc = (address_u64 & 0x2) >> 1;
tmp.v[loc] = val; tmp.v[loc] = val;
// 4B aligned pointer. // 4B aligned pointer.
@ -223,7 +223,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
auto val = fabs(static_cast<float>(weights[i])); auto val = fabs(static_cast<float>(weights[i]));
max = max > val ? max : val; max = max > val ? max : val;
} }
const auto scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if constexpr (std::is_same_v<T_S, float>) if constexpr (std::is_same_v<T_S, float>)
{ {
@ -231,7 +231,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
} }
else else
{ {
const auto address_u64 = reinterpret_cast<uint64_t>(quant_ptr + col); auto const address_u64 = reinterpret_cast<uint64_t>(quant_ptr + col);
if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0)) if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0))
atomicMaxExtd(quant_ptr + col, scale); atomicMaxExtd(quant_ptr + col, scale);
else else
@ -244,7 +244,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
} }
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN) else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
{ {
const auto nrows = size / n; auto const nrows = size / n;
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
{ {
float max = 0.f; float max = 0.f;
@ -256,7 +256,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
max = blockReduceMax<float>(max); max = blockReduceMax<float>(max);
if (threadIdx.x == 0) if (threadIdx.x == 0)
{ {
const auto scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
quant_ptr[row] = scale; quant_ptr[row] = scale;
} }
} }
@ -272,7 +272,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
max = blockReduceMax<float>(max); max = blockReduceMax<float>(max);
if (threadIdx.x == 0) if (threadIdx.x == 0)
{ {
const auto scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
atomicMaxExtd(quant_ptr, scale); atomicMaxExtd(quant_ptr, scale);
} }
} }
@ -326,19 +326,19 @@ __global__ void dynamicQuantizeMatrixPerToken(
extern __shared__ __align__(sizeof(float)) char _shmem[]; extern __shared__ __align__(sizeof(float)) char _shmem[];
T_IN* shmem = reinterpret_cast<T_IN*>(_shmem); T_IN* shmem = reinterpret_cast<T_IN*>(_shmem);
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f); constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
const auto nrows = numel / lda; auto const nrows = numel / lda;
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x) for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
{ {
float max = 0.f; float max = 0.f;
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
{ {
const auto in = input[row * lda + i]; auto const in = input[row * lda + i];
shmem[i] = in; shmem[i] = in;
auto val = fabs(static_cast<float>(in)); auto val = fabs(static_cast<float>(in));
max = max > val ? max : val; max = max > val ? max : val;
} }
max = blockAllReduceMax<float>(max); // __syncthreads() called so we can read shmem max = blockAllReduceMax<float>(max); // __syncthreads() called so we can read shmem
const auto s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor); auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x) for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
{ {
// true means we are quantizing // true means we are quantizing
@ -359,7 +359,7 @@ void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T
{ {
dim3 grid(numel / lda); dim3 grid(numel / lda);
bool use_shmem = true; bool use_shmem = true;
const auto shmem_size = lda * sizeof(T_IN); auto const shmem_size = lda * sizeof(T_IN);
if (shmem_size >= (48 << 10)) if (shmem_size >= (48 << 10))
{ {
cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken<T_OUT, T_S, T_IN>, cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken<T_OUT, T_S, T_IN>,

View File

@ -181,37 +181,37 @@ struct PackType<__nv_fp8_e4m3, 8>
}; };
#endif #endif
__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, const __nv_fp8x4_e4m3* in) __inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, __nv_fp8x4_e4m3 const* in)
{ {
const char4 tmp_val = reinterpret_cast<const char4*>(in)[0]; const char4 tmp_val = reinterpret_cast<char4 const*>(in)[0];
*out1 = __nv_bfloat162((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.x)[0], *out1 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.y)[0]); (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
*out2 = __nv_bfloat162((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.z)[0], *out2 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0],
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.w)[0]); (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]);
} }
__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(const __nv_fp8x2_e4m3* in) __inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(__nv_fp8x2_e4m3 const* in)
{ {
const char2 tmp_val = reinterpret_cast<const char2*>(in)[0]; const char2 tmp_val = reinterpret_cast<char2 const*>(in)[0];
__nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.x)[0], __nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.y)[0]); (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
return out; return out;
} }
__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, const __nv_fp8x4_e4m3* in) __inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, __nv_fp8x4_e4m3 const* in)
{ {
const char4 tmp_val = reinterpret_cast<const char4*>(in)[0]; const char4 tmp_val = reinterpret_cast<char4 const*>(in)[0];
*out1 = half2((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.x)[0], *out1 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.y)[0]); (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
*out2 = half2((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.z)[0], *out2 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0],
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.w)[0]); (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]);
} }
__inline__ __device__ half2 fp8x2_e4m3_to_half2(const __nv_fp8x2_e4m3* in) __inline__ __device__ half2 fp8x2_e4m3_to_half2(__nv_fp8x2_e4m3 const* in)
{ {
const char2 tmp_val = reinterpret_cast<const char2*>(in)[0]; const char2 tmp_val = reinterpret_cast<char2 const*>(in)[0];
half2 out = half2((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.x)[0], half2 out = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.y)[0]); (float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
return out; return out;
} }

View File

@ -32,14 +32,14 @@ namespace common
{ {
template <typename T> template <typename T>
inline __device__ T ldg(const T* val) inline __device__ T ldg(T const* val)
{ {
return __ldg(val); return __ldg(val);
} }
#if ENABLE_BF16 #if ENABLE_BF16
template <> template <>
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val) inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val)
{ {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0]; return val[0];
@ -49,7 +49,7 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val)
} }
template <> template <>
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val) inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val)
{ {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return val[0]; return val[0];

View File

@ -81,12 +81,12 @@ enum class OperationType
}; };
/* **************************** debug tools ********************************* */ /* **************************** debug tools ********************************* */
static const char* _cudaGetErrorEnum(cudaError_t error) static char const* _cudaGetErrorEnum(cudaError_t error)
{ {
return cudaGetErrorString(error); return cudaGetErrorString(error);
} }
static const char* _cudaGetErrorEnum(cublasStatus_t error) static char const* _cudaGetErrorEnum(cublasStatus_t error)
{ {
switch (error) switch (error)
{ {
@ -114,7 +114,7 @@ static const char* _cudaGetErrorEnum(cublasStatus_t error)
} }
template <typename T> template <typename T>
void check(T result, char const* const func, const char* const file, int const line) void check(T result, char const* const func, char const* const file, int const line)
{ {
if (result) if (result)
{ {
@ -133,7 +133,7 @@ inline bool isCudaLaunchBlocking()
if (firstCall) if (firstCall)
{ {
const char* env = std::getenv("CUDA_LAUNCH_BLOCKING"); char const* env = std::getenv("CUDA_LAUNCH_BLOCKING");
result = env != nullptr && std::string(env) == "1"; result = env != nullptr && std::string(env) == "1";
firstCall = false; firstCall = false;
} }
@ -141,12 +141,12 @@ inline bool isCudaLaunchBlocking()
return result; return result;
} }
inline void syncAndCheck(const char* const file, int const line) inline void syncAndCheck(char const* const file, int const line)
{ {
#ifndef NDEBUG #ifndef NDEBUG
const bool checkError = true; bool const checkError = true;
#else #else
const bool checkError = isCudaLaunchBlocking(); bool const checkError = isCudaLaunchBlocking();
#endif #endif
if (checkError) if (checkError)
@ -279,7 +279,7 @@ inline int getDeviceCount()
/// Get the memory info /// Get the memory info
/// \return The free and total amount of memory in bytes /// \return The free and total amount of memory in bytes
inline std::tuple<size_t, size_t> getDeviceMemoryInfo(const bool useUvm) inline std::tuple<size_t, size_t> getDeviceMemoryInfo(bool const useUvm)
{ {
if (useUvm) if (useUvm)
{ {
@ -351,7 +351,7 @@ auto constexpr ceilDiv(T numerator, U denominator)
} }
template <typename T> template <typename T>
void printAbsMean(const T* buf, uint64_t size, cudaStream_t stream, std::string name = "") void printAbsMean(T const* buf, uint64_t size, cudaStream_t stream, std::string name = "")
{ {
if (buf == nullptr) if (buf == nullptr)
{ {
@ -390,9 +390,9 @@ void printAbsMean(const T* buf, uint64_t size, cudaStream_t stream, std::string
} }
template <typename T> template <typename T>
void printToStream(const T* result, const int size, FILE* strm) void printToStream(T const* result, int const size, FILE* strm)
{ {
const bool split_rows = (strm == stdout); bool const split_rows = (strm == stdout);
if (result == nullptr) if (result == nullptr)
{ {
TLLM_LOG_WARNING("It is an nullptr, skip! \n"); TLLM_LOG_WARNING("It is an nullptr, skip! \n");
@ -414,13 +414,13 @@ void printToStream(const T* result, const int size, FILE* strm)
} }
template <typename T> template <typename T>
void printToScreen(const T* result, const int size) void printToScreen(T const* result, int const size)
{ {
printToStream(result, size, stdout); printToStream(result, size, stdout);
} }
template <typename T> template <typename T>
void print2dToStream(const T* result, const int r, const int c, const int stride, FILE* strm) void print2dToStream(T const* result, int const r, int const c, int const stride, FILE* strm)
{ {
if (result == nullptr) if (result == nullptr)
{ {
@ -429,20 +429,20 @@ void print2dToStream(const T* result, const int r, const int c, const int stride
} }
for (int ri = 0; ri < r; ++ri) for (int ri = 0; ri < r; ++ri)
{ {
const T* ptr = result + ri * stride; T const* ptr = result + ri * stride;
printToStream(ptr, c, strm); printToStream(ptr, c, strm);
} }
fprintf(strm, "\n"); fprintf(strm, "\n");
} }
template <typename T> template <typename T>
void print2dToScreen(const T* result, const int r, const int c, const int stride) void print2dToScreen(T const* result, int const r, int const c, int const stride)
{ {
print2dToStream(result, r, c, stride, stdout); print2dToStream(result, r, c, stride, stdout);
} }
template <typename T> template <typename T>
void print2dToFile(std::string fname, const T* result, const int r, const int c, const int stride) void print2dToFile(std::string fname, T const* result, int const r, int const c, int const stride)
{ {
FILE* fp = fopen(fname.c_str(), "wt"); FILE* fp = fopen(fname.c_str(), "wt");
if (fp != nullptr) if (fp != nullptr)
@ -493,7 +493,7 @@ inline void print_element_(int64_t ill)
} }
template <typename T> template <typename T>
inline void printMatrix(const T* ptr, int m, int k, int stride, bool is_device_ptr) inline void printMatrix(T const* ptr, int m, int k, int stride, bool is_device_ptr)
{ {
T* tmp; T* tmp;
if (is_device_ptr) if (is_device_ptr)
@ -538,14 +538,14 @@ inline void printMatrix(const T* ptr, int m, int k, int stride, bool is_device_p
} }
} }
template void printMatrix(const float* ptr, int m, int k, int stride, bool is_device_ptr); template void printMatrix(float const* ptr, int m, int k, int stride, bool is_device_ptr);
template void printMatrix(const half* ptr, int m, int k, int stride, bool is_device_ptr); template void printMatrix(half const* ptr, int m, int k, int stride, bool is_device_ptr);
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template void printMatrix(const __nv_bfloat16* ptr, int m, int k, int stride, bool is_device_ptr); template void printMatrix(__nv_bfloat16 const* ptr, int m, int k, int stride, bool is_device_ptr);
#endif #endif
template void printMatrix(const uint32_t* ptr, int m, int k, int stride, bool is_device_ptr); template void printMatrix(uint32_t const* ptr, int m, int k, int stride, bool is_device_ptr);
template void printMatrix(const uint64_t* ptr, int m, int k, int stride, bool is_device_ptr); template void printMatrix(uint64_t const* ptr, int m, int k, int stride, bool is_device_ptr);
template void printMatrix(const int* ptr, int m, int k, int stride, bool is_device_ptr); template void printMatrix(int const* ptr, int m, int k, int stride, bool is_device_ptr);
} // namespace tensorrt_llm::common } // namespace tensorrt_llm::common

View File

@ -25,7 +25,7 @@ namespace tensorrt_llm::common
// XQA kernels (optimized kernels for generation phase). // XQA kernels (optimized kernels for generation phase).
bool forceXQAKernels() bool forceXQAKernels()
{ {
const char* force_xqa_env_var = getenv("TRTLLM_FORCE_XQA"); char const* force_xqa_env_var = getenv("TRTLLM_FORCE_XQA");
static bool forceXQA = false; static bool forceXQA = false;
if (force_xqa_env_var != nullptr) if (force_xqa_env_var != nullptr)
{ {
@ -45,7 +45,7 @@ bool getEnvMmhaMultiblockDebug()
if (!init) if (!init)
{ {
init = true; init = true;
const char* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG"); char const* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG");
if (enable_mmha_debug_var) if (enable_mmha_debug_var)
{ {
if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0') if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0')
@ -64,7 +64,7 @@ int getEnvMmhaBlocksPerSequence()
if (!init) if (!init)
{ {
init = true; init = true;
const char* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE"); char const* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE");
if (mmhaBlocksPerSequenceEnv) if (mmhaBlocksPerSequenceEnv)
{ {
mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv); mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv);

View File

@ -65,5 +65,4 @@ Logger* Logger::getLogger()
thread_local Logger instance; thread_local Logger instance;
return &instance; return &instance;
} }
} // namespace tensorrt_llm::common } // namespace tensorrt_llm::common

View File

@ -54,26 +54,26 @@ public:
#if defined(_MSC_VER) #if defined(_MSC_VER)
template <typename... Args> template <typename... Args>
void log(Level level, char const* format, const Args&... args); void log(Level level, char const* format, Args const&... args);
template <typename... Args> template <typename... Args>
void log(Level level, int rank, char const* format, const Args&... args); void log(Level level, int rank, char const* format, Args const&... args);
#else #else
template <typename... Args> template <typename... Args>
void log(Level level, char const* format, const Args&... args) __attribute__((format(printf, 3, 0))); void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0)));
template <typename... Args> template <typename... Args>
void log(Level level, int rank, char const* format, const Args&... args) __attribute__((format(printf, 4, 0))); void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0)));
#endif #endif
template <typename... Args> template <typename... Args>
void log(Level level, std::string const& format, const Args&... args) void log(Level level, std::string const& format, Args const&... args)
{ {
return log(level, format.c_str(), args...); return log(level, format.c_str(), args...);
} }
template <typename... Args> template <typename... Args>
void log(const Level level, const int rank, const std::string& format, const Args&... args) void log(const Level level, int const rank, std::string const& format, Args const&... args)
{ {
return log(level, rank, format.c_str(), args...); return log(level, rank, format.c_str(), args...);
} }
@ -122,7 +122,7 @@ private:
return fmtstr("%s[%s] ", kPREFIX, getLevelName(level)); return fmtstr("%s[%s] ", kPREFIX, getLevelName(level));
} }
static inline std::string getPrefix(const Level level, const int rank) static inline std::string getPrefix(const Level level, int const rank)
{ {
return fmtstr("%s[%s][%d] ", kPREFIX, getLevelName(level), rank); return fmtstr("%s[%s][%d] ", kPREFIX, getLevelName(level), rank);
} }
@ -148,7 +148,7 @@ void Logger::log(Logger::Level level, char const* format, Args const&... args)
} }
template <typename... Args> template <typename... Args>
void Logger::log(const Logger::Level level, const int rank, char const* format, const Args&... args) void Logger::log(const Logger::Level level, int const rank, char const* format, Args const&... args)
{ {
if (level_ <= level) if (level_ <= level)
{ {

View File

@ -112,63 +112,63 @@ template void deviceFill(int* devptr, size_t size, int value, cudaStream_t strea
template void deviceFill(bool* devptr, size_t size, bool value, cudaStream_t stream); template void deviceFill(bool* devptr, size_t size, bool value, cudaStream_t stream);
template <typename T> template <typename T>
void cudaD2Hcpy(T* tgt, const T* src, const size_t size) void cudaD2Hcpy(T* tgt, T const* src, const size_t size)
{ {
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToHost)); check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToHost));
} }
template void cudaD2Hcpy(float* tgt, const float* src, size_t size); template void cudaD2Hcpy(float* tgt, float const* src, size_t size);
template void cudaD2Hcpy(half* tgt, const half* src, size_t size); template void cudaD2Hcpy(half* tgt, half const* src, size_t size);
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template void cudaD2Hcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size); template void cudaD2Hcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size);
#endif #endif
template void cudaD2Hcpy(int* tgt, const int* src, size_t size); template void cudaD2Hcpy(int* tgt, int const* src, size_t size);
template void cudaD2Hcpy(bool* tgt, const bool* src, size_t size); template void cudaD2Hcpy(bool* tgt, bool const* src, size_t size);
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, const __nv_fp8_e4m3* src, size_t size); template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size);
#endif #endif
template void cudaD2Hcpy(unsigned long long* tgt, const unsigned long long* src, size_t size); template void cudaD2Hcpy(unsigned long long* tgt, unsigned long long const* src, size_t size);
template void cudaD2Hcpy(unsigned int* tgt, const unsigned int* src, size_t size); template void cudaD2Hcpy(unsigned int* tgt, unsigned int const* src, size_t size);
template void cudaD2Hcpy(int8_t* tgt, const int8_t* src, size_t size); template void cudaD2Hcpy(int8_t* tgt, int8_t const* src, size_t size);
template <typename T> template <typename T>
void cudaH2Dcpy(T* tgt, const T* src, const size_t size) void cudaH2Dcpy(T* tgt, T const* src, const size_t size)
{ {
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice)); check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice));
} }
template void cudaH2Dcpy(float* tgt, const float* src, size_t size); template void cudaH2Dcpy(float* tgt, float const* src, size_t size);
template void cudaH2Dcpy(half* tgt, const half* src, size_t size); template void cudaH2Dcpy(half* tgt, half const* src, size_t size);
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template void cudaH2Dcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size); template void cudaH2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size);
#endif #endif
template void cudaH2Dcpy(int* tgt, const int* src, size_t size); template void cudaH2Dcpy(int* tgt, int const* src, size_t size);
template void cudaH2Dcpy(bool* tgt, const bool* src, size_t size); template void cudaH2Dcpy(bool* tgt, bool const* src, size_t size);
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, const __nv_fp8_e4m3* src, size_t size); template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size);
#endif #endif
template void cudaH2Dcpy(unsigned long long* tgt, const unsigned long long* src, size_t size); template void cudaH2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size);
template void cudaH2Dcpy(unsigned int* tgt, const unsigned int* src, size_t size); template void cudaH2Dcpy(unsigned int* tgt, unsigned int const* src, size_t size);
template void cudaH2Dcpy(int8_t* tgt, const int8_t* src, size_t size); template void cudaH2Dcpy(int8_t* tgt, int8_t const* src, size_t size);
template <typename T> template <typename T>
void cudaD2Dcpy(T* tgt, const T* src, const size_t size, cudaStream_t stream) void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream)
{ {
check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToDevice, stream)); check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToDevice, stream));
} }
template void cudaD2Dcpy(float* tgt, const float* src, size_t size, cudaStream_t stream); template void cudaD2Dcpy(float* tgt, float const* src, size_t size, cudaStream_t stream);
template void cudaD2Dcpy(half* tgt, const half* src, size_t size, cudaStream_t stream); template void cudaD2Dcpy(half* tgt, half const* src, size_t size, cudaStream_t stream);
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template void cudaD2Dcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size, cudaStream_t stream); template void cudaD2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream);
#endif #endif
template void cudaD2Dcpy(int* tgt, const int* src, size_t size, cudaStream_t stream); template void cudaD2Dcpy(int* tgt, int const* src, size_t size, cudaStream_t stream);
template void cudaD2Dcpy(bool* tgt, const bool* src, size_t size, cudaStream_t stream); template void cudaD2Dcpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream);
template void cudaD2Dcpy(int8_t* tgt, const int8_t* src, size_t size, cudaStream_t stream); template void cudaD2Dcpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream);
#ifdef ENABLE_FP8 #ifdef ENABLE_FP8
template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, const __nv_fp8_e4m3* src, size_t size, cudaStream_t stream); template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size, cudaStream_t stream);
#endif #endif
template void cudaD2Dcpy(unsigned long long* tgt, const unsigned long long* src, size_t size, cudaStream_t stream); template void cudaD2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream);
template <typename T_OUT, typename T_IN> template <typename T_OUT, typename T_IN>
__global__ void cudaCast(T_OUT* dst, T_IN* src, const size_t size) __global__ void cudaCast(T_OUT* dst, T_IN* src, const size_t size)
@ -204,7 +204,7 @@ template void invokeCudaCast(__nv_fp8_e4m3* dst, half const* const src, const si
#endif #endif
template <typename T> template <typename T>
void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream) void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream)
{ {
if (stream != NULL) if (stream != NULL)
{ {
@ -216,19 +216,19 @@ void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream)
} }
} }
template void cudaAutoCpy(float* tgt, const float* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(float* tgt, float const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(half* tgt, const half* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(half* tgt, half const* src, size_t size, cudaStream_t stream);
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template void cudaAutoCpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream);
#endif #endif
template void cudaAutoCpy(int* tgt, const int* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(int* tgt, int const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(bool* tgt, const bool* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(int8_t* tgt, const int8_t* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(uint8_t* tgt, const uint8_t* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(uint8_t* tgt, uint8_t const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(uint32_t* tgt, const uint32_t* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(uint32_t* tgt, uint32_t const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(unsigned long long* tgt, const unsigned long long* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(unsigned long* tgt, const unsigned long* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(unsigned long* tgt, unsigned long const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(char* tgt, const char* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(char* tgt, char const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(float const** tgt, float const* const* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(float const** tgt, float const* const* src, size_t size, cudaStream_t stream);
template void cudaAutoCpy(half const** tgt, half const* const* src, size_t size, cudaStream_t stream); template void cudaAutoCpy(half const** tgt, half const* const* src, size_t size, cudaStream_t stream);
@ -242,7 +242,7 @@ template void cudaAutoCpy(
unsigned long long const** tgt, unsigned long long const* const* src, size_t size, cudaStream_t stream); unsigned long long const** tgt, unsigned long long const* const* src, size_t size, cudaStream_t stream);
template <typename T> template <typename T>
__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, const int seq_offset) __global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, int const seq_offset)
{ {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
curandState_t local_state; curandState_t local_state;
@ -254,7 +254,7 @@ __global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, const i
} }
template <> template <>
__global__ void cuda_random_uniform_kernel<int>(int* buffer, const size_t size, const int seq_offset) __global__ void cuda_random_uniform_kernel<int>(int* buffer, const size_t size, int const seq_offset)
{ {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
curandState_t local_state; curandState_t local_state;
@ -266,7 +266,7 @@ __global__ void cuda_random_uniform_kernel<int>(int* buffer, const size_t size,
} }
template <> template <>
__global__ void cuda_random_uniform_kernel<bool>(bool* buffer, const size_t size, const int seq_offset) __global__ void cuda_random_uniform_kernel<bool>(bool* buffer, const size_t size, int const seq_offset)
{ {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
curandState_t local_state; curandState_t local_state;
@ -278,7 +278,7 @@ __global__ void cuda_random_uniform_kernel<bool>(bool* buffer, const size_t size
} }
template <> template <>
__global__ void cuda_random_uniform_kernel<char>(char* buffer, const size_t size, const int seq_offset) __global__ void cuda_random_uniform_kernel<char>(char* buffer, const size_t size, int const seq_offset)
{ {
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
curandState_t local_state; curandState_t local_state;
@ -462,30 +462,30 @@ void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cud
cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size); cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size);
} }
template void invokeCudaD2DcpyConvert(int8_t* tgt, const float* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(int8_t* tgt, float const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(float* tgt, const int8_t* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(float* tgt, int8_t const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(float* tgt, const int* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(float* tgt, int const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(half* tgt, const int* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(half* tgt, int const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(float* tgt, const float* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(float* tgt, float const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(half* tgt, const float* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(half* tgt, float const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(float* tgt, const half* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(float* tgt, half const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(uint32_t* tgt, const int* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(uint32_t* tgt, int const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(int* tgt, const uint32_t* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(int* tgt, uint32_t const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(int* tgt, const float* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(int* tgt, float const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(int* tgt, const half* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(int* tgt, half const* src, const size_t size, cudaStream_t stream);
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, const float* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, float const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, const int* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, int const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(float* tgt, const __nv_bfloat16* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(float* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream);
template void invokeCudaD2DcpyConvert(int* tgt, const __nv_bfloat16* src, const size_t size, cudaStream_t stream); template void invokeCudaD2DcpyConvert(int* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream);
#endif // ENABLE_BF16 #endif // ENABLE_BF16
template <typename T_IN, typename T_OUT> template <typename T_IN, typename T_OUT>
__global__ void cudaD2DScaleCpyConvert( __global__ void cudaD2DScaleCpyConvert(
T_OUT* dst, const T_IN* src, const float* scale, bool invert_scale, const size_t size) T_OUT* dst, const T_IN* src, float const* scale, bool invert_scale, const size_t size)
{ {
const float scale_value = invert_scale ? 1.0f / scale[0] : scale[0]; float const scale_value = invert_scale ? 1.0f / scale[0] : scale[0];
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x) for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
{ {
dst[tid] = cuda_cast<T_OUT>(cuda_cast<float>(src[tid]) * scale_value); dst[tid] = cuda_cast<T_OUT>(cuda_cast<float>(src[tid]) * scale_value);
@ -494,7 +494,7 @@ __global__ void cudaD2DScaleCpyConvert(
template <typename T_IN, typename T_OUT> template <typename T_IN, typename T_OUT>
void invokeCudaD2DScaleCpyConvert( void invokeCudaD2DScaleCpyConvert(
T_OUT* tgt, const T_IN* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream) T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream)
{ {
cudaD2DScaleCpyConvert<<<256, 256, 0, stream>>>(tgt, src, scale, invert_scale, size); cudaD2DScaleCpyConvert<<<256, 256, 0, stream>>>(tgt, src, scale, invert_scale, size);
} }
@ -524,7 +524,7 @@ void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const size_t size, cudaSt
} }
template <typename T> template <typename T>
void saveToBinary(const T* ptr, const size_t size, std::string filename) void saveToBinary(T const* ptr, const size_t size, std::string filename)
{ {
std::vector<T> h_ptr(size); std::vector<T> h_ptr(size);
@ -541,14 +541,14 @@ void saveToBinary(const T* ptr, const size_t size, std::string filename)
out.write((char*) float_ptr.data(), size * sizeof(float)); out.write((char*) float_ptr.data(), size * sizeof(float));
} }
template void saveToBinary(const float* ptr, const size_t size, std::string filename); template void saveToBinary(float const* ptr, const size_t size, std::string filename);
template void saveToBinary(const half* ptr, const size_t size, std::string filename); template void saveToBinary(half const* ptr, const size_t size, std::string filename);
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template void saveToBinary(const __nv_bfloat16* ptr, const size_t size, std::string filename); template void saveToBinary(__nv_bfloat16 const* ptr, const size_t size, std::string filename);
#endif // ENABLE_BF16 #endif // ENABLE_BF16
template <> template <>
void saveToBinary(const int* ptr, const size_t size, std::string filename) void saveToBinary(int const* ptr, const size_t size, std::string filename)
{ {
std::vector<int> h_ptr(size); std::vector<int> h_ptr(size);
cudaD2Hcpy(h_ptr.data(), ptr, size); cudaD2Hcpy(h_ptr.data(), ptr, size);
@ -831,7 +831,7 @@ size_t cuda_datatype_size(TRTLLMCudaDataType dt)
} }
template <typename T> template <typename T>
__global__ void check_range(const T* buffer, size_t size, T min, T max, bool* d_within_range) __global__ void check_range(T const* buffer, size_t size, T min, T max, bool* d_within_range)
{ {
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x) for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x)
{ {
@ -844,7 +844,7 @@ __global__ void check_range(const T* buffer, size_t size, T min, T max, bool* d_
} }
template <typename T> template <typename T>
bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream) bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream)
{ {
cudaMemsetAsync(d_within_range, true, sizeof(bool), stream); cudaMemsetAsync(d_within_range, true, sizeof(bool), stream);
@ -858,12 +858,12 @@ bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_
} }
template bool invokeCheckRange<int>( template bool invokeCheckRange<int>(
const int* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream); int const* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream);
/* /*
* Determine the total workspace size based on a vector containing multiple variable sizes. * Determine the total workspace size based on a vector containing multiple variable sizes.
*/ */
size_t calcAlignedSize(const std::vector<size_t>& sizes, const size_t ALIGN_BYTES) size_t calcAlignedSize(std::vector<size_t> const& sizes, const size_t ALIGN_BYTES)
{ {
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
// Check ALIGN_BYTES is a power of 2 // Check ALIGN_BYTES is a power of 2
@ -885,7 +885,7 @@ size_t calcAlignedSize(const std::vector<size_t>& sizes, const size_t ALIGN_BYTE
* of each variable. * of each variable.
*/ */
void calcAlignedPointers( void calcAlignedPointers(
std::vector<void*>& outPtrs, const void* p, const std::vector<size_t>& sizes, size_t ALIGN_BYTES) std::vector<void*>& outPtrs, void const* p, std::vector<size_t> const& sizes, size_t ALIGN_BYTES)
{ {
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1); const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
// Check ALIGN_BYTES is a power of 2 // Check ALIGN_BYTES is a power of 2

View File

@ -40,16 +40,16 @@ template <typename T>
void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0); void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0);
template <typename T> template <typename T>
void cudaD2Hcpy(T* tgt, const T* src, const size_t size); void cudaD2Hcpy(T* tgt, T const* src, const size_t size);
template <typename T> template <typename T>
void cudaH2Dcpy(T* tgt, const T* src, const size_t size); void cudaH2Dcpy(T* tgt, T const* src, const size_t size);
template <typename T> template <typename T>
void cudaD2Dcpy(T* tgt, const T* src, const size_t size, cudaStream_t stream = NULL); void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream = NULL);
template <typename T> template <typename T>
void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream = NULL); void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream = NULL);
template <typename T> template <typename T>
void cudaRandomUniform(T* buffer, const size_t size); void cudaRandomUniform(T* buffer, const size_t size);
@ -234,9 +234,9 @@ void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cud
template <typename T_IN, typename T_OUT> template <typename T_IN, typename T_OUT>
void invokeCudaD2DScaleCpyConvert( void invokeCudaD2DScaleCpyConvert(
T_OUT* tgt, const T_IN* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream = 0); T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream = 0);
inline bool checkIfFileExist(const std::string& file_path) inline bool checkIfFileExist(std::string const& file_path)
{ {
std::ifstream in(file_path, std::ios::in | std::ios::binary); std::ifstream in(file_path, std::ios::in | std::ios::binary);
if (in.is_open()) if (in.is_open())
@ -248,7 +248,7 @@ inline bool checkIfFileExist(const std::string& file_path)
} }
template <typename T> template <typename T>
void saveToBinary(const T* ptr, const size_t size, std::string filename); void saveToBinary(T const* ptr, const size_t size, std::string filename);
template <typename T_IN, typename T_fake_type> template <typename T_IN, typename T_fake_type>
void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream); void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream);
@ -256,10 +256,10 @@ void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream);
size_t cuda_datatype_size(TRTLLMCudaDataType dt); size_t cuda_datatype_size(TRTLLMCudaDataType dt);
template <typename T> template <typename T>
bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream); bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream);
size_t calcAlignedSize(const std::vector<size_t>& sizes, size_t ALIGN_BYTES = 256); size_t calcAlignedSize(std::vector<size_t> const& sizes, size_t ALIGN_BYTES = 256);
void calcAlignedPointers( void calcAlignedPointers(
std::vector<void*>& outPtrs, const void* p, const std::vector<size_t>& sizes, size_t ALIGN_BYTES = 256); std::vector<void*>& outPtrs, void const* p, std::vector<size_t> const& sizes, size_t ALIGN_BYTES = 256);
} // namespace common } // namespace common
} // namespace tensorrt_llm } // namespace tensorrt_llm

View File

@ -50,6 +50,7 @@ MPI_Datatype getMpiDtype(MpiType dtype)
{MpiType::kUINT64, MPI_UINT64_T}, {MpiType::kUINT64, MPI_UINT64_T},
{MpiType::kFP8, MPI_UINT8_T}, {MpiType::kFP8, MPI_UINT8_T},
{MpiType::kBF16, MPI_UINT16_T}, {MpiType::kBF16, MPI_UINT16_T},
{MpiType::kCHAR, MPI_CHAR},
}; };
return dtype_map.at(dtype); return dtype_map.at(dtype);
} }
@ -126,23 +127,6 @@ void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
MPICHECK(MPI_Bcast(buffer, size, getMpiDtype(dtype), root, mComm)); MPICHECK(MPI_Bcast(buffer, size, getMpiDtype(dtype), root, mComm));
} }
void MpiComm::bcast(std::vector<int64_t>& packed, int root) const
{
int64_t nWords1;
auto const rank = getRank();
if (rank == root)
{
nWords1 = static_cast<int64_t>(packed.size());
}
auto const mpiInt64 = MpiTypeConverter<int64_t>::value;
bcast(&nWords1, 1, mpiInt64, root);
if (rank != root)
{
packed.resize(nWords1);
}
bcast(packed.data(), packed.size(), mpiInt64, root);
}
void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
{ {
MPICHECK(MPI_Send(buffer, size, getMpiDtype(dtype), dest, tag, mComm)); MPICHECK(MPI_Send(buffer, size, getMpiDtype(dtype), dest, tag, mComm));
@ -162,12 +146,12 @@ MpiComm MpiComm::split(int color, int key) const
return MpiComm{splitComm, true}; return MpiComm{splitComm, true};
} }
void MpiComm::allreduce(const void* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const
{ {
MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm)); MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm));
} }
void MpiComm::allgather(const void* sendbuf, void* recvbuf, int count, MpiType dtype) const void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const
{ {
MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm)); MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
} }

View File

@ -39,7 +39,7 @@ public:
constexpr QuantMode(QuantMode const&) noexcept = default; constexpr QuantMode(QuantMode const&) noexcept = default;
constexpr QuantMode& operator=(const QuantMode& other) noexcept = default; constexpr QuantMode& operator=(QuantMode const& other) noexcept = default;
static constexpr QuantMode none() noexcept static constexpr QuantMode none() noexcept
{ {
@ -276,32 +276,32 @@ public:
return quantMode; return quantMode;
} }
constexpr QuantMode operator+(const QuantMode& other) const noexcept constexpr QuantMode operator+(QuantMode const& other) const noexcept
{ {
return QuantMode(mValue | other.mValue); return QuantMode(mValue | other.mValue);
} }
constexpr QuantMode& operator+=(const QuantMode& other) noexcept constexpr QuantMode& operator+=(QuantMode const& other) noexcept
{ {
return *this = *this + other; return *this = *this + other;
} }
constexpr QuantMode operator-(const QuantMode& other) const noexcept constexpr QuantMode operator-(QuantMode const& other) const noexcept
{ {
return QuantMode(mValue & ~other.mValue); return QuantMode(mValue & ~other.mValue);
} }
constexpr QuantMode& operator-=(const QuantMode& other) noexcept constexpr QuantMode& operator-=(QuantMode const& other) noexcept
{ {
return *this = *this - other; return *this = *this - other;
} }
constexpr bool operator==(const QuantMode& other) const noexcept constexpr bool operator==(QuantMode const& other) const noexcept
{ {
return mValue == other.mValue; return mValue == other.mValue;
} }
constexpr bool operator!=(const QuantMode& other) const noexcept constexpr bool operator!=(QuantMode const& other) const noexcept
{ {
return !(*this == other); return !(*this == other);
} }

View File

@ -63,11 +63,11 @@ struct BytesToType<16>
}; };
template <int Bytes> template <int Bytes>
__device__ inline void copy(const void* local, void* data) __device__ inline void copy(void const* local, void* data)
{ {
using T = typename BytesToType<Bytes>::type; using T = typename BytesToType<Bytes>::type;
const T* in = static_cast<const T*>(local); T const* in = static_cast<T const*>(local);
T* out = static_cast<T*>(data); T* out = static_cast<T*>(data);
*out = *in; *out = *in;
} }
@ -257,8 +257,8 @@ __inline__ __device__ void cgBlockReduceSumElements(float* element_list, float*
cg::thread_block cta = cg::this_thread_block(); cg::thread_block cta = cg::this_thread_block();
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta); cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);
const int tid = cta.thread_rank(); int const tid = cta.thread_rank();
const int blockz = blockDim.x; int const blockz = blockDim.x;
for (int i = 0; i < NUM; i++) for (int i = 0; i < NUM; i++)
{ {
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) #if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
@ -325,7 +325,7 @@ struct TopK
__device__ __forceinline__ void init() __device__ __forceinline__ void init()
{ {
const bool IS_FP16 = std::is_same<T, half>::value; bool const IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
for (int i = 0; i < MAX_K; i++) for (int i = 0; i < MAX_K; i++)
@ -337,7 +337,7 @@ struct TopK
}; };
template <typename T, int MAX_K> template <typename T, int MAX_K>
__device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(const TopK<T, MAX_K>& a, const TopK<T, MAX_K>& b) __device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(TopK<T, MAX_K> const& a, TopK<T, MAX_K> const& b)
{ {
TopK<T, MAX_K> res = a; TopK<T, MAX_K> res = a;
for (int i = 0; i < MAX_K; ++i) for (int i = 0; i < MAX_K; ++i)
@ -368,19 +368,19 @@ struct TopK_2
}; };
template <typename T> template <typename T>
__device__ __forceinline__ TopK_2<T> reduce_topk_op_2(const TopK_2<T>& a, const TopK_2<T>& b) __device__ __forceinline__ TopK_2<T> reduce_topk_op_2(TopK_2<T> const& a, TopK_2<T> const& b)
{ {
return a.u > b.u ? a : b; return a.u > b.u ? a : b;
} }
template <typename T> template <typename T>
__device__ __forceinline__ T clamp_inf_for_half(const float input) __device__ __forceinline__ T clamp_inf_for_half(float const input)
{ {
return input; return input;
} }
template <> template <>
__device__ __forceinline__ half clamp_inf_for_half(const float input) __device__ __forceinline__ half clamp_inf_for_half(float const input)
{ {
// clamp inf values to enable fp16 training // clamp inf values to enable fp16 training
return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000); return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000);

View File

@ -152,7 +152,7 @@ Tensor Tensor::slice(std::vector<size_t> shape, size_t offset) const
return Tensor(this->where, this->type, shape, this->getPtrWithOffset(offset)); return Tensor(this->where, this->type, shape, this->getPtrWithOffset(offset));
} }
TensorMap::TensorMap(const std::unordered_map<std::string, Tensor>& tensor_map) TensorMap::TensorMap(std::unordered_map<std::string, Tensor> const& tensor_map)
{ {
for (auto& kv : tensor_map) for (auto& kv : tensor_map)
{ {
@ -167,7 +167,7 @@ TensorMap::TensorMap(const std::unordered_map<std::string, Tensor>& tensor_map)
} }
} }
TensorMap::TensorMap(const std::vector<Tensor>& tensor_map) TensorMap::TensorMap(std::vector<Tensor> const& tensor_map)
{ {
for (size_t i = 0; i < tensor_map.size(); i++) for (size_t i = 0; i < tensor_map.size(); i++)
{ {

View File

@ -191,7 +191,7 @@ struct TensorDataType<int*>
}; };
template <> template <>
struct TensorDataType<const int*> struct TensorDataType<int const*>
{ {
static constexpr DataType value = TYPE_INT32_PTR; static constexpr DataType value = TYPE_INT32_PTR;
}; };
@ -419,8 +419,8 @@ private:
public: public:
TensorMap() = default; TensorMap() = default;
TensorMap(const std::unordered_map<std::string, Tensor>& tensor_map); TensorMap(std::unordered_map<std::string, Tensor> const& tensor_map);
TensorMap(const std::vector<Tensor>& tensor_map); TensorMap(std::vector<Tensor> const& tensor_map);
TensorMap(std::initializer_list<std::pair<std::string, Tensor>> tensor_map); TensorMap(std::initializer_list<std::pair<std::string, Tensor>> tensor_map);
~TensorMap(); ~TensorMap();
@ -429,7 +429,7 @@ public:
return tensor_map_.size(); return tensor_map_.size();
} }
inline bool contains(const std::string& key) const inline bool contains(std::string const& key) const
{ {
TLLM_LOG_TRACE("%s for key: %s", __PRETTY_FUNCTION__, key.c_str()); TLLM_LOG_TRACE("%s for key: %s", __PRETTY_FUNCTION__, key.c_str());
return tensor_map_.find(key) != tensor_map_.end(); return tensor_map_.find(key) != tensor_map_.end();
@ -437,7 +437,7 @@ public:
std::vector<std::string> keys() const; std::vector<std::string> keys() const;
inline void insert(const std::string& key, const Tensor& value) inline void insert(std::string const& key, Tensor const& value)
{ {
TLLM_CHECK_WITH_INFO(!contains(key), fmtstr("Duplicated key %s", key.c_str())); TLLM_CHECK_WITH_INFO(!contains(key), fmtstr("Duplicated key %s", key.c_str()));
TLLM_CHECK_WITH_INFO( TLLM_CHECK_WITH_INFO(
@ -445,7 +445,7 @@ public:
tensor_map_.insert({key, value}); tensor_map_.insert({key, value});
} }
inline void insertIfValid(const std::string& key, const Tensor& value) inline void insertIfValid(std::string const& key, Tensor const& value)
{ {
if (value.isValid()) if (value.isValid())
{ {
@ -462,7 +462,7 @@ public:
Tensor at(int tmp) = delete; Tensor at(int tmp) = delete;
Tensor at(size_t tmp) = delete; Tensor at(size_t tmp) = delete;
inline Tensor& at(const std::string& key) inline Tensor& at(std::string const& key)
{ {
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str()); TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
TLLM_CHECK_WITH_INFO(contains(key), TLLM_CHECK_WITH_INFO(contains(key),
@ -471,7 +471,7 @@ public:
return tensor_map_.at(key); return tensor_map_.at(key);
} }
inline Tensor at(const std::string& key) const inline Tensor at(std::string const& key) const
{ {
TLLM_CHECK_WITH_INFO(contains(key), TLLM_CHECK_WITH_INFO(contains(key),
fmtstr( fmtstr(
@ -479,7 +479,7 @@ public:
return tensor_map_.at(key); return tensor_map_.at(key);
} }
inline std::optional<Tensor> atOpt(const std::string& key) const inline std::optional<Tensor> atOpt(std::string const& key) const
{ {
if (contains(key)) if (contains(key))
return tensor_map_.at(key); return tensor_map_.at(key);
@ -487,7 +487,7 @@ public:
return std::nullopt; return std::nullopt;
} }
inline Tensor& at(const std::string& key, Tensor& default_tensor) inline Tensor& at(std::string const& key, Tensor& default_tensor)
{ {
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str()); TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
if (contains(key)) if (contains(key))
@ -497,7 +497,7 @@ public:
return default_tensor; return default_tensor;
} }
inline Tensor at(const std::string& key, Tensor& default_tensor) const inline Tensor at(std::string const& key, Tensor& default_tensor) const
{ {
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str()); TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
if (contains(key)) if (contains(key))
@ -507,7 +507,7 @@ public:
return default_tensor; return default_tensor;
} }
inline Tensor& at(const std::string& key, Tensor&& default_tensor) inline Tensor& at(std::string const& key, Tensor&& default_tensor)
{ {
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str()); TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
if (contains(key)) if (contains(key))
@ -517,7 +517,7 @@ public:
return default_tensor; return default_tensor;
} }
inline Tensor at(const std::string& key, Tensor&& default_tensor) const inline Tensor at(std::string const& key, Tensor&& default_tensor) const
{ {
if (contains(key)) if (contains(key))
{ {
@ -527,7 +527,7 @@ public:
} }
template <typename T> template <typename T>
inline T getVal(const std::string& key) const inline T getVal(std::string const& key) const
{ {
TLLM_CHECK_WITH_INFO(contains(key), TLLM_CHECK_WITH_INFO(contains(key),
fmtstr( fmtstr(
@ -536,7 +536,7 @@ public:
} }
template <typename T> template <typename T>
inline std::optional<T> getValOpt(const std::string& key) const inline std::optional<T> getValOpt(std::string const& key) const
{ {
if (contains(key)) if (contains(key))
{ {
@ -549,7 +549,7 @@ public:
} }
template <typename T> template <typename T>
inline T getVal(const std::string& key, T default_value) const inline T getVal(std::string const& key, T default_value) const
{ {
if (contains(key)) if (contains(key))
{ {
@ -559,7 +559,7 @@ public:
} }
template <typename T> template <typename T>
inline T getValWithOffset(const std::string& key, size_t index) const inline T getValWithOffset(std::string const& key, size_t index) const
{ {
TLLM_CHECK_WITH_INFO(contains(key), TLLM_CHECK_WITH_INFO(contains(key),
fmtstr( fmtstr(
@ -568,7 +568,7 @@ public:
} }
template <typename T> template <typename T>
inline T getValWithOffset(const std::string& key, size_t index, T default_value) const inline T getValWithOffset(std::string const& key, size_t index, T default_value) const
{ {
if (contains(key)) if (contains(key))
{ {
@ -578,7 +578,7 @@ public:
} }
template <typename T> template <typename T>
inline T* getPtr(const std::string& key) const inline T* getPtr(std::string const& key) const
{ {
TLLM_CHECK_WITH_INFO(contains(key), TLLM_CHECK_WITH_INFO(contains(key),
fmtstr( fmtstr(
@ -587,7 +587,7 @@ public:
} }
template <typename T> template <typename T>
inline T* getPtr(const std::string& key, T* default_ptr) const inline T* getPtr(std::string const& key, T* default_ptr) const
{ {
if (contains(key)) if (contains(key))
{ {
@ -597,7 +597,7 @@ public:
} }
template <typename T> template <typename T>
inline T* getPtrWithOffset(const std::string& key, size_t index) const inline T* getPtrWithOffset(std::string const& key, size_t index) const
{ {
TLLM_CHECK_WITH_INFO(contains(key), TLLM_CHECK_WITH_INFO(contains(key),
fmtstr( fmtstr(
@ -606,7 +606,7 @@ public:
} }
template <typename T> template <typename T>
inline T* getPtrWithOffset(const std::string& key, size_t index, T* default_ptr) const inline T* getPtrWithOffset(std::string const& key, size_t index, T* default_ptr) const
{ {
if (contains(key)) if (contains(key))
{ {

View File

@ -34,7 +34,7 @@ int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2;
#if !defined(_MSC_VER) #if !defined(_MSC_VER)
TllmException::TllmException(char const* file, std::size_t line, const std::string& msg) TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
: std::runtime_error{""} : std::runtime_error{""}
{ {
mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES); mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES);
@ -43,7 +43,7 @@ TllmException::TllmException(char const* file, std::size_t line, const std::stri
std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())}); std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())});
} }
#else #else
TllmException::TllmException(char const* file, std::size_t line, const std::string& msg) TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
: mNbFrames{} : mNbFrames{}
, std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)} , std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)}
{ {

View File

@ -65,7 +65,7 @@ __forceinline__ __device__ float copysignf_pos(float a, float b)
__forceinline__ __device__ float tanh_opt(float x) __forceinline__ __device__ float tanh_opt(float x)
{ {
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750) #if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
const float exp_val = -1.f * fabs(2 * x); float const exp_val = -1.f * fabs(2 * x);
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x); return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
#else #else
return fast_tanh(x); return fast_tanh(x);
@ -76,7 +76,7 @@ __forceinline__ __device__ float tanh_opt(float x)
template <> template <>
struct GELU_taylor<float> struct GELU_taylor<float>
{ {
static const bool kIsHeavy = true; static bool const kIsHeavy = true;
CUTLASS_DEVICE CUTLASS_DEVICE
float operator()(float const& z) const float operator()(float const& z) const

View File

@ -157,8 +157,8 @@ private:
MatrixCoord extent_real_; MatrixCoord extent_real_;
ElementwiseFunctor elementwise_; ElementwiseFunctor elementwise_;
const bool per_token_quant_; bool const per_token_quant_;
const bool per_channel_quant_; bool const per_channel_quant_;
AlphaScaleElementType* ptr_alpha_row_; AlphaScaleElementType* ptr_alpha_row_;
AlphaScaleElementType* ptr_alpha_col_; AlphaScaleElementType* ptr_alpha_col_;

View File

@ -65,7 +65,7 @@ namespace device
///////////////////////////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T_IN, typename T_OUT> template <typename T_IN, typename T_OUT>
__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, const GemmCoord* problem_sizes, int splitk, __global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk,
int64_t* splitk_buffer_offsets) int64_t* splitk_buffer_offsets)
{ {
// in_tensor: [problem_idx, k_partition, hidden_size] // in_tensor: [problem_idx, k_partition, hidden_size]
@ -73,9 +73,9 @@ __global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, const
// so, we need to use splitk_buffer_offsets. // so, we need to use splitk_buffer_offsets.
// out_tensor: problem_idx * [hidden_size] // out_tensor: problem_idx * [hidden_size]
const int problem_idx = blockIdx.y; int const problem_idx = blockIdx.y;
GemmCoord problem = problem_sizes[problem_idx]; GemmCoord problem = problem_sizes[problem_idx];
const int hidden_size = problem.m() * problem.n(); int const hidden_size = problem.m() * problem.n();
const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk; const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk;
T_OUT* out_tensor_ = out_tensor[problem_idx]; T_OUT* out_tensor_ = out_tensor[problem_idx];
@ -143,7 +143,7 @@ protected:
private: private:
/// Get the number of tiles across all problems in a group /// Get the number of tiles across all problems in a group
static int32_t group_tile_count(const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count) static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count)
{ {
int32_t tiles = 0; int32_t tiles = 0;
for (int32_t i = 0; i < problem_count; ++i) for (int32_t i = 0; i < problem_count; ++i)
@ -182,7 +182,7 @@ private:
/// Reorder `data` according to `indices` /// Reorder `data` according to `indices`
template <typename T> template <typename T>
static void reorder_array(T* data, const std::vector<size_t>& indices) static void reorder_array(T* data, std::vector<size_t> const& indices)
{ {
// For now, simply create a copy of the data and then copy over to the original. // For now, simply create a copy of the data and then copy over to the original.
std::vector<T> copy(indices.size()); std::vector<T> copy(indices.size());
@ -314,7 +314,7 @@ public:
/// Computes the number of threadblocks to launch for the grouped kernel /// Computes the number of threadblocks to launch for the grouped kernel
static int sufficient( static int sufficient(
const cutlass::gemm::GemmCoord* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1) cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1)
{ {
// Determine the number of blocks that would be launched to fill up a single // Determine the number of blocks that would be launched to fill up a single
// wave on the GPU with each SM having maximum occupancy. // wave on the GPU with each SM having maximum occupancy.

View File

@ -142,7 +142,7 @@ struct GemmFpAIntB
Arguments() {} Arguments() {}
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Arguments(cutlass::gemm::GemmCoord const& problem_size, const int group_size, Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size,
typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B, typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B,
typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero, typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero,
typename Epilogue::OutputTileIterator::TensorRef ref_C, typename Epilogue::OutputTileIterator::TensorRef ref_C,
@ -206,7 +206,7 @@ struct GemmFpAIntB
} }
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, const int gemm_k_size, Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size,
void* workspace = nullptr) void* workspace = nullptr)
: problem_size(args.problem_size) : problem_size(args.problem_size)
, group_size(args.group_size) , group_size(args.group_size)

View File

@ -174,7 +174,7 @@ public:
/// Ctor /// Ctor
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op, Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op,
const ElementA* ptr_A, const ElementB* ptr_B, const ElementScale* weight_scales, const ElementC* ptr_C, ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C,
ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k, ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k,
GemmCoord* host_problem_sizes = nullptr) GemmCoord* host_problem_sizes = nullptr)
: problem_count(problem_count) : problem_count(problem_count)

View File

@ -119,7 +119,7 @@ struct BaseMoeProblemVisitor
/// Get the grid shape /// Get the grid shape
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem)
{ {
return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
@ -177,12 +177,12 @@ struct BaseMoeProblemVisitor
} }
CUTLASS_HOST_DEVICE CUTLASS_HOST_DEVICE
static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) static int32_t tile_count(cutlass::gemm::GemmCoord const& grid)
{ {
return ProblemSizeHelper::tile_count(grid); return ProblemSizeHelper::tile_count(grid);
} }
static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count) static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count)
{ {
int32_t total_tiles = 0; int32_t total_tiles = 0;
for (int32_t i = 0; i < problem_count; ++i) for (int32_t i = 0; i < problem_count; ++i)
@ -328,12 +328,12 @@ struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode:
} }
static size_t get_workspace_size( static size_t get_workspace_size(
const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count) cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count)
{ {
return 0; return 0;
} }
static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count,
int32_t block_count, void* host_workspace_ptr) int32_t block_count, void* host_workspace_ptr)
{ {
} }

View File

@ -60,7 +60,7 @@ namespace threadblock
template <typename WarpMma, int kExpansionFactor = 1> template <typename WarpMma, int kExpansionFactor = 1>
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C, typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C,
const int warp_tileB_k_offset) int const warp_tileB_k_offset)
{ {
warp_mma(D, A, B, C); warp_mma(D, A, B, C);
} }
@ -68,7 +68,7 @@ CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC&
template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor> template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D, CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B, typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B,
typename WarpMma::FragmentC const& C, const int warp_tileB_k_offset) typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset)
{ {
warp_mma(D, A, B, C, warp_tileB_k_offset); warp_mma(D, A, B, C, warp_tileB_k_offset);
} }

View File

@ -572,8 +572,8 @@ public:
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_A_;
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{ {
this->warp_tile_iterator_B_.set_kgroup_index( this->warp_tile_iterator_B_.set_kgroup_index(

View File

@ -219,7 +219,7 @@ public:
///< Shared storage needed for internal use by threadblock-scoped GEMM ///< Shared storage needed for internal use by threadblock-scoped GEMM
typename Base::SharedStorage& shared_storage, typename Base::SharedStorage& shared_storage,
///< Group size for quantization. Not used by this main loop since it assumes per-column ///< Group size for quantization. Not used by this main loop since it assumes per-column
const int group_size, int const group_size,
///< ID within the threadblock ///< ID within the threadblock
int thread_idx, int thread_idx,
///< ID of warp ///< ID of warp
@ -534,8 +534,8 @@ public:
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_A_;
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{ {
this->warp_tile_iterator_B_.set_kgroup_index( this->warp_tile_iterator_B_.set_kgroup_index(

View File

@ -184,7 +184,7 @@ public:
CUTLASS_DEVICE CUTLASS_DEVICE
DqMmaPipelined(typename Base::SharedStorage& DqMmaPipelined(typename Base::SharedStorage&
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
const int group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this ///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
///< argument is not added, it does not affect compilation for sm>=80. ///< argument is not added, it does not affect compilation for sm>=80.
int thread_idx, ///< ID within the threadblock int thread_idx, ///< ID within the threadblock
@ -353,8 +353,8 @@ public:
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
++this->warp_tile_iterator_A_; ++this->warp_tile_iterator_A_;
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad; int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad; int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment. // We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1) if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
{ {

View File

@ -218,7 +218,7 @@ public:
/// Performs a warp-level matrix multiply-accumulate operation /// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE CUTLASS_DEVICE
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C, void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C,
const int warp_tileB_k_offset) const int const warp_tileB_k_offset) const
{ {
using MmaOperandA = typename ArchMmaOperator::FragmentA; using MmaOperandA = typename ArchMmaOperator::FragmentA;

View File

@ -136,11 +136,11 @@ public:
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
CUTLASS_DEVICE CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, const int warp_idx_n, const int lane_idx) MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
{ {
const int warp_offset = warp_idx_n * Shape::kN; int const warp_offset = warp_idx_n * Shape::kN;
const int quad = lane_idx / 4; int const quad = lane_idx / 4;
const int thread_offset = warp_offset + quad; int const thread_offset = warp_offset + quad;
pointer_scale_ = smem_scales.data() + thread_offset; pointer_scale_ = smem_scales.data() + thread_offset;
if constexpr (hasZero(QuantOp)) if constexpr (hasZero(QuantOp))
{ {
@ -149,7 +149,7 @@ public:
} }
CUTLASS_DEVICE CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
{ {
} }
@ -165,7 +165,7 @@ public:
} }
CUTLASS_DEVICE CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
{ {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
using _MmaOperandB = typename ArchMmaOperator::FragmentB; using _MmaOperandB = typename ArchMmaOperator::FragmentB;
@ -174,7 +174,7 @@ public:
== FragmentDequantizedOperand::kElements, == FragmentDequantizedOperand::kElements,
""); "");
const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag); __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag); ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter) for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
@ -222,7 +222,7 @@ public:
CUTLASS_DEVICE CUTLASS_DEVICE
void dequantize( void dequantize(
FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag, const FragmentScale& zero_frag) FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
{ {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16)) #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
using _MmaOperandB = typename ArchMmaOperator::FragmentB; using _MmaOperandB = typename ArchMmaOperator::FragmentB;
@ -231,8 +231,8 @@ public:
== FragmentDequantizedOperand::kElements, == FragmentDequantizedOperand::kElements,
""); "");
const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag); __nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
const __nv_bfloat16* zero_ptr = reinterpret_cast<const __nv_bfloat16*>(&zero_frag); __nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag);
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag); ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
CUTLASS_PRAGMA_UNROLL CUTLASS_PRAGMA_UNROLL
@ -335,11 +335,11 @@ public:
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_; static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
CUTLASS_DEVICE CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, const int warp_idx_n, const int lane_idx) MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
{ {
const int warp_offset = warp_idx_n * Shape::kN; int const warp_offset = warp_idx_n * Shape::kN;
const int quad = lane_idx / 4; int const quad = lane_idx / 4;
const int thread_offset = warp_offset + quad; int const thread_offset = warp_offset + quad;
pointer_scale_ = smem_scales.data() + thread_offset; pointer_scale_ = smem_scales.data() + thread_offset;
if constexpr (hasZero(QuantOp)) if constexpr (hasZero(QuantOp))
{ {
@ -348,7 +348,7 @@ public:
} }
CUTLASS_DEVICE CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx) : MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
{ {
} }
@ -364,7 +364,7 @@ public:
} }
CUTLASS_DEVICE CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
{ {
using _MmaOperandB = typename ArchMmaOperator::FragmentB; using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>; using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
@ -406,7 +406,7 @@ public:
CUTLASS_DEVICE CUTLASS_DEVICE
void dequantize( void dequantize(
FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag, const FragmentScale& zero_frag) FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
{ {
using _MmaOperandB = typename ArchMmaOperator::FragmentB; using _MmaOperandB = typename ArchMmaOperator::FragmentB;
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>; using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
@ -505,11 +505,11 @@ public:
static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
CUTLASS_DEVICE CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
{ {
const int warp_offset = warp_idx_n * Shape::kN; int const warp_offset = warp_idx_n * Shape::kN;
const int base_col = lane_idx & 0xF8; int const base_col = lane_idx & 0xF8;
const int thread_offset = warp_offset + base_col; int const thread_offset = warp_offset + base_col;
pointer_ = smem_scales.data() + thread_offset; pointer_ = smem_scales.data() + thread_offset;
} }
@ -527,7 +527,7 @@ public:
} }
CUTLASS_DEVICE CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
{ {
static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, ""); static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, "");
@ -591,11 +591,11 @@ public:
static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, ""); static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
CUTLASS_DEVICE CUTLASS_DEVICE
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx) MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
{ {
const int warp_offset = warp_idx_n * Shape::kN; int const warp_offset = warp_idx_n * Shape::kN;
const int base_col = lane_idx & 0xF8 + lane_idx % 4; int const base_col = lane_idx & 0xF8 + lane_idx % 4;
const int thread_offset = warp_offset + base_col; int const thread_offset = warp_offset + base_col;
pointer_ = smem_scales.data() + thread_offset; pointer_ = smem_scales.data() + thread_offset;
} }
@ -617,7 +617,7 @@ public:
} }
CUTLASS_DEVICE CUTLASS_DEVICE
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag) void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
{ {
using MmaOperandB = typename ArchMmaOperator::FragmentB; using MmaOperandB = typename ArchMmaOperator::FragmentB;
static constexpr int total_n_mmas = 2 * TileNIterations; static constexpr int total_n_mmas = 2 * TileNIterations;

View File

@ -167,8 +167,8 @@ public:
static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment; static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment;
const int thread_row = thread_id / THREADS_PER_ROW; int const thread_row = thread_id / THREADS_PER_ROW;
const int thread_col = thread_id % THREADS_PER_ROW; int const thread_col = thread_id % THREADS_PER_ROW;
const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits<Element>::value / 8; const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits<Element>::value / 8;
const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits<Element>::value / 8; const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits<Element>::value / 8;
@ -182,11 +182,11 @@ public:
// a given iteration. The same threads will be responsible for issues reads since the number of scales // a given iteration. The same threads will be responsible for issues reads since the number of scales
// read in a given iteration is a constant. Therefore, we should never have to update is_valid_ // read in a given iteration is a constant. Therefore, we should never have to update is_valid_
// outside of the constructor. // outside of the constructor.
const int global_row = threadblock_offset.row() + thread_row; int const global_row = threadblock_offset.row() + thread_row;
const int global_col = threadblock_offset.column() + thread_col * kAlignment; int const global_col = threadblock_offset.column() + thread_col * kAlignment;
const bool row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow; bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow;
const bool col_in_bounds = global_col < extent.column(); bool const col_in_bounds = global_col < extent.column();
is_valid_ = row_in_bounds && col_in_bounds; is_valid_ = row_in_bounds && col_in_bounds;
} }

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:4201c7241d53298ca52d4f1447cc9cbc4024f63b42a24cbcff82192cc10bed67 oid sha256:e1cdcabfbc5115c0d3228c567800d2706f1bc9e3752aaaa8148bcfe83be2c08c
size 576098 size 716756

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:2960feb2c7ad941a473408e2f6fd8c324f60f6af3c4d8f11217c676fd830e4cb oid sha256:ea48a79b211bc9857e7a881d6b9bc22580280e1d7cf3b30d6613466f4f440f8f
size 578660 size 721934

View File

@ -1,3 +1,3 @@
8a8d6505d9ef62cb2eeb8c75a5ee5bbb libtensorrt_llm_executor_static.a 56853a19cf213aa5330ea087c9d86a60 libtensorrt_llm_executor_static.a
e3b8edc619c99a7f125fe81bc8554ff0 libtensorrt_llm_executor_static.pre_cxx11.a 213487d55c816a1987aa79547091068f libtensorrt_llm_executor_static.pre_cxx11.a
230623fa285048a2de5c54c2cc0f364fb9f2c559 commit 741fb083cc42933439ae54557b177b6d7064da4f commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:cde295fa290b15b3d76b8e8b2cc435d7fceb2f456d8cb4d9b22ee2cf3ddbd344 oid sha256:499f3aac1b98c5b411f1dacdddf8521b2b1f600388b44e6f7aab5b3f0cdf1280
size 588504 size 721366

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1 version https://git-lfs.github.com/spec/v1
oid sha256:54ac66f3555bff4ed28ba0352bcb4a0f541346592cf109b491071b6374e5238c oid sha256:9c2c7e84be6b0e8baf296196ee9d7e84509bda2630ce3ada8a39dc498713ff48
size 562260 size 700000

View File

@ -1,2 +1,2 @@
ee96c6e2742539da0e8d732635f84449 libtensorrt_llm_executor_static.a dcca3b095dad76dac36611be6104f011 libtensorrt_llm_executor_static.a
9154564ed926ffbcdb83e7eac3504fa0 libtensorrt_llm_executor_static.pre_cxx11.a 6cae7ce493704f7ad8d724cf8a538e2c libtensorrt_llm_executor_static.pre_cxx11.a

View File

@ -25,9 +25,9 @@ namespace kernels
{ {
template <typename T> template <typename T>
__global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf, __global__ void ban_repeat_ngram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf,
const int** parent_ids_buf, const int* batch_slots, int batch_size, int beam_width, int max_seq_len, int const** parent_ids_buf, int const* batch_slots, int batch_size, int beam_width, int max_seq_len,
const int* no_repeat_ngram_size_buf, int vocab_size_padded, const int* sequence_lengths) int const* no_repeat_ngram_size_buf, int vocab_size_padded, int const* sequence_lengths)
{ {
/** /**
* Find subsequences that match the last (ngram_size - 1) generated tokens. The next-tokens of those matching * Find subsequences that match the last (ngram_size - 1) generated tokens. The next-tokens of those matching
@ -46,13 +46,13 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const Fi
* in-bound positions only. For leftside out-of-boundary tokens, access by global memory. * in-bound positions only. For leftside out-of-boundary tokens, access by global memory.
*/ */
const int output_idx = blockIdx.x * blockDim.x + threadIdx.x; int const output_idx = blockIdx.x * blockDim.x + threadIdx.x;
const int local_batch_idx = blockIdx.y / beam_width; int const local_batch_idx = blockIdx.y / beam_width;
auto const batch_slot = batch_slots != nullptr ? batch_slots[local_batch_idx] : local_batch_idx; auto const batch_slot = batch_slots != nullptr ? batch_slots[local_batch_idx] : local_batch_idx;
const int beam_idx = blockIdx.y % beam_width; int const beam_idx = blockIdx.y % beam_width;
const bool beam_search = beam_width > 1; bool const beam_search = beam_width > 1;
const int no_repeat_ngram_size = no_repeat_ngram_size_buf[batch_slot]; int const no_repeat_ngram_size = no_repeat_ngram_size_buf[batch_slot];
const int step = sequence_lengths[batch_slot]; int const step = sequence_lengths[batch_slot];
// case 1: ngram_size == 0 --> this means no ngram limit // case 1: ngram_size == 0 --> this means no ngram limit
// case 2: generated length must be greater than ngram_size to do ngram check // case 2: generated length must be greater than ngram_size to do ngram check
@ -133,9 +133,9 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const Fi
} }
template <typename T> template <typename T>
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf, void invokeBanRepeatNgram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf,
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, int beam_width, int const** parent_ids_buf, int const* batch_slot, int const* sequence_lengths, int batch_size, int beam_width,
int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream) int max_seq_len, int const* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream)
{ {
// each input in the local batch can have different no_repeat_ngram_size. Use max for shmem allocation // each input in the local batch can have different no_repeat_ngram_size. Use max for shmem allocation
// getting the max of current batch and allocate shmem as needed is ideal. But here the ngram_buf is on GPU, while // getting the max of current batch and allocate shmem as needed is ideal. But here the ngram_buf is on GPU, while

View File

@ -26,9 +26,9 @@ namespace kernels
{ {
template <typename T> template <typename T>
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf, void invokeBanRepeatNgram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf,
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, int beam_width, int const** parent_ids_buf, int const* batch_slot, int const* sequence_lengths, int batch_size, int beam_width,
int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream); int max_seq_len, int const* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream);
} // namespace kernels } // namespace kernels
} // namespace tensorrt_llm } // namespace tensorrt_llm

View File

@ -49,8 +49,8 @@ __device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float
template <typename T, int MAX_K, int THREADBLOCK_SIZE> template <typename T, int MAX_K, int THREADBLOCK_SIZE>
__launch_bounds__(THREADBLOCK_SIZE) __global__ __launch_bounds__(THREADBLOCK_SIZE) __global__
void beam_topK_kernel(const T* log_probs, int* topk_tmp_id_buf, T* topk_tmp_val_buf, const bool* finished, void beam_topK_kernel(T const* log_probs, int* topk_tmp_id_buf, T* topk_tmp_val_buf, bool const* finished,
const int* sequence_lengths, const int vocab_size, T diversity_rate, float length_penalty) int const* sequence_lengths, int const vocab_size, T diversity_rate, float length_penalty)
{ {
typedef cub::BlockReduce<TopK<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce; typedef cub::BlockReduce<TopK<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
@ -59,7 +59,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
int block_id = blockIdx.x; // batch beam index. int block_id = blockIdx.x; // batch beam index.
TopK<T, MAX_K> partial; TopK<T, MAX_K> partial;
const bool IS_FP16 = std::is_same<T, half>::value; bool const IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
#pragma unroll #pragma unroll
@ -101,7 +101,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
{ {
int thread_id = threadIdx.x; int thread_id = threadIdx.x;
int block_id = blockIdx.x; int block_id = blockIdx.x;
const bool IS_FP16 = std::is_same<T, half>::value; bool const IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
TopK<T, MAX_K> partial; TopK<T, MAX_K> partial;
if (thread_id == 0) if (thread_id == 0)
@ -136,7 +136,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
int tid = threadIdx.x; int tid = threadIdx.x;
int bid = blockIdx.x; int bid = blockIdx.x;
TopK<T, MAX_K> partial; TopK<T, MAX_K> partial;
const bool IS_FP16 = std::is_same<T, half>::value; bool const IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
#pragma unroll #pragma unroll
@ -167,32 +167,32 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
} }
template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_> template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
__global__ void topk_stage_1_opt3(const T* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf, __global__ void topk_stage_1_opt3(T const* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf,
T* topk_tmp_val_buf, const bool* finished, const int* sequence_lengths, const int k, const int vocab_size, T* topk_tmp_val_buf, bool const* finished, int const* sequence_lengths, int const k, int const vocab_size,
const float length_penalty, const int* end_ids) float const length_penalty, int const* end_ids)
{ {
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce; typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
const int tid = threadIdx.x; int const tid = threadIdx.x;
const int bid = blockIdx.x; int const bid = blockIdx.x;
const int row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs (batchbeam index) int const row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs (batchbeam index)
const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam int const block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam
const int tmp_log_buf_index = row_id * vocab_size; int const tmp_log_buf_index = row_id * vocab_size;
const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k; int const tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k;
TopK_2<T> partial; TopK_2<T> partial;
const bool IS_FP16 = std::is_same<T, half>::value; bool const IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
if (finished != nullptr && finished[row_id] == true) if (finished != nullptr && finished[row_id] == true)
{ {
if (tid < k) if (tid < k)
{ {
const int index = tmp_topk_buf_index + tid; int const index = tmp_topk_buf_index + tid;
if (block_lane == 0 && tid == 0) if (block_lane == 0 && tid == 0)
{ {
const int end_id = end_ids[row_id / k]; int const end_id = end_ids[row_id / k];
topk_tmp_id_buf[index] = tmp_log_buf_index + end_id; topk_tmp_id_buf[index] = tmp_log_buf_index + end_id;
topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id]; topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id];
} }
@ -226,7 +226,7 @@ __global__ void topk_stage_1_opt3(const T* __restrict log_probs, T* tmp_log_prob
if (tid == 0) if (tid == 0)
{ {
const int index = tmp_topk_buf_index + ite; int const index = tmp_topk_buf_index + ite;
topk_tmp_id_buf[index] = total.p; topk_tmp_id_buf[index] = total.p;
topk_tmp_val_buf[index] = total.u; topk_tmp_val_buf[index] = total.u;
tmp_log_probs[total.p] = -MAX_T_VAL; tmp_log_probs[total.p] = -MAX_T_VAL;
@ -236,15 +236,15 @@ __global__ void topk_stage_1_opt3(const T* __restrict log_probs, T* tmp_log_prob
} }
template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_> template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
__global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids, __global__ void topk_stage_2_opt3(int const* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids,
BeamHypotheses beam_hyps, const int* end_ids, const int vocab_size, const int k) BeamHypotheses beam_hyps, int const* end_ids, int const vocab_size, int const k)
{ {
const int size = k * k * BLOCKS_PER_BEAM_; int const size = k * k * BLOCKS_PER_BEAM_;
const int tid = threadIdx.x; int const tid = threadIdx.x;
const int batch_id = blockIdx.x; int const batch_id = blockIdx.x;
const bool IS_FP16 = std::is_same<T, half>::value; bool const IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
const float length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]}; float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]};
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce; typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
@ -263,7 +263,7 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk
__syncthreads(); __syncthreads();
if (beam_hyps.num_beams != nullptr) if (beam_hyps.num_beams != nullptr)
{ {
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0) if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0)
{ {
// initialize the buffer // initialize the buffer
@ -304,9 +304,9 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk
} }
else else
{ {
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
const float normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty); float const normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty);
const int num_beam = beam_hyps.num_beams[global_batch_idx]; int const num_beam = beam_hyps.num_beams[global_batch_idx];
int beam_idx = num_beam; int beam_idx = num_beam;
// If there are beam_width finished sentences, check that the score of // If there are beam_width finished sentences, check that the score of
// selected candidatet is higher than min_normed_score or not. If // selected candidatet is higher than min_normed_score or not. If
@ -345,20 +345,20 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk
} }
} }
} }
const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx) int const tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx)
* beam_hyps.max_seq_len; * beam_hyps.max_seq_len;
beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id]; beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id];
int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k; int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k;
for (int j = beam_hyps.step - 1; j >= 0; j--) for (int j = beam_hyps.step - 1; j >= 0; j--)
{ {
const int src_idx = j * beam_hyps.batch_size * k int const src_idx = j * beam_hyps.batch_size * k
+ beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id; + beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id;
beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx]; beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx];
prev_id = beam_hyps.parent_ids_src[src_idx]; prev_id = beam_hyps.parent_ids_src[src_idx];
} }
const int tgt_beam_idx = global_batch_idx * k + beam_idx; int const tgt_beam_idx = global_batch_idx * k + beam_idx;
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step; beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step;
beam_hyps.normed_scores[tgt_beam_idx] = normed_score; beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
beam_hyps.min_normed_scores[global_batch_idx] beam_hyps.min_normed_scores[global_batch_idx]
@ -389,21 +389,21 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk
} }
template <typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM> template <typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM>
__global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf, __global__ void topk_stage_1_opt2_general(T const* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf,
T* topk_tmp_val_buf, const bool* finished, const int* sequence_lengths, const int k, const int vocab_size, T* topk_tmp_val_buf, bool const* finished, int const* sequence_lengths, int const k, int const vocab_size,
const float length_penalty) float const length_penalty)
{ {
const bool IS_FP16 = std::is_same<T, half>::value; bool const IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce; typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
const int tid = threadIdx.x; int const tid = threadIdx.x;
const int bid = blockIdx.x; int const bid = blockIdx.x;
const int row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs int const row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs
const int block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam int const block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam
const int tmp_log_buf_index = row_id * vocab_size; int const tmp_log_buf_index = row_id * vocab_size;
const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k; int const tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k;
TopK_2<T> partial; TopK_2<T> partial;
for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM) for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM)
@ -426,7 +426,7 @@ __global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, T* tmp_
if (tid == 0) if (tid == 0)
{ {
const int index = tmp_topk_buf_index + ite; int const index = tmp_topk_buf_index + ite;
topk_tmp_id_buf[index] = total.p; topk_tmp_id_buf[index] = total.p;
topk_tmp_val_buf[index] = total.u; topk_tmp_val_buf[index] = total.u;
tmp_log_probs[total.p] = -MAX_T_VAL; tmp_log_probs[total.p] = -MAX_T_VAL;
@ -436,15 +436,15 @@ __global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, T* tmp_
} }
template <typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM> template <typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM>
__global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids, __global__ void topk_stage_2_opt2_general(int const* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids,
BeamHypotheses beam_hyps, const int* end_ids, const int k, const int vocab_size) BeamHypotheses beam_hyps, int const* end_ids, int const k, int const vocab_size)
{ {
const int size = k * k * BLOCKS_PER_BEAM; int const size = k * k * BLOCKS_PER_BEAM;
const int tid = threadIdx.x; int const tid = threadIdx.x;
const int batch_id = blockIdx.x; int const batch_id = blockIdx.x;
const bool IS_FP16 = std::is_same<T, half>::value; bool const IS_FP16 = std::is_same<T, half>::value;
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX; const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
const float length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]}; float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]};
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce; typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
@ -463,7 +463,7 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf,
__syncthreads(); __syncthreads();
if (beam_hyps.num_beams != nullptr) if (beam_hyps.num_beams != nullptr)
{ {
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0) if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0)
{ {
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX; beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
@ -503,9 +503,9 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf,
} }
else else
{ {
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id; int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
const float normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty); float const normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty);
const int num_beam = beam_hyps.num_beams[global_batch_idx]; int const num_beam = beam_hyps.num_beams[global_batch_idx];
int beam_idx = num_beam; int beam_idx = num_beam;
// If there are beam_width finished sentences, check that the score of // If there are beam_width finished sentences, check that the score of
// selected candidatet is higher than min_normed_score or not. If // selected candidatet is higher than min_normed_score or not. If
@ -544,20 +544,20 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf,
} }
} }
} }
const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx) int const tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx)
* beam_hyps.max_seq_len; * beam_hyps.max_seq_len;
beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id]; beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id];
int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k; int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k;
for (int j = beam_hyps.step - 1; j >= 0; j--) for (int j = beam_hyps.step - 1; j >= 0; j--)
{ {
const int src_idx = j * beam_hyps.batch_size * k int const src_idx = j * beam_hyps.batch_size * k
+ beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id; + beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id;
beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx]; beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx];
prev_id = beam_hyps.parent_ids_src[src_idx]; prev_id = beam_hyps.parent_ids_src[src_idx];
} }
const int tgt_beam_idx = global_batch_idx * k + beam_idx; int const tgt_beam_idx = global_batch_idx * k + beam_idx;
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step; beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step;
beam_hyps.normed_scores[tgt_beam_idx] = normed_score; beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
beam_hyps.min_normed_scores[global_batch_idx] beam_hyps.min_normed_scores[global_batch_idx]
@ -613,18 +613,18 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf,
template <typename T> template <typename T>
void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps, void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps,
const bool* finished, const int* sequence_lengths, const int batch_size, const int beam_width, bool const* finished, int const* sequence_lengths, int const batch_size, int const beam_width,
const int vocab_size_padded_, const T diversity_rate, const float length_penalty, const int* end_ids, int const vocab_size_padded_, const T diversity_rate, float const length_penalty, int const* end_ids,
cudaStream_t stream) cudaStream_t stream)
{ {
// log_probs: (batch, beam, vocab) cumulative log_probs of beams ending with a // log_probs: (batch, beam, vocab) cumulative log_probs of beams ending with a
// token. // token.
const int vocab_size = vocab_size_padded_; int const vocab_size = vocab_size_padded_;
// Beam size should be less than or equal to vocab size. // Beam size should be less than or equal to vocab size.
assert(beam_width <= vocab_size); assert(beam_width <= vocab_size);
// Beam search needs the sequence lengths of beams to apply length penalty. // Beam search needs the sequence lengths of beams to apply length penalty.
assert(length_penalty == 0.0f || sequence_lengths != nullptr); assert(length_penalty == 0.0f || sequence_lengths != nullptr);
const int max_block_per_beam = 8; int const max_block_per_beam = 8;
int temp_log_probs_buf_size = batch_size * beam_width * vocab_size; // type float int temp_log_probs_buf_size = batch_size * beam_width * vocab_size; // type float
int topk_tmp_ids_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type int int topk_tmp_ids_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type int
int topk_tmp_val_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type float int topk_tmp_val_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type float
@ -685,13 +685,13 @@ void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs,
#undef CASE_K_DIV #undef CASE_K_DIV
template void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, float* log_probs, int* ids, template void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, float* log_probs, int* ids,
BeamHypotheses* beam_hyps, const bool* finished, const int* sequence_lengths, const int batch_size, BeamHypotheses* beam_hyps, bool const* finished, int const* sequence_lengths, int const batch_size,
const int beam_width, const int vocab_size_padded_, const float diversity_rate, const float length_penalty, int const beam_width, int const vocab_size_padded_, float const diversity_rate, float const length_penalty,
const int* end_ids, cudaStream_t stream); int const* end_ids, cudaStream_t stream);
template <typename T> template <typename T>
__global__ void tileEncoderResults(T* tiled_output, int* tiled_sequence_length, const T* output, __global__ void tileEncoderResults(T* tiled_output, int* tiled_sequence_length, T const* output,
const int* sequence_length, const uint32_t batch_size, const uint32_t beam_width, const uint32_t d_model) int const* sequence_length, const uint32_t batch_size, const uint32_t beam_width, const uint32_t d_model)
{ {
if (blockIdx.x == 0) if (blockIdx.x == 0)
{ {
@ -711,7 +711,7 @@ __global__ void tileEncoderResults(T* tiled_output, int* tiled_sequence_length,
} }
template <typename T> template <typename T>
void invokeTileEncoderResults(T* tiled_output, int* tiled_sequence_length, const T* output, const int* sequence_length, void invokeTileEncoderResults(T* tiled_output, int* tiled_sequence_length, T const* output, int const* sequence_length,
const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model,
cudaStream_t stream) cudaStream_t stream)
{ {
@ -739,30 +739,30 @@ void invokeTileEncoderResults(T* tiled_output, int* tiled_sequence_length, const
} }
} }
template void invokeTileEncoderResults(float* tiled_output, int* tiled_sequence_length, const float* output, template void invokeTileEncoderResults(float* tiled_output, int* tiled_sequence_length, float const* output,
const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
const size_t d_model, cudaStream_t stream); const size_t d_model, cudaStream_t stream);
template void invokeTileEncoderResults(half* tiled_output, int* tiled_sequence_length, const half* output, template void invokeTileEncoderResults(half* tiled_output, int* tiled_sequence_length, half const* output,
const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
const size_t d_model, cudaStream_t stream); const size_t d_model, cudaStream_t stream);
template void invokeTileEncoderResults(half2* tiled_output, int* tiled_sequence_length, const half2* output, template void invokeTileEncoderResults(half2* tiled_output, int* tiled_sequence_length, half2 const* output,
const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
const size_t d_model, cudaStream_t stream); const size_t d_model, cudaStream_t stream);
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template void invokeTileEncoderResults(__nv_bfloat16* tiled_output, int* tiled_sequence_length, template void invokeTileEncoderResults(__nv_bfloat16* tiled_output, int* tiled_sequence_length,
const __nv_bfloat16* output, const int* sequence_length, const size_t batch_size, const size_t beam_width, __nv_bfloat16 const* output, int const* sequence_length, const size_t batch_size, const size_t beam_width,
const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream); const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream);
#endif #endif
__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished,
const float* cum_log_probs, const int batch_size, const int beam_width) float const* cum_log_probs, int const batch_size, int const beam_width)
{ {
const int bid = blockIdx.x; int const bid = blockIdx.x;
const int tgt_start_idx = beam_hyps.num_beams[bid]; int const tgt_start_idx = beam_hyps.num_beams[bid];
const int max_seq_len{beam_hyps.max_seq_len}; int const max_seq_len{beam_hyps.max_seq_len};
const float length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[bid]}; float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[bid]};
if (beam_hyps.is_done[bid]) if (beam_hyps.is_done[bid])
{ {
return; return;
@ -771,10 +771,10 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedSta
{ {
if (threadIdx.x == 0) if (threadIdx.x == 0)
{ {
const int src_beam_idx = bid * beam_width + beam_idx; int const src_beam_idx = bid * beam_width + beam_idx;
const int tgt_beam_idx = bid * beam_width * 2 + beam_idx + tgt_start_idx; int const tgt_beam_idx = bid * beam_width * 2 + beam_idx + tgt_start_idx;
const int last_token_idx = beam_hyps.sequence_lengths_src[src_beam_idx] - 1; int const last_token_idx = beam_hyps.sequence_lengths_src[src_beam_idx] - 1;
beam_hyps.output_ids_tgt[tgt_beam_idx * max_seq_len + last_token_idx] beam_hyps.output_ids_tgt[tgt_beam_idx * max_seq_len + last_token_idx]
= beam_hyps.output_ids_src[src_beam_idx * max_seq_len + last_token_idx]; = beam_hyps.output_ids_src[src_beam_idx * max_seq_len + last_token_idx];
@ -810,8 +810,8 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedSta
} }
} }
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, const float* cum_log_probs, void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished, float const* cum_log_probs,
const int batch_size, const int beam_width, cudaStream_t stream) int const batch_size, int const beam_width, cudaStream_t stream)
{ {
insertUnfinishedPath<<<batch_size, 256, 0, stream>>>(beam_hyps, finished, cum_log_probs, batch_size, beam_width); insertUnfinishedPath<<<batch_size, 256, 0, stream>>>(beam_hyps, finished, cum_log_probs, batch_size, beam_width);
} }

View File

@ -35,57 +35,64 @@ namespace kernels
// After we collect `beam_width` beams, we will sort them by their norm_scores. // After we collect `beam_width` beams, we will sort them by their norm_scores.
struct BeamHypotheses struct BeamHypotheses
{ {
// TODO: simplify the pointers // BS: batch_size
// Pointers initialized in function prepareOutputs in gptDecoder.cpp // BM: beam_width
bool* is_done{nullptr}; // [batchSize], whether the batch is finished // mSL: max_seq_length
const int* input_lengths{nullptr}; // [batchSize] // %%: parameter name when we call [generation.py] dynamic_decoder.forward
float* cum_log_probs{nullptr}; // [batchSize, 2 * beamWidth], outputs.cum_log_probs->template getPtr<float>()
float* log_probs{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen], not used?
float* min_normed_scores{nullptr}; // [batchSize], worst normed scores for each batch
float* normed_scores{nullptr}; // [batchSize, 2 * beamWidth], cum_log / (length ^ length_penalty)
int* num_beams{nullptr}; // [batchSize], count of finished beams for each batch
int* output_ids_tgt{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen],
int* sequence_lengths_tgt{nullptr}; // [batchSize, 2 * beamWidth], different from sequence_lengths_src
// Pointers initialized in function invokeSoftMax in onlineBeamSearchLayer.cu // Pointers initialized in these two functions:
const int* end_ids{nullptr}; // get from SoftmaxParams // [gptDecoder.cpp] GptDecoder<T>::forward or [dynamicDecodeOp.cpp] FtDynamicDecode<T>::forward
const int* output_ids_src{nullptr}; // for gatherTree bool* is_done{nullptr}; // [BS] %% self.beam_hyps_is_done
const int* parent_ids_src{nullptr}; // for gatherTree float* cum_log_probs{nullptr}; // [BS, BM*2] %% self.beam_hyps_cum_log_probs
const int** output_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading float* log_probs{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_log_probs
const int** parent_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading float* min_normed_scores{nullptr}; // [BS] %% self.beam_hyps_min_normed_scores
float* log_probs_src{nullptr}; // get from outputs.output_log_probs float* normed_scores{nullptr}; // [BS, BM*2] %% self.beam_hyps_normed_scores
int* sequence_lengths_src{nullptr}; // get from BeamSearchOutputParams int* num_beams{nullptr}; // [BS] %% self.beam_hyps_num_beams
// For reading in function invokeTopkSoftMax but reading and writing in function invokeUpdate int* output_ids_tgt{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_is_done
int** output_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing int* sequence_lengths_tgt{nullptr}; // [BS, BM*2] %% self.beam_hyps_sequence_lengths_tgt
int** parent_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing int const* input_lengths{nullptr}; // [BS*BM] %% context_length
// Other scalar values and buffers // Pointers initialized in [onlineBeamSearchLayer.cu] invokeSoftMax:
int batch_size{0}; int const* end_ids{nullptr}; // [BS*BM] %% self.end_ids
int beam_width{0}; FinishedState* finished; // [BS*BM] %% self.finished
int ite{0}; float* cum_log_probs_src{nullptr}; // [BS, BM] %% self.cum_log_probs
int local_batch_size{0}; float* log_probs_src{nullptr}; // [mSL, BS, BM] %% self.log_probs_tiled
int max_seq_len{0}; int* sequence_lengths_src{nullptr}; // [BS*BM] %% self.sequence_length_buffer
int step{0}; // useless in online version of beam search int** output_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] from [dynamicDecodeLayer.cpp]
int vocab_size{0}; int** parent_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] from [dynamicDecodeLayer.cpp]
float* diversity_rates{nullptr};
float* length_penalties{nullptr}; float* diversity_rates{nullptr}; // [BS] from SamplingConfig
int* early_stoppings{nullptr}; float* length_penalties{nullptr}; // [BS] from SamplingConfig
bool is_return_normed_score{true}; // return normed_cum_log_probs or cum_log_probs int* early_stoppings{nullptr}; // [BS] from SamplingConfig
// Pointers for function gatherTree
int const* output_ids_src{nullptr}; //
int const* parent_ids_src{nullptr}; //
// Scalar values
bool is_return_normed_score{true}; // return normed_cum_log_probs or cum_log_probs, always be true now
int batch_size{0}; //
int beam_width{0}; //
int ite{0}; // index of local_batch, always be 0 if pp_size==1
int local_batch_size{0}; //
int max_seq_len{0}; //
int step{0}; // only used in [beamSearchTopkKernels.cu], always be 0 in [onlineSoftmaxBeamsearchKernels*.cu.h]
int vocab_size{0}; // vocab_size_padded
}; };
template <typename T> template <typename T>
void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps, void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps,
const bool* finished, const int* sequence_lengths, const int batch_size, const int beam_width, bool const* finished, int const* sequence_lengths, int const batch_size, int const beam_width,
const int vocab_size_padded_, const T diversity_rate, const float length_penalty, const int* end_ids, int const vocab_size_padded_, const T diversity_rate, float const length_penalty, int const* end_ids,
cudaStream_t stream); cudaStream_t stream);
template <typename T> template <typename T>
void invokeTileEncoderResults(T* tiled_encoder_output, int* tiled_encoder_sequence_length, const T* encoder_output, void invokeTileEncoderResults(T* tiled_encoder_output, int* tiled_encoder_sequence_length, T const* encoder_output,
const int* encoder_sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, int const* encoder_sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
const size_t d_model, cudaStream_t stream); const size_t d_model, cudaStream_t stream);
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, const float* cum_log_probs, void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished, float const* cum_log_probs,
const int batch_size, const int beam_width, cudaStream_t stream); int const batch_size, int const beam_width, cudaStream_t stream);
void invokeCopyBatchMajorToGeneralPtr( void invokeCopyBatchMajorToGeneralPtr(
void* output_ids_ptr, int* output_ids, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream); void* output_ids_ptr, int* output_ids, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream);

View File

@ -58,13 +58,13 @@ static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype)
else if (dtype == DATA_TYPE_INT32) else if (dtype == DATA_TYPE_INT32)
{ {
int32_t inorm = static_cast<int32_t>(norm); int32_t inorm = static_cast<int32_t>(norm);
alpha = reinterpret_cast<const uint32_t&>(inorm); alpha = reinterpret_cast<uint32_t const&>(inorm);
} }
else if (dtype == DATA_TYPE_BF16) else if (dtype == DATA_TYPE_BF16)
{ {
// TODO HACK!! BF16 Outputs are computed in FP32 for FP8. // TODO HACK!! BF16 Outputs are computed in FP32 for FP8.
// This is because cublas does not allow current FP32 output. // This is because cublas does not allow current FP32 output.
alpha = reinterpret_cast<const uint32_t&>(norm); alpha = reinterpret_cast<uint32_t const&>(norm);
} }
else else
{ {
@ -77,7 +77,7 @@ static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype)
class FusedMHARunnerV2::mhaImpl class FusedMHARunnerV2::mhaImpl
{ {
public: public:
mhaImpl(const Data_type data_type, const int numHeads, const int headSize, const float qScaling, int sm_) mhaImpl(const Data_type data_type, int const numHeads, int const headSize, float const qScaling, int sm_)
: mDataType(data_type) : mDataType(data_type)
, mNumHeads(numHeads) , mNumHeads(numHeads)
, mHeadSize(headSize) , mHeadSize(headSize)
@ -105,17 +105,17 @@ public:
// Shared setup function. // Shared setup function.
template <typename Params> template <typename Params>
void setup_params(Params& params, const int b, const int s_q, const int s_kv, const int sliding_window_size, void setup_params(Params& params, int const b, int const s_q, int const s_kv, int const sliding_window_size,
const int total_seqlen, const bool has_alibi, const bool scale_alibi, const int tp_size, const int tp_rank) int const total_seqlen, bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank)
{ {
const float inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling)); float const inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling));
// Note that we apply scales and bias in the order of // Note that we apply scales and bias in the order of
// (bmm1_output * scale_bmm1 + alibi) * scale_after_alibi // (bmm1_output * scale_bmm1 + alibi) * scale_after_alibi
const float scale_after_alibi = scale_alibi ? inv_sqrt_scale : 1.0f; float const scale_after_alibi = scale_alibi ? inv_sqrt_scale : 1.0f;
const float scale_bmm1 = scale_alibi ? 1.0f : inv_sqrt_scale; float const scale_bmm1 = scale_alibi ? 1.0f : inv_sqrt_scale;
const float scale_softmax = 1.f; // Seems to be only required for int8 float const scale_softmax = 1.f; // Seems to be only required for int8
const float scale_bmm2 = 1.f; float const scale_bmm2 = 1.f;
Data_type scale_type = mLaunchParams.force_fp32_acc ? DATA_TYPE_FP32 : mDataType; Data_type scale_type = mLaunchParams.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
// Use exp2f optimization for warp-specialized ws kernels on Hopper. // Use exp2f optimization for warp-specialized ws kernels on Hopper.
@ -153,8 +153,8 @@ public:
} }
// Support packed QKV. // Support packed QKV.
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, const bool has_alibi, void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen, bool const has_alibi,
const bool scale_alibi, const int tp_size, const int tp_rank) bool const scale_alibi, int const tp_size, int const tp_rank)
{ {
// Determine launch parameters. // Determine launch parameters.
@ -165,10 +165,10 @@ public:
TLLM_CHECK_WITH_INFO(mHeadSize > 0, "Head size should be greater than 0."); TLLM_CHECK_WITH_INFO(mHeadSize > 0, "Head size should be greater than 0.");
mLaunchParams.padded_d = (mHeadSize & (mHeadSize - 1)) == 0 ? mHeadSize : pow(2, int(log2(mHeadSize)) + 1); mLaunchParams.padded_d = (mHeadSize & (mHeadSize - 1)) == 0 ? mHeadSize : pow(2, int(log2(mHeadSize)) + 1);
const bool isSm70 = (sm == kSM_70); bool const isSm70 = (sm == kSM_70);
const bool isSm90 = (sm == kSM_90); bool const isSm90 = (sm == kSM_90);
const bool isSm8x = (sm == kSM_86 || sm == kSM_89); bool const isSm8x = (sm == kSM_86 || sm == kSM_89);
const bool isSm80 = (sm == kSM_80); bool const isSm80 = (sm == kSM_80);
if (isSm70) if (isSm70)
{ {
mLaunchParams.flash_attention = true; mLaunchParams.flash_attention = true;
@ -238,9 +238,9 @@ public:
} }
// Support paged_kv_cache and chunked_attention. // Support paged_kv_cache and chunked_attention.
void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence, void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi, int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen, bool const has_alibi,
const bool scale_alibi, const int tp_size, const int tp_rank) bool const scale_alibi, int const tp_size, int const tp_rank)
{ {
// Determine launch parameters. // Determine launch parameters.
@ -253,9 +253,9 @@ public:
mLaunchParams.padded_d = (mHeadSize & (mHeadSize - 1)) == 0 ? mHeadSize : pow(2, int(log2(mHeadSize)) + 1); mLaunchParams.padded_d = (mHeadSize & (mHeadSize - 1)) == 0 ? mHeadSize : pow(2, int(log2(mHeadSize)) + 1);
// Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256 // Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256
const bool isSm90 = (sm == kSM_90); bool const isSm90 = (sm == kSM_90);
const bool isSm8x = (sm == kSM_86 || sm == kSM_89); bool const isSm8x = (sm == kSM_86 || sm == kSM_89);
const bool isSm80 = (sm == kSM_80); bool const isSm80 = (sm == kSM_80);
// always use flash attention kernels. // always use flash attention kernels.
mLaunchParams.flash_attention = true; mLaunchParams.flash_attention = true;
@ -383,7 +383,7 @@ public:
// QKV [TOTAL, 3, h, d] // QKV [TOTAL, 3, h, d]
// NOTE: we may need to use actual seqlen to set oob_value // NOTE: we may need to use actual seqlen to set oob_value
const char* qkv_ptr = reinterpret_cast<const char*>(mParams.qkv_ptr); char const* qkv_ptr = reinterpret_cast<char const*>(mParams.qkv_ptr);
tensor_size_qkv[3] = mTotalSeqLen; tensor_size_qkv[3] = mTotalSeqLen;
// Q: STEP_Q // Q: STEP_Q
@ -467,7 +467,7 @@ public:
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B)); : (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
// Q ptr. // Q ptr.
const char* q_ptr = reinterpret_cast<const char*>(mPagedKVParams.q_ptr); char const* q_ptr = reinterpret_cast<char const*>(mPagedKVParams.q_ptr);
// Q: STEP_Q. // Q: STEP_Q.
q_tma_descriptor.set_tma_desctriptor(q_ptr, cudaTmaDescFormat::F16_RN, q_tma_descriptor.set_tma_desctriptor(q_ptr, cudaTmaDescFormat::F16_RN,
@ -518,7 +518,7 @@ public:
paged_kv_tma_descriptor.copy_to_device(mPagedKVParams.tma_desc_paged_kv, stream); paged_kv_tma_descriptor.copy_to_device(mPagedKVParams.tma_desc_paged_kv, stream);
} }
void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, const int num_kv_heads) void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask, int const num_kv_heads)
{ {
// BF16 FMHA only accumulates on FP32 // BF16 FMHA only accumulates on FP32
mLaunchParams.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc; mLaunchParams.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc;
@ -541,11 +541,11 @@ public:
return MHARunner::fmha_supported(mHeadSize, sm); return MHARunner::fmha_supported(mHeadSize, sm);
} }
void run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream) void run(void const* qkvPtr, void const* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
{ {
mParams.qkv_ptr = qkvPtr; mParams.qkv_ptr = qkvPtr;
mParams.o_ptr = outputPtr; mParams.o_ptr = outputPtr;
mParams.cu_seqlens = reinterpret_cast<const int*>(cuSeqlenPtr); mParams.cu_seqlens = reinterpret_cast<int const*>(cuSeqlenPtr);
if (sm == kSM_90 && mLaunchParams.use_tma) if (sm == kSM_90 && mLaunchParams.use_tma)
{ {
@ -556,8 +556,8 @@ public:
xmmaKernel->run(mParams, mLaunchParams, stream); xmmaKernel->run(mParams, mLaunchParams, stream);
} }
void run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost, void run_paged_kv(void const* qPtr, void* pagedKVTmaDesc, void const* pagedKVBlockPtrsOnHost,
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr, const KVBlockArray pagedKVCache, void const* cuQSeqlenPtr, void const* cuKVSeqlenPtr, void* outputPtr,
cudaStream_t stream) cudaStream_t stream)
{ {
KVBlockArrayForContextFMHA pagedKVCacheForContextMHA; KVBlockArrayForContextFMHA pagedKVCacheForContextMHA;
@ -568,10 +568,10 @@ public:
mPagedKVParams.tma_desc_paged_kv = reinterpret_cast<cudaTmaDesc*>(pagedKVTmaDesc); mPagedKVParams.tma_desc_paged_kv = reinterpret_cast<cudaTmaDesc*>(pagedKVTmaDesc);
mPagedKVParams.paged_kv_cache = pagedKVCacheForContextMHA; mPagedKVParams.paged_kv_cache = pagedKVCacheForContextMHA;
mPagedKVParams.o_ptr = outputPtr; mPagedKVParams.o_ptr = outputPtr;
mPagedKVParams.cu_q_seqlens = reinterpret_cast<const int*>(cuQSeqlenPtr); mPagedKVParams.cu_q_seqlens = reinterpret_cast<int const*>(cuQSeqlenPtr);
mPagedKVParams.cu_seqlens = reinterpret_cast<const int*>(cuKVSeqlenPtr); mPagedKVParams.cu_seqlens = reinterpret_cast<int const*>(cuKVSeqlenPtr);
// paged kv block device ptrs on host (used by tma descriptors). // paged kv block device ptrs on host (used by tma descriptors).
mLaunchParams.paged_kv_block_ptrs = reinterpret_cast<const int64_t*>(pagedKVBlockPtrsOnHost); mLaunchParams.paged_kv_block_ptrs = reinterpret_cast<int64_t const*>(pagedKVBlockPtrsOnHost);
if (sm == kSM_90 && mLaunchParams.use_tma) if (sm == kSM_90 && mLaunchParams.use_tma)
{ {
@ -587,7 +587,7 @@ public:
return pagedKVXmmaKernel->isValid(s) && xmmaKernel->isValid(s); return pagedKVXmmaKernel->isValid(s) && xmmaKernel->isValid(s);
} }
int getSFromMaxSeqLen(const int max_seq_len) int getSFromMaxSeqLen(int const max_seq_len)
{ {
int S = 1024; int S = 1024;
@ -625,35 +625,35 @@ private:
Fused_multihead_attention_paged_kv_params_v2 mPagedKVParams; Fused_multihead_attention_paged_kv_params_v2 mPagedKVParams;
Launch_params mLaunchParams; Launch_params mLaunchParams;
int sm; int sm;
const FusedMultiHeadAttentionXMMAKernelV2* xmmaKernel; FusedMultiHeadAttentionXMMAKernelV2 const* xmmaKernel;
const FusedMultiHeadAttentionPagedKVXMMAKernelV2* pagedKVXmmaKernel; FusedMultiHeadAttentionPagedKVXMMAKernelV2 const* pagedKVXmmaKernel;
bool use_flash_attention = false; bool use_flash_attention = false;
const Data_type mDataType; const Data_type mDataType;
const int mNumHeads; int const mNumHeads;
const int mHeadSize; int const mHeadSize;
const float mQScaling; float const mQScaling;
int mTotalSeqLen; int mTotalSeqLen;
}; };
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
FusedMHARunnerV2::FusedMHARunnerV2( FusedMHARunnerV2::FusedMHARunnerV2(
const Data_type data_type, const int numHeads, const int headSize, const float qScaling) const Data_type data_type, int const numHeads, int const headSize, float const qScaling)
: pimpl(new mhaImpl(data_type, numHeads, headSize, qScaling, tensorrt_llm::common::getSMVersion())) : pimpl(new mhaImpl(data_type, numHeads, headSize, qScaling, tensorrt_llm::common::getSMVersion()))
{ {
} }
FusedMHARunnerV2::~FusedMHARunnerV2() = default; FusedMHARunnerV2::~FusedMHARunnerV2() = default;
void FusedMHARunnerV2::setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, void FusedMHARunnerV2::setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
const bool has_alibi, const bool scale_alibi, const int tp_size, const int tp_rank) bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank)
{ {
pimpl->setup(b, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank); pimpl->setup(b, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
} }
void FusedMHARunnerV2::setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence, void FusedMHARunnerV2::setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi, int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen, bool const has_alibi,
const bool scale_alibi, const int tp_size, const int tp_rank) bool const scale_alibi, int const tp_size, int const tp_rank)
{ {
pimpl->setup_paged_kv(b, s_q, s_kv, blocks_per_context_sequence, tokens_per_kv_block, sliding_window_size, pimpl->setup_paged_kv(b, s_q, s_kv, blocks_per_context_sequence, tokens_per_kv_block, sliding_window_size,
total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank); total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
@ -665,18 +665,18 @@ bool FusedMHARunnerV2::fmha_supported()
} }
void FusedMHARunnerV2::setup_flags( void FusedMHARunnerV2::setup_flags(
const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, const int num_kv_heads) bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask, int const num_kv_heads)
{ {
pimpl->setup_flags(force_fp32_acc, is_s_padded, causal_mask, num_kv_heads); pimpl->setup_flags(force_fp32_acc, is_s_padded, causal_mask, num_kv_heads);
} }
void FusedMHARunnerV2::run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream) void FusedMHARunnerV2::run(void const* qkvPtr, void const* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
{ {
pimpl->run(qkvPtr, cuSeqlenPtr, outputPtr, stream); pimpl->run(qkvPtr, cuSeqlenPtr, outputPtr, stream);
} }
void FusedMHARunnerV2::run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost, void FusedMHARunnerV2::run_paged_kv(void const* qPtr, void* pagedKVTmaDesc, void const* pagedKVBlockPtrsOnHost,
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr, const KVBlockArray pagedKVCache, void const* cuQSeqlenPtr, void const* cuKVSeqlenPtr, void* outputPtr,
cudaStream_t stream) cudaStream_t stream)
{ {
pimpl->run_paged_kv( pimpl->run_paged_kv(
@ -689,7 +689,7 @@ bool FusedMHARunnerV2::isValid(int s) const
} }
// static function to check if fmha is supported when building plugins // static function to check if fmha is supported when building plugins
bool MHARunner::fmha_supported(const int headSize, const int sm) bool MHARunner::fmha_supported(int const headSize, int const sm)
{ {
if (sm == kSM_70) if (sm == kSM_70)
{ {

View File

@ -41,33 +41,33 @@ namespace kernels
class MHARunner class MHARunner
{ {
public: public:
MHARunner(const Data_type dataType, const int numHeads, const int headSize, const float qScaling); MHARunner(const Data_type dataType, int const numHeads, int const headSize, float const qScaling);
MHARunner() = default; MHARunner() = default;
virtual ~MHARunner() = default; virtual ~MHARunner() = default;
virtual void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, virtual void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0) bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1, int const tp_rank = 0)
= 0; = 0;
virtual void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence, virtual void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen,
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0) bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1, int const tp_rank = 0)
= 0; = 0;
static bool fmha_supported(const int headSize, const int sm); static bool fmha_supported(int const headSize, int const sm);
virtual bool fmha_supported() = 0; virtual bool fmha_supported() = 0;
virtual void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, virtual void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask,
const int num_kv_heads /* MQA or GQA */) int const num_kv_heads /* MQA or GQA */)
= 0; = 0;
virtual void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) = 0; virtual void run(void const* input, void const* cu_seqlens, void* output, cudaStream_t stream) = 0;
virtual void run_paged_kv(const void* q_input, void* paged_kv_tma_desc, const void* paged_kv_block_ptrs_on_host, virtual void run_paged_kv(void const* q_input, void* paged_kv_tma_desc, void const* paged_kv_block_ptrs_on_host,
const KVBlockArray paged_kv_cache, const void* cu_q_seqlens, const void* cu_kv_seqlens, void* output, const KVBlockArray paged_kv_cache, void const* cu_q_seqlens, void const* cu_kv_seqlens, void* output,
cudaStream_t stream) cudaStream_t stream)
= 0; = 0;
@ -86,28 +86,28 @@ public:
class FusedMHARunnerV2 : public MHARunner class FusedMHARunnerV2 : public MHARunner
{ {
public: public:
FusedMHARunnerV2(const Data_type dataType, const int numHeads, const int headSize, const float qScaling); FusedMHARunnerV2(const Data_type dataType, int const numHeads, int const headSize, float const qScaling);
~FusedMHARunnerV2(); // for pimpl ~FusedMHARunnerV2(); // for pimpl
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1,
const int tp_rank = 0) override; int const tp_rank = 0) override;
void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence, void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen,
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1,
const int tp_rank = 0) override; int const tp_rank = 0) override;
bool fmha_supported() override; bool fmha_supported() override;
void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) override; void run(void const* input, void const* cu_seqlens, void* output, cudaStream_t stream) override;
void run_paged_kv(const void* q_input, void* paged_kv_tma_desc, const void* paged_kv_block_ptrs_on_host, void run_paged_kv(void const* q_input, void* paged_kv_tma_desc, void const* paged_kv_block_ptrs_on_host,
const KVBlockArray paged_kv_cache, const void* cu_q_seqlens, const void* cu_kv_seqlens, void* output, const KVBlockArray paged_kv_cache, void const* cu_q_seqlens, void const* cu_kv_seqlens, void* output,
cudaStream_t stream) override; cudaStream_t stream) override;
void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask,
const int num_kv_heads /* MQA or GQA */) override; int const num_kv_heads /* MQA or GQA */) override;
bool isValid(int s) const override; bool isValid(int s) const override;

View File

@ -84,9 +84,9 @@ struct AlibiParams
struct Fused_multihead_attention_params_v2 struct Fused_multihead_attention_params_v2
{ {
// The QKV matrices. // The QKV matrices.
const void* qkv_ptr; void const* qkv_ptr;
// The mask to implement drop-out. // The mask to implement drop-out.
const void* packed_mask_ptr; void const* packed_mask_ptr;
// The O matrix (output). // The O matrix (output).
void* o_ptr; void* o_ptr;
@ -106,7 +106,7 @@ struct Fused_multihead_attention_params_v2
bool enable_i2f_trick; bool enable_i2f_trick;
// array of length b+1 holding prefix sum of actual sequence lengths // array of length b+1 holding prefix sum of actual sequence lengths
const int* cu_seqlens; int const* cu_seqlens;
// use C/32 Format. // use C/32 Format.
bool interleaved = false; bool interleaved = false;
@ -177,13 +177,13 @@ struct Fused_multihead_attention_params_v2
struct Fused_multihead_attention_paged_kv_params_v2 struct Fused_multihead_attention_paged_kv_params_v2
{ {
// The Q matrices. // The Q matrices.
const void* q_ptr; void const* q_ptr;
// Paged KV Cache buffer. // Paged KV Cache buffer.
KVBlockArrayForContextFMHA paged_kv_cache; KVBlockArrayForContextFMHA paged_kv_cache;
// The O matrix (output). // The O matrix (output).
void* o_ptr; void* o_ptr;
// The packed mask for random mask. // The packed mask for random mask.
const void* packed_mask_ptr; void const* packed_mask_ptr;
// The stride between rows of the Q matrices. // The stride between rows of the Q matrices.
int64_t q_stride_in_bytes; int64_t q_stride_in_bytes;
@ -211,9 +211,9 @@ struct Fused_multihead_attention_paged_kv_params_v2
AlibiParams alibi_params; AlibiParams alibi_params;
// array of length b+1 holding prefix sum of actual kv sequence lengths. // array of length b+1 holding prefix sum of actual kv sequence lengths.
const int* cu_seqlens; int const* cu_seqlens;
// Chunked attention (only handles one tile of Q). // Chunked attention (only handles one tile of Q).
const int* cu_q_seqlens; int const* cu_q_seqlens;
// q with shape [B, S, H, D] in const cache. // q with shape [B, S, H, D] in const cache.
cudaTmaDesc tma_desc_q; cudaTmaDesc tma_desc_q;
@ -301,7 +301,7 @@ struct Launch_params
// number of paged kv blocks for context sequence. // number of paged kv blocks for context sequence.
int blocks_per_context_sequence = 0; int blocks_per_context_sequence = 0;
// device ptrs on the host for paged kv cache. // device ptrs on the host for paged kv cache.
const int64_t* paged_kv_block_ptrs = nullptr; int64_t const* paged_kv_block_ptrs = nullptr;
// if flash attention is used (only FP16) // if flash attention is used (only FP16)
bool flash_attention = false; bool flash_attention = false;
// if warp_specialized kernels are used (only SM90 HGMMA + TMA) // if warp_specialized kernels are used (only SM90 HGMMA + TMA)

View File

@ -63,13 +63,13 @@ public:
return (uint64_t) s << 32 | d; return (uint64_t) s << 32 | d;
} }
virtual uint64_t hashID(const KernelMeta& kernelMeta) const virtual uint64_t hashID(KernelMeta const& kernelMeta) const
{ {
return hashID(kernelMeta.mS, kernelMeta.mD); return hashID(kernelMeta.mS, kernelMeta.mD);
} }
TFusedMultiHeadAttentionXMMAKernel( TFusedMultiHeadAttentionXMMAKernel(
const TKernelMeta* pMetaStart, unsigned int nMetaCount, Data_type type, unsigned int sm) TKernelMeta const* pMetaStart, unsigned int nMetaCount, Data_type type, unsigned int sm)
: mDataType(type) : mDataType(type)
, mKernelMeta(pMetaStart) , mKernelMeta(pMetaStart)
, mKernelMetaCount(nMetaCount) , mKernelMetaCount(nMetaCount)
@ -86,7 +86,7 @@ public:
for (unsigned int i = 0; i < mKernelMetaCount; ++i) for (unsigned int i = 0; i < mKernelMetaCount; ++i)
{ {
const auto& kernelMeta = mKernelMeta[i]; auto const& kernelMeta = mKernelMeta[i];
if (kernelMeta.mSM == mSM && kernelMeta.mDataType == mDataType) if (kernelMeta.mSM == mSM && kernelMeta.mDataType == mDataType)
{ {
CUmodule hmod{0}; CUmodule hmod{0};
@ -125,9 +125,9 @@ public:
virtual void run(TKernelParam& params, Launch_params& launch_params, cudaStream_t ss) const virtual void run(TKernelParam& params, Launch_params& launch_params, cudaStream_t ss) const
{ {
const auto findIter = mFunctions.find(hashID(params.s, params.d)); auto const findIter = mFunctions.find(hashID(params.s, params.d));
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex]; auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
const CUfunction func = findIter->second.mDeviceFunction; const CUfunction func = findIter->second.mDeviceFunction;
void* kernelParams[] = {&params, nullptr}; void* kernelParams[] = {&params, nullptr};
@ -142,10 +142,10 @@ protected:
tensorrt_llm::common::CUDADriverWrapper mDriver; tensorrt_llm::common::CUDADriverWrapper mDriver;
Data_type mDataType; Data_type mDataType;
const TKernelMeta* mKernelMeta; TKernelMeta const* mKernelMeta;
unsigned int mKernelMetaCount; unsigned int mKernelMetaCount;
unsigned int mSM; unsigned int mSM;
std::unordered_map<const unsigned char*, CUmodule> mModules; std::unordered_map<unsigned char const*, CUmodule> mModules;
struct FusedMultiHeadAttentionKernelInfo struct FusedMultiHeadAttentionKernelInfo
{ {
@ -161,14 +161,14 @@ template <typename TFusedMHAKernelList>
class TFusedMHAKernelFactory class TFusedMHAKernelFactory
{ {
public: public:
const TFusedMHAKernelList* getXMMAKernels(const typename TFusedMHAKernelList::KernelMeta* pKernelList, TFusedMHAKernelList const* getXMMAKernels(const typename TFusedMHAKernelList::KernelMeta* pKernelList,
unsigned int nbKernels, Data_type type, unsigned int sm) unsigned int nbKernels, Data_type type, unsigned int sm)
{ {
static std::mutex s_mutex; static std::mutex s_mutex;
std::lock_guard<std::mutex> lg(s_mutex); std::lock_guard<std::mutex> lg(s_mutex);
const auto id = hashID(type, sm); auto const id = hashID(type, sm);
const auto findIter = mKernels.find(id); auto const findIter = mKernels.find(id);
if (findIter == mKernels.end()) if (findIter == mKernels.end())
{ {
TFusedMHAKernelList* newKernel = new TFusedMHAKernelList{pKernelList, nbKernels, type, sm}; TFusedMHAKernelList* newKernel = new TFusedMHAKernelList{pKernelList, nbKernels, type, sm};
@ -214,7 +214,7 @@ class FusedMultiHeadAttentionXMMAKernelV2
Fused_multihead_attention_params_v2> Fused_multihead_attention_params_v2>
{ {
public: public:
FusedMultiHeadAttentionXMMAKernelV2(const FusedMultiHeadAttentionKernelMetaInfoV2* pMetaStart, FusedMultiHeadAttentionXMMAKernelV2(FusedMultiHeadAttentionKernelMetaInfoV2 const* pMetaStart,
unsigned int nMetaCount, Data_type type, unsigned int sm) unsigned int nMetaCount, Data_type type, unsigned int sm)
: TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionKernelMetaInfoV2, : TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionKernelMetaInfoV2,
Fused_multihead_attention_params_v2>(pMetaStart, nMetaCount, type, sm) Fused_multihead_attention_params_v2>(pMetaStart, nMetaCount, type, sm)
@ -231,7 +231,7 @@ public:
| (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull); | (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
} }
virtual uint64_t hashID(const KernelMeta& kernelMeta) const virtual uint64_t hashID(KernelMeta const& kernelMeta) const
{ {
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep, return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
@ -278,7 +278,7 @@ public:
} }
} }
const auto findIter auto const findIter
= mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved, forceUnroll, = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved, forceUnroll,
launch_params.force_fp32_acc, launch_params.flash_attention, !launch_params.useKernelWithoutAlibi, launch_params.force_fp32_acc, launch_params.flash_attention, !launch_params.useKernelWithoutAlibi,
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling)); static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
@ -290,7 +290,7 @@ public:
launch_params.flash_attention, !launch_params.useKernelWithoutAlibi, launch_params.flash_attention, !launch_params.useKernelWithoutAlibi,
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling); static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling);
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex]; auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
const CUfunction func = findIter->second.mDeviceFunction; const CUfunction func = findIter->second.mDeviceFunction;
void* kernelParams[] = {&params, nullptr}; void* kernelParams[] = {&params, nullptr};
@ -369,7 +369,7 @@ public:
using FusedMHAKernelFactoryV2 = TFusedMHAKernelFactory<FusedMultiHeadAttentionXMMAKernelV2>; using FusedMHAKernelFactoryV2 = TFusedMHAKernelFactory<FusedMultiHeadAttentionXMMAKernelV2>;
inline const FusedMultiHeadAttentionXMMAKernelV2* getXMMAKernelsV2(Data_type type, unsigned int sm) inline FusedMultiHeadAttentionXMMAKernelV2 const* getXMMAKernelsV2(Data_type type, unsigned int sm)
{ {
return FusedMHAKernelFactoryV2::Get().getXMMAKernels( return FusedMHAKernelFactoryV2::Get().getXMMAKernels(
sMhaKernelMetaInfosV2, sizeof(sMhaKernelMetaInfosV2) / sizeof(sMhaKernelMetaInfosV2[0]), type, sm); sMhaKernelMetaInfosV2, sizeof(sMhaKernelMetaInfosV2) / sizeof(sMhaKernelMetaInfosV2[0]), type, sm);
@ -384,7 +384,7 @@ class FusedMultiHeadAttentionPagedKVXMMAKernelV2
Fused_multihead_attention_paged_kv_params_v2> Fused_multihead_attention_paged_kv_params_v2>
{ {
public: public:
FusedMultiHeadAttentionPagedKVXMMAKernelV2(const FusedMultiHeadAttentionPagedKVKernelMetaInfoV2* pMetaStart, FusedMultiHeadAttentionPagedKVXMMAKernelV2(FusedMultiHeadAttentionPagedKVKernelMetaInfoV2 const* pMetaStart,
unsigned int nMetaCount, Data_type type, unsigned int sm) unsigned int nMetaCount, Data_type type, unsigned int sm)
: TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionPagedKVKernelMetaInfoV2, : TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionPagedKVKernelMetaInfoV2,
Fused_multihead_attention_paged_kv_params_v2>(pMetaStart, nMetaCount, type, sm) Fused_multihead_attention_paged_kv_params_v2>(pMetaStart, nMetaCount, type, sm)
@ -402,7 +402,7 @@ public:
| (flash_attention ? 4ull : 0ull) | (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull); | (flash_attention ? 4ull : 0ull) | (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
} }
virtual uint64_t hashID(const KernelMeta& kernelMeta) const virtual uint64_t hashID(KernelMeta const& kernelMeta) const
{ {
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep, return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mWarpSpecialization, kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mWarpSpecialization,
@ -413,7 +413,7 @@ public:
Fused_multihead_attention_paged_kv_params_v2& params, Launch_params& launch_params, cudaStream_t stream) const Fused_multihead_attention_paged_kv_params_v2& params, Launch_params& launch_params, cudaStream_t stream) const
{ {
const auto findIter = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved, auto const findIter = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved,
launch_params.force_unroll, launch_params.force_fp32_acc, launch_params.flash_attention, launch_params.force_unroll, launch_params.force_fp32_acc, launch_params.flash_attention,
launch_params.warp_specialization, !launch_params.useKernelWithoutAlibi, launch_params.warp_specialization, !launch_params.useKernelWithoutAlibi,
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling)); static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
@ -426,7 +426,7 @@ public:
!launch_params.useKernelWithoutAlibi, static_cast<int>(launch_params.attention_mask_type), !launch_params.useKernelWithoutAlibi, static_cast<int>(launch_params.attention_mask_type),
launch_params.granular_tiling); launch_params.granular_tiling);
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex]; auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
const CUfunction func = findIter->second.mDeviceFunction; const CUfunction func = findIter->second.mDeviceFunction;
void* kernelParams[] = {&params, nullptr}; void* kernelParams[] = {&params, nullptr};
@ -488,7 +488,7 @@ public:
using FusedMHAPagedKVKernelFactoryV2 = TFusedMHAKernelFactory<FusedMultiHeadAttentionPagedKVXMMAKernelV2>; using FusedMHAPagedKVKernelFactoryV2 = TFusedMHAKernelFactory<FusedMultiHeadAttentionPagedKVXMMAKernelV2>;
inline const FusedMultiHeadAttentionPagedKVXMMAKernelV2* getPagedKVXMMAKernelsV2(Data_type type, unsigned int sm) inline FusedMultiHeadAttentionPagedKVXMMAKernelV2 const* getPagedKVXMMAKernelsV2(Data_type type, unsigned int sm)
{ {
return FusedMHAPagedKVKernelFactoryV2::Get().getXMMAKernels(sMhaPagedKVKernelMetaInfosV2, return FusedMHAPagedKVKernelFactoryV2::Get().getXMMAKernels(sMhaPagedKVKernelMetaInfosV2,
sizeof(sMhaPagedKVKernelMetaInfosV2) / sizeof(sMhaPagedKVKernelMetaInfosV2[0]), type, sm); sizeof(sMhaPagedKVKernelMetaInfosV2) / sizeof(sMhaPagedKVKernelMetaInfosV2[0]), type, sm);

View File

@ -186,7 +186,7 @@ public:
// set the desctriptor. // set the desctriptor.
int set_tma_desctriptor( int set_tma_desctriptor(
// ptr to gmem // ptr to gmem
const void* gmem_ptr, void const* gmem_ptr,
// format is really data_type in TMA terminology. // format is really data_type in TMA terminology.
cudaTmaDescFormat format, cudaTmaDescFormat format,
// interleave mode. // interleave mode.
@ -221,7 +221,7 @@ public:
// set the desctriptor. // set the desctriptor.
int set_tma_desctriptor( int set_tma_desctriptor(
// ptr to gmem // ptr to gmem
const void* gmem_ptr, void const* gmem_ptr,
// format is really data_type in TMA terminology. // format is really data_type in TMA terminology.
cudaTmaDescFormat format, cudaTmaDescFormat format,
// interleave mode. // interleave mode.

View File

@ -108,10 +108,10 @@ inline __device__ int4 add128b(T& a, T& b)
} }
__inline__ __device__ void multi_gpu_barrier( __inline__ __device__ void multi_gpu_barrier(
uint32_t** signals, const uint32_t flag, const size_t rank, const size_t world_size, const int tidx, const int bidx) uint32_t** signals, const uint32_t flag, const size_t rank, const size_t world_size, int const tidx, int const bidx)
{ {
// At the end of the function, we now that has least block 0 from all others GPUs have reached that point. // At the end of the function, we now that has least block 0 from all others GPUs have reached that point.
volatile uint32_t* my_signals = signals[rank]; uint32_t volatile* my_signals = signals[rank];
if (tidx < world_size) if (tidx < world_size)
{ {
// The 1st block notifies the other ranks. // The 1st block notifies the other ranks.
@ -139,8 +139,8 @@ __global__ void multiGpuBarrierKernel(AllReduceParams params)
template <typename T, int RANKS_PER_NODE> template <typename T, int RANKS_PER_NODE>
static __global__ void oneShotAllReduceKernel(AllReduceParams params) static __global__ void oneShotAllReduceKernel(AllReduceParams params)
{ {
const int bidx = blockIdx.x; int const bidx = blockIdx.x;
const int tidx = threadIdx.x; int const tidx = threadIdx.x;
// The number of elements packed into one for comms // The number of elements packed into one for comms
static constexpr int NUM_ELTS = 16 / sizeof(T); static constexpr int NUM_ELTS = 16 / sizeof(T);
@ -151,7 +151,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx); multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
// The source pointers. Distributed round-robin for the different warps. // The source pointers. Distributed round-robin for the different warps.
const T* src_d[RANKS_PER_NODE]; T const* src_d[RANKS_PER_NODE];
#pragma unroll #pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{ {
@ -172,7 +172,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
#pragma unroll #pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{ {
vals[ii].packed = *reinterpret_cast<const int4*>(&src_d[ii][iter_offset]); vals[ii].packed = *reinterpret_cast<int4 const*>(&src_d[ii][iter_offset]);
} }
// Sum the values from the different ranks. // Sum the values from the different ranks.
@ -194,9 +194,9 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
{ {
// The block index. // The block index.
const int bidx = blockIdx.x; int const bidx = blockIdx.x;
// The thread index with the block. // The thread index with the block.
const int tidx = threadIdx.x; int const tidx = threadIdx.x;
// The number of elements packed into one for comms // The number of elements packed into one for comms
static constexpr int NUM_ELTS = 16 / sizeof(T); static constexpr int NUM_ELTS = 16 / sizeof(T);
@ -233,7 +233,7 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
#pragma unroll #pragma unroll
for (int ii = 0; ii < RANKS_PER_NODE; ++ii) for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
{ {
vals[ii].packed = *reinterpret_cast<const int4*>(&src_d[ii][local_offset]); vals[ii].packed = *reinterpret_cast<int4 const*>(&src_d[ii][local_offset]);
} }
// Sum the values from the different ranks. // Sum the values from the different ranks.
@ -396,14 +396,14 @@ void invokeMultiGpuBarrier(AllReduceParams& param, cudaStream_t stream)
multiGpuBarrierKernel<<<1, param.ranks_per_node, 0, stream>>>(param); multiGpuBarrierKernel<<<1, param.ranks_per_node, 0, stream>>>(param);
} }
AllReduceParams AllReduceParams::deserialize(const int32_t* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value) AllReduceParams AllReduceParams::deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value)
{ {
void* const* buffer_ptrs = reinterpret_cast<void* const*>(buffer); void* const* buffer_ptrs = reinterpret_cast<void* const*>(buffer);
AllReduceParams params; AllReduceParams params;
// Even plugins use ping buffers, odd plugins use pong. // Even plugins use ping buffers, odd plugins use pong.
// That way, we don't need to wait for other GPUs to be done // That way, we don't need to wait for other GPUs to be done
// before copying input tensor to workspace. // before copying input tensor to workspace.
const auto buffer_offset = (flag_value % 2 == 0) ? 0 : tpSize; auto const buffer_offset = (flag_value % 2 == 0) ? 0 : tpSize;
for (int i = 0; i < tpSize; ++i) for (int i = 0; i < tpSize; ++i)
{ {

View File

@ -57,7 +57,7 @@ struct AllReduceParams
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE]; void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
void* local_output_buffer_ptr; void* local_output_buffer_ptr;
static AllReduceParams deserialize(const int32_t* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value); static AllReduceParams deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value);
}; };
template <typename T> template <typename T>

View File

@ -70,7 +70,7 @@ TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
} }
bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape, bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape,
const int split_k_factor, const size_t workspace_bytes, const bool is_weight_only) int const split_k_factor, const size_t workspace_bytes, bool const is_weight_only)
{ {
// All tile sizes have a k_tile of 64. // All tile sizes have a k_tile of 64.
@ -89,7 +89,7 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
return false; return false;
} }
const int k_elements_per_split = k / split_k_factor; int const k_elements_per_split = k / split_k_factor;
if ((k_elements_per_split % k_tile) != 0) if ((k_elements_per_split % k_tile) != 0)
{ {
return false; return false;
@ -97,9 +97,9 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
} }
// Check that the workspace has sufficient space for this split-k factor // Check that the workspace has sufficient space for this split-k factor
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim; int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
if (required_ws_bytes > workspace_bytes) if (required_ws_bytes > workspace_bytes)
{ {
@ -110,7 +110,7 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
} }
std::vector<CutlassTileConfig> get_candidate_tiles( std::vector<CutlassTileConfig> get_candidate_tiles(
const int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only) int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only)
{ {
enum class CutlassGemmType : char enum class CutlassGemmType : char
{ {
@ -170,7 +170,7 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
} }
std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90( std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
const int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only) int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only)
{ {
enum class CutlassGemmType : char enum class CutlassGemmType : char
{ {
@ -226,8 +226,8 @@ bool supports_mcast_along_n(const CutlassTileConfigSM90 tile)
return valid_tiles.count(tile) == 1; return valid_tiles.count(tile) == 1;
} }
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only, std::vector<CutlassGemmConfig> get_candidate_configs(int sm, bool const is_weight_only, bool const simt_configs_only,
const bool int8_configs_only, const int max_split_k, const bool enable_hopper_gmma) bool const int8_configs_only, int const max_split_k, bool const enable_hopper_gmma)
{ {
if (sm == 90 && enable_hopper_gmma) if (sm == 90 && enable_hopper_gmma)
{ {
@ -235,14 +235,14 @@ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weigh
= get_candidate_tiles_sm90(sm, is_weight_only, simt_configs_only, int8_configs_only); = get_candidate_tiles_sm90(sm, is_weight_only, simt_configs_only, int8_configs_only);
std::vector<CutlassGemmConfig> candidate_configs; std::vector<CutlassGemmConfig> candidate_configs;
for (const auto& tile_config : tiles) for (auto const& tile_config : tiles)
{ {
CutlassGemmConfig config( CutlassGemmConfig config(
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1); tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
candidate_configs.push_back(config); candidate_configs.push_back(config);
const bool has_m_mcast = supports_mcast_along_m(tile_config); bool const has_m_mcast = supports_mcast_along_m(tile_config);
const bool has_n_mcast = supports_mcast_along_n(tile_config); bool const has_n_mcast = supports_mcast_along_n(tile_config);
if (has_m_mcast) if (has_m_mcast)
{ {
CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
@ -270,9 +270,9 @@ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weigh
= get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only); = get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only);
std::vector<CutlassGemmConfig> candidate_configs; std::vector<CutlassGemmConfig> candidate_configs;
const int min_stages = int8_configs_only ? 3 : 2; int const min_stages = int8_configs_only ? 3 : 2;
const int max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2); int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
for (const auto& tile_config : tiles) for (auto const& tile_config : tiles)
{ {
for (int stages = min_stages; stages <= max_stages; ++stages) for (int stages = min_stages; stages <= max_stages; ++stages)
{ {
@ -292,9 +292,9 @@ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weigh
return candidate_configs; return candidate_configs;
} }
CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs, CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmConfig> const& candidate_configs,
const std::vector<int>& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts, std::vector<int> const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
const int split_k_limit, const size_t workspace_bytes, const int multi_processor_count, const int is_weight_only) int const split_k_limit, const size_t workspace_bytes, int const multi_processor_count, int const is_weight_only)
{ {
if (occupancies.size() != candidate_configs.size()) if (occupancies.size() != candidate_configs.size())
@ -311,7 +311,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
int config_waves = INT_MAX; int config_waves = INT_MAX;
int current_m_tile = 0; int current_m_tile = 0;
const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit; int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
for (int ii = 0; ii < candidate_configs.size(); ++ii) for (int ii = 0; ii < candidate_configs.size(); ++ii)
{ {
CutlassGemmConfig candidate_config = candidate_configs[ii]; CutlassGemmConfig candidate_config = candidate_configs[ii];
@ -330,21 +330,21 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
continue; continue;
} }
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m; int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n; int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor) for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor)
{ {
if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only)) if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only))
{ {
const int ctas_per_wave = occupancy * multi_processor_count; int const ctas_per_wave = occupancy * multi_processor_count;
const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor; int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave; int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave); float const num_waves_fractional = ctas_for_problem / float(ctas_per_wave);
const float current_score = float(num_waves_total) - num_waves_fractional; float const current_score = float(num_waves_total) - num_waves_fractional;
const float score_slack = 0.1f; float const score_slack = 0.1f;
if (current_score < config_score if (current_score < config_score
|| ((config_waves > num_waves_total) && (current_score < config_score + score_slack))) || ((config_waves > num_waves_total) && (current_score < config_score + score_slack)))
{ {

View File

@ -27,13 +27,13 @@ namespace cutlass_kernels
{ {
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(int sm, std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(int sm,
const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only = false, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only = false,
const int max_split_k = 1, const bool enable_hopper_gmma = false); int const max_split_k = 1, bool const enable_hopper_gmma = false);
tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies( tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies(
const std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig>& candidate_configs, std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> const& candidate_configs,
const std::vector<int>& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts, std::vector<int> const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
const int split_k_limit, const size_t workspace_bytes, const int multi_processor_count, const int is_weight_only); int const split_k_limit, const size_t workspace_bytes, int const multi_processor_count, int const is_weight_only);
} // namespace cutlass_kernels } // namespace cutlass_kernels
} // namespace kernels } // namespace kernels

View File

@ -158,8 +158,8 @@ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15 // 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
// For int4, each group of 32 rows is permuted using the map below: // For int4, each group of 32 rows is permuted using the map below:
// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31 // 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor, void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor,
const std::vector<size_t>& shape, QuantType quant_type, const int64_t arch_version) std::vector<size_t> const& shape, QuantType quant_type, const int64_t arch_version)
{ {
// We only want to run this step for weight only quant. // We only want to run this step for weight only quant.
@ -170,19 +170,19 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); int const BITS_PER_ELT = get_bits_in_quant_type(quant_type);
const int K = 16 / BITS_PER_ELT; int const K = 16 / BITS_PER_ELT;
const int ELTS_PER_BYTE = 8 / BITS_PER_ELT; int const ELTS_PER_BYTE = 8 / BITS_PER_ELT;
const int ELTS_PER_REG = 32 / BITS_PER_ELT; int const ELTS_PER_REG = 32 / BITS_PER_ELT;
const uint32_t* input_byte_ptr = reinterpret_cast<const uint32_t*>(quantized_tensor); uint32_t const* input_byte_ptr = reinterpret_cast<uint32_t const*>(quantized_tensor);
uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(permuted_quantized_tensor); uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(permuted_quantized_tensor);
int MMA_SHAPE_N = 8; int MMA_SHAPE_N = 8;
int B_ROWS_PER_MMA = 8 * K; int B_ROWS_PER_MMA = 8 * K;
const int elts_in_int32 = 32 / BITS_PER_ELT; int const elts_in_int32 = 32 / BITS_PER_ELT;
const int num_vec_cols = num_cols / elts_in_int32; int const num_vec_cols = num_cols / elts_in_int32;
TLLM_CHECK_WITH_INFO( TLLM_CHECK_WITH_INFO(
arch_version >= 75, "Unsupported Arch. Pre-volta not supported. Column interleave not needed on Volta."); arch_version >= 75, "Unsupported Arch. Pre-volta not supported. Column interleave not needed on Volta.");
@ -205,11 +205,11 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8
for (int write_col = 0; write_col < num_vec_cols; ++write_col) for (int write_col = 0; write_col < num_vec_cols; ++write_col)
{ {
const int write_row = base_row + tile_row; int const write_row = base_row + tile_row;
const int tile_read_row int const tile_read_row
= 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG); = 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
const int read_row = base_row + tile_read_row; int const read_row = base_row + tile_read_row;
const int read_col = write_col; int const read_col = write_col;
const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col; const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col;
const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col; const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col;
@ -227,9 +227,9 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8
// issue for relatively large models. // issue for relatively large models.
template <QuantType quant_type> template <QuantType quant_type>
void subbyte_transpose_impl( void subbyte_transpose_impl(
int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, const std::vector<size_t>& shape) int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, std::vector<size_t> const& shape)
{ {
const int bits_per_elt = get_bits_in_quant_type(quant_type); int const bits_per_elt = get_bits_in_quant_type(quant_type);
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
const size_t num_experts = shape.size() == 2 ? 1 : shape[0]; const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
@ -240,7 +240,7 @@ void subbyte_transpose_impl(
const size_t col_bytes_trans = num_rows * bits_per_elt / 8; const size_t col_bytes_trans = num_rows * bits_per_elt / 8;
const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes; const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes;
const uint8_t* input_byte_ptr = reinterpret_cast<const uint8_t*>(quantized_tensor); uint8_t const* input_byte_ptr = reinterpret_cast<uint8_t const*>(quantized_tensor);
uint8_t* output_byte_ptr = reinterpret_cast<uint8_t*>(transposed_quantized_tensor); uint8_t* output_byte_ptr = reinterpret_cast<uint8_t*>(transposed_quantized_tensor);
static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, ""); static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, "");
@ -260,8 +260,8 @@ void subbyte_transpose_impl(
"num_col_bytes = %ld.", "num_col_bytes = %ld.",
VECTOR_WIDTH, col_bytes_trans, col_bytes)); VECTOR_WIDTH, col_bytes_trans, col_bytes));
const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1; int const num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1;
const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1; int const num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1;
for (size_t expert = 0; expert < num_experts; ++expert) for (size_t expert = 0; expert < num_experts; ++expert)
{ {
@ -271,16 +271,16 @@ void subbyte_transpose_impl(
for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1) for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1)
{ {
const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows); int const row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
const int col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes); int const col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
for (int ii = 0; ii < M_TILE_L1; ++ii) for (int ii = 0; ii < M_TILE_L1; ++ii)
{ {
const int row = row_tile_start + ii; int const row = row_tile_start + ii;
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH)
{ {
const int col = col_tile_start_byte + jj; int const col = col_tile_start_byte + jj;
const size_t logical_src_offset = matrix_offset + row * col_bytes + col; const size_t logical_src_offset = matrix_offset + row * col_bytes + col;
@ -313,11 +313,11 @@ void subbyte_transpose_impl(
// is square in the number of elements (not necessarily the number of bytes). // is square in the number of elements (not necessarily the number of bytes).
for (int jj = ii + 1; jj < M_TILE_L1; ++jj) for (int jj = ii + 1; jj < M_TILE_L1; ++jj)
{ {
const int ii_byte = ii / ELTS_PER_BYTE; int const ii_byte = ii / ELTS_PER_BYTE;
const int ii_bit_offset = ii % ELTS_PER_BYTE; int const ii_bit_offset = ii % ELTS_PER_BYTE;
const int jj_byte = jj / ELTS_PER_BYTE; int const jj_byte = jj / ELTS_PER_BYTE;
const int jj_bit_offset = jj % ELTS_PER_BYTE; int const jj_bit_offset = jj % ELTS_PER_BYTE;
uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset)); uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset));
uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset)); uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset));
@ -338,15 +338,15 @@ void subbyte_transpose_impl(
const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE; const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE;
const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE; const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE;
const int row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols); int const row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols);
const int col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans); int const col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
for (int ii = 0; ii < M_TILE_L1; ++ii) for (int ii = 0; ii < M_TILE_L1; ++ii)
{ {
const int row = row_tile_start_trans + ii; int const row = row_tile_start_trans + ii;
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH) for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH)
{ {
const int col = col_tile_start_byte_trans + jj; int const col = col_tile_start_byte_trans + jj;
const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col; const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col;
@ -364,8 +364,8 @@ void subbyte_transpose_impl(
} }
} }
void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor,
const std::vector<size_t>& shape, QuantType quant_type) std::vector<size_t> const& shape, QuantType quant_type)
{ {
if (quant_type == QuantType::INT8_WEIGHT_ONLY) if (quant_type == QuantType::INT8_WEIGHT_ONLY)
@ -409,7 +409,7 @@ void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, const size_t num
void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts) void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts)
{ {
const int num_bytes = num_elts / 2; int const num_bytes = num_elts / 2;
// Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little // Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little
// instructions as possible in the CUDA code. // instructions as possible in the CUDA code.
@ -451,9 +451,9 @@ void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const siz
for (int dest_idx = 0; dest_idx < 8; ++dest_idx) for (int dest_idx = 0; dest_idx < 8; ++dest_idx)
{ {
const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1; int const src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1;
const int src_shift = 4 * src_idx; int const src_shift = 4 * src_idx;
const int dest_shift = 4 * dest_idx; int const dest_shift = 4 * dest_idx;
const uint32_t src_bits = (current_register >> src_shift) & 0xF; const uint32_t src_bits = (current_register >> src_shift) & 0xF;
transformed_register |= (src_bits << dest_shift); transformed_register |= (src_bits << dest_shift);
@ -478,8 +478,8 @@ void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size
} }
} }
void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const int8_t* quantized_tensor, void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, int8_t const* quantized_tensor,
const std::vector<size_t>& shape, QuantType quant_type, LayoutDetails details) std::vector<size_t> const& shape, QuantType quant_type, LayoutDetails details)
{ {
// We only want to run this step for weight only quant. // We only want to run this step for weight only quant.
@ -490,23 +490,23 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
const int BITS_PER_ELT = get_bits_in_quant_type(quant_type); int const BITS_PER_ELT = get_bits_in_quant_type(quant_type);
const int elts_in_int32 = 32 / BITS_PER_ELT; int const elts_in_int32 = 32 / BITS_PER_ELT;
const int rows_per_tile = details.rows_per_column_tile; int const rows_per_tile = details.rows_per_column_tile;
TLLM_CHECK_WITH_INFO(!(num_rows % elts_in_int32), TLLM_CHECK_WITH_INFO(!(num_rows % elts_in_int32),
fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", elts_in_int32, num_rows)); fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", elts_in_int32, num_rows));
const uint32_t* input_byte_ptr = reinterpret_cast<const uint32_t*>(quantized_tensor); uint32_t const* input_byte_ptr = reinterpret_cast<uint32_t const*>(quantized_tensor);
uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(interleaved_quantized_tensor); uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(interleaved_quantized_tensor);
TLLM_CHECK_WITH_INFO(!(num_rows % rows_per_tile), TLLM_CHECK_WITH_INFO(!(num_rows % rows_per_tile),
fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", rows_per_tile, num_rows)); fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", rows_per_tile, num_rows));
const int num_vec_rows = num_rows / elts_in_int32; int const num_vec_rows = num_rows / elts_in_int32;
const int vec_rows_per_tile = rows_per_tile / elts_in_int32; int const vec_rows_per_tile = rows_per_tile / elts_in_int32;
const int interleave = details.columns_interleaved; int const interleave = details.columns_interleaved;
for (int expert = 0; expert < num_experts; ++expert) for (int expert = 0; expert < num_experts; ++expert)
{ {
@ -532,8 +532,8 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const
} }
} }
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight, void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight,
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave) std::vector<size_t> const& shape, QuantType quant_type, bool force_interleave)
{ {
int arch = getSMVersion(); int arch = getSMVersion();
if (force_interleave && arch == 90) if (force_interleave && arch == 90)
@ -546,7 +546,7 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, co
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D"); TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
size_t num_elts = 1; size_t num_elts = 1;
for (const auto& dim : shape) for (auto const& dim : shape)
{ {
num_elts *= dim; num_elts *= dim;
} }
@ -620,7 +620,7 @@ Outputs
template <typename ComputeType, typename WeightType> template <typename ComputeType, typename WeightType>
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type, ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector<size_t> const& shape, QuantType quant_type,
bool force_interleave) bool force_interleave)
{ {
@ -633,8 +633,8 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1]; const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2]; const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
const int bits_in_type = get_bits_in_quant_type(quant_type); int const bits_in_type = get_bits_in_quant_type(quant_type);
const int bytes_per_out_col = num_cols * bits_in_type / 8; int const bytes_per_out_col = num_cols * bits_in_type / 8;
std::vector<int8_t> weight_buf; std::vector<int8_t> weight_buf;
if (unprocessed_quantized_weight == nullptr) if (unprocessed_quantized_weight == nullptr)
@ -643,15 +643,15 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
unprocessed_quantized_weight = weight_buf.data(); unprocessed_quantized_weight = weight_buf.data();
} }
const int input_mat_size = num_rows * num_cols; int const input_mat_size = num_rows * num_cols;
const int quantized_mat_size = num_rows * bytes_per_out_col; int const quantized_mat_size = num_rows * bytes_per_out_col;
const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1)); float const quant_range_scale = 1.f / float(1 << (bits_in_type - 1));
std::vector<float> per_col_max(num_cols); std::vector<float> per_col_max(num_cols);
for (int expert = 0; expert < num_experts; ++expert) for (int expert = 0; expert < num_experts; ++expert)
{ {
const WeightType* current_weight = input_weight_ptr + expert * input_mat_size; WeightType const* current_weight = input_weight_ptr + expert * input_mat_size;
int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size; int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size;
// First we find the per column max for this expert weight. // First we find the per column max for this expert weight.
@ -662,7 +662,7 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
for (int ii = 0; ii < num_rows; ++ii) for (int ii = 0; ii < num_rows; ++ii)
{ {
const WeightType* current_weight_row = current_weight + ii * num_cols; WeightType const* current_weight_row = current_weight + ii * num_cols;
for (int jj = 0; jj < num_cols; ++jj) for (int jj = 0; jj < num_cols; ++jj)
{ {
per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj]))); per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj])));
@ -681,15 +681,15 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
for (int ii = 0; ii < num_rows; ++ii) for (int ii = 0; ii < num_rows; ++ii)
{ {
int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col; int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col;
const WeightType* current_weight_row = current_weight + ii * num_cols; WeightType const* current_weight_row = current_weight + ii * num_cols;
for (int jj = 0; jj < bytes_per_out_col; ++jj) for (int jj = 0; jj < bytes_per_out_col; ++jj)
{ {
if (quant_type == QuantType::INT8_WEIGHT_ONLY) if (quant_type == QuantType::INT8_WEIGHT_ONLY)
{ {
const float col_scale = per_col_max[jj]; float const col_scale = per_col_max[jj];
const float weight_elt = float(current_weight_row[jj]); float const weight_elt = float(current_weight_row[jj]);
const float scaled_weight = round(weight_elt / col_scale); float const scaled_weight = round(weight_elt / col_scale);
const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight))); const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight)));
current_quantized_weight_row[jj] = clipped_weight; current_quantized_weight_row[jj] = clipped_weight;
} }
@ -700,12 +700,12 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
int8_t packed_int4s = 0; int8_t packed_int4s = 0;
for (int packed_idx = 0; packed_idx < 2; ++packed_idx) for (int packed_idx = 0; packed_idx < 2; ++packed_idx)
{ {
const int input_idx = 2 * jj + packed_idx; int const input_idx = 2 * jj + packed_idx;
if (input_idx < num_cols) if (input_idx < num_cols)
{ {
const float col_scale = per_col_max[input_idx]; float const col_scale = per_col_max[input_idx];
const float weight_elt = float(current_weight_row[input_idx]); float const weight_elt = float(current_weight_row[input_idx]);
const float scaled_weight = round(weight_elt / col_scale); float const scaled_weight = round(weight_elt / col_scale);
int int_weight = int(scaled_weight); int int_weight = int(scaled_weight);
const int8_t clipped_weight = std::max(-8, std::min(7, int_weight)); const int8_t clipped_weight = std::max(-8, std::min(7, int_weight));
@ -729,47 +729,47 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
} }
template void symmetric_quantize<half, float>( template void symmetric_quantize<half, float>(
int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType, bool); int8_t*, int8_t*, half*, float const*, std::vector<size_t> const&, QuantType, bool);
template void symmetric_quantize<half, half>( template void symmetric_quantize<half, half>(
int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType, bool); int8_t*, int8_t*, half*, half const*, std::vector<size_t> const&, QuantType, bool);
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
int8_t*, int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool); int8_t*, int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector<size_t> const&, QuantType, bool);
template void symmetric_quantize<__nv_bfloat16, float>( template void symmetric_quantize<__nv_bfloat16, float>(
int8_t*, int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType, bool); int8_t*, int8_t*, __nv_bfloat16*, float const*, std::vector<size_t> const&, QuantType, bool);
#endif #endif
template <typename ComputeType, typename WeightType> template <typename ComputeType, typename WeightType>
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr, void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr,
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave) std::vector<size_t> const& shape, QuantType quant_type, bool force_interleave)
{ {
symmetric_quantize( symmetric_quantize(
processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave); processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave);
} }
template void symmetric_quantize<float, float>( template void symmetric_quantize<float, float>(
int8_t*, float*, const float*, const std::vector<size_t>&, QuantType, bool); int8_t*, float*, float const*, std::vector<size_t> const&, QuantType, bool);
template void symmetric_quantize<half, float>( template void symmetric_quantize<half, float>(
int8_t*, half*, const float*, const std::vector<size_t>&, QuantType, bool); int8_t*, half*, float const*, std::vector<size_t> const&, QuantType, bool);
template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType, bool); template void symmetric_quantize<half, half>(int8_t*, half*, half const*, std::vector<size_t> const&, QuantType, bool);
#ifdef ENABLE_BF16 #ifdef ENABLE_BF16
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>( template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool); int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector<size_t> const&, QuantType, bool);
template void symmetric_quantize<__nv_bfloat16, half>( template void symmetric_quantize<__nv_bfloat16, half>(
int8_t*, __nv_bfloat16*, const half*, const std::vector<size_t>&, QuantType, bool); int8_t*, __nv_bfloat16*, half const*, std::vector<size_t> const&, QuantType, bool);
template void symmetric_quantize<half, __nv_bfloat16>( template void symmetric_quantize<half, __nv_bfloat16>(
int8_t*, half*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool); int8_t*, half*, __nv_bfloat16 const*, std::vector<size_t> const&, QuantType, bool);
template void symmetric_quantize<__nv_bfloat16, float>( template void symmetric_quantize<__nv_bfloat16, float>(
int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType, bool); int8_t*, __nv_bfloat16*, float const*, std::vector<size_t> const&, QuantType, bool);
#endif #endif
} // namespace cutlass_kernels } // namespace cutlass_kernels

View File

@ -38,26 +38,26 @@ int get_bits_in_quant_type(QuantType quant_type);
// Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols] // Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols]
// 3-D shapes are [num_experts, num_rows, num_cols] // 3-D shapes are [num_experts, num_rows, num_cols]
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor, void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor,
const std::vector<size_t>& shape, QuantType quant_type, const int64_t arch_version); std::vector<size_t> const& shape, QuantType quant_type, const int64_t arch_version);
void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor,
const std::vector<size_t>& shape, QuantType quant_type); std::vector<size_t> const& shape, QuantType quant_type);
void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type); void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type);
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight, void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight,
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave = false); std::vector<size_t> const& shape, QuantType quant_type, bool force_interleave = false);
template <typename ComputeType, typename WeightType> template <typename ComputeType, typename WeightType>
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr, void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr,
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave); std::vector<size_t> const& shape, QuantType quant_type, bool force_interleave);
// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight // This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight
// to implement a simple reference implementation. // to implement a simple reference implementation.
template <typename ComputeType, typename WeightType> template <typename ComputeType, typename WeightType>
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight, void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type, ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector<size_t> const& shape, QuantType quant_type,
bool force_interleave); bool force_interleave);
} // namespace cutlass_kernels } // namespace cutlass_kernels

View File

@ -58,27 +58,27 @@ public:
virtual ~CutlassFpAIntBGemmRunnerInterface() {} virtual ~CutlassFpAIntBGemmRunnerInterface() {}
virtual void gemm(const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k, virtual void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k,
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
= 0; = 0;
virtual void gemm(const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, virtual void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n,
int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
cudaStream_t stream) cudaStream_t stream)
= 0; = 0;
virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points,
const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig,
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
= 0; = 0;
virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points,
const void* biases, const float alpha, void* C, int m, int n, int k, const int group_size, void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size,
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
= 0; = 0;
// Returns desired workspace size in bytes. // Returns desired workspace size in bytes.
virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0; virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0;
virtual std::vector<tkc::CutlassGemmConfig> getConfigs() const = 0; virtual std::vector<tkc::CutlassGemmConfig> getConfigs() const = 0;
@ -96,20 +96,20 @@ public:
CutlassFpAIntBGemmRunner(); CutlassFpAIntBGemmRunner();
~CutlassFpAIntBGemmRunner(); ~CutlassFpAIntBGemmRunner();
void gemm(const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k, void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k,
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
cudaStream_t stream) override; cudaStream_t stream) override;
void gemm(const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, int k, void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k,
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
cudaStream_t stream) override; cudaStream_t stream) override;
void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points,
const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig,
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override; char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override;
void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points,
const void* biases, const float alpha, void* C, int m, int n, int k, const int group_size, void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size,
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
cudaStream_t stream) override; cudaStream_t stream) override;
@ -120,15 +120,15 @@ public:
// stream); // stream);
// Returns desired workspace size in bytes. // Returns desired workspace size in bytes.
size_t getWorkspaceSize(const int m, const int n, const int k) override; size_t getWorkspaceSize(int const m, int const n, int const k) override;
std::vector<tkc::CutlassGemmConfig> getConfigs() const override; std::vector<tkc::CutlassGemmConfig> getConfigs() const override;
private: private:
template <typename EpilogueTag> template <typename EpilogueTag>
void dispatch_to_arch(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales, void dispatch_to_arch(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n, ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr,
const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr); const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr);
private: private:

View File

@ -52,8 +52,8 @@ namespace cutlass_kernels
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
typename ThreadblockShape, typename WarpShape, int Stages> typename ThreadblockShape, typename WarpShape, int Stages>
void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T* weight_scales, void generic_mixed_gemm_kernelLauncher(T const* A, WeightType const* B, T const* weight_scales,
const T* weight_zero_points, const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size, T const* weight_zero_points, T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
int* occupancy = nullptr) int* occupancy = nullptr)
{ {
@ -127,7 +127,7 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>; using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
const int ldb = cutlass::platform::is_same<cutlass::layout::RowMajor, typename MixedGemmArchTraits::LayoutB>::value int const ldb = cutlass::platform::is_same<cutlass::layout::RowMajor, typename MixedGemmArchTraits::LayoutB>::value
? n ? n
: k * GemmKernel::kInterleave; : k * GemmKernel::kInterleave;
@ -171,7 +171,7 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
} }
} }
const int ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0; int const ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0;
ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f); ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f);
typename Gemm::Arguments args({m, n, k}, group_size, {reinterpret_cast<ElementType*>(const_cast<T*>(A)), k}, typename Gemm::Arguments args({m, n, k}, group_size, {reinterpret_cast<ElementType*>(const_cast<T*>(A)), k},
{reinterpret_cast<CutlassWeightType*>(const_cast<WeightType*>(B)), ldb}, {reinterpret_cast<CutlassWeightType*>(const_cast<WeightType*>(B)), ldb},
@ -230,8 +230,8 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
// quanitzation is only supported on Ampere+ GPUs. // quanitzation is only supported on Ampere+ GPUs.
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
typename ThreadblockShape, typename WarpShape, int Stages> typename ThreadblockShape, typename WarpShape, int Stages>
void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points, void filter_and_run_mixed_gemm(T const* A, WeightType const* B, T const* weight_scales, T const* weight_zero_points,
const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size, T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
int* occupancy = nullptr) int* occupancy = nullptr)
{ {
@ -261,8 +261,8 @@ void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
typename ThreadblockShape, typename WarpShape> typename ThreadblockShape, typename WarpShape>
void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points, void dispatch_gemm_config(T const* A, WeightType const* B, T const* weight_scales, T const* weight_zero_points,
const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size, T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
int* occupancy = nullptr) int* occupancy = nullptr)
{ {
@ -300,9 +300,9 @@ constexpr bool is_fp8()
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType, template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag> typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
void dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales, void dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n, ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
int k, const int group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config, int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
cudaStream_t stream, int* occupancy = nullptr) cudaStream_t stream, int* occupancy = nullptr)
{ {
@ -412,9 +412,9 @@ template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuant
typename BiasType, typename OutputType> typename BiasType, typename OutputType>
template <typename EpilogueTag> template <typename EpilogueTag>
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType,
OutputType>::dispatch_to_arch<EpilogueTag>(const ActivationType* A, const WeightType* B, OutputType>::dispatch_to_arch<EpilogueTag>(ActivationType const* A, WeightType const* B,
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases, ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
const float alpha, OutputType* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config, float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config,
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy) char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy)
{ {
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
@ -453,16 +453,16 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType, template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType> typename BiasType, typename OutputType>
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm( void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, const void* biases, void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases,
const float alpha, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, float const alpha, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig,
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
{ {
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
if constexpr ((QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS) if constexpr ((QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
|| (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY)) || (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY))
{ {
dispatch_to_arch<tkc::EpilogueOpBias>((const ActivationType*) A, (const WeightType*) B, dispatch_to_arch<tkc::EpilogueOpBias>((ActivationType const*) A, (WeightType const*) B,
(const ScaleZeroType*) weight_scales, (const ScaleZeroType*) weight_zero_points, (const BiasType*) biases, (ScaleZeroType const*) weight_scales, (ScaleZeroType const*) weight_zero_points, (BiasType const*) biases,
alpha, (OutputType*) C, m, n, k, group_size, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr); alpha, (OutputType*) C, m, n, k, group_size, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr);
} }
else else
@ -475,8 +475,8 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType, template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType> typename BiasType, typename OutputType>
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm( void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, const void* biases, void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases,
void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr,
const size_t workspace_bytes, cudaStream_t stream) const size_t workspace_bytes, cudaStream_t stream)
{ {
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
@ -487,15 +487,15 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType, template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType> typename BiasType, typename OutputType>
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm( void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, int k, void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k,
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
{ {
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY) if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY)
{ {
dispatch_to_arch<tkc::EpilogueOpBias>((const ActivationType*) A, (const WeightType*) B, dispatch_to_arch<tkc::EpilogueOpBias>((ActivationType const*) A, (WeightType const*) B,
(const ScaleZeroType*) weight_scales, nullptr, nullptr, alpha, (OutputType*) C, m, n, k, k, gemmConfig, (ScaleZeroType const*) weight_scales, nullptr, nullptr, alpha, (OutputType*) C, m, n, k, k, gemmConfig,
workspace_ptr, workspace_bytes, stream, nullptr); workspace_ptr, workspace_bytes, stream, nullptr);
} }
else else
@ -507,7 +507,7 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType, template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
typename BiasType, typename OutputType> typename BiasType, typename OutputType>
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm( void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k, void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k,
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
{ {
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
@ -529,12 +529,12 @@ template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuant
typename BiasType, typename OutputType> typename BiasType, typename OutputType>
size_t size_t
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::getWorkspaceSize( CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::getWorkspaceSize(
const int m, const int n, const int k) int const m, int const n, int const k)
{ {
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
// These are the min tile sizes for each config, which would launch the maximum number of blocks // These are the min tile sizes for each config, which would launch the maximum number of blocks
const int max_grid_m = cutlass::ceil_div(m, MIN_M_TILE); int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE);
const int max_grid_n = cutlass::ceil_div(n, MIN_N_TILE); int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE);
// We need 4 bytes per block in the worst case. We launch split_k_limit in z dim. // We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
return static_cast<size_t>(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4); return static_cast<size_t>(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4);
} }

View File

@ -44,9 +44,9 @@ namespace cutlass_kernels
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType, template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
typename MainloopScheduleType> typename MainloopScheduleType>
void sm90_dispatch_epilogue_schedules(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales, void sm90_dispatch_epilogue_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n, ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr) cudaStream_t stream, int* occupancy = nullptr)
{ {
@ -114,9 +114,9 @@ constexpr bool are_tile_shapes_supported()
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType, template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape> cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape>
void sm90_dispatch_mainloop_schedules(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales, void sm90_dispatch_mainloop_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n, ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr) cudaStream_t stream, int* occupancy = nullptr)
{ {
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
@ -153,9 +153,9 @@ void sm90_dispatch_mainloop_schedules(const ActivationType* A, const WeightType*
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType, template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape> cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape>
void sm90_dispatch_gemm_config(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales, void sm90_dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n, ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr) cudaStream_t stream, int* occupancy = nullptr)
{ {
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
@ -190,9 +190,9 @@ void sm90_dispatch_gemm_config(const ActivationType* A, const WeightType* B, con
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType, template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag> cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
void sm90_dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales, void sm90_dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n, ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
int k, const int group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config, int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
cudaStream_t stream, int* occupancy = nullptr) cudaStream_t stream, int* occupancy = nullptr)
{ {
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);

View File

@ -28,9 +28,9 @@ namespace cutlass_kernels
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType, template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
typename MainloopScheduleType, typename EpilogueScheduleType> typename MainloopScheduleType, typename EpilogueScheduleType>
void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const WeightType* B, void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B,
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases, ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
const float alpha, OutputType* C, int m, int n, int k, const int group_size, float const alpha, OutputType* C, int m, int n, int k, int const group_size,
tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
cudaStream_t stream, int* occupancy = nullptr); cudaStream_t stream, int* occupancy = nullptr);

View File

@ -59,9 +59,9 @@ namespace cutlass_kernels
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType, template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
typename MainloopScheduleType, typename EpilogueScheduleType> typename MainloopScheduleType, typename EpilogueScheduleType>
void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const WeightType* B, void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B,
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases, ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
const float alpha, OutputType* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config, float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config,
char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy) char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy)
{ {
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__); TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
@ -233,7 +233,7 @@ void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const Weigh
StrideS stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(n, cutlass_scale_k, 1)); StrideS stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(n, cutlass_scale_k, 1));
// Use the output as the bias to avoid making a tma descriptor with a nullptr. // Use the output as the bias to avoid making a tma descriptor with a nullptr.
auto output_as_bias_type = reinterpret_cast<const CutlassBiasType*>(C); auto output_as_bias_type = reinterpret_cast<CutlassBiasType const*>(C);
typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, {n, m, k, 1}, typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, {n, m, k, 1},
{reinterpret_cast<CutlassWeightType const*>(B), stride_B, reinterpret_cast<CutlassActivationType const*>(A), {reinterpret_cast<CutlassWeightType const*>(B), stride_B, reinterpret_cast<CutlassActivationType const*>(A),

View File

@ -47,13 +47,13 @@ public:
virtual ~CutlassInt8GemmRunnerInterface() {} virtual ~CutlassInt8GemmRunnerInterface() {}
virtual void gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, virtual void gemm(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol,
const float* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, float const* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
const size_t workspaceBytes, cudaStream_t stream) const size_t workspaceBytes, cudaStream_t stream)
= 0; = 0;
// Returns desired workspace size in bytes. // Returns desired workspace size in bytes.
virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0; virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0;
virtual std::vector<tkc::CutlassGemmConfig> getConfigs() const = 0; virtual std::vector<tkc::CutlassGemmConfig> getConfigs() const = 0;
@ -70,18 +70,18 @@ public:
CutlassInt8GemmRunner(); CutlassInt8GemmRunner();
~CutlassInt8GemmRunner(); ~CutlassInt8GemmRunner();
void gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, const float* alphaRow, void gemm(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, float const* alphaRow,
void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
const size_t workspaceBytes, cudaStream_t stream) override; const size_t workspaceBytes, cudaStream_t stream) override;
// Returns desired workspace size in bytes. // Returns desired workspace size in bytes.
size_t getWorkspaceSize(const int m, const int n, const int k) override; size_t getWorkspaceSize(int const m, int const n, int const k) override;
std::vector<tkc::CutlassGemmConfig> getConfigs() const override; std::vector<tkc::CutlassGemmConfig> getConfigs() const override;
private: private:
void dispatchToArch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, void dispatchToArch(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol,
const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr, float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
const size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr); const size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr);
int mSm; int mSm;

Some files were not shown because too many files have changed in this diff Show More