mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#1274)
* Update TensorRT-LLM --------- Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
parent
728cc0044b
commit
4bb65f216f
@ -59,6 +59,7 @@ PenaltyBreakString: 1000
|
||||
PenaltyExcessCharacter: 1000000
|
||||
PenaltyReturnTypeOnItsOwnLine: 60
|
||||
PointerAlignment: Left
|
||||
QualifierAlignment: Right
|
||||
ReflowComments: true
|
||||
SeparateDefinitionBlocks: Always
|
||||
SortIncludes: CaseSensitive
|
||||
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@ -17,6 +17,16 @@ venv/
|
||||
.local/
|
||||
.hypothesis/
|
||||
.idea/
|
||||
dump*/
|
||||
.trt-internal
|
||||
*.dot
|
||||
*.prof
|
||||
*.log
|
||||
*.pkl
|
||||
*.hdf5
|
||||
*.lock
|
||||
config.json
|
||||
/*.svg
|
||||
cpp/cmake-build-*
|
||||
cpp/.ccache/
|
||||
tensorrt_llm/libs
|
||||
|
||||
@ -355,6 +355,9 @@ however, that it is recommended to use the C++ version.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
* If you encounter accuracy issues in the generated text, you may want to increase
|
||||
the internal precision in the attention layer. For that, pass the `--context_fmha_fp32_acc enable` to
|
||||
`trtllm-build`.
|
||||
|
||||
* It's recommended to add options `–shm-size=1g –ulimit memlock=-1` to the
|
||||
docker or nvidia-docker run command. Otherwise you may see NCCL errors when
|
||||
|
||||
@ -39,7 +39,6 @@ Take GPT-350M as an example for single GPU
|
||||
|
||||
```
|
||||
./benchmarks/gptSessionBenchmark \
|
||||
--model gpt_350m \
|
||||
--engine_dir "../../benchmarks/gpt_350m/" \
|
||||
--batch_size "1" \
|
||||
--input_output_len "60,20"
|
||||
@ -50,7 +49,6 @@ Take GPT-350M as an example for single GPU
|
||||
Take GPT-175B as an example for multiple GPUs
|
||||
```
|
||||
mpirun -n 8 ./benchmarks/gptSessionBenchmark \
|
||||
--model gpt_175b \
|
||||
--engine_dir "../../benchmarks/gpt_175b/" \
|
||||
--batch_size "1" \
|
||||
--input_output_len "60,20"
|
||||
@ -125,7 +123,6 @@ cd cpp/build
|
||||
Take GPT-350M as an example for single GPU V1 batching
|
||||
```
|
||||
./benchmarks/gptManagerBenchmark \
|
||||
--model gpt \
|
||||
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
|
||||
--type V1 \
|
||||
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
|
||||
@ -135,7 +132,6 @@ Take GPT-350M as an example for single GPU V1 batching
|
||||
Take GPT-350M as an example for 2-GPU inflight batching
|
||||
```
|
||||
mpirun -n 2 ./benchmarks/gptManagerBenchmark \
|
||||
--model gpt \
|
||||
--engine_dir ../../examples/gpt/trt_engine/gpt2-ib/fp16/2-gpu/ \
|
||||
--type IFB \
|
||||
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
|
||||
@ -165,7 +161,6 @@ Given a `static_emulated_batch_size` of `n` the server will wait for `n` request
|
||||
Take GPT-350M as an example for single GPU with static batching
|
||||
```
|
||||
./benchmarks/gptManagerBenchmark \
|
||||
--model gpt \
|
||||
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
|
||||
--type IFB \
|
||||
--static_emulated_batch_size 32 \
|
||||
|
||||
@ -237,7 +237,7 @@ int main(int argc, char* argv[])
|
||||
benchmarkBert(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes, inLens,
|
||||
logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>());
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_ERROR(e.what());
|
||||
return 1;
|
||||
|
||||
@ -24,6 +24,7 @@
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/tllmLogger.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
@ -64,20 +65,18 @@ struct BenchmarkParams
|
||||
class WorkItem
|
||||
{
|
||||
public:
|
||||
WorkItem(std::shared_ptr<InferenceRequest> ir, uint64_t requestId)
|
||||
: mInferenceRequest(ir)
|
||||
WorkItem(std::shared_ptr<InferenceRequest> inferenceRequest, uint64_t requestId)
|
||||
: mInferenceRequest(std::move(inferenceRequest))
|
||||
, mRequestId(requestId)
|
||||
{
|
||||
}
|
||||
|
||||
~WorkItem() {}
|
||||
|
||||
uint64_t requestId() const
|
||||
[[nodiscard]] uint64_t requestId() const
|
||||
{
|
||||
return mRequestId;
|
||||
}
|
||||
|
||||
std::shared_ptr<InferenceRequest> getInferenceRequest() const
|
||||
[[nodiscard]] std::shared_ptr<InferenceRequest> getInferenceRequest() const
|
||||
{
|
||||
return mInferenceRequest;
|
||||
}
|
||||
@ -93,7 +92,7 @@ class WorkItemsQueue
|
||||
public:
|
||||
void clear()
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
mPendingWorkItems.clear();
|
||||
mPendingWorkItemsReqIds.clear();
|
||||
mInProgressWorkItems.clear();
|
||||
@ -289,7 +288,7 @@ public:
|
||||
|
||||
if (outputFile.is_open())
|
||||
{
|
||||
for (const auto& header : headers)
|
||||
for (auto const& header : headers)
|
||||
{
|
||||
outputFile << header << ",";
|
||||
}
|
||||
@ -340,13 +339,12 @@ public:
|
||||
mExecutor = std::make_shared<texec::Executor>(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig);
|
||||
}
|
||||
|
||||
~ExecutorServer() {}
|
||||
|
||||
void enqueue(std::vector<texec::Request> requests, bool warmup = false)
|
||||
{
|
||||
try
|
||||
{
|
||||
std::vector<SizeType> inputLengths, maxNewTokens;
|
||||
std::vector<SizeType> inputLengths;
|
||||
std::vector<SizeType> maxNewTokens;
|
||||
for (auto const& request : requests)
|
||||
{
|
||||
inputLengths.push_back(request.getInputTokenIds().size());
|
||||
@ -363,11 +361,10 @@ public:
|
||||
mActiveCount++;
|
||||
}
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_THROW("%s", e.what());
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
void waitForResponses(std::optional<SizeType> numRequests, bool warmup = false)
|
||||
@ -415,17 +412,16 @@ private:
|
||||
class GptServer
|
||||
{
|
||||
public:
|
||||
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
|
||||
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, SizeType maxBeamWidth,
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams,
|
||||
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId, std::chrono::milliseconds waitSleep,
|
||||
std::optional<uint64_t> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs, bool logIterationData)
|
||||
std::optional<SizeType> const staticEmulatedBatchSize,
|
||||
std::optional<std::chrono::milliseconds> const batchTimeout, bool logIterationData)
|
||||
: mRecorder(std::move(recorder))
|
||||
, mTerminateReqId(terminateReqId)
|
||||
, mWaitSleep(waitSleep)
|
||||
, mStaticEmulatedBatchSize(staticEmulatedBatchSize)
|
||||
, mEmulatedBatchEndTimestamp(
|
||||
std::chrono::steady_clock::now() + std::chrono::milliseconds(staticEmulatedTimeoutMs))
|
||||
, mStaticEmulatedTimeoutMs(staticEmulatedTimeoutMs)
|
||||
, mBatchTimeout(batchTimeout.value_or(std::chrono::milliseconds{0}))
|
||||
, mActiveCount(0)
|
||||
{
|
||||
ReturnBatchManagerStatsCallback iterationDataCallback = [this, logIterationData](std::string const& log)
|
||||
@ -473,16 +469,21 @@ public:
|
||||
mRecorder->recordStart(request, requestId);
|
||||
mWorkItemsQueue.push(request, requestId);
|
||||
}
|
||||
catch (const tc::TllmException& e)
|
||||
catch (tc::TllmException const& e)
|
||||
{
|
||||
throw;
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_THROW("%s", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void resetBatchDeadline()
|
||||
{
|
||||
mBatchDeadline = (std::chrono::steady_clock::now() + mBatchTimeout).time_since_epoch();
|
||||
}
|
||||
|
||||
void waitForEmpty() const
|
||||
{
|
||||
while (!mWorkItemsQueue.empty())
|
||||
@ -502,9 +503,9 @@ public:
|
||||
}
|
||||
|
||||
// Return up to max_num_requests inference requests.
|
||||
std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(const int max_num_requests)
|
||||
std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(int const max_num_requests)
|
||||
{
|
||||
std::list<std::shared_ptr<InferenceRequest>> rval;
|
||||
std::list<std::shared_ptr<InferenceRequest>> inferenceRequests;
|
||||
auto& comm = COMM_SESSION;
|
||||
if (max_num_requests > 0)
|
||||
{
|
||||
@ -515,12 +516,12 @@ public:
|
||||
auto const numNewWorkItems = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
|
||||
static_cast<int64_t>(max_num_requests));
|
||||
|
||||
bool readyForNextBatch = numNewWorkItems > 0;
|
||||
bool const timeout = std::chrono::steady_clock::now().time_since_epoch() > mBatchDeadline.load();
|
||||
bool readyForNextBatch = numNewWorkItems > 0 && timeout;
|
||||
if (mStaticEmulatedBatchSize)
|
||||
{
|
||||
if (numNewWorkItems > 0)
|
||||
{
|
||||
bool const timeout = std::chrono::steady_clock::now() > mEmulatedBatchEndTimestamp;
|
||||
bool const previousBatchFinished = mActiveCount == 0;
|
||||
bool const haveEnoughForNextBatch = numNewWorkItems >= mStaticEmulatedBatchSize.value();
|
||||
readyForNextBatch = previousBatchFinished && (timeout || haveEnoughForNextBatch);
|
||||
@ -529,26 +530,23 @@ public:
|
||||
{
|
||||
// Timeout should only begin once we have at least 1 pending request.
|
||||
// Reset timeout when no requests are pending or we submit a new batch.
|
||||
mEmulatedBatchEndTimestamp
|
||||
= std::chrono::steady_clock::now() + std::chrono::milliseconds(mStaticEmulatedTimeoutMs);
|
||||
resetBatchDeadline();
|
||||
}
|
||||
}
|
||||
|
||||
if (readyForNextBatch)
|
||||
{
|
||||
int count = 0;
|
||||
// Only add a single batch at a time when emulating static batching
|
||||
auto const numItemsToAdd = std::min(
|
||||
numNewWorkItems, static_cast<int64_t>(mStaticEmulatedBatchSize.value_or(numNewWorkItems)));
|
||||
mActiveCount += numItemsToAdd;
|
||||
while (count < numItemsToAdd)
|
||||
while (inferenceRequests.size() < numItemsToAdd)
|
||||
{
|
||||
auto [workItem, markedInProgress] = mWorkItemsQueue.pop();
|
||||
|
||||
if (markedInProgress)
|
||||
{
|
||||
rval.emplace_back(workItem->getInferenceRequest());
|
||||
count++;
|
||||
inferenceRequests.emplace_back(workItem->getInferenceRequest());
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -561,14 +559,14 @@ public:
|
||||
}
|
||||
if (world_size > 1)
|
||||
{
|
||||
auto numNewWorkItems = static_cast<int64_t>(rval.size());
|
||||
auto numNewWorkItems = static_cast<int64_t>(inferenceRequests.size());
|
||||
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
|
||||
if (numNewWorkItems > 0)
|
||||
{
|
||||
std::vector<int64_t> packed;
|
||||
for (auto const& ir : rval)
|
||||
for (auto const& infReq : inferenceRequests)
|
||||
{
|
||||
auto vpacked = ir->serialize();
|
||||
auto vpacked = infReq->serialize();
|
||||
packed.push_back(static_cast<int64_t>(vpacked.size()));
|
||||
packed.insert(
|
||||
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
|
||||
@ -590,18 +588,18 @@ public:
|
||||
for (int64_t count = 0; count < numNewWorkItems; ++count)
|
||||
{
|
||||
int64_t n = *(packed_ptr++);
|
||||
auto ir = InferenceRequest::deserialize(packed_ptr);
|
||||
auto infReq = InferenceRequest::deserialize(packed_ptr);
|
||||
packed_ptr += n;
|
||||
rval.emplace_back(ir);
|
||||
inferenceRequests.emplace_back(infReq);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return rval;
|
||||
return inferenceRequests;
|
||||
}
|
||||
|
||||
void sendResponse(uint64_t requestId, [[maybe_unused]] std::list<NamedTensor> const& response_tensors,
|
||||
bool final_response, [[maybe_unused]] const std::string& errMsg)
|
||||
bool final_response, [[maybe_unused]] std::string const& errMsg)
|
||||
{
|
||||
// `response_tensors` contains `outputIds, sequenceLength, [contextLogits, generationLogits], logProbs,
|
||||
// cumLogProbs`. `contextLogits, generationLogits` are optional, only contained when `gather_context_logits` and
|
||||
@ -616,7 +614,7 @@ public:
|
||||
mActiveCount--;
|
||||
}
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_ERROR("Failed to send response for requestId %lu\n%s", requestId, e.what());
|
||||
}
|
||||
@ -628,9 +626,9 @@ private:
|
||||
WorkItemsQueue mWorkItemsQueue;
|
||||
std::optional<uint64_t> mTerminateReqId;
|
||||
std::chrono::milliseconds mWaitSleep;
|
||||
std::optional<int> mStaticEmulatedBatchSize;
|
||||
std::chrono::time_point<std::chrono::steady_clock> mEmulatedBatchEndTimestamp;
|
||||
int32_t mStaticEmulatedTimeoutMs;
|
||||
std::optional<SizeType> mStaticEmulatedBatchSize;
|
||||
std::chrono::milliseconds mBatchTimeout;
|
||||
std::atomic<std::chrono::steady_clock::time_point::duration> mBatchDeadline;
|
||||
std::atomic<uint64_t> mActiveCount;
|
||||
|
||||
}; // class GptServer
|
||||
@ -674,10 +672,9 @@ std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const&
|
||||
auto request = std::make_shared<InferenceRequest>(reqId);
|
||||
auto const& inputIds = sample.inputIds;
|
||||
request->setInputIds(bufferManager.copyFrom(
|
||||
inputIds, ITensor::makeShape({static_cast<SizeType>(inputIds.size())}), MemoryType::kPINNED));
|
||||
inputIds, ITensor::makeShape({static_cast<SizeType>(inputIds.size())}), MemoryType::kCPU));
|
||||
auto const requestOutputLen = sample.outputLen;
|
||||
request->setMaxNewTokens(
|
||||
bufferManager.copyFrom(&requestOutputLen, ITensor::makeShape({1, 1}), MemoryType::kPINNED));
|
||||
request->setMaxNewTokens(bufferManager.copyFrom(&requestOutputLen, ITensor::makeShape({1, 1}), MemoryType::kCPU));
|
||||
request->setBeamWidth(beamWidthTensor);
|
||||
if (eosId != nullptr)
|
||||
{
|
||||
@ -704,14 +701,15 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType const& beamWid
|
||||
{
|
||||
auto samplingConfig = texec::SamplingConfig{beamWidth};
|
||||
auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false};
|
||||
return texec::Request(sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId);
|
||||
return {sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId};
|
||||
}
|
||||
|
||||
void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType,
|
||||
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
|
||||
std::optional<int32_t> const& eosId, std::optional<int32_t> const& padId, BenchmarkParams const& benchmarkParams,
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits,
|
||||
bool returnGenerationLogits, std::optional<int> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs,
|
||||
std::optional<TokenIdType> const& eosId, std::optional<TokenIdType> const& padId,
|
||||
BenchmarkParams const& benchmarkParams, batch_scheduler::SchedulerPolicy schedulerPolicy,
|
||||
std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits,
|
||||
std::optional<SizeType> const staticEmulatedBatchSize, std::optional<std::chrono::milliseconds> const batchTimeout,
|
||||
bool logIterationData)
|
||||
{
|
||||
auto const worldConfig = WorldConfig::mpi();
|
||||
@ -736,14 +734,14 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
bufferManager.copyFrom(&beamWidth, ITensor::makeShape({1}), MemoryType::kPINNED)};
|
||||
|
||||
// Load dataset
|
||||
const auto samples = parseWorkloadJson(datasetPath, maxNumSamples);
|
||||
const auto numSamples = samples.size();
|
||||
auto const samples = parseWorkloadJson(datasetPath, maxNumSamples);
|
||||
auto const numSamples = samples.size();
|
||||
|
||||
const int maxBeamWidth = beamWidth;
|
||||
int const maxBeamWidth = beamWidth;
|
||||
auto recorder = std::make_shared<Recorder>(opCsvFile);
|
||||
uint64_t terminateReqId = numSamples + 1;
|
||||
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams,
|
||||
recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, staticEmulatedTimeoutMs, logIterationData);
|
||||
recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, batchTimeout, logIterationData);
|
||||
|
||||
ITensor::SharedPtr eosIdTensor{
|
||||
eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr};
|
||||
@ -761,6 +759,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
if (worldConfig.getRank() == 0)
|
||||
{
|
||||
// Warm up
|
||||
gptServer->resetBatchDeadline();
|
||||
SizeType reqId = 0;
|
||||
for (auto i = 0; i < warmUp; ++i)
|
||||
{
|
||||
@ -774,6 +773,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
|
||||
// Benchmark
|
||||
recorder->initialize();
|
||||
gptServer->resetBatchDeadline();
|
||||
for (std::size_t i = 0; i < numSamples; ++i)
|
||||
{
|
||||
auto request = makeRequest(i + 1, samples[i], beamWidthTensor, eosIdTensor, padIdTensor, bufferManager,
|
||||
@ -806,23 +806,19 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits,
|
||||
bool returnGenerationLogits, std::optional<int> const staticEmulatedBatchSize, bool logIterationData)
|
||||
{
|
||||
// Check that mpi size is 1 for now
|
||||
auto const worldConfig = WorldConfig::mpi();
|
||||
if (worldConfig.getSize() > 1)
|
||||
{
|
||||
TLLM_THROW("benchmarkExecutor does not yet support mpiSize > 1");
|
||||
}
|
||||
auto const& world = tensorrt_llm::mpi::MpiComm::world();
|
||||
auto worldRank = world.getRank();
|
||||
|
||||
// Load dataset
|
||||
const auto samples = parseWorkloadJson(datasetPath, maxNumSamples);
|
||||
const auto numSamples = samples.size();
|
||||
auto const samples = parseWorkloadJson(datasetPath, maxNumSamples);
|
||||
auto const numSamples = samples.size();
|
||||
|
||||
auto recorder = std::make_shared<Recorder>(opCsvFile);
|
||||
|
||||
auto executorServer = std::make_shared<ExecutorServer>(engineDir, modelType, beamWidth, schedulerPolicy,
|
||||
benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData);
|
||||
|
||||
if (worldConfig.getRank() == 0)
|
||||
if (worldRank == 0)
|
||||
{
|
||||
// Warm up
|
||||
{
|
||||
@ -849,7 +845,7 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m
|
||||
delays.push_back(static_cast<int>(samples[i].delay * 1000));
|
||||
}
|
||||
|
||||
bool hasDelay = std::any_of(delays.begin(), delays.end(), [](const auto& delay) { return delay > 0; });
|
||||
bool hasDelay = std::any_of(delays.begin(), delays.end(), [](auto const& delay) { return delay > 0; });
|
||||
if (hasDelay && staticEmulatedBatchSize)
|
||||
{
|
||||
TLLM_THROW("Executor benchmark doesn't support delays with emulated static batch sizes");
|
||||
@ -910,9 +906,6 @@ int main(int argc, char* argv[])
|
||||
cxxopts::Options options(
|
||||
"TensorRT-LLM BatchManager Benchmark", "TensorRT-LLM BatchManager Benchmark for GPT and GPT-like models.");
|
||||
options.add_options()("h,help", "Print usage");
|
||||
// TODO(rkobus): remove because unused
|
||||
options.add_options()(
|
||||
"m,model", "Model name specified for engines.", cxxopts::value<std::string>()->default_value("gpt_350m"));
|
||||
options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>());
|
||||
options.add_options()(
|
||||
"api", "API type: gptManager or executor.", cxxopts::value<std::string>()->default_value("gptManager"));
|
||||
@ -929,8 +922,8 @@ int main(int argc, char* argv[])
|
||||
options.add_options()(
|
||||
"warm_up", "Specify warm up iterations before benchmark starts.", cxxopts::value<int>()->default_value("2"));
|
||||
options.add_options()(
|
||||
"eos_id", "Specify the end-of-sequence token id.", cxxopts::value<int>()->default_value("-1"));
|
||||
options.add_options()("pad_id", "Specify the padding token id.", cxxopts::value<int>());
|
||||
"eos_id", "Specify the end-of-sequence token id.", cxxopts::value<TokenIdType>()->default_value("-1"));
|
||||
options.add_options()("pad_id", "Specify the padding token id.", cxxopts::value<TokenIdType>());
|
||||
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
|
||||
options.add_options()(
|
||||
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
|
||||
@ -949,11 +942,15 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.",
|
||||
cxxopts::value<std::string>()->default_value("guaranteed_no_evict"));
|
||||
|
||||
options.add_options()("first_batch_delay",
|
||||
"Delay before submitting the first batch of requests. This can be used to increase the size of the first "
|
||||
"batch.",
|
||||
cxxopts::value<int32_t>());
|
||||
options.add_options()("static_emulated_batch_size",
|
||||
"Emulate static batching performance with the provided batch size.", cxxopts::value<int>());
|
||||
"Emulate static batching performance with the provided batch size.", cxxopts::value<SizeType>());
|
||||
options.add_options()("static_emulated_timeout",
|
||||
"Timeout (ms) before launching a partial batch in emulated static batching mode",
|
||||
cxxopts::value<int>()->default_value("500"));
|
||||
cxxopts::value<int32_t>()->default_value("500"));
|
||||
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
|
||||
cxxopts::value<std::string>()->default_value("error"));
|
||||
options.add_options()("log_iteration_data", "On each decoder iteration, print batch state metadata.",
|
||||
@ -1042,23 +1039,31 @@ int main(int argc, char* argv[])
|
||||
// Argument: Enable return context logits
|
||||
bool returnGenerationLogits = result["return_generation_logits"].as<bool>();
|
||||
|
||||
std::optional<int32_t> padId;
|
||||
std::optional<TokenIdType> padId;
|
||||
// Argument: Padding token id
|
||||
if (result.count("pad_id"))
|
||||
{
|
||||
padId = result["pad_id"].as<int>();
|
||||
padId = result["pad_id"].as<TokenIdType>();
|
||||
}
|
||||
|
||||
// Argument: End-of-sentence token id
|
||||
std::optional<int32_t> eosId = result["eos_id"].as<int>();
|
||||
std::optional<TokenIdType> eosId = result["eos_id"].as<TokenIdType>();
|
||||
|
||||
std::optional<int> staticEmulatedBatchSize;
|
||||
std::optional<std::chrono::milliseconds> batchTimeout;
|
||||
// Argument: first_batch_delay
|
||||
if (result.count("first_batch_delay"))
|
||||
{
|
||||
batchTimeout = std::chrono::milliseconds{result["first_batch_delay"].as<int32_t>()};
|
||||
}
|
||||
|
||||
std::optional<SizeType> staticEmulatedBatchSize;
|
||||
// Argument: Static emulated batch size
|
||||
if (result.count("static_emulated_batch_size"))
|
||||
{
|
||||
staticEmulatedBatchSize = result["static_emulated_batch_size"].as<int>();
|
||||
staticEmulatedBatchSize = result["static_emulated_batch_size"].as<SizeType>();
|
||||
|
||||
batchTimeout = std::chrono::milliseconds{result["static_emulated_timeout"].as<int32_t>()};
|
||||
}
|
||||
auto const staticEmulatedTimeout = result["static_emulated_timeout"].as<int>();
|
||||
|
||||
// Argument: Scheduler policy
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy;
|
||||
@ -1114,10 +1119,10 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
benchmarkGptManager(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile,
|
||||
maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy,
|
||||
waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, staticEmulatedTimeout,
|
||||
waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, batchTimeout,
|
||||
logIterationData);
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_ERROR(e.what());
|
||||
return 1;
|
||||
@ -1131,7 +1136,7 @@ int main(int argc, char* argv[])
|
||||
beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy, waitSleep,
|
||||
returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, logIterationData);
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_ERROR(e.what());
|
||||
return 1;
|
||||
|
||||
@ -15,7 +15,6 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
||||
#include "tensorrt_llm/runtime/gptJsonConfig.h"
|
||||
#include "tensorrt_llm/runtime/gptSession.h"
|
||||
@ -56,12 +55,11 @@ size_t monitorMemory(std::atomic_bool& done)
|
||||
return peakMem;
|
||||
}
|
||||
|
||||
void benchmarkGptSession(std::string const& modelName, std::filesystem::path const& dataPath,
|
||||
std::vector<int> const& batchSizes, int beamWidth, std::vector<std::vector<int>> const& inOutLen,
|
||||
std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp, int numRuns, int duration,
|
||||
GptSession::Config& sessionConfig, bool cudaGraphMode, bool printAllLogits, bool disableForceMaxTokens)
|
||||
void benchmarkGptSession(std::filesystem::path const& dataPath, std::vector<int> const& batchSizes, int beamWidth,
|
||||
std::vector<std::vector<int>> const& inOutLen, std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp,
|
||||
int numRuns, int duration, GptSession::Config& sessionConfig, bool cudaGraphMode, bool printAllLogits,
|
||||
bool disableForceMaxTokens)
|
||||
{
|
||||
std::string modelNameHyphen = modelName;
|
||||
std::filesystem::path jsonFileName = dataPath / "config.json";
|
||||
auto const json = GptJsonConfig::parse(jsonFileName);
|
||||
auto const modelConfig = json.getModelConfig();
|
||||
@ -69,7 +67,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
SizeType deviceCount{0};
|
||||
TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
|
||||
auto const worldConfig = WorldConfig::mpi(deviceCount, json.getTensorParallelism(), json.getPipelineParallelism());
|
||||
auto const enginePath = dataPath / json.engineFilename(worldConfig, modelNameHyphen);
|
||||
auto const enginePath = dataPath / json.engineFilename(worldConfig);
|
||||
auto const dtype = modelConfig.getDataType();
|
||||
auto const maxNumTokens = modelConfig.getMaxNumTokens();
|
||||
auto const useHalf = (dtype == nvinfer1::DataType::kHALF);
|
||||
@ -104,7 +102,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
|
||||
auto& memoryCounter = MemoryCounters::getInstance();
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
|
||||
std::atomic_bool done;
|
||||
for (auto const batchSize : batchSizes)
|
||||
{
|
||||
if (inputPacked && maxNumTokens != std::nullopt)
|
||||
@ -114,10 +112,11 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
"benchmark on %d tokens",
|
||||
maxNumTokens.value(), maxBatchSize * maxInputLength);
|
||||
}
|
||||
std::atomic_bool done = false;
|
||||
done = false;
|
||||
auto peakMemFuture = std::async(&monitorMemory, std::ref(done));
|
||||
size_t peakMem;
|
||||
try
|
||||
{
|
||||
auto peakMemFuture = std::async(&monitorMemory, std::ref(done));
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
|
||||
std::vector<SizeType> inputLengthsHost(batchSize, maxInputLength);
|
||||
@ -205,7 +204,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
|
||||
TLLM_LOG_INFO(memoryCounter.toString());
|
||||
done = true;
|
||||
size_t peakMem = peakMemFuture.get();
|
||||
peakMemFuture.wait();
|
||||
peakMem = peakMemFuture.get();
|
||||
|
||||
printf("Benchmarking done. Iteration: %d, duration: %.2f sec.\n", iterIdx, curDuration / 1000);
|
||||
|
||||
@ -275,6 +275,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
std::size_t found = std::string(e.what()).find("out of memory");
|
||||
// We need to kill the memory monitor when OOM.
|
||||
done = true;
|
||||
peakMemFuture.wait();
|
||||
peakMem = peakMemFuture.get();
|
||||
|
||||
// Unexpected error; rethrow
|
||||
if (found == std::string::npos)
|
||||
@ -297,6 +299,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
||||
{
|
||||
// We need to kill memory monitor when any other issue occurs
|
||||
done = true;
|
||||
peakMemFuture.wait();
|
||||
peakMem = peakMemFuture.get();
|
||||
throw;
|
||||
}
|
||||
}
|
||||
@ -311,8 +315,6 @@ int main(int argc, char* argv[])
|
||||
cxxopts::Options options(
|
||||
"TensorRT-LLM C++ Runtime Benchmark", "TensorRT-LLM C++ Runtime Benchmark for GPT and GPT-like models.");
|
||||
options.add_options()("h,help", "Print usage");
|
||||
options.add_options()(
|
||||
"m,model", "Model name specified for engines.", cxxopts::value<std::string>()->default_value("gpt_350m"));
|
||||
options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>());
|
||||
options.add_options()("batch_size",
|
||||
"Specify batch size(s) you want to benchmark. Multiple batch sizes can be separated by \";\", example: "
|
||||
@ -459,11 +461,11 @@ int main(int argc, char* argv[])
|
||||
|
||||
try
|
||||
{
|
||||
benchmarkGptSession(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes,
|
||||
beamWidth, inOutLen, logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(),
|
||||
result["duration"].as<int>(), sessionConfig, enableCudaGraph, printAllLogits, disableForceMaxTokens);
|
||||
benchmarkGptSession(result["engine_dir"].as<std::string>(), batchSizes, beamWidth, inOutLen, logger,
|
||||
result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>(), sessionConfig,
|
||||
enableCudaGraph, printAllLogits, disableForceMaxTokens);
|
||||
}
|
||||
catch (const std::exception& e)
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_ERROR(e.what());
|
||||
return 1;
|
||||
|
||||
@ -86,6 +86,7 @@ class EncDecBuildConfig:
|
||||
max_output_len: Optional[int] = None
|
||||
builder_opt: Optional[int] = None
|
||||
n_mels: Optional[int] = None
|
||||
skip_cross_qkv: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
assert self.head_size is not None
|
||||
|
||||
@ -89,6 +89,10 @@ class BaseBenchmark(object):
|
||||
(f'Engine world size ({world_size}) != Runtime world size ({self.world_size})')
|
||||
# Load config into self
|
||||
for key, value in self.config['pretrained_config'].items():
|
||||
if key == "ssm_cfg":
|
||||
for ssm_key, ssm_value in value.items():
|
||||
setattr(self, "mamba_" + ssm_key, ssm_value)
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
self.quant_mode = QuantMode.from_quant_algo(
|
||||
|
||||
@ -327,9 +327,16 @@ def main(args):
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
latencies = []
|
||||
# Disable Host memory monitor when cuda graph is enabled for cuda graph performance.
|
||||
disable_host_mem_monitor = False
|
||||
if args.enable_cuda_graph:
|
||||
logger.warning(
|
||||
'Disable host memory monitor when cuda graph is enabled.')
|
||||
disable_host_mem_monitor = True
|
||||
|
||||
if not disable_mem_monitor:
|
||||
memory_monitor = MemoryMonitor()
|
||||
memory_monitor = MemoryMonitor(
|
||||
disable_host_mem_monitor=disable_host_mem_monitor)
|
||||
memory_monitor.start()
|
||||
|
||||
iter_idx = 0
|
||||
|
||||
@ -648,9 +648,12 @@ def build_gpt(args):
|
||||
'tp_size': world_size,
|
||||
},
|
||||
}
|
||||
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(config)
|
||||
elif family == "internlm":
|
||||
quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization)
|
||||
|
||||
config = {
|
||||
'architecture':
|
||||
'LLaMAForCausalLM',
|
||||
@ -673,8 +676,10 @@ def build_gpt(args):
|
||||
build_config['n_positions'],
|
||||
'hidden_act':
|
||||
build_config['hidden_act'],
|
||||
'quantization':
|
||||
quant_mode.to_dict(),
|
||||
'quantization': {
|
||||
'quant_algo': quant_algo,
|
||||
'kv_cache_quant_algo': kv_cache_quant_algo
|
||||
},
|
||||
'mapping': {
|
||||
'world_size': world_size,
|
||||
'tp_size': world_size
|
||||
@ -696,6 +701,7 @@ def build_gpt(args):
|
||||
"has_zero_point": True,
|
||||
"pre_quant_scale": False,
|
||||
})
|
||||
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(config)
|
||||
elif family == "qwen":
|
||||
@ -1038,6 +1044,7 @@ def enc_dec_build_helper(component, config, args):
|
||||
or quant_mode.is_int8_weight_only()),
|
||||
quant_mode=quant_mode,
|
||||
n_mels=n_mels,
|
||||
skip_cross_qkv=config['skip_cross_qkv'],
|
||||
)
|
||||
|
||||
# build engine
|
||||
|
||||
@ -22,7 +22,7 @@ from tensorrt_llm.profiler import (MemUnitType, bytes_to_target_unit,
|
||||
|
||||
class MemoryMonitor:
|
||||
|
||||
def __init__(self, query_interval=0.1):
|
||||
def __init__(self, query_interval=0.1, disable_host_mem_monitor=False):
|
||||
self.query_interval = query_interval # second(s)
|
||||
self.mem_monitor_process = None
|
||||
# bytes
|
||||
@ -35,6 +35,8 @@ class MemoryMonitor:
|
||||
self.signal_event = Event() # Sending signal to subprocess
|
||||
self.peak_mem_queue = Queue() # Receiving results from subprocess
|
||||
|
||||
self.disable_host_mem_monitor = disable_host_mem_monitor
|
||||
|
||||
def start(self):
|
||||
self.mem_monitor_process = Process(target=self._upd_peak_memory_usage,
|
||||
args=(self.signal_event,
|
||||
@ -70,6 +72,9 @@ class MemoryMonitor:
|
||||
peak_mem_queue.put((peak_host_used, peak_device_used))
|
||||
|
||||
def get_memory_usage(self):
|
||||
if self.disable_host_mem_monitor:
|
||||
host_used = 0
|
||||
else:
|
||||
host_used, _, _ = host_memory_info(self.pid)
|
||||
device_used, _, _ = device_memory_info()
|
||||
return host_used, device_used
|
||||
|
||||
@ -36,6 +36,7 @@ option(NVTX_DISABLE "Disable all NVTX features" ON)
|
||||
option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)
|
||||
option(FAST_BUILD "Skip compiling some kernels to accelerate compiling" OFF)
|
||||
option(FAST_MATH "Compiling in fast math mode" OFF)
|
||||
option(INDEX_RANGE_CHECK "Compiling with index range checks" OFF)
|
||||
|
||||
if(NVTX_DISABLE)
|
||||
add_compile_definitions("NVTX_DISABLE")
|
||||
@ -97,6 +98,11 @@ if(FAST_BUILD)
|
||||
message(WARNING "Skip some kernels to accelerate compilation")
|
||||
endif()
|
||||
|
||||
if(INDEX_RANGE_CHECK)
|
||||
add_compile_definitions("INDEX_RANGE_CHECK")
|
||||
message(WARNING "Check index range to detect OOB accesses")
|
||||
endif()
|
||||
|
||||
# Determine CUDA version before enabling the language extension
|
||||
check_language(CUDA)
|
||||
if(CMAKE_CUDA_COMPILER)
|
||||
@ -162,10 +168,6 @@ message(STATUS " version: ${CUDAToolkit_VERSION}")
|
||||
message(STATUS " libraries: ${CUDAToolkit_LIBRARY_DIR}")
|
||||
message(STATUS " include path: ${CUDAToolkit_INCLUDE_DIRS}")
|
||||
|
||||
find_library(
|
||||
CUDNN_LIB cudnn
|
||||
HINTS ${CUDNN_ROOT_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
||||
PATH_SUFFIXES lib64 lib lib/x64)
|
||||
set(CUBLAS_LIB CUDA::cublas)
|
||||
set(CUBLASLT_LIB CUDA::cublasLt)
|
||||
set(CUDA_DRV_LIB CUDA::cuda_driver)
|
||||
|
||||
@ -29,9 +29,9 @@ class InferenceRequest;
|
||||
class NamedTensor;
|
||||
|
||||
using GetInferenceRequestsCallback = std::function<std::list<std::shared_ptr<InferenceRequest>>(int32_t)>;
|
||||
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, const std::string&)>;
|
||||
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, std::string const&)>;
|
||||
using PollStopSignalCallback = std::function<std::unordered_set<uint64_t>()>;
|
||||
// json of stats as a string
|
||||
using ReturnBatchManagerStatsCallback = std::function<void(const std::string&)>;
|
||||
using ReturnBatchManagerStatsCallback = std::function<void(std::string const&)>;
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -312,9 +312,9 @@ public:
|
||||
|
||||
[[nodiscard]] std::vector<int64_t> serialize() const;
|
||||
|
||||
static std::shared_ptr<InferenceRequest> deserialize(const std::vector<int64_t>& packed);
|
||||
static std::shared_ptr<InferenceRequest> deserialize(std::vector<int64_t> const& packed);
|
||||
|
||||
static std::shared_ptr<InferenceRequest> deserialize(const int64_t* packed_ptr);
|
||||
static std::shared_ptr<InferenceRequest> deserialize(int64_t const* packed_ptr);
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -50,6 +50,13 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
bool operator==(KvCacheConfig const& other) const
|
||||
{
|
||||
return maxTokens == other.maxTokens && maxAttentionWindow == other.maxAttentionWindow
|
||||
&& sinkTokenLength == other.sinkTokenLength && freeGpuMemoryFraction == other.freeGpuMemoryFraction
|
||||
&& enableBlockReuse == other.enableBlockReuse && useUvm == other.useUvm;
|
||||
}
|
||||
|
||||
std::optional<SizeType> maxTokens;
|
||||
std::optional<SizeType> maxAttentionWindow;
|
||||
std::optional<SizeType> sinkTokenLength;
|
||||
|
||||
@ -176,6 +176,13 @@ public:
|
||||
mNumTokens += n;
|
||||
}
|
||||
|
||||
void removeTokens(SizeType n)
|
||||
{
|
||||
TLLM_CHECK(n <= mNumTokens);
|
||||
TLLM_CHECK(mNumTokens - n >= 0);
|
||||
mNumTokens -= n;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getSequenceSlotIdx() const
|
||||
{
|
||||
return mSeqSlotIdx;
|
||||
@ -214,6 +221,14 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void removeLastBlock()
|
||||
{
|
||||
for (auto& beamBlockIds : mCacheBlockIds)
|
||||
{
|
||||
beamBlockIds.pop_back();
|
||||
}
|
||||
}
|
||||
|
||||
void setNumPrepopulatedTokens(std::vector<int> numPrepopulatedTokens)
|
||||
{
|
||||
mNumPrepopulatedTokens = std::move(numPrepopulatedTokens);
|
||||
@ -280,32 +295,40 @@ public:
|
||||
//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
|
||||
void schedulingReleaseBlocks(GenerationRequest& sequence);
|
||||
|
||||
[[nodiscard]] SizeType getNumFreeBlocks() const
|
||||
//! \brief Release last block in the sequence
|
||||
void releaseLastBlock(GenerationRequest& sequence);
|
||||
|
||||
[[nodiscard]] SizeType getNumFreeBlocks() const noexcept
|
||||
{
|
||||
return mFreeBlocks.size();
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getNumAllocatedBlocks() const
|
||||
[[nodiscard]] SizeType getNumReusedBlocks() const noexcept
|
||||
{
|
||||
return mReusedBlocks;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getNumAllocatedBlocks() const noexcept
|
||||
{
|
||||
return getMaxNumBlocks() - getNumFreeBlocks();
|
||||
}
|
||||
|
||||
[[nodiscard]] bool hasFreeBlocks(SizeType numRequired = 1) const
|
||||
[[nodiscard]] bool hasFreeBlocks(SizeType numRequired = 1) const noexcept
|
||||
{
|
||||
return getNumFreeBlocks() >= numRequired;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool schedulingHasFreeBlocks(SizeType numRequired = 1) const
|
||||
[[nodiscard]] bool schedulingHasFreeBlocks(SizeType numRequired = 1) const noexcept
|
||||
{
|
||||
return mSchedulingNumFreeBlocks >= numRequired;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getMaxNumBlocks() const
|
||||
[[nodiscard]] SizeType getMaxNumBlocks() const noexcept
|
||||
{
|
||||
return static_cast<SizeType>(mAllBlocksByIdx.size());
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getTokensPerBlock() const
|
||||
[[nodiscard]] SizeType getTokensPerBlock() const noexcept
|
||||
{
|
||||
return mTokensPerBlock;
|
||||
}
|
||||
@ -478,11 +501,15 @@ public:
|
||||
return mEnableBlockReuse;
|
||||
}
|
||||
|
||||
void removeToken(SizeType seqSlotIdx);
|
||||
void rewindKVCache(SizeType seqSlotIdx, SizeType rewindLengths);
|
||||
|
||||
private:
|
||||
void resetBlockPointers(SizeType seqSlotIdx, SizeType beamWidth);
|
||||
void cacheBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
|
||||
void cacheNewBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
|
||||
void updateNewBlockPointer(const GenerationRequest& seq, SizeType seqSlotIdx, SizeType blockIdx);
|
||||
void updateNewBlockPointer(GenerationRequest const& seq, SizeType seqSlotIdx, SizeType blockIdx);
|
||||
void updateToken(SizeType seqSlotIdx, bool addToken);
|
||||
|
||||
private:
|
||||
// Number of elements per one blocks
|
||||
|
||||
@ -474,7 +474,7 @@ public:
|
||||
return mDraftTokens->size();
|
||||
}
|
||||
|
||||
void setReturnContextLogits(const bool returnContextLogits)
|
||||
void setReturnContextLogits(bool const returnContextLogits)
|
||||
{
|
||||
mReturnContextLogits = returnContextLogits;
|
||||
}
|
||||
@ -484,7 +484,7 @@ public:
|
||||
return mReturnContextLogits;
|
||||
}
|
||||
|
||||
void setReturnGenerationLogits(const bool returnGenerationLogits)
|
||||
void setReturnGenerationLogits(bool const returnGenerationLogits)
|
||||
{
|
||||
mReturnGenerationLogits = returnGenerationLogits;
|
||||
}
|
||||
@ -556,6 +556,11 @@ public:
|
||||
return mState == REQUEST_STATE_GENERATION_IN_PROGRESS;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isGenerationCompleteState() const noexcept
|
||||
{
|
||||
return mState == REQUEST_STATE_GENERATION_COMPLETE;
|
||||
}
|
||||
|
||||
/// To determine whether the context is unchunked. When a context is chunked into only a part, it
|
||||
/// is still different from the unchunked state, which indicates the initial status.
|
||||
[[nodiscard]] bool isFullContextRequest() const noexcept
|
||||
|
||||
@ -64,7 +64,7 @@ public:
|
||||
using TensorPtr = Base::TensorPtr;
|
||||
|
||||
NamedTensor(
|
||||
nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, const void* _data = nullptr);
|
||||
nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, void const* _data = nullptr);
|
||||
|
||||
NamedTensor(TensorPtr _tensor, std::string _name)
|
||||
: Base(std::move(_tensor), std::move(_name)){};
|
||||
@ -74,6 +74,10 @@ public:
|
||||
|
||||
[[nodiscard]] std::vector<int64_t> serialize() const;
|
||||
|
||||
static NamedTensor deserialize(const int64_t* packed);
|
||||
void serialize(int64_t* out, const size_t totalSize) const;
|
||||
|
||||
[[nodiscard]] size_t serializedSize() const;
|
||||
|
||||
static NamedTensor deserialize(int64_t const* packed);
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -50,11 +50,19 @@ public:
|
||||
|
||||
explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig)
|
||||
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()),
|
||||
executorConfig.getEnableTrtOverlap(), executorConfig.getDeviceIds(), executorConfig.getNormalizeLogProbs(),
|
||||
executorConfig.getEnableChunkedContext())
|
||||
executorConfig.getEnableTrtOverlap(),
|
||||
executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(),
|
||||
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext())
|
||||
{
|
||||
}
|
||||
|
||||
bool operator==(TrtGptModelOptionalParams const& other) const
|
||||
{
|
||||
return kvCacheConfig == other.kvCacheConfig && enableTrtOverlap == other.enableTrtOverlap
|
||||
&& deviceIds == other.deviceIds && normalizeLogProbs == other.normalizeLogProbs
|
||||
&& enableChunkedContext == other.enableChunkedContext && decodingMode == other.decodingMode;
|
||||
}
|
||||
|
||||
KvCacheConfig kvCacheConfig;
|
||||
|
||||
bool enableTrtOverlap;
|
||||
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include <cstdint>
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
@ -80,11 +81,17 @@ public:
|
||||
|
||||
[[nodiscard]] reference operator[](size_type index)
|
||||
{
|
||||
#ifdef INDEX_RANGE_CHECK
|
||||
TLLM_CHECK_WITH_INFO(index < mSize, "Index %lu is out of bounds [0, %lu)", index, mSize);
|
||||
#endif
|
||||
return mData[index];
|
||||
}
|
||||
|
||||
[[nodiscard]] const_reference operator[](size_type index) const
|
||||
{
|
||||
#ifdef INDEX_RANGE_CHECK
|
||||
TLLM_CHECK_WITH_INFO(index < mSize, "Index %lu is out of bounds [0, %lu)", index, mSize);
|
||||
#endif
|
||||
return mData[index];
|
||||
}
|
||||
|
||||
|
||||
@ -56,6 +56,7 @@ enum class MpiType
|
||||
kUINT64,
|
||||
kFP8,
|
||||
kBF16,
|
||||
kCHAR,
|
||||
};
|
||||
|
||||
//! \brief For converting a C++ data type to a TensorRT data type.
|
||||
@ -133,6 +134,12 @@ struct MpiTypeConverter<std::uint64_t>
|
||||
static constexpr auto value = MpiType::kUINT64;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct MpiTypeConverter<char>
|
||||
{
|
||||
static constexpr auto value = MpiType::kCHAR;
|
||||
};
|
||||
|
||||
#ifdef ENABLE_FP8
|
||||
template <>
|
||||
struct MpiTypeConverter<__nv_fp8_e4m3>
|
||||
@ -202,8 +209,8 @@ public:
|
||||
~MpiComm() noexcept;
|
||||
|
||||
// no copy
|
||||
MpiComm(const MpiComm&) = delete;
|
||||
MpiComm& operator=(const MpiComm&) = delete;
|
||||
MpiComm(MpiComm const&) = delete;
|
||||
MpiComm& operator=(MpiComm const&) = delete;
|
||||
|
||||
// move
|
||||
MpiComm(MpiComm&&) noexcept;
|
||||
@ -253,7 +260,24 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void bcast(std::vector<int64_t>& packed, int root) const;
|
||||
template <typename T>
|
||||
void bcast(std::vector<T>& vec, int root) const
|
||||
{
|
||||
auto const rank = getRank();
|
||||
auto vecSize = (rank == root) ? static_cast<int64_t>(vec.size()) : int64_t(0);
|
||||
bcast(&vecSize, 1, MpiType::kINT64, root);
|
||||
vec.resize(vecSize);
|
||||
|
||||
if constexpr (std::is_fundamental_v<std::remove_cv_t<T>>)
|
||||
{
|
||||
auto const mpiType = MpiTypeConverter<std::remove_cv_t<T>>::value;
|
||||
bcast(vec.data(), vec.size(), mpiType, root);
|
||||
}
|
||||
else
|
||||
{
|
||||
bcast(vec.data(), vec.size() * sizeof(T), MpiType::kBYTE, root);
|
||||
}
|
||||
}
|
||||
|
||||
void send(void const* buffer, std::size_t size, MpiType dtype, int dest, int tag) const;
|
||||
|
||||
@ -297,8 +321,8 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void allreduce(const void* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const;
|
||||
void allgather(const void* sendbuf, void* recvbuf, int count, MpiType dtype) const;
|
||||
void allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const;
|
||||
void allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const;
|
||||
void barrier() const;
|
||||
|
||||
void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
|
||||
|
||||
@ -34,6 +34,9 @@
|
||||
namespace tensorrt_llm::executor
|
||||
{
|
||||
|
||||
class Model;
|
||||
class Serialization;
|
||||
|
||||
/// @brief Sampling configuration
|
||||
class SamplingConfig
|
||||
{
|
||||
@ -51,6 +54,8 @@ public:
|
||||
|
||||
~SamplingConfig();
|
||||
|
||||
bool operator==(SamplingConfig const& other) const;
|
||||
|
||||
[[nodiscard]] SizeType getBeamWidth() const;
|
||||
[[nodiscard]] std::optional<SizeType> getTopK() const;
|
||||
[[nodiscard]] std::optional<FloatType> getTopP() const;
|
||||
@ -68,6 +73,7 @@ public:
|
||||
[[nodiscard]] std::optional<SizeType> getEarlyStopping() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
SizeType mBeamWidth;
|
||||
std::optional<SizeType> mTopK;
|
||||
std::optional<FloatType> mTopP;
|
||||
@ -86,12 +92,16 @@ private:
|
||||
};
|
||||
|
||||
/// @brief Configuration that controls the outputs of a Result
|
||||
struct OutputConfig
|
||||
class OutputConfig
|
||||
{
|
||||
bool returnLogProbs{false};
|
||||
bool returnContextLogits{false};
|
||||
bool returnGenerationLogits{false};
|
||||
bool excludeInputFromOutput{false};
|
||||
public:
|
||||
OutputConfig(bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false,
|
||||
bool excludeInputFromOutput = false);
|
||||
|
||||
bool returnLogProbs;
|
||||
bool returnContextLogits;
|
||||
bool returnGenerationLogits;
|
||||
bool excludeInputFromOutput;
|
||||
};
|
||||
|
||||
/// @brief Configuration for speculative decoding. Allows to include draft tokens, draft logits and specify acceptance
|
||||
@ -109,6 +119,7 @@ public:
|
||||
[[nodiscard]] std::optional<FloatType> getAcceptanceThreshold() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
VecTokens mTokens;
|
||||
std::optional<Tensor> mLogits;
|
||||
std::optional<FloatType> mAcceptanceThreshold;
|
||||
@ -128,6 +139,7 @@ public:
|
||||
[[nodiscard]] Tensor getEmbeddingTable() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
Tensor mEmbeddingTable;
|
||||
};
|
||||
|
||||
@ -142,6 +154,8 @@ public:
|
||||
[[nodiscard]] Tensor getConfig() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
|
||||
Tensor mWeights;
|
||||
Tensor mConfig;
|
||||
};
|
||||
@ -207,6 +221,7 @@ public:
|
||||
void setLoraConfig(LoraConfig loraConfig);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> mImpl;
|
||||
};
|
||||
@ -298,15 +313,49 @@ private:
|
||||
|
||||
SizeType const kDefaultIterStatsMaxIterations = 1000;
|
||||
|
||||
/// @brief A configuration class for the parallel execution parameters
|
||||
/// Currently only supports commType = CommunicationType::kMPI
|
||||
class ParallelConfig
|
||||
{
|
||||
public:
|
||||
/// @brief Constructor
|
||||
/// @param commType The communication type. See CommunicationType.
|
||||
/// @param commMode The communication mode. See CommunicationMode.
|
||||
/// @param deviceIds The IDs of the GPUs involved in the execution of the model
|
||||
/// @param participantIds The participant IDs (MPI ranks if commType == kMPI) involved in the execution of the
|
||||
/// model. The first participant is considered to be the leader.
|
||||
ParallelConfig(CommunicationType commType = CommunicationType::kMPI,
|
||||
CommunicationMode commMode = CommunicationMode::kLEADER,
|
||||
std::optional<std::vector<SizeType>> deviceIds = std::nullopt,
|
||||
std::optional<std::vector<SizeType>> participantIds = std::nullopt);
|
||||
~ParallelConfig();
|
||||
|
||||
[[nodiscard]] CommunicationType getCommunicationType() const;
|
||||
[[nodiscard]] CommunicationMode getCommunicationMode() const;
|
||||
[[nodiscard]] std::optional<std::vector<SizeType>> getDeviceIds() const;
|
||||
[[nodiscard]] std::optional<std::vector<SizeType>> getParticipantIds() const;
|
||||
|
||||
void setCommunicationType(CommunicationType type);
|
||||
void setCommunicationMode(CommunicationMode mode);
|
||||
void setDeviceIds(std::vector<SizeType> deviceIds);
|
||||
void setParticipantIds(std::vector<SizeType> participantIds);
|
||||
|
||||
private:
|
||||
CommunicationType mCommType;
|
||||
CommunicationMode mCommMode;
|
||||
std::optional<std::vector<SizeType>> mDeviceIds;
|
||||
std::optional<std::vector<SizeType>> mParticipantIds;
|
||||
};
|
||||
|
||||
/// @brief Configuration class for the model executor
|
||||
class ExecutorConfig
|
||||
{
|
||||
public:
|
||||
ExecutorConfig(SizeType maxBeamWidth = 1, SchedulerConfig schedulerConfig = SchedulerConfig(),
|
||||
KvCacheConfig kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false, bool normalizeLogProbs = true,
|
||||
bool enableTrtOverlap = false, std::optional<std::vector<SizeType>> deviceIds = std::nullopt,
|
||||
SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations,
|
||||
BatchingType batchingType = BatchingType::kINFLIGHT);
|
||||
bool enableTrtOverlap = false, SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations,
|
||||
BatchingType batchingType = BatchingType::kINFLIGHT,
|
||||
std::optional<ParallelConfig> parallelConfig = std::nullopt);
|
||||
|
||||
[[nodiscard]] SizeType getMaxBeamWidth() const;
|
||||
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
|
||||
@ -314,9 +363,9 @@ public:
|
||||
[[nodiscard]] bool getEnableChunkedContext() const;
|
||||
[[nodiscard]] bool getNormalizeLogProbs() const;
|
||||
[[nodiscard]] bool getEnableTrtOverlap() const;
|
||||
[[nodiscard]] std::optional<std::vector<SizeType>> getDeviceIds() const;
|
||||
[[nodiscard]] SizeType getIterStatsMaxIterations() const;
|
||||
[[nodiscard]] BatchingType getBatchingType() const;
|
||||
[[nodiscard]] std::optional<ParallelConfig> getParallelConfig() const;
|
||||
|
||||
void setMaxBeamWidth(SizeType maxBeamWidth);
|
||||
void setSchedulerConfig(SchedulerConfig schedulerConfig);
|
||||
@ -324,9 +373,9 @@ public:
|
||||
void setEnableChunkedContext(bool enableChunkedContext);
|
||||
void setNormalizeLogProbs(bool normalizeLogProbs);
|
||||
void setEnableTrtOverlap(bool enableTrtOverlap);
|
||||
void setDeviceIds(std::optional<std::vector<SizeType>> deviceIds);
|
||||
void setIterStatsMaxIterations(SizeType iterStatsMaxIterations);
|
||||
void setBatchingType(BatchingType batchingType);
|
||||
void setParallelConfig(ParallelConfig parallelConfig);
|
||||
|
||||
private:
|
||||
SizeType mMaxBeamWidth;
|
||||
@ -335,24 +384,11 @@ private:
|
||||
bool mEnableChunkedContext;
|
||||
bool mNormalizeLogProbs;
|
||||
bool mEnableTrtOverlap;
|
||||
std::optional<std::vector<SizeType>> mDeviceIds;
|
||||
SizeType mIterStatsMaxIterations;
|
||||
BatchingType mBatchingType;
|
||||
std::optional<ParallelConfig> mParallelConfig;
|
||||
};
|
||||
|
||||
/// TODO:
|
||||
/// @brief A class to identify processes involved in the execution of a model
|
||||
/// Currently only supports MPI communication
|
||||
class Communicator
|
||||
{
|
||||
public:
|
||||
Communicator(CommunicatorType commType, CommMode mode, SizeType currentId, std::vector<SizeType> const& commIds,
|
||||
std::optional<SizeType> orchestratorId){};
|
||||
~Communicator() = default;
|
||||
};
|
||||
|
||||
class Model;
|
||||
|
||||
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
|
||||
class Executor
|
||||
{
|
||||
@ -364,14 +400,12 @@ public:
|
||||
/// @param modelType The type of model
|
||||
/// @param executorConfig The configuration for the executor
|
||||
/// @param comm An optional inter-process communicator configuration
|
||||
Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig executorConfig,
|
||||
std::optional<Communicator> comm = std::nullopt);
|
||||
Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig executorConfig);
|
||||
|
||||
Executor(std::vector<uint8_t> const& engineBuffer, std::string const& jsonConfigStr, ModelType modelType,
|
||||
ExecutorConfig executorConfig, std::optional<Communicator> comm = std::nullopt);
|
||||
ExecutorConfig executorConfig);
|
||||
|
||||
Executor(
|
||||
std::shared_ptr<Model> model, ExecutorConfig executorConfig, std::optional<Communicator> comm = std::nullopt);
|
||||
Executor(std::shared_ptr<Model> model, ExecutorConfig executorConfig);
|
||||
|
||||
~Executor();
|
||||
|
||||
|
||||
@ -180,11 +180,11 @@ public:
|
||||
|
||||
~Tensor() = default;
|
||||
|
||||
Tensor(const Tensor& other) noexcept = default;
|
||||
Tensor(Tensor const& other) noexcept = default;
|
||||
|
||||
Tensor(Tensor&& other) noexcept = default;
|
||||
|
||||
Tensor& operator=(const Tensor& other) noexcept = default;
|
||||
Tensor& operator=(Tensor const& other) noexcept = default;
|
||||
|
||||
Tensor& operator=(Tensor&& other) noexcept = default;
|
||||
|
||||
@ -267,6 +267,7 @@ private:
|
||||
|
||||
friend std::shared_ptr<runtime::ITensor> const& detail::toITensor(Tensor const& tensor);
|
||||
friend Tensor detail::ofITensor(std::shared_ptr<runtime::ITensor> tensor);
|
||||
friend class Serialization;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::executor
|
||||
|
||||
@ -155,21 +155,16 @@ enum class SchedulerPolicy
|
||||
kGUARANTEED_NO_EVICT = 1,
|
||||
};
|
||||
|
||||
enum class CommunicatorType
|
||||
enum class CommunicationType
|
||||
{
|
||||
kMPI = 0
|
||||
};
|
||||
|
||||
enum class CommMode
|
||||
enum class CommunicationMode
|
||||
{
|
||||
kLEADER, // With the leader mode, only the leader will be returning from the executor constructor and
|
||||
// therefore only the leader can enqueue requests and get responses
|
||||
kORCHESTRATOR, // With the orchestrator mode, only the orchestrator will be returning from the executor constructor
|
||||
// and therefore only the leader can enqueue requests and get responses The orchestrator doesn't
|
||||
// participate in the computations
|
||||
kALL, // With the ALL mode, all participants are expected to make the same calls to the executor API
|
||||
// So they all need to send the same requests
|
||||
// Responses will be the same for all participants
|
||||
kLEADER, // With the leader mode, only the leader can enqueue requests. The requests will be
|
||||
// broadcasted to the workers. All participants can get response via awaitResponses. The leader is the
|
||||
// first participant in the provided participant IDS, or 0 if participant ID is not provided
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::executor
|
||||
|
||||
@ -81,6 +81,11 @@ public:
|
||||
|
||||
using UnderlyingType = uint8_t;
|
||||
|
||||
bool operator==(DecodingMode const& other) const
|
||||
{
|
||||
return mState == other.mState;
|
||||
}
|
||||
|
||||
private:
|
||||
constexpr DecodingMode(UnderlyingType state)
|
||||
: mState(state)
|
||||
|
||||
@ -17,10 +17,13 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
#include "tensorrt_llm/runtime/decodingInput.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
#include "tensorrt_llm/runtime/decodingOutput.h"
|
||||
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
#include <curand_kernel.h>
|
||||
|
||||
#include <memory>
|
||||
@ -59,7 +62,7 @@ public:
|
||||
DecodingInput const& decodingInput, BufferManager const& manager)
|
||||
= 0;
|
||||
|
||||
virtual const SamplingConfig& getSamplingConfig() = 0;
|
||||
virtual SamplingConfig const& getSamplingConfig() = 0;
|
||||
|
||||
static void acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds,
|
||||
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths,
|
||||
@ -71,6 +74,11 @@ public:
|
||||
SizeType vocabSize, SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
|
||||
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream);
|
||||
|
||||
static void updateKVCacheBasedOnAcceptedTokens(ITensor const& acceptedOffsets, ITensor const& packedAcceptedIds,
|
||||
ITensor const& pointerArray, ITensor const& pastKeyValueLengths, GptModelConfig const& modelConfig,
|
||||
WorldConfig const& worldConfig, BufferManager::CudaStreamPtr stream, SizeType rewindDraftTokenCount,
|
||||
SizeType maxAttentionWindow, SizeType maxBlocksPerSeq, nvinfer1::DataType dtype);
|
||||
|
||||
static std::unique_ptr<IGptDecoder> create(DecodingMode const& mode, nvinfer1::DataType dtype, size_t maxBatchSize,
|
||||
size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
|
||||
BufferManager::CudaStreamPtr const& stream);
|
||||
@ -97,7 +105,7 @@ public:
|
||||
void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput, DecodingInput const& decodingInput,
|
||||
BufferManager const& manager) override;
|
||||
|
||||
const SamplingConfig& getSamplingConfig() override
|
||||
SamplingConfig const& getSamplingConfig() override
|
||||
{
|
||||
return mSamplingConfig;
|
||||
}
|
||||
|
||||
@ -153,6 +153,18 @@ public:
|
||||
return mFinishedSum;
|
||||
}
|
||||
|
||||
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
|
||||
[[nodiscard]] TensorPtr getNextDraftTokens() const override
|
||||
{
|
||||
return mNextDraftTokens;
|
||||
}
|
||||
|
||||
//! @returns [batchSize], lengths of the predicted draft tokens for next step, on gpu
|
||||
[[nodiscard]] TensorPtr getNextDraftTokenLengths() const override
|
||||
{
|
||||
return mNextDraftTokenLengths;
|
||||
}
|
||||
|
||||
private:
|
||||
//! @brief Gather final beam search results for request `batchIdx`.
|
||||
[[nodiscard]] CudaEvent postProcessRequest(SizeType batchIdx) const;
|
||||
@ -204,6 +216,8 @@ private:
|
||||
TensorPtr mBatchSlotsAcceptTokens; // [maxBatchSize], int32_t, address map, pinned
|
||||
TensorPtr mBatchSlotsAcceptLogits; // [maxBatchSize], int32_t, address map, pinned
|
||||
TensorPtr mTargetLogitsPtrs; // [maxBatchSize], float*, pointers to target logits, pinned
|
||||
TensorPtr mNextDraftTokens;
|
||||
TensorPtr mNextDraftTokenLengths;
|
||||
SizeType mMaxSequenceLength{};
|
||||
SizeType mMaxAttentionWindow{};
|
||||
SizeType mSinkTokenLength{};
|
||||
|
||||
@ -46,15 +46,10 @@ public:
|
||||
, endId{endId}
|
||||
, computeCumLogProbs(false)
|
||||
, computeLogProbs(false)
|
||||
, generatedTokensPerStep(1)
|
||||
{
|
||||
}
|
||||
|
||||
// the number of tokens generated per step
|
||||
SizeType generatedTokensPerStep() const
|
||||
{
|
||||
return draftTokens ? draftTokens->getSize() + 1 : 1;
|
||||
}
|
||||
|
||||
// mandatory parameters
|
||||
ConstTensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu
|
||||
SizeType inputLen; // the input length without draft tokens
|
||||
@ -71,6 +66,7 @@ public:
|
||||
|
||||
bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||
bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||
SizeType generatedTokensPerStep;
|
||||
};
|
||||
|
||||
class Input
|
||||
@ -184,6 +180,12 @@ public:
|
||||
std::vector<SamplingConfig> const& samplingConfigs)
|
||||
= 0;
|
||||
|
||||
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
|
||||
virtual TensorPtr getNextDraftTokens() const = 0;
|
||||
|
||||
//! @returns [batchSize], lengths of the predicted draft tokens for next step, on gpu
|
||||
virtual TensorPtr getNextDraftTokenLengths() const = 0;
|
||||
|
||||
protected:
|
||||
IGptDecoderBatch() = default;
|
||||
};
|
||||
|
||||
@ -36,7 +36,7 @@ public:
|
||||
IpcMemory(WorldConfig const& worldConfig, std::size_t bufferSize);
|
||||
~IpcMemory();
|
||||
|
||||
[[nodiscard]] const std::vector<void*>& getCommPtrsTensor() const
|
||||
[[nodiscard]] std::vector<void*> const& getCommPtrsTensor() const
|
||||
{
|
||||
return mCommPtrs;
|
||||
}
|
||||
|
||||
@ -67,7 +67,7 @@ public:
|
||||
// Fill the tasks tensor for the batch using the provided tasksHost
|
||||
// Function assumes that the first numContextRequests requests in the batch are context requests
|
||||
void fillTasksTensor(TensorPtr tasksHost, const SizeType batchSize, const SizeType numContextRequests,
|
||||
const std::vector<SizeType>& reqBeamWidths, const std::vector<SizeType>& reqPromptLengths,
|
||||
std::vector<SizeType> const& reqBeamWidths, std::vector<SizeType> const& reqPromptLengths,
|
||||
BufferManager const& manager, bool packedInput);
|
||||
};
|
||||
|
||||
|
||||
@ -43,7 +43,7 @@ private:
|
||||
auto const hasValues = accessor(0).has_value();
|
||||
for (size_t ci = 0; ci < configs.size(); ++ci)
|
||||
{
|
||||
const auto& configValue = accessor(ci);
|
||||
auto const& configValue = accessor(ci);
|
||||
TLLM_CHECK(hasValues == configValue.has_value());
|
||||
if (hasValues)
|
||||
{
|
||||
|
||||
@ -188,7 +188,6 @@ endif()
|
||||
set(TRTLLM_LINK_LIBS
|
||||
${CUBLAS_LIB}
|
||||
${CUBLASLT_LIB}
|
||||
${CUDNN_LIB}
|
||||
${CMAKE_DL_LIBS}
|
||||
${MPI_C_LIBRARIES}
|
||||
${NCCL_LIB}
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0ecc134ad10a54b2953c772e72db2f71e84130d5736087b033e9e5b78594db6d
|
||||
size 2113376
|
||||
oid sha256:c56ee13bb109917ab10df168ca15e6057436df1cd8b64a4268c6e7aae78a5ad8
|
||||
size 2126310
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9aa3f3d7f8313c099df8e9bd4c9707922a4f1c4025c4c99986acf6df781738c7
|
||||
size 2128450
|
||||
oid sha256:339532215fa4c16e68ca28ee23d0a0e09c9caefa7bd19b563d2f7b83cad6822e
|
||||
size 2142070
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
add62ff328028bbcded1af694fe758c5 libtensorrt_llm_batch_manager_static.a
|
||||
9e8846e200e2aaaeace862741a90c3ab libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
230623fa285048a2de5c54c2cc0f364fb9f2c559 commit
|
||||
c9c505e2cb6e95b7cfc124c04ab1fcb3 libtensorrt_llm_batch_manager_static.a
|
||||
2f5cec5a5b42e0031bc2edc688c1e74b libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
741fb083cc42933439ae54557b177b6d7064da4f commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7b25de974b6ca5f0dcb279f16f38199167d1efc35c01770d3234bec2dfb5dc86
|
||||
size 2097848
|
||||
oid sha256:a4060f2d60472850344e5b5799f9ad88390f4ad9c056e3843f3bdbcc046ca68b
|
||||
size 2106440
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5f06cee5ae2bcf393196265cd9a3ef832690cd4c5c53934bbfb169d50ab33c41
|
||||
size 2055004
|
||||
oid sha256:829f1ed5af0b0d2577e57fd13979706fe0b3636bd6338aac3c34a615f64afedc
|
||||
size 2064310
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
bb62a31b8e17dae284d784ba43d5bc02 libtensorrt_llm_batch_manager_static.a
|
||||
19327f59c7f5b6235e15b322d5f5a0f4 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
2db5c985786dad3dd16c22ec54af0803 libtensorrt_llm_batch_manager_static.a
|
||||
96940249ff7b3ff09754b89ad25fcf9f libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
|
||||
@ -42,11 +42,11 @@ public:
|
||||
virtual ~IAllocator() = default;
|
||||
|
||||
// no copying
|
||||
IAllocator(const IAllocator&) = delete;
|
||||
IAllocator& operator=(const IAllocator&) = delete;
|
||||
IAllocator(IAllocator const&) = delete;
|
||||
IAllocator& operator=(IAllocator const&) = delete;
|
||||
|
||||
template <typename T>
|
||||
[[nodiscard]] T* reMalloc(T* ptr, size_t sizeBytes, const bool setZero = true)
|
||||
[[nodiscard]] T* reMalloc(T* ptr, size_t sizeBytes, bool const setZero = true)
|
||||
{
|
||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||
// TODO martinma: why do we need this size extension?
|
||||
|
||||
@ -23,7 +23,7 @@
|
||||
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "")
|
||||
[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "")
|
||||
{
|
||||
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str()));
|
||||
}
|
||||
@ -38,8 +38,10 @@ public:
|
||||
|
||||
#if defined(_WIN32)
|
||||
#define TLLM_LIKELY(x) (__assume((x) == 1), (x))
|
||||
#define TLLM_UNLIKELY(x) (__assume((x) == 0), (x))
|
||||
#else
|
||||
#define TLLM_LIKELY(x) __builtin_expect((x), 1)
|
||||
#define TLLM_UNLIKELY(x) __builtin_expect((x), 0)
|
||||
#endif
|
||||
|
||||
#define TLLM_CHECK(val) \
|
||||
@ -61,20 +63,22 @@ public:
|
||||
#define TLLM_CHECK_DEBUG(val) \
|
||||
do \
|
||||
{ \
|
||||
if (DebugConfig::isCheckDebugEnabled()) \
|
||||
if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
|
||||
{ \
|
||||
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
|
||||
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
#define TLLM_CHECK_DEBUG_WITH_INFO(val, info) \
|
||||
#define TLLM_CHECK_DEBUG_WITH_INFO(val, info, ...) \
|
||||
do \
|
||||
{ \
|
||||
if (DebugConfig::isCheckDebugEnabled()) \
|
||||
if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
|
||||
{ \
|
||||
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
|
||||
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, info); \
|
||||
TLLM_LIKELY(static_cast<bool>(val)) \
|
||||
? ((void) 0) \
|
||||
: tensorrt_llm::common::throwRuntimeError( \
|
||||
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
|
||||
@ -42,7 +42,7 @@ CublasMMWrapper::~CublasMMWrapper()
|
||||
mMutex = nullptr;
|
||||
}
|
||||
|
||||
CublasMMWrapper::CublasMMWrapper(const CublasMMWrapper& wrapper)
|
||||
CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper)
|
||||
: mCublasHandle(wrapper.mCublasHandle)
|
||||
, mCublasLtHandle(wrapper.mCublasLtHandle)
|
||||
, mStream(wrapper.mStream)
|
||||
@ -50,8 +50,8 @@ CublasMMWrapper::CublasMMWrapper(const CublasMMWrapper& wrapper)
|
||||
{
|
||||
}
|
||||
|
||||
void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
||||
const int k, const int lda, const int ldb, const int ldc)
|
||||
void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||
int const k, int const lda, int const ldb, int const ldc)
|
||||
{
|
||||
// --------------------------------------
|
||||
// Create descriptors for the original matrices
|
||||
@ -79,15 +79,15 @@ void CublasMMWrapper::destroyDescriptors()
|
||||
mCDesc = NULL;
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc)
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc)
|
||||
{
|
||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f);
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc,
|
||||
const std::optional<cublasLtMatmulHeuristicResult_t>& heuristic)
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc,
|
||||
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
|
||||
{
|
||||
if (heuristic)
|
||||
{
|
||||
@ -102,8 +102,8 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c
|
||||
}
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta)
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta)
|
||||
{
|
||||
bool usingCublasLt = mAType == CUDA_R_16F;
|
||||
|
||||
@ -111,9 +111,9 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c
|
||||
/* usingCublasLt */ usingCublasLt);
|
||||
}
|
||||
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta,
|
||||
const cublasLtMatmulAlgo_t& algo, bool hasAlgo, bool usingCublasLt)
|
||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt)
|
||||
{
|
||||
half h_alpha = (half) (f_alpha);
|
||||
half h_beta = (half) (f_beta);
|
||||
@ -126,8 +126,8 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c
|
||||
int batch_count = 1;
|
||||
// fp32 use cublas as default
|
||||
// fp16 use cublasLt as default
|
||||
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
|
||||
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
|
||||
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
|
||||
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
|
||||
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
||||
|
||||
if (usingCublasLt)
|
||||
@ -154,10 +154,10 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c
|
||||
}
|
||||
}
|
||||
|
||||
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
||||
const int k, const void* A, const int lda, const int64_t strideA, const void* B, const int ldb,
|
||||
const int64_t strideB, void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha,
|
||||
const float f_beta)
|
||||
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||
int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb,
|
||||
const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha,
|
||||
float const f_beta)
|
||||
{
|
||||
half h_alpha = (half) f_alpha;
|
||||
half h_beta = (half) f_beta;
|
||||
@ -165,26 +165,26 @@ void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperati
|
||||
std::lock_guard<std::mutex> lock(*mMutex);
|
||||
|
||||
int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
|
||||
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
|
||||
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
|
||||
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
|
||||
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
|
||||
|
||||
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda,
|
||||
strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType,
|
||||
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||
}
|
||||
|
||||
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
||||
const int k, const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA,
|
||||
const void* B, cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C,
|
||||
cudaDataType_t CType, const int ldc, const int64_t strideC, const int batchCount, cudaDataType_t computeType)
|
||||
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||
int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA,
|
||||
void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C,
|
||||
cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType)
|
||||
{
|
||||
half h_alpha = (half) f_alpha;
|
||||
half h_beta = (half) f_beta;
|
||||
|
||||
std::lock_guard<std::mutex> lock(*mMutex);
|
||||
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
|
||||
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
|
||||
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
|
||||
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
|
||||
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
|
||||
|
||||
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda,
|
||||
strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType,
|
||||
@ -267,8 +267,8 @@ void CublasMMWrapper::setStream(cudaStream_t stream)
|
||||
mStream = stream;
|
||||
}
|
||||
|
||||
bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
||||
const int k, const int lda, const int ldb, const int ldc, const cublasLtMatmulAlgo_t& algo)
|
||||
bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||
int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
|
||||
@ -291,12 +291,12 @@ bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t tr
|
||||
}
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasOperation_t transa,
|
||||
cublasOperation_t transb, const int m, const int n, const int k, const int lda, const int ldb, const int ldc)
|
||||
cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
|
||||
|
||||
const auto heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc);
|
||||
auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc);
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
|
||||
@ -65,39 +65,39 @@ public:
|
||||
|
||||
~CublasMMWrapper();
|
||||
|
||||
CublasMMWrapper(const CublasMMWrapper& wrapper);
|
||||
CublasMMWrapper(CublasMMWrapper const& wrapper);
|
||||
|
||||
/********************** GEMMs **********************/
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
|
||||
const int lda, const void* B, const int ldb, void* C, const int ldc);
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
|
||||
const int lda, const void* B, const int ldb, void* C, const int ldc,
|
||||
const std::optional<cublasLtMatmulHeuristicResult_t>& algo);
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc,
|
||||
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
|
||||
const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta);
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta);
|
||||
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
|
||||
const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta,
|
||||
const cublasLtMatmulAlgo_t& algo, bool hasAlgo, bool usingCublasLt);
|
||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt);
|
||||
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const void* A, const int lda, const int64_t strideA, const void* B, const int ldb, const int64_t strideB,
|
||||
void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha = 1.0f,
|
||||
const float f_beta = 0.0f);
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB,
|
||||
void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f,
|
||||
float const f_beta = 0.0f);
|
||||
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA, const void* B,
|
||||
cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C, cudaDataType_t CType,
|
||||
const int ldc, const int64_t strideC, const int batchCount, cudaDataType_t computeType);
|
||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B,
|
||||
cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType,
|
||||
int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType);
|
||||
|
||||
/********************** Tactic selection helpers **********************/
|
||||
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const int lda, const int ldb, const int ldc, const cublasLtMatmulAlgo_t& algo);
|
||||
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo);
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
|
||||
const int m, const int n, const int k, const int lda, const int ldb, const int ldc);
|
||||
int const m, int const n, int const k, int const lda, int const ldb, int const ldc);
|
||||
|
||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
|
||||
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
|
||||
@ -126,8 +126,8 @@ public:
|
||||
|
||||
CublasDataType getCublasDataType(cudaDataType_t data_type);
|
||||
|
||||
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
||||
const int lda, const int ldb, const int ldc);
|
||||
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||
int const lda, int const ldb, int const ldc);
|
||||
void destroyDescriptors();
|
||||
|
||||
cublasHandle_t getCublasHandle()
|
||||
|
||||
@ -43,7 +43,7 @@ CUDADriverWrapper::CUDADriverWrapper()
|
||||
handle = dllOpen(CUDA_LIB_NAME);
|
||||
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
|
||||
|
||||
auto load_sym = [](void* handle, const char* name)
|
||||
auto load_sym = [](void* handle, char const* name)
|
||||
{
|
||||
void* ret = dllGetSym(handle, name);
|
||||
return ret;
|
||||
@ -69,7 +69,7 @@ CUDADriverWrapper::~CUDADriverWrapper()
|
||||
dllClose(handle);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, const char** pStr) const
|
||||
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const
|
||||
{
|
||||
return (*_cuGetErrorName)(error, pStr);
|
||||
}
|
||||
@ -94,7 +94,7 @@ CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const
|
||||
return (*_cuLinkDestroy)(state);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, const void* image) const
|
||||
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const
|
||||
{
|
||||
return (*_cuModuleLoadData)(module, image);
|
||||
}
|
||||
@ -105,24 +105,24 @@ CUresult CUDADriverWrapper::cuLinkCreate(
|
||||
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, const char* name) const
|
||||
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const
|
||||
{
|
||||
return (*_cuModuleGetFunction)(hfunc, hmod, name);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name) const
|
||||
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const
|
||||
{
|
||||
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path,
|
||||
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path,
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
||||
{
|
||||
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
|
||||
const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
||||
char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
||||
{
|
||||
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
|
||||
}
|
||||
|
||||
@ -37,7 +37,7 @@ public:
|
||||
|
||||
~CUDADriverWrapper();
|
||||
|
||||
CUresult cuGetErrorName(CUresult error, const char** pStr) const;
|
||||
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
|
||||
|
||||
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
|
||||
|
||||
@ -47,19 +47,19 @@ public:
|
||||
|
||||
CUresult cuLinkDestroy(CUlinkState state) const;
|
||||
|
||||
CUresult cuModuleLoadData(CUmodule* module, const void* image) const;
|
||||
CUresult cuModuleLoadData(CUmodule* module, void const* image) const;
|
||||
|
||||
CUresult cuLinkCreate(
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
|
||||
|
||||
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, const char* name) const;
|
||||
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const;
|
||||
|
||||
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name) const;
|
||||
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const;
|
||||
|
||||
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path, unsigned int numOptions,
|
||||
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions,
|
||||
CUjit_option* options, void** optionValues) const;
|
||||
|
||||
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name,
|
||||
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name,
|
||||
unsigned int numOptions, CUjit_option* options, void** optionValues) const;
|
||||
|
||||
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
|
||||
@ -72,18 +72,18 @@ public:
|
||||
|
||||
private:
|
||||
void* handle;
|
||||
CUresult (*_cuGetErrorName)(CUresult, const char**);
|
||||
CUresult (*_cuGetErrorName)(CUresult, char const**);
|
||||
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
|
||||
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
|
||||
CUresult (*_cuModuleUnload)(CUmodule);
|
||||
CUresult (*_cuLinkDestroy)(CUlinkState);
|
||||
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
|
||||
CUresult (*_cuModuleLoadData)(CUmodule*, const void*);
|
||||
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, const char*);
|
||||
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, const char*);
|
||||
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, const char*, unsigned int, CUjit_option*, void**);
|
||||
CUresult (*_cuModuleLoadData)(CUmodule*, void const*);
|
||||
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*);
|
||||
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*);
|
||||
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**);
|
||||
CUresult (*_cuLinkAddData)(
|
||||
CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
|
||||
CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**);
|
||||
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
|
||||
unsigned int, unsigned int, unsigned int, CUstream, void**);
|
||||
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
|
||||
@ -91,11 +91,11 @@ private:
|
||||
CUstream hStream, void** kernelParams, void** extra);
|
||||
};
|
||||
|
||||
inline void cuErrCheck_(CUresult stat, const CUDADriverWrapper& wrap, const char* file, int line)
|
||||
inline void cuErrCheck_(CUresult stat, CUDADriverWrapper const& wrap, char const* file, int line)
|
||||
{
|
||||
if (stat != CUDA_SUCCESS)
|
||||
{
|
||||
const char* msg = nullptr;
|
||||
char const* msg = nullptr;
|
||||
wrap.cuGetErrorName(stat, &msg);
|
||||
fprintf(stderr, "CUDA Error: %s %s %d\n", msg, file, line);
|
||||
}
|
||||
|
||||
@ -121,16 +121,16 @@ void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaSt
|
||||
}
|
||||
|
||||
template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>(
|
||||
float* dst, const float* src, const int64_t numel, cudaStream_t stream);
|
||||
float* dst, float const* src, const int64_t numel, cudaStream_t stream);
|
||||
template void invokeFakeQuantize<float, float, __nv_fp8_e4m3>(
|
||||
float* dst, const __nv_fp8_e4m3* src, const int64_t numel, cudaStream_t stream);
|
||||
float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream);
|
||||
template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>(
|
||||
half* dst, const half* src, const int64_t numel, cudaStream_t stream);
|
||||
half* dst, half const* src, const int64_t numel, cudaStream_t stream);
|
||||
template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>(
|
||||
__nv_bfloat16* dst, const __nv_bfloat16* src, const int64_t numel, cudaStream_t stream);
|
||||
__nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream);
|
||||
|
||||
template void invokeFakeQuantize<float, half, float>(
|
||||
half* dst, const float* src, const int64_t numel, cudaStream_t stream);
|
||||
half* dst, float const* src, const int64_t numel, cudaStream_t stream);
|
||||
|
||||
__device__ float atomicMaxExtd(float* address, float val)
|
||||
{
|
||||
@ -146,7 +146,7 @@ inline __device__ T atomicMaxExtdV2(T* address, T val)
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
static_assert(std::is_same_v<T, half> | std::is_same_v<T, __nv_bfloat16>, "T needs to be either half or bfloat16");
|
||||
// The address in 64 bits.
|
||||
uint64_t address_u64 = reinterpret_cast<const uint64_t&>(address);
|
||||
uint64_t address_u64 = reinterpret_cast<uint64_t const&>(address);
|
||||
|
||||
// Pack the input value into 32 bits.
|
||||
union
|
||||
@ -155,7 +155,7 @@ inline __device__ T atomicMaxExtdV2(T* address, T val)
|
||||
uint16_t u[2];
|
||||
} old, tmp = {};
|
||||
|
||||
const int loc = (address_u64 & 0x2) >> 1;
|
||||
int const loc = (address_u64 & 0x2) >> 1;
|
||||
tmp.v[loc] = val;
|
||||
|
||||
// 4B aligned pointer.
|
||||
@ -223,7 +223,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
|
||||
auto val = fabs(static_cast<float>(weights[i]));
|
||||
max = max > val ? max : val;
|
||||
}
|
||||
const auto scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||
if constexpr (std::is_same_v<T_S, float>)
|
||||
{
|
||||
@ -231,7 +231,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto address_u64 = reinterpret_cast<uint64_t>(quant_ptr + col);
|
||||
auto const address_u64 = reinterpret_cast<uint64_t>(quant_ptr + col);
|
||||
if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0))
|
||||
atomicMaxExtd(quant_ptr + col, scale);
|
||||
else
|
||||
@ -244,7 +244,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
|
||||
}
|
||||
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
|
||||
{
|
||||
const auto nrows = size / n;
|
||||
auto const nrows = size / n;
|
||||
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
|
||||
{
|
||||
float max = 0.f;
|
||||
@ -256,7 +256,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
|
||||
max = blockReduceMax<float>(max);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
const auto scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
quant_ptr[row] = scale;
|
||||
}
|
||||
}
|
||||
@ -272,7 +272,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
|
||||
max = blockReduceMax<float>(max);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
const auto scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
atomicMaxExtd(quant_ptr, scale);
|
||||
}
|
||||
}
|
||||
@ -326,19 +326,19 @@ __global__ void dynamicQuantizeMatrixPerToken(
|
||||
extern __shared__ __align__(sizeof(float)) char _shmem[];
|
||||
T_IN* shmem = reinterpret_cast<T_IN*>(_shmem);
|
||||
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||
const auto nrows = numel / lda;
|
||||
auto const nrows = numel / lda;
|
||||
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
|
||||
{
|
||||
float max = 0.f;
|
||||
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
|
||||
{
|
||||
const auto in = input[row * lda + i];
|
||||
auto const in = input[row * lda + i];
|
||||
shmem[i] = in;
|
||||
auto val = fabs(static_cast<float>(in));
|
||||
max = max > val ? max : val;
|
||||
}
|
||||
max = blockAllReduceMax<float>(max); // __syncthreads() called so we can read shmem
|
||||
const auto s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
|
||||
{
|
||||
// true means we are quantizing
|
||||
@ -359,7 +359,7 @@ void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T
|
||||
{
|
||||
dim3 grid(numel / lda);
|
||||
bool use_shmem = true;
|
||||
const auto shmem_size = lda * sizeof(T_IN);
|
||||
auto const shmem_size = lda * sizeof(T_IN);
|
||||
if (shmem_size >= (48 << 10))
|
||||
{
|
||||
cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken<T_OUT, T_S, T_IN>,
|
||||
|
||||
@ -181,37 +181,37 @@ struct PackType<__nv_fp8_e4m3, 8>
|
||||
};
|
||||
#endif
|
||||
|
||||
__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, const __nv_fp8x4_e4m3* in)
|
||||
__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, __nv_fp8x4_e4m3 const* in)
|
||||
{
|
||||
const char4 tmp_val = reinterpret_cast<const char4*>(in)[0];
|
||||
*out1 = __nv_bfloat162((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.x)[0],
|
||||
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.y)[0]);
|
||||
*out2 = __nv_bfloat162((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.z)[0],
|
||||
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.w)[0]);
|
||||
const char4 tmp_val = reinterpret_cast<char4 const*>(in)[0];
|
||||
*out1 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
|
||||
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
|
||||
*out2 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0],
|
||||
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]);
|
||||
}
|
||||
|
||||
__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(const __nv_fp8x2_e4m3* in)
|
||||
__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(__nv_fp8x2_e4m3 const* in)
|
||||
{
|
||||
const char2 tmp_val = reinterpret_cast<const char2*>(in)[0];
|
||||
__nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.x)[0],
|
||||
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.y)[0]);
|
||||
const char2 tmp_val = reinterpret_cast<char2 const*>(in)[0];
|
||||
__nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
|
||||
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
|
||||
return out;
|
||||
}
|
||||
|
||||
__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, const __nv_fp8x4_e4m3* in)
|
||||
__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, __nv_fp8x4_e4m3 const* in)
|
||||
{
|
||||
const char4 tmp_val = reinterpret_cast<const char4*>(in)[0];
|
||||
*out1 = half2((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.x)[0],
|
||||
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.y)[0]);
|
||||
*out2 = half2((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.z)[0],
|
||||
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.w)[0]);
|
||||
const char4 tmp_val = reinterpret_cast<char4 const*>(in)[0];
|
||||
*out1 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
|
||||
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
|
||||
*out2 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0],
|
||||
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]);
|
||||
}
|
||||
|
||||
__inline__ __device__ half2 fp8x2_e4m3_to_half2(const __nv_fp8x2_e4m3* in)
|
||||
__inline__ __device__ half2 fp8x2_e4m3_to_half2(__nv_fp8x2_e4m3 const* in)
|
||||
{
|
||||
const char2 tmp_val = reinterpret_cast<const char2*>(in)[0];
|
||||
half2 out = half2((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.x)[0],
|
||||
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.y)[0]);
|
||||
const char2 tmp_val = reinterpret_cast<char2 const*>(in)[0];
|
||||
half2 out = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
|
||||
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@ -32,14 +32,14 @@ namespace common
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
inline __device__ T ldg(const T* val)
|
||||
inline __device__ T ldg(T const* val)
|
||||
{
|
||||
return __ldg(val);
|
||||
}
|
||||
|
||||
#if ENABLE_BF16
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val)
|
||||
inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return val[0];
|
||||
@ -49,7 +49,7 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val)
|
||||
}
|
||||
|
||||
template <>
|
||||
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val)
|
||||
inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val)
|
||||
{
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
return val[0];
|
||||
|
||||
@ -81,12 +81,12 @@ enum class OperationType
|
||||
};
|
||||
|
||||
/* **************************** debug tools ********************************* */
|
||||
static const char* _cudaGetErrorEnum(cudaError_t error)
|
||||
static char const* _cudaGetErrorEnum(cudaError_t error)
|
||||
{
|
||||
return cudaGetErrorString(error);
|
||||
}
|
||||
|
||||
static const char* _cudaGetErrorEnum(cublasStatus_t error)
|
||||
static char const* _cudaGetErrorEnum(cublasStatus_t error)
|
||||
{
|
||||
switch (error)
|
||||
{
|
||||
@ -114,7 +114,7 @@ static const char* _cudaGetErrorEnum(cublasStatus_t error)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void check(T result, char const* const func, const char* const file, int const line)
|
||||
void check(T result, char const* const func, char const* const file, int const line)
|
||||
{
|
||||
if (result)
|
||||
{
|
||||
@ -133,7 +133,7 @@ inline bool isCudaLaunchBlocking()
|
||||
|
||||
if (firstCall)
|
||||
{
|
||||
const char* env = std::getenv("CUDA_LAUNCH_BLOCKING");
|
||||
char const* env = std::getenv("CUDA_LAUNCH_BLOCKING");
|
||||
result = env != nullptr && std::string(env) == "1";
|
||||
firstCall = false;
|
||||
}
|
||||
@ -141,12 +141,12 @@ inline bool isCudaLaunchBlocking()
|
||||
return result;
|
||||
}
|
||||
|
||||
inline void syncAndCheck(const char* const file, int const line)
|
||||
inline void syncAndCheck(char const* const file, int const line)
|
||||
{
|
||||
#ifndef NDEBUG
|
||||
const bool checkError = true;
|
||||
bool const checkError = true;
|
||||
#else
|
||||
const bool checkError = isCudaLaunchBlocking();
|
||||
bool const checkError = isCudaLaunchBlocking();
|
||||
#endif
|
||||
|
||||
if (checkError)
|
||||
@ -279,7 +279,7 @@ inline int getDeviceCount()
|
||||
|
||||
/// Get the memory info
|
||||
/// \return The free and total amount of memory in bytes
|
||||
inline std::tuple<size_t, size_t> getDeviceMemoryInfo(const bool useUvm)
|
||||
inline std::tuple<size_t, size_t> getDeviceMemoryInfo(bool const useUvm)
|
||||
{
|
||||
if (useUvm)
|
||||
{
|
||||
@ -351,7 +351,7 @@ auto constexpr ceilDiv(T numerator, U denominator)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void printAbsMean(const T* buf, uint64_t size, cudaStream_t stream, std::string name = "")
|
||||
void printAbsMean(T const* buf, uint64_t size, cudaStream_t stream, std::string name = "")
|
||||
{
|
||||
if (buf == nullptr)
|
||||
{
|
||||
@ -390,9 +390,9 @@ void printAbsMean(const T* buf, uint64_t size, cudaStream_t stream, std::string
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void printToStream(const T* result, const int size, FILE* strm)
|
||||
void printToStream(T const* result, int const size, FILE* strm)
|
||||
{
|
||||
const bool split_rows = (strm == stdout);
|
||||
bool const split_rows = (strm == stdout);
|
||||
if (result == nullptr)
|
||||
{
|
||||
TLLM_LOG_WARNING("It is an nullptr, skip! \n");
|
||||
@ -414,13 +414,13 @@ void printToStream(const T* result, const int size, FILE* strm)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void printToScreen(const T* result, const int size)
|
||||
void printToScreen(T const* result, int const size)
|
||||
{
|
||||
printToStream(result, size, stdout);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print2dToStream(const T* result, const int r, const int c, const int stride, FILE* strm)
|
||||
void print2dToStream(T const* result, int const r, int const c, int const stride, FILE* strm)
|
||||
{
|
||||
if (result == nullptr)
|
||||
{
|
||||
@ -429,20 +429,20 @@ void print2dToStream(const T* result, const int r, const int c, const int stride
|
||||
}
|
||||
for (int ri = 0; ri < r; ++ri)
|
||||
{
|
||||
const T* ptr = result + ri * stride;
|
||||
T const* ptr = result + ri * stride;
|
||||
printToStream(ptr, c, strm);
|
||||
}
|
||||
fprintf(strm, "\n");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print2dToScreen(const T* result, const int r, const int c, const int stride)
|
||||
void print2dToScreen(T const* result, int const r, int const c, int const stride)
|
||||
{
|
||||
print2dToStream(result, r, c, stride, stdout);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void print2dToFile(std::string fname, const T* result, const int r, const int c, const int stride)
|
||||
void print2dToFile(std::string fname, T const* result, int const r, int const c, int const stride)
|
||||
{
|
||||
FILE* fp = fopen(fname.c_str(), "wt");
|
||||
if (fp != nullptr)
|
||||
@ -493,7 +493,7 @@ inline void print_element_(int64_t ill)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void printMatrix(const T* ptr, int m, int k, int stride, bool is_device_ptr)
|
||||
inline void printMatrix(T const* ptr, int m, int k, int stride, bool is_device_ptr)
|
||||
{
|
||||
T* tmp;
|
||||
if (is_device_ptr)
|
||||
@ -538,14 +538,14 @@ inline void printMatrix(const T* ptr, int m, int k, int stride, bool is_device_p
|
||||
}
|
||||
}
|
||||
|
||||
template void printMatrix(const float* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
template void printMatrix(const half* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
template void printMatrix(float const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
template void printMatrix(half const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
#ifdef ENABLE_BF16
|
||||
template void printMatrix(const __nv_bfloat16* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
template void printMatrix(__nv_bfloat16 const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
#endif
|
||||
template void printMatrix(const uint32_t* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
template void printMatrix(const uint64_t* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
template void printMatrix(const int* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
template void printMatrix(uint32_t const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
template void printMatrix(uint64_t const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
template void printMatrix(int const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
|
||||
@ -25,7 +25,7 @@ namespace tensorrt_llm::common
|
||||
// XQA kernels (optimized kernels for generation phase).
|
||||
bool forceXQAKernels()
|
||||
{
|
||||
const char* force_xqa_env_var = getenv("TRTLLM_FORCE_XQA");
|
||||
char const* force_xqa_env_var = getenv("TRTLLM_FORCE_XQA");
|
||||
static bool forceXQA = false;
|
||||
if (force_xqa_env_var != nullptr)
|
||||
{
|
||||
@ -45,7 +45,7 @@ bool getEnvMmhaMultiblockDebug()
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
const char* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG");
|
||||
char const* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG");
|
||||
if (enable_mmha_debug_var)
|
||||
{
|
||||
if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0')
|
||||
@ -64,7 +64,7 @@ int getEnvMmhaBlocksPerSequence()
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
const char* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE");
|
||||
char const* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE");
|
||||
if (mmhaBlocksPerSequenceEnv)
|
||||
{
|
||||
mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv);
|
||||
|
||||
@ -65,5 +65,4 @@ Logger* Logger::getLogger()
|
||||
thread_local Logger instance;
|
||||
return &instance;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
@ -54,26 +54,26 @@ public:
|
||||
|
||||
#if defined(_MSC_VER)
|
||||
template <typename... Args>
|
||||
void log(Level level, char const* format, const Args&... args);
|
||||
void log(Level level, char const* format, Args const&... args);
|
||||
|
||||
template <typename... Args>
|
||||
void log(Level level, int rank, char const* format, const Args&... args);
|
||||
void log(Level level, int rank, char const* format, Args const&... args);
|
||||
#else
|
||||
template <typename... Args>
|
||||
void log(Level level, char const* format, const Args&... args) __attribute__((format(printf, 3, 0)));
|
||||
void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0)));
|
||||
|
||||
template <typename... Args>
|
||||
void log(Level level, int rank, char const* format, const Args&... args) __attribute__((format(printf, 4, 0)));
|
||||
void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0)));
|
||||
#endif
|
||||
|
||||
template <typename... Args>
|
||||
void log(Level level, std::string const& format, const Args&... args)
|
||||
void log(Level level, std::string const& format, Args const&... args)
|
||||
{
|
||||
return log(level, format.c_str(), args...);
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void log(const Level level, const int rank, const std::string& format, const Args&... args)
|
||||
void log(const Level level, int const rank, std::string const& format, Args const&... args)
|
||||
{
|
||||
return log(level, rank, format.c_str(), args...);
|
||||
}
|
||||
@ -122,7 +122,7 @@ private:
|
||||
return fmtstr("%s[%s] ", kPREFIX, getLevelName(level));
|
||||
}
|
||||
|
||||
static inline std::string getPrefix(const Level level, const int rank)
|
||||
static inline std::string getPrefix(const Level level, int const rank)
|
||||
{
|
||||
return fmtstr("%s[%s][%d] ", kPREFIX, getLevelName(level), rank);
|
||||
}
|
||||
@ -148,7 +148,7 @@ void Logger::log(Logger::Level level, char const* format, Args const&... args)
|
||||
}
|
||||
|
||||
template <typename... Args>
|
||||
void Logger::log(const Logger::Level level, const int rank, char const* format, const Args&... args)
|
||||
void Logger::log(const Logger::Level level, int const rank, char const* format, Args const&... args)
|
||||
{
|
||||
if (level_ <= level)
|
||||
{
|
||||
|
||||
@ -112,63 +112,63 @@ template void deviceFill(int* devptr, size_t size, int value, cudaStream_t strea
|
||||
template void deviceFill(bool* devptr, size_t size, bool value, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void cudaD2Hcpy(T* tgt, const T* src, const size_t size)
|
||||
void cudaD2Hcpy(T* tgt, T const* src, const size_t size)
|
||||
{
|
||||
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
template void cudaD2Hcpy(float* tgt, const float* src, size_t size);
|
||||
template void cudaD2Hcpy(half* tgt, const half* src, size_t size);
|
||||
template void cudaD2Hcpy(float* tgt, float const* src, size_t size);
|
||||
template void cudaD2Hcpy(half* tgt, half const* src, size_t size);
|
||||
#ifdef ENABLE_BF16
|
||||
template void cudaD2Hcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size);
|
||||
template void cudaD2Hcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size);
|
||||
#endif
|
||||
template void cudaD2Hcpy(int* tgt, const int* src, size_t size);
|
||||
template void cudaD2Hcpy(bool* tgt, const bool* src, size_t size);
|
||||
template void cudaD2Hcpy(int* tgt, int const* src, size_t size);
|
||||
template void cudaD2Hcpy(bool* tgt, bool const* src, size_t size);
|
||||
#ifdef ENABLE_FP8
|
||||
template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, const __nv_fp8_e4m3* src, size_t size);
|
||||
template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size);
|
||||
#endif
|
||||
template void cudaD2Hcpy(unsigned long long* tgt, const unsigned long long* src, size_t size);
|
||||
template void cudaD2Hcpy(unsigned int* tgt, const unsigned int* src, size_t size);
|
||||
template void cudaD2Hcpy(int8_t* tgt, const int8_t* src, size_t size);
|
||||
template void cudaD2Hcpy(unsigned long long* tgt, unsigned long long const* src, size_t size);
|
||||
template void cudaD2Hcpy(unsigned int* tgt, unsigned int const* src, size_t size);
|
||||
template void cudaD2Hcpy(int8_t* tgt, int8_t const* src, size_t size);
|
||||
|
||||
template <typename T>
|
||||
void cudaH2Dcpy(T* tgt, const T* src, const size_t size)
|
||||
void cudaH2Dcpy(T* tgt, T const* src, const size_t size)
|
||||
{
|
||||
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
template void cudaH2Dcpy(float* tgt, const float* src, size_t size);
|
||||
template void cudaH2Dcpy(half* tgt, const half* src, size_t size);
|
||||
template void cudaH2Dcpy(float* tgt, float const* src, size_t size);
|
||||
template void cudaH2Dcpy(half* tgt, half const* src, size_t size);
|
||||
#ifdef ENABLE_BF16
|
||||
template void cudaH2Dcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size);
|
||||
template void cudaH2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size);
|
||||
#endif
|
||||
template void cudaH2Dcpy(int* tgt, const int* src, size_t size);
|
||||
template void cudaH2Dcpy(bool* tgt, const bool* src, size_t size);
|
||||
template void cudaH2Dcpy(int* tgt, int const* src, size_t size);
|
||||
template void cudaH2Dcpy(bool* tgt, bool const* src, size_t size);
|
||||
#ifdef ENABLE_FP8
|
||||
template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, const __nv_fp8_e4m3* src, size_t size);
|
||||
template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size);
|
||||
#endif
|
||||
template void cudaH2Dcpy(unsigned long long* tgt, const unsigned long long* src, size_t size);
|
||||
template void cudaH2Dcpy(unsigned int* tgt, const unsigned int* src, size_t size);
|
||||
template void cudaH2Dcpy(int8_t* tgt, const int8_t* src, size_t size);
|
||||
template void cudaH2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size);
|
||||
template void cudaH2Dcpy(unsigned int* tgt, unsigned int const* src, size_t size);
|
||||
template void cudaH2Dcpy(int8_t* tgt, int8_t const* src, size_t size);
|
||||
|
||||
template <typename T>
|
||||
void cudaD2Dcpy(T* tgt, const T* src, const size_t size, cudaStream_t stream)
|
||||
void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToDevice, stream));
|
||||
}
|
||||
|
||||
template void cudaD2Dcpy(float* tgt, const float* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(half* tgt, const half* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(float* tgt, float const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(half* tgt, half const* src, size_t size, cudaStream_t stream);
|
||||
#ifdef ENABLE_BF16
|
||||
template void cudaD2Dcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
template void cudaD2Dcpy(int* tgt, const int* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(bool* tgt, const bool* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(int8_t* tgt, const int8_t* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(int* tgt, int const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream);
|
||||
#ifdef ENABLE_FP8
|
||||
template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, const __nv_fp8_e4m3* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
template void cudaD2Dcpy(unsigned long long* tgt, const unsigned long long* src, size_t size, cudaStream_t stream);
|
||||
template void cudaD2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream);
|
||||
|
||||
template <typename T_OUT, typename T_IN>
|
||||
__global__ void cudaCast(T_OUT* dst, T_IN* src, const size_t size)
|
||||
@ -204,7 +204,7 @@ template void invokeCudaCast(__nv_fp8_e4m3* dst, half const* const src, const si
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream)
|
||||
void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
if (stream != NULL)
|
||||
{
|
||||
@ -216,19 +216,19 @@ void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream)
|
||||
}
|
||||
}
|
||||
|
||||
template void cudaAutoCpy(float* tgt, const float* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(half* tgt, const half* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(float* tgt, float const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(half* tgt, half const* src, size_t size, cudaStream_t stream);
|
||||
#ifdef ENABLE_BF16
|
||||
template void cudaAutoCpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream);
|
||||
#endif
|
||||
template void cudaAutoCpy(int* tgt, const int* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(bool* tgt, const bool* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(int8_t* tgt, const int8_t* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(uint8_t* tgt, const uint8_t* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(uint32_t* tgt, const uint32_t* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(unsigned long long* tgt, const unsigned long long* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(unsigned long* tgt, const unsigned long* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(char* tgt, const char* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(int* tgt, int const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(uint8_t* tgt, uint8_t const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(uint32_t* tgt, uint32_t const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(unsigned long* tgt, unsigned long const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(char* tgt, char const* src, size_t size, cudaStream_t stream);
|
||||
|
||||
template void cudaAutoCpy(float const** tgt, float const* const* src, size_t size, cudaStream_t stream);
|
||||
template void cudaAutoCpy(half const** tgt, half const* const* src, size_t size, cudaStream_t stream);
|
||||
@ -242,7 +242,7 @@ template void cudaAutoCpy(
|
||||
unsigned long long const** tgt, unsigned long long const* const* src, size_t size, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, const int seq_offset)
|
||||
__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, int const seq_offset)
|
||||
{
|
||||
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
curandState_t local_state;
|
||||
@ -254,7 +254,7 @@ __global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, const i
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void cuda_random_uniform_kernel<int>(int* buffer, const size_t size, const int seq_offset)
|
||||
__global__ void cuda_random_uniform_kernel<int>(int* buffer, const size_t size, int const seq_offset)
|
||||
{
|
||||
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
curandState_t local_state;
|
||||
@ -266,7 +266,7 @@ __global__ void cuda_random_uniform_kernel<int>(int* buffer, const size_t size,
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void cuda_random_uniform_kernel<bool>(bool* buffer, const size_t size, const int seq_offset)
|
||||
__global__ void cuda_random_uniform_kernel<bool>(bool* buffer, const size_t size, int const seq_offset)
|
||||
{
|
||||
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
curandState_t local_state;
|
||||
@ -278,7 +278,7 @@ __global__ void cuda_random_uniform_kernel<bool>(bool* buffer, const size_t size
|
||||
}
|
||||
|
||||
template <>
|
||||
__global__ void cuda_random_uniform_kernel<char>(char* buffer, const size_t size, const int seq_offset)
|
||||
__global__ void cuda_random_uniform_kernel<char>(char* buffer, const size_t size, int const seq_offset)
|
||||
{
|
||||
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
curandState_t local_state;
|
||||
@ -462,30 +462,30 @@ void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cud
|
||||
cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size);
|
||||
}
|
||||
|
||||
template void invokeCudaD2DcpyConvert(int8_t* tgt, const float* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, const int8_t* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, const int* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(half* tgt, const int* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, const float* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(half* tgt, const float* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, const half* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(uint32_t* tgt, const int* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(int* tgt, const uint32_t* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(int* tgt, const float* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(int* tgt, const half* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(int8_t* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, int8_t const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, int const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(half* tgt, int const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(half* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, half const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(uint32_t* tgt, int const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(int* tgt, uint32_t const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(int* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(int* tgt, half const* src, const size_t size, cudaStream_t stream);
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, const float* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, const int* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, const __nv_bfloat16* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(int* tgt, const __nv_bfloat16* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, int const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(float* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream);
|
||||
template void invokeCudaD2DcpyConvert(int* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream);
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
__global__ void cudaD2DScaleCpyConvert(
|
||||
T_OUT* dst, const T_IN* src, const float* scale, bool invert_scale, const size_t size)
|
||||
T_OUT* dst, const T_IN* src, float const* scale, bool invert_scale, const size_t size)
|
||||
{
|
||||
const float scale_value = invert_scale ? 1.0f / scale[0] : scale[0];
|
||||
float const scale_value = invert_scale ? 1.0f / scale[0] : scale[0];
|
||||
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
|
||||
{
|
||||
dst[tid] = cuda_cast<T_OUT>(cuda_cast<float>(src[tid]) * scale_value);
|
||||
@ -494,7 +494,7 @@ __global__ void cudaD2DScaleCpyConvert(
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
void invokeCudaD2DScaleCpyConvert(
|
||||
T_OUT* tgt, const T_IN* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream)
|
||||
T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream)
|
||||
{
|
||||
cudaD2DScaleCpyConvert<<<256, 256, 0, stream>>>(tgt, src, scale, invert_scale, size);
|
||||
}
|
||||
@ -524,7 +524,7 @@ void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const size_t size, cudaSt
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void saveToBinary(const T* ptr, const size_t size, std::string filename)
|
||||
void saveToBinary(T const* ptr, const size_t size, std::string filename)
|
||||
{
|
||||
|
||||
std::vector<T> h_ptr(size);
|
||||
@ -541,14 +541,14 @@ void saveToBinary(const T* ptr, const size_t size, std::string filename)
|
||||
out.write((char*) float_ptr.data(), size * sizeof(float));
|
||||
}
|
||||
|
||||
template void saveToBinary(const float* ptr, const size_t size, std::string filename);
|
||||
template void saveToBinary(const half* ptr, const size_t size, std::string filename);
|
||||
template void saveToBinary(float const* ptr, const size_t size, std::string filename);
|
||||
template void saveToBinary(half const* ptr, const size_t size, std::string filename);
|
||||
#ifdef ENABLE_BF16
|
||||
template void saveToBinary(const __nv_bfloat16* ptr, const size_t size, std::string filename);
|
||||
template void saveToBinary(__nv_bfloat16 const* ptr, const size_t size, std::string filename);
|
||||
#endif // ENABLE_BF16
|
||||
|
||||
template <>
|
||||
void saveToBinary(const int* ptr, const size_t size, std::string filename)
|
||||
void saveToBinary(int const* ptr, const size_t size, std::string filename)
|
||||
{
|
||||
std::vector<int> h_ptr(size);
|
||||
cudaD2Hcpy(h_ptr.data(), ptr, size);
|
||||
@ -831,7 +831,7 @@ size_t cuda_datatype_size(TRTLLMCudaDataType dt)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void check_range(const T* buffer, size_t size, T min, T max, bool* d_within_range)
|
||||
__global__ void check_range(T const* buffer, size_t size, T min, T max, bool* d_within_range)
|
||||
{
|
||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x)
|
||||
{
|
||||
@ -844,7 +844,7 @@ __global__ void check_range(const T* buffer, size_t size, T min, T max, bool* d_
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream)
|
||||
bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream)
|
||||
{
|
||||
cudaMemsetAsync(d_within_range, true, sizeof(bool), stream);
|
||||
|
||||
@ -858,12 +858,12 @@ bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_
|
||||
}
|
||||
|
||||
template bool invokeCheckRange<int>(
|
||||
const int* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream);
|
||||
int const* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream);
|
||||
|
||||
/*
|
||||
* Determine the total workspace size based on a vector containing multiple variable sizes.
|
||||
*/
|
||||
size_t calcAlignedSize(const std::vector<size_t>& sizes, const size_t ALIGN_BYTES)
|
||||
size_t calcAlignedSize(std::vector<size_t> const& sizes, const size_t ALIGN_BYTES)
|
||||
{
|
||||
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
|
||||
// Check ALIGN_BYTES is a power of 2
|
||||
@ -885,7 +885,7 @@ size_t calcAlignedSize(const std::vector<size_t>& sizes, const size_t ALIGN_BYTE
|
||||
* of each variable.
|
||||
*/
|
||||
void calcAlignedPointers(
|
||||
std::vector<void*>& outPtrs, const void* p, const std::vector<size_t>& sizes, size_t ALIGN_BYTES)
|
||||
std::vector<void*>& outPtrs, void const* p, std::vector<size_t> const& sizes, size_t ALIGN_BYTES)
|
||||
{
|
||||
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
|
||||
// Check ALIGN_BYTES is a power of 2
|
||||
|
||||
@ -40,16 +40,16 @@ template <typename T>
|
||||
void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0);
|
||||
|
||||
template <typename T>
|
||||
void cudaD2Hcpy(T* tgt, const T* src, const size_t size);
|
||||
void cudaD2Hcpy(T* tgt, T const* src, const size_t size);
|
||||
|
||||
template <typename T>
|
||||
void cudaH2Dcpy(T* tgt, const T* src, const size_t size);
|
||||
void cudaH2Dcpy(T* tgt, T const* src, const size_t size);
|
||||
|
||||
template <typename T>
|
||||
void cudaD2Dcpy(T* tgt, const T* src, const size_t size, cudaStream_t stream = NULL);
|
||||
void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream = NULL);
|
||||
|
||||
template <typename T>
|
||||
void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream = NULL);
|
||||
void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream = NULL);
|
||||
|
||||
template <typename T>
|
||||
void cudaRandomUniform(T* buffer, const size_t size);
|
||||
@ -234,9 +234,9 @@ void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cud
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
void invokeCudaD2DScaleCpyConvert(
|
||||
T_OUT* tgt, const T_IN* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream = 0);
|
||||
T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream = 0);
|
||||
|
||||
inline bool checkIfFileExist(const std::string& file_path)
|
||||
inline bool checkIfFileExist(std::string const& file_path)
|
||||
{
|
||||
std::ifstream in(file_path, std::ios::in | std::ios::binary);
|
||||
if (in.is_open())
|
||||
@ -248,7 +248,7 @@ inline bool checkIfFileExist(const std::string& file_path)
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void saveToBinary(const T* ptr, const size_t size, std::string filename);
|
||||
void saveToBinary(T const* ptr, const size_t size, std::string filename);
|
||||
|
||||
template <typename T_IN, typename T_fake_type>
|
||||
void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream);
|
||||
@ -256,10 +256,10 @@ void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream);
|
||||
size_t cuda_datatype_size(TRTLLMCudaDataType dt);
|
||||
|
||||
template <typename T>
|
||||
bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream);
|
||||
bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream);
|
||||
|
||||
size_t calcAlignedSize(const std::vector<size_t>& sizes, size_t ALIGN_BYTES = 256);
|
||||
size_t calcAlignedSize(std::vector<size_t> const& sizes, size_t ALIGN_BYTES = 256);
|
||||
void calcAlignedPointers(
|
||||
std::vector<void*>& outPtrs, const void* p, const std::vector<size_t>& sizes, size_t ALIGN_BYTES = 256);
|
||||
std::vector<void*>& outPtrs, void const* p, std::vector<size_t> const& sizes, size_t ALIGN_BYTES = 256);
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -50,6 +50,7 @@ MPI_Datatype getMpiDtype(MpiType dtype)
|
||||
{MpiType::kUINT64, MPI_UINT64_T},
|
||||
{MpiType::kFP8, MPI_UINT8_T},
|
||||
{MpiType::kBF16, MPI_UINT16_T},
|
||||
{MpiType::kCHAR, MPI_CHAR},
|
||||
};
|
||||
return dtype_map.at(dtype);
|
||||
}
|
||||
@ -126,23 +127,6 @@ void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
|
||||
MPICHECK(MPI_Bcast(buffer, size, getMpiDtype(dtype), root, mComm));
|
||||
}
|
||||
|
||||
void MpiComm::bcast(std::vector<int64_t>& packed, int root) const
|
||||
{
|
||||
int64_t nWords1;
|
||||
auto const rank = getRank();
|
||||
if (rank == root)
|
||||
{
|
||||
nWords1 = static_cast<int64_t>(packed.size());
|
||||
}
|
||||
auto const mpiInt64 = MpiTypeConverter<int64_t>::value;
|
||||
bcast(&nWords1, 1, mpiInt64, root);
|
||||
if (rank != root)
|
||||
{
|
||||
packed.resize(nWords1);
|
||||
}
|
||||
bcast(packed.data(), packed.size(), mpiInt64, root);
|
||||
}
|
||||
|
||||
void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
|
||||
{
|
||||
MPICHECK(MPI_Send(buffer, size, getMpiDtype(dtype), dest, tag, mComm));
|
||||
@ -162,12 +146,12 @@ MpiComm MpiComm::split(int color, int key) const
|
||||
return MpiComm{splitComm, true};
|
||||
}
|
||||
|
||||
void MpiComm::allreduce(const void* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const
|
||||
void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const
|
||||
{
|
||||
MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm));
|
||||
}
|
||||
|
||||
void MpiComm::allgather(const void* sendbuf, void* recvbuf, int count, MpiType dtype) const
|
||||
void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const
|
||||
{
|
||||
MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
|
||||
}
|
||||
|
||||
@ -39,7 +39,7 @@ public:
|
||||
|
||||
constexpr QuantMode(QuantMode const&) noexcept = default;
|
||||
|
||||
constexpr QuantMode& operator=(const QuantMode& other) noexcept = default;
|
||||
constexpr QuantMode& operator=(QuantMode const& other) noexcept = default;
|
||||
|
||||
static constexpr QuantMode none() noexcept
|
||||
{
|
||||
@ -276,32 +276,32 @@ public:
|
||||
return quantMode;
|
||||
}
|
||||
|
||||
constexpr QuantMode operator+(const QuantMode& other) const noexcept
|
||||
constexpr QuantMode operator+(QuantMode const& other) const noexcept
|
||||
{
|
||||
return QuantMode(mValue | other.mValue);
|
||||
}
|
||||
|
||||
constexpr QuantMode& operator+=(const QuantMode& other) noexcept
|
||||
constexpr QuantMode& operator+=(QuantMode const& other) noexcept
|
||||
{
|
||||
return *this = *this + other;
|
||||
}
|
||||
|
||||
constexpr QuantMode operator-(const QuantMode& other) const noexcept
|
||||
constexpr QuantMode operator-(QuantMode const& other) const noexcept
|
||||
{
|
||||
return QuantMode(mValue & ~other.mValue);
|
||||
}
|
||||
|
||||
constexpr QuantMode& operator-=(const QuantMode& other) noexcept
|
||||
constexpr QuantMode& operator-=(QuantMode const& other) noexcept
|
||||
{
|
||||
return *this = *this - other;
|
||||
}
|
||||
|
||||
constexpr bool operator==(const QuantMode& other) const noexcept
|
||||
constexpr bool operator==(QuantMode const& other) const noexcept
|
||||
{
|
||||
return mValue == other.mValue;
|
||||
}
|
||||
|
||||
constexpr bool operator!=(const QuantMode& other) const noexcept
|
||||
constexpr bool operator!=(QuantMode const& other) const noexcept
|
||||
{
|
||||
return !(*this == other);
|
||||
}
|
||||
|
||||
@ -63,11 +63,11 @@ struct BytesToType<16>
|
||||
};
|
||||
|
||||
template <int Bytes>
|
||||
__device__ inline void copy(const void* local, void* data)
|
||||
__device__ inline void copy(void const* local, void* data)
|
||||
{
|
||||
using T = typename BytesToType<Bytes>::type;
|
||||
|
||||
const T* in = static_cast<const T*>(local);
|
||||
T const* in = static_cast<T const*>(local);
|
||||
T* out = static_cast<T*>(data);
|
||||
*out = *in;
|
||||
}
|
||||
@ -257,8 +257,8 @@ __inline__ __device__ void cgBlockReduceSumElements(float* element_list, float*
|
||||
cg::thread_block cta = cg::this_thread_block();
|
||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);
|
||||
|
||||
const int tid = cta.thread_rank();
|
||||
const int blockz = blockDim.x;
|
||||
int const tid = cta.thread_rank();
|
||||
int const blockz = blockDim.x;
|
||||
for (int i = 0; i < NUM; i++)
|
||||
{
|
||||
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
|
||||
@ -325,7 +325,7 @@ struct TopK
|
||||
|
||||
__device__ __forceinline__ void init()
|
||||
{
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
for (int i = 0; i < MAX_K; i++)
|
||||
@ -337,7 +337,7 @@ struct TopK
|
||||
};
|
||||
|
||||
template <typename T, int MAX_K>
|
||||
__device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(const TopK<T, MAX_K>& a, const TopK<T, MAX_K>& b)
|
||||
__device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(TopK<T, MAX_K> const& a, TopK<T, MAX_K> const& b)
|
||||
{
|
||||
TopK<T, MAX_K> res = a;
|
||||
for (int i = 0; i < MAX_K; ++i)
|
||||
@ -368,19 +368,19 @@ struct TopK_2
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ TopK_2<T> reduce_topk_op_2(const TopK_2<T>& a, const TopK_2<T>& b)
|
||||
__device__ __forceinline__ TopK_2<T> reduce_topk_op_2(TopK_2<T> const& a, TopK_2<T> const& b)
|
||||
{
|
||||
return a.u > b.u ? a : b;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T clamp_inf_for_half(const float input)
|
||||
__device__ __forceinline__ T clamp_inf_for_half(float const input)
|
||||
{
|
||||
return input;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ __forceinline__ half clamp_inf_for_half(const float input)
|
||||
__device__ __forceinline__ half clamp_inf_for_half(float const input)
|
||||
{
|
||||
// clamp inf values to enable fp16 training
|
||||
return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000);
|
||||
|
||||
@ -152,7 +152,7 @@ Tensor Tensor::slice(std::vector<size_t> shape, size_t offset) const
|
||||
return Tensor(this->where, this->type, shape, this->getPtrWithOffset(offset));
|
||||
}
|
||||
|
||||
TensorMap::TensorMap(const std::unordered_map<std::string, Tensor>& tensor_map)
|
||||
TensorMap::TensorMap(std::unordered_map<std::string, Tensor> const& tensor_map)
|
||||
{
|
||||
for (auto& kv : tensor_map)
|
||||
{
|
||||
@ -167,7 +167,7 @@ TensorMap::TensorMap(const std::unordered_map<std::string, Tensor>& tensor_map)
|
||||
}
|
||||
}
|
||||
|
||||
TensorMap::TensorMap(const std::vector<Tensor>& tensor_map)
|
||||
TensorMap::TensorMap(std::vector<Tensor> const& tensor_map)
|
||||
{
|
||||
for (size_t i = 0; i < tensor_map.size(); i++)
|
||||
{
|
||||
|
||||
@ -191,7 +191,7 @@ struct TensorDataType<int*>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TensorDataType<const int*>
|
||||
struct TensorDataType<int const*>
|
||||
{
|
||||
static constexpr DataType value = TYPE_INT32_PTR;
|
||||
};
|
||||
@ -419,8 +419,8 @@ private:
|
||||
|
||||
public:
|
||||
TensorMap() = default;
|
||||
TensorMap(const std::unordered_map<std::string, Tensor>& tensor_map);
|
||||
TensorMap(const std::vector<Tensor>& tensor_map);
|
||||
TensorMap(std::unordered_map<std::string, Tensor> const& tensor_map);
|
||||
TensorMap(std::vector<Tensor> const& tensor_map);
|
||||
TensorMap(std::initializer_list<std::pair<std::string, Tensor>> tensor_map);
|
||||
~TensorMap();
|
||||
|
||||
@ -429,7 +429,7 @@ public:
|
||||
return tensor_map_.size();
|
||||
}
|
||||
|
||||
inline bool contains(const std::string& key) const
|
||||
inline bool contains(std::string const& key) const
|
||||
{
|
||||
TLLM_LOG_TRACE("%s for key: %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
return tensor_map_.find(key) != tensor_map_.end();
|
||||
@ -437,7 +437,7 @@ public:
|
||||
|
||||
std::vector<std::string> keys() const;
|
||||
|
||||
inline void insert(const std::string& key, const Tensor& value)
|
||||
inline void insert(std::string const& key, Tensor const& value)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!contains(key), fmtstr("Duplicated key %s", key.c_str()));
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
@ -445,7 +445,7 @@ public:
|
||||
tensor_map_.insert({key, value});
|
||||
}
|
||||
|
||||
inline void insertIfValid(const std::string& key, const Tensor& value)
|
||||
inline void insertIfValid(std::string const& key, Tensor const& value)
|
||||
{
|
||||
if (value.isValid())
|
||||
{
|
||||
@ -462,7 +462,7 @@ public:
|
||||
Tensor at(int tmp) = delete;
|
||||
Tensor at(size_t tmp) = delete;
|
||||
|
||||
inline Tensor& at(const std::string& key)
|
||||
inline Tensor& at(std::string const& key)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
TLLM_CHECK_WITH_INFO(contains(key),
|
||||
@ -471,7 +471,7 @@ public:
|
||||
return tensor_map_.at(key);
|
||||
}
|
||||
|
||||
inline Tensor at(const std::string& key) const
|
||||
inline Tensor at(std::string const& key) const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(contains(key),
|
||||
fmtstr(
|
||||
@ -479,7 +479,7 @@ public:
|
||||
return tensor_map_.at(key);
|
||||
}
|
||||
|
||||
inline std::optional<Tensor> atOpt(const std::string& key) const
|
||||
inline std::optional<Tensor> atOpt(std::string const& key) const
|
||||
{
|
||||
if (contains(key))
|
||||
return tensor_map_.at(key);
|
||||
@ -487,7 +487,7 @@ public:
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
inline Tensor& at(const std::string& key, Tensor& default_tensor)
|
||||
inline Tensor& at(std::string const& key, Tensor& default_tensor)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
if (contains(key))
|
||||
@ -497,7 +497,7 @@ public:
|
||||
return default_tensor;
|
||||
}
|
||||
|
||||
inline Tensor at(const std::string& key, Tensor& default_tensor) const
|
||||
inline Tensor at(std::string const& key, Tensor& default_tensor) const
|
||||
{
|
||||
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
if (contains(key))
|
||||
@ -507,7 +507,7 @@ public:
|
||||
return default_tensor;
|
||||
}
|
||||
|
||||
inline Tensor& at(const std::string& key, Tensor&& default_tensor)
|
||||
inline Tensor& at(std::string const& key, Tensor&& default_tensor)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||
if (contains(key))
|
||||
@ -517,7 +517,7 @@ public:
|
||||
return default_tensor;
|
||||
}
|
||||
|
||||
inline Tensor at(const std::string& key, Tensor&& default_tensor) const
|
||||
inline Tensor at(std::string const& key, Tensor&& default_tensor) const
|
||||
{
|
||||
if (contains(key))
|
||||
{
|
||||
@ -527,7 +527,7 @@ public:
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T getVal(const std::string& key) const
|
||||
inline T getVal(std::string const& key) const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(contains(key),
|
||||
fmtstr(
|
||||
@ -536,7 +536,7 @@ public:
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::optional<T> getValOpt(const std::string& key) const
|
||||
inline std::optional<T> getValOpt(std::string const& key) const
|
||||
{
|
||||
if (contains(key))
|
||||
{
|
||||
@ -549,7 +549,7 @@ public:
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T getVal(const std::string& key, T default_value) const
|
||||
inline T getVal(std::string const& key, T default_value) const
|
||||
{
|
||||
if (contains(key))
|
||||
{
|
||||
@ -559,7 +559,7 @@ public:
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T getValWithOffset(const std::string& key, size_t index) const
|
||||
inline T getValWithOffset(std::string const& key, size_t index) const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(contains(key),
|
||||
fmtstr(
|
||||
@ -568,7 +568,7 @@ public:
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T getValWithOffset(const std::string& key, size_t index, T default_value) const
|
||||
inline T getValWithOffset(std::string const& key, size_t index, T default_value) const
|
||||
{
|
||||
if (contains(key))
|
||||
{
|
||||
@ -578,7 +578,7 @@ public:
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T* getPtr(const std::string& key) const
|
||||
inline T* getPtr(std::string const& key) const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(contains(key),
|
||||
fmtstr(
|
||||
@ -587,7 +587,7 @@ public:
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T* getPtr(const std::string& key, T* default_ptr) const
|
||||
inline T* getPtr(std::string const& key, T* default_ptr) const
|
||||
{
|
||||
if (contains(key))
|
||||
{
|
||||
@ -597,7 +597,7 @@ public:
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T* getPtrWithOffset(const std::string& key, size_t index) const
|
||||
inline T* getPtrWithOffset(std::string const& key, size_t index) const
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(contains(key),
|
||||
fmtstr(
|
||||
@ -606,7 +606,7 @@ public:
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline T* getPtrWithOffset(const std::string& key, size_t index, T* default_ptr) const
|
||||
inline T* getPtrWithOffset(std::string const& key, size_t index, T* default_ptr) const
|
||||
{
|
||||
if (contains(key))
|
||||
{
|
||||
|
||||
@ -34,7 +34,7 @@ int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2;
|
||||
|
||||
#if !defined(_MSC_VER)
|
||||
|
||||
TllmException::TllmException(char const* file, std::size_t line, const std::string& msg)
|
||||
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
|
||||
: std::runtime_error{""}
|
||||
{
|
||||
mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES);
|
||||
@ -43,7 +43,7 @@ TllmException::TllmException(char const* file, std::size_t line, const std::stri
|
||||
std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())});
|
||||
}
|
||||
#else
|
||||
TllmException::TllmException(char const* file, std::size_t line, const std::string& msg)
|
||||
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
|
||||
: mNbFrames{}
|
||||
, std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)}
|
||||
{
|
||||
|
||||
@ -65,7 +65,7 @@ __forceinline__ __device__ float copysignf_pos(float a, float b)
|
||||
__forceinline__ __device__ float tanh_opt(float x)
|
||||
{
|
||||
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
|
||||
const float exp_val = -1.f * fabs(2 * x);
|
||||
float const exp_val = -1.f * fabs(2 * x);
|
||||
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
|
||||
#else
|
||||
return fast_tanh(x);
|
||||
@ -76,7 +76,7 @@ __forceinline__ __device__ float tanh_opt(float x)
|
||||
template <>
|
||||
struct GELU_taylor<float>
|
||||
{
|
||||
static const bool kIsHeavy = true;
|
||||
static bool const kIsHeavy = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
float operator()(float const& z) const
|
||||
|
||||
@ -157,8 +157,8 @@ private:
|
||||
MatrixCoord extent_real_;
|
||||
ElementwiseFunctor elementwise_;
|
||||
|
||||
const bool per_token_quant_;
|
||||
const bool per_channel_quant_;
|
||||
bool const per_token_quant_;
|
||||
bool const per_channel_quant_;
|
||||
|
||||
AlphaScaleElementType* ptr_alpha_row_;
|
||||
AlphaScaleElementType* ptr_alpha_col_;
|
||||
|
||||
@ -65,7 +65,7 @@ namespace device
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T_IN, typename T_OUT>
|
||||
__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, const GemmCoord* problem_sizes, int splitk,
|
||||
__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk,
|
||||
int64_t* splitk_buffer_offsets)
|
||||
{
|
||||
// in_tensor: [problem_idx, k_partition, hidden_size]
|
||||
@ -73,9 +73,9 @@ __global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, const
|
||||
// so, we need to use splitk_buffer_offsets.
|
||||
// out_tensor: problem_idx * [hidden_size]
|
||||
|
||||
const int problem_idx = blockIdx.y;
|
||||
int const problem_idx = blockIdx.y;
|
||||
GemmCoord problem = problem_sizes[problem_idx];
|
||||
const int hidden_size = problem.m() * problem.n();
|
||||
int const hidden_size = problem.m() * problem.n();
|
||||
const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk;
|
||||
T_OUT* out_tensor_ = out_tensor[problem_idx];
|
||||
|
||||
@ -143,7 +143,7 @@ protected:
|
||||
|
||||
private:
|
||||
/// Get the number of tiles across all problems in a group
|
||||
static int32_t group_tile_count(const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count)
|
||||
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count)
|
||||
{
|
||||
int32_t tiles = 0;
|
||||
for (int32_t i = 0; i < problem_count; ++i)
|
||||
@ -182,7 +182,7 @@ private:
|
||||
|
||||
/// Reorder `data` according to `indices`
|
||||
template <typename T>
|
||||
static void reorder_array(T* data, const std::vector<size_t>& indices)
|
||||
static void reorder_array(T* data, std::vector<size_t> const& indices)
|
||||
{
|
||||
// For now, simply create a copy of the data and then copy over to the original.
|
||||
std::vector<T> copy(indices.size());
|
||||
@ -314,7 +314,7 @@ public:
|
||||
|
||||
/// Computes the number of threadblocks to launch for the grouped kernel
|
||||
static int sufficient(
|
||||
const cutlass::gemm::GemmCoord* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1)
|
||||
cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1)
|
||||
{
|
||||
// Determine the number of blocks that would be launched to fill up a single
|
||||
// wave on the GPU with each SM having maximum occupancy.
|
||||
|
||||
@ -142,7 +142,7 @@ struct GemmFpAIntB
|
||||
Arguments() {}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(cutlass::gemm::GemmCoord const& problem_size, const int group_size,
|
||||
Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size,
|
||||
typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B,
|
||||
typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero,
|
||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||
@ -206,7 +206,7 @@ struct GemmFpAIntB
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, const int gemm_k_size,
|
||||
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size,
|
||||
void* workspace = nullptr)
|
||||
: problem_size(args.problem_size)
|
||||
, group_size(args.group_size)
|
||||
|
||||
@ -174,7 +174,7 @@ public:
|
||||
/// Ctor
|
||||
CUTLASS_HOST_DEVICE
|
||||
Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op,
|
||||
const ElementA* ptr_A, const ElementB* ptr_B, const ElementScale* weight_scales, const ElementC* ptr_C,
|
||||
ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C,
|
||||
ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k,
|
||||
GemmCoord* host_problem_sizes = nullptr)
|
||||
: problem_count(problem_count)
|
||||
|
||||
@ -119,7 +119,7 @@ struct BaseMoeProblemVisitor
|
||||
|
||||
/// Get the grid shape
|
||||
CUTLASS_HOST_DEVICE
|
||||
static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem)
|
||||
static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem)
|
||||
{
|
||||
|
||||
return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
|
||||
@ -177,12 +177,12 @@ struct BaseMoeProblemVisitor
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static int32_t tile_count(const cutlass::gemm::GemmCoord& grid)
|
||||
static int32_t tile_count(cutlass::gemm::GemmCoord const& grid)
|
||||
{
|
||||
return ProblemSizeHelper::tile_count(grid);
|
||||
}
|
||||
|
||||
static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count)
|
||||
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count)
|
||||
{
|
||||
int32_t total_tiles = 0;
|
||||
for (int32_t i = 0; i < problem_count; ++i)
|
||||
@ -328,12 +328,12 @@ struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode:
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(
|
||||
const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count)
|
||||
cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count)
|
||||
{
|
||||
return 0;
|
||||
}
|
||||
|
||||
static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count,
|
||||
static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count,
|
||||
int32_t block_count, void* host_workspace_ptr)
|
||||
{
|
||||
}
|
||||
|
||||
@ -60,7 +60,7 @@ namespace threadblock
|
||||
template <typename WarpMma, int kExpansionFactor = 1>
|
||||
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
||||
typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C,
|
||||
const int warp_tileB_k_offset)
|
||||
int const warp_tileB_k_offset)
|
||||
{
|
||||
warp_mma(D, A, B, C);
|
||||
}
|
||||
@ -68,7 +68,7 @@ CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC&
|
||||
template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
|
||||
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
||||
typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B,
|
||||
typename WarpMma::FragmentC const& C, const int warp_tileB_k_offset)
|
||||
typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset)
|
||||
{
|
||||
warp_mma(D, A, B, C, warp_tileB_k_offset);
|
||||
}
|
||||
|
||||
@ -572,8 +572,8 @@ public:
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
|
||||
@ -219,7 +219,7 @@ public:
|
||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
typename Base::SharedStorage& shared_storage,
|
||||
///< Group size for quantization. Not used by this main loop since it assumes per-column
|
||||
const int group_size,
|
||||
int const group_size,
|
||||
///< ID within the threadblock
|
||||
int thread_idx,
|
||||
///< ID of warp
|
||||
@ -534,8 +534,8 @@ public:
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||
|
||||
@ -184,7 +184,7 @@ public:
|
||||
CUTLASS_DEVICE
|
||||
DqMmaPipelined(typename Base::SharedStorage&
|
||||
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||
const int group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
|
||||
int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
|
||||
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
|
||||
///< argument is not added, it does not affect compilation for sm>=80.
|
||||
int thread_idx, ///< ID within the threadblock
|
||||
@ -353,8 +353,8 @@ public:
|
||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||
++this->warp_tile_iterator_A_;
|
||||
|
||||
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||
{
|
||||
|
||||
@ -218,7 +218,7 @@ public:
|
||||
/// Performs a warp-level matrix multiply-accumulate operation
|
||||
CUTLASS_DEVICE
|
||||
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C,
|
||||
const int warp_tileB_k_offset) const
|
||||
int const warp_tileB_k_offset) const
|
||||
{
|
||||
|
||||
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||
|
||||
@ -136,11 +136,11 @@ public:
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, const int warp_idx_n, const int lane_idx)
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
|
||||
{
|
||||
const int warp_offset = warp_idx_n * Shape::kN;
|
||||
const int quad = lane_idx / 4;
|
||||
const int thread_offset = warp_offset + quad;
|
||||
int const warp_offset = warp_idx_n * Shape::kN;
|
||||
int const quad = lane_idx / 4;
|
||||
int const thread_offset = warp_offset + quad;
|
||||
pointer_scale_ = smem_scales.data() + thread_offset;
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
@ -149,7 +149,7 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
|
||||
{
|
||||
}
|
||||
@ -165,7 +165,7 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
@ -174,7 +174,7 @@ public:
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag);
|
||||
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||
@ -222,7 +222,7 @@ public:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(
|
||||
FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag, const FragmentScale& zero_frag)
|
||||
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
@ -231,8 +231,8 @@ public:
|
||||
== FragmentDequantizedOperand::kElements,
|
||||
"");
|
||||
|
||||
const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag);
|
||||
const __nv_bfloat16* zero_ptr = reinterpret_cast<const __nv_bfloat16*>(&zero_frag);
|
||||
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
|
||||
__nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag);
|
||||
|
||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
@ -335,11 +335,11 @@ public:
|
||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, const int warp_idx_n, const int lane_idx)
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
|
||||
{
|
||||
const int warp_offset = warp_idx_n * Shape::kN;
|
||||
const int quad = lane_idx / 4;
|
||||
const int thread_offset = warp_offset + quad;
|
||||
int const warp_offset = warp_idx_n * Shape::kN;
|
||||
int const quad = lane_idx / 4;
|
||||
int const thread_offset = warp_offset + quad;
|
||||
pointer_scale_ = smem_scales.data() + thread_offset;
|
||||
if constexpr (hasZero(QuantOp))
|
||||
{
|
||||
@ -348,7 +348,7 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
|
||||
{
|
||||
}
|
||||
@ -364,7 +364,7 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||
{
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
@ -406,7 +406,7 @@ public:
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(
|
||||
FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag, const FragmentScale& zero_frag)
|
||||
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
|
||||
{
|
||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||
@ -505,11 +505,11 @@ public:
|
||||
static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||
{
|
||||
const int warp_offset = warp_idx_n * Shape::kN;
|
||||
const int base_col = lane_idx & 0xF8;
|
||||
const int thread_offset = warp_offset + base_col;
|
||||
int const warp_offset = warp_idx_n * Shape::kN;
|
||||
int const base_col = lane_idx & 0xF8;
|
||||
int const thread_offset = warp_offset + base_col;
|
||||
pointer_ = smem_scales.data() + thread_offset;
|
||||
}
|
||||
|
||||
@ -527,7 +527,7 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||
{
|
||||
static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, "");
|
||||
|
||||
@ -591,11 +591,11 @@ public:
|
||||
static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
||||
|
||||
CUTLASS_DEVICE
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
||||
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||
{
|
||||
const int warp_offset = warp_idx_n * Shape::kN;
|
||||
const int base_col = lane_idx & 0xF8 + lane_idx % 4;
|
||||
const int thread_offset = warp_offset + base_col;
|
||||
int const warp_offset = warp_idx_n * Shape::kN;
|
||||
int const base_col = lane_idx & 0xF8 + lane_idx % 4;
|
||||
int const thread_offset = warp_offset + base_col;
|
||||
pointer_ = smem_scales.data() + thread_offset;
|
||||
}
|
||||
|
||||
@ -617,7 +617,7 @@ public:
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
||||
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||
{
|
||||
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
static constexpr int total_n_mmas = 2 * TileNIterations;
|
||||
|
||||
@ -167,8 +167,8 @@ public:
|
||||
|
||||
static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment;
|
||||
|
||||
const int thread_row = thread_id / THREADS_PER_ROW;
|
||||
const int thread_col = thread_id % THREADS_PER_ROW;
|
||||
int const thread_row = thread_id / THREADS_PER_ROW;
|
||||
int const thread_col = thread_id % THREADS_PER_ROW;
|
||||
|
||||
const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits<Element>::value / 8;
|
||||
const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits<Element>::value / 8;
|
||||
@ -182,11 +182,11 @@ public:
|
||||
// a given iteration. The same threads will be responsible for issues reads since the number of scales
|
||||
// read in a given iteration is a constant. Therefore, we should never have to update is_valid_
|
||||
// outside of the constructor.
|
||||
const int global_row = threadblock_offset.row() + thread_row;
|
||||
const int global_col = threadblock_offset.column() + thread_col * kAlignment;
|
||||
int const global_row = threadblock_offset.row() + thread_row;
|
||||
int const global_col = threadblock_offset.column() + thread_col * kAlignment;
|
||||
|
||||
const bool row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow;
|
||||
const bool col_in_bounds = global_col < extent.column();
|
||||
bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow;
|
||||
bool const col_in_bounds = global_col < extent.column();
|
||||
|
||||
is_valid_ = row_in_bounds && col_in_bounds;
|
||||
}
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4201c7241d53298ca52d4f1447cc9cbc4024f63b42a24cbcff82192cc10bed67
|
||||
size 576098
|
||||
oid sha256:e1cdcabfbc5115c0d3228c567800d2706f1bc9e3752aaaa8148bcfe83be2c08c
|
||||
size 716756
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2960feb2c7ad941a473408e2f6fd8c324f60f6af3c4d8f11217c676fd830e4cb
|
||||
size 578660
|
||||
oid sha256:ea48a79b211bc9857e7a881d6b9bc22580280e1d7cf3b30d6613466f4f440f8f
|
||||
size 721934
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
8a8d6505d9ef62cb2eeb8c75a5ee5bbb libtensorrt_llm_executor_static.a
|
||||
e3b8edc619c99a7f125fe81bc8554ff0 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
230623fa285048a2de5c54c2cc0f364fb9f2c559 commit
|
||||
56853a19cf213aa5330ea087c9d86a60 libtensorrt_llm_executor_static.a
|
||||
213487d55c816a1987aa79547091068f libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
741fb083cc42933439ae54557b177b6d7064da4f commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:cde295fa290b15b3d76b8e8b2cc435d7fceb2f456d8cb4d9b22ee2cf3ddbd344
|
||||
size 588504
|
||||
oid sha256:499f3aac1b98c5b411f1dacdddf8521b2b1f600388b44e6f7aab5b3f0cdf1280
|
||||
size 721366
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:54ac66f3555bff4ed28ba0352bcb4a0f541346592cf109b491071b6374e5238c
|
||||
size 562260
|
||||
oid sha256:9c2c7e84be6b0e8baf296196ee9d7e84509bda2630ce3ada8a39dc498713ff48
|
||||
size 700000
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
ee96c6e2742539da0e8d732635f84449 libtensorrt_llm_executor_static.a
|
||||
9154564ed926ffbcdb83e7eac3504fa0 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
dcca3b095dad76dac36611be6104f011 libtensorrt_llm_executor_static.a
|
||||
6cae7ce493704f7ad8d724cf8a538e2c libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
|
||||
@ -25,9 +25,9 @@ namespace kernels
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
__global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
|
||||
const int** parent_ids_buf, const int* batch_slots, int batch_size, int beam_width, int max_seq_len,
|
||||
const int* no_repeat_ngram_size_buf, int vocab_size_padded, const int* sequence_lengths)
|
||||
__global__ void ban_repeat_ngram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf,
|
||||
int const** parent_ids_buf, int const* batch_slots, int batch_size, int beam_width, int max_seq_len,
|
||||
int const* no_repeat_ngram_size_buf, int vocab_size_padded, int const* sequence_lengths)
|
||||
{
|
||||
/**
|
||||
* Find subsequences that match the last (ngram_size - 1) generated tokens. The next-tokens of those matching
|
||||
@ -46,13 +46,13 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const Fi
|
||||
* in-bound positions only. For leftside out-of-boundary tokens, access by global memory.
|
||||
*/
|
||||
|
||||
const int output_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
const int local_batch_idx = blockIdx.y / beam_width;
|
||||
int const output_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int const local_batch_idx = blockIdx.y / beam_width;
|
||||
auto const batch_slot = batch_slots != nullptr ? batch_slots[local_batch_idx] : local_batch_idx;
|
||||
const int beam_idx = blockIdx.y % beam_width;
|
||||
const bool beam_search = beam_width > 1;
|
||||
const int no_repeat_ngram_size = no_repeat_ngram_size_buf[batch_slot];
|
||||
const int step = sequence_lengths[batch_slot];
|
||||
int const beam_idx = blockIdx.y % beam_width;
|
||||
bool const beam_search = beam_width > 1;
|
||||
int const no_repeat_ngram_size = no_repeat_ngram_size_buf[batch_slot];
|
||||
int const step = sequence_lengths[batch_slot];
|
||||
|
||||
// case 1: ngram_size == 0 --> this means no ngram limit
|
||||
// case 2: generated length must be greater than ngram_size to do ngram check
|
||||
@ -133,9 +133,9 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const Fi
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
|
||||
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, int beam_width,
|
||||
int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream)
|
||||
void invokeBanRepeatNgram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf,
|
||||
int const** parent_ids_buf, int const* batch_slot, int const* sequence_lengths, int batch_size, int beam_width,
|
||||
int max_seq_len, int const* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream)
|
||||
{
|
||||
// each input in the local batch can have different no_repeat_ngram_size. Use max for shmem allocation
|
||||
// getting the max of current batch and allocate shmem as needed is ideal. But here the ngram_buf is on GPU, while
|
||||
|
||||
@ -26,9 +26,9 @@ namespace kernels
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
|
||||
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, int beam_width,
|
||||
int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream);
|
||||
void invokeBanRepeatNgram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf,
|
||||
int const** parent_ids_buf, int const* batch_slot, int const* sequence_lengths, int batch_size, int beam_width,
|
||||
int max_seq_len, int const* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -49,8 +49,8 @@ __device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float
|
||||
|
||||
template <typename T, int MAX_K, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
void beam_topK_kernel(const T* log_probs, int* topk_tmp_id_buf, T* topk_tmp_val_buf, const bool* finished,
|
||||
const int* sequence_lengths, const int vocab_size, T diversity_rate, float length_penalty)
|
||||
void beam_topK_kernel(T const* log_probs, int* topk_tmp_id_buf, T* topk_tmp_val_buf, bool const* finished,
|
||||
int const* sequence_lengths, int const vocab_size, T diversity_rate, float length_penalty)
|
||||
{
|
||||
typedef cub::BlockReduce<TopK<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
@ -59,7 +59,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
int block_id = blockIdx.x; // batch beam index.
|
||||
TopK<T, MAX_K> partial;
|
||||
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
#pragma unroll
|
||||
@ -101,7 +101,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
{
|
||||
int thread_id = threadIdx.x;
|
||||
int block_id = blockIdx.x;
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
TopK<T, MAX_K> partial;
|
||||
if (thread_id == 0)
|
||||
@ -136,7 +136,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
int tid = threadIdx.x;
|
||||
int bid = blockIdx.x;
|
||||
TopK<T, MAX_K> partial;
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
#pragma unroll
|
||||
@ -167,32 +167,32 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
|
||||
__global__ void topk_stage_1_opt3(const T* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf,
|
||||
T* topk_tmp_val_buf, const bool* finished, const int* sequence_lengths, const int k, const int vocab_size,
|
||||
const float length_penalty, const int* end_ids)
|
||||
__global__ void topk_stage_1_opt3(T const* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf,
|
||||
T* topk_tmp_val_buf, bool const* finished, int const* sequence_lengths, int const k, int const vocab_size,
|
||||
float const length_penalty, int const* end_ids)
|
||||
{
|
||||
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
int const tid = threadIdx.x;
|
||||
int const bid = blockIdx.x;
|
||||
|
||||
const int row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs (batchbeam index)
|
||||
const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam
|
||||
const int tmp_log_buf_index = row_id * vocab_size;
|
||||
const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k;
|
||||
int const row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs (batchbeam index)
|
||||
int const block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam
|
||||
int const tmp_log_buf_index = row_id * vocab_size;
|
||||
int const tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k;
|
||||
TopK_2<T> partial;
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
if (finished != nullptr && finished[row_id] == true)
|
||||
{
|
||||
if (tid < k)
|
||||
{
|
||||
const int index = tmp_topk_buf_index + tid;
|
||||
int const index = tmp_topk_buf_index + tid;
|
||||
if (block_lane == 0 && tid == 0)
|
||||
{
|
||||
const int end_id = end_ids[row_id / k];
|
||||
int const end_id = end_ids[row_id / k];
|
||||
topk_tmp_id_buf[index] = tmp_log_buf_index + end_id;
|
||||
topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id];
|
||||
}
|
||||
@ -226,7 +226,7 @@ __global__ void topk_stage_1_opt3(const T* __restrict log_probs, T* tmp_log_prob
|
||||
|
||||
if (tid == 0)
|
||||
{
|
||||
const int index = tmp_topk_buf_index + ite;
|
||||
int const index = tmp_topk_buf_index + ite;
|
||||
topk_tmp_id_buf[index] = total.p;
|
||||
topk_tmp_val_buf[index] = total.u;
|
||||
tmp_log_probs[total.p] = -MAX_T_VAL;
|
||||
@ -236,15 +236,15 @@ __global__ void topk_stage_1_opt3(const T* __restrict log_probs, T* tmp_log_prob
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
|
||||
__global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids,
|
||||
BeamHypotheses beam_hyps, const int* end_ids, const int vocab_size, const int k)
|
||||
__global__ void topk_stage_2_opt3(int const* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids,
|
||||
BeamHypotheses beam_hyps, int const* end_ids, int const vocab_size, int const k)
|
||||
{
|
||||
const int size = k * k * BLOCKS_PER_BEAM_;
|
||||
const int tid = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
int const size = k * k * BLOCKS_PER_BEAM_;
|
||||
int const tid = threadIdx.x;
|
||||
int const batch_id = blockIdx.x;
|
||||
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
const float length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]};
|
||||
float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]};
|
||||
|
||||
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
@ -263,7 +263,7 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk
|
||||
__syncthreads();
|
||||
if (beam_hyps.num_beams != nullptr)
|
||||
{
|
||||
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
||||
int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
||||
if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0)
|
||||
{
|
||||
// initialize the buffer
|
||||
@ -304,9 +304,9 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk
|
||||
}
|
||||
else
|
||||
{
|
||||
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
||||
const float normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty);
|
||||
const int num_beam = beam_hyps.num_beams[global_batch_idx];
|
||||
int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
||||
float const normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty);
|
||||
int const num_beam = beam_hyps.num_beams[global_batch_idx];
|
||||
int beam_idx = num_beam;
|
||||
// If there are beam_width finished sentences, check that the score of
|
||||
// selected candidatet is higher than min_normed_score or not. If
|
||||
@ -345,20 +345,20 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk
|
||||
}
|
||||
}
|
||||
}
|
||||
const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx)
|
||||
int const tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx)
|
||||
* beam_hyps.max_seq_len;
|
||||
beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id];
|
||||
|
||||
int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k;
|
||||
for (int j = beam_hyps.step - 1; j >= 0; j--)
|
||||
{
|
||||
const int src_idx = j * beam_hyps.batch_size * k
|
||||
int const src_idx = j * beam_hyps.batch_size * k
|
||||
+ beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id;
|
||||
|
||||
beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx];
|
||||
prev_id = beam_hyps.parent_ids_src[src_idx];
|
||||
}
|
||||
const int tgt_beam_idx = global_batch_idx * k + beam_idx;
|
||||
int const tgt_beam_idx = global_batch_idx * k + beam_idx;
|
||||
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step;
|
||||
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
|
||||
beam_hyps.min_normed_scores[global_batch_idx]
|
||||
@ -389,21 +389,21 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM>
|
||||
__global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf,
|
||||
T* topk_tmp_val_buf, const bool* finished, const int* sequence_lengths, const int k, const int vocab_size,
|
||||
const float length_penalty)
|
||||
__global__ void topk_stage_1_opt2_general(T const* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf,
|
||||
T* topk_tmp_val_buf, bool const* finished, int const* sequence_lengths, int const k, int const vocab_size,
|
||||
float const length_penalty)
|
||||
{
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
const int bid = blockIdx.x;
|
||||
const int row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs
|
||||
const int block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam
|
||||
const int tmp_log_buf_index = row_id * vocab_size;
|
||||
const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k;
|
||||
int const tid = threadIdx.x;
|
||||
int const bid = blockIdx.x;
|
||||
int const row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs
|
||||
int const block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam
|
||||
int const tmp_log_buf_index = row_id * vocab_size;
|
||||
int const tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k;
|
||||
TopK_2<T> partial;
|
||||
|
||||
for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM)
|
||||
@ -426,7 +426,7 @@ __global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, T* tmp_
|
||||
|
||||
if (tid == 0)
|
||||
{
|
||||
const int index = tmp_topk_buf_index + ite;
|
||||
int const index = tmp_topk_buf_index + ite;
|
||||
topk_tmp_id_buf[index] = total.p;
|
||||
topk_tmp_val_buf[index] = total.u;
|
||||
tmp_log_probs[total.p] = -MAX_T_VAL;
|
||||
@ -436,15 +436,15 @@ __global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, T* tmp_
|
||||
}
|
||||
|
||||
template <typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM>
|
||||
__global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids,
|
||||
BeamHypotheses beam_hyps, const int* end_ids, const int k, const int vocab_size)
|
||||
__global__ void topk_stage_2_opt2_general(int const* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids,
|
||||
BeamHypotheses beam_hyps, int const* end_ids, int const k, int const vocab_size)
|
||||
{
|
||||
const int size = k * k * BLOCKS_PER_BEAM;
|
||||
const int tid = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
||||
int const size = k * k * BLOCKS_PER_BEAM;
|
||||
int const tid = threadIdx.x;
|
||||
int const batch_id = blockIdx.x;
|
||||
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||
const float length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]};
|
||||
float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]};
|
||||
|
||||
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
@ -463,7 +463,7 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf,
|
||||
__syncthreads();
|
||||
if (beam_hyps.num_beams != nullptr)
|
||||
{
|
||||
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
||||
int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
||||
if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0)
|
||||
{
|
||||
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
|
||||
@ -503,9 +503,9 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf,
|
||||
}
|
||||
else
|
||||
{
|
||||
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
||||
const float normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty);
|
||||
const int num_beam = beam_hyps.num_beams[global_batch_idx];
|
||||
int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
||||
float const normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty);
|
||||
int const num_beam = beam_hyps.num_beams[global_batch_idx];
|
||||
int beam_idx = num_beam;
|
||||
// If there are beam_width finished sentences, check that the score of
|
||||
// selected candidatet is higher than min_normed_score or not. If
|
||||
@ -544,20 +544,20 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf,
|
||||
}
|
||||
}
|
||||
}
|
||||
const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx)
|
||||
int const tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx)
|
||||
* beam_hyps.max_seq_len;
|
||||
beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id];
|
||||
|
||||
int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k;
|
||||
for (int j = beam_hyps.step - 1; j >= 0; j--)
|
||||
{
|
||||
const int src_idx = j * beam_hyps.batch_size * k
|
||||
int const src_idx = j * beam_hyps.batch_size * k
|
||||
+ beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id;
|
||||
|
||||
beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx];
|
||||
prev_id = beam_hyps.parent_ids_src[src_idx];
|
||||
}
|
||||
const int tgt_beam_idx = global_batch_idx * k + beam_idx;
|
||||
int const tgt_beam_idx = global_batch_idx * k + beam_idx;
|
||||
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step;
|
||||
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
|
||||
beam_hyps.min_normed_scores[global_batch_idx]
|
||||
@ -613,18 +613,18 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf,
|
||||
|
||||
template <typename T>
|
||||
void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps,
|
||||
const bool* finished, const int* sequence_lengths, const int batch_size, const int beam_width,
|
||||
const int vocab_size_padded_, const T diversity_rate, const float length_penalty, const int* end_ids,
|
||||
bool const* finished, int const* sequence_lengths, int const batch_size, int const beam_width,
|
||||
int const vocab_size_padded_, const T diversity_rate, float const length_penalty, int const* end_ids,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
// log_probs: (batch, beam, vocab) cumulative log_probs of beams ending with a
|
||||
// token.
|
||||
const int vocab_size = vocab_size_padded_;
|
||||
int const vocab_size = vocab_size_padded_;
|
||||
// Beam size should be less than or equal to vocab size.
|
||||
assert(beam_width <= vocab_size);
|
||||
// Beam search needs the sequence lengths of beams to apply length penalty.
|
||||
assert(length_penalty == 0.0f || sequence_lengths != nullptr);
|
||||
const int max_block_per_beam = 8;
|
||||
int const max_block_per_beam = 8;
|
||||
int temp_log_probs_buf_size = batch_size * beam_width * vocab_size; // type float
|
||||
int topk_tmp_ids_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type int
|
||||
int topk_tmp_val_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type float
|
||||
@ -685,13 +685,13 @@ void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs,
|
||||
#undef CASE_K_DIV
|
||||
|
||||
template void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, float* log_probs, int* ids,
|
||||
BeamHypotheses* beam_hyps, const bool* finished, const int* sequence_lengths, const int batch_size,
|
||||
const int beam_width, const int vocab_size_padded_, const float diversity_rate, const float length_penalty,
|
||||
const int* end_ids, cudaStream_t stream);
|
||||
BeamHypotheses* beam_hyps, bool const* finished, int const* sequence_lengths, int const batch_size,
|
||||
int const beam_width, int const vocab_size_padded_, float const diversity_rate, float const length_penalty,
|
||||
int const* end_ids, cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
__global__ void tileEncoderResults(T* tiled_output, int* tiled_sequence_length, const T* output,
|
||||
const int* sequence_length, const uint32_t batch_size, const uint32_t beam_width, const uint32_t d_model)
|
||||
__global__ void tileEncoderResults(T* tiled_output, int* tiled_sequence_length, T const* output,
|
||||
int const* sequence_length, const uint32_t batch_size, const uint32_t beam_width, const uint32_t d_model)
|
||||
{
|
||||
if (blockIdx.x == 0)
|
||||
{
|
||||
@ -711,7 +711,7 @@ __global__ void tileEncoderResults(T* tiled_output, int* tiled_sequence_length,
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeTileEncoderResults(T* tiled_output, int* tiled_sequence_length, const T* output, const int* sequence_length,
|
||||
void invokeTileEncoderResults(T* tiled_output, int* tiled_sequence_length, T const* output, int const* sequence_length,
|
||||
const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
@ -739,30 +739,30 @@ void invokeTileEncoderResults(T* tiled_output, int* tiled_sequence_length, const
|
||||
}
|
||||
}
|
||||
|
||||
template void invokeTileEncoderResults(float* tiled_output, int* tiled_sequence_length, const float* output,
|
||||
const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
||||
template void invokeTileEncoderResults(float* tiled_output, int* tiled_sequence_length, float const* output,
|
||||
int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
||||
const size_t d_model, cudaStream_t stream);
|
||||
|
||||
template void invokeTileEncoderResults(half* tiled_output, int* tiled_sequence_length, const half* output,
|
||||
const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
||||
template void invokeTileEncoderResults(half* tiled_output, int* tiled_sequence_length, half const* output,
|
||||
int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
||||
const size_t d_model, cudaStream_t stream);
|
||||
|
||||
template void invokeTileEncoderResults(half2* tiled_output, int* tiled_sequence_length, const half2* output,
|
||||
const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
||||
template void invokeTileEncoderResults(half2* tiled_output, int* tiled_sequence_length, half2 const* output,
|
||||
int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
||||
const size_t d_model, cudaStream_t stream);
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeTileEncoderResults(__nv_bfloat16* tiled_output, int* tiled_sequence_length,
|
||||
const __nv_bfloat16* output, const int* sequence_length, const size_t batch_size, const size_t beam_width,
|
||||
__nv_bfloat16 const* output, int const* sequence_length, const size_t batch_size, const size_t beam_width,
|
||||
const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream);
|
||||
#endif
|
||||
|
||||
__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished,
|
||||
const float* cum_log_probs, const int batch_size, const int beam_width)
|
||||
__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished,
|
||||
float const* cum_log_probs, int const batch_size, int const beam_width)
|
||||
{
|
||||
const int bid = blockIdx.x;
|
||||
const int tgt_start_idx = beam_hyps.num_beams[bid];
|
||||
const int max_seq_len{beam_hyps.max_seq_len};
|
||||
const float length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[bid]};
|
||||
int const bid = blockIdx.x;
|
||||
int const tgt_start_idx = beam_hyps.num_beams[bid];
|
||||
int const max_seq_len{beam_hyps.max_seq_len};
|
||||
float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[bid]};
|
||||
if (beam_hyps.is_done[bid])
|
||||
{
|
||||
return;
|
||||
@ -771,10 +771,10 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedSta
|
||||
{
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
const int src_beam_idx = bid * beam_width + beam_idx;
|
||||
const int tgt_beam_idx = bid * beam_width * 2 + beam_idx + tgt_start_idx;
|
||||
int const src_beam_idx = bid * beam_width + beam_idx;
|
||||
int const tgt_beam_idx = bid * beam_width * 2 + beam_idx + tgt_start_idx;
|
||||
|
||||
const int last_token_idx = beam_hyps.sequence_lengths_src[src_beam_idx] - 1;
|
||||
int const last_token_idx = beam_hyps.sequence_lengths_src[src_beam_idx] - 1;
|
||||
|
||||
beam_hyps.output_ids_tgt[tgt_beam_idx * max_seq_len + last_token_idx]
|
||||
= beam_hyps.output_ids_src[src_beam_idx * max_seq_len + last_token_idx];
|
||||
@ -810,8 +810,8 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedSta
|
||||
}
|
||||
}
|
||||
|
||||
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, const float* cum_log_probs,
|
||||
const int batch_size, const int beam_width, cudaStream_t stream)
|
||||
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished, float const* cum_log_probs,
|
||||
int const batch_size, int const beam_width, cudaStream_t stream)
|
||||
{
|
||||
insertUnfinishedPath<<<batch_size, 256, 0, stream>>>(beam_hyps, finished, cum_log_probs, batch_size, beam_width);
|
||||
}
|
||||
|
||||
@ -35,57 +35,64 @@ namespace kernels
|
||||
// After we collect `beam_width` beams, we will sort them by their norm_scores.
|
||||
struct BeamHypotheses
|
||||
{
|
||||
// TODO: simplify the pointers
|
||||
// Pointers initialized in function prepareOutputs in gptDecoder.cpp
|
||||
bool* is_done{nullptr}; // [batchSize], whether the batch is finished
|
||||
const int* input_lengths{nullptr}; // [batchSize]
|
||||
float* cum_log_probs{nullptr}; // [batchSize, 2 * beamWidth], outputs.cum_log_probs->template getPtr<float>()
|
||||
float* log_probs{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen], not used?
|
||||
float* min_normed_scores{nullptr}; // [batchSize], worst normed scores for each batch
|
||||
float* normed_scores{nullptr}; // [batchSize, 2 * beamWidth], cum_log / (length ^ length_penalty)
|
||||
int* num_beams{nullptr}; // [batchSize], count of finished beams for each batch
|
||||
int* output_ids_tgt{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen],
|
||||
int* sequence_lengths_tgt{nullptr}; // [batchSize, 2 * beamWidth], different from sequence_lengths_src
|
||||
// BS: batch_size
|
||||
// BM: beam_width
|
||||
// mSL: max_seq_length
|
||||
// %%: parameter name when we call [generation.py] dynamic_decoder.forward
|
||||
|
||||
// Pointers initialized in function invokeSoftMax in onlineBeamSearchLayer.cu
|
||||
const int* end_ids{nullptr}; // get from SoftmaxParams
|
||||
const int* output_ids_src{nullptr}; // for gatherTree
|
||||
const int* parent_ids_src{nullptr}; // for gatherTree
|
||||
const int** output_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading
|
||||
const int** parent_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading
|
||||
float* log_probs_src{nullptr}; // get from outputs.output_log_probs
|
||||
int* sequence_lengths_src{nullptr}; // get from BeamSearchOutputParams
|
||||
// For reading in function invokeTopkSoftMax but reading and writing in function invokeUpdate
|
||||
int** output_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing
|
||||
int** parent_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing
|
||||
// Pointers initialized in these two functions:
|
||||
// [gptDecoder.cpp] GptDecoder<T>::forward or [dynamicDecodeOp.cpp] FtDynamicDecode<T>::forward
|
||||
bool* is_done{nullptr}; // [BS] %% self.beam_hyps_is_done
|
||||
float* cum_log_probs{nullptr}; // [BS, BM*2] %% self.beam_hyps_cum_log_probs
|
||||
float* log_probs{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_log_probs
|
||||
float* min_normed_scores{nullptr}; // [BS] %% self.beam_hyps_min_normed_scores
|
||||
float* normed_scores{nullptr}; // [BS, BM*2] %% self.beam_hyps_normed_scores
|
||||
int* num_beams{nullptr}; // [BS] %% self.beam_hyps_num_beams
|
||||
int* output_ids_tgt{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_is_done
|
||||
int* sequence_lengths_tgt{nullptr}; // [BS, BM*2] %% self.beam_hyps_sequence_lengths_tgt
|
||||
int const* input_lengths{nullptr}; // [BS*BM] %% context_length
|
||||
|
||||
// Other scalar values and buffers
|
||||
int batch_size{0};
|
||||
int beam_width{0};
|
||||
int ite{0};
|
||||
int local_batch_size{0};
|
||||
int max_seq_len{0};
|
||||
int step{0}; // useless in online version of beam search
|
||||
int vocab_size{0};
|
||||
float* diversity_rates{nullptr};
|
||||
float* length_penalties{nullptr};
|
||||
int* early_stoppings{nullptr};
|
||||
bool is_return_normed_score{true}; // return normed_cum_log_probs or cum_log_probs
|
||||
// Pointers initialized in [onlineBeamSearchLayer.cu] invokeSoftMax:
|
||||
int const* end_ids{nullptr}; // [BS*BM] %% self.end_ids
|
||||
FinishedState* finished; // [BS*BM] %% self.finished
|
||||
float* cum_log_probs_src{nullptr}; // [BS, BM] %% self.cum_log_probs
|
||||
float* log_probs_src{nullptr}; // [mSL, BS, BM] %% self.log_probs_tiled
|
||||
int* sequence_lengths_src{nullptr}; // [BS*BM] %% self.sequence_length_buffer
|
||||
int** output_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] from [dynamicDecodeLayer.cpp]
|
||||
int** parent_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] from [dynamicDecodeLayer.cpp]
|
||||
|
||||
float* diversity_rates{nullptr}; // [BS] from SamplingConfig
|
||||
float* length_penalties{nullptr}; // [BS] from SamplingConfig
|
||||
int* early_stoppings{nullptr}; // [BS] from SamplingConfig
|
||||
|
||||
// Pointers for function gatherTree
|
||||
int const* output_ids_src{nullptr}; //
|
||||
int const* parent_ids_src{nullptr}; //
|
||||
|
||||
// Scalar values
|
||||
bool is_return_normed_score{true}; // return normed_cum_log_probs or cum_log_probs, always be true now
|
||||
int batch_size{0}; //
|
||||
int beam_width{0}; //
|
||||
int ite{0}; // index of local_batch, always be 0 if pp_size==1
|
||||
int local_batch_size{0}; //
|
||||
int max_seq_len{0}; //
|
||||
int step{0}; // only used in [beamSearchTopkKernels.cu], always be 0 in [onlineSoftmaxBeamsearchKernels*.cu.h]
|
||||
int vocab_size{0}; // vocab_size_padded
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps,
|
||||
const bool* finished, const int* sequence_lengths, const int batch_size, const int beam_width,
|
||||
const int vocab_size_padded_, const T diversity_rate, const float length_penalty, const int* end_ids,
|
||||
bool const* finished, int const* sequence_lengths, int const batch_size, int const beam_width,
|
||||
int const vocab_size_padded_, const T diversity_rate, float const length_penalty, int const* end_ids,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void invokeTileEncoderResults(T* tiled_encoder_output, int* tiled_encoder_sequence_length, const T* encoder_output,
|
||||
const int* encoder_sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
||||
void invokeTileEncoderResults(T* tiled_encoder_output, int* tiled_encoder_sequence_length, T const* encoder_output,
|
||||
int const* encoder_sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
||||
const size_t d_model, cudaStream_t stream);
|
||||
|
||||
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, const float* cum_log_probs,
|
||||
const int batch_size, const int beam_width, cudaStream_t stream);
|
||||
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished, float const* cum_log_probs,
|
||||
int const batch_size, int const beam_width, cudaStream_t stream);
|
||||
|
||||
void invokeCopyBatchMajorToGeneralPtr(
|
||||
void* output_ids_ptr, int* output_ids, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream);
|
||||
|
||||
@ -58,13 +58,13 @@ static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype)
|
||||
else if (dtype == DATA_TYPE_INT32)
|
||||
{
|
||||
int32_t inorm = static_cast<int32_t>(norm);
|
||||
alpha = reinterpret_cast<const uint32_t&>(inorm);
|
||||
alpha = reinterpret_cast<uint32_t const&>(inorm);
|
||||
}
|
||||
else if (dtype == DATA_TYPE_BF16)
|
||||
{
|
||||
// TODO HACK!! BF16 Outputs are computed in FP32 for FP8.
|
||||
// This is because cublas does not allow current FP32 output.
|
||||
alpha = reinterpret_cast<const uint32_t&>(norm);
|
||||
alpha = reinterpret_cast<uint32_t const&>(norm);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -77,7 +77,7 @@ static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype)
|
||||
class FusedMHARunnerV2::mhaImpl
|
||||
{
|
||||
public:
|
||||
mhaImpl(const Data_type data_type, const int numHeads, const int headSize, const float qScaling, int sm_)
|
||||
mhaImpl(const Data_type data_type, int const numHeads, int const headSize, float const qScaling, int sm_)
|
||||
: mDataType(data_type)
|
||||
, mNumHeads(numHeads)
|
||||
, mHeadSize(headSize)
|
||||
@ -105,17 +105,17 @@ public:
|
||||
|
||||
// Shared setup function.
|
||||
template <typename Params>
|
||||
void setup_params(Params& params, const int b, const int s_q, const int s_kv, const int sliding_window_size,
|
||||
const int total_seqlen, const bool has_alibi, const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
void setup_params(Params& params, int const b, int const s_q, int const s_kv, int const sliding_window_size,
|
||||
int const total_seqlen, bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||
{
|
||||
|
||||
const float inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling));
|
||||
float const inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling));
|
||||
// Note that we apply scales and bias in the order of
|
||||
// (bmm1_output * scale_bmm1 + alibi) * scale_after_alibi
|
||||
const float scale_after_alibi = scale_alibi ? inv_sqrt_scale : 1.0f;
|
||||
const float scale_bmm1 = scale_alibi ? 1.0f : inv_sqrt_scale;
|
||||
const float scale_softmax = 1.f; // Seems to be only required for int8
|
||||
const float scale_bmm2 = 1.f;
|
||||
float const scale_after_alibi = scale_alibi ? inv_sqrt_scale : 1.0f;
|
||||
float const scale_bmm1 = scale_alibi ? 1.0f : inv_sqrt_scale;
|
||||
float const scale_softmax = 1.f; // Seems to be only required for int8
|
||||
float const scale_bmm2 = 1.f;
|
||||
|
||||
Data_type scale_type = mLaunchParams.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
|
||||
// Use exp2f optimization for warp-specialized ws kernels on Hopper.
|
||||
@ -153,8 +153,8 @@ public:
|
||||
}
|
||||
|
||||
// Support packed QKV.
|
||||
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
|
||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen, bool const has_alibi,
|
||||
bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||
{
|
||||
|
||||
// Determine launch parameters.
|
||||
@ -165,10 +165,10 @@ public:
|
||||
TLLM_CHECK_WITH_INFO(mHeadSize > 0, "Head size should be greater than 0.");
|
||||
mLaunchParams.padded_d = (mHeadSize & (mHeadSize - 1)) == 0 ? mHeadSize : pow(2, int(log2(mHeadSize)) + 1);
|
||||
|
||||
const bool isSm70 = (sm == kSM_70);
|
||||
const bool isSm90 = (sm == kSM_90);
|
||||
const bool isSm8x = (sm == kSM_86 || sm == kSM_89);
|
||||
const bool isSm80 = (sm == kSM_80);
|
||||
bool const isSm70 = (sm == kSM_70);
|
||||
bool const isSm90 = (sm == kSM_90);
|
||||
bool const isSm8x = (sm == kSM_86 || sm == kSM_89);
|
||||
bool const isSm80 = (sm == kSM_80);
|
||||
if (isSm70)
|
||||
{
|
||||
mLaunchParams.flash_attention = true;
|
||||
@ -238,9 +238,9 @@ public:
|
||||
}
|
||||
|
||||
// Support paged_kv_cache and chunked_attention.
|
||||
void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
|
||||
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
|
||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
|
||||
int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen, bool const has_alibi,
|
||||
bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||
{
|
||||
|
||||
// Determine launch parameters.
|
||||
@ -253,9 +253,9 @@ public:
|
||||
mLaunchParams.padded_d = (mHeadSize & (mHeadSize - 1)) == 0 ? mHeadSize : pow(2, int(log2(mHeadSize)) + 1);
|
||||
|
||||
// Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256
|
||||
const bool isSm90 = (sm == kSM_90);
|
||||
const bool isSm8x = (sm == kSM_86 || sm == kSM_89);
|
||||
const bool isSm80 = (sm == kSM_80);
|
||||
bool const isSm90 = (sm == kSM_90);
|
||||
bool const isSm8x = (sm == kSM_86 || sm == kSM_89);
|
||||
bool const isSm80 = (sm == kSM_80);
|
||||
|
||||
// always use flash attention kernels.
|
||||
mLaunchParams.flash_attention = true;
|
||||
@ -383,7 +383,7 @@ public:
|
||||
|
||||
// QKV [TOTAL, 3, h, d]
|
||||
// NOTE: we may need to use actual seqlen to set oob_value
|
||||
const char* qkv_ptr = reinterpret_cast<const char*>(mParams.qkv_ptr);
|
||||
char const* qkv_ptr = reinterpret_cast<char const*>(mParams.qkv_ptr);
|
||||
tensor_size_qkv[3] = mTotalSeqLen;
|
||||
|
||||
// Q: STEP_Q
|
||||
@ -467,7 +467,7 @@ public:
|
||||
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
|
||||
|
||||
// Q ptr.
|
||||
const char* q_ptr = reinterpret_cast<const char*>(mPagedKVParams.q_ptr);
|
||||
char const* q_ptr = reinterpret_cast<char const*>(mPagedKVParams.q_ptr);
|
||||
|
||||
// Q: STEP_Q.
|
||||
q_tma_descriptor.set_tma_desctriptor(q_ptr, cudaTmaDescFormat::F16_RN,
|
||||
@ -518,7 +518,7 @@ public:
|
||||
paged_kv_tma_descriptor.copy_to_device(mPagedKVParams.tma_desc_paged_kv, stream);
|
||||
}
|
||||
|
||||
void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, const int num_kv_heads)
|
||||
void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask, int const num_kv_heads)
|
||||
{
|
||||
// BF16 FMHA only accumulates on FP32
|
||||
mLaunchParams.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc;
|
||||
@ -541,11 +541,11 @@ public:
|
||||
return MHARunner::fmha_supported(mHeadSize, sm);
|
||||
}
|
||||
|
||||
void run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
|
||||
void run(void const* qkvPtr, void const* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
|
||||
{
|
||||
mParams.qkv_ptr = qkvPtr;
|
||||
mParams.o_ptr = outputPtr;
|
||||
mParams.cu_seqlens = reinterpret_cast<const int*>(cuSeqlenPtr);
|
||||
mParams.cu_seqlens = reinterpret_cast<int const*>(cuSeqlenPtr);
|
||||
|
||||
if (sm == kSM_90 && mLaunchParams.use_tma)
|
||||
{
|
||||
@ -556,8 +556,8 @@ public:
|
||||
xmmaKernel->run(mParams, mLaunchParams, stream);
|
||||
}
|
||||
|
||||
void run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost,
|
||||
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr,
|
||||
void run_paged_kv(void const* qPtr, void* pagedKVTmaDesc, void const* pagedKVBlockPtrsOnHost,
|
||||
const KVBlockArray pagedKVCache, void const* cuQSeqlenPtr, void const* cuKVSeqlenPtr, void* outputPtr,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
KVBlockArrayForContextFMHA pagedKVCacheForContextMHA;
|
||||
@ -568,10 +568,10 @@ public:
|
||||
mPagedKVParams.tma_desc_paged_kv = reinterpret_cast<cudaTmaDesc*>(pagedKVTmaDesc);
|
||||
mPagedKVParams.paged_kv_cache = pagedKVCacheForContextMHA;
|
||||
mPagedKVParams.o_ptr = outputPtr;
|
||||
mPagedKVParams.cu_q_seqlens = reinterpret_cast<const int*>(cuQSeqlenPtr);
|
||||
mPagedKVParams.cu_seqlens = reinterpret_cast<const int*>(cuKVSeqlenPtr);
|
||||
mPagedKVParams.cu_q_seqlens = reinterpret_cast<int const*>(cuQSeqlenPtr);
|
||||
mPagedKVParams.cu_seqlens = reinterpret_cast<int const*>(cuKVSeqlenPtr);
|
||||
// paged kv block device ptrs on host (used by tma descriptors).
|
||||
mLaunchParams.paged_kv_block_ptrs = reinterpret_cast<const int64_t*>(pagedKVBlockPtrsOnHost);
|
||||
mLaunchParams.paged_kv_block_ptrs = reinterpret_cast<int64_t const*>(pagedKVBlockPtrsOnHost);
|
||||
|
||||
if (sm == kSM_90 && mLaunchParams.use_tma)
|
||||
{
|
||||
@ -587,7 +587,7 @@ public:
|
||||
return pagedKVXmmaKernel->isValid(s) && xmmaKernel->isValid(s);
|
||||
}
|
||||
|
||||
int getSFromMaxSeqLen(const int max_seq_len)
|
||||
int getSFromMaxSeqLen(int const max_seq_len)
|
||||
{
|
||||
int S = 1024;
|
||||
|
||||
@ -625,35 +625,35 @@ private:
|
||||
Fused_multihead_attention_paged_kv_params_v2 mPagedKVParams;
|
||||
Launch_params mLaunchParams;
|
||||
int sm;
|
||||
const FusedMultiHeadAttentionXMMAKernelV2* xmmaKernel;
|
||||
const FusedMultiHeadAttentionPagedKVXMMAKernelV2* pagedKVXmmaKernel;
|
||||
FusedMultiHeadAttentionXMMAKernelV2 const* xmmaKernel;
|
||||
FusedMultiHeadAttentionPagedKVXMMAKernelV2 const* pagedKVXmmaKernel;
|
||||
bool use_flash_attention = false;
|
||||
const Data_type mDataType;
|
||||
const int mNumHeads;
|
||||
const int mHeadSize;
|
||||
const float mQScaling;
|
||||
int const mNumHeads;
|
||||
int const mHeadSize;
|
||||
float const mQScaling;
|
||||
int mTotalSeqLen;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
FusedMHARunnerV2::FusedMHARunnerV2(
|
||||
const Data_type data_type, const int numHeads, const int headSize, const float qScaling)
|
||||
const Data_type data_type, int const numHeads, int const headSize, float const qScaling)
|
||||
: pimpl(new mhaImpl(data_type, numHeads, headSize, qScaling, tensorrt_llm::common::getSMVersion()))
|
||||
{
|
||||
}
|
||||
|
||||
FusedMHARunnerV2::~FusedMHARunnerV2() = default;
|
||||
|
||||
void FusedMHARunnerV2::setup(const int b, const int s, const int sliding_window_size, const int total_seqlen,
|
||||
const bool has_alibi, const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
void FusedMHARunnerV2::setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
|
||||
bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||
{
|
||||
pimpl->setup(b, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||
}
|
||||
|
||||
void FusedMHARunnerV2::setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
|
||||
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
|
||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
||||
void FusedMHARunnerV2::setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
|
||||
int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen, bool const has_alibi,
|
||||
bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||
{
|
||||
pimpl->setup_paged_kv(b, s_q, s_kv, blocks_per_context_sequence, tokens_per_kv_block, sliding_window_size,
|
||||
total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||
@ -665,18 +665,18 @@ bool FusedMHARunnerV2::fmha_supported()
|
||||
}
|
||||
|
||||
void FusedMHARunnerV2::setup_flags(
|
||||
const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, const int num_kv_heads)
|
||||
bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask, int const num_kv_heads)
|
||||
{
|
||||
pimpl->setup_flags(force_fp32_acc, is_s_padded, causal_mask, num_kv_heads);
|
||||
}
|
||||
|
||||
void FusedMHARunnerV2::run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
|
||||
void FusedMHARunnerV2::run(void const* qkvPtr, void const* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
|
||||
{
|
||||
pimpl->run(qkvPtr, cuSeqlenPtr, outputPtr, stream);
|
||||
}
|
||||
|
||||
void FusedMHARunnerV2::run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost,
|
||||
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr,
|
||||
void FusedMHARunnerV2::run_paged_kv(void const* qPtr, void* pagedKVTmaDesc, void const* pagedKVBlockPtrsOnHost,
|
||||
const KVBlockArray pagedKVCache, void const* cuQSeqlenPtr, void const* cuKVSeqlenPtr, void* outputPtr,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
pimpl->run_paged_kv(
|
||||
@ -689,7 +689,7 @@ bool FusedMHARunnerV2::isValid(int s) const
|
||||
}
|
||||
|
||||
// static function to check if fmha is supported when building plugins
|
||||
bool MHARunner::fmha_supported(const int headSize, const int sm)
|
||||
bool MHARunner::fmha_supported(int const headSize, int const sm)
|
||||
{
|
||||
if (sm == kSM_70)
|
||||
{
|
||||
|
||||
@ -41,33 +41,33 @@ namespace kernels
|
||||
class MHARunner
|
||||
{
|
||||
public:
|
||||
MHARunner(const Data_type dataType, const int numHeads, const int headSize, const float qScaling);
|
||||
MHARunner(const Data_type dataType, int const numHeads, int const headSize, float const qScaling);
|
||||
|
||||
MHARunner() = default;
|
||||
|
||||
virtual ~MHARunner() = default;
|
||||
|
||||
virtual void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen,
|
||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0)
|
||||
virtual void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
|
||||
bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1, int const tp_rank = 0)
|
||||
= 0;
|
||||
|
||||
virtual void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
|
||||
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen,
|
||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0)
|
||||
virtual void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
|
||||
int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen,
|
||||
bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1, int const tp_rank = 0)
|
||||
= 0;
|
||||
|
||||
static bool fmha_supported(const int headSize, const int sm);
|
||||
static bool fmha_supported(int const headSize, int const sm);
|
||||
|
||||
virtual bool fmha_supported() = 0;
|
||||
|
||||
virtual void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask,
|
||||
const int num_kv_heads /* MQA or GQA */)
|
||||
virtual void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask,
|
||||
int const num_kv_heads /* MQA or GQA */)
|
||||
= 0;
|
||||
|
||||
virtual void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) = 0;
|
||||
virtual void run(void const* input, void const* cu_seqlens, void* output, cudaStream_t stream) = 0;
|
||||
|
||||
virtual void run_paged_kv(const void* q_input, void* paged_kv_tma_desc, const void* paged_kv_block_ptrs_on_host,
|
||||
const KVBlockArray paged_kv_cache, const void* cu_q_seqlens, const void* cu_kv_seqlens, void* output,
|
||||
virtual void run_paged_kv(void const* q_input, void* paged_kv_tma_desc, void const* paged_kv_block_ptrs_on_host,
|
||||
const KVBlockArray paged_kv_cache, void const* cu_q_seqlens, void const* cu_kv_seqlens, void* output,
|
||||
cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
@ -86,28 +86,28 @@ public:
|
||||
class FusedMHARunnerV2 : public MHARunner
|
||||
{
|
||||
public:
|
||||
FusedMHARunnerV2(const Data_type dataType, const int numHeads, const int headSize, const float qScaling);
|
||||
FusedMHARunnerV2(const Data_type dataType, int const numHeads, int const headSize, float const qScaling);
|
||||
|
||||
~FusedMHARunnerV2(); // for pimpl
|
||||
|
||||
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen,
|
||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1,
|
||||
const int tp_rank = 0) override;
|
||||
void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
|
||||
bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1,
|
||||
int const tp_rank = 0) override;
|
||||
|
||||
void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
|
||||
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen,
|
||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1,
|
||||
const int tp_rank = 0) override;
|
||||
void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
|
||||
int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen,
|
||||
bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1,
|
||||
int const tp_rank = 0) override;
|
||||
|
||||
bool fmha_supported() override;
|
||||
|
||||
void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) override;
|
||||
void run_paged_kv(const void* q_input, void* paged_kv_tma_desc, const void* paged_kv_block_ptrs_on_host,
|
||||
const KVBlockArray paged_kv_cache, const void* cu_q_seqlens, const void* cu_kv_seqlens, void* output,
|
||||
void run(void const* input, void const* cu_seqlens, void* output, cudaStream_t stream) override;
|
||||
void run_paged_kv(void const* q_input, void* paged_kv_tma_desc, void const* paged_kv_block_ptrs_on_host,
|
||||
const KVBlockArray paged_kv_cache, void const* cu_q_seqlens, void const* cu_kv_seqlens, void* output,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask,
|
||||
const int num_kv_heads /* MQA or GQA */) override;
|
||||
void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask,
|
||||
int const num_kv_heads /* MQA or GQA */) override;
|
||||
|
||||
bool isValid(int s) const override;
|
||||
|
||||
|
||||
@ -84,9 +84,9 @@ struct AlibiParams
|
||||
struct Fused_multihead_attention_params_v2
|
||||
{
|
||||
// The QKV matrices.
|
||||
const void* qkv_ptr;
|
||||
void const* qkv_ptr;
|
||||
// The mask to implement drop-out.
|
||||
const void* packed_mask_ptr;
|
||||
void const* packed_mask_ptr;
|
||||
// The O matrix (output).
|
||||
void* o_ptr;
|
||||
|
||||
@ -106,7 +106,7 @@ struct Fused_multihead_attention_params_v2
|
||||
bool enable_i2f_trick;
|
||||
|
||||
// array of length b+1 holding prefix sum of actual sequence lengths
|
||||
const int* cu_seqlens;
|
||||
int const* cu_seqlens;
|
||||
|
||||
// use C/32 Format.
|
||||
bool interleaved = false;
|
||||
@ -177,13 +177,13 @@ struct Fused_multihead_attention_params_v2
|
||||
struct Fused_multihead_attention_paged_kv_params_v2
|
||||
{
|
||||
// The Q matrices.
|
||||
const void* q_ptr;
|
||||
void const* q_ptr;
|
||||
// Paged KV Cache buffer.
|
||||
KVBlockArrayForContextFMHA paged_kv_cache;
|
||||
// The O matrix (output).
|
||||
void* o_ptr;
|
||||
// The packed mask for random mask.
|
||||
const void* packed_mask_ptr;
|
||||
void const* packed_mask_ptr;
|
||||
|
||||
// The stride between rows of the Q matrices.
|
||||
int64_t q_stride_in_bytes;
|
||||
@ -211,9 +211,9 @@ struct Fused_multihead_attention_paged_kv_params_v2
|
||||
AlibiParams alibi_params;
|
||||
|
||||
// array of length b+1 holding prefix sum of actual kv sequence lengths.
|
||||
const int* cu_seqlens;
|
||||
int const* cu_seqlens;
|
||||
// Chunked attention (only handles one tile of Q).
|
||||
const int* cu_q_seqlens;
|
||||
int const* cu_q_seqlens;
|
||||
|
||||
// q with shape [B, S, H, D] in const cache.
|
||||
cudaTmaDesc tma_desc_q;
|
||||
@ -301,7 +301,7 @@ struct Launch_params
|
||||
// number of paged kv blocks for context sequence.
|
||||
int blocks_per_context_sequence = 0;
|
||||
// device ptrs on the host for paged kv cache.
|
||||
const int64_t* paged_kv_block_ptrs = nullptr;
|
||||
int64_t const* paged_kv_block_ptrs = nullptr;
|
||||
// if flash attention is used (only FP16)
|
||||
bool flash_attention = false;
|
||||
// if warp_specialized kernels are used (only SM90 HGMMA + TMA)
|
||||
|
||||
@ -63,13 +63,13 @@ public:
|
||||
return (uint64_t) s << 32 | d;
|
||||
}
|
||||
|
||||
virtual uint64_t hashID(const KernelMeta& kernelMeta) const
|
||||
virtual uint64_t hashID(KernelMeta const& kernelMeta) const
|
||||
{
|
||||
return hashID(kernelMeta.mS, kernelMeta.mD);
|
||||
}
|
||||
|
||||
TFusedMultiHeadAttentionXMMAKernel(
|
||||
const TKernelMeta* pMetaStart, unsigned int nMetaCount, Data_type type, unsigned int sm)
|
||||
TKernelMeta const* pMetaStart, unsigned int nMetaCount, Data_type type, unsigned int sm)
|
||||
: mDataType(type)
|
||||
, mKernelMeta(pMetaStart)
|
||||
, mKernelMetaCount(nMetaCount)
|
||||
@ -86,7 +86,7 @@ public:
|
||||
|
||||
for (unsigned int i = 0; i < mKernelMetaCount; ++i)
|
||||
{
|
||||
const auto& kernelMeta = mKernelMeta[i];
|
||||
auto const& kernelMeta = mKernelMeta[i];
|
||||
if (kernelMeta.mSM == mSM && kernelMeta.mDataType == mDataType)
|
||||
{
|
||||
CUmodule hmod{0};
|
||||
@ -125,9 +125,9 @@ public:
|
||||
|
||||
virtual void run(TKernelParam& params, Launch_params& launch_params, cudaStream_t ss) const
|
||||
{
|
||||
const auto findIter = mFunctions.find(hashID(params.s, params.d));
|
||||
auto const findIter = mFunctions.find(hashID(params.s, params.d));
|
||||
|
||||
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
||||
auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
||||
const CUfunction func = findIter->second.mDeviceFunction;
|
||||
|
||||
void* kernelParams[] = {¶ms, nullptr};
|
||||
@ -142,10 +142,10 @@ protected:
|
||||
tensorrt_llm::common::CUDADriverWrapper mDriver;
|
||||
|
||||
Data_type mDataType;
|
||||
const TKernelMeta* mKernelMeta;
|
||||
TKernelMeta const* mKernelMeta;
|
||||
unsigned int mKernelMetaCount;
|
||||
unsigned int mSM;
|
||||
std::unordered_map<const unsigned char*, CUmodule> mModules;
|
||||
std::unordered_map<unsigned char const*, CUmodule> mModules;
|
||||
|
||||
struct FusedMultiHeadAttentionKernelInfo
|
||||
{
|
||||
@ -161,14 +161,14 @@ template <typename TFusedMHAKernelList>
|
||||
class TFusedMHAKernelFactory
|
||||
{
|
||||
public:
|
||||
const TFusedMHAKernelList* getXMMAKernels(const typename TFusedMHAKernelList::KernelMeta* pKernelList,
|
||||
TFusedMHAKernelList const* getXMMAKernels(const typename TFusedMHAKernelList::KernelMeta* pKernelList,
|
||||
unsigned int nbKernels, Data_type type, unsigned int sm)
|
||||
{
|
||||
static std::mutex s_mutex;
|
||||
std::lock_guard<std::mutex> lg(s_mutex);
|
||||
|
||||
const auto id = hashID(type, sm);
|
||||
const auto findIter = mKernels.find(id);
|
||||
auto const id = hashID(type, sm);
|
||||
auto const findIter = mKernels.find(id);
|
||||
if (findIter == mKernels.end())
|
||||
{
|
||||
TFusedMHAKernelList* newKernel = new TFusedMHAKernelList{pKernelList, nbKernels, type, sm};
|
||||
@ -214,7 +214,7 @@ class FusedMultiHeadAttentionXMMAKernelV2
|
||||
Fused_multihead_attention_params_v2>
|
||||
{
|
||||
public:
|
||||
FusedMultiHeadAttentionXMMAKernelV2(const FusedMultiHeadAttentionKernelMetaInfoV2* pMetaStart,
|
||||
FusedMultiHeadAttentionXMMAKernelV2(FusedMultiHeadAttentionKernelMetaInfoV2 const* pMetaStart,
|
||||
unsigned int nMetaCount, Data_type type, unsigned int sm)
|
||||
: TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionKernelMetaInfoV2,
|
||||
Fused_multihead_attention_params_v2>(pMetaStart, nMetaCount, type, sm)
|
||||
@ -231,7 +231,7 @@ public:
|
||||
| (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
|
||||
}
|
||||
|
||||
virtual uint64_t hashID(const KernelMeta& kernelMeta) const
|
||||
virtual uint64_t hashID(KernelMeta const& kernelMeta) const
|
||||
{
|
||||
|
||||
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
|
||||
@ -278,7 +278,7 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
const auto findIter
|
||||
auto const findIter
|
||||
= mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved, forceUnroll,
|
||||
launch_params.force_fp32_acc, launch_params.flash_attention, !launch_params.useKernelWithoutAlibi,
|
||||
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
|
||||
@ -290,7 +290,7 @@ public:
|
||||
launch_params.flash_attention, !launch_params.useKernelWithoutAlibi,
|
||||
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling);
|
||||
|
||||
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
||||
auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
||||
const CUfunction func = findIter->second.mDeviceFunction;
|
||||
|
||||
void* kernelParams[] = {¶ms, nullptr};
|
||||
@ -369,7 +369,7 @@ public:
|
||||
|
||||
using FusedMHAKernelFactoryV2 = TFusedMHAKernelFactory<FusedMultiHeadAttentionXMMAKernelV2>;
|
||||
|
||||
inline const FusedMultiHeadAttentionXMMAKernelV2* getXMMAKernelsV2(Data_type type, unsigned int sm)
|
||||
inline FusedMultiHeadAttentionXMMAKernelV2 const* getXMMAKernelsV2(Data_type type, unsigned int sm)
|
||||
{
|
||||
return FusedMHAKernelFactoryV2::Get().getXMMAKernels(
|
||||
sMhaKernelMetaInfosV2, sizeof(sMhaKernelMetaInfosV2) / sizeof(sMhaKernelMetaInfosV2[0]), type, sm);
|
||||
@ -384,7 +384,7 @@ class FusedMultiHeadAttentionPagedKVXMMAKernelV2
|
||||
Fused_multihead_attention_paged_kv_params_v2>
|
||||
{
|
||||
public:
|
||||
FusedMultiHeadAttentionPagedKVXMMAKernelV2(const FusedMultiHeadAttentionPagedKVKernelMetaInfoV2* pMetaStart,
|
||||
FusedMultiHeadAttentionPagedKVXMMAKernelV2(FusedMultiHeadAttentionPagedKVKernelMetaInfoV2 const* pMetaStart,
|
||||
unsigned int nMetaCount, Data_type type, unsigned int sm)
|
||||
: TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionPagedKVKernelMetaInfoV2,
|
||||
Fused_multihead_attention_paged_kv_params_v2>(pMetaStart, nMetaCount, type, sm)
|
||||
@ -402,7 +402,7 @@ public:
|
||||
| (flash_attention ? 4ull : 0ull) | (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
|
||||
}
|
||||
|
||||
virtual uint64_t hashID(const KernelMeta& kernelMeta) const
|
||||
virtual uint64_t hashID(KernelMeta const& kernelMeta) const
|
||||
{
|
||||
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
|
||||
kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mWarpSpecialization,
|
||||
@ -413,7 +413,7 @@ public:
|
||||
Fused_multihead_attention_paged_kv_params_v2& params, Launch_params& launch_params, cudaStream_t stream) const
|
||||
{
|
||||
|
||||
const auto findIter = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved,
|
||||
auto const findIter = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved,
|
||||
launch_params.force_unroll, launch_params.force_fp32_acc, launch_params.flash_attention,
|
||||
launch_params.warp_specialization, !launch_params.useKernelWithoutAlibi,
|
||||
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
|
||||
@ -426,7 +426,7 @@ public:
|
||||
!launch_params.useKernelWithoutAlibi, static_cast<int>(launch_params.attention_mask_type),
|
||||
launch_params.granular_tiling);
|
||||
|
||||
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
||||
auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
||||
const CUfunction func = findIter->second.mDeviceFunction;
|
||||
|
||||
void* kernelParams[] = {¶ms, nullptr};
|
||||
@ -488,7 +488,7 @@ public:
|
||||
|
||||
using FusedMHAPagedKVKernelFactoryV2 = TFusedMHAKernelFactory<FusedMultiHeadAttentionPagedKVXMMAKernelV2>;
|
||||
|
||||
inline const FusedMultiHeadAttentionPagedKVXMMAKernelV2* getPagedKVXMMAKernelsV2(Data_type type, unsigned int sm)
|
||||
inline FusedMultiHeadAttentionPagedKVXMMAKernelV2 const* getPagedKVXMMAKernelsV2(Data_type type, unsigned int sm)
|
||||
{
|
||||
return FusedMHAPagedKVKernelFactoryV2::Get().getXMMAKernels(sMhaPagedKVKernelMetaInfosV2,
|
||||
sizeof(sMhaPagedKVKernelMetaInfosV2) / sizeof(sMhaPagedKVKernelMetaInfosV2[0]), type, sm);
|
||||
|
||||
@ -186,7 +186,7 @@ public:
|
||||
// set the desctriptor.
|
||||
int set_tma_desctriptor(
|
||||
// ptr to gmem
|
||||
const void* gmem_ptr,
|
||||
void const* gmem_ptr,
|
||||
// format is really data_type in TMA terminology.
|
||||
cudaTmaDescFormat format,
|
||||
// interleave mode.
|
||||
@ -221,7 +221,7 @@ public:
|
||||
// set the desctriptor.
|
||||
int set_tma_desctriptor(
|
||||
// ptr to gmem
|
||||
const void* gmem_ptr,
|
||||
void const* gmem_ptr,
|
||||
// format is really data_type in TMA terminology.
|
||||
cudaTmaDescFormat format,
|
||||
// interleave mode.
|
||||
|
||||
@ -108,10 +108,10 @@ inline __device__ int4 add128b(T& a, T& b)
|
||||
}
|
||||
|
||||
__inline__ __device__ void multi_gpu_barrier(
|
||||
uint32_t** signals, const uint32_t flag, const size_t rank, const size_t world_size, const int tidx, const int bidx)
|
||||
uint32_t** signals, const uint32_t flag, const size_t rank, const size_t world_size, int const tidx, int const bidx)
|
||||
{
|
||||
// At the end of the function, we now that has least block 0 from all others GPUs have reached that point.
|
||||
volatile uint32_t* my_signals = signals[rank];
|
||||
uint32_t volatile* my_signals = signals[rank];
|
||||
if (tidx < world_size)
|
||||
{
|
||||
// The 1st block notifies the other ranks.
|
||||
@ -139,8 +139,8 @@ __global__ void multiGpuBarrierKernel(AllReduceParams params)
|
||||
template <typename T, int RANKS_PER_NODE>
|
||||
static __global__ void oneShotAllReduceKernel(AllReduceParams params)
|
||||
{
|
||||
const int bidx = blockIdx.x;
|
||||
const int tidx = threadIdx.x;
|
||||
int const bidx = blockIdx.x;
|
||||
int const tidx = threadIdx.x;
|
||||
|
||||
// The number of elements packed into one for comms
|
||||
static constexpr int NUM_ELTS = 16 / sizeof(T);
|
||||
@ -151,7 +151,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
|
||||
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
||||
|
||||
// The source pointers. Distributed round-robin for the different warps.
|
||||
const T* src_d[RANKS_PER_NODE];
|
||||
T const* src_d[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
@ -172,7 +172,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
vals[ii].packed = *reinterpret_cast<const int4*>(&src_d[ii][iter_offset]);
|
||||
vals[ii].packed = *reinterpret_cast<int4 const*>(&src_d[ii][iter_offset]);
|
||||
}
|
||||
|
||||
// Sum the values from the different ranks.
|
||||
@ -194,9 +194,9 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
|
||||
{
|
||||
|
||||
// The block index.
|
||||
const int bidx = blockIdx.x;
|
||||
int const bidx = blockIdx.x;
|
||||
// The thread index with the block.
|
||||
const int tidx = threadIdx.x;
|
||||
int const tidx = threadIdx.x;
|
||||
|
||||
// The number of elements packed into one for comms
|
||||
static constexpr int NUM_ELTS = 16 / sizeof(T);
|
||||
@ -233,7 +233,7 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
vals[ii].packed = *reinterpret_cast<const int4*>(&src_d[ii][local_offset]);
|
||||
vals[ii].packed = *reinterpret_cast<int4 const*>(&src_d[ii][local_offset]);
|
||||
}
|
||||
|
||||
// Sum the values from the different ranks.
|
||||
@ -396,14 +396,14 @@ void invokeMultiGpuBarrier(AllReduceParams& param, cudaStream_t stream)
|
||||
multiGpuBarrierKernel<<<1, param.ranks_per_node, 0, stream>>>(param);
|
||||
}
|
||||
|
||||
AllReduceParams AllReduceParams::deserialize(const int32_t* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value)
|
||||
AllReduceParams AllReduceParams::deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value)
|
||||
{
|
||||
void* const* buffer_ptrs = reinterpret_cast<void* const*>(buffer);
|
||||
AllReduceParams params;
|
||||
// Even plugins use ping buffers, odd plugins use pong.
|
||||
// That way, we don't need to wait for other GPUs to be done
|
||||
// before copying input tensor to workspace.
|
||||
const auto buffer_offset = (flag_value % 2 == 0) ? 0 : tpSize;
|
||||
auto const buffer_offset = (flag_value % 2 == 0) ? 0 : tpSize;
|
||||
|
||||
for (int i = 0; i < tpSize; ++i)
|
||||
{
|
||||
|
||||
@ -57,7 +57,7 @@ struct AllReduceParams
|
||||
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
|
||||
void* local_output_buffer_ptr;
|
||||
|
||||
static AllReduceParams deserialize(const int32_t* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value);
|
||||
static AllReduceParams deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
||||
@ -70,7 +70,7 @@ TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
|
||||
}
|
||||
|
||||
bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape,
|
||||
const int split_k_factor, const size_t workspace_bytes, const bool is_weight_only)
|
||||
int const split_k_factor, const size_t workspace_bytes, bool const is_weight_only)
|
||||
{
|
||||
|
||||
// All tile sizes have a k_tile of 64.
|
||||
@ -89,7 +89,7 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
|
||||
return false;
|
||||
}
|
||||
|
||||
const int k_elements_per_split = k / split_k_factor;
|
||||
int const k_elements_per_split = k / split_k_factor;
|
||||
if ((k_elements_per_split % k_tile) != 0)
|
||||
{
|
||||
return false;
|
||||
@ -97,9 +97,9 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
|
||||
}
|
||||
|
||||
// Check that the workspace has sufficient space for this split-k factor
|
||||
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
|
||||
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
|
||||
const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
|
||||
int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
|
||||
int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
|
||||
int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
|
||||
|
||||
if (required_ws_bytes > workspace_bytes)
|
||||
{
|
||||
@ -110,7 +110,7 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
|
||||
}
|
||||
|
||||
std::vector<CutlassTileConfig> get_candidate_tiles(
|
||||
const int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only)
|
||||
int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only)
|
||||
{
|
||||
enum class CutlassGemmType : char
|
||||
{
|
||||
@ -170,7 +170,7 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
|
||||
}
|
||||
|
||||
std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
|
||||
const int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only)
|
||||
int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only)
|
||||
{
|
||||
enum class CutlassGemmType : char
|
||||
{
|
||||
@ -226,8 +226,8 @@ bool supports_mcast_along_n(const CutlassTileConfigSM90 tile)
|
||||
return valid_tiles.count(tile) == 1;
|
||||
}
|
||||
|
||||
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only,
|
||||
const bool int8_configs_only, const int max_split_k, const bool enable_hopper_gmma)
|
||||
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, bool const is_weight_only, bool const simt_configs_only,
|
||||
bool const int8_configs_only, int const max_split_k, bool const enable_hopper_gmma)
|
||||
{
|
||||
if (sm == 90 && enable_hopper_gmma)
|
||||
{
|
||||
@ -235,14 +235,14 @@ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weigh
|
||||
= get_candidate_tiles_sm90(sm, is_weight_only, simt_configs_only, int8_configs_only);
|
||||
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
for (const auto& tile_config : tiles)
|
||||
for (auto const& tile_config : tiles)
|
||||
{
|
||||
CutlassGemmConfig config(
|
||||
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
|
||||
candidate_configs.push_back(config);
|
||||
|
||||
const bool has_m_mcast = supports_mcast_along_m(tile_config);
|
||||
const bool has_n_mcast = supports_mcast_along_n(tile_config);
|
||||
bool const has_m_mcast = supports_mcast_along_m(tile_config);
|
||||
bool const has_n_mcast = supports_mcast_along_n(tile_config);
|
||||
if (has_m_mcast)
|
||||
{
|
||||
CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||
@ -270,9 +270,9 @@ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weigh
|
||||
= get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only);
|
||||
|
||||
std::vector<CutlassGemmConfig> candidate_configs;
|
||||
const int min_stages = int8_configs_only ? 3 : 2;
|
||||
const int max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
|
||||
for (const auto& tile_config : tiles)
|
||||
int const min_stages = int8_configs_only ? 3 : 2;
|
||||
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
|
||||
for (auto const& tile_config : tiles)
|
||||
{
|
||||
for (int stages = min_stages; stages <= max_stages; ++stages)
|
||||
{
|
||||
@ -292,9 +292,9 @@ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weigh
|
||||
return candidate_configs;
|
||||
}
|
||||
|
||||
CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
|
||||
const std::vector<int>& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
|
||||
const int split_k_limit, const size_t workspace_bytes, const int multi_processor_count, const int is_weight_only)
|
||||
CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmConfig> const& candidate_configs,
|
||||
std::vector<int> const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
|
||||
int const split_k_limit, const size_t workspace_bytes, int const multi_processor_count, int const is_weight_only)
|
||||
{
|
||||
|
||||
if (occupancies.size() != candidate_configs.size())
|
||||
@ -311,7 +311,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
|
||||
int config_waves = INT_MAX;
|
||||
int current_m_tile = 0;
|
||||
|
||||
const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
|
||||
int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
|
||||
for (int ii = 0; ii < candidate_configs.size(); ++ii)
|
||||
{
|
||||
CutlassGemmConfig candidate_config = candidate_configs[ii];
|
||||
@ -330,21 +330,21 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
|
||||
continue;
|
||||
}
|
||||
|
||||
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
|
||||
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
|
||||
int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
|
||||
int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
|
||||
|
||||
for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor)
|
||||
{
|
||||
if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only))
|
||||
{
|
||||
const int ctas_per_wave = occupancy * multi_processor_count;
|
||||
const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
|
||||
int const ctas_per_wave = occupancy * multi_processor_count;
|
||||
int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
|
||||
|
||||
const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
|
||||
const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave);
|
||||
const float current_score = float(num_waves_total) - num_waves_fractional;
|
||||
int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
|
||||
float const num_waves_fractional = ctas_for_problem / float(ctas_per_wave);
|
||||
float const current_score = float(num_waves_total) - num_waves_fractional;
|
||||
|
||||
const float score_slack = 0.1f;
|
||||
float const score_slack = 0.1f;
|
||||
if (current_score < config_score
|
||||
|| ((config_waves > num_waves_total) && (current_score < config_score + score_slack)))
|
||||
{
|
||||
|
||||
@ -27,13 +27,13 @@ namespace cutlass_kernels
|
||||
{
|
||||
|
||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(int sm,
|
||||
const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only = false,
|
||||
const int max_split_k = 1, const bool enable_hopper_gmma = false);
|
||||
bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only = false,
|
||||
int const max_split_k = 1, bool const enable_hopper_gmma = false);
|
||||
|
||||
tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies(
|
||||
const std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig>& candidate_configs,
|
||||
const std::vector<int>& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
|
||||
const int split_k_limit, const size_t workspace_bytes, const int multi_processor_count, const int is_weight_only);
|
||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> const& candidate_configs,
|
||||
std::vector<int> const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
|
||||
int const split_k_limit, const size_t workspace_bytes, int const multi_processor_count, int const is_weight_only);
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
} // namespace kernels
|
||||
|
||||
@ -158,8 +158,8 @@ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
|
||||
// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
|
||||
// For int4, each group of 32 rows is permuted using the map below:
|
||||
// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31
|
||||
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor,
|
||||
const std::vector<size_t>& shape, QuantType quant_type, const int64_t arch_version)
|
||||
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor,
|
||||
std::vector<size_t> const& shape, QuantType quant_type, const int64_t arch_version)
|
||||
{
|
||||
|
||||
// We only want to run this step for weight only quant.
|
||||
@ -170,19 +170,19 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8
|
||||
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
||||
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
||||
|
||||
const int BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
||||
const int K = 16 / BITS_PER_ELT;
|
||||
const int ELTS_PER_BYTE = 8 / BITS_PER_ELT;
|
||||
const int ELTS_PER_REG = 32 / BITS_PER_ELT;
|
||||
int const BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
||||
int const K = 16 / BITS_PER_ELT;
|
||||
int const ELTS_PER_BYTE = 8 / BITS_PER_ELT;
|
||||
int const ELTS_PER_REG = 32 / BITS_PER_ELT;
|
||||
|
||||
const uint32_t* input_byte_ptr = reinterpret_cast<const uint32_t*>(quantized_tensor);
|
||||
uint32_t const* input_byte_ptr = reinterpret_cast<uint32_t const*>(quantized_tensor);
|
||||
uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(permuted_quantized_tensor);
|
||||
|
||||
int MMA_SHAPE_N = 8;
|
||||
int B_ROWS_PER_MMA = 8 * K;
|
||||
const int elts_in_int32 = 32 / BITS_PER_ELT;
|
||||
int const elts_in_int32 = 32 / BITS_PER_ELT;
|
||||
|
||||
const int num_vec_cols = num_cols / elts_in_int32;
|
||||
int const num_vec_cols = num_cols / elts_in_int32;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
arch_version >= 75, "Unsupported Arch. Pre-volta not supported. Column interleave not needed on Volta.");
|
||||
@ -205,11 +205,11 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8
|
||||
|
||||
for (int write_col = 0; write_col < num_vec_cols; ++write_col)
|
||||
{
|
||||
const int write_row = base_row + tile_row;
|
||||
const int tile_read_row
|
||||
int const write_row = base_row + tile_row;
|
||||
int const tile_read_row
|
||||
= 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
|
||||
const int read_row = base_row + tile_read_row;
|
||||
const int read_col = write_col;
|
||||
int const read_row = base_row + tile_read_row;
|
||||
int const read_col = write_col;
|
||||
|
||||
const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col;
|
||||
const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col;
|
||||
@ -227,9 +227,9 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8
|
||||
// issue for relatively large models.
|
||||
template <QuantType quant_type>
|
||||
void subbyte_transpose_impl(
|
||||
int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, const std::vector<size_t>& shape)
|
||||
int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, std::vector<size_t> const& shape)
|
||||
{
|
||||
const int bits_per_elt = get_bits_in_quant_type(quant_type);
|
||||
int const bits_per_elt = get_bits_in_quant_type(quant_type);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
|
||||
const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
|
||||
@ -240,7 +240,7 @@ void subbyte_transpose_impl(
|
||||
const size_t col_bytes_trans = num_rows * bits_per_elt / 8;
|
||||
const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes;
|
||||
|
||||
const uint8_t* input_byte_ptr = reinterpret_cast<const uint8_t*>(quantized_tensor);
|
||||
uint8_t const* input_byte_ptr = reinterpret_cast<uint8_t const*>(quantized_tensor);
|
||||
uint8_t* output_byte_ptr = reinterpret_cast<uint8_t*>(transposed_quantized_tensor);
|
||||
|
||||
static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, "");
|
||||
@ -260,8 +260,8 @@ void subbyte_transpose_impl(
|
||||
"num_col_bytes = %ld.",
|
||||
VECTOR_WIDTH, col_bytes_trans, col_bytes));
|
||||
|
||||
const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1;
|
||||
const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1;
|
||||
int const num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1;
|
||||
int const num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1;
|
||||
|
||||
for (size_t expert = 0; expert < num_experts; ++expert)
|
||||
{
|
||||
@ -271,16 +271,16 @@ void subbyte_transpose_impl(
|
||||
for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1)
|
||||
{
|
||||
|
||||
const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
|
||||
const int col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
|
||||
int const row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
|
||||
int const col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
|
||||
|
||||
for (int ii = 0; ii < M_TILE_L1; ++ii)
|
||||
{
|
||||
const int row = row_tile_start + ii;
|
||||
int const row = row_tile_start + ii;
|
||||
|
||||
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH)
|
||||
{
|
||||
const int col = col_tile_start_byte + jj;
|
||||
int const col = col_tile_start_byte + jj;
|
||||
|
||||
const size_t logical_src_offset = matrix_offset + row * col_bytes + col;
|
||||
|
||||
@ -313,11 +313,11 @@ void subbyte_transpose_impl(
|
||||
// is square in the number of elements (not necessarily the number of bytes).
|
||||
for (int jj = ii + 1; jj < M_TILE_L1; ++jj)
|
||||
{
|
||||
const int ii_byte = ii / ELTS_PER_BYTE;
|
||||
const int ii_bit_offset = ii % ELTS_PER_BYTE;
|
||||
int const ii_byte = ii / ELTS_PER_BYTE;
|
||||
int const ii_bit_offset = ii % ELTS_PER_BYTE;
|
||||
|
||||
const int jj_byte = jj / ELTS_PER_BYTE;
|
||||
const int jj_bit_offset = jj % ELTS_PER_BYTE;
|
||||
int const jj_byte = jj / ELTS_PER_BYTE;
|
||||
int const jj_bit_offset = jj % ELTS_PER_BYTE;
|
||||
|
||||
uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset));
|
||||
uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset));
|
||||
@ -338,15 +338,15 @@ void subbyte_transpose_impl(
|
||||
const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE;
|
||||
const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE;
|
||||
|
||||
const int row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols);
|
||||
const int col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
|
||||
int const row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols);
|
||||
int const col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
|
||||
|
||||
for (int ii = 0; ii < M_TILE_L1; ++ii)
|
||||
{
|
||||
const int row = row_tile_start_trans + ii;
|
||||
int const row = row_tile_start_trans + ii;
|
||||
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH)
|
||||
{
|
||||
const int col = col_tile_start_byte_trans + jj;
|
||||
int const col = col_tile_start_byte_trans + jj;
|
||||
|
||||
const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col;
|
||||
|
||||
@ -364,8 +364,8 @@ void subbyte_transpose_impl(
|
||||
}
|
||||
}
|
||||
|
||||
void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor,
|
||||
const std::vector<size_t>& shape, QuantType quant_type)
|
||||
void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor,
|
||||
std::vector<size_t> const& shape, QuantType quant_type)
|
||||
{
|
||||
|
||||
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
|
||||
@ -409,7 +409,7 @@ void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, const size_t num
|
||||
|
||||
void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts)
|
||||
{
|
||||
const int num_bytes = num_elts / 2;
|
||||
int const num_bytes = num_elts / 2;
|
||||
|
||||
// Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little
|
||||
// instructions as possible in the CUDA code.
|
||||
@ -451,9 +451,9 @@ void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const siz
|
||||
|
||||
for (int dest_idx = 0; dest_idx < 8; ++dest_idx)
|
||||
{
|
||||
const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1;
|
||||
const int src_shift = 4 * src_idx;
|
||||
const int dest_shift = 4 * dest_idx;
|
||||
int const src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1;
|
||||
int const src_shift = 4 * src_idx;
|
||||
int const dest_shift = 4 * dest_idx;
|
||||
|
||||
const uint32_t src_bits = (current_register >> src_shift) & 0xF;
|
||||
transformed_register |= (src_bits << dest_shift);
|
||||
@ -478,8 +478,8 @@ void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size
|
||||
}
|
||||
}
|
||||
|
||||
void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const int8_t* quantized_tensor,
|
||||
const std::vector<size_t>& shape, QuantType quant_type, LayoutDetails details)
|
||||
void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, int8_t const* quantized_tensor,
|
||||
std::vector<size_t> const& shape, QuantType quant_type, LayoutDetails details)
|
||||
{
|
||||
|
||||
// We only want to run this step for weight only quant.
|
||||
@ -490,23 +490,23 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const
|
||||
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
||||
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
||||
|
||||
const int BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
||||
const int elts_in_int32 = 32 / BITS_PER_ELT;
|
||||
int const BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
||||
int const elts_in_int32 = 32 / BITS_PER_ELT;
|
||||
|
||||
const int rows_per_tile = details.rows_per_column_tile;
|
||||
int const rows_per_tile = details.rows_per_column_tile;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(!(num_rows % elts_in_int32),
|
||||
fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", elts_in_int32, num_rows));
|
||||
|
||||
const uint32_t* input_byte_ptr = reinterpret_cast<const uint32_t*>(quantized_tensor);
|
||||
uint32_t const* input_byte_ptr = reinterpret_cast<uint32_t const*>(quantized_tensor);
|
||||
uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(interleaved_quantized_tensor);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(!(num_rows % rows_per_tile),
|
||||
fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", rows_per_tile, num_rows));
|
||||
|
||||
const int num_vec_rows = num_rows / elts_in_int32;
|
||||
const int vec_rows_per_tile = rows_per_tile / elts_in_int32;
|
||||
const int interleave = details.columns_interleaved;
|
||||
int const num_vec_rows = num_rows / elts_in_int32;
|
||||
int const vec_rows_per_tile = rows_per_tile / elts_in_int32;
|
||||
int const interleave = details.columns_interleaved;
|
||||
|
||||
for (int expert = 0; expert < num_experts; ++expert)
|
||||
{
|
||||
@ -532,8 +532,8 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const
|
||||
}
|
||||
}
|
||||
|
||||
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight,
|
||||
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave)
|
||||
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight,
|
||||
std::vector<size_t> const& shape, QuantType quant_type, bool force_interleave)
|
||||
{
|
||||
int arch = getSMVersion();
|
||||
if (force_interleave && arch == 90)
|
||||
@ -546,7 +546,7 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, co
|
||||
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
|
||||
|
||||
size_t num_elts = 1;
|
||||
for (const auto& dim : shape)
|
||||
for (auto const& dim : shape)
|
||||
{
|
||||
num_elts *= dim;
|
||||
}
|
||||
@ -620,7 +620,7 @@ Outputs
|
||||
|
||||
template <typename ComputeType, typename WeightType>
|
||||
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
|
||||
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type,
|
||||
ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector<size_t> const& shape, QuantType quant_type,
|
||||
bool force_interleave)
|
||||
{
|
||||
|
||||
@ -633,8 +633,8 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
||||
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
||||
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
||||
|
||||
const int bits_in_type = get_bits_in_quant_type(quant_type);
|
||||
const int bytes_per_out_col = num_cols * bits_in_type / 8;
|
||||
int const bits_in_type = get_bits_in_quant_type(quant_type);
|
||||
int const bytes_per_out_col = num_cols * bits_in_type / 8;
|
||||
|
||||
std::vector<int8_t> weight_buf;
|
||||
if (unprocessed_quantized_weight == nullptr)
|
||||
@ -643,15 +643,15 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
||||
unprocessed_quantized_weight = weight_buf.data();
|
||||
}
|
||||
|
||||
const int input_mat_size = num_rows * num_cols;
|
||||
const int quantized_mat_size = num_rows * bytes_per_out_col;
|
||||
const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1));
|
||||
int const input_mat_size = num_rows * num_cols;
|
||||
int const quantized_mat_size = num_rows * bytes_per_out_col;
|
||||
float const quant_range_scale = 1.f / float(1 << (bits_in_type - 1));
|
||||
|
||||
std::vector<float> per_col_max(num_cols);
|
||||
|
||||
for (int expert = 0; expert < num_experts; ++expert)
|
||||
{
|
||||
const WeightType* current_weight = input_weight_ptr + expert * input_mat_size;
|
||||
WeightType const* current_weight = input_weight_ptr + expert * input_mat_size;
|
||||
int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size;
|
||||
|
||||
// First we find the per column max for this expert weight.
|
||||
@ -662,7 +662,7 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
||||
|
||||
for (int ii = 0; ii < num_rows; ++ii)
|
||||
{
|
||||
const WeightType* current_weight_row = current_weight + ii * num_cols;
|
||||
WeightType const* current_weight_row = current_weight + ii * num_cols;
|
||||
for (int jj = 0; jj < num_cols; ++jj)
|
||||
{
|
||||
per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj])));
|
||||
@ -681,15 +681,15 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
||||
for (int ii = 0; ii < num_rows; ++ii)
|
||||
{
|
||||
int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col;
|
||||
const WeightType* current_weight_row = current_weight + ii * num_cols;
|
||||
WeightType const* current_weight_row = current_weight + ii * num_cols;
|
||||
for (int jj = 0; jj < bytes_per_out_col; ++jj)
|
||||
{
|
||||
|
||||
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
|
||||
{
|
||||
const float col_scale = per_col_max[jj];
|
||||
const float weight_elt = float(current_weight_row[jj]);
|
||||
const float scaled_weight = round(weight_elt / col_scale);
|
||||
float const col_scale = per_col_max[jj];
|
||||
float const weight_elt = float(current_weight_row[jj]);
|
||||
float const scaled_weight = round(weight_elt / col_scale);
|
||||
const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight)));
|
||||
current_quantized_weight_row[jj] = clipped_weight;
|
||||
}
|
||||
@ -700,12 +700,12 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
||||
int8_t packed_int4s = 0;
|
||||
for (int packed_idx = 0; packed_idx < 2; ++packed_idx)
|
||||
{
|
||||
const int input_idx = 2 * jj + packed_idx;
|
||||
int const input_idx = 2 * jj + packed_idx;
|
||||
if (input_idx < num_cols)
|
||||
{
|
||||
const float col_scale = per_col_max[input_idx];
|
||||
const float weight_elt = float(current_weight_row[input_idx]);
|
||||
const float scaled_weight = round(weight_elt / col_scale);
|
||||
float const col_scale = per_col_max[input_idx];
|
||||
float const weight_elt = float(current_weight_row[input_idx]);
|
||||
float const scaled_weight = round(weight_elt / col_scale);
|
||||
int int_weight = int(scaled_weight);
|
||||
const int8_t clipped_weight = std::max(-8, std::min(7, int_weight));
|
||||
|
||||
@ -729,47 +729,47 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
||||
}
|
||||
|
||||
template void symmetric_quantize<half, float>(
|
||||
int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType, bool);
|
||||
int8_t*, int8_t*, half*, float const*, std::vector<size_t> const&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<half, half>(
|
||||
int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType, bool);
|
||||
int8_t*, int8_t*, half*, half const*, std::vector<size_t> const&, QuantType, bool);
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
|
||||
int8_t*, int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
|
||||
int8_t*, int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector<size_t> const&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<__nv_bfloat16, float>(
|
||||
int8_t*, int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType, bool);
|
||||
int8_t*, int8_t*, __nv_bfloat16*, float const*, std::vector<size_t> const&, QuantType, bool);
|
||||
#endif
|
||||
|
||||
template <typename ComputeType, typename WeightType>
|
||||
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr,
|
||||
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave)
|
||||
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr,
|
||||
std::vector<size_t> const& shape, QuantType quant_type, bool force_interleave)
|
||||
{
|
||||
symmetric_quantize(
|
||||
processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave);
|
||||
}
|
||||
|
||||
template void symmetric_quantize<float, float>(
|
||||
int8_t*, float*, const float*, const std::vector<size_t>&, QuantType, bool);
|
||||
int8_t*, float*, float const*, std::vector<size_t> const&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<half, float>(
|
||||
int8_t*, half*, const float*, const std::vector<size_t>&, QuantType, bool);
|
||||
int8_t*, half*, float const*, std::vector<size_t> const&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType, bool);
|
||||
template void symmetric_quantize<half, half>(int8_t*, half*, half const*, std::vector<size_t> const&, QuantType, bool);
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
|
||||
int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
|
||||
int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector<size_t> const&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<__nv_bfloat16, half>(
|
||||
int8_t*, __nv_bfloat16*, const half*, const std::vector<size_t>&, QuantType, bool);
|
||||
int8_t*, __nv_bfloat16*, half const*, std::vector<size_t> const&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<half, __nv_bfloat16>(
|
||||
int8_t*, half*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
|
||||
int8_t*, half*, __nv_bfloat16 const*, std::vector<size_t> const&, QuantType, bool);
|
||||
|
||||
template void symmetric_quantize<__nv_bfloat16, float>(
|
||||
int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType, bool);
|
||||
int8_t*, __nv_bfloat16*, float const*, std::vector<size_t> const&, QuantType, bool);
|
||||
#endif
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
|
||||
@ -38,26 +38,26 @@ int get_bits_in_quant_type(QuantType quant_type);
|
||||
|
||||
// Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols]
|
||||
// 3-D shapes are [num_experts, num_rows, num_cols]
|
||||
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor,
|
||||
const std::vector<size_t>& shape, QuantType quant_type, const int64_t arch_version);
|
||||
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor,
|
||||
std::vector<size_t> const& shape, QuantType quant_type, const int64_t arch_version);
|
||||
|
||||
void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor,
|
||||
const std::vector<size_t>& shape, QuantType quant_type);
|
||||
void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor,
|
||||
std::vector<size_t> const& shape, QuantType quant_type);
|
||||
|
||||
void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type);
|
||||
|
||||
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight,
|
||||
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave = false);
|
||||
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight,
|
||||
std::vector<size_t> const& shape, QuantType quant_type, bool force_interleave = false);
|
||||
|
||||
template <typename ComputeType, typename WeightType>
|
||||
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr,
|
||||
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave);
|
||||
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr,
|
||||
std::vector<size_t> const& shape, QuantType quant_type, bool force_interleave);
|
||||
|
||||
// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight
|
||||
// to implement a simple reference implementation.
|
||||
template <typename ComputeType, typename WeightType>
|
||||
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
|
||||
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type,
|
||||
ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector<size_t> const& shape, QuantType quant_type,
|
||||
bool force_interleave);
|
||||
|
||||
} // namespace cutlass_kernels
|
||||
|
||||
@ -58,27 +58,27 @@ public:
|
||||
|
||||
virtual ~CutlassFpAIntBGemmRunnerInterface() {}
|
||||
|
||||
virtual void gemm(const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k,
|
||||
virtual void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
virtual void gemm(const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n,
|
||||
virtual void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n,
|
||||
int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
||||
cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
|
||||
const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig,
|
||||
virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points,
|
||||
void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig,
|
||||
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
|
||||
const void* biases, const float alpha, void* C, int m, int n, int k, const int group_size,
|
||||
virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points,
|
||||
void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
// Returns desired workspace size in bytes.
|
||||
virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0;
|
||||
virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0;
|
||||
|
||||
virtual std::vector<tkc::CutlassGemmConfig> getConfigs() const = 0;
|
||||
|
||||
@ -96,20 +96,20 @@ public:
|
||||
CutlassFpAIntBGemmRunner();
|
||||
~CutlassFpAIntBGemmRunner();
|
||||
|
||||
void gemm(const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k,
|
||||
void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
void gemm(const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, int k,
|
||||
void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
|
||||
const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig,
|
||||
void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points,
|
||||
void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig,
|
||||
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override;
|
||||
|
||||
void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
|
||||
const void* biases, const float alpha, void* C, int m, int n, int k, const int group_size,
|
||||
void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points,
|
||||
void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
||||
cudaStream_t stream) override;
|
||||
|
||||
@ -120,15 +120,15 @@ public:
|
||||
// stream);
|
||||
|
||||
// Returns desired workspace size in bytes.
|
||||
size_t getWorkspaceSize(const int m, const int n, const int k) override;
|
||||
size_t getWorkspaceSize(int const m, int const n, int const k) override;
|
||||
|
||||
std::vector<tkc::CutlassGemmConfig> getConfigs() const override;
|
||||
|
||||
private:
|
||||
template <typename EpilogueTag>
|
||||
void dispatch_to_arch(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
||||
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr,
|
||||
void dispatch_to_arch(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr,
|
||||
const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr);
|
||||
|
||||
private:
|
||||
|
||||
@ -52,8 +52,8 @@ namespace cutlass_kernels
|
||||
|
||||
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape, int Stages>
|
||||
void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T* weight_scales,
|
||||
const T* weight_zero_points, const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size,
|
||||
void generic_mixed_gemm_kernelLauncher(T const* A, WeightType const* B, T const* weight_scales,
|
||||
T const* weight_zero_points, T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
|
||||
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
@ -127,7 +127,7 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
|
||||
|
||||
const int ldb = cutlass::platform::is_same<cutlass::layout::RowMajor, typename MixedGemmArchTraits::LayoutB>::value
|
||||
int const ldb = cutlass::platform::is_same<cutlass::layout::RowMajor, typename MixedGemmArchTraits::LayoutB>::value
|
||||
? n
|
||||
: k * GemmKernel::kInterleave;
|
||||
|
||||
@ -171,7 +171,7 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
|
||||
}
|
||||
}
|
||||
|
||||
const int ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0;
|
||||
int const ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0;
|
||||
ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f);
|
||||
typename Gemm::Arguments args({m, n, k}, group_size, {reinterpret_cast<ElementType*>(const_cast<T*>(A)), k},
|
||||
{reinterpret_cast<CutlassWeightType*>(const_cast<WeightType*>(B)), ldb},
|
||||
@ -230,8 +230,8 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
|
||||
// quanitzation is only supported on Ampere+ GPUs.
|
||||
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape, int Stages>
|
||||
void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points,
|
||||
const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size,
|
||||
void filter_and_run_mixed_gemm(T const* A, WeightType const* B, T const* weight_scales, T const* weight_zero_points,
|
||||
T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
|
||||
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
@ -261,8 +261,8 @@ void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_
|
||||
|
||||
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
||||
typename ThreadblockShape, typename WarpShape>
|
||||
void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points,
|
||||
const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size,
|
||||
void dispatch_gemm_config(T const* A, WeightType const* B, T const* weight_scales, T const* weight_zero_points,
|
||||
T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
|
||||
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
||||
int* occupancy = nullptr)
|
||||
{
|
||||
@ -300,9 +300,9 @@ constexpr bool is_fp8()
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
|
||||
void dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
||||
int k, const int group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
|
||||
void dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||
int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
|
||||
@ -412,9 +412,9 @@ template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuant
|
||||
typename BiasType, typename OutputType>
|
||||
template <typename EpilogueTag>
|
||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType,
|
||||
OutputType>::dispatch_to_arch<EpilogueTag>(const ActivationType* A, const WeightType* B,
|
||||
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases,
|
||||
const float alpha, OutputType* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
|
||||
OutputType>::dispatch_to_arch<EpilogueTag>(ActivationType const* A, WeightType const* B,
|
||||
ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
|
||||
float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config,
|
||||
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
@ -453,16 +453,16 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
||||
const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, const void* biases,
|
||||
const float alpha, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig,
|
||||
void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases,
|
||||
float const alpha, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig,
|
||||
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
if constexpr ((QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
|
||||
|| (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY))
|
||||
{
|
||||
dispatch_to_arch<tkc::EpilogueOpBias>((const ActivationType*) A, (const WeightType*) B,
|
||||
(const ScaleZeroType*) weight_scales, (const ScaleZeroType*) weight_zero_points, (const BiasType*) biases,
|
||||
dispatch_to_arch<tkc::EpilogueOpBias>((ActivationType const*) A, (WeightType const*) B,
|
||||
(ScaleZeroType const*) weight_scales, (ScaleZeroType const*) weight_zero_points, (BiasType const*) biases,
|
||||
alpha, (OutputType*) C, m, n, k, group_size, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr);
|
||||
}
|
||||
else
|
||||
@ -475,8 +475,8 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
||||
const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, const void* biases,
|
||||
void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr,
|
||||
void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases,
|
||||
void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr,
|
||||
const size_t workspace_bytes, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
@ -487,15 +487,15 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
||||
const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, int k,
|
||||
void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
|
||||
if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY)
|
||||
{
|
||||
dispatch_to_arch<tkc::EpilogueOpBias>((const ActivationType*) A, (const WeightType*) B,
|
||||
(const ScaleZeroType*) weight_scales, nullptr, nullptr, alpha, (OutputType*) C, m, n, k, k, gemmConfig,
|
||||
dispatch_to_arch<tkc::EpilogueOpBias>((ActivationType const*) A, (WeightType const*) B,
|
||||
(ScaleZeroType const*) weight_scales, nullptr, nullptr, alpha, (OutputType*) C, m, n, k, k, gemmConfig,
|
||||
workspace_ptr, workspace_bytes, stream, nullptr);
|
||||
}
|
||||
else
|
||||
@ -507,7 +507,7 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
|
||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||
typename BiasType, typename OutputType>
|
||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
||||
const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k,
|
||||
void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k,
|
||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
@ -529,12 +529,12 @@ template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuant
|
||||
typename BiasType, typename OutputType>
|
||||
size_t
|
||||
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::getWorkspaceSize(
|
||||
const int m, const int n, const int k)
|
||||
int const m, int const n, int const k)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
// These are the min tile sizes for each config, which would launch the maximum number of blocks
|
||||
const int max_grid_m = cutlass::ceil_div(m, MIN_M_TILE);
|
||||
const int max_grid_n = cutlass::ceil_div(n, MIN_N_TILE);
|
||||
int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE);
|
||||
int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE);
|
||||
// We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
|
||||
return static_cast<size_t>(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4);
|
||||
}
|
||||
|
||||
@ -44,9 +44,9 @@ namespace cutlass_kernels
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
||||
typename MainloopScheduleType>
|
||||
void sm90_dispatch_epilogue_schedules(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
||||
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
void sm90_dispatch_epilogue_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
|
||||
@ -114,9 +114,9 @@ constexpr bool are_tile_shapes_supported()
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape>
|
||||
void sm90_dispatch_mainloop_schedules(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
||||
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
void sm90_dispatch_mainloop_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
@ -153,9 +153,9 @@ void sm90_dispatch_mainloop_schedules(const ActivationType* A, const WeightType*
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape>
|
||||
void sm90_dispatch_gemm_config(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
||||
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
void sm90_dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
@ -190,9 +190,9 @@ void sm90_dispatch_gemm_config(const ActivationType* A, const WeightType* B, con
|
||||
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
|
||||
void sm90_dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
||||
int k, const int group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
|
||||
void sm90_dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||
int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
|
||||
cudaStream_t stream, int* occupancy = nullptr)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
|
||||
@ -28,9 +28,9 @@ namespace cutlass_kernels
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
||||
typename MainloopScheduleType, typename EpilogueScheduleType>
|
||||
void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const WeightType* B,
|
||||
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases,
|
||||
const float alpha, OutputType* C, int m, int n, int k, const int group_size,
|
||||
void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B,
|
||||
ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
|
||||
float const alpha, OutputType* C, int m, int n, int k, int const group_size,
|
||||
tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||
cudaStream_t stream, int* occupancy = nullptr);
|
||||
|
||||
|
||||
@ -59,9 +59,9 @@ namespace cutlass_kernels
|
||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
||||
typename MainloopScheduleType, typename EpilogueScheduleType>
|
||||
void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const WeightType* B,
|
||||
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases,
|
||||
const float alpha, OutputType* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
|
||||
void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B,
|
||||
ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
|
||||
float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config,
|
||||
char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy)
|
||||
{
|
||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||
@ -233,7 +233,7 @@ void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const Weigh
|
||||
StrideS stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(n, cutlass_scale_k, 1));
|
||||
|
||||
// Use the output as the bias to avoid making a tma descriptor with a nullptr.
|
||||
auto output_as_bias_type = reinterpret_cast<const CutlassBiasType*>(C);
|
||||
auto output_as_bias_type = reinterpret_cast<CutlassBiasType const*>(C);
|
||||
|
||||
typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, {n, m, k, 1},
|
||||
{reinterpret_cast<CutlassWeightType const*>(B), stride_B, reinterpret_cast<CutlassActivationType const*>(A),
|
||||
|
||||
@ -47,13 +47,13 @@ public:
|
||||
|
||||
virtual ~CutlassInt8GemmRunnerInterface() {}
|
||||
|
||||
virtual void gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol,
|
||||
const float* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
|
||||
virtual void gemm(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol,
|
||||
float const* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
|
||||
const size_t workspaceBytes, cudaStream_t stream)
|
||||
= 0;
|
||||
|
||||
// Returns desired workspace size in bytes.
|
||||
virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0;
|
||||
virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0;
|
||||
|
||||
virtual std::vector<tkc::CutlassGemmConfig> getConfigs() const = 0;
|
||||
|
||||
@ -70,18 +70,18 @@ public:
|
||||
CutlassInt8GemmRunner();
|
||||
~CutlassInt8GemmRunner();
|
||||
|
||||
void gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, const float* alphaRow,
|
||||
void gemm(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, float const* alphaRow,
|
||||
void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
|
||||
const size_t workspaceBytes, cudaStream_t stream) override;
|
||||
|
||||
// Returns desired workspace size in bytes.
|
||||
size_t getWorkspaceSize(const int m, const int n, const int k) override;
|
||||
size_t getWorkspaceSize(int const m, int const n, int const k) override;
|
||||
|
||||
std::vector<tkc::CutlassGemmConfig> getConfigs() const override;
|
||||
|
||||
private:
|
||||
void dispatchToArch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol,
|
||||
const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
|
||||
void dispatchToArch(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol,
|
||||
float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
|
||||
const size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr);
|
||||
|
||||
int mSm;
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user