mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#2436)
This commit is contained in:
parent
b7868dd1bd
commit
c629546ce4
@ -6,8 +6,8 @@ TensorRT-LLM
|
||||
|
||||
[](https://nvidia.github.io/TensorRT-LLM/)
|
||||
[](https://www.python.org/downloads/release/python-31012/)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](https://developer.nvidia.com/cuda-downloads)
|
||||
[](https://developer.nvidia.com/tensorrt)
|
||||
[](./tensorrt_llm/version.py)
|
||||
[](./LICENSE)
|
||||
|
||||
|
||||
@ -182,211 +182,8 @@ struct BenchmarkParams
|
||||
std::optional<texec::LookaheadDecodingConfig> executorLookaheadConfig;
|
||||
std::optional<texec::LookaheadDecodingConfig> requestLookaheadConfig;
|
||||
};
|
||||
|
||||
class InferenceRequestsAsyncSend
|
||||
{
|
||||
public:
|
||||
InferenceRequestsAsyncSend(std::shared_ptr<tensorrt_llm::mpi::MpiComm> comm,
|
||||
std::list<std::shared_ptr<InferenceRequest>> const& inferenceRequests, int const peer)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_DEBUG("start send requests to rank %d", peer);
|
||||
mNumNewWorkItems = static_cast<int64_t>(inferenceRequests.size());
|
||||
mRequest1 = comm->sendAsync(&mNumNewWorkItems, 1, mpi::MpiType::kINT64, peer, 0);
|
||||
if (mNumNewWorkItems > 0)
|
||||
{
|
||||
for (auto const& infReq : inferenceRequests)
|
||||
{
|
||||
auto vpacked = infReq->serialize();
|
||||
mPacked.push_back(static_cast<int64_t>(vpacked.size()));
|
||||
mPacked.insert(mPacked.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
|
||||
}
|
||||
mVecSize = static_cast<int64_t>(mPacked.size());
|
||||
mRequest2 = comm->sendAsync(&mVecSize, 1, mpi::MpiType::kINT64, peer, 1);
|
||||
mRequest3 = comm->sendAsync(mPacked.data(), mPacked.size(), mpi::MpiType::kINT64, peer, 2);
|
||||
}
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
~InferenceRequestsAsyncSend()
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
mRequest1->wait();
|
||||
if (mRequest2)
|
||||
mRequest2->wait();
|
||||
if (mRequest3)
|
||||
mRequest3->wait();
|
||||
TLLM_LOG_DEBUG("end send requests");
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
private:
|
||||
int64_t mNumNewWorkItems;
|
||||
int64_t mVecSize;
|
||||
std::vector<int64_t> mPacked;
|
||||
std::shared_ptr<tensorrt_llm::mpi::MpiRequest> mRequest1;
|
||||
std::shared_ptr<tensorrt_llm::mpi::MpiRequest> mRequest2;
|
||||
std::shared_ptr<tensorrt_llm::mpi::MpiRequest> mRequest3;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void inferenceRequestsRecv(std::shared_ptr<tensorrt_llm::mpi::MpiComm> comm,
|
||||
std::list<std::shared_ptr<InferenceRequest>>& inferenceRequests, int const peer)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_LOG_DEBUG("start recv requests from rank %d", peer);
|
||||
int64_t numNewWorkItems = 0;
|
||||
comm->recv(&numNewWorkItems, 1, mpi::MpiType::kINT64, peer, 0);
|
||||
if (numNewWorkItems > 0)
|
||||
{
|
||||
std::vector<int64_t> packed;
|
||||
int64_t vecSize;
|
||||
comm->recv(&vecSize, 1, mpi::MpiType::kINT64, peer, 1);
|
||||
packed.resize(vecSize);
|
||||
comm->recv(packed.data(), packed.size(), mpi::MpiType::kINT64, peer, 2);
|
||||
int64_t* packed_ptr = packed.data();
|
||||
for (int64_t count = 0; count < numNewWorkItems; ++count)
|
||||
{
|
||||
int64_t n = *(packed_ptr++);
|
||||
auto infReq = InferenceRequest::deserialize(packed_ptr);
|
||||
packed_ptr += n;
|
||||
inferenceRequests.emplace_back(infReq);
|
||||
}
|
||||
}
|
||||
TLLM_LOG_DEBUG("end recv requests from rank %d", peer);
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
// Class holding all infos regarding a single work item.
|
||||
// This includes the original request, associated response factor
|
||||
// and state.
|
||||
class WorkItem
|
||||
{
|
||||
public:
|
||||
WorkItem(std::shared_ptr<InferenceRequest> inferenceRequest, uint64_t requestId)
|
||||
: mInferenceRequest(std::move(inferenceRequest))
|
||||
, mRequestId(requestId)
|
||||
{
|
||||
}
|
||||
|
||||
[[nodiscard]] uint64_t requestId() const
|
||||
{
|
||||
return mRequestId;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::shared_ptr<InferenceRequest> getInferenceRequest() const
|
||||
{
|
||||
return mInferenceRequest;
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<InferenceRequest> mInferenceRequest;
|
||||
uint64_t mRequestId;
|
||||
};
|
||||
|
||||
/// @brief Thread-safe queue of work items
|
||||
class WorkItemsQueue
|
||||
{
|
||||
public:
|
||||
void clear()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
mPendingWorkItems.clear();
|
||||
mPendingWorkItemsReqIds.clear();
|
||||
mInProgressWorkItems.clear();
|
||||
}
|
||||
|
||||
// Note: this function only be called under a lock
|
||||
bool hasInProgressReqId(uint64_t const reqId) const
|
||||
{
|
||||
return (mInProgressWorkItems.find(reqId) != mInProgressWorkItems.end());
|
||||
}
|
||||
|
||||
// Note: this function only be called under a lock
|
||||
bool hasPendingReqId(uint64_t const reqId) const
|
||||
{
|
||||
return (mPendingWorkItemsReqIds.find(reqId) != mPendingWorkItemsReqIds.end());
|
||||
}
|
||||
|
||||
bool empty() const
|
||||
{
|
||||
return mPendingWorkItems.empty() && mInProgressWorkItems.empty() && mPendingWorkItemsReqIds.empty();
|
||||
}
|
||||
|
||||
/// @brief Add a new work item to the queue
|
||||
/// Throws an error if requestId already exists
|
||||
|
||||
void push(std::shared_ptr<InferenceRequest> request, uint64_t requestId)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
TLLM_CHECK_WITH_INFO(!hasInProgressReqId(requestId) && !hasPendingReqId(requestId),
|
||||
"requestId %lu is already in progress, request is ignored.", requestId);
|
||||
|
||||
auto workItem = std::make_shared<WorkItem>(request, requestId);
|
||||
mPendingWorkItems.push_back(workItem);
|
||||
mPendingWorkItemsReqIds.insert(workItem->requestId());
|
||||
}
|
||||
|
||||
/// @brief Get a new work item from the queue, and move it to the list of
|
||||
/// in progress work items if it hasn't been stopped
|
||||
/// @return A tuple of the workItem and a boolean flag indicating if the work item
|
||||
/// has been marked in progress
|
||||
std::tuple<std::shared_ptr<WorkItem>, bool> pop()
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
|
||||
auto workItem = mPendingWorkItems.front();
|
||||
mPendingWorkItems.pop_front();
|
||||
mPendingWorkItemsReqIds.erase(workItem->requestId());
|
||||
|
||||
bool markedInProgress = false;
|
||||
mInProgressWorkItems.emplace(workItem->requestId(), workItem);
|
||||
markedInProgress = true;
|
||||
|
||||
return {workItem, markedInProgress};
|
||||
}
|
||||
|
||||
size_t numPendingWorkItems() const
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
return mPendingWorkItems.size();
|
||||
}
|
||||
|
||||
size_t numInProgressWorkItems() const
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
return mInProgressWorkItems.size();
|
||||
}
|
||||
|
||||
size_t size() const
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
return mPendingWorkItems.size() + mInProgressWorkItems.size();
|
||||
}
|
||||
|
||||
/// @brief Mark a request as being finished
|
||||
/// @param requestId
|
||||
void markFinished(uint64_t const requestId)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mMutex);
|
||||
if (hasInProgressReqId(requestId))
|
||||
{
|
||||
mInProgressWorkItems.erase(requestId);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
/// Queue of work items
|
||||
std::list<std::shared_ptr<WorkItem>> mPendingWorkItems;
|
||||
/// requestIds of work items in the queue
|
||||
std::set<uint64_t> mPendingWorkItemsReqIds;
|
||||
|
||||
/// work items currently in progress
|
||||
std::unordered_map<uint64_t, std::shared_ptr<WorkItem>> mInProgressWorkItems;
|
||||
|
||||
mutable std::mutex mMutex;
|
||||
};
|
||||
|
||||
struct BenchInfo
|
||||
{
|
||||
BenchInfo() = default;
|
||||
@ -748,7 +545,7 @@ public:
|
||||
std::vector<std::string> headers = {"num_samples", "num_error_samples", "total_latency(ms)",
|
||||
"seq_throughput(seq/sec)", "token_throughput(token/sec)", "avg_sequence_latency(ms)",
|
||||
"max_sequence_latency(ms)", "min_sequence_latency(ms)", "p99_sequence_latency(ms)",
|
||||
"p90_sequence_latency(ms)", "p50_sequence_latency(ms)"};
|
||||
"p90_sequence_latency(ms)", "p50_sequence_latency(ms)", "avg_acceptance_rate(tokens/decoding steps)"};
|
||||
|
||||
if (mStreaming)
|
||||
{
|
||||
@ -772,7 +569,8 @@ public:
|
||||
outputFile << "\n";
|
||||
outputFile << mNumSamples << "," << mNumErrorSamples << "," << mTotalLatency << "," << mSeqThroughput
|
||||
<< "," << mTokenThroughput << "," << mAvgSeqLatency << "," << mMaxSeqLatency << ","
|
||||
<< mMinSeqLatency << "," << mP99SeqLatency << "," << mP90SeqLatency << "," << mP50SeqLatency;
|
||||
<< mMinSeqLatency << "," << mP99SeqLatency << "," << mP90SeqLatency << "," << mP50SeqLatency
|
||||
<< "," << mAcceptanceRate;
|
||||
if (mStreaming)
|
||||
{
|
||||
outputFile << "," << mAvgFtLatency << "," << mMaxFtLatency << "," << mMinFtLatency << ","
|
||||
@ -869,11 +667,9 @@ public:
|
||||
std::optional<std::filesystem::path> const& encoderTrtEnginePath, TrtGptModelType modelType,
|
||||
int32_t maxBeamWidth, texec::CapacitySchedulerPolicy capacitySchedulerPolicy,
|
||||
BenchmarkParams const& benchmarkParams, std::shared_ptr<Recorder> recorder, std::chrono::milliseconds waitSleep,
|
||||
std::optional<uint64_t> const staticEmulatedBatchSize, bool logIterationData,
|
||||
texec::ModelType executorModelType)
|
||||
bool logIterationData, texec::ModelType executorModelType)
|
||||
: mRecorder(std::move(recorder))
|
||||
, mWaitSleep(waitSleep)
|
||||
, mStaticEmulatedBatchSize(staticEmulatedBatchSize)
|
||||
, mConcurrency(benchmarkParams.concurrency)
|
||||
, mActiveCount(0)
|
||||
, mNumFinished(0)
|
||||
@ -1048,7 +844,6 @@ private:
|
||||
std::thread mCollectStatsThread;
|
||||
std::shared_ptr<Recorder> mRecorder;
|
||||
std::chrono::milliseconds mWaitSleep;
|
||||
std::optional<int> mStaticEmulatedBatchSize;
|
||||
std::optional<int> mConcurrency;
|
||||
std::atomic<uint64_t> mActiveCount;
|
||||
std::atomic<uint64_t> mNumFinished;
|
||||
@ -1056,288 +851,6 @@ private:
|
||||
bool mLogIterationData;
|
||||
}; // class ExecutorServer
|
||||
|
||||
class GptServer
|
||||
{
|
||||
public:
|
||||
GptServer(std::filesystem::path const& trtEnginePath, TrtGptModelType modelType,
|
||||
TrtGptModelOptionalParams const& optionalParams, std::shared_ptr<Recorder> recorder,
|
||||
std::optional<uint64_t> terminateReqId, std::chrono::milliseconds waitSleep,
|
||||
std::optional<SizeType32> const staticEmulatedBatchSize,
|
||||
std::optional<std::chrono::milliseconds> const batchTimeout, bool logIterationData, bool excludeInputInOutput)
|
||||
: mRecorder(std::move(recorder))
|
||||
, mTerminateReqId(terminateReqId)
|
||||
, mWaitSleep(waitSleep)
|
||||
, mStaticEmulatedBatchSize(staticEmulatedBatchSize)
|
||||
, mBatchTimeout(batchTimeout.value_or(std::chrono::milliseconds{0}))
|
||||
, mActiveCount(0)
|
||||
{
|
||||
auto const jsonConfig = GptJsonConfig::parse(trtEnginePath / "config.json");
|
||||
mWorldConfig = WorldConfig::mpi(jsonConfig.getGpusPerNode(), jsonConfig.getTensorParallelism(),
|
||||
jsonConfig.getPipelineParallelism(), optionalParams.deviceIds);
|
||||
auto& comm = COMM_SESSION;
|
||||
mCommTensorParallel = std::make_shared<tensorrt_llm::mpi::MpiComm>(
|
||||
comm.split(mWorldConfig.getPipelineParallelRank(), mWorldConfig.getTensorParallelRank()));
|
||||
mCommPipelineParallel = std::make_shared<tensorrt_llm::mpi::MpiComm>(
|
||||
comm.split(mWorldConfig.getTensorParallelRank(), mWorldConfig.getPipelineParallelRank()));
|
||||
|
||||
ReturnBatchManagerStatsCallback iterationDataCallback = [this, logIterationData](std::string const& log)
|
||||
{
|
||||
if (logIterationData)
|
||||
{
|
||||
TLLM_LOG_INFO(log);
|
||||
}
|
||||
|
||||
if (mStaticEmulatedBatchSize)
|
||||
{
|
||||
auto const json = nlohmann::json::parse(log);
|
||||
auto const activeRequests = json["Active Request Count"];
|
||||
TLLM_CHECK(activeRequests <= mStaticEmulatedBatchSize.value());
|
||||
}
|
||||
};
|
||||
|
||||
mBatchManager = std::make_shared<GptManager>(
|
||||
trtEnginePath, modelType, [this](int max_num_requests) { return getInferenceRequests(max_num_requests); },
|
||||
[this](uint64_t requestId, std::list<NamedTensor> const& response_tensors, bool final_response,
|
||||
std::string const& errMsg)
|
||||
{ return sendResponse(requestId, response_tensors, final_response, errMsg); },
|
||||
nullptr, iterationDataCallback, optionalParams, terminateReqId, excludeInputInOutput);
|
||||
}
|
||||
|
||||
~GptServer()
|
||||
{
|
||||
if (mInferReqWaitThread)
|
||||
{
|
||||
mInferReqWaitThread->join();
|
||||
mInferReqWaitThread.reset(nullptr);
|
||||
}
|
||||
|
||||
mWorkItemsQueue.clear();
|
||||
}
|
||||
|
||||
std::string getLayerProfileInfo()
|
||||
{
|
||||
return mBatchManager->getLayerProfileInfo();
|
||||
}
|
||||
|
||||
void setLayerProfiler()
|
||||
{
|
||||
return mBatchManager->setLayerProfiler();
|
||||
}
|
||||
|
||||
void enqueue(std::shared_ptr<InferenceRequest> const& request)
|
||||
{
|
||||
TLLM_CHECK(request != nullptr);
|
||||
auto const requestId = request->getRequestId();
|
||||
if (requestId == mTerminateReqId)
|
||||
{
|
||||
mWorkItemsQueue.push(request, requestId);
|
||||
return;
|
||||
}
|
||||
|
||||
// Enqueue
|
||||
try
|
||||
{
|
||||
mRecorder->recordStart(request, requestId);
|
||||
mWorkItemsQueue.push(request, requestId);
|
||||
}
|
||||
catch (tc::TllmException const& e)
|
||||
{
|
||||
throw;
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_THROW("%s", e.what());
|
||||
}
|
||||
}
|
||||
|
||||
void resetBatchDeadline()
|
||||
{
|
||||
mBatchDeadline = (std::chrono::steady_clock::now() + mBatchTimeout).time_since_epoch();
|
||||
}
|
||||
|
||||
void waitForEmpty() const
|
||||
{
|
||||
while (!mWorkItemsQueue.empty())
|
||||
{
|
||||
std::this_thread::sleep_for(mWaitSleep);
|
||||
}
|
||||
}
|
||||
|
||||
void waitBatchManager() const
|
||||
{
|
||||
mBatchManager->waitUntilTerminate();
|
||||
}
|
||||
|
||||
void shutdown() const
|
||||
{
|
||||
mBatchManager->shutdown();
|
||||
}
|
||||
|
||||
// Return up to max_num_requests inference requests.
|
||||
std::list<std::shared_ptr<InferenceRequest>> getInferenceRequests(int const max_num_requests)
|
||||
{
|
||||
if (mInferReqWaitThread)
|
||||
{
|
||||
mInferReqWaitThread->join();
|
||||
mInferReqWaitThread.reset(nullptr);
|
||||
}
|
||||
std::list<std::shared_ptr<InferenceRequest>> inferenceRequests;
|
||||
auto& comm = COMM_SESSION;
|
||||
if (max_num_requests > 0)
|
||||
{
|
||||
auto rank = comm.getRank();
|
||||
if (rank == 0)
|
||||
{
|
||||
auto const numNewWorkItems = std::min(static_cast<int64_t>(mWorkItemsQueue.numPendingWorkItems()),
|
||||
static_cast<int64_t>(max_num_requests));
|
||||
|
||||
bool const timeout = std::chrono::steady_clock::now().time_since_epoch() > mBatchDeadline.load();
|
||||
bool readyForNextBatch = numNewWorkItems > 0 && timeout;
|
||||
if (mStaticEmulatedBatchSize)
|
||||
{
|
||||
if (numNewWorkItems > 0)
|
||||
{
|
||||
bool const previousBatchFinished = mActiveCount == 0;
|
||||
bool const haveEnoughForNextBatch = numNewWorkItems >= mStaticEmulatedBatchSize.value();
|
||||
readyForNextBatch = previousBatchFinished && (timeout || haveEnoughForNextBatch);
|
||||
}
|
||||
if (numNewWorkItems == 0 || readyForNextBatch)
|
||||
{
|
||||
// Timeout should only begin once we have at least 1 pending request.
|
||||
// Reset timeout when no requests are pending or we submit a new batch.
|
||||
resetBatchDeadline();
|
||||
}
|
||||
}
|
||||
|
||||
if (readyForNextBatch)
|
||||
{
|
||||
// Only add a single batch at a time when emulating static batching
|
||||
auto const numItemsToAdd = std::min(
|
||||
numNewWorkItems, static_cast<int64_t>(mStaticEmulatedBatchSize.value_or(numNewWorkItems)));
|
||||
mActiveCount += numItemsToAdd;
|
||||
while (inferenceRequests.size() < numItemsToAdd)
|
||||
{
|
||||
auto [workItem, markedInProgress] = mWorkItemsQueue.pop();
|
||||
|
||||
if (markedInProgress)
|
||||
{
|
||||
inferenceRequests.emplace_back(workItem->getInferenceRequest());
|
||||
}
|
||||
else
|
||||
{
|
||||
auto warnStr = tc::fmtstr(
|
||||
"request Id %lu has been stopped. Request is ignored.", workItem->requestId());
|
||||
TLLM_LOG_WARNING(warnStr);
|
||||
sendResponse(workItem->requestId(), {}, true, warnStr);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (mWorldConfig.isTensorParallel())
|
||||
{
|
||||
auto numNewWorkItems = static_cast<int64_t>(inferenceRequests.size());
|
||||
if (numNewWorkItems > 0 || mBatchManager->getNumActiveRequests() > 0)
|
||||
{
|
||||
mCommTensorParallel->bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
|
||||
}
|
||||
if (numNewWorkItems > 0)
|
||||
{
|
||||
std::vector<int64_t> packed;
|
||||
for (auto const& infReq : inferenceRequests)
|
||||
{
|
||||
auto vpacked = infReq->serialize();
|
||||
packed.push_back(static_cast<int64_t>(vpacked.size()));
|
||||
packed.insert(
|
||||
packed.end(), std::move_iterator(vpacked.begin()), std::move_iterator(vpacked.end()));
|
||||
}
|
||||
mCommTensorParallel->bcast(packed, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// subordinate ranks hang until master rank sends work
|
||||
if (mWorldConfig.isFirstPipelineParallelRank())
|
||||
{
|
||||
int64_t numNewWorkItems = 0;
|
||||
mCommTensorParallel->bcast(&numNewWorkItems, 1, mpi::MpiType::kINT64, 0);
|
||||
if (numNewWorkItems > 0)
|
||||
{
|
||||
std::vector<int64_t> packed;
|
||||
mCommTensorParallel->bcast(packed, 0);
|
||||
int64_t* packed_ptr = packed.data();
|
||||
for (int64_t count = 0; count < numNewWorkItems; ++count)
|
||||
{
|
||||
int64_t n = *(packed_ptr++);
|
||||
auto infReq = InferenceRequest::deserialize(packed_ptr);
|
||||
packed_ptr += n;
|
||||
inferenceRequests.emplace_back(infReq);
|
||||
}
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
auto const peer = mWorldConfig.getPipelineParallelRank() - 1;
|
||||
inferenceRequestsRecv(mCommPipelineParallel, inferenceRequests, peer);
|
||||
}
|
||||
}
|
||||
if (!mWorldConfig.isLastPipelineParallelRank())
|
||||
{
|
||||
auto const peer = mWorldConfig.getPipelineParallelRank() + 1;
|
||||
auto inferReqAsyncSndHdl
|
||||
= std::make_unique<InferenceRequestsAsyncSend>(mCommPipelineParallel, inferenceRequests, peer);
|
||||
mInferReqWaitThread = std::make_unique<std::thread>([handle = std::move(inferReqAsyncSndHdl)]() {});
|
||||
}
|
||||
}
|
||||
return inferenceRequests;
|
||||
}
|
||||
|
||||
void sendResponse(uint64_t requestId, [[maybe_unused]] std::list<NamedTensor> const& response_tensors,
|
||||
bool final_response, [[maybe_unused]] std::string const& errMsg)
|
||||
{
|
||||
// `response_tensors` contains `outputIds, sequenceLength, [contextLogits, generationLogits], logProbs,
|
||||
// cumLogProbs`. `contextLogits, generationLogits` are optional, only contained when `gather_context_logits` and
|
||||
// `gather_generation_logits` are enabled respectively. Or enable 'gather_all_token_logits' to enable both of
|
||||
// them.
|
||||
try
|
||||
{
|
||||
|
||||
if (final_response)
|
||||
{
|
||||
mWorkItemsQueue.markFinished(requestId);
|
||||
mRecorder->recordEnd(requestId, response_tensors, !errMsg.empty());
|
||||
mActiveCount--;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (errMsg.empty())
|
||||
{
|
||||
mRecorder->recordToken(requestId, response_tensors);
|
||||
}
|
||||
}
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_ERROR("Failed to send response for requestId %lu\n%s", requestId, e.what());
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<GptManager> mBatchManager;
|
||||
std::shared_ptr<Recorder> mRecorder;
|
||||
WorkItemsQueue mWorkItemsQueue;
|
||||
std::optional<uint64_t> mTerminateReqId;
|
||||
std::chrono::milliseconds mWaitSleep;
|
||||
std::optional<SizeType32> mStaticEmulatedBatchSize;
|
||||
std::chrono::milliseconds mBatchTimeout;
|
||||
std::atomic<std::chrono::steady_clock::time_point::duration> mBatchDeadline;
|
||||
std::atomic<uint64_t> mActiveCount;
|
||||
WorldConfig mWorldConfig;
|
||||
std::shared_ptr<tensorrt_llm::mpi::MpiComm> mCommTensorParallel;
|
||||
std::shared_ptr<tensorrt_llm::mpi::MpiComm> mCommPipelineParallel;
|
||||
std::unique_ptr<std::thread> mInferReqWaitThread;
|
||||
|
||||
}; // class GptServer
|
||||
|
||||
namespace
|
||||
{
|
||||
|
||||
@ -1418,60 +931,6 @@ std::vector<double> computeTimeDelays(BenchmarkParams const& benchmarkParams, in
|
||||
return timeDelays;
|
||||
}
|
||||
|
||||
std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const& sample, bool streaming,
|
||||
ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId, ITensor::SharedPtr const& padId,
|
||||
BufferManager const& bufferManager, ITensor::SharedPtr const& returnContextLogits = nullptr,
|
||||
ITensor::SharedPtr const& returnGenerationLogits = nullptr, ITensor::SharedPtr const& loraWeights = nullptr,
|
||||
ITensor::SharedPtr const& loraConfig = nullptr,
|
||||
std::optional<tensorrt_llm::executor::LookaheadDecodingConfig> lookaheadConfig = std::nullopt)
|
||||
{
|
||||
auto request = std::make_shared<InferenceRequest>(reqId);
|
||||
auto const& inputIds = sample.inputIds;
|
||||
request->setInputIds(bufferManager.copyFrom(
|
||||
inputIds, ITensor::makeShape({static_cast<SizeType32>(inputIds.size())}), MemoryType::kCPU));
|
||||
auto const requestOutputLen = sample.outputLen;
|
||||
request->setMaxNewTokens(bufferManager.copyFrom(&requestOutputLen, ITensor::makeShape({1, 1}), MemoryType::kCPU));
|
||||
request->setBeamWidth(beamWidthTensor);
|
||||
if (eosId != nullptr)
|
||||
{
|
||||
request->setEndId(eosId);
|
||||
}
|
||||
if (padId != nullptr)
|
||||
{
|
||||
request->setPadId(padId);
|
||||
}
|
||||
if (returnContextLogits)
|
||||
{
|
||||
request->setReturnContextLogits(returnContextLogits);
|
||||
}
|
||||
if (returnGenerationLogits)
|
||||
{
|
||||
request->setReturnGenerationLogits(returnGenerationLogits);
|
||||
}
|
||||
if (sample.taskId >= 0)
|
||||
{
|
||||
uint64_t taskId = static_cast<uint64_t>(sample.taskId);
|
||||
request->setLoraTaskId(bufferManager.copyFrom(&taskId, ITensor::makeShape({1}), MemoryType::kPINNEDPOOL));
|
||||
}
|
||||
if (loraWeights)
|
||||
{
|
||||
request->setLoraWeights(loraWeights);
|
||||
}
|
||||
if (loraConfig)
|
||||
{
|
||||
request->setLoraConfig(loraConfig);
|
||||
}
|
||||
if (lookaheadConfig)
|
||||
{
|
||||
request->setLookaheadConfig(lookaheadConfig.value());
|
||||
}
|
||||
if (streaming)
|
||||
{
|
||||
request->setIsStreaming(true);
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamWidth,
|
||||
std::optional<SizeType32> const& eosId, std::optional<SizeType32> const& padId, bool streaming = false,
|
||||
bool const& returnContextLogits = false, bool const& returnGenerationLogits = false,
|
||||
@ -1495,185 +954,6 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
|
||||
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
|
||||
}
|
||||
|
||||
void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType modelType,
|
||||
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
|
||||
std::optional<TokenIdType> const& eosId, std::optional<TokenIdType> const& padId,
|
||||
BenchmarkParams const& benchmarkParams, texec::CapacitySchedulerPolicy capacitySchedulerPolicy,
|
||||
std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits,
|
||||
std::optional<SizeType32> const staticEmulatedBatchSize,
|
||||
std::optional<std::chrono::milliseconds> const batchTimeout, bool logIterationData, bool excludeInputInOutput,
|
||||
std::string const& responsesJsonFile, std::optional<SizeType32> const maxPromptLen, bool dumpProfile)
|
||||
{
|
||||
TrtGptModelOptionalParams optionalParams;
|
||||
|
||||
if (benchmarkParams.maxTokensInPagedKvCache)
|
||||
{
|
||||
optionalParams.kvCacheConfig.maxTokens = benchmarkParams.maxTokensInPagedKvCache;
|
||||
}
|
||||
if (benchmarkParams.freeGpuMemoryFraction)
|
||||
{
|
||||
optionalParams.kvCacheConfig.freeGpuMemoryFraction = benchmarkParams.freeGpuMemoryFraction;
|
||||
}
|
||||
if (benchmarkParams.crossKvCacheFraction)
|
||||
{
|
||||
optionalParams.kvCacheConfig.crossKvCacheFraction = benchmarkParams.crossKvCacheFraction;
|
||||
}
|
||||
if (benchmarkParams.maxAttentionWindowVec)
|
||||
{
|
||||
optionalParams.kvCacheConfig.maxAttentionWindowVec = benchmarkParams.maxAttentionWindowVec;
|
||||
}
|
||||
if (benchmarkParams.sinkTokenLength)
|
||||
{
|
||||
optionalParams.kvCacheConfig.sinkTokenLength = benchmarkParams.sinkTokenLength;
|
||||
}
|
||||
optionalParams.kvCacheConfig.enableBlockReuse = benchmarkParams.enableBlockReuse;
|
||||
optionalParams.enableChunkedContext = benchmarkParams.enableChunkedContext;
|
||||
optionalParams.enableTrtOverlap = benchmarkParams.enableTrtOverlap;
|
||||
optionalParams.peftCacheManagerConfig.hostCacheSize = benchmarkParams.loraHostCacheSize;
|
||||
optionalParams.peftCacheManagerConfig.numDeviceModuleLayer = benchmarkParams.loraDeviceNumModLayers;
|
||||
optionalParams.peftCacheManagerConfig.numPutWorkers = 4;
|
||||
optionalParams.peftCacheManagerConfig.numEnsureWorkers = 4;
|
||||
optionalParams.peftCacheManagerConfig.numCopyStreams = 4;
|
||||
optionalParams.kvCacheConfig.hostCacheSize = benchmarkParams.kvHostCacheSize;
|
||||
optionalParams.kvCacheConfig.onboardBlocks = benchmarkParams.kvOnboardBlocks;
|
||||
optionalParams.gpuWeightsPercent = benchmarkParams.gpuWeightsPercent;
|
||||
optionalParams.maxBeamWidth = beamWidth;
|
||||
optionalParams.maxBatchSize = benchmarkParams.maxBatchSize;
|
||||
optionalParams.maxNumTokens = benchmarkParams.maxNumTokens;
|
||||
optionalParams.schedulerConfig = texec::SchedulerConfig{capacitySchedulerPolicy};
|
||||
optionalParams.decodingConfig
|
||||
= texec::DecodingConfig(benchmarkParams.medusaChoices.has_value() ? texec::DecodingMode::Medusa()
|
||||
: benchmarkParams.executorLookaheadConfig.has_value() ? texec::DecodingMode::Lookahead()
|
||||
: texec::DecodingMode::Auto(),
|
||||
benchmarkParams.executorLookaheadConfig, benchmarkParams.medusaChoices);
|
||||
optionalParams.extendedRuntimePerfKnobConfig = texec::ExtendedRuntimePerfKnobConfig(benchmarkParams.multiBlockMode,
|
||||
benchmarkParams.enableContextFMHAFP32Acc, benchmarkParams.cudaGraphMode, benchmarkParams.cudaGraphCacheSize);
|
||||
|
||||
auto const jsonConfig = GptJsonConfig::parse(engineDir / "config.json");
|
||||
auto const worldConfig = WorldConfig::mpi(jsonConfig.getGpusPerNode(), jsonConfig.getTensorParallelism(),
|
||||
jsonConfig.getPipelineParallelism(), optionalParams.deviceIds);
|
||||
|
||||
BufferManager bufferManager{std::make_shared<CudaStream>()}; // the stream is not used
|
||||
|
||||
ITensor::SharedPtr beamWidthTensor{
|
||||
bufferManager.copyFrom(&beamWidth, ITensor::makeShape({1}), MemoryType::kPINNEDPOOL)};
|
||||
|
||||
// Load dataset
|
||||
auto const samples = parseWorkloadJson(datasetPath, maxNumSamples, maxPromptLen);
|
||||
auto const numSamples = samples.size();
|
||||
|
||||
auto recorder = std::make_shared<Recorder>(
|
||||
opCsvFile, benchmarkParams.streaming, beamWidth, responsesJsonFile, excludeInputInOutput);
|
||||
uint64_t terminateReqId = numSamples + 1;
|
||||
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, optionalParams, recorder, terminateReqId,
|
||||
waitSleep, staticEmulatedBatchSize, batchTimeout, logIterationData, excludeInputInOutput);
|
||||
|
||||
ITensor::SharedPtr eosIdTensor{
|
||||
eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNEDPOOL) : nullptr};
|
||||
ITensor::SharedPtr padIdTensor{
|
||||
padId ? bufferManager.copyFrom(&padId.value(), ITensor::makeShape({1}), MemoryType::kPINNEDPOOL) : nullptr};
|
||||
|
||||
ITensor::SharedPtr returnContextLogitsFlagTensor{returnContextLogits
|
||||
? bufferManager.copyFrom(&returnContextLogits, ITensor::makeShape({1}), MemoryType::kPINNEDPOOL)
|
||||
: nullptr};
|
||||
|
||||
ITensor::SharedPtr returnGenerationLogitsFlagTensor{returnGenerationLogits
|
||||
? bufferManager.copyFrom(&returnGenerationLogits, ITensor::makeShape({1}), MemoryType::kPINNEDPOOL)
|
||||
: nullptr};
|
||||
|
||||
if (worldConfig.getRank() == 0)
|
||||
{
|
||||
if (benchmarkParams.loraDir)
|
||||
{
|
||||
auto startLoraLoad = std::chrono::steady_clock::now();
|
||||
LoraLib loras(benchmarkParams.loraDir.value());
|
||||
SizeType32 reqId = 0;
|
||||
gptServer->resetBatchDeadline();
|
||||
for (auto const& [taskId, p] : loras.getLoras())
|
||||
{
|
||||
reqId++;
|
||||
if (reqId == terminateReqId)
|
||||
{
|
||||
reqId++;
|
||||
}
|
||||
Sample s{std::vector<int32_t>{1, 2, 3, 4, 5}, 1, static_cast<int32_t>(taskId)};
|
||||
auto r = makeRequest(reqId, s, benchmarkParams.streaming, beamWidthTensor, eosIdTensor, padIdTensor,
|
||||
bufferManager, nullptr, nullptr, p.first, p.second);
|
||||
gptServer->enqueue(r);
|
||||
}
|
||||
gptServer->waitForEmpty();
|
||||
auto endLoraLoad = std::chrono::steady_clock::now();
|
||||
printf("[BENCHMARK] time to preload LoRAs(ms) %.2f\n",
|
||||
std::chrono::duration<float, std::milli>(endLoraLoad - startLoraLoad).count());
|
||||
}
|
||||
|
||||
// Warm up
|
||||
gptServer->resetBatchDeadline();
|
||||
SizeType32 reqId = 0;
|
||||
for (auto i = 0; i < warmUp; ++i)
|
||||
{
|
||||
++reqId;
|
||||
if (i == terminateReqId)
|
||||
++reqId;
|
||||
auto request = makeRequest(reqId, samples[0], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
|
||||
padIdTensor, bufferManager, nullptr, nullptr, nullptr, nullptr, benchmarkParams.requestLookaheadConfig);
|
||||
gptServer->enqueue(request);
|
||||
}
|
||||
gptServer->waitForEmpty();
|
||||
|
||||
// Time delay
|
||||
auto timeDelays = computeTimeDelays(benchmarkParams, numSamples - 1);
|
||||
|
||||
// Benchmark
|
||||
recorder->initialize();
|
||||
gptServer->resetBatchDeadline();
|
||||
|
||||
for (std::size_t i = 0; i < numSamples; ++i)
|
||||
{
|
||||
auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
|
||||
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor, nullptr,
|
||||
nullptr, benchmarkParams.requestLookaheadConfig);
|
||||
gptServer->enqueue(request);
|
||||
|
||||
if (i < numSamples - 1)
|
||||
{
|
||||
auto delayInMs = static_cast<int>(timeDelays.at(i) * 1000);
|
||||
std::chrono::milliseconds delay(delayInMs);
|
||||
std::this_thread::sleep_for(delay);
|
||||
}
|
||||
}
|
||||
gptServer->waitForEmpty();
|
||||
recorder->finalize();
|
||||
recorder->calculateMetrics();
|
||||
recorder->report();
|
||||
recorder->writeOpMetricsToCsv();
|
||||
recorder->dumpResponseSeqs();
|
||||
if (dumpProfile)
|
||||
{
|
||||
// Do per-layer profiling after normal benchmarking to avoid introducing perf overhead.
|
||||
gptServer->resetBatchDeadline();
|
||||
gptServer->setLayerProfiler();
|
||||
for (std::size_t i = 0; i < numSamples; ++i)
|
||||
{
|
||||
auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
|
||||
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor,
|
||||
nullptr, nullptr, benchmarkParams.requestLookaheadConfig);
|
||||
gptServer->enqueue(request);
|
||||
}
|
||||
gptServer->waitForEmpty();
|
||||
if (worldConfig.getRank() == 0)
|
||||
{
|
||||
printf("[BENCHMARK] Per layer performance profile\n%s\n", gptServer->getLayerProfileInfo().c_str());
|
||||
}
|
||||
}
|
||||
// Send terminateReqId to terminate servers on all ranks
|
||||
// Server on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
|
||||
gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId));
|
||||
}
|
||||
// Wait until benchmarking is done and batch manager is terminated
|
||||
gptServer->waitBatchManager();
|
||||
}
|
||||
|
||||
void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngineDir,
|
||||
std::optional<std::filesystem::path> const& encoderEngineDir, TrtGptModelType modelType,
|
||||
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
|
||||
@ -1698,16 +978,15 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
decoderEngineDir.has_value(), "decoder models require a path to decoder engine in executor benchmark.");
|
||||
executorServer = std::make_shared<ExecutorServer>(decoderEngineDir.value(), std::nullopt, modelType, beamWidth,
|
||||
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData,
|
||||
executorModelType);
|
||||
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
|
||||
}
|
||||
else if (executorModelType == texec::ModelType::kENCODER_DECODER)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(encoderEngineDir.has_value(),
|
||||
"encoder-decoder models require a path to encoder engine in executor benchmark.");
|
||||
executorServer = std::make_shared<ExecutorServer>(decoderEngineDir.value(), encoderEngineDir.value(), modelType,
|
||||
beamWidth, capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize,
|
||||
logIterationData, executorModelType);
|
||||
executorServer
|
||||
= std::make_shared<ExecutorServer>(decoderEngineDir.value(), encoderEngineDir.value(), modelType, beamWidth,
|
||||
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
|
||||
try
|
||||
{
|
||||
std::ifstream decoderJsonConfigPath(decoderEngineDir.value() / "config.json");
|
||||
@ -1733,8 +1012,7 @@ void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngine
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
encoderEngineDir.has_value(), "encoder models require a path to encoder engine in executor benchmark.");
|
||||
executorServer = std::make_shared<ExecutorServer>(std::nullopt, encoderEngineDir.value(), modelType, beamWidth,
|
||||
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData,
|
||||
executorModelType);
|
||||
capacitySchedulerPolicy, benchmarkParams, recorder, waitSleep, logIterationData, executorModelType);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -1937,6 +1215,8 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("h,help", "Print usage");
|
||||
options.add_options()("engine_dir, decoder_engine_dir", "Directory that store the engines of decoder models.",
|
||||
cxxopts::value<std::string>());
|
||||
options.add_options()(
|
||||
"encoder_engine_dir", "Directory that store the engines of the encoder models.", cxxopts::value<std::string>());
|
||||
options.add_options()(
|
||||
"api", "API type: gptManager or executor.", cxxopts::value<std::string>()->default_value("executor"));
|
||||
options.add_options()("type", "Batching type: IFB, UIFB (unfused IFB) or V1 (non-IFB) batching.",
|
||||
@ -1971,8 +1251,6 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("max_batch_size", "The max runtime batch size when benchmarking", cxxopts::value<int>());
|
||||
options.add_options()(
|
||||
"max_num_tokens", "The max runtime number of tokens per batch when benchmarking", cxxopts::value<int>());
|
||||
options.add_options()("enable_trt_overlap", "Overlap TRT context preparation and execution",
|
||||
cxxopts::value<bool>()->default_value("false"));
|
||||
options.add_options()(
|
||||
"enable_batch_size_tuning", "Dynamic tuning of batch size", cxxopts::value<bool>()->default_value("false"));
|
||||
options.add_options()("enable_exp_delays", "Enables exponential delay distr to mimic real world request arrival",
|
||||
@ -1991,15 +1269,8 @@ int main(int argc, char* argv[])
|
||||
"Choose scheduler policy between max_utilization/guaranteed_no_evict/static_batch.",
|
||||
cxxopts::value<std::string>()->default_value("guaranteed_no_evict"));
|
||||
|
||||
options.add_options()("first_batch_delay",
|
||||
"Delay before submitting the first batch of requests. This can be used to increase the size of the first "
|
||||
"batch.",
|
||||
cxxopts::value<int32_t>());
|
||||
options.add_options()("static_emulated_batch_size",
|
||||
"Emulate static batching performance with the provided batch size.", cxxopts::value<SizeType32>());
|
||||
options.add_options()("static_emulated_timeout",
|
||||
"Timeout (ms) before launching a partial batch in emulated static batching mode",
|
||||
cxxopts::value<int32_t>()->default_value("500"));
|
||||
options.add_options()("log_level", "Choose log level between verbose/info/warning/error/internal_error.",
|
||||
cxxopts::value<std::string>()->default_value("error"));
|
||||
options.add_options()("log_iteration_data", "On each decoder iteration, print batch state metadata.",
|
||||
@ -2012,22 +1283,11 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("kv_host_cache_bytes",
|
||||
"Size of secondary memory pool used for offloading kv cache blocks (in bytes).",
|
||||
cxxopts::value<size_t>()->default_value("0"));
|
||||
options.add_options()("kv_dont_onboard_blocks",
|
||||
"If offloaded blocks should be onboarded to primary memory before reuse",
|
||||
cxxopts::value<bool>()->default_value("false"));
|
||||
|
||||
options.add_options()("exclude_input_in_output_seq",
|
||||
"When enabled, GptManager will exclude the input sequence from output. (Only works if --api is gptManager)",
|
||||
cxxopts::value<bool>());
|
||||
|
||||
options.add_options()("responses_json_file",
|
||||
"When specified, dumps the responses to JSON file. (only works if --api is gptManager)",
|
||||
cxxopts::value<std::string>()->default_value(""));
|
||||
|
||||
options.add_options()("kv_onboard_blocks", "If offloaded blocks should be onboarded to primary memory before reuse",
|
||||
cxxopts::value<bool>()->default_value("true"));
|
||||
options.add_options()(
|
||||
"max_prompt_len", "Truncate all prompts from dataset to the length specified.", cxxopts::value<SizeType32>());
|
||||
|
||||
options.add_options()("dump_profile", "Print profile information per layer.", cxxopts::value<bool>());
|
||||
options.add_options()("gpu_weights_percent",
|
||||
"Specify the percentage of weights that reside on GPU (from 0.0 to 1.0).",
|
||||
cxxopts::value<float>()->default_value("1.0"));
|
||||
@ -2037,8 +1297,6 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("multi_block_mode",
|
||||
"Distribute the work across multiple CUDA thread-blocks on the GPU for masked MHA kernel",
|
||||
cxxopts::value<bool>()->default_value("true"));
|
||||
options.add_options()(
|
||||
"encoder_engine_dir", "Directory that store the engines of the encoder models.", cxxopts::value<std::string>());
|
||||
options.add_options()("cuda_graph_mode", "When enabled, inference is executed with cuda graph.",
|
||||
cxxopts::value<bool>()->default_value("false"));
|
||||
options.add_options()("cuda_graph_cache_size",
|
||||
@ -2051,7 +1309,6 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("executor_lookahead_config",
|
||||
"lookahead config in the format of [max_window_size, max_ngram_size, max_verification_set_size]",
|
||||
cxxopts::value<std::string>());
|
||||
|
||||
options.add_options()("request_lookahead_config",
|
||||
"lookahead config in the format of [max_window_size, max_ngram_size, max_verification_set_size], and each <= "
|
||||
"executor lookahead config",
|
||||
@ -2073,9 +1330,6 @@ int main(int argc, char* argv[])
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Argument: API
|
||||
auto const api = result["api"].as<std::string>();
|
||||
|
||||
// Argument: Batching Type
|
||||
auto const type = result["type"].as<std::string>();
|
||||
TrtGptModelType modelType{TrtGptModelType::V1};
|
||||
@ -2153,9 +1407,6 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
}
|
||||
|
||||
// Argument: Enable TRT overlap
|
||||
benchmarkParams.enableTrtOverlap = result["enable_trt_overlap"].as<bool>();
|
||||
|
||||
// Argument: Enable dynamic tuning of batch size
|
||||
benchmarkParams.enableBatchSizeTuning = result["enable_batch_size_tuning"].as<bool>();
|
||||
|
||||
@ -2228,7 +1479,7 @@ int main(int argc, char* argv[])
|
||||
benchmarkParams.kvHostCacheSize = result["kv_host_cache_bytes"].as<size_t>();
|
||||
|
||||
// Argument: If offloaded blocks should be onboarded to primary memory before they are reused.
|
||||
benchmarkParams.kvOnboardBlocks = !result["kv_dont_onboard_blocks"].as<bool>();
|
||||
benchmarkParams.kvOnboardBlocks = result["kv_onboard_blocks"].as<bool>();
|
||||
|
||||
// Argument: Medusa choices for the Medusa speculative decoding.
|
||||
if (result.count("medusa_choices"))
|
||||
@ -2255,7 +1506,7 @@ int main(int argc, char* argv[])
|
||||
// Argument: cuda_graph_mode
|
||||
benchmarkParams.cudaGraphMode = result["cuda_graph_mode"].as<bool>();
|
||||
|
||||
// Argument: cuda_graph_mode
|
||||
// Argument: cuda_graph_cache_size
|
||||
benchmarkParams.cudaGraphCacheSize = result["cuda_graph_cache_size"].as<SizeType32>();
|
||||
|
||||
std::optional<TokenIdType> padId;
|
||||
@ -2268,20 +1519,11 @@ int main(int argc, char* argv[])
|
||||
// Argument: End-of-sentence token id
|
||||
std::optional<TokenIdType> eosId = result["eos_id"].as<TokenIdType>();
|
||||
|
||||
std::optional<std::chrono::milliseconds> batchTimeout;
|
||||
// Argument: first_batch_delay
|
||||
if (result.count("first_batch_delay"))
|
||||
{
|
||||
batchTimeout = std::chrono::milliseconds{result["first_batch_delay"].as<int32_t>()};
|
||||
}
|
||||
|
||||
std::optional<SizeType32> staticEmulatedBatchSize;
|
||||
// Argument: Static emulated batch size
|
||||
if (result.count("static_emulated_batch_size"))
|
||||
{
|
||||
staticEmulatedBatchSize = result["static_emulated_batch_size"].as<SizeType32>();
|
||||
|
||||
batchTimeout = std::chrono::milliseconds{result["static_emulated_timeout"].as<int32_t>()};
|
||||
}
|
||||
|
||||
// Argument: Scheduler policy
|
||||
@ -2313,7 +1555,6 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
|
||||
// Argument: GPU weights percentage
|
||||
std::istringstream ssGpuPercentArg;
|
||||
auto gpuWeightsPercent = result["gpu_weights_percent"].as<float>();
|
||||
if (gpuWeightsPercent < 0 || gpuWeightsPercent > 1)
|
||||
{
|
||||
@ -2351,29 +1592,11 @@ int main(int argc, char* argv[])
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Argument: dump profile
|
||||
bool dumpProfile = result["dump_profile"].as<bool>();
|
||||
|
||||
initTrtLlmPlugins(logger.get());
|
||||
|
||||
if (api == "gptManager")
|
||||
{
|
||||
try
|
||||
{
|
||||
benchmarkGptManager(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile,
|
||||
maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams,
|
||||
capacitySchedulerPolicy, waitSleep, returnContextLogits, returnGenerationLogits,
|
||||
staticEmulatedBatchSize, batchTimeout, logIterationData,
|
||||
result["exclude_input_in_output_seq"].as<bool>(), result["responses_json_file"].as<std::string>(),
|
||||
maxPromptLen, dumpProfile);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
TLLM_LOG_ERROR(e.what());
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
else if (api == "executor")
|
||||
// Argument: API
|
||||
auto const api = result["api"].as<std::string>();
|
||||
if (api == "executor")
|
||||
{
|
||||
texec::ModelType executorModelType;
|
||||
std::optional<std::string> decoderEngineDir = std::nullopt, encoderEngineDir = std::nullopt;
|
||||
@ -2409,6 +1632,11 @@ int main(int argc, char* argv[])
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
else if (api == "gptManager")
|
||||
{
|
||||
TLLM_LOG_ERROR("gptManager is deprecated, please use the executor API.");
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
TLLM_LOG_ERROR("api parameter must be gptManager or executor");
|
||||
|
||||
50
cpp/include/tensorrt_llm/batch_manager/allocateKvCache.h
Normal file
50
cpp/include/tensorrt_llm/batch_manager/allocateKvCache.h
Normal file
@ -0,0 +1,50 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/common/algorithm.h"
|
||||
#include "tensorrt_llm/common/optionalRef.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
class AllocateKvCache : Algorithm
|
||||
{
|
||||
using KVCacheManager = tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager;
|
||||
|
||||
template <typename T>
|
||||
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
|
||||
|
||||
public:
|
||||
constexpr static auto name{"AllocateKvCache"};
|
||||
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
AllocateKvCache() = default;
|
||||
|
||||
void operator()(KVCacheManager& kvCacheManager, RequestVector& contextRequests,
|
||||
RequestVector const& generationRequests, runtime::ModelConfig const& modelConfig,
|
||||
OptionalRef<KVCacheManager> crossKvCacheManager = std::nullopt) const;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
44
cpp/include/tensorrt_llm/batch_manager/assignReqSeqSlots.h
Normal file
44
cpp/include/tensorrt_llm/batch_manager/assignReqSeqSlots.h
Normal file
@ -0,0 +1,44 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/batch_manager/sequenceSlotManager.h"
|
||||
#include "tensorrt_llm/common/algorithm.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
class AssignReqSeqSlots : Algorithm
|
||||
{
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
public:
|
||||
constexpr static auto name{"AssignReqSeqSlots"};
|
||||
|
||||
AssignReqSeqSlots() = default;
|
||||
|
||||
void operator()(SequenceSlotManager& seqSlotManager, RequestVector const& contextRequests,
|
||||
RequestVector const& generationRequests) const;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
@ -19,6 +19,7 @@
|
||||
#include "common.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h"
|
||||
#include "tensorrt_llm/common/algorithm.h"
|
||||
#include "tensorrt_llm/common/optionalRef.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include <variant>
|
||||
|
||||
@ -35,6 +36,7 @@ namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
using tensorrt_llm::runtime::SizeType32;
|
||||
using common::OptionalRef;
|
||||
|
||||
/// @brief This scheduler takes into account the given request capacity and the KV cache capacity.
|
||||
/// Depending on the CapacitySchedulerPolicy it will schedule already started and new requests,
|
||||
@ -69,8 +71,6 @@ class MaxRequestsScheduler : public BaseCapacityScheduler
|
||||
{
|
||||
public:
|
||||
explicit MaxRequestsScheduler(SizeType32 maxNumRequests,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
|
||||
|
||||
@ -80,8 +80,6 @@ public:
|
||||
|
||||
private:
|
||||
SizeType32 mMaxNumRequests;
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr};
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr};
|
||||
};
|
||||
|
||||
/// @brief Schedule requests using the MAX_UTILIZATION policy
|
||||
@ -90,24 +88,21 @@ private:
|
||||
class MaxUtilizationScheduler : public BaseCapacityScheduler
|
||||
{
|
||||
public:
|
||||
MaxUtilizationScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
|
||||
std::shared_ptr<BasePeftCacheManager> peftCacheManager, bool manyMicroBatches,
|
||||
MaxUtilizationScheduler(SizeType32 maxNumRequests, bool manyMicroBatches,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
|
||||
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(kv_cache_manager::KVCacheManager& kvCacheManager,
|
||||
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
|
||||
|
||||
private:
|
||||
/// @return {fitsKvCache, fitsPeft}
|
||||
std::pair<bool, bool> trySchedulingRequestMaxUtilization(std::shared_ptr<LlmRequest> const& req,
|
||||
std::pair<bool, bool> trySchedulingRequestMaxUtilization(kv_cache_manager::KVCacheManager const& kvCacheManager,
|
||||
OptionalRef<BasePeftCacheManager const> peftCacheManager, std::shared_ptr<LlmRequest> const& req,
|
||||
RequestVector& scheduledRequests, SizeType32& numScheduledBlocks, SizeType32& numScheduledPeftPages,
|
||||
std::unordered_set<uint64_t>& seenTaskIds) const;
|
||||
|
||||
SizeType32 mMaxNumRequests;
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr};
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr};
|
||||
std::shared_ptr<BasePeftCacheManager> mPeftCacheManager{nullptr};
|
||||
/// @brief Boolean that indicates if multiple micro batches might be in flight
|
||||
bool mManyMicroBatches;
|
||||
};
|
||||
@ -117,36 +112,36 @@ class GuaranteedNoEvictScheduler : public BaseCapacityScheduler
|
||||
{
|
||||
public:
|
||||
GuaranteedNoEvictScheduler(SizeType32 maxNumRequests,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
|
||||
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
|
||||
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
|
||||
kv_cache_manager::KVCacheManager const& kvCacheManager,
|
||||
OptionalRef<kv_cache_manager::KVCacheManager const> crossKvCacheManager,
|
||||
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
|
||||
|
||||
protected:
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> forwardImpl(
|
||||
RequestList const& activeRequests, bool staticBatchScheduling) const;
|
||||
template <bool StaticBatchScheduling>
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> impl(kv_cache_manager::KVCacheManager const& kvCacheManager,
|
||||
OptionalRef<kv_cache_manager::KVCacheManager const> crossKvCacheManager,
|
||||
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
|
||||
|
||||
private:
|
||||
SizeType32 mMaxNumRequests;
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> mKvCacheManager{nullptr};
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> mCrossKvCacheManager{nullptr};
|
||||
std::shared_ptr<BasePeftCacheManager> mPeftCacheManager{nullptr};
|
||||
};
|
||||
|
||||
/// @brief Schedule requests using the STATIC_BATCH policy
|
||||
class StaticBatchScheduler : public GuaranteedNoEvictScheduler
|
||||
{
|
||||
public:
|
||||
StaticBatchScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
|
||||
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
|
||||
StaticBatchScheduler(SizeType32 maxNumRequests,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
|
||||
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(
|
||||
kv_cache_manager::KVCacheManager const& kvCacheManager,
|
||||
OptionalRef<kv_cache_manager::KVCacheManager const> crossKvCacheManager,
|
||||
OptionalRef<BasePeftCacheManager const> peftCacheManager, RequestList const& activeRequests) const;
|
||||
};
|
||||
|
||||
class CapacityScheduler : public Algorithm
|
||||
@ -154,29 +149,26 @@ class CapacityScheduler : public Algorithm
|
||||
public:
|
||||
constexpr static auto name{"CapacityScheduler"};
|
||||
|
||||
CapacityScheduler() = default;
|
||||
|
||||
CapacityScheduler(SizeType32 maxNumRequests, std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
|
||||
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
|
||||
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool manyMicroBatches = false,
|
||||
explicit CapacityScheduler(SizeType32 maxNumRequests, executor::CapacitySchedulerPolicy capacitySchedulerPolicy,
|
||||
bool hasKvCacheManager, std::optional<bool> manyMicroBatches = std::nullopt,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
|
||||
|
||||
static CapacityScheduler make(SizeType32 maxNumRequests,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> kvCacheManager,
|
||||
std::shared_ptr<kv_cache_manager::KVCacheManager> crossKvCacheManager,
|
||||
std::shared_ptr<BasePeftCacheManager> peftCacheManager,
|
||||
executor::CapacitySchedulerPolicy capacitySchedulerPolicy, bool manyMicroBatches = false,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE)
|
||||
{
|
||||
return CapacityScheduler{maxNumRequests, std::move(kvCacheManager), std::move(crossKvCacheManager),
|
||||
std::move(peftCacheManager), capacitySchedulerPolicy, manyMicroBatches, noScheduleUntilState,
|
||||
noScheduleAfterState};
|
||||
}
|
||||
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests) const;
|
||||
/**
|
||||
* @brief Schedules requests following the selected policy.
|
||||
*
|
||||
* @param kvCacheManager Required in MaxUtilizationScheduler (as a ref) and in GuaranteedNoEvictScheduler and
|
||||
* StaticBatchScheduler (as a const ref).
|
||||
* @param crossKvCacheManager Optional used in GuaranteedNoEvictScheduler and StaticBatchScheduler.
|
||||
* @param peftCacheManager Optional used in MaxUtilizationScheduler, GuaranteedNoEvictScheduler and
|
||||
* StaticBatchScheduler.
|
||||
* @param activeRequests
|
||||
* @return std::tuple<RequestVector, RequestVector>, fittingRequests and pausedRequests respectively.
|
||||
*/
|
||||
[[nodiscard]] std::tuple<RequestVector, RequestVector> operator()(RequestList const& activeRequests,
|
||||
OptionalRef<kv_cache_manager::KVCacheManager> kvCacheManager = std::nullopt,
|
||||
OptionalRef<BasePeftCacheManager const> peftCacheManager = std::nullopt,
|
||||
OptionalRef<kv_cache_manager::KVCacheManager const> crossKvCacheManager = std::nullopt) const;
|
||||
|
||||
private:
|
||||
std::variant<std::monostate, MaxRequestsScheduler, MaxUtilizationScheduler, GuaranteedNoEvictScheduler,
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <vector>
|
||||
|
||||
using namespace tensorrt_llm::batch_manager::kv_cache_manager;
|
||||
@ -33,33 +34,62 @@ public:
|
||||
// TODO(TRTLLM-1564): Don't use a separate `initialize` function. Ensure eviction policies can't be in-between a
|
||||
// state of construction and initialization.
|
||||
virtual void initialize(std::vector<BlockPtr>& mAllBlocksById, std::vector<SizeType32> sizes,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt)
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority)
|
||||
= 0;
|
||||
|
||||
/// @brief Get a free block from the specified cache level
|
||||
/// @returns The pointer to the free block, along with whether it can be offloaded
|
||||
virtual std::tuple<BlockPtr, bool> getFreeBlock(SizeType32 cacheLevel) = 0;
|
||||
/// @brief Release a block. Prioritize the block for eviction if toFront=true
|
||||
virtual void releaseBlock(BlockPtr block, bool toFront = false) = 0;
|
||||
virtual void releaseBlock(BlockPtr block) = 0;
|
||||
virtual void releaseBlock(BlockPtr block, bool toFront) = 0;
|
||||
/// @brief Get the amount of free blocks in the primary memory pool
|
||||
virtual SizeType32 getNumFreeBlocks(SizeType32 cacheLevel) = 0;
|
||||
/// @brief Claim a free block. Called when the cache manager allocates or reuses a new block
|
||||
virtual void claimBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority = std::nullopt) = 0;
|
||||
virtual void claimBlock(BlockPtr block) = 0;
|
||||
virtual void claimBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority,
|
||||
std::optional<std::chrono::milliseconds> durationMs)
|
||||
= 0;
|
||||
/// @brief Perform any per-iteration bookkeeping
|
||||
virtual void refresh() = 0;
|
||||
};
|
||||
|
||||
struct ExpiringBlockComparator
|
||||
{
|
||||
inline bool operator()(BlockPtr const& a, BlockPtr const& b) const
|
||||
{
|
||||
// If two blocks expire in the same millisecond, their expiration times will be equal. As a fallback, check the
|
||||
// raw pointer values.
|
||||
return a->getExpirationTime() != b->getExpirationTime() ? a->getExpirationTime() < b->getExpirationTime()
|
||||
: a.get() < b.get();
|
||||
}
|
||||
};
|
||||
|
||||
class LRUEvictionPolicy : public BaseEvictionPolicy
|
||||
{
|
||||
public:
|
||||
void initialize(std::vector<BlockPtr>& mAllBlocksById, std::vector<SizeType32> sizes,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt) override;
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority) override;
|
||||
std::tuple<BlockPtr, bool> getFreeBlock(SizeType32 cacheLevel) override;
|
||||
void releaseBlock(BlockPtr block, bool toFront = false) override;
|
||||
|
||||
void releaseBlock(BlockPtr block) override;
|
||||
void releaseBlock(BlockPtr block, bool toFront) override;
|
||||
|
||||
SizeType32 getNumFreeBlocks(SizeType32 cacheLevel) override;
|
||||
void claimBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority = std::nullopt) override;
|
||||
|
||||
void claimBlock(BlockPtr block) override;
|
||||
void claimBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority,
|
||||
std::optional<std::chrono::milliseconds> durationMs) override;
|
||||
|
||||
// Check the expiring blocks heap, and move expired blocks back to the default queue.
|
||||
void refresh() override;
|
||||
|
||||
// Making this public and virtual makes it possible to test.
|
||||
[[nodiscard]] virtual std::chrono::steady_clock::time_point::duration getTime() const;
|
||||
|
||||
private:
|
||||
// Check if the block should be added to mFreeQueues.
|
||||
bool isReleasedLeafBlock(BlockPtr block);
|
||||
bool isReleasedLeafBlock(BlockPtr const& block);
|
||||
|
||||
// Queues of available leaf blocks, split by cache level and priority level
|
||||
std::vector<std::vector<FreeBlocksQueue>> mFreeQueues;
|
||||
@ -71,6 +101,8 @@ private:
|
||||
std::vector<SizeType32> mNumFreeBlocksPerLevel;
|
||||
// Secondary offload threshold. Blocks below this priority won't be evicted.
|
||||
executor::RetentionPriority mSecondaryOffloadMinPriority;
|
||||
// Heap of block times
|
||||
std::set<BlockPtr, ExpiringBlockComparator> mExpiringBlockHeap;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::eviction_policy
|
||||
|
||||
@ -64,6 +64,7 @@ auto constexpr kPromptEmbeddingTableName = "prompt_embedding_table";
|
||||
auto constexpr kPromptVocabSizeName = "prompt_vocab_size";
|
||||
auto constexpr kLoraTaskId = "lora_task_id";
|
||||
auto constexpr kNoRepeatNgramSizeTensorName = "noRepeatNgramSize";
|
||||
auto constexpr kSkipCrossAttnBlocksTensorName = "skipCrossAttnBlocks";
|
||||
// weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ]
|
||||
// where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer
|
||||
// each of the in / out tensors are first flattened and then concatenated together in the format above.
|
||||
|
||||
@ -43,9 +43,9 @@ public:
|
||||
std::optional<float> freeGpuMemoryFraction = std::nullopt, bool enableBlockReuse = false, bool useUvm = false,
|
||||
std::optional<size_t> hostCacheSize = std::nullopt, bool onboardBlocks = true,
|
||||
std::optional<float> crossKvCacheFraction = std::nullopt,
|
||||
std::optional<SizeType32> secondaryOffloadMinPriority = std::nullopt)
|
||||
std::optional<SizeType32> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0)
|
||||
: maxTokens{maxTokens}
|
||||
, maxAttentionWindowVec{maxAttentionWindowVec}
|
||||
, maxAttentionWindowVec{std::move(maxAttentionWindowVec)}
|
||||
, sinkTokenLength{sinkTokenLength}
|
||||
, freeGpuMemoryFraction{freeGpuMemoryFraction}
|
||||
, enableBlockReuse(enableBlockReuse)
|
||||
@ -54,6 +54,7 @@ public:
|
||||
, onboardBlocks(onboardBlocks)
|
||||
, crossKvCacheFraction{crossKvCacheFraction}
|
||||
, secondaryOffloadMinPriority(secondaryOffloadMinPriority)
|
||||
, eventBufferMaxSize(eventBufferMaxSize)
|
||||
{
|
||||
}
|
||||
|
||||
@ -61,7 +62,8 @@ public:
|
||||
: KvCacheConfig(kvCacheConfig.getMaxTokens(), kvCacheConfig.getMaxAttentionWindowVec(),
|
||||
kvCacheConfig.getSinkTokenLength(), kvCacheConfig.getFreeGpuMemoryFraction(),
|
||||
kvCacheConfig.getEnableBlockReuse(), false, kvCacheConfig.getHostCacheSize(),
|
||||
kvCacheConfig.getOnboardBlocks(), kvCacheConfig.getCrossKvCacheFraction())
|
||||
kvCacheConfig.getOnboardBlocks(), kvCacheConfig.getCrossKvCacheFraction(),
|
||||
kvCacheConfig.getSecondaryOffloadMinPriority(), kvCacheConfig.getEventBufferMaxSize())
|
||||
{
|
||||
}
|
||||
|
||||
@ -71,7 +73,9 @@ public:
|
||||
&& sinkTokenLength == other.sinkTokenLength && freeGpuMemoryFraction == other.freeGpuMemoryFraction
|
||||
&& enableBlockReuse == other.enableBlockReuse && useUvm == other.useUvm
|
||||
&& hostCacheSize == other.hostCacheSize && onboardBlocks == other.onboardBlocks
|
||||
&& crossKvCacheFraction == other.crossKvCacheFraction;
|
||||
&& crossKvCacheFraction == other.crossKvCacheFraction
|
||||
&& secondaryOffloadMinPriority == other.secondaryOffloadMinPriority
|
||||
&& eventBufferMaxSize == other.eventBufferMaxSize;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, KvCacheConfig const& self);
|
||||
@ -89,5 +93,7 @@ public:
|
||||
std::optional<float> crossKvCacheFraction;
|
||||
// The minimum priority level to allow blocks to be offloaded to secondary memory.
|
||||
std::optional<SizeType32> secondaryOffloadMinPriority;
|
||||
// Maximum size of the KV Cache event buffer
|
||||
size_t eventBufferMaxSize;
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
96
cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h
Normal file
96
cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h
Normal file
@ -0,0 +1,96 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <condition_variable>
|
||||
#include <deque>
|
||||
#include <mutex>
|
||||
#include <thread>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
class KVCacheBlock;
|
||||
using BlockPtr = std::shared_ptr<KVCacheBlock>;
|
||||
|
||||
class KVCacheEventManager
|
||||
{
|
||||
public:
|
||||
explicit KVCacheEventManager(size_t maxKVEventEntries);
|
||||
|
||||
~KVCacheEventManager();
|
||||
KVCacheEventManager(KVCacheEventManager& other) = delete;
|
||||
KVCacheEventManager& operator=(KVCacheEventManager& other) = delete;
|
||||
KVCacheEventManager(KVCacheEventManager&& other) = delete;
|
||||
KVCacheEventManager& operator=(KVCacheEventManager&& other) = delete;
|
||||
|
||||
void enqueueCreatedEvent(std::vector<SizeType32> const& numBlocksPerCacheLevel);
|
||||
|
||||
void enqueueStoredEvent(std::vector<BlockPtr> const& blocks);
|
||||
|
||||
void enqueueRemovedEvent(BlockPtr const& block);
|
||||
|
||||
void enqueueUpdatedEvent(executor::KVCacheUpdatedData const& data);
|
||||
|
||||
// Get events in mEvents. If there are no events, wait for a maximum of `timeout` milliseconds.
|
||||
std::deque<executor::KVCacheEvent> getEvents(std::optional<std::chrono::milliseconds> timeout);
|
||||
|
||||
// Clear the event buffer, and asynchronously move events to the event queue.
|
||||
void flush();
|
||||
|
||||
// Worker thread which adds events to mEvents.
|
||||
void worker();
|
||||
|
||||
private:
|
||||
// Add an event to mEventQueue
|
||||
void enqueueEvent(executor::KVCacheEvent&& event);
|
||||
|
||||
/// @brief Flag to terminate the worker
|
||||
bool mRun;
|
||||
/// @brief Worker thread
|
||||
std::thread mWorkerThread;
|
||||
|
||||
/// @brief The deque of events
|
||||
std::deque<executor::KVCacheEvent> mEvents;
|
||||
/// @brief Lock for mEvents
|
||||
std::mutex mEventsMutex;
|
||||
/// @brief Condition variable for blocking read
|
||||
std::condition_variable mEmptyCV;
|
||||
|
||||
/// @brief List of buffers waiting awaiting insertion into mEvents. Consumed by the worker.
|
||||
std::deque<std::deque<executor::KVCacheEvent>> mPendingEvents;
|
||||
/// @brief Lock for mPendingEvents
|
||||
std::mutex mPendingEventsMutex;
|
||||
/// @brief Condition variable to notify worker thread
|
||||
std::condition_variable mPendingEmptyCV;
|
||||
|
||||
/// @brief Buffer of events waiting to be added to the eventQueue. Only ever accessed by forward pass thread.
|
||||
std::deque<executor::KVCacheEvent> mEventQueue;
|
||||
|
||||
/// @brief The maximum size of the deque
|
||||
size_t mMaxSize;
|
||||
/// @brief An auto-incrementing event id counter
|
||||
size_t mEventId;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
@ -17,6 +17,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheEventManager.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
|
||||
#include "tensorrt_llm/common/optionalRef.h"
|
||||
#include "tensorrt_llm/kernels/kvCacheIndex.h"
|
||||
@ -40,11 +41,15 @@
|
||||
namespace tensorrt_llm::batch_manager::eviction_policy
|
||||
{
|
||||
class BaseEvictionPolicy;
|
||||
}
|
||||
} // namespace tensorrt_llm::batch_manager::eviction_policy
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
|
||||
static constexpr SizeType32 kPrimaryLevel = 0;
|
||||
|
||||
static constexpr SizeType32 kSecondaryLevel = 1;
|
||||
|
||||
class KVCacheBlock;
|
||||
class KVCacheManager;
|
||||
|
||||
@ -67,6 +72,15 @@ struct BlockKey
|
||||
LoraTaskIdType loraTaskId;
|
||||
VecUniqueTokens uniqueTokens;
|
||||
|
||||
BlockKey() = default;
|
||||
|
||||
explicit BlockKey(bool hasLora, LoraTaskIdType loraTaskId, VecUniqueTokens uniqueTokens)
|
||||
: hasLora{hasLora}
|
||||
, loraTaskId{loraTaskId}
|
||||
, uniqueTokens{std::move(uniqueTokens)}
|
||||
{
|
||||
}
|
||||
|
||||
bool operator==(BlockKey const& other) const noexcept
|
||||
{
|
||||
return (hasLora == other.hasLora && loraTaskId == other.loraTaskId && uniqueTokens == other.uniqueTokens);
|
||||
@ -78,9 +92,9 @@ struct BlockKey
|
||||
// Based on https://stackoverflow.com/questions/20511347/a-good-hash-function-for-a-vector/72073933#72073933
|
||||
struct BlockKeyHasher
|
||||
{
|
||||
std::size_t operator()(BlockKey const& blockKey) const noexcept
|
||||
std::size_t operator()(BlockKey const& blockKey, std::size_t parentHash = 0) const noexcept
|
||||
{
|
||||
size_t seed = blockKey.uniqueTokens.size();
|
||||
size_t seed = blockKey.uniqueTokens.size() ^ parentHash;
|
||||
for (auto const& uniqueToken : blockKey.uniqueTokens)
|
||||
{
|
||||
uint32_t a = static_cast<uint32_t>(uniqueToken.tokenId);
|
||||
@ -169,7 +183,9 @@ public:
|
||||
|
||||
[[nodiscard]] bool hasSchedulingRefs() const;
|
||||
|
||||
void setBlockKey(BlockKey& blockKey, bool isFull);
|
||||
void setBlockKey(BlockKey const& blockKey, bool isFull);
|
||||
|
||||
BlockKey getBlockKey();
|
||||
|
||||
[[nodiscard]] VecUniqueTokens const& getUniqueTokens() const;
|
||||
|
||||
@ -192,7 +208,19 @@ public:
|
||||
|
||||
void setPriority(executor::RetentionPriority priority);
|
||||
|
||||
executor::RetentionPriority getPriority() const;
|
||||
[[nodiscard]] executor::RetentionPriority getPriority() const;
|
||||
|
||||
void setDurationMs(std::optional<std::chrono::milliseconds> durationMs);
|
||||
|
||||
[[nodiscard]] std::optional<std::chrono::milliseconds> getDurationMs() const;
|
||||
|
||||
void setExpirationTime(std::optional<std::chrono::steady_clock::time_point::duration> expirationTime);
|
||||
|
||||
[[nodiscard]] std::optional<std::chrono::steady_clock::time_point::duration> getExpirationTime() const;
|
||||
|
||||
void setHash(size_t hash);
|
||||
|
||||
size_t getHash() const;
|
||||
|
||||
private:
|
||||
// Linear ID of block independent of pool
|
||||
@ -225,6 +253,12 @@ private:
|
||||
|
||||
// Priority of the block
|
||||
executor::RetentionPriority mPriority;
|
||||
// Duration that the block's priority level applies for
|
||||
std::optional<std::chrono::milliseconds> mDurationMs;
|
||||
// Expiration time of the block
|
||||
std::optional<std::chrono::steady_clock::time_point::duration> mExpirationTime;
|
||||
// Hash for the event manager
|
||||
size_t mHash;
|
||||
};
|
||||
|
||||
class GenerationRequest
|
||||
@ -234,8 +268,7 @@ public:
|
||||
|
||||
explicit GenerationRequest(LlmRequest::RequestIdType requestId, SizeType32 numTokens, SizeType32 beamWidth,
|
||||
SizeType32 maxBlocks, SizeType32 numPools = 1,
|
||||
executor::RetentionPriority decodeRetentionPriority
|
||||
= executor::KvCacheRetentionConfig::kDefaultRetentionPriority)
|
||||
executor::KvCacheRetentionConfig kvCacheRetentionConfig = executor::KvCacheRetentionConfig())
|
||||
: mRequestId(requestId)
|
||||
, mNumTokens(numTokens)
|
||||
, mBeamWidth(beamWidth)
|
||||
@ -243,7 +276,7 @@ public:
|
||||
, mCacheBlockIndices{runtime::BufferManager::cpu(
|
||||
runtime::ITensor::makeShape({numPools, beamWidth, 2, maxBlocks}),
|
||||
runtime::TRTDataType<tensorrt_llm::kernels::KVCacheIndex>::value)}
|
||||
, mDecodeRetentionPriority(decodeRetentionPriority)
|
||||
, mKvCacheRetentionConfig(std::move(kvCacheRetentionConfig))
|
||||
{
|
||||
auto cacheBlockIdsRange = runtime::BufferRange<tensorrt_llm::kernels::KVCacheIndex>(*mCacheBlockIndices);
|
||||
std::fill(cacheBlockIdsRange.begin(), cacheBlockIdsRange.end(),
|
||||
@ -321,7 +354,12 @@ public:
|
||||
|
||||
[[nodiscard]] executor::RetentionPriority getDecodeRetentionPriority() const
|
||||
{
|
||||
return mDecodeRetentionPriority;
|
||||
return mKvCacheRetentionConfig.getDecodeRetentionPriority();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<std::chrono::milliseconds> getDecodeDurationMs() const
|
||||
{
|
||||
return mKvCacheRetentionConfig.getDecodeDurationMs();
|
||||
}
|
||||
|
||||
private:
|
||||
@ -336,7 +374,7 @@ private:
|
||||
// Tensor of block indices allocated for each beam of the sequence
|
||||
runtime::ITensor::SharedPtr mCacheBlockIndices;
|
||||
// The retention priority to assign to decode blocks
|
||||
executor::RetentionPriority mDecodeRetentionPriority;
|
||||
executor::KvCacheRetentionConfig mKvCacheRetentionConfig;
|
||||
};
|
||||
|
||||
// attach metadata to a pool pointer
|
||||
@ -385,7 +423,8 @@ public:
|
||||
SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool,
|
||||
SizeType32 maxNumSequences, std::shared_ptr<runtime::CudaStream> stream, bool onboardBlocks,
|
||||
CacheType cacheType = CacheType::kSELF,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt);
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager = nullptr);
|
||||
|
||||
~BlockManager();
|
||||
|
||||
@ -442,6 +481,9 @@ public:
|
||||
return mMissedBlocks;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::deque<executor::KVCacheEvent> getLatestEvents(
|
||||
std::optional<std::chrono::milliseconds> timeout) const;
|
||||
|
||||
[[nodiscard]] bool hasFreeBlocks(SizeType32 numRequired = 1) const noexcept
|
||||
{
|
||||
return getNumFreeBlocks() >= numRequired;
|
||||
@ -522,13 +564,27 @@ public:
|
||||
|
||||
//! \brief Find first new block that must be allocated for context phase and return it's concatenated token vectors.
|
||||
//! \details Only full blocks are considered.
|
||||
BlockKey findNewContextBlock(VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const;
|
||||
[[nodiscard]] std::optional<BlockKey> findNewContextBlock(
|
||||
VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const;
|
||||
|
||||
[[nodiscard]] runtime::BufferManager const& getBufferManager() const
|
||||
{
|
||||
return mBufferManager;
|
||||
}
|
||||
|
||||
//! \brief Perform per-request bookkeeping
|
||||
void refreshBlocks();
|
||||
|
||||
void flushIterationEvents()
|
||||
{
|
||||
if (mEventManager)
|
||||
{
|
||||
mEventManager->flush();
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] static bool blockInRadixTree(BlockPtr const& block);
|
||||
|
||||
private:
|
||||
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
|
||||
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType32 beamIdx);
|
||||
@ -539,23 +595,23 @@ private:
|
||||
//! \brief Store blocks in cached blocks.
|
||||
//! \param blockKeys Key of each block.
|
||||
//! \param blockIds Id of each block.
|
||||
//! \param isChunkedContext Whether these blocks are being stored for chunked context.
|
||||
void storeBlocks(std::list<BlockKey> blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds,
|
||||
bool isChunkedContext = false);
|
||||
void storeBlocks(std::vector<BlockKey> blockKeys, std::vector<KVCacheBlock::IdType> const& blockIds);
|
||||
|
||||
//! \brief Try to load blocks from cache. Allocate new blocks if necessary.
|
||||
//! \param blockKeys Key of each block.
|
||||
//! \param sequence Sequence to which blocks are assigned.
|
||||
//! \return Number of matched tokens from loaded blocks.
|
||||
SizeType32 loadOrAllocateBlocks(std::list<BlockKey> const& blockKeys, SizeType32 numContextBlocks,
|
||||
GenerationRequest& sequence, std::vector<std::optional<executor::RetentionPriority>> blockPriorities);
|
||||
SizeType32 loadOrAllocateBlocks(std::vector<BlockKey> const& blockKeys, SizeType32 numContextBlocks,
|
||||
GenerationRequest& sequence, std::vector<executor::RetentionPriorityAndDuration> const& perBlockRetentions);
|
||||
|
||||
//! \brief Find block least likely to be reused, free it if necessary and return.
|
||||
[[nodiscard]] BlockPtr getFreeBlock(
|
||||
executor::RetentionPriority = executor::KvCacheRetentionConfig::kDefaultRetentionPriority);
|
||||
executor::RetentionPriority = executor::KvCacheRetentionConfig::kDefaultRetentionPriority,
|
||||
std::optional<std::chrono::milliseconds> durationMs = std::nullopt);
|
||||
|
||||
//! \brief Free block from previous block and claim it from free blocks list.
|
||||
void claimLeafBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority = std::nullopt);
|
||||
void claimLeafBlock(BlockPtr block, std::optional<executor::RetentionPriority> priority = std::nullopt,
|
||||
std::optional<std::chrono::milliseconds> durationMs = std::nullopt);
|
||||
|
||||
//! \brief Compute pointer to raw KV block (K & V, all layers).
|
||||
[[nodiscard]] runtime::ITensor::SharedPtr computeBlockPointer(
|
||||
@ -598,6 +654,8 @@ private:
|
||||
CacheType mCacheType;
|
||||
// Eviction Policy
|
||||
std::shared_ptr<BaseEvictionPolicy> mEvictionPolicy;
|
||||
// Event manager
|
||||
std::shared_ptr<KVCacheEventManager> mEventManager;
|
||||
|
||||
// Statistics for block allocations/reuse
|
||||
// Total number of blocks allocated by all requests
|
||||
@ -630,14 +688,16 @@ public:
|
||||
SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, bool useOneMoreBlock,
|
||||
CudaStreamPtr stream, bool enableBlockReuse = false, bool onboardBlocks = true,
|
||||
CacheType cacheType = CacheType::kSELF,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt);
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager = nullptr);
|
||||
|
||||
KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
|
||||
SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences,
|
||||
SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, bool useOneMoreBlock,
|
||||
CudaStreamPtr stream, bool enableBlockReuse = true, bool onboardBlocks = true,
|
||||
CacheType cacheType = CacheType::kSELF,
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt);
|
||||
std::optional<executor::RetentionPriority> secondaryOffloadMinPriority = std::nullopt,
|
||||
std::shared_ptr<KVCacheEventManager> eventManager = nullptr);
|
||||
|
||||
void allocatePools(nvinfer1::DataType dtype, bool useUvm = false);
|
||||
|
||||
@ -705,6 +765,12 @@ public:
|
||||
return mMaxBlocksPerSeq;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::deque<executor::KVCacheEvent> getLatestEvents(
|
||||
std::optional<std::chrono::milliseconds> timeout = std::nullopt) const
|
||||
{
|
||||
return mBlockManager.getLatestEvents(timeout);
|
||||
}
|
||||
|
||||
[[nodiscard]] BlockManager const& getBlockManager() const
|
||||
{
|
||||
return mBlockManager;
|
||||
@ -790,7 +856,8 @@ public:
|
||||
|
||||
//! \brief Find first new block that must be allocated for context phase and return it's concatenated token vector.
|
||||
//! \details Only full blocks are considered.
|
||||
BlockKey findNewContextBlock(VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const;
|
||||
[[nodiscard]] std::optional<BlockKey> findNewContextBlock(
|
||||
VecUniqueTokens const& uniqueTokens, LlmRequest const& llmRequest) const;
|
||||
|
||||
//! \brief Store full context blocks contributed by llmRequest.
|
||||
//! \details These blocks become reusable from next step.
|
||||
@ -804,6 +871,17 @@ public:
|
||||
//! \brief Get the batch size that can fill the kv cache to the maximum capacity given the sequence length
|
||||
[[nodiscard]] SizeType32 getMaxCapacityBatchSize(SizeType32 seqLen);
|
||||
|
||||
//! \brief Perform per-iteration bookkeeping
|
||||
void refreshBlocks()
|
||||
{
|
||||
mBlockManager.refreshBlocks();
|
||||
}
|
||||
|
||||
void flushIterationEvents()
|
||||
{
|
||||
mBlockManager.flushIterationEvents();
|
||||
}
|
||||
|
||||
private:
|
||||
void setOffsets(kernels::KVCacheIndex* offsetsPtr, nvinfer1::Dims const& offsetsShape, SizeType32 beamIdx,
|
||||
SizeType32 blockIdx, KVCacheBlock::IdType blockId) const;
|
||||
|
||||
@ -55,7 +55,6 @@ enum class LlmRequestState : int32_t
|
||||
/// Waiting context-only request transmitting the kv cache
|
||||
kDISAGG_CONTEXT_COMPLETE = 8, ///< Context-only request finished kv cache transmission.
|
||||
kDISAGG_GENERATION_TRANS_IN_PROGRESS = 9, ///< For disaggregated serving only: transmitting the kv cache
|
||||
kWAITING_TO_SEND_LOGITS = 10, ///< Generation phase completed, logits not sent yet
|
||||
};
|
||||
|
||||
enum LlmRequestType
|
||||
@ -110,7 +109,8 @@ public:
|
||||
std::optional<TensorPtr> crossAttentionMask = std::nullopt,
|
||||
LlmRequestType llmRequestType = LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
|
||||
std::optional<std::shared_ptr<VecTokenExtraIds>> inputTokenExtraIds = std::nullopt,
|
||||
SizeType32 numReturnSequences = 1, std::optional<executor::EagleConfig> eagleConfig = std::nullopt)
|
||||
SizeType32 numReturnSequences = 1, std::optional<executor::EagleConfig> eagleConfig = std::nullopt,
|
||||
std::optional<TensorPtr> skipCrossAttnBlocks = std::nullopt)
|
||||
: mRequestId(requestId)
|
||||
, mPromptLen(inputTokens->size())
|
||||
, mMaxNewTokens(maxNewTokens)
|
||||
@ -159,6 +159,7 @@ public:
|
||||
, mNumReturnSequences(numReturnSequences)
|
||||
, mEagleConfig(eagleConfig)
|
||||
, mSequenceIndex(0)
|
||||
, mSkipCrossAttnBlocks(std::move(skipCrossAttnBlocks))
|
||||
{
|
||||
if (mEncoderTokens.has_value() || encoderInputFeatures.has_value())
|
||||
{
|
||||
@ -337,6 +338,16 @@ public:
|
||||
mCrossAttentionMask = std::nullopt;
|
||||
}
|
||||
|
||||
auto const& skipCrossAttnBlocks = req.getSkipCrossAttnBlocks();
|
||||
if (skipCrossAttnBlocks.has_value())
|
||||
{
|
||||
mSkipCrossAttnBlocks = executor::detail::toITensor(skipCrossAttnBlocks.value());
|
||||
}
|
||||
else
|
||||
{
|
||||
mSkipCrossAttnBlocks = std::nullopt;
|
||||
}
|
||||
|
||||
switch (req.getRequestType())
|
||||
{
|
||||
case executor::RequestType::REQUEST_TYPE_CONTEXT_AND_GENERATION:
|
||||
@ -1071,6 +1082,11 @@ public:
|
||||
return mCrossAttentionMask.value_or(nullptr);
|
||||
}
|
||||
|
||||
[[nodiscard]] TensorPtr const getSkipCrossAttnBlocks() const
|
||||
{
|
||||
return mSkipCrossAttnBlocks.value_or(nullptr);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr isStreaming() const noexcept
|
||||
{
|
||||
return mIsStreaming;
|
||||
@ -1235,11 +1251,6 @@ public:
|
||||
return mState == LlmRequestState::kDISAGG_CONTEXT_COMPLETE;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool isCompleteWaitingToSendLogits() const noexcept
|
||||
{
|
||||
return mState == LlmRequestState::kWAITING_TO_SEND_LOGITS;
|
||||
}
|
||||
|
||||
/// To determine whether the context is unchunked. When a context is chunked into only a part, it
|
||||
/// is still different from the unchunked state, which indicates the initial status.
|
||||
[[nodiscard]] bool isFullContextRequest() const noexcept
|
||||
@ -1342,7 +1353,7 @@ public:
|
||||
|
||||
[[nodiscard]] bool isFinished() const noexcept
|
||||
{
|
||||
return isGenerationCompleteState() || isDisaggContextTransmissionState() || isCompleteWaitingToSendLogits();
|
||||
return isGenerationCompleteState() || isDisaggContextTransmissionState();
|
||||
}
|
||||
|
||||
/// @brief Create a Response from the current state of the request
|
||||
@ -1665,6 +1676,8 @@ protected:
|
||||
SizeType32 mReusedBlocksPerRequest{0};
|
||||
SizeType32 mMissedBlocksPerRequest{0};
|
||||
|
||||
std::optional<TensorPtr> mSkipCrossAttnBlocks;
|
||||
|
||||
private:
|
||||
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
|
||||
{
|
||||
@ -1823,7 +1836,8 @@ public:
|
||||
std::optional<TensorPtr> crossAttentionMask = std::nullopt,
|
||||
LlmRequestType llmRequestType = LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
|
||||
std::optional<std::shared_ptr<VecTokenExtraIds>> inputTokenExtraIds = std::nullopt,
|
||||
SizeType32 numReturnSequences = 1, std::optional<executor::EagleConfig> eagleConfig = std::nullopt)
|
||||
SizeType32 numReturnSequences = 1, std::optional<executor::EagleConfig> eagleConfig = std::nullopt,
|
||||
std::optional<TensorPtr> skipCrossAttnBlocks = std::nullopt)
|
||||
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
|
||||
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
|
||||
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
|
||||
@ -1832,7 +1846,7 @@ public:
|
||||
std::move(logitsPostProcessor), applyLogitsPostProcessorBatched, std::move(encoderInputTokens),
|
||||
returnEncoderOutput, clientId, priority, std::move(encoderInputFeatures), std::move(encoderOutputLength),
|
||||
std::move(crossAttentionMask), llmRequestType, std::move(inputTokenExtraIds), numReturnSequences,
|
||||
std::move(eagleConfig))
|
||||
std::move(eagleConfig), std::move(skipCrossAttnBlocks))
|
||||
{
|
||||
}
|
||||
|
||||
@ -1857,7 +1871,9 @@ public:
|
||||
std::optional<SizeType32> encoderOutputLength = std::nullopt,
|
||||
std::optional<TensorPtr> crossAttentionMask = std::nullopt,
|
||||
LlmRequestType llmRequestType = LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION,
|
||||
std::optional<VecTokenExtraIds> inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1)
|
||||
std::optional<VecTokenExtraIds> inputTokenExtraIds = std::nullopt, SizeType32 numReturnSequences = 1,
|
||||
std::optional<executor::EagleConfig> eagleConfig = std::nullopt,
|
||||
std::optional<TensorPtr> skipCrossAttnBlocks = std::nullopt)
|
||||
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
|
||||
samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList),
|
||||
std::move(stopWordsList),
|
||||
@ -1875,7 +1891,7 @@ public:
|
||||
llmRequestType,
|
||||
inputTokenExtraIds ? std::make_optional(std::make_shared<VecTokenExtraIds>(std::move(*inputTokenExtraIds)))
|
||||
: std::optional<std::shared_ptr<VecTokenExtraIds>>(std::nullopt),
|
||||
numReturnSequences)
|
||||
numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks)
|
||||
{
|
||||
}
|
||||
|
||||
|
||||
@ -50,68 +50,29 @@ public:
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
using ContextChunkingPolicy = tensorrt_llm::executor::ContextChunkingPolicy;
|
||||
|
||||
MicroBatchScheduler() = default;
|
||||
|
||||
explicit MicroBatchScheduler(SizeType32 maxBatchSize, std::optional<SizeType32> maxNumTokens = std::nullopt,
|
||||
explicit MicroBatchScheduler(std::optional<SizeType32> maxNumTokens = std::nullopt,
|
||||
std::optional<batch_scheduler::ContextChunkingConfig> ctxChunkConfig = std::nullopt,
|
||||
std::optional<SizeType32> maxContextLength = std::nullopt,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE);
|
||||
|
||||
static MicroBatchScheduler make(SizeType32 maxBatchSize, std::optional<SizeType32> maxNumTokens = std::nullopt,
|
||||
std::optional<batch_scheduler::ContextChunkingConfig> ctxChunkConfig = std::nullopt,
|
||||
std::optional<SizeType32> maxContextLength = std::nullopt,
|
||||
LlmRequestState noScheduleUntilState = LlmRequestState::kCONTEXT_INIT,
|
||||
LlmRequestState noScheduleAfterState = LlmRequestState::kGENERATION_COMPLETE)
|
||||
{
|
||||
return MicroBatchScheduler{
|
||||
maxBatchSize, maxNumTokens, ctxChunkConfig, maxContextLength, noScheduleUntilState, noScheduleAfterState};
|
||||
}
|
||||
|
||||
std::tuple<RequestVector, RequestVector> operator()(
|
||||
RequestVector const& activeRequests, ReqIdsSet const& inflightReqIds);
|
||||
RequestVector& activeRequests, ReqIdsSet const& inflightReqIds, SizeType32 maxBatchSizeRuntime) const;
|
||||
|
||||
static void setCtxRequestsChunkSize(RequestVector const& contextsToBeChunked, ContextChunkingPolicy ctxChunkPolicy,
|
||||
static void setCtxRequestsChunkSize(RequestVector& contextsToBeChunked, ContextChunkingPolicy ctxChunkPolicy,
|
||||
std::optional<SizeType32> ctxTokensCapacity, SizeType32 chunkUnitSize,
|
||||
std::optional<SizeType32> const& maxContextLength);
|
||||
|
||||
void setRuntimeMaxBatchSize(SizeType32 runtimeMaxBatchSize);
|
||||
|
||||
SizeType32 getMaxBatchSizeStatic() const
|
||||
{
|
||||
return mMaxBatchSize;
|
||||
}
|
||||
|
||||
SizeType32 getMaxBatchSizeTunerRecommended() const
|
||||
{
|
||||
return mMaxBatchSizeTunerRecommended;
|
||||
}
|
||||
|
||||
SizeType32 getMaxBatchSizeRuntime() const
|
||||
{
|
||||
return mMaxBatchSizeRuntime;
|
||||
}
|
||||
|
||||
private:
|
||||
template <ContextChunkingPolicy tPolicy>
|
||||
static void setCtxRequestsChunkSize(RequestVector const& contextsToBeChunked,
|
||||
std::optional<SizeType32> ctxTokensCapacity, SizeType32 chunkUnitSize,
|
||||
std::optional<SizeType32> const& maxContextLength);
|
||||
static void setCtxRequestsChunkSize(RequestVector& contextsToBeChunked, std::optional<SizeType32> ctxTokensCapacity,
|
||||
SizeType32 chunkUnitSize, std::optional<SizeType32> const& maxContextLength);
|
||||
|
||||
/// After the chunk sizes have been determined, this function will discard
|
||||
/// any draft tokens that don't fit.
|
||||
static void fitDraftTokens(RequestVector const& contextsToBeChunked, std::optional<SizeType32> ctxTokensCapacity,
|
||||
static void fitDraftTokens(RequestVector& contextsToBeChunked, std::optional<SizeType32> ctxTokensCapacity,
|
||||
SizeType32 chunkUnitSize, std::optional<SizeType32> const& maxContextLength);
|
||||
|
||||
/// The maximum number of requests returned by scheduleRequests
|
||||
SizeType32 mMaxBatchSize;
|
||||
|
||||
/// The max batch size recommended by the dynamic tuner
|
||||
SizeType32 mMaxBatchSizeTunerRecommended;
|
||||
|
||||
/// The min of mMaxBatchSize and mMaxBatchSizeTunerRecommended
|
||||
SizeType32 mMaxBatchSizeRuntime;
|
||||
|
||||
/// The maximum number of tokens to include in a batch
|
||||
std::optional<SizeType32> mMaxNumTokens;
|
||||
|
||||
|
||||
74
cpp/include/tensorrt_llm/batch_manager/pauseRequests.h
Normal file
74
cpp/include/tensorrt_llm/batch_manager/pauseRequests.h
Normal file
@ -0,0 +1,74 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
|
||||
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
|
||||
#include "tensorrt_llm/batch_manager/sequenceSlotManager.h"
|
||||
#include "tensorrt_llm/common/algorithm.h"
|
||||
#include "tensorrt_llm/common/optionalRef.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
class BasePeftCacheManager;
|
||||
class LlmRequest;
|
||||
|
||||
namespace kv_cache_manager
|
||||
{
|
||||
|
||||
class KVCacheManager;
|
||||
|
||||
}
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
namespace tensorrt_llm::batch_manager
|
||||
{
|
||||
|
||||
namespace tle = tensorrt_llm::executor;
|
||||
|
||||
class PauseRequests : Algorithm
|
||||
{
|
||||
using KVCacheManager = kv_cache_manager::KVCacheManager;
|
||||
|
||||
template <typename T>
|
||||
using OptionalRef = common::OptionalRef<T>;
|
||||
|
||||
public:
|
||||
constexpr static auto name{"PauseRequests"};
|
||||
|
||||
using SizeType32 = tensorrt_llm::runtime::SizeType32;
|
||||
|
||||
PauseRequests(SizeType32 maxInputLen)
|
||||
: mMaxInputLen(maxInputLen)
|
||||
{
|
||||
}
|
||||
|
||||
void operator()(RequestVector& requestsToPause, ReqIdsSet& inflightReqIds, ReqIdsSet& reqIdsToPause,
|
||||
bool pauseFlagged, SequenceSlotManager& seqSlotManager,
|
||||
OptionalRef<KVCacheManager> kvCacheManager = std::nullopt,
|
||||
OptionalRef<KVCacheManager> crossKvCacheManager = std::nullopt,
|
||||
OptionalRef<BasePeftCacheManager> peftCacheManager = std::nullopt) const;
|
||||
|
||||
private:
|
||||
SizeType32 mMaxInputLen;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
@ -17,6 +17,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/cudaBf16Wrapper.h"
|
||||
#include "tensorrt_llm/common/cudaDriverWrapper.h"
|
||||
#include "tensorrt_llm/common/cudaFp8Utils.h"
|
||||
#include "tensorrt_llm/common/logger.h"
|
||||
#include "tensorrt_llm/common/tllmException.h"
|
||||
@ -24,7 +25,9 @@
|
||||
#include <cinttypes>
|
||||
#include <cublasLt.h>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <driver_types.h>
|
||||
#include <fstream>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
@ -318,9 +321,13 @@ inline std::tuple<size_t, size_t> getDeviceMemoryInfo(bool const useUvm)
|
||||
{
|
||||
if (useUvm)
|
||||
{
|
||||
size_t freeSysMem, totalSysMem;
|
||||
size_t freeSysMem = 0;
|
||||
size_t totalSysMem = 0;
|
||||
#ifndef _WIN32 // Linux
|
||||
struct sysinfo info;
|
||||
struct sysinfo info
|
||||
{
|
||||
};
|
||||
|
||||
sysinfo(&info);
|
||||
totalSysMem = info.totalram * info.mem_unit;
|
||||
freeSysMem = info.freeram * info.mem_unit;
|
||||
@ -336,20 +343,38 @@ inline std::tuple<size_t, size_t> getDeviceMemoryInfo(bool const useUvm)
|
||||
((double) totalSysMem / 1e9), ((double) freeSysMem / 1e9));
|
||||
return {freeSysMem, totalSysMem};
|
||||
}
|
||||
else
|
||||
{
|
||||
size_t free, total;
|
||||
check_cuda_error(cudaMemGetInfo(&free, &total));
|
||||
TLLM_LOG_DEBUG("Using GPU memory for KV cache, total memory %0.2f GB, available memory %0.2f GB",
|
||||
((double) total / 1e9), ((double) free / 1e9));
|
||||
return {free, total};
|
||||
}
|
||||
|
||||
size_t free = 0;
|
||||
size_t total = 0;
|
||||
check_cuda_error(cudaMemGetInfo(&free, &total));
|
||||
TLLM_LOG_DEBUG("Using GPU memory for KV cache, total memory %0.2f GB, available memory %0.2f GB",
|
||||
((double) total / 1e9), ((double) free / 1e9));
|
||||
return {free, total};
|
||||
}
|
||||
|
||||
/// @brief Gets the memory allocation granularity for the current device.
|
||||
///
|
||||
/// @return size_t The size of the smallest difference in memory size supported by the current device.
|
||||
inline size_t getAllocationGranularity()
|
||||
{
|
||||
auto const currentDevice = getDevice();
|
||||
::CUmemAllocationProp prop = {};
|
||||
|
||||
prop.type = ::CU_MEM_ALLOCATION_TYPE_PINNED;
|
||||
prop.location.type = ::CU_MEM_LOCATION_TYPE_DEVICE;
|
||||
prop.location.id = currentDevice;
|
||||
prop.requestedHandleTypes = ::CU_MEM_HANDLE_TYPE_NONE;
|
||||
|
||||
// Get the minimum granularity supported for allocation with cuMemCreate()
|
||||
size_t granularity = 0;
|
||||
TLLM_CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));
|
||||
return granularity;
|
||||
}
|
||||
|
||||
inline int getMultiProcessorCount()
|
||||
{
|
||||
int device_id;
|
||||
int multi_processor_count;
|
||||
int device_id = 0;
|
||||
int multi_processor_count = 0;
|
||||
check_cuda_error(cudaGetDevice(&device_id));
|
||||
check_cuda_error(cudaDeviceGetAttribute(&multi_processor_count, cudaDevAttrMultiProcessorCount, device_id));
|
||||
return multi_processor_count;
|
||||
@ -357,8 +382,8 @@ inline int getMultiProcessorCount()
|
||||
|
||||
inline int getMaxSharedMemoryPerBlockOptin()
|
||||
{
|
||||
int device_id;
|
||||
int max_shared_memory_per_block;
|
||||
int device_id = 0;
|
||||
int max_shared_memory_per_block = 0;
|
||||
check_cuda_error(cudaGetDevice(&device_id));
|
||||
check_cuda_error(
|
||||
cudaDeviceGetAttribute(&max_shared_memory_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device_id));
|
||||
@ -368,8 +393,8 @@ inline int getMaxSharedMemoryPerBlockOptin()
|
||||
template <typename T1, typename T2>
|
||||
inline size_t divUp(const T1& a, const T2& n)
|
||||
{
|
||||
size_t tmp_a = static_cast<size_t>(a);
|
||||
size_t tmp_n = static_cast<size_t>(n);
|
||||
auto const tmp_a = static_cast<size_t>(a);
|
||||
auto const tmp_n = static_cast<size_t>(n);
|
||||
return (tmp_a + tmp_n - 1) / tmp_n;
|
||||
}
|
||||
|
||||
|
||||
@ -1,12 +1,11 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
* Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
@ -97,6 +97,11 @@ public:
|
||||
return QuantMode(BaseType(1u) << 3 | BaseType(1u) << 4 | BaseType(1u) << 9);
|
||||
}
|
||||
|
||||
static constexpr QuantMode w4a8QServe() noexcept
|
||||
{
|
||||
return QuantMode(BaseType(1u) << 10);
|
||||
}
|
||||
|
||||
constexpr BaseType value() const noexcept
|
||||
{
|
||||
return mValue;
|
||||
@ -169,7 +174,8 @@ public:
|
||||
|
||||
static constexpr QuantMode fromDescription(bool quantizeWeights = false, bool quantizeActivations = false,
|
||||
bool perToken = false, bool perChannel = false, bool perGroup = false, bool useInt4Weights = false,
|
||||
bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false)
|
||||
bool useInt8KvCache = false, bool useFp8KvCache = false, bool useFp8Qdq = false, bool useFp8RowWise = false,
|
||||
bool useW4a8QServe = false)
|
||||
{
|
||||
QuantMode quantMode{};
|
||||
if (quantizeWeights)
|
||||
@ -218,6 +224,11 @@ public:
|
||||
quantMode += fp8RowWise();
|
||||
}
|
||||
|
||||
if (useW4a8QServe)
|
||||
{
|
||||
quantMode += w4a8QServe();
|
||||
}
|
||||
|
||||
return quantMode;
|
||||
}
|
||||
|
||||
@ -226,12 +237,17 @@ public:
|
||||
return fromDescription(true, true, perToken, perChannel);
|
||||
}
|
||||
|
||||
static constexpr QuantMode useQServe(bool perGroup)
|
||||
{
|
||||
return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true);
|
||||
}
|
||||
|
||||
static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false)
|
||||
{
|
||||
return fromDescription(true, false, false, false, perGroup, useInt4Weights);
|
||||
}
|
||||
|
||||
static const QuantMode fromQuantAlgo(
|
||||
static QuantMode const fromQuantAlgo(
|
||||
std::optional<std::string> quantAlgo = std::nullopt, std::optional<std::string> kvCacheQuantAlgo = std::nullopt)
|
||||
{
|
||||
QuantMode quantMode{};
|
||||
@ -251,6 +267,14 @@ public:
|
||||
{
|
||||
quantMode = useWeightOnly(true, true);
|
||||
}
|
||||
else if (quantAlgo == "W4A8_QSERVE_PER_GROUP")
|
||||
{
|
||||
quantMode = useQServe(false);
|
||||
}
|
||||
else if (quantAlgo == "W4A8_QSERVE_PER_CHANNEL")
|
||||
{
|
||||
quantMode = useQServe(true);
|
||||
}
|
||||
else if (quantAlgo == "W4A16_GPTQ")
|
||||
{
|
||||
quantMode = useWeightOnly(true, true);
|
||||
|
||||
@ -18,6 +18,8 @@
|
||||
|
||||
#include "tensorrt_llm/executor/tensor.h"
|
||||
#include "tensorrt_llm/executor/types.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/runtimeDefaults.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <cstdint>
|
||||
@ -29,12 +31,18 @@
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::mpi
|
||||
{
|
||||
class MpiComm;
|
||||
}
|
||||
} // namespace tensorrt_llm::mpi
|
||||
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
class KVCacheManager;
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
namespace tensorrt_llm::executor
|
||||
{
|
||||
@ -355,7 +363,7 @@ private:
|
||||
class SpeculativeDecodingConfig
|
||||
{
|
||||
public:
|
||||
explicit SpeculativeDecodingConfig(bool fastLogits);
|
||||
explicit SpeculativeDecodingConfig(bool fastLogits = false);
|
||||
|
||||
bool operator==(SpeculativeDecodingConfig const& other) const;
|
||||
|
||||
@ -365,6 +373,20 @@ public:
|
||||
|
||||
using RetentionPriority = SizeType32;
|
||||
|
||||
struct RetentionPriorityAndDuration
|
||||
{
|
||||
|
||||
RetentionPriorityAndDuration(std::optional<RetentionPriority> const& retentionPriority,
|
||||
std::optional<std::chrono::milliseconds> const& durationMs)
|
||||
: retentionPriority{retentionPriority}
|
||||
, durationMs{durationMs}
|
||||
{
|
||||
}
|
||||
|
||||
std::optional<RetentionPriority> retentionPriority;
|
||||
std::optional<std::chrono::milliseconds> durationMs;
|
||||
};
|
||||
|
||||
/// @brief Configuration for the request's retention in the KV Cache
|
||||
class KvCacheRetentionConfig
|
||||
{
|
||||
@ -376,14 +398,16 @@ public:
|
||||
|
||||
/// @brief A single entry to set block priorities over a token range. Earlier ranges always take priority over later
|
||||
/// ones. For example, with a block size of 16, a range of [0, 17] would be applied to the first two blocks.
|
||||
struct TokenRangeRetentionPriority
|
||||
struct TokenRangeRetentionConfig
|
||||
{
|
||||
public:
|
||||
explicit TokenRangeRetentionPriority(SizeType32 tokenStart, std::optional<SizeType32> tokenEnd = std::nullopt,
|
||||
RetentionPriority priority = KvCacheRetentionConfig::kDefaultRetentionPriority)
|
||||
explicit TokenRangeRetentionConfig(SizeType32 tokenStart, std::optional<SizeType32> tokenEnd = std::nullopt,
|
||||
RetentionPriority priority = KvCacheRetentionConfig::kDefaultRetentionPriority,
|
||||
std::optional<std::chrono::milliseconds> durationMs = std::nullopt)
|
||||
: tokenStart{tokenStart}
|
||||
, tokenEnd{tokenEnd}
|
||||
, priority{priority}
|
||||
, durationMs{durationMs}
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(priority >= KvCacheRetentionConfig::kMinRetentionPriority
|
||||
&& priority <= KvCacheRetentionConfig::kMaxRetentionPriority,
|
||||
@ -398,10 +422,15 @@ public:
|
||||
std::optional<SizeType32> tokenEnd;
|
||||
/// @brief The priority of this token range. Higher priorities are less likely to be evicted or offloaded.
|
||||
RetentionPriority priority;
|
||||
/// @brief The duration in ms that the block should remain at the given priority level. Set to std::nullopt to
|
||||
/// have no expiration time, and keep the block at the given priority level until it gets reclaimed. After the
|
||||
/// duration has passed, the block will be moved back to the `kDefaultRetentionPriority` level.
|
||||
std::optional<std::chrono::milliseconds> durationMs;
|
||||
|
||||
bool operator==(TokenRangeRetentionPriority const& other) const
|
||||
bool operator==(TokenRangeRetentionConfig const& other) const
|
||||
{
|
||||
return tokenStart == other.tokenStart && tokenEnd == other.tokenEnd && priority == other.priority;
|
||||
return tokenStart == other.tokenStart && tokenEnd == other.tokenEnd && priority == other.priority
|
||||
&& durationMs == other.durationMs;
|
||||
}
|
||||
};
|
||||
|
||||
@ -410,22 +439,28 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
KvCacheRetentionConfig(std::vector<TokenRangeRetentionPriority> const& tokenRangeRetentionPriorities,
|
||||
RetentionPriority decodeRetentionPriority);
|
||||
explicit KvCacheRetentionConfig(std::vector<TokenRangeRetentionConfig> const& tokenRangeRetentionPriorities,
|
||||
RetentionPriority decodeRetentionPriority = kDefaultRetentionPriority,
|
||||
std::optional<std::chrono::milliseconds> decodeDurationMs = std::nullopt);
|
||||
|
||||
[[nodiscard]] std::vector<TokenRangeRetentionPriority> getTokenRangeRetentionPriorities() const;
|
||||
[[nodiscard]] std::vector<TokenRangeRetentionConfig> getTokenRangeRetentionConfigs() const;
|
||||
[[nodiscard]] RetentionPriority getDecodeRetentionPriority() const;
|
||||
[[nodiscard]] std::optional<std::chrono::milliseconds> getDecodeDurationMs() const;
|
||||
|
||||
/// @brief Convert the token range data into an entry per kv cache block for a given seqLen
|
||||
std::vector<std::optional<RetentionPriority>> getPerBlockEvictionPolicy(SizeType32 blockSize, SizeType32 seqLen);
|
||||
/// @brief Convert the token range data into an entry per kv block. Returns a tuple of vectors corresponding to the
|
||||
/// priorities and durations for each block.
|
||||
[[nodiscard]] std::vector<RetentionPriorityAndDuration> getPerBlockRetentionPriorityDuration(
|
||||
SizeType32 blockSize, SizeType32 seqLen) const;
|
||||
|
||||
private:
|
||||
/// @brief The token ranges and priority levels to update. Ranges must be non-overlapping. For example [(0, 64),
|
||||
/// (100, 128), (70, 80)] is valid, whereas
|
||||
/// [(0, 64), (60, 128)] is not.
|
||||
std::vector<TokenRangeRetentionPriority> mTokenRangeRetentionPriorities;
|
||||
std::vector<TokenRangeRetentionConfig> mTokenRangeRetentionConfigs;
|
||||
/// @brief The priority level to assign to blocks allocated in the decode phase
|
||||
RetentionPriority mDecodeRetentionPriority;
|
||||
/// @brief The duration in ms that decode blocks should remain at their assigned priority level.
|
||||
std::optional<std::chrono::milliseconds> mDecodeDurationMs;
|
||||
};
|
||||
|
||||
/// @brief A class that holds information about the request
|
||||
@ -466,6 +501,7 @@ public:
|
||||
/// @param contextPhaseParams Generated token ID from context only executor.
|
||||
/// @param numReturnSequences The number of returning sequences.
|
||||
/// @param eagleConfig The EAGLE speculative decoding configuration
|
||||
/// @param skipCrossAttnBlocks Skip the cross attention transformer blocks or not.
|
||||
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
|
||||
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
|
||||
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
|
||||
@ -486,7 +522,8 @@ public:
|
||||
std::optional<Tensor> encoderInputFeatures = std::nullopt,
|
||||
std::optional<SizeType32> encoderOutputLength = std::nullopt,
|
||||
std::optional<Tensor> crossAttentionMask = std::nullopt, SizeType32 numReturnSequences = 1,
|
||||
std::optional<EagleConfig> eagleConfig = std::nullopt);
|
||||
std::optional<EagleConfig> eagleConfig = std::nullopt,
|
||||
std::optional<Tensor> skipCrossAttnBlocks = std::nullopt);
|
||||
|
||||
/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
|
||||
static auto constexpr kBatchedPostProcessorName = "batched";
|
||||
@ -526,6 +563,7 @@ public:
|
||||
[[nodiscard]] RequestType getRequestType() const;
|
||||
[[nodiscard]] SizeType32 getNumReturnSequences() const;
|
||||
[[nodiscard]] std::optional<EagleConfig> getEagleConfig() const;
|
||||
[[nodiscard]] std::optional<Tensor> getSkipCrossAttnBlocks() const;
|
||||
|
||||
void setStreaming(bool streaming);
|
||||
void setSamplingConfig(SamplingConfig const& config);
|
||||
@ -553,6 +591,7 @@ public:
|
||||
void setCrossAttentionMask(Tensor crossAttentionMask);
|
||||
void setNumReturnSequences(SizeType32 numReturnSequences);
|
||||
void setEagleConfig(std::optional<EagleConfig> const& eagleConfig);
|
||||
void setSkipCrossAttnBlocks(Tensor skipCrossAttnBlocks);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
@ -568,6 +607,9 @@ struct SpeculativeDecodingFastLogitsInfo
|
||||
|
||||
/// @brief MPI world rank of the draft model leader
|
||||
int32_t draftParticipantId;
|
||||
|
||||
/// @brief Returns the struct serialized into a tensor that can be used as generation logits input
|
||||
[[nodiscard]] Tensor toTensor() const;
|
||||
};
|
||||
|
||||
/// @brief Struct that holds the generation result
|
||||
@ -735,7 +777,8 @@ public:
|
||||
std::optional<FloatType> const& freeGpuMemoryFraction = std::nullopt,
|
||||
std::optional<size_t> const& hostCacheSize = std::nullopt, bool onboardBlocks = true,
|
||||
std::optional<FloatType> const& crossKvCacheFraction = std::nullopt,
|
||||
std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt);
|
||||
std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0,
|
||||
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt);
|
||||
|
||||
[[nodiscard]] bool getEnableBlockReuse() const;
|
||||
[[nodiscard]] std::optional<SizeType32> getMaxTokens() const;
|
||||
@ -746,6 +789,7 @@ public:
|
||||
[[nodiscard]] std::optional<size_t> getHostCacheSize() const;
|
||||
[[nodiscard]] bool getOnboardBlocks() const;
|
||||
[[nodiscard]] std::optional<RetentionPriority> getSecondaryOffloadMinPriority() const;
|
||||
[[nodiscard]] size_t getEventBufferMaxSize() const;
|
||||
|
||||
void setEnableBlockReuse(bool enableBlockReuse);
|
||||
void setMaxTokens(SizeType32 maxTokens);
|
||||
@ -756,6 +800,8 @@ public:
|
||||
void setHostCacheSize(size_t hostCacheSize);
|
||||
void setOnboardBlocks(bool onboardBlocks);
|
||||
void setSecondaryOffloadMinPriority(std::optional<RetentionPriority> secondaryOffloadMinPriority);
|
||||
void setEventBufferMaxSize(size_t eventBufferMaxSize);
|
||||
void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults runtimeDefaults);
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
@ -797,6 +843,9 @@ private:
|
||||
|
||||
/// @brief Only blocks with priority > mSecondaryOfflineMinPriority can be offloaded to secondary memory.
|
||||
std::optional<RetentionPriority> mSecondaryOffloadMinPriority;
|
||||
|
||||
/// @brief Max size of the KV cache event buffer
|
||||
size_t mEventBufferMaxSize;
|
||||
};
|
||||
|
||||
/// @brief Configuration class for the runtime perf knobs
|
||||
@ -1197,6 +1246,113 @@ private:
|
||||
std::optional<SpeculativeDecodingConfig> mSpeculativeDecodingConfig;
|
||||
};
|
||||
|
||||
struct KVCacheCreatedData
|
||||
{
|
||||
/// @brief The amount of blocks at each cache level
|
||||
std::vector<SizeType32> numBlocksPerCacheLevel;
|
||||
};
|
||||
|
||||
/// @brief An entry for a single block stored into the tree
|
||||
struct KVCacheStoredBlockData
|
||||
{
|
||||
|
||||
KVCacheStoredBlockData(IdType blockHash, tensorrt_llm::runtime::VecUniqueTokens const& tokens,
|
||||
tensorrt_llm::runtime::LoraTaskIdType loraId, SizeType32 cacheLevel, SizeType32 priority)
|
||||
: blockHash{blockHash}
|
||||
, tokens{tokens}
|
||||
, loraId{loraId}
|
||||
, cacheLevel{cacheLevel}
|
||||
, priority{priority}
|
||||
{
|
||||
}
|
||||
|
||||
/// @brief The hash of the block
|
||||
IdType blockHash;
|
||||
/// @brief The unique tokens of the block
|
||||
tensorrt_llm::runtime::VecUniqueTokens tokens;
|
||||
/// @brief The Lora task id of the block
|
||||
tensorrt_llm::runtime::LoraTaskIdType loraId;
|
||||
/// @brief The cache level of the block
|
||||
SizeType32 cacheLevel;
|
||||
/// @brief The priority of the block
|
||||
SizeType32 priority;
|
||||
};
|
||||
|
||||
struct KVCacheStoredData
|
||||
{
|
||||
/// @brief The parent of this sequence of stored blocks
|
||||
std::optional<IdType> parentHash;
|
||||
/// @brief A sequence of blocks. The parent of block `i` is block `i-1`
|
||||
std::vector<KVCacheStoredBlockData> blocks;
|
||||
};
|
||||
|
||||
struct KVCacheRemovedData
|
||||
{
|
||||
/// @brief The hashes of blocks being removed
|
||||
std::vector<IdType> blockHashes;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct KVCacheEventDiff
|
||||
{
|
||||
T oldValue;
|
||||
T newValue;
|
||||
};
|
||||
|
||||
struct KVCacheUpdatedData
|
||||
{
|
||||
|
||||
explicit KVCacheUpdatedData(IdType blockHash)
|
||||
: blockHash{blockHash} {};
|
||||
|
||||
KVCacheUpdatedData& cacheLevelUpdated(SizeType32 oldValue, SizeType32 newValue)
|
||||
{
|
||||
cacheLevel = KVCacheEventDiff<SizeType32>{oldValue, newValue};
|
||||
return *this;
|
||||
}
|
||||
|
||||
KVCacheUpdatedData& priorityUpdated(SizeType32 oldValue, SizeType32 newValue)
|
||||
{
|
||||
priority = KVCacheEventDiff<SizeType32>{oldValue, newValue};
|
||||
return *this;
|
||||
}
|
||||
|
||||
/// @brief The hash of the updated block
|
||||
IdType blockHash;
|
||||
/// @brief The updated value of the cacheLevel field
|
||||
std::optional<KVCacheEventDiff<SizeType32>> cacheLevel = std::nullopt;
|
||||
/// @brief The updated value of the priority field
|
||||
std::optional<KVCacheEventDiff<SizeType32>> priority = std::nullopt;
|
||||
};
|
||||
|
||||
using KVCacheEventData = std::variant<KVCacheCreatedData, KVCacheStoredData, KVCacheRemovedData, KVCacheUpdatedData>;
|
||||
|
||||
struct KVCacheEvent
|
||||
{
|
||||
|
||||
KVCacheEvent(IdType eventId, KVCacheEventData data);
|
||||
|
||||
/// @brief The unique id of this event
|
||||
IdType eventId;
|
||||
/// @brief The data corresponding to this event
|
||||
KVCacheEventData data;
|
||||
};
|
||||
|
||||
/// @brief Exposes a limited set of KV cache manager functionalities
|
||||
class KVCacheEventManager
|
||||
{
|
||||
public:
|
||||
KVCacheEventManager(std::shared_ptr<tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager> kvCacheManager);
|
||||
|
||||
/// @brief Get the latest KV Cache events.
|
||||
/// @param timeout The maximum time to wait for new events. If nullopt, will only return when new events are
|
||||
/// available, or when the executor instance has shutdown.
|
||||
std::deque<KVCacheEvent> getLatestEvents(std::optional<std::chrono::milliseconds> timeout = std::nullopt);
|
||||
|
||||
private:
|
||||
std::shared_ptr<tensorrt_llm::batch_manager::kv_cache_manager::KVCacheManager> kvCacheManager;
|
||||
};
|
||||
|
||||
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
|
||||
class Executor
|
||||
{
|
||||
@ -1300,6 +1456,8 @@ public:
|
||||
/// @brief Indicates if the current process participates in this executor instance
|
||||
[[nodiscard]] bool isParticipant() const;
|
||||
|
||||
std::optional<std::shared_ptr<KVCacheEventManager>> getKVCacheEventManager() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> mImpl;
|
||||
|
||||
@ -164,12 +164,11 @@ public:
|
||||
static void serialize(KvCacheRetentionConfig const& kvCacheRetentionConfig, std::ostream& os);
|
||||
static size_t serializedSize(KvCacheRetentionConfig const& kvCacheRetentionConfig);
|
||||
|
||||
// TokenRangeRetentionPriority
|
||||
static KvCacheRetentionConfig::TokenRangeRetentionPriority deserializeTokenRangeRetentionPriority(std::istream& is);
|
||||
// TokenRangeRetentionConfig
|
||||
static KvCacheRetentionConfig::TokenRangeRetentionConfig deserializeTokenRangeRetentionConfig(std::istream& is);
|
||||
static void serialize(
|
||||
KvCacheRetentionConfig::TokenRangeRetentionPriority const& tokenRangeRetentionPriority, std::ostream& os);
|
||||
static size_t serializedSize(
|
||||
KvCacheRetentionConfig::TokenRangeRetentionPriority const& tokenRangeRetentionPriority);
|
||||
KvCacheRetentionConfig::TokenRangeRetentionConfig const& tokenRangeRetentionConfig, std::ostream& os);
|
||||
static size_t serializedSize(KvCacheRetentionConfig::TokenRangeRetentionConfig const& tokenRangeRetentionConfig);
|
||||
|
||||
// DecodingConfig
|
||||
static DecodingConfig deserializeDecodingConfig(std::istream& is);
|
||||
|
||||
@ -58,6 +58,8 @@ public:
|
||||
TensorPtr draftPaths;
|
||||
//! [maxBatchSize] or [numGenSequences]
|
||||
TensorPtr specDecodingGenerationLengths;
|
||||
//! [maxBatchSize] or [numGenSequences]
|
||||
TensorPtr specDecodingGenerationLengthsHost;
|
||||
//! [maxBatchSize, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
|
||||
//! or [numGenSequences, maxDecodingTokens, ceil(maxDecodingTokens / 32)]
|
||||
TensorPtr specDecodingPackedMasks;
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/modelConfig.h"
|
||||
#include "tensorrt_llm/runtime/runtimeDefaults.h"
|
||||
#include "tensorrt_llm/runtime/worldConfig.h"
|
||||
|
||||
#include <filesystem>
|
||||
@ -32,7 +33,8 @@ class GptJsonConfig
|
||||
{
|
||||
public:
|
||||
GptJsonConfig(std::string name, std::string version, std::string precision, SizeType32 tensorParallelism,
|
||||
SizeType32 pipelineParallelism, SizeType32 gpusPerNode, ModelConfig modelConfig)
|
||||
SizeType32 pipelineParallelism, SizeType32 gpusPerNode, ModelConfig modelConfig,
|
||||
std::optional<RuntimeDefaults> runtimeDefaults = std::nullopt)
|
||||
: mName(std::move(name))
|
||||
, mVersion(std::move(version))
|
||||
, mPrecision(std::move(precision))
|
||||
@ -40,6 +42,7 @@ public:
|
||||
, mPipelineParallelism{pipelineParallelism}
|
||||
, mGpusPerNode{gpusPerNode}
|
||||
, mModelConfig(std::move(modelConfig))
|
||||
, mRuntimeDefaults(std::move(runtimeDefaults))
|
||||
{
|
||||
}
|
||||
|
||||
@ -94,6 +97,11 @@ public:
|
||||
return mTensorParallelism * mPipelineParallelism;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<RuntimeDefaults> getRuntimeDefaults() const
|
||||
{
|
||||
return mRuntimeDefaults;
|
||||
}
|
||||
|
||||
[[nodiscard]] std::string engineFilename(WorldConfig const& worldConfig, std::string const& model) const;
|
||||
|
||||
[[nodiscard]] std::string engineFilename(WorldConfig const& worldConfig) const
|
||||
@ -109,6 +117,7 @@ private:
|
||||
SizeType32 const mPipelineParallelism;
|
||||
SizeType32 const mGpusPerNode;
|
||||
ModelConfig mModelConfig; // remove const qualifier because config has to mutable after json parsing
|
||||
std::optional<RuntimeDefaults> mRuntimeDefaults;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -71,4 +71,6 @@ public:
|
||||
std::vector<runtime::IpcMemory> mIpcMemoryHandles;
|
||||
};
|
||||
|
||||
void lamportInitializeAll(void* buffer_0, void* buffer_1, void* buffer_2, size_t size);
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -46,8 +46,8 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
/// We use mc_sim_7b_63 from official Medusa implementation, i.e. one of the best trees with 63 nodes found for 7B
|
||||
/// Vicuna model.
|
||||
// We use mc_sim_7b_63 from official Medusa implementation, i.e. one of the best trees with 63 nodes found for 7B
|
||||
// Vicuna model.
|
||||
// We use it as default, if no other are trees are specified on the server level.
|
||||
MedusaChoices mDefaultMedusaChoices = {{0}, {0, 0}, {1}, {0, 1}, {2}, {0, 0, 0}, {1, 0}, {0, 2}, {3}, {0, 3}, {4},
|
||||
{0, 4}, {2, 0}, {0, 5}, {0, 0, 1}, {5}, {0, 6}, {6}, {0, 7}, {0, 1, 0}, {1, 1}, {7}, {0, 8}, {0, 0, 2}, {3, 0},
|
||||
|
||||
@ -137,6 +137,7 @@ public:
|
||||
, mLogitsDtype(nvinfer1::DataType::kFLOAT)
|
||||
, mUseShapeInference(true)
|
||||
, mManageWeightsType(ManageWeightsType::kDisabled)
|
||||
, mSkipCrossAttnBlocks(false)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mNbLayers >= mNbAttentionLayers + mNbRnnLayers,
|
||||
"Number of layers (%d) expected to be >= number of attention (%d) + number of rnn layers (%d)", mNbLayers,
|
||||
@ -760,6 +761,16 @@ public:
|
||||
return sumLocalHeads;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr skipCrossAttnBlocks() const noexcept
|
||||
{
|
||||
return mSkipCrossAttnBlocks;
|
||||
}
|
||||
|
||||
void constexpr setSkipCrossAttnBlocks(bool skipCrossAttnBlocks) noexcept
|
||||
{
|
||||
mSkipCrossAttnBlocks = skipCrossAttnBlocks;
|
||||
}
|
||||
|
||||
private:
|
||||
SizeType32 mVocabSize;
|
||||
SizeType32 mNbLayers;
|
||||
@ -821,6 +832,7 @@ private:
|
||||
std::string mModelName;
|
||||
std::vector<SizeType32> mNumKvHeadsPerAttentionLayer;
|
||||
std::vector<SizeType32> mNumKvHeadsPerCrossAttentionLayer;
|
||||
bool mSkipCrossAttnBlocks;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
38
cpp/include/tensorrt_llm/runtime/runtimeDefaults.h
Normal file
38
cpp/include/tensorrt_llm/runtime/runtimeDefaults.h
Normal file
@ -0,0 +1,38 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
struct RuntimeDefaults
|
||||
{
|
||||
RuntimeDefaults(
|
||||
std::optional<std::vector<SizeType32>> maxAttentionWindowVec, std::optional<SizeType32> sinkTokenLength)
|
||||
: maxAttentionWindowVec(maxAttentionWindowVec)
|
||||
, sinkTokenLength(sinkTokenLength)
|
||||
{
|
||||
}
|
||||
|
||||
std::optional<std::vector<SizeType32>> maxAttentionWindowVec;
|
||||
std::optional<SizeType32> sinkTokenLength;
|
||||
|
||||
RuntimeDefaults() = default;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
@ -87,7 +87,7 @@ public:
|
||||
|
||||
[[nodiscard]] bool constexpr updatesPositionIds() const
|
||||
{
|
||||
return anyBitSet(kLookaheadDecoding | kExplicitDraftTokens | kEagle);
|
||||
return anyBitSet(kLookaheadDecoding | kExplicitDraftTokens);
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr requiresAttentionMask() const
|
||||
|
||||
@ -23,10 +23,6 @@
|
||||
|
||||
namespace tensorrt_llm::runtime::utils
|
||||
{
|
||||
|
||||
static SizeType32 constexpr PREFIX_CHUNK_SIZE_BITS = 4;
|
||||
static SizeType32 constexpr PREFIX_MAX_VALUE = 16;
|
||||
|
||||
struct TreeNode
|
||||
{
|
||||
SizeType32 nodeId;
|
||||
@ -36,11 +32,9 @@ struct TreeNode
|
||||
std::vector<SizeType32> childLinearIndices;
|
||||
};
|
||||
|
||||
void initTensorsFromChoices(SpeculativeDecodingModule const& speculativeDecodingModule,
|
||||
SizeType32 initTensorsFromChoices(SpeculativeDecodingModule const& speculativeDecodingModule,
|
||||
std::vector<std::vector<SizeType32>> const& choices, std::vector<SizeType32>& topKs,
|
||||
ITensor::SharedPtr generationInputLengths, ITensor::SharedPtr positionOffsets, ITensor::SharedPtr treeIds,
|
||||
ITensor::SharedPtr paths, ITensor::SharedPtr packedMask);
|
||||
|
||||
void dumpChoices(std::vector<std::vector<SizeType32>> const& choices, std::vector<SizeType32> const& indices);
|
||||
|
||||
} // namespace tensorrt_llm::runtime::utils
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:24555ce900cefa5eb24441c503332c4997d8f1e695e29bfe72eef76eb01d4406
|
||||
size 5389730
|
||||
oid sha256:748a53a5f70813f0ddb5bb54a56cd07a4b9146917c12ec34504dc4384b00610b
|
||||
size 5882210
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:33f9d6f50e53218bd935226fa364d28eed82b5624e606c08a9e51a63b5b2e15d
|
||||
size 5507018
|
||||
oid sha256:2350b7f07b5f30179ebf24f6e103dc17d4a656c95c171eaca684529120ca245a
|
||||
size 6001974
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:94319899700c8ff1bba3f9f3df4b3cde190a2da2d676e9c2af71f281e99e6cf8
|
||||
size 1986712
|
||||
oid sha256:8b28f05452036c1722a37ac625921cf4902cfb6c04fb01b9d958b9f40ff9be0b
|
||||
size 1958384
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
279b2521d189ac03d35a7466330ea425 libtensorrt_llm_ucx_wrapper.so
|
||||
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
|
||||
0066a5a67ec747f565158bbbc398cca9 libtensorrt_llm_ucx_wrapper.so
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:bf6dfef7af51dc17f06f70010c0a9197ff107c98042a010eabdc1c5a9931abbe
|
||||
size 5239294
|
||||
oid sha256:0132b1d4544101465ac37993ae20324c0c49ae978b0a3c8c95a03a08a17b5b36
|
||||
size 5692876
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e7943e6af2196982ca90fb8c7fe813dc5dbd6dbd6839cd5706c27044ed3272cf
|
||||
size 5202544
|
||||
oid sha256:15ff5d0aeae4d3e776fdf3bb68af0cc5896b14f435b66a11fecc2111668fd089
|
||||
size 5659602
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
1598761c1df1fd35b2180b599ad34f58 libtensorrt_llm_ucx_wrapper.so
|
||||
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c4ebe7311b7423dd5f9ec822586743c9bcdfe088a9259647c8772d005ec64f79
|
||||
size 34643904
|
||||
oid sha256:f975b781b240c8489a48243a94dfdf0be6bfe6b862cf6ec6cbeacd5c66fae7af
|
||||
size 36139148
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
f7660c6225ba8f9a9bce7a06365c3a60 tensorrt_llm_batch_manager_static.lib
|
||||
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
|
||||
f9557afc965818430dcae14ae7542adf tensorrt_llm_batch_manager_static.lib
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
@ -30,12 +30,10 @@
|
||||
|
||||
#include "cudaDriverWrapper.h"
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include <cstdio>
|
||||
#include <cuda.h>
|
||||
#include <stdio.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
|
||||
@ -47,22 +45,21 @@ std::shared_ptr<CUDADriverWrapper> CUDADriverWrapper::getInstance()
|
||||
{
|
||||
return result;
|
||||
}
|
||||
else
|
||||
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
result = instance.lock();
|
||||
if (!result)
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
result = instance.lock();
|
||||
if (!result)
|
||||
{
|
||||
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
|
||||
instance = result;
|
||||
}
|
||||
return result;
|
||||
result = std::shared_ptr<CUDADriverWrapper>(new CUDADriverWrapper());
|
||||
instance = result;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
CUDADriverWrapper::CUDADriverWrapper()
|
||||
: handle(dllOpen(CUDA_LIB_NAME))
|
||||
{
|
||||
handle = dllOpen(CUDA_LIB_NAME);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(handle != nullptr, "CUDA driver library is not open correctly.");
|
||||
|
||||
auto load_sym = [](void* handle, char const* name)
|
||||
@ -71,21 +68,22 @@ CUDADriverWrapper::CUDADriverWrapper()
|
||||
return ret;
|
||||
};
|
||||
|
||||
*(void**) (&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
|
||||
*(void**) (&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
|
||||
*(void**) (&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
|
||||
*(void**) (&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
|
||||
*(void**) (&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
|
||||
*(void**) (&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
|
||||
*(void**) (&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
|
||||
*(void**) (&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
|
||||
*(void**) (&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
|
||||
*(void**) (&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
|
||||
*(void**) (&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
|
||||
*(void**) (&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
|
||||
*(void**) (&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
|
||||
*(void**) (&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
|
||||
*(void**) (&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
|
||||
*reinterpret_cast<void**>(&_cuGetErrorName) = load_sym(handle, "cuGetErrorName");
|
||||
*reinterpret_cast<void**>(&_cuGetErrorMessage) = load_sym(handle, "cuGetErrorMessage");
|
||||
*reinterpret_cast<void**>(&_cuFuncSetAttribute) = load_sym(handle, "cuFuncSetAttribute");
|
||||
*reinterpret_cast<void**>(&_cuLinkComplete) = load_sym(handle, "cuLinkComplete");
|
||||
*reinterpret_cast<void**>(&_cuModuleUnload) = load_sym(handle, "cuModuleUnload");
|
||||
*reinterpret_cast<void**>(&_cuLinkDestroy) = load_sym(handle, "cuLinkDestroy");
|
||||
*reinterpret_cast<void**>(&_cuModuleLoadData) = load_sym(handle, "cuModuleLoadData");
|
||||
*reinterpret_cast<void**>(&_cuLinkCreate) = load_sym(handle, "cuLinkCreate_v2");
|
||||
*reinterpret_cast<void**>(&_cuModuleGetFunction) = load_sym(handle, "cuModuleGetFunction");
|
||||
*reinterpret_cast<void**>(&_cuModuleGetGlobal) = load_sym(handle, "cuModuleGetGlobal_v2");
|
||||
*reinterpret_cast<void**>(&_cuLinkAddFile) = load_sym(handle, "cuLinkAddFile_v2");
|
||||
*reinterpret_cast<void**>(&_cuLinkAddData) = load_sym(handle, "cuLinkAddData_v2");
|
||||
*reinterpret_cast<void**>(&_cuLaunchCooperativeKernel) = load_sym(handle, "cuLaunchCooperativeKernel");
|
||||
*reinterpret_cast<void**>(&_cuLaunchKernel) = load_sym(handle, "cuLaunchKernel");
|
||||
*reinterpret_cast<void**>(&_cuTensorMapEncodeTiled) = load_sym(handle, "cuTensorMapEncodeTiled");
|
||||
*reinterpret_cast<void**>(&_cuMemcpyDtoH) = load_sym(handle, "cuMemcpyDtoH_v2");
|
||||
}
|
||||
|
||||
CUDADriverWrapper::~CUDADriverWrapper()
|
||||
@ -98,6 +96,11 @@ CUresult CUDADriverWrapper::cuGetErrorName(CUresult error, char const** pStr) co
|
||||
return (*_cuGetErrorName)(error, pStr);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuGetErrorMessage(CUresult error, char const** pStr) const
|
||||
{
|
||||
return (*_cuGetErrorMessage)(error, pStr);
|
||||
}
|
||||
|
||||
CUresult CUDADriverWrapper::cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const
|
||||
{
|
||||
return (*_cuFuncSetAttribute)(hfunc, attrib, value);
|
||||
@ -181,5 +184,4 @@ CUresult CUDADriverWrapper::cuMemcpyDtoH(void* dstHost, CUdeviceptr srcDevice, s
|
||||
return (*_cuMemcpyDtoH)(dstHost, srcDevice, ByteCount);
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
@ -17,33 +17,30 @@
|
||||
#ifndef CUDA_DRIVER_WRAPPER_H
|
||||
#define CUDA_DRIVER_WRAPPER_H
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include <cstdio>
|
||||
#include <cuda.h>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
#define cuErrCheck(stat, wrap) \
|
||||
{ \
|
||||
cuErrCheck_((stat), wrap.get(), __FILE__, __LINE__); \
|
||||
}
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace common
|
||||
namespace tensorrt_llm::common
|
||||
{
|
||||
|
||||
class CUDADriverWrapper
|
||||
{
|
||||
// Use getInstance() instead.
|
||||
CUDADriverWrapper();
|
||||
|
||||
public:
|
||||
static std::shared_ptr<CUDADriverWrapper> getInstance();
|
||||
|
||||
~CUDADriverWrapper();
|
||||
CUDADriverWrapper(CUDADriverWrapper const&) = delete;
|
||||
CUDADriverWrapper operator=(CUDADriverWrapper const&) = delete;
|
||||
CUDADriverWrapper(CUDADriverWrapper&&) = delete;
|
||||
CUDADriverWrapper operator=(CUDADriverWrapper&&) = delete;
|
||||
|
||||
CUresult cuGetErrorName(CUresult error, char const** pStr) const;
|
||||
|
||||
CUresult cuGetErrorMessage(CUresult error, char const** pStr) const;
|
||||
|
||||
CUresult cuFuncSetAttribute(CUfunction hfunc, CUfunction_attribute attrib, int value) const;
|
||||
|
||||
CUresult cuLinkComplete(CUlinkState state, void** cubinOut, size_t* sizeOut) const;
|
||||
@ -84,7 +81,10 @@ public:
|
||||
|
||||
private:
|
||||
void* handle;
|
||||
CUDADriverWrapper();
|
||||
|
||||
CUresult (*_cuGetErrorName)(CUresult, char const**);
|
||||
CUresult (*_cuGetErrorMessage)(CUresult, char const**);
|
||||
CUresult (*_cuFuncSetAttribute)(CUfunction, CUfunction_attribute, int);
|
||||
CUresult (*_cuLinkComplete)(CUlinkState, void**, size_t*);
|
||||
CUresult (*_cuModuleUnload)(CUmodule);
|
||||
@ -108,17 +108,31 @@ private:
|
||||
CUresult (*_cuMemcpyDtoH)(void* dstHost, CUdeviceptr srcDevice, size_t ByteCount);
|
||||
};
|
||||
|
||||
inline void cuErrCheck_(CUresult stat, CUDADriverWrapper const* wrap, char const* file, int line)
|
||||
template <typename T>
|
||||
void checkDriver(
|
||||
T result, CUDADriverWrapper const& wrap, char const* const func, char const* const file, int const line)
|
||||
{
|
||||
if (stat != CUDA_SUCCESS)
|
||||
if (result)
|
||||
{
|
||||
char const* msg = nullptr;
|
||||
wrap->cuGetErrorName(stat, &msg);
|
||||
fprintf(stderr, "CUDA Error: %s %s %d\n", msg, file, line);
|
||||
char const* errorName = nullptr;
|
||||
char const* errorMsg = nullptr;
|
||||
wrap.cuGetErrorName(result, &errorName);
|
||||
wrap.cuGetErrorMessage(result, &errorMsg);
|
||||
throw TllmException(
|
||||
file, line, fmtstr("[TensorRT-LLM][ERROR] CUDA driver error in %s: %s: %s", func, errorName, errorMsg));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace common
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
/*
|
||||
* Macros compliant with TensorRT coding conventions
|
||||
*/
|
||||
#define TLLM_CU_CHECK(stat) \
|
||||
do \
|
||||
{ \
|
||||
tensorrt_llm::common::checkDriver( \
|
||||
(stat), *tensorrt_llm::common::CUDADriverWrapper::getInstance(), #stat, __FILE__, __LINE__); \
|
||||
} while (0)
|
||||
|
||||
#endif // CUDA_DRIVER_WRAPPER_H
|
||||
|
||||
54
cpp/tensorrt_llm/common/jsonSerializeOptional.h
Normal file
54
cpp/tensorrt_llm/common/jsonSerializeOptional.h
Normal file
@ -0,0 +1,54 @@
|
||||
/*
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: Apache-2.0
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <optional>
|
||||
|
||||
namespace nlohmann
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
struct adl_serializer<std::optional<T>>
|
||||
{
|
||||
static void to_json(nlohmann::json& j, std::optional<T> const& opt)
|
||||
{
|
||||
if (opt == std::nullopt)
|
||||
{
|
||||
j = nullptr;
|
||||
}
|
||||
else
|
||||
{
|
||||
j = opt.value(); // this will call adl_serializer<T>::to_json which will
|
||||
// find the free function to_json in T's namespace!
|
||||
}
|
||||
}
|
||||
|
||||
static void from_json(nlohmann::json const& j, std::optional<T>& opt)
|
||||
{
|
||||
if (j.is_null())
|
||||
{
|
||||
opt = std::nullopt;
|
||||
}
|
||||
else
|
||||
{
|
||||
opt = j.template get<T>(); // same as above, but with
|
||||
// adl_serializer<T>::from_json
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace nlohmann
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8e633c968cb2712a79886a940a07a13f598543c2e936912d9099ef088a240d7c
|
||||
size 2358334
|
||||
oid sha256:33f66dba2f3024d979e38cf1aae4d10802c5a1fb0f4c801108c35824339eae5d
|
||||
size 2419566
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:775cba2bfd47b779944fb2e7ce275d002fba5b5cdefabb409f8a48cc77f157f9
|
||||
size 2391240
|
||||
oid sha256:d224780476ce5f398f30ffbfa0d61bbd0aae5cb1538c8d4c0a16cdf8945ba5d3
|
||||
size 2449532
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
b4c02793e0859c9a1cde80851235ac8b libtensorrt_llm_executor_static.a
|
||||
773778beb42d127b4322a92290c497cb libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
|
||||
ee532edbf35321d4ac0aadf8a3c6a3a5 libtensorrt_llm_executor_static.a
|
||||
0bf468a19d4c353dcf421fc3e05a9d7d libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:7155a2b04357cf1b5ea2cb46a49dd536d18943f4bce57fc8e3761aa11d4df943
|
||||
size 3440434
|
||||
oid sha256:9b21e2488bdb5c1e18e7aa129acb18087d031eea4f5b063910081ca09a3041a5
|
||||
size 3494984
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ce7093244a223954c718dc3a60a50e452057e378f79f58971989b30e6d858feb
|
||||
size 3357394
|
||||
oid sha256:94964aa02020e38e869bf9ca18385ae379c8b9d1819ad02e10b23d8175cc9d82
|
||||
size 3412104
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
09d25336d48b4281a27dc0e88e154ed8 libtensorrt_llm_executor_static.a
|
||||
549f092db03a48c3be836f90b0938347 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
|
||||
ba01eba908f38eb582c22c1f822cfedf libtensorrt_llm_executor_static.a
|
||||
ffe68ec0af94d364ec8db50a24ae0e8c libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:06f2682b105be0fc48687afc88480a56f3037e3c885d1a042bc2d6065fc59436
|
||||
size 22719724
|
||||
oid sha256:67f59341edab284c309d39f2a0ad39e91f8afe198c4cf6ba838ae7adb54ad01d
|
||||
size 23192460
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
45c167501f2f191bf2a3aa6e9d80ce9a tensorrt_llm_executor_static.lib
|
||||
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
|
||||
e3cd49147c73b0066dcb759df9556191 tensorrt_llm_executor_static.lib
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
@ -30,9 +30,7 @@
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -85,19 +83,17 @@ public:
|
||||
}
|
||||
else
|
||||
{
|
||||
cuErrCheck(mDriver->cuModuleLoadData(&hmod, kernelMeta.mCubin), mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuModuleLoadData(&hmod, kernelMeta.mCubin));
|
||||
mModules.insert(std::make_pair(kernelMeta.mCubin, hmod));
|
||||
}
|
||||
|
||||
FusedMultiHeadAttentionKernelInfo funcInfo;
|
||||
funcInfo.mMetaInfoIndex = i;
|
||||
cuErrCheck(
|
||||
mDriver->cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName), mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName));
|
||||
if (kernelMeta.mSharedMemBytes >= 48 * 1024)
|
||||
{
|
||||
cuErrCheck(mDriver->cuFuncSetAttribute(funcInfo.mDeviceFunction,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, kernelMeta.mSharedMemBytes),
|
||||
mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuFuncSetAttribute(funcInfo.mDeviceFunction,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, kernelMeta.mSharedMemBytes));
|
||||
}
|
||||
mFunctions.insert(std::make_pair(hashID(kernelMeta), funcInfo));
|
||||
int s = static_cast<int>(kernelMeta.mS);
|
||||
@ -120,9 +116,8 @@ public:
|
||||
const CUfunction func = findIter->second.mDeviceFunction;
|
||||
|
||||
void* kernelParams[] = {¶ms, nullptr};
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr));
|
||||
}
|
||||
|
||||
virtual bool checkIfKernelExist(MHARunnerFixedParams params) const = 0;
|
||||
@ -276,9 +271,8 @@ public:
|
||||
|
||||
if (!forceUnroll)
|
||||
{
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, params.h, params.b, 1, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr));
|
||||
} // forceunroll = true for flash attention kernels
|
||||
else if (mSM == kSM_90 && launch_params.flash_attention && launch_params.warp_specialization)
|
||||
{
|
||||
@ -327,9 +321,8 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, block_size.x, block_size.y, block_size.z,
|
||||
kernelMeta.mThreadsPerCTA, 1, 1, kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, block_size.x, block_size.y, block_size.z,
|
||||
kernelMeta.mThreadsPerCTA, 1, 1, kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr));
|
||||
}
|
||||
else
|
||||
{ // forceunroll = true for flash attention kernels
|
||||
@ -344,9 +337,8 @@ public:
|
||||
// on Hopper non-flash-attention, we still launch blocks (h, b, steps)
|
||||
if (mSM == kSM_90 && !launch_params.flash_attention)
|
||||
{
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, params.h, params.b, unroll, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, params.h, params.b, unroll, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr));
|
||||
} // on Ampere/Ada/Volta flash attention, we launch blocks (steps, h, b)
|
||||
else
|
||||
{
|
||||
@ -356,9 +348,8 @@ public:
|
||||
// For cases exceeding 256 dimensions, the number of CTAs needs to be multiplied.
|
||||
unroll *= (params.dv + 256 - 1) / 256;
|
||||
}
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, unroll, params.h, params.b, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, unroll, params.h, params.b, kernelMeta.mThreadsPerCTA, 1, 1,
|
||||
kernelMeta.mSharedMemBytes, stream, kernelParams, nullptr));
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -452,5 +443,4 @@ inline FusedMultiHeadAttentionXMMAKernelV2 const* getXMMAKernelsV2(Data_type typ
|
||||
sMhaKernelMetaInfosV2, sizeof(sMhaKernelMetaInfosV2) / sizeof(sMhaKernelMetaInfosV2[0]), type, sm);
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::kernels
|
||||
|
||||
@ -882,7 +882,7 @@ static __global__ void __launch_bounds__(1024, 1) one_shot_all_reduce_norm_kerne
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool is_lamport_supported(int token_num)
|
||||
bool is_lamport_supported(int token_num, int hidden_size)
|
||||
{
|
||||
static char* disableLamportReduceNormFusionChar = std::getenv("DISABLE_LAMPORT_REDUCE_NORM_FUSION");
|
||||
bool disableLamportReduceNormFusion = (disableLamportReduceNormFusionChar != nullptr);
|
||||
@ -901,17 +901,21 @@ bool is_lamport_supported(int token_num)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if (hidden_size < details::kLamportHiddenSizeThreshold)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool is_lamport_supported(nvinfer1::DataType dataType, int token_num)
|
||||
bool is_lamport_supported(nvinfer1::DataType dataType, int token_num, int hidden_size)
|
||||
{
|
||||
switch (dataType)
|
||||
{
|
||||
case nvinfer1::DataType::kFLOAT: return is_lamport_supported<float>(token_num);
|
||||
case nvinfer1::DataType::kHALF: return is_lamport_supported<half>(token_num);
|
||||
case nvinfer1::DataType::kFLOAT: return is_lamport_supported<float>(token_num, hidden_size);
|
||||
case nvinfer1::DataType::kHALF: return is_lamport_supported<half>(token_num, hidden_size);
|
||||
#ifdef ENABLE_BF16
|
||||
case nvinfer1::DataType::kBF16: return is_lamport_supported<__nv_bfloat16>(token_num);
|
||||
case nvinfer1::DataType::kBF16: return is_lamport_supported<__nv_bfloat16>(token_num, hidden_size);
|
||||
#endif
|
||||
default: return false;
|
||||
}
|
||||
@ -921,7 +925,7 @@ template <typename T, int RanksPerNode, bool Bias, bool Affine>
|
||||
void one_shot_all_reduce_norm_kernel_launcher(AllReduceParams& params, cudaStream_t stream)
|
||||
{
|
||||
int token_num = params.elts_total / params.fusion_params.hidden_size;
|
||||
if (is_lamport_supported<T>(token_num))
|
||||
if (is_lamport_supported<T>(token_num, params.fusion_params.hidden_size))
|
||||
{
|
||||
lamport_style_one_shot_all_reduce_norm_kernel_launcher<T, RanksPerNode, Bias, Affine>(params, stream);
|
||||
}
|
||||
@ -1568,12 +1572,13 @@ void AllReduceDispatchType(AllReduceParams& params, AllReduceStrategyType strat,
|
||||
}
|
||||
}
|
||||
|
||||
AllReduceParams AllReduceParams::deserialize(
|
||||
int64_t* buffer, size_t tpSize, size_t tpRank, nvinfer1::DataType dataType, int token_num, AllReduceFusionOp op)
|
||||
AllReduceParams AllReduceParams::deserialize(int64_t* buffer, size_t tpSize, size_t tpRank, nvinfer1::DataType dataType,
|
||||
int token_num, int hidden_size, AllReduceFusionOp op)
|
||||
{
|
||||
void* const* buffer_ptrs = reinterpret_cast<void* const*>(buffer);
|
||||
int flag_offset;
|
||||
if (op == AllReduceFusionOp::RESIDUAL_RMS_NORM && reduce_fusion::is_lamport_supported(dataType, token_num))
|
||||
if (op == AllReduceFusionOp::RESIDUAL_RMS_NORM
|
||||
&& reduce_fusion::is_lamport_supported(dataType, token_num, hidden_size))
|
||||
{
|
||||
flag_offset = 0;
|
||||
}
|
||||
|
||||
@ -38,6 +38,7 @@ static constexpr int kWarpSize = 32;
|
||||
static constexpr int kMaxCtaSize = 1024;
|
||||
static constexpr int kClusterMaxSize = 8;
|
||||
static constexpr int kLamportTokenNumThreshold = 16;
|
||||
static constexpr int kLamportHiddenSizeThreshold = 256;
|
||||
}; // namespace reduce_fusion::details
|
||||
|
||||
// Warning: python definition is in tensorrt_llm/functional.py
|
||||
@ -103,7 +104,7 @@ struct AllReduceParams
|
||||
AllReduceFusionParams fusion_params;
|
||||
|
||||
static AllReduceParams deserialize(int64_t* buffer, size_t tpSize, size_t tpRank, nvinfer1::DataType dataType,
|
||||
int token_num, AllReduceFusionOp op);
|
||||
int token_num, int hidden_size, AllReduceFusionOp op);
|
||||
};
|
||||
|
||||
bool configurationSupported(AllReduceStrategyType algo, size_t msg_size, size_t n_ranks, nvinfer1::DataType type);
|
||||
|
||||
@ -703,6 +703,18 @@ extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nq
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin[];
|
||||
extern unsigned long long xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin[];
|
||||
|
||||
// MHA with beamWidth=4
|
||||
extern unsigned long long xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_90_cubin[];
|
||||
@ -829,6 +841,18 @@ extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len;
|
||||
extern uint32_t xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len;
|
||||
|
||||
// MHA with beamWidth=4
|
||||
extern uint32_t xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_90_cubin_len;
|
||||
@ -1258,6 +1282,18 @@ static const struct XQAKernelMetaInfo
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 16, 128, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 128, 1, 0, 32, 128, true, true, kSM_90, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_128_beam_1_kvt_e4m3_pagedKV_128_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 64, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 64, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_fp16_d_64_beam_1_kvt_fp16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 64, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_INT8, 64, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_fp16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 64, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_E4M3, 64, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_fp16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 64, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_BF16, 64, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_64_beam_1_kvt_bf16_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 64, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_INT8, 64, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_64_beam_1_kvt_int8_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 64, 1, 0, 16, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin, xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_16_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_BF16, DATA_TYPE_E4M3, 64, 1, 0, 32, 64, true, true, kSM_90, xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin, xqa_kernel_dt_bf16_d_64_beam_1_kvt_e4m3_pagedKV_64_nqpkv_0_m_32_sm_90_cubin_len, "kernel_mha"},
|
||||
// MHA with beamWidth=4
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 64, true, false, kSM_90, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_90_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_64_nqpkv_1_m_1_sm_90_cubin_len, "kernel_mha"},
|
||||
{ DATA_TYPE_FP16, DATA_TYPE_FP16, 256, 4, 1, 1, 128, true, false, kSM_90, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_90_cubin, xqa_kernel_dt_fp16_d_256_beam_4_kvt_fp16_pagedKV_128_nqpkv_1_m_1_sm_90_cubin_len, "kernel_mha"},
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -21,11 +21,7 @@
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include <cuda_runtime_api.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace jit
|
||||
namespace tensorrt_llm::kernels::jit
|
||||
{
|
||||
|
||||
CubinObj::CubinObj(void const* buffer_, size_t buffer_size)
|
||||
@ -135,9 +131,8 @@ void CubinObj::serialize(void* buffer_, size_t buffer_size) const noexcept
|
||||
void CubinObj::launch(dim3 gridDim, dim3 blockDim, CUstream hStream, void** kernelParams)
|
||||
{
|
||||
TLLM_CHECK(mInitialized);
|
||||
cuErrCheck(mDriver->cuLaunchKernel(mFunction, gridDim.x, gridDim.y, gridDim.z, blockDim.x, blockDim.y, blockDim.z,
|
||||
mSharedMemBytes, hStream, kernelParams, /*extra=*/nullptr),
|
||||
mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuLaunchKernel(mFunction, gridDim.x, gridDim.y, gridDim.z, blockDim.x, blockDim.y,
|
||||
blockDim.z, mSharedMemBytes, hStream, kernelParams, /*extra=*/nullptr));
|
||||
}
|
||||
|
||||
void CubinObj::initialize()
|
||||
@ -146,26 +141,25 @@ void CubinObj::initialize()
|
||||
{
|
||||
mDriver = tensorrt_llm::common::CUDADriverWrapper::getInstance();
|
||||
mModule = nullptr;
|
||||
cuErrCheck(mDriver->cuModuleLoadData(&mModule, mContent.c_str()), mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuModuleLoadData(&mModule, mContent.c_str()));
|
||||
TLLM_CHECK(mModule != nullptr);
|
||||
mFunction = nullptr;
|
||||
cuErrCheck(mDriver->cuModuleGetFunction(&mFunction, mModule, kFuncName), mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuModuleGetFunction(&mFunction, mModule, kFuncName));
|
||||
TLLM_CHECK(mFunction != nullptr);
|
||||
|
||||
// Populate mSharedMemBytes.
|
||||
CUdeviceptr shmem_dev_ptr = 0;
|
||||
cuErrCheck(mDriver->cuModuleGetGlobal(&shmem_dev_ptr, nullptr, mModule, kSmemName), mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuModuleGetGlobal(&shmem_dev_ptr, nullptr, mModule, kSmemName));
|
||||
TLLM_CHECK(shmem_dev_ptr != 0);
|
||||
cuErrCheck(mDriver->cuMemcpyDtoH(&mSharedMemBytes, shmem_dev_ptr, sizeof(unsigned int)), mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuMemcpyDtoH(&mSharedMemBytes, shmem_dev_ptr, sizeof(unsigned int)));
|
||||
|
||||
TLLM_CHECK(mSharedMemBytes > 0);
|
||||
|
||||
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */
|
||||
if (mSharedMemBytes >= 46 * 1024)
|
||||
{
|
||||
cuErrCheck(mDriver->cuFuncSetAttribute(
|
||||
mFunction, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, mSharedMemBytes),
|
||||
mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuFuncSetAttribute(
|
||||
mFunction, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, mSharedMemBytes));
|
||||
}
|
||||
|
||||
sync_check_cuda_error();
|
||||
@ -177,11 +171,9 @@ CubinObj::~CubinObj()
|
||||
{
|
||||
if (mInitialized)
|
||||
{
|
||||
cuErrCheck(mDriver->cuModuleUnload(mModule), mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuModuleUnload(mModule));
|
||||
mInitialized = false;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::kernels::jit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1397678ac1cab957f7d272750035ffd88b9e2b3b9d4f132073119d21c288b5da
|
||||
size 82262624
|
||||
oid sha256:53b2ebc1484d068fa60c8e5ad22bf2db40bd84963bb0d2e679bcec9f53b65c5d
|
||||
size 82318536
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
5ea3eabf1c58887230ba5ebc583e0d3c libtensorrt_llm_nvrtc_wrapper.so
|
||||
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
|
||||
90df70c216d9aa2c85b8b097c853e4ba libtensorrt_llm_nvrtc_wrapper.so
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:17481ff01045ac335223451c83943a7f97b5c63ca2ab5da3e71d0909c8f4e68b
|
||||
size 84578328
|
||||
oid sha256:64aec9fe985b5dd0d38d9b76ee6f2fde14a183bfe44de9f0148fc482af086a48
|
||||
size 84643008
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
270f246f5eccb170a759cda4787216f4 libtensorrt_llm_nvrtc_wrapper.so
|
||||
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
|
||||
232f492424a31204a2be2e67be299aef libtensorrt_llm_nvrtc_wrapper.so
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2c9102fab7a800539f1fabb01d3a443abdfb10ccac94ae5704883260088de71d
|
||||
oid sha256:bed2713947315cf941533dd12b5b98270a2aabd584cc33bc2092be6dbf879959
|
||||
size 1128448
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
c5f36e093e875c8ea84523fb1566d986 tensorrt_llm_nvrtc_wrapper.lib
|
||||
3be88342e596bda98d75e3b2d31ae484 tensorrt_llm_nvrtc_wrapper.dll
|
||||
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
|
||||
aaa20992c207e46eab50dd90bcf3c405 tensorrt_llm_nvrtc_wrapper.dll
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
@ -33,9 +33,7 @@
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
class XQAKernelList
|
||||
@ -77,13 +75,13 @@ public:
|
||||
}
|
||||
else
|
||||
{
|
||||
cuErrCheck(mDriver->cuModuleLoadData(&hmod, kernelMeta.mCubin), mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuModuleLoadData(&hmod, kernelMeta.mCubin));
|
||||
mModules.insert(std::make_pair(kernelMeta.mCubin, hmod));
|
||||
}
|
||||
|
||||
XQAKernelFuncInfo funcInfo{};
|
||||
funcInfo.mMetaInfoIndex = i;
|
||||
cuErrCheck(mDriver->cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName), mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuModuleGetFunction(&funcInfo.mDeviceFunction, hmod, kernelMeta.mFuncName));
|
||||
funcInfo.mSharedMemBytes = getGlobalVar<uint32_t>(mDriver, hmod, "smemSize", true).value();
|
||||
funcInfo.mKernelType = getGlobalVar<XQAKernelType>(mDriver, hmod, "kernelType", false)
|
||||
.value_or(XQAKernelType::kAMPERE_WARP_SPECIALIZED);
|
||||
@ -91,9 +89,8 @@ public:
|
||||
/* Set 46KB threshold here because we have to take static/driver shared memory into consideration. */
|
||||
if (funcInfo.mSharedMemBytes >= 46 * 1024)
|
||||
{
|
||||
cuErrCheck(mDriver->cuFuncSetAttribute(funcInfo.mDeviceFunction,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, funcInfo.mSharedMemBytes),
|
||||
mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuFuncSetAttribute(funcInfo.mDeviceFunction,
|
||||
CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, funcInfo.mSharedMemBytes));
|
||||
}
|
||||
XQAKernelRuntimeHashKey hash_key{kernelMeta.mKVDataType, kernelMeta.mHeadDim, kernelMeta.mBeamWidth,
|
||||
kernelMeta.mNumQHeadsOverKV, kernelMeta.mMTileSize, kernelMeta.mTokensPerPage, kernelMeta.mPagedKVCache,
|
||||
@ -256,9 +253,8 @@ public:
|
||||
sizeof(int) * xqaParams.batch_size * qSeqLen * xqaParams.num_kv_heads, stream));
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp,
|
||||
xqaParams.batch_size, 128, 1, 2, shared_mem_bytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads * nbTokenBlocksPerGrp,
|
||||
xqaParams.batch_size, 128, 1, 2, shared_mem_bytes, stream, kernelParams, nullptr));
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -299,9 +295,8 @@ public:
|
||||
{
|
||||
multi_block = computeMultiBlockCount(xqaParams, xqaParams.batch_size, multiprocessor_count);
|
||||
}
|
||||
cuErrCheck(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads, xqaParams.batch_size, 128, 1,
|
||||
isGmmaKernel ? 3 : 2, shared_mem_bytes, stream, kernelParams, nullptr),
|
||||
mDriver);
|
||||
TLLM_CU_CHECK(mDriver->cuLaunchKernel(func, multi_block, xqaParams.num_kv_heads, xqaParams.batch_size, 128,
|
||||
1, isGmmaKernel ? 3 : 2, shared_mem_bytes, stream, kernelParams, nullptr));
|
||||
}
|
||||
|
||||
sync_check_cuda_error();
|
||||
@ -417,7 +412,7 @@ bool DecoderXQAImplPrecompiled::shouldUse(XQAParams const& xqaParams, bool forCo
|
||||
}
|
||||
bool const isGPTJBeam4Kernel = (xqaParams.head_size == 256 && xqaParams.beam_width == 4 && xqaParams.paged_kv_cache
|
||||
&& (xqaParams.tokens_per_block == 64 || xqaParams.tokens_per_block == 128));
|
||||
if (xqaParams.head_size != 128 && xqaParams.head_size != 256 && !isGPTJBeam4Kernel)
|
||||
if (xqaParams.head_size != 64 && xqaParams.head_size != 128 && xqaParams.head_size != 256 && !isGPTJBeam4Kernel)
|
||||
{
|
||||
SUPPORT_RETURN_FALSE("head_size");
|
||||
}
|
||||
@ -512,5 +507,4 @@ void DecoderXQAImplPrecompiled::runWithKVBlockArray(
|
||||
runDispatchBuffer<KVBlockArray>(xqa_params, kv_block_array, stream);
|
||||
}
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::kernels
|
||||
|
||||
@ -19,9 +19,7 @@
|
||||
#include <cstdint>
|
||||
#include <type_traits>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
namespace
|
||||
@ -75,10 +73,9 @@ CUtensorMap makeTensorMapForPagedKVCache(std::shared_ptr<CUDADriverWrapper> cons
|
||||
}
|
||||
}();
|
||||
|
||||
cuErrCheck(driver->cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast<void*>(addr), globalDims,
|
||||
globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE),
|
||||
driver);
|
||||
TLLM_CU_CHECK(driver->cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast<void*>(addr), globalDims,
|
||||
globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
|
||||
return tensorMap;
|
||||
}
|
||||
|
||||
@ -107,10 +104,9 @@ CUtensorMap makeTensorMapForContiguousKVCache(std::shared_ptr<CUDADriverWrapper>
|
||||
}
|
||||
}();
|
||||
|
||||
cuErrCheck(driver->cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast<void*>(addr), globalDims,
|
||||
globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle,
|
||||
CU_TENSOR_MAP_L2_PROMOTION_NONE, CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE),
|
||||
driver);
|
||||
TLLM_CU_CHECK(driver->cuTensorMapEncodeTiled(&tensorMap, dataType, 4, const_cast<void*>(addr), globalDims,
|
||||
globalStrides, boxDims, elemStrides, CU_TENSOR_MAP_INTERLEAVE_NONE, swizzle, CU_TENSOR_MAP_L2_PROMOTION_NONE,
|
||||
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
|
||||
return tensorMap;
|
||||
}
|
||||
|
||||
@ -139,5 +135,4 @@ template CUtensorMap makeTensorMapForKVCache(
|
||||
template CUtensorMap makeTensorMapForKVCache(
|
||||
std::shared_ptr<CUDADriverWrapper> const&, XQAParams const&, KVLinearBuffer const&);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
} // namespace tensorrt_llm::kernels
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:b9ddbaa50b4b51d158e163aa3160bfee88a2e0e3c987fa4a883e14066b9c09e2
|
||||
size 21861322
|
||||
oid sha256:8032548ca52a51b3245dcff4fd834e02b93a00e61146be5e418aff94d3e655cb
|
||||
size 21863082
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8cee9ceda43c0321aa2c374383eb00ee8347e5b7aadb57473dc16b2a3ceeef39
|
||||
size 22133914
|
||||
oid sha256:71ef05b9741f027279efedeec8f9d598299a1348cc0e37022424645f9efccd22
|
||||
size 22111930
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
cc18fb3ea093606954fcfc7247418571 libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
c9ae57fa6eb2d950cea0acb08483e978 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
|
||||
f1820d73fc5cac7fa324d71933e5412a libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
a8785db1cc11e3b571bc071d7abec1a8 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:1f7c5f0a38b061bc373c095c4d4790ffeb23ca6595f9dfc52f81f6a9f772dbf1
|
||||
size 36622632
|
||||
oid sha256:4b6917794ec6e67989fdcd0af3cc4d84713f3d8d4dcd822d2df2272117c66d6b
|
||||
size 36626184
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:0c53931240f61f25a54aa74835c732440366c54c70b43372b5ad8f1b0a140562
|
||||
size 36094714
|
||||
oid sha256:e6d2f3c25a8ce88917ba512eba804f14827703fab6f9ac8d63043e2d95b6b281
|
||||
size 36080026
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
fb560fdda660a1fb8a78ca605329697f libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
0d2d1e33a0e588fc1653feaccecff418 libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
|
||||
9e6ff6d826caeea1e6e19c71f5d0986b libtensorrt_llm_internal_cutlass_kernels_static.a
|
||||
2fca4e76d5f21089f00b1d39f624b80a libtensorrt_llm_internal_cutlass_kernels_static.pre_cxx11.a
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:788694476443028c72cede1228a9beb9be431d5b871bb4e076fd2b3ddd184455
|
||||
size 2669966
|
||||
oid sha256:c2f34df6d47b7b2b6629358bb03b33eb193db067188e8b980598027b0ff85392
|
||||
size 2669968
|
||||
|
||||
@ -1,2 +1,2 @@
|
||||
aebeb3c0b4864efa09724a47638098c2 tensorrt_llm_internal_cutlass_kernels_static.lib
|
||||
92c307ad86369ee668e2a6eb9d8d5e7ce549f4bb commit
|
||||
95c2f50347d4de94e2e09cbf0cf99582 tensorrt_llm_internal_cutlass_kernels_static.lib
|
||||
1c2eb102257f836cd50faf985e693241d7a84dbe commit
|
||||
@ -163,8 +163,10 @@ __launch_bounds__(TPB) __global__ void moeTopK(float const* inputs_after_softmax
|
||||
|
||||
for (int prior_k = startk; prior_k < k_idx; ++prior_k)
|
||||
{
|
||||
int const prior_winning_expert = indices[k * block_row + prior_k];
|
||||
|
||||
int prior_winning_expert = indices[k * block_row + prior_k];
|
||||
// Adjust the selected index to correct for the expert parallel transformation
|
||||
prior_winning_expert = prior_winning_expert >= num_experts ? prior_winning_expert - num_experts
|
||||
: prior_winning_expert + start_expert;
|
||||
if (prior_winning_expert == expert)
|
||||
{
|
||||
inp_kvp = thread_kvp;
|
||||
|
||||
74
cpp/tensorrt_llm/kernels/qserveGemm.h
Normal file
74
cpp/tensorrt_llm/kernels/qserveGemm.h
Normal file
@ -0,0 +1,74 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace qserve
|
||||
{
|
||||
|
||||
struct ParamsPerGroup
|
||||
{
|
||||
int8_t const* A;
|
||||
int8_t const* B;
|
||||
int8_t const* s2_zeros;
|
||||
int8_t const* s2_scales;
|
||||
half const* s1_scales;
|
||||
half const* act_scales;
|
||||
half* C;
|
||||
int m;
|
||||
int n;
|
||||
int k;
|
||||
};
|
||||
|
||||
struct ParamsPerChannel
|
||||
{
|
||||
int8_t const* A;
|
||||
int8_t const* B;
|
||||
half const* s1_scales;
|
||||
half const* s1_szeros;
|
||||
half const* act_sums;
|
||||
half const* act_scales;
|
||||
half* C;
|
||||
int m;
|
||||
int n;
|
||||
int k;
|
||||
};
|
||||
|
||||
class QServeGemmRunner
|
||||
{
|
||||
public:
|
||||
void gemmPerGroup(ParamsPerGroup const& params, cudaStream_t stream);
|
||||
void gemmPerChannel(ParamsPerChannel const& params, cudaStream_t stream);
|
||||
|
||||
// We do not use workspace for now.
|
||||
// char* workspacePtr, const size_t workspaceBytes, cudaStream_t stream);
|
||||
|
||||
// Returns desired workspace size in bytes.
|
||||
size_t getWorkspaceSize(int const m, int const n, int const k);
|
||||
|
||||
// virtual std::vector<tkc::CutlassGemmConfig> getConfigs() const = 0;
|
||||
};
|
||||
|
||||
} // namespace qserve
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
618
cpp/tensorrt_llm/kernels/qserveGemmPerChannel.cu
Normal file
618
cpp/tensorrt_llm/kernels/qserveGemmPerChannel.cu
Normal file
@ -0,0 +1,618 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
// Implemented by Haotian Tang and Shang Yang.
|
||||
// @article{lin2024qserve,
|
||||
// title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving},
|
||||
// author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and
|
||||
// Han, Song}, journal={arXiv preprint arXiv:2405.04532}, year={2024}
|
||||
// }
|
||||
|
||||
#include "qserveGemm.h"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline_primitives.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace qserve
|
||||
{
|
||||
|
||||
#define OP_M 16
|
||||
#define OP_N 8
|
||||
#define OP_K 32
|
||||
#define INTRIN_M 16
|
||||
#define INTRIN_N 16
|
||||
#define INTRIN_K 32
|
||||
#define WARP_SIZE 32
|
||||
#define SMEM_PAD_A 0
|
||||
#define SMEM_PAD_B 0
|
||||
#define PACK_SIZE 16
|
||||
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
|
||||
#define L2_CACHEHINT(size) ".L2::" #size "B"
|
||||
#else
|
||||
#define L2_CACHEHINT(size)
|
||||
#endif
|
||||
#define KERNEL_LAUNCH_CODE \
|
||||
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
|
||||
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
|
||||
constexpr int kSmemByteSize \
|
||||
= ((CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / 2) * STAGES + SCALES_SMEM_SIZE) \
|
||||
* sizeof(int8_t); \
|
||||
if (kSmemByteSize >= 99 * 1024) \
|
||||
{ \
|
||||
printf( \
|
||||
"This kernel requires %d Bytes of shared memory, which exceeds " \
|
||||
"device limit.\n", \
|
||||
kSmemByteSize); \
|
||||
return; \
|
||||
} \
|
||||
int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \
|
||||
int num_blocks_n = num_out_channels / CTA_N / 1; \
|
||||
const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \
|
||||
const int tile_shift = 1 << log_tile; \
|
||||
dim3 num_blocks(num_blocks_n* tile_shift, (num_blocks_m + tile_shift - 1) / tile_shift); \
|
||||
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
|
||||
auto kernel_func = dense_kernel0<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>; \
|
||||
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
|
||||
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize, stream>>>(in_feats, kernel, wscales, ascales, w_szs, \
|
||||
a_ssums, out_feats, num_in_feats, num_out_channels, num_in_channels);
|
||||
|
||||
template <int N>
|
||||
inline __host__ __device__ int get_log_tile(int n)
|
||||
{
|
||||
if (N >= 8 && n >= 6)
|
||||
return 3;
|
||||
else if (N >= 4 && n >= 3)
|
||||
return 2;
|
||||
else if (N >= 2 && n >= 2)
|
||||
return 1;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
inline __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile)
|
||||
{
|
||||
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
|
||||
}
|
||||
|
||||
inline __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
|
||||
{
|
||||
uint32_t smem_int_ptr;
|
||||
|
||||
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, "
|
||||
"smem_ptr; }\n"
|
||||
: "=r"(smem_int_ptr)
|
||||
: "l"(ptr));
|
||||
|
||||
return smem_int_ptr;
|
||||
}
|
||||
|
||||
inline __device__ void ldmatrix_m8n8_x4_b16(int8_t* shared_warp, int ax0_0, uint32_t addr)
|
||||
{
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[0]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[1]),
|
||||
"=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[2]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
inline __device__ void ldmatrix_m8n8_x4_trans_b16(int8_t* shared_warp, int ax0_0, uint32_t addr)
|
||||
{
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[0]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[1]),
|
||||
"=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[2]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
// function from lmdeploy
|
||||
inline __device__ void cp_async_cg_A(uint32_t smem_int_ptr, uint4 const* __restrict__ src, bool mask)
|
||||
{
|
||||
int const cp_size = 16;
|
||||
asm volatile("{"
|
||||
" .reg .pred p;"
|
||||
" setp.ne.b32 p, %0, 0;"
|
||||
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
|
||||
"}" ::"r"((int)mask),
|
||||
"r"(smem_int_ptr),
|
||||
"l"(src),
|
||||
"n"(cp_size));
|
||||
}
|
||||
|
||||
__device__ inline void mma_m16n8k32(void* C_warp, void* A_shared_warp, void* B_shared_warp)
|
||||
{
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
|
||||
: "=r"(((int*) C_warp)[0]), "=r"(((int*) C_warp)[1]), "=r"(((int*) C_warp)[2]), "=r"(((int*) C_warp)[3])
|
||||
: "r"(((unsigned*) A_shared_warp)[0]), "r"(((unsigned*) A_shared_warp)[1]), "r"(((unsigned*) A_shared_warp)[2]),
|
||||
"r"(((unsigned*) A_shared_warp)[3]), "r"(((unsigned*) B_shared_warp)[0]), "r"(((unsigned*) B_shared_warp)[1]),
|
||||
"r"(((int*) C_warp)[0]), "r"(((int*) C_warp)[1]), "r"(((int*) C_warp)[2]), "r"(((int*) C_warp)[3]));
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
|
||||
__device__ inline void global_to_share_one_stage_A(int8_t const* src, int8_t* dst, int global_ncols, int cta_offset_m,
|
||||
int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask, bool* preds)
|
||||
{
|
||||
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int partial_global_iters = total_global_iters / SHARED_K_ITERS;
|
||||
constexpr int cta_step_m_or_n = (CTA_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int threads_per_row = CTA_K / PACK_SIZE;
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
|
||||
int8_t* dst_hoisted = dst;
|
||||
int8_t const* src_hoisted = src + global_iter_k * CTA_K;
|
||||
|
||||
if (mask)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
|
||||
{
|
||||
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
|
||||
|
||||
void* dst_ptr = (void*) (dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol);
|
||||
uint4* src_ptr = (uint4*) (src_hoisted + global_iter * cta_step_m_or_n * global_ncols);
|
||||
// *dst_ptr = *src_ptr;
|
||||
if constexpr (STAGES > 1)
|
||||
{
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, preds[global_iter]);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (preds[global_iter])
|
||||
*(uint4*) dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
|
||||
__device__ inline void global_to_share_one_stage_B(int8_t const* src, int8_t* dst, int global_ncols, int cta_offset_m,
|
||||
int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
|
||||
{
|
||||
constexpr int total_global_iters = (CTA_N * CTA_K) / 32 / CTA_SIZE;
|
||||
constexpr int NUM_WARPS = CTA_SIZE / WARP_SIZE;
|
||||
constexpr int warps_per_row = CTA_K / 32;
|
||||
constexpr int cta_step_m_or_n = NUM_WARPS / warps_per_row;
|
||||
constexpr int kSmemCol = CTA_K;
|
||||
int8_t* dst_hoisted = dst;
|
||||
int8_t const* src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE;
|
||||
|
||||
#pragma unroll
|
||||
for (int global_iter = 0; global_iter < total_global_iters; ++global_iter)
|
||||
{
|
||||
void* dst_ptr = (void*) (dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol * PACK_SIZE);
|
||||
uint4* src_ptr = (uint4*) (src_hoisted + global_iter * cta_step_m_or_n * global_ncols * PACK_SIZE);
|
||||
if constexpr (STAGES > 1)
|
||||
{
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (mask)
|
||||
*(uint4*) dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
|
||||
__device__ inline void global_to_share_one_stage_zeros(int8_t const* src, int8_t* dst, int global_ncols,
|
||||
int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
|
||||
{
|
||||
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
|
||||
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
|
||||
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
|
||||
constexpr int threads_per_row = CTA_N / PACK_SIZE;
|
||||
constexpr int kSmemCol = CTA_N;
|
||||
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
|
||||
int g_idx = global_iter_k * CTA_K / G;
|
||||
|
||||
void* dst_ptr = (void*) (dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
|
||||
uint4* src_ptr = (uint4*) (src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
|
||||
if (STAGES > 1)
|
||||
{
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, local_mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (local_mask)
|
||||
{
|
||||
*(uint4*) dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES>
|
||||
__device__ inline void share_to_reg_one_stage_A(
|
||||
int8_t* src, int8_t* dst, int warp_offset_m, int warp_offset_n, int k_0_1, int shared_iters)
|
||||
{
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
|
||||
int ld_col = (k_0_1 * INTRIN_K + (threadIdx.x / 16) * 16) / PACK_SIZE;
|
||||
|
||||
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
|
||||
{
|
||||
int ld_row = warp_offset_m + shared_iter * INTRIN_M + (threadIdx.x % 16);
|
||||
int ld_col_swizzled = ld_col ^ (ld_row / 2) & 3;
|
||||
void* addr_ptr = (void*) (src + ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE);
|
||||
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
|
||||
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
|
||||
}
|
||||
}
|
||||
|
||||
template <int WARP_K, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
|
||||
__device__ inline void share_to_reg_one_stage_B(int8_t* src, int8_t* dst, int8_t* zeros, int8_t* scales_i8,
|
||||
int warp_offset_m, int warp_offset_n, int k_0_0, int k_0_1, int shared_iters)
|
||||
{
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
|
||||
#pragma unroll
|
||||
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
|
||||
{
|
||||
uint4 loaded = *((uint4*) (src) + warp_offset_n / 32 * kSmemCol + shared_iter * 32 / 32 * kSmemCol
|
||||
+ k_0_1 * INTRIN_K + threadIdx.x);
|
||||
|
||||
auto ptr = (uint32_t*) dst + shared_iter * 8;
|
||||
ptr[0] = loaded.x & 0x0F0F0F0F;
|
||||
ptr[4] = (loaded.x & 0xF0F0F0F0) >> 4;
|
||||
ptr[2] = loaded.y & 0x0F0F0F0F;
|
||||
ptr[6] = (loaded.y & 0xF0F0F0F0) >> 4;
|
||||
ptr[1] = loaded.z & 0x0F0F0F0F;
|
||||
ptr[5] = (loaded.z & 0xF0F0F0F0) >> 4;
|
||||
ptr[3] = loaded.w & 0x0F0F0F0F;
|
||||
ptr[7] = (loaded.w & 0xF0F0F0F0) >> 4;
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
|
||||
__global__ void dense_kernel0(int8_t const* __restrict__ A, int8_t const* __restrict__ B,
|
||||
half2 const* __restrict__ wscales, half const* __restrict__ ascales, half2 const* __restrict__ w_szs,
|
||||
half const* __restrict__ a_ssums, half* __restrict__ C, int M, int64_t N, int64_t K)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
constexpr int SPLITK = 1;
|
||||
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
|
||||
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
|
||||
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
|
||||
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
|
||||
constexpr int SLICES = CTA_K / WARP_K;
|
||||
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
|
||||
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
|
||||
|
||||
int blockIdx_n = blockIdx.x;
|
||||
int blockIdx_m = blockIdx.y;
|
||||
int const log_tile = get_log_tile<8>((M + CTA_M - 1) / CTA_M);
|
||||
uint2 const block_idx_mapping = get_block_idx_mapping(blockIdx_n, blockIdx_m, log_tile);
|
||||
blockIdx_n = block_idx_mapping.x;
|
||||
blockIdx_m = block_idx_mapping.y;
|
||||
|
||||
int C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
|
||||
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
|
||||
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
|
||||
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
|
||||
constexpr int kSmemSizeBPerStage = CTA_N * kSmemPadKB / 2;
|
||||
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
|
||||
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
|
||||
|
||||
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
|
||||
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
|
||||
constexpr int kSmemSizeScales = CTA_N * STAGES;
|
||||
|
||||
extern __shared__ int8_t mem_shared[];
|
||||
int8_t* A_shared = mem_shared;
|
||||
|
||||
int8_t* B_shared = mem_shared + kSmemSizeA;
|
||||
int8_t* zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB;
|
||||
int8_t* scales_i8_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales;
|
||||
|
||||
int8_t A_shared_warp_[2][WARP_M * WARP_K / WARP_SIZE];
|
||||
int8_t B_shared_warp_[2][WARP_N * WARP_K / WARP_SIZE];
|
||||
constexpr int A_total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int B_total_global_iters = (CTA_N * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int A_src_step_m = (CTA_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int A_warp_step_m = (WARP_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int A_threads_per_row = CTA_K / PACK_SIZE;
|
||||
|
||||
constexpr int B_warps_per_row = CTA_K / 32;
|
||||
constexpr int B_src_step_n = NUM_WARPS / B_warps_per_row;
|
||||
|
||||
int cta_offset_m = blockIdx_m * CTA_M;
|
||||
int cta_offset_n = blockIdx_n * CTA_N;
|
||||
int warp_mn = threadIdx.y % NUM_WARPS_MN;
|
||||
int slice_id = threadIdx.y / NUM_WARPS_MN;
|
||||
int warp_offset_m = (warp_mn % (CTA_M / WARP_M)) * WARP_M;
|
||||
int warp_offset_n = (warp_mn / (CTA_M / WARP_M)) * WARP_N;
|
||||
int warp_offset_k = slice_id * WARP_K;
|
||||
|
||||
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
|
||||
C_warp[i] = 0;
|
||||
|
||||
int gemm_iters = (K + CTA_K - 1) / CTA_K;
|
||||
int k_0_0_ld = 0;
|
||||
int k_0_0 = 0;
|
||||
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
|
||||
int A_hoisted_row = threadIdx.y * A_warp_step_m + (threadIdx.x / A_threads_per_row);
|
||||
int A_hoisted_col = (threadIdx.x % A_threads_per_row);
|
||||
int A_hoisted_col_swizzled = A_hoisted_col ^ (A_hoisted_row / 2) & 3;
|
||||
|
||||
int8_t* A_shared_hoisted = A_shared + A_hoisted_row * kSmemPadKA + A_hoisted_col_swizzled * PACK_SIZE;
|
||||
int8_t* B_shared_hoisted = B_shared + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE
|
||||
+ (threadIdx.y / B_warps_per_row) * kSmemPadKB * PACK_SIZE + threadIdx.x * PACK_SIZE;
|
||||
int8_t const* A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + A_hoisted_col * PACK_SIZE;
|
||||
int8_t const* B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE
|
||||
+ (threadIdx.y / B_warps_per_row) * K * PACK_SIZE + threadIdx.x * PACK_SIZE;
|
||||
|
||||
bool A_g2s_preds[A_total_global_iters];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < A_total_global_iters; i++)
|
||||
{
|
||||
A_g2s_preds[i] = (cta_offset_m + A_hoisted_row + i * A_src_step_m) < M;
|
||||
}
|
||||
|
||||
int* C_shared = reinterpret_cast<int*>(mem_shared);
|
||||
|
||||
#pragma unroll
|
||||
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)
|
||||
{
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A_hoisted,
|
||||
A_shared_hoisted + k_0_0_ld * kSmemSizeAPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B_hoisted,
|
||||
B_shared_hoisted + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
|
||||
|
||||
if constexpr (STAGES > 1)
|
||||
__pipeline_commit();
|
||||
}
|
||||
if constexpr (STAGES > 1)
|
||||
__pipeline_wait_prior(STAGES - 2);
|
||||
__syncthreads();
|
||||
|
||||
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(
|
||||
A_shared + warp_offset_k, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0, WARP_M / INTRIN_M);
|
||||
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(B_shared + warp_offset_k * PACK_SIZE,
|
||||
B_shared_warp_[0], zeros_shared, scales_i8_shared, warp_offset_m, warp_offset_n, 0, 0, WARP_N / 32);
|
||||
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
|
||||
|
||||
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)
|
||||
{
|
||||
int ld_stage = k_0_0_ld % STAGES;
|
||||
int compute_stage = k_0_0 % STAGES;
|
||||
int8_t* A_shared_this_compute_stage;
|
||||
int8_t* B_shared_this_compute_stage;
|
||||
int8_t* zeros_shared_this_compute_stage;
|
||||
int8_t* scales_i8_shared_this_compute_stage;
|
||||
|
||||
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)
|
||||
{
|
||||
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage + warp_offset_k;
|
||||
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage + warp_offset_k * PACK_SIZE;
|
||||
zeros_shared_this_compute_stage = zeros_shared + (compute_stage) *CTA_N;
|
||||
scales_i8_shared_this_compute_stage = scales_i8_shared + (compute_stage) *CTA_N;
|
||||
|
||||
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(A_shared_this_compute_stage,
|
||||
A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS,
|
||||
WARP_M / INTRIN_M);
|
||||
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(B_shared_this_compute_stage,
|
||||
B_shared_warp_[(iter_k + 1) % 2], zeros_shared_this_compute_stage, scales_i8_shared_this_compute_stage,
|
||||
warp_offset_m, warp_offset_n, k_0_0 + (iter_k == SHARED_K_ITERS - 1), (iter_k + 1) % SHARED_K_ITERS,
|
||||
WARP_N / 32);
|
||||
int8_t* A_shared_warp = A_shared_warp_[iter_k % 2];
|
||||
int8_t* B_shared_warp = B_shared_warp_[iter_k % 2];
|
||||
|
||||
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)
|
||||
{
|
||||
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)
|
||||
{
|
||||
mma_m16n8k32((void*) (C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8),
|
||||
(void*) (A_shared_warp + i_0_3 * 16), (void*) (B_shared_warp + j_0_4 * 16));
|
||||
mma_m16n8k32((void*) (C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4),
|
||||
(void*) (A_shared_warp + i_0_3 * 16), (void*) (B_shared_warp + j_0_4 * 16 + 8));
|
||||
}
|
||||
}
|
||||
|
||||
if (iter_k < SHARED_K_ITERS - 1)
|
||||
{
|
||||
if constexpr (STAGES == 1)
|
||||
__syncthreads();
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A_hoisted,
|
||||
A_shared_hoisted + ld_stage * kSmemSizeAPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k,
|
||||
k_0_0_ld < gemm_iters, A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B_hoisted,
|
||||
B_shared_hoisted + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k,
|
||||
k_0_0_ld < gemm_iters);
|
||||
}
|
||||
|
||||
if (iter_k == SHARED_K_ITERS - 2)
|
||||
{
|
||||
if constexpr (STAGES == 1 && SHARED_K_ITERS > 2)
|
||||
{
|
||||
__syncthreads();
|
||||
}
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A_hoisted,
|
||||
A_shared_hoisted + ld_stage * kSmemSizeAPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld,
|
||||
iter_k + 1, k_0_0_ld < gemm_iters, A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B_hoisted,
|
||||
B_shared_hoisted + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld,
|
||||
iter_k + 1, k_0_0_ld < gemm_iters);
|
||||
if constexpr (STAGES > 1)
|
||||
{
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(STAGES - 2);
|
||||
}
|
||||
compute_stage = (k_0_0 + 1) % STAGES;
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (SLICES > 1)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int z = 0; z < SLICES; ++z)
|
||||
{
|
||||
if (slice_id == z)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
|
||||
{
|
||||
if (z > 0)
|
||||
{
|
||||
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]
|
||||
+= C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n
|
||||
+ ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N
|
||||
+ (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
|
||||
}
|
||||
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16
|
||||
+ ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8
|
||||
+ (local_id % 2) + (threadIdx.x % 4) * 2]
|
||||
= C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (slice_id == 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
|
||||
{
|
||||
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]
|
||||
= C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16
|
||||
+ ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8
|
||||
+ (local_id % 2) + (threadIdx.x % 4) * 2];
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int row_wb_thd = cta_offset_m + warp_offset_m + (threadIdx.x / 4);
|
||||
int col_wb_thd = cta_offset_n + warp_offset_n + (threadIdx.x % 4) * 2;
|
||||
if (slice_id == 0)
|
||||
{
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
|
||||
{
|
||||
int row_wb_1 = row_wb_thd + ax0_0_1 * OP_M;
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
|
||||
{
|
||||
int col_wb_1 = col_wb_thd + ax1_0_1 * 16;
|
||||
int* C_warp_local = C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8;
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
|
||||
{
|
||||
int row_wb = row_wb_1 + (local_id % 4) / 2 * 8;
|
||||
if (row_wb < M)
|
||||
{
|
||||
int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2);
|
||||
float2 wscale = __half22float2(*(wscales + col_wb / 2));
|
||||
float2 w_sz = __half22float2(*(w_szs + col_wb / 2));
|
||||
float ascale = __half2float(ascales[row_wb]);
|
||||
float a_ssum = __half2float(a_ssums[row_wb]);
|
||||
float2 psums = make_float2(
|
||||
__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1]));
|
||||
psums.x = psums.x * wscale.x * ascale - w_sz.x * a_ssum;
|
||||
psums.y = psums.y * wscale.y * ascale - w_sz.y * a_ssum;
|
||||
*reinterpret_cast<half2*>(C + row_wb * N + col_wb) = __float22half2_rn(psums);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void qserveGemmPerChannelLaunch(ParamsPerChannel const& params, cudaStream_t stream)
|
||||
{
|
||||
auto in_feats = params.A;
|
||||
auto kernel = params.B;
|
||||
auto w_szs = reinterpret_cast<half2 const*>(params.s1_szeros);
|
||||
auto a_ssums = params.act_sums;
|
||||
auto wscales = reinterpret_cast<half2 const*>(params.s1_scales);
|
||||
auto ascales = params.act_scales;
|
||||
|
||||
auto out_feats = params.C;
|
||||
|
||||
int num_out_feats = params.m;
|
||||
int num_out_channels = params.n;
|
||||
int num_in_feats = params.m;
|
||||
int num_in_channels = params.k;
|
||||
|
||||
constexpr int G = 128;
|
||||
|
||||
if (num_out_feats > 256)
|
||||
{
|
||||
constexpr int CTA_M = 128;
|
||||
constexpr int CTA_N = 128;
|
||||
constexpr int CTA_K = 64;
|
||||
constexpr int WARP_M = 64;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 3;
|
||||
KERNEL_LAUNCH_CODE
|
||||
}
|
||||
else if (num_out_feats >= 128)
|
||||
{
|
||||
constexpr int CTA_M = 64;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 64;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 4;
|
||||
KERNEL_LAUNCH_CODE
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int CTA_M = 32;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 128;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 3;
|
||||
KERNEL_LAUNCH_CODE
|
||||
}
|
||||
}
|
||||
|
||||
void QServeGemmRunner::gemmPerChannel(ParamsPerChannel const& params, cudaStream_t stream)
|
||||
{
|
||||
qserveGemmPerChannelLaunch(params, stream);
|
||||
}
|
||||
|
||||
} // namespace qserve
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
677
cpp/tensorrt_llm/kernels/qserveGemmPerGroup.cu
Normal file
677
cpp/tensorrt_llm/kernels/qserveGemmPerGroup.cu
Normal file
@ -0,0 +1,677 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
// Implemented by Haotian Tang and Shang Yang.
|
||||
// @article{lin2024qserve,
|
||||
// title={QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving},
|
||||
// author={Lin*, Yujun and Tang*, Haotian and Yang*, Shang and Zhang, Zhekai and Xiao, Guangxuan and Gan, Chuang and
|
||||
// Han, Song}, journal={arXiv preprint arXiv:2405.04532}, year={2024}
|
||||
// }
|
||||
|
||||
#include "qserveGemm.h"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_pipeline_primitives.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
namespace qserve
|
||||
{
|
||||
|
||||
#define OP_M 16
|
||||
#define OP_N 8
|
||||
#define OP_K 32
|
||||
#define INTRIN_M 16
|
||||
#define INTRIN_N 16
|
||||
#define INTRIN_K 32
|
||||
#define WARP_SIZE 32
|
||||
#define SMEM_PAD_A 0
|
||||
#define SMEM_PAD_B 0
|
||||
#define PACK_SIZE 16
|
||||
#if (__CUDACC_VER_MAJOR__ >= 11) && (__CUDACC_VER_MINOR__ >= 4)
|
||||
#define L2_CACHEHINT(size) ".L2::" #size "B"
|
||||
#else
|
||||
#define L2_CACHEHINT(size)
|
||||
#endif
|
||||
#define KERNEL_LAUNCH_CODE \
|
||||
constexpr int NUM_WARPS = (CTA_M / WARP_M) * (CTA_N / WARP_N) * (CTA_K / WARP_K); \
|
||||
constexpr int SCALES_SMEM_SIZE = (G >= CTA_K) ? (CTA_N * STAGES * 2) : (CTA_N * (CTA_K / G) * STAGES * 2); \
|
||||
constexpr int kSmemByteSize \
|
||||
= ((CTA_M * (CTA_K + SMEM_PAD_A) + CTA_N * (CTA_K + SMEM_PAD_B) / 2) * STAGES + SCALES_SMEM_SIZE) \
|
||||
* sizeof(int8_t); \
|
||||
if (kSmemByteSize >= 99 * 1024) \
|
||||
{ \
|
||||
printf( \
|
||||
"This kernel requires %d Bytes of shared memory, which exceeds " \
|
||||
"device limit.\n", \
|
||||
kSmemByteSize); \
|
||||
return; \
|
||||
} \
|
||||
int num_blocks_m = (num_out_feats + CTA_M - 1) / CTA_M; \
|
||||
int num_blocks_n = num_out_channels / CTA_N / 1; \
|
||||
const int log_tile = get_log_tile<8>((num_out_feats + CTA_M - 1) / CTA_M); \
|
||||
const int tile_shift = 1 << log_tile; \
|
||||
dim3 num_blocks(num_blocks_n* tile_shift, (num_blocks_m + tile_shift - 1) / tile_shift); \
|
||||
dim3 threads_per_block(WARP_SIZE, NUM_WARPS); \
|
||||
auto kernel_func = dense_kernel0<CTA_M, CTA_N, CTA_K, WARP_M, WARP_N, WARP_K, STAGES, G>; \
|
||||
cudaFuncSetAttribute(kernel_func, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemByteSize); \
|
||||
kernel_func<<<num_blocks, threads_per_block, kSmemByteSize, stream>>>(in_feats, kernel, zeros, scales_i8, wscales, \
|
||||
ascales, out_feats, num_in_feats, num_out_channels, num_in_channels);
|
||||
|
||||
template <int N>
|
||||
inline __host__ __device__ int get_log_tile(int n)
|
||||
{
|
||||
if (N >= 8 && n >= 6)
|
||||
return 3;
|
||||
else if (N >= 4 && n >= 3)
|
||||
return 2;
|
||||
else if (N >= 2 && n >= 2)
|
||||
return 1;
|
||||
else
|
||||
return 0;
|
||||
}
|
||||
|
||||
inline __device__ uint2 get_block_idx_mapping(int blockIdx_x, int blockIdx_y, int log_tile)
|
||||
{
|
||||
return make_uint2((blockIdx_x >> log_tile), (blockIdx_y << log_tile) + ((blockIdx_x) & ((1 << (log_tile)) - 1)));
|
||||
}
|
||||
|
||||
inline __device__ uint32_t cast_smem_ptr_to_uint(void const* const ptr)
|
||||
{
|
||||
uint32_t smem_int_ptr;
|
||||
|
||||
asm("{.reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, "
|
||||
"smem_ptr; }\n"
|
||||
: "=r"(smem_int_ptr)
|
||||
: "l"(ptr));
|
||||
|
||||
return smem_int_ptr;
|
||||
}
|
||||
|
||||
inline __device__ void ldmatrix_m8n8_x4_b16(int8_t* shared_warp, int ax0_0, uint32_t addr)
|
||||
{
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[0]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[1]),
|
||||
"=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[2]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
inline __device__ void ldmatrix_m8n8_x4_trans_b16(int8_t* shared_warp, int ax0_0, uint32_t addr)
|
||||
{
|
||||
asm volatile(
|
||||
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
|
||||
"{%0, %1, %2, %3}, [%4];"
|
||||
: "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[0]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[1]),
|
||||
"=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[2]), "=r"(((unsigned*) (shared_warp + (ax0_0 * 16)))[3])
|
||||
: "r"(addr));
|
||||
}
|
||||
|
||||
// function from lmdeploy
|
||||
inline __device__ void cp_async_cg_A(uint32_t smem_int_ptr, uint4 const* __restrict__ src, bool mask)
|
||||
{
|
||||
int const cp_size = 16;
|
||||
asm volatile("{"
|
||||
" .reg .pred p;"
|
||||
" setp.ne.b32 p, %0, 0;"
|
||||
" @p cp.async.cg.shared.global" L2_CACHEHINT(128) " [%1], [%2], %3;"
|
||||
"}" ::"r"((int)mask),
|
||||
"r"(smem_int_ptr),
|
||||
"l"(src),
|
||||
"n"(cp_size));
|
||||
}
|
||||
|
||||
__device__ inline void mma_m16n8k32(void* C_warp, void* A_shared_warp, void* B_shared_warp)
|
||||
{
|
||||
asm volatile(
|
||||
"mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32"
|
||||
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};"
|
||||
: "=r"(((int*) C_warp)[0]), "=r"(((int*) C_warp)[1]), "=r"(((int*) C_warp)[2]), "=r"(((int*) C_warp)[3])
|
||||
: "r"(((unsigned*) A_shared_warp)[0]), "r"(((unsigned*) A_shared_warp)[1]), "r"(((unsigned*) A_shared_warp)[2]),
|
||||
"r"(((unsigned*) A_shared_warp)[3]), "r"(((unsigned*) B_shared_warp)[0]), "r"(((unsigned*) B_shared_warp)[1]),
|
||||
"r"(((int*) C_warp)[0]), "r"(((int*) C_warp)[1]), "r"(((int*) C_warp)[2]), "r"(((int*) C_warp)[3]));
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
|
||||
__device__ inline void global_to_share_one_stage_A(int8_t const* src, int8_t* dst, int global_ncols, int cta_offset_m,
|
||||
int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask, bool* preds)
|
||||
{
|
||||
constexpr int total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int partial_global_iters = total_global_iters / SHARED_K_ITERS;
|
||||
constexpr int cta_step_m_or_n = (CTA_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int warp_step_m_or_n = (WARP_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int threads_per_row = CTA_K / PACK_SIZE;
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
|
||||
int8_t* dst_hoisted = dst;
|
||||
int8_t const* src_hoisted = src + global_iter_k * CTA_K;
|
||||
|
||||
if (mask)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int _global_iter = 0; _global_iter < partial_global_iters; ++_global_iter)
|
||||
{
|
||||
int global_iter = shared_iter_k * partial_global_iters + _global_iter;
|
||||
void* dst_ptr = (void*) (dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol);
|
||||
uint4* src_ptr = (uint4*) (src_hoisted + global_iter * cta_step_m_or_n * global_ncols);
|
||||
if constexpr (STAGES > 1)
|
||||
{
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, preds[global_iter]);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (preds[global_iter])
|
||||
*(uint4*) dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int SHARED_K_ITERS, int STAGES>
|
||||
__device__ inline void global_to_share_one_stage_B(int8_t const* src, int8_t* dst, int global_ncols, int cta_offset_m,
|
||||
int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
|
||||
{
|
||||
constexpr int total_global_iters = (CTA_N * CTA_K) / 32 / CTA_SIZE;
|
||||
constexpr int NUM_WARPS = CTA_SIZE / WARP_SIZE;
|
||||
constexpr int warps_per_row = CTA_K / 32;
|
||||
constexpr int cta_step_m_or_n = NUM_WARPS / warps_per_row;
|
||||
constexpr int kSmemCol = CTA_K;
|
||||
int8_t* dst_hoisted = dst;
|
||||
int8_t const* src_hoisted = src + global_iter_k * CTA_K * PACK_SIZE;
|
||||
|
||||
#pragma unroll
|
||||
for (int global_iter = 0; global_iter < total_global_iters; ++global_iter)
|
||||
{
|
||||
void* dst_ptr = (void*) (dst_hoisted + global_iter * cta_step_m_or_n * kSmemCol * PACK_SIZE);
|
||||
uint4* src_ptr = (uint4*) (src_hoisted + global_iter * cta_step_m_or_n * global_ncols * PACK_SIZE);
|
||||
if constexpr (STAGES > 1)
|
||||
{
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (mask)
|
||||
*(uint4*) dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
|
||||
__device__ inline void global_to_share_one_stage_zeros(int8_t const* src, int8_t* dst, int global_ncols,
|
||||
int cta_offset_m, int cta_offset_n, int global_iter_k, int shared_iter_k, bool mask)
|
||||
{
|
||||
constexpr int threads_needed = CTA_N / PACK_SIZE / 1;
|
||||
constexpr int threads_used = threads_needed < CTA_SIZE ? threads_needed : CTA_SIZE;
|
||||
constexpr int total_global_iters = CTA_N / PACK_SIZE / threads_used;
|
||||
constexpr int threads_per_row = CTA_N / PACK_SIZE;
|
||||
constexpr int kSmemCol = CTA_N;
|
||||
bool local_mask = mask & (threadIdx.y * WARP_SIZE + threadIdx.x < threads_used);
|
||||
int g_idx = global_iter_k * CTA_K / G;
|
||||
|
||||
void* dst_ptr = (void*) (dst + (threadIdx.x % threads_per_row) * PACK_SIZE);
|
||||
uint4 const* src_ptr
|
||||
= (uint4 const*) (src + g_idx * global_ncols + cta_offset_n + (threadIdx.x % threads_per_row) * PACK_SIZE);
|
||||
if (STAGES > 1)
|
||||
{
|
||||
uint32_t addr = cast_smem_ptr_to_uint(dst_ptr);
|
||||
cp_async_cg_A(addr, src_ptr, local_mask);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (local_mask)
|
||||
{
|
||||
*(uint4*) dst_ptr = *src_ptr;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES>
|
||||
__device__ inline void share_to_reg_one_stage_A(
|
||||
int8_t* src, int8_t* dst, int warp_offset_m, int warp_offset_n, int k_0_1, int shared_iters)
|
||||
{
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_A;
|
||||
int ld_col = (k_0_1 * INTRIN_K + (threadIdx.x / 16) * 16) / PACK_SIZE;
|
||||
|
||||
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
|
||||
{
|
||||
int ld_row = warp_offset_m + shared_iter * INTRIN_M + (threadIdx.x % 16);
|
||||
int ld_col_swizzled = ld_col ^ (ld_row / 2) & 3;
|
||||
void* addr_ptr = (void*) (src + ld_row * kSmemCol + ld_col_swizzled * PACK_SIZE);
|
||||
uint32_t addr = cast_smem_ptr_to_uint(addr_ptr);
|
||||
ldmatrix_m8n8_x4_b16(dst, shared_iter, addr);
|
||||
}
|
||||
}
|
||||
|
||||
template <int WARP_K, int CTA_N, int CTA_K, int CTA_SIZE, int STAGES, int G>
|
||||
__device__ inline void share_to_reg_one_stage_B(int8_t* src, int8_t* dst, int8_t* zeros, int8_t* scales_i8,
|
||||
int warp_offset_m, int warp_offset_n, int k_0_0, int k_0_1, int shared_iters)
|
||||
{
|
||||
constexpr int kSmemCol = CTA_K + SMEM_PAD_B;
|
||||
#pragma unroll
|
||||
for (int shared_iter = 0; shared_iter < shared_iters; ++shared_iter)
|
||||
{
|
||||
uint4 loaded = *((uint4*) (src) + warp_offset_n / 32 * kSmemCol + shared_iter * 32 / 32 * kSmemCol
|
||||
+ k_0_1 * INTRIN_K + threadIdx.x);
|
||||
uint32_t loaded_0 = loaded.x & 0x0F0F0F0F;
|
||||
uint32_t loaded_4 = (loaded.x & 0xF0F0F0F0) >> 4;
|
||||
uint32_t loaded_2 = loaded.y & 0x0F0F0F0F;
|
||||
uint32_t loaded_6 = (loaded.y & 0xF0F0F0F0) >> 4;
|
||||
uint32_t loaded_1 = loaded.z & 0x0F0F0F0F;
|
||||
uint32_t loaded_5 = (loaded.z & 0xF0F0F0F0) >> 4;
|
||||
uint32_t loaded_3 = loaded.w & 0x0F0F0F0F;
|
||||
uint32_t loaded_7 = (loaded.w & 0xF0F0F0F0) >> 4;
|
||||
|
||||
auto ptr = (uint32_t*) dst + shared_iter * 8;
|
||||
int scales_zeros_offset = warp_offset_n + (threadIdx.x / 4) * 4 + shared_iter * 32;
|
||||
uint32_t packed_scales = *reinterpret_cast<uint32_t*>(scales_i8 + scales_zeros_offset);
|
||||
uint32_t packed_zeros = *reinterpret_cast<uint32_t*>(zeros + scales_zeros_offset);
|
||||
|
||||
uint32_t scale_0 = packed_scales & 0xFF;
|
||||
uint32_t zero_point_0 = __byte_perm(packed_zeros, 0, 0x00000000);
|
||||
uint32_t ptr_0 = loaded_0 * scale_0;
|
||||
uint32_t ptr_1 = loaded_1 * scale_0;
|
||||
ptr[0] = __vadd4(ptr_0, zero_point_0);
|
||||
ptr[1] = __vadd4(ptr_1, zero_point_0);
|
||||
|
||||
uint32_t scale_1 = (packed_scales & 0xFF00) >> 8;
|
||||
uint32_t zero_point_1 = __byte_perm(packed_zeros, 0, 0x00001111);
|
||||
uint32_t ptr_2 = loaded_2 * scale_1;
|
||||
uint32_t ptr_3 = loaded_3 * scale_1;
|
||||
ptr[2] = __vadd4(ptr_2, zero_point_1);
|
||||
ptr[3] = __vadd4(ptr_3, zero_point_1);
|
||||
|
||||
uint32_t scale_2 = (packed_scales & 0xFF0000) >> 16;
|
||||
uint32_t zero_point_2 = __byte_perm(packed_zeros, 0, 0x00002222);
|
||||
uint32_t ptr_4 = loaded_4 * scale_2;
|
||||
uint32_t ptr_5 = loaded_5 * scale_2;
|
||||
ptr[4] = __vadd4(ptr_4, zero_point_2);
|
||||
ptr[5] = __vadd4(ptr_5, zero_point_2);
|
||||
|
||||
uint32_t scale_3 = (packed_scales & 0xFF000000) >> 24;
|
||||
uint32_t zero_point_3 = __byte_perm(packed_zeros, 0, 0x00003333);
|
||||
uint32_t ptr_6 = loaded_6 * scale_3;
|
||||
uint32_t ptr_7 = loaded_7 * scale_3;
|
||||
ptr[6] = __vadd4(ptr_6, zero_point_3);
|
||||
ptr[7] = __vadd4(ptr_7, zero_point_3);
|
||||
}
|
||||
}
|
||||
|
||||
template <int CTA_M, int CTA_N, int CTA_K, int WARP_M, int WARP_N, int WARP_K, int STAGES, int G>
|
||||
__global__ void dense_kernel0(int8_t const* __restrict__ A, int8_t const* __restrict__ B,
|
||||
int8_t const* __restrict__ zeros, int8_t const* __restrict__ scales_i8, half2 const* __restrict__ wscales,
|
||||
half const* __restrict__ ascales, half* __restrict__ C, int M, int64_t N, int64_t K)
|
||||
{
|
||||
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
|
||||
constexpr int SPLITK = 1;
|
||||
constexpr int NUM_WARPS_MN = CTA_M / WARP_M * CTA_N / WARP_N;
|
||||
constexpr int NUM_WARPS = NUM_WARPS_MN * CTA_K / WARP_K;
|
||||
constexpr int CTA_SIZE = NUM_WARPS * WARP_SIZE;
|
||||
constexpr int CTA_SIZE_MN = NUM_WARPS_MN * WARP_SIZE;
|
||||
constexpr int SLICES = CTA_K / WARP_K;
|
||||
int num_blocks_n = (N + CTA_N - 1) / CTA_N;
|
||||
int num_blocks_m = (M + CTA_M - 1) / CTA_M;
|
||||
|
||||
int blockIdx_n = blockIdx.x;
|
||||
int blockIdx_m = blockIdx.y;
|
||||
int const log_tile = get_log_tile<8>((M + CTA_M - 1) / CTA_M);
|
||||
uint2 const block_idx_mapping = get_block_idx_mapping(blockIdx_n, blockIdx_m, log_tile);
|
||||
blockIdx_n = block_idx_mapping.x;
|
||||
blockIdx_m = block_idx_mapping.y;
|
||||
|
||||
int C_warp[CTA_M * CTA_N / CTA_SIZE_MN];
|
||||
constexpr int kSmemPadKA = CTA_K + SMEM_PAD_A;
|
||||
constexpr int kSmemPadKB = CTA_K + SMEM_PAD_B;
|
||||
constexpr int kSmemSizeAPerStage = CTA_M * kSmemPadKA;
|
||||
constexpr int kSmemSizeBPerStage = CTA_N * kSmemPadKB / 2;
|
||||
constexpr int kSmemSizeA = kSmemSizeAPerStage * STAGES;
|
||||
constexpr int kSmemSizeB = kSmemSizeBPerStage * STAGES;
|
||||
|
||||
constexpr int scales_load_interval = G >= CTA_K ? G / CTA_K : 1;
|
||||
constexpr int scales_per_load = G < CTA_K ? CTA_K / G : 1;
|
||||
constexpr int kSmemSizeScales = CTA_N * STAGES;
|
||||
|
||||
extern __shared__ int8_t mem_shared[];
|
||||
int8_t* A_shared = mem_shared;
|
||||
|
||||
int8_t* B_shared = mem_shared + kSmemSizeA;
|
||||
int8_t* zeros_shared = mem_shared + kSmemSizeA + kSmemSizeB;
|
||||
int8_t* scales_i8_shared = mem_shared + kSmemSizeA + kSmemSizeB + kSmemSizeScales;
|
||||
|
||||
int8_t A_shared_warp_[2][WARP_M * WARP_K / WARP_SIZE];
|
||||
int8_t B_shared_warp_[2][WARP_N * WARP_K / WARP_SIZE];
|
||||
constexpr int A_total_global_iters = (CTA_M * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int B_total_global_iters = (CTA_N * CTA_K) / PACK_SIZE / CTA_SIZE;
|
||||
constexpr int A_src_step_m = (CTA_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int A_warp_step_m = (WARP_SIZE * PACK_SIZE) / CTA_K;
|
||||
constexpr int A_threads_per_row = CTA_K / PACK_SIZE;
|
||||
|
||||
constexpr int B_warps_per_row = CTA_K / 32;
|
||||
constexpr int B_src_step_n = NUM_WARPS / B_warps_per_row;
|
||||
|
||||
int cta_offset_m = blockIdx_m * CTA_M;
|
||||
int cta_offset_n = blockIdx_n * CTA_N;
|
||||
int warp_mn = threadIdx.y % NUM_WARPS_MN;
|
||||
int slice_id = threadIdx.y / NUM_WARPS_MN;
|
||||
int warp_offset_m = (warp_mn % (CTA_M / WARP_M)) * WARP_M;
|
||||
int warp_offset_n = (warp_mn / (CTA_M / WARP_M)) * WARP_N;
|
||||
int warp_offset_k = slice_id * WARP_K;
|
||||
|
||||
for (int i = 0; i < CTA_M * CTA_N / CTA_SIZE_MN; i++)
|
||||
C_warp[i] = 0;
|
||||
|
||||
int gemm_iters = (K + CTA_K - 1) / CTA_K;
|
||||
|
||||
int k_0_0_ld = 0;
|
||||
int k_0_0 = 0;
|
||||
constexpr int prologue_stages = STAGES == 1 ? 1 : STAGES - 1;
|
||||
int A_hoisted_row = threadIdx.y * A_warp_step_m + (threadIdx.x / A_threads_per_row);
|
||||
int A_hoisted_col = (threadIdx.x % A_threads_per_row);
|
||||
int A_hoisted_col_swizzled = A_hoisted_col ^ (A_hoisted_row / 2) & 3;
|
||||
|
||||
int8_t* A_shared_hoisted = A_shared + A_hoisted_row * kSmemPadKA + A_hoisted_col_swizzled * PACK_SIZE;
|
||||
int8_t* B_shared_hoisted = B_shared + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE
|
||||
+ (threadIdx.y / B_warps_per_row) * kSmemPadKB * PACK_SIZE + threadIdx.x * PACK_SIZE;
|
||||
int8_t const* A_hoisted = A + cta_offset_m * K + A_hoisted_row * K + A_hoisted_col * PACK_SIZE;
|
||||
int8_t const* B_hoisted = B + cta_offset_n / 32 * K * PACK_SIZE + (threadIdx.y % B_warps_per_row) * 32 * PACK_SIZE
|
||||
+ (threadIdx.y / B_warps_per_row) * K * PACK_SIZE + threadIdx.x * PACK_SIZE;
|
||||
|
||||
bool A_g2s_preds[A_total_global_iters];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < A_total_global_iters; i++)
|
||||
{
|
||||
A_g2s_preds[i] = (cta_offset_m + A_hoisted_row + i * A_src_step_m) < M;
|
||||
}
|
||||
|
||||
int* C_shared = reinterpret_cast<int*>(mem_shared);
|
||||
|
||||
#pragma unroll
|
||||
for (k_0_0_ld = 0; k_0_0_ld < prologue_stages; ++k_0_0_ld)
|
||||
{
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(A_hoisted,
|
||||
A_shared_hoisted + k_0_0_ld * kSmemSizeAPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true,
|
||||
A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, 1, STAGES>(B_hoisted,
|
||||
B_shared_hoisted + k_0_0_ld * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, 0, true);
|
||||
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(
|
||||
zeros, zeros_shared + (k_0_0_ld) *CTA_N, N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters);
|
||||
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(scales_i8,
|
||||
scales_i8_shared + (k_0_0_ld) *CTA_N, N, cta_offset_m, cta_offset_n, k_0_0_ld, 0, k_0_0_ld < gemm_iters);
|
||||
|
||||
if constexpr (STAGES > 1)
|
||||
__pipeline_commit();
|
||||
}
|
||||
if constexpr (STAGES > 1)
|
||||
__pipeline_wait_prior(STAGES - 2);
|
||||
__syncthreads();
|
||||
|
||||
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(
|
||||
A_shared + warp_offset_k, A_shared_warp_[0], warp_offset_m, warp_offset_n, 0, WARP_M / INTRIN_M);
|
||||
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(B_shared + warp_offset_k * PACK_SIZE,
|
||||
B_shared_warp_[0], zeros_shared, scales_i8_shared, warp_offset_m, warp_offset_n, 0, 0, WARP_N / 32);
|
||||
constexpr int SHARED_K_ITERS = WARP_K / INTRIN_K;
|
||||
|
||||
for (; k_0_0 < gemm_iters; ++k_0_0, ++k_0_0_ld)
|
||||
{
|
||||
int ld_stage = k_0_0_ld % STAGES;
|
||||
int compute_stage = k_0_0 % STAGES;
|
||||
int8_t* A_shared_this_compute_stage;
|
||||
int8_t* B_shared_this_compute_stage;
|
||||
int8_t* zeros_shared_this_compute_stage;
|
||||
int8_t* scales_i8_shared_this_compute_stage;
|
||||
|
||||
for (int iter_k = 0; iter_k < SHARED_K_ITERS; ++iter_k)
|
||||
{
|
||||
A_shared_this_compute_stage = A_shared + compute_stage * kSmemSizeAPerStage + warp_offset_k;
|
||||
B_shared_this_compute_stage = B_shared + compute_stage * kSmemSizeBPerStage + warp_offset_k * PACK_SIZE;
|
||||
zeros_shared_this_compute_stage = zeros_shared + (compute_stage) *CTA_N;
|
||||
scales_i8_shared_this_compute_stage = scales_i8_shared + (compute_stage) *CTA_N;
|
||||
|
||||
share_to_reg_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES>(A_shared_this_compute_stage,
|
||||
A_shared_warp_[(iter_k + 1) % 2], warp_offset_m, warp_offset_n, (iter_k + 1) % SHARED_K_ITERS,
|
||||
WARP_M / INTRIN_M);
|
||||
share_to_reg_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(B_shared_this_compute_stage,
|
||||
B_shared_warp_[(iter_k + 1) % 2], zeros_shared_this_compute_stage, scales_i8_shared_this_compute_stage,
|
||||
warp_offset_m, warp_offset_n, k_0_0 + (iter_k == SHARED_K_ITERS - 1), (iter_k + 1) % SHARED_K_ITERS,
|
||||
WARP_N / 32);
|
||||
int8_t* A_shared_warp = A_shared_warp_[iter_k % 2];
|
||||
int8_t* B_shared_warp = B_shared_warp_[iter_k % 2];
|
||||
|
||||
for (int j_0_4 = 0; j_0_4 < WARP_N / INTRIN_N; ++j_0_4)
|
||||
{
|
||||
for (int i_0_3 = 0; i_0_3 < WARP_M / INTRIN_M; ++i_0_3)
|
||||
{
|
||||
mma_m16n8k32((void*) (C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8),
|
||||
(void*) (A_shared_warp + i_0_3 * 16), (void*) (B_shared_warp + j_0_4 * 16));
|
||||
mma_m16n8k32((void*) (C_warp + i_0_3 * WARP_N / INTRIN_N * 8 + j_0_4 * 8 + 4),
|
||||
(void*) (A_shared_warp + i_0_3 * 16), (void*) (B_shared_warp + j_0_4 * 16 + 8));
|
||||
}
|
||||
}
|
||||
|
||||
if (iter_k < SHARED_K_ITERS - 1)
|
||||
{
|
||||
if constexpr (STAGES == 1)
|
||||
__syncthreads();
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A_hoisted,
|
||||
A_shared_hoisted + ld_stage * kSmemSizeAPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k,
|
||||
k_0_0_ld < gemm_iters, A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B_hoisted,
|
||||
B_shared_hoisted + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k,
|
||||
k_0_0_ld < gemm_iters);
|
||||
}
|
||||
|
||||
if (iter_k == SHARED_K_ITERS - 2)
|
||||
{
|
||||
if constexpr (STAGES == 1 && SHARED_K_ITERS > 2)
|
||||
{
|
||||
__syncthreads();
|
||||
}
|
||||
global_to_share_one_stage_A<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(A_hoisted,
|
||||
A_shared_hoisted + ld_stage * kSmemSizeAPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld,
|
||||
iter_k + 1, k_0_0_ld < gemm_iters, A_g2s_preds);
|
||||
global_to_share_one_stage_B<CTA_M, CTA_N, CTA_K, CTA_SIZE, WARP_K / INTRIN_K, STAGES>(B_hoisted,
|
||||
B_shared_hoisted + ld_stage * kSmemSizeBPerStage, K, cta_offset_m, cta_offset_n, k_0_0_ld,
|
||||
iter_k + 1, k_0_0_ld < gemm_iters);
|
||||
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(zeros,
|
||||
zeros_shared + (ld_stage) *CTA_N, N, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k,
|
||||
k_0_0_ld < gemm_iters);
|
||||
global_to_share_one_stage_zeros<CTA_M, CTA_N, CTA_K, CTA_SIZE, STAGES, G>(scales_i8,
|
||||
scales_i8_shared + (ld_stage) *CTA_N, N, cta_offset_m, cta_offset_n, k_0_0_ld, iter_k,
|
||||
k_0_0_ld < gemm_iters);
|
||||
if constexpr (STAGES > 1)
|
||||
{
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(STAGES - 2);
|
||||
}
|
||||
compute_stage = (k_0_0 + 1) % STAGES;
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
__pipeline_commit();
|
||||
__pipeline_wait_prior(0);
|
||||
__syncthreads();
|
||||
|
||||
if constexpr (SLICES > 1)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int z = 0; z < SLICES; ++z)
|
||||
{
|
||||
if (slice_id == z)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
|
||||
{
|
||||
if (z > 0)
|
||||
{
|
||||
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]
|
||||
+= C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n
|
||||
+ ax1_0_1 * 16 + ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N
|
||||
+ (local_id / 4) * 8 + (local_id % 2) + (threadIdx.x % 4) * 2];
|
||||
}
|
||||
C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16
|
||||
+ ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8
|
||||
+ (local_id % 2) + (threadIdx.x % 4) * 2]
|
||||
= C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id];
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
if (slice_id == 0)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; ++local_id)
|
||||
{
|
||||
C_warp[ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8 + local_id]
|
||||
= C_shared[warp_offset_m * CTA_N + ax0_0_1 * OP_M * CTA_N + warp_offset_n + ax1_0_1 * 16
|
||||
+ ((local_id % 4) / 2 * 8 + (threadIdx.x / 4)) * CTA_N + (local_id / 4) * 8
|
||||
+ (local_id % 2) + (threadIdx.x % 4) * 2];
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int row_wb_thd = cta_offset_m + warp_offset_m + (threadIdx.x / 4);
|
||||
int col_wb_thd = cta_offset_n + warp_offset_n + (threadIdx.x % 4) * 2;
|
||||
if (slice_id == 0)
|
||||
{
|
||||
for (int ax0_0_1 = 0; ax0_0_1 < WARP_M / INTRIN_M; ++ax0_0_1)
|
||||
{
|
||||
int row_wb_1 = row_wb_thd + ax0_0_1 * OP_M;
|
||||
for (int ax1_0_1 = 0; ax1_0_1 < WARP_N / INTRIN_N; ++ax1_0_1)
|
||||
{
|
||||
int col_wb_1 = col_wb_thd + ax1_0_1 * 16;
|
||||
int* C_warp_local = C_warp + ax0_0_1 * WARP_N / INTRIN_N * 8 + ax1_0_1 * 8;
|
||||
for (int local_id = 0; local_id < OP_M * 16 / WARP_SIZE; local_id += 2)
|
||||
{
|
||||
int row_wb = row_wb_1 + (local_id % 4) / 2 * 8;
|
||||
if (row_wb < M)
|
||||
{
|
||||
int col_wb = col_wb_1 + (local_id / 4) * 8 + (local_id % 2);
|
||||
float2 wscale = __half22float2(*(wscales + col_wb / 2));
|
||||
float ascale = __half2float(ascales[row_wb]);
|
||||
float2 psums = make_float2(
|
||||
__int2float_rn(C_warp_local[local_id]), __int2float_rn(C_warp_local[local_id + 1]));
|
||||
psums.x *= wscale.x * ascale;
|
||||
psums.y *= wscale.y * ascale;
|
||||
*reinterpret_cast<half2*>(C + row_wb * N + col_wb) = __float22half2_rn(psums);
|
||||
}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void qserveGemmPerGroupLaunch(ParamsPerGroup const& params, cudaStream_t stream)
|
||||
{
|
||||
auto in_feats = params.A;
|
||||
auto kernel = params.B;
|
||||
auto zeros = params.s2_zeros;
|
||||
auto scales_i8 = params.s2_scales;
|
||||
auto wscales = reinterpret_cast<half2 const*>(params.s1_scales);
|
||||
auto ascales = params.act_scales;
|
||||
|
||||
auto out_feats = params.C;
|
||||
|
||||
int num_out_feats = params.m;
|
||||
int num_out_channels = params.n;
|
||||
int num_in_feats = params.m;
|
||||
int num_in_channels = params.k;
|
||||
|
||||
constexpr int G = 128;
|
||||
|
||||
if (num_out_feats > 128)
|
||||
{
|
||||
constexpr int CTA_M = 128;
|
||||
constexpr int CTA_N = 128;
|
||||
constexpr int CTA_K = 64;
|
||||
constexpr int WARP_M = 64;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 3;
|
||||
KERNEL_LAUNCH_CODE
|
||||
}
|
||||
else if (num_out_feats >= 128)
|
||||
{
|
||||
if (num_in_channels <= 4096)
|
||||
{
|
||||
constexpr int CTA_M = 64;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 64;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 4;
|
||||
KERNEL_LAUNCH_CODE
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int CTA_M = 64;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 128;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 3;
|
||||
KERNEL_LAUNCH_CODE
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr int CTA_M = 32;
|
||||
constexpr int CTA_N = 64;
|
||||
constexpr int CTA_K = 128;
|
||||
constexpr int WARP_M = 32;
|
||||
constexpr int WARP_N = 32;
|
||||
constexpr int WARP_K = 64;
|
||||
constexpr int STAGES = 3;
|
||||
KERNEL_LAUNCH_CODE
|
||||
}
|
||||
}
|
||||
|
||||
void QServeGemmRunner::gemmPerGroup(ParamsPerGroup const& params, cudaStream_t stream)
|
||||
{
|
||||
qserveGemmPerGroupLaunch(params, stream);
|
||||
}
|
||||
|
||||
size_t QServeGemmRunner::getWorkspaceSize(int const m, int const n, int const k)
|
||||
{
|
||||
// We do not use workspace for now.
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace qserve
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -29,13 +29,13 @@ namespace tensorrt_llm
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
__global__ void quantizedKernel(char4* dst, float4 const* src, const int64_t sizeDiv4, float const* scalePtr)
|
||||
__global__ void quantizedKernel(char4* dst, float4 const* src, int64_t const sizeDiv4, float const* scalePtr)
|
||||
{
|
||||
for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4; idx += blockDim.x * gridDim.x)
|
||||
{
|
||||
float const scale = __ldg(scalePtr);
|
||||
char4 tmp;
|
||||
const float4 floatTmp = __ldg(src + idx);
|
||||
float4 const floatTmp = __ldg(src + idx);
|
||||
tmp.x = cuda_cast<int8_t>(floatTmp.x * scale);
|
||||
tmp.y = cuda_cast<int8_t>(floatTmp.y * scale);
|
||||
tmp.z = cuda_cast<int8_t>(floatTmp.z * scale);
|
||||
@ -44,7 +44,7 @@ __global__ void quantizedKernel(char4* dst, float4 const* src, const int64_t siz
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void quantizedKernel(char4* dst, half2 const* src, const int64_t sizeDiv4, float const* scalePtr)
|
||||
__global__ void quantizedKernel(char4* dst, half2 const* src, int64_t const sizeDiv4, float const* scalePtr)
|
||||
{
|
||||
for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4; idx += blockDim.x * gridDim.x)
|
||||
{
|
||||
@ -52,10 +52,10 @@ __global__ void quantizedKernel(char4* dst, half2 const* src, const int64_t size
|
||||
char4 tmp;
|
||||
int srcId = idx << 1;
|
||||
|
||||
const uint2 h2 = __ldg(reinterpret_cast<uint2 const*>(src + srcId));
|
||||
uint2 const h2 = __ldg(reinterpret_cast<uint2 const*>(src + srcId));
|
||||
|
||||
const half2 half2Tmp = reinterpret_cast<half2 const&>(h2.x);
|
||||
const half2 half2Tmp2 = reinterpret_cast<half2 const&>(h2.y);
|
||||
half2 const half2Tmp = reinterpret_cast<half2 const&>(h2.x);
|
||||
half2 const half2Tmp2 = reinterpret_cast<half2 const&>(h2.y);
|
||||
|
||||
tmp.x = cuda_cast<int8_t>(cuda_cast<float>(half2Tmp.x) * scale);
|
||||
tmp.y = cuda_cast<int8_t>(cuda_cast<float>(half2Tmp.y) * scale);
|
||||
@ -66,7 +66,7 @@ __global__ void quantizedKernel(char4* dst, half2 const* src, const int64_t size
|
||||
}
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
__global__ void quantizedKernel(char4* dst, __nv_bfloat162 const* src, const int64_t sizeDiv4, float const* scalePtr)
|
||||
__global__ void quantizedKernel(char4* dst, __nv_bfloat162 const* src, int64_t const sizeDiv4, float const* scalePtr)
|
||||
{
|
||||
for (int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < sizeDiv4; idx += blockDim.x * gridDim.x)
|
||||
{
|
||||
@ -74,10 +74,10 @@ __global__ void quantizedKernel(char4* dst, __nv_bfloat162 const* src, const int
|
||||
char4 tmp;
|
||||
int srcId = idx << 1;
|
||||
|
||||
const uint2 h2 = __ldg(reinterpret_cast<uint2 const*>(src + srcId));
|
||||
uint2 const h2 = __ldg(reinterpret_cast<uint2 const*>(src + srcId));
|
||||
|
||||
const __nv_bfloat162 bfloat162Tmp = reinterpret_cast<__nv_bfloat162 const&>(h2.x);
|
||||
const __nv_bfloat162 bfloat162Tmp2 = reinterpret_cast<__nv_bfloat162 const&>(h2.y);
|
||||
__nv_bfloat162 const bfloat162Tmp = reinterpret_cast<__nv_bfloat162 const&>(h2.x);
|
||||
__nv_bfloat162 const bfloat162Tmp2 = reinterpret_cast<__nv_bfloat162 const&>(h2.y);
|
||||
|
||||
tmp.x = cuda_cast<int8_t>(cuda_cast<float>(bfloat162Tmp.x) * scale);
|
||||
tmp.y = cuda_cast<int8_t>(cuda_cast<float>(bfloat162Tmp.y) * scale);
|
||||
@ -91,7 +91,7 @@ __global__ void quantizedKernel(char4* dst, __nv_bfloat162 const* src, const int
|
||||
|
||||
template <typename T>
|
||||
void invokeQuantization(
|
||||
int8_t* dst, T const* src, const int64_t size, float const* scalePtr, cudaStream_t stream, int maxGridSize)
|
||||
int8_t* dst, T const* src, int64_t const size, float const* scalePtr, cudaStream_t stream, int maxGridSize)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(size % 4 == 0, "[ERROR][invokeQuantization] size should be a multiple of 4.\n");
|
||||
|
||||
@ -116,13 +116,13 @@ void invokeQuantization(
|
||||
}
|
||||
|
||||
template void invokeQuantization<float>(
|
||||
int8_t* dst, float const* src, const int64_t size, float const* scalePtr, cudaStream_t stream, int maxGridSize);
|
||||
int8_t* dst, float const* src, int64_t const size, float const* scalePtr, cudaStream_t stream, int maxGridSize);
|
||||
|
||||
template void invokeQuantization<half>(
|
||||
int8_t* dst, half const* src, const int64_t size, float const* scalePtr, cudaStream_t stream, int maxGridSize);
|
||||
int8_t* dst, half const* src, int64_t const size, float const* scalePtr, cudaStream_t stream, int maxGridSize);
|
||||
|
||||
#ifdef ENABLE_BF16
|
||||
template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const* src, const int64_t size,
|
||||
template void invokeQuantization<__nv_bfloat16>(int8_t* dst, __nv_bfloat16 const* src, int64_t const size,
|
||||
float const* scalePtr, cudaStream_t stream, int maxGridSize);
|
||||
#endif
|
||||
|
||||
@ -220,8 +220,8 @@ inline __device__ void quantizeAndStore(
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, typename QuantT, bool USE_SMEM>
|
||||
__global__ void perTokenQuantization(QuantT* dst, T const* src, const int64_t numRows, const int64_t numCols,
|
||||
float const* clampPtr, float* scalePtr, bool hasFp8MinScaling)
|
||||
__global__ void perTokenQuantization(QuantT* dst, T const* src, int64_t const numRows, int64_t const numCols,
|
||||
float const* clampPtr, float* scalePtr, float* sumPtr, bool hasFp8MinScaling)
|
||||
{
|
||||
// Smem buffer.
|
||||
extern __shared__ uint4 smemBuffer[];
|
||||
@ -244,6 +244,8 @@ __global__ void perTokenQuantization(QuantT* dst, T const* src, const int64_t nu
|
||||
|
||||
// The number of elements in the packed uint4 vec.
|
||||
static constexpr int NUM_ELTS_PER_VEC = sizeof(uint4) / sizeof(T);
|
||||
static constexpr int NUM_ELTS2_PER_VEC = sizeof(uint4) / sizeof(T2);
|
||||
|
||||
// The number of vectors in the column.
|
||||
int const numColVecs = numCols / NUM_ELTS_PER_VEC;
|
||||
// The vector pointers for src.
|
||||
@ -253,10 +255,25 @@ __global__ void perTokenQuantization(QuantT* dst, T const* src, const int64_t nu
|
||||
// T const* srcRow = src + blockIdx.x * numCols;
|
||||
|
||||
T2 localMax2 = cuda_cast<T2, T>(T(1e-6f));
|
||||
float2 localSum2 = {0.f, 0.f};
|
||||
|
||||
for (int i = threadIdx.x; i < numColVecs; i += blockDim.x)
|
||||
{
|
||||
uint4 vec = srcVec[i];
|
||||
clampAndAbsMax(localMax2, vec, clampMin2, clampMax2);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NUM_ELTS2_PER_VEC; ++i)
|
||||
{
|
||||
T2& val2 = reinterpret_cast<T2*>(&vec)[i];
|
||||
val2 = cuda_clamp(val2, clampMin2, clampMax2);
|
||||
localMax2 = cuda_max(localMax2, cuda_abs(val2));
|
||||
// TODO: template the version that requires sum to avoid dynamic branching.
|
||||
if (sumPtr != nullptr)
|
||||
{
|
||||
localSum2.x += cuda_cast<float>(val2.x);
|
||||
localSum2.y += cuda_cast<float>(val2.y);
|
||||
}
|
||||
}
|
||||
// Avoid reloading from global memory.
|
||||
if constexpr (USE_SMEM)
|
||||
{
|
||||
@ -264,13 +281,22 @@ __global__ void perTokenQuantization(QuantT* dst, T const* src, const int64_t nu
|
||||
}
|
||||
}
|
||||
float const rowMax = blockAllReduceMax(cuda_cast<float>(cuda_max<T, T2>(localMax2)));
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
scalePtr[blockIdx.x]
|
||||
= hasFp8MinScaling ? cuda_max(rowMax / MAX_QUANT_VAL, MIN_SCALING_FACTOR) : (rowMax / MAX_QUANT_VAL);
|
||||
}
|
||||
|
||||
if (sumPtr != nullptr)
|
||||
{
|
||||
float rowSum[1] = {cuda_sum<float>(localSum2)};
|
||||
blockReduceSumV2<float, 1>(rowSum);
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
sumPtr[blockIdx.x] = rowSum[0];
|
||||
}
|
||||
}
|
||||
|
||||
float const scaleOrigQuant
|
||||
= hasFp8MinScaling ? fminf(MAX_QUANT_VAL / rowMax, MIN_SCALING_FACTOR_RCP) : MAX_QUANT_VAL / rowMax;
|
||||
for (int i = threadIdx.x; i < numColVecs; i += blockDim.x)
|
||||
@ -283,12 +309,12 @@ __global__ void perTokenQuantization(QuantT* dst, T const* src, const int64_t nu
|
||||
|
||||
// Do per-token (row) quantization from fp16/bf16/fp32 to int8/fp8_e4m3.
|
||||
template <typename T, typename QuantT>
|
||||
void invokePerTokenQuantization(QuantT* dst, T const* src, const int64_t numRows, const int64_t numCols,
|
||||
float const* clampPtr, float* scalePtr, QuantMode quantMode, cudaStream_t stream)
|
||||
void invokePerTokenQuantization(QuantT* dst, T const* src, int64_t const numRows, int64_t const numCols,
|
||||
float const* clampPtr, float* scalePtr, float* sumPtr, QuantMode quantMode, cudaStream_t stream)
|
||||
{
|
||||
// each block is responsible for a single row
|
||||
const dim3 block(512);
|
||||
const dim3 grid(numRows);
|
||||
dim3 const block(512);
|
||||
dim3 const grid(numRows);
|
||||
|
||||
// The number of elements in the packed uint4 vec.
|
||||
static constexpr int NUM_ELTS_PER_VEC = sizeof(uint4) / sizeof(T);
|
||||
@ -311,19 +337,19 @@ void invokePerTokenQuantization(QuantT* dst, T const* src, const int64_t numRows
|
||||
// Do we use smem ?
|
||||
if (useSmem)
|
||||
{
|
||||
perTokenQuantization<T, QuantT, true>
|
||||
<<<grid, block, dynamicSmemSz, stream>>>(dst, src, numRows, numCols, clampPtr, scalePtr, hasFp8MinScaling);
|
||||
perTokenQuantization<T, QuantT, true><<<grid, block, dynamicSmemSz, stream>>>(
|
||||
dst, src, numRows, numCols, clampPtr, scalePtr, sumPtr, hasFp8MinScaling);
|
||||
}
|
||||
else
|
||||
{
|
||||
perTokenQuantization<T, QuantT, false>
|
||||
<<<grid, block, 0, stream>>>(dst, src, numRows, numCols, clampPtr, scalePtr, hasFp8MinScaling);
|
||||
<<<grid, block, 0, stream>>>(dst, src, numRows, numCols, clampPtr, scalePtr, sumPtr, hasFp8MinScaling);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(T, QuantT) \
|
||||
template void invokePerTokenQuantization(QuantT* dst, const T* src, const int64_t numRows, const int64_t numCols, \
|
||||
float const* clampPtr, float* scalePtr, QuantMode quantMode, cudaStream_t stream)
|
||||
float const* clampPtr, float* scalePtr, float* sumPtr, QuantMode quantMode, cudaStream_t stream)
|
||||
|
||||
INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(float, int8_t);
|
||||
INSTANTIATE_INVOKE_PER_TOKEN_QUANTIZATION(half, int8_t);
|
||||
|
||||
@ -26,11 +26,12 @@ namespace kernels
|
||||
|
||||
template <typename T>
|
||||
void invokeQuantization(
|
||||
int8_t* dst, T const* src, const int64_t size, float const* scalePtr, cudaStream_t stream = 0, int maxGirdSize = 0);
|
||||
int8_t* dst, T const* src, int64_t const size, float const* scalePtr, cudaStream_t stream = 0, int maxGirdSize = 0);
|
||||
|
||||
template <typename T, typename QuantT>
|
||||
void invokePerTokenQuantization(QuantT* dst, T const* src, const int64_t numRows, const int64_t numCols,
|
||||
float const* clampPtr, float* scalePtr, tensorrt_llm::common::QuantMode quantMode, cudaStream_t stream = 0);
|
||||
void invokePerTokenQuantization(QuantT* dst, T const* src, int64_t const numRows, int64_t const numCols,
|
||||
float const* clampPtr, float* scalePtr, float* sumPtr, tensorrt_llm::common::QuantMode quantMode,
|
||||
cudaStream_t stream = 0);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -53,7 +53,7 @@ __inline__ __device__ Tf compute_rmsnorm(Tf val, float s_variance, T const* gamm
|
||||
template <typename T, typename QuantT, bool USE_SHMEM>
|
||||
__global__ void generalRmsNorm(T const* input, T const* gamma, T const* beta, T* normed_output, float const eps,
|
||||
int tokens, int hidden_dim, float const* clampPtr, float const* scale_orig_quant_per_tensor,
|
||||
float* scale_orig_quant_per_token, QuantT* normed_output_quant, bool hasFp8MinScaling)
|
||||
float* scale_orig_quant_per_token, float* sum_per_token, QuantT* normed_output_quant, bool hasFp8MinScaling)
|
||||
{
|
||||
constexpr auto num_elems_T = num_elems<T>::value;
|
||||
// using int8_packed_t = typename packed_as<int8_t, num_elems_T>::type;
|
||||
@ -86,13 +86,13 @@ __global__ void generalRmsNorm(T const* input, T const* gamma, T const* beta, T*
|
||||
int const n_elems = hidden_dim / num_elems_T;
|
||||
for (int i = tidx; i < n_elems; i += blockDim.x)
|
||||
{
|
||||
const T val = input[bidx * n_elems + i];
|
||||
T const val = input[bidx * n_elems + i];
|
||||
if (USE_SHMEM)
|
||||
{
|
||||
shmem[i] = val;
|
||||
}
|
||||
|
||||
const float_packed_t val_f = cuda_cast<float_packed_t>(val);
|
||||
float_packed_t const val_f = cuda_cast<float_packed_t>(val);
|
||||
|
||||
local_var_sum += cuda_sum<float>(val_f * val_f);
|
||||
}
|
||||
@ -110,14 +110,17 @@ __global__ void generalRmsNorm(T const* input, T const* gamma, T const* beta, T*
|
||||
|
||||
bool const with_per_token_scaling = scale_orig_quant_per_token != nullptr;
|
||||
bool const with_per_tensor_scaling = scale_orig_quant_per_tensor != nullptr;
|
||||
const float_packed_t scale_orig_quant
|
||||
bool const with_per_token_sum = sum_per_token != nullptr;
|
||||
|
||||
float_packed_t const scale_orig_quant
|
||||
= cuda_cast<float_packed_t>(with_per_tensor_scaling ? *scale_orig_quant_per_tensor : 0.0f);
|
||||
T_scalar amax = 1e-6f;
|
||||
float local_sum = 0.f;
|
||||
|
||||
for (int i = tidx; i < n_elems; i += blockDim.x)
|
||||
{
|
||||
int const index = bidx * n_elems + i;
|
||||
const float_packed_t val_f = cuda_cast<float_packed_t>(USE_SHMEM ? shmem[i] : input[index]);
|
||||
float_packed_t const val_f = cuda_cast<float_packed_t>(USE_SHMEM ? shmem[i] : input[index]);
|
||||
T val = cuda_cast<T>(compute_rmsnorm(val_f, s_variance, gamma, beta, i));
|
||||
|
||||
if (with_per_token_scaling)
|
||||
@ -139,6 +142,11 @@ __global__ void generalRmsNorm(T const* input, T const* gamma, T const* beta, T*
|
||||
{
|
||||
normed_output[index] = val;
|
||||
}
|
||||
|
||||
if (with_per_token_sum)
|
||||
{
|
||||
local_sum += cuda_sum<float>(cuda_cast<float_packed_t>(val));
|
||||
}
|
||||
}
|
||||
|
||||
if (with_per_token_scaling)
|
||||
@ -165,13 +173,23 @@ __global__ void generalRmsNorm(T const* input, T const* gamma, T const* beta, T*
|
||||
: abs_max_f / MAX_QUANT_VAL;
|
||||
}
|
||||
}
|
||||
|
||||
if (with_per_token_sum)
|
||||
{
|
||||
float packed_sum[1] = {local_sum};
|
||||
blockReduceSumV2<float, 1>(packed_sum);
|
||||
if (tidx == 0)
|
||||
{
|
||||
sum_per_token[bidx] = packed_sum[0];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename QuantT>
|
||||
void dispatch_rmsnorm_type_square_method(T const* input, T const* gamma, T const* beta, T* normed_output,
|
||||
float const eps, int tokens, int hidden_dim, float const* clampPtr, float const* scale_orig_quant_per_tensor,
|
||||
float* scale_orig_quant_per_token, QuantT* normed_output_quant, bool const hasFp8MinScaling, const dim3 grid,
|
||||
const dim3 block, const size_t shmem_size, cudaStream_t stream)
|
||||
float* scale_orig_quant_per_token, float* sum_per_token, QuantT* normed_output_quant, bool const hasFp8MinScaling,
|
||||
dim3 const grid, dim3 const block, size_t const shmem_size, cudaStream_t stream)
|
||||
{
|
||||
// Do we use shared memory to cache intermediate results.
|
||||
bool use_shmem = true;
|
||||
@ -186,32 +204,32 @@ void dispatch_rmsnorm_type_square_method(T const* input, T const* gamma, T const
|
||||
if (use_shmem)
|
||||
{
|
||||
generalRmsNorm<T, QuantT, true><<<grid, block, shmem_size, stream>>>(input, gamma, beta, normed_output, eps,
|
||||
tokens, hidden_dim, clampPtr, scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant,
|
||||
hasFp8MinScaling);
|
||||
tokens, hidden_dim, clampPtr, scale_orig_quant_per_tensor, scale_orig_quant_per_token, sum_per_token,
|
||||
normed_output_quant, hasFp8MinScaling);
|
||||
}
|
||||
else
|
||||
{
|
||||
generalRmsNorm<T, QuantT, false><<<grid, block, shmem_size, stream>>>(input, gamma, beta, normed_output, eps,
|
||||
tokens, hidden_dim, clampPtr, scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant,
|
||||
hasFp8MinScaling);
|
||||
tokens, hidden_dim, clampPtr, scale_orig_quant_per_tensor, scale_orig_quant_per_token, sum_per_token,
|
||||
normed_output_quant, hasFp8MinScaling);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, typename QuantT>
|
||||
void dispatch_rmsnorm_type(T const* input, T const* gamma, T const* beta, T* normed_output, float const eps, int tokens,
|
||||
int hidden_dim, float const* clampPtr, float const* scale_orig_quant_per_tensor, float* scale_orig_quant_per_token,
|
||||
QuantT* normed_output_quant, bool const hasFp8MinScaling, const dim3 grid, const dim3 block,
|
||||
const size_t shmem_size, cudaStream_t stream)
|
||||
float* sum_per_token, QuantT* normed_output_quant, bool const hasFp8MinScaling, dim3 const grid, dim3 const block,
|
||||
size_t const shmem_size, cudaStream_t stream)
|
||||
{
|
||||
dispatch_rmsnorm_type_square_method(input, gamma, beta, normed_output, eps, tokens, hidden_dim, clampPtr,
|
||||
scale_orig_quant_per_tensor, scale_orig_quant_per_token, normed_output_quant, hasFp8MinScaling, grid, block,
|
||||
shmem_size, stream);
|
||||
scale_orig_quant_per_tensor, scale_orig_quant_per_token, sum_per_token, normed_output_quant, hasFp8MinScaling,
|
||||
grid, block, shmem_size, stream);
|
||||
}
|
||||
|
||||
template <typename T, typename QuantT>
|
||||
void invokeGeneralRmsNorm(T* out, T const* input, T const* gamma, T const* beta, float const eps, int const tokens,
|
||||
int const hidden_dim, QuantMode quantMode, cudaStream_t stream, float const* clampPtr, float const* scale,
|
||||
float* dynamic_scale, QuantT* normed_output_quant)
|
||||
float* dynamic_scale, float* sum_per_token, QuantT* normed_output_quant)
|
||||
{
|
||||
dim3 grid(tokens);
|
||||
dim3 block(min(hidden_dim, 1024));
|
||||
@ -219,7 +237,7 @@ void invokeGeneralRmsNorm(T* out, T const* input, T const* gamma, T const* beta,
|
||||
block.x = 32 * ((block.x + 31) / 32);
|
||||
|
||||
constexpr size_t vec_size = 2;
|
||||
const size_t shmem_size = hidden_dim * sizeof(T);
|
||||
size_t const shmem_size = hidden_dim * sizeof(T);
|
||||
bool const use_vec_type = (hidden_dim % vec_size == 0)
|
||||
&& (std::is_same<T, half>::value
|
||||
#ifdef ENABLE_BF16
|
||||
@ -235,19 +253,19 @@ void invokeGeneralRmsNorm(T* out, T const* input, T const* gamma, T const* beta,
|
||||
using Tp = typename packed_as<T, vec_size>::type;
|
||||
dispatch_rmsnorm_type(reinterpret_cast<Tp const*>(input), reinterpret_cast<Tp const*>(gamma),
|
||||
reinterpret_cast<Tp const*>(beta), reinterpret_cast<Tp*>(out), eps, tokens, hidden_dim, clampPtr, scale,
|
||||
dynamic_scale, normed_output_quant, hasFp8MinScaling, grid, block, shmem_size, stream);
|
||||
dynamic_scale, sum_per_token, normed_output_quant, hasFp8MinScaling, grid, block, shmem_size, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
dispatch_rmsnorm_type(input, gamma, beta, out, eps, tokens, hidden_dim, clampPtr, scale, dynamic_scale,
|
||||
normed_output_quant, hasFp8MinScaling, grid, block, shmem_size, stream);
|
||||
sum_per_token, normed_output_quant, hasFp8MinScaling, grid, block, shmem_size, stream);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_GENERAL_RMSNORM(T, QuantT) \
|
||||
template void invokeGeneralRmsNorm(T* out, const T* input, const T* gamma, const T* beta, const float eps, \
|
||||
const int tokens, const int hidden_dim, QuantMode quantMode, cudaStream_t stream, float const* clampPtr, \
|
||||
const float* scale, float* dynamic_scale, QuantT* normed_output_quant);
|
||||
const float* scale, float* dynamic_scale, float* sum_per_token, QuantT* normed_output_quant);
|
||||
|
||||
INSTANTIATE_GENERAL_RMSNORM(float, int8_t);
|
||||
INSTANTIATE_GENERAL_RMSNORM(half, int8_t);
|
||||
|
||||
@ -31,7 +31,7 @@ template <typename T, typename QuantT>
|
||||
void invokeGeneralRmsNorm(T* out, T const* input, T const* gamma, T const* beta, float const eps, int const tokens,
|
||||
int const hidden_dim, tensorrt_llm::common::QuantMode quantMode, cudaStream_t stream = 0,
|
||||
float const* clampPtr = nullptr, float const* scale = nullptr, float* dynamic_scale = nullptr,
|
||||
QuantT* out_quant = nullptr);
|
||||
float* sum_per_token = nullptr, QuantT* out_quant = nullptr);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -231,6 +231,7 @@ __global__ void topKStage2Sampling(SizeType32 const* __restrict topKTmpIdBuf, T*
|
||||
auto outputId = idx != -1
|
||||
? topKTmpIdBuf[(batchIdx * maxTokensPerStep + tokenIdx) * stride + idx] % vocabSize
|
||||
: vocabSize - 1;
|
||||
outputId = outputId == -1 ? vocabSize - 1 : outputId;
|
||||
auto const curSeqLen = sequenceLengths == nullptr ? 0 : sequenceLengths[batchSlot];
|
||||
auto const outIdx = returnAllSelectedTokens ? tokenIdx * maxTopK + ki : curSeqLen + tokenIdx;
|
||||
outputIdsRequestPtr[outIdx] = outputId;
|
||||
|
||||
@ -153,8 +153,7 @@ __global__ void prepareCtxEagleNetInputsKernel(SizeType32* eagleNetSequenceLengt
|
||||
if (isValid)
|
||||
{
|
||||
// Sequence length of the base model (without draft len for gen requests and all prompt tokens for ctx request)
|
||||
auto const oldSequenceLength
|
||||
= baseNetSequenceLengths[bid] - (numInputTokens - static_cast<SizeType32>(!isContextRequest));
|
||||
auto const oldSequenceLength = baseNetSequenceLengths[bid] - numInputTokens;
|
||||
for (SizeType32 ti = 0; ti < numDecodingTokens; ++ti)
|
||||
{
|
||||
TokenIdType token;
|
||||
@ -249,14 +248,6 @@ __global__ void buildLeafMask(
|
||||
{
|
||||
auto const bid = static_cast<SizeType32>(blockIdx.x);
|
||||
auto const level = static_cast<SizeType32>(blockIdx.y);
|
||||
// Prefill mask setting all to leaves.
|
||||
for (auto tid = static_cast<SizeType32>(threadIdx.x); tid < maxDecodingTokens;
|
||||
tid += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
isLeafMask[bid * maxDecodingTokens + tid] = 1;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// For all paths
|
||||
for (auto pathIdx = static_cast<SizeType32>(threadIdx.x); pathIdx < maxDecodingTokens;
|
||||
@ -268,8 +259,8 @@ __global__ void buildLeafMask(
|
||||
auto const tokenCurLevelOffset = flat_index3(bid, pathIdx, level, maxDecodingTokens, maxPathLen);
|
||||
// Get token idx in the flattened draft tokens for the of the current level
|
||||
auto const curNodeTokenIdx = paths[tokenCurLevelOffset];
|
||||
// If token idx is not -1 (not terminated path)
|
||||
// And the next level token is not -1 -- path is not terminating and token at current level has child.
|
||||
// If token idx is not -1 (not terminated path) And the next
|
||||
// level token is not -1 -- path is not terminating and token at current level has child.
|
||||
if (curNodeTokenIdx != -1 && paths[tokenNextLevelOffset] != -1)
|
||||
{
|
||||
// Mark mask to 0.
|
||||
@ -303,11 +294,6 @@ __global__ void getNonLeafEndingSubtree(SizeType32* selectedDraftIndices, SizeTy
|
||||
// Init selected paths for CAS.
|
||||
selectedPathsSmem[ii] = -1;
|
||||
}
|
||||
// Fill mask.
|
||||
for (auto ii = static_cast<SizeType32>(threadIdx.x); ii < maxDecodingTokens;
|
||||
ii += static_cast<SizeType32>(blockDim.x))
|
||||
{
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@ -320,7 +306,7 @@ __global__ void getNonLeafEndingSubtree(SizeType32* selectedDraftIndices, SizeTy
|
||||
auto const tokenIdxLevel = paths[tokenCurLevelOffset];
|
||||
// Check if this path is not terminated yet.
|
||||
// And check if this node is not leaf.
|
||||
if (tokenIdxLevel >= 0 && !isLeafMask[tokenIdxLevel])
|
||||
if (tokenIdxLevel >= 0 && !isLeafMask[bid * maxDecodingTokens + tokenIdxLevel])
|
||||
{
|
||||
// Set this path as selected for this token.
|
||||
atomicCAS(&selectedPathsSmem[tokenIdxLevel], -1, pi);
|
||||
@ -372,7 +358,7 @@ __global__ void getNonLeafEndingSubtree(SizeType32* selectedDraftIndices, SizeTy
|
||||
}
|
||||
|
||||
// If node is not leaf
|
||||
if (!isLeafMask[ti])
|
||||
if (!isLeafMask[bid * maxDecodingTokens + ti])
|
||||
{
|
||||
// Save its in-level index for output hidden state indices calculation.
|
||||
nonLeavesInLevelOffsets[bid * maxDecodingTokens + ti] = nonLeavesInLevelCounter;
|
||||
@ -506,15 +492,13 @@ __global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths,
|
||||
if (bid == 0)
|
||||
{
|
||||
// Save max draft length for the mask packing kernel.
|
||||
maxGenerationLength[0] = maxDecodingTokens;
|
||||
maxGenerationLength[0] = maxGenLength;
|
||||
}
|
||||
|
||||
if (isValid)
|
||||
{
|
||||
// Fill spec decoding gen length.
|
||||
// We do -1 here as attn expects golden token + draft tokens.
|
||||
// We do not provide golden token, but just reduce the number of draft tokens.
|
||||
specDecodingGenLengths[bid] = nextDraftLen - 1;
|
||||
specDecodingGenLengths[bid] = nextDraftLen;
|
||||
// Simply copy context len.
|
||||
nextContextLengths[bid] = prevContextLengths[bid];
|
||||
auto const sequenceLen = eagleNet0SequenceLengths[bid];
|
||||
@ -524,6 +508,9 @@ __global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths,
|
||||
// Fill cumulative sum for the mask packing kernel.
|
||||
cumSumGenerationLengths[bid] = genLengthCumSum;
|
||||
|
||||
// Pos id is Ctx EagleNet seqLen (prompt + all accepted).
|
||||
positionIds[bid] = sequenceLen;
|
||||
|
||||
SizeType32 lastTokenIdx{0};
|
||||
for (SizeType32 ti = 0; ti < nextDraftLen; ++ti)
|
||||
{
|
||||
@ -535,8 +522,6 @@ __global__ void prepareGenEagleNetInputsKernel(SizeType32* nextSequenceLengths,
|
||||
// Get draft pos offset.
|
||||
auto const posOffset = selectedDraftPosIds[bid * maxDecodingDraftTokens + ti];
|
||||
specDecodingPositionOffsets[bid * maxDecodingTokens + ti] = posOffset;
|
||||
// Pos id is Ctx EagleNet seqLen (prompt + all accepted) + pos offset
|
||||
positionIds[outputIndexBase + ti] = sequenceLen + posOffset;
|
||||
|
||||
// hiddenStatesIndex is constructed having hidden states layout in mind.
|
||||
// Hidden states are placed in memory as [maxPathLen - 1, batchSize, numOutputTokens] (this tensor is
|
||||
@ -589,6 +574,11 @@ inline __device__ __host__ T divUp(T m, T n)
|
||||
return (m + n - 1) / n;
|
||||
}
|
||||
|
||||
__device__ SizeType32 positivePowerOfTwo(SizeType32 n)
|
||||
{
|
||||
return 1 << n;
|
||||
}
|
||||
|
||||
//! @brief Takes mask of size [maxGenerationLength] filled with 1s and 0s defined in the shared memory
|
||||
//! and packs it to bitmask of size [numPackedMasks] written to outputPtr.
|
||||
//! numPackedMasks = ceil(maxGenerationLength / 32);
|
||||
@ -609,20 +599,33 @@ __device__ __forceinline__ void maskToPackedMask(
|
||||
auto const shMaskIndexEnd = maxGenerationLength - maskId * 32;
|
||||
auto const validNumBits = shMaskIndexEnd - shMaskIndexStart;
|
||||
|
||||
SizeType32 packedMask = 0;
|
||||
for (SizeType32 ii = 0; ii < validNumBits; ++ii)
|
||||
auto const firstBit1 = (shMask[shMaskIndexStart] == '1') ? true : false;
|
||||
SizeType32 mask31bits = 0;
|
||||
if (validNumBits != 1)
|
||||
{
|
||||
auto const index = (validNumBits - 1) - (ii - shMaskIndexStart - 1) - 1;
|
||||
packedMask += (shMask[shMaskIndexStart + ii] == '1') ? 1 << index : 0;
|
||||
for (auto i = shMaskIndexStart + 1; i < shMaskIndexEnd; i++)
|
||||
{
|
||||
auto const index = (validNumBits - 1) - (i - shMaskIndexStart);
|
||||
mask31bits += (shMask[i] == '1') ? positivePowerOfTwo(index) : 0;
|
||||
}
|
||||
}
|
||||
outputPtr[maskId] = packedMask;
|
||||
SizeType32 mask32bits;
|
||||
if (validNumBits == 32)
|
||||
{
|
||||
mask32bits = firstBit1 ? mask31bits - positivePowerOfTwo(validNumBits - 1) : mask31bits;
|
||||
}
|
||||
else
|
||||
{
|
||||
mask32bits = firstBit1 ? mask31bits + positivePowerOfTwo(validNumBits - 1) : mask31bits;
|
||||
}
|
||||
outputPtr[maskId] = mask32bits;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void getPackedMask(SizeType32 const* __restrict__ cumGenerationLengths,
|
||||
SizeType32 const* __restrict__ maxGenerationLengths, bool const* __restrict__ mask, SizeType32 maxDraftTokens,
|
||||
SizeType32* __restrict__ packedMask)
|
||||
SizeType32 const* __restrict__ maxGenerationLengths, bool const* __restrict__ mask,
|
||||
SizeType32 maxDecodingDraftTokens, SizeType32* __restrict__ packedMask)
|
||||
{
|
||||
auto const batchIdx = static_cast<SizeType32>(blockIdx.y);
|
||||
auto const tokenIdx = static_cast<SizeType32>(blockIdx.x);
|
||||
@ -635,7 +638,8 @@ __global__ void getPackedMask(SizeType32 const* __restrict__ cumGenerationLength
|
||||
}
|
||||
|
||||
auto const maxGenerationLength = maxGenerationLengths[0];
|
||||
auto const numPackedMasks = divUp(maxDraftTokens + 1, 32);
|
||||
auto const maxDecodingTokens = maxDecodingDraftTokens + 1;
|
||||
auto const numPackedMasks = divUp(maxDecodingTokens, 32);
|
||||
|
||||
auto const outputStartId = ((batchIdx == 0) ? 0 : cumGenerationLengths[batchIdx - 1]);
|
||||
auto* outputPtr = packedMask + (outputStartId + tokenIdx) * numPackedMasks;
|
||||
@ -650,8 +654,7 @@ __global__ void getPackedMask(SizeType32 const* __restrict__ cumGenerationLength
|
||||
}
|
||||
else
|
||||
{
|
||||
bool const* maskPtr
|
||||
= mask + batchIdx * maxGenerationLength * maxGenerationLength + tokenIdx * maxGenerationLength;
|
||||
bool const* maskPtr = mask + batchIdx * maxDecodingTokens * maxDecodingTokens + tokenIdx * maxDecodingTokens;
|
||||
extern __shared__ char shMask[];
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < maxGenerationLength;
|
||||
ti += static_cast<SizeType32>(blockDim.x))
|
||||
@ -1071,7 +1074,7 @@ __global__ void packEagleGenerationLengths(PackEagleParams params)
|
||||
|
||||
if (threadIdx.x == 0 && isGenerationRequest)
|
||||
{
|
||||
params.outputSpecDecodingGenerationLengths[genIdx] = params.inputNextDraftLens[batchSlot];
|
||||
params.outputSpecDecodingGenerationLengths[genIdx] = params.inputSpecDecodingGenerationLengths[batchSlot];
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
@ -1101,7 +1104,8 @@ __global__ void packEagleTensors(PackEagleParams params)
|
||||
params.outputRandomDataValidation[batchIdx] = params.inputRandomDataValidation[batchSlot];
|
||||
|
||||
// 0 for ctx request and actual draft len for gen requests.
|
||||
params.outputNextDraftLens[batchIdx] = isGenerationRequest ? params.inputNextDraftLens[batchSlot] : 0;
|
||||
params.outputNextDraftLens[batchIdx]
|
||||
= isGenerationRequest ? params.inputSpecDecodingGenerationLengths[batchSlot] - 1 : 0;
|
||||
}
|
||||
|
||||
// Copy draft paths
|
||||
@ -1186,6 +1190,8 @@ __global__ void unpackEagleData(UnpackEagleDataParams params)
|
||||
// Copy next draft tokens to the slots.
|
||||
params.outputNextDraftTokens[batchSlot * maxDecodingDraftTokens + ti]
|
||||
= params.inputNextDraftTokens[bid * maxDecodingDraftTokens + ti];
|
||||
params.outputUnpackedNextDraftTokens[batchSlot * maxDecodingDraftTokens + ti]
|
||||
= params.inputNextDraftTokens[bid * maxDecodingDraftTokens + ti];
|
||||
}
|
||||
|
||||
for (auto ti = static_cast<SizeType32>(threadIdx.x); ti < params.maxDecodingTokens * params.maxPathLength;
|
||||
@ -1204,6 +1210,8 @@ __global__ void unpackEagleData(UnpackEagleDataParams params)
|
||||
params.outputNumNewTokens[batchSlot] = acceptedLength;
|
||||
// One thread copies next draft len to slot
|
||||
params.outputNextDraftLengths[batchSlot] = params.inputNextDraftLens[bid];
|
||||
// Set next gen len to slot.
|
||||
params.outputNextGenerationLength[batchSlot] = params.inputNextDraftLens[bid] + 1;
|
||||
// Set prev draft lengths needed for kv cache rewind in variable draft len.
|
||||
params.outputPrevDraftLengths[batchSlot] = params.inputLastDraftLens[bid];
|
||||
// Set random data for draft sampling kernels.
|
||||
|
||||
@ -121,7 +121,7 @@ struct PrepareGenEagleNetInputsParams
|
||||
//! output buffer [numOutputTokens]
|
||||
//! Selected tokens ids.
|
||||
runtime::TokenIdType* outputIds{nullptr};
|
||||
//! output buffer [numOutputTokens]
|
||||
//! output buffer [batchSize]
|
||||
//! Position ids of the selected tokens.
|
||||
runtime::SizeType32* positionIds{nullptr};
|
||||
//! output buffer [batchSize]
|
||||
@ -270,8 +270,6 @@ struct PackEagleParams
|
||||
float const* inputRandomDataValidation{nullptr};
|
||||
//! [maxBatchSize, maxDecodingDraftTokens]
|
||||
runtime::TokenIdType const* inputNextDraftTokens{nullptr};
|
||||
//! [maxBatchSize]
|
||||
runtime::SizeType32 const* inputNextDraftLens{nullptr};
|
||||
//! [maxBatchSize, maxDecodingTokens, maxPathLen]
|
||||
runtime::SizeType32 const* inputNextDraftPaths{nullptr};
|
||||
//! [maxBatchSize]
|
||||
@ -315,7 +313,6 @@ struct PackEagleParams
|
||||
TLLM_CHECK(inputRandomDataSample);
|
||||
TLLM_CHECK(inputRandomDataValidation);
|
||||
TLLM_CHECK(inputNextDraftTokens);
|
||||
TLLM_CHECK(inputNextDraftLens);
|
||||
TLLM_CHECK(inputNextDraftPaths);
|
||||
TLLM_CHECK(inputSpecDecodingGenerationLengths);
|
||||
TLLM_CHECK(inputSpecDecodingPositionOffsets);
|
||||
@ -327,9 +324,9 @@ struct PackEagleParams
|
||||
TLLM_CHECK(outputNextDraftTokens);
|
||||
TLLM_CHECK(outputNextDraftLens);
|
||||
TLLM_CHECK(outputNextDraftPaths);
|
||||
TLLM_CHECK(outputSpecDecodingGenerationLengths);
|
||||
TLLM_CHECK(outputSpecDecodingPositionOffsets);
|
||||
TLLM_CHECK(outputSpecDecodingPackedMasks);
|
||||
TLLM_CHECK((numGenerationRequests > 0 && outputSpecDecodingGenerationLengths) || numGenerationRequests == 0);
|
||||
TLLM_CHECK((numGenerationRequests > 0 && outputSpecDecodingPositionOffsets) || numGenerationRequests == 0);
|
||||
TLLM_CHECK((numGenerationRequests > 0 && outputSpecDecodingPackedMasks) || numGenerationRequests == 0);
|
||||
|
||||
TLLM_CHECK(maxGenerationLength);
|
||||
TLLM_CHECK(cumSumGenerationLengths);
|
||||
@ -377,6 +374,8 @@ struct UnpackEagleDataParams
|
||||
//! [maxBatchSize]
|
||||
runtime::SizeType32* outputSequenceLengths{nullptr};
|
||||
//! [maxBatchSize, maxDecodingDraftTokens]
|
||||
runtime::TokenIdType* outputUnpackedNextDraftTokens{nullptr};
|
||||
//! [maxBatchSize, maxDecodingDraftTokens]
|
||||
runtime::TokenIdType* outputNextDraftTokens{nullptr};
|
||||
//! [maxBatchSize]
|
||||
runtime::SizeType32* outputNextDraftLengths{nullptr};
|
||||
@ -384,6 +383,8 @@ struct UnpackEagleDataParams
|
||||
runtime::SizeType32* outputNextDraftPaths{nullptr};
|
||||
//! [maxBatchSize]
|
||||
runtime::SizeType32* outputPrevDraftLengths{nullptr};
|
||||
//! [maxBatchSize]
|
||||
runtime::SizeType32* outputNextGenerationLength{nullptr};
|
||||
//! [maxBatchSize, maxDecodingTokens]
|
||||
runtime::SizeType32* outputPositionIds{nullptr};
|
||||
|
||||
@ -428,10 +429,12 @@ struct UnpackEagleDataParams
|
||||
TLLM_CHECK(outputIds);
|
||||
TLLM_CHECK(outputNumNewTokens);
|
||||
TLLM_CHECK(outputSequenceLengths);
|
||||
TLLM_CHECK(outputUnpackedNextDraftTokens);
|
||||
TLLM_CHECK(outputNextDraftTokens);
|
||||
TLLM_CHECK(outputNextDraftLengths);
|
||||
TLLM_CHECK(outputNextDraftPaths);
|
||||
TLLM_CHECK(outputPrevDraftLengths);
|
||||
TLLM_CHECK(outputNextGenerationLength);
|
||||
TLLM_CHECK(outputPositionIds);
|
||||
|
||||
TLLM_CHECK(outputRandDataSample);
|
||||
|
||||
1515
cpp/tensorrt_llm/kernels/topkLastDim.cu
Normal file
1515
cpp/tensorrt_llm/kernels/topkLastDim.cu
Normal file
File diff suppressed because it is too large
Load Diff
38
cpp/tensorrt_llm/kernels/topkLastDim.h
Normal file
38
cpp/tensorrt_llm/kernels/topkLastDim.h
Normal file
@ -0,0 +1,38 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template <typename T>
|
||||
size_t invokeComputeTopkLastDimWorkspaceSize(
|
||||
runtime::SizeType32 batchSize, runtime::SizeType32 inputLength, runtime::SizeType32 k, bool is_largest);
|
||||
|
||||
template <typename T>
|
||||
void invokeTopkLastDim(runtime::SizeType32 batchSize, runtime::SizeType32 inputLength, runtime::SizeType32 k,
|
||||
bool is_largest, void const* __restrict__ input, void* __restrict__ out_val, void* __restrict__ out_ind,
|
||||
void* workspace, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -681,27 +681,33 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
//! Unpacked draft tokens
|
||||
TensorPtr unpackedNextDraftTokens; // [maxBatchSize, maxDecodingDraftTokens] on gpu
|
||||
//! Draft paths for the next iteration.
|
||||
TensorPtr nextDraftPaths; // [maxBatchSize, maxDecodingTokens, maxPathLen]
|
||||
TensorPtr nextDraftPaths; // [maxBatchSize, maxDecodingTokens, maxPathLen] on gpu
|
||||
//! Randomly sampled data (between 0.f and 1.f)
|
||||
TensorPtr randomDataSample; // [maxBatchSize] on gpu
|
||||
//! Randomly sampled data (between 0.f and 1.f)
|
||||
TensorPtr randomDataValidation; // [maxBatchSize] on gpu
|
||||
//! Sampling temperature.
|
||||
TensorPtr temperatures; // [maxBatchSize] on gpu
|
||||
//! Next generation lengths.
|
||||
TensorPtr generationLengths; // [maxBatchSize] on gpu
|
||||
//! Next generation lengths.
|
||||
TensorPtr generationLengthsHost; // [maxBatchSize] on pinned
|
||||
|
||||
//! Request types for ctx stage of the EagleNet0 (filled with 0s).
|
||||
TensorPtr eagleNetCtxRequestTypesHost; //! [maxBatchSize]
|
||||
TensorPtr eagleNetCtxRequestTypesHost; //! [maxBatchSize] on pinned
|
||||
//! Context lengths of the context EagleNet0.
|
||||
TensorPtr eagleNetCtxContextLengthsHost; //! [maxBatchSize]
|
||||
TensorPtr eagleNetCtxContextLengthsHost; //! [maxBatchSize] on pinned
|
||||
//! Past kv lengths of the context EagleNet0.
|
||||
TensorPtr eagleNetCtxPastKeyValueLengthsHost; //! [maxBatchSize]
|
||||
TensorPtr eagleNetCtxPastKeyValueLengthsHost; //! [maxBatchSize] on pinned
|
||||
//! Request types for ctx stage of the EagleNetX (filled with 1s).
|
||||
TensorPtr eagleNetGenRequestTypesHost; //! [maxBatchSize]
|
||||
TensorPtr eagleNetGenRequestTypesHost; //! [maxBatchSize] on pinned
|
||||
//! Context lengths of the generation EagleNetX.
|
||||
TensorPtr eagleNetGenContextLengthsHost; //! [maxBatchSize]
|
||||
TensorPtr eagleNetGenContextLengthsHost; //! [maxBatchSize] on pinned
|
||||
//! Past kv lengths of the generation EagleNetX.
|
||||
TensorPtr eagleNetGenPastKeyValueLengthsHost; //! [maxBatchSize]
|
||||
TensorPtr eagleNetGenPastKeyValueLengthsHost; //! [maxBatchSize] on pinned
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -160,8 +160,9 @@ void DynamicDecodeLayer<T>::forwardAsync(std::shared_ptr<BaseDecodingOutputs> co
|
||||
|
||||
auto params = std::dynamic_pointer_cast<DecodingInputs>(baseInputs);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(mDecodingMode.isExplicitDraftTokens() || params->logits || params->logitsVec,
|
||||
"If not explicit Draft Tokens mode, either logits or logitsVec have to be specified.");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
mDecodingMode.isExplicitDraftTokens() || mDecodingMode.isEagle() || params->logits || params->logitsVec,
|
||||
"If not Explicit Draft Tokens or Eagle mode, either logits or logitsVec have to be specified.");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
baseOutputs->sequenceLength.has_value(), "sequenceLength tensor is required in DynamicDecoderLayer.");
|
||||
|
||||
|
||||
@ -155,8 +155,12 @@ void EagleDecodingLayer<T>::unpackData(EagleOutputs const& outputs, EagleInputs
|
||||
params.outputIds = bufferCast<TokenIdType>(*outputs.outputIds);
|
||||
params.outputNumNewTokens = bufferCast<SizeType32>(*outputs.numNewTokens.value());
|
||||
params.outputSequenceLengths = bufferCast<SizeType32>(*outputs.sequenceLength.value());
|
||||
// FIXME outputUnpackedNextDraftTokens is the same as outputNextDraftTokens.
|
||||
// outputUnpackedNextDraftTokens is used in eagleBuffers and outputNextDraftTokens is used in the runtime
|
||||
params.outputUnpackedNextDraftTokens = bufferCast<TokenIdType>(*outputs.unpackedNextDraftTokens);
|
||||
params.outputNextDraftTokens = bufferCast<TokenIdType>(*outputs.nextDraftTokens);
|
||||
params.outputNextDraftLengths = bufferCast<SizeType32>(*outputs.nextDraftLengths);
|
||||
params.outputNextGenerationLength = bufferCast<SizeType32>(*outputs.generationLengths);
|
||||
params.outputNextDraftPaths = bufferCast<SizeType32>(*outputs.nextDraftPaths);
|
||||
params.outputPrevDraftLengths = bufferCast<SizeType32>(*outputs.prevDraftLengths);
|
||||
params.outputPositionIds = bufferCast<SizeType32>(*outputs.nextDraftPosIds);
|
||||
@ -189,6 +193,8 @@ void EagleDecodingLayer<T>::unpackData(EagleOutputs const& outputs, EagleInputs
|
||||
mBufferManager->copy(*mEagleNetGenContextLengths, *outputs.eagleNetGenContextLengthsHost);
|
||||
mBufferManager->copy(*mEagleNetGenPastKeyValueLengths, *outputs.eagleNetGenPastKeyValueLengthsHost);
|
||||
|
||||
mBufferManager->copy(*outputs.generationLengths, *outputs.generationLengthsHost);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
|
||||
@ -148,10 +148,10 @@ void ExternalDraftTokensLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWid
|
||||
{
|
||||
auto initWorkspaceSizes = getTopKInitWorkspaceSizes(batchSize);
|
||||
calcAlignedPointers(workspace->getRawWorkspaceDevicePtr(), initWorkspaceSizes)(topKsPtr, topPsPtr);
|
||||
DecodingLayerWorkspace::copyToWorkspace(*mBufferManager, runtimeTopK,
|
||||
ITensor::wrap(topKsPtr, {1, {static_cast<ITensor::DimType64>(initWorkspaceSizes[0])}}));
|
||||
DecodingLayerWorkspace::copyToWorkspace(*mBufferManager, runtimeTopP,
|
||||
ITensor::wrap(topPsPtr, {1, {static_cast<ITensor::DimType64>(initWorkspaceSizes[1])}}));
|
||||
DecodingLayerWorkspace::copyToWorkspace(
|
||||
*mBufferManager, runtimeTopK, IBuffer::wrap(topKsPtr, initWorkspaceSizes[0] / sizeof(*topKsPtr)));
|
||||
DecodingLayerWorkspace::copyToWorkspace(
|
||||
*mBufferManager, runtimeTopP, IBuffer::wrap(topPsPtr, initWorkspaceSizes[1] / sizeof(*topPsPtr)));
|
||||
}
|
||||
auto const* batchSlotsDevicePtr = workspace->getDeviceBatchSlotsPtr();
|
||||
auto* skipTopKDecodeDevicePtr = bufferCastOrNull<bool>(mSkipTopKDecodeDevice);
|
||||
|
||||
@ -60,7 +60,7 @@ static std::vector<std::unique_ptr<BaseLayer>> createLayers(executor::DecodingMo
|
||||
std::vector<std::unique_ptr<BaseLayer>> layers;
|
||||
auto layerTypes = createDecodingLayerTypes(mode);
|
||||
// Only when draft tokens and predicted and decoded by the engine, we can skip penalty layer.
|
||||
if (!mode.isExplicitDraftTokens())
|
||||
if (!mode.isExplicitDraftTokens() && !mode.isEagle())
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(layerTypes.size() && layerTypes[0] == DecodingLayers_t::PENALTY_LAYER,
|
||||
"Penalty layer is required to be the first layer for any decoder configuration");
|
||||
|
||||
@ -61,7 +61,7 @@ LookaheadAlgorithm::LookaheadAlgorithm(
|
||||
mEncodeMapMax = runtime::BufferManager::cpu(runtime::ITensor::makeShape({maxDraftLen}), nvinfer1::DataType::kINT32);
|
||||
}
|
||||
|
||||
void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeType32 n, SizeType32 g)
|
||||
void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeType32 n, SizeType32 g, uint64_t seed)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
TLLM_CHECK_WITH_INFO(w <= mMaxW, "lookahead requires setup w (%d) <= max_w (%d)", w, mMaxW);
|
||||
@ -85,6 +85,9 @@ void LookaheadAlgorithm::setup(TensorConstPtr const& prompt, SizeType32 w, SizeT
|
||||
BufferRange<TokenIdType> prefillRange(*mPrefills);
|
||||
BufferRange<TokenIdType> pastRange(*mPastTokens);
|
||||
BufferRange<TokenIdType> goldRange(*mGoldenTokens);
|
||||
|
||||
srand(seed);
|
||||
|
||||
auto randToken = [&promptRange](auto& item) { item = promptRange[rand() % promptRange.size()]; };
|
||||
std::for_each(prefillRange.begin(), prefillRange.end(), randToken);
|
||||
std::for_each(pastRange.begin(), pastRange.end(), [](auto& a) { a = -1; });
|
||||
|
||||
@ -39,7 +39,8 @@ public:
|
||||
runtime::SizeType32 maxW, runtime::SizeType32 maxN, runtime::SizeType32 maxG, runtime::SizeType32 id = 0);
|
||||
|
||||
//! @brief setup per request, fill internal states from @param prompt.
|
||||
void setup(TensorConstPtr const& prompt, runtime::SizeType32 w, runtime::SizeType32 n, runtime::SizeType32 g);
|
||||
void setup(TensorConstPtr const& prompt, runtime::SizeType32 w, runtime::SizeType32 n, runtime::SizeType32 g,
|
||||
uint64_t seed);
|
||||
|
||||
//! @brief accept the new generated tokens.
|
||||
//! LookaheadDecodingLayer need call once for the first token in generation phase.
|
||||
|
||||
@ -163,7 +163,13 @@ void LookaheadDecodingLayer<T>::setup(SizeType32 batchSize, SizeType32 beamWidth
|
||||
"runtime w(%d) n(%d) g(%d) exceeds maxTokensPerStep(%d)", w, n, g,
|
||||
mDecoderDomain.getMaxDecodingTokens());
|
||||
PRINT_VALUES(mCpuAlgo->mPrompts[bi]);
|
||||
mCpuAlgo->mAlgos[gbi].setup(mCpuAlgo->mPrompts[bi], w, n, g);
|
||||
auto seed = DefaultDecodingParams::getSeed();
|
||||
if (setupParams->randomSeed)
|
||||
{
|
||||
auto& seeds = setupParams->randomSeed.value();
|
||||
seed = seeds.size() == 1 ? seeds[0] : seeds[bi];
|
||||
}
|
||||
mCpuAlgo->mAlgos[gbi].setup(mCpuAlgo->mPrompts[bi], w, n, g, seed);
|
||||
}
|
||||
|
||||
for (runtime::SizeType32 bi = 0; bi < batchSize; bi++)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user