mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-14 06:27:45 +08:00
Update TensorRT-LLM (#1274)
* Update TensorRT-LLM --------- Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com> Co-authored-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
This commit is contained in:
parent
728cc0044b
commit
4bb65f216f
@ -59,6 +59,7 @@ PenaltyBreakString: 1000
|
|||||||
PenaltyExcessCharacter: 1000000
|
PenaltyExcessCharacter: 1000000
|
||||||
PenaltyReturnTypeOnItsOwnLine: 60
|
PenaltyReturnTypeOnItsOwnLine: 60
|
||||||
PointerAlignment: Left
|
PointerAlignment: Left
|
||||||
|
QualifierAlignment: Right
|
||||||
ReflowComments: true
|
ReflowComments: true
|
||||||
SeparateDefinitionBlocks: Always
|
SeparateDefinitionBlocks: Always
|
||||||
SortIncludes: CaseSensitive
|
SortIncludes: CaseSensitive
|
||||||
|
|||||||
10
.gitignore
vendored
10
.gitignore
vendored
@ -17,6 +17,16 @@ venv/
|
|||||||
.local/
|
.local/
|
||||||
.hypothesis/
|
.hypothesis/
|
||||||
.idea/
|
.idea/
|
||||||
|
dump*/
|
||||||
|
.trt-internal
|
||||||
|
*.dot
|
||||||
|
*.prof
|
||||||
|
*.log
|
||||||
|
*.pkl
|
||||||
|
*.hdf5
|
||||||
|
*.lock
|
||||||
|
config.json
|
||||||
|
/*.svg
|
||||||
cpp/cmake-build-*
|
cpp/cmake-build-*
|
||||||
cpp/.ccache/
|
cpp/.ccache/
|
||||||
tensorrt_llm/libs
|
tensorrt_llm/libs
|
||||||
|
|||||||
@ -355,6 +355,9 @@ however, that it is recommended to use the C++ version.
|
|||||||
|
|
||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
|
||||||
|
* If you encounter accuracy issues in the generated text, you may want to increase
|
||||||
|
the internal precision in the attention layer. For that, pass the `--context_fmha_fp32_acc enable` to
|
||||||
|
`trtllm-build`.
|
||||||
|
|
||||||
* It's recommended to add options `–shm-size=1g –ulimit memlock=-1` to the
|
* It's recommended to add options `–shm-size=1g –ulimit memlock=-1` to the
|
||||||
docker or nvidia-docker run command. Otherwise you may see NCCL errors when
|
docker or nvidia-docker run command. Otherwise you may see NCCL errors when
|
||||||
|
|||||||
@ -39,7 +39,6 @@ Take GPT-350M as an example for single GPU
|
|||||||
|
|
||||||
```
|
```
|
||||||
./benchmarks/gptSessionBenchmark \
|
./benchmarks/gptSessionBenchmark \
|
||||||
--model gpt_350m \
|
|
||||||
--engine_dir "../../benchmarks/gpt_350m/" \
|
--engine_dir "../../benchmarks/gpt_350m/" \
|
||||||
--batch_size "1" \
|
--batch_size "1" \
|
||||||
--input_output_len "60,20"
|
--input_output_len "60,20"
|
||||||
@ -50,7 +49,6 @@ Take GPT-350M as an example for single GPU
|
|||||||
Take GPT-175B as an example for multiple GPUs
|
Take GPT-175B as an example for multiple GPUs
|
||||||
```
|
```
|
||||||
mpirun -n 8 ./benchmarks/gptSessionBenchmark \
|
mpirun -n 8 ./benchmarks/gptSessionBenchmark \
|
||||||
--model gpt_175b \
|
|
||||||
--engine_dir "../../benchmarks/gpt_175b/" \
|
--engine_dir "../../benchmarks/gpt_175b/" \
|
||||||
--batch_size "1" \
|
--batch_size "1" \
|
||||||
--input_output_len "60,20"
|
--input_output_len "60,20"
|
||||||
@ -125,7 +123,6 @@ cd cpp/build
|
|||||||
Take GPT-350M as an example for single GPU V1 batching
|
Take GPT-350M as an example for single GPU V1 batching
|
||||||
```
|
```
|
||||||
./benchmarks/gptManagerBenchmark \
|
./benchmarks/gptManagerBenchmark \
|
||||||
--model gpt \
|
|
||||||
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
|
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
|
||||||
--type V1 \
|
--type V1 \
|
||||||
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
|
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
|
||||||
@ -135,7 +132,6 @@ Take GPT-350M as an example for single GPU V1 batching
|
|||||||
Take GPT-350M as an example for 2-GPU inflight batching
|
Take GPT-350M as an example for 2-GPU inflight batching
|
||||||
```
|
```
|
||||||
mpirun -n 2 ./benchmarks/gptManagerBenchmark \
|
mpirun -n 2 ./benchmarks/gptManagerBenchmark \
|
||||||
--model gpt \
|
|
||||||
--engine_dir ../../examples/gpt/trt_engine/gpt2-ib/fp16/2-gpu/ \
|
--engine_dir ../../examples/gpt/trt_engine/gpt2-ib/fp16/2-gpu/ \
|
||||||
--type IFB \
|
--type IFB \
|
||||||
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
|
--dataset ../../benchmarks/cpp/preprocessed_dataset.json
|
||||||
@ -165,7 +161,6 @@ Given a `static_emulated_batch_size` of `n` the server will wait for `n` request
|
|||||||
Take GPT-350M as an example for single GPU with static batching
|
Take GPT-350M as an example for single GPU with static batching
|
||||||
```
|
```
|
||||||
./benchmarks/gptManagerBenchmark \
|
./benchmarks/gptManagerBenchmark \
|
||||||
--model gpt \
|
|
||||||
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
|
--engine_dir ../../examples/gpt/trt_engine/gpt2/fp16/1-gpu/ \
|
||||||
--type IFB \
|
--type IFB \
|
||||||
--static_emulated_batch_size 32 \
|
--static_emulated_batch_size 32 \
|
||||||
|
|||||||
@ -237,7 +237,7 @@ int main(int argc, char* argv[])
|
|||||||
benchmarkBert(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes, inLens,
|
benchmarkBert(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes, inLens,
|
||||||
logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>());
|
logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>());
|
||||||
}
|
}
|
||||||
catch (const std::exception& e)
|
catch (std::exception const& e)
|
||||||
{
|
{
|
||||||
TLLM_LOG_ERROR(e.what());
|
TLLM_LOG_ERROR(e.what());
|
||||||
return 1;
|
return 1;
|
||||||
|
|||||||
@ -24,6 +24,7 @@
|
|||||||
#include "tensorrt_llm/common/stringUtils.h"
|
#include "tensorrt_llm/common/stringUtils.h"
|
||||||
#include "tensorrt_llm/executor/executor.h"
|
#include "tensorrt_llm/executor/executor.h"
|
||||||
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
||||||
|
#include "tensorrt_llm/runtime/common.h"
|
||||||
#include "tensorrt_llm/runtime/tllmLogger.h"
|
#include "tensorrt_llm/runtime/tllmLogger.h"
|
||||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||||
|
|
||||||
@ -64,20 +65,18 @@ struct BenchmarkParams
|
|||||||
class WorkItem
|
class WorkItem
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
WorkItem(std::shared_ptr<InferenceRequest> ir, uint64_t requestId)
|
WorkItem(std::shared_ptr<InferenceRequest> inferenceRequest, uint64_t requestId)
|
||||||
: mInferenceRequest(ir)
|
: mInferenceRequest(std::move(inferenceRequest))
|
||||||
, mRequestId(requestId)
|
, mRequestId(requestId)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
~WorkItem() {}
|
[[nodiscard]] uint64_t requestId() const
|
||||||
|
|
||||||
uint64_t requestId() const
|
|
||||||
{
|
{
|
||||||
return mRequestId;
|
return mRequestId;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::shared_ptr<InferenceRequest> getInferenceRequest() const
|
[[nodiscard]] std::shared_ptr<InferenceRequest> getInferenceRequest() const
|
||||||
{
|
{
|
||||||
return mInferenceRequest;
|
return mInferenceRequest;
|
||||||
}
|
}
|
||||||
@ -93,7 +92,7 @@ class WorkItemsQueue
|
|||||||
public:
|
public:
|
||||||
void clear()
|
void clear()
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> lk(mMutex);
|
std::lock_guard<std::mutex> lock(mMutex);
|
||||||
mPendingWorkItems.clear();
|
mPendingWorkItems.clear();
|
||||||
mPendingWorkItemsReqIds.clear();
|
mPendingWorkItemsReqIds.clear();
|
||||||
mInProgressWorkItems.clear();
|
mInProgressWorkItems.clear();
|
||||||
@ -289,7 +288,7 @@ public:
|
|||||||
|
|
||||||
if (outputFile.is_open())
|
if (outputFile.is_open())
|
||||||
{
|
{
|
||||||
for (const auto& header : headers)
|
for (auto const& header : headers)
|
||||||
{
|
{
|
||||||
outputFile << header << ",";
|
outputFile << header << ",";
|
||||||
}
|
}
|
||||||
@ -340,13 +339,12 @@ public:
|
|||||||
mExecutor = std::make_shared<texec::Executor>(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig);
|
mExecutor = std::make_shared<texec::Executor>(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig);
|
||||||
}
|
}
|
||||||
|
|
||||||
~ExecutorServer() {}
|
|
||||||
|
|
||||||
void enqueue(std::vector<texec::Request> requests, bool warmup = false)
|
void enqueue(std::vector<texec::Request> requests, bool warmup = false)
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
std::vector<SizeType> inputLengths, maxNewTokens;
|
std::vector<SizeType> inputLengths;
|
||||||
|
std::vector<SizeType> maxNewTokens;
|
||||||
for (auto const& request : requests)
|
for (auto const& request : requests)
|
||||||
{
|
{
|
||||||
inputLengths.push_back(request.getInputTokenIds().size());
|
inputLengths.push_back(request.getInputTokenIds().size());
|
||||||
@ -363,11 +361,10 @@ public:
|
|||||||
mActiveCount++;
|
mActiveCount++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch (const std::exception& e)
|
catch (std::exception const& e)
|
||||||
{
|
{
|
||||||
TLLM_THROW("%s", e.what());
|
TLLM_THROW("%s", e.what());
|
||||||
}
|
}
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void waitForResponses(std::optional<SizeType> numRequests, bool warmup = false)
|
void waitForResponses(std::optional<SizeType> numRequests, bool warmup = false)
|
||||||
@ -415,17 +412,16 @@ private:
|
|||||||
class GptServer
|
class GptServer
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, int32_t maxBeamWidth,
|
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType, SizeType maxBeamWidth,
|
||||||
batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams,
|
batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams,
|
||||||
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId, std::chrono::milliseconds waitSleep,
|
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId, std::chrono::milliseconds waitSleep,
|
||||||
std::optional<uint64_t> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs, bool logIterationData)
|
std::optional<SizeType> const staticEmulatedBatchSize,
|
||||||
|
std::optional<std::chrono::milliseconds> const batchTimeout, bool logIterationData)
|
||||||
: mRecorder(std::move(recorder))
|
: mRecorder(std::move(recorder))
|
||||||
, mTerminateReqId(terminateReqId)
|
, mTerminateReqId(terminateReqId)
|
||||||
, mWaitSleep(waitSleep)
|
, mWaitSleep(waitSleep)
|
||||||
, mStaticEmulatedBatchSize(staticEmulatedBatchSize)
|
, mStaticEmulatedBatchSize(staticEmulatedBatchSize)
|
||||||
, mEmulatedBatchEndTimestamp(
|
, mBatchTimeout(batchTimeout.value_or(std::chrono::milliseconds{0}))
|
||||||
std::chrono::steady_clock::now() + std::chrono::milliseconds(staticEmulatedTimeoutMs))
|
|
||||||
, mStaticEmulatedTimeoutMs(staticEmulatedTimeoutMs)
|
|
||||||
, mActiveCount(0)
|
, mActiveCount(0)
|
||||||
{
|
{
|
||||||
ReturnBatchManagerStatsCallback iterationDataCallback = [this, logIterationData](std::string const& log)
|
ReturnBatchManagerStatsCallback iterationDataCallback = [this, logIterationData](std::string const& log)
|
||||||
@ -473,16 +469,21 @@ public:
|
|||||||
mRecorder->recordStart(request, requestId);
|
mRecorder->recordStart(request, requestId);
|
||||||
mWorkItemsQueue.push(request, requestId);
|
mWorkItemsQueue.push(request, requestId);
|
||||||
}
|
}
|
||||||
catch (const tc::TllmException& e)
|
catch (tc::TllmException const& e)
|
||||||
{
|
{
|
||||||
throw;
|
throw;
|
||||||
}
|
}
|
||||||
catch (const std::exception& e)
|
catch (std::exception const& e)
|
||||||
{
|
{
|
||||||
TLLM_THROW("%s", e.what());
|
TLLM_THROW("%s", e.what());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void resetBatchDeadline()
|
||||||
|
{
|
||||||
|
mBatchDeadline = (std::chrono::steady_clock::now() + mBatchTimeout).time_since_epoch();
|
||||||
|
}
|
||||||
|
|
||||||
void waitForEmpty() const
|
void waitForEmpty() const
|
||||||
{
|
{
|
||||||
while (!mWorkItemsQueue.empty())
|
while (!mWorkItemsQueue.empty())
|
||||||
@ -502,9 +503,9 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Return up to max_num_requests inference requests.
|
// Return up to max_num_requests inference requests.
|
||||||
std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(const int max_num_requests)
|
std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(int const max_num_requests)
|
||||||
{
|
{
|
||||||
std::list<std::shared_ptr<InferenceRequest>> rval;
|
std::list<std::shared_ptr<InferenceRequest>> inferenceRequests;
|
||||||
auto& comm = COMM_SESSION;
|
auto& comm = COMM_SESSION;
|
||||||
if (max_num_requests > 0)
|
if (max_num_requests > 0)
|
||||||
{
|
{
|
||||||
@ -515,12 +516,12 @@ public:
|
|||||||
auto const numNewWorkItems = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
|
auto const numNewWorkItems = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
|
||||||
static_cast<int64_t>(max_num_requests));
|
static_cast<int64_t>(max_num_requests));
|
||||||
|
|
||||||
bool readyForNextBatch = numNewWorkItems > 0;
|
bool const timeout = std::chrono::steady_clock::now().time_since_epoch() > mBatchDeadline.load();
|
||||||
|
bool readyForNextBatch = numNewWorkItems > 0 && timeout;
|
||||||
if (mStaticEmulatedBatchSize)
|
if (mStaticEmulatedBatchSize)
|
||||||
{
|
{
|
||||||
if (numNewWorkItems > 0)
|
if (numNewWorkItems > 0)
|
||||||
{
|
{
|
||||||
bool const timeout = std::chrono::steady_clock::now() > mEmulatedBatchEndTimestamp;
|
|
||||||
bool const previousBatchFinished = mActiveCount == 0;
|
bool const previousBatchFinished = mActiveCount == 0;
|
||||||
bool const haveEnoughForNextBatch = numNewWorkItems >= mStaticEmulatedBatchSize.value();
|
bool const haveEnoughForNextBatch = numNewWorkItems >= mStaticEmulatedBatchSize.value();
|
||||||
readyForNextBatch = previousBatchFinished && (timeout || haveEnoughForNextBatch);
|
readyForNextBatch = previousBatchFinished && (timeout || haveEnoughForNextBatch);
|
||||||
@ -529,26 +530,23 @@ public:
|
|||||||
{
|
{
|
||||||
// Timeout should only begin once we have at least 1 pending request.
|
// Timeout should only begin once we have at least 1 pending request.
|
||||||
// Reset timeout when no requests are pending or we submit a new batch.
|
// Reset timeout when no requests are pending or we submit a new batch.
|
||||||
mEmulatedBatchEndTimestamp
|
resetBatchDeadline();
|
||||||
= std::chrono::steady_clock::now() + std::chrono::milliseconds(mStaticEmulatedTimeoutMs);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (readyForNextBatch)
|
if (readyForNextBatch)
|
||||||
{
|
{
|
||||||
int count = 0;
|
|
||||||
// Only add a single batch at a time when emulating static batching
|
// Only add a single batch at a time when emulating static batching
|
||||||
auto const numItemsToAdd = std::min(
|
auto const numItemsToAdd = std::min(
|
||||||
numNewWorkItems, static_cast<int64_t>(mStaticEmulatedBatchSize.value_or(numNewWorkItems)));
|
numNewWorkItems, static_cast<int64_t>(mStaticEmulatedBatchSize.value_or(numNewWorkItems)));
|
||||||
mActiveCount += numItemsToAdd;
|
mActiveCount += numItemsToAdd;
|
||||||
while (count < numItemsToAdd)
|
while (inferenceRequests.size() < numItemsToAdd)
|
||||||
{
|
{
|
||||||
auto [workItem, markedInProgress] = mWorkItemsQueue.pop();
|
auto [workItem, markedInProgress] = mWorkItemsQueue.pop();
|
||||||
|
|
||||||
if (markedInProgress)
|
if (markedInProgress)
|
||||||
{
|
{
|
||||||
rval.emplace_back(workItem->getInferenceRequest());
|
inferenceRequests.emplace_back(workItem->getInferenceRequest());
|
||||||
count++;
|
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -561,14 +559,14 @@ public:
|
|||||||
}
|
}
|
||||||
if (world_size > 1)
|
if (world_size > 1)
|
||||||
{
|
{
|
||||||
auto numNewWorkItems = static_cast<int64_t>(rval.size());
|
auto numNewWorkItems = static_cast<int64_t>(inferenceRequests.size());
|
||||||
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
|
comm.bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
|
||||||
if (numNewWorkItems > 0)
|
if (numNewWorkItems > 0)
|
||||||
{
|
{
|
||||||
std::vector<int64_t> packed;
|
std::vector<int64_t> packed;
|
||||||
for (auto const& ir : rval)
|
for (auto const& infReq : inferenceRequests)
|
||||||
{
|
{
|
||||||
auto vpacked = ir->serialize();
|
auto vpacked = infReq->serialize();
|
||||||
packed.push_back(static_cast<int64_t>(vpacked.size()));
|
packed.push_back(static_cast<int64_t>(vpacked.size()));
|
||||||
packed.insert(
|
packed.insert(
|
||||||
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
|
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
|
||||||
@ -590,18 +588,18 @@ public:
|
|||||||
for (int64_t count = 0; count < numNewWorkItems; ++count)
|
for (int64_t count = 0; count < numNewWorkItems; ++count)
|
||||||
{
|
{
|
||||||
int64_t n = *(packed_ptr++);
|
int64_t n = *(packed_ptr++);
|
||||||
auto ir = InferenceRequest::deserialize(packed_ptr);
|
auto infReq = InferenceRequest::deserialize(packed_ptr);
|
||||||
packed_ptr += n;
|
packed_ptr += n;
|
||||||
rval.emplace_back(ir);
|
inferenceRequests.emplace_back(infReq);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return rval;
|
return inferenceRequests;
|
||||||
}
|
}
|
||||||
|
|
||||||
void sendResponse(uint64_t requestId, [[maybe_unused]] std::list<NamedTensor> const& response_tensors,
|
void sendResponse(uint64_t requestId, [[maybe_unused]] std::list<NamedTensor> const& response_tensors,
|
||||||
bool final_response, [[maybe_unused]] const std::string& errMsg)
|
bool final_response, [[maybe_unused]] std::string const& errMsg)
|
||||||
{
|
{
|
||||||
// `response_tensors` contains `outputIds, sequenceLength, [contextLogits, generationLogits], logProbs,
|
// `response_tensors` contains `outputIds, sequenceLength, [contextLogits, generationLogits], logProbs,
|
||||||
// cumLogProbs`. `contextLogits, generationLogits` are optional, only contained when `gather_context_logits` and
|
// cumLogProbs`. `contextLogits, generationLogits` are optional, only contained when `gather_context_logits` and
|
||||||
@ -616,7 +614,7 @@ public:
|
|||||||
mActiveCount--;
|
mActiveCount--;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
catch (const std::exception& e)
|
catch (std::exception const& e)
|
||||||
{
|
{
|
||||||
TLLM_LOG_ERROR("Failed to send response for requestId %lu\n%s", requestId, e.what());
|
TLLM_LOG_ERROR("Failed to send response for requestId %lu\n%s", requestId, e.what());
|
||||||
}
|
}
|
||||||
@ -628,9 +626,9 @@ private:
|
|||||||
WorkItemsQueue mWorkItemsQueue;
|
WorkItemsQueue mWorkItemsQueue;
|
||||||
std::optional<uint64_t> mTerminateReqId;
|
std::optional<uint64_t> mTerminateReqId;
|
||||||
std::chrono::milliseconds mWaitSleep;
|
std::chrono::milliseconds mWaitSleep;
|
||||||
std::optional<int> mStaticEmulatedBatchSize;
|
std::optional<SizeType> mStaticEmulatedBatchSize;
|
||||||
std::chrono::time_point<std::chrono::steady_clock> mEmulatedBatchEndTimestamp;
|
std::chrono::milliseconds mBatchTimeout;
|
||||||
int32_t mStaticEmulatedTimeoutMs;
|
std::atomic<std::chrono::steady_clock::time_point::duration> mBatchDeadline;
|
||||||
std::atomic<uint64_t> mActiveCount;
|
std::atomic<uint64_t> mActiveCount;
|
||||||
|
|
||||||
}; // class GptServer
|
}; // class GptServer
|
||||||
@ -674,10 +672,9 @@ std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const&
|
|||||||
auto request = std::make_shared<InferenceRequest>(reqId);
|
auto request = std::make_shared<InferenceRequest>(reqId);
|
||||||
auto const& inputIds = sample.inputIds;
|
auto const& inputIds = sample.inputIds;
|
||||||
request->setInputIds(bufferManager.copyFrom(
|
request->setInputIds(bufferManager.copyFrom(
|
||||||
inputIds, ITensor::makeShape({static_cast<SizeType>(inputIds.size())}), MemoryType::kPINNED));
|
inputIds, ITensor::makeShape({static_cast<SizeType>(inputIds.size())}), MemoryType::kCPU));
|
||||||
auto const requestOutputLen = sample.outputLen;
|
auto const requestOutputLen = sample.outputLen;
|
||||||
request->setMaxNewTokens(
|
request->setMaxNewTokens(bufferManager.copyFrom(&requestOutputLen, ITensor::makeShape({1, 1}), MemoryType::kCPU));
|
||||||
bufferManager.copyFrom(&requestOutputLen, ITensor::makeShape({1, 1}), MemoryType::kPINNED));
|
|
||||||
request->setBeamWidth(beamWidthTensor);
|
request->setBeamWidth(beamWidthTensor);
|
||||||
if (eosId != nullptr)
|
if (eosId != nullptr)
|
||||||
{
|
{
|
||||||
@ -704,14 +701,15 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType const& beamWid
|
|||||||
{
|
{
|
||||||
auto samplingConfig = texec::SamplingConfig{beamWidth};
|
auto samplingConfig = texec::SamplingConfig{beamWidth};
|
||||||
auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false};
|
auto outputConfig = texec::OutputConfig{false, returnContextLogits, returnGenerationLogits, false};
|
||||||
return texec::Request(sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId);
|
return {sample.inputIds, sample.outputLen, streaming, samplingConfig, outputConfig, eosId, padId};
|
||||||
}
|
}
|
||||||
|
|
||||||
void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType,
|
void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType,
|
||||||
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
|
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
|
||||||
std::optional<int32_t> const& eosId, std::optional<int32_t> const& padId, BenchmarkParams const& benchmarkParams,
|
std::optional<TokenIdType> const& eosId, std::optional<TokenIdType> const& padId,
|
||||||
batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits,
|
BenchmarkParams const& benchmarkParams, batch_scheduler::SchedulerPolicy schedulerPolicy,
|
||||||
bool returnGenerationLogits, std::optional<int> const staticEmulatedBatchSize, int const staticEmulatedTimeoutMs,
|
std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits,
|
||||||
|
std::optional<SizeType> const staticEmulatedBatchSize, std::optional<std::chrono::milliseconds> const batchTimeout,
|
||||||
bool logIterationData)
|
bool logIterationData)
|
||||||
{
|
{
|
||||||
auto const worldConfig = WorldConfig::mpi();
|
auto const worldConfig = WorldConfig::mpi();
|
||||||
@ -736,14 +734,14 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
|||||||
bufferManager.copyFrom(&beamWidth, ITensor::makeShape({1}), MemoryType::kPINNED)};
|
bufferManager.copyFrom(&beamWidth, ITensor::makeShape({1}), MemoryType::kPINNED)};
|
||||||
|
|
||||||
// Load dataset
|
// Load dataset
|
||||||
const auto samples = parseWorkloadJson(datasetPath, maxNumSamples);
|
auto const samples = parseWorkloadJson(datasetPath, maxNumSamples);
|
||||||
const auto numSamples = samples.size();
|
auto const numSamples = samples.size();
|
||||||
|
|
||||||
const int maxBeamWidth = beamWidth;
|
int const maxBeamWidth = beamWidth;
|
||||||
auto recorder = std::make_shared<Recorder>(opCsvFile);
|
auto recorder = std::make_shared<Recorder>(opCsvFile);
|
||||||
uint64_t terminateReqId = numSamples + 1;
|
uint64_t terminateReqId = numSamples + 1;
|
||||||
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams,
|
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams,
|
||||||
recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, staticEmulatedTimeoutMs, logIterationData);
|
recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, batchTimeout, logIterationData);
|
||||||
|
|
||||||
ITensor::SharedPtr eosIdTensor{
|
ITensor::SharedPtr eosIdTensor{
|
||||||
eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr};
|
eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr};
|
||||||
@ -761,6 +759,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
|||||||
if (worldConfig.getRank() == 0)
|
if (worldConfig.getRank() == 0)
|
||||||
{
|
{
|
||||||
// Warm up
|
// Warm up
|
||||||
|
gptServer->resetBatchDeadline();
|
||||||
SizeType reqId = 0;
|
SizeType reqId = 0;
|
||||||
for (auto i = 0; i < warmUp; ++i)
|
for (auto i = 0; i < warmUp; ++i)
|
||||||
{
|
{
|
||||||
@ -774,6 +773,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
|||||||
|
|
||||||
// Benchmark
|
// Benchmark
|
||||||
recorder->initialize();
|
recorder->initialize();
|
||||||
|
gptServer->resetBatchDeadline();
|
||||||
for (std::size_t i = 0; i < numSamples; ++i)
|
for (std::size_t i = 0; i < numSamples; ++i)
|
||||||
{
|
{
|
||||||
auto request = makeRequest(i + 1, samples[i], beamWidthTensor, eosIdTensor, padIdTensor, bufferManager,
|
auto request = makeRequest(i + 1, samples[i], beamWidthTensor, eosIdTensor, padIdTensor, bufferManager,
|
||||||
@ -806,23 +806,19 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m
|
|||||||
batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits,
|
batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits,
|
||||||
bool returnGenerationLogits, std::optional<int> const staticEmulatedBatchSize, bool logIterationData)
|
bool returnGenerationLogits, std::optional<int> const staticEmulatedBatchSize, bool logIterationData)
|
||||||
{
|
{
|
||||||
// Check that mpi size is 1 for now
|
auto const& world = tensorrt_llm::mpi::MpiComm::world();
|
||||||
auto const worldConfig = WorldConfig::mpi();
|
auto worldRank = world.getRank();
|
||||||
if (worldConfig.getSize() > 1)
|
|
||||||
{
|
|
||||||
TLLM_THROW("benchmarkExecutor does not yet support mpiSize > 1");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load dataset
|
// Load dataset
|
||||||
const auto samples = parseWorkloadJson(datasetPath, maxNumSamples);
|
auto const samples = parseWorkloadJson(datasetPath, maxNumSamples);
|
||||||
const auto numSamples = samples.size();
|
auto const numSamples = samples.size();
|
||||||
|
|
||||||
auto recorder = std::make_shared<Recorder>(opCsvFile);
|
auto recorder = std::make_shared<Recorder>(opCsvFile);
|
||||||
|
|
||||||
auto executorServer = std::make_shared<ExecutorServer>(engineDir, modelType, beamWidth, schedulerPolicy,
|
auto executorServer = std::make_shared<ExecutorServer>(engineDir, modelType, beamWidth, schedulerPolicy,
|
||||||
benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData);
|
benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData);
|
||||||
|
|
||||||
if (worldConfig.getRank() == 0)
|
if (worldRank == 0)
|
||||||
{
|
{
|
||||||
// Warm up
|
// Warm up
|
||||||
{
|
{
|
||||||
@ -849,7 +845,7 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m
|
|||||||
delays.push_back(static_cast<int>(samples[i].delay * 1000));
|
delays.push_back(static_cast<int>(samples[i].delay * 1000));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool hasDelay = std::any_of(delays.begin(), delays.end(), [](const auto& delay) { return delay > 0; });
|
bool hasDelay = std::any_of(delays.begin(), delays.end(), [](auto const& delay) { return delay > 0; });
|
||||||
if (hasDelay && staticEmulatedBatchSize)
|
if (hasDelay && staticEmulatedBatchSize)
|
||||||
{
|
{
|
||||||
TLLM_THROW("Executor benchmark doesn't support delays with emulated static batch sizes");
|
TLLM_THROW("Executor benchmark doesn't support delays with emulated static batch sizes");
|
||||||
@ -910,9 +906,6 @@ int main(int argc, char* argv[])
|
|||||||
cxxopts::Options options(
|
cxxopts::Options options(
|
||||||
"TensorRT-LLM BatchManager Benchmark", "TensorRT-LLM BatchManager Benchmark for GPT and GPT-like models.");
|
"TensorRT-LLM BatchManager Benchmark", "TensorRT-LLM BatchManager Benchmark for GPT and GPT-like models.");
|
||||||
options.add_options()("h,help", "Print usage");
|
options.add_options()("h,help", "Print usage");
|
||||||
// TODO(rkobus): remove because unused
|
|
||||||
options.add_options()(
|
|
||||||
"m,model", "Model name specified for engines.", cxxopts::value<std::string>()->default_value("gpt_350m"));
|
|
||||||
options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>());
|
options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>());
|
||||||
options.add_options()(
|
options.add_options()(
|
||||||
"api", "API type: gptManager or executor.", cxxopts::value<std::string>()->default_value("gptManager"));
|
"api", "API type: gptManager or executor.", cxxopts::value<std::string>()->default_value("gptManager"));
|
||||||
@ -929,8 +922,8 @@ int main(int argc, char* argv[])
|
|||||||
options.add_options()(
|
options.add_options()(
|
||||||
"warm_up", "Specify warm up iterations before benchmark starts.", cxxopts::value<int>()->default_value("2"));
|
"warm_up", "Specify warm up iterations before benchmark starts.", cxxopts::value<int>()->default_value("2"));
|
||||||
options.add_options()(
|
options.add_options()(
|
||||||
"eos_id", "Specify the end-of-sequence token id.", cxxopts::value<int>()->default_value("-1"));
|
"eos_id", "Specify the end-of-sequence token id.", cxxopts::value<TokenIdType>()->default_value("-1"));
|
||||||
options.add_options()("pad_id", "Specify the padding token id.", cxxopts::value<int>());
|
options.add_options()("pad_id", "Specify the padding token id.", cxxopts::value<TokenIdType>());
|
||||||
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
|
options.add_options()("max_tokens_in_paged_kvcache", "Max tokens in paged K-V Cache.", cxxopts::value<int>());
|
||||||
options.add_options()(
|
options.add_options()(
|
||||||
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
|
"kv_cache_free_gpu_mem_fraction", "K-V Cache Free Gpu Mem Fraction.", cxxopts::value<float>());
|
||||||
@ -949,11 +942,15 @@ int main(int argc, char* argv[])
|
|||||||
options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.",
|
options.add_options()("scheduler_policy", "Choose scheduler policy between max_utilization/guaranteed_no_evict.",
|
||||||
cxxopts::value<std::string>()->default_value("guaranteed_no_evict"));
|
cxxopts::value<std::string>()->default_value("guaranteed_no_evict"));
|
||||||
|
|
||||||
|
options.add_options()("first_batch_delay",
|
||||||
|
"Delay before submitting the first batch of requests. This can be used to increase the size of the first "
|
||||||
|
"batch.",
|
||||||
|
cxxopts::value<int32_t>());
|
||||||
options.add_options()("static_emulated_batch_size",
|
options.add_options()("static_emulated_batch_size",
|
||||||
"Emulate static batching performance with the provided batch size.", cxxopts::value<int>());
|
"Emulate static batching performance with the provided batch size.", cxxopts::value<SizeType>());
|
||||||
options.add_options()("static_emulated_timeout",
|
options.add_options()("static_emulated_timeout",
|
||||||
"Timeout (ms) before launching a partial batch in emulated static batching mode",
|
"Timeout (ms) before launching a partial batch in emulated static batching mode",
|
||||||
cxxopts::value<int>()->default_value("500"));
|
cxxopts::value<int32_t>()->default_value("500"));
|
||||||
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
|
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
|
||||||
cxxopts::value<std::string>()->default_value("error"));
|
cxxopts::value<std::string>()->default_value("error"));
|
||||||
options.add_options()("log_iteration_data", "On each decoder iteration, print batch state metadata.",
|
options.add_options()("log_iteration_data", "On each decoder iteration, print batch state metadata.",
|
||||||
@ -1042,23 +1039,31 @@ int main(int argc, char* argv[])
|
|||||||
// Argument: Enable return context logits
|
// Argument: Enable return context logits
|
||||||
bool returnGenerationLogits = result["return_generation_logits"].as<bool>();
|
bool returnGenerationLogits = result["return_generation_logits"].as<bool>();
|
||||||
|
|
||||||
std::optional<int32_t> padId;
|
std::optional<TokenIdType> padId;
|
||||||
// Argument: Padding token id
|
// Argument: Padding token id
|
||||||
if (result.count("pad_id"))
|
if (result.count("pad_id"))
|
||||||
{
|
{
|
||||||
padId = result["pad_id"].as<int>();
|
padId = result["pad_id"].as<TokenIdType>();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Argument: End-of-sentence token id
|
// Argument: End-of-sentence token id
|
||||||
std::optional<int32_t> eosId = result["eos_id"].as<int>();
|
std::optional<TokenIdType> eosId = result["eos_id"].as<TokenIdType>();
|
||||||
|
|
||||||
std::optional<int> staticEmulatedBatchSize;
|
std::optional<std::chrono::milliseconds> batchTimeout;
|
||||||
|
// Argument: first_batch_delay
|
||||||
|
if (result.count("first_batch_delay"))
|
||||||
|
{
|
||||||
|
batchTimeout = std::chrono::milliseconds{result["first_batch_delay"].as<int32_t>()};
|
||||||
|
}
|
||||||
|
|
||||||
|
std::optional<SizeType> staticEmulatedBatchSize;
|
||||||
// Argument: Static emulated batch size
|
// Argument: Static emulated batch size
|
||||||
if (result.count("static_emulated_batch_size"))
|
if (result.count("static_emulated_batch_size"))
|
||||||
{
|
{
|
||||||
staticEmulatedBatchSize = result["static_emulated_batch_size"].as<int>();
|
staticEmulatedBatchSize = result["static_emulated_batch_size"].as<SizeType>();
|
||||||
|
|
||||||
|
batchTimeout = std::chrono::milliseconds{result["static_emulated_timeout"].as<int32_t>()};
|
||||||
}
|
}
|
||||||
auto const staticEmulatedTimeout = result["static_emulated_timeout"].as<int>();
|
|
||||||
|
|
||||||
// Argument: Scheduler policy
|
// Argument: Scheduler policy
|
||||||
batch_scheduler::SchedulerPolicy schedulerPolicy;
|
batch_scheduler::SchedulerPolicy schedulerPolicy;
|
||||||
@ -1114,10 +1119,10 @@ int main(int argc, char* argv[])
|
|||||||
{
|
{
|
||||||
benchmarkGptManager(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile,
|
benchmarkGptManager(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile,
|
||||||
maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy,
|
maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy,
|
||||||
waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, staticEmulatedTimeout,
|
waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, batchTimeout,
|
||||||
logIterationData);
|
logIterationData);
|
||||||
}
|
}
|
||||||
catch (const std::exception& e)
|
catch (std::exception const& e)
|
||||||
{
|
{
|
||||||
TLLM_LOG_ERROR(e.what());
|
TLLM_LOG_ERROR(e.what());
|
||||||
return 1;
|
return 1;
|
||||||
@ -1131,7 +1136,7 @@ int main(int argc, char* argv[])
|
|||||||
beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy, waitSleep,
|
beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy, waitSleep,
|
||||||
returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, logIterationData);
|
returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, logIterationData);
|
||||||
}
|
}
|
||||||
catch (const std::exception& e)
|
catch (std::exception const& e)
|
||||||
{
|
{
|
||||||
TLLM_LOG_ERROR(e.what());
|
TLLM_LOG_ERROR(e.what());
|
||||||
return 1;
|
return 1;
|
||||||
|
|||||||
@ -15,7 +15,6 @@
|
|||||||
* limitations under the License.
|
* limitations under the License.
|
||||||
*/
|
*/
|
||||||
#include "tensorrt_llm/common/cudaUtils.h"
|
#include "tensorrt_llm/common/cudaUtils.h"
|
||||||
#include "tensorrt_llm/common/memoryUtils.h"
|
|
||||||
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
#include "tensorrt_llm/plugins/api/tllmPlugin.h"
|
||||||
#include "tensorrt_llm/runtime/gptJsonConfig.h"
|
#include "tensorrt_llm/runtime/gptJsonConfig.h"
|
||||||
#include "tensorrt_llm/runtime/gptSession.h"
|
#include "tensorrt_llm/runtime/gptSession.h"
|
||||||
@ -56,12 +55,11 @@ size_t monitorMemory(std::atomic_bool& done)
|
|||||||
return peakMem;
|
return peakMem;
|
||||||
}
|
}
|
||||||
|
|
||||||
void benchmarkGptSession(std::string const& modelName, std::filesystem::path const& dataPath,
|
void benchmarkGptSession(std::filesystem::path const& dataPath, std::vector<int> const& batchSizes, int beamWidth,
|
||||||
std::vector<int> const& batchSizes, int beamWidth, std::vector<std::vector<int>> const& inOutLen,
|
std::vector<std::vector<int>> const& inOutLen, std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp,
|
||||||
std::shared_ptr<nvinfer1::ILogger> const& logger, int warmUp, int numRuns, int duration,
|
int numRuns, int duration, GptSession::Config& sessionConfig, bool cudaGraphMode, bool printAllLogits,
|
||||||
GptSession::Config& sessionConfig, bool cudaGraphMode, bool printAllLogits, bool disableForceMaxTokens)
|
bool disableForceMaxTokens)
|
||||||
{
|
{
|
||||||
std::string modelNameHyphen = modelName;
|
|
||||||
std::filesystem::path jsonFileName = dataPath / "config.json";
|
std::filesystem::path jsonFileName = dataPath / "config.json";
|
||||||
auto const json = GptJsonConfig::parse(jsonFileName);
|
auto const json = GptJsonConfig::parse(jsonFileName);
|
||||||
auto const modelConfig = json.getModelConfig();
|
auto const modelConfig = json.getModelConfig();
|
||||||
@ -69,7 +67,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
|||||||
SizeType deviceCount{0};
|
SizeType deviceCount{0};
|
||||||
TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
|
TLLM_CUDA_CHECK(cudaGetDeviceCount(&deviceCount));
|
||||||
auto const worldConfig = WorldConfig::mpi(deviceCount, json.getTensorParallelism(), json.getPipelineParallelism());
|
auto const worldConfig = WorldConfig::mpi(deviceCount, json.getTensorParallelism(), json.getPipelineParallelism());
|
||||||
auto const enginePath = dataPath / json.engineFilename(worldConfig, modelNameHyphen);
|
auto const enginePath = dataPath / json.engineFilename(worldConfig);
|
||||||
auto const dtype = modelConfig.getDataType();
|
auto const dtype = modelConfig.getDataType();
|
||||||
auto const maxNumTokens = modelConfig.getMaxNumTokens();
|
auto const maxNumTokens = modelConfig.getMaxNumTokens();
|
||||||
auto const useHalf = (dtype == nvinfer1::DataType::kHALF);
|
auto const useHalf = (dtype == nvinfer1::DataType::kHALF);
|
||||||
@ -104,7 +102,7 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
|||||||
|
|
||||||
auto& memoryCounter = MemoryCounters::getInstance();
|
auto& memoryCounter = MemoryCounters::getInstance();
|
||||||
TLLM_LOG_INFO(memoryCounter.toString());
|
TLLM_LOG_INFO(memoryCounter.toString());
|
||||||
|
std::atomic_bool done;
|
||||||
for (auto const batchSize : batchSizes)
|
for (auto const batchSize : batchSizes)
|
||||||
{
|
{
|
||||||
if (inputPacked && maxNumTokens != std::nullopt)
|
if (inputPacked && maxNumTokens != std::nullopt)
|
||||||
@ -114,10 +112,11 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
|||||||
"benchmark on %d tokens",
|
"benchmark on %d tokens",
|
||||||
maxNumTokens.value(), maxBatchSize * maxInputLength);
|
maxNumTokens.value(), maxBatchSize * maxInputLength);
|
||||||
}
|
}
|
||||||
std::atomic_bool done = false;
|
done = false;
|
||||||
|
auto peakMemFuture = std::async(&monitorMemory, std::ref(done));
|
||||||
|
size_t peakMem;
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
auto peakMemFuture = std::async(&monitorMemory, std::ref(done));
|
|
||||||
TLLM_LOG_INFO(memoryCounter.toString());
|
TLLM_LOG_INFO(memoryCounter.toString());
|
||||||
|
|
||||||
std::vector<SizeType> inputLengthsHost(batchSize, maxInputLength);
|
std::vector<SizeType> inputLengthsHost(batchSize, maxInputLength);
|
||||||
@ -205,7 +204,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
|||||||
|
|
||||||
TLLM_LOG_INFO(memoryCounter.toString());
|
TLLM_LOG_INFO(memoryCounter.toString());
|
||||||
done = true;
|
done = true;
|
||||||
size_t peakMem = peakMemFuture.get();
|
peakMemFuture.wait();
|
||||||
|
peakMem = peakMemFuture.get();
|
||||||
|
|
||||||
printf("Benchmarking done. Iteration: %d, duration: %.2f sec.\n", iterIdx, curDuration / 1000);
|
printf("Benchmarking done. Iteration: %d, duration: %.2f sec.\n", iterIdx, curDuration / 1000);
|
||||||
|
|
||||||
@ -275,6 +275,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
|||||||
std::size_t found = std::string(e.what()).find("out of memory");
|
std::size_t found = std::string(e.what()).find("out of memory");
|
||||||
// We need to kill the memory monitor when OOM.
|
// We need to kill the memory monitor when OOM.
|
||||||
done = true;
|
done = true;
|
||||||
|
peakMemFuture.wait();
|
||||||
|
peakMem = peakMemFuture.get();
|
||||||
|
|
||||||
// Unexpected error; rethrow
|
// Unexpected error; rethrow
|
||||||
if (found == std::string::npos)
|
if (found == std::string::npos)
|
||||||
@ -297,6 +299,8 @@ void benchmarkGptSession(std::string const& modelName, std::filesystem::path con
|
|||||||
{
|
{
|
||||||
// We need to kill memory monitor when any other issue occurs
|
// We need to kill memory monitor when any other issue occurs
|
||||||
done = true;
|
done = true;
|
||||||
|
peakMemFuture.wait();
|
||||||
|
peakMem = peakMemFuture.get();
|
||||||
throw;
|
throw;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -311,8 +315,6 @@ int main(int argc, char* argv[])
|
|||||||
cxxopts::Options options(
|
cxxopts::Options options(
|
||||||
"TensorRT-LLM C++ Runtime Benchmark", "TensorRT-LLM C++ Runtime Benchmark for GPT and GPT-like models.");
|
"TensorRT-LLM C++ Runtime Benchmark", "TensorRT-LLM C++ Runtime Benchmark for GPT and GPT-like models.");
|
||||||
options.add_options()("h,help", "Print usage");
|
options.add_options()("h,help", "Print usage");
|
||||||
options.add_options()(
|
|
||||||
"m,model", "Model name specified for engines.", cxxopts::value<std::string>()->default_value("gpt_350m"));
|
|
||||||
options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>());
|
options.add_options()("engine_dir", "Directory that store the engines.", cxxopts::value<std::string>());
|
||||||
options.add_options()("batch_size",
|
options.add_options()("batch_size",
|
||||||
"Specify batch size(s) you want to benchmark. Multiple batch sizes can be separated by \";\", example: "
|
"Specify batch size(s) you want to benchmark. Multiple batch sizes can be separated by \";\", example: "
|
||||||
@ -459,11 +461,11 @@ int main(int argc, char* argv[])
|
|||||||
|
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
benchmarkGptSession(result["model"].as<std::string>(), result["engine_dir"].as<std::string>(), batchSizes,
|
benchmarkGptSession(result["engine_dir"].as<std::string>(), batchSizes, beamWidth, inOutLen, logger,
|
||||||
beamWidth, inOutLen, logger, result["warm_up"].as<int>(), result["num_runs"].as<int>(),
|
result["warm_up"].as<int>(), result["num_runs"].as<int>(), result["duration"].as<int>(), sessionConfig,
|
||||||
result["duration"].as<int>(), sessionConfig, enableCudaGraph, printAllLogits, disableForceMaxTokens);
|
enableCudaGraph, printAllLogits, disableForceMaxTokens);
|
||||||
}
|
}
|
||||||
catch (const std::exception& e)
|
catch (std::exception const& e)
|
||||||
{
|
{
|
||||||
TLLM_LOG_ERROR(e.what());
|
TLLM_LOG_ERROR(e.what());
|
||||||
return 1;
|
return 1;
|
||||||
|
|||||||
@ -86,6 +86,7 @@ class EncDecBuildConfig:
|
|||||||
max_output_len: Optional[int] = None
|
max_output_len: Optional[int] = None
|
||||||
builder_opt: Optional[int] = None
|
builder_opt: Optional[int] = None
|
||||||
n_mels: Optional[int] = None
|
n_mels: Optional[int] = None
|
||||||
|
skip_cross_qkv: bool = False
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
assert self.head_size is not None
|
assert self.head_size is not None
|
||||||
|
|||||||
@ -89,7 +89,11 @@ class BaseBenchmark(object):
|
|||||||
(f'Engine world size ({world_size}) != Runtime world size ({self.world_size})')
|
(f'Engine world size ({world_size}) != Runtime world size ({self.world_size})')
|
||||||
# Load config into self
|
# Load config into self
|
||||||
for key, value in self.config['pretrained_config'].items():
|
for key, value in self.config['pretrained_config'].items():
|
||||||
setattr(self, key, value)
|
if key == "ssm_cfg":
|
||||||
|
for ssm_key, ssm_value in value.items():
|
||||||
|
setattr(self, "mamba_" + ssm_key, ssm_value)
|
||||||
|
else:
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
self.quant_mode = QuantMode.from_quant_algo(
|
self.quant_mode = QuantMode.from_quant_algo(
|
||||||
quant_algo=self.quantization['quant_algo'],
|
quant_algo=self.quantization['quant_algo'],
|
||||||
|
|||||||
@ -327,9 +327,16 @@ def main(args):
|
|||||||
|
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
latencies = []
|
latencies = []
|
||||||
|
# Disable Host memory monitor when cuda graph is enabled for cuda graph performance.
|
||||||
|
disable_host_mem_monitor = False
|
||||||
|
if args.enable_cuda_graph:
|
||||||
|
logger.warning(
|
||||||
|
'Disable host memory monitor when cuda graph is enabled.')
|
||||||
|
disable_host_mem_monitor = True
|
||||||
|
|
||||||
if not disable_mem_monitor:
|
if not disable_mem_monitor:
|
||||||
memory_monitor = MemoryMonitor()
|
memory_monitor = MemoryMonitor(
|
||||||
|
disable_host_mem_monitor=disable_host_mem_monitor)
|
||||||
memory_monitor.start()
|
memory_monitor.start()
|
||||||
|
|
||||||
iter_idx = 0
|
iter_idx = 0
|
||||||
|
|||||||
@ -648,9 +648,12 @@ def build_gpt(args):
|
|||||||
'tp_size': world_size,
|
'tp_size': world_size,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
config = PretrainedConfig.from_dict(config)
|
config = PretrainedConfig.from_dict(config)
|
||||||
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(config)
|
tensorrt_llm_model = tensorrt_llm.models.BaichuanForCausalLM(config)
|
||||||
elif family == "internlm":
|
elif family == "internlm":
|
||||||
|
quant_algo, kv_cache_quant_algo = get_quant_algo(args.quantization)
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'architecture':
|
'architecture':
|
||||||
'LLaMAForCausalLM',
|
'LLaMAForCausalLM',
|
||||||
@ -673,8 +676,10 @@ def build_gpt(args):
|
|||||||
build_config['n_positions'],
|
build_config['n_positions'],
|
||||||
'hidden_act':
|
'hidden_act':
|
||||||
build_config['hidden_act'],
|
build_config['hidden_act'],
|
||||||
'quantization':
|
'quantization': {
|
||||||
quant_mode.to_dict(),
|
'quant_algo': quant_algo,
|
||||||
|
'kv_cache_quant_algo': kv_cache_quant_algo
|
||||||
|
},
|
||||||
'mapping': {
|
'mapping': {
|
||||||
'world_size': world_size,
|
'world_size': world_size,
|
||||||
'tp_size': world_size
|
'tp_size': world_size
|
||||||
@ -696,6 +701,7 @@ def build_gpt(args):
|
|||||||
"has_zero_point": True,
|
"has_zero_point": True,
|
||||||
"pre_quant_scale": False,
|
"pre_quant_scale": False,
|
||||||
})
|
})
|
||||||
|
|
||||||
config = PretrainedConfig.from_dict(config)
|
config = PretrainedConfig.from_dict(config)
|
||||||
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(config)
|
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(config)
|
||||||
elif family == "qwen":
|
elif family == "qwen":
|
||||||
@ -1038,6 +1044,7 @@ def enc_dec_build_helper(component, config, args):
|
|||||||
or quant_mode.is_int8_weight_only()),
|
or quant_mode.is_int8_weight_only()),
|
||||||
quant_mode=quant_mode,
|
quant_mode=quant_mode,
|
||||||
n_mels=n_mels,
|
n_mels=n_mels,
|
||||||
|
skip_cross_qkv=config['skip_cross_qkv'],
|
||||||
)
|
)
|
||||||
|
|
||||||
# build engine
|
# build engine
|
||||||
|
|||||||
@ -22,7 +22,7 @@ from tensorrt_llm.profiler import (MemUnitType, bytes_to_target_unit,
|
|||||||
|
|
||||||
class MemoryMonitor:
|
class MemoryMonitor:
|
||||||
|
|
||||||
def __init__(self, query_interval=0.1):
|
def __init__(self, query_interval=0.1, disable_host_mem_monitor=False):
|
||||||
self.query_interval = query_interval # second(s)
|
self.query_interval = query_interval # second(s)
|
||||||
self.mem_monitor_process = None
|
self.mem_monitor_process = None
|
||||||
# bytes
|
# bytes
|
||||||
@ -35,6 +35,8 @@ class MemoryMonitor:
|
|||||||
self.signal_event = Event() # Sending signal to subprocess
|
self.signal_event = Event() # Sending signal to subprocess
|
||||||
self.peak_mem_queue = Queue() # Receiving results from subprocess
|
self.peak_mem_queue = Queue() # Receiving results from subprocess
|
||||||
|
|
||||||
|
self.disable_host_mem_monitor = disable_host_mem_monitor
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
self.mem_monitor_process = Process(target=self._upd_peak_memory_usage,
|
self.mem_monitor_process = Process(target=self._upd_peak_memory_usage,
|
||||||
args=(self.signal_event,
|
args=(self.signal_event,
|
||||||
@ -70,7 +72,10 @@ class MemoryMonitor:
|
|||||||
peak_mem_queue.put((peak_host_used, peak_device_used))
|
peak_mem_queue.put((peak_host_used, peak_device_used))
|
||||||
|
|
||||||
def get_memory_usage(self):
|
def get_memory_usage(self):
|
||||||
host_used, _, _ = host_memory_info(self.pid)
|
if self.disable_host_mem_monitor:
|
||||||
|
host_used = 0
|
||||||
|
else:
|
||||||
|
host_used, _, _ = host_memory_info(self.pid)
|
||||||
device_used, _, _ = device_memory_info()
|
device_used, _, _ = device_memory_info()
|
||||||
return host_used, device_used
|
return host_used, device_used
|
||||||
|
|
||||||
|
|||||||
@ -36,6 +36,7 @@ option(NVTX_DISABLE "Disable all NVTX features" ON)
|
|||||||
option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)
|
option(WARNING_IS_ERROR "Treat all warnings as errors" OFF)
|
||||||
option(FAST_BUILD "Skip compiling some kernels to accelerate compiling" OFF)
|
option(FAST_BUILD "Skip compiling some kernels to accelerate compiling" OFF)
|
||||||
option(FAST_MATH "Compiling in fast math mode" OFF)
|
option(FAST_MATH "Compiling in fast math mode" OFF)
|
||||||
|
option(INDEX_RANGE_CHECK "Compiling with index range checks" OFF)
|
||||||
|
|
||||||
if(NVTX_DISABLE)
|
if(NVTX_DISABLE)
|
||||||
add_compile_definitions("NVTX_DISABLE")
|
add_compile_definitions("NVTX_DISABLE")
|
||||||
@ -97,6 +98,11 @@ if(FAST_BUILD)
|
|||||||
message(WARNING "Skip some kernels to accelerate compilation")
|
message(WARNING "Skip some kernels to accelerate compilation")
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if(INDEX_RANGE_CHECK)
|
||||||
|
add_compile_definitions("INDEX_RANGE_CHECK")
|
||||||
|
message(WARNING "Check index range to detect OOB accesses")
|
||||||
|
endif()
|
||||||
|
|
||||||
# Determine CUDA version before enabling the language extension
|
# Determine CUDA version before enabling the language extension
|
||||||
check_language(CUDA)
|
check_language(CUDA)
|
||||||
if(CMAKE_CUDA_COMPILER)
|
if(CMAKE_CUDA_COMPILER)
|
||||||
@ -162,10 +168,6 @@ message(STATUS " version: ${CUDAToolkit_VERSION}")
|
|||||||
message(STATUS " libraries: ${CUDAToolkit_LIBRARY_DIR}")
|
message(STATUS " libraries: ${CUDAToolkit_LIBRARY_DIR}")
|
||||||
message(STATUS " include path: ${CUDAToolkit_INCLUDE_DIRS}")
|
message(STATUS " include path: ${CUDAToolkit_INCLUDE_DIRS}")
|
||||||
|
|
||||||
find_library(
|
|
||||||
CUDNN_LIB cudnn
|
|
||||||
HINTS ${CUDNN_ROOT_DIR} ${CUDAToolkit_LIBRARY_DIR}
|
|
||||||
PATH_SUFFIXES lib64 lib lib/x64)
|
|
||||||
set(CUBLAS_LIB CUDA::cublas)
|
set(CUBLAS_LIB CUDA::cublas)
|
||||||
set(CUBLASLT_LIB CUDA::cublasLt)
|
set(CUBLASLT_LIB CUDA::cublasLt)
|
||||||
set(CUDA_DRV_LIB CUDA::cuda_driver)
|
set(CUDA_DRV_LIB CUDA::cuda_driver)
|
||||||
|
|||||||
@ -29,9 +29,9 @@ class InferenceRequest;
|
|||||||
class NamedTensor;
|
class NamedTensor;
|
||||||
|
|
||||||
using GetInferenceRequestsCallback = std::function<std::list<std::shared_ptr<InferenceRequest>>(int32_t)>;
|
using GetInferenceRequestsCallback = std::function<std::list<std::shared_ptr<InferenceRequest>>(int32_t)>;
|
||||||
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, const std::string&)>;
|
using SendResponseCallback = std::function<void(uint64_t, std::list<NamedTensor> const&, bool, std::string const&)>;
|
||||||
using PollStopSignalCallback = std::function<std::unordered_set<uint64_t>()>;
|
using PollStopSignalCallback = std::function<std::unordered_set<uint64_t>()>;
|
||||||
// json of stats as a string
|
// json of stats as a string
|
||||||
using ReturnBatchManagerStatsCallback = std::function<void(const std::string&)>;
|
using ReturnBatchManagerStatsCallback = std::function<void(std::string const&)>;
|
||||||
|
|
||||||
} // namespace tensorrt_llm::batch_manager
|
} // namespace tensorrt_llm::batch_manager
|
||||||
|
|||||||
@ -312,9 +312,9 @@ public:
|
|||||||
|
|
||||||
[[nodiscard]] std::vector<int64_t> serialize() const;
|
[[nodiscard]] std::vector<int64_t> serialize() const;
|
||||||
|
|
||||||
static std::shared_ptr<InferenceRequest> deserialize(const std::vector<int64_t>& packed);
|
static std::shared_ptr<InferenceRequest> deserialize(std::vector<int64_t> const& packed);
|
||||||
|
|
||||||
static std::shared_ptr<InferenceRequest> deserialize(const int64_t* packed_ptr);
|
static std::shared_ptr<InferenceRequest> deserialize(int64_t const* packed_ptr);
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorrt_llm::batch_manager
|
} // namespace tensorrt_llm::batch_manager
|
||||||
|
|||||||
@ -50,6 +50,13 @@ public:
|
|||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool operator==(KvCacheConfig const& other) const
|
||||||
|
{
|
||||||
|
return maxTokens == other.maxTokens && maxAttentionWindow == other.maxAttentionWindow
|
||||||
|
&& sinkTokenLength == other.sinkTokenLength && freeGpuMemoryFraction == other.freeGpuMemoryFraction
|
||||||
|
&& enableBlockReuse == other.enableBlockReuse && useUvm == other.useUvm;
|
||||||
|
}
|
||||||
|
|
||||||
std::optional<SizeType> maxTokens;
|
std::optional<SizeType> maxTokens;
|
||||||
std::optional<SizeType> maxAttentionWindow;
|
std::optional<SizeType> maxAttentionWindow;
|
||||||
std::optional<SizeType> sinkTokenLength;
|
std::optional<SizeType> sinkTokenLength;
|
||||||
|
|||||||
@ -176,6 +176,13 @@ public:
|
|||||||
mNumTokens += n;
|
mNumTokens += n;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void removeTokens(SizeType n)
|
||||||
|
{
|
||||||
|
TLLM_CHECK(n <= mNumTokens);
|
||||||
|
TLLM_CHECK(mNumTokens - n >= 0);
|
||||||
|
mNumTokens -= n;
|
||||||
|
}
|
||||||
|
|
||||||
[[nodiscard]] SizeType getSequenceSlotIdx() const
|
[[nodiscard]] SizeType getSequenceSlotIdx() const
|
||||||
{
|
{
|
||||||
return mSeqSlotIdx;
|
return mSeqSlotIdx;
|
||||||
@ -214,6 +221,14 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void removeLastBlock()
|
||||||
|
{
|
||||||
|
for (auto& beamBlockIds : mCacheBlockIds)
|
||||||
|
{
|
||||||
|
beamBlockIds.pop_back();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void setNumPrepopulatedTokens(std::vector<int> numPrepopulatedTokens)
|
void setNumPrepopulatedTokens(std::vector<int> numPrepopulatedTokens)
|
||||||
{
|
{
|
||||||
mNumPrepopulatedTokens = std::move(numPrepopulatedTokens);
|
mNumPrepopulatedTokens = std::move(numPrepopulatedTokens);
|
||||||
@ -280,32 +295,40 @@ public:
|
|||||||
//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
|
//! \brief Simulate freeing all blocks for that sequence to check impact on number of free blocks
|
||||||
void schedulingReleaseBlocks(GenerationRequest& sequence);
|
void schedulingReleaseBlocks(GenerationRequest& sequence);
|
||||||
|
|
||||||
[[nodiscard]] SizeType getNumFreeBlocks() const
|
//! \brief Release last block in the sequence
|
||||||
|
void releaseLastBlock(GenerationRequest& sequence);
|
||||||
|
|
||||||
|
[[nodiscard]] SizeType getNumFreeBlocks() const noexcept
|
||||||
{
|
{
|
||||||
return mFreeBlocks.size();
|
return mFreeBlocks.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] SizeType getNumAllocatedBlocks() const
|
[[nodiscard]] SizeType getNumReusedBlocks() const noexcept
|
||||||
|
{
|
||||||
|
return mReusedBlocks;
|
||||||
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] SizeType getNumAllocatedBlocks() const noexcept
|
||||||
{
|
{
|
||||||
return getMaxNumBlocks() - getNumFreeBlocks();
|
return getMaxNumBlocks() - getNumFreeBlocks();
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] bool hasFreeBlocks(SizeType numRequired = 1) const
|
[[nodiscard]] bool hasFreeBlocks(SizeType numRequired = 1) const noexcept
|
||||||
{
|
{
|
||||||
return getNumFreeBlocks() >= numRequired;
|
return getNumFreeBlocks() >= numRequired;
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] bool schedulingHasFreeBlocks(SizeType numRequired = 1) const
|
[[nodiscard]] bool schedulingHasFreeBlocks(SizeType numRequired = 1) const noexcept
|
||||||
{
|
{
|
||||||
return mSchedulingNumFreeBlocks >= numRequired;
|
return mSchedulingNumFreeBlocks >= numRequired;
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] SizeType getMaxNumBlocks() const
|
[[nodiscard]] SizeType getMaxNumBlocks() const noexcept
|
||||||
{
|
{
|
||||||
return static_cast<SizeType>(mAllBlocksByIdx.size());
|
return static_cast<SizeType>(mAllBlocksByIdx.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] SizeType getTokensPerBlock() const
|
[[nodiscard]] SizeType getTokensPerBlock() const noexcept
|
||||||
{
|
{
|
||||||
return mTokensPerBlock;
|
return mTokensPerBlock;
|
||||||
}
|
}
|
||||||
@ -478,11 +501,15 @@ public:
|
|||||||
return mEnableBlockReuse;
|
return mEnableBlockReuse;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void removeToken(SizeType seqSlotIdx);
|
||||||
|
void rewindKVCache(SizeType seqSlotIdx, SizeType rewindLengths);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void resetBlockPointers(SizeType seqSlotIdx, SizeType beamWidth);
|
void resetBlockPointers(SizeType seqSlotIdx, SizeType beamWidth);
|
||||||
void cacheBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
|
void cacheBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
|
||||||
void cacheNewBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
|
void cacheNewBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
|
||||||
void updateNewBlockPointer(const GenerationRequest& seq, SizeType seqSlotIdx, SizeType blockIdx);
|
void updateNewBlockPointer(GenerationRequest const& seq, SizeType seqSlotIdx, SizeType blockIdx);
|
||||||
|
void updateToken(SizeType seqSlotIdx, bool addToken);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Number of elements per one blocks
|
// Number of elements per one blocks
|
||||||
|
|||||||
@ -474,7 +474,7 @@ public:
|
|||||||
return mDraftTokens->size();
|
return mDraftTokens->size();
|
||||||
}
|
}
|
||||||
|
|
||||||
void setReturnContextLogits(const bool returnContextLogits)
|
void setReturnContextLogits(bool const returnContextLogits)
|
||||||
{
|
{
|
||||||
mReturnContextLogits = returnContextLogits;
|
mReturnContextLogits = returnContextLogits;
|
||||||
}
|
}
|
||||||
@ -484,7 +484,7 @@ public:
|
|||||||
return mReturnContextLogits;
|
return mReturnContextLogits;
|
||||||
}
|
}
|
||||||
|
|
||||||
void setReturnGenerationLogits(const bool returnGenerationLogits)
|
void setReturnGenerationLogits(bool const returnGenerationLogits)
|
||||||
{
|
{
|
||||||
mReturnGenerationLogits = returnGenerationLogits;
|
mReturnGenerationLogits = returnGenerationLogits;
|
||||||
}
|
}
|
||||||
@ -556,6 +556,11 @@ public:
|
|||||||
return mState == REQUEST_STATE_GENERATION_IN_PROGRESS;
|
return mState == REQUEST_STATE_GENERATION_IN_PROGRESS;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
[[nodiscard]] bool isGenerationCompleteState() const noexcept
|
||||||
|
{
|
||||||
|
return mState == REQUEST_STATE_GENERATION_COMPLETE;
|
||||||
|
}
|
||||||
|
|
||||||
/// To determine whether the context is unchunked. When a context is chunked into only a part, it
|
/// To determine whether the context is unchunked. When a context is chunked into only a part, it
|
||||||
/// is still different from the unchunked state, which indicates the initial status.
|
/// is still different from the unchunked state, which indicates the initial status.
|
||||||
[[nodiscard]] bool isFullContextRequest() const noexcept
|
[[nodiscard]] bool isFullContextRequest() const noexcept
|
||||||
|
|||||||
@ -64,7 +64,7 @@ public:
|
|||||||
using TensorPtr = Base::TensorPtr;
|
using TensorPtr = Base::TensorPtr;
|
||||||
|
|
||||||
NamedTensor(
|
NamedTensor(
|
||||||
nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, const void* _data = nullptr);
|
nvinfer1::DataType _type, std::vector<int64_t> const& _shape, std::string _name, void const* _data = nullptr);
|
||||||
|
|
||||||
NamedTensor(TensorPtr _tensor, std::string _name)
|
NamedTensor(TensorPtr _tensor, std::string _name)
|
||||||
: Base(std::move(_tensor), std::move(_name)){};
|
: Base(std::move(_tensor), std::move(_name)){};
|
||||||
@ -74,6 +74,10 @@ public:
|
|||||||
|
|
||||||
[[nodiscard]] std::vector<int64_t> serialize() const;
|
[[nodiscard]] std::vector<int64_t> serialize() const;
|
||||||
|
|
||||||
static NamedTensor deserialize(const int64_t* packed);
|
void serialize(int64_t* out, const size_t totalSize) const;
|
||||||
|
|
||||||
|
[[nodiscard]] size_t serializedSize() const;
|
||||||
|
|
||||||
|
static NamedTensor deserialize(int64_t const* packed);
|
||||||
};
|
};
|
||||||
} // namespace tensorrt_llm::batch_manager
|
} // namespace tensorrt_llm::batch_manager
|
||||||
|
|||||||
@ -50,11 +50,19 @@ public:
|
|||||||
|
|
||||||
explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig)
|
explicit TrtGptModelOptionalParams(executor::ExecutorConfig const& executorConfig)
|
||||||
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()),
|
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()),
|
||||||
executorConfig.getEnableTrtOverlap(), executorConfig.getDeviceIds(), executorConfig.getNormalizeLogProbs(),
|
executorConfig.getEnableTrtOverlap(),
|
||||||
executorConfig.getEnableChunkedContext())
|
executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(),
|
||||||
|
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext())
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool operator==(TrtGptModelOptionalParams const& other) const
|
||||||
|
{
|
||||||
|
return kvCacheConfig == other.kvCacheConfig && enableTrtOverlap == other.enableTrtOverlap
|
||||||
|
&& deviceIds == other.deviceIds && normalizeLogProbs == other.normalizeLogProbs
|
||||||
|
&& enableChunkedContext == other.enableChunkedContext && decodingMode == other.decodingMode;
|
||||||
|
}
|
||||||
|
|
||||||
KvCacheConfig kvCacheConfig;
|
KvCacheConfig kvCacheConfig;
|
||||||
|
|
||||||
bool enableTrtOverlap;
|
bool enableTrtOverlap;
|
||||||
|
|||||||
@ -16,6 +16,7 @@
|
|||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include "tensorrt_llm/common/assert.h"
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
|
||||||
namespace tensorrt_llm::common
|
namespace tensorrt_llm::common
|
||||||
@ -80,11 +81,17 @@ public:
|
|||||||
|
|
||||||
[[nodiscard]] reference operator[](size_type index)
|
[[nodiscard]] reference operator[](size_type index)
|
||||||
{
|
{
|
||||||
|
#ifdef INDEX_RANGE_CHECK
|
||||||
|
TLLM_CHECK_WITH_INFO(index < mSize, "Index %lu is out of bounds [0, %lu)", index, mSize);
|
||||||
|
#endif
|
||||||
return mData[index];
|
return mData[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
[[nodiscard]] const_reference operator[](size_type index) const
|
[[nodiscard]] const_reference operator[](size_type index) const
|
||||||
{
|
{
|
||||||
|
#ifdef INDEX_RANGE_CHECK
|
||||||
|
TLLM_CHECK_WITH_INFO(index < mSize, "Index %lu is out of bounds [0, %lu)", index, mSize);
|
||||||
|
#endif
|
||||||
return mData[index];
|
return mData[index];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -56,6 +56,7 @@ enum class MpiType
|
|||||||
kUINT64,
|
kUINT64,
|
||||||
kFP8,
|
kFP8,
|
||||||
kBF16,
|
kBF16,
|
||||||
|
kCHAR,
|
||||||
};
|
};
|
||||||
|
|
||||||
//! \brief For converting a C++ data type to a TensorRT data type.
|
//! \brief For converting a C++ data type to a TensorRT data type.
|
||||||
@ -133,6 +134,12 @@ struct MpiTypeConverter<std::uint64_t>
|
|||||||
static constexpr auto value = MpiType::kUINT64;
|
static constexpr auto value = MpiType::kUINT64;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct MpiTypeConverter<char>
|
||||||
|
{
|
||||||
|
static constexpr auto value = MpiType::kCHAR;
|
||||||
|
};
|
||||||
|
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
template <>
|
template <>
|
||||||
struct MpiTypeConverter<__nv_fp8_e4m3>
|
struct MpiTypeConverter<__nv_fp8_e4m3>
|
||||||
@ -202,8 +209,8 @@ public:
|
|||||||
~MpiComm() noexcept;
|
~MpiComm() noexcept;
|
||||||
|
|
||||||
// no copy
|
// no copy
|
||||||
MpiComm(const MpiComm&) = delete;
|
MpiComm(MpiComm const&) = delete;
|
||||||
MpiComm& operator=(const MpiComm&) = delete;
|
MpiComm& operator=(MpiComm const&) = delete;
|
||||||
|
|
||||||
// move
|
// move
|
||||||
MpiComm(MpiComm&&) noexcept;
|
MpiComm(MpiComm&&) noexcept;
|
||||||
@ -253,7 +260,24 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void bcast(std::vector<int64_t>& packed, int root) const;
|
template <typename T>
|
||||||
|
void bcast(std::vector<T>& vec, int root) const
|
||||||
|
{
|
||||||
|
auto const rank = getRank();
|
||||||
|
auto vecSize = (rank == root) ? static_cast<int64_t>(vec.size()) : int64_t(0);
|
||||||
|
bcast(&vecSize, 1, MpiType::kINT64, root);
|
||||||
|
vec.resize(vecSize);
|
||||||
|
|
||||||
|
if constexpr (std::is_fundamental_v<std::remove_cv_t<T>>)
|
||||||
|
{
|
||||||
|
auto const mpiType = MpiTypeConverter<std::remove_cv_t<T>>::value;
|
||||||
|
bcast(vec.data(), vec.size(), mpiType, root);
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
bcast(vec.data(), vec.size() * sizeof(T), MpiType::kBYTE, root);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void send(void const* buffer, std::size_t size, MpiType dtype, int dest, int tag) const;
|
void send(void const* buffer, std::size_t size, MpiType dtype, int dest, int tag) const;
|
||||||
|
|
||||||
@ -297,8 +321,8 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void allreduce(const void* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const;
|
void allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const;
|
||||||
void allgather(const void* sendbuf, void* recvbuf, int count, MpiType dtype) const;
|
void allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const;
|
||||||
void barrier() const;
|
void barrier() const;
|
||||||
|
|
||||||
void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
|
void mprobe(int source, int tag, MPI_Message* msg, MPI_Status* status) const;
|
||||||
|
|||||||
@ -34,6 +34,9 @@
|
|||||||
namespace tensorrt_llm::executor
|
namespace tensorrt_llm::executor
|
||||||
{
|
{
|
||||||
|
|
||||||
|
class Model;
|
||||||
|
class Serialization;
|
||||||
|
|
||||||
/// @brief Sampling configuration
|
/// @brief Sampling configuration
|
||||||
class SamplingConfig
|
class SamplingConfig
|
||||||
{
|
{
|
||||||
@ -51,6 +54,8 @@ public:
|
|||||||
|
|
||||||
~SamplingConfig();
|
~SamplingConfig();
|
||||||
|
|
||||||
|
bool operator==(SamplingConfig const& other) const;
|
||||||
|
|
||||||
[[nodiscard]] SizeType getBeamWidth() const;
|
[[nodiscard]] SizeType getBeamWidth() const;
|
||||||
[[nodiscard]] std::optional<SizeType> getTopK() const;
|
[[nodiscard]] std::optional<SizeType> getTopK() const;
|
||||||
[[nodiscard]] std::optional<FloatType> getTopP() const;
|
[[nodiscard]] std::optional<FloatType> getTopP() const;
|
||||||
@ -68,6 +73,7 @@ public:
|
|||||||
[[nodiscard]] std::optional<SizeType> getEarlyStopping() const;
|
[[nodiscard]] std::optional<SizeType> getEarlyStopping() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class Serialization;
|
||||||
SizeType mBeamWidth;
|
SizeType mBeamWidth;
|
||||||
std::optional<SizeType> mTopK;
|
std::optional<SizeType> mTopK;
|
||||||
std::optional<FloatType> mTopP;
|
std::optional<FloatType> mTopP;
|
||||||
@ -86,12 +92,16 @@ private:
|
|||||||
};
|
};
|
||||||
|
|
||||||
/// @brief Configuration that controls the outputs of a Result
|
/// @brief Configuration that controls the outputs of a Result
|
||||||
struct OutputConfig
|
class OutputConfig
|
||||||
{
|
{
|
||||||
bool returnLogProbs{false};
|
public:
|
||||||
bool returnContextLogits{false};
|
OutputConfig(bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false,
|
||||||
bool returnGenerationLogits{false};
|
bool excludeInputFromOutput = false);
|
||||||
bool excludeInputFromOutput{false};
|
|
||||||
|
bool returnLogProbs;
|
||||||
|
bool returnContextLogits;
|
||||||
|
bool returnGenerationLogits;
|
||||||
|
bool excludeInputFromOutput;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// @brief Configuration for speculative decoding. Allows to include draft tokens, draft logits and specify acceptance
|
/// @brief Configuration for speculative decoding. Allows to include draft tokens, draft logits and specify acceptance
|
||||||
@ -109,6 +119,7 @@ public:
|
|||||||
[[nodiscard]] std::optional<FloatType> getAcceptanceThreshold() const;
|
[[nodiscard]] std::optional<FloatType> getAcceptanceThreshold() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class Serialization;
|
||||||
VecTokens mTokens;
|
VecTokens mTokens;
|
||||||
std::optional<Tensor> mLogits;
|
std::optional<Tensor> mLogits;
|
||||||
std::optional<FloatType> mAcceptanceThreshold;
|
std::optional<FloatType> mAcceptanceThreshold;
|
||||||
@ -128,6 +139,7 @@ public:
|
|||||||
[[nodiscard]] Tensor getEmbeddingTable() const;
|
[[nodiscard]] Tensor getEmbeddingTable() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class Serialization;
|
||||||
Tensor mEmbeddingTable;
|
Tensor mEmbeddingTable;
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -142,6 +154,8 @@ public:
|
|||||||
[[nodiscard]] Tensor getConfig() const;
|
[[nodiscard]] Tensor getConfig() const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class Serialization;
|
||||||
|
|
||||||
Tensor mWeights;
|
Tensor mWeights;
|
||||||
Tensor mConfig;
|
Tensor mConfig;
|
||||||
};
|
};
|
||||||
@ -207,6 +221,7 @@ public:
|
|||||||
void setLoraConfig(LoraConfig loraConfig);
|
void setLoraConfig(LoraConfig loraConfig);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
friend class Serialization;
|
||||||
class Impl;
|
class Impl;
|
||||||
std::unique_ptr<Impl> mImpl;
|
std::unique_ptr<Impl> mImpl;
|
||||||
};
|
};
|
||||||
@ -298,15 +313,49 @@ private:
|
|||||||
|
|
||||||
SizeType const kDefaultIterStatsMaxIterations = 1000;
|
SizeType const kDefaultIterStatsMaxIterations = 1000;
|
||||||
|
|
||||||
|
/// @brief A configuration class for the parallel execution parameters
|
||||||
|
/// Currently only supports commType = CommunicationType::kMPI
|
||||||
|
class ParallelConfig
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
/// @brief Constructor
|
||||||
|
/// @param commType The communication type. See CommunicationType.
|
||||||
|
/// @param commMode The communication mode. See CommunicationMode.
|
||||||
|
/// @param deviceIds The IDs of the GPUs involved in the execution of the model
|
||||||
|
/// @param participantIds The participant IDs (MPI ranks if commType == kMPI) involved in the execution of the
|
||||||
|
/// model. The first participant is considered to be the leader.
|
||||||
|
ParallelConfig(CommunicationType commType = CommunicationType::kMPI,
|
||||||
|
CommunicationMode commMode = CommunicationMode::kLEADER,
|
||||||
|
std::optional<std::vector<SizeType>> deviceIds = std::nullopt,
|
||||||
|
std::optional<std::vector<SizeType>> participantIds = std::nullopt);
|
||||||
|
~ParallelConfig();
|
||||||
|
|
||||||
|
[[nodiscard]] CommunicationType getCommunicationType() const;
|
||||||
|
[[nodiscard]] CommunicationMode getCommunicationMode() const;
|
||||||
|
[[nodiscard]] std::optional<std::vector<SizeType>> getDeviceIds() const;
|
||||||
|
[[nodiscard]] std::optional<std::vector<SizeType>> getParticipantIds() const;
|
||||||
|
|
||||||
|
void setCommunicationType(CommunicationType type);
|
||||||
|
void setCommunicationMode(CommunicationMode mode);
|
||||||
|
void setDeviceIds(std::vector<SizeType> deviceIds);
|
||||||
|
void setParticipantIds(std::vector<SizeType> participantIds);
|
||||||
|
|
||||||
|
private:
|
||||||
|
CommunicationType mCommType;
|
||||||
|
CommunicationMode mCommMode;
|
||||||
|
std::optional<std::vector<SizeType>> mDeviceIds;
|
||||||
|
std::optional<std::vector<SizeType>> mParticipantIds;
|
||||||
|
};
|
||||||
|
|
||||||
/// @brief Configuration class for the model executor
|
/// @brief Configuration class for the model executor
|
||||||
class ExecutorConfig
|
class ExecutorConfig
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
ExecutorConfig(SizeType maxBeamWidth = 1, SchedulerConfig schedulerConfig = SchedulerConfig(),
|
ExecutorConfig(SizeType maxBeamWidth = 1, SchedulerConfig schedulerConfig = SchedulerConfig(),
|
||||||
KvCacheConfig kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false, bool normalizeLogProbs = true,
|
KvCacheConfig kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false, bool normalizeLogProbs = true,
|
||||||
bool enableTrtOverlap = false, std::optional<std::vector<SizeType>> deviceIds = std::nullopt,
|
bool enableTrtOverlap = false, SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations,
|
||||||
SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations,
|
BatchingType batchingType = BatchingType::kINFLIGHT,
|
||||||
BatchingType batchingType = BatchingType::kINFLIGHT);
|
std::optional<ParallelConfig> parallelConfig = std::nullopt);
|
||||||
|
|
||||||
[[nodiscard]] SizeType getMaxBeamWidth() const;
|
[[nodiscard]] SizeType getMaxBeamWidth() const;
|
||||||
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
|
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
|
||||||
@ -314,9 +363,9 @@ public:
|
|||||||
[[nodiscard]] bool getEnableChunkedContext() const;
|
[[nodiscard]] bool getEnableChunkedContext() const;
|
||||||
[[nodiscard]] bool getNormalizeLogProbs() const;
|
[[nodiscard]] bool getNormalizeLogProbs() const;
|
||||||
[[nodiscard]] bool getEnableTrtOverlap() const;
|
[[nodiscard]] bool getEnableTrtOverlap() const;
|
||||||
[[nodiscard]] std::optional<std::vector<SizeType>> getDeviceIds() const;
|
|
||||||
[[nodiscard]] SizeType getIterStatsMaxIterations() const;
|
[[nodiscard]] SizeType getIterStatsMaxIterations() const;
|
||||||
[[nodiscard]] BatchingType getBatchingType() const;
|
[[nodiscard]] BatchingType getBatchingType() const;
|
||||||
|
[[nodiscard]] std::optional<ParallelConfig> getParallelConfig() const;
|
||||||
|
|
||||||
void setMaxBeamWidth(SizeType maxBeamWidth);
|
void setMaxBeamWidth(SizeType maxBeamWidth);
|
||||||
void setSchedulerConfig(SchedulerConfig schedulerConfig);
|
void setSchedulerConfig(SchedulerConfig schedulerConfig);
|
||||||
@ -324,9 +373,9 @@ public:
|
|||||||
void setEnableChunkedContext(bool enableChunkedContext);
|
void setEnableChunkedContext(bool enableChunkedContext);
|
||||||
void setNormalizeLogProbs(bool normalizeLogProbs);
|
void setNormalizeLogProbs(bool normalizeLogProbs);
|
||||||
void setEnableTrtOverlap(bool enableTrtOverlap);
|
void setEnableTrtOverlap(bool enableTrtOverlap);
|
||||||
void setDeviceIds(std::optional<std::vector<SizeType>> deviceIds);
|
|
||||||
void setIterStatsMaxIterations(SizeType iterStatsMaxIterations);
|
void setIterStatsMaxIterations(SizeType iterStatsMaxIterations);
|
||||||
void setBatchingType(BatchingType batchingType);
|
void setBatchingType(BatchingType batchingType);
|
||||||
|
void setParallelConfig(ParallelConfig parallelConfig);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
SizeType mMaxBeamWidth;
|
SizeType mMaxBeamWidth;
|
||||||
@ -335,24 +384,11 @@ private:
|
|||||||
bool mEnableChunkedContext;
|
bool mEnableChunkedContext;
|
||||||
bool mNormalizeLogProbs;
|
bool mNormalizeLogProbs;
|
||||||
bool mEnableTrtOverlap;
|
bool mEnableTrtOverlap;
|
||||||
std::optional<std::vector<SizeType>> mDeviceIds;
|
|
||||||
SizeType mIterStatsMaxIterations;
|
SizeType mIterStatsMaxIterations;
|
||||||
BatchingType mBatchingType;
|
BatchingType mBatchingType;
|
||||||
|
std::optional<ParallelConfig> mParallelConfig;
|
||||||
};
|
};
|
||||||
|
|
||||||
/// TODO:
|
|
||||||
/// @brief A class to identify processes involved in the execution of a model
|
|
||||||
/// Currently only supports MPI communication
|
|
||||||
class Communicator
|
|
||||||
{
|
|
||||||
public:
|
|
||||||
Communicator(CommunicatorType commType, CommMode mode, SizeType currentId, std::vector<SizeType> const& commIds,
|
|
||||||
std::optional<SizeType> orchestratorId){};
|
|
||||||
~Communicator() = default;
|
|
||||||
};
|
|
||||||
|
|
||||||
class Model;
|
|
||||||
|
|
||||||
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
|
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
|
||||||
class Executor
|
class Executor
|
||||||
{
|
{
|
||||||
@ -364,14 +400,12 @@ public:
|
|||||||
/// @param modelType The type of model
|
/// @param modelType The type of model
|
||||||
/// @param executorConfig The configuration for the executor
|
/// @param executorConfig The configuration for the executor
|
||||||
/// @param comm An optional inter-process communicator configuration
|
/// @param comm An optional inter-process communicator configuration
|
||||||
Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig executorConfig,
|
Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig executorConfig);
|
||||||
std::optional<Communicator> comm = std::nullopt);
|
|
||||||
|
|
||||||
Executor(std::vector<uint8_t> const& engineBuffer, std::string const& jsonConfigStr, ModelType modelType,
|
Executor(std::vector<uint8_t> const& engineBuffer, std::string const& jsonConfigStr, ModelType modelType,
|
||||||
ExecutorConfig executorConfig, std::optional<Communicator> comm = std::nullopt);
|
ExecutorConfig executorConfig);
|
||||||
|
|
||||||
Executor(
|
Executor(std::shared_ptr<Model> model, ExecutorConfig executorConfig);
|
||||||
std::shared_ptr<Model> model, ExecutorConfig executorConfig, std::optional<Communicator> comm = std::nullopt);
|
|
||||||
|
|
||||||
~Executor();
|
~Executor();
|
||||||
|
|
||||||
|
|||||||
@ -180,11 +180,11 @@ public:
|
|||||||
|
|
||||||
~Tensor() = default;
|
~Tensor() = default;
|
||||||
|
|
||||||
Tensor(const Tensor& other) noexcept = default;
|
Tensor(Tensor const& other) noexcept = default;
|
||||||
|
|
||||||
Tensor(Tensor&& other) noexcept = default;
|
Tensor(Tensor&& other) noexcept = default;
|
||||||
|
|
||||||
Tensor& operator=(const Tensor& other) noexcept = default;
|
Tensor& operator=(Tensor const& other) noexcept = default;
|
||||||
|
|
||||||
Tensor& operator=(Tensor&& other) noexcept = default;
|
Tensor& operator=(Tensor&& other) noexcept = default;
|
||||||
|
|
||||||
@ -267,6 +267,7 @@ private:
|
|||||||
|
|
||||||
friend std::shared_ptr<runtime::ITensor> const& detail::toITensor(Tensor const& tensor);
|
friend std::shared_ptr<runtime::ITensor> const& detail::toITensor(Tensor const& tensor);
|
||||||
friend Tensor detail::ofITensor(std::shared_ptr<runtime::ITensor> tensor);
|
friend Tensor detail::ofITensor(std::shared_ptr<runtime::ITensor> tensor);
|
||||||
|
friend class Serialization;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorrt_llm::executor
|
} // namespace tensorrt_llm::executor
|
||||||
|
|||||||
@ -155,21 +155,16 @@ enum class SchedulerPolicy
|
|||||||
kGUARANTEED_NO_EVICT = 1,
|
kGUARANTEED_NO_EVICT = 1,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum class CommunicatorType
|
enum class CommunicationType
|
||||||
{
|
{
|
||||||
kMPI = 0
|
kMPI = 0
|
||||||
};
|
};
|
||||||
|
|
||||||
enum class CommMode
|
enum class CommunicationMode
|
||||||
{
|
{
|
||||||
kLEADER, // With the leader mode, only the leader will be returning from the executor constructor and
|
kLEADER, // With the leader mode, only the leader can enqueue requests. The requests will be
|
||||||
// therefore only the leader can enqueue requests and get responses
|
// broadcasted to the workers. All participants can get response via awaitResponses. The leader is the
|
||||||
kORCHESTRATOR, // With the orchestrator mode, only the orchestrator will be returning from the executor constructor
|
// first participant in the provided participant IDS, or 0 if participant ID is not provided
|
||||||
// and therefore only the leader can enqueue requests and get responses The orchestrator doesn't
|
|
||||||
// participate in the computations
|
|
||||||
kALL, // With the ALL mode, all participants are expected to make the same calls to the executor API
|
|
||||||
// So they all need to send the same requests
|
|
||||||
// Responses will be the same for all participants
|
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace tensorrt_llm::executor
|
} // namespace tensorrt_llm::executor
|
||||||
|
|||||||
@ -81,6 +81,11 @@ public:
|
|||||||
|
|
||||||
using UnderlyingType = uint8_t;
|
using UnderlyingType = uint8_t;
|
||||||
|
|
||||||
|
bool operator==(DecodingMode const& other) const
|
||||||
|
{
|
||||||
|
return mState == other.mState;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
constexpr DecodingMode(UnderlyingType state)
|
constexpr DecodingMode(UnderlyingType state)
|
||||||
: mState(state)
|
: mState(state)
|
||||||
|
|||||||
@ -17,10 +17,13 @@
|
|||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||||
|
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||||
#include "tensorrt_llm/runtime/decodingInput.h"
|
#include "tensorrt_llm/runtime/decodingInput.h"
|
||||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||||
#include "tensorrt_llm/runtime/decodingOutput.h"
|
#include "tensorrt_llm/runtime/decodingOutput.h"
|
||||||
|
#include "tensorrt_llm/runtime/gptModelConfig.h"
|
||||||
#include "tensorrt_llm/runtime/samplingConfig.h"
|
#include "tensorrt_llm/runtime/samplingConfig.h"
|
||||||
|
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||||
#include <curand_kernel.h>
|
#include <curand_kernel.h>
|
||||||
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
@ -59,7 +62,7 @@ public:
|
|||||||
DecodingInput const& decodingInput, BufferManager const& manager)
|
DecodingInput const& decodingInput, BufferManager const& manager)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
virtual const SamplingConfig& getSamplingConfig() = 0;
|
virtual SamplingConfig const& getSamplingConfig() = 0;
|
||||||
|
|
||||||
static void acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds,
|
static void acceptDraftTokensByIds(ITensor const& targetTokenIds, ITensor const& draftTokenIds,
|
||||||
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths,
|
ITensor const& contextLengths, ITensor const& numDraftTokens, ITensor& sequenceLengths,
|
||||||
@ -71,6 +74,11 @@ public:
|
|||||||
SizeType vocabSize, SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
|
SizeType vocabSize, SizeType vocabSizePadded, bool useRandomAcceptThreshold, float randomAcceptThreshold,
|
||||||
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream);
|
curandState_t* curandState, BufferManager::CudaStreamPtr const& stream);
|
||||||
|
|
||||||
|
static void updateKVCacheBasedOnAcceptedTokens(ITensor const& acceptedOffsets, ITensor const& packedAcceptedIds,
|
||||||
|
ITensor const& pointerArray, ITensor const& pastKeyValueLengths, GptModelConfig const& modelConfig,
|
||||||
|
WorldConfig const& worldConfig, BufferManager::CudaStreamPtr stream, SizeType rewindDraftTokenCount,
|
||||||
|
SizeType maxAttentionWindow, SizeType maxBlocksPerSeq, nvinfer1::DataType dtype);
|
||||||
|
|
||||||
static std::unique_ptr<IGptDecoder> create(DecodingMode const& mode, nvinfer1::DataType dtype, size_t maxBatchSize,
|
static std::unique_ptr<IGptDecoder> create(DecodingMode const& mode, nvinfer1::DataType dtype, size_t maxBatchSize,
|
||||||
size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
|
size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
|
||||||
BufferManager::CudaStreamPtr const& stream);
|
BufferManager::CudaStreamPtr const& stream);
|
||||||
@ -97,7 +105,7 @@ public:
|
|||||||
void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput, DecodingInput const& decodingInput,
|
void gatherTree(ITensor& finalOutputIds, DecodingOutput const& decodingOutput, DecodingInput const& decodingInput,
|
||||||
BufferManager const& manager) override;
|
BufferManager const& manager) override;
|
||||||
|
|
||||||
const SamplingConfig& getSamplingConfig() override
|
SamplingConfig const& getSamplingConfig() override
|
||||||
{
|
{
|
||||||
return mSamplingConfig;
|
return mSamplingConfig;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -153,6 +153,18 @@ public:
|
|||||||
return mFinishedSum;
|
return mFinishedSum;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
|
||||||
|
[[nodiscard]] TensorPtr getNextDraftTokens() const override
|
||||||
|
{
|
||||||
|
return mNextDraftTokens;
|
||||||
|
}
|
||||||
|
|
||||||
|
//! @returns [batchSize], lengths of the predicted draft tokens for next step, on gpu
|
||||||
|
[[nodiscard]] TensorPtr getNextDraftTokenLengths() const override
|
||||||
|
{
|
||||||
|
return mNextDraftTokenLengths;
|
||||||
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
//! @brief Gather final beam search results for request `batchIdx`.
|
//! @brief Gather final beam search results for request `batchIdx`.
|
||||||
[[nodiscard]] CudaEvent postProcessRequest(SizeType batchIdx) const;
|
[[nodiscard]] CudaEvent postProcessRequest(SizeType batchIdx) const;
|
||||||
@ -204,6 +216,8 @@ private:
|
|||||||
TensorPtr mBatchSlotsAcceptTokens; // [maxBatchSize], int32_t, address map, pinned
|
TensorPtr mBatchSlotsAcceptTokens; // [maxBatchSize], int32_t, address map, pinned
|
||||||
TensorPtr mBatchSlotsAcceptLogits; // [maxBatchSize], int32_t, address map, pinned
|
TensorPtr mBatchSlotsAcceptLogits; // [maxBatchSize], int32_t, address map, pinned
|
||||||
TensorPtr mTargetLogitsPtrs; // [maxBatchSize], float*, pointers to target logits, pinned
|
TensorPtr mTargetLogitsPtrs; // [maxBatchSize], float*, pointers to target logits, pinned
|
||||||
|
TensorPtr mNextDraftTokens;
|
||||||
|
TensorPtr mNextDraftTokenLengths;
|
||||||
SizeType mMaxSequenceLength{};
|
SizeType mMaxSequenceLength{};
|
||||||
SizeType mMaxAttentionWindow{};
|
SizeType mMaxAttentionWindow{};
|
||||||
SizeType mSinkTokenLength{};
|
SizeType mSinkTokenLength{};
|
||||||
|
|||||||
@ -46,15 +46,10 @@ public:
|
|||||||
, endId{endId}
|
, endId{endId}
|
||||||
, computeCumLogProbs(false)
|
, computeCumLogProbs(false)
|
||||||
, computeLogProbs(false)
|
, computeLogProbs(false)
|
||||||
|
, generatedTokensPerStep(1)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
// the number of tokens generated per step
|
|
||||||
SizeType generatedTokensPerStep() const
|
|
||||||
{
|
|
||||||
return draftTokens ? draftTokens->getSize() + 1 : 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// mandatory parameters
|
// mandatory parameters
|
||||||
ConstTensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu
|
ConstTensorPtr ids; // [inputSeqLen], the input sequence of token ids, on gpu
|
||||||
SizeType inputLen; // the input length without draft tokens
|
SizeType inputLen; // the input length without draft tokens
|
||||||
@ -71,6 +66,7 @@ public:
|
|||||||
|
|
||||||
bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||||
bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||||
|
SizeType generatedTokensPerStep;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Input
|
class Input
|
||||||
@ -184,6 +180,12 @@ public:
|
|||||||
std::vector<SamplingConfig> const& samplingConfigs)
|
std::vector<SamplingConfig> const& samplingConfigs)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
|
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
|
||||||
|
virtual TensorPtr getNextDraftTokens() const = 0;
|
||||||
|
|
||||||
|
//! @returns [batchSize], lengths of the predicted draft tokens for next step, on gpu
|
||||||
|
virtual TensorPtr getNextDraftTokenLengths() const = 0;
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
IGptDecoderBatch() = default;
|
IGptDecoderBatch() = default;
|
||||||
};
|
};
|
||||||
|
|||||||
@ -36,7 +36,7 @@ public:
|
|||||||
IpcMemory(WorldConfig const& worldConfig, std::size_t bufferSize);
|
IpcMemory(WorldConfig const& worldConfig, std::size_t bufferSize);
|
||||||
~IpcMemory();
|
~IpcMemory();
|
||||||
|
|
||||||
[[nodiscard]] const std::vector<void*>& getCommPtrsTensor() const
|
[[nodiscard]] std::vector<void*> const& getCommPtrsTensor() const
|
||||||
{
|
{
|
||||||
return mCommPtrs;
|
return mCommPtrs;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -67,7 +67,7 @@ public:
|
|||||||
// Fill the tasks tensor for the batch using the provided tasksHost
|
// Fill the tasks tensor for the batch using the provided tasksHost
|
||||||
// Function assumes that the first numContextRequests requests in the batch are context requests
|
// Function assumes that the first numContextRequests requests in the batch are context requests
|
||||||
void fillTasksTensor(TensorPtr tasksHost, const SizeType batchSize, const SizeType numContextRequests,
|
void fillTasksTensor(TensorPtr tasksHost, const SizeType batchSize, const SizeType numContextRequests,
|
||||||
const std::vector<SizeType>& reqBeamWidths, const std::vector<SizeType>& reqPromptLengths,
|
std::vector<SizeType> const& reqBeamWidths, std::vector<SizeType> const& reqPromptLengths,
|
||||||
BufferManager const& manager, bool packedInput);
|
BufferManager const& manager, bool packedInput);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@ -43,7 +43,7 @@ private:
|
|||||||
auto const hasValues = accessor(0).has_value();
|
auto const hasValues = accessor(0).has_value();
|
||||||
for (size_t ci = 0; ci < configs.size(); ++ci)
|
for (size_t ci = 0; ci < configs.size(); ++ci)
|
||||||
{
|
{
|
||||||
const auto& configValue = accessor(ci);
|
auto const& configValue = accessor(ci);
|
||||||
TLLM_CHECK(hasValues == configValue.has_value());
|
TLLM_CHECK(hasValues == configValue.has_value());
|
||||||
if (hasValues)
|
if (hasValues)
|
||||||
{
|
{
|
||||||
|
|||||||
@ -188,7 +188,6 @@ endif()
|
|||||||
set(TRTLLM_LINK_LIBS
|
set(TRTLLM_LINK_LIBS
|
||||||
${CUBLAS_LIB}
|
${CUBLAS_LIB}
|
||||||
${CUBLASLT_LIB}
|
${CUBLASLT_LIB}
|
||||||
${CUDNN_LIB}
|
|
||||||
${CMAKE_DL_LIBS}
|
${CMAKE_DL_LIBS}
|
||||||
${MPI_C_LIBRARIES}
|
${MPI_C_LIBRARIES}
|
||||||
${NCCL_LIB}
|
${NCCL_LIB}
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:0ecc134ad10a54b2953c772e72db2f71e84130d5736087b033e9e5b78594db6d
|
oid sha256:c56ee13bb109917ab10df168ca15e6057436df1cd8b64a4268c6e7aae78a5ad8
|
||||||
size 2113376
|
size 2126310
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:9aa3f3d7f8313c099df8e9bd4c9707922a4f1c4025c4c99986acf6df781738c7
|
oid sha256:339532215fa4c16e68ca28ee23d0a0e09c9caefa7bd19b563d2f7b83cad6822e
|
||||||
size 2128450
|
size 2142070
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
add62ff328028bbcded1af694fe758c5 libtensorrt_llm_batch_manager_static.a
|
c9c505e2cb6e95b7cfc124c04ab1fcb3 libtensorrt_llm_batch_manager_static.a
|
||||||
9e8846e200e2aaaeace862741a90c3ab libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
2f5cec5a5b42e0031bc2edc688c1e74b libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||||
230623fa285048a2de5c54c2cc0f364fb9f2c559 commit
|
741fb083cc42933439ae54557b177b6d7064da4f commit
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:7b25de974b6ca5f0dcb279f16f38199167d1efc35c01770d3234bec2dfb5dc86
|
oid sha256:a4060f2d60472850344e5b5799f9ad88390f4ad9c056e3843f3bdbcc046ca68b
|
||||||
size 2097848
|
size 2106440
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:5f06cee5ae2bcf393196265cd9a3ef832690cd4c5c53934bbfb169d50ab33c41
|
oid sha256:829f1ed5af0b0d2577e57fd13979706fe0b3636bd6338aac3c34a615f64afedc
|
||||||
size 2055004
|
size 2064310
|
||||||
|
|||||||
@ -1,2 +1,2 @@
|
|||||||
bb62a31b8e17dae284d784ba43d5bc02 libtensorrt_llm_batch_manager_static.a
|
2db5c985786dad3dd16c22ec54af0803 libtensorrt_llm_batch_manager_static.a
|
||||||
19327f59c7f5b6235e15b322d5f5a0f4 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
96940249ff7b3ff09754b89ad25fcf9f libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||||
|
|||||||
@ -42,11 +42,11 @@ public:
|
|||||||
virtual ~IAllocator() = default;
|
virtual ~IAllocator() = default;
|
||||||
|
|
||||||
// no copying
|
// no copying
|
||||||
IAllocator(const IAllocator&) = delete;
|
IAllocator(IAllocator const&) = delete;
|
||||||
IAllocator& operator=(const IAllocator&) = delete;
|
IAllocator& operator=(IAllocator const&) = delete;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
[[nodiscard]] T* reMalloc(T* ptr, size_t sizeBytes, const bool setZero = true)
|
[[nodiscard]] T* reMalloc(T* ptr, size_t sizeBytes, bool const setZero = true)
|
||||||
{
|
{
|
||||||
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
TLLM_LOG_TRACE(__PRETTY_FUNCTION__);
|
||||||
// TODO martinma: why do we need this size extension?
|
// TODO martinma: why do we need this size extension?
|
||||||
|
|||||||
@ -23,7 +23,7 @@
|
|||||||
|
|
||||||
namespace tensorrt_llm::common
|
namespace tensorrt_llm::common
|
||||||
{
|
{
|
||||||
[[noreturn]] inline void throwRuntimeError(const char* const file, int const line, std::string const& info = "")
|
[[noreturn]] inline void throwRuntimeError(char const* const file, int const line, std::string const& info = "")
|
||||||
{
|
{
|
||||||
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str()));
|
throw TllmException(file, line, fmtstr("[TensorRT-LLM][ERROR] Assertion failed: %s", info.c_str()));
|
||||||
}
|
}
|
||||||
@ -38,8 +38,10 @@ public:
|
|||||||
|
|
||||||
#if defined(_WIN32)
|
#if defined(_WIN32)
|
||||||
#define TLLM_LIKELY(x) (__assume((x) == 1), (x))
|
#define TLLM_LIKELY(x) (__assume((x) == 1), (x))
|
||||||
|
#define TLLM_UNLIKELY(x) (__assume((x) == 0), (x))
|
||||||
#else
|
#else
|
||||||
#define TLLM_LIKELY(x) __builtin_expect((x), 1)
|
#define TLLM_LIKELY(x) __builtin_expect((x), 1)
|
||||||
|
#define TLLM_UNLIKELY(x) __builtin_expect((x), 0)
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#define TLLM_CHECK(val) \
|
#define TLLM_CHECK(val) \
|
||||||
@ -61,20 +63,22 @@ public:
|
|||||||
#define TLLM_CHECK_DEBUG(val) \
|
#define TLLM_CHECK_DEBUG(val) \
|
||||||
do \
|
do \
|
||||||
{ \
|
{ \
|
||||||
if (DebugConfig::isCheckDebugEnabled()) \
|
if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
|
||||||
{ \
|
{ \
|
||||||
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
|
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
|
||||||
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
|
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, #val); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
#define TLLM_CHECK_DEBUG_WITH_INFO(val, info) \
|
#define TLLM_CHECK_DEBUG_WITH_INFO(val, info, ...) \
|
||||||
do \
|
do \
|
||||||
{ \
|
{ \
|
||||||
if (DebugConfig::isCheckDebugEnabled()) \
|
if (TLLM_UNLIKELY(DebugConfig::isCheckDebugEnabled())) \
|
||||||
{ \
|
{ \
|
||||||
TLLM_LIKELY(static_cast<bool>(val)) ? ((void) 0) \
|
TLLM_LIKELY(static_cast<bool>(val)) \
|
||||||
: tensorrt_llm::common::throwRuntimeError(__FILE__, __LINE__, info); \
|
? ((void) 0) \
|
||||||
|
: tensorrt_llm::common::throwRuntimeError( \
|
||||||
|
__FILE__, __LINE__, tensorrt_llm::common::fmtstr(info, ##__VA_ARGS__)); \
|
||||||
} \
|
} \
|
||||||
} while (0)
|
} while (0)
|
||||||
|
|
||||||
|
|||||||
@ -42,7 +42,7 @@ CublasMMWrapper::~CublasMMWrapper()
|
|||||||
mMutex = nullptr;
|
mMutex = nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
CublasMMWrapper::CublasMMWrapper(const CublasMMWrapper& wrapper)
|
CublasMMWrapper::CublasMMWrapper(CublasMMWrapper const& wrapper)
|
||||||
: mCublasHandle(wrapper.mCublasHandle)
|
: mCublasHandle(wrapper.mCublasHandle)
|
||||||
, mCublasLtHandle(wrapper.mCublasLtHandle)
|
, mCublasLtHandle(wrapper.mCublasLtHandle)
|
||||||
, mStream(wrapper.mStream)
|
, mStream(wrapper.mStream)
|
||||||
@ -50,8 +50,8 @@ CublasMMWrapper::CublasMMWrapper(const CublasMMWrapper& wrapper)
|
|||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
void CublasMMWrapper::createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||||
const int k, const int lda, const int ldb, const int ldc)
|
int const k, int const lda, int const ldb, int const ldc)
|
||||||
{
|
{
|
||||||
// --------------------------------------
|
// --------------------------------------
|
||||||
// Create descriptors for the original matrices
|
// Create descriptors for the original matrices
|
||||||
@ -79,15 +79,15 @@ void CublasMMWrapper::destroyDescriptors()
|
|||||||
mCDesc = NULL;
|
mCDesc = NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||||
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc)
|
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc)
|
||||||
{
|
{
|
||||||
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f);
|
Gemm(transa, transb, m, n, k, A, lda, B, ldb, C, ldc, 1.0f, 0.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||||
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc,
|
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc,
|
||||||
const std::optional<cublasLtMatmulHeuristicResult_t>& heuristic)
|
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic)
|
||||||
{
|
{
|
||||||
if (heuristic)
|
if (heuristic)
|
||||||
{
|
{
|
||||||
@ -102,8 +102,8 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||||
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta)
|
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta)
|
||||||
{
|
{
|
||||||
bool usingCublasLt = mAType == CUDA_R_16F;
|
bool usingCublasLt = mAType == CUDA_R_16F;
|
||||||
|
|
||||||
@ -111,9 +111,9 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c
|
|||||||
/* usingCublasLt */ usingCublasLt);
|
/* usingCublasLt */ usingCublasLt);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||||
const void* A, const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta,
|
void const* A, int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||||
const cublasLtMatmulAlgo_t& algo, bool hasAlgo, bool usingCublasLt)
|
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt)
|
||||||
{
|
{
|
||||||
half h_alpha = (half) (f_alpha);
|
half h_alpha = (half) (f_alpha);
|
||||||
half h_beta = (half) (f_beta);
|
half h_beta = (half) (f_beta);
|
||||||
@ -126,8 +126,8 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c
|
|||||||
int batch_count = 1;
|
int batch_count = 1;
|
||||||
// fp32 use cublas as default
|
// fp32 use cublas as default
|
||||||
// fp16 use cublasLt as default
|
// fp16 use cublasLt as default
|
||||||
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
|
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void*>(&f_alpha);
|
||||||
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
|
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void*>(&f_beta);
|
||||||
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
int workspaceSize = mCublasWorkspace == NULL ? 0 : CUBLAS_WORKSPACE_SIZE;
|
||||||
|
|
||||||
if (usingCublasLt)
|
if (usingCublasLt)
|
||||||
@ -154,10 +154,10 @@ void CublasMMWrapper::Gemm(cublasOperation_t transa, cublasOperation_t transb, c
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||||
const int k, const void* A, const int lda, const int64_t strideA, const void* B, const int ldb,
|
int const k, void const* A, int const lda, const int64_t strideA, void const* B, int const ldb,
|
||||||
const int64_t strideB, void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha,
|
const int64_t strideB, void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha,
|
||||||
const float f_beta)
|
float const f_beta)
|
||||||
{
|
{
|
||||||
half h_alpha = (half) f_alpha;
|
half h_alpha = (half) f_alpha;
|
||||||
half h_beta = (half) f_beta;
|
half h_beta = (half) f_beta;
|
||||||
@ -165,26 +165,26 @@ void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperati
|
|||||||
std::lock_guard<std::mutex> lock(*mMutex);
|
std::lock_guard<std::mutex> lock(*mMutex);
|
||||||
|
|
||||||
int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
|
int isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
|
||||||
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
|
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
|
||||||
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
|
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
|
||||||
|
|
||||||
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda,
|
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, mAType, lda,
|
||||||
strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType,
|
strideA, B, mBType, ldb, strideB, beta, C, mCType, ldc, strideC, batchCount, mComputeType,
|
||||||
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
mAType == CUDA_R_32F ? CUBLAS_GEMM_DEFAULT : CUBLAS_GEMM_DEFAULT_TENSOR_OP));
|
||||||
}
|
}
|
||||||
|
|
||||||
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
void CublasMMWrapper::stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||||
const int k, const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA,
|
int const k, float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA,
|
||||||
const void* B, cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C,
|
void const* B, cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C,
|
||||||
cudaDataType_t CType, const int ldc, const int64_t strideC, const int batchCount, cudaDataType_t computeType)
|
cudaDataType_t CType, int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType)
|
||||||
{
|
{
|
||||||
half h_alpha = (half) f_alpha;
|
half h_alpha = (half) f_alpha;
|
||||||
half h_beta = (half) f_beta;
|
half h_beta = (half) f_beta;
|
||||||
|
|
||||||
std::lock_guard<std::mutex> lock(*mMutex);
|
std::lock_guard<std::mutex> lock(*mMutex);
|
||||||
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
|
bool isFp16ComputeType = mComputeType == CUBLAS_COMPUTE_16F ? 1 : 0;
|
||||||
const void* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<const void*>(&f_alpha);
|
void const* alpha = isFp16ComputeType ? reinterpret_cast<void*>(&h_alpha) : reinterpret_cast<void const*>(&f_alpha);
|
||||||
const void* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<const void*>(&f_beta);
|
void const* beta = isFp16ComputeType ? reinterpret_cast<void*>(&h_beta) : reinterpret_cast<void const*>(&f_beta);
|
||||||
|
|
||||||
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda,
|
check_cuda_error(cublasGemmStridedBatchedEx(getCublasHandle(), transa, transb, m, n, k, alpha, A, AType, lda,
|
||||||
strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType,
|
strideA, B, BType, ldb, strideB, beta, C, CType, ldc, strideC, batchCount, computeType,
|
||||||
@ -267,8 +267,8 @@ void CublasMMWrapper::setStream(cudaStream_t stream)
|
|||||||
mStream = stream;
|
mStream = stream;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n,
|
bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n,
|
||||||
const int k, const int lda, const int ldb, const int ldc, const cublasLtMatmulAlgo_t& algo)
|
int const k, int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo)
|
||||||
{
|
{
|
||||||
TLLM_CHECK_WITH_INFO(
|
TLLM_CHECK_WITH_INFO(
|
||||||
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
|
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
|
||||||
@ -291,12 +291,12 @@ bool CublasMMWrapper::checkTactic(cublasOperation_t transa, cublasOperation_t tr
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasOperation_t transa,
|
std::vector<cublasLtMatmulHeuristicResult_t> CublasMMWrapper::getTactics(cublasOperation_t transa,
|
||||||
cublasOperation_t transb, const int m, const int n, const int k, const int lda, const int ldb, const int ldc)
|
cublasOperation_t transb, int const m, int const n, int const k, int const lda, int const ldb, int const ldc)
|
||||||
{
|
{
|
||||||
TLLM_CHECK_WITH_INFO(
|
TLLM_CHECK_WITH_INFO(
|
||||||
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
|
descriptorsCreated(), "Descriptors are not created! Call createDescriptors before calling this function");
|
||||||
|
|
||||||
const auto heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc);
|
auto const heuristics = getTactics(getCublasLtHandle(), mOperationDesc, mADesc, mBDesc, mCDesc, mCDesc);
|
||||||
|
|
||||||
sync_check_cuda_error();
|
sync_check_cuda_error();
|
||||||
|
|
||||||
|
|||||||
@ -65,39 +65,39 @@ public:
|
|||||||
|
|
||||||
~CublasMMWrapper();
|
~CublasMMWrapper();
|
||||||
|
|
||||||
CublasMMWrapper(const CublasMMWrapper& wrapper);
|
CublasMMWrapper(CublasMMWrapper const& wrapper);
|
||||||
|
|
||||||
/********************** GEMMs **********************/
|
/********************** GEMMs **********************/
|
||||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
|
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||||
const int lda, const void* B, const int ldb, void* C, const int ldc);
|
int const lda, void const* B, int const ldb, void* C, int const ldc);
|
||||||
|
|
||||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
|
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||||
const int lda, const void* B, const int ldb, void* C, const int ldc,
|
int const lda, void const* B, int const ldb, void* C, int const ldc,
|
||||||
const std::optional<cublasLtMatmulHeuristicResult_t>& algo);
|
std::optional<cublasLtMatmulHeuristicResult_t> const& algo);
|
||||||
|
|
||||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
|
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||||
const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta);
|
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta);
|
||||||
|
|
||||||
void Gemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k, const void* A,
|
void Gemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k, void const* A,
|
||||||
const int lda, const void* B, const int ldb, void* C, const int ldc, float f_alpha, float f_beta,
|
int const lda, void const* B, int const ldb, void* C, int const ldc, float f_alpha, float f_beta,
|
||||||
const cublasLtMatmulAlgo_t& algo, bool hasAlgo, bool usingCublasLt);
|
cublasLtMatmulAlgo_t const& algo, bool hasAlgo, bool usingCublasLt);
|
||||||
|
|
||||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||||
const void* A, const int lda, const int64_t strideA, const void* B, const int ldb, const int64_t strideB,
|
void const* A, int const lda, const int64_t strideA, void const* B, int const ldb, const int64_t strideB,
|
||||||
void* C, const int ldc, const int64_t strideC, const int batchCount, const float f_alpha = 1.0f,
|
void* C, int const ldc, const int64_t strideC, int const batchCount, float const f_alpha = 1.0f,
|
||||||
const float f_beta = 0.0f);
|
float const f_beta = 0.0f);
|
||||||
|
|
||||||
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
void stridedBatchedGemm(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||||
const float f_alpha, const void* A, cudaDataType_t AType, const int lda, const int64_t strideA, const void* B,
|
float const f_alpha, void const* A, cudaDataType_t AType, int const lda, const int64_t strideA, void const* B,
|
||||||
cudaDataType_t BType, const int ldb, const int64_t strideB, const float f_beta, void* C, cudaDataType_t CType,
|
cudaDataType_t BType, int const ldb, const int64_t strideB, float const f_beta, void* C, cudaDataType_t CType,
|
||||||
const int ldc, const int64_t strideC, const int batchCount, cudaDataType_t computeType);
|
int const ldc, const int64_t strideC, int const batchCount, cudaDataType_t computeType);
|
||||||
|
|
||||||
/********************** Tactic selection helpers **********************/
|
/********************** Tactic selection helpers **********************/
|
||||||
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
bool checkTactic(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||||
const int lda, const int ldb, const int ldc, const cublasLtMatmulAlgo_t& algo);
|
int const lda, int const ldb, int const ldc, cublasLtMatmulAlgo_t const& algo);
|
||||||
|
|
||||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
|
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasOperation_t transa, cublasOperation_t transb,
|
||||||
const int m, const int n, const int k, const int lda, const int ldb, const int ldc);
|
int const m, int const n, int const k, int const lda, int const ldb, int const ldc);
|
||||||
|
|
||||||
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
|
std::vector<cublasLtMatmulHeuristicResult_t> getTactics(cublasLtHandle_t lightHandle,
|
||||||
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
|
cublasLtMatmulDesc_t computeDesc, cublasLtMatrixLayout_t Adesc, cublasLtMatrixLayout_t Bdesc,
|
||||||
@ -126,8 +126,8 @@ public:
|
|||||||
|
|
||||||
CublasDataType getCublasDataType(cudaDataType_t data_type);
|
CublasDataType getCublasDataType(cudaDataType_t data_type);
|
||||||
|
|
||||||
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, const int m, const int n, const int k,
|
void createDescriptors(cublasOperation_t transa, cublasOperation_t transb, int const m, int const n, int const k,
|
||||||
const int lda, const int ldb, const int ldc);
|
int const lda, int const ldb, int const ldc);
|
||||||
void destroyDescriptors();
|
void destroyDescriptors();
|
||||||
|
|
||||||
cublasHandle_t getCublasHandle()
|
cublasHandle_t getCublasHandle()
|
||||||
|
|||||||
@ -43,7 +43,7 @@ CUDADriverWrapper::CUDADriverWrapper()
|
|||||||
handle = dllOpen(CUDA_LIB_NAME);
|
handle = dllOpen(CUDA_LIB_NAME);
|
||||||
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
|
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
|
||||||
|
|
||||||
auto load_sym = [](void* handle, const char* name)
|
auto load_sym = [](void* handle, char const* name)
|
||||||
{
|
{
|
||||||
void* ret = dllGetSym(handle, name);
|
void* ret = dllGetSym(handle, name);
|
||||||
return ret;
|
return ret;
|
||||||
@ -69,7 +69,7 @@ CUDADriverWrapper::~CUDADriverWrapper()
|
|||||||
dllClose(handle);
|
dllClose(handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, const char** pStr) const
|
CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) const
|
||||||
{
|
{
|
||||||
return (*_cuGetErrorName)(error, pStr);
|
return (*_cuGetErrorName)(error, pStr);
|
||||||
}
|
}
|
||||||
@ -94,7 +94,7 @@ CUresult CUDADriverWrapper::cuLinkDestroy(CUlinkState state) const
|
|||||||
return (*_cuLinkDestroy)(state);
|
return (*_cuLinkDestroy)(state);
|
||||||
}
|
}
|
||||||
|
|
||||||
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, const void* image) const
|
CUresult CUDADriverWrapper::cuModuleLoadData(CUmodule* module, void const* image) const
|
||||||
{
|
{
|
||||||
return (*_cuModuleLoadData)(module, image);
|
return (*_cuModuleLoadData)(module, image);
|
||||||
}
|
}
|
||||||
@ -105,24 +105,24 @@ CUresult CUDADriverWrapper::cuLinkCreate(
|
|||||||
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
|
return (*_cuLinkCreate)(numOptions, options, optionValues, stateOut);
|
||||||
}
|
}
|
||||||
|
|
||||||
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, const char* name) const
|
CUresult CUDADriverWrapper::cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const
|
||||||
{
|
{
|
||||||
return (*_cuModuleGetFunction)(hfunc, hmod, name);
|
return (*_cuModuleGetFunction)(hfunc, hmod, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name) const
|
CUresult CUDADriverWrapper::cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const
|
||||||
{
|
{
|
||||||
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
|
return (*_cuModuleGetGlobal)(dptr, bytes, hmod, name);
|
||||||
}
|
}
|
||||||
|
|
||||||
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path,
|
CUresult CUDADriverWrapper::cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path,
|
||||||
unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
||||||
{
|
{
|
||||||
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
|
return (*_cuLinkAddFile)(state, type, path, numOptions, options, optionValues);
|
||||||
}
|
}
|
||||||
|
|
||||||
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
|
CUresult CUDADriverWrapper::cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size,
|
||||||
const char* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
char const* name, unsigned int numOptions, CUjit_option* options, void** optionValues) const
|
||||||
{
|
{
|
||||||
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
|
return (*_cuLinkAddData)(state, type, data, size, name, numOptions, options, optionValues);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -37,7 +37,7 @@ public:
|
|||||||
|
|
||||||
~CUDADriverWrapper();
|
~CUDADriverWrapper();
|
||||||
|
|
||||||
CUresult cuGetErrorName(CUresult error, const char** pStr) const;
|
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
|
||||||
|
|
||||||
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
|
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
|
||||||
|
|
||||||
@ -47,19 +47,19 @@ public:
|
|||||||
|
|
||||||
CUresult cuLinkDestroy(CUlinkState state) const;
|
CUresult cuLinkDestroy(CUlinkState state) const;
|
||||||
|
|
||||||
CUresult cuModuleLoadData(CUmodule* module, const void* image) const;
|
CUresult cuModuleLoadData(CUmodule* module, void const* image) const;
|
||||||
|
|
||||||
CUresult cuLinkCreate(
|
CUresult cuLinkCreate(
|
||||||
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
|
unsigned int numOptions, CUjit_option* options, void** optionValues, CUlinkState* stateOut) const;
|
||||||
|
|
||||||
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, const char* name) const;
|
CUresult cuModuleGetFunction(CUfunction* hfunc, CUmodule hmod, char const* name) const;
|
||||||
|
|
||||||
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, const char* name) const;
|
CUresult cuModuleGetGlobal(CUdeviceptr* dptr, size_t* bytes, CUmodule hmod, char const* name) const;
|
||||||
|
|
||||||
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, const char* path, unsigned int numOptions,
|
CUresult cuLinkAddFile(CUlinkState state, CUjitInputType type, char const* path, unsigned int numOptions,
|
||||||
CUjit_option* options, void** optionValues) const;
|
CUjit_option* options, void** optionValues) const;
|
||||||
|
|
||||||
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, const char* name,
|
CUresult cuLinkAddData(CUlinkState state, CUjitInputType type, void* data, size_t size, char const* name,
|
||||||
unsigned int numOptions, CUjit_option* options, void** optionValues) const;
|
unsigned int numOptions, CUjit_option* options, void** optionValues) const;
|
||||||
|
|
||||||
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
|
CUresult cuLaunchCooperativeKernel(CUfunction f, unsigned int gridDimX, unsigned int gridDimY,
|
||||||
@ -72,18 +72,18 @@ public:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
void* handle;
|
void* handle;
|
||||||
CUresult (*_cuGetErrorName)(CUresult, const char**);
|
CUresult (*_cuGetErrorName)(CUresult, char const**);
|
||||||
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
|
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
|
||||||
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
|
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
|
||||||
CUresult (*_cuModuleUnload)(CUmodule);
|
CUresult (*_cuModuleUnload)(CUmodule);
|
||||||
CUresult (*_cuLinkDestroy)(CUlinkState);
|
CUresult (*_cuLinkDestroy)(CUlinkState);
|
||||||
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
|
CUresult (*_cuLinkCreate)(unsigned int, CUjit_option*, void**, CUlinkState*);
|
||||||
CUresult (*_cuModuleLoadData)(CUmodule*, const void*);
|
CUresult (*_cuModuleLoadData)(CUmodule*, void const*);
|
||||||
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, const char*);
|
CUresult (*_cuModuleGetFunction)(CUfunction*, CUmodule, char const*);
|
||||||
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, const char*);
|
CUresult (*_cuModuleGetGlobal)(CUdeviceptr*, size_t*, CUmodule, char const*);
|
||||||
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, const char*, unsigned int, CUjit_option*, void**);
|
CUresult (*_cuLinkAddFile)(CUlinkState, CUjitInputType, char const*, unsigned int, CUjit_option*, void**);
|
||||||
CUresult (*_cuLinkAddData)(
|
CUresult (*_cuLinkAddData)(
|
||||||
CUlinkState, CUjitInputType, void*, size_t, const char*, unsigned int, CUjit_option*, void**);
|
CUlinkState, CUjitInputType, void*, size_t, char const*, unsigned int, CUjit_option*, void**);
|
||||||
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
|
CUresult (*_cuLaunchCooperativeKernel)(CUfunction, unsigned int, unsigned int, unsigned int, unsigned int,
|
||||||
unsigned int, unsigned int, unsigned int, CUstream, void**);
|
unsigned int, unsigned int, unsigned int, CUstream, void**);
|
||||||
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
|
CUresult (*_cuLaunchKernel)(CUfunction f, unsigned int gridDimX, unsigned int gridDimY, unsigned int gridDimZ,
|
||||||
@ -91,11 +91,11 @@ private:
|
|||||||
CUstream hStream, void** kernelParams, void** extra);
|
CUstream hStream, void** kernelParams, void** extra);
|
||||||
};
|
};
|
||||||
|
|
||||||
inline void cuErrCheck_(CUresult stat, const CUDADriverWrapper& wrap, const char* file, int line)
|
inline void cuErrCheck_(CUresult stat, CUDADriverWrapper const& wrap, char const* file, int line)
|
||||||
{
|
{
|
||||||
if (stat != CUDA_SUCCESS)
|
if (stat != CUDA_SUCCESS)
|
||||||
{
|
{
|
||||||
const char* msg = nullptr;
|
char const* msg = nullptr;
|
||||||
wrap.cuGetErrorName(stat, &msg);
|
wrap.cuGetErrorName(stat, &msg);
|
||||||
fprintf(stderr, "CUDA Error: %s %s %d\n", msg, file, line);
|
fprintf(stderr, "CUDA Error: %s %s %d\n", msg, file, line);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -121,16 +121,16 @@ void invokeFakeQuantize(T_OUT* dst, const T_IN* src, const int64_t numel, cudaSt
|
|||||||
}
|
}
|
||||||
|
|
||||||
template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>(
|
template void invokeFakeQuantize<__nv_fp8_e4m3, float, float>(
|
||||||
float* dst, const float* src, const int64_t numel, cudaStream_t stream);
|
float* dst, float const* src, const int64_t numel, cudaStream_t stream);
|
||||||
template void invokeFakeQuantize<float, float, __nv_fp8_e4m3>(
|
template void invokeFakeQuantize<float, float, __nv_fp8_e4m3>(
|
||||||
float* dst, const __nv_fp8_e4m3* src, const int64_t numel, cudaStream_t stream);
|
float* dst, __nv_fp8_e4m3 const* src, const int64_t numel, cudaStream_t stream);
|
||||||
template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>(
|
template void invokeFakeQuantize<__nv_fp8_e4m3, half, half>(
|
||||||
half* dst, const half* src, const int64_t numel, cudaStream_t stream);
|
half* dst, half const* src, const int64_t numel, cudaStream_t stream);
|
||||||
template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>(
|
template void invokeFakeQuantize<__nv_fp8_e4m3, __nv_bfloat16, __nv_bfloat16>(
|
||||||
__nv_bfloat16* dst, const __nv_bfloat16* src, const int64_t numel, cudaStream_t stream);
|
__nv_bfloat16* dst, __nv_bfloat16 const* src, const int64_t numel, cudaStream_t stream);
|
||||||
|
|
||||||
template void invokeFakeQuantize<float, half, float>(
|
template void invokeFakeQuantize<float, half, float>(
|
||||||
half* dst, const float* src, const int64_t numel, cudaStream_t stream);
|
half* dst, float const* src, const int64_t numel, cudaStream_t stream);
|
||||||
|
|
||||||
__device__ float atomicMaxExtd(float* address, float val)
|
__device__ float atomicMaxExtd(float* address, float val)
|
||||||
{
|
{
|
||||||
@ -146,7 +146,7 @@ inline __device__ T atomicMaxExtdV2(T* address, T val)
|
|||||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||||
static_assert(std::is_same_v<T, half> | std::is_same_v<T, __nv_bfloat16>, "T needs to be either half or bfloat16");
|
static_assert(std::is_same_v<T, half> | std::is_same_v<T, __nv_bfloat16>, "T needs to be either half or bfloat16");
|
||||||
// The address in 64 bits.
|
// The address in 64 bits.
|
||||||
uint64_t address_u64 = reinterpret_cast<const uint64_t&>(address);
|
uint64_t address_u64 = reinterpret_cast<uint64_t const&>(address);
|
||||||
|
|
||||||
// Pack the input value into 32 bits.
|
// Pack the input value into 32 bits.
|
||||||
union
|
union
|
||||||
@ -155,7 +155,7 @@ inline __device__ T atomicMaxExtdV2(T* address, T val)
|
|||||||
uint16_t u[2];
|
uint16_t u[2];
|
||||||
} old, tmp = {};
|
} old, tmp = {};
|
||||||
|
|
||||||
const int loc = (address_u64 & 0x2) >> 1;
|
int const loc = (address_u64 & 0x2) >> 1;
|
||||||
tmp.v[loc] = val;
|
tmp.v[loc] = val;
|
||||||
|
|
||||||
// 4B aligned pointer.
|
// 4B aligned pointer.
|
||||||
@ -223,7 +223,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
|
|||||||
auto val = fabs(static_cast<float>(weights[i]));
|
auto val = fabs(static_cast<float>(weights[i]));
|
||||||
max = max > val ? max : val;
|
max = max > val ? max : val;
|
||||||
}
|
}
|
||||||
const auto scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
|
||||||
if constexpr (std::is_same_v<T_S, float>)
|
if constexpr (std::is_same_v<T_S, float>)
|
||||||
{
|
{
|
||||||
@ -231,7 +231,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
const auto address_u64 = reinterpret_cast<uint64_t>(quant_ptr + col);
|
auto const address_u64 = reinterpret_cast<uint64_t>(quant_ptr + col);
|
||||||
if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0))
|
if ((col == 0 && address_u64 % 4 != 0) || (col == n - 1 && address_u64 % 4 == 0))
|
||||||
atomicMaxExtd(quant_ptr + col, scale);
|
atomicMaxExtd(quant_ptr + col, scale);
|
||||||
else
|
else
|
||||||
@ -244,7 +244,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
|
|||||||
}
|
}
|
||||||
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
|
else if (QUANTIZE_MODE == QuantizeMode::PER_TOKEN)
|
||||||
{
|
{
|
||||||
const auto nrows = size / n;
|
auto const nrows = size / n;
|
||||||
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
|
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
|
||||||
{
|
{
|
||||||
float max = 0.f;
|
float max = 0.f;
|
||||||
@ -256,7 +256,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
|
|||||||
max = blockReduceMax<float>(max);
|
max = blockReduceMax<float>(max);
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
{
|
{
|
||||||
const auto scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||||
quant_ptr[row] = scale;
|
quant_ptr[row] = scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -272,7 +272,7 @@ __global__ void computeFP8QuantizeScale(T_S* quant_ptr, const T_W* weights, cons
|
|||||||
max = blockReduceMax<float>(max);
|
max = blockReduceMax<float>(max);
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
{
|
{
|
||||||
const auto scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
auto const scale = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||||
atomicMaxExtd(quant_ptr, scale);
|
atomicMaxExtd(quant_ptr, scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -326,19 +326,19 @@ __global__ void dynamicQuantizeMatrixPerToken(
|
|||||||
extern __shared__ __align__(sizeof(float)) char _shmem[];
|
extern __shared__ __align__(sizeof(float)) char _shmem[];
|
||||||
T_IN* shmem = reinterpret_cast<T_IN*>(_shmem);
|
T_IN* shmem = reinterpret_cast<T_IN*>(_shmem);
|
||||||
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
constexpr float min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||||
const auto nrows = numel / lda;
|
auto const nrows = numel / lda;
|
||||||
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
|
for (int64_t row = blockIdx.x; row < nrows; row += gridDim.x)
|
||||||
{
|
{
|
||||||
float max = 0.f;
|
float max = 0.f;
|
||||||
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
|
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
|
||||||
{
|
{
|
||||||
const auto in = input[row * lda + i];
|
auto const in = input[row * lda + i];
|
||||||
shmem[i] = in;
|
shmem[i] = in;
|
||||||
auto val = fabs(static_cast<float>(in));
|
auto val = fabs(static_cast<float>(in));
|
||||||
max = max > val ? max : val;
|
max = max > val ? max : val;
|
||||||
}
|
}
|
||||||
max = blockAllReduceMax<float>(max); // __syncthreads() called so we can read shmem
|
max = blockAllReduceMax<float>(max); // __syncthreads() called so we can read shmem
|
||||||
const auto s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
auto const s = (T_S) std::max(max / FP8_E4M3_MAX, min_scaling_factor);
|
||||||
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
|
for (int64_t i = threadIdx.x; i < lda; i += blockDim.x)
|
||||||
{
|
{
|
||||||
// true means we are quantizing
|
// true means we are quantizing
|
||||||
@ -359,7 +359,7 @@ void invokeComputeScalesAndQuantizeMatrix(T_OUT* output, T_S* quant_ptr, const T
|
|||||||
{
|
{
|
||||||
dim3 grid(numel / lda);
|
dim3 grid(numel / lda);
|
||||||
bool use_shmem = true;
|
bool use_shmem = true;
|
||||||
const auto shmem_size = lda * sizeof(T_IN);
|
auto const shmem_size = lda * sizeof(T_IN);
|
||||||
if (shmem_size >= (48 << 10))
|
if (shmem_size >= (48 << 10))
|
||||||
{
|
{
|
||||||
cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken<T_OUT, T_S, T_IN>,
|
cudaError_t ret = cudaFuncSetAttribute(dynamicQuantizeMatrixPerToken<T_OUT, T_S, T_IN>,
|
||||||
|
|||||||
@ -181,37 +181,37 @@ struct PackType<__nv_fp8_e4m3, 8>
|
|||||||
};
|
};
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, const __nv_fp8x4_e4m3* in)
|
__inline__ __device__ void fp8x4_e4m3_to_bfloat2(__nv_bfloat162* out1, __nv_bfloat162* out2, __nv_fp8x4_e4m3 const* in)
|
||||||
{
|
{
|
||||||
const char4 tmp_val = reinterpret_cast<const char4*>(in)[0];
|
const char4 tmp_val = reinterpret_cast<char4 const*>(in)[0];
|
||||||
*out1 = __nv_bfloat162((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.x)[0],
|
*out1 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
|
||||||
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.y)[0]);
|
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
|
||||||
*out2 = __nv_bfloat162((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.z)[0],
|
*out2 = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0],
|
||||||
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.w)[0]);
|
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(const __nv_fp8x2_e4m3* in)
|
__inline__ __device__ __nv_bfloat162 fp8x2_e4m3_to_bfloat2(__nv_fp8x2_e4m3 const* in)
|
||||||
{
|
{
|
||||||
const char2 tmp_val = reinterpret_cast<const char2*>(in)[0];
|
const char2 tmp_val = reinterpret_cast<char2 const*>(in)[0];
|
||||||
__nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.x)[0],
|
__nv_bfloat162 out = __nv_bfloat162((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
|
||||||
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.y)[0]);
|
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, const __nv_fp8x4_e4m3* in)
|
__inline__ __device__ void fp8x4_e4m3_to_half2(half2* out1, half2* out2, __nv_fp8x4_e4m3 const* in)
|
||||||
{
|
{
|
||||||
const char4 tmp_val = reinterpret_cast<const char4*>(in)[0];
|
const char4 tmp_val = reinterpret_cast<char4 const*>(in)[0];
|
||||||
*out1 = half2((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.x)[0],
|
*out1 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
|
||||||
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.y)[0]);
|
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
|
||||||
*out2 = half2((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.z)[0],
|
*out2 = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.z)[0],
|
||||||
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.w)[0]);
|
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.w)[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
__inline__ __device__ half2 fp8x2_e4m3_to_half2(const __nv_fp8x2_e4m3* in)
|
__inline__ __device__ half2 fp8x2_e4m3_to_half2(__nv_fp8x2_e4m3 const* in)
|
||||||
{
|
{
|
||||||
const char2 tmp_val = reinterpret_cast<const char2*>(in)[0];
|
const char2 tmp_val = reinterpret_cast<char2 const*>(in)[0];
|
||||||
half2 out = half2((float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.x)[0],
|
half2 out = half2((float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.x)[0],
|
||||||
(float) reinterpret_cast<const __nv_fp8_e4m3*>(&tmp_val.y)[0]);
|
(float) reinterpret_cast<__nv_fp8_e4m3 const*>(&tmp_val.y)[0]);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -32,14 +32,14 @@ namespace common
|
|||||||
{
|
{
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline __device__ T ldg(const T* val)
|
inline __device__ T ldg(T const* val)
|
||||||
{
|
{
|
||||||
return __ldg(val);
|
return __ldg(val);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if ENABLE_BF16
|
#if ENABLE_BF16
|
||||||
template <>
|
template <>
|
||||||
inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val)
|
inline __device__ __nv_bfloat162 ldg(__nv_bfloat162 const* val)
|
||||||
{
|
{
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
return val[0];
|
return val[0];
|
||||||
@ -49,7 +49,7 @@ inline __device__ __nv_bfloat162 ldg(const __nv_bfloat162* val)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
inline __device__ __nv_bfloat16 ldg(const __nv_bfloat16* val)
|
inline __device__ __nv_bfloat16 ldg(__nv_bfloat16 const* val)
|
||||||
{
|
{
|
||||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||||
return val[0];
|
return val[0];
|
||||||
|
|||||||
@ -81,12 +81,12 @@ enum class OperationType
|
|||||||
};
|
};
|
||||||
|
|
||||||
/* **************************** debug tools ********************************* */
|
/* **************************** debug tools ********************************* */
|
||||||
static const char* _cudaGetErrorEnum(cudaError_t error)
|
static char const* _cudaGetErrorEnum(cudaError_t error)
|
||||||
{
|
{
|
||||||
return cudaGetErrorString(error);
|
return cudaGetErrorString(error);
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char* _cudaGetErrorEnum(cublasStatus_t error)
|
static char const* _cudaGetErrorEnum(cublasStatus_t error)
|
||||||
{
|
{
|
||||||
switch (error)
|
switch (error)
|
||||||
{
|
{
|
||||||
@ -114,7 +114,7 @@ static const char* _cudaGetErrorEnum(cublasStatus_t error)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void check(T result, char const* const func, const char* const file, int const line)
|
void check(T result, char const* const func, char const* const file, int const line)
|
||||||
{
|
{
|
||||||
if (result)
|
if (result)
|
||||||
{
|
{
|
||||||
@ -133,7 +133,7 @@ inline bool isCudaLaunchBlocking()
|
|||||||
|
|
||||||
if (firstCall)
|
if (firstCall)
|
||||||
{
|
{
|
||||||
const char* env = std::getenv("CUDA_LAUNCH_BLOCKING");
|
char const* env = std::getenv("CUDA_LAUNCH_BLOCKING");
|
||||||
result = env != nullptr && std::string(env) == "1";
|
result = env != nullptr && std::string(env) == "1";
|
||||||
firstCall = false;
|
firstCall = false;
|
||||||
}
|
}
|
||||||
@ -141,12 +141,12 @@ inline bool isCudaLaunchBlocking()
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void syncAndCheck(const char* const file, int const line)
|
inline void syncAndCheck(char const* const file, int const line)
|
||||||
{
|
{
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
const bool checkError = true;
|
bool const checkError = true;
|
||||||
#else
|
#else
|
||||||
const bool checkError = isCudaLaunchBlocking();
|
bool const checkError = isCudaLaunchBlocking();
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
if (checkError)
|
if (checkError)
|
||||||
@ -279,7 +279,7 @@ inline int getDeviceCount()
|
|||||||
|
|
||||||
/// Get the memory info
|
/// Get the memory info
|
||||||
/// \return The free and total amount of memory in bytes
|
/// \return The free and total amount of memory in bytes
|
||||||
inline std::tuple<size_t, size_t> getDeviceMemoryInfo(const bool useUvm)
|
inline std::tuple<size_t, size_t> getDeviceMemoryInfo(bool const useUvm)
|
||||||
{
|
{
|
||||||
if (useUvm)
|
if (useUvm)
|
||||||
{
|
{
|
||||||
@ -351,7 +351,7 @@ auto constexpr ceilDiv(T numerator, U denominator)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void printAbsMean(const T* buf, uint64_t size, cudaStream_t stream, std::string name = "")
|
void printAbsMean(T const* buf, uint64_t size, cudaStream_t stream, std::string name = "")
|
||||||
{
|
{
|
||||||
if (buf == nullptr)
|
if (buf == nullptr)
|
||||||
{
|
{
|
||||||
@ -390,9 +390,9 @@ void printAbsMean(const T* buf, uint64_t size, cudaStream_t stream, std::string
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void printToStream(const T* result, const int size, FILE* strm)
|
void printToStream(T const* result, int const size, FILE* strm)
|
||||||
{
|
{
|
||||||
const bool split_rows = (strm == stdout);
|
bool const split_rows = (strm == stdout);
|
||||||
if (result == nullptr)
|
if (result == nullptr)
|
||||||
{
|
{
|
||||||
TLLM_LOG_WARNING("It is an nullptr, skip! \n");
|
TLLM_LOG_WARNING("It is an nullptr, skip! \n");
|
||||||
@ -414,13 +414,13 @@ void printToStream(const T* result, const int size, FILE* strm)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void printToScreen(const T* result, const int size)
|
void printToScreen(T const* result, int const size)
|
||||||
{
|
{
|
||||||
printToStream(result, size, stdout);
|
printToStream(result, size, stdout);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void print2dToStream(const T* result, const int r, const int c, const int stride, FILE* strm)
|
void print2dToStream(T const* result, int const r, int const c, int const stride, FILE* strm)
|
||||||
{
|
{
|
||||||
if (result == nullptr)
|
if (result == nullptr)
|
||||||
{
|
{
|
||||||
@ -429,20 +429,20 @@ void print2dToStream(const T* result, const int r, const int c, const int stride
|
|||||||
}
|
}
|
||||||
for (int ri = 0; ri < r; ++ri)
|
for (int ri = 0; ri < r; ++ri)
|
||||||
{
|
{
|
||||||
const T* ptr = result + ri * stride;
|
T const* ptr = result + ri * stride;
|
||||||
printToStream(ptr, c, strm);
|
printToStream(ptr, c, strm);
|
||||||
}
|
}
|
||||||
fprintf(strm, "\n");
|
fprintf(strm, "\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void print2dToScreen(const T* result, const int r, const int c, const int stride)
|
void print2dToScreen(T const* result, int const r, int const c, int const stride)
|
||||||
{
|
{
|
||||||
print2dToStream(result, r, c, stride, stdout);
|
print2dToStream(result, r, c, stride, stdout);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void print2dToFile(std::string fname, const T* result, const int r, const int c, const int stride)
|
void print2dToFile(std::string fname, T const* result, int const r, int const c, int const stride)
|
||||||
{
|
{
|
||||||
FILE* fp = fopen(fname.c_str(), "wt");
|
FILE* fp = fopen(fname.c_str(), "wt");
|
||||||
if (fp != nullptr)
|
if (fp != nullptr)
|
||||||
@ -493,7 +493,7 @@ inline void print_element_(int64_t ill)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void printMatrix(const T* ptr, int m, int k, int stride, bool is_device_ptr)
|
inline void printMatrix(T const* ptr, int m, int k, int stride, bool is_device_ptr)
|
||||||
{
|
{
|
||||||
T* tmp;
|
T* tmp;
|
||||||
if (is_device_ptr)
|
if (is_device_ptr)
|
||||||
@ -538,14 +538,14 @@ inline void printMatrix(const T* ptr, int m, int k, int stride, bool is_device_p
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template void printMatrix(const float* ptr, int m, int k, int stride, bool is_device_ptr);
|
template void printMatrix(float const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||||
template void printMatrix(const half* ptr, int m, int k, int stride, bool is_device_ptr);
|
template void printMatrix(half const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||||
#ifdef ENABLE_BF16
|
#ifdef ENABLE_BF16
|
||||||
template void printMatrix(const __nv_bfloat16* ptr, int m, int k, int stride, bool is_device_ptr);
|
template void printMatrix(__nv_bfloat16 const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||||
#endif
|
#endif
|
||||||
template void printMatrix(const uint32_t* ptr, int m, int k, int stride, bool is_device_ptr);
|
template void printMatrix(uint32_t const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||||
template void printMatrix(const uint64_t* ptr, int m, int k, int stride, bool is_device_ptr);
|
template void printMatrix(uint64_t const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||||
template void printMatrix(const int* ptr, int m, int k, int stride, bool is_device_ptr);
|
template void printMatrix(int const* ptr, int m, int k, int stride, bool is_device_ptr);
|
||||||
|
|
||||||
} // namespace tensorrt_llm::common
|
} // namespace tensorrt_llm::common
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,7 @@ namespace tensorrt_llm::common
|
|||||||
// XQA kernels (optimized kernels for generation phase).
|
// XQA kernels (optimized kernels for generation phase).
|
||||||
bool forceXQAKernels()
|
bool forceXQAKernels()
|
||||||
{
|
{
|
||||||
const char* force_xqa_env_var = getenv("TRTLLM_FORCE_XQA");
|
char const* force_xqa_env_var = getenv("TRTLLM_FORCE_XQA");
|
||||||
static bool forceXQA = false;
|
static bool forceXQA = false;
|
||||||
if (force_xqa_env_var != nullptr)
|
if (force_xqa_env_var != nullptr)
|
||||||
{
|
{
|
||||||
@ -45,7 +45,7 @@ bool getEnvMmhaMultiblockDebug()
|
|||||||
if (!init)
|
if (!init)
|
||||||
{
|
{
|
||||||
init = true;
|
init = true;
|
||||||
const char* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG");
|
char const* enable_mmha_debug_var = std::getenv("TRTLLM_ENABLE_MMHA_MULTI_BLOCK_DEBUG");
|
||||||
if (enable_mmha_debug_var)
|
if (enable_mmha_debug_var)
|
||||||
{
|
{
|
||||||
if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0')
|
if (enable_mmha_debug_var[0] == '1' && enable_mmha_debug_var[1] == '\0')
|
||||||
@ -64,7 +64,7 @@ int getEnvMmhaBlocksPerSequence()
|
|||||||
if (!init)
|
if (!init)
|
||||||
{
|
{
|
||||||
init = true;
|
init = true;
|
||||||
const char* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE");
|
char const* mmhaBlocksPerSequenceEnv = std::getenv("TRTLLM_MMHA_BLOCKS_PER_SEQUENCE");
|
||||||
if (mmhaBlocksPerSequenceEnv)
|
if (mmhaBlocksPerSequenceEnv)
|
||||||
{
|
{
|
||||||
mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv);
|
mmhaBlocksPerSequence = std::atoi(mmhaBlocksPerSequenceEnv);
|
||||||
|
|||||||
@ -65,5 +65,4 @@ Logger* Logger::getLogger()
|
|||||||
thread_local Logger instance;
|
thread_local Logger instance;
|
||||||
return &instance;
|
return &instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace tensorrt_llm::common
|
} // namespace tensorrt_llm::common
|
||||||
|
|||||||
@ -54,26 +54,26 @@ public:
|
|||||||
|
|
||||||
#if defined(_MSC_VER)
|
#if defined(_MSC_VER)
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
void log(Level level, char const* format, const Args&... args);
|
void log(Level level, char const* format, Args const&... args);
|
||||||
|
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
void log(Level level, int rank, char const* format, const Args&... args);
|
void log(Level level, int rank, char const* format, Args const&... args);
|
||||||
#else
|
#else
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
void log(Level level, char const* format, const Args&... args) __attribute__((format(printf, 3, 0)));
|
void log(Level level, char const* format, Args const&... args) __attribute__((format(printf, 3, 0)));
|
||||||
|
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
void log(Level level, int rank, char const* format, const Args&... args) __attribute__((format(printf, 4, 0)));
|
void log(Level level, int rank, char const* format, Args const&... args) __attribute__((format(printf, 4, 0)));
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
void log(Level level, std::string const& format, const Args&... args)
|
void log(Level level, std::string const& format, Args const&... args)
|
||||||
{
|
{
|
||||||
return log(level, format.c_str(), args...);
|
return log(level, format.c_str(), args...);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
void log(const Level level, const int rank, const std::string& format, const Args&... args)
|
void log(const Level level, int const rank, std::string const& format, Args const&... args)
|
||||||
{
|
{
|
||||||
return log(level, rank, format.c_str(), args...);
|
return log(level, rank, format.c_str(), args...);
|
||||||
}
|
}
|
||||||
@ -122,7 +122,7 @@ private:
|
|||||||
return fmtstr("%s[%s] ", kPREFIX, getLevelName(level));
|
return fmtstr("%s[%s] ", kPREFIX, getLevelName(level));
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline std::string getPrefix(const Level level, const int rank)
|
static inline std::string getPrefix(const Level level, int const rank)
|
||||||
{
|
{
|
||||||
return fmtstr("%s[%s][%d] ", kPREFIX, getLevelName(level), rank);
|
return fmtstr("%s[%s][%d] ", kPREFIX, getLevelName(level), rank);
|
||||||
}
|
}
|
||||||
@ -148,7 +148,7 @@ void Logger::log(Logger::Level level, char const* format, Args const&... args)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename... Args>
|
template <typename... Args>
|
||||||
void Logger::log(const Logger::Level level, const int rank, char const* format, const Args&... args)
|
void Logger::log(const Logger::Level level, int const rank, char const* format, Args const&... args)
|
||||||
{
|
{
|
||||||
if (level_ <= level)
|
if (level_ <= level)
|
||||||
{
|
{
|
||||||
|
|||||||
@ -112,63 +112,63 @@ template void deviceFill(int* devptr, size_t size, int value, cudaStream_t strea
|
|||||||
template void deviceFill(bool* devptr, size_t size, bool value, cudaStream_t stream);
|
template void deviceFill(bool* devptr, size_t size, bool value, cudaStream_t stream);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void cudaD2Hcpy(T* tgt, const T* src, const size_t size)
|
void cudaD2Hcpy(T* tgt, T const* src, const size_t size)
|
||||||
{
|
{
|
||||||
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToHost));
|
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToHost));
|
||||||
}
|
}
|
||||||
|
|
||||||
template void cudaD2Hcpy(float* tgt, const float* src, size_t size);
|
template void cudaD2Hcpy(float* tgt, float const* src, size_t size);
|
||||||
template void cudaD2Hcpy(half* tgt, const half* src, size_t size);
|
template void cudaD2Hcpy(half* tgt, half const* src, size_t size);
|
||||||
#ifdef ENABLE_BF16
|
#ifdef ENABLE_BF16
|
||||||
template void cudaD2Hcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size);
|
template void cudaD2Hcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size);
|
||||||
#endif
|
#endif
|
||||||
template void cudaD2Hcpy(int* tgt, const int* src, size_t size);
|
template void cudaD2Hcpy(int* tgt, int const* src, size_t size);
|
||||||
template void cudaD2Hcpy(bool* tgt, const bool* src, size_t size);
|
template void cudaD2Hcpy(bool* tgt, bool const* src, size_t size);
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, const __nv_fp8_e4m3* src, size_t size);
|
template void cudaD2Hcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size);
|
||||||
#endif
|
#endif
|
||||||
template void cudaD2Hcpy(unsigned long long* tgt, const unsigned long long* src, size_t size);
|
template void cudaD2Hcpy(unsigned long long* tgt, unsigned long long const* src, size_t size);
|
||||||
template void cudaD2Hcpy(unsigned int* tgt, const unsigned int* src, size_t size);
|
template void cudaD2Hcpy(unsigned int* tgt, unsigned int const* src, size_t size);
|
||||||
template void cudaD2Hcpy(int8_t* tgt, const int8_t* src, size_t size);
|
template void cudaD2Hcpy(int8_t* tgt, int8_t const* src, size_t size);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void cudaH2Dcpy(T* tgt, const T* src, const size_t size)
|
void cudaH2Dcpy(T* tgt, T const* src, const size_t size)
|
||||||
{
|
{
|
||||||
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice));
|
check_cuda_error(cudaMemcpy(tgt, src, sizeof(T) * size, cudaMemcpyHostToDevice));
|
||||||
}
|
}
|
||||||
|
|
||||||
template void cudaH2Dcpy(float* tgt, const float* src, size_t size);
|
template void cudaH2Dcpy(float* tgt, float const* src, size_t size);
|
||||||
template void cudaH2Dcpy(half* tgt, const half* src, size_t size);
|
template void cudaH2Dcpy(half* tgt, half const* src, size_t size);
|
||||||
#ifdef ENABLE_BF16
|
#ifdef ENABLE_BF16
|
||||||
template void cudaH2Dcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size);
|
template void cudaH2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size);
|
||||||
#endif
|
#endif
|
||||||
template void cudaH2Dcpy(int* tgt, const int* src, size_t size);
|
template void cudaH2Dcpy(int* tgt, int const* src, size_t size);
|
||||||
template void cudaH2Dcpy(bool* tgt, const bool* src, size_t size);
|
template void cudaH2Dcpy(bool* tgt, bool const* src, size_t size);
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, const __nv_fp8_e4m3* src, size_t size);
|
template void cudaH2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size);
|
||||||
#endif
|
#endif
|
||||||
template void cudaH2Dcpy(unsigned long long* tgt, const unsigned long long* src, size_t size);
|
template void cudaH2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size);
|
||||||
template void cudaH2Dcpy(unsigned int* tgt, const unsigned int* src, size_t size);
|
template void cudaH2Dcpy(unsigned int* tgt, unsigned int const* src, size_t size);
|
||||||
template void cudaH2Dcpy(int8_t* tgt, const int8_t* src, size_t size);
|
template void cudaH2Dcpy(int8_t* tgt, int8_t const* src, size_t size);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void cudaD2Dcpy(T* tgt, const T* src, const size_t size, cudaStream_t stream)
|
void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToDevice, stream));
|
check_cuda_error(cudaMemcpyAsync(tgt, src, sizeof(T) * size, cudaMemcpyDeviceToDevice, stream));
|
||||||
}
|
}
|
||||||
|
|
||||||
template void cudaD2Dcpy(float* tgt, const float* src, size_t size, cudaStream_t stream);
|
template void cudaD2Dcpy(float* tgt, float const* src, size_t size, cudaStream_t stream);
|
||||||
template void cudaD2Dcpy(half* tgt, const half* src, size_t size, cudaStream_t stream);
|
template void cudaD2Dcpy(half* tgt, half const* src, size_t size, cudaStream_t stream);
|
||||||
#ifdef ENABLE_BF16
|
#ifdef ENABLE_BF16
|
||||||
template void cudaD2Dcpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size, cudaStream_t stream);
|
template void cudaD2Dcpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream);
|
||||||
#endif
|
#endif
|
||||||
template void cudaD2Dcpy(int* tgt, const int* src, size_t size, cudaStream_t stream);
|
template void cudaD2Dcpy(int* tgt, int const* src, size_t size, cudaStream_t stream);
|
||||||
template void cudaD2Dcpy(bool* tgt, const bool* src, size_t size, cudaStream_t stream);
|
template void cudaD2Dcpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream);
|
||||||
template void cudaD2Dcpy(int8_t* tgt, const int8_t* src, size_t size, cudaStream_t stream);
|
template void cudaD2Dcpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream);
|
||||||
#ifdef ENABLE_FP8
|
#ifdef ENABLE_FP8
|
||||||
template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, const __nv_fp8_e4m3* src, size_t size, cudaStream_t stream);
|
template void cudaD2Dcpy(__nv_fp8_e4m3* tgt, __nv_fp8_e4m3 const* src, size_t size, cudaStream_t stream);
|
||||||
#endif
|
#endif
|
||||||
template void cudaD2Dcpy(unsigned long long* tgt, const unsigned long long* src, size_t size, cudaStream_t stream);
|
template void cudaD2Dcpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream);
|
||||||
|
|
||||||
template <typename T_OUT, typename T_IN>
|
template <typename T_OUT, typename T_IN>
|
||||||
__global__ void cudaCast(T_OUT* dst, T_IN* src, const size_t size)
|
__global__ void cudaCast(T_OUT* dst, T_IN* src, const size_t size)
|
||||||
@ -204,7 +204,7 @@ template void invokeCudaCast(__nv_fp8_e4m3* dst, half const* const src, const si
|
|||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream)
|
void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
if (stream != NULL)
|
if (stream != NULL)
|
||||||
{
|
{
|
||||||
@ -216,19 +216,19 @@ void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template void cudaAutoCpy(float* tgt, const float* src, size_t size, cudaStream_t stream);
|
template void cudaAutoCpy(float* tgt, float const* src, size_t size, cudaStream_t stream);
|
||||||
template void cudaAutoCpy(half* tgt, const half* src, size_t size, cudaStream_t stream);
|
template void cudaAutoCpy(half* tgt, half const* src, size_t size, cudaStream_t stream);
|
||||||
#ifdef ENABLE_BF16
|
#ifdef ENABLE_BF16
|
||||||
template void cudaAutoCpy(__nv_bfloat16* tgt, const __nv_bfloat16* src, size_t size, cudaStream_t stream);
|
template void cudaAutoCpy(__nv_bfloat16* tgt, __nv_bfloat16 const* src, size_t size, cudaStream_t stream);
|
||||||
#endif
|
#endif
|
||||||
template void cudaAutoCpy(int* tgt, const int* src, size_t size, cudaStream_t stream);
|
template void cudaAutoCpy(int* tgt, int const* src, size_t size, cudaStream_t stream);
|
||||||
template void cudaAutoCpy(bool* tgt, const bool* src, size_t size, cudaStream_t stream);
|
template void cudaAutoCpy(bool* tgt, bool const* src, size_t size, cudaStream_t stream);
|
||||||
template void cudaAutoCpy(int8_t* tgt, const int8_t* src, size_t size, cudaStream_t stream);
|
template void cudaAutoCpy(int8_t* tgt, int8_t const* src, size_t size, cudaStream_t stream);
|
||||||
template void cudaAutoCpy(uint8_t* tgt, const uint8_t* src, size_t size, cudaStream_t stream);
|
template void cudaAutoCpy(uint8_t* tgt, uint8_t const* src, size_t size, cudaStream_t stream);
|
||||||
template void cudaAutoCpy(uint32_t* tgt, const uint32_t* src, size_t size, cudaStream_t stream);
|
template void cudaAutoCpy(uint32_t* tgt, uint32_t const* src, size_t size, cudaStream_t stream);
|
||||||
template void cudaAutoCpy(unsigned long long* tgt, const unsigned long long* src, size_t size, cudaStream_t stream);
|
template void cudaAutoCpy(unsigned long long* tgt, unsigned long long const* src, size_t size, cudaStream_t stream);
|
||||||
template void cudaAutoCpy(unsigned long* tgt, const unsigned long* src, size_t size, cudaStream_t stream);
|
template void cudaAutoCpy(unsigned long* tgt, unsigned long const* src, size_t size, cudaStream_t stream);
|
||||||
template void cudaAutoCpy(char* tgt, const char* src, size_t size, cudaStream_t stream);
|
template void cudaAutoCpy(char* tgt, char const* src, size_t size, cudaStream_t stream);
|
||||||
|
|
||||||
template void cudaAutoCpy(float const** tgt, float const* const* src, size_t size, cudaStream_t stream);
|
template void cudaAutoCpy(float const** tgt, float const* const* src, size_t size, cudaStream_t stream);
|
||||||
template void cudaAutoCpy(half const** tgt, half const* const* src, size_t size, cudaStream_t stream);
|
template void cudaAutoCpy(half const** tgt, half const* const* src, size_t size, cudaStream_t stream);
|
||||||
@ -242,7 +242,7 @@ template void cudaAutoCpy(
|
|||||||
unsigned long long const** tgt, unsigned long long const* const* src, size_t size, cudaStream_t stream);
|
unsigned long long const** tgt, unsigned long long const* const* src, size_t size, cudaStream_t stream);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, const int seq_offset)
|
__global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, int const seq_offset)
|
||||||
{
|
{
|
||||||
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
curandState_t local_state;
|
curandState_t local_state;
|
||||||
@ -254,7 +254,7 @@ __global__ void cuda_random_uniform_kernel(T* buffer, const size_t size, const i
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__global__ void cuda_random_uniform_kernel<int>(int* buffer, const size_t size, const int seq_offset)
|
__global__ void cuda_random_uniform_kernel<int>(int* buffer, const size_t size, int const seq_offset)
|
||||||
{
|
{
|
||||||
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
curandState_t local_state;
|
curandState_t local_state;
|
||||||
@ -266,7 +266,7 @@ __global__ void cuda_random_uniform_kernel<int>(int* buffer, const size_t size,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__global__ void cuda_random_uniform_kernel<bool>(bool* buffer, const size_t size, const int seq_offset)
|
__global__ void cuda_random_uniform_kernel<bool>(bool* buffer, const size_t size, int const seq_offset)
|
||||||
{
|
{
|
||||||
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
curandState_t local_state;
|
curandState_t local_state;
|
||||||
@ -278,7 +278,7 @@ __global__ void cuda_random_uniform_kernel<bool>(bool* buffer, const size_t size
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__global__ void cuda_random_uniform_kernel<char>(char* buffer, const size_t size, const int seq_offset)
|
__global__ void cuda_random_uniform_kernel<char>(char* buffer, const size_t size, int const seq_offset)
|
||||||
{
|
{
|
||||||
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
const size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
curandState_t local_state;
|
curandState_t local_state;
|
||||||
@ -462,30 +462,30 @@ void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cud
|
|||||||
cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size);
|
cudaD2DcpyConvert<<<256, 256, 0, stream>>>(tgt, src, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
template void invokeCudaD2DcpyConvert(int8_t* tgt, const float* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(int8_t* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||||
template void invokeCudaD2DcpyConvert(float* tgt, const int8_t* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(float* tgt, int8_t const* src, const size_t size, cudaStream_t stream);
|
||||||
template void invokeCudaD2DcpyConvert(float* tgt, const int* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(float* tgt, int const* src, const size_t size, cudaStream_t stream);
|
||||||
template void invokeCudaD2DcpyConvert(half* tgt, const int* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(half* tgt, int const* src, const size_t size, cudaStream_t stream);
|
||||||
template void invokeCudaD2DcpyConvert(float* tgt, const float* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(float* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||||
template void invokeCudaD2DcpyConvert(half* tgt, const float* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(half* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||||
template void invokeCudaD2DcpyConvert(float* tgt, const half* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(float* tgt, half const* src, const size_t size, cudaStream_t stream);
|
||||||
template void invokeCudaD2DcpyConvert(uint32_t* tgt, const int* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(uint32_t* tgt, int const* src, const size_t size, cudaStream_t stream);
|
||||||
template void invokeCudaD2DcpyConvert(int* tgt, const uint32_t* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(int* tgt, uint32_t const* src, const size_t size, cudaStream_t stream);
|
||||||
template void invokeCudaD2DcpyConvert(int* tgt, const float* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(int* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||||
template void invokeCudaD2DcpyConvert(int* tgt, const half* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(int* tgt, half const* src, const size_t size, cudaStream_t stream);
|
||||||
|
|
||||||
#ifdef ENABLE_BF16
|
#ifdef ENABLE_BF16
|
||||||
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, const float* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, float const* src, const size_t size, cudaStream_t stream);
|
||||||
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, const int* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(__nv_bfloat16* tgt, int const* src, const size_t size, cudaStream_t stream);
|
||||||
template void invokeCudaD2DcpyConvert(float* tgt, const __nv_bfloat16* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(float* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream);
|
||||||
template void invokeCudaD2DcpyConvert(int* tgt, const __nv_bfloat16* src, const size_t size, cudaStream_t stream);
|
template void invokeCudaD2DcpyConvert(int* tgt, __nv_bfloat16 const* src, const size_t size, cudaStream_t stream);
|
||||||
#endif // ENABLE_BF16
|
#endif // ENABLE_BF16
|
||||||
|
|
||||||
template <typename T_IN, typename T_OUT>
|
template <typename T_IN, typename T_OUT>
|
||||||
__global__ void cudaD2DScaleCpyConvert(
|
__global__ void cudaD2DScaleCpyConvert(
|
||||||
T_OUT* dst, const T_IN* src, const float* scale, bool invert_scale, const size_t size)
|
T_OUT* dst, const T_IN* src, float const* scale, bool invert_scale, const size_t size)
|
||||||
{
|
{
|
||||||
const float scale_value = invert_scale ? 1.0f / scale[0] : scale[0];
|
float const scale_value = invert_scale ? 1.0f / scale[0] : scale[0];
|
||||||
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
|
for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < size; tid += blockDim.x * gridDim.x)
|
||||||
{
|
{
|
||||||
dst[tid] = cuda_cast<T_OUT>(cuda_cast<float>(src[tid]) * scale_value);
|
dst[tid] = cuda_cast<T_OUT>(cuda_cast<float>(src[tid]) * scale_value);
|
||||||
@ -494,7 +494,7 @@ __global__ void cudaD2DScaleCpyConvert(
|
|||||||
|
|
||||||
template <typename T_IN, typename T_OUT>
|
template <typename T_IN, typename T_OUT>
|
||||||
void invokeCudaD2DScaleCpyConvert(
|
void invokeCudaD2DScaleCpyConvert(
|
||||||
T_OUT* tgt, const T_IN* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream)
|
T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
cudaD2DScaleCpyConvert<<<256, 256, 0, stream>>>(tgt, src, scale, invert_scale, size);
|
cudaD2DScaleCpyConvert<<<256, 256, 0, stream>>>(tgt, src, scale, invert_scale, size);
|
||||||
}
|
}
|
||||||
@ -524,7 +524,7 @@ void invokeCudaD2DcpyFloat2Half(half* dst, float* src, const size_t size, cudaSt
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void saveToBinary(const T* ptr, const size_t size, std::string filename)
|
void saveToBinary(T const* ptr, const size_t size, std::string filename)
|
||||||
{
|
{
|
||||||
|
|
||||||
std::vector<T> h_ptr(size);
|
std::vector<T> h_ptr(size);
|
||||||
@ -541,14 +541,14 @@ void saveToBinary(const T* ptr, const size_t size, std::string filename)
|
|||||||
out.write((char*) float_ptr.data(), size * sizeof(float));
|
out.write((char*) float_ptr.data(), size * sizeof(float));
|
||||||
}
|
}
|
||||||
|
|
||||||
template void saveToBinary(const float* ptr, const size_t size, std::string filename);
|
template void saveToBinary(float const* ptr, const size_t size, std::string filename);
|
||||||
template void saveToBinary(const half* ptr, const size_t size, std::string filename);
|
template void saveToBinary(half const* ptr, const size_t size, std::string filename);
|
||||||
#ifdef ENABLE_BF16
|
#ifdef ENABLE_BF16
|
||||||
template void saveToBinary(const __nv_bfloat16* ptr, const size_t size, std::string filename);
|
template void saveToBinary(__nv_bfloat16 const* ptr, const size_t size, std::string filename);
|
||||||
#endif // ENABLE_BF16
|
#endif // ENABLE_BF16
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void saveToBinary(const int* ptr, const size_t size, std::string filename)
|
void saveToBinary(int const* ptr, const size_t size, std::string filename)
|
||||||
{
|
{
|
||||||
std::vector<int> h_ptr(size);
|
std::vector<int> h_ptr(size);
|
||||||
cudaD2Hcpy(h_ptr.data(), ptr, size);
|
cudaD2Hcpy(h_ptr.data(), ptr, size);
|
||||||
@ -831,7 +831,7 @@ size_t cuda_datatype_size(TRTLLMCudaDataType dt)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void check_range(const T* buffer, size_t size, T min, T max, bool* d_within_range)
|
__global__ void check_range(T const* buffer, size_t size, T min, T max, bool* d_within_range)
|
||||||
{
|
{
|
||||||
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x)
|
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += blockDim.x * gridDim.x)
|
||||||
{
|
{
|
||||||
@ -844,7 +844,7 @@ __global__ void check_range(const T* buffer, size_t size, T min, T max, bool* d_
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream)
|
bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
cudaMemsetAsync(d_within_range, true, sizeof(bool), stream);
|
cudaMemsetAsync(d_within_range, true, sizeof(bool), stream);
|
||||||
|
|
||||||
@ -858,12 +858,12 @@ bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_
|
|||||||
}
|
}
|
||||||
|
|
||||||
template bool invokeCheckRange<int>(
|
template bool invokeCheckRange<int>(
|
||||||
const int* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream);
|
int const* buffer, const size_t size, int min, int max, bool* d_within_range, cudaStream_t stream);
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Determine the total workspace size based on a vector containing multiple variable sizes.
|
* Determine the total workspace size based on a vector containing multiple variable sizes.
|
||||||
*/
|
*/
|
||||||
size_t calcAlignedSize(const std::vector<size_t>& sizes, const size_t ALIGN_BYTES)
|
size_t calcAlignedSize(std::vector<size_t> const& sizes, const size_t ALIGN_BYTES)
|
||||||
{
|
{
|
||||||
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
|
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
|
||||||
// Check ALIGN_BYTES is a power of 2
|
// Check ALIGN_BYTES is a power of 2
|
||||||
@ -885,7 +885,7 @@ size_t calcAlignedSize(const std::vector<size_t>& sizes, const size_t ALIGN_BYTE
|
|||||||
* of each variable.
|
* of each variable.
|
||||||
*/
|
*/
|
||||||
void calcAlignedPointers(
|
void calcAlignedPointers(
|
||||||
std::vector<void*>& outPtrs, const void* p, const std::vector<size_t>& sizes, size_t ALIGN_BYTES)
|
std::vector<void*>& outPtrs, void const* p, std::vector<size_t> const& sizes, size_t ALIGN_BYTES)
|
||||||
{
|
{
|
||||||
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
|
const size_t ALIGN_MASK = ~(ALIGN_BYTES - 1);
|
||||||
// Check ALIGN_BYTES is a power of 2
|
// Check ALIGN_BYTES is a power of 2
|
||||||
|
|||||||
@ -40,16 +40,16 @@ template <typename T>
|
|||||||
void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0);
|
void deviceFill(T* devptr, size_t size, T value, cudaStream_t stream = 0);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void cudaD2Hcpy(T* tgt, const T* src, const size_t size);
|
void cudaD2Hcpy(T* tgt, T const* src, const size_t size);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void cudaH2Dcpy(T* tgt, const T* src, const size_t size);
|
void cudaH2Dcpy(T* tgt, T const* src, const size_t size);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void cudaD2Dcpy(T* tgt, const T* src, const size_t size, cudaStream_t stream = NULL);
|
void cudaD2Dcpy(T* tgt, T const* src, const size_t size, cudaStream_t stream = NULL);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void cudaAutoCpy(T* tgt, const T* src, const size_t size, cudaStream_t stream = NULL);
|
void cudaAutoCpy(T* tgt, T const* src, const size_t size, cudaStream_t stream = NULL);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void cudaRandomUniform(T* buffer, const size_t size);
|
void cudaRandomUniform(T* buffer, const size_t size);
|
||||||
@ -234,9 +234,9 @@ void invokeCudaD2DcpyConvert(T_OUT* tgt, const T_IN* src, const size_t size, cud
|
|||||||
|
|
||||||
template <typename T_IN, typename T_OUT>
|
template <typename T_IN, typename T_OUT>
|
||||||
void invokeCudaD2DScaleCpyConvert(
|
void invokeCudaD2DScaleCpyConvert(
|
||||||
T_OUT* tgt, const T_IN* src, const float* scale, bool invert_scale, const size_t size, cudaStream_t stream = 0);
|
T_OUT* tgt, const T_IN* src, float const* scale, bool invert_scale, const size_t size, cudaStream_t stream = 0);
|
||||||
|
|
||||||
inline bool checkIfFileExist(const std::string& file_path)
|
inline bool checkIfFileExist(std::string const& file_path)
|
||||||
{
|
{
|
||||||
std::ifstream in(file_path, std::ios::in | std::ios::binary);
|
std::ifstream in(file_path, std::ios::in | std::ios::binary);
|
||||||
if (in.is_open())
|
if (in.is_open())
|
||||||
@ -248,7 +248,7 @@ inline bool checkIfFileExist(const std::string& file_path)
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void saveToBinary(const T* ptr, const size_t size, std::string filename);
|
void saveToBinary(T const* ptr, const size_t size, std::string filename);
|
||||||
|
|
||||||
template <typename T_IN, typename T_fake_type>
|
template <typename T_IN, typename T_fake_type>
|
||||||
void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream);
|
void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream);
|
||||||
@ -256,10 +256,10 @@ void invokeFakeCast(T_IN* input_ptr, const size_t size, cudaStream_t stream);
|
|||||||
size_t cuda_datatype_size(TRTLLMCudaDataType dt);
|
size_t cuda_datatype_size(TRTLLMCudaDataType dt);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
bool invokeCheckRange(const T* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream);
|
bool invokeCheckRange(T const* buffer, const size_t size, T min, T max, bool* d_within_range, cudaStream_t stream);
|
||||||
|
|
||||||
size_t calcAlignedSize(const std::vector<size_t>& sizes, size_t ALIGN_BYTES = 256);
|
size_t calcAlignedSize(std::vector<size_t> const& sizes, size_t ALIGN_BYTES = 256);
|
||||||
void calcAlignedPointers(
|
void calcAlignedPointers(
|
||||||
std::vector<void*>& outPtrs, const void* p, const std::vector<size_t>& sizes, size_t ALIGN_BYTES = 256);
|
std::vector<void*>& outPtrs, void const* p, std::vector<size_t> const& sizes, size_t ALIGN_BYTES = 256);
|
||||||
} // namespace common
|
} // namespace common
|
||||||
} // namespace tensorrt_llm
|
} // namespace tensorrt_llm
|
||||||
|
|||||||
@ -50,6 +50,7 @@ MPI_Datatype getMpiDtype(MpiType dtype)
|
|||||||
{MpiType::kUINT64, MPI_UINT64_T},
|
{MpiType::kUINT64, MPI_UINT64_T},
|
||||||
{MpiType::kFP8, MPI_UINT8_T},
|
{MpiType::kFP8, MPI_UINT8_T},
|
||||||
{MpiType::kBF16, MPI_UINT16_T},
|
{MpiType::kBF16, MPI_UINT16_T},
|
||||||
|
{MpiType::kCHAR, MPI_CHAR},
|
||||||
};
|
};
|
||||||
return dtype_map.at(dtype);
|
return dtype_map.at(dtype);
|
||||||
}
|
}
|
||||||
@ -126,23 +127,6 @@ void MpiComm::bcast(void* buffer, size_t size, MpiType dtype, int root) const
|
|||||||
MPICHECK(MPI_Bcast(buffer, size, getMpiDtype(dtype), root, mComm));
|
MPICHECK(MPI_Bcast(buffer, size, getMpiDtype(dtype), root, mComm));
|
||||||
}
|
}
|
||||||
|
|
||||||
void MpiComm::bcast(std::vector<int64_t>& packed, int root) const
|
|
||||||
{
|
|
||||||
int64_t nWords1;
|
|
||||||
auto const rank = getRank();
|
|
||||||
if (rank == root)
|
|
||||||
{
|
|
||||||
nWords1 = static_cast<int64_t>(packed.size());
|
|
||||||
}
|
|
||||||
auto const mpiInt64 = MpiTypeConverter<int64_t>::value;
|
|
||||||
bcast(&nWords1, 1, mpiInt64, root);
|
|
||||||
if (rank != root)
|
|
||||||
{
|
|
||||||
packed.resize(nWords1);
|
|
||||||
}
|
|
||||||
bcast(packed.data(), packed.size(), mpiInt64, root);
|
|
||||||
}
|
|
||||||
|
|
||||||
void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
|
void MpiComm::send(void const* buffer, size_t size, MpiType dtype, int dest, int tag) const
|
||||||
{
|
{
|
||||||
MPICHECK(MPI_Send(buffer, size, getMpiDtype(dtype), dest, tag, mComm));
|
MPICHECK(MPI_Send(buffer, size, getMpiDtype(dtype), dest, tag, mComm));
|
||||||
@ -162,12 +146,12 @@ MpiComm MpiComm::split(int color, int key) const
|
|||||||
return MpiComm{splitComm, true};
|
return MpiComm{splitComm, true};
|
||||||
}
|
}
|
||||||
|
|
||||||
void MpiComm::allreduce(const void* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const
|
void MpiComm::allreduce(void const* sendbuf, void* recvbuf, int count, MpiType dtype, MpiOp op) const
|
||||||
{
|
{
|
||||||
MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm));
|
MPICHECK(MPI_Allreduce(sendbuf, recvbuf, count, getMpiDtype(dtype), getMpiOp(op), mComm));
|
||||||
}
|
}
|
||||||
|
|
||||||
void MpiComm::allgather(const void* sendbuf, void* recvbuf, int count, MpiType dtype) const
|
void MpiComm::allgather(void const* sendbuf, void* recvbuf, int count, MpiType dtype) const
|
||||||
{
|
{
|
||||||
MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
|
MPICHECK(MPI_Allgather(sendbuf, count, getMpiDtype(dtype), recvbuf, count, getMpiDtype(dtype), mComm));
|
||||||
}
|
}
|
||||||
|
|||||||
@ -39,7 +39,7 @@ public:
|
|||||||
|
|
||||||
constexpr QuantMode(QuantMode const&) noexcept = default;
|
constexpr QuantMode(QuantMode const&) noexcept = default;
|
||||||
|
|
||||||
constexpr QuantMode& operator=(const QuantMode& other) noexcept = default;
|
constexpr QuantMode& operator=(QuantMode const& other) noexcept = default;
|
||||||
|
|
||||||
static constexpr QuantMode none() noexcept
|
static constexpr QuantMode none() noexcept
|
||||||
{
|
{
|
||||||
@ -276,32 +276,32 @@ public:
|
|||||||
return quantMode;
|
return quantMode;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr QuantMode operator+(const QuantMode& other) const noexcept
|
constexpr QuantMode operator+(QuantMode const& other) const noexcept
|
||||||
{
|
{
|
||||||
return QuantMode(mValue | other.mValue);
|
return QuantMode(mValue | other.mValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr QuantMode& operator+=(const QuantMode& other) noexcept
|
constexpr QuantMode& operator+=(QuantMode const& other) noexcept
|
||||||
{
|
{
|
||||||
return *this = *this + other;
|
return *this = *this + other;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr QuantMode operator-(const QuantMode& other) const noexcept
|
constexpr QuantMode operator-(QuantMode const& other) const noexcept
|
||||||
{
|
{
|
||||||
return QuantMode(mValue & ~other.mValue);
|
return QuantMode(mValue & ~other.mValue);
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr QuantMode& operator-=(const QuantMode& other) noexcept
|
constexpr QuantMode& operator-=(QuantMode const& other) noexcept
|
||||||
{
|
{
|
||||||
return *this = *this - other;
|
return *this = *this - other;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr bool operator==(const QuantMode& other) const noexcept
|
constexpr bool operator==(QuantMode const& other) const noexcept
|
||||||
{
|
{
|
||||||
return mValue == other.mValue;
|
return mValue == other.mValue;
|
||||||
}
|
}
|
||||||
|
|
||||||
constexpr bool operator!=(const QuantMode& other) const noexcept
|
constexpr bool operator!=(QuantMode const& other) const noexcept
|
||||||
{
|
{
|
||||||
return !(*this == other);
|
return !(*this == other);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -63,11 +63,11 @@ struct BytesToType<16>
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <int Bytes>
|
template <int Bytes>
|
||||||
__device__ inline void copy(const void* local, void* data)
|
__device__ inline void copy(void const* local, void* data)
|
||||||
{
|
{
|
||||||
using T = typename BytesToType<Bytes>::type;
|
using T = typename BytesToType<Bytes>::type;
|
||||||
|
|
||||||
const T* in = static_cast<const T*>(local);
|
T const* in = static_cast<T const*>(local);
|
||||||
T* out = static_cast<T*>(data);
|
T* out = static_cast<T*>(data);
|
||||||
*out = *in;
|
*out = *in;
|
||||||
}
|
}
|
||||||
@ -257,8 +257,8 @@ __inline__ __device__ void cgBlockReduceSumElements(float* element_list, float*
|
|||||||
cg::thread_block cta = cg::this_thread_block();
|
cg::thread_block cta = cg::this_thread_block();
|
||||||
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);
|
cg::thread_block_tile<32> tile = cg::tiled_partition<32>(cta);
|
||||||
|
|
||||||
const int tid = cta.thread_rank();
|
int const tid = cta.thread_rank();
|
||||||
const int blockz = blockDim.x;
|
int const blockz = blockDim.x;
|
||||||
for (int i = 0; i < NUM; i++)
|
for (int i = 0; i < NUM; i++)
|
||||||
{
|
{
|
||||||
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
|
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
|
||||||
@ -325,7 +325,7 @@ struct TopK
|
|||||||
|
|
||||||
__device__ __forceinline__ void init()
|
__device__ __forceinline__ void init()
|
||||||
{
|
{
|
||||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||||
|
|
||||||
for (int i = 0; i < MAX_K; i++)
|
for (int i = 0; i < MAX_K; i++)
|
||||||
@ -337,7 +337,7 @@ struct TopK
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T, int MAX_K>
|
template <typename T, int MAX_K>
|
||||||
__device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(const TopK<T, MAX_K>& a, const TopK<T, MAX_K>& b)
|
__device__ __forceinline__ TopK<T, MAX_K> reduce_topk_op(TopK<T, MAX_K> const& a, TopK<T, MAX_K> const& b)
|
||||||
{
|
{
|
||||||
TopK<T, MAX_K> res = a;
|
TopK<T, MAX_K> res = a;
|
||||||
for (int i = 0; i < MAX_K; ++i)
|
for (int i = 0; i < MAX_K; ++i)
|
||||||
@ -368,19 +368,19 @@ struct TopK_2
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ TopK_2<T> reduce_topk_op_2(const TopK_2<T>& a, const TopK_2<T>& b)
|
__device__ __forceinline__ TopK_2<T> reduce_topk_op_2(TopK_2<T> const& a, TopK_2<T> const& b)
|
||||||
{
|
{
|
||||||
return a.u > b.u ? a : b;
|
return a.u > b.u ? a : b;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__device__ __forceinline__ T clamp_inf_for_half(const float input)
|
__device__ __forceinline__ T clamp_inf_for_half(float const input)
|
||||||
{
|
{
|
||||||
return input;
|
return input;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
__device__ __forceinline__ half clamp_inf_for_half(const float input)
|
__device__ __forceinline__ half clamp_inf_for_half(float const input)
|
||||||
{
|
{
|
||||||
// clamp inf values to enable fp16 training
|
// clamp inf values to enable fp16 training
|
||||||
return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000);
|
return input > 0.0f ? (half) min(input, HALF_FLT_MAX - 1000) : (half) max(input, -HALF_FLT_MAX + 1000);
|
||||||
|
|||||||
@ -152,7 +152,7 @@ Tensor Tensor::slice(std::vector<size_t> shape, size_t offset) const
|
|||||||
return Tensor(this->where, this->type, shape, this->getPtrWithOffset(offset));
|
return Tensor(this->where, this->type, shape, this->getPtrWithOffset(offset));
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorMap::TensorMap(const std::unordered_map<std::string, Tensor>& tensor_map)
|
TensorMap::TensorMap(std::unordered_map<std::string, Tensor> const& tensor_map)
|
||||||
{
|
{
|
||||||
for (auto& kv : tensor_map)
|
for (auto& kv : tensor_map)
|
||||||
{
|
{
|
||||||
@ -167,7 +167,7 @@ TensorMap::TensorMap(const std::unordered_map<std::string, Tensor>& tensor_map)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TensorMap::TensorMap(const std::vector<Tensor>& tensor_map)
|
TensorMap::TensorMap(std::vector<Tensor> const& tensor_map)
|
||||||
{
|
{
|
||||||
for (size_t i = 0; i < tensor_map.size(); i++)
|
for (size_t i = 0; i < tensor_map.size(); i++)
|
||||||
{
|
{
|
||||||
|
|||||||
@ -191,7 +191,7 @@ struct TensorDataType<int*>
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct TensorDataType<const int*>
|
struct TensorDataType<int const*>
|
||||||
{
|
{
|
||||||
static constexpr DataType value = TYPE_INT32_PTR;
|
static constexpr DataType value = TYPE_INT32_PTR;
|
||||||
};
|
};
|
||||||
@ -419,8 +419,8 @@ private:
|
|||||||
|
|
||||||
public:
|
public:
|
||||||
TensorMap() = default;
|
TensorMap() = default;
|
||||||
TensorMap(const std::unordered_map<std::string, Tensor>& tensor_map);
|
TensorMap(std::unordered_map<std::string, Tensor> const& tensor_map);
|
||||||
TensorMap(const std::vector<Tensor>& tensor_map);
|
TensorMap(std::vector<Tensor> const& tensor_map);
|
||||||
TensorMap(std::initializer_list<std::pair<std::string, Tensor>> tensor_map);
|
TensorMap(std::initializer_list<std::pair<std::string, Tensor>> tensor_map);
|
||||||
~TensorMap();
|
~TensorMap();
|
||||||
|
|
||||||
@ -429,7 +429,7 @@ public:
|
|||||||
return tensor_map_.size();
|
return tensor_map_.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool contains(const std::string& key) const
|
inline bool contains(std::string const& key) const
|
||||||
{
|
{
|
||||||
TLLM_LOG_TRACE("%s for key: %s", __PRETTY_FUNCTION__, key.c_str());
|
TLLM_LOG_TRACE("%s for key: %s", __PRETTY_FUNCTION__, key.c_str());
|
||||||
return tensor_map_.find(key) != tensor_map_.end();
|
return tensor_map_.find(key) != tensor_map_.end();
|
||||||
@ -437,7 +437,7 @@ public:
|
|||||||
|
|
||||||
std::vector<std::string> keys() const;
|
std::vector<std::string> keys() const;
|
||||||
|
|
||||||
inline void insert(const std::string& key, const Tensor& value)
|
inline void insert(std::string const& key, Tensor const& value)
|
||||||
{
|
{
|
||||||
TLLM_CHECK_WITH_INFO(!contains(key), fmtstr("Duplicated key %s", key.c_str()));
|
TLLM_CHECK_WITH_INFO(!contains(key), fmtstr("Duplicated key %s", key.c_str()));
|
||||||
TLLM_CHECK_WITH_INFO(
|
TLLM_CHECK_WITH_INFO(
|
||||||
@ -445,7 +445,7 @@ public:
|
|||||||
tensor_map_.insert({key, value});
|
tensor_map_.insert({key, value});
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void insertIfValid(const std::string& key, const Tensor& value)
|
inline void insertIfValid(std::string const& key, Tensor const& value)
|
||||||
{
|
{
|
||||||
if (value.isValid())
|
if (value.isValid())
|
||||||
{
|
{
|
||||||
@ -462,7 +462,7 @@ public:
|
|||||||
Tensor at(int tmp) = delete;
|
Tensor at(int tmp) = delete;
|
||||||
Tensor at(size_t tmp) = delete;
|
Tensor at(size_t tmp) = delete;
|
||||||
|
|
||||||
inline Tensor& at(const std::string& key)
|
inline Tensor& at(std::string const& key)
|
||||||
{
|
{
|
||||||
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||||
TLLM_CHECK_WITH_INFO(contains(key),
|
TLLM_CHECK_WITH_INFO(contains(key),
|
||||||
@ -471,7 +471,7 @@ public:
|
|||||||
return tensor_map_.at(key);
|
return tensor_map_.at(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Tensor at(const std::string& key) const
|
inline Tensor at(std::string const& key) const
|
||||||
{
|
{
|
||||||
TLLM_CHECK_WITH_INFO(contains(key),
|
TLLM_CHECK_WITH_INFO(contains(key),
|
||||||
fmtstr(
|
fmtstr(
|
||||||
@ -479,7 +479,7 @@ public:
|
|||||||
return tensor_map_.at(key);
|
return tensor_map_.at(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::optional<Tensor> atOpt(const std::string& key) const
|
inline std::optional<Tensor> atOpt(std::string const& key) const
|
||||||
{
|
{
|
||||||
if (contains(key))
|
if (contains(key))
|
||||||
return tensor_map_.at(key);
|
return tensor_map_.at(key);
|
||||||
@ -487,7 +487,7 @@ public:
|
|||||||
return std::nullopt;
|
return std::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Tensor& at(const std::string& key, Tensor& default_tensor)
|
inline Tensor& at(std::string const& key, Tensor& default_tensor)
|
||||||
{
|
{
|
||||||
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||||
if (contains(key))
|
if (contains(key))
|
||||||
@ -497,7 +497,7 @@ public:
|
|||||||
return default_tensor;
|
return default_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Tensor at(const std::string& key, Tensor& default_tensor) const
|
inline Tensor at(std::string const& key, Tensor& default_tensor) const
|
||||||
{
|
{
|
||||||
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||||
if (contains(key))
|
if (contains(key))
|
||||||
@ -507,7 +507,7 @@ public:
|
|||||||
return default_tensor;
|
return default_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Tensor& at(const std::string& key, Tensor&& default_tensor)
|
inline Tensor& at(std::string const& key, Tensor&& default_tensor)
|
||||||
{
|
{
|
||||||
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
TLLM_LOG_TRACE("%s for key %s", __PRETTY_FUNCTION__, key.c_str());
|
||||||
if (contains(key))
|
if (contains(key))
|
||||||
@ -517,7 +517,7 @@ public:
|
|||||||
return default_tensor;
|
return default_tensor;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline Tensor at(const std::string& key, Tensor&& default_tensor) const
|
inline Tensor at(std::string const& key, Tensor&& default_tensor) const
|
||||||
{
|
{
|
||||||
if (contains(key))
|
if (contains(key))
|
||||||
{
|
{
|
||||||
@ -527,7 +527,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T getVal(const std::string& key) const
|
inline T getVal(std::string const& key) const
|
||||||
{
|
{
|
||||||
TLLM_CHECK_WITH_INFO(contains(key),
|
TLLM_CHECK_WITH_INFO(contains(key),
|
||||||
fmtstr(
|
fmtstr(
|
||||||
@ -536,7 +536,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline std::optional<T> getValOpt(const std::string& key) const
|
inline std::optional<T> getValOpt(std::string const& key) const
|
||||||
{
|
{
|
||||||
if (contains(key))
|
if (contains(key))
|
||||||
{
|
{
|
||||||
@ -549,7 +549,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T getVal(const std::string& key, T default_value) const
|
inline T getVal(std::string const& key, T default_value) const
|
||||||
{
|
{
|
||||||
if (contains(key))
|
if (contains(key))
|
||||||
{
|
{
|
||||||
@ -559,7 +559,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T getValWithOffset(const std::string& key, size_t index) const
|
inline T getValWithOffset(std::string const& key, size_t index) const
|
||||||
{
|
{
|
||||||
TLLM_CHECK_WITH_INFO(contains(key),
|
TLLM_CHECK_WITH_INFO(contains(key),
|
||||||
fmtstr(
|
fmtstr(
|
||||||
@ -568,7 +568,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T getValWithOffset(const std::string& key, size_t index, T default_value) const
|
inline T getValWithOffset(std::string const& key, size_t index, T default_value) const
|
||||||
{
|
{
|
||||||
if (contains(key))
|
if (contains(key))
|
||||||
{
|
{
|
||||||
@ -578,7 +578,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T* getPtr(const std::string& key) const
|
inline T* getPtr(std::string const& key) const
|
||||||
{
|
{
|
||||||
TLLM_CHECK_WITH_INFO(contains(key),
|
TLLM_CHECK_WITH_INFO(contains(key),
|
||||||
fmtstr(
|
fmtstr(
|
||||||
@ -587,7 +587,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T* getPtr(const std::string& key, T* default_ptr) const
|
inline T* getPtr(std::string const& key, T* default_ptr) const
|
||||||
{
|
{
|
||||||
if (contains(key))
|
if (contains(key))
|
||||||
{
|
{
|
||||||
@ -597,7 +597,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T* getPtrWithOffset(const std::string& key, size_t index) const
|
inline T* getPtrWithOffset(std::string const& key, size_t index) const
|
||||||
{
|
{
|
||||||
TLLM_CHECK_WITH_INFO(contains(key),
|
TLLM_CHECK_WITH_INFO(contains(key),
|
||||||
fmtstr(
|
fmtstr(
|
||||||
@ -606,7 +606,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T* getPtrWithOffset(const std::string& key, size_t index, T* default_ptr) const
|
inline T* getPtrWithOffset(std::string const& key, size_t index, T* default_ptr) const
|
||||||
{
|
{
|
||||||
if (contains(key))
|
if (contains(key))
|
||||||
{
|
{
|
||||||
|
|||||||
@ -34,7 +34,7 @@ int constexpr VOID_PTR_SZ = 2 + sizeof(void*) * 2;
|
|||||||
|
|
||||||
#if !defined(_MSC_VER)
|
#if !defined(_MSC_VER)
|
||||||
|
|
||||||
TllmException::TllmException(char const* file, std::size_t line, const std::string& msg)
|
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
|
||||||
: std::runtime_error{""}
|
: std::runtime_error{""}
|
||||||
{
|
{
|
||||||
mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES);
|
mNbFrames = backtrace(mCallstack.data(), MAX_FRAMES);
|
||||||
@ -43,7 +43,7 @@ TllmException::TllmException(char const* file, std::size_t line, const std::stri
|
|||||||
std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())});
|
std::runtime_error{fmtstr("%s (%s:%zu)\n%s", msg.c_str(), file, line, trace.c_str())});
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
TllmException::TllmException(char const* file, std::size_t line, const std::string& msg)
|
TllmException::TllmException(char const* file, std::size_t line, std::string const& msg)
|
||||||
: mNbFrames{}
|
: mNbFrames{}
|
||||||
, std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)}
|
, std::runtime_error{fmtstr("%s (%s:%zu)", msg.c_str(), file, line)}
|
||||||
{
|
{
|
||||||
|
|||||||
@ -65,7 +65,7 @@ __forceinline__ __device__ float copysignf_pos(float a, float b)
|
|||||||
__forceinline__ __device__ float tanh_opt(float x)
|
__forceinline__ __device__ float tanh_opt(float x)
|
||||||
{
|
{
|
||||||
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
|
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
|
||||||
const float exp_val = -1.f * fabs(2 * x);
|
float const exp_val = -1.f * fabs(2 * x);
|
||||||
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
|
return copysignf_pos((1.0f - __expf(exp_val)) / (__expf(exp_val) + 1.0f), x);
|
||||||
#else
|
#else
|
||||||
return fast_tanh(x);
|
return fast_tanh(x);
|
||||||
@ -76,7 +76,7 @@ __forceinline__ __device__ float tanh_opt(float x)
|
|||||||
template <>
|
template <>
|
||||||
struct GELU_taylor<float>
|
struct GELU_taylor<float>
|
||||||
{
|
{
|
||||||
static const bool kIsHeavy = true;
|
static bool const kIsHeavy = true;
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
float operator()(float const& z) const
|
float operator()(float const& z) const
|
||||||
|
|||||||
@ -157,8 +157,8 @@ private:
|
|||||||
MatrixCoord extent_real_;
|
MatrixCoord extent_real_;
|
||||||
ElementwiseFunctor elementwise_;
|
ElementwiseFunctor elementwise_;
|
||||||
|
|
||||||
const bool per_token_quant_;
|
bool const per_token_quant_;
|
||||||
const bool per_channel_quant_;
|
bool const per_channel_quant_;
|
||||||
|
|
||||||
AlphaScaleElementType* ptr_alpha_row_;
|
AlphaScaleElementType* ptr_alpha_row_;
|
||||||
AlphaScaleElementType* ptr_alpha_col_;
|
AlphaScaleElementType* ptr_alpha_col_;
|
||||||
|
|||||||
@ -65,7 +65,7 @@ namespace device
|
|||||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
template <typename T_IN, typename T_OUT>
|
template <typename T_IN, typename T_OUT>
|
||||||
__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, const GemmCoord* problem_sizes, int splitk,
|
__global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, GemmCoord const* problem_sizes, int splitk,
|
||||||
int64_t* splitk_buffer_offsets)
|
int64_t* splitk_buffer_offsets)
|
||||||
{
|
{
|
||||||
// in_tensor: [problem_idx, k_partition, hidden_size]
|
// in_tensor: [problem_idx, k_partition, hidden_size]
|
||||||
@ -73,9 +73,9 @@ __global__ void splitkReduction(T_OUT** out_tensor, const T_IN* in_tensor, const
|
|||||||
// so, we need to use splitk_buffer_offsets.
|
// so, we need to use splitk_buffer_offsets.
|
||||||
// out_tensor: problem_idx * [hidden_size]
|
// out_tensor: problem_idx * [hidden_size]
|
||||||
|
|
||||||
const int problem_idx = blockIdx.y;
|
int const problem_idx = blockIdx.y;
|
||||||
GemmCoord problem = problem_sizes[problem_idx];
|
GemmCoord problem = problem_sizes[problem_idx];
|
||||||
const int hidden_size = problem.m() * problem.n();
|
int const hidden_size = problem.m() * problem.n();
|
||||||
const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk;
|
const T_IN* in_tensor_ = in_tensor + splitk_buffer_offsets[problem_idx] * splitk;
|
||||||
T_OUT* out_tensor_ = out_tensor[problem_idx];
|
T_OUT* out_tensor_ = out_tensor[problem_idx];
|
||||||
|
|
||||||
@ -143,7 +143,7 @@ protected:
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
/// Get the number of tiles across all problems in a group
|
/// Get the number of tiles across all problems in a group
|
||||||
static int32_t group_tile_count(const cutlass::gemm::GemmCoord* problem_sizes_ptr, int problem_count)
|
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* problem_sizes_ptr, int problem_count)
|
||||||
{
|
{
|
||||||
int32_t tiles = 0;
|
int32_t tiles = 0;
|
||||||
for (int32_t i = 0; i < problem_count; ++i)
|
for (int32_t i = 0; i < problem_count; ++i)
|
||||||
@ -182,7 +182,7 @@ private:
|
|||||||
|
|
||||||
/// Reorder `data` according to `indices`
|
/// Reorder `data` according to `indices`
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void reorder_array(T* data, const std::vector<size_t>& indices)
|
static void reorder_array(T* data, std::vector<size_t> const& indices)
|
||||||
{
|
{
|
||||||
// For now, simply create a copy of the data and then copy over to the original.
|
// For now, simply create a copy of the data and then copy over to the original.
|
||||||
std::vector<T> copy(indices.size());
|
std::vector<T> copy(indices.size());
|
||||||
@ -314,7 +314,7 @@ public:
|
|||||||
|
|
||||||
/// Computes the number of threadblocks to launch for the grouped kernel
|
/// Computes the number of threadblocks to launch for the grouped kernel
|
||||||
static int sufficient(
|
static int sufficient(
|
||||||
const cutlass::gemm::GemmCoord* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1)
|
cutlass::gemm::GemmCoord const* problem_sizes_ptr = nullptr, int problem_count = 0, int available_sm_count = -1)
|
||||||
{
|
{
|
||||||
// Determine the number of blocks that would be launched to fill up a single
|
// Determine the number of blocks that would be launched to fill up a single
|
||||||
// wave on the GPU with each SM having maximum occupancy.
|
// wave on the GPU with each SM having maximum occupancy.
|
||||||
|
|||||||
@ -142,7 +142,7 @@ struct GemmFpAIntB
|
|||||||
Arguments() {}
|
Arguments() {}
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Arguments(cutlass::gemm::GemmCoord const& problem_size, const int group_size,
|
Arguments(cutlass::gemm::GemmCoord const& problem_size, int const group_size,
|
||||||
typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B,
|
typename Mma::IteratorA::TensorRef ref_A, typename Mma::IteratorB::TensorRef ref_B,
|
||||||
typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero,
|
typename Mma::IteratorScale::TensorRef ref_scale, typename Mma::IteratorScale::TensorRef ref_zero,
|
||||||
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
typename Epilogue::OutputTileIterator::TensorRef ref_C,
|
||||||
@ -206,7 +206,7 @@ struct GemmFpAIntB
|
|||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, const int gemm_k_size,
|
Params(Arguments const& args, cutlass::gemm::GemmCoord const& grid_tiled_shape, int const gemm_k_size,
|
||||||
void* workspace = nullptr)
|
void* workspace = nullptr)
|
||||||
: problem_size(args.problem_size)
|
: problem_size(args.problem_size)
|
||||||
, group_size(args.group_size)
|
, group_size(args.group_size)
|
||||||
|
|||||||
@ -174,7 +174,7 @@ public:
|
|||||||
/// Ctor
|
/// Ctor
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op,
|
Arguments(int problem_count, int threadblock_count, int group_size, typename EpilogueOutputOp::Params output_op,
|
||||||
const ElementA* ptr_A, const ElementB* ptr_B, const ElementScale* weight_scales, const ElementC* ptr_C,
|
ElementA const* ptr_A, ElementB const* ptr_B, ElementScale const* weight_scales, ElementC const* ptr_C,
|
||||||
ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k,
|
ElementC* ptr_D, int64_t* total_rows_before_expert, int64_t gemm_n, int64_t gemm_k,
|
||||||
GemmCoord* host_problem_sizes = nullptr)
|
GemmCoord* host_problem_sizes = nullptr)
|
||||||
: problem_count(problem_count)
|
: problem_count(problem_count)
|
||||||
|
|||||||
@ -119,7 +119,7 @@ struct BaseMoeProblemVisitor
|
|||||||
|
|
||||||
/// Get the grid shape
|
/// Get the grid shape
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem)
|
static cutlass::gemm::GemmCoord grid_shape(cutlass::gemm::GemmCoord const& problem)
|
||||||
{
|
{
|
||||||
|
|
||||||
return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
|
return cutlass::gemm::GemmCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM),
|
||||||
@ -177,12 +177,12 @@ struct BaseMoeProblemVisitor
|
|||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_HOST_DEVICE
|
CUTLASS_HOST_DEVICE
|
||||||
static int32_t tile_count(const cutlass::gemm::GemmCoord& grid)
|
static int32_t tile_count(cutlass::gemm::GemmCoord const& grid)
|
||||||
{
|
{
|
||||||
return ProblemSizeHelper::tile_count(grid);
|
return ProblemSizeHelper::tile_count(grid);
|
||||||
}
|
}
|
||||||
|
|
||||||
static int32_t group_tile_count(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count)
|
static int32_t group_tile_count(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count)
|
||||||
{
|
{
|
||||||
int32_t total_tiles = 0;
|
int32_t total_tiles = 0;
|
||||||
for (int32_t i = 0; i < problem_count; ++i)
|
for (int32_t i = 0; i < problem_count; ++i)
|
||||||
@ -328,12 +328,12 @@ struct MoeProblemVisitor<ProblemSizeHelper, ThreadblockShape, GroupScheduleMode:
|
|||||||
}
|
}
|
||||||
|
|
||||||
static size_t get_workspace_size(
|
static size_t get_workspace_size(
|
||||||
const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count)
|
cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count, int32_t block_count)
|
||||||
{
|
{
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void host_precompute(const cutlass::gemm::GemmCoord* host_problem_sizes_ptr, int32_t problem_count,
|
static void host_precompute(cutlass::gemm::GemmCoord const* host_problem_sizes_ptr, int32_t problem_count,
|
||||||
int32_t block_count, void* host_workspace_ptr)
|
int32_t block_count, void* host_workspace_ptr)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|||||||
@ -60,7 +60,7 @@ namespace threadblock
|
|||||||
template <typename WarpMma, int kExpansionFactor = 1>
|
template <typename WarpMma, int kExpansionFactor = 1>
|
||||||
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
||||||
typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C,
|
typename WarpMma::FragmentA const& A, typename WarpMma::FragmentB const& B, typename WarpMma::FragmentC const& C,
|
||||||
const int warp_tileB_k_offset)
|
int const warp_tileB_k_offset)
|
||||||
{
|
{
|
||||||
warp_mma(D, A, B, C);
|
warp_mma(D, A, B, C);
|
||||||
}
|
}
|
||||||
@ -68,7 +68,7 @@ CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC&
|
|||||||
template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
|
template <typename WarpMma, int kExpansionFactor = WarpMma::kExpansionFactor>
|
||||||
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
CUTLASS_DEVICE void run_warp_mma(WarpMma& warp_mma, typename WarpMma::FragmentC& D,
|
||||||
typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B,
|
typename WarpMma::TransformedFragmentA const& A, typename WarpMma::TransformedFragmentB const& B,
|
||||||
typename WarpMma::FragmentC const& C, const int warp_tileB_k_offset)
|
typename WarpMma::FragmentC const& C, int const warp_tileB_k_offset)
|
||||||
{
|
{
|
||||||
warp_mma(D, A, B, C, warp_tileB_k_offset);
|
warp_mma(D, A, B, C, warp_tileB_k_offset);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -572,8 +572,8 @@ public:
|
|||||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||||
++this->warp_tile_iterator_A_;
|
++this->warp_tile_iterator_A_;
|
||||||
|
|
||||||
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||||
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||||
{
|
{
|
||||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||||
|
|||||||
@ -219,7 +219,7 @@ public:
|
|||||||
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||||
typename Base::SharedStorage& shared_storage,
|
typename Base::SharedStorage& shared_storage,
|
||||||
///< Group size for quantization. Not used by this main loop since it assumes per-column
|
///< Group size for quantization. Not used by this main loop since it assumes per-column
|
||||||
const int group_size,
|
int const group_size,
|
||||||
///< ID within the threadblock
|
///< ID within the threadblock
|
||||||
int thread_idx,
|
int thread_idx,
|
||||||
///< ID of warp
|
///< ID of warp
|
||||||
@ -534,8 +534,8 @@ public:
|
|||||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||||
++this->warp_tile_iterator_A_;
|
++this->warp_tile_iterator_A_;
|
||||||
|
|
||||||
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||||
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||||
{
|
{
|
||||||
this->warp_tile_iterator_B_.set_kgroup_index(
|
this->warp_tile_iterator_B_.set_kgroup_index(
|
||||||
|
|||||||
@ -184,7 +184,7 @@ public:
|
|||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
DqMmaPipelined(typename Base::SharedStorage&
|
DqMmaPipelined(typename Base::SharedStorage&
|
||||||
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM
|
||||||
const int group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
|
int const group_size, ///< Will not be used, just to adapt to finegrained modifications and make the compilation
|
||||||
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
|
///< successful. Because DqMmaPipelined is only enabled for sm<80, so even if this
|
||||||
///< argument is not added, it does not affect compilation for sm>=80.
|
///< argument is not added, it does not affect compilation for sm>=80.
|
||||||
int thread_idx, ///< ID within the threadblock
|
int thread_idx, ///< ID within the threadblock
|
||||||
@ -353,8 +353,8 @@ public:
|
|||||||
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
|
||||||
++this->warp_tile_iterator_A_;
|
++this->warp_tile_iterator_A_;
|
||||||
|
|
||||||
const int warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
int const warp_tileB_k_compute_offset = warp_mma_k % Base::kNumKIterationsPerWarpBLoad;
|
||||||
const int warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
int const warp_tileB_k_load_offset = warp_mma_k / Base::kNumKIterationsPerWarpBLoad;
|
||||||
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
// We are just about to finish computing on a fragment of B, so initiate the load for the next fragment.
|
||||||
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
if (warp_tileB_k_compute_offset == Base::kNumKIterationsPerWarpBLoad - 1)
|
||||||
{
|
{
|
||||||
|
|||||||
@ -218,7 +218,7 @@ public:
|
|||||||
/// Performs a warp-level matrix multiply-accumulate operation
|
/// Performs a warp-level matrix multiply-accumulate operation
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C,
|
void operator()(FragmentC& D, TransformedFragmentA const& A, TransformedFragmentB const& B, FragmentC const& C,
|
||||||
const int warp_tileB_k_offset) const
|
int const warp_tileB_k_offset) const
|
||||||
{
|
{
|
||||||
|
|
||||||
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||||
|
|||||||
@ -136,11 +136,11 @@ public:
|
|||||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, const int warp_idx_n, const int lane_idx)
|
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
|
||||||
{
|
{
|
||||||
const int warp_offset = warp_idx_n * Shape::kN;
|
int const warp_offset = warp_idx_n * Shape::kN;
|
||||||
const int quad = lane_idx / 4;
|
int const quad = lane_idx / 4;
|
||||||
const int thread_offset = warp_offset + quad;
|
int const thread_offset = warp_offset + quad;
|
||||||
pointer_scale_ = smem_scales.data() + thread_offset;
|
pointer_scale_ = smem_scales.data() + thread_offset;
|
||||||
if constexpr (hasZero(QuantOp))
|
if constexpr (hasZero(QuantOp))
|
||||||
{
|
{
|
||||||
@ -149,7 +149,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||||
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
|
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@ -165,7 +165,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||||
{
|
{
|
||||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||||
@ -174,7 +174,7 @@ public:
|
|||||||
== FragmentDequantizedOperand::kElements,
|
== FragmentDequantizedOperand::kElements,
|
||||||
"");
|
"");
|
||||||
|
|
||||||
const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag);
|
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
|
||||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
for (int mma_n_iter = 0; mma_n_iter < MmaOperator::MmaIterations::kColumn; ++mma_n_iter)
|
||||||
@ -222,7 +222,7 @@ public:
|
|||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void dequantize(
|
void dequantize(
|
||||||
FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag, const FragmentScale& zero_frag)
|
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
|
||||||
{
|
{
|
||||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) && defined(ENABLE_BF16))
|
||||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||||
@ -231,8 +231,8 @@ public:
|
|||||||
== FragmentDequantizedOperand::kElements,
|
== FragmentDequantizedOperand::kElements,
|
||||||
"");
|
"");
|
||||||
|
|
||||||
const __nv_bfloat16* scale_ptr = reinterpret_cast<const __nv_bfloat16*>(&scale_frag);
|
__nv_bfloat16 const* scale_ptr = reinterpret_cast<__nv_bfloat16 const*>(&scale_frag);
|
||||||
const __nv_bfloat16* zero_ptr = reinterpret_cast<const __nv_bfloat16*>(&zero_frag);
|
__nv_bfloat16 const* zero_ptr = reinterpret_cast<__nv_bfloat16 const*>(&zero_frag);
|
||||||
|
|
||||||
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
ExpandedMmaOperandB* operand_frag_ptr = reinterpret_cast<ExpandedMmaOperandB*>(&operand_frag);
|
||||||
CUTLASS_PRAGMA_UNROLL
|
CUTLASS_PRAGMA_UNROLL
|
||||||
@ -335,11 +335,11 @@ public:
|
|||||||
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
static constexpr WeightOnlyQuantOp QuantOp = QuantOp_;
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, const int warp_idx_n, const int lane_idx)
|
MmaTensorOpDequantizer(TensorRef smem_scales, TensorRef smem_zeros, int const warp_idx_n, int const lane_idx)
|
||||||
{
|
{
|
||||||
const int warp_offset = warp_idx_n * Shape::kN;
|
int const warp_offset = warp_idx_n * Shape::kN;
|
||||||
const int quad = lane_idx / 4;
|
int const quad = lane_idx / 4;
|
||||||
const int thread_offset = warp_offset + quad;
|
int const thread_offset = warp_offset + quad;
|
||||||
pointer_scale_ = smem_scales.data() + thread_offset;
|
pointer_scale_ = smem_scales.data() + thread_offset;
|
||||||
if constexpr (hasZero(QuantOp))
|
if constexpr (hasZero(QuantOp))
|
||||||
{
|
{
|
||||||
@ -348,7 +348,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||||
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
|
: MmaTensorOpDequantizer(smem_scales, TensorRef(), warp_idx_n, lane_idx)
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
@ -364,7 +364,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||||
{
|
{
|
||||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||||
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||||
@ -406,7 +406,7 @@ public:
|
|||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void dequantize(
|
void dequantize(
|
||||||
FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag, const FragmentScale& zero_frag)
|
FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag, FragmentScale const& zero_frag)
|
||||||
{
|
{
|
||||||
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
using _MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||||
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
using ExpandedMmaOperandB = Array<typename _MmaOperandB::Element, kExpansionFactor * _MmaOperandB::kElements>;
|
||||||
@ -505,11 +505,11 @@ public:
|
|||||||
static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||||
{
|
{
|
||||||
const int warp_offset = warp_idx_n * Shape::kN;
|
int const warp_offset = warp_idx_n * Shape::kN;
|
||||||
const int base_col = lane_idx & 0xF8;
|
int const base_col = lane_idx & 0xF8;
|
||||||
const int thread_offset = warp_offset + base_col;
|
int const thread_offset = warp_offset + base_col;
|
||||||
pointer_ = smem_scales.data() + thread_offset;
|
pointer_ = smem_scales.data() + thread_offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -527,7 +527,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||||
{
|
{
|
||||||
static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, "");
|
static_assert(FragmentScale::kElements == FragmentDequantizedOperand::kElements, "");
|
||||||
|
|
||||||
@ -591,11 +591,11 @@ public:
|
|||||||
static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
static_assert(QuantOp == WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY, "");
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
MmaTensorOpDequantizer(TensorRef smem_scales, const int warp_idx_n, const int lane_idx)
|
MmaTensorOpDequantizer(TensorRef smem_scales, int const warp_idx_n, int const lane_idx)
|
||||||
{
|
{
|
||||||
const int warp_offset = warp_idx_n * Shape::kN;
|
int const warp_offset = warp_idx_n * Shape::kN;
|
||||||
const int base_col = lane_idx & 0xF8 + lane_idx % 4;
|
int const base_col = lane_idx & 0xF8 + lane_idx % 4;
|
||||||
const int thread_offset = warp_offset + base_col;
|
int const thread_offset = warp_offset + base_col;
|
||||||
pointer_ = smem_scales.data() + thread_offset;
|
pointer_ = smem_scales.data() + thread_offset;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -617,7 +617,7 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
CUTLASS_DEVICE
|
CUTLASS_DEVICE
|
||||||
void dequantize(FragmentDequantizedOperand& operand_frag, const FragmentScale& scale_frag)
|
void dequantize(FragmentDequantizedOperand& operand_frag, FragmentScale const& scale_frag)
|
||||||
{
|
{
|
||||||
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||||
static constexpr int total_n_mmas = 2 * TileNIterations;
|
static constexpr int total_n_mmas = 2 * TileNIterations;
|
||||||
|
|||||||
@ -167,8 +167,8 @@ public:
|
|||||||
|
|
||||||
static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment;
|
static constexpr int THREADS_PER_ROW = Shape::kColumn / kAlignment;
|
||||||
|
|
||||||
const int thread_row = thread_id / THREADS_PER_ROW;
|
int const thread_row = thread_id / THREADS_PER_ROW;
|
||||||
const int thread_col = thread_id % THREADS_PER_ROW;
|
int const thread_col = thread_id % THREADS_PER_ROW;
|
||||||
|
|
||||||
const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits<Element>::value / 8;
|
const LongIndex thread_row_byte_offset = thread_row * params_.stride_ * sizeof_bits<Element>::value / 8;
|
||||||
const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits<Element>::value / 8;
|
const LongIndex thread_col_byte_offset = thread_col * kAlignment * sizeof_bits<Element>::value / 8;
|
||||||
@ -182,11 +182,11 @@ public:
|
|||||||
// a given iteration. The same threads will be responsible for issues reads since the number of scales
|
// a given iteration. The same threads will be responsible for issues reads since the number of scales
|
||||||
// read in a given iteration is a constant. Therefore, we should never have to update is_valid_
|
// read in a given iteration is a constant. Therefore, we should never have to update is_valid_
|
||||||
// outside of the constructor.
|
// outside of the constructor.
|
||||||
const int global_row = threadblock_offset.row() + thread_row;
|
int const global_row = threadblock_offset.row() + thread_row;
|
||||||
const int global_col = threadblock_offset.column() + thread_col * kAlignment;
|
int const global_col = threadblock_offset.column() + thread_col * kAlignment;
|
||||||
|
|
||||||
const bool row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow;
|
bool const row_in_bounds = global_row < extent.row() && thread_row < Shape::kRow;
|
||||||
const bool col_in_bounds = global_col < extent.column();
|
bool const col_in_bounds = global_col < extent.column();
|
||||||
|
|
||||||
is_valid_ = row_in_bounds && col_in_bounds;
|
is_valid_ = row_in_bounds && col_in_bounds;
|
||||||
}
|
}
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:4201c7241d53298ca52d4f1447cc9cbc4024f63b42a24cbcff82192cc10bed67
|
oid sha256:e1cdcabfbc5115c0d3228c567800d2706f1bc9e3752aaaa8148bcfe83be2c08c
|
||||||
size 576098
|
size 716756
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:2960feb2c7ad941a473408e2f6fd8c324f60f6af3c4d8f11217c676fd830e4cb
|
oid sha256:ea48a79b211bc9857e7a881d6b9bc22580280e1d7cf3b30d6613466f4f440f8f
|
||||||
size 578660
|
size 721934
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
8a8d6505d9ef62cb2eeb8c75a5ee5bbb libtensorrt_llm_executor_static.a
|
56853a19cf213aa5330ea087c9d86a60 libtensorrt_llm_executor_static.a
|
||||||
e3b8edc619c99a7f125fe81bc8554ff0 libtensorrt_llm_executor_static.pre_cxx11.a
|
213487d55c816a1987aa79547091068f libtensorrt_llm_executor_static.pre_cxx11.a
|
||||||
230623fa285048a2de5c54c2cc0f364fb9f2c559 commit
|
741fb083cc42933439ae54557b177b6d7064da4f commit
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:cde295fa290b15b3d76b8e8b2cc435d7fceb2f456d8cb4d9b22ee2cf3ddbd344
|
oid sha256:499f3aac1b98c5b411f1dacdddf8521b2b1f600388b44e6f7aab5b3f0cdf1280
|
||||||
size 588504
|
size 721366
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
version https://git-lfs.github.com/spec/v1
|
version https://git-lfs.github.com/spec/v1
|
||||||
oid sha256:54ac66f3555bff4ed28ba0352bcb4a0f541346592cf109b491071b6374e5238c
|
oid sha256:9c2c7e84be6b0e8baf296196ee9d7e84509bda2630ce3ada8a39dc498713ff48
|
||||||
size 562260
|
size 700000
|
||||||
|
|||||||
@ -1,2 +1,2 @@
|
|||||||
ee96c6e2742539da0e8d732635f84449 libtensorrt_llm_executor_static.a
|
dcca3b095dad76dac36611be6104f011 libtensorrt_llm_executor_static.a
|
||||||
9154564ed926ffbcdb83e7eac3504fa0 libtensorrt_llm_executor_static.pre_cxx11.a
|
6cae7ce493704f7ad8d724cf8a538e2c libtensorrt_llm_executor_static.pre_cxx11.a
|
||||||
|
|||||||
@ -25,9 +25,9 @@ namespace kernels
|
|||||||
{
|
{
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
|
__global__ void ban_repeat_ngram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf,
|
||||||
const int** parent_ids_buf, const int* batch_slots, int batch_size, int beam_width, int max_seq_len,
|
int const** parent_ids_buf, int const* batch_slots, int batch_size, int beam_width, int max_seq_len,
|
||||||
const int* no_repeat_ngram_size_buf, int vocab_size_padded, const int* sequence_lengths)
|
int const* no_repeat_ngram_size_buf, int vocab_size_padded, int const* sequence_lengths)
|
||||||
{
|
{
|
||||||
/**
|
/**
|
||||||
* Find subsequences that match the last (ngram_size - 1) generated tokens. The next-tokens of those matching
|
* Find subsequences that match the last (ngram_size - 1) generated tokens. The next-tokens of those matching
|
||||||
@ -46,13 +46,13 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const Fi
|
|||||||
* in-bound positions only. For leftside out-of-boundary tokens, access by global memory.
|
* in-bound positions only. For leftside out-of-boundary tokens, access by global memory.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
const int output_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
int const output_idx = blockIdx.x * blockDim.x + threadIdx.x;
|
||||||
const int local_batch_idx = blockIdx.y / beam_width;
|
int const local_batch_idx = blockIdx.y / beam_width;
|
||||||
auto const batch_slot = batch_slots != nullptr ? batch_slots[local_batch_idx] : local_batch_idx;
|
auto const batch_slot = batch_slots != nullptr ? batch_slots[local_batch_idx] : local_batch_idx;
|
||||||
const int beam_idx = blockIdx.y % beam_width;
|
int const beam_idx = blockIdx.y % beam_width;
|
||||||
const bool beam_search = beam_width > 1;
|
bool const beam_search = beam_width > 1;
|
||||||
const int no_repeat_ngram_size = no_repeat_ngram_size_buf[batch_slot];
|
int const no_repeat_ngram_size = no_repeat_ngram_size_buf[batch_slot];
|
||||||
const int step = sequence_lengths[batch_slot];
|
int const step = sequence_lengths[batch_slot];
|
||||||
|
|
||||||
// case 1: ngram_size == 0 --> this means no ngram limit
|
// case 1: ngram_size == 0 --> this means no ngram limit
|
||||||
// case 2: generated length must be greater than ngram_size to do ngram check
|
// case 2: generated length must be greater than ngram_size to do ngram check
|
||||||
@ -133,9 +133,9 @@ __global__ void ban_repeat_ngram(T* logits, const int** output_ids_buf, const Fi
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
|
void invokeBanRepeatNgram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf,
|
||||||
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, int beam_width,
|
int const** parent_ids_buf, int const* batch_slot, int const* sequence_lengths, int batch_size, int beam_width,
|
||||||
int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream)
|
int max_seq_len, int const* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
// each input in the local batch can have different no_repeat_ngram_size. Use max for shmem allocation
|
// each input in the local batch can have different no_repeat_ngram_size. Use max for shmem allocation
|
||||||
// getting the max of current batch and allocate shmem as needed is ideal. But here the ngram_buf is on GPU, while
|
// getting the max of current batch and allocate shmem as needed is ideal. But here the ngram_buf is on GPU, while
|
||||||
|
|||||||
@ -26,9 +26,9 @@ namespace kernels
|
|||||||
{
|
{
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void invokeBanRepeatNgram(T* logits, const int** output_ids_buf, const FinishedState* finished_buf,
|
void invokeBanRepeatNgram(T* logits, int const** output_ids_buf, FinishedState const* finished_buf,
|
||||||
const int** parent_ids_buf, const int* batch_slot, const int* sequence_lengths, int batch_size, int beam_width,
|
int const** parent_ids_buf, int const* batch_slot, int const* sequence_lengths, int batch_size, int beam_width,
|
||||||
int max_seq_len, const int* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream);
|
int max_seq_len, int const* no_repeat_ngram_size_buf, int vocab_size_padded, size_t max_step, cudaStream_t stream);
|
||||||
|
|
||||||
} // namespace kernels
|
} // namespace kernels
|
||||||
} // namespace tensorrt_llm
|
} // namespace tensorrt_llm
|
||||||
|
|||||||
@ -49,8 +49,8 @@ __device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float
|
|||||||
|
|
||||||
template <typename T, int MAX_K, int THREADBLOCK_SIZE>
|
template <typename T, int MAX_K, int THREADBLOCK_SIZE>
|
||||||
__launch_bounds__(THREADBLOCK_SIZE) __global__
|
__launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||||
void beam_topK_kernel(const T* log_probs, int* topk_tmp_id_buf, T* topk_tmp_val_buf, const bool* finished,
|
void beam_topK_kernel(T const* log_probs, int* topk_tmp_id_buf, T* topk_tmp_val_buf, bool const* finished,
|
||||||
const int* sequence_lengths, const int vocab_size, T diversity_rate, float length_penalty)
|
int const* sequence_lengths, int const vocab_size, T diversity_rate, float length_penalty)
|
||||||
{
|
{
|
||||||
typedef cub::BlockReduce<TopK<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
|
typedef cub::BlockReduce<TopK<T, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
|
||||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
@ -59,7 +59,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
|||||||
int block_id = blockIdx.x; // batch beam index.
|
int block_id = blockIdx.x; // batch beam index.
|
||||||
TopK<T, MAX_K> partial;
|
TopK<T, MAX_K> partial;
|
||||||
|
|
||||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -101,7 +101,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
|||||||
{
|
{
|
||||||
int thread_id = threadIdx.x;
|
int thread_id = threadIdx.x;
|
||||||
int block_id = blockIdx.x;
|
int block_id = blockIdx.x;
|
||||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||||
TopK<T, MAX_K> partial;
|
TopK<T, MAX_K> partial;
|
||||||
if (thread_id == 0)
|
if (thread_id == 0)
|
||||||
@ -136,7 +136,7 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
|||||||
int tid = threadIdx.x;
|
int tid = threadIdx.x;
|
||||||
int bid = blockIdx.x;
|
int bid = blockIdx.x;
|
||||||
TopK<T, MAX_K> partial;
|
TopK<T, MAX_K> partial;
|
||||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
@ -167,32 +167,32 @@ __launch_bounds__(THREADBLOCK_SIZE) __global__
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
|
template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
|
||||||
__global__ void topk_stage_1_opt3(const T* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf,
|
__global__ void topk_stage_1_opt3(T const* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf,
|
||||||
T* topk_tmp_val_buf, const bool* finished, const int* sequence_lengths, const int k, const int vocab_size,
|
T* topk_tmp_val_buf, bool const* finished, int const* sequence_lengths, int const k, int const vocab_size,
|
||||||
const float length_penalty, const int* end_ids)
|
float const length_penalty, int const* end_ids)
|
||||||
{
|
{
|
||||||
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
|
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
|
||||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
int const tid = threadIdx.x;
|
||||||
const int bid = blockIdx.x;
|
int const bid = blockIdx.x;
|
||||||
|
|
||||||
const int row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs (batchbeam index)
|
int const row_id = bid / BLOCKS_PER_BEAM_; // row id for log_probs (batchbeam index)
|
||||||
const int block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam
|
int const block_lane = bid % BLOCKS_PER_BEAM_; // block id for a beam
|
||||||
const int tmp_log_buf_index = row_id * vocab_size;
|
int const tmp_log_buf_index = row_id * vocab_size;
|
||||||
const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k;
|
int const tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM_ * k + block_lane * k;
|
||||||
TopK_2<T> partial;
|
TopK_2<T> partial;
|
||||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||||
|
|
||||||
if (finished != nullptr && finished[row_id] == true)
|
if (finished != nullptr && finished[row_id] == true)
|
||||||
{
|
{
|
||||||
if (tid < k)
|
if (tid < k)
|
||||||
{
|
{
|
||||||
const int index = tmp_topk_buf_index + tid;
|
int const index = tmp_topk_buf_index + tid;
|
||||||
if (block_lane == 0 && tid == 0)
|
if (block_lane == 0 && tid == 0)
|
||||||
{
|
{
|
||||||
const int end_id = end_ids[row_id / k];
|
int const end_id = end_ids[row_id / k];
|
||||||
topk_tmp_id_buf[index] = tmp_log_buf_index + end_id;
|
topk_tmp_id_buf[index] = tmp_log_buf_index + end_id;
|
||||||
topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id];
|
topk_tmp_val_buf[index] = log_probs[tmp_log_buf_index + end_id];
|
||||||
}
|
}
|
||||||
@ -226,7 +226,7 @@ __global__ void topk_stage_1_opt3(const T* __restrict log_probs, T* tmp_log_prob
|
|||||||
|
|
||||||
if (tid == 0)
|
if (tid == 0)
|
||||||
{
|
{
|
||||||
const int index = tmp_topk_buf_index + ite;
|
int const index = tmp_topk_buf_index + ite;
|
||||||
topk_tmp_id_buf[index] = total.p;
|
topk_tmp_id_buf[index] = total.p;
|
||||||
topk_tmp_val_buf[index] = total.u;
|
topk_tmp_val_buf[index] = total.u;
|
||||||
tmp_log_probs[total.p] = -MAX_T_VAL;
|
tmp_log_probs[total.p] = -MAX_T_VAL;
|
||||||
@ -236,15 +236,15 @@ __global__ void topk_stage_1_opt3(const T* __restrict log_probs, T* tmp_log_prob
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
|
template <typename T, int BLOCK_SIZE_, int BLOCKS_PER_BEAM_>
|
||||||
__global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids,
|
__global__ void topk_stage_2_opt3(int const* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids,
|
||||||
BeamHypotheses beam_hyps, const int* end_ids, const int vocab_size, const int k)
|
BeamHypotheses beam_hyps, int const* end_ids, int const vocab_size, int const k)
|
||||||
{
|
{
|
||||||
const int size = k * k * BLOCKS_PER_BEAM_;
|
int const size = k * k * BLOCKS_PER_BEAM_;
|
||||||
const int tid = threadIdx.x;
|
int const tid = threadIdx.x;
|
||||||
const int batch_id = blockIdx.x;
|
int const batch_id = blockIdx.x;
|
||||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||||
const float length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]};
|
float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]};
|
||||||
|
|
||||||
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
|
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE_> BlockReduce;
|
||||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
@ -263,7 +263,7 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (beam_hyps.num_beams != nullptr)
|
if (beam_hyps.num_beams != nullptr)
|
||||||
{
|
{
|
||||||
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
||||||
if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0)
|
if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0)
|
||||||
{
|
{
|
||||||
// initialize the buffer
|
// initialize the buffer
|
||||||
@ -304,9 +304,9 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
||||||
const float normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty);
|
float const normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty);
|
||||||
const int num_beam = beam_hyps.num_beams[global_batch_idx];
|
int const num_beam = beam_hyps.num_beams[global_batch_idx];
|
||||||
int beam_idx = num_beam;
|
int beam_idx = num_beam;
|
||||||
// If there are beam_width finished sentences, check that the score of
|
// If there are beam_width finished sentences, check that the score of
|
||||||
// selected candidatet is higher than min_normed_score or not. If
|
// selected candidatet is higher than min_normed_score or not. If
|
||||||
@ -345,20 +345,20 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx)
|
int const tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx)
|
||||||
* beam_hyps.max_seq_len;
|
* beam_hyps.max_seq_len;
|
||||||
beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id];
|
beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id];
|
||||||
|
|
||||||
int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k;
|
int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k;
|
||||||
for (int j = beam_hyps.step - 1; j >= 0; j--)
|
for (int j = beam_hyps.step - 1; j >= 0; j--)
|
||||||
{
|
{
|
||||||
const int src_idx = j * beam_hyps.batch_size * k
|
int const src_idx = j * beam_hyps.batch_size * k
|
||||||
+ beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id;
|
+ beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id;
|
||||||
|
|
||||||
beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx];
|
beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx];
|
||||||
prev_id = beam_hyps.parent_ids_src[src_idx];
|
prev_id = beam_hyps.parent_ids_src[src_idx];
|
||||||
}
|
}
|
||||||
const int tgt_beam_idx = global_batch_idx * k + beam_idx;
|
int const tgt_beam_idx = global_batch_idx * k + beam_idx;
|
||||||
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step;
|
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step;
|
||||||
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
|
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
|
||||||
beam_hyps.min_normed_scores[global_batch_idx]
|
beam_hyps.min_normed_scores[global_batch_idx]
|
||||||
@ -389,21 +389,21 @@ __global__ void topk_stage_2_opt3(const int* __restrict topk_tmp_id_buf, T* topk
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM>
|
template <typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM>
|
||||||
__global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf,
|
__global__ void topk_stage_1_opt2_general(T const* __restrict log_probs, T* tmp_log_probs, int* topk_tmp_id_buf,
|
||||||
T* topk_tmp_val_buf, const bool* finished, const int* sequence_lengths, const int k, const int vocab_size,
|
T* topk_tmp_val_buf, bool const* finished, int const* sequence_lengths, int const k, int const vocab_size,
|
||||||
const float length_penalty)
|
float const length_penalty)
|
||||||
{
|
{
|
||||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||||
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce;
|
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce;
|
||||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
int const tid = threadIdx.x;
|
||||||
const int bid = blockIdx.x;
|
int const bid = blockIdx.x;
|
||||||
const int row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs
|
int const row_id = bid / BLOCKS_PER_BEAM; // row id for log_probs
|
||||||
const int block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam
|
int const block_lane = bid % BLOCKS_PER_BEAM; // block id for a beam
|
||||||
const int tmp_log_buf_index = row_id * vocab_size;
|
int const tmp_log_buf_index = row_id * vocab_size;
|
||||||
const int tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k;
|
int const tmp_topk_buf_index = row_id * BLOCKS_PER_BEAM * k + block_lane * k;
|
||||||
TopK_2<T> partial;
|
TopK_2<T> partial;
|
||||||
|
|
||||||
for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM)
|
for (int elem_id = tid + block_lane * BLOCK_SIZE; elem_id < vocab_size; elem_id += BLOCK_SIZE * BLOCKS_PER_BEAM)
|
||||||
@ -426,7 +426,7 @@ __global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, T* tmp_
|
|||||||
|
|
||||||
if (tid == 0)
|
if (tid == 0)
|
||||||
{
|
{
|
||||||
const int index = tmp_topk_buf_index + ite;
|
int const index = tmp_topk_buf_index + ite;
|
||||||
topk_tmp_id_buf[index] = total.p;
|
topk_tmp_id_buf[index] = total.p;
|
||||||
topk_tmp_val_buf[index] = total.u;
|
topk_tmp_val_buf[index] = total.u;
|
||||||
tmp_log_probs[total.p] = -MAX_T_VAL;
|
tmp_log_probs[total.p] = -MAX_T_VAL;
|
||||||
@ -436,15 +436,15 @@ __global__ void topk_stage_1_opt2_general(const T* __restrict log_probs, T* tmp_
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM>
|
template <typename T, int BLOCK_SIZE, int BLOCKS_PER_BEAM>
|
||||||
__global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids,
|
__global__ void topk_stage_2_opt2_general(int const* __restrict topk_tmp_id_buf, T* topk_tmp_val_buf, int* ids,
|
||||||
BeamHypotheses beam_hyps, const int* end_ids, const int k, const int vocab_size)
|
BeamHypotheses beam_hyps, int const* end_ids, int const k, int const vocab_size)
|
||||||
{
|
{
|
||||||
const int size = k * k * BLOCKS_PER_BEAM;
|
int const size = k * k * BLOCKS_PER_BEAM;
|
||||||
const int tid = threadIdx.x;
|
int const tid = threadIdx.x;
|
||||||
const int batch_id = blockIdx.x;
|
int const batch_id = blockIdx.x;
|
||||||
const bool IS_FP16 = std::is_same<T, half>::value;
|
bool const IS_FP16 = std::is_same<T, half>::value;
|
||||||
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
const T MAX_T_VAL = (IS_FP16) ? HALF_FLT_MAX : FLT_MAX;
|
||||||
const float length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]};
|
float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[batch_id]};
|
||||||
|
|
||||||
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce;
|
typedef cub::BlockReduce<TopK_2<T>, BLOCK_SIZE> BlockReduce;
|
||||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
@ -463,7 +463,7 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf,
|
|||||||
__syncthreads();
|
__syncthreads();
|
||||||
if (beam_hyps.num_beams != nullptr)
|
if (beam_hyps.num_beams != nullptr)
|
||||||
{
|
{
|
||||||
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
||||||
if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0)
|
if (beam_hyps.num_beams[global_batch_idx] == 0 && tid == 0)
|
||||||
{
|
{
|
||||||
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
|
beam_hyps.min_normed_scores[global_batch_idx] = FLT_MAX;
|
||||||
@ -503,9 +503,9 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf,
|
|||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
const int global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
int const global_batch_idx = beam_hyps.ite * beam_hyps.local_batch_size + batch_id;
|
||||||
const float normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty);
|
float const normed_score = apply_length_penalty(s_val[total.p], beam_hyps.step, length_penalty);
|
||||||
const int num_beam = beam_hyps.num_beams[global_batch_idx];
|
int const num_beam = beam_hyps.num_beams[global_batch_idx];
|
||||||
int beam_idx = num_beam;
|
int beam_idx = num_beam;
|
||||||
// If there are beam_width finished sentences, check that the score of
|
// If there are beam_width finished sentences, check that the score of
|
||||||
// selected candidatet is higher than min_normed_score or not. If
|
// selected candidatet is higher than min_normed_score or not. If
|
||||||
@ -544,20 +544,20 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
const int tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx)
|
int const tgt_id_offset = ((batch_id + beam_hyps.ite * beam_hyps.local_batch_size) * k + beam_idx)
|
||||||
* beam_hyps.max_seq_len;
|
* beam_hyps.max_seq_len;
|
||||||
beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id];
|
beam_hyps.output_ids_tgt[tgt_id_offset + beam_hyps.step] = end_ids[batch_id];
|
||||||
|
|
||||||
int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k;
|
int prev_id = (topk_tmp_id_buf[batch_id * size + total.p] / vocab_size) % k;
|
||||||
for (int j = beam_hyps.step - 1; j >= 0; j--)
|
for (int j = beam_hyps.step - 1; j >= 0; j--)
|
||||||
{
|
{
|
||||||
const int src_idx = j * beam_hyps.batch_size * k
|
int const src_idx = j * beam_hyps.batch_size * k
|
||||||
+ beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id;
|
+ beam_hyps.ite * beam_hyps.local_batch_size * k + batch_id * k + prev_id;
|
||||||
|
|
||||||
beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx];
|
beam_hyps.output_ids_tgt[tgt_id_offset + j] = beam_hyps.output_ids_src[src_idx];
|
||||||
prev_id = beam_hyps.parent_ids_src[src_idx];
|
prev_id = beam_hyps.parent_ids_src[src_idx];
|
||||||
}
|
}
|
||||||
const int tgt_beam_idx = global_batch_idx * k + beam_idx;
|
int const tgt_beam_idx = global_batch_idx * k + beam_idx;
|
||||||
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step;
|
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = beam_hyps.step;
|
||||||
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
|
beam_hyps.normed_scores[tgt_beam_idx] = normed_score;
|
||||||
beam_hyps.min_normed_scores[global_batch_idx]
|
beam_hyps.min_normed_scores[global_batch_idx]
|
||||||
@ -613,18 +613,18 @@ __global__ void topk_stage_2_opt2_general(const int* __restrict topk_tmp_id_buf,
|
|||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps,
|
void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps,
|
||||||
const bool* finished, const int* sequence_lengths, const int batch_size, const int beam_width,
|
bool const* finished, int const* sequence_lengths, int const batch_size, int const beam_width,
|
||||||
const int vocab_size_padded_, const T diversity_rate, const float length_penalty, const int* end_ids,
|
int const vocab_size_padded_, const T diversity_rate, float const length_penalty, int const* end_ids,
|
||||||
cudaStream_t stream)
|
cudaStream_t stream)
|
||||||
{
|
{
|
||||||
// log_probs: (batch, beam, vocab) cumulative log_probs of beams ending with a
|
// log_probs: (batch, beam, vocab) cumulative log_probs of beams ending with a
|
||||||
// token.
|
// token.
|
||||||
const int vocab_size = vocab_size_padded_;
|
int const vocab_size = vocab_size_padded_;
|
||||||
// Beam size should be less than or equal to vocab size.
|
// Beam size should be less than or equal to vocab size.
|
||||||
assert(beam_width <= vocab_size);
|
assert(beam_width <= vocab_size);
|
||||||
// Beam search needs the sequence lengths of beams to apply length penalty.
|
// Beam search needs the sequence lengths of beams to apply length penalty.
|
||||||
assert(length_penalty == 0.0f || sequence_lengths != nullptr);
|
assert(length_penalty == 0.0f || sequence_lengths != nullptr);
|
||||||
const int max_block_per_beam = 8;
|
int const max_block_per_beam = 8;
|
||||||
int temp_log_probs_buf_size = batch_size * beam_width * vocab_size; // type float
|
int temp_log_probs_buf_size = batch_size * beam_width * vocab_size; // type float
|
||||||
int topk_tmp_ids_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type int
|
int topk_tmp_ids_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type int
|
||||||
int topk_tmp_val_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type float
|
int topk_tmp_val_buf_size = batch_size * beam_width * beam_width * max_block_per_beam; // type float
|
||||||
@ -685,13 +685,13 @@ void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs,
|
|||||||
#undef CASE_K_DIV
|
#undef CASE_K_DIV
|
||||||
|
|
||||||
template void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, float* log_probs, int* ids,
|
template void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, float* log_probs, int* ids,
|
||||||
BeamHypotheses* beam_hyps, const bool* finished, const int* sequence_lengths, const int batch_size,
|
BeamHypotheses* beam_hyps, bool const* finished, int const* sequence_lengths, int const batch_size,
|
||||||
const int beam_width, const int vocab_size_padded_, const float diversity_rate, const float length_penalty,
|
int const beam_width, int const vocab_size_padded_, float const diversity_rate, float const length_penalty,
|
||||||
const int* end_ids, cudaStream_t stream);
|
int const* end_ids, cudaStream_t stream);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
__global__ void tileEncoderResults(T* tiled_output, int* tiled_sequence_length, const T* output,
|
__global__ void tileEncoderResults(T* tiled_output, int* tiled_sequence_length, T const* output,
|
||||||
const int* sequence_length, const uint32_t batch_size, const uint32_t beam_width, const uint32_t d_model)
|
int const* sequence_length, const uint32_t batch_size, const uint32_t beam_width, const uint32_t d_model)
|
||||||
{
|
{
|
||||||
if (blockIdx.x == 0)
|
if (blockIdx.x == 0)
|
||||||
{
|
{
|
||||||
@ -711,7 +711,7 @@ __global__ void tileEncoderResults(T* tiled_output, int* tiled_sequence_length,
|
|||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void invokeTileEncoderResults(T* tiled_output, int* tiled_sequence_length, const T* output, const int* sequence_length,
|
void invokeTileEncoderResults(T* tiled_output, int* tiled_sequence_length, T const* output, int const* sequence_length,
|
||||||
const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model,
|
const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len, const size_t d_model,
|
||||||
cudaStream_t stream)
|
cudaStream_t stream)
|
||||||
{
|
{
|
||||||
@ -739,30 +739,30 @@ void invokeTileEncoderResults(T* tiled_output, int* tiled_sequence_length, const
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template void invokeTileEncoderResults(float* tiled_output, int* tiled_sequence_length, const float* output,
|
template void invokeTileEncoderResults(float* tiled_output, int* tiled_sequence_length, float const* output,
|
||||||
const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
||||||
const size_t d_model, cudaStream_t stream);
|
const size_t d_model, cudaStream_t stream);
|
||||||
|
|
||||||
template void invokeTileEncoderResults(half* tiled_output, int* tiled_sequence_length, const half* output,
|
template void invokeTileEncoderResults(half* tiled_output, int* tiled_sequence_length, half const* output,
|
||||||
const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
||||||
const size_t d_model, cudaStream_t stream);
|
const size_t d_model, cudaStream_t stream);
|
||||||
|
|
||||||
template void invokeTileEncoderResults(half2* tiled_output, int* tiled_sequence_length, const half2* output,
|
template void invokeTileEncoderResults(half2* tiled_output, int* tiled_sequence_length, half2 const* output,
|
||||||
const int* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
int const* sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
||||||
const size_t d_model, cudaStream_t stream);
|
const size_t d_model, cudaStream_t stream);
|
||||||
#ifdef ENABLE_BF16
|
#ifdef ENABLE_BF16
|
||||||
template void invokeTileEncoderResults(__nv_bfloat16* tiled_output, int* tiled_sequence_length,
|
template void invokeTileEncoderResults(__nv_bfloat16* tiled_output, int* tiled_sequence_length,
|
||||||
const __nv_bfloat16* output, const int* sequence_length, const size_t batch_size, const size_t beam_width,
|
__nv_bfloat16 const* output, int const* sequence_length, const size_t batch_size, const size_t beam_width,
|
||||||
const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream);
|
const size_t mem_max_seq_len, const size_t d_model, cudaStream_t stream);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished,
|
__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished,
|
||||||
const float* cum_log_probs, const int batch_size, const int beam_width)
|
float const* cum_log_probs, int const batch_size, int const beam_width)
|
||||||
{
|
{
|
||||||
const int bid = blockIdx.x;
|
int const bid = blockIdx.x;
|
||||||
const int tgt_start_idx = beam_hyps.num_beams[bid];
|
int const tgt_start_idx = beam_hyps.num_beams[bid];
|
||||||
const int max_seq_len{beam_hyps.max_seq_len};
|
int const max_seq_len{beam_hyps.max_seq_len};
|
||||||
const float length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[bid]};
|
float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[bid]};
|
||||||
if (beam_hyps.is_done[bid])
|
if (beam_hyps.is_done[bid])
|
||||||
{
|
{
|
||||||
return;
|
return;
|
||||||
@ -771,10 +771,10 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedSta
|
|||||||
{
|
{
|
||||||
if (threadIdx.x == 0)
|
if (threadIdx.x == 0)
|
||||||
{
|
{
|
||||||
const int src_beam_idx = bid * beam_width + beam_idx;
|
int const src_beam_idx = bid * beam_width + beam_idx;
|
||||||
const int tgt_beam_idx = bid * beam_width * 2 + beam_idx + tgt_start_idx;
|
int const tgt_beam_idx = bid * beam_width * 2 + beam_idx + tgt_start_idx;
|
||||||
|
|
||||||
const int last_token_idx = beam_hyps.sequence_lengths_src[src_beam_idx] - 1;
|
int const last_token_idx = beam_hyps.sequence_lengths_src[src_beam_idx] - 1;
|
||||||
|
|
||||||
beam_hyps.output_ids_tgt[tgt_beam_idx * max_seq_len + last_token_idx]
|
beam_hyps.output_ids_tgt[tgt_beam_idx * max_seq_len + last_token_idx]
|
||||||
= beam_hyps.output_ids_src[src_beam_idx * max_seq_len + last_token_idx];
|
= beam_hyps.output_ids_src[src_beam_idx * max_seq_len + last_token_idx];
|
||||||
@ -810,8 +810,8 @@ __global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedSta
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, const float* cum_log_probs,
|
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished, float const* cum_log_probs,
|
||||||
const int batch_size, const int beam_width, cudaStream_t stream)
|
int const batch_size, int const beam_width, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
insertUnfinishedPath<<<batch_size, 256, 0, stream>>>(beam_hyps, finished, cum_log_probs, batch_size, beam_width);
|
insertUnfinishedPath<<<batch_size, 256, 0, stream>>>(beam_hyps, finished, cum_log_probs, batch_size, beam_width);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -35,57 +35,64 @@ namespace kernels
|
|||||||
// After we collect `beam_width` beams, we will sort them by their norm_scores.
|
// After we collect `beam_width` beams, we will sort them by their norm_scores.
|
||||||
struct BeamHypotheses
|
struct BeamHypotheses
|
||||||
{
|
{
|
||||||
// TODO: simplify the pointers
|
// BS: batch_size
|
||||||
// Pointers initialized in function prepareOutputs in gptDecoder.cpp
|
// BM: beam_width
|
||||||
bool* is_done{nullptr}; // [batchSize], whether the batch is finished
|
// mSL: max_seq_length
|
||||||
const int* input_lengths{nullptr}; // [batchSize]
|
// %%: parameter name when we call [generation.py] dynamic_decoder.forward
|
||||||
float* cum_log_probs{nullptr}; // [batchSize, 2 * beamWidth], outputs.cum_log_probs->template getPtr<float>()
|
|
||||||
float* log_probs{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen], not used?
|
|
||||||
float* min_normed_scores{nullptr}; // [batchSize], worst normed scores for each batch
|
|
||||||
float* normed_scores{nullptr}; // [batchSize, 2 * beamWidth], cum_log / (length ^ length_penalty)
|
|
||||||
int* num_beams{nullptr}; // [batchSize], count of finished beams for each batch
|
|
||||||
int* output_ids_tgt{nullptr}; // [batchSize, 2 * beamWidth, maxSeqLen],
|
|
||||||
int* sequence_lengths_tgt{nullptr}; // [batchSize, 2 * beamWidth], different from sequence_lengths_src
|
|
||||||
|
|
||||||
// Pointers initialized in function invokeSoftMax in onlineBeamSearchLayer.cu
|
// Pointers initialized in these two functions:
|
||||||
const int* end_ids{nullptr}; // get from SoftmaxParams
|
// [gptDecoder.cpp] GptDecoder<T>::forward or [dynamicDecodeOp.cpp] FtDynamicDecode<T>::forward
|
||||||
const int* output_ids_src{nullptr}; // for gatherTree
|
bool* is_done{nullptr}; // [BS] %% self.beam_hyps_is_done
|
||||||
const int* parent_ids_src{nullptr}; // for gatherTree
|
float* cum_log_probs{nullptr}; // [BS, BM*2] %% self.beam_hyps_cum_log_probs
|
||||||
const int** output_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading
|
float* log_probs{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_log_probs
|
||||||
const int** parent_ids_src_ptr{nullptr}; // get from BeamSearchOutputParams for reading
|
float* min_normed_scores{nullptr}; // [BS] %% self.beam_hyps_min_normed_scores
|
||||||
float* log_probs_src{nullptr}; // get from outputs.output_log_probs
|
float* normed_scores{nullptr}; // [BS, BM*2] %% self.beam_hyps_normed_scores
|
||||||
int* sequence_lengths_src{nullptr}; // get from BeamSearchOutputParams
|
int* num_beams{nullptr}; // [BS] %% self.beam_hyps_num_beams
|
||||||
// For reading in function invokeTopkSoftMax but reading and writing in function invokeUpdate
|
int* output_ids_tgt{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_is_done
|
||||||
int** output_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing
|
int* sequence_lengths_tgt{nullptr}; // [BS, BM*2] %% self.beam_hyps_sequence_lengths_tgt
|
||||||
int** parent_ids_tgt_ptr{nullptr}; // get from BeamSearchOutputParams for writing
|
int const* input_lengths{nullptr}; // [BS*BM] %% context_length
|
||||||
|
|
||||||
// Other scalar values and buffers
|
// Pointers initialized in [onlineBeamSearchLayer.cu] invokeSoftMax:
|
||||||
int batch_size{0};
|
int const* end_ids{nullptr}; // [BS*BM] %% self.end_ids
|
||||||
int beam_width{0};
|
FinishedState* finished; // [BS*BM] %% self.finished
|
||||||
int ite{0};
|
float* cum_log_probs_src{nullptr}; // [BS, BM] %% self.cum_log_probs
|
||||||
int local_batch_size{0};
|
float* log_probs_src{nullptr}; // [mSL, BS, BM] %% self.log_probs_tiled
|
||||||
int max_seq_len{0};
|
int* sequence_lengths_src{nullptr}; // [BS*BM] %% self.sequence_length_buffer
|
||||||
int step{0}; // useless in online version of beam search
|
int** output_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] from [dynamicDecodeLayer.cpp]
|
||||||
int vocab_size{0};
|
int** parent_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] from [dynamicDecodeLayer.cpp]
|
||||||
float* diversity_rates{nullptr};
|
|
||||||
float* length_penalties{nullptr};
|
float* diversity_rates{nullptr}; // [BS] from SamplingConfig
|
||||||
int* early_stoppings{nullptr};
|
float* length_penalties{nullptr}; // [BS] from SamplingConfig
|
||||||
bool is_return_normed_score{true}; // return normed_cum_log_probs or cum_log_probs
|
int* early_stoppings{nullptr}; // [BS] from SamplingConfig
|
||||||
|
|
||||||
|
// Pointers for function gatherTree
|
||||||
|
int const* output_ids_src{nullptr}; //
|
||||||
|
int const* parent_ids_src{nullptr}; //
|
||||||
|
|
||||||
|
// Scalar values
|
||||||
|
bool is_return_normed_score{true}; // return normed_cum_log_probs or cum_log_probs, always be true now
|
||||||
|
int batch_size{0}; //
|
||||||
|
int beam_width{0}; //
|
||||||
|
int ite{0}; // index of local_batch, always be 0 if pp_size==1
|
||||||
|
int local_batch_size{0}; //
|
||||||
|
int max_seq_len{0}; //
|
||||||
|
int step{0}; // only used in [beamSearchTopkKernels.cu], always be 0 in [onlineSoftmaxBeamsearchKernels*.cu.h]
|
||||||
|
int vocab_size{0}; // vocab_size_padded
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps,
|
void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps,
|
||||||
const bool* finished, const int* sequence_lengths, const int batch_size, const int beam_width,
|
bool const* finished, int const* sequence_lengths, int const batch_size, int const beam_width,
|
||||||
const int vocab_size_padded_, const T diversity_rate, const float length_penalty, const int* end_ids,
|
int const vocab_size_padded_, const T diversity_rate, float const length_penalty, int const* end_ids,
|
||||||
cudaStream_t stream);
|
cudaStream_t stream);
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void invokeTileEncoderResults(T* tiled_encoder_output, int* tiled_encoder_sequence_length, const T* encoder_output,
|
void invokeTileEncoderResults(T* tiled_encoder_output, int* tiled_encoder_sequence_length, T const* encoder_output,
|
||||||
const int* encoder_sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
int const* encoder_sequence_length, const size_t batch_size, const size_t beam_width, const size_t mem_max_seq_len,
|
||||||
const size_t d_model, cudaStream_t stream);
|
const size_t d_model, cudaStream_t stream);
|
||||||
|
|
||||||
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, const FinishedState* finished, const float* cum_log_probs,
|
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished, float const* cum_log_probs,
|
||||||
const int batch_size, const int beam_width, cudaStream_t stream);
|
int const batch_size, int const beam_width, cudaStream_t stream);
|
||||||
|
|
||||||
void invokeCopyBatchMajorToGeneralPtr(
|
void invokeCopyBatchMajorToGeneralPtr(
|
||||||
void* output_ids_ptr, int* output_ids, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream);
|
void* output_ids_ptr, int* output_ids, int batch_size, int beam_width, int max_seq_len, cudaStream_t stream);
|
||||||
|
|||||||
@ -58,13 +58,13 @@ static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype)
|
|||||||
else if (dtype == DATA_TYPE_INT32)
|
else if (dtype == DATA_TYPE_INT32)
|
||||||
{
|
{
|
||||||
int32_t inorm = static_cast<int32_t>(norm);
|
int32_t inorm = static_cast<int32_t>(norm);
|
||||||
alpha = reinterpret_cast<const uint32_t&>(inorm);
|
alpha = reinterpret_cast<uint32_t const&>(inorm);
|
||||||
}
|
}
|
||||||
else if (dtype == DATA_TYPE_BF16)
|
else if (dtype == DATA_TYPE_BF16)
|
||||||
{
|
{
|
||||||
// TODO HACK!! BF16 Outputs are computed in FP32 for FP8.
|
// TODO HACK!! BF16 Outputs are computed in FP32 for FP8.
|
||||||
// This is because cublas does not allow current FP32 output.
|
// This is because cublas does not allow current FP32 output.
|
||||||
alpha = reinterpret_cast<const uint32_t&>(norm);
|
alpha = reinterpret_cast<uint32_t const&>(norm);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
@ -77,7 +77,7 @@ static inline void set_alpha(uint32_t& alpha, float norm, Data_type dtype)
|
|||||||
class FusedMHARunnerV2::mhaImpl
|
class FusedMHARunnerV2::mhaImpl
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
mhaImpl(const Data_type data_type, const int numHeads, const int headSize, const float qScaling, int sm_)
|
mhaImpl(const Data_type data_type, int const numHeads, int const headSize, float const qScaling, int sm_)
|
||||||
: mDataType(data_type)
|
: mDataType(data_type)
|
||||||
, mNumHeads(numHeads)
|
, mNumHeads(numHeads)
|
||||||
, mHeadSize(headSize)
|
, mHeadSize(headSize)
|
||||||
@ -105,17 +105,17 @@ public:
|
|||||||
|
|
||||||
// Shared setup function.
|
// Shared setup function.
|
||||||
template <typename Params>
|
template <typename Params>
|
||||||
void setup_params(Params& params, const int b, const int s_q, const int s_kv, const int sliding_window_size,
|
void setup_params(Params& params, int const b, int const s_q, int const s_kv, int const sliding_window_size,
|
||||||
const int total_seqlen, const bool has_alibi, const bool scale_alibi, const int tp_size, const int tp_rank)
|
int const total_seqlen, bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||||
{
|
{
|
||||||
|
|
||||||
const float inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling));
|
float const inv_sqrt_scale = (1.f / (sqrtf(mHeadSize) * mQScaling));
|
||||||
// Note that we apply scales and bias in the order of
|
// Note that we apply scales and bias in the order of
|
||||||
// (bmm1_output * scale_bmm1 + alibi) * scale_after_alibi
|
// (bmm1_output * scale_bmm1 + alibi) * scale_after_alibi
|
||||||
const float scale_after_alibi = scale_alibi ? inv_sqrt_scale : 1.0f;
|
float const scale_after_alibi = scale_alibi ? inv_sqrt_scale : 1.0f;
|
||||||
const float scale_bmm1 = scale_alibi ? 1.0f : inv_sqrt_scale;
|
float const scale_bmm1 = scale_alibi ? 1.0f : inv_sqrt_scale;
|
||||||
const float scale_softmax = 1.f; // Seems to be only required for int8
|
float const scale_softmax = 1.f; // Seems to be only required for int8
|
||||||
const float scale_bmm2 = 1.f;
|
float const scale_bmm2 = 1.f;
|
||||||
|
|
||||||
Data_type scale_type = mLaunchParams.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
|
Data_type scale_type = mLaunchParams.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
|
||||||
// Use exp2f optimization for warp-specialized ws kernels on Hopper.
|
// Use exp2f optimization for warp-specialized ws kernels on Hopper.
|
||||||
@ -153,8 +153,8 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Support packed QKV.
|
// Support packed QKV.
|
||||||
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
|
void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen, bool const has_alibi,
|
||||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||||
{
|
{
|
||||||
|
|
||||||
// Determine launch parameters.
|
// Determine launch parameters.
|
||||||
@ -165,10 +165,10 @@ public:
|
|||||||
TLLM_CHECK_WITH_INFO(mHeadSize > 0, "Head size should be greater than 0.");
|
TLLM_CHECK_WITH_INFO(mHeadSize > 0, "Head size should be greater than 0.");
|
||||||
mLaunchParams.padded_d = (mHeadSize & (mHeadSize - 1)) == 0 ? mHeadSize : pow(2, int(log2(mHeadSize)) + 1);
|
mLaunchParams.padded_d = (mHeadSize & (mHeadSize - 1)) == 0 ? mHeadSize : pow(2, int(log2(mHeadSize)) + 1);
|
||||||
|
|
||||||
const bool isSm70 = (sm == kSM_70);
|
bool const isSm70 = (sm == kSM_70);
|
||||||
const bool isSm90 = (sm == kSM_90);
|
bool const isSm90 = (sm == kSM_90);
|
||||||
const bool isSm8x = (sm == kSM_86 || sm == kSM_89);
|
bool const isSm8x = (sm == kSM_86 || sm == kSM_89);
|
||||||
const bool isSm80 = (sm == kSM_80);
|
bool const isSm80 = (sm == kSM_80);
|
||||||
if (isSm70)
|
if (isSm70)
|
||||||
{
|
{
|
||||||
mLaunchParams.flash_attention = true;
|
mLaunchParams.flash_attention = true;
|
||||||
@ -238,9 +238,9 @@ public:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Support paged_kv_cache and chunked_attention.
|
// Support paged_kv_cache and chunked_attention.
|
||||||
void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
|
void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
|
||||||
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
|
int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen, bool const has_alibi,
|
||||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||||
{
|
{
|
||||||
|
|
||||||
// Determine launch parameters.
|
// Determine launch parameters.
|
||||||
@ -253,9 +253,9 @@ public:
|
|||||||
mLaunchParams.padded_d = (mHeadSize & (mHeadSize - 1)) == 0 ? mHeadSize : pow(2, int(log2(mHeadSize)) + 1);
|
mLaunchParams.padded_d = (mHeadSize & (mHeadSize - 1)) == 0 ? mHeadSize : pow(2, int(log2(mHeadSize)) + 1);
|
||||||
|
|
||||||
// Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256
|
// Hopper: fallback to original fmha_v2 when head_size <= 64 and seq_len <= 256
|
||||||
const bool isSm90 = (sm == kSM_90);
|
bool const isSm90 = (sm == kSM_90);
|
||||||
const bool isSm8x = (sm == kSM_86 || sm == kSM_89);
|
bool const isSm8x = (sm == kSM_86 || sm == kSM_89);
|
||||||
const bool isSm80 = (sm == kSM_80);
|
bool const isSm80 = (sm == kSM_80);
|
||||||
|
|
||||||
// always use flash attention kernels.
|
// always use flash attention kernels.
|
||||||
mLaunchParams.flash_attention = true;
|
mLaunchParams.flash_attention = true;
|
||||||
@ -383,7 +383,7 @@ public:
|
|||||||
|
|
||||||
// QKV [TOTAL, 3, h, d]
|
// QKV [TOTAL, 3, h, d]
|
||||||
// NOTE: we may need to use actual seqlen to set oob_value
|
// NOTE: we may need to use actual seqlen to set oob_value
|
||||||
const char* qkv_ptr = reinterpret_cast<const char*>(mParams.qkv_ptr);
|
char const* qkv_ptr = reinterpret_cast<char const*>(mParams.qkv_ptr);
|
||||||
tensor_size_qkv[3] = mTotalSeqLen;
|
tensor_size_qkv[3] = mTotalSeqLen;
|
||||||
|
|
||||||
// Q: STEP_Q
|
// Q: STEP_Q
|
||||||
@ -467,7 +467,7 @@ public:
|
|||||||
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
|
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
|
||||||
|
|
||||||
// Q ptr.
|
// Q ptr.
|
||||||
const char* q_ptr = reinterpret_cast<const char*>(mPagedKVParams.q_ptr);
|
char const* q_ptr = reinterpret_cast<char const*>(mPagedKVParams.q_ptr);
|
||||||
|
|
||||||
// Q: STEP_Q.
|
// Q: STEP_Q.
|
||||||
q_tma_descriptor.set_tma_desctriptor(q_ptr, cudaTmaDescFormat::F16_RN,
|
q_tma_descriptor.set_tma_desctriptor(q_ptr, cudaTmaDescFormat::F16_RN,
|
||||||
@ -518,7 +518,7 @@ public:
|
|||||||
paged_kv_tma_descriptor.copy_to_device(mPagedKVParams.tma_desc_paged_kv, stream);
|
paged_kv_tma_descriptor.copy_to_device(mPagedKVParams.tma_desc_paged_kv, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, const int num_kv_heads)
|
void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask, int const num_kv_heads)
|
||||||
{
|
{
|
||||||
// BF16 FMHA only accumulates on FP32
|
// BF16 FMHA only accumulates on FP32
|
||||||
mLaunchParams.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc;
|
mLaunchParams.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc;
|
||||||
@ -541,11 +541,11 @@ public:
|
|||||||
return MHARunner::fmha_supported(mHeadSize, sm);
|
return MHARunner::fmha_supported(mHeadSize, sm);
|
||||||
}
|
}
|
||||||
|
|
||||||
void run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
|
void run(void const* qkvPtr, void const* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
mParams.qkv_ptr = qkvPtr;
|
mParams.qkv_ptr = qkvPtr;
|
||||||
mParams.o_ptr = outputPtr;
|
mParams.o_ptr = outputPtr;
|
||||||
mParams.cu_seqlens = reinterpret_cast<const int*>(cuSeqlenPtr);
|
mParams.cu_seqlens = reinterpret_cast<int const*>(cuSeqlenPtr);
|
||||||
|
|
||||||
if (sm == kSM_90 && mLaunchParams.use_tma)
|
if (sm == kSM_90 && mLaunchParams.use_tma)
|
||||||
{
|
{
|
||||||
@ -556,8 +556,8 @@ public:
|
|||||||
xmmaKernel->run(mParams, mLaunchParams, stream);
|
xmmaKernel->run(mParams, mLaunchParams, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost,
|
void run_paged_kv(void const* qPtr, void* pagedKVTmaDesc, void const* pagedKVBlockPtrsOnHost,
|
||||||
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr,
|
const KVBlockArray pagedKVCache, void const* cuQSeqlenPtr, void const* cuKVSeqlenPtr, void* outputPtr,
|
||||||
cudaStream_t stream)
|
cudaStream_t stream)
|
||||||
{
|
{
|
||||||
KVBlockArrayForContextFMHA pagedKVCacheForContextMHA;
|
KVBlockArrayForContextFMHA pagedKVCacheForContextMHA;
|
||||||
@ -568,10 +568,10 @@ public:
|
|||||||
mPagedKVParams.tma_desc_paged_kv = reinterpret_cast<cudaTmaDesc*>(pagedKVTmaDesc);
|
mPagedKVParams.tma_desc_paged_kv = reinterpret_cast<cudaTmaDesc*>(pagedKVTmaDesc);
|
||||||
mPagedKVParams.paged_kv_cache = pagedKVCacheForContextMHA;
|
mPagedKVParams.paged_kv_cache = pagedKVCacheForContextMHA;
|
||||||
mPagedKVParams.o_ptr = outputPtr;
|
mPagedKVParams.o_ptr = outputPtr;
|
||||||
mPagedKVParams.cu_q_seqlens = reinterpret_cast<const int*>(cuQSeqlenPtr);
|
mPagedKVParams.cu_q_seqlens = reinterpret_cast<int const*>(cuQSeqlenPtr);
|
||||||
mPagedKVParams.cu_seqlens = reinterpret_cast<const int*>(cuKVSeqlenPtr);
|
mPagedKVParams.cu_seqlens = reinterpret_cast<int const*>(cuKVSeqlenPtr);
|
||||||
// paged kv block device ptrs on host (used by tma descriptors).
|
// paged kv block device ptrs on host (used by tma descriptors).
|
||||||
mLaunchParams.paged_kv_block_ptrs = reinterpret_cast<const int64_t*>(pagedKVBlockPtrsOnHost);
|
mLaunchParams.paged_kv_block_ptrs = reinterpret_cast<int64_t const*>(pagedKVBlockPtrsOnHost);
|
||||||
|
|
||||||
if (sm == kSM_90 && mLaunchParams.use_tma)
|
if (sm == kSM_90 && mLaunchParams.use_tma)
|
||||||
{
|
{
|
||||||
@ -587,7 +587,7 @@ public:
|
|||||||
return pagedKVXmmaKernel->isValid(s) && xmmaKernel->isValid(s);
|
return pagedKVXmmaKernel->isValid(s) && xmmaKernel->isValid(s);
|
||||||
}
|
}
|
||||||
|
|
||||||
int getSFromMaxSeqLen(const int max_seq_len)
|
int getSFromMaxSeqLen(int const max_seq_len)
|
||||||
{
|
{
|
||||||
int S = 1024;
|
int S = 1024;
|
||||||
|
|
||||||
@ -625,35 +625,35 @@ private:
|
|||||||
Fused_multihead_attention_paged_kv_params_v2 mPagedKVParams;
|
Fused_multihead_attention_paged_kv_params_v2 mPagedKVParams;
|
||||||
Launch_params mLaunchParams;
|
Launch_params mLaunchParams;
|
||||||
int sm;
|
int sm;
|
||||||
const FusedMultiHeadAttentionXMMAKernelV2* xmmaKernel;
|
FusedMultiHeadAttentionXMMAKernelV2 const* xmmaKernel;
|
||||||
const FusedMultiHeadAttentionPagedKVXMMAKernelV2* pagedKVXmmaKernel;
|
FusedMultiHeadAttentionPagedKVXMMAKernelV2 const* pagedKVXmmaKernel;
|
||||||
bool use_flash_attention = false;
|
bool use_flash_attention = false;
|
||||||
const Data_type mDataType;
|
const Data_type mDataType;
|
||||||
const int mNumHeads;
|
int const mNumHeads;
|
||||||
const int mHeadSize;
|
int const mHeadSize;
|
||||||
const float mQScaling;
|
float const mQScaling;
|
||||||
int mTotalSeqLen;
|
int mTotalSeqLen;
|
||||||
};
|
};
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
FusedMHARunnerV2::FusedMHARunnerV2(
|
FusedMHARunnerV2::FusedMHARunnerV2(
|
||||||
const Data_type data_type, const int numHeads, const int headSize, const float qScaling)
|
const Data_type data_type, int const numHeads, int const headSize, float const qScaling)
|
||||||
: pimpl(new mhaImpl(data_type, numHeads, headSize, qScaling, tensorrt_llm::common::getSMVersion()))
|
: pimpl(new mhaImpl(data_type, numHeads, headSize, qScaling, tensorrt_llm::common::getSMVersion()))
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
FusedMHARunnerV2::~FusedMHARunnerV2() = default;
|
FusedMHARunnerV2::~FusedMHARunnerV2() = default;
|
||||||
|
|
||||||
void FusedMHARunnerV2::setup(const int b, const int s, const int sliding_window_size, const int total_seqlen,
|
void FusedMHARunnerV2::setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
|
||||||
const bool has_alibi, const bool scale_alibi, const int tp_size, const int tp_rank)
|
bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||||
{
|
{
|
||||||
pimpl->setup(b, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
pimpl->setup(b, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FusedMHARunnerV2::setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
|
void FusedMHARunnerV2::setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
|
||||||
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen, const bool has_alibi,
|
int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen, bool const has_alibi,
|
||||||
const bool scale_alibi, const int tp_size, const int tp_rank)
|
bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||||
{
|
{
|
||||||
pimpl->setup_paged_kv(b, s_q, s_kv, blocks_per_context_sequence, tokens_per_kv_block, sliding_window_size,
|
pimpl->setup_paged_kv(b, s_q, s_kv, blocks_per_context_sequence, tokens_per_kv_block, sliding_window_size,
|
||||||
total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||||
@ -665,18 +665,18 @@ bool FusedMHARunnerV2::fmha_supported()
|
|||||||
}
|
}
|
||||||
|
|
||||||
void FusedMHARunnerV2::setup_flags(
|
void FusedMHARunnerV2::setup_flags(
|
||||||
const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask, const int num_kv_heads)
|
bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask, int const num_kv_heads)
|
||||||
{
|
{
|
||||||
pimpl->setup_flags(force_fp32_acc, is_s_padded, causal_mask, num_kv_heads);
|
pimpl->setup_flags(force_fp32_acc, is_s_padded, causal_mask, num_kv_heads);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FusedMHARunnerV2::run(const void* qkvPtr, const void* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
|
void FusedMHARunnerV2::run(void const* qkvPtr, void const* cuSeqlenPtr, void* outputPtr, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
pimpl->run(qkvPtr, cuSeqlenPtr, outputPtr, stream);
|
pimpl->run(qkvPtr, cuSeqlenPtr, outputPtr, stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
void FusedMHARunnerV2::run_paged_kv(const void* qPtr, void* pagedKVTmaDesc, const void* pagedKVBlockPtrsOnHost,
|
void FusedMHARunnerV2::run_paged_kv(void const* qPtr, void* pagedKVTmaDesc, void const* pagedKVBlockPtrsOnHost,
|
||||||
const KVBlockArray pagedKVCache, const void* cuQSeqlenPtr, const void* cuKVSeqlenPtr, void* outputPtr,
|
const KVBlockArray pagedKVCache, void const* cuQSeqlenPtr, void const* cuKVSeqlenPtr, void* outputPtr,
|
||||||
cudaStream_t stream)
|
cudaStream_t stream)
|
||||||
{
|
{
|
||||||
pimpl->run_paged_kv(
|
pimpl->run_paged_kv(
|
||||||
@ -689,7 +689,7 @@ bool FusedMHARunnerV2::isValid(int s) const
|
|||||||
}
|
}
|
||||||
|
|
||||||
// static function to check if fmha is supported when building plugins
|
// static function to check if fmha is supported when building plugins
|
||||||
bool MHARunner::fmha_supported(const int headSize, const int sm)
|
bool MHARunner::fmha_supported(int const headSize, int const sm)
|
||||||
{
|
{
|
||||||
if (sm == kSM_70)
|
if (sm == kSM_70)
|
||||||
{
|
{
|
||||||
|
|||||||
@ -41,33 +41,33 @@ namespace kernels
|
|||||||
class MHARunner
|
class MHARunner
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
MHARunner(const Data_type dataType, const int numHeads, const int headSize, const float qScaling);
|
MHARunner(const Data_type dataType, int const numHeads, int const headSize, float const qScaling);
|
||||||
|
|
||||||
MHARunner() = default;
|
MHARunner() = default;
|
||||||
|
|
||||||
virtual ~MHARunner() = default;
|
virtual ~MHARunner() = default;
|
||||||
|
|
||||||
virtual void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen,
|
virtual void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
|
||||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0)
|
bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1, int const tp_rank = 0)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
virtual void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
|
virtual void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
|
||||||
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen,
|
int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen,
|
||||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1, const int tp_rank = 0)
|
bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1, int const tp_rank = 0)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
static bool fmha_supported(const int headSize, const int sm);
|
static bool fmha_supported(int const headSize, int const sm);
|
||||||
|
|
||||||
virtual bool fmha_supported() = 0;
|
virtual bool fmha_supported() = 0;
|
||||||
|
|
||||||
virtual void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask,
|
virtual void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask,
|
||||||
const int num_kv_heads /* MQA or GQA */)
|
int const num_kv_heads /* MQA or GQA */)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
virtual void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) = 0;
|
virtual void run(void const* input, void const* cu_seqlens, void* output, cudaStream_t stream) = 0;
|
||||||
|
|
||||||
virtual void run_paged_kv(const void* q_input, void* paged_kv_tma_desc, const void* paged_kv_block_ptrs_on_host,
|
virtual void run_paged_kv(void const* q_input, void* paged_kv_tma_desc, void const* paged_kv_block_ptrs_on_host,
|
||||||
const KVBlockArray paged_kv_cache, const void* cu_q_seqlens, const void* cu_kv_seqlens, void* output,
|
const KVBlockArray paged_kv_cache, void const* cu_q_seqlens, void const* cu_kv_seqlens, void* output,
|
||||||
cudaStream_t stream)
|
cudaStream_t stream)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
@ -86,28 +86,28 @@ public:
|
|||||||
class FusedMHARunnerV2 : public MHARunner
|
class FusedMHARunnerV2 : public MHARunner
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FusedMHARunnerV2(const Data_type dataType, const int numHeads, const int headSize, const float qScaling);
|
FusedMHARunnerV2(const Data_type dataType, int const numHeads, int const headSize, float const qScaling);
|
||||||
|
|
||||||
~FusedMHARunnerV2(); // for pimpl
|
~FusedMHARunnerV2(); // for pimpl
|
||||||
|
|
||||||
void setup(const int b, const int s, const int sliding_window_size, const int total_seqlen,
|
void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
|
||||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1,
|
bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1,
|
||||||
const int tp_rank = 0) override;
|
int const tp_rank = 0) override;
|
||||||
|
|
||||||
void setup_paged_kv(const int b, const int s_q, const int s_kv, const int blocks_per_context_sequence,
|
void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
|
||||||
const int tokens_per_kv_block, const int sliding_window_size, const int total_seqlen,
|
int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen,
|
||||||
const bool has_alibi = false, const bool scale_alibi = false, const int tp_size = 1,
|
bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1,
|
||||||
const int tp_rank = 0) override;
|
int const tp_rank = 0) override;
|
||||||
|
|
||||||
bool fmha_supported() override;
|
bool fmha_supported() override;
|
||||||
|
|
||||||
void run(const void* input, const void* cu_seqlens, void* output, cudaStream_t stream) override;
|
void run(void const* input, void const* cu_seqlens, void* output, cudaStream_t stream) override;
|
||||||
void run_paged_kv(const void* q_input, void* paged_kv_tma_desc, const void* paged_kv_block_ptrs_on_host,
|
void run_paged_kv(void const* q_input, void* paged_kv_tma_desc, void const* paged_kv_block_ptrs_on_host,
|
||||||
const KVBlockArray paged_kv_cache, const void* cu_q_seqlens, const void* cu_kv_seqlens, void* output,
|
const KVBlockArray paged_kv_cache, void const* cu_q_seqlens, void const* cu_kv_seqlens, void* output,
|
||||||
cudaStream_t stream) override;
|
cudaStream_t stream) override;
|
||||||
|
|
||||||
void setup_flags(const bool force_fp32_acc, const bool is_s_padded, const bool causal_mask,
|
void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask,
|
||||||
const int num_kv_heads /* MQA or GQA */) override;
|
int const num_kv_heads /* MQA or GQA */) override;
|
||||||
|
|
||||||
bool isValid(int s) const override;
|
bool isValid(int s) const override;
|
||||||
|
|
||||||
|
|||||||
@ -84,9 +84,9 @@ struct AlibiParams
|
|||||||
struct Fused_multihead_attention_params_v2
|
struct Fused_multihead_attention_params_v2
|
||||||
{
|
{
|
||||||
// The QKV matrices.
|
// The QKV matrices.
|
||||||
const void* qkv_ptr;
|
void const* qkv_ptr;
|
||||||
// The mask to implement drop-out.
|
// The mask to implement drop-out.
|
||||||
const void* packed_mask_ptr;
|
void const* packed_mask_ptr;
|
||||||
// The O matrix (output).
|
// The O matrix (output).
|
||||||
void* o_ptr;
|
void* o_ptr;
|
||||||
|
|
||||||
@ -106,7 +106,7 @@ struct Fused_multihead_attention_params_v2
|
|||||||
bool enable_i2f_trick;
|
bool enable_i2f_trick;
|
||||||
|
|
||||||
// array of length b+1 holding prefix sum of actual sequence lengths
|
// array of length b+1 holding prefix sum of actual sequence lengths
|
||||||
const int* cu_seqlens;
|
int const* cu_seqlens;
|
||||||
|
|
||||||
// use C/32 Format.
|
// use C/32 Format.
|
||||||
bool interleaved = false;
|
bool interleaved = false;
|
||||||
@ -177,13 +177,13 @@ struct Fused_multihead_attention_params_v2
|
|||||||
struct Fused_multihead_attention_paged_kv_params_v2
|
struct Fused_multihead_attention_paged_kv_params_v2
|
||||||
{
|
{
|
||||||
// The Q matrices.
|
// The Q matrices.
|
||||||
const void* q_ptr;
|
void const* q_ptr;
|
||||||
// Paged KV Cache buffer.
|
// Paged KV Cache buffer.
|
||||||
KVBlockArrayForContextFMHA paged_kv_cache;
|
KVBlockArrayForContextFMHA paged_kv_cache;
|
||||||
// The O matrix (output).
|
// The O matrix (output).
|
||||||
void* o_ptr;
|
void* o_ptr;
|
||||||
// The packed mask for random mask.
|
// The packed mask for random mask.
|
||||||
const void* packed_mask_ptr;
|
void const* packed_mask_ptr;
|
||||||
|
|
||||||
// The stride between rows of the Q matrices.
|
// The stride between rows of the Q matrices.
|
||||||
int64_t q_stride_in_bytes;
|
int64_t q_stride_in_bytes;
|
||||||
@ -211,9 +211,9 @@ struct Fused_multihead_attention_paged_kv_params_v2
|
|||||||
AlibiParams alibi_params;
|
AlibiParams alibi_params;
|
||||||
|
|
||||||
// array of length b+1 holding prefix sum of actual kv sequence lengths.
|
// array of length b+1 holding prefix sum of actual kv sequence lengths.
|
||||||
const int* cu_seqlens;
|
int const* cu_seqlens;
|
||||||
// Chunked attention (only handles one tile of Q).
|
// Chunked attention (only handles one tile of Q).
|
||||||
const int* cu_q_seqlens;
|
int const* cu_q_seqlens;
|
||||||
|
|
||||||
// q with shape [B, S, H, D] in const cache.
|
// q with shape [B, S, H, D] in const cache.
|
||||||
cudaTmaDesc tma_desc_q;
|
cudaTmaDesc tma_desc_q;
|
||||||
@ -301,7 +301,7 @@ struct Launch_params
|
|||||||
// number of paged kv blocks for context sequence.
|
// number of paged kv blocks for context sequence.
|
||||||
int blocks_per_context_sequence = 0;
|
int blocks_per_context_sequence = 0;
|
||||||
// device ptrs on the host for paged kv cache.
|
// device ptrs on the host for paged kv cache.
|
||||||
const int64_t* paged_kv_block_ptrs = nullptr;
|
int64_t const* paged_kv_block_ptrs = nullptr;
|
||||||
// if flash attention is used (only FP16)
|
// if flash attention is used (only FP16)
|
||||||
bool flash_attention = false;
|
bool flash_attention = false;
|
||||||
// if warp_specialized kernels are used (only SM90 HGMMA + TMA)
|
// if warp_specialized kernels are used (only SM90 HGMMA + TMA)
|
||||||
|
|||||||
@ -63,13 +63,13 @@ public:
|
|||||||
return (uint64_t) s << 32 | d;
|
return (uint64_t) s << 32 | d;
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual uint64_t hashID(const KernelMeta& kernelMeta) const
|
virtual uint64_t hashID(KernelMeta const& kernelMeta) const
|
||||||
{
|
{
|
||||||
return hashID(kernelMeta.mS, kernelMeta.mD);
|
return hashID(kernelMeta.mS, kernelMeta.mD);
|
||||||
}
|
}
|
||||||
|
|
||||||
TFusedMultiHeadAttentionXMMAKernel(
|
TFusedMultiHeadAttentionXMMAKernel(
|
||||||
const TKernelMeta* pMetaStart, unsigned int nMetaCount, Data_type type, unsigned int sm)
|
TKernelMeta const* pMetaStart, unsigned int nMetaCount, Data_type type, unsigned int sm)
|
||||||
: mDataType(type)
|
: mDataType(type)
|
||||||
, mKernelMeta(pMetaStart)
|
, mKernelMeta(pMetaStart)
|
||||||
, mKernelMetaCount(nMetaCount)
|
, mKernelMetaCount(nMetaCount)
|
||||||
@ -86,7 +86,7 @@ public:
|
|||||||
|
|
||||||
for (unsigned int i = 0; i < mKernelMetaCount; ++i)
|
for (unsigned int i = 0; i < mKernelMetaCount; ++i)
|
||||||
{
|
{
|
||||||
const auto& kernelMeta = mKernelMeta[i];
|
auto const& kernelMeta = mKernelMeta[i];
|
||||||
if (kernelMeta.mSM == mSM && kernelMeta.mDataType == mDataType)
|
if (kernelMeta.mSM == mSM && kernelMeta.mDataType == mDataType)
|
||||||
{
|
{
|
||||||
CUmodule hmod{0};
|
CUmodule hmod{0};
|
||||||
@ -125,9 +125,9 @@ public:
|
|||||||
|
|
||||||
virtual void run(TKernelParam& params, Launch_params& launch_params, cudaStream_t ss) const
|
virtual void run(TKernelParam& params, Launch_params& launch_params, cudaStream_t ss) const
|
||||||
{
|
{
|
||||||
const auto findIter = mFunctions.find(hashID(params.s, params.d));
|
auto const findIter = mFunctions.find(hashID(params.s, params.d));
|
||||||
|
|
||||||
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
||||||
const CUfunction func = findIter->second.mDeviceFunction;
|
const CUfunction func = findIter->second.mDeviceFunction;
|
||||||
|
|
||||||
void* kernelParams[] = {¶ms, nullptr};
|
void* kernelParams[] = {¶ms, nullptr};
|
||||||
@ -142,10 +142,10 @@ protected:
|
|||||||
tensorrt_llm::common::CUDADriverWrapper mDriver;
|
tensorrt_llm::common::CUDADriverWrapper mDriver;
|
||||||
|
|
||||||
Data_type mDataType;
|
Data_type mDataType;
|
||||||
const TKernelMeta* mKernelMeta;
|
TKernelMeta const* mKernelMeta;
|
||||||
unsigned int mKernelMetaCount;
|
unsigned int mKernelMetaCount;
|
||||||
unsigned int mSM;
|
unsigned int mSM;
|
||||||
std::unordered_map<const unsigned char*, CUmodule> mModules;
|
std::unordered_map<unsigned char const*, CUmodule> mModules;
|
||||||
|
|
||||||
struct FusedMultiHeadAttentionKernelInfo
|
struct FusedMultiHeadAttentionKernelInfo
|
||||||
{
|
{
|
||||||
@ -161,14 +161,14 @@ template <typename TFusedMHAKernelList>
|
|||||||
class TFusedMHAKernelFactory
|
class TFusedMHAKernelFactory
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
const TFusedMHAKernelList* getXMMAKernels(const typename TFusedMHAKernelList::KernelMeta* pKernelList,
|
TFusedMHAKernelList const* getXMMAKernels(const typename TFusedMHAKernelList::KernelMeta* pKernelList,
|
||||||
unsigned int nbKernels, Data_type type, unsigned int sm)
|
unsigned int nbKernels, Data_type type, unsigned int sm)
|
||||||
{
|
{
|
||||||
static std::mutex s_mutex;
|
static std::mutex s_mutex;
|
||||||
std::lock_guard<std::mutex> lg(s_mutex);
|
std::lock_guard<std::mutex> lg(s_mutex);
|
||||||
|
|
||||||
const auto id = hashID(type, sm);
|
auto const id = hashID(type, sm);
|
||||||
const auto findIter = mKernels.find(id);
|
auto const findIter = mKernels.find(id);
|
||||||
if (findIter == mKernels.end())
|
if (findIter == mKernels.end())
|
||||||
{
|
{
|
||||||
TFusedMHAKernelList* newKernel = new TFusedMHAKernelList{pKernelList, nbKernels, type, sm};
|
TFusedMHAKernelList* newKernel = new TFusedMHAKernelList{pKernelList, nbKernels, type, sm};
|
||||||
@ -214,7 +214,7 @@ class FusedMultiHeadAttentionXMMAKernelV2
|
|||||||
Fused_multihead_attention_params_v2>
|
Fused_multihead_attention_params_v2>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FusedMultiHeadAttentionXMMAKernelV2(const FusedMultiHeadAttentionKernelMetaInfoV2* pMetaStart,
|
FusedMultiHeadAttentionXMMAKernelV2(FusedMultiHeadAttentionKernelMetaInfoV2 const* pMetaStart,
|
||||||
unsigned int nMetaCount, Data_type type, unsigned int sm)
|
unsigned int nMetaCount, Data_type type, unsigned int sm)
|
||||||
: TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionKernelMetaInfoV2,
|
: TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionKernelMetaInfoV2,
|
||||||
Fused_multihead_attention_params_v2>(pMetaStart, nMetaCount, type, sm)
|
Fused_multihead_attention_params_v2>(pMetaStart, nMetaCount, type, sm)
|
||||||
@ -231,7 +231,7 @@ public:
|
|||||||
| (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
|
| (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual uint64_t hashID(const KernelMeta& kernelMeta) const
|
virtual uint64_t hashID(KernelMeta const& kernelMeta) const
|
||||||
{
|
{
|
||||||
|
|
||||||
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
|
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
|
||||||
@ -278,7 +278,7 @@ public:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto findIter
|
auto const findIter
|
||||||
= mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved, forceUnroll,
|
= mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved, forceUnroll,
|
||||||
launch_params.force_fp32_acc, launch_params.flash_attention, !launch_params.useKernelWithoutAlibi,
|
launch_params.force_fp32_acc, launch_params.flash_attention, !launch_params.useKernelWithoutAlibi,
|
||||||
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
|
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
|
||||||
@ -290,7 +290,7 @@ public:
|
|||||||
launch_params.flash_attention, !launch_params.useKernelWithoutAlibi,
|
launch_params.flash_attention, !launch_params.useKernelWithoutAlibi,
|
||||||
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling);
|
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling);
|
||||||
|
|
||||||
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
||||||
const CUfunction func = findIter->second.mDeviceFunction;
|
const CUfunction func = findIter->second.mDeviceFunction;
|
||||||
|
|
||||||
void* kernelParams[] = {¶ms, nullptr};
|
void* kernelParams[] = {¶ms, nullptr};
|
||||||
@ -369,7 +369,7 @@ public:
|
|||||||
|
|
||||||
using FusedMHAKernelFactoryV2 = TFusedMHAKernelFactory<FusedMultiHeadAttentionXMMAKernelV2>;
|
using FusedMHAKernelFactoryV2 = TFusedMHAKernelFactory<FusedMultiHeadAttentionXMMAKernelV2>;
|
||||||
|
|
||||||
inline const FusedMultiHeadAttentionXMMAKernelV2* getXMMAKernelsV2(Data_type type, unsigned int sm)
|
inline FusedMultiHeadAttentionXMMAKernelV2 const* getXMMAKernelsV2(Data_type type, unsigned int sm)
|
||||||
{
|
{
|
||||||
return FusedMHAKernelFactoryV2::Get().getXMMAKernels(
|
return FusedMHAKernelFactoryV2::Get().getXMMAKernels(
|
||||||
sMhaKernelMetaInfosV2, sizeof(sMhaKernelMetaInfosV2) / sizeof(sMhaKernelMetaInfosV2[0]), type, sm);
|
sMhaKernelMetaInfosV2, sizeof(sMhaKernelMetaInfosV2) / sizeof(sMhaKernelMetaInfosV2[0]), type, sm);
|
||||||
@ -384,7 +384,7 @@ class FusedMultiHeadAttentionPagedKVXMMAKernelV2
|
|||||||
Fused_multihead_attention_paged_kv_params_v2>
|
Fused_multihead_attention_paged_kv_params_v2>
|
||||||
{
|
{
|
||||||
public:
|
public:
|
||||||
FusedMultiHeadAttentionPagedKVXMMAKernelV2(const FusedMultiHeadAttentionPagedKVKernelMetaInfoV2* pMetaStart,
|
FusedMultiHeadAttentionPagedKVXMMAKernelV2(FusedMultiHeadAttentionPagedKVKernelMetaInfoV2 const* pMetaStart,
|
||||||
unsigned int nMetaCount, Data_type type, unsigned int sm)
|
unsigned int nMetaCount, Data_type type, unsigned int sm)
|
||||||
: TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionPagedKVKernelMetaInfoV2,
|
: TFusedMultiHeadAttentionXMMAKernel<FusedMultiHeadAttentionPagedKVKernelMetaInfoV2,
|
||||||
Fused_multihead_attention_paged_kv_params_v2>(pMetaStart, nMetaCount, type, sm)
|
Fused_multihead_attention_paged_kv_params_v2>(pMetaStart, nMetaCount, type, sm)
|
||||||
@ -402,7 +402,7 @@ public:
|
|||||||
| (flash_attention ? 4ull : 0ull) | (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
|
| (flash_attention ? 4ull : 0ull) | (interleaved ? 2ull : 0ull) | (unroll ? 1ull : 0ull);
|
||||||
}
|
}
|
||||||
|
|
||||||
virtual uint64_t hashID(const KernelMeta& kernelMeta) const
|
virtual uint64_t hashID(KernelMeta const& kernelMeta) const
|
||||||
{
|
{
|
||||||
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
|
return hashID(kernelMeta.mS, kernelMeta.mD, kernelMeta.mInterleaved, kernelMeta.mUnrollStep,
|
||||||
kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mWarpSpecialization,
|
kernelMeta.mFP32Accumulation, kernelMeta.mFlashAttention, kernelMeta.mWarpSpecialization,
|
||||||
@ -413,7 +413,7 @@ public:
|
|||||||
Fused_multihead_attention_paged_kv_params_v2& params, Launch_params& launch_params, cudaStream_t stream) const
|
Fused_multihead_attention_paged_kv_params_v2& params, Launch_params& launch_params, cudaStream_t stream) const
|
||||||
{
|
{
|
||||||
|
|
||||||
const auto findIter = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved,
|
auto const findIter = mFunctions.find(hashID(launch_params.kernel_s, params.d, launch_params.interleaved,
|
||||||
launch_params.force_unroll, launch_params.force_fp32_acc, launch_params.flash_attention,
|
launch_params.force_unroll, launch_params.force_fp32_acc, launch_params.flash_attention,
|
||||||
launch_params.warp_specialization, !launch_params.useKernelWithoutAlibi,
|
launch_params.warp_specialization, !launch_params.useKernelWithoutAlibi,
|
||||||
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
|
static_cast<int>(launch_params.attention_mask_type), launch_params.granular_tiling));
|
||||||
@ -426,7 +426,7 @@ public:
|
|||||||
!launch_params.useKernelWithoutAlibi, static_cast<int>(launch_params.attention_mask_type),
|
!launch_params.useKernelWithoutAlibi, static_cast<int>(launch_params.attention_mask_type),
|
||||||
launch_params.granular_tiling);
|
launch_params.granular_tiling);
|
||||||
|
|
||||||
const auto& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
auto const& kernelMeta = mKernelMeta[findIter->second.mMetaInfoIndex];
|
||||||
const CUfunction func = findIter->second.mDeviceFunction;
|
const CUfunction func = findIter->second.mDeviceFunction;
|
||||||
|
|
||||||
void* kernelParams[] = {¶ms, nullptr};
|
void* kernelParams[] = {¶ms, nullptr};
|
||||||
@ -488,7 +488,7 @@ public:
|
|||||||
|
|
||||||
using FusedMHAPagedKVKernelFactoryV2 = TFusedMHAKernelFactory<FusedMultiHeadAttentionPagedKVXMMAKernelV2>;
|
using FusedMHAPagedKVKernelFactoryV2 = TFusedMHAKernelFactory<FusedMultiHeadAttentionPagedKVXMMAKernelV2>;
|
||||||
|
|
||||||
inline const FusedMultiHeadAttentionPagedKVXMMAKernelV2* getPagedKVXMMAKernelsV2(Data_type type, unsigned int sm)
|
inline FusedMultiHeadAttentionPagedKVXMMAKernelV2 const* getPagedKVXMMAKernelsV2(Data_type type, unsigned int sm)
|
||||||
{
|
{
|
||||||
return FusedMHAPagedKVKernelFactoryV2::Get().getXMMAKernels(sMhaPagedKVKernelMetaInfosV2,
|
return FusedMHAPagedKVKernelFactoryV2::Get().getXMMAKernels(sMhaPagedKVKernelMetaInfosV2,
|
||||||
sizeof(sMhaPagedKVKernelMetaInfosV2) / sizeof(sMhaPagedKVKernelMetaInfosV2[0]), type, sm);
|
sizeof(sMhaPagedKVKernelMetaInfosV2) / sizeof(sMhaPagedKVKernelMetaInfosV2[0]), type, sm);
|
||||||
|
|||||||
@ -186,7 +186,7 @@ public:
|
|||||||
// set the desctriptor.
|
// set the desctriptor.
|
||||||
int set_tma_desctriptor(
|
int set_tma_desctriptor(
|
||||||
// ptr to gmem
|
// ptr to gmem
|
||||||
const void* gmem_ptr,
|
void const* gmem_ptr,
|
||||||
// format is really data_type in TMA terminology.
|
// format is really data_type in TMA terminology.
|
||||||
cudaTmaDescFormat format,
|
cudaTmaDescFormat format,
|
||||||
// interleave mode.
|
// interleave mode.
|
||||||
@ -221,7 +221,7 @@ public:
|
|||||||
// set the desctriptor.
|
// set the desctriptor.
|
||||||
int set_tma_desctriptor(
|
int set_tma_desctriptor(
|
||||||
// ptr to gmem
|
// ptr to gmem
|
||||||
const void* gmem_ptr,
|
void const* gmem_ptr,
|
||||||
// format is really data_type in TMA terminology.
|
// format is really data_type in TMA terminology.
|
||||||
cudaTmaDescFormat format,
|
cudaTmaDescFormat format,
|
||||||
// interleave mode.
|
// interleave mode.
|
||||||
|
|||||||
@ -108,10 +108,10 @@ inline __device__ int4 add128b(T& a, T& b)
|
|||||||
}
|
}
|
||||||
|
|
||||||
__inline__ __device__ void multi_gpu_barrier(
|
__inline__ __device__ void multi_gpu_barrier(
|
||||||
uint32_t** signals, const uint32_t flag, const size_t rank, const size_t world_size, const int tidx, const int bidx)
|
uint32_t** signals, const uint32_t flag, const size_t rank, const size_t world_size, int const tidx, int const bidx)
|
||||||
{
|
{
|
||||||
// At the end of the function, we now that has least block 0 from all others GPUs have reached that point.
|
// At the end of the function, we now that has least block 0 from all others GPUs have reached that point.
|
||||||
volatile uint32_t* my_signals = signals[rank];
|
uint32_t volatile* my_signals = signals[rank];
|
||||||
if (tidx < world_size)
|
if (tidx < world_size)
|
||||||
{
|
{
|
||||||
// The 1st block notifies the other ranks.
|
// The 1st block notifies the other ranks.
|
||||||
@ -139,8 +139,8 @@ __global__ void multiGpuBarrierKernel(AllReduceParams params)
|
|||||||
template <typename T, int RANKS_PER_NODE>
|
template <typename T, int RANKS_PER_NODE>
|
||||||
static __global__ void oneShotAllReduceKernel(AllReduceParams params)
|
static __global__ void oneShotAllReduceKernel(AllReduceParams params)
|
||||||
{
|
{
|
||||||
const int bidx = blockIdx.x;
|
int const bidx = blockIdx.x;
|
||||||
const int tidx = threadIdx.x;
|
int const tidx = threadIdx.x;
|
||||||
|
|
||||||
// The number of elements packed into one for comms
|
// The number of elements packed into one for comms
|
||||||
static constexpr int NUM_ELTS = 16 / sizeof(T);
|
static constexpr int NUM_ELTS = 16 / sizeof(T);
|
||||||
@ -151,7 +151,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
|
|||||||
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
||||||
|
|
||||||
// The source pointers. Distributed round-robin for the different warps.
|
// The source pointers. Distributed round-robin for the different warps.
|
||||||
const T* src_d[RANKS_PER_NODE];
|
T const* src_d[RANKS_PER_NODE];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||||
{
|
{
|
||||||
@ -172,7 +172,7 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||||
{
|
{
|
||||||
vals[ii].packed = *reinterpret_cast<const int4*>(&src_d[ii][iter_offset]);
|
vals[ii].packed = *reinterpret_cast<int4 const*>(&src_d[ii][iter_offset]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sum the values from the different ranks.
|
// Sum the values from the different ranks.
|
||||||
@ -194,9 +194,9 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
|
|||||||
{
|
{
|
||||||
|
|
||||||
// The block index.
|
// The block index.
|
||||||
const int bidx = blockIdx.x;
|
int const bidx = blockIdx.x;
|
||||||
// The thread index with the block.
|
// The thread index with the block.
|
||||||
const int tidx = threadIdx.x;
|
int const tidx = threadIdx.x;
|
||||||
|
|
||||||
// The number of elements packed into one for comms
|
// The number of elements packed into one for comms
|
||||||
static constexpr int NUM_ELTS = 16 / sizeof(T);
|
static constexpr int NUM_ELTS = 16 / sizeof(T);
|
||||||
@ -233,7 +233,7 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
|
|||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||||
{
|
{
|
||||||
vals[ii].packed = *reinterpret_cast<const int4*>(&src_d[ii][local_offset]);
|
vals[ii].packed = *reinterpret_cast<int4 const*>(&src_d[ii][local_offset]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sum the values from the different ranks.
|
// Sum the values from the different ranks.
|
||||||
@ -396,14 +396,14 @@ void invokeMultiGpuBarrier(AllReduceParams& param, cudaStream_t stream)
|
|||||||
multiGpuBarrierKernel<<<1, param.ranks_per_node, 0, stream>>>(param);
|
multiGpuBarrierKernel<<<1, param.ranks_per_node, 0, stream>>>(param);
|
||||||
}
|
}
|
||||||
|
|
||||||
AllReduceParams AllReduceParams::deserialize(const int32_t* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value)
|
AllReduceParams AllReduceParams::deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value)
|
||||||
{
|
{
|
||||||
void* const* buffer_ptrs = reinterpret_cast<void* const*>(buffer);
|
void* const* buffer_ptrs = reinterpret_cast<void* const*>(buffer);
|
||||||
AllReduceParams params;
|
AllReduceParams params;
|
||||||
// Even plugins use ping buffers, odd plugins use pong.
|
// Even plugins use ping buffers, odd plugins use pong.
|
||||||
// That way, we don't need to wait for other GPUs to be done
|
// That way, we don't need to wait for other GPUs to be done
|
||||||
// before copying input tensor to workspace.
|
// before copying input tensor to workspace.
|
||||||
const auto buffer_offset = (flag_value % 2 == 0) ? 0 : tpSize;
|
auto const buffer_offset = (flag_value % 2 == 0) ? 0 : tpSize;
|
||||||
|
|
||||||
for (int i = 0; i < tpSize; ++i)
|
for (int i = 0; i < tpSize; ++i)
|
||||||
{
|
{
|
||||||
|
|||||||
@ -57,7 +57,7 @@ struct AllReduceParams
|
|||||||
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
|
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
|
||||||
void* local_output_buffer_ptr;
|
void* local_output_buffer_ptr;
|
||||||
|
|
||||||
static AllReduceParams deserialize(const int32_t* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value);
|
static AllReduceParams deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value);
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|||||||
@ -70,7 +70,7 @@ TileShape get_cta_shape_for_config(CutlassTileConfig tile_config)
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape,
|
bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k, const TileShape tile_shape,
|
||||||
const int split_k_factor, const size_t workspace_bytes, const bool is_weight_only)
|
int const split_k_factor, const size_t workspace_bytes, bool const is_weight_only)
|
||||||
{
|
{
|
||||||
|
|
||||||
// All tile sizes have a k_tile of 64.
|
// All tile sizes have a k_tile of 64.
|
||||||
@ -89,7 +89,7 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
|
|||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int k_elements_per_split = k / split_k_factor;
|
int const k_elements_per_split = k / split_k_factor;
|
||||||
if ((k_elements_per_split % k_tile) != 0)
|
if ((k_elements_per_split % k_tile) != 0)
|
||||||
{
|
{
|
||||||
return false;
|
return false;
|
||||||
@ -97,9 +97,9 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check that the workspace has sufficient space for this split-k factor
|
// Check that the workspace has sufficient space for this split-k factor
|
||||||
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
|
int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
|
||||||
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
|
int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
|
||||||
const int required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
|
int const required_ws_bytes = split_k_factor == 1 ? 0 : sizeof(int) * ctas_in_m_dim * ctas_in_n_dim;
|
||||||
|
|
||||||
if (required_ws_bytes > workspace_bytes)
|
if (required_ws_bytes > workspace_bytes)
|
||||||
{
|
{
|
||||||
@ -110,7 +110,7 @@ bool is_valid_split_k_factor(const int64_t m, const int64_t n, const int64_t k,
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<CutlassTileConfig> get_candidate_tiles(
|
std::vector<CutlassTileConfig> get_candidate_tiles(
|
||||||
const int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only)
|
int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only)
|
||||||
{
|
{
|
||||||
enum class CutlassGemmType : char
|
enum class CutlassGemmType : char
|
||||||
{
|
{
|
||||||
@ -170,7 +170,7 @@ std::vector<CutlassTileConfig> get_candidate_tiles(
|
|||||||
}
|
}
|
||||||
|
|
||||||
std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
|
std::vector<CutlassTileConfigSM90> get_candidate_tiles_sm90(
|
||||||
const int sm, const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only)
|
int const sm, bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only)
|
||||||
{
|
{
|
||||||
enum class CutlassGemmType : char
|
enum class CutlassGemmType : char
|
||||||
{
|
{
|
||||||
@ -226,8 +226,8 @@ bool supports_mcast_along_n(const CutlassTileConfigSM90 tile)
|
|||||||
return valid_tiles.count(tile) == 1;
|
return valid_tiles.count(tile) == 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weight_only, const bool simt_configs_only,
|
std::vector<CutlassGemmConfig> get_candidate_configs(int sm, bool const is_weight_only, bool const simt_configs_only,
|
||||||
const bool int8_configs_only, const int max_split_k, const bool enable_hopper_gmma)
|
bool const int8_configs_only, int const max_split_k, bool const enable_hopper_gmma)
|
||||||
{
|
{
|
||||||
if (sm == 90 && enable_hopper_gmma)
|
if (sm == 90 && enable_hopper_gmma)
|
||||||
{
|
{
|
||||||
@ -235,14 +235,14 @@ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weigh
|
|||||||
= get_candidate_tiles_sm90(sm, is_weight_only, simt_configs_only, int8_configs_only);
|
= get_candidate_tiles_sm90(sm, is_weight_only, simt_configs_only, int8_configs_only);
|
||||||
|
|
||||||
std::vector<CutlassGemmConfig> candidate_configs;
|
std::vector<CutlassGemmConfig> candidate_configs;
|
||||||
for (const auto& tile_config : tiles)
|
for (auto const& tile_config : tiles)
|
||||||
{
|
{
|
||||||
CutlassGemmConfig config(
|
CutlassGemmConfig config(
|
||||||
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
|
tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO, ClusterShape::ClusterShape_1x1x1);
|
||||||
candidate_configs.push_back(config);
|
candidate_configs.push_back(config);
|
||||||
|
|
||||||
const bool has_m_mcast = supports_mcast_along_m(tile_config);
|
bool const has_m_mcast = supports_mcast_along_m(tile_config);
|
||||||
const bool has_n_mcast = supports_mcast_along_n(tile_config);
|
bool const has_n_mcast = supports_mcast_along_n(tile_config);
|
||||||
if (has_m_mcast)
|
if (has_m_mcast)
|
||||||
{
|
{
|
||||||
CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
CutlassGemmConfig config(tile_config, MainloopScheduleType::AUTO, EpilogueScheduleType::AUTO,
|
||||||
@ -270,9 +270,9 @@ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weigh
|
|||||||
= get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only);
|
= get_candidate_tiles(sm, is_weight_only, simt_configs_only, int8_configs_only);
|
||||||
|
|
||||||
std::vector<CutlassGemmConfig> candidate_configs;
|
std::vector<CutlassGemmConfig> candidate_configs;
|
||||||
const int min_stages = int8_configs_only ? 3 : 2;
|
int const min_stages = int8_configs_only ? 3 : 2;
|
||||||
const int max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
|
int const max_stages = int8_configs_only ? 6 : (sm >= 80 ? 4 : 2);
|
||||||
for (const auto& tile_config : tiles)
|
for (auto const& tile_config : tiles)
|
||||||
{
|
{
|
||||||
for (int stages = min_stages; stages <= max_stages; ++stages)
|
for (int stages = min_stages; stages <= max_stages; ++stages)
|
||||||
{
|
{
|
||||||
@ -292,9 +292,9 @@ std::vector<CutlassGemmConfig> get_candidate_configs(int sm, const bool is_weigh
|
|||||||
return candidate_configs;
|
return candidate_configs;
|
||||||
}
|
}
|
||||||
|
|
||||||
CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<CutlassGemmConfig>& candidate_configs,
|
CutlassGemmConfig estimate_best_config_from_occupancies(std::vector<CutlassGemmConfig> const& candidate_configs,
|
||||||
const std::vector<int>& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
|
std::vector<int> const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
|
||||||
const int split_k_limit, const size_t workspace_bytes, const int multi_processor_count, const int is_weight_only)
|
int const split_k_limit, const size_t workspace_bytes, int const multi_processor_count, int const is_weight_only)
|
||||||
{
|
{
|
||||||
|
|
||||||
if (occupancies.size() != candidate_configs.size())
|
if (occupancies.size() != candidate_configs.size())
|
||||||
@ -311,7 +311,7 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
|
|||||||
int config_waves = INT_MAX;
|
int config_waves = INT_MAX;
|
||||||
int current_m_tile = 0;
|
int current_m_tile = 0;
|
||||||
|
|
||||||
const int max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
|
int const max_split_k = n >= multi_processor_count * 256 ? 1 : split_k_limit;
|
||||||
for (int ii = 0; ii < candidate_configs.size(); ++ii)
|
for (int ii = 0; ii < candidate_configs.size(); ++ii)
|
||||||
{
|
{
|
||||||
CutlassGemmConfig candidate_config = candidate_configs[ii];
|
CutlassGemmConfig candidate_config = candidate_configs[ii];
|
||||||
@ -330,21 +330,21 @@ CutlassGemmConfig estimate_best_config_from_occupancies(const std::vector<Cutlas
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
const int ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
|
int const ctas_in_m_dim = (m + tile_shape.m - 1) / tile_shape.m;
|
||||||
const int ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
|
int const ctas_in_n_dim = (n + tile_shape.n - 1) / tile_shape.n;
|
||||||
|
|
||||||
for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor)
|
for (int split_k_factor = 1; split_k_factor <= max_split_k; ++split_k_factor)
|
||||||
{
|
{
|
||||||
if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only))
|
if (is_valid_split_k_factor(m, n, k, tile_shape, split_k_factor, workspace_bytes, is_weight_only))
|
||||||
{
|
{
|
||||||
const int ctas_per_wave = occupancy * multi_processor_count;
|
int const ctas_per_wave = occupancy * multi_processor_count;
|
||||||
const int ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
|
int const ctas_for_problem = ctas_in_m_dim * ctas_in_n_dim * split_k_factor;
|
||||||
|
|
||||||
const int num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
|
int const num_waves_total = (ctas_for_problem + ctas_per_wave - 1) / ctas_per_wave;
|
||||||
const float num_waves_fractional = ctas_for_problem / float(ctas_per_wave);
|
float const num_waves_fractional = ctas_for_problem / float(ctas_per_wave);
|
||||||
const float current_score = float(num_waves_total) - num_waves_fractional;
|
float const current_score = float(num_waves_total) - num_waves_fractional;
|
||||||
|
|
||||||
const float score_slack = 0.1f;
|
float const score_slack = 0.1f;
|
||||||
if (current_score < config_score
|
if (current_score < config_score
|
||||||
|| ((config_waves > num_waves_total) && (current_score < config_score + score_slack)))
|
|| ((config_waves > num_waves_total) && (current_score < config_score + score_slack)))
|
||||||
{
|
{
|
||||||
|
|||||||
@ -27,13 +27,13 @@ namespace cutlass_kernels
|
|||||||
{
|
{
|
||||||
|
|
||||||
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(int sm,
|
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> get_candidate_configs(int sm,
|
||||||
const bool is_weight_only, const bool simt_configs_only, const bool int8_configs_only = false,
|
bool const is_weight_only, bool const simt_configs_only, bool const int8_configs_only = false,
|
||||||
const int max_split_k = 1, const bool enable_hopper_gmma = false);
|
int const max_split_k = 1, bool const enable_hopper_gmma = false);
|
||||||
|
|
||||||
tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies(
|
tensorrt_llm::cutlass_extensions::CutlassGemmConfig estimate_best_config_from_occupancies(
|
||||||
const std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig>& candidate_configs,
|
std::vector<tensorrt_llm::cutlass_extensions::CutlassGemmConfig> const& candidate_configs,
|
||||||
const std::vector<int>& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
|
std::vector<int> const& occupancies, const int64_t m, const int64_t n, const int64_t k, const int64_t num_experts,
|
||||||
const int split_k_limit, const size_t workspace_bytes, const int multi_processor_count, const int is_weight_only);
|
int const split_k_limit, const size_t workspace_bytes, int const multi_processor_count, int const is_weight_only);
|
||||||
|
|
||||||
} // namespace cutlass_kernels
|
} // namespace cutlass_kernels
|
||||||
} // namespace kernels
|
} // namespace kernels
|
||||||
|
|||||||
@ -158,8 +158,8 @@ LayoutDetails getLayoutDetailsForTransform(QuantType quant_type, int arch)
|
|||||||
// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
|
// 0 1 8 9 2 3 10 11 4 5 12 13 6 7 14 15
|
||||||
// For int4, each group of 32 rows is permuted using the map below:
|
// For int4, each group of 32 rows is permuted using the map below:
|
||||||
// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31
|
// 0 1 8 9 16 17 24 25 2 3 10 11 18 19 26 27 4 5 12 13 20 21 28 29 6 7 14 15 22 23 30 31
|
||||||
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor,
|
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor,
|
||||||
const std::vector<size_t>& shape, QuantType quant_type, const int64_t arch_version)
|
std::vector<size_t> const& shape, QuantType quant_type, const int64_t arch_version)
|
||||||
{
|
{
|
||||||
|
|
||||||
// We only want to run this step for weight only quant.
|
// We only want to run this step for weight only quant.
|
||||||
@ -170,19 +170,19 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8
|
|||||||
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
||||||
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
||||||
|
|
||||||
const int BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
int const BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
||||||
const int K = 16 / BITS_PER_ELT;
|
int const K = 16 / BITS_PER_ELT;
|
||||||
const int ELTS_PER_BYTE = 8 / BITS_PER_ELT;
|
int const ELTS_PER_BYTE = 8 / BITS_PER_ELT;
|
||||||
const int ELTS_PER_REG = 32 / BITS_PER_ELT;
|
int const ELTS_PER_REG = 32 / BITS_PER_ELT;
|
||||||
|
|
||||||
const uint32_t* input_byte_ptr = reinterpret_cast<const uint32_t*>(quantized_tensor);
|
uint32_t const* input_byte_ptr = reinterpret_cast<uint32_t const*>(quantized_tensor);
|
||||||
uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(permuted_quantized_tensor);
|
uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(permuted_quantized_tensor);
|
||||||
|
|
||||||
int MMA_SHAPE_N = 8;
|
int MMA_SHAPE_N = 8;
|
||||||
int B_ROWS_PER_MMA = 8 * K;
|
int B_ROWS_PER_MMA = 8 * K;
|
||||||
const int elts_in_int32 = 32 / BITS_PER_ELT;
|
int const elts_in_int32 = 32 / BITS_PER_ELT;
|
||||||
|
|
||||||
const int num_vec_cols = num_cols / elts_in_int32;
|
int const num_vec_cols = num_cols / elts_in_int32;
|
||||||
|
|
||||||
TLLM_CHECK_WITH_INFO(
|
TLLM_CHECK_WITH_INFO(
|
||||||
arch_version >= 75, "Unsupported Arch. Pre-volta not supported. Column interleave not needed on Volta.");
|
arch_version >= 75, "Unsupported Arch. Pre-volta not supported. Column interleave not needed on Volta.");
|
||||||
@ -205,11 +205,11 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8
|
|||||||
|
|
||||||
for (int write_col = 0; write_col < num_vec_cols; ++write_col)
|
for (int write_col = 0; write_col < num_vec_cols; ++write_col)
|
||||||
{
|
{
|
||||||
const int write_row = base_row + tile_row;
|
int const write_row = base_row + tile_row;
|
||||||
const int tile_read_row
|
int const tile_read_row
|
||||||
= 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
|
= 8 * (((tile_row % ELTS_PER_REG) / 2)) + tile_row % 2 + 2 * (tile_row / ELTS_PER_REG);
|
||||||
const int read_row = base_row + tile_read_row;
|
int const read_row = base_row + tile_read_row;
|
||||||
const int read_col = write_col;
|
int const read_col = write_col;
|
||||||
|
|
||||||
const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col;
|
const int64_t read_offset = matrix_offset + int64_t(read_row) * num_vec_cols + read_col;
|
||||||
const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col;
|
const int64_t write_offset = matrix_offset + int64_t(write_row) * num_vec_cols + write_col;
|
||||||
@ -227,9 +227,9 @@ void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8
|
|||||||
// issue for relatively large models.
|
// issue for relatively large models.
|
||||||
template <QuantType quant_type>
|
template <QuantType quant_type>
|
||||||
void subbyte_transpose_impl(
|
void subbyte_transpose_impl(
|
||||||
int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor, const std::vector<size_t>& shape)
|
int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor, std::vector<size_t> const& shape)
|
||||||
{
|
{
|
||||||
const int bits_per_elt = get_bits_in_quant_type(quant_type);
|
int const bits_per_elt = get_bits_in_quant_type(quant_type);
|
||||||
|
|
||||||
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
|
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
|
||||||
const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
|
const size_t num_experts = shape.size() == 2 ? 1 : shape[0];
|
||||||
@ -240,7 +240,7 @@ void subbyte_transpose_impl(
|
|||||||
const size_t col_bytes_trans = num_rows * bits_per_elt / 8;
|
const size_t col_bytes_trans = num_rows * bits_per_elt / 8;
|
||||||
const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes;
|
const size_t num_bytes = size_t(num_experts) * num_rows * col_bytes;
|
||||||
|
|
||||||
const uint8_t* input_byte_ptr = reinterpret_cast<const uint8_t*>(quantized_tensor);
|
uint8_t const* input_byte_ptr = reinterpret_cast<uint8_t const*>(quantized_tensor);
|
||||||
uint8_t* output_byte_ptr = reinterpret_cast<uint8_t*>(transposed_quantized_tensor);
|
uint8_t* output_byte_ptr = reinterpret_cast<uint8_t*>(transposed_quantized_tensor);
|
||||||
|
|
||||||
static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, "");
|
static_assert(quant_type == QuantType::INT8_WEIGHT_ONLY || quant_type == QuantType::PACKED_INT4_WEIGHT_ONLY, "");
|
||||||
@ -260,8 +260,8 @@ void subbyte_transpose_impl(
|
|||||||
"num_col_bytes = %ld.",
|
"num_col_bytes = %ld.",
|
||||||
VECTOR_WIDTH, col_bytes_trans, col_bytes));
|
VECTOR_WIDTH, col_bytes_trans, col_bytes));
|
||||||
|
|
||||||
const int num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1;
|
int const num_m_tiles = (num_rows + M_TILE_L1 - 1) / M_TILE_L1;
|
||||||
const int num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1;
|
int const num_n_tiles = (col_bytes + N_TILE_L1 - 1) / N_TILE_L1;
|
||||||
|
|
||||||
for (size_t expert = 0; expert < num_experts; ++expert)
|
for (size_t expert = 0; expert < num_experts; ++expert)
|
||||||
{
|
{
|
||||||
@ -271,16 +271,16 @@ void subbyte_transpose_impl(
|
|||||||
for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1)
|
for (size_t col_tile_start_byte = 0; col_tile_start_byte < col_bytes; col_tile_start_byte += N_TILE_L1)
|
||||||
{
|
{
|
||||||
|
|
||||||
const int row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
|
int const row_limit = std::min(row_tile_start + M_TILE_L1, num_rows);
|
||||||
const int col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
|
int const col_limit = std::min(col_tile_start_byte + N_TILE_L1, col_bytes);
|
||||||
|
|
||||||
for (int ii = 0; ii < M_TILE_L1; ++ii)
|
for (int ii = 0; ii < M_TILE_L1; ++ii)
|
||||||
{
|
{
|
||||||
const int row = row_tile_start + ii;
|
int const row = row_tile_start + ii;
|
||||||
|
|
||||||
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH)
|
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH)
|
||||||
{
|
{
|
||||||
const int col = col_tile_start_byte + jj;
|
int const col = col_tile_start_byte + jj;
|
||||||
|
|
||||||
const size_t logical_src_offset = matrix_offset + row * col_bytes + col;
|
const size_t logical_src_offset = matrix_offset + row * col_bytes + col;
|
||||||
|
|
||||||
@ -313,11 +313,11 @@ void subbyte_transpose_impl(
|
|||||||
// is square in the number of elements (not necessarily the number of bytes).
|
// is square in the number of elements (not necessarily the number of bytes).
|
||||||
for (int jj = ii + 1; jj < M_TILE_L1; ++jj)
|
for (int jj = ii + 1; jj < M_TILE_L1; ++jj)
|
||||||
{
|
{
|
||||||
const int ii_byte = ii / ELTS_PER_BYTE;
|
int const ii_byte = ii / ELTS_PER_BYTE;
|
||||||
const int ii_bit_offset = ii % ELTS_PER_BYTE;
|
int const ii_bit_offset = ii % ELTS_PER_BYTE;
|
||||||
|
|
||||||
const int jj_byte = jj / ELTS_PER_BYTE;
|
int const jj_byte = jj / ELTS_PER_BYTE;
|
||||||
const int jj_bit_offset = jj % ELTS_PER_BYTE;
|
int const jj_bit_offset = jj % ELTS_PER_BYTE;
|
||||||
|
|
||||||
uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset));
|
uint8_t src_elt = 0xF & (cache_buf[ii][jj_byte] >> (4 * jj_bit_offset));
|
||||||
uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset));
|
uint8_t tgt_elt = 0xF & (cache_buf[jj][ii_byte] >> (4 * ii_bit_offset));
|
||||||
@ -338,15 +338,15 @@ void subbyte_transpose_impl(
|
|||||||
const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE;
|
const size_t row_tile_start_trans = col_tile_start_byte * ELTS_PER_BYTE;
|
||||||
const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE;
|
const size_t col_tile_start_byte_trans = row_tile_start / ELTS_PER_BYTE;
|
||||||
|
|
||||||
const int row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols);
|
int const row_limit_trans = std::min(row_tile_start_trans + M_TILE_L1, num_cols);
|
||||||
const int col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
|
int const col_limit_trans = std::min(col_tile_start_byte_trans + N_TILE_L1, col_bytes_trans);
|
||||||
|
|
||||||
for (int ii = 0; ii < M_TILE_L1; ++ii)
|
for (int ii = 0; ii < M_TILE_L1; ++ii)
|
||||||
{
|
{
|
||||||
const int row = row_tile_start_trans + ii;
|
int const row = row_tile_start_trans + ii;
|
||||||
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH)
|
for (int jj = 0; jj < N_TILE_L1; jj += VECTOR_WIDTH)
|
||||||
{
|
{
|
||||||
const int col = col_tile_start_byte_trans + jj;
|
int const col = col_tile_start_byte_trans + jj;
|
||||||
|
|
||||||
const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col;
|
const size_t logical_tgt_offset = matrix_offset + row * col_bytes_trans + col;
|
||||||
|
|
||||||
@ -364,8 +364,8 @@ void subbyte_transpose_impl(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor,
|
void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor,
|
||||||
const std::vector<size_t>& shape, QuantType quant_type)
|
std::vector<size_t> const& shape, QuantType quant_type)
|
||||||
{
|
{
|
||||||
|
|
||||||
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
|
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
|
||||||
@ -409,7 +409,7 @@ void add_bias_and_interleave_int8s_inplace(int8_t* int8_tensor, const size_t num
|
|||||||
|
|
||||||
void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts)
|
void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const size_t num_elts)
|
||||||
{
|
{
|
||||||
const int num_bytes = num_elts / 2;
|
int const num_bytes = num_elts / 2;
|
||||||
|
|
||||||
// Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little
|
// Step 1 will be to transform all the int4s to unsigned in order to make the dequantize take as little
|
||||||
// instructions as possible in the CUDA code.
|
// instructions as possible in the CUDA code.
|
||||||
@ -451,9 +451,9 @@ void add_bias_and_interleave_int4s_inplace(int8_t* packed_int4_tensor, const siz
|
|||||||
|
|
||||||
for (int dest_idx = 0; dest_idx < 8; ++dest_idx)
|
for (int dest_idx = 0; dest_idx < 8; ++dest_idx)
|
||||||
{
|
{
|
||||||
const int src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1;
|
int const src_idx = dest_idx < 4 ? 2 * dest_idx : 2 * (dest_idx - 4) + 1;
|
||||||
const int src_shift = 4 * src_idx;
|
int const src_shift = 4 * src_idx;
|
||||||
const int dest_shift = 4 * dest_idx;
|
int const dest_shift = 4 * dest_idx;
|
||||||
|
|
||||||
const uint32_t src_bits = (current_register >> src_shift) & 0xF;
|
const uint32_t src_bits = (current_register >> src_shift) & 0xF;
|
||||||
transformed_register |= (src_bits << dest_shift);
|
transformed_register |= (src_bits << dest_shift);
|
||||||
@ -478,8 +478,8 @@ void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const int8_t* quantized_tensor,
|
void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, int8_t const* quantized_tensor,
|
||||||
const std::vector<size_t>& shape, QuantType quant_type, LayoutDetails details)
|
std::vector<size_t> const& shape, QuantType quant_type, LayoutDetails details)
|
||||||
{
|
{
|
||||||
|
|
||||||
// We only want to run this step for weight only quant.
|
// We only want to run this step for weight only quant.
|
||||||
@ -490,23 +490,23 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const
|
|||||||
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
||||||
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
||||||
|
|
||||||
const int BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
int const BITS_PER_ELT = get_bits_in_quant_type(quant_type);
|
||||||
const int elts_in_int32 = 32 / BITS_PER_ELT;
|
int const elts_in_int32 = 32 / BITS_PER_ELT;
|
||||||
|
|
||||||
const int rows_per_tile = details.rows_per_column_tile;
|
int const rows_per_tile = details.rows_per_column_tile;
|
||||||
|
|
||||||
TLLM_CHECK_WITH_INFO(!(num_rows % elts_in_int32),
|
TLLM_CHECK_WITH_INFO(!(num_rows % elts_in_int32),
|
||||||
fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", elts_in_int32, num_rows));
|
fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", elts_in_int32, num_rows));
|
||||||
|
|
||||||
const uint32_t* input_byte_ptr = reinterpret_cast<const uint32_t*>(quantized_tensor);
|
uint32_t const* input_byte_ptr = reinterpret_cast<uint32_t const*>(quantized_tensor);
|
||||||
uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(interleaved_quantized_tensor);
|
uint32_t* output_byte_ptr = reinterpret_cast<uint32_t*>(interleaved_quantized_tensor);
|
||||||
|
|
||||||
TLLM_CHECK_WITH_INFO(!(num_rows % rows_per_tile),
|
TLLM_CHECK_WITH_INFO(!(num_rows % rows_per_tile),
|
||||||
fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", rows_per_tile, num_rows));
|
fmtstr("The number of rows must be a multiple of %d but the number of rows is %ld.", rows_per_tile, num_rows));
|
||||||
|
|
||||||
const int num_vec_rows = num_rows / elts_in_int32;
|
int const num_vec_rows = num_rows / elts_in_int32;
|
||||||
const int vec_rows_per_tile = rows_per_tile / elts_in_int32;
|
int const vec_rows_per_tile = rows_per_tile / elts_in_int32;
|
||||||
const int interleave = details.columns_interleaved;
|
int const interleave = details.columns_interleaved;
|
||||||
|
|
||||||
for (int expert = 0; expert < num_experts; ++expert)
|
for (int expert = 0; expert < num_experts; ++expert)
|
||||||
{
|
{
|
||||||
@ -532,8 +532,8 @@ void interleave_column_major_tensor(int8_t* interleaved_quantized_tensor, const
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight,
|
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight,
|
||||||
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave)
|
std::vector<size_t> const& shape, QuantType quant_type, bool force_interleave)
|
||||||
{
|
{
|
||||||
int arch = getSMVersion();
|
int arch = getSMVersion();
|
||||||
if (force_interleave && arch == 90)
|
if (force_interleave && arch == 90)
|
||||||
@ -546,7 +546,7 @@ void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, co
|
|||||||
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
|
TLLM_CHECK_WITH_INFO(shape.size() == 2 || shape.size() == 3, "Shape must be 2-D or 3-D");
|
||||||
|
|
||||||
size_t num_elts = 1;
|
size_t num_elts = 1;
|
||||||
for (const auto& dim : shape)
|
for (auto const& dim : shape)
|
||||||
{
|
{
|
||||||
num_elts *= dim;
|
num_elts *= dim;
|
||||||
}
|
}
|
||||||
@ -620,7 +620,7 @@ Outputs
|
|||||||
|
|
||||||
template <typename ComputeType, typename WeightType>
|
template <typename ComputeType, typename WeightType>
|
||||||
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
|
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
|
||||||
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type,
|
ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector<size_t> const& shape, QuantType quant_type,
|
||||||
bool force_interleave)
|
bool force_interleave)
|
||||||
{
|
{
|
||||||
|
|
||||||
@ -633,8 +633,8 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
|||||||
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
const size_t num_rows = shape.size() == 2 ? shape[0] : shape[1];
|
||||||
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
const size_t num_cols = shape.size() == 2 ? shape[1] : shape[2];
|
||||||
|
|
||||||
const int bits_in_type = get_bits_in_quant_type(quant_type);
|
int const bits_in_type = get_bits_in_quant_type(quant_type);
|
||||||
const int bytes_per_out_col = num_cols * bits_in_type / 8;
|
int const bytes_per_out_col = num_cols * bits_in_type / 8;
|
||||||
|
|
||||||
std::vector<int8_t> weight_buf;
|
std::vector<int8_t> weight_buf;
|
||||||
if (unprocessed_quantized_weight == nullptr)
|
if (unprocessed_quantized_weight == nullptr)
|
||||||
@ -643,15 +643,15 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
|||||||
unprocessed_quantized_weight = weight_buf.data();
|
unprocessed_quantized_weight = weight_buf.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
const int input_mat_size = num_rows * num_cols;
|
int const input_mat_size = num_rows * num_cols;
|
||||||
const int quantized_mat_size = num_rows * bytes_per_out_col;
|
int const quantized_mat_size = num_rows * bytes_per_out_col;
|
||||||
const float quant_range_scale = 1.f / float(1 << (bits_in_type - 1));
|
float const quant_range_scale = 1.f / float(1 << (bits_in_type - 1));
|
||||||
|
|
||||||
std::vector<float> per_col_max(num_cols);
|
std::vector<float> per_col_max(num_cols);
|
||||||
|
|
||||||
for (int expert = 0; expert < num_experts; ++expert)
|
for (int expert = 0; expert < num_experts; ++expert)
|
||||||
{
|
{
|
||||||
const WeightType* current_weight = input_weight_ptr + expert * input_mat_size;
|
WeightType const* current_weight = input_weight_ptr + expert * input_mat_size;
|
||||||
int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size;
|
int8_t* current_quantized_weight = unprocessed_quantized_weight + expert * quantized_mat_size;
|
||||||
|
|
||||||
// First we find the per column max for this expert weight.
|
// First we find the per column max for this expert weight.
|
||||||
@ -662,7 +662,7 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
|||||||
|
|
||||||
for (int ii = 0; ii < num_rows; ++ii)
|
for (int ii = 0; ii < num_rows; ++ii)
|
||||||
{
|
{
|
||||||
const WeightType* current_weight_row = current_weight + ii * num_cols;
|
WeightType const* current_weight_row = current_weight + ii * num_cols;
|
||||||
for (int jj = 0; jj < num_cols; ++jj)
|
for (int jj = 0; jj < num_cols; ++jj)
|
||||||
{
|
{
|
||||||
per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj])));
|
per_col_max[jj] = std::max(per_col_max[jj], std::abs(float(current_weight_row[jj])));
|
||||||
@ -681,15 +681,15 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
|||||||
for (int ii = 0; ii < num_rows; ++ii)
|
for (int ii = 0; ii < num_rows; ++ii)
|
||||||
{
|
{
|
||||||
int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col;
|
int8_t* current_quantized_weight_row = current_quantized_weight + ii * bytes_per_out_col;
|
||||||
const WeightType* current_weight_row = current_weight + ii * num_cols;
|
WeightType const* current_weight_row = current_weight + ii * num_cols;
|
||||||
for (int jj = 0; jj < bytes_per_out_col; ++jj)
|
for (int jj = 0; jj < bytes_per_out_col; ++jj)
|
||||||
{
|
{
|
||||||
|
|
||||||
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
|
if (quant_type == QuantType::INT8_WEIGHT_ONLY)
|
||||||
{
|
{
|
||||||
const float col_scale = per_col_max[jj];
|
float const col_scale = per_col_max[jj];
|
||||||
const float weight_elt = float(current_weight_row[jj]);
|
float const weight_elt = float(current_weight_row[jj]);
|
||||||
const float scaled_weight = round(weight_elt / col_scale);
|
float const scaled_weight = round(weight_elt / col_scale);
|
||||||
const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight)));
|
const int8_t clipped_weight = int8_t(std::max(-128.f, std::min(127.f, scaled_weight)));
|
||||||
current_quantized_weight_row[jj] = clipped_weight;
|
current_quantized_weight_row[jj] = clipped_weight;
|
||||||
}
|
}
|
||||||
@ -700,12 +700,12 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
|||||||
int8_t packed_int4s = 0;
|
int8_t packed_int4s = 0;
|
||||||
for (int packed_idx = 0; packed_idx < 2; ++packed_idx)
|
for (int packed_idx = 0; packed_idx < 2; ++packed_idx)
|
||||||
{
|
{
|
||||||
const int input_idx = 2 * jj + packed_idx;
|
int const input_idx = 2 * jj + packed_idx;
|
||||||
if (input_idx < num_cols)
|
if (input_idx < num_cols)
|
||||||
{
|
{
|
||||||
const float col_scale = per_col_max[input_idx];
|
float const col_scale = per_col_max[input_idx];
|
||||||
const float weight_elt = float(current_weight_row[input_idx]);
|
float const weight_elt = float(current_weight_row[input_idx]);
|
||||||
const float scaled_weight = round(weight_elt / col_scale);
|
float const scaled_weight = round(weight_elt / col_scale);
|
||||||
int int_weight = int(scaled_weight);
|
int int_weight = int(scaled_weight);
|
||||||
const int8_t clipped_weight = std::max(-8, std::min(7, int_weight));
|
const int8_t clipped_weight = std::max(-8, std::min(7, int_weight));
|
||||||
|
|
||||||
@ -729,47 +729,47 @@ void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_
|
|||||||
}
|
}
|
||||||
|
|
||||||
template void symmetric_quantize<half, float>(
|
template void symmetric_quantize<half, float>(
|
||||||
int8_t*, int8_t*, half*, const float*, const std::vector<size_t>&, QuantType, bool);
|
int8_t*, int8_t*, half*, float const*, std::vector<size_t> const&, QuantType, bool);
|
||||||
|
|
||||||
template void symmetric_quantize<half, half>(
|
template void symmetric_quantize<half, half>(
|
||||||
int8_t*, int8_t*, half*, const half*, const std::vector<size_t>&, QuantType, bool);
|
int8_t*, int8_t*, half*, half const*, std::vector<size_t> const&, QuantType, bool);
|
||||||
|
|
||||||
#ifdef ENABLE_BF16
|
#ifdef ENABLE_BF16
|
||||||
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
|
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
|
||||||
int8_t*, int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
|
int8_t*, int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector<size_t> const&, QuantType, bool);
|
||||||
|
|
||||||
template void symmetric_quantize<__nv_bfloat16, float>(
|
template void symmetric_quantize<__nv_bfloat16, float>(
|
||||||
int8_t*, int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType, bool);
|
int8_t*, int8_t*, __nv_bfloat16*, float const*, std::vector<size_t> const&, QuantType, bool);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
template <typename ComputeType, typename WeightType>
|
template <typename ComputeType, typename WeightType>
|
||||||
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr,
|
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr,
|
||||||
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave)
|
std::vector<size_t> const& shape, QuantType quant_type, bool force_interleave)
|
||||||
{
|
{
|
||||||
symmetric_quantize(
|
symmetric_quantize(
|
||||||
processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave);
|
processed_quantized_weight, nullptr, scale_ptr, input_weight_ptr, shape, quant_type, force_interleave);
|
||||||
}
|
}
|
||||||
|
|
||||||
template void symmetric_quantize<float, float>(
|
template void symmetric_quantize<float, float>(
|
||||||
int8_t*, float*, const float*, const std::vector<size_t>&, QuantType, bool);
|
int8_t*, float*, float const*, std::vector<size_t> const&, QuantType, bool);
|
||||||
|
|
||||||
template void symmetric_quantize<half, float>(
|
template void symmetric_quantize<half, float>(
|
||||||
int8_t*, half*, const float*, const std::vector<size_t>&, QuantType, bool);
|
int8_t*, half*, float const*, std::vector<size_t> const&, QuantType, bool);
|
||||||
|
|
||||||
template void symmetric_quantize<half, half>(int8_t*, half*, const half*, const std::vector<size_t>&, QuantType, bool);
|
template void symmetric_quantize<half, half>(int8_t*, half*, half const*, std::vector<size_t> const&, QuantType, bool);
|
||||||
|
|
||||||
#ifdef ENABLE_BF16
|
#ifdef ENABLE_BF16
|
||||||
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
|
template void symmetric_quantize<__nv_bfloat16, __nv_bfloat16>(
|
||||||
int8_t*, __nv_bfloat16*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
|
int8_t*, __nv_bfloat16*, __nv_bfloat16 const*, std::vector<size_t> const&, QuantType, bool);
|
||||||
|
|
||||||
template void symmetric_quantize<__nv_bfloat16, half>(
|
template void symmetric_quantize<__nv_bfloat16, half>(
|
||||||
int8_t*, __nv_bfloat16*, const half*, const std::vector<size_t>&, QuantType, bool);
|
int8_t*, __nv_bfloat16*, half const*, std::vector<size_t> const&, QuantType, bool);
|
||||||
|
|
||||||
template void symmetric_quantize<half, __nv_bfloat16>(
|
template void symmetric_quantize<half, __nv_bfloat16>(
|
||||||
int8_t*, half*, const __nv_bfloat16*, const std::vector<size_t>&, QuantType, bool);
|
int8_t*, half*, __nv_bfloat16 const*, std::vector<size_t> const&, QuantType, bool);
|
||||||
|
|
||||||
template void symmetric_quantize<__nv_bfloat16, float>(
|
template void symmetric_quantize<__nv_bfloat16, float>(
|
||||||
int8_t*, __nv_bfloat16*, const float*, const std::vector<size_t>&, QuantType, bool);
|
int8_t*, __nv_bfloat16*, float const*, std::vector<size_t> const&, QuantType, bool);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
} // namespace cutlass_kernels
|
} // namespace cutlass_kernels
|
||||||
|
|||||||
@ -38,26 +38,26 @@ int get_bits_in_quant_type(QuantType quant_type);
|
|||||||
|
|
||||||
// Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols]
|
// Shapes here can be 2 or 3D. 2-D shapes are [num_rows, num_cols]
|
||||||
// 3-D shapes are [num_experts, num_rows, num_cols]
|
// 3-D shapes are [num_experts, num_rows, num_cols]
|
||||||
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, const int8_t* quantized_tensor,
|
void permute_B_rows_for_mixed_gemm(int8_t* permuted_quantized_tensor, int8_t const* quantized_tensor,
|
||||||
const std::vector<size_t>& shape, QuantType quant_type, const int64_t arch_version);
|
std::vector<size_t> const& shape, QuantType quant_type, const int64_t arch_version);
|
||||||
|
|
||||||
void subbyte_transpose(int8_t* transposed_quantized_tensor, const int8_t* quantized_tensor,
|
void subbyte_transpose(int8_t* transposed_quantized_tensor, int8_t const* quantized_tensor,
|
||||||
const std::vector<size_t>& shape, QuantType quant_type);
|
std::vector<size_t> const& shape, QuantType quant_type);
|
||||||
|
|
||||||
void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type);
|
void add_bias_and_interleave_quantized_tensor_inplace(int8_t* tensor, const size_t num_elts, QuantType quant_type);
|
||||||
|
|
||||||
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, const int8_t* row_major_quantized_weight,
|
void preprocess_weights_for_mixed_gemm(int8_t* preprocessed_quantized_weight, int8_t const* row_major_quantized_weight,
|
||||||
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave = false);
|
std::vector<size_t> const& shape, QuantType quant_type, bool force_interleave = false);
|
||||||
|
|
||||||
template <typename ComputeType, typename WeightType>
|
template <typename ComputeType, typename WeightType>
|
||||||
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, const WeightType* input_weight_ptr,
|
void symmetric_quantize(int8_t* processed_quantized_weight, ComputeType* scale_ptr, WeightType const* input_weight_ptr,
|
||||||
const std::vector<size_t>& shape, QuantType quant_type, bool force_interleave);
|
std::vector<size_t> const& shape, QuantType quant_type, bool force_interleave);
|
||||||
|
|
||||||
// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight
|
// This is exposed so that we can write tests that use the processed weights for CUTLASS but the unprocessed weight
|
||||||
// to implement a simple reference implementation.
|
// to implement a simple reference implementation.
|
||||||
template <typename ComputeType, typename WeightType>
|
template <typename ComputeType, typename WeightType>
|
||||||
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
|
void symmetric_quantize(int8_t* processed_quantized_weight, int8_t* unprocessed_quantized_weight,
|
||||||
ComputeType* scale_ptr, const WeightType* input_weight_ptr, const std::vector<size_t>& shape, QuantType quant_type,
|
ComputeType* scale_ptr, WeightType const* input_weight_ptr, std::vector<size_t> const& shape, QuantType quant_type,
|
||||||
bool force_interleave);
|
bool force_interleave);
|
||||||
|
|
||||||
} // namespace cutlass_kernels
|
} // namespace cutlass_kernels
|
||||||
|
|||||||
@ -58,27 +58,27 @@ public:
|
|||||||
|
|
||||||
virtual ~CutlassFpAIntBGemmRunnerInterface() {}
|
virtual ~CutlassFpAIntBGemmRunnerInterface() {}
|
||||||
|
|
||||||
virtual void gemm(const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k,
|
virtual void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k,
|
||||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
virtual void gemm(const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n,
|
virtual void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n,
|
||||||
int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
int k, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
||||||
cudaStream_t stream)
|
cudaStream_t stream)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
|
virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points,
|
||||||
const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig,
|
void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig,
|
||||||
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
virtual void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
|
virtual void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points,
|
||||||
const void* biases, const float alpha, void* C, int m, int n, int k, const int group_size,
|
void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size,
|
||||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
// Returns desired workspace size in bytes.
|
// Returns desired workspace size in bytes.
|
||||||
virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0;
|
virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0;
|
||||||
|
|
||||||
virtual std::vector<tkc::CutlassGemmConfig> getConfigs() const = 0;
|
virtual std::vector<tkc::CutlassGemmConfig> getConfigs() const = 0;
|
||||||
|
|
||||||
@ -96,20 +96,20 @@ public:
|
|||||||
CutlassFpAIntBGemmRunner();
|
CutlassFpAIntBGemmRunner();
|
||||||
~CutlassFpAIntBGemmRunner();
|
~CutlassFpAIntBGemmRunner();
|
||||||
|
|
||||||
void gemm(const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k,
|
void gemm(void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k,
|
||||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
||||||
cudaStream_t stream) override;
|
cudaStream_t stream) override;
|
||||||
|
|
||||||
void gemm(const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, int k,
|
void gemm(void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k,
|
||||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
||||||
cudaStream_t stream) override;
|
cudaStream_t stream) override;
|
||||||
|
|
||||||
void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
|
void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points,
|
||||||
const void* biases, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig,
|
void const* biases, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig,
|
||||||
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override;
|
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream) override;
|
||||||
|
|
||||||
void gemm(const void* A, const void* B, const void* weight_scales, const void* weight_zero_points,
|
void gemm(void const* A, void const* B, void const* weight_scales, void const* weight_zero_points,
|
||||||
const void* biases, const float alpha, void* C, int m, int n, int k, const int group_size,
|
void const* biases, float const alpha, void* C, int m, int n, int k, int const group_size,
|
||||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes,
|
||||||
cudaStream_t stream) override;
|
cudaStream_t stream) override;
|
||||||
|
|
||||||
@ -120,15 +120,15 @@ public:
|
|||||||
// stream);
|
// stream);
|
||||||
|
|
||||||
// Returns desired workspace size in bytes.
|
// Returns desired workspace size in bytes.
|
||||||
size_t getWorkspaceSize(const int m, const int n, const int k) override;
|
size_t getWorkspaceSize(int const m, int const n, int const k) override;
|
||||||
|
|
||||||
std::vector<tkc::CutlassGemmConfig> getConfigs() const override;
|
std::vector<tkc::CutlassGemmConfig> getConfigs() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
template <typename EpilogueTag>
|
template <typename EpilogueTag>
|
||||||
void dispatch_to_arch(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
void dispatch_to_arch(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||||
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr,
|
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace_ptr,
|
||||||
const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr);
|
const size_t workspace_bytes, cudaStream_t stream, int* occupancy = nullptr);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|||||||
@ -52,8 +52,8 @@ namespace cutlass_kernels
|
|||||||
|
|
||||||
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
||||||
typename ThreadblockShape, typename WarpShape, int Stages>
|
typename ThreadblockShape, typename WarpShape, int Stages>
|
||||||
void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T* weight_scales,
|
void generic_mixed_gemm_kernelLauncher(T const* A, WeightType const* B, T const* weight_scales,
|
||||||
const T* weight_zero_points, const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size,
|
T const* weight_zero_points, T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
|
||||||
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
||||||
int* occupancy = nullptr)
|
int* occupancy = nullptr)
|
||||||
{
|
{
|
||||||
@ -127,7 +127,7 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
|
|||||||
|
|
||||||
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
|
using Gemm = cutlass::gemm::device::GemmUniversalBaseCompat<GemmKernel>;
|
||||||
|
|
||||||
const int ldb = cutlass::platform::is_same<cutlass::layout::RowMajor, typename MixedGemmArchTraits::LayoutB>::value
|
int const ldb = cutlass::platform::is_same<cutlass::layout::RowMajor, typename MixedGemmArchTraits::LayoutB>::value
|
||||||
? n
|
? n
|
||||||
: k * GemmKernel::kInterleave;
|
: k * GemmKernel::kInterleave;
|
||||||
|
|
||||||
@ -171,7 +171,7 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const int ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0;
|
int const ld_scale_zero = cutlass::isFinegrained(QuantOp) ? n : 0;
|
||||||
ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f);
|
ElementAccumulator output_op_beta = (biases == nullptr) ? ElementAccumulator(0.f) : ElementAccumulator(1.f);
|
||||||
typename Gemm::Arguments args({m, n, k}, group_size, {reinterpret_cast<ElementType*>(const_cast<T*>(A)), k},
|
typename Gemm::Arguments args({m, n, k}, group_size, {reinterpret_cast<ElementType*>(const_cast<T*>(A)), k},
|
||||||
{reinterpret_cast<CutlassWeightType*>(const_cast<WeightType*>(B)), ldb},
|
{reinterpret_cast<CutlassWeightType*>(const_cast<WeightType*>(B)), ldb},
|
||||||
@ -230,8 +230,8 @@ void generic_mixed_gemm_kernelLauncher(const T* A, const WeightType* B, const T*
|
|||||||
// quanitzation is only supported on Ampere+ GPUs.
|
// quanitzation is only supported on Ampere+ GPUs.
|
||||||
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
||||||
typename ThreadblockShape, typename WarpShape, int Stages>
|
typename ThreadblockShape, typename WarpShape, int Stages>
|
||||||
void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points,
|
void filter_and_run_mixed_gemm(T const* A, WeightType const* B, T const* weight_scales, T const* weight_zero_points,
|
||||||
const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size,
|
T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
|
||||||
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
||||||
int* occupancy = nullptr)
|
int* occupancy = nullptr)
|
||||||
{
|
{
|
||||||
@ -261,8 +261,8 @@ void filter_and_run_mixed_gemm(const T* A, const WeightType* B, const T* weight_
|
|||||||
|
|
||||||
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
template <typename T, typename WeightType, typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag,
|
||||||
typename ThreadblockShape, typename WarpShape>
|
typename ThreadblockShape, typename WarpShape>
|
||||||
void dispatch_gemm_config(const T* A, const WeightType* B, const T* weight_scales, const T* weight_zero_points,
|
void dispatch_gemm_config(T const* A, WeightType const* B, T const* weight_scales, T const* weight_zero_points,
|
||||||
const T* biases, const float alpha, T* C, int m, int n, int k, const int group_size,
|
T const* biases, float const alpha, T* C, int m, int n, int k, int const group_size,
|
||||||
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes, cudaStream_t stream,
|
||||||
int* occupancy = nullptr)
|
int* occupancy = nullptr)
|
||||||
{
|
{
|
||||||
@ -300,9 +300,9 @@ constexpr bool is_fp8()
|
|||||||
|
|
||||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||||
typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
|
typename arch, cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
|
||||||
void dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
void dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||||
int k, const int group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
|
int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
|
||||||
cudaStream_t stream, int* occupancy = nullptr)
|
cudaStream_t stream, int* occupancy = nullptr)
|
||||||
{
|
{
|
||||||
|
|
||||||
@ -412,9 +412,9 @@ template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuant
|
|||||||
typename BiasType, typename OutputType>
|
typename BiasType, typename OutputType>
|
||||||
template <typename EpilogueTag>
|
template <typename EpilogueTag>
|
||||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType,
|
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType,
|
||||||
OutputType>::dispatch_to_arch<EpilogueTag>(const ActivationType* A, const WeightType* B,
|
OutputType>::dispatch_to_arch<EpilogueTag>(ActivationType const* A, WeightType const* B,
|
||||||
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases,
|
ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
|
||||||
const float alpha, OutputType* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
|
float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config,
|
||||||
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy)
|
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream, int* occupancy)
|
||||||
{
|
{
|
||||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||||
@ -453,16 +453,16 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
|
|||||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||||
typename BiasType, typename OutputType>
|
typename BiasType, typename OutputType>
|
||||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
||||||
const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, const void* biases,
|
void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases,
|
||||||
const float alpha, void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig,
|
float const alpha, void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig,
|
||||||
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||||
if constexpr ((QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
|
if constexpr ((QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_AND_ZEROS)
|
||||||
|| (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY))
|
|| (QuantOp == cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY))
|
||||||
{
|
{
|
||||||
dispatch_to_arch<tkc::EpilogueOpBias>((const ActivationType*) A, (const WeightType*) B,
|
dispatch_to_arch<tkc::EpilogueOpBias>((ActivationType const*) A, (WeightType const*) B,
|
||||||
(const ScaleZeroType*) weight_scales, (const ScaleZeroType*) weight_zero_points, (const BiasType*) biases,
|
(ScaleZeroType const*) weight_scales, (ScaleZeroType const*) weight_zero_points, (BiasType const*) biases,
|
||||||
alpha, (OutputType*) C, m, n, k, group_size, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr);
|
alpha, (OutputType*) C, m, n, k, group_size, gemmConfig, workspace_ptr, workspace_bytes, stream, nullptr);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@ -475,8 +475,8 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
|
|||||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||||
typename BiasType, typename OutputType>
|
typename BiasType, typename OutputType>
|
||||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
||||||
const void* A, const void* B, const void* weight_scales, const void* weight_zero_points, const void* biases,
|
void const* A, void const* B, void const* weight_scales, void const* weight_zero_points, void const* biases,
|
||||||
void* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr,
|
void* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr,
|
||||||
const size_t workspace_bytes, cudaStream_t stream)
|
const size_t workspace_bytes, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||||
@ -487,15 +487,15 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
|
|||||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||||
typename BiasType, typename OutputType>
|
typename BiasType, typename OutputType>
|
||||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
||||||
const void* A, const void* B, const void* weight_scales, const float alpha, void* C, int m, int n, int k,
|
void const* A, void const* B, void const* weight_scales, float const alpha, void* C, int m, int n, int k,
|
||||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||||
|
|
||||||
if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY)
|
if constexpr (QuantOp == cutlass::WeightOnlyQuantOp::PER_COLUMN_SCALE_ONLY)
|
||||||
{
|
{
|
||||||
dispatch_to_arch<tkc::EpilogueOpBias>((const ActivationType*) A, (const WeightType*) B,
|
dispatch_to_arch<tkc::EpilogueOpBias>((ActivationType const*) A, (WeightType const*) B,
|
||||||
(const ScaleZeroType*) weight_scales, nullptr, nullptr, alpha, (OutputType*) C, m, n, k, k, gemmConfig,
|
(ScaleZeroType const*) weight_scales, nullptr, nullptr, alpha, (OutputType*) C, m, n, k, k, gemmConfig,
|
||||||
workspace_ptr, workspace_bytes, stream, nullptr);
|
workspace_ptr, workspace_bytes, stream, nullptr);
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
@ -507,7 +507,7 @@ void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType
|
|||||||
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuantOp QuantOp, typename ScaleZeroType,
|
||||||
typename BiasType, typename OutputType>
|
typename BiasType, typename OutputType>
|
||||||
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
void CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::gemm(
|
||||||
const void* A, const void* B, const void* weight_scales, void* C, int m, int n, int k,
|
void const* A, void const* B, void const* weight_scales, void* C, int m, int n, int k,
|
||||||
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
tkc::CutlassGemmConfig gemmConfig, char* workspace_ptr, const size_t workspace_bytes, cudaStream_t stream)
|
||||||
{
|
{
|
||||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||||
@ -529,12 +529,12 @@ template <typename ActivationType, typename WeightType, cutlass::WeightOnlyQuant
|
|||||||
typename BiasType, typename OutputType>
|
typename BiasType, typename OutputType>
|
||||||
size_t
|
size_t
|
||||||
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::getWorkspaceSize(
|
CutlassFpAIntBGemmRunner<ActivationType, WeightType, QuantOp, ScaleZeroType, BiasType, OutputType>::getWorkspaceSize(
|
||||||
const int m, const int n, const int k)
|
int const m, int const n, int const k)
|
||||||
{
|
{
|
||||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||||
// These are the min tile sizes for each config, which would launch the maximum number of blocks
|
// These are the min tile sizes for each config, which would launch the maximum number of blocks
|
||||||
const int max_grid_m = cutlass::ceil_div(m, MIN_M_TILE);
|
int const max_grid_m = cutlass::ceil_div(m, MIN_M_TILE);
|
||||||
const int max_grid_n = cutlass::ceil_div(n, MIN_N_TILE);
|
int const max_grid_n = cutlass::ceil_div(n, MIN_N_TILE);
|
||||||
// We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
|
// We need 4 bytes per block in the worst case. We launch split_k_limit in z dim.
|
||||||
return static_cast<size_t>(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4);
|
return static_cast<size_t>(max_grid_m * max_grid_n * SPLIT_K_LIMIT * 4);
|
||||||
}
|
}
|
||||||
|
|||||||
@ -44,9 +44,9 @@ namespace cutlass_kernels
|
|||||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
||||||
typename MainloopScheduleType>
|
typename MainloopScheduleType>
|
||||||
void sm90_dispatch_epilogue_schedules(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
void sm90_dispatch_epilogue_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||||
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||||
cudaStream_t stream, int* occupancy = nullptr)
|
cudaStream_t stream, int* occupancy = nullptr)
|
||||||
{
|
{
|
||||||
|
|
||||||
@ -114,9 +114,9 @@ constexpr bool are_tile_shapes_supported()
|
|||||||
|
|
||||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape>
|
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape>
|
||||||
void sm90_dispatch_mainloop_schedules(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
void sm90_dispatch_mainloop_schedules(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||||
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||||
cudaStream_t stream, int* occupancy = nullptr)
|
cudaStream_t stream, int* occupancy = nullptr)
|
||||||
{
|
{
|
||||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||||
@ -153,9 +153,9 @@ void sm90_dispatch_mainloop_schedules(const ActivationType* A, const WeightType*
|
|||||||
|
|
||||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape>
|
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape>
|
||||||
void sm90_dispatch_gemm_config(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
void sm90_dispatch_gemm_config(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||||
int k, const int group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
int k, int const group_size, tkc::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||||
cudaStream_t stream, int* occupancy = nullptr)
|
cudaStream_t stream, int* occupancy = nullptr)
|
||||||
{
|
{
|
||||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||||
@ -190,9 +190,9 @@ void sm90_dispatch_gemm_config(const ActivationType* A, const WeightType* B, con
|
|||||||
|
|
||||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
|
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag>
|
||||||
void sm90_dispatch_gemm_to_cutlass(const ActivationType* A, const WeightType* B, const ScaleZeroType* weight_scales,
|
void sm90_dispatch_gemm_to_cutlass(ActivationType const* A, WeightType const* B, ScaleZeroType const* weight_scales,
|
||||||
const ScaleZeroType* weight_zero_points, const BiasType* biases, const float alpha, OutputType* C, int m, int n,
|
ScaleZeroType const* weight_zero_points, BiasType const* biases, float const alpha, OutputType* C, int m, int n,
|
||||||
int k, const int group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
|
int k, int const group_size, char* workspace, size_t workspace_bytes, tkc::CutlassGemmConfig gemm_config,
|
||||||
cudaStream_t stream, int* occupancy = nullptr)
|
cudaStream_t stream, int* occupancy = nullptr)
|
||||||
{
|
{
|
||||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||||
|
|||||||
@ -28,9 +28,9 @@ namespace cutlass_kernels
|
|||||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
||||||
typename MainloopScheduleType, typename EpilogueScheduleType>
|
typename MainloopScheduleType, typename EpilogueScheduleType>
|
||||||
void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const WeightType* B,
|
void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B,
|
||||||
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases,
|
ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
|
||||||
const float alpha, OutputType* C, int m, int n, int k, const int group_size,
|
float const alpha, OutputType* C, int m, int n, int k, int const group_size,
|
||||||
tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
tensorrt_llm::cutlass_extensions::CutlassGemmConfig gemm_config, char* workspace, size_t workspace_bytes,
|
||||||
cudaStream_t stream, int* occupancy = nullptr);
|
cudaStream_t stream, int* occupancy = nullptr);
|
||||||
|
|
||||||
|
|||||||
@ -59,9 +59,9 @@ namespace cutlass_kernels
|
|||||||
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
template <typename ActivationType, typename WeightType, typename ScaleZeroType, typename BiasType, typename OutputType,
|
||||||
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
cutlass::WeightOnlyQuantOp QuantOp, typename EpilogueTag, typename CTAShape, typename ClusterShape,
|
||||||
typename MainloopScheduleType, typename EpilogueScheduleType>
|
typename MainloopScheduleType, typename EpilogueScheduleType>
|
||||||
void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const WeightType* B,
|
void sm90_generic_mixed_gemm_kernelLauncher(ActivationType const* A, WeightType const* B,
|
||||||
const ScaleZeroType* weight_scales, const ScaleZeroType* weight_zero_points, const BiasType* biases,
|
ScaleZeroType const* weight_scales, ScaleZeroType const* weight_zero_points, BiasType const* biases,
|
||||||
const float alpha, OutputType* C, int m, int n, int k, const int group_size, tkc::CutlassGemmConfig gemm_config,
|
float const alpha, OutputType* C, int m, int n, int k, int const group_size, tkc::CutlassGemmConfig gemm_config,
|
||||||
char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy)
|
char* workspace, size_t workspace_bytes, cudaStream_t stream, int* occupancy)
|
||||||
{
|
{
|
||||||
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
TLLM_LOG_DEBUG(__PRETTY_FUNCTION__);
|
||||||
@ -233,7 +233,7 @@ void sm90_generic_mixed_gemm_kernelLauncher(const ActivationType* A, const Weigh
|
|||||||
StrideS stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(n, cutlass_scale_k, 1));
|
StrideS stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(n, cutlass_scale_k, 1));
|
||||||
|
|
||||||
// Use the output as the bias to avoid making a tma descriptor with a nullptr.
|
// Use the output as the bias to avoid making a tma descriptor with a nullptr.
|
||||||
auto output_as_bias_type = reinterpret_cast<const CutlassBiasType*>(C);
|
auto output_as_bias_type = reinterpret_cast<CutlassBiasType const*>(C);
|
||||||
|
|
||||||
typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, {n, m, k, 1},
|
typename Gemm::Arguments args{cutlass::gemm::GemmUniversalMode::kGemm, {n, m, k, 1},
|
||||||
{reinterpret_cast<CutlassWeightType const*>(B), stride_B, reinterpret_cast<CutlassActivationType const*>(A),
|
{reinterpret_cast<CutlassWeightType const*>(B), stride_B, reinterpret_cast<CutlassActivationType const*>(A),
|
||||||
|
|||||||
@ -47,13 +47,13 @@ public:
|
|||||||
|
|
||||||
virtual ~CutlassInt8GemmRunnerInterface() {}
|
virtual ~CutlassInt8GemmRunnerInterface() {}
|
||||||
|
|
||||||
virtual void gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol,
|
virtual void gemm(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol,
|
||||||
const float* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
|
float const* alphaRow, void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
|
||||||
const size_t workspaceBytes, cudaStream_t stream)
|
const size_t workspaceBytes, cudaStream_t stream)
|
||||||
= 0;
|
= 0;
|
||||||
|
|
||||||
// Returns desired workspace size in bytes.
|
// Returns desired workspace size in bytes.
|
||||||
virtual size_t getWorkspaceSize(const int m, const int n, const int k) = 0;
|
virtual size_t getWorkspaceSize(int const m, int const n, int const k) = 0;
|
||||||
|
|
||||||
virtual std::vector<tkc::CutlassGemmConfig> getConfigs() const = 0;
|
virtual std::vector<tkc::CutlassGemmConfig> getConfigs() const = 0;
|
||||||
|
|
||||||
@ -70,18 +70,18 @@ public:
|
|||||||
CutlassInt8GemmRunner();
|
CutlassInt8GemmRunner();
|
||||||
~CutlassInt8GemmRunner();
|
~CutlassInt8GemmRunner();
|
||||||
|
|
||||||
void gemm(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol, const float* alphaRow,
|
void gemm(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol, float const* alphaRow,
|
||||||
void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
|
void* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
|
||||||
const size_t workspaceBytes, cudaStream_t stream) override;
|
const size_t workspaceBytes, cudaStream_t stream) override;
|
||||||
|
|
||||||
// Returns desired workspace size in bytes.
|
// Returns desired workspace size in bytes.
|
||||||
size_t getWorkspaceSize(const int m, const int n, const int k) override;
|
size_t getWorkspaceSize(int const m, int const n, int const k) override;
|
||||||
|
|
||||||
std::vector<tkc::CutlassGemmConfig> getConfigs() const override;
|
std::vector<tkc::CutlassGemmConfig> getConfigs() const override;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
void dispatchToArch(const int8_t* A, const int8_t* B, tk::QuantMode quantOption, const float* alphaCol,
|
void dispatchToArch(int8_t const* A, int8_t const* B, tk::QuantMode quantOption, float const* alphaCol,
|
||||||
const float* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
|
float const* alphaRow, T* C, int m, int n, int k, tkc::CutlassGemmConfig gemmConfig, char* workspacePtr,
|
||||||
const size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr);
|
const size_t workspaceBytes, cudaStream_t stream, int* occupancy = nullptr);
|
||||||
|
|
||||||
int mSm;
|
int mSm;
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user