Update TensorRT-LLM (#2436)

This commit is contained in:
Kaiyu Xie 2024-11-12 15:27:49 +08:00 committed by GitHub
parent b7868dd1bd
commit c629546ce4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
337 changed files with 41747 additions and 3917 deletions

View File

@ -6,8 +6,8 @@ TensorRT-LLM
[![Documentation](https://img.shields.io/badge/docs-latest-brightgreen.svg?style=flat)](https://nvidia.github.io/TensorRT-LLM/)
[![python](https://img.shields.io/badge/python-3.10.12-green)](https://www.python.org/downloads/release/python-31012/)
[![cuda](https://img.shields.io/badge/cuda-12.6.1-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-10.5.0-green)](https://developer.nvidia.com/tensorrt)
[![cuda](https://img.shields.io/badge/cuda-12.6.2-green)](https://developer.nvidia.com/cuda-downloads)
[![trt](https://img.shields.io/badge/TRT-10.6.0-green)](https://developer.nvidia.com/tensorrt)
[![version](https://img.shields.io/badge/release-0.15.0.dev-green)](./tensorrt_llm/version.py)
[![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE)

View File

@ -182,211 +182,8 @@ struct BenchmarkParams
std::optional<texec::LookaheadDecodingConfig> executorLookaheadConfig;
std::optional<texec::LookaheadDecodingConfig> requestLookaheadConfig;
};
class InferenceRequestsAsyncSend
{
public:
InferenceRequestsAsyncSend(std::shared_ptr<tensorrt_llm::mpi::MpiComm> comm,
std::list<std::shared_ptr<InferenceRequest>> const& inferenceRequests, int const peer)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG("start send requests to rank %d", peer);
mNumNewWorkItems = static_cast<int64_t>(inferenceRequests.size());
mRequest1 = comm->sendAsync(&mNumNewWorkItems, 1, mpi::MpiType::kINT64, peer, 0);
if (mNumNewWorkItems > 0)
{
for (auto const& infReq : inferenceRequests)
{
auto vpacked = infReq->serialize();
mPacked.push_back(static_cast<int64_t>(vpacked.size()));
mPacked.insert(mPacked.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
}
mVecSize = static_cast<int64_t>(mPacked.size());
mRequest2 = comm->sendAsync(&mVecSize, 1, mpi::MpiType::kINT64, peer, 1);
mRequest3 = comm->sendAsync(mPacked.data(), mPacked.size(), mpi::MpiType::kINT64, peer, 2);
}
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
~InferenceRequestsAsyncSend()
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
mRequest1->wait();
if (mRequest2)
mRequest2->wait();
if (mRequest3)
mRequest3->wait();
TLLM_LOG_DEBUG("end send requests");
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
private:
int64_t mNumNewWorkItems;
int64_t mVecSize;
std::vector<int64_t> mPacked;
std::shared_ptr<tensorrt_llm::mpi::MpiRequest> mRequest1;
std::shared_ptr<tensorrt_llm::mpi::MpiRequest> mRequest2;
std::shared_ptr<tensorrt_llm::mpi::MpiRequest> mRequest3;
};
} // namespace
void inferenceRequestsRecv(std::shared_ptr<tensorrt_llm::mpi::MpiComm> comm,
std::list<std::shared_ptr<InferenceRequest>>& inferenceRequests, int const peer)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_LOG_DEBUG("start recv requests from rank %d", peer);
int64_t numNewWorkItems = 0;
comm->recv(&numNewWorkItems, 1, mpi::MpiType::kINT64, peer, 0);
if (numNewWorkItems > 0)
{
std::vector<int64_t> packed;
int64_t vecSize;
comm->recv(&vecSize, 1, mpi::MpiType::kINT64, peer, 1);
packed.resize(vecSize);
comm->recv(packed.data(), packed.size(), mpi::MpiType::kINT64, peer, 2);
int64_t* packed_ptr = packed.data();
for (int64_t count = 0; count < numNewWorkItems; ++count)
{
int64_t n = *(packed_ptr++);
auto infReq = InferenceRequest::deserialize(packed_ptr);
packed_ptr += n;
inferenceRequests.emplace_back(infReq);
}
}
TLLM_LOG_DEBUG("end recv requests from rank %d", peer);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}
// Class holding all infos regarding a single work item.
// This includes the original request, associated response factor
// and state.
class WorkItem
{
public:
WorkItem(std::shared_ptr<InferenceRequest> inferenceRequest, uint64_t requestId)
: mInferenceRequest(std::move(inferenceRequest))
, mRequestId(requestId)
{
}
[[nodiscard]] uint64_t requestId() const
{
return mRequestId;
}
[[nodiscard]] std::shared_ptr<InferenceRequest> getInferenceRequest() const
{
return mInferenceRequest;
}
private:
std::shared_ptr<InferenceRequest> mInferenceRequest;
uint64_t mRequestId;
};
/// @brief Thread-safe queue of work items
class WorkItemsQueue
{
public:
void clear()
{
std::lock_guard<std::mutex> lock(mMutex);
mPendingWorkItems.clear();
mPendingWorkItemsReqIds.clear();
mInProgressWorkItems.clear();
}
// Note: this function only be called under a lock
bool hasInProgressReqId(uint64_t const reqId) const
{
return (mInProgressWorkItems.find(reqId) != mInProgressWorkItems.end());
}
// Note: this function only be called under a lock
bool hasPendingReqId(uint64_t const reqId) const
{
return (mPendingWorkItemsReqIds.find(reqId) != mPendingWorkItemsReqIds.end());
}
bool empty() const
{
return mPendingWorkItems.empty() && mInProgressWorkItems.empty() && mPendingWorkItemsReqIds.empty();
}
/// @brief Add a new work item to the queue
/// Throws an error if requestId already exists
void push(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
{
std::lock_guard<std::mutex> lock(mMutex);
TLLM_CHECK_WITH_INFO(!hasInProgressReqId(requestId) && !hasPendingReqId(requestId),
"requestId %lu is already in progress, request is ignored.", requestId);
auto workItem = std::make_shared<WorkItem>(request, requestId);
mPendingWorkItems.push_back(workItem);
mPendingWorkItemsReqIds.insert(workItem->requestId());
}
/// @brief Get a new work item from the queue, and move it to the list of
/// in progress work items if it hasn't been stopped
/// @return A tuple of the workItem and a boolean flag indicating if the work item
/// has been marked in progress
std::tuple<std::shared_ptr<WorkItem>, bool> pop()
{
std::lock_guard<std::mutex> lock(mMutex);
auto workItem = mPendingWorkItems.front();
mPendingWorkItems.pop_front();
mPendingWorkItemsReqIds.erase(workItem->requestId());
bool markedInProgress = false;
mInProgressWorkItems.emplace(workItem->requestId(), workItem);
markedInProgress = true;
return {workItem, markedInProgress};
}
size_t numPendingWorkItems() const
{
std::lock_guard<std::mutex> lock(mMutex);
return mPendingWorkItems.size();
}
size_t numInProgressWorkItems() const
{
std::lock_guard<std::mutex> lock(mMutex);
return mInProgressWorkItems.size();
}
size_t size() const
{
std::lock_guard<std::mutex> lock(mMutex);
return mPendingWorkItems.size() + mInProgressWorkItems.size();
}
/// @brief Mark a request as being finished
/// @param requestId
void markFinished(uint64_t const requestId)
{
std::lock_guard<std::mutex> lock(mMutex);
if (hasInProgressReqId(requestId))
{
mInProgressWorkItems.erase(requestId);
}
}
private:
/// Queue of work items
std::list<std::shared_ptr<WorkItem>> mPendingWorkItems;
/// requestIds of work items in the queue
std::set<uint64_t> mPendingWorkItemsReqIds;
/// work items currently in progress
std::unordered_map<uint64_t, std::shared_ptr<WorkItem>> mInProgressWorkItems;
mutable std::mutex mMutex;
};
struct BenchInfo
{
BenchInfo() = default;
@ -748,7 +545,7 @@ public:
std::vector<std::string> headers = {"num_samples", "num_error_samples", "total_latency(ms)",
"seq_throughput(seq/sec)", "token_throughput(token/sec)", "avg_sequence_latency(ms)",
"max_sequence_latency(ms)", "min_sequence_latency(ms)", "p99_sequence_latency(ms)",
"p90_sequence_latency(ms)", "p50_sequence_latency(ms)"};
"p90_sequence_latency(ms)", "p50_sequence_latency(ms)", "avg_acceptance_rate(tokens/decoding steps)"};
if (mStreaming)
{
@ -772,7 +569,8 @@ public:
outputFile << "\n";
outputFile << mNumSamples << "," << mNumErrorSamples << "," << mTotalLatency << "," << mSeqThroughput
<< "," << mTokenThroughput << "," << mAvgSeqLatency << "," << mMaxSeqLatency << ","
<< mMinSeqLatency << "," << mP99SeqLatency << "," << mP90SeqLatency << "," << mP50SeqLatency;
<< mMinSeqLatency << "," << mP99SeqLatency << "," << mP90SeqLatency << "," << mP50SeqLatency
<< "," << mAcceptanceRate;
if (mStreaming)
{
outputFile << "," << mAvgFtLatency << "," << mMaxFtLatency << "," << mMinFtLatency << ","
@ -869,11 +667,9 @@ public:
std::optional<std::filesystem::path> const& encoderTrtEnginePath, TrtGptModelType modelType,
int32_t maxBeamWidth, texec::CapacitySchedulerPolicy capacitySchedulerPolicy,
BenchmarkParams const& benchmarkParams, std::shared_ptr<Recorder> recorder, std::chrono::milliseconds waitSleep,
std::optional<uint64_t> const staticEmulatedBatchSize, bool logIterationData,
texec::ModelType executorModelType)
bool logIterationData, texec::ModelType executorModelType)
: mRecorder(std::move(recorder))
, mWaitSleep(waitSleep)
, mStaticEmulatedBatchSize(staticEmulatedBatchSize)
, mConcurrency(benchmarkParams.concurrency)
, mActiveCount(0)
, mNumFinished(0)
@ -1048,7 +844,6 @@ private:
std::thread mCollectStatsThread;
std::shared_ptr<Recorder> mRecorder;
std::chrono::milliseconds mWaitSleep;
std::optional<int> mStaticEmulatedBatchSize;
std::optional<int> mConcurrency;
std::atomic<uint64_t> mActiveCount;
std::atomic<uint64_t> mNumFinished;
@ -1056,288 +851,6 @@ private:
bool mLogIterationData;
}; // class ExecutorServer
class GptServer
{
public:
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType,
TrtGptModelOptionalParams const& optionalParams, std::shared_ptr<Recorder> recorder,
std::optional<uint64_t> terminateReqId, std::chrono::milliseconds waitSleep,
std::optional<SizeType32> const staticEmulatedBatchSize,
std::optional<std::chrono::milliseconds> const batchTimeout, bool logIterationData, bool excludeInputInOutput)
: mRecorder(std::move(recorder))
, mTerminateReqId(terminateReqId)
, mWaitSleep(waitSleep)
, mStaticEmulatedBatchSize(staticEmulatedBatchSize)
, mBatchTimeout(batchTimeout.value_or(std::chrono::milliseconds{0}))
, mActiveCount(0)
{
auto const jsonConfig = GptJsonConfig::parse(trtEnginePath / "config.json");
mWorldConfig = WorldConfig::mpi(jsonConfig.getGpusPerNode(), jsonConfig.getTensorParallelism(),
jsonConfig.getPipelineParallelism(), optionalParams.deviceIds);
auto& comm = COMM_SESSION;
mCommTensorParallel = std::make_shared<tensorrt_llm::mpi::MpiComm>(
comm.split(mWorldConfig.getPipelineParallelRank(), mWorldConfig.getTensorParallelRank()));
mCommPipelineParallel = std::make_shared<tensorrt_llm::mpi::MpiComm>(
comm.split(mWorldConfig.getTensorParallelRank(), mWorldConfig.getPipelineParallelRank()));
ReturnBatchManagerStatsCallback iterationDataCallback = [this, logIterationData](std::string const& log)
{
if (logIterationData)
{
TLLM_LOG_INFO(log);
}
if (mStaticEmulatedBatchSize)
{
auto const json = nlohmann::json::parse(log);
auto const activeRequests = json["Active Request Count"];
TLLM_CHECK(activeRequests <= mStaticEmulatedBatchSize.value());
}
};
mBatchManager = std::make_shared<GptManager>(
trtEnginePath, modelType, [this](int max_num_requests) { return getInferenceRequests(max_num_requests); },
[this](uint64_t requestId, std::list<NamedTensor> const& response_tensors, bool final_response,
std::string const& errMsg)
{ return sendResponse(requestId, response_tensors, final_response, errMsg); },
nullptr, iterationDataCallback, optionalParams, terminateReqId, excludeInputInOutput);
}
~GptServer()
{
if (mInferReqWaitThread)
{
mInferReqWaitThread->join();
mInferReqWaitThread.reset(nullptr);
}
mWorkItemsQueue.clear();
}
std::string getLayerProfileInfo()
{
return mBatchManager->getLayerProfileInfo();
}
void setLayerProfiler()
{
return mBatchManager->setLayerProfiler();
}
void enqueue(std::shared_ptr<InferenceRequest> const& request)
{
TLLM_CHECK(request != nullptr);
auto const requestId = request->getRequestId();
if (requestId == mTerminateReqId)
{
mWorkItemsQueue.push(request, requestId);
return;
}
// Enqueue
try
{
mRecorder->recordStart(request, requestId);
mWorkItemsQueue.push(request, requestId);
}
catch (tc::TllmException const& e)
{
throw;
}
catch (std::exception const& e)
{
TLLM_THROW("%s", e.what());
}
}
void resetBatchDeadline()
{
mBatchDeadline = (std::chrono::steady_clock::now() + mBatchTimeout).time_since_epoch();
}
void waitForEmpty() const
{
while (!mWorkItemsQueue.empty())
{
std::this_thread::sleep_for(mWaitSleep);
}
}
void waitBatchManager() const
{
mBatchManager->waitUntilTerminate();
}
void shutdown() const
{
mBatchManager->shutdown();
}
// Return up to max_num_requests inference requests.
std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(int const max_num_requests)
{
if (mInferReqWaitThread)
{
mInferReqWaitThread->join();
mInferReqWaitThread.reset(nullptr);
}
std::list<std::shared_ptr<InferenceRequest>> inferenceRequests;
auto& comm = COMM_SESSION;
if (max_num_requests > 0)
{
auto rank = comm.getRank();
if (rank == 0)
{
auto const numNewWorkItems = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
static_cast<int64_t>(max_num_requests));
bool const timeout = std::chrono::steady_clock::now().time_since_epoch() > mBatchDeadline.load();
bool readyForNextBatch = numNewWorkItems > 0 && timeout;
if (mStaticEmulatedBatchSize)
{
if (numNewWorkItems > 0)
{
bool const previousBatchFinished = mActiveCount == 0;
bool const haveEnoughForNextBatch = numNewWorkItems >= mStaticEmulatedBatchSize.value();
readyForNextBatch = previousBatchFinished && (timeout || haveEnoughForNextBatch);
}
if (numNewWorkItems == 0 || readyForNextBatch)
{
// Timeout should only begin once we have at least 1 pending request.
// Reset timeout when no requests are pending or we submit a new batch.
resetBatchDeadline();
}
}
if (readyForNextBatch)
{
// Only add a single batch at a time when emulating static batching
auto const numItemsToAdd = std::min(
numNewWorkItems, static_cast<int64_t>(mStaticEmulatedBatchSize.value_or(numNewWorkItems)));
mActiveCount += numItemsToAdd;
while (inferenceRequests.size() < numItemsToAdd)
{
auto [workItem, markedInProgress] = mWorkItemsQueue.pop();
if (markedInProgress)
{
inferenceRequests.emplace_back(workItem->getInferenceRequest());
}
else
{
auto warnStr = tc::fmtstr(
"request Id %lu has been stopped. Request is ignored.", workItem->requestId());
TLLM_LOG_WARNING(warnStr);
sendResponse(workItem->requestId(), {}, true, warnStr);
}
}
}
if (mWorldConfig.isTensorParallel())
{
auto numNewWorkItems = static_cast<int64_t>(inferenceRequests.size());
if (numNewWorkItems > 0 || mBatchManager->getNumActiveRequests() > 0)
{
mCommTensorParallel->bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
}
if (numNewWorkItems > 0)
{
std::vector<int64_t> packed;
for (auto const& infReq : inferenceRequests)
{
auto vpacked = infReq->serialize();
packed.push_back(static_cast<int64_t>(vpacked.size()));
packed.insert(
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
}
mCommTensorParallel->bcast(packed, 0);
}
}
}
else
{
// subordinate ranks hang until master rank sends work
if (mWorldConfig.isFirstPipelineParallelRank())
{
int64_t numNewWorkItems = 0;
mCommTensorParallel->bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
if (numNewWorkItems > 0)
{
std::vector<int64_t> packed;
mCommTensorParallel->bcast(packed, 0);
int64_t* packed_ptr = packed.data();
for (int64_t count = 0; count < numNewWorkItems; ++count)
{
int64_t n = *(packed_ptr++);
auto infReq = InferenceRequest::deserialize(packed_ptr);
packed_ptr += n;
inferenceRequests.emplace_back(infReq);
}
}
}
else
{
auto const peer = mWorldConfig.getPipelineParallelRank() - 1;
inferenceRequestsRecv(mCommPipelineParallel, inferenceRequests, peer);
}
}
if (!mWorldConfig.isLastPipelineParallelRank())
{
auto const peer = mWorldConfig.getPipelineParallelRank() + 1;
auto inferReqAsyncSndHdl
= std::make_unique<InferenceRequestsAsyncSend>(mCommPipelineParallel, inferenceRequests, peer);
mInferReqWaitThread = std::make_unique<std::thread>([handle = std::move(inferReqAsyncSndHdl)]() {});
}
}
return inferenceRequests;
}
void sendResponse(uint64_t requestId, [[maybe_unused]] std::list<NamedTensor> const& response_tensors,
bool final_response, [[maybe_unused]] std::string const& errMsg)
{
// `response_tensors` contains `outputIds, sequenceLength, [contextLogits, generationLogits], logProbs,
// cumLogProbs`. `contextLogits, generationLogits` are optional, only contained when `gather_context_logits` and
// `gather_generation_logits` are enabled respectively. Or enable 'gather_all_token_logits' to enable both of
// them.
try
{
if (final_response)
{
mWorkItemsQueue.markFinished(requestId);
mRecorder->recordEnd(requestId, response_tensors, !errMsg.empty());
mActiveCount--;
}
else
{
if (errMsg.empty())
{
mRecorder->recordToken(requestId, response_tensors);
}
}
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR("Failed to send response for requestId %lu\n%s", requestId, e.what());
}
}
private:
std::shared_ptr<GptManager> mBatchManager;
std::shared_ptr<Recorder> mRecorder;
WorkItemsQueue mWorkItemsQueue;
std::optional<uint64_t> mTerminateReqId;
std::chrono::milliseconds mWaitSleep;
std::optional<SizeType32> mStaticEmulatedBatchSize;
std::chrono::milliseconds mBatchTimeout;
std::atomic<std::chrono::steady_clock::time_point::duration> mBatchDeadline;
std::atomic<uint64_t> mActiveCount;
WorldConfig mWorldConfig;
std::shared_ptr<tensorrt_llm::mpi::MpiComm> mCommTensorParallel;
std::shared_ptr<tensorrt_llm::mpi::MpiComm> mCommPipelineParallel;
std::unique_ptr<std::thread> mInferReqWaitThread;
}; // class GptServer
namespace
{
@ -1418,60 +931,6 @@ std::vector<double> computeTimeDelays(BenchmarkParams const& benchmarkParams, in
return timeDelays;
}
std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const& sample, bool streaming,
ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId, ITensor::SharedPtr const& padId,
BufferManager const& bufferManager, ITensor::SharedPtr const& returnContextLogits = nullptr,
ITensor::SharedPtr const& returnGenerationLogits = nullptr, ITensor::SharedPtr const& loraWeights = nullptr,
ITensor::SharedPtr const& loraConfig = nullptr,
std::optional<tensorrt_llm::executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt)
{
auto request = std::make_shared<InferenceRequest>(reqId);
auto const& inputIds = sample.inputIds;
request->setInputIds(bufferManager.copyFrom(
inputIds, ITensor::makeShape({static_cast<SizeType32>(inputIds.size())}), MemoryType::kCPU));
auto const requestOutputLen = sample.outputLen;
request->setMaxNewTokens(bufferManager.copyFrom(&requestOutputLen, ITensor::makeShape({1, 1}), MemoryType::kCPU));
request->setBeamWidth(beamWidthTensor);
if (eosId != nullptr)
{
request->setEndId(eosId);
}
if (padId != nullptr)
{
request->setPadId(padId);
}
if (returnContextLogits)
{
request->setReturnContextLogits(returnContextLogits);
}
if (returnGenerationLogits)
{
request->setReturnGenerationLogits(returnGenerationLogits);
}
if (sample.taskId >= 0)
{
uint64_t taskId = static_cast<uint64_t>(sample.taskId);
request->setLoraTaskId(bufferManager.copyFrom(&taskId, ITensor::makeShape({1}), MemoryType::kPINNEDPOOL));
}
if (loraWeights)
{
request->setLoraWeights(loraWeights);
}
if (loraConfig)
{
request->setLoraConfig(loraConfig);
}
if (lookaheadConfig)
{
request->setLookaheadConfig(lookaheadConfig.value());
}
if (streaming)
{
request->setIsStreaming(true);
}
return request;
}
texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamWidth,
std::optional<SizeType32> const& eosId, std::optional<SizeType32> const& padId, bool streaming = false,
bool const& returnContextLogits = false, bool const& returnGenerationLogits = false,
@ -1495,185 +954,6 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
}
void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType,
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
std::optional<TokenIdType> const& eosId, std::optional<TokenIdType> const& padId,
BenchmarkParams const& benchmarkParams, texec::CapacitySchedulerPolicy capacitySchedulerPolicy,
std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits,
std::optional<SizeType32> const staticEmulatedBatchSize,
std::optional<std::chrono::milliseconds> const batchTimeout, bool logIterationData, bool excludeInputInOutput,
std::string const& responsesJsonFile, std::optional<SizeType32> const maxPromptLen, bool dumpProfile)
{
TrtGptModelOptionalParams optionalParams;
if (benchmarkParams.maxTokensInPagedKvCache)
{
optionalParams.kvCacheConfig.maxTokens = benchmarkParams.maxTokensInPagedKvCache;
}
if (benchmarkParams.freeGpuMemoryFraction)
{
optionalParams.kvCacheConfig.freeGpuMemoryFraction = benchmarkParams.freeGpuMemoryFraction;
}
if (benchmarkParams.crossKvCacheFraction)
{
optionalParams.kvCacheConfig.crossKvCacheFraction = benchmarkParams.crossKvCacheFraction;
}
if (benchmarkParams.maxAttentionWindowVec)
{
optionalParams.kvCacheConfig.maxAttentionWindowVec = benchmarkParams.maxAttentionWindowVec;
}
if (benchmarkParams.sinkTokenLength)
{
optionalParams.kvCacheConfig.sinkTokenLength = benchmarkParams.sinkTokenLength;
}
optionalParams.kvCacheConfig.enableBlockReuse = benchmarkParams.enableBlockReuse;
optionalParams.enableChunkedContext = benchmarkParams.enableChunkedContext;
optionalParams.enableTrtOverlap = benchmarkParams.enableTrtOverlap;
optionalParams.peftCacheManagerConfig.hostCacheSize = benchmarkParams.loraHostCacheSize;
optionalParams.peftCacheManagerConfig.numDeviceModuleLayer = benchmarkParams.loraDeviceNumModLayers;
optionalParams.peftCacheManagerConfig.numPutWorkers = 4;
optionalParams.peftCacheManagerConfig.numEnsureWorkers = 4;
optionalParams.peftCacheManagerConfig.numCopyStreams = 4;
optionalParams.kvCacheConfig.hostCacheSize = benchmarkParams.kvHostCacheSize;
optionalParams.kvCacheConfig.onboardBlocks = benchmarkParams.kvOnboardBlocks;
optionalParams.gpuWeightsPercent = benchmarkParams.gpuWeightsPercent;
optionalParams.maxBeamWidth = beamWidth;
optionalParams.maxBatchSize = benchmarkParams.maxBatchSize;
optionalParams.maxNumTokens = benchmarkParams.maxNumTokens;
optionalParams.schedulerConfig = texec::SchedulerConfig{capacitySchedulerPolicy};
optionalParams.decodingConfig
= texec::DecodingConfig(benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa()
: benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead()
: texec::DecodingMode::Auto(),
benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices);
optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode,
benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode, benchmarkParams.cudaGraphCacheSize);
auto const jsonConfig = GptJsonConfig::parse(engineDir / "config.json");
auto const worldConfig = WorldConfig::mpi(jsonConfig.getGpusPerNode(), jsonConfig.getTensorParallelism(),
jsonConfig.getPipelineParallelism(), optionalParams.deviceIds);
BufferManager bufferManager{std::make_shared<CudaStream>()}; // the stream is not used
ITensor::SharedPtr beamWidthTensor{
bufferManager.copyFrom(&beamWidth, ITensor::makeShape({1}), MemoryType::kPINNEDPOOL)};
// Load dataset
auto const samples = parseWorkloadJson(datasetPath, maxNumSamples, maxPromptLen);
auto const numSamples = samples.size();
auto recorder = std::make_shared<Recorder>(
opCsvFile, benchmarkParams.streaming, beamWidth, responsesJsonFile, excludeInputInOutput);
uint64_t terminateReqId = numSamples + 1;
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, optionalParams, recorder, terminateReqId,
waitSleep, staticEmulatedBatchSize, batchTimeout, logIterationData, excludeInputInOutput);
ITensor::SharedPtr eosIdTensor{
eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNEDPOOL) : nullptr};
ITensor::SharedPtr padIdTensor{
padId ? bufferManager.copyFrom(&padId.value(), ITensor::makeShape({1}), MemoryType::kPINNEDPOOL) : nullptr};
ITensor::SharedPtr returnContextLogitsFlagTensor{returnContextLogits
? bufferManager.copyFrom(&returnContextLogits, ITensor::makeShape({1}), MemoryType::kPINNEDPOOL)
: nullptr};
ITensor::SharedPtr returnGenerationLogitsFlagTensor{returnGenerationLogits
? bufferManager.copyFrom(&returnGenerationLogits, ITensor::makeShape({1}), MemoryType::kPINNEDPOOL)
: nullptr};
if (worldConfig.getRank() == 0)
{
if (benchmarkParams.loraDir)
{
auto startLoraLoad = std::chrono::steady_clock::now();
LoraLib loras(benchmarkParams.loraDir.value());
SizeType32 reqId = 0;
gptServer->resetBatchDeadline();
for (auto const& [taskId, p] : loras.getLoras())
{
reqId++;
if (reqId == terminateReqId)
{
reqId++;
}
Sample s{std::vector<int32_t>{1, 2, 3, 4, 5}, 1, static_cast<int32_t>(taskId)};
auto r = makeRequest(reqId, s, benchmarkParams.streaming, beamWidthTensor, eosIdTensor, padIdTensor,
bufferManager, nullptr, nullptr, p.first, p.second);
gptServer->enqueue(r);
}
gptServer->waitForEmpty();
auto endLoraLoad = std::chrono::steady_clock::now();
printf("[BENCHMARK] time to preload LoRAs(ms) %.2f\n",
std::chrono::duration<float, std::milli>(endLoraLoad - startLoraLoad).count());
}
// Warm up
gptServer->resetBatchDeadline();
SizeType32 reqId = 0;
for (auto i = 0; i < warmUp; ++i)
{
++reqId;
if (i == terminateReqId)
++reqId;
auto request = makeRequest(reqId, samples[0], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
padIdTensor, bufferManager, nullptr, nullptr, nullptr, nullptr, benchmarkParams.requestLookaheadConfig);
gptServer->enqueue(request);
}
gptServer->waitForEmpty();
// Time delay
auto timeDelays = computeTimeDelays(benchmarkParams, numSamples - 1);
// Benchmark
recorder->initialize();
gptServer->resetBatchDeadline();
for (std::size_t i = 0; i < numSamples; ++i)
{
auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor, nullptr,
nullptr, benchmarkParams.requestLookaheadConfig);
gptServer->enqueue(request);
if (i < numSamples - 1)
{
auto delayInMs = static_cast<int>(timeDelays.at(i) * 1000);
std::chrono::milliseconds delay(delayInMs);
std::this_thread::sleep_for(delay);
}
}
gptServer->waitForEmpty();
recorder->finalize();
recorder->calculateMetrics();
recorder->report();
recorder->writeOpMetricsToCsv();
recorder->dumpResponseSeqs();
if (dumpProfile)
{
// Do per-layer profiling after normal benchmarking to avoid introducing perf overhead.
gptServer->resetBatchDeadline();
gptServer->setLayerProfiler();
for (std::size_t i = 0; i < numSamples; ++i)
{
auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor,
nullptr, nullptr, benchmarkParams.requestLookaheadConfig);
gptServer->enqueue(request);
}
gptServer->waitForEmpty();
if (worldConfig.getRank() == 0)
{
printf("[BENCHMARK] Per layer performance profile\n%s\n", gptServer->getLayerProfileInfo().c_str());
}
}
// Send terminateReqId to terminate servers on all ranks
// Server on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId));
}
// Wait until benchmarking is done and batch manager is terminated
gptServer->waitBatchManager();
}
void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngineDir,
std::optional<std::filesystem::path> const& encoderEngineDir, TrtGptModelType modelType,
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
@ -1698,16 +978,15 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
TLLM_CHECK_WITH_INFO(
decoderEngineDir.has_value(), "decoder models require a path to decoder engine in executor benchmark.");
executorServer = std::make_shared<ExecutorServer>(decoderEngineDir.value(), std::nullopt, modelType, beamWidth,
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData,
executorModelType);
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
}
else if (executorModelType == texec::ModelType::kENCODER_DECODER)
{
TLLM_CHECK_WITH_INFO(encoderEngineDir.has_value(),
"encoder-decoder models require a path to encoder engine in executor benchmark.");
executorServer = std::make_shared<ExecutorServer>(decoderEngineDir.value(), encoderEngineDir.value(), modelType,
beamWidth, capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize,
logIterationData, executorModelType);
executorServer
= std::make_shared<ExecutorServer>(decoderEngineDir.value(), encoderEngineDir.value(), modelType, beamWidth,
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
try
{
std::ifstream decoderJsonConfigPath(decoderEngineDir.value() / "config.json");
@ -1733,8 +1012,7 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
TLLM_CHECK_WITH_INFO(
encoderEngineDir.has_value(), "encoder models require a path to encoder engine in executor benchmark.");
executorServer = std::make_shared<ExecutorServer>(std::nullopt, encoderEngineDir.value(), modelType, beamWidth,
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData,
executorModelType);
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
}
else
{
@ -1937,6 +1215,8 @@ int main(int argc, char* argv[])
options.add_options()("h,help", "Print usage");
options.add_options()("engine_dir, decoder_engine_dir", "Directory that store the engines of decoder models.",
cxxopts::value<std::string>());
options.add_options()(
"encoder_engine_dir", "Directory that store the engines of the encoder models.", cxxopts::value<std::string>());
options.add_options()(
"api", "API type: gptManager or executor.", cxxopts::value<std::string>()->default_value("executor"));
options.add_options()("type", "Batching type: IFB, UIFB (unfused IFB) or V1 (non-IFB) batching.",
@ -1971,8 +1251,6 @@ int main(int argc, char* argv[])
options.add_options()("max_batch_size", "The max runtime batch size when benchmarking", cxxopts::value<int>());
options.add_options()(
"max_num_tokens", "The max runtime number of tokens per batch when benchmarking", cxxopts::value<int>());
options.add_options()("enable_trt_overlap", "Overlap TRT context preparation and execution",
cxxopts::value<bool>()->default_value("false"));
options.add_options()(
"enable_batch_size_tuning", "Dynamic tuning of batch size", cxxopts::value<bool>()->default_value("false"));
options.add_options()("enable_exp_delays", "Enables exponential delay distr to mimic real world request arrival",
@ -1991,15 +1269,8 @@ int main(int argc, char* argv[])
"Choose scheduler policy between max_utilization/guaranteed_no_evict/static_batch.",
cxxopts::value<std::string>()->default_value("guaranteed_no_evict"));
options.add_options()("first_batch_delay",
"Delay before submitting the first batch of requests. This can be used to increase the size of the first "
"batch.",
cxxopts::value<int32_t>());
options.add_options()("static_emulated_batch_size",
"Emulate static batching performance with the provided batch size.", cxxopts::value<SizeType32>());
options.add_options()("static_emulated_timeout",
"Timeout (ms) before launching a partial batch in emulated static batching mode",
cxxopts::value<int32_t>()->default_value("500"));
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
cxxopts::value<std::string>()->default_value("error"));
options.add_options()("log_iteration_data", "On each decoder iteration, print batch state metadata.",
@ -2012,22 +1283,11 @@ int main(int argc, char* argv[])
options.add_options()("kv_host_cache_bytes",
"Size of secondary memory pool used for offloading kv cache blocks (in bytes).",
cxxopts::value<size_t>()->default_value("0"));
options.add_options()("kv_dont_onboard_blocks",
"If offloaded blocks should be onboarded to primary memory before reuse",
cxxopts::value<bool>()->default_value("false"));
options.add_options()("exclude_input_in_output_seq",
"When enabled, GptManager will exclude the input sequence from output. (Only works if --api is gptManager)",
cxxopts::value<bool>());
options.add_options()("responses_json_file",
"When specified, dumps the responses to JSON file. (only works if --api is gptManager)",
cxxopts::value<std::string>()->default_value(""));
options.add_options()("kv_onboard_blocks", "If offloaded blocks should be onboarded to primary memory before reuse",
cxxopts::value<bool>()->default_value("true"));
options.add_options()(
"max_prompt_len", "Truncate all prompts from dataset to the length specified.", cxxopts::value<SizeType32>());
options.add_options()("dump_profile", "Print profile information per layer.", cxxopts::value<bool>());
options.add_options()("gpu_weights_percent",
"Specify the percentage of weights that reside on GPU (from 0.0 to 1.0).",
cxxopts::value<float>()->default_value("1.0"));
@ -2037,8 +1297,6 @@ int main(int argc, char* argv[])
options.add_options()("multi_block_mode",
"Distribute the work across multiple CUDA thread-blocks on the GPU for masked MHA kernel",
cxxopts::value<bool>()->default_value("true"));
options.add_options()(
"encoder_engine_dir", "Directory that store the engines of the encoder models.", cxxopts::value<std::string>());
options.add_options()("cuda_graph_mode", "When enabled, inference is executed with cuda graph.",
cxxopts::value<bool>()->default_value("false"));
options.add_options()("cuda_graph_cache_size",
@ -2051,7 +1309,6 @@ int main(int argc, char* argv[])
options.add_options()("executor_lookahead_config",
"lookahead config in the format of [max_window_size, max_ngram_size, max_verification_set_size]",
cxxopts::value<std::string>());
options.add_options()("request_lookahead_config",
"lookahead config in the format of [max_window_size, max_ngram_size, max_verification_set_size], and each <= "
"executor lookahead config",
@ -2073,9 +1330,6 @@ int main(int argc, char* argv[])
return 1;
}
// Argument: API
auto const api = result["api"].as<std::string>();
// Argument: Batching Type
auto const type = result["type"].as<std::string>();
TrtGptModelType modelType{TrtGptModelType::V1};
@ -2153,9 +1407,6 @@ int main(int argc, char* argv[])
}
}
// Argument: Enable TRT overlap
benchmarkParams.enableTrtOverlap = result["enable_trt_overlap"].as<bool>();
// Argument: Enable dynamic tuning of batch size
benchmarkParams.enableBatchSizeTuning = result["enable_batch_size_tuning"].as<bool>();
@ -2228,7 +1479,7 @@ int main(int argc, char* argv[])
benchmarkParams.kvHostCacheSize = result["kv_host_cache_bytes"].as<size_t>();
// Argument: If offloaded blocks should be onboarded to primary memory before they are reused.
benchmarkParams.kvOnboardBlocks = !result["kv_dont_onboard_blocks"].as<bool>();
benchmarkParams.kvOnboardBlocks = result["kv_onboard_blocks"].as<bool>();
// Argument: Medusa choices for the Medusa speculative decoding.
if (result.count("medusa_choices"))
@ -2255,7 +1506,7 @@ int main(int argc, char* argv[])
// Argument: cuda_graph_mode
benchmarkParams.cudaGraphMode = result["cuda_graph_mode"].as<bool>();
// Argument: cuda_graph_mode
// Argument: cuda_graph_cache_size
benchmarkParams.cudaGraphCacheSize = result["cuda_graph_cache_size"].as<SizeType32>();
std::optional<TokenIdType> padId;
@ -2268,20 +1519,11 @@ int main(int argc, char* argv[])
// Argument: End-of-sentence token id
std::optional<TokenIdType> eosId = result["eos_id"].as<TokenIdType>();
std::optional<std::chrono::milliseconds> batchTimeout;
// Argument: first_batch_delay
if (result.count("first_batch_delay"))
{
batchTimeout = std::chrono::milliseconds{result["first_batch_delay"].as<int32_t>()};
}
std::optional<SizeType32> staticEmulatedBatchSize;
// Argument: Static emulated batch size
if (result.count("static_emulated_batch_size"))
{
staticEmulatedBatchSize = result["static_emulated_batch_size"].as<SizeType32>();
batchTimeout = std::chrono::milliseconds{result["static_emulated_timeout"].as<int32_t>()};
}
// Argument: Scheduler policy
@ -2313,7 +1555,6 @@ int main(int argc, char* argv[])
}
// Argument: GPU weights percentage
std::istringstream ssGpuPercentArg;
auto gpuWeightsPercent = result["gpu_weights_percent"].as<float>();
if (gpuWeightsPercent < 0 || gpuWeightsPercent > 1)
{
@ -2351,29 +1592,11 @@ int main(int argc, char* argv[])
return 1;
}
// Argument: dump profile
bool dumpProfile = result["dump_profile"].as<bool>();
initTrtLlmPlugins(logger.get());
if (api == "gptManager")
{
try
{
benchmarkGptManager(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile,
maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams,
capacitySchedulerPolicy, waitSleep, returnContextLogits, returnGenerationLogits,
staticEmulatedBatchSize, batchTimeout, logIterationData,
result["exclude_input_in_output_seq"].as<bool>(), result["responses_json_file"].as<std::string>(),
maxPromptLen, dumpProfile);
}
catch (std::exception const& e)
{
TLLM_LOG_ERROR(e.what());
return 1;
}
}
else if (api == "executor")
// Argument: API
auto const api = result["api"].as<std::string>();
if (api == "executor")
{
texec::ModelType executorModelType;
std::optional<std::string> decoderEngineDir = std::nullopt, encoderEngineDir = std::nullopt;
@ -2409,6 +1632,11 @@ int main(int argc, char* argv[])
return 1;
}
}
else if (api == "gptManager")
{
TLLM_LOG_ERROR("gptManager is deprecated, please use the executor API.");
return 1;
}
else
{
TLLM_LOG_ERROR("api parameter must be gptManager or executor");

View File

@ -0,0 +1,50 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "common.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/common/algorithm.h"
#include "tensorrt_llm/common/optionalRef.h"
#include "tensorrt_llm/runtime/common.h"
namespace tensorrt_llm::batch_manager
{
namespace tle = tensorrt_llm::executor;
class AllocateKvCache : Algorithm
{
using KVCacheManager = tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager;
template <typename T>
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
public:
constexpr static auto name{"AllocateKvCache"};
using SizeType32 = tensorrt_llm::runtime::SizeType32;
AllocateKvCache() = default;
void operator()(KVCacheManager& kvCacheManager, RequestVector& contextRequests,
RequestVector const& generationRequests, runtime::ModelConfig const& modelConfig,
OptionalRef<KVCacheManager> crossKvCacheManager = std::nullopt) const;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -0,0 +1,44 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "common.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/batch_manager/sequenceSlotManager.h"
#include "tensorrt_llm/common/algorithm.h"
#include "tensorrt_llm/runtime/common.h"
namespace tensorrt_llm::batch_manager
{
namespace tle = tensorrt_llm::executor;
class AssignReqSeqSlots : Algorithm
{
using SizeType32 = tensorrt_llm::runtime::SizeType32;
public:
constexpr static auto name{"AssignReqSeqSlots"};
AssignReqSeqSlots() = default;
void operator()(SequenceSlotManager& seqSlotManager, RequestVector const& contextRequests,
RequestVector const& generationRequests) const;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -19,6 +19,7 @@
#include "common.h"
#include "tensorrt_llm/batch_manager/llmRequest.h"
#include "tensorrt_llm/common/algorithm.h"
#include "tensorrt_llm/common/optionalRef.h"
#include "tensorrt_llm/runtime/common.h"
#include <variant>
@ -35,6 +36,7 @@ namespace tensorrt_llm::batch_manager
{
using tensorrt_llm::runtime::SizeType32;
using common::OptionalRef;
/// @brief This scheduler takes into account the given request capacity and the KV cache capacity.
/// Depending on the CapacitySchedulerPolicy it will schedule already started and new requests,
@ -69,8 +71,6 @@ class MaxRequestsScheduler : public BaseCapacityScheduler
{
public:
explicit MaxRequestsScheduler(SizeType32 maxNumRequests,
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
@ -80,8 +80,6 @@ public:
private:
SizeType32 mMaxNumRequests;
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr};
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr};
};
/// @brief Schedule requests using the MAX_UTILIZATION policy
@ -90,24 +88,21 @@ private:
class MaxUtilizationScheduler : public BaseCapacityScheduler
{
public:
MaxUtilizationScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
std::shared_ptr<BasePeftCacheManager> peftCacheManager, bool manyMicroBatches,
MaxUtilizationScheduler(SizeType32 maxNumRequests, bool manyMicroBatches,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(kv_cache_manager::KVCacheManager& kvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
private:
/// @return {fitsKvCache, fitsPeft}
std::pair<bool, bool> trySchedulingRequestMaxUtilization(std::shared_ptr<LlmRequest> const& req,
std::pair<bool, bool> trySchedulingRequestMaxUtilization(kv_cache_manager::KVCacheManager const& kvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, std::shared_ptr<LlmRequest> const& req,
RequestVector& scheduledRequests, SizeType32& numScheduledBlocks, SizeType32& numScheduledPeftPages,
std::unordered_set<uint64_t>& seenTaskIds) const;
SizeType32 mMaxNumRequests;
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr};
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr};
std::shared_ptr<BasePeftCacheManager> mPeftCacheManager{nullptr};
/// @brief Boolean that indicates if multiple micro batches might be in flight
bool mManyMicroBatches;
};
@ -117,36 +112,36 @@ class GuaranteedNoEvictScheduler : public BaseCapacityScheduler
{
public:
GuaranteedNoEvictScheduler(SizeType32 maxNumRequests,
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
kv_cache_manager::KVCacheManager const& kvCacheManager,
OptionalRef<kv_cache_manager::KVCacheManager const> crossKvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
protected:
[[nodiscard]] std::tuple<RequestVector, RequestVector> forwardImpl(
RequestList const& activeRequests, bool staticBatchScheduling) const;
template <bool StaticBatchScheduling>
[[nodiscard]] std::tuple<RequestVector, RequestVector> impl(kv_cache_manager::KVCacheManager const& kvCacheManager,
OptionalRef<kv_cache_manager::KVCacheManager const> crossKvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
private:
SizeType32 mMaxNumRequests;
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr};
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr};
std::shared_ptr<BasePeftCacheManager> mPeftCacheManager{nullptr};
};
/// @brief Schedule requests using the STATIC_BATCH policy
class StaticBatchScheduler : public GuaranteedNoEvictScheduler
{
public:
StaticBatchScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
StaticBatchScheduler(SizeType32 maxNumRequests,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
kv_cache_manager::KVCacheManager const& kvCacheManager,
OptionalRef<kv_cache_manager::KVCacheManager const> crossKvCacheManager,
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
};
class CapacityScheduler : public Algorithm
@ -154,29 +149,26 @@ class CapacityScheduler : public Algorithm
public:
constexpr static auto name{"CapacityScheduler"};
CapacityScheduler() = default;
CapacityScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool manyMicroBatches = false,
explicit CapacityScheduler(SizeType32 maxNumRequests, executor::CapacitySchedulerPolicy capacitySchedulerPolicy,
bool hasKvCacheManager, std::optional<bool> manyMicroBatches = std::nullopt,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
static CapacityScheduler make(SizeType32 maxNumRequests,
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool manyMicroBatches = false,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE)
{
return CapacityScheduler{maxNumRequests, std::move(kvCacheManager), std::move(crossKvCacheManager),
std::move(peftCacheManager), capacitySchedulerPolicy, manyMicroBatches, noScheduleUntilState,
noScheduleAfterState};
}
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
/**
* @brief Schedules requests following the selected policy.
*
* @param kvCacheManager Required in MaxUtilizationScheduler (as a ref) and in GuaranteedNoEvictScheduler and
* StaticBatchScheduler (as a const ref).
* @param crossKvCacheManager Optional used in GuaranteedNoEvictScheduler and StaticBatchScheduler.
* @param peftCacheManager Optional used in MaxUtilizationScheduler, GuaranteedNoEvictScheduler and
* StaticBatchScheduler.
* @param activeRequests
* @return std::tuple<RequestVector, RequestVector>, fittingRequests and pausedRequests respectively.
*/
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests,
OptionalRef<kv_cache_manager::KVCacheManager> kvCacheManager = std::nullopt,
OptionalRef<BasePeftCacheManager const> peftCacheManager = std::nullopt,
OptionalRef<kv_cache_manager::KVCacheManager const> crossKvCacheManager = std::nullopt) const;
private:
std::variant<std::monostate, MaxRequestsScheduler, MaxUtilizationScheduler, GuaranteedNoEvictScheduler,

View File

@ -18,6 +18,7 @@
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include <chrono>
#include <vector>
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
@ -33,33 +34,62 @@ public:
// TODO(TRTLLM-1564): Don't use a separate `initialize` function. Ensure eviction policies can't be in-between a
// state of construction and initialization.
virtual void initialize(std::vector<BlockPtr>& mAllBlocksById, std::vector<SizeType32> sizes,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt)
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority)
= 0;
/// @brief Get a free block from the specified cache level
/// @returns The pointer to the free block, along with whether it can be offloaded
virtual std::tuple<BlockPtr, bool> getFreeBlock(SizeType32 cacheLevel) = 0;
/// @brief Release a block. Prioritize the block for eviction if toFront=true
virtual void releaseBlock(BlockPtr block, bool toFront = false) = 0;
virtual void releaseBlock(BlockPtr block) = 0;
virtual void releaseBlock(BlockPtr block, bool toFront) = 0;
/// @brief Get the amount of free blocks in the primary memory pool
virtual SizeType32 getNumFreeBlocks(SizeType32 cacheLevel) = 0;
/// @brief Claim a free block. Called when the cache manager allocates or reuses a new block
virtual void claimBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority = std::nullopt) = 0;
virtual void claimBlock(BlockPtr block) = 0;
virtual void claimBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority,
std::optional<std::chrono::milliseconds> durationMs)
= 0;
/// @brief Perform any per-iteration bookkeeping
virtual void refresh() = 0;
};
struct ExpiringBlockComparator
{
inline bool operator()(BlockPtr const& a, BlockPtr const& b) const
{
// If two blocks expire in the same millisecond, their expiration times will be equal. As a fallback, check the
// raw pointer values.
return a->getExpirationTime() != b->getExpirationTime() ? a->getExpirationTime() < b->getExpirationTime()
: a.get() < b.get();
}
};
class LRUEvictionPolicy : public BaseEvictionPolicy
{
public:
void initialize(std::vector<BlockPtr>& mAllBlocksById, std::vector<SizeType32> sizes,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt) override;
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority) override;
std::tuple<BlockPtr, bool> getFreeBlock(SizeType32 cacheLevel) override;
void releaseBlock(BlockPtr block, bool toFront = false) override;
void releaseBlock(BlockPtr block) override;
void releaseBlock(BlockPtr block, bool toFront) override;
SizeType32 getNumFreeBlocks(SizeType32 cacheLevel) override;
void claimBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority = std::nullopt) override;
void claimBlock(BlockPtr block) override;
void claimBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority,
std::optional<std::chrono::milliseconds> durationMs) override;
// Check the expiring blocks heap, and move expired blocks back to the default queue.
void refresh() override;
// Making this public and virtual makes it possible to test.
[[nodiscard]] virtual std::chrono::steady_clock::time_point::duration getTime() const;
private:
// Check if the block should be added to mFreeQueues.
bool isReleasedLeafBlock(BlockPtr block);
bool isReleasedLeafBlock(BlockPtr const& block);
// Queues of available leaf blocks, split by cache level and priority level
std::vector<std::vector<FreeBlocksQueue>> mFreeQueues;
@ -71,6 +101,8 @@ private:
std::vector<SizeType32> mNumFreeBlocksPerLevel;
// Secondary offload threshold. Blocks below this priority won't be evicted.
executor::RetentionPriority mSecondaryOffloadMinPriority;
// Heap of block times
std::set<BlockPtr, ExpiringBlockComparator> mExpiringBlockHeap;
};
} // namespace tensorrt_llm::batch_manager::eviction_policy

View File

@ -64,6 +64,7 @@ auto constexpr kPromptEmbeddingTableName = "prompt_embedding_table";
auto constexpr kPromptVocabSizeName = "prompt_vocab_size";
auto constexpr kLoraTaskId = "lora_task_id";
auto constexpr kNoRepeatNgramSizeTensorName = "noRepeatNgramSize";
auto constexpr kSkipCrossAttnBlocksTensorName = "skipCrossAttnBlocks";
// weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
// where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
// each of the in / out tensors are first flattened and then concatenated together in the format above.

View File

@ -43,9 +43,9 @@ public:
std::optional<float> freeGpuMemoryFraction = std::nullopt, bool enableBlockReuse = false, bool useUvm = false,
std::optional<size_t> hostCacheSize = std::nullopt, bool onboardBlocks = true,
std::optional<float> crossKvCacheFraction = std::nullopt,
std::optional<SizeType32> secondaryOffloadMinPriority = std::nullopt)
std::optional<SizeType32> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0)
: maxTokens{maxTokens}
, maxAttentionWindowVec{maxAttentionWindowVec}
, maxAttentionWindowVec{std::move(maxAttentionWindowVec)}
, sinkTokenLength{sinkTokenLength}
, freeGpuMemoryFraction{freeGpuMemoryFraction}
, enableBlockReuse(enableBlockReuse)
@ -54,6 +54,7 @@ public:
, onboardBlocks(onboardBlocks)
, crossKvCacheFraction{crossKvCacheFraction}
, secondaryOffloadMinPriority(secondaryOffloadMinPriority)
, eventBufferMaxSize(eventBufferMaxSize)
{
}
@ -61,7 +62,8 @@ public:
: KvCacheConfig(kvCacheConfig.getMaxTokens(), kvCacheConfig.getMaxAttentionWindowVec(),
kvCacheConfig.getSinkTokenLength(), kvCacheConfig.getFreeGpuMemoryFraction(),
kvCacheConfig.getEnableBlockReuse(), false, kvCacheConfig.getHostCacheSize(),
kvCacheConfig.getOnboardBlocks(), kvCacheConfig.getCrossKvCacheFraction())
kvCacheConfig.getOnboardBlocks(), kvCacheConfig.getCrossKvCacheFraction(),
kvCacheConfig.getSecondaryOffloadMinPriority(), kvCacheConfig.getEventBufferMaxSize())
{
}
@ -71,7 +73,9 @@ public:
&& sinkTokenLength == other.sinkTokenLength && freeGpuMemoryFraction == other.freeGpuMemoryFraction
&& enableBlockReuse == other.enableBlockReuse && useUvm == other.useUvm
&& hostCacheSize == other.hostCacheSize && onboardBlocks == other.onboardBlocks
&& crossKvCacheFraction == other.crossKvCacheFraction;
&& crossKvCacheFraction == other.crossKvCacheFraction
&& secondaryOffloadMinPriority == other.secondaryOffloadMinPriority
&& eventBufferMaxSize == other.eventBufferMaxSize;
}
friend std::ostream& operator<<(std::ostream& os, KvCacheConfig const& self);
@ -89,5 +93,7 @@ public:
std::optional<float> crossKvCacheFraction;
// The minimum priority level to allow blocks to be offloaded to secondary memory.
std::optional<SizeType32> secondaryOffloadMinPriority;
// Maximum size of the KV Cache event buffer
size_t eventBufferMaxSize;
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

View File

@ -0,0 +1,96 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/executor/executor.h"
#include <chrono>
#include <condition_variable>
#include <deque>
#include <mutex>
#include <thread>
#include <vector>
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
using SizeType32 = tensorrt_llm::runtime::SizeType32;
class KVCacheBlock;
using BlockPtr = std::shared_ptr<KVCacheBlock>;
class KVCacheEventManager
{
public:
explicit KVCacheEventManager(size_t maxKVEventEntries);
~KVCacheEventManager();
KVCacheEventManager(KVCacheEventManager& other) = delete;
KVCacheEventManager& operator=(KVCacheEventManager& other) = delete;
KVCacheEventManager(KVCacheEventManager&& other) = delete;
KVCacheEventManager& operator=(KVCacheEventManager&& other) = delete;
void enqueueCreatedEvent(std::vector<SizeType32> const& numBlocksPerCacheLevel);
void enqueueStoredEvent(std::vector<BlockPtr> const& blocks);
void enqueueRemovedEvent(BlockPtr const& block);
void enqueueUpdatedEvent(executor::KVCacheUpdatedData const& data);
// Get events in mEvents. If there are no events, wait for a maximum of `timeout` milliseconds.
std::deque<executor::KVCacheEvent> getEvents(std::optional<std::chrono::milliseconds> timeout);
// Clear the event buffer, and asynchronously move events to the event queue.
void flush();
// Worker thread which adds events to mEvents.
void worker();
private:
// Add an event to mEventQueue
void enqueueEvent(executor::KVCacheEvent&& event);
/// @brief Flag to terminate the worker
bool mRun;
/// @brief Worker thread
std::thread mWorkerThread;
/// @brief The deque of events
std::deque<executor::KVCacheEvent> mEvents;
/// @brief Lock for mEvents
std::mutex mEventsMutex;
/// @brief Condition variable for blocking read
std::condition_variable mEmptyCV;
/// @brief List of buffers waiting awaiting insertion into mEvents. Consumed by the worker.
std::deque<std::deque<executor::KVCacheEvent>> mPendingEvents;
/// @brief Lock for mPendingEvents
std::mutex mPendingEventsMutex;
/// @brief Condition variable to notify worker thread
std::condition_variable mPendingEmptyCV;
/// @brief Buffer of events waiting to be added to the eventQueue. Only ever accessed by forward pass thread.
std::deque<executor::KVCacheEvent> mEventQueue;
/// @brief The maximum size of the deque
size_t mMaxSize;
/// @brief An auto-incrementing event id counter
size_t mEventId;
};
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

View File

@ -17,6 +17,7 @@
#pragma once
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
#include "tensorrt_llm/batch_manager/kvCacheEventManager.h"
#include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
#include "tensorrt_llm/common/optionalRef.h"
#include "tensorrt_llm/kernels/kvCacheIndex.h"
@ -40,11 +41,15 @@
namespace tensorrt_llm::batch_manager::eviction_policy
{
class BaseEvictionPolicy;
}
} // namespace tensorrt_llm::batch_manager::eviction_policy
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
static constexpr SizeType32 kPrimaryLevel = 0;
static constexpr SizeType32 kSecondaryLevel = 1;
class KVCacheBlock;
class KVCacheManager;
@ -67,6 +72,15 @@ struct BlockKey
LoraTaskIdType loraTaskId;
VecUniqueTokens uniqueTokens;
BlockKey() = default;
explicit BlockKey(bool hasLora, LoraTaskIdType loraTaskId, VecUniqueTokens uniqueTokens)
: hasLora{hasLora}
, loraTaskId{loraTaskId}
, uniqueTokens{std::move(uniqueTokens)}
{
}
bool operator==(BlockKey const& other) const noexcept
{
return (hasLora == other.hasLora && loraTaskId == other.loraTaskId && uniqueTokens == other.uniqueTokens);
@ -78,9 +92,9 @@ struct BlockKey
// Based on https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector/72073933#72073933
struct BlockKeyHasher
{
std::size_t operator()(BlockKey const& blockKey) const noexcept
std::size_t operator()(BlockKey const& blockKey, std::size_t parentHash = 0) const noexcept
{
size_t seed = blockKey.uniqueTokens.size();
size_t seed = blockKey.uniqueTokens.size() ^ parentHash;
for (auto const& uniqueToken : blockKey.uniqueTokens)
{
uint32_t a = static_cast<uint32_t>(uniqueToken.tokenId);
@ -169,7 +183,9 @@ public:
[[nodiscard]] bool hasSchedulingRefs() const;
void setBlockKey(BlockKey& blockKey, bool isFull);
void setBlockKey(BlockKey const& blockKey, bool isFull);
BlockKey getBlockKey();
[[nodiscard]] VecUniqueTokens const& getUniqueTokens() const;
@ -192,7 +208,19 @@ public:
void setPriority(executor::RetentionPriority priority);
executor::RetentionPriority getPriority() const;
[[nodiscard]] executor::RetentionPriority getPriority() const;
void setDurationMs(std::optional<std::chrono::milliseconds> durationMs);
[[nodiscard]] std::optional<std::chrono::milliseconds> getDurationMs() const;
void setExpirationTime(std::optional<std::chrono::steady_clock::time_point::duration> expirationTime);
[[nodiscard]] std::optional<std::chrono::steady_clock::time_point::duration> getExpirationTime() const;
void setHash(size_t hash);
size_t getHash() const;
private:
// Linear ID of block independent of pool
@ -225,6 +253,12 @@ private:
// Priority of the block
executor::RetentionPriority mPriority;
// Duration that the block's priority level applies for
std::optional<std::chrono::milliseconds> mDurationMs;
// Expiration time of the block
std::optional<std::chrono::steady_clock::time_point::duration> mExpirationTime;
// Hash for the event manager
size_t mHash;
};
class GenerationRequest
@ -234,8 +268,7 @@ public:
explicit GenerationRequest(LlmRequest::RequestIdType requestId, SizeType32 numTokens, SizeType32 beamWidth,
SizeType32 maxBlocks, SizeType32 numPools = 1,
executor::RetentionPriority decodeRetentionPriority
= executor::KvCacheRetentionConfig::kDefaultRetentionPriority)
executor::KvCacheRetentionConfig kvCacheRetentionConfig = executor::KvCacheRetentionConfig())
: mRequestId(requestId)
, mNumTokens(numTokens)
, mBeamWidth(beamWidth)
@ -243,7 +276,7 @@ public:
, mCacheBlockIndices{runtime::BufferManager::cpu(
runtime::ITensor::makeShape({numPools, beamWidth, 2, maxBlocks}),
runtime::TRTDataType<tensorrt_llm::kernels::KVCacheIndex>::value)}
, mDecodeRetentionPriority(decodeRetentionPriority)
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
{
auto cacheBlockIdsRange = runtime::BufferRange<tensorrt_llm::kernels::KVCacheIndex>(*mCacheBlockIndices);
std::fill(cacheBlockIdsRange.begin(), cacheBlockIdsRange.end(),
@ -321,7 +354,12 @@ public:
[[nodiscard]] executor::RetentionPriority getDecodeRetentionPriority() const
{
return mDecodeRetentionPriority;
return mKvCacheRetentionConfig.getDecodeRetentionPriority();
}
[[nodiscard]] std::optional<std::chrono::milliseconds> getDecodeDurationMs() const
{
return mKvCacheRetentionConfig.getDecodeDurationMs();
}
private:
@ -336,7 +374,7 @@ private:
// Tensor of block indices allocated for each beam of the sequence
runtime::ITensor::SharedPtr mCacheBlockIndices;
// The retention priority to assign to decode blocks
executor::RetentionPriority mDecodeRetentionPriority;
executor::KvCacheRetentionConfig mKvCacheRetentionConfig;
};
// attach metadata to a pool pointer
@ -385,7 +423,8 @@ public:
SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks,
CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt);
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr);
~BlockManager();
@ -442,6 +481,9 @@ public:
return mMissedBlocks;
}
[[nodiscard]] std::deque<executor::KVCacheEvent> getLatestEvents(
std::optional<std::chrono::milliseconds> timeout) const;
[[nodiscard]] bool hasFreeBlocks(SizeType32 numRequired = 1) const noexcept
{
return getNumFreeBlocks() >= numRequired;
@ -522,13 +564,27 @@ public:
//! \brief Find first new block that must be allocated for context phase and return it's concatenated token vectors.
//! \details Only full blocks are considered.
BlockKey findNewContextBlock(VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const;
[[nodiscard]] std::optional<BlockKey> findNewContextBlock(
VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const;
[[nodiscard]] runtime::BufferManager const& getBufferManager() const
{
return mBufferManager;
}
//! \brief Perform per-request bookkeeping
void refreshBlocks();
void flushIterationEvents()
{
if (mEventManager)
{
mEventManager->flush();
}
}
[[nodiscard]] static bool blockInRadixTree(BlockPtr const& block);
private:
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
@ -539,23 +595,23 @@ private:
//! \brief Store blocks in cached blocks.
//! \param blockKeys Key of each block.
//! \param blockIds Id of each block.
//! \param isChunkedContext Whether these blocks are being stored for chunked context.
void storeBlocks(std::list<BlockKey> blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
bool isChunkedContext = false);
void storeBlocks(std::vector<BlockKey> blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
//! \brief Try to load blocks from cache. Allocate new blocks if necessary.
//! \param blockKeys Key of each block.
//! \param sequence Sequence to which blocks are assigned.
//! \return Number of matched tokens from loaded blocks.
SizeType32 loadOrAllocateBlocks(std::list<BlockKey> const& blockKeys, SizeType32 numContextBlocks,
GenerationRequest& sequence, std::vector<std::optional<executor::RetentionPriority>> blockPriorities);
SizeType32 loadOrAllocateBlocks(std::vector<BlockKey> const& blockKeys, SizeType32 numContextBlocks,
GenerationRequest& sequence, std::vector<executor::RetentionPriorityAndDuration> const& perBlockRetentions);
//! \brief Find block least likely to be reused, free it if necessary and return.
[[nodiscard]] BlockPtr getFreeBlock(
executor::RetentionPriority = executor::KvCacheRetentionConfig::kDefaultRetentionPriority);
executor::RetentionPriority = executor::KvCacheRetentionConfig::kDefaultRetentionPriority,
std::optional<std::chrono::milliseconds> durationMs = std::nullopt);
//! \brief Free block from previous block and claim it from free blocks list.
void claimLeafBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority = std::nullopt);
void claimLeafBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority = std::nullopt,
std::optional<std::chrono::milliseconds> durationMs = std::nullopt);
//! \brief Compute pointer to raw KV block (K & V, all layers).
[[nodiscard]] runtime::ITensor::SharedPtr computeBlockPointer(
@ -598,6 +654,8 @@ private:
CacheType mCacheType;
// Eviction Policy
std::shared_ptr<BaseEvictionPolicy> mEvictionPolicy;
// Event manager
std::shared_ptr<KVCacheEventManager> mEventManager;
// Statistics for block allocations/reuse
// Total number of blocks allocated by all requests
@ -630,14 +688,16 @@ public:
SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, bool useOneMoreBlock,
CudaStreamPtr stream, bool enableBlockReuse = false, bool onboardBlocks = true,
CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt);
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr);
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, bool useOneMoreBlock,
CudaStreamPtr stream, bool enableBlockReuse = true, bool onboardBlocks = true,
CacheType cacheType = CacheType::kSELF,
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt);
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
std::shared_ptr<KVCacheEventManager> eventManager = nullptr);
void allocatePools(nvinfer1::DataType dtype, bool useUvm = false);
@ -705,6 +765,12 @@ public:
return mMaxBlocksPerSeq;
}
[[nodiscard]] std::deque<executor::KVCacheEvent> getLatestEvents(
std::optional<std::chrono::milliseconds> timeout = std::nullopt) const
{
return mBlockManager.getLatestEvents(timeout);
}
[[nodiscard]] BlockManager const& getBlockManager() const
{
return mBlockManager;
@ -790,7 +856,8 @@ public:
//! \brief Find first new block that must be allocated for context phase and return it's concatenated token vector.
//! \details Only full blocks are considered.
BlockKey findNewContextBlock(VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const;
[[nodiscard]] std::optional<BlockKey> findNewContextBlock(
VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const;
//! \brief Store full context blocks contributed by llmRequest.
//! \details These blocks become reusable from next step.
@ -804,6 +871,17 @@ public:
//! \brief Get the batch size that can fill the kv cache to the maximum capacity given the sequence length
[[nodiscard]] SizeType32 getMaxCapacityBatchSize(SizeType32 seqLen);
//! \brief Perform per-iteration bookkeeping
void refreshBlocks()
{
mBlockManager.refreshBlocks();
}
void flushIterationEvents()
{
mBlockManager.flushIterationEvents();
}
private:
void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
SizeType32 blockIdx, KVCacheBlock::IdType blockId) const;

View File

@ -55,7 +55,6 @@ enum class LlmRequestState : int32_t
/// Waiting context-only request transmitting the kv cache
kDISAGG_CONTEXT_COMPLETE = 8, ///< Context-only request finished kv cache transmission.
kDISAGG_GENERATION_TRANS_IN_PROGRESS = 9, ///< For disaggregated serving only: transmitting the kv cache
kWAITING_TO_SEND_LOGITS = 10, ///< Generation phase completed, logits not sent yet
};
enum LlmRequestType
@ -110,7 +109,8 @@ public:
std::optional<TensorPtr> crossAttentionMask = std::nullopt,
LlmRequestType llmRequestType = LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
std::optional<std::shared_ptr<VecTokenExtraIds>> inputTokenExtraIds = std::nullopt,
SizeType32 numReturnSequences = 1, std::optional<executor::EagleConfig> eagleConfig = std::nullopt)
SizeType32 numReturnSequences = 1, std::optional<executor::EagleConfig> eagleConfig = std::nullopt,
std::optional<TensorPtr> skipCrossAttnBlocks = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens)
@ -159,6 +159,7 @@ public:
, mNumReturnSequences(numReturnSequences)
, mEagleConfig(eagleConfig)
, mSequenceIndex(0)
, mSkipCrossAttnBlocks(std::move(skipCrossAttnBlocks))
{
if (mEncoderTokens.has_value() || encoderInputFeatures.has_value())
{
@ -337,6 +338,16 @@ public:
mCrossAttentionMask = std::nullopt;
}
auto const& skipCrossAttnBlocks = req.getSkipCrossAttnBlocks();
if (skipCrossAttnBlocks.has_value())
{
mSkipCrossAttnBlocks = executor::detail::toITensor(skipCrossAttnBlocks.value());
}
else
{
mSkipCrossAttnBlocks = std::nullopt;
}
switch (req.getRequestType())
{
case executor::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION:
@ -1071,6 +1082,11 @@ public:
return mCrossAttentionMask.value_or(nullptr);
}
[[nodiscard]] TensorPtr const getSkipCrossAttnBlocks() const
{
return mSkipCrossAttnBlocks.value_or(nullptr);
}
[[nodiscard]] bool constexpr isStreaming() const noexcept
{
return mIsStreaming;
@ -1235,11 +1251,6 @@ public:
return mState == LlmRequestState::kDISAGG_CONTEXT_COMPLETE;
}
[[nodiscard]] bool isCompleteWaitingToSendLogits() const noexcept
{
return mState == LlmRequestState::kWAITING_TO_SEND_LOGITS;
}
/// To determine whether the context is unchunked. When a context is chunked into only a part, it
/// is still different from the unchunked state, which indicates the initial status.
[[nodiscard]] bool isFullContextRequest() const noexcept
@ -1342,7 +1353,7 @@ public:
[[nodiscard]] bool isFinished() const noexcept
{
return isGenerationCompleteState() || isDisaggContextTransmissionState() || isCompleteWaitingToSendLogits();
return isGenerationCompleteState() || isDisaggContextTransmissionState();
}
/// @brief Create a Response from the current state of the request
@ -1665,6 +1676,8 @@ protected:
SizeType32 mReusedBlocksPerRequest{0};
SizeType32 mMissedBlocksPerRequest{0};
std::optional<TensorPtr> mSkipCrossAttnBlocks;
private:
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
{
@ -1823,7 +1836,8 @@ public:
std::optional<TensorPtr> crossAttentionMask = std::nullopt,
LlmRequestType llmRequestType = LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
std::optional<std::shared_ptr<VecTokenExtraIds>> inputTokenExtraIds = std::nullopt,
SizeType32 numReturnSequences = 1, std::optional<executor::EagleConfig> eagleConfig = std::nullopt)
SizeType32 numReturnSequences = 1, std::optional<executor::EagleConfig> eagleConfig = std::nullopt,
std::optional<TensorPtr> skipCrossAttnBlocks = std::nullopt)
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
@ -1832,7 +1846,7 @@ public:
std::move(logitsPostProcessor), applyLogitsPostProcessorBatched, std::move(encoderInputTokens),
returnEncoderOutput, clientId, priority, std::move(encoderInputFeatures), std::move(encoderOutputLength),
std::move(crossAttentionMask), llmRequestType, std::move(inputTokenExtraIds), numReturnSequences,
std::move(eagleConfig))
std::move(eagleConfig), std::move(skipCrossAttnBlocks))
{
}
@ -1857,7 +1871,9 @@ public:
std::optional<SizeType32> encoderOutputLength = std::nullopt,
std::optional<TensorPtr> crossAttentionMask = std::nullopt,
LlmRequestType llmRequestType = LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
std::optional<VecTokenExtraIds> inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1)
std::optional<VecTokenExtraIds> inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1,
std::optional<executor::EagleConfig> eagleConfig = std::nullopt,
std::optional<TensorPtr> skipCrossAttnBlocks = std::nullopt)
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList),
std::move(stopWordsList),
@ -1875,7 +1891,7 @@ public:
llmRequestType,
inputTokenExtraIds ? std::make_optional(std::make_shared<VecTokenExtraIds>(std::move(*inputTokenExtraIds)))
: std::optional<std::shared_ptr<VecTokenExtraIds>>(std::nullopt),
numReturnSequences)
numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks)
{
}

View File

@ -50,68 +50,29 @@ public:
using SizeType32 = tensorrt_llm::runtime::SizeType32;
using ContextChunkingPolicy = tensorrt_llm::executor::ContextChunkingPolicy;
MicroBatchScheduler() = default;
explicit MicroBatchScheduler(SizeType32 maxBatchSize, std::optional<SizeType32> maxNumTokens = std::nullopt,
explicit MicroBatchScheduler(std::optional<SizeType32> maxNumTokens = std::nullopt,
std::optional<batch_scheduler::ContextChunkingConfig> ctxChunkConfig = std::nullopt,
std::optional<SizeType32> maxContextLength = std::nullopt,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
static MicroBatchScheduler make(SizeType32 maxBatchSize, std::optional<SizeType32> maxNumTokens = std::nullopt,
std::optional<batch_scheduler::ContextChunkingConfig> ctxChunkConfig = std::nullopt,
std::optional<SizeType32> maxContextLength = std::nullopt,
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE)
{
return MicroBatchScheduler{
maxBatchSize, maxNumTokens, ctxChunkConfig, maxContextLength, noScheduleUntilState, noScheduleAfterState};
}
std::tuple<RequestVector, RequestVector> operator()(
RequestVector const& activeRequests, ReqIdsSet const& inflightReqIds);
RequestVector& activeRequests, ReqIdsSet const& inflightReqIds, SizeType32 maxBatchSizeRuntime) const;
static void setCtxRequestsChunkSize(RequestVector const& contextsToBeChunked, ContextChunkingPolicy ctxChunkPolicy,
static void setCtxRequestsChunkSize(RequestVector& contextsToBeChunked, ContextChunkingPolicy ctxChunkPolicy,
std::optional<SizeType32> ctxTokensCapacity, SizeType32 chunkUnitSize,
std::optional<SizeType32> const& maxContextLength);
void setRuntimeMaxBatchSize(SizeType32 runtimeMaxBatchSize);
SizeType32 getMaxBatchSizeStatic() const
{
return mMaxBatchSize;
}
SizeType32 getMaxBatchSizeTunerRecommended() const
{
return mMaxBatchSizeTunerRecommended;
}
SizeType32 getMaxBatchSizeRuntime() const
{
return mMaxBatchSizeRuntime;
}
private:
template <ContextChunkingPolicy tPolicy>
static void setCtxRequestsChunkSize(RequestVector const& contextsToBeChunked,
std::optional<SizeType32> ctxTokensCapacity, SizeType32 chunkUnitSize,
std::optional<SizeType32> const& maxContextLength);
static void setCtxRequestsChunkSize(RequestVector& contextsToBeChunked, std::optional<SizeType32> ctxTokensCapacity,
SizeType32 chunkUnitSize, std::optional<SizeType32> const& maxContextLength);
/// After the chunk sizes have been determined, this function will discard
/// any draft tokens that don't fit.
static void fitDraftTokens(RequestVector const& contextsToBeChunked, std::optional<SizeType32> ctxTokensCapacity,
static void fitDraftTokens(RequestVector& contextsToBeChunked, std::optional<SizeType32> ctxTokensCapacity,
SizeType32 chunkUnitSize, std::optional<SizeType32> const& maxContextLength);
/// The maximum number of requests returned by scheduleRequests
SizeType32 mMaxBatchSize;
/// The max batch size recommended by the dynamic tuner
SizeType32 mMaxBatchSizeTunerRecommended;
/// The min of mMaxBatchSize and mMaxBatchSizeTunerRecommended
SizeType32 mMaxBatchSizeRuntime;
/// The maximum number of tokens to include in a batch
std::optional<SizeType32> mMaxNumTokens;

View File

@ -0,0 +1,74 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "common.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
#include "tensorrt_llm/batch_manager/sequenceSlotManager.h"
#include "tensorrt_llm/common/algorithm.h"
#include "tensorrt_llm/common/optionalRef.h"
#include "tensorrt_llm/runtime/common.h"
namespace tensorrt_llm::batch_manager
{
class BasePeftCacheManager;
class LlmRequest;
namespace kv_cache_manager
{
class KVCacheManager;
}
} // namespace tensorrt_llm::batch_manager
namespace tensorrt_llm::batch_manager
{
namespace tle = tensorrt_llm::executor;
class PauseRequests : Algorithm
{
using KVCacheManager = kv_cache_manager::KVCacheManager;
template <typename T>
using OptionalRef = common::OptionalRef<T>;
public:
constexpr static auto name{"PauseRequests"};
using SizeType32 = tensorrt_llm::runtime::SizeType32;
PauseRequests(SizeType32 maxInputLen)
: mMaxInputLen(maxInputLen)
{
}
void operator()(RequestVector& requestsToPause, ReqIdsSet& inflightReqIds, ReqIdsSet& reqIdsToPause,
bool pauseFlagged, SequenceSlotManager& seqSlotManager,
OptionalRef<KVCacheManager> kvCacheManager = std::nullopt,
OptionalRef<KVCacheManager> crossKvCacheManager = std::nullopt,
OptionalRef<BasePeftCacheManager> peftCacheManager = std::nullopt) const;
private:
SizeType32 mMaxInputLen;
};
} // namespace tensorrt_llm::batch_manager

View File

@ -17,6 +17,7 @@
#pragma once
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
#include "tensorrt_llm/common/cudaDriverWrapper.h"
#include "tensorrt_llm/common/cudaFp8Utils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/common/tllmException.h"
@ -24,7 +25,9 @@
#include <cinttypes>
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <driver_types.h>
#include <fstream>
#include <iostream>
#include <memory>
@ -318,9 +321,13 @@ inline std::tuple<size_t, size_t> getDeviceMemoryInfo(bool const useUvm)
{
if (useUvm)
{
size_t freeSysMem, totalSysMem;
size_t freeSysMem = 0;
size_t totalSysMem = 0;
#ifndef _WIN32 // Linux
struct sysinfo info;
struct sysinfo info
{
};
sysinfo(&info);
totalSysMem = info.totalram * info.mem_unit;
freeSysMem = info.freeram * info.mem_unit;
@ -336,20 +343,38 @@ inline std::tuple<size_t, size_t> getDeviceMemoryInfo(bool const useUvm)
((double) totalSysMem / 1e9), ((double) freeSysMem / 1e9));
return {freeSysMem, totalSysMem};
}
else
{
size_t free, total;
check_cuda_error(cudaMemGetInfo(&free, &total));
TLLM_LOG_DEBUG("Using GPU memory for KV cache, total memory %0.2f GB, available memory %0.2f GB",
((double) total / 1e9), ((double) free / 1e9));
return {free, total};
}
size_t free = 0;
size_t total = 0;
check_cuda_error(cudaMemGetInfo(&free, &total));
TLLM_LOG_DEBUG("Using GPU memory for KV cache, total memory %0.2f GB, available memory %0.2f GB",
((double) total / 1e9), ((double) free / 1e9));
return {free, total};
}
/// @brief Gets the memory allocation granularity for the current device.
///
/// @return size_t The size of the smallest difference in memory size supported by the current device.
inline size_t getAllocationGranularity()
{
auto const currentDevice = getDevice();
::CUmemAllocationProp prop = {};
prop.type = ::CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = ::CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = currentDevice;
prop.requestedHandleTypes = ::CU_MEM_HANDLE_TYPE_NONE;
// Get the minimum granularity supported for allocation with cuMemCreate()
size_t granularity = 0;
TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
return granularity;
}
inline int getMultiProcessorCount()
{
int device_id;
int multi_processor_count;
int device_id = 0;
int multi_processor_count = 0;
check_cuda_error(cudaGetDevice(&device_id));
check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count, cudaDevAttrMultiProcessorCount, device_id));
return multi_processor_count;
@ -357,8 +382,8 @@ inline int getMultiProcessorCount()
inline int getMaxSharedMemoryPerBlockOptin()
{
int device_id;
int max_shared_memory_per_block;
int device_id = 0;
int max_shared_memory_per_block = 0;
check_cuda_error(cudaGetDevice(&device_id));
check_cuda_error(
cudaDeviceGetAttribute(&max_shared_memory_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id));
@ -368,8 +393,8 @@ inline int getMaxSharedMemoryPerBlockOptin()
template <typename T1, typename T2>
inline size_t divUp(const T1& a, const T2& n)
{
size_t tmp_a = static_cast<size_t>(a);
size_t tmp_n = static_cast<size_t>(n);
auto const tmp_a = static_cast<size_t>(a);
auto const tmp_n = static_cast<size_t>(n);
return (tmp_a + tmp_n - 1) / tmp_n;
}

View File

@ -1,12 +1,11 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,

View File

@ -97,6 +97,11 @@ public:
return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9);
}
static constexpr QuantMode w4a8QServe() noexcept
{
return QuantMode(BaseType(1u) << 10);
}
constexpr BaseType value() const noexcept
{
return mValue;
@ -169,7 +174,8 @@ public:
static constexpr QuantMode fromDescription(bool quantizeWeights = false, bool quantizeActivations = false,
bool perToken = false, bool perChannel = false, bool perGroup = false, bool useInt4Weights = false,
bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false)
bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false,
bool useW4a8QServe = false)
{
QuantMode quantMode{};
if (quantizeWeights)
@ -218,6 +224,11 @@ public:
quantMode += fp8RowWise();
}
if (useW4a8QServe)
{
quantMode += w4a8QServe();
}
return quantMode;
}
@ -226,12 +237,17 @@ public:
return fromDescription(true, true, perToken, perChannel);
}
static constexpr QuantMode useQServe(bool perGroup)
{
return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true);
}
static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false)
{
return fromDescription(true, false, false, false, perGroup, useInt4Weights);
}
static const QuantMode fromQuantAlgo(
static QuantMode const fromQuantAlgo(
std::optional<std::string> quantAlgo = std::nullopt, std::optional<std::string> kvCacheQuantAlgo = std::nullopt)
{
QuantMode quantMode{};
@ -251,6 +267,14 @@ public:
{
quantMode = useWeightOnly(true, true);
}
else if (quantAlgo == "W4A8_QSERVE_PER_GROUP")
{
quantMode = useQServe(false);
}
else if (quantAlgo == "W4A8_QSERVE_PER_CHANNEL")
{
quantMode = useQServe(true);
}
else if (quantAlgo == "W4A16_GPTQ")
{
quantMode = useWeightOnly(true, true);

View File

@ -18,6 +18,8 @@
#include "tensorrt_llm/executor/tensor.h"
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/runtimeDefaults.h"
#include <chrono>
#include <cstdint>
@ -29,12 +31,18 @@
#include <optional>
#include <string>
#include <utility>
#include <variant>
#include <vector>
namespace tensorrt_llm::mpi
{
class MpiComm;
}
} // namespace tensorrt_llm::mpi
namespace tensorrt_llm::batch_manager::kv_cache_manager
{
class KVCacheManager;
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
namespace tensorrt_llm::executor
{
@ -355,7 +363,7 @@ private:
class SpeculativeDecodingConfig
{
public:
explicit SpeculativeDecodingConfig(bool fastLogits);
explicit SpeculativeDecodingConfig(bool fastLogits = false);
bool operator==(SpeculativeDecodingConfig const& other) const;
@ -365,6 +373,20 @@ public:
using RetentionPriority = SizeType32;
struct RetentionPriorityAndDuration
{
RetentionPriorityAndDuration(std::optional<RetentionPriority> const& retentionPriority,
std::optional<std::chrono::milliseconds> const& durationMs)
: retentionPriority{retentionPriority}
, durationMs{durationMs}
{
}
std::optional<RetentionPriority> retentionPriority;
std::optional<std::chrono::milliseconds> durationMs;
};
/// @brief Configuration for the request's retention in the KV Cache
class KvCacheRetentionConfig
{
@ -376,14 +398,16 @@ public:
/// @brief A single entry to set block priorities over a token range. Earlier ranges always take priority over later
/// ones. For example, with a block size of 16, a range of [0, 17] would be applied to the first two blocks.
struct TokenRangeRetentionPriority
struct TokenRangeRetentionConfig
{
public:
explicit TokenRangeRetentionPriority(SizeType32 tokenStart, std::optional<SizeType32> tokenEnd = std::nullopt,
RetentionPriority priority = KvCacheRetentionConfig::kDefaultRetentionPriority)
explicit TokenRangeRetentionConfig(SizeType32 tokenStart, std::optional<SizeType32> tokenEnd = std::nullopt,
RetentionPriority priority = KvCacheRetentionConfig::kDefaultRetentionPriority,
std::optional<std::chrono::milliseconds> durationMs = std::nullopt)
: tokenStart{tokenStart}
, tokenEnd{tokenEnd}
, priority{priority}
, durationMs{durationMs}
{
TLLM_CHECK_WITH_INFO(priority >= KvCacheRetentionConfig::kMinRetentionPriority
&& priority <= KvCacheRetentionConfig::kMaxRetentionPriority,
@ -398,10 +422,15 @@ public:
std::optional<SizeType32> tokenEnd;
/// @brief The priority of this token range. Higher priorities are less likely to be evicted or offloaded.
RetentionPriority priority;
/// @brief The duration in ms that the block should remain at the given priority level. Set to std::nullopt to
/// have no expiration time, and keep the block at the given priority level until it gets reclaimed. After the
/// duration has passed, the block will be moved back to the `kDefaultRetentionPriority` level.
std::optional<std::chrono::milliseconds> durationMs;
bool operator==(TokenRangeRetentionPriority const& other) const
bool operator==(TokenRangeRetentionConfig const& other) const
{
return tokenStart == other.tokenStart && tokenEnd == other.tokenEnd && priority == other.priority;
return tokenStart == other.tokenStart && tokenEnd == other.tokenEnd && priority == other.priority
&& durationMs == other.durationMs;
}
};
@ -410,22 +439,28 @@ public:
{
}
KvCacheRetentionConfig(std::vector<TokenRangeRetentionPriority> const& tokenRangeRetentionPriorities,
RetentionPriority decodeRetentionPriority);
explicit KvCacheRetentionConfig(std::vector<TokenRangeRetentionConfig> const& tokenRangeRetentionPriorities,
RetentionPriority decodeRetentionPriority = kDefaultRetentionPriority,
std::optional<std::chrono::milliseconds> decodeDurationMs = std::nullopt);
[[nodiscard]] std::vector<TokenRangeRetentionPriority> getTokenRangeRetentionPriorities() const;
[[nodiscard]] std::vector<TokenRangeRetentionConfig> getTokenRangeRetentionConfigs() const;
[[nodiscard]] RetentionPriority getDecodeRetentionPriority() const;
[[nodiscard]] std::optional<std::chrono::milliseconds> getDecodeDurationMs() const;
/// @brief Convert the token range data into an entry per kv cache block for a given seqLen
std::vector<std::optional<RetentionPriority>> getPerBlockEvictionPolicy(SizeType32 blockSize, SizeType32 seqLen);
/// @brief Convert the token range data into an entry per kv block. Returns a tuple of vectors corresponding to the
/// priorities and durations for each block.
[[nodiscard]] std::vector<RetentionPriorityAndDuration> getPerBlockRetentionPriorityDuration(
SizeType32 blockSize, SizeType32 seqLen) const;
private:
/// @brief The token ranges and priority levels to update. Ranges must be non-overlapping. For example [(0, 64),
/// (100, 128), (70, 80)] is valid, whereas
/// [(0, 64), (60, 128)] is not.
std::vector<TokenRangeRetentionPriority> mTokenRangeRetentionPriorities;
std::vector<TokenRangeRetentionConfig> mTokenRangeRetentionConfigs;
/// @brief The priority level to assign to blocks allocated in the decode phase
RetentionPriority mDecodeRetentionPriority;
/// @brief The duration in ms that decode blocks should remain at their assigned priority level.
std::optional<std::chrono::milliseconds> mDecodeDurationMs;
};
/// @brief A class that holds information about the request
@ -466,6 +501,7 @@ public:
/// @param contextPhaseParams Generated token ID from context only executor.
/// @param numReturnSequences The number of returning sequences.
/// @param eagleConfig The EAGLE speculative decoding configuration
/// @param skipCrossAttnBlocks Skip the cross attention transformer blocks or not.
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
@ -486,7 +522,8 @@ public:
std::optional<Tensor> encoderInputFeatures = std::nullopt,
std::optional<SizeType32> encoderOutputLength = std::nullopt,
std::optional<Tensor> crossAttentionMask = std::nullopt, SizeType32 numReturnSequences = 1,
std::optional<EagleConfig> eagleConfig = std::nullopt);
std::optional<EagleConfig> eagleConfig = std::nullopt,
std::optional<Tensor> skipCrossAttnBlocks = std::nullopt);
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
static auto constexpr kBatchedPostProcessorName = "batched";
@ -526,6 +563,7 @@ public:
[[nodiscard]] RequestType getRequestType() const;
[[nodiscard]] SizeType32 getNumReturnSequences() const;
[[nodiscard]] std::optional<EagleConfig> getEagleConfig() const;
[[nodiscard]] std::optional<Tensor> getSkipCrossAttnBlocks() const;
void setStreaming(bool streaming);
void setSamplingConfig(SamplingConfig const& config);
@ -553,6 +591,7 @@ public:
void setCrossAttentionMask(Tensor crossAttentionMask);
void setNumReturnSequences(SizeType32 numReturnSequences);
void setEagleConfig(std::optional<EagleConfig> const& eagleConfig);
void setSkipCrossAttnBlocks(Tensor skipCrossAttnBlocks);
private:
friend class Serialization;
@ -568,6 +607,9 @@ struct SpeculativeDecodingFastLogitsInfo
/// @brief MPI world rank of the draft model leader
int32_t draftParticipantId;
/// @brief Returns the struct serialized into a tensor that can be used as generation logits input
[[nodiscard]] Tensor toTensor() const;
};
/// @brief Struct that holds the generation result
@ -735,7 +777,8 @@ public:
std::optional<FloatType> const& freeGpuMemoryFraction = std::nullopt,
std::optional<size_t> const& hostCacheSize = std::nullopt, bool onboardBlocks = true,
std::optional<FloatType> const& crossKvCacheFraction = std::nullopt,
std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt);
std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0,
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt);
[[nodiscard]] bool getEnableBlockReuse() const;
[[nodiscard]] std::optional<SizeType32> getMaxTokens() const;
@ -746,6 +789,7 @@ public:
[[nodiscard]] std::optional<size_t> getHostCacheSize() const;
[[nodiscard]] bool getOnboardBlocks() const;
[[nodiscard]] std::optional<RetentionPriority> getSecondaryOffloadMinPriority() const;
[[nodiscard]] size_t getEventBufferMaxSize() const;
void setEnableBlockReuse(bool enableBlockReuse);
void setMaxTokens(SizeType32 maxTokens);
@ -756,6 +800,8 @@ public:
void setHostCacheSize(size_t hostCacheSize);
void setOnboardBlocks(bool onboardBlocks);
void setSecondaryOffloadMinPriority(std::optional<RetentionPriority> secondaryOffloadMinPriority);
void setEventBufferMaxSize(size_t eventBufferMaxSize);
void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults runtimeDefaults);
private:
friend class Serialization;
@ -797,6 +843,9 @@ private:
/// @brief Only blocks with priority > mSecondaryOfflineMinPriority can be offloaded to secondary memory.
std::optional<RetentionPriority> mSecondaryOffloadMinPriority;
/// @brief Max size of the KV cache event buffer
size_t mEventBufferMaxSize;
};
/// @brief Configuration class for the runtime perf knobs
@ -1197,6 +1246,113 @@ private:
std::optional<SpeculativeDecodingConfig> mSpeculativeDecodingConfig;
};
struct KVCacheCreatedData
{
/// @brief The amount of blocks at each cache level
std::vector<SizeType32> numBlocksPerCacheLevel;
};
/// @brief An entry for a single block stored into the tree
struct KVCacheStoredBlockData
{
KVCacheStoredBlockData(IdType blockHash, tensorrt_llm::runtime::VecUniqueTokens const& tokens,
tensorrt_llm::runtime::LoraTaskIdType loraId, SizeType32 cacheLevel, SizeType32 priority)
: blockHash{blockHash}
, tokens{tokens}
, loraId{loraId}
, cacheLevel{cacheLevel}
, priority{priority}
{
}
/// @brief The hash of the block
IdType blockHash;
/// @brief The unique tokens of the block
tensorrt_llm::runtime::VecUniqueTokens tokens;
/// @brief The Lora task id of the block
tensorrt_llm::runtime::LoraTaskIdType loraId;
/// @brief The cache level of the block
SizeType32 cacheLevel;
/// @brief The priority of the block
SizeType32 priority;
};
struct KVCacheStoredData
{
/// @brief The parent of this sequence of stored blocks
std::optional<IdType> parentHash;
/// @brief A sequence of blocks. The parent of block `i` is block `i-1`
std::vector<KVCacheStoredBlockData> blocks;
};
struct KVCacheRemovedData
{
/// @brief The hashes of blocks being removed
std::vector<IdType> blockHashes;
};
template <typename T>
struct KVCacheEventDiff
{
T oldValue;
T newValue;
};
struct KVCacheUpdatedData
{
explicit KVCacheUpdatedData(IdType blockHash)
: blockHash{blockHash} {};
KVCacheUpdatedData& cacheLevelUpdated(SizeType32 oldValue, SizeType32 newValue)
{
cacheLevel = KVCacheEventDiff<SizeType32>{oldValue, newValue};
return *this;
}
KVCacheUpdatedData& priorityUpdated(SizeType32 oldValue, SizeType32 newValue)
{
priority = KVCacheEventDiff<SizeType32>{oldValue, newValue};
return *this;
}
/// @brief The hash of the updated block
IdType blockHash;
/// @brief The updated value of the cacheLevel field
std::optional<KVCacheEventDiff<SizeType32>> cacheLevel = std::nullopt;
/// @brief The updated value of the priority field
std::optional<KVCacheEventDiff<SizeType32>> priority = std::nullopt;
};
using KVCacheEventData = std::variant<KVCacheCreatedData, KVCacheStoredData, KVCacheRemovedData, KVCacheUpdatedData>;
struct KVCacheEvent
{
KVCacheEvent(IdType eventId, KVCacheEventData data);
/// @brief The unique id of this event
IdType eventId;
/// @brief The data corresponding to this event
KVCacheEventData data;
};
/// @brief Exposes a limited set of KV cache manager functionalities
class KVCacheEventManager
{
public:
KVCacheEventManager(std::shared_ptr<tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager> kvCacheManager);
/// @brief Get the latest KV Cache events.
/// @param timeout The maximum time to wait for new events. If nullopt, will only return when new events are
/// available, or when the executor instance has shutdown.
std::deque<KVCacheEvent> getLatestEvents(std::optional<std::chrono::milliseconds> timeout = std::nullopt);
private:
std::shared_ptr<tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager> kvCacheManager;
};
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
class Executor
{
@ -1300,6 +1456,8 @@ public:
/// @brief Indicates if the current process participates in this executor instance
[[nodiscard]] bool isParticipant() const;
std::optional<std::shared_ptr<KVCacheEventManager>> getKVCacheEventManager() const;
private:
class Impl;
std::unique_ptr<Impl> mImpl;

View File

@ -164,12 +164,11 @@ public:
static void serialize(KvCacheRetentionConfig const& kvCacheRetentionConfig, std::ostream& os);
static size_t serializedSize(KvCacheRetentionConfig const& kvCacheRetentionConfig);
// TokenRangeRetentionPriority
static KvCacheRetentionConfig::TokenRangeRetentionPriority deserializeTokenRangeRetentionPriority(std::istream& is);
// TokenRangeRetentionConfig
static KvCacheRetentionConfig::TokenRangeRetentionConfig deserializeTokenRangeRetentionConfig(std::istream& is);
static void serialize(
KvCacheRetentionConfig::TokenRangeRetentionPriority const& tokenRangeRetentionPriority, std::ostream& os);
static size_t serializedSize(
KvCacheRetentionConfig::TokenRangeRetentionPriority const& tokenRangeRetentionPriority);
KvCacheRetentionConfig::TokenRangeRetentionConfig const& tokenRangeRetentionConfig, std::ostream& os);
static size_t serializedSize(KvCacheRetentionConfig::TokenRangeRetentionConfig const& tokenRangeRetentionConfig);
// DecodingConfig
static DecodingConfig deserializeDecodingConfig(std::istream& is);

View File

@ -58,6 +58,8 @@ public:
TensorPtr draftPaths;
//! [maxBatchSize] or [numGenSequences]
TensorPtr specDecodingGenerationLengths;
//! [maxBatchSize] or [numGenSequences]
TensorPtr specDecodingGenerationLengthsHost;
//! [maxBatchSize, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
//! or [numGenSequences, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
TensorPtr specDecodingPackedMasks;

View File

@ -18,6 +18,7 @@
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/runtimeDefaults.h"
#include "tensorrt_llm/runtime/worldConfig.h"
#include <filesystem>
@ -32,7 +33,8 @@ class GptJsonConfig
{
public:
GptJsonConfig(std::string name, std::string version, std::string precision, SizeType32 tensorParallelism,
SizeType32 pipelineParallelism, SizeType32 gpusPerNode, ModelConfig modelConfig)
SizeType32 pipelineParallelism, SizeType32 gpusPerNode, ModelConfig modelConfig,
std::optional<RuntimeDefaults> runtimeDefaults = std::nullopt)
: mName(std::move(name))
, mVersion(std::move(version))
, mPrecision(std::move(precision))
@ -40,6 +42,7 @@ public:
, mPipelineParallelism{pipelineParallelism}
, mGpusPerNode{gpusPerNode}
, mModelConfig(std::move(modelConfig))
, mRuntimeDefaults(std::move(runtimeDefaults))
{
}
@ -94,6 +97,11 @@ public:
return mTensorParallelism * mPipelineParallelism;
}
[[nodiscard]] std::optional<RuntimeDefaults> getRuntimeDefaults() const
{
return mRuntimeDefaults;
}
[[nodiscard]] std::string engineFilename(WorldConfig const& worldConfig, std::string const& model) const;
[[nodiscard]] std::string engineFilename(WorldConfig const& worldConfig) const
@ -109,6 +117,7 @@ private:
SizeType32 const mPipelineParallelism;
SizeType32 const mGpusPerNode;
ModelConfig mModelConfig; // remove const qualifier because config has to mutable after json parsing
std::optional<RuntimeDefaults> mRuntimeDefaults;
};
} // namespace tensorrt_llm::runtime

View File

@ -71,4 +71,6 @@ public:
std::vector<runtime::IpcMemory> mIpcMemoryHandles;
};
void lamportInitializeAll(void* buffer_0, void* buffer_1, void* buffer_2, size_t size);
} // namespace tensorrt_llm::runtime

View File

@ -46,8 +46,8 @@ public:
}
private:
/// We use mc_sim_7b_63 from official Medusa implementation, i.e. one of the best trees with 63 nodes found for 7B
/// Vicuna model.
// We use mc_sim_7b_63 from official Medusa implementation, i.e. one of the best trees with 63 nodes found for 7B
// Vicuna model.
// We use it as default, if no other are trees are specified on the server level.
MedusaChoices mDefaultMedusaChoices = {{0}, {0, 0}, {1}, {0, 1}, {2}, {0, 0, 0}, {1, 0}, {0, 2}, {3}, {0, 3}, {4},
{0, 4}, {2, 0}, {0, 5}, {0, 0, 1}, {5}, {0, 6}, {6}, {0, 7}, {0, 1, 0}, {1, 1}, {7}, {0, 8}, {0, 0, 2}, {3, 0},

View File

@ -137,6 +137,7 @@ public:
, mLogitsDtype(nvinfer1::DataType::kFLOAT)
, mUseShapeInference(true)
, mManageWeightsType(ManageWeightsType::kDisabled)
, mSkipCrossAttnBlocks(false)
{
TLLM_CHECK_WITH_INFO(mNbLayers >= mNbAttentionLayers + mNbRnnLayers,
"Number of layers (%d) expected to be >= number of attention (%d) + number of rnn layers (%d)", mNbLayers,
@ -760,6 +761,16 @@ public:
return sumLocalHeads;
}
[[nodiscard]] bool constexpr skipCrossAttnBlocks() const noexcept
{
return mSkipCrossAttnBlocks;
}
void constexpr setSkipCrossAttnBlocks(bool skipCrossAttnBlocks) noexcept
{
mSkipCrossAttnBlocks = skipCrossAttnBlocks;
}
private:
SizeType32 mVocabSize;
SizeType32 mNbLayers;
@ -821,6 +832,7 @@ private:
std::string mModelName;
std::vector<SizeType32> mNumKvHeadsPerAttentionLayer;
std::vector<SizeType32> mNumKvHeadsPerCrossAttentionLayer;
bool mSkipCrossAttnBlocks;
};
} // namespace tensorrt_llm::runtime

View File

@ -0,0 +1,38 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/runtime/common.h"
namespace tensorrt_llm::runtime
{
struct RuntimeDefaults
{
RuntimeDefaults(
std::optional<std::vector<SizeType32>> maxAttentionWindowVec, std::optional<SizeType32> sinkTokenLength)
: maxAttentionWindowVec(maxAttentionWindowVec)
, sinkTokenLength(sinkTokenLength)
{
}
std::optional<std::vector<SizeType32>> maxAttentionWindowVec;
std::optional<SizeType32> sinkTokenLength;
RuntimeDefaults() = default;
};
} // namespace tensorrt_llm::runtime

View File

@ -87,7 +87,7 @@ public:
[[nodiscard]] bool constexpr updatesPositionIds() const
{
return anyBitSet(kLookaheadDecoding | kExplicitDraftTokens | kEagle);
return anyBitSet(kLookaheadDecoding | kExplicitDraftTokens);
}
[[nodiscard]] bool constexpr requiresAttentionMask() const

View File

@ -23,10 +23,6 @@
namespace tensorrt_llm::runtime::utils
{
static SizeType32 constexpr PREFIX_CHUNK_SIZE_BITS = 4;
static SizeType32 constexpr PREFIX_MAX_VALUE = 16;
struct TreeNode
{
SizeType32 nodeId;
@ -36,11 +32,9 @@ struct TreeNode
std::vector<SizeType32> childLinearIndices;
};
void initTensorsFromChoices(SpeculativeDecodingModule const& speculativeDecodingModule,
SizeType32 initTensorsFromChoices(SpeculativeDecodingModule const& speculativeDecodingModule,
std::vector<std::vector<SizeType32>> const& choices, std::vector<SizeType32>& topKs,
ITensor::SharedPtr generationInputLengths, ITensor::SharedPtr positionOffsets, ITensor::SharedPtr treeIds,
ITensor::SharedPtr paths, ITensor::SharedPtr packedMask);
void dumpChoices(std::vector<std::vector<SizeType32>> const& choices, std::vector<SizeType32> const& indices);
} // namespace tensorrt_llm::runtime::utils

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:24555ce900cefa5eb24441c503332c4997d8f1e695e29bfe72eef76eb01d4406
size 5389730
oid sha256:748a53a5f70813f0ddb5bb54a56cd07a4b9146917c12ec34504dc4384b00610b
size 5882210

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:33f9d6f50e53218bd935226fa364d28eed82b5624e606c08a9e51a63b5b2e15d
size 5507018
oid sha256:2350b7f07b5f30179ebf24f6e103dc17d4a656c95c171eaca684529120ca245a
size 6001974

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:94319899700c8ff1bba3f9f3df4b3cde190a2da2d676e9c2af71f281e99e6cf8
size 1986712
oid sha256:8b28f05452036c1722a37ac625921cf4902cfb6c04fb01b9d958b9f40ff9be0b
size 1958384

View File

@ -1,2 +1,2 @@
279b2521d189ac03d35a7466330ea425 libtensorrt_llm_ucx_wrapper.so
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
0066a5a67ec747f565158bbbc398cca9 libtensorrt_llm_ucx_wrapper.so
1c2eb102257f836cd50faf985e693241d7a84dbe commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:bf6dfef7af51dc17f06f70010c0a9197ff107c98042a010eabdc1c5a9931abbe
size 5239294
oid sha256:0132b1d4544101465ac37993ae20324c0c49ae978b0a3c8c95a03a08a17b5b36
size 5692876

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e7943e6af2196982ca90fb8c7fe813dc5dbd6dbd6839cd5706c27044ed3272cf
size 5202544
oid sha256:15ff5d0aeae4d3e776fdf3bb68af0cc5896b14f435b66a11fecc2111668fd089
size 5659602

View File

@ -1,2 +1,2 @@
1598761c1df1fd35b2180b599ad34f58 libtensorrt_llm_ucx_wrapper.so
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
1c2eb102257f836cd50faf985e693241d7a84dbe commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:c4ebe7311b7423dd5f9ec822586743c9bcdfe088a9259647c8772d005ec64f79
size 34643904
oid sha256:f975b781b240c8489a48243a94dfdf0be6bfe6b862cf6ec6cbeacd5c66fae7af
size 36139148

View File

@ -1,2 +1,2 @@
f7660c6225ba8f9a9bce7a06365c3a60 tensorrt_llm_batch_manager_static.lib
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
f9557afc965818430dcae14ae7542adf tensorrt_llm_batch_manager_static.lib
1c2eb102257f836cd50faf985e693241d7a84dbe commit

View File

@ -30,12 +30,10 @@
#include "cudaDriverWrapper.h"
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
#include <stdio.h>
namespace tensorrt_llm
{
namespace common
namespace tensorrt_llm::common
{
std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
@ -47,22 +45,21 @@ std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
{
return result;
}
else
std::lock_guard<std::mutex> lock(mutex);
result = instance.lock();
if (!result)
{
std::lock_guard<std::mutex> lock(mutex);
result = instance.lock();
if (!result)
{
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
instance = result;
}
return result;
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
instance = result;
}
return result;
}
CUDADriverWrapper::CUDADriverWrapper()
: handle(dllOpen(CUDA_LIB_NAME))
{
handle = dllOpen(CUDA_LIB_NAME);
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
auto load_sym = [](void* handle, char const* name)
@ -71,21 +68,22 @@ CUDADriverWrapper::CUDADriverWrapper()
return ret;
};
*(void**) (&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
*(void**) (&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
*(void**) (&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
*(void**) (&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
*(void**) (&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
*(void**) (&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
*(void**) (&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
*(void**) (&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
*(void**) (&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
*(void**) (&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
*(void**) (&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
*(void**) (&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
*(void**) (&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
*(void**) (&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
*(void**) (&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
*reinterpret_cast<void**>(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
*reinterpret_cast<void**>(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage");
*reinterpret_cast<void**>(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
*reinterpret_cast<void**>(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
*reinterpret_cast<void**>(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
*reinterpret_cast<void**>(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
*reinterpret_cast<void**>(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
*reinterpret_cast<void**>(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
*reinterpret_cast<void**>(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
*reinterpret_cast<void**>(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
*reinterpret_cast<void**>(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
*reinterpret_cast<void**>(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
*reinterpret_cast<void**>(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
*reinterpret_cast<void**>(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
*reinterpret_cast<void**>(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
*reinterpret_cast<void**>(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
}
CUDADriverWrapper::~CUDADriverWrapper()
@ -98,6 +96,11 @@ CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) co
return (*_cuGetErrorName)(error, pStr);
}
CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const
{
return (*_cuGetErrorMessage)(error, pStr);
}
CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const
{
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
@ -181,5 +184,4 @@ CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, s
return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount);
}
} // namespace common
} // namespace tensorrt_llm
} // namespace tensorrt_llm::common

View File

@ -17,33 +17,30 @@
#ifndef CUDA_DRIVER_WRAPPER_H
#define CUDA_DRIVER_WRAPPER_H
#include "tensorrt_llm/common/assert.h"
#include <cstdio>
#include <cuda.h>
#include <memory>
#include <mutex>
#define cuErrCheck(stat, wrap) \
{ \
cuErrCheck_((stat), wrap.get(), __FILE__, __LINE__); \
}
namespace tensorrt_llm
{
namespace common
namespace tensorrt_llm::common
{
class CUDADriverWrapper
{
// Use getInstance() instead.
CUDADriverWrapper();
public:
static std::shared_ptr<CUDADriverWrapper> getInstance();
~CUDADriverWrapper();
CUDADriverWrapper(CUDADriverWrapper const&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete;
CUDADriverWrapper(CUDADriverWrapper&&) = delete;
CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete;
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
CUresult cuGetErrorMessage(CUresult error, char const** pStr) const;
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
@ -84,7 +81,10 @@ public:
private:
void* handle;
CUDADriverWrapper();
CUresult (*_cuGetErrorName)(CUresult, char const**);
CUresult (*_cuGetErrorMessage)(CUresult, char const**);
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
CUresult (*_cuModuleUnload)(CUmodule);
@ -108,17 +108,31 @@ private:
CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount);
};
inline void cuErrCheck_(CUresult stat, CUDADriverWrapper const* wrap, char const* file, int line)
template <typename T>
void checkDriver(
T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line)
{
if (stat != CUDA_SUCCESS)
if (result)
{
char const* msg = nullptr;
wrap->cuGetErrorName(stat, &msg);
fprintf(stderr, "CUDA Error: %s %s %d\n", msg, file, line);
char const* errorName = nullptr;
char const* errorMsg = nullptr;
wrap.cuGetErrorName(result, &errorName);
wrap.cuGetErrorMessage(result, &errorMsg);
throw TllmException(
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg));
}
}
} // namespace common
} // namespace tensorrt_llm
} // namespace tensorrt_llm::common
/*
* Macros compliant with TensorRT coding conventions
*/
#define TLLM_CU_CHECK(stat) \
do \
{ \
tensorrt_llm::common::checkDriver( \
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
} while (0)
#endif // CUDA_DRIVER_WRAPPER_H

View File

@ -0,0 +1,54 @@
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <nlohmann/json.hpp>
#include <optional>
namespace nlohmann
{
template <typename T>
struct adl_serializer<std::optional<T>>
{
static void to_json(nlohmann::json& j, std::optional<T> const& opt)
{
if (opt == std::nullopt)
{
j = nullptr;
}
else
{
j = opt.value(); // this will call adl_serializer<T>::to_json which will
// find the free function to_json in T's namespace!
}
}
static void from_json(nlohmann::json const& j, std::optional<T>& opt)
{
if (j.is_null())
{
opt = std::nullopt;
}
else
{
opt = j.template get<T>(); // same as above, but with
// adl_serializer<T>::from_json
}
}
};
} // namespace nlohmann

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8e633c968cb2712a79886a940a07a13f598543c2e936912d9099ef088a240d7c
size 2358334
oid sha256:33f66dba2f3024d979e38cf1aae4d10802c5a1fb0f4c801108c35824339eae5d
size 2419566

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:775cba2bfd47b779944fb2e7ce275d002fba5b5cdefabb409f8a48cc77f157f9
size 2391240
oid sha256:d224780476ce5f398f30ffbfa0d61bbd0aae5cb1538c8d4c0a16cdf8945ba5d3
size 2449532

View File

@ -1,3 +1,3 @@
b4c02793e0859c9a1cde80851235ac8b libtensorrt_llm_executor_static.a
773778beb42d127b4322a92290c497cb libtensorrt_llm_executor_static.pre_cxx11.a
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
ee532edbf35321d4ac0aadf8a3c6a3a5 libtensorrt_llm_executor_static.a
0bf468a19d4c353dcf421fc3e05a9d7d libtensorrt_llm_executor_static.pre_cxx11.a
1c2eb102257f836cd50faf985e693241d7a84dbe commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:7155a2b04357cf1b5ea2cb46a49dd536d18943f4bce57fc8e3761aa11d4df943
size 3440434
oid sha256:9b21e2488bdb5c1e18e7aa129acb18087d031eea4f5b063910081ca09a3041a5
size 3494984

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:ce7093244a223954c718dc3a60a50e452057e378f79f58971989b30e6d858feb
size 3357394
oid sha256:94964aa02020e38e869bf9ca18385ae379c8b9d1819ad02e10b23d8175cc9d82
size 3412104

View File

@ -1,3 +1,3 @@
09d25336d48b4281a27dc0e88e154ed8 libtensorrt_llm_executor_static.a
549f092db03a48c3be836f90b0938347 libtensorrt_llm_executor_static.pre_cxx11.a
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
ba01eba908f38eb582c22c1f822cfedf libtensorrt_llm_executor_static.a
ffe68ec0af94d364ec8db50a24ae0e8c libtensorrt_llm_executor_static.pre_cxx11.a
1c2eb102257f836cd50faf985e693241d7a84dbe commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:06f2682b105be0fc48687afc88480a56f3037e3c885d1a042bc2d6065fc59436
size 22719724
oid sha256:67f59341edab284c309d39f2a0ad39e91f8afe198c4cf6ba838ae7adb54ad01d
size 23192460

View File

@ -1,2 +1,2 @@
45c167501f2f191bf2a3aa6e9d80ce9a tensorrt_llm_executor_static.lib
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
e3cd49147c73b0066dcb759df9556191 tensorrt_llm_executor_static.lib
1c2eb102257f836cd50faf985e693241d7a84dbe commit

View File

@ -30,9 +30,7 @@
#include <unordered_map>
#include <vector>
namespace tensorrt_llm
{
namespace kernels
namespace tensorrt_llm::kernels
{
////////////////////////////////////////////////////////////////////////////////////////////////////
@ -85,19 +83,17 @@ public:
}
else
{
cuErrCheck(mDriver->cuModuleLoadData(&hmod, kernelMeta.mCubin), mDriver);
TLLM_CU_CHECK(mDriver->cuModuleLoadData(&hmod, kernelMeta.mCubin));
mModules.insert(std::make_pair(kernelMeta.mCubin, hmod));
}
FusedMultiHeadAttentionKernelInfo funcInfo;
funcInfo.mMetaInfoIndex = i;
cuErrCheck(
mDriver->cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName), mDriver);
TLLM_CU_CHECK(mDriver->cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName));
if (kernelMeta.mSharedMemBytes >= 48 * 1024)
{
cuErrCheck(mDriver->cuFuncSetAttribute(funcInfo.mDeviceFunction,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, kernelMeta.mSharedMemBytes),
mDriver);
TLLM_CU_CHECK(mDriver->cuFuncSetAttribute(funcInfo.mDeviceFunction,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, kernelMeta.mSharedMemBytes));
}
mFunctions.insert(std::make_pair(hashID(kernelMeta), funcInfo));
int s = static_cast<int>(kernelMeta.mS);
@ -120,9 +116,8 @@ public:
const CUfunction func = findIter->second.mDeviceFunction;
void* kernelParams[] = {&params, nullptr};
cuErrCheck(mDriver->cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
mDriver);
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr));
}
virtual bool checkIfKernelExist(MHARunnerFixedParams params) const = 0;
@ -276,9 +271,8 @@ public:
if (!forceUnroll)
{
cuErrCheck(mDriver->cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
mDriver);
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr));
} // forceunroll = true for flash attention kernels
else if (mSM == kSM_90 && launch_params.flash_attention && launch_params.warp_specialization)
{
@ -327,9 +321,8 @@ public:
}
}
cuErrCheck(mDriver->cuLaunchKernel(func, block_size.x, block_size.y, block_size.z,
kernelMeta.mThreadsPerCTA, 1, 1, kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
mDriver);
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, block_size.x, block_size.y, block_size.z,
kernelMeta.mThreadsPerCTA, 1, 1, kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr));
}
else
{ // forceunroll = true for flash attention kernels
@ -344,9 +337,8 @@ public:
// on Hopper non-flash-attention, we still launch blocks (h, b, steps)
if (mSM == kSM_90 && !launch_params.flash_attention)
{
cuErrCheck(mDriver->cuLaunchKernel(func, params.h, params.b, unroll, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
mDriver);
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, params.h, params.b, unroll, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr));
} // on Ampere/Ada/Volta flash attention, we launch blocks (steps, h, b)
else
{
@ -356,9 +348,8 @@ public:
// For cases exceeding 256 dimensions, the number of CTAs needs to be multiplied.
unroll *= (params.dv + 256 - 1) / 256;
}
cuErrCheck(mDriver->cuLaunchKernel(func, unroll, params.h, params.b, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
mDriver);
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, unroll, params.h, params.b, kernelMeta.mThreadsPerCTA, 1, 1,
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr));
}
}
}
@ -452,5 +443,4 @@ inline FusedMultiHeadAttentionXMMAKernelV2 const* getXMMAKernelsV2(Data_type typ
sMhaKernelMetaInfosV2, sizeof(sMhaKernelMetaInfosV2) / sizeof(sMhaKernelMetaInfosV2[0]), type, sm);
}
} // namespace kernels
} // namespace tensorrt_llm
} // namespace tensorrt_llm::kernels

View File

@ -882,7 +882,7 @@ static __global__ void __launch_bounds__(1024, 1) one_shot_all_reduce_norm_kerne
}
template <typename T>
bool is_lamport_supported(int token_num)
bool is_lamport_supported(int token_num, int hidden_size)
{
static char* disableLamportReduceNormFusionChar = std::getenv("DISABLE_LAMPORT_REDUCE_NORM_FUSION");
bool disableLamportReduceNormFusion = (disableLamportReduceNormFusionChar != nullptr);
@ -901,17 +901,21 @@ bool is_lamport_supported(int token_num)
{
return false;
}
if (hidden_size < details::kLamportHiddenSizeThreshold)
{
return false;
}
return true;
}
bool is_lamport_supported(nvinfer1::DataType dataType, int token_num)
bool is_lamport_supported(nvinfer1::DataType dataType, int token_num, int hidden_size)
{
switch (dataType)
{
case nvinfer1::DataType::kFLOAT: return is_lamport_supported<float>(token_num);
case nvinfer1::DataType::kHALF: return is_lamport_supported<half>(token_num);
case nvinfer1::DataType::kFLOAT: return is_lamport_supported<float>(token_num, hidden_size);
case nvinfer1::DataType::kHALF: return is_lamport_supported<half>(token_num, hidden_size);
#ifdef ENABLE_BF16
case nvinfer1::DataType::kBF16: return is_lamport_supported<__nv_bfloat16>(token_num);
case nvinfer1::DataType::kBF16: return is_lamport_supported<__nv_bfloat16>(token_num, hidden_size);
#endif
default: return false;
}
@ -921,7 +925,7 @@ template <typename T, int RanksPerNode, bool Bias, bool Affine>
void one_shot_all_reduce_norm_kernel_launcher(AllReduceParams& params, cudaStream_t stream)
{
int token_num = params.elts_total / params.fusion_params.hidden_size;
if (is_lamport_supported<T>(token_num))
if (is_lamport_supported<T>(token_num, params.fusion_params.hidden_size))
{
lamport_style_one_shot_all_reduce_norm_kernel_launcher<T, RanksPerNode, Bias, Affine>(params, stream);
}
@ -1568,12 +1572,13 @@ void AllReduceDispatchType(AllReduceParams& params, AllReduceStrategyType strat,
}
}
AllReduceParams AllReduceParams::deserialize(
int64_t* buffer, size_t tpSize, size_t tpRank, nvinfer1::DataType dataType, int token_num, AllReduceFusionOp op)
AllReduceParams AllReduceParams::deserialize(int64_t* buffer, size_t tpSize, size_t tpRank, nvinfer1::DataType dataType,
int token_num, int hidden_size, AllReduceFusionOp op)
{
void* const* buffer_ptrs = reinterpret_cast<void* const*>(buffer);
int flag_offset;
if (op == AllReduceFusionOp::RESIDUAL_RMS_NORM && reduce_fusion::is_lamport_supported(dataType, token_num))
if (op == AllReduceFusionOp::RESIDUAL_RMS_NORM
&& reduce_fusion::is_lamport_supported(dataType, token_num, hidden_size))
{
flag_offset = 0;
}

View File

@ -38,6 +38,7 @@ static constexpr int kWarpSize = 32;
static constexpr int kMaxCtaSize = 1024;
static constexpr int kClusterMaxSize = 8;
static constexpr int kLamportTokenNumThreshold = 16;
static constexpr int kLamportHiddenSizeThreshold = 256;
}; // namespace reduce_fusion::details
// Warning: python definition is in tensorrt_llm/functional.py
@ -103,7 +104,7 @@ struct AllReduceParams
AllReduceFusionParams fusion_params;
static AllReduceParams deserialize(int64_t* buffer, size_t tpSize, size_t tpRank, nvinfer1::DataType dataType,
int token_num, AllReduceFusionOp op);
int token_num, int hidden_size, AllReduceFusionOp op);
};
bool configurationSupported(AllReduceStrategyType algo, size_t msg_size, size_t n_ranks, nvinfer1::DataType type);

View File

@ -703,6 +703,18 @@ extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nq
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
extern unsigned long long xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
// MHA with beamWidth=4
extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_90_cubin[];
@ -829,6 +841,18 @@ extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
extern uint32_t xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
// MHA with beamWidth=4
extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_90_cubin_len;
@ -1258,6 +1282,18 @@ static const struct XQAKernelMetaInfo
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 16, 128, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 32, 128, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 64, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 64, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 64, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 64, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 64, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 64, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 64, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 64, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 64, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 64, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 64, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 64, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
// MHA with beamWidth=4
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 64, true, false, kSM_90, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_90_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_90_cubin_len, "kernel_mha"},
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 128, true, false, kSM_90, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_90_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_90_cubin_len, "kernel_mha"},

View File

@ -21,11 +21,7 @@
#include "tensorrt_llm/common/cudaUtils.h"
#include <cuda_runtime_api.h>
namespace tensorrt_llm
{
namespace kernels
{
namespace jit
namespace tensorrt_llm::kernels::jit
{
CubinObj::CubinObj(void const* buffer_, size_t buffer_size)
@ -135,9 +131,8 @@ void CubinObj::serialize(void* buffer_, size_t buffer_size) const noexcept
void CubinObj::launch(dim3 gridDim, dim3 blockDim, CUstream hStream, void** kernelParams)
{
TLLM_CHECK(mInitialized);
cuErrCheck(mDriver->cuLaunchKernel(mFunction, gridDim.x, gridDim.y, gridDim.z, blockDim.x, blockDim.y, blockDim.z,
mSharedMemBytes, hStream, kernelParams, /*extra=*/nullptr),
mDriver);
TLLM_CU_CHECK(mDriver->cuLaunchKernel(mFunction, gridDim.x, gridDim.y, gridDim.z, blockDim.x, blockDim.y,
blockDim.z, mSharedMemBytes, hStream, kernelParams, /*extra=*/nullptr));
}
void CubinObj::initialize()
@ -146,26 +141,25 @@ void CubinObj::initialize()
{
mDriver = tensorrt_llm::common::CUDADriverWrapper::getInstance();
mModule = nullptr;
cuErrCheck(mDriver->cuModuleLoadData(&mModule, mContent.c_str()), mDriver);
TLLM_CU_CHECK(mDriver->cuModuleLoadData(&mModule, mContent.c_str()));
TLLM_CHECK(mModule != nullptr);
mFunction = nullptr;
cuErrCheck(mDriver->cuModuleGetFunction(&mFunction, mModule, kFuncName), mDriver);
TLLM_CU_CHECK(mDriver->cuModuleGetFunction(&mFunction, mModule, kFuncName));
TLLM_CHECK(mFunction != nullptr);
// Populate mSharedMemBytes.
CUdeviceptr shmem_dev_ptr = 0;
cuErrCheck(mDriver->cuModuleGetGlobal(&shmem_dev_ptr, nullptr, mModule, kSmemName), mDriver);
TLLM_CU_CHECK(mDriver->cuModuleGetGlobal(&shmem_dev_ptr, nullptr, mModule, kSmemName));
TLLM_CHECK(shmem_dev_ptr != 0);
cuErrCheck(mDriver->cuMemcpyDtoH(&mSharedMemBytes, shmem_dev_ptr, sizeof(unsigned int)), mDriver);
TLLM_CU_CHECK(mDriver->cuMemcpyDtoH(&mSharedMemBytes, shmem_dev_ptr, sizeof(unsigned int)));
TLLM_CHECK(mSharedMemBytes > 0);
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */
if (mSharedMemBytes >= 46 * 1024)
{
cuErrCheck(mDriver->cuFuncSetAttribute(
mFunction, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, mSharedMemBytes),
mDriver);
TLLM_CU_CHECK(mDriver->cuFuncSetAttribute(
mFunction, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, mSharedMemBytes));
}
sync_check_cuda_error();
@ -177,11 +171,9 @@ CubinObj::~CubinObj()
{
if (mInitialized)
{
cuErrCheck(mDriver->cuModuleUnload(mModule), mDriver);
TLLM_CU_CHECK(mDriver->cuModuleUnload(mModule));
mInitialized = false;
}
}
} // namespace jit
} // namespace kernels
} // namespace tensorrt_llm
} // namespace tensorrt_llm::kernels::jit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1397678ac1cab957f7d272750035ffd88b9e2b3b9d4f132073119d21c288b5da
size 82262624
oid sha256:53b2ebc1484d068fa60c8e5ad22bf2db40bd84963bb0d2e679bcec9f53b65c5d
size 82318536

View File

@ -1,2 +1,2 @@
5ea3eabf1c58887230ba5ebc583e0d3c libtensorrt_llm_nvrtc_wrapper.so
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
90df70c216d9aa2c85b8b097c853e4ba libtensorrt_llm_nvrtc_wrapper.so
1c2eb102257f836cd50faf985e693241d7a84dbe commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:17481ff01045ac335223451c83943a7f97b5c63ca2ab5da3e71d0909c8f4e68b
size 84578328
oid sha256:64aec9fe985b5dd0d38d9b76ee6f2fde14a183bfe44de9f0148fc482af086a48
size 84643008

View File

@ -1,2 +1,2 @@
270f246f5eccb170a759cda4787216f4 libtensorrt_llm_nvrtc_wrapper.so
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
232f492424a31204a2be2e67be299aef libtensorrt_llm_nvrtc_wrapper.so
1c2eb102257f836cd50faf985e693241d7a84dbe commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:2c9102fab7a800539f1fabb01d3a443abdfb10ccac94ae5704883260088de71d
oid sha256:bed2713947315cf941533dd12b5b98270a2aabd584cc33bc2092be6dbf879959
size 1128448

View File

@ -1,3 +1,3 @@
c5f36e093e875c8ea84523fb1566d986 tensorrt_llm_nvrtc_wrapper.lib
3be88342e596bda98d75e3b2d31ae484 tensorrt_llm_nvrtc_wrapper.dll
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
aaa20992c207e46eab50dd90bcf3c405 tensorrt_llm_nvrtc_wrapper.dll
1c2eb102257f836cd50faf985e693241d7a84dbe commit

View File

@ -33,9 +33,7 @@
using namespace tensorrt_llm::common;
namespace tensorrt_llm
{
namespace kernels
namespace tensorrt_llm::kernels
{
class XQAKernelList
@ -77,13 +75,13 @@ public:
}
else
{
cuErrCheck(mDriver->cuModuleLoadData(&hmod, kernelMeta.mCubin), mDriver);
TLLM_CU_CHECK(mDriver->cuModuleLoadData(&hmod, kernelMeta.mCubin));
mModules.insert(std::make_pair(kernelMeta.mCubin, hmod));
}
XQAKernelFuncInfo funcInfo{};
funcInfo.mMetaInfoIndex = i;
cuErrCheck(mDriver->cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName), mDriver);
TLLM_CU_CHECK(mDriver->cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName));
funcInfo.mSharedMemBytes = getGlobalVar<uint32_t>(mDriver, hmod, "smemSize", true).value();
funcInfo.mKernelType = getGlobalVar<XQAKernelType>(mDriver, hmod, "kernelType", false)
.value_or(XQAKernelType::kAMPERE_WARP_SPECIALIZED);
@ -91,9 +89,8 @@ public:
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */
if (funcInfo.mSharedMemBytes >= 46 * 1024)
{
cuErrCheck(mDriver->cuFuncSetAttribute(funcInfo.mDeviceFunction,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, funcInfo.mSharedMemBytes),
mDriver);
TLLM_CU_CHECK(mDriver->cuFuncSetAttribute(funcInfo.mDeviceFunction,
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, funcInfo.mSharedMemBytes));
}
XQAKernelRuntimeHashKey hash_key{kernelMeta.mKVDataType, kernelMeta.mHeadDim, kernelMeta.mBeamWidth,
kernelMeta.mNumQHeadsOverKV, kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache,
@ -256,9 +253,8 @@ public:
sizeof(int) * xqaParams.batch_size * qSeqLen * xqaParams.num_kv_heads, stream));
sync_check_cuda_error();
}
cuErrCheck(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp,
xqaParams.batch_size, 128, 1, 2, shared_mem_bytes, stream, kernelParams, nullptr),
mDriver);
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp,
xqaParams.batch_size, 128, 1, 2, shared_mem_bytes, stream, kernelParams, nullptr));
}
else
{
@ -299,9 +295,8 @@ public:
{
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
}
cuErrCheck(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads, xqaParams.batch_size, 128, 1,
isGmmaKernel ? 3 : 2, shared_mem_bytes, stream, kernelParams, nullptr),
mDriver);
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads, xqaParams.batch_size, 128,
1, isGmmaKernel ? 3 : 2, shared_mem_bytes, stream, kernelParams, nullptr));
}
sync_check_cuda_error();
@ -417,7 +412,7 @@ bool DecoderXQAImplPrecompiled::shouldUse(XQAParams const& xqaParams, bool forCo
}
bool const isGPTJBeam4Kernel = (xqaParams.head_size == 256 && xqaParams.beam_width == 4 && xqaParams.paged_kv_cache
&& (xqaParams.tokens_per_block == 64 || xqaParams.tokens_per_block == 128));
if (xqaParams.head_size != 128 && xqaParams.head_size != 256 && !isGPTJBeam4Kernel)
if (xqaParams.head_size != 64 && xqaParams.head_size != 128 && xqaParams.head_size != 256 && !isGPTJBeam4Kernel)
{
SUPPORT_RETURN_FALSE("head_size");
}
@ -512,5 +507,4 @@ void DecoderXQAImplPrecompiled::runWithKVBlockArray(
runDispatchBuffer<KVBlockArray>(xqa_params, kv_block_array, stream);
}
} // namespace kernels
} // namespace tensorrt_llm
} // namespace tensorrt_llm::kernels

View File

@ -19,9 +19,7 @@
#include <cstdint>
#include <type_traits>
namespace tensorrt_llm
{
namespace kernels
namespace tensorrt_llm::kernels
{
namespace
@ -75,10 +73,9 @@ CUtensorMap makeTensorMapForPagedKVCache(std::shared_ptr<CUDADriverWrapper> cons
}
}();
cuErrCheck(driver->cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast<void*>(addr), globalDims,
globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE),
driver);
TLLM_CU_CHECK(driver->cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast<void*>(addr), globalDims,
globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
return tensorMap;
}
@ -107,10 +104,9 @@ CUtensorMap makeTensorMapForContiguousKVCache(std::shared_ptr<CUDADriverWrapper>
}
}();
cuErrCheck(driver->cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast<void*>(addr), globalDims,
globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE),
driver);
TLLM_CU_CHECK(driver->cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast<void*>(addr), globalDims,
globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE,
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
return tensorMap;
}
@ -139,5 +135,4 @@ template CUtensorMap makeTensorMapForKVCache(
template CUtensorMap makeTensorMapForKVCache(
std::shared_ptr<CUDADriverWrapper> const&, XQAParams const&, KVLinearBuffer const&);
} // namespace kernels
} // namespace tensorrt_llm
} // namespace tensorrt_llm::kernels

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:b9ddbaa50b4b51d158e163aa3160bfee88a2e0e3c987fa4a883e14066b9c09e2
size 21861322
oid sha256:8032548ca52a51b3245dcff4fd834e02b93a00e61146be5e418aff94d3e655cb
size 21863082

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8cee9ceda43c0321aa2c374383eb00ee8347e5b7aadb57473dc16b2a3ceeef39
size 22133914
oid sha256:71ef05b9741f027279efedeec8f9d598299a1348cc0e37022424645f9efccd22
size 22111930

View File

@ -1,3 +1,3 @@
cc18fb3ea093606954fcfc7247418571 libtensorrt_llm_internal_cutlass_kernels_static.a
c9ae57fa6eb2d950cea0acb08483e978 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
f1820d73fc5cac7fa324d71933e5412a libtensorrt_llm_internal_cutlass_kernels_static.a
a8785db1cc11e3b571bc071d7abec1a8 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
1c2eb102257f836cd50faf985e693241d7a84dbe commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:1f7c5f0a38b061bc373c095c4d4790ffeb23ca6595f9dfc52f81f6a9f772dbf1
size 36622632
oid sha256:4b6917794ec6e67989fdcd0af3cc4d84713f3d8d4dcd822d2df2272117c66d6b
size 36626184

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:0c53931240f61f25a54aa74835c732440366c54c70b43372b5ad8f1b0a140562
size 36094714
oid sha256:e6d2f3c25a8ce88917ba512eba804f14827703fab6f9ac8d63043e2d95b6b281
size 36080026

View File

@ -1,3 +1,3 @@
fb560fdda660a1fb8a78ca605329697f libtensorrt_llm_internal_cutlass_kernels_static.a
0d2d1e33a0e588fc1653feaccecff418 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
9e6ff6d826caeea1e6e19c71f5d0986b libtensorrt_llm_internal_cutlass_kernels_static.a
2fca4e76d5f21089f00b1d39f624b80a libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
1c2eb102257f836cd50faf985e693241d7a84dbe commit

View File

@ -1,3 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:788694476443028c72cede1228a9beb9be431d5b871bb4e076fd2b3ddd184455
size 2669966
oid sha256:c2f34df6d47b7b2b6629358bb03b33eb193db067188e8b980598027b0ff85392
size 2669968

View File

@ -1,2 +1,2 @@
aebeb3c0b4864efa09724a47638098c2 tensorrt_llm_internal_cutlass_kernels_static.lib
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
95c2f50347d4de94e2e09cbf0cf99582 tensorrt_llm_internal_cutlass_kernels_static.lib
1c2eb102257f836cd50faf985e693241d7a84dbe commit

View File

@ -163,8 +163,10 @@ __launch_bounds__(TPB) __global__ void moeTopK(float const* inputs_after_softmax
for (int prior_k = startk; prior_k < k_idx; ++prior_k)
{
int const prior_winning_expert = indices[k * block_row + prior_k];
int prior_winning_expert = indices[k * block_row + prior_k];
// Adjust the selected index to correct for the expert parallel transformation
prior_winning_expert = prior_winning_expert >= num_experts ? prior_winning_expert - num_experts
: prior_winning_expert + start_expert;
if (prior_winning_expert == expert)
{
inp_kvp = thread_kvp;

View File

@ -0,0 +1,74 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
namespace tensorrt_llm
{
namespace kernels
{
namespace qserve
{
struct ParamsPerGroup
{
int8_t const* A;
int8_t const* B;
int8_t const* s2_zeros;
int8_t const* s2_scales;
half const* s1_scales;
half const* act_scales;
half* C;
int m;
int n;
int k;
};
struct ParamsPerChannel
{
int8_t const* A;
int8_t const* B;
half const* s1_scales;
half const* s1_szeros;
half const* act_sums;
half const* act_scales;
half* C;
int m;
int n;
int k;
};
class QServeGemmRunner
{
public:
void gemmPerGroup(ParamsPerGroup const& params, cudaStream_t stream);
void gemmPerChannel(ParamsPerChannel const& params, cudaStream_t stream);
// We do not use workspace for now.
// char* workspacePtr, const size_t workspaceBytes, cudaStream_t stream);
// Returns desired workspace size in bytes.
size_t getWorkspaceSize(int const m, int const n, int const k);
// virtual std::vector<tkc::CutlassGemmConfig> getConfigs() const = 0;
};
} // namespace qserve
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,618 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Implemented by Haotian Tang and Shang Yang.
// @article{lin2024qserve,
// title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving},
// author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and
// Han, Song}, journal={arXiv preprint arXiv:2405.04532}, year={2024}
// }
#include "qserveGemm.h"
#include <cuda_fp16.h>
#include <cuda_pipeline_primitives.h>
namespace tensorrt_llm
{
namespace kernels
{
namespace qserve
{
#define OP_M 16
#define OP_N 8
#define OP_K 32
#define INTRIN_M 16
#define INTRIN_N 16
#define INTRIN_K 32
#define WARP_SIZE 32
#define SMEM_PAD_A 0
#define SMEM_PAD_B 0
#define PACK_SIZE 16
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B"
#else
#define L2_CACHEHINT(size)
#endif
#define KERNEL_LAUNCH_CODE \
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
constexpr int kSmemByteSize \
= ((CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / 2) * STAGES + SCALES_SMEM_SIZE) \
* sizeof(int8_t); \
if (kSmemByteSize >= 99 * 1024) \
{ \
printf( \
"This kernel requires %d Bytes of shared memory, which exceeds " \
"device limit.\n", \
kSmemByteSize); \
return; \
} \
int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \
int num_blocks_n = num_out_channels / CTA_N / 1; \
const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \
const int tile_shift = 1 << log_tile; \
dim3 num_blocks(num_blocks_n* tile_shift, (num_blocks_m + tile_shift - 1) / tile_shift); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
auto kernel_func = dense_kernel0<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>; \
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize, stream>>>(in_feats, kernel, wscales, ascales, w_szs, \
a_ssums, out_feats, num_in_feats, num_out_channels, num_in_channels);
template <int N>
inline __host__ __device__ int get_log_tile(int n)
{
if (N >= 8 && n >= 6)
return 3;
else if (N >= 4 && n >= 3)
return 2;
else if (N >= 2 && n >= 2)
return 1;
else
return 0;
}
inline __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile)
{
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
}
inline __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
{
uint32_t smem_int_ptr;
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, "
"smem_ptr; }\n"
: "=r"(smem_int_ptr)
: "l"(ptr));
return smem_int_ptr;
}
inline __device__ void ldmatrix_m8n8_x4_b16(int8_t* shared_warp, int ax0_0, uint32_t addr)
{
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[0]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[1]),
"=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[2]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[3])
: "r"(addr));
}
inline __device__ void ldmatrix_m8n8_x4_trans_b16(int8_t* shared_warp, int ax0_0, uint32_t addr)
{
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[0]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[1]),
"=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[2]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[3])
: "r"(addr));
}
// function from lmdeploy
inline __device__ void cp_async_cg_A(uint32_t smem_int_ptr, uint4 const* __restrict__ src, bool mask)
{
int const cp_size = 16;
asm volatile("{"
" .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
"}" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
}
__device__ inline void mma_m16n8k32(void* C_warp, void* A_shared_warp, void* B_shared_warp)
{
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=r"(((int*) C_warp)[0]), "=r"(((int*) C_warp)[1]), "=r"(((int*) C_warp)[2]), "=r"(((int*) C_warp)[3])
: "r"(((unsigned*) A_shared_warp)[0]), "r"(((unsigned*) A_shared_warp)[1]), "r"(((unsigned*) A_shared_warp)[2]),
"r"(((unsigned*) A_shared_warp)[3]), "r"(((unsigned*) B_shared_warp)[0]), "r"(((unsigned*) B_shared_warp)[1]),
"r"(((int*) C_warp)[0]), "r"(((int*) C_warp)[1]), "r"(((int*) C_warp)[2]), "r"(((int*) C_warp)[3]));
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ inline void global_to_share_one_stage_A(int8_t const* src, int8_t* dst, int global_ncols, int cta_offset_m,
int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask, bool* preds)
{
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
constexpr int partial_global_iters = total_global_iters / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (CTA_SIZE * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
int8_t* dst_hoisted = dst;
int8_t const* src_hoisted = src + global_iter_k * CTA_K;
if (mask)
{
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
void* dst_ptr = (void*) (dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol);
uint4* src_ptr = (uint4*) (src_hoisted + global_iter * cta_step_m_or_n * global_ncols);
// *dst_ptr = *src_ptr;
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, preds[global_iter]);
}
else
{
if (preds[global_iter])
*(uint4*) dst_ptr = *src_ptr;
}
}
}
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ inline void global_to_share_one_stage_B(int8_t const* src, int8_t* dst, int global_ncols, int cta_offset_m,
int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int total_global_iters = (CTA_N * CTA_K) / 32 / CTA_SIZE;
constexpr int NUM_WARPS = CTA_SIZE / WARP_SIZE;
constexpr int warps_per_row = CTA_K / 32;
constexpr int cta_step_m_or_n = NUM_WARPS / warps_per_row;
constexpr int kSmemCol = CTA_K;
int8_t* dst_hoisted = dst;
int8_t const* src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE;
#pragma unroll
for (int global_iter = 0; global_iter < total_global_iters; ++global_iter)
{
void* dst_ptr = (void*) (dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol * PACK_SIZE);
uint4* src_ptr = (uint4*) (src_hoisted + global_iter * cta_step_m_or_n * global_ncols * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, mask);
}
else
{
if (mask)
*(uint4*) dst_ptr = *src_ptr;
}
}
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ inline void global_to_share_one_stage_zeros(int8_t const* src, int8_t* dst, int global_ncols,
int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = global_iter_k * CTA_K / G;
void* dst_ptr = (void*) (dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4* src_ptr = (uint4*) (src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
}
else
{
if (local_mask)
{
*(uint4*) dst_ptr = *src_ptr;
}
}
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES>
__device__ inline void share_to_reg_one_stage_A(
int8_t* src, int8_t* dst, int warp_offset_m, int warp_offset_n, int k_0_1, int shared_iters)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
int ld_col = (k_0_1 * INTRIN_K + (threadIdx.x / 16) * 16) / PACK_SIZE;
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
int ld_row = warp_offset_m + shared_iter * INTRIN_M + (threadIdx.x % 16);
int ld_col_swizzled = ld_col ^ (ld_row / 2) & 3;
void* addr_ptr = (void*) (src + ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
template <int WARP_K, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ inline void share_to_reg_one_stage_B(int8_t* src, int8_t* dst, int8_t* zeros, int8_t* scales_i8,
int warp_offset_m, int warp_offset_n, int k_0_0, int k_0_1, int shared_iters)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
uint4 loaded = *((uint4*) (src) + warp_offset_n / 32 * kSmemCol + shared_iter * 32 / 32 * kSmemCol
+ k_0_1 * INTRIN_K + threadIdx.x);
auto ptr = (uint32_t*) dst + shared_iter * 8;
ptr[0] = loaded.x & 0x0F0F0F0F;
ptr[4] = (loaded.x & 0xF0F0F0F0) >> 4;
ptr[2] = loaded.y & 0x0F0F0F0F;
ptr[6] = (loaded.y & 0xF0F0F0F0) >> 4;
ptr[1] = loaded.z & 0x0F0F0F0F;
ptr[5] = (loaded.z & 0xF0F0F0F0) >> 4;
ptr[3] = loaded.w & 0x0F0F0F0F;
ptr[7] = (loaded.w & 0xF0F0F0F0) >> 4;
}
}
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
__global__ void dense_kernel0(int8_t const* __restrict__ A, int8_t const* __restrict__ B,
half2 const* __restrict__ wscales, half const* __restrict__ ascales, half2 const* __restrict__ w_szs,
half const* __restrict__ a_ssums, half* __restrict__ C, int M, int64_t N, int64_t K)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
constexpr int SPLITK = 1;
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
constexpr int SLICES = CTA_K / WARP_K;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_n = blockIdx.x;
int blockIdx_m = blockIdx.y;
int const log_tile = get_log_tile<8>((M + CTA_M - 1) / CTA_M);
uint2 const block_idx_mapping = get_block_idx_mapping(blockIdx_n, blockIdx_m, log_tile);
blockIdx_n = block_idx_mapping.x;
blockIdx_m = block_idx_mapping.y;
int C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N * kSmemPadKB / 2;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
constexpr int kSmemSizeScales = CTA_N * STAGES;
extern __shared__ int8_t mem_shared[];
int8_t* A_shared = mem_shared;
int8_t* B_shared = mem_shared + kSmemSizeA;
int8_t* zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB;
int8_t* scales_i8_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales;
int8_t A_shared_warp_[2][WARP_M * WARP_K / WARP_SIZE];
int8_t B_shared_warp_[2][WARP_N * WARP_K / WARP_SIZE];
constexpr int A_total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
constexpr int B_total_global_iters = (CTA_N * CTA_K) / PACK_SIZE / CTA_SIZE;
constexpr int A_src_step_m = (CTA_SIZE * PACK_SIZE) / CTA_K;
constexpr int A_warp_step_m = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int A_threads_per_row = CTA_K / PACK_SIZE;
constexpr int B_warps_per_row = CTA_K / 32;
constexpr int B_src_step_n = NUM_WARPS / B_warps_per_row;
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int warp_mn = threadIdx.y % NUM_WARPS_MN;
int slice_id = threadIdx.y / NUM_WARPS_MN;
int warp_offset_m = (warp_mn % (CTA_M / WARP_M)) * WARP_M;
int warp_offset_n = (warp_mn / (CTA_M / WARP_M)) * WARP_N;
int warp_offset_k = slice_id * WARP_K;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
C_warp[i] = 0;
int gemm_iters = (K + CTA_K - 1) / CTA_K;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
int A_hoisted_row = threadIdx.y * A_warp_step_m + (threadIdx.x / A_threads_per_row);
int A_hoisted_col = (threadIdx.x % A_threads_per_row);
int A_hoisted_col_swizzled = A_hoisted_col ^ (A_hoisted_row / 2) & 3;
int8_t* A_shared_hoisted = A_shared + A_hoisted_row * kSmemPadKA + A_hoisted_col_swizzled * PACK_SIZE;
int8_t* B_shared_hoisted = B_shared + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE
+ (threadIdx.y / B_warps_per_row) * kSmemPadKB * PACK_SIZE + threadIdx.x * PACK_SIZE;
int8_t const* A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + A_hoisted_col * PACK_SIZE;
int8_t const* B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE
+ (threadIdx.y / B_warps_per_row) * K * PACK_SIZE + threadIdx.x * PACK_SIZE;
bool A_g2s_preds[A_total_global_iters];
#pragma unroll
for (int i = 0; i < A_total_global_iters; i++)
{
A_g2s_preds[i] = (cta_offset_m + A_hoisted_row + i * A_src_step_m) < M;
}
int* C_shared = reinterpret_cast<int*>(mem_shared);
#pragma unroll
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)
{
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A_hoisted,
A_shared_hoisted + k_0_0_ld * kSmemSizeAPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true,
A_g2s_preds);
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B_hoisted,
B_shared_hoisted + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(
A_shared + warp_offset_k, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0, WARP_M / INTRIN_M);
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(B_shared + warp_offset_k * PACK_SIZE,
B_shared_warp_[0], zeros_shared, scales_i8_shared, warp_offset_m, warp_offset_n, 0, 0, WARP_N / 32);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)
{
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
int8_t* A_shared_this_compute_stage;
int8_t* B_shared_this_compute_stage;
int8_t* zeros_shared_this_compute_stage;
int8_t* scales_i8_shared_this_compute_stage;
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)
{
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage + warp_offset_k;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage + warp_offset_k * PACK_SIZE;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage) *CTA_N;
scales_i8_shared_this_compute_stage = scales_i8_shared + (compute_stage) *CTA_N;
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(A_shared_this_compute_stage,
A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS,
WARP_M / INTRIN_M);
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(B_shared_this_compute_stage,
B_shared_warp_[(iter_k + 1) % 2], zeros_shared_this_compute_stage, scales_i8_shared_this_compute_stage,
warp_offset_m, warp_offset_n, k_0_0 + (iter_k == SHARED_K_ITERS - 1), (iter_k + 1) % SHARED_K_ITERS,
WARP_N / 32);
int8_t* A_shared_warp = A_shared_warp_[iter_k % 2];
int8_t* B_shared_warp = B_shared_warp_[iter_k % 2];
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)
{
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)
{
mma_m16n8k32((void*) (C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8),
(void*) (A_shared_warp + i_0_3 * 16), (void*) (B_shared_warp + j_0_4 * 16));
mma_m16n8k32((void*) (C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4),
(void*) (A_shared_warp + i_0_3 * 16), (void*) (B_shared_warp + j_0_4 * 16 + 8));
}
}
if (iter_k < SHARED_K_ITERS - 1)
{
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A_hoisted,
A_shared_hoisted + ld_stage * kSmemSizeAPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k,
k_0_0_ld < gemm_iters, A_g2s_preds);
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B_hoisted,
B_shared_hoisted + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k,
k_0_0_ld < gemm_iters);
}
if (iter_k == SHARED_K_ITERS - 2)
{
if constexpr (STAGES == 1 && SHARED_K_ITERS > 2)
{
__syncthreads();
}
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A_hoisted,
A_shared_hoisted + ld_stage * kSmemSizeAPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld,
iter_k + 1, k_0_0_ld < gemm_iters, A_g2s_preds);
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B_hoisted,
B_shared_hoisted + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld,
iter_k + 1, k_0_0_ld < gemm_iters);
if constexpr (STAGES > 1)
{
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
}
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
if constexpr (SLICES > 1)
{
#pragma unroll
for (int z = 0; z < SLICES; ++z)
{
if (slice_id == z)
{
#pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
#pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
#pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
{
if (z > 0)
{
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]
+= C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n
+ ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N
+ (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
}
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16
+ ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8
+ (local_id % 2) + (threadIdx.x % 4) * 2]
= C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
};
}
}
}
__syncthreads();
}
if (slice_id == 0)
{
#pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
#pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
#pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
{
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]
= C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16
+ ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8
+ (local_id % 2) + (threadIdx.x % 4) * 2];
};
}
}
}
}
int row_wb_thd = cta_offset_m + warp_offset_m + (threadIdx.x / 4);
int col_wb_thd = cta_offset_n + warp_offset_n + (threadIdx.x % 4) * 2;
if (slice_id == 0)
{
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
int row_wb_1 = row_wb_thd + ax0_0_1 * OP_M;
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
int col_wb_1 = col_wb_thd + ax1_0_1 * 16;
int* C_warp_local = C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8;
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int row_wb = row_wb_1 + (local_id % 4) / 2 * 8;
if (row_wb < M)
{
int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2);
float2 wscale = __half22float2(*(wscales + col_wb / 2));
float2 w_sz = __half22float2(*(w_szs + col_wb / 2));
float ascale = __half2float(ascales[row_wb]);
float a_ssum = __half2float(a_ssums[row_wb]);
float2 psums = make_float2(
__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1]));
psums.x = psums.x * wscale.x * ascale - w_sz.x * a_ssum;
psums.y = psums.y * wscale.y * ascale - w_sz.y * a_ssum;
*reinterpret_cast<half2*>(C + row_wb * N + col_wb) = __float22half2_rn(psums);
}
};
}
}
}
#endif
}
void qserveGemmPerChannelLaunch(ParamsPerChannel const& params, cudaStream_t stream)
{
auto in_feats = params.A;
auto kernel = params.B;
auto w_szs = reinterpret_cast<half2 const*>(params.s1_szeros);
auto a_ssums = params.act_sums;
auto wscales = reinterpret_cast<half2 const*>(params.s1_scales);
auto ascales = params.act_scales;
auto out_feats = params.C;
int num_out_feats = params.m;
int num_out_channels = params.n;
int num_in_feats = params.m;
int num_in_channels = params.k;
constexpr int G = 128;
if (num_out_feats > 256)
{
constexpr int CTA_M = 128;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats >= 128)
{
constexpr int CTA_M = 64;
constexpr int CTA_N = 64;
constexpr int CTA_K = 64;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else
{
constexpr int CTA_M = 32;
constexpr int CTA_N = 64;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
}
}
void QServeGemmRunner::gemmPerChannel(ParamsPerChannel const& params, cudaStream_t stream)
{
qserveGemmPerChannelLaunch(params, stream);
}
} // namespace qserve
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -0,0 +1,677 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Implemented by Haotian Tang and Shang Yang.
// @article{lin2024qserve,
// title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving},
// author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and
// Han, Song}, journal={arXiv preprint arXiv:2405.04532}, year={2024}
// }
#include "qserveGemm.h"
#include <cuda_fp16.h>
#include <cuda_pipeline_primitives.h>
namespace tensorrt_llm
{
namespace kernels
{
namespace qserve
{
#define OP_M 16
#define OP_N 8
#define OP_K 32
#define INTRIN_M 16
#define INTRIN_N 16
#define INTRIN_K 32
#define WARP_SIZE 32
#define SMEM_PAD_A 0
#define SMEM_PAD_B 0
#define PACK_SIZE 16
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
#define L2_CACHEHINT(size) ".L2::" #size "B"
#else
#define L2_CACHEHINT(size)
#endif
#define KERNEL_LAUNCH_CODE \
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
constexpr int kSmemByteSize \
= ((CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / 2) * STAGES + SCALES_SMEM_SIZE) \
* sizeof(int8_t); \
if (kSmemByteSize >= 99 * 1024) \
{ \
printf( \
"This kernel requires %d Bytes of shared memory, which exceeds " \
"device limit.\n", \
kSmemByteSize); \
return; \
} \
int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \
int num_blocks_n = num_out_channels / CTA_N / 1; \
const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \
const int tile_shift = 1 << log_tile; \
dim3 num_blocks(num_blocks_n* tile_shift, (num_blocks_m + tile_shift - 1) / tile_shift); \
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
auto kernel_func = dense_kernel0<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>; \
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize, stream>>>(in_feats, kernel, zeros, scales_i8, wscales, \
ascales, out_feats, num_in_feats, num_out_channels, num_in_channels);
template <int N>
inline __host__ __device__ int get_log_tile(int n)
{
if (N >= 8 && n >= 6)
return 3;
else if (N >= 4 && n >= 3)
return 2;
else if (N >= 2 && n >= 2)
return 1;
else
return 0;
}
inline __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile)
{
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
}
inline __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
{
uint32_t smem_int_ptr;
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, "
"smem_ptr; }\n"
: "=r"(smem_int_ptr)
: "l"(ptr));
return smem_int_ptr;
}
inline __device__ void ldmatrix_m8n8_x4_b16(int8_t* shared_warp, int ax0_0, uint32_t addr)
{
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[0]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[1]),
"=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[2]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[3])
: "r"(addr));
}
inline __device__ void ldmatrix_m8n8_x4_trans_b16(int8_t* shared_warp, int ax0_0, uint32_t addr)
{
asm volatile(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];"
: "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[0]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[1]),
"=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[2]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[3])
: "r"(addr));
}
// function from lmdeploy
inline __device__ void cp_async_cg_A(uint32_t smem_int_ptr, uint4 const* __restrict__ src, bool mask)
{
int const cp_size = 16;
asm volatile("{"
" .reg .pred p;"
" setp.ne.b32 p, %0, 0;"
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
"}" ::"r"((int)mask),
"r"(smem_int_ptr),
"l"(src),
"n"(cp_size));
}
__device__ inline void mma_m16n8k32(void* C_warp, void* A_shared_warp, void* B_shared_warp)
{
asm volatile(
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
: "=r"(((int*) C_warp)[0]), "=r"(((int*) C_warp)[1]), "=r"(((int*) C_warp)[2]), "=r"(((int*) C_warp)[3])
: "r"(((unsigned*) A_shared_warp)[0]), "r"(((unsigned*) A_shared_warp)[1]), "r"(((unsigned*) A_shared_warp)[2]),
"r"(((unsigned*) A_shared_warp)[3]), "r"(((unsigned*) B_shared_warp)[0]), "r"(((unsigned*) B_shared_warp)[1]),
"r"(((int*) C_warp)[0]), "r"(((int*) C_warp)[1]), "r"(((int*) C_warp)[2]), "r"(((int*) C_warp)[3]));
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ inline void global_to_share_one_stage_A(int8_t const* src, int8_t* dst, int global_ncols, int cta_offset_m,
int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask, bool* preds)
{
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
constexpr int partial_global_iters = total_global_iters / SHARED_K_ITERS;
constexpr int cta_step_m_or_n = (CTA_SIZE * PACK_SIZE) / CTA_K;
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int threads_per_row = CTA_K / PACK_SIZE;
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
int8_t* dst_hoisted = dst;
int8_t const* src_hoisted = src + global_iter_k * CTA_K;
if (mask)
{
#pragma unroll
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
{
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
void* dst_ptr = (void*) (dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol);
uint4* src_ptr = (uint4*) (src_hoisted + global_iter * cta_step_m_or_n * global_ncols);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, preds[global_iter]);
}
else
{
if (preds[global_iter])
*(uint4*) dst_ptr = *src_ptr;
}
}
}
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
__device__ inline void global_to_share_one_stage_B(int8_t const* src, int8_t* dst, int global_ncols, int cta_offset_m,
int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int total_global_iters = (CTA_N * CTA_K) / 32 / CTA_SIZE;
constexpr int NUM_WARPS = CTA_SIZE / WARP_SIZE;
constexpr int warps_per_row = CTA_K / 32;
constexpr int cta_step_m_or_n = NUM_WARPS / warps_per_row;
constexpr int kSmemCol = CTA_K;
int8_t* dst_hoisted = dst;
int8_t const* src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE;
#pragma unroll
for (int global_iter = 0; global_iter < total_global_iters; ++global_iter)
{
void* dst_ptr = (void*) (dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol * PACK_SIZE);
uint4* src_ptr = (uint4*) (src_hoisted + global_iter * cta_step_m_or_n * global_ncols * PACK_SIZE);
if constexpr (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, mask);
}
else
{
if (mask)
*(uint4*) dst_ptr = *src_ptr;
}
}
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ inline void global_to_share_one_stage_zeros(int8_t const* src, int8_t* dst, int global_ncols,
int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
{
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
constexpr int threads_per_row = CTA_N / PACK_SIZE;
constexpr int kSmemCol = CTA_N;
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
int g_idx = global_iter_k * CTA_K / G;
void* dst_ptr = (void*) (dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
uint4 const* src_ptr
= (uint4 const*) (src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
if (STAGES > 1)
{
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
cp_async_cg_A(addr, src_ptr, local_mask);
}
else
{
if (local_mask)
{
*(uint4*) dst_ptr = *src_ptr;
}
}
}
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES>
__device__ inline void share_to_reg_one_stage_A(
int8_t* src, int8_t* dst, int warp_offset_m, int warp_offset_n, int k_0_1, int shared_iters)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
int ld_col = (k_0_1 * INTRIN_K + (threadIdx.x / 16) * 16) / PACK_SIZE;
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
int ld_row = warp_offset_m + shared_iter * INTRIN_M + (threadIdx.x % 16);
int ld_col_swizzled = ld_col ^ (ld_row / 2) & 3;
void* addr_ptr = (void*) (src + ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE);
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
}
}
template <int WARP_K, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
__device__ inline void share_to_reg_one_stage_B(int8_t* src, int8_t* dst, int8_t* zeros, int8_t* scales_i8,
int warp_offset_m, int warp_offset_n, int k_0_0, int k_0_1, int shared_iters)
{
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
#pragma unroll
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
{
uint4 loaded = *((uint4*) (src) + warp_offset_n / 32 * kSmemCol + shared_iter * 32 / 32 * kSmemCol
+ k_0_1 * INTRIN_K + threadIdx.x);
uint32_t loaded_0 = loaded.x & 0x0F0F0F0F;
uint32_t loaded_4 = (loaded.x & 0xF0F0F0F0) >> 4;
uint32_t loaded_2 = loaded.y & 0x0F0F0F0F;
uint32_t loaded_6 = (loaded.y & 0xF0F0F0F0) >> 4;
uint32_t loaded_1 = loaded.z & 0x0F0F0F0F;
uint32_t loaded_5 = (loaded.z & 0xF0F0F0F0) >> 4;
uint32_t loaded_3 = loaded.w & 0x0F0F0F0F;
uint32_t loaded_7 = (loaded.w & 0xF0F0F0F0) >> 4;
auto ptr = (uint32_t*) dst + shared_iter * 8;
int scales_zeros_offset = warp_offset_n + (threadIdx.x / 4) * 4 + shared_iter * 32;
uint32_t packed_scales = *reinterpret_cast<uint32_t*>(scales_i8 + scales_zeros_offset);
uint32_t packed_zeros = *reinterpret_cast<uint32_t*>(zeros + scales_zeros_offset);
uint32_t scale_0 = packed_scales & 0xFF;
uint32_t zero_point_0 = __byte_perm(packed_zeros, 0, 0x00000000);
uint32_t ptr_0 = loaded_0 * scale_0;
uint32_t ptr_1 = loaded_1 * scale_0;
ptr[0] = __vadd4(ptr_0, zero_point_0);
ptr[1] = __vadd4(ptr_1, zero_point_0);
uint32_t scale_1 = (packed_scales & 0xFF00) >> 8;
uint32_t zero_point_1 = __byte_perm(packed_zeros, 0, 0x00001111);
uint32_t ptr_2 = loaded_2 * scale_1;
uint32_t ptr_3 = loaded_3 * scale_1;
ptr[2] = __vadd4(ptr_2, zero_point_1);
ptr[3] = __vadd4(ptr_3, zero_point_1);
uint32_t scale_2 = (packed_scales & 0xFF0000) >> 16;
uint32_t zero_point_2 = __byte_perm(packed_zeros, 0, 0x00002222);
uint32_t ptr_4 = loaded_4 * scale_2;
uint32_t ptr_5 = loaded_5 * scale_2;
ptr[4] = __vadd4(ptr_4, zero_point_2);
ptr[5] = __vadd4(ptr_5, zero_point_2);
uint32_t scale_3 = (packed_scales & 0xFF000000) >> 24;
uint32_t zero_point_3 = __byte_perm(packed_zeros, 0, 0x00003333);
uint32_t ptr_6 = loaded_6 * scale_3;
uint32_t ptr_7 = loaded_7 * scale_3;
ptr[6] = __vadd4(ptr_6, zero_point_3);
ptr[7] = __vadd4(ptr_7, zero_point_3);
}
}
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
__global__ void dense_kernel0(int8_t const* __restrict__ A, int8_t const* __restrict__ B,
int8_t const* __restrict__ zeros, int8_t const* __restrict__ scales_i8, half2 const* __restrict__ wscales,
half const* __restrict__ ascales, half* __restrict__ C, int M, int64_t N, int64_t K)
{
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
constexpr int SPLITK = 1;
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
constexpr int SLICES = CTA_K / WARP_K;
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
int blockIdx_n = blockIdx.x;
int blockIdx_m = blockIdx.y;
int const log_tile = get_log_tile<8>((M + CTA_M - 1) / CTA_M);
uint2 const block_idx_mapping = get_block_idx_mapping(blockIdx_n, blockIdx_m, log_tile);
blockIdx_n = block_idx_mapping.x;
blockIdx_m = block_idx_mapping.y;
int C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
constexpr int kSmemSizeBPerStage = CTA_N * kSmemPadKB / 2;
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
constexpr int kSmemSizeScales = CTA_N * STAGES;
extern __shared__ int8_t mem_shared[];
int8_t* A_shared = mem_shared;
int8_t* B_shared = mem_shared + kSmemSizeA;
int8_t* zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB;
int8_t* scales_i8_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales;
int8_t A_shared_warp_[2][WARP_M * WARP_K / WARP_SIZE];
int8_t B_shared_warp_[2][WARP_N * WARP_K / WARP_SIZE];
constexpr int A_total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
constexpr int B_total_global_iters = (CTA_N * CTA_K) / PACK_SIZE / CTA_SIZE;
constexpr int A_src_step_m = (CTA_SIZE * PACK_SIZE) / CTA_K;
constexpr int A_warp_step_m = (WARP_SIZE * PACK_SIZE) / CTA_K;
constexpr int A_threads_per_row = CTA_K / PACK_SIZE;
constexpr int B_warps_per_row = CTA_K / 32;
constexpr int B_src_step_n = NUM_WARPS / B_warps_per_row;
int cta_offset_m = blockIdx_m * CTA_M;
int cta_offset_n = blockIdx_n * CTA_N;
int warp_mn = threadIdx.y % NUM_WARPS_MN;
int slice_id = threadIdx.y / NUM_WARPS_MN;
int warp_offset_m = (warp_mn % (CTA_M / WARP_M)) * WARP_M;
int warp_offset_n = (warp_mn / (CTA_M / WARP_M)) * WARP_N;
int warp_offset_k = slice_id * WARP_K;
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
C_warp[i] = 0;
int gemm_iters = (K + CTA_K - 1) / CTA_K;
int k_0_0_ld = 0;
int k_0_0 = 0;
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
int A_hoisted_row = threadIdx.y * A_warp_step_m + (threadIdx.x / A_threads_per_row);
int A_hoisted_col = (threadIdx.x % A_threads_per_row);
int A_hoisted_col_swizzled = A_hoisted_col ^ (A_hoisted_row / 2) & 3;
int8_t* A_shared_hoisted = A_shared + A_hoisted_row * kSmemPadKA + A_hoisted_col_swizzled * PACK_SIZE;
int8_t* B_shared_hoisted = B_shared + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE
+ (threadIdx.y / B_warps_per_row) * kSmemPadKB * PACK_SIZE + threadIdx.x * PACK_SIZE;
int8_t const* A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + A_hoisted_col * PACK_SIZE;
int8_t const* B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE
+ (threadIdx.y / B_warps_per_row) * K * PACK_SIZE + threadIdx.x * PACK_SIZE;
bool A_g2s_preds[A_total_global_iters];
#pragma unroll
for (int i = 0; i < A_total_global_iters; i++)
{
A_g2s_preds[i] = (cta_offset_m + A_hoisted_row + i * A_src_step_m) < M;
}
int* C_shared = reinterpret_cast<int*>(mem_shared);
#pragma unroll
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)
{
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A_hoisted,
A_shared_hoisted + k_0_0_ld * kSmemSizeAPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true,
A_g2s_preds);
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B_hoisted,
B_shared_hoisted + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
zeros, zeros_shared + (k_0_0_ld) *CTA_N, N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters);
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(scales_i8,
scales_i8_shared + (k_0_0_ld) *CTA_N, N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters);
if constexpr (STAGES > 1)
__pipeline_commit();
}
if constexpr (STAGES > 1)
__pipeline_wait_prior(STAGES - 2);
__syncthreads();
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(
A_shared + warp_offset_k, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0, WARP_M / INTRIN_M);
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(B_shared + warp_offset_k * PACK_SIZE,
B_shared_warp_[0], zeros_shared, scales_i8_shared, warp_offset_m, warp_offset_n, 0, 0, WARP_N / 32);
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)
{
int ld_stage = k_0_0_ld % STAGES;
int compute_stage = k_0_0 % STAGES;
int8_t* A_shared_this_compute_stage;
int8_t* B_shared_this_compute_stage;
int8_t* zeros_shared_this_compute_stage;
int8_t* scales_i8_shared_this_compute_stage;
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)
{
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage + warp_offset_k;
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage + warp_offset_k * PACK_SIZE;
zeros_shared_this_compute_stage = zeros_shared + (compute_stage) *CTA_N;
scales_i8_shared_this_compute_stage = scales_i8_shared + (compute_stage) *CTA_N;
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(A_shared_this_compute_stage,
A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS,
WARP_M / INTRIN_M);
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(B_shared_this_compute_stage,
B_shared_warp_[(iter_k + 1) % 2], zeros_shared_this_compute_stage, scales_i8_shared_this_compute_stage,
warp_offset_m, warp_offset_n, k_0_0 + (iter_k == SHARED_K_ITERS - 1), (iter_k + 1) % SHARED_K_ITERS,
WARP_N / 32);
int8_t* A_shared_warp = A_shared_warp_[iter_k % 2];
int8_t* B_shared_warp = B_shared_warp_[iter_k % 2];
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)
{
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)
{
mma_m16n8k32((void*) (C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8),
(void*) (A_shared_warp + i_0_3 * 16), (void*) (B_shared_warp + j_0_4 * 16));
mma_m16n8k32((void*) (C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4),
(void*) (A_shared_warp + i_0_3 * 16), (void*) (B_shared_warp + j_0_4 * 16 + 8));
}
}
if (iter_k < SHARED_K_ITERS - 1)
{
if constexpr (STAGES == 1)
__syncthreads();
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A_hoisted,
A_shared_hoisted + ld_stage * kSmemSizeAPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k,
k_0_0_ld < gemm_iters, A_g2s_preds);
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B_hoisted,
B_shared_hoisted + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k,
k_0_0_ld < gemm_iters);
}
if (iter_k == SHARED_K_ITERS - 2)
{
if constexpr (STAGES == 1 && SHARED_K_ITERS > 2)
{
__syncthreads();
}
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A_hoisted,
A_shared_hoisted + ld_stage * kSmemSizeAPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld,
iter_k + 1, k_0_0_ld < gemm_iters, A_g2s_preds);
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B_hoisted,
B_shared_hoisted + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld,
iter_k + 1, k_0_0_ld < gemm_iters);
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(zeros,
zeros_shared + (ld_stage) *CTA_N, N, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k,
k_0_0_ld < gemm_iters);
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(scales_i8,
scales_i8_shared + (ld_stage) *CTA_N, N, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k,
k_0_0_ld < gemm_iters);
if constexpr (STAGES > 1)
{
__pipeline_commit();
__pipeline_wait_prior(STAGES - 2);
}
compute_stage = (k_0_0 + 1) % STAGES;
__syncthreads();
}
}
}
__pipeline_commit();
__pipeline_wait_prior(0);
__syncthreads();
if constexpr (SLICES > 1)
{
#pragma unroll
for (int z = 0; z < SLICES; ++z)
{
if (slice_id == z)
{
#pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
#pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
#pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
{
if (z > 0)
{
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]
+= C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n
+ ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N
+ (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
}
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16
+ ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8
+ (local_id % 2) + (threadIdx.x % 4) * 2]
= C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
};
}
}
}
__syncthreads();
}
if (slice_id == 0)
{
#pragma unroll
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
#pragma unroll
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
#pragma unroll
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
{
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]
= C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16
+ ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8
+ (local_id % 2) + (threadIdx.x % 4) * 2];
};
}
}
}
}
int row_wb_thd = cta_offset_m + warp_offset_m + (threadIdx.x / 4);
int col_wb_thd = cta_offset_n + warp_offset_n + (threadIdx.x % 4) * 2;
if (slice_id == 0)
{
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
{
int row_wb_1 = row_wb_thd + ax0_0_1 * OP_M;
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
{
int col_wb_1 = col_wb_thd + ax1_0_1 * 16;
int* C_warp_local = C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8;
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
{
int row_wb = row_wb_1 + (local_id % 4) / 2 * 8;
if (row_wb < M)
{
int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2);
float2 wscale = __half22float2(*(wscales + col_wb / 2));
float ascale = __half2float(ascales[row_wb]);
float2 psums = make_float2(
__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1]));
psums.x *= wscale.x * ascale;
psums.y *= wscale.y * ascale;
*reinterpret_cast<half2*>(C + row_wb * N + col_wb) = __float22half2_rn(psums);
}
};
}
}
}
#endif
}
void qserveGemmPerGroupLaunch(ParamsPerGroup const& params, cudaStream_t stream)
{
auto in_feats = params.A;
auto kernel = params.B;
auto zeros = params.s2_zeros;
auto scales_i8 = params.s2_scales;
auto wscales = reinterpret_cast<half2 const*>(params.s1_scales);
auto ascales = params.act_scales;
auto out_feats = params.C;
int num_out_feats = params.m;
int num_out_channels = params.n;
int num_in_feats = params.m;
int num_in_channels = params.k;
constexpr int G = 128;
if (num_out_feats > 128)
{
constexpr int CTA_M = 128;
constexpr int CTA_N = 128;
constexpr int CTA_K = 64;
constexpr int WARP_M = 64;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
}
else if (num_out_feats >= 128)
{
if (num_in_channels <= 4096)
{
constexpr int CTA_M = 64;
constexpr int CTA_N = 64;
constexpr int CTA_K = 64;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 4;
KERNEL_LAUNCH_CODE
}
else
{
constexpr int CTA_M = 64;
constexpr int CTA_N = 64;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
}
}
else
{
constexpr int CTA_M = 32;
constexpr int CTA_N = 64;
constexpr int CTA_K = 128;
constexpr int WARP_M = 32;
constexpr int WARP_N = 32;
constexpr int WARP_K = 64;
constexpr int STAGES = 3;
KERNEL_LAUNCH_CODE
}
}
void QServeGemmRunner::gemmPerGroup(ParamsPerGroup const& params, cudaStream_t stream)
{
qserveGemmPerGroupLaunch(params, stream);
}
size_t QServeGemmRunner::getWorkspaceSize(int const m, int const n, int const k)
{
// We do not use workspace for now.
return 0;
}
} // namespace qserve
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -29,13 +29,13 @@ namespace tensorrt_llm
namespace kernels
{
__global__ void quantizedKernel(char4* dst, float4 const* src, const int64_t sizeDiv4, float const* scalePtr)
__global__ void quantizedKernel(char4* dst, float4 const* src, int64_t const sizeDiv4, float const* scalePtr)
{
for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4; idx += blockDim.x * gridDim.x)
{
float const scale = __ldg(scalePtr);
char4 tmp;
const float4 floatTmp = __ldg(src + idx);
float4 const floatTmp = __ldg(src + idx);
tmp.x = cuda_cast<int8_t>(floatTmp.x * scale);
tmp.y = cuda_cast<int8_t>(floatTmp.y * scale);
tmp.z = cuda_cast<int8_t>(floatTmp.z * scale);
@ -44,7 +44,7 @@ __global__ void quantizedKernel(char4* dst, float4 const* src, const int64_t siz
}
}
__global__ void quantizedKernel(char4* dst, half2 const* src, const int64_t sizeDiv4, float const* scalePtr)
__global__ void quantizedKernel(char4* dst, half2 const* src, int64_t const sizeDiv4, float const* scalePtr)
{
for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4; idx += blockDim.x * gridDim.x)
{
@ -52,10 +52,10 @@ __global__ void quantizedKernel(char4* dst, half2 const* src, const int64_t size
char4 tmp;
int srcId = idx << 1;
const uint2 h2 = __ldg(reinterpret_cast<uint2 const*>(src + srcId));
uint2 const h2 = __ldg(reinterpret_cast<uint2 const*>(src + srcId));
const half2 half2Tmp = reinterpret_cast<half2 const&>(h2.x);
const half2 half2Tmp2 = reinterpret_cast<half2 const&>(h2.y);
half2 const half2Tmp = reinterpret_cast<half2 const&>(h2.x);
half2 const half2Tmp2 = reinterpret_cast<half2 const&>(h2.y);
tmp.x = cuda_cast<int8_t>(cuda_cast<float>(half2Tmp.x) * scale);
tmp.y = cuda_cast<int8_t>(cuda_cast<float>(half2Tmp.y) * scale);
@ -66,7 +66,7 @@ __global__ void quantizedKernel(char4* dst, half2 const* src, const int64_t size
}
#ifdef ENABLE_BF16
__global__ void quantizedKernel(char4* dst, __nv_bfloat162 const* src, const int64_t sizeDiv4, float const* scalePtr)
__global__ void quantizedKernel(char4* dst, __nv_bfloat162 const* src, int64_t const sizeDiv4, float const* scalePtr)
{
for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4; idx += blockDim.x * gridDim.x)
{
@ -74,10 +74,10 @@ __global__ void quantizedKernel(char4* dst, __nv_bfloat162 const* src, const int
char4 tmp;
int srcId = idx << 1;
const uint2 h2 = __ldg(reinterpret_cast<uint2 const*>(src + srcId));
uint2 const h2 = __ldg(reinterpret_cast<uint2 const*>(src + srcId));
const __nv_bfloat162 bfloat162Tmp = reinterpret_cast<__nv_bfloat162 const&>(h2.x);
const __nv_bfloat162 bfloat162Tmp2 = reinterpret_cast<__nv_bfloat162 const&>(h2.y);
__nv_bfloat162 const bfloat162Tmp = reinterpret_cast<__nv_bfloat162 const&>(h2.x);
__nv_bfloat162 const bfloat162Tmp2 = reinterpret_cast<__nv_bfloat162 const&>(h2.y);
tmp.x = cuda_cast<int8_t>(cuda_cast<float>(bfloat162Tmp.x) * scale);
tmp.y = cuda_cast<int8_t>(cuda_cast<float>(bfloat162Tmp.y) * scale);
@ -91,7 +91,7 @@ __global__ void quantizedKernel(char4* dst, __nv_bfloat162 const* src, const int
template <typename T>
void invokeQuantization(
int8_t* dst, T const* src, const int64_t size, float const* scalePtr, cudaStream_t stream, int maxGridSize)
int8_t* dst, T const* src, int64_t const size, float const* scalePtr, cudaStream_t stream, int maxGridSize)
{
TLLM_CHECK_WITH_INFO(size % 4 == 0, "[ERROR][invokeQuantization] size should be a multiple of 4.\n");
@ -116,13 +116,13 @@ void invokeQuantization(
}
template void invokeQuantization<float>(
int8_t* dst, float const* src, const int64_t size, float const* scalePtr, cudaStream_t stream, int maxGridSize);
int8_t* dst, float const* src, int64_t const size, float const* scalePtr, cudaStream_t stream, int maxGridSize);
template void invokeQuantization<half>(
int8_t* dst, half const* src, const int64_t size, float const* scalePtr, cudaStream_t stream, int maxGridSize);
int8_t* dst, half const* src, int64_t const size, float const* scalePtr, cudaStream_t stream, int maxGridSize);
#ifdef ENABLE_BF16
template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const* src, const int64_t size,
template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const* src, int64_t const size,
float const* scalePtr, cudaStream_t stream, int maxGridSize);
#endif
@ -220,8 +220,8 @@ inline __device__ void quantizeAndStore(
////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, typename QuantT, bool USE_SMEM>
__global__ void perTokenQuantization(QuantT* dst, T const* src, const int64_t numRows, const int64_t numCols,
float const* clampPtr, float* scalePtr, bool hasFp8MinScaling)
__global__ void perTokenQuantization(QuantT* dst, T const* src, int64_t const numRows, int64_t const numCols,
float const* clampPtr, float* scalePtr, float* sumPtr, bool hasFp8MinScaling)
{
// Smem buffer.
extern __shared__ uint4 smemBuffer[];
@ -244,6 +244,8 @@ __global__ void perTokenQuantization(QuantT* dst, T const* src, const int64_t nu
// The number of elements in the packed uint4 vec.
static constexpr int NUM_ELTS_PER_VEC = sizeof(uint4) / sizeof(T);
static constexpr int NUM_ELTS2_PER_VEC = sizeof(uint4) / sizeof(T2);
// The number of vectors in the column.
int const numColVecs = numCols / NUM_ELTS_PER_VEC;
// The vector pointers for src.
@ -253,10 +255,25 @@ __global__ void perTokenQuantization(QuantT* dst, T const* src, const int64_t nu
// T const* srcRow = src + blockIdx.x * numCols;
T2 localMax2 = cuda_cast<T2, T>(T(1e-6f));
float2 localSum2 = {0.f, 0.f};
for (int i = threadIdx.x; i < numColVecs; i += blockDim.x)
{
uint4 vec = srcVec[i];
clampAndAbsMax(localMax2, vec, clampMin2, clampMax2);
#pragma unroll
for (int i = 0; i < NUM_ELTS2_PER_VEC; ++i)
{
T2& val2 = reinterpret_cast<T2*>(&vec)[i];
val2 = cuda_clamp(val2, clampMin2, clampMax2);
localMax2 = cuda_max(localMax2, cuda_abs(val2));
// TODO: template the version that requires sum to avoid dynamic branching.
if (sumPtr != nullptr)
{
localSum2.x += cuda_cast<float>(val2.x);
localSum2.y += cuda_cast<float>(val2.y);
}
}
// Avoid reloading from global memory.
if constexpr (USE_SMEM)
{
@ -264,13 +281,22 @@ __global__ void perTokenQuantization(QuantT* dst, T const* src, const int64_t nu
}
}
float const rowMax = blockAllReduceMax(cuda_cast<float>(cuda_max<T, T2>(localMax2)));
if (threadIdx.x == 0)
{
scalePtr[blockIdx.x]
= hasFp8MinScaling ? cuda_max(rowMax / MAX_QUANT_VAL, MIN_SCALING_FACTOR) : (rowMax / MAX_QUANT_VAL);
}
if (sumPtr != nullptr)
{
float rowSum[1] = {cuda_sum<float>(localSum2)};
blockReduceSumV2<float, 1>(rowSum);
if (threadIdx.x == 0)
{
sumPtr[blockIdx.x] = rowSum[0];
}
}
float const scaleOrigQuant
= hasFp8MinScaling ? fminf(MAX_QUANT_VAL / rowMax, MIN_SCALING_FACTOR_RCP) : MAX_QUANT_VAL / rowMax;
for (int i = threadIdx.x; i < numColVecs; i += blockDim.x)
@ -283,12 +309,12 @@ __global__ void perTokenQuantization(QuantT* dst, T const* src, const int64_t nu
// Do per-token (row) quantization from fp16/bf16/fp32 to int8/fp8_e4m3.
template <typename T, typename QuantT>
void invokePerTokenQuantization(QuantT* dst, T const* src, const int64_t numRows, const int64_t numCols,
float const* clampPtr, float* scalePtr, QuantMode quantMode, cudaStream_t stream)
void invokePerTokenQuantization(QuantT* dst, T const* src, int64_t const numRows, int64_t const numCols,
float const* clampPtr, float* scalePtr, float* sumPtr, QuantMode quantMode, cudaStream_t stream)
{
// each block is responsible for a single row
const dim3 block(512);
const dim3 grid(numRows);
dim3 const block(512);
dim3 const grid(numRows);
// The number of elements in the packed uint4 vec.
static constexpr int NUM_ELTS_PER_VEC = sizeof(uint4) / sizeof(T);
@ -311,19 +337,19 @@ void invokePerTokenQuantization(QuantT* dst, T const* src, const int64_t numRows
// Do we use smem ?
if (useSmem)
{
perTokenQuantization<T, QuantT, true>
<<<grid, block, dynamicSmemSz, stream>>>(dst, src, numRows, numCols, clampPtr, scalePtr, hasFp8MinScaling);
perTokenQuantization<T, QuantT, true><<<grid, block, dynamicSmemSz, stream>>>(
dst, src, numRows, numCols, clampPtr, scalePtr, sumPtr, hasFp8MinScaling);
}
else
{
perTokenQuantization<T, QuantT, false>
<<<grid, block, 0, stream>>>(dst, src, numRows, numCols, clampPtr, scalePtr, hasFp8MinScaling);
<<<grid, block, 0, stream>>>(dst, src, numRows, numCols, clampPtr, scalePtr, sumPtr, hasFp8MinScaling);
}
}
#define INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(T, QuantT) \
template void invokePerTokenQuantization(QuantT* dst, const T* src, const int64_t numRows, const int64_t numCols, \
float const* clampPtr, float* scalePtr, QuantMode quantMode, cudaStream_t stream)
float const* clampPtr, float* scalePtr, float* sumPtr, QuantMode quantMode, cudaStream_t stream)
INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(float, int8_t);
INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(half, int8_t);

View File

@ -26,11 +26,12 @@ namespace kernels
template <typename T>
void invokeQuantization(
int8_t* dst, T const* src, const int64_t size, float const* scalePtr, cudaStream_t stream = 0, int maxGirdSize = 0);
int8_t* dst, T const* src, int64_t const size, float const* scalePtr, cudaStream_t stream = 0, int maxGirdSize = 0);
template <typename T, typename QuantT>
void invokePerTokenQuantization(QuantT* dst, T const* src, const int64_t numRows, const int64_t numCols,
float const* clampPtr, float* scalePtr, tensorrt_llm::common::QuantMode quantMode, cudaStream_t stream = 0);
void invokePerTokenQuantization(QuantT* dst, T const* src, int64_t const numRows, int64_t const numCols,
float const* clampPtr, float* scalePtr, float* sumPtr, tensorrt_llm::common::QuantMode quantMode,
cudaStream_t stream = 0);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -53,7 +53,7 @@ __inline__ __device__ Tf compute_rmsnorm(Tf val, float s_variance, T const* gamm
template <typename T, typename QuantT, bool USE_SHMEM>
__global__ void generalRmsNorm(T const* input, T const* gamma, T const* beta, T* normed_output, float const eps,
int tokens, int hidden_dim, float const* clampPtr, float const* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token, QuantT* normed_output_quant, bool hasFp8MinScaling)
float* scale_orig_quant_per_token, float* sum_per_token, QuantT* normed_output_quant, bool hasFp8MinScaling)
{
constexpr auto num_elems_T = num_elems<T>::value;
// using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type;
@ -86,13 +86,13 @@ __global__ void generalRmsNorm(T const* input, T const* gamma, T const* beta, T*
int const n_elems = hidden_dim / num_elems_T;
for (int i = tidx; i < n_elems; i += blockDim.x)
{
const T val = input[bidx * n_elems + i];
T const val = input[bidx * n_elems + i];
if (USE_SHMEM)
{
shmem[i] = val;
}
const float_packed_t val_f = cuda_cast<float_packed_t>(val);
float_packed_t const val_f = cuda_cast<float_packed_t>(val);
local_var_sum += cuda_sum<float>(val_f * val_f);
}
@ -110,14 +110,17 @@ __global__ void generalRmsNorm(T const* input, T const* gamma, T const* beta, T*
bool const with_per_token_scaling = scale_orig_quant_per_token != nullptr;
bool const with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
const float_packed_t scale_orig_quant
bool const with_per_token_sum = sum_per_token != nullptr;
float_packed_t const scale_orig_quant
= cuda_cast<float_packed_t>(with_per_tensor_scaling ? *scale_orig_quant_per_tensor : 0.0f);
T_scalar amax = 1e-6f;
float local_sum = 0.f;
for (int i = tidx; i < n_elems; i += blockDim.x)
{
int const index = bidx * n_elems + i;
const float_packed_t val_f = cuda_cast<float_packed_t>(USE_SHMEM ? shmem[i] : input[index]);
float_packed_t const val_f = cuda_cast<float_packed_t>(USE_SHMEM ? shmem[i] : input[index]);
T val = cuda_cast<T>(compute_rmsnorm(val_f, s_variance, gamma, beta, i));
if (with_per_token_scaling)
@ -139,6 +142,11 @@ __global__ void generalRmsNorm(T const* input, T const* gamma, T const* beta, T*
{
normed_output[index] = val;
}
if (with_per_token_sum)
{
local_sum += cuda_sum<float>(cuda_cast<float_packed_t>(val));
}
}
if (with_per_token_scaling)
@ -165,13 +173,23 @@ __global__ void generalRmsNorm(T const* input, T const* gamma, T const* beta, T*
: abs_max_f / MAX_QUANT_VAL;
}
}
if (with_per_token_sum)
{
float packed_sum[1] = {local_sum};
blockReduceSumV2<float, 1>(packed_sum);
if (tidx == 0)
{
sum_per_token[bidx] = packed_sum[0];
}
}
}
template <typename T, typename QuantT>
void dispatch_rmsnorm_type_square_method(T const* input, T const* gamma, T const* beta, T* normed_output,
float const eps, int tokens, int hidden_dim, float const* clampPtr, float const* scale_orig_quant_per_tensor,
float* scale_orig_quant_per_token, QuantT* normed_output_quant, bool const hasFp8MinScaling, const dim3 grid,
const dim3 block, const size_t shmem_size, cudaStream_t stream)
float* scale_orig_quant_per_token, float* sum_per_token, QuantT* normed_output_quant, bool const hasFp8MinScaling,
dim3 const grid, dim3 const block, size_t const shmem_size, cudaStream_t stream)
{
// Do we use shared memory to cache intermediate results.
bool use_shmem = true;
@ -186,32 +204,32 @@ void dispatch_rmsnorm_type_square_method(T const* input, T const* gamma, T const
if (use_shmem)
{
generalRmsNorm<T, QuantT, true><<<grid, block, shmem_size, stream>>>(input, gamma, beta, normed_output, eps,
tokens, hidden_dim, clampPtr, scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant,
hasFp8MinScaling);
tokens, hidden_dim, clampPtr, scale_orig_quant_per_tensor, scale_orig_quant_per_token, sum_per_token,
normed_output_quant, hasFp8MinScaling);
}
else
{
generalRmsNorm<T, QuantT, false><<<grid, block, shmem_size, stream>>>(input, gamma, beta, normed_output, eps,
tokens, hidden_dim, clampPtr, scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant,
hasFp8MinScaling);
tokens, hidden_dim, clampPtr, scale_orig_quant_per_tensor, scale_orig_quant_per_token, sum_per_token,
normed_output_quant, hasFp8MinScaling);
}
}
template <typename T, typename QuantT>
void dispatch_rmsnorm_type(T const* input, T const* gamma, T const* beta, T* normed_output, float const eps, int tokens,
int hidden_dim, float const* clampPtr, float const* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token,
QuantT* normed_output_quant, bool const hasFp8MinScaling, const dim3 grid, const dim3 block,
const size_t shmem_size, cudaStream_t stream)
float* sum_per_token, QuantT* normed_output_quant, bool const hasFp8MinScaling, dim3 const grid, dim3 const block,
size_t const shmem_size, cudaStream_t stream)
{
dispatch_rmsnorm_type_square_method(input, gamma, beta, normed_output, eps, tokens, hidden_dim, clampPtr,
scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant, hasFp8MinScaling, grid, block,
shmem_size, stream);
scale_orig_quant_per_tensor, scale_orig_quant_per_token, sum_per_token, normed_output_quant, hasFp8MinScaling,
grid, block, shmem_size, stream);
}
template <typename T, typename QuantT>
void invokeGeneralRmsNorm(T* out, T const* input, T const* gamma, T const* beta, float const eps, int const tokens,
int const hidden_dim, QuantMode quantMode, cudaStream_t stream, float const* clampPtr, float const* scale,
float* dynamic_scale, QuantT* normed_output_quant)
float* dynamic_scale, float* sum_per_token, QuantT* normed_output_quant)
{
dim3 grid(tokens);
dim3 block(min(hidden_dim, 1024));
@ -219,7 +237,7 @@ void invokeGeneralRmsNorm(T* out, T const* input, T const* gamma, T const* beta,
block.x = 32 * ((block.x + 31) / 32);
constexpr size_t vec_size = 2;
const size_t shmem_size = hidden_dim * sizeof(T);
size_t const shmem_size = hidden_dim * sizeof(T);
bool const use_vec_type = (hidden_dim % vec_size == 0)
&& (std::is_same<T, half>::value
#ifdef ENABLE_BF16
@ -235,19 +253,19 @@ void invokeGeneralRmsNorm(T* out, T const* input, T const* gamma, T const* beta,
using Tp = typename packed_as<T, vec_size>::type;
dispatch_rmsnorm_type(reinterpret_cast<Tp const*>(input), reinterpret_cast<Tp const*>(gamma),
reinterpret_cast<Tp const*>(beta), reinterpret_cast<Tp*>(out), eps, tokens, hidden_dim, clampPtr, scale,
dynamic_scale, normed_output_quant, hasFp8MinScaling, grid, block, shmem_size, stream);
dynamic_scale, sum_per_token, normed_output_quant, hasFp8MinScaling, grid, block, shmem_size, stream);
}
else
{
dispatch_rmsnorm_type(input, gamma, beta, out, eps, tokens, hidden_dim, clampPtr, scale, dynamic_scale,
normed_output_quant, hasFp8MinScaling, grid, block, shmem_size, stream);
sum_per_token, normed_output_quant, hasFp8MinScaling, grid, block, shmem_size, stream);
}
}
#define INSTANTIATE_GENERAL_RMSNORM(T, QuantT) \
template void invokeGeneralRmsNorm(T* out, const T* input, const T* gamma, const T* beta, const float eps, \
const int tokens, const int hidden_dim, QuantMode quantMode, cudaStream_t stream, float const* clampPtr, \
const float* scale, float* dynamic_scale, QuantT* normed_output_quant);
const float* scale, float* dynamic_scale, float* sum_per_token, QuantT* normed_output_quant);
INSTANTIATE_GENERAL_RMSNORM(float, int8_t);
INSTANTIATE_GENERAL_RMSNORM(half, int8_t);

View File

@ -31,7 +31,7 @@ template <typename T, typename QuantT>
void invokeGeneralRmsNorm(T* out, T const* input, T const* gamma, T const* beta, float const eps, int const tokens,
int const hidden_dim, tensorrt_llm::common::QuantMode quantMode, cudaStream_t stream = 0,
float const* clampPtr = nullptr, float const* scale = nullptr, float* dynamic_scale = nullptr,
QuantT* out_quant = nullptr);
float* sum_per_token = nullptr, QuantT* out_quant = nullptr);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -231,6 +231,7 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
auto outputId = idx != -1
? topKTmpIdBuf[(batchIdx * maxTokensPerStep + tokenIdx) * stride + idx] % vocabSize
: vocabSize - 1;
outputId = outputId == -1 ? vocabSize - 1 : outputId;
auto const curSeqLen = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot];
auto const outIdx = returnAllSelectedTokens ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx;
outputIdsRequestPtr[outIdx] = outputId;

View File

@ -153,8 +153,7 @@ __global__ void prepareCtxEagleNetInputsKernel(SizeType32* eagleNetSequenceLengt
if (isValid)
{
// Sequence length of the base model (without draft len for gen requests and all prompt tokens for ctx request)
auto const oldSequenceLength
= baseNetSequenceLengths[bid] - (numInputTokens - static_cast<SizeType32>(!isContextRequest));
auto const oldSequenceLength = baseNetSequenceLengths[bid] - numInputTokens;
for (SizeType32 ti = 0; ti < numDecodingTokens; ++ti)
{
TokenIdType token;
@ -249,14 +248,6 @@ __global__ void buildLeafMask(
{
auto const bid = static_cast<SizeType32>(blockIdx.x);
auto const level = static_cast<SizeType32>(blockIdx.y);
// Prefill mask setting all to leaves.
for (auto tid = static_cast<SizeType32>(threadIdx.x); tid < maxDecodingTokens;
tid += static_cast<SizeType32>(blockDim.x))
{
isLeafMask[bid * maxDecodingTokens + tid] = 1;
}
__syncthreads();
// For all paths
for (auto pathIdx = static_cast<SizeType32>(threadIdx.x); pathIdx < maxDecodingTokens;
@ -268,8 +259,8 @@ __global__ void buildLeafMask(
auto const tokenCurLevelOffset = flat_index3(bid, pathIdx, level, maxDecodingTokens, maxPathLen);
// Get token idx in the flattened draft tokens for the of the current level
auto const curNodeTokenIdx = paths[tokenCurLevelOffset];
// If token idx is not -1 (not terminated path)
// And the next level token is not -1 -- path is not terminating and token at current level has child.
// If token idx is not -1 (not terminated path) And the next
// level token is not -1 -- path is not terminating and token at current level has child.
if (curNodeTokenIdx != -1 && paths[tokenNextLevelOffset] != -1)
{
// Mark mask to 0.
@ -303,11 +294,6 @@ __global__ void getNonLeafEndingSubtree(SizeType32* selectedDraftIndices, SizeTy
// Init selected paths for CAS.
selectedPathsSmem[ii] = -1;
}
// Fill mask.
for (auto ii = static_cast<SizeType32>(threadIdx.x); ii < maxDecodingTokens;
ii += static_cast<SizeType32>(blockDim.x))
{
}
__syncthreads();
@ -320,7 +306,7 @@ __global__ void getNonLeafEndingSubtree(SizeType32* selectedDraftIndices, SizeTy
auto const tokenIdxLevel = paths[tokenCurLevelOffset];
// Check if this path is not terminated yet.
// And check if this node is not leaf.
if (tokenIdxLevel >= 0 && !isLeafMask[tokenIdxLevel])
if (tokenIdxLevel >= 0 && !isLeafMask[bid * maxDecodingTokens + tokenIdxLevel])
{
// Set this path as selected for this token.
atomicCAS(&selectedPathsSmem[tokenIdxLevel], -1, pi);
@ -372,7 +358,7 @@ __global__ void getNonLeafEndingSubtree(SizeType32* selectedDraftIndices, SizeTy
}
// If node is not leaf
if (!isLeafMask[ti])
if (!isLeafMask[bid * maxDecodingTokens + ti])
{
// Save its in-level index for output hidden state indices calculation.
nonLeavesInLevelOffsets[bid * maxDecodingTokens + ti] = nonLeavesInLevelCounter;
@ -506,15 +492,13 @@ __global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths,
if (bid == 0)
{
// Save max draft length for the mask packing kernel.
maxGenerationLength[0] = maxDecodingTokens;
maxGenerationLength[0] = maxGenLength;
}
if (isValid)
{
// Fill spec decoding gen length.
// We do -1 here as attn expects golden token + draft tokens.
// We do not provide golden token, but just reduce the number of draft tokens.
specDecodingGenLengths[bid] = nextDraftLen - 1;
specDecodingGenLengths[bid] = nextDraftLen;
// Simply copy context len.
nextContextLengths[bid] = prevContextLengths[bid];
auto const sequenceLen = eagleNet0SequenceLengths[bid];
@ -524,6 +508,9 @@ __global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths,
// Fill cumulative sum for the mask packing kernel.
cumSumGenerationLengths[bid] = genLengthCumSum;
// Pos id is Ctx EagleNet seqLen (prompt + all accepted).
positionIds[bid] = sequenceLen;
SizeType32 lastTokenIdx{0};
for (SizeType32 ti = 0; ti < nextDraftLen; ++ti)
{
@ -535,8 +522,6 @@ __global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths,
// Get draft pos offset.
auto const posOffset = selectedDraftPosIds[bid * maxDecodingDraftTokens + ti];
specDecodingPositionOffsets[bid * maxDecodingTokens + ti] = posOffset;
// Pos id is Ctx EagleNet seqLen (prompt + all accepted) + pos offset
positionIds[outputIndexBase + ti] = sequenceLen + posOffset;
// hiddenStatesIndex is constructed having hidden states layout in mind.
// Hidden states are placed in memory as [maxPathLen - 1, batchSize, numOutputTokens] (this tensor is
@ -589,6 +574,11 @@ inline __device__ __host__ T divUp(T m, T n)
return (m + n - 1) / n;
}
__device__ SizeType32 positivePowerOfTwo(SizeType32 n)
{
return 1 << n;
}
//! @brief Takes mask of size [maxGenerationLength] filled with 1s and 0s defined in the shared memory
//! and packs it to bitmask of size [numPackedMasks] written to outputPtr.
//! numPackedMasks = ceil(maxGenerationLength / 32);
@ -609,20 +599,33 @@ __device__ __forceinline__ void maskToPackedMask(
auto const shMaskIndexEnd = maxGenerationLength - maskId * 32;
auto const validNumBits = shMaskIndexEnd - shMaskIndexStart;
SizeType32 packedMask = 0;
for (SizeType32 ii = 0; ii < validNumBits; ++ii)
auto const firstBit1 = (shMask[shMaskIndexStart] == '1') ? true : false;
SizeType32 mask31bits = 0;
if (validNumBits != 1)
{
auto const index = (validNumBits - 1) - (ii - shMaskIndexStart - 1) - 1;
packedMask += (shMask[shMaskIndexStart + ii] == '1') ? 1 << index : 0;
for (auto i = shMaskIndexStart + 1; i < shMaskIndexEnd; i++)
{
auto const index = (validNumBits - 1) - (i - shMaskIndexStart);
mask31bits += (shMask[i] == '1') ? positivePowerOfTwo(index) : 0;
}
}
outputPtr[maskId] = packedMask;
SizeType32 mask32bits;
if (validNumBits == 32)
{
mask32bits = firstBit1 ? mask31bits - positivePowerOfTwo(validNumBits - 1) : mask31bits;
}
else
{
mask32bits = firstBit1 ? mask31bits + positivePowerOfTwo(validNumBits - 1) : mask31bits;
}
outputPtr[maskId] = mask32bits;
}
}
}
__global__ void getPackedMask(SizeType32 const* __restrict__ cumGenerationLengths,
SizeType32 const* __restrict__ maxGenerationLengths, bool const* __restrict__ mask, SizeType32 maxDraftTokens,
SizeType32* __restrict__ packedMask)
SizeType32 const* __restrict__ maxGenerationLengths, bool const* __restrict__ mask,
SizeType32 maxDecodingDraftTokens, SizeType32* __restrict__ packedMask)
{
auto const batchIdx = static_cast<SizeType32>(blockIdx.y);
auto const tokenIdx = static_cast<SizeType32>(blockIdx.x);
@ -635,7 +638,8 @@ __global__ void getPackedMask(SizeType32 const* __restrict__ cumGenerationLength
}
auto const maxGenerationLength = maxGenerationLengths[0];
auto const numPackedMasks = divUp(maxDraftTokens + 1, 32);
auto const maxDecodingTokens = maxDecodingDraftTokens + 1;
auto const numPackedMasks = divUp(maxDecodingTokens, 32);
auto const outputStartId = ((batchIdx == 0) ? 0 : cumGenerationLengths[batchIdx - 1]);
auto* outputPtr = packedMask + (outputStartId + tokenIdx) * numPackedMasks;
@ -650,8 +654,7 @@ __global__ void getPackedMask(SizeType32 const* __restrict__ cumGenerationLength
}
else
{
bool const* maskPtr
= mask + batchIdx * maxGenerationLength * maxGenerationLength + tokenIdx * maxGenerationLength;
bool const* maskPtr = mask + batchIdx * maxDecodingTokens * maxDecodingTokens + tokenIdx * maxDecodingTokens;
extern __shared__ char shMask[];
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < maxGenerationLength;
ti += static_cast<SizeType32>(blockDim.x))
@ -1071,7 +1074,7 @@ __global__ void packEagleGenerationLengths(PackEagleParams params)
if (threadIdx.x == 0 && isGenerationRequest)
{
params.outputSpecDecodingGenerationLengths[genIdx] = params.inputNextDraftLens[batchSlot];
params.outputSpecDecodingGenerationLengths[genIdx] = params.inputSpecDecodingGenerationLengths[batchSlot];
}
}
} // namespace
@ -1101,7 +1104,8 @@ __global__ void packEagleTensors(PackEagleParams params)
params.outputRandomDataValidation[batchIdx] = params.inputRandomDataValidation[batchSlot];
// 0 for ctx request and actual draft len for gen requests.
params.outputNextDraftLens[batchIdx] = isGenerationRequest ? params.inputNextDraftLens[batchSlot] : 0;
params.outputNextDraftLens[batchIdx]
= isGenerationRequest ? params.inputSpecDecodingGenerationLengths[batchSlot] - 1 : 0;
}
// Copy draft paths
@ -1186,6 +1190,8 @@ __global__ void unpackEagleData(UnpackEagleDataParams params)
// Copy next draft tokens to the slots.
params.outputNextDraftTokens[batchSlot * maxDecodingDraftTokens + ti]
= params.inputNextDraftTokens[bid * maxDecodingDraftTokens + ti];
params.outputUnpackedNextDraftTokens[batchSlot * maxDecodingDraftTokens + ti]
= params.inputNextDraftTokens[bid * maxDecodingDraftTokens + ti];
}
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < params.maxDecodingTokens * params.maxPathLength;
@ -1204,6 +1210,8 @@ __global__ void unpackEagleData(UnpackEagleDataParams params)
params.outputNumNewTokens[batchSlot] = acceptedLength;
// One thread copies next draft len to slot
params.outputNextDraftLengths[batchSlot] = params.inputNextDraftLens[bid];
// Set next gen len to slot.
params.outputNextGenerationLength[batchSlot] = params.inputNextDraftLens[bid] + 1;
// Set prev draft lengths needed for kv cache rewind in variable draft len.
params.outputPrevDraftLengths[batchSlot] = params.inputLastDraftLens[bid];
// Set random data for draft sampling kernels.

View File

@ -121,7 +121,7 @@ struct PrepareGenEagleNetInputsParams
//! output buffer [numOutputTokens]
//! Selected tokens ids.
runtime::TokenIdType* outputIds{nullptr};
//! output buffer [numOutputTokens]
//! output buffer [batchSize]
//! Position ids of the selected tokens.
runtime::SizeType32* positionIds{nullptr};
//! output buffer [batchSize]
@ -270,8 +270,6 @@ struct PackEagleParams
float const* inputRandomDataValidation{nullptr};
//! [maxBatchSize, maxDecodingDraftTokens]
runtime::TokenIdType const* inputNextDraftTokens{nullptr};
//! [maxBatchSize]
runtime::SizeType32 const* inputNextDraftLens{nullptr};
//! [maxBatchSize, maxDecodingTokens, maxPathLen]
runtime::SizeType32 const* inputNextDraftPaths{nullptr};
//! [maxBatchSize]
@ -315,7 +313,6 @@ struct PackEagleParams
TLLM_CHECK(inputRandomDataSample);
TLLM_CHECK(inputRandomDataValidation);
TLLM_CHECK(inputNextDraftTokens);
TLLM_CHECK(inputNextDraftLens);
TLLM_CHECK(inputNextDraftPaths);
TLLM_CHECK(inputSpecDecodingGenerationLengths);
TLLM_CHECK(inputSpecDecodingPositionOffsets);
@ -327,9 +324,9 @@ struct PackEagleParams
TLLM_CHECK(outputNextDraftTokens);
TLLM_CHECK(outputNextDraftLens);
TLLM_CHECK(outputNextDraftPaths);
TLLM_CHECK(outputSpecDecodingGenerationLengths);
TLLM_CHECK(outputSpecDecodingPositionOffsets);
TLLM_CHECK(outputSpecDecodingPackedMasks);
TLLM_CHECK((numGenerationRequests > 0 && outputSpecDecodingGenerationLengths) || numGenerationRequests == 0);
TLLM_CHECK((numGenerationRequests > 0 && outputSpecDecodingPositionOffsets) || numGenerationRequests == 0);
TLLM_CHECK((numGenerationRequests > 0 && outputSpecDecodingPackedMasks) || numGenerationRequests == 0);
TLLM_CHECK(maxGenerationLength);
TLLM_CHECK(cumSumGenerationLengths);
@ -377,6 +374,8 @@ struct UnpackEagleDataParams
//! [maxBatchSize]
runtime::SizeType32* outputSequenceLengths{nullptr};
//! [maxBatchSize, maxDecodingDraftTokens]
runtime::TokenIdType* outputUnpackedNextDraftTokens{nullptr};
//! [maxBatchSize, maxDecodingDraftTokens]
runtime::TokenIdType* outputNextDraftTokens{nullptr};
//! [maxBatchSize]
runtime::SizeType32* outputNextDraftLengths{nullptr};
@ -384,6 +383,8 @@ struct UnpackEagleDataParams
runtime::SizeType32* outputNextDraftPaths{nullptr};
//! [maxBatchSize]
runtime::SizeType32* outputPrevDraftLengths{nullptr};
//! [maxBatchSize]
runtime::SizeType32* outputNextGenerationLength{nullptr};
//! [maxBatchSize, maxDecodingTokens]
runtime::SizeType32* outputPositionIds{nullptr};
@ -428,10 +429,12 @@ struct UnpackEagleDataParams
TLLM_CHECK(outputIds);
TLLM_CHECK(outputNumNewTokens);
TLLM_CHECK(outputSequenceLengths);
TLLM_CHECK(outputUnpackedNextDraftTokens);
TLLM_CHECK(outputNextDraftTokens);
TLLM_CHECK(outputNextDraftLengths);
TLLM_CHECK(outputNextDraftPaths);
TLLM_CHECK(outputPrevDraftLengths);
TLLM_CHECK(outputNextGenerationLength);
TLLM_CHECK(outputPositionIds);
TLLM_CHECK(outputRandDataSample);

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,38 @@
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "tensorrt_llm/common/assert.h"
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/runtime/common.h"
namespace tensorrt_llm
{
namespace kernels
{
template <typename T>
size_t invokeComputeTopkLastDimWorkspaceSize(
runtime::SizeType32 batchSize, runtime::SizeType32 inputLength, runtime::SizeType32 k, bool is_largest);
template <typename T>
void invokeTopkLastDim(runtime::SizeType32 batchSize, runtime::SizeType32 inputLength, runtime::SizeType32 k,
bool is_largest, void const* __restrict__ input, void* __restrict__ out_val, void* __restrict__ out_ind,
void* workspace, cudaStream_t stream);
} // namespace kernels
} // namespace tensorrt_llm

View File

@ -681,27 +681,33 @@ public:
{
}
//! Unpacked draft tokens
TensorPtr unpackedNextDraftTokens; // [maxBatchSize, maxDecodingDraftTokens] on gpu
//! Draft paths for the next iteration.
TensorPtr nextDraftPaths; // [maxBatchSize, maxDecodingTokens, maxPathLen]
TensorPtr nextDraftPaths; // [maxBatchSize, maxDecodingTokens, maxPathLen] on gpu
//! Randomly sampled data (between 0.f and 1.f)
TensorPtr randomDataSample; // [maxBatchSize] on gpu
//! Randomly sampled data (between 0.f and 1.f)
TensorPtr randomDataValidation; // [maxBatchSize] on gpu
//! Sampling temperature.
TensorPtr temperatures; // [maxBatchSize] on gpu
//! Next generation lengths.
TensorPtr generationLengths; // [maxBatchSize] on gpu
//! Next generation lengths.
TensorPtr generationLengthsHost; // [maxBatchSize] on pinned
//! Request types for ctx stage of the EagleNet0 (filled with 0s).
TensorPtr eagleNetCtxRequestTypesHost; //! [maxBatchSize]
TensorPtr eagleNetCtxRequestTypesHost; //! [maxBatchSize] on pinned
//! Context lengths of the context EagleNet0.
TensorPtr eagleNetCtxContextLengthsHost; //! [maxBatchSize]
TensorPtr eagleNetCtxContextLengthsHost; //! [maxBatchSize] on pinned
//! Past kv lengths of the context EagleNet0.
TensorPtr eagleNetCtxPastKeyValueLengthsHost; //! [maxBatchSize]
TensorPtr eagleNetCtxPastKeyValueLengthsHost; //! [maxBatchSize] on pinned
//! Request types for ctx stage of the EagleNetX (filled with 1s).
TensorPtr eagleNetGenRequestTypesHost; //! [maxBatchSize]
TensorPtr eagleNetGenRequestTypesHost; //! [maxBatchSize] on pinned
//! Context lengths of the generation EagleNetX.
TensorPtr eagleNetGenContextLengthsHost; //! [maxBatchSize]
TensorPtr eagleNetGenContextLengthsHost; //! [maxBatchSize] on pinned
//! Past kv lengths of the generation EagleNetX.
TensorPtr eagleNetGenPastKeyValueLengthsHost; //! [maxBatchSize]
TensorPtr eagleNetGenPastKeyValueLengthsHost; //! [maxBatchSize] on pinned
};
} // namespace tensorrt_llm::layers

View File

@ -160,8 +160,9 @@ void DynamicDecodeLayer<T>::forwardAsync(std::shared_ptr<BaseDecodingOutputs> co
auto params = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
TLLM_CHECK_WITH_INFO(mDecodingMode.isExplicitDraftTokens() || params->logits || params->logitsVec,
"If not explicit Draft Tokens mode, either logits or logitsVec have to be specified.");
TLLM_CHECK_WITH_INFO(
mDecodingMode.isExplicitDraftTokens() || mDecodingMode.isEagle() || params->logits || params->logitsVec,
"If not Explicit Draft Tokens or Eagle mode, either logits or logitsVec have to be specified.");
TLLM_CHECK_WITH_INFO(
baseOutputs->sequenceLength.has_value(), "sequenceLength tensor is required in DynamicDecoderLayer.");

View File

@ -155,8 +155,12 @@ void EagleDecodingLayer<T>::unpackData(EagleOutputs const& outputs, EagleInputs
params.outputIds = bufferCast<TokenIdType>(*outputs.outputIds);
params.outputNumNewTokens = bufferCast<SizeType32>(*outputs.numNewTokens.value());
params.outputSequenceLengths = bufferCast<SizeType32>(*outputs.sequenceLength.value());
// FIXME outputUnpackedNextDraftTokens is the same as outputNextDraftTokens.
// outputUnpackedNextDraftTokens is used in eagleBuffers and outputNextDraftTokens is used in the runtime
params.outputUnpackedNextDraftTokens = bufferCast<TokenIdType>(*outputs.unpackedNextDraftTokens);
params.outputNextDraftTokens = bufferCast<TokenIdType>(*outputs.nextDraftTokens);
params.outputNextDraftLengths = bufferCast<SizeType32>(*outputs.nextDraftLengths);
params.outputNextGenerationLength = bufferCast<SizeType32>(*outputs.generationLengths);
params.outputNextDraftPaths = bufferCast<SizeType32>(*outputs.nextDraftPaths);
params.outputPrevDraftLengths = bufferCast<SizeType32>(*outputs.prevDraftLengths);
params.outputPositionIds = bufferCast<SizeType32>(*outputs.nextDraftPosIds);
@ -189,6 +193,8 @@ void EagleDecodingLayer<T>::unpackData(EagleOutputs const& outputs, EagleInputs
mBufferManager->copy(*mEagleNetGenContextLengths, *outputs.eagleNetGenContextLengthsHost);
mBufferManager->copy(*mEagleNetGenPastKeyValueLengths, *outputs.eagleNetGenPastKeyValueLengthsHost);
mBufferManager->copy(*outputs.generationLengths, *outputs.generationLengthsHost);
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

View File

@ -148,10 +148,10 @@ void ExternalDraftTokensLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWid
{
auto initWorkspaceSizes = getTopKInitWorkspaceSizes(batchSize);
calcAlignedPointers(workspace->getRawWorkspaceDevicePtr(), initWorkspaceSizes)(topKsPtr, topPsPtr);
DecodingLayerWorkspace::copyToWorkspace(*mBufferManager, runtimeTopK,
ITensor::wrap(topKsPtr, {1, {static_cast<ITensor::DimType64>(initWorkspaceSizes[0])}}));
DecodingLayerWorkspace::copyToWorkspace(*mBufferManager, runtimeTopP,
ITensor::wrap(topPsPtr, {1, {static_cast<ITensor::DimType64>(initWorkspaceSizes[1])}}));
DecodingLayerWorkspace::copyToWorkspace(
*mBufferManager, runtimeTopK, IBuffer::wrap(topKsPtr, initWorkspaceSizes[0] / sizeof(*topKsPtr)));
DecodingLayerWorkspace::copyToWorkspace(
*mBufferManager, runtimeTopP, IBuffer::wrap(topPsPtr, initWorkspaceSizes[1] / sizeof(*topPsPtr)));
}
auto const* batchSlotsDevicePtr = workspace->getDeviceBatchSlotsPtr();
auto* skipTopKDecodeDevicePtr = bufferCastOrNull<bool>(mSkipTopKDecodeDevice);

View File

@ -60,7 +60,7 @@ static std::vector<std::unique_ptr<BaseLayer>> createLayers(executor::DecodingMo
std::vector<std::unique_ptr<BaseLayer>> layers;
auto layerTypes = createDecodingLayerTypes(mode);
// Only when draft tokens and predicted and decoded by the engine, we can skip penalty layer.
if (!mode.isExplicitDraftTokens())
if (!mode.isExplicitDraftTokens() && !mode.isEagle())
{
TLLM_CHECK_WITH_INFO(layerTypes.size() && layerTypes[0] == DecodingLayers_t::PENALTY_LAYER,
"Penalty layer is required to be the first layer for any decoder configuration");

View File

@ -61,7 +61,7 @@ LookaheadAlgorithm::LookaheadAlgorithm(
mEncodeMapMax = runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxDraftLen}), nvinfer1::DataType::kINT32);
}
void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeType32 n, SizeType32 g)
void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeType32 n, SizeType32 g, uint64_t seed)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
TLLM_CHECK_WITH_INFO(w <= mMaxW, "lookahead requires setup w (%d) <= max_w (%d)", w, mMaxW);
@ -85,6 +85,9 @@ void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeT
BufferRange<TokenIdType> prefillRange(*mPrefills);
BufferRange<TokenIdType> pastRange(*mPastTokens);
BufferRange<TokenIdType> goldRange(*mGoldenTokens);
srand(seed);
auto randToken = [&promptRange](auto& item) { item = promptRange[rand() % promptRange.size()]; };
std::for_each(prefillRange.begin(), prefillRange.end(), randToken);
std::for_each(pastRange.begin(), pastRange.end(), [](auto& a) { a = -1; });

View File

@ -39,7 +39,8 @@ public:
runtime::SizeType32 maxW, runtime::SizeType32 maxN, runtime::SizeType32 maxG, runtime::SizeType32 id = 0);
//! @brief setup per request, fill internal states from @param prompt.
void setup(TensorConstPtr const& prompt, runtime::SizeType32 w, runtime::SizeType32 n, runtime::SizeType32 g);
void setup(TensorConstPtr const& prompt, runtime::SizeType32 w, runtime::SizeType32 n, runtime::SizeType32 g,
uint64_t seed);
//! @brief accept the new generated tokens.
//! LookaheadDecodingLayer need call once for the first token in generation phase.

View File

@ -163,7 +163,13 @@ void LookaheadDecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth
"runtime w(%d) n(%d) g(%d) exceeds maxTokensPerStep(%d)", w, n, g,
mDecoderDomain.getMaxDecodingTokens());
PRINT_VALUES(mCpuAlgo->mPrompts[bi]);
mCpuAlgo->mAlgos[gbi].setup(mCpuAlgo->mPrompts[bi], w, n, g);
auto seed = DefaultDecodingParams::getSeed();
if (setupParams->randomSeed)
{
auto& seeds = setupParams->randomSeed.value();
seed = seeds.size() == 1 ? seeds[0] : seeds[bi];
}
mCpuAlgo->mAlgos[gbi].setup(mCpuAlgo->mPrompts[bi], w, n, g, seed);
}
for (runtime::SizeType32 bi = 0; bi < batchSize; bi++)

Some files were not shown because too many files have changed in this diff Show More