mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#1427)
* Update TensorRT-LLM --------- Co-authored-by: meghagarwal <16129366+megha95@users.noreply.github.com>
This commit is contained in:
parent
118b3d7e7b
commit
035b99e0d0
1
.gitattributes
vendored
1
.gitattributes
vendored
@ -1 +1,2 @@
|
||||
*.a filter=lfs diff=lfs merge=lfs -text
|
||||
*.lib filter=lfs diff=lfs merge=lfs -text
|
||||
|
||||
@ -83,6 +83,7 @@ python3 prepare_dataset.py \
|
||||
[--time-delay-dist exponential_dist] \
|
||||
dataset
|
||||
--dataset-name <name of the dataset> \
|
||||
--dataset-split <split of the dataset to use> \
|
||||
--dataset-input-key <dataset dictionary key for input> \
|
||||
--dataset-prompt-key <dataset dictionary key for prompt> \
|
||||
--dataset-output-key <dataset dictionary key for output> \
|
||||
@ -99,6 +100,7 @@ python3 prepare_dataset.py \
|
||||
--output cnn_dailymail.json
|
||||
dataset
|
||||
--dataset-name cnn_dailymail \
|
||||
--dataset-split validation \
|
||||
--dataset-config-name 3.0.0 \
|
||||
--dataset-input-key article \
|
||||
--dataset-prompt "Summarize the following article:" \
|
||||
|
||||
@ -95,8 +95,10 @@ private:
|
||||
std::map<uint64_t, std::pair<TensorPtr, TensorPtr>> loras;
|
||||
for (auto const& [id, p] : taskPaths)
|
||||
{
|
||||
TensorPtr loraWeights = utils::loadNpy(mBufferManager, p / "model.lora_weights.npy", MemoryType::kCPU);
|
||||
TensorPtr loraConfig = utils::loadNpy(mBufferManager, p / "model.lora_config.npy", MemoryType::kCPU);
|
||||
TensorPtr loraWeights
|
||||
= utils::loadNpy(mBufferManager, (p / "model.lora_weights.npy").string(), MemoryType::kCPU);
|
||||
TensorPtr loraConfig
|
||||
= utils::loadNpy(mBufferManager, (p / "model.lora_config.npy").string(), MemoryType::kCPU);
|
||||
loras.insert_or_assign(id, std::make_pair(loraWeights, loraConfig));
|
||||
}
|
||||
return loras;
|
||||
@ -136,17 +138,21 @@ private:
|
||||
|
||||
struct BenchmarkParams
|
||||
{
|
||||
std::optional<SizeType> maxTokensInPagedKvCache = std::nullopt;
|
||||
std::optional<float> freeGpuMemoryFraction = std::nullopt;
|
||||
bool enableTrtOverlap = false;
|
||||
bool enableBlockReuse = false;
|
||||
bool enableChunkedContext = false;
|
||||
bool streaming = false;
|
||||
std::optional<SizeType> maxTokensInPagedKvCache{std::nullopt};
|
||||
std::optional<float> freeGpuMemoryFraction{std::nullopt};
|
||||
bool enableTrtOverlap{false};
|
||||
bool enableBlockReuse{false};
|
||||
bool enableChunkedContext{false};
|
||||
bool streaming{false};
|
||||
|
||||
// lora / peft params
|
||||
std::optional<std::string> loraDir = std::nullopt;
|
||||
SizeType loraDeviceNumModLayers = 0;
|
||||
size_t loraHostCacheSize = 1024 * 2024 * 1024;
|
||||
std::optional<std::string> loraDir{std::nullopt};
|
||||
SizeType loraDeviceNumModLayers{0};
|
||||
size_t loraHostCacheSize{1024 * 2024 * 1024};
|
||||
|
||||
// KV cache block offloading
|
||||
size_t kvHostCacheSize{0};
|
||||
bool kvOnboardBlocks{true};
|
||||
};
|
||||
} // namespace
|
||||
|
||||
@ -289,6 +295,8 @@ struct BenchInfo
|
||||
, outputLength(_outputLength)
|
||||
, start(_start)
|
||||
, latency()
|
||||
, firstTokenLatency()
|
||||
, avgGenT2TLatency()
|
||||
{
|
||||
}
|
||||
|
||||
@ -296,7 +304,12 @@ struct BenchInfo
|
||||
int outputLength;
|
||||
std::chrono::time_point<std::chrono::steady_clock> start;
|
||||
std::chrono::time_point<std::chrono::steady_clock> end;
|
||||
std::chrono::time_point<std::chrono::steady_clock> firstTokenTs;
|
||||
float latency; // millisecond
|
||||
bool hasError;
|
||||
float firstTokenLatency;
|
||||
float avgGenT2TLatency;
|
||||
bool firstTokenSeen = false;
|
||||
};
|
||||
|
||||
class Recorder
|
||||
@ -304,8 +317,10 @@ class Recorder
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
|
||||
public:
|
||||
explicit Recorder(std::string opCsvFile, std::string responsesJsonFile = "", bool excludeInputInOutput = false)
|
||||
explicit Recorder(std::string opCsvFile, bool streaming = false, std::string responsesJsonFile = "",
|
||||
bool excludeInputInOutput = false)
|
||||
: mOpCsvFile(std::move(opCsvFile))
|
||||
, mStreaming(streaming)
|
||||
, mRespJsonFile(std::move(responsesJsonFile))
|
||||
, mOutputHasInput(!excludeInputInOutput)
|
||||
{
|
||||
@ -343,17 +358,28 @@ public:
|
||||
mRequestBenchInfos[requestId] = BenchInfo(inputLength, maxNewTokens, start);
|
||||
}
|
||||
|
||||
void recordEnd(uint64_t requestId)
|
||||
void recordEnd(uint64_t requestId, bool hasError)
|
||||
{
|
||||
mRequestBenchInfos[requestId].end = std::chrono::steady_clock::now();
|
||||
mRequestBenchInfos[requestId].latency = std::chrono::duration<float, std::milli>(
|
||||
mRequestBenchInfos[requestId].end - mRequestBenchInfos[requestId].start)
|
||||
.count();
|
||||
mRequestBenchInfos[requestId].hasError = hasError;
|
||||
}
|
||||
|
||||
void recordEnd(uint64_t requestId, std::list<NamedTensor> const& responseTensors)
|
||||
void recordToken(uint64_t requestId)
|
||||
{
|
||||
this->recordEnd(requestId);
|
||||
if (mRequestBenchInfos[requestId].firstTokenSeen)
|
||||
{
|
||||
return;
|
||||
}
|
||||
else
|
||||
{
|
||||
mRequestBenchInfos[requestId].firstTokenTs = std::chrono::steady_clock::now();
|
||||
mRequestBenchInfos[requestId].firstTokenSeen = true;
|
||||
}
|
||||
}
|
||||
|
||||
void recordEnd(uint64_t requestId, std::list<NamedTensor> const& responseTensors, bool hasError)
|
||||
{
|
||||
this->recordEnd(requestId, hasError);
|
||||
|
||||
if (mRespJsonFile.empty())
|
||||
return;
|
||||
@ -385,50 +411,152 @@ public:
|
||||
return latencies[index];
|
||||
}
|
||||
|
||||
void calculateLatencies()
|
||||
{
|
||||
for (auto& reqInfo : mRequestBenchInfos)
|
||||
{
|
||||
reqInfo.second.latency
|
||||
= std::chrono::duration<float, std::milli>(reqInfo.second.end - reqInfo.second.start).count();
|
||||
if (mStreaming)
|
||||
{
|
||||
reqInfo.second.firstTokenLatency
|
||||
= std::chrono::duration<float, std::milli>(reqInfo.second.firstTokenTs - reqInfo.second.start)
|
||||
.count();
|
||||
reqInfo.second.avgGenT2TLatency
|
||||
= std::chrono::duration<float, std::milli>(reqInfo.second.end - reqInfo.second.firstTokenTs).count()
|
||||
/ (reqInfo.second.outputLength - 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void calculateMetrics()
|
||||
{
|
||||
mNumSamples = mRequestBenchInfos.size();
|
||||
mTotalLatency = std::chrono::duration<float, std::milli>(mEnd - mStart).count();
|
||||
|
||||
mSeqThroughput = mNumSamples / (mTotalLatency / 1000);
|
||||
mAvgSeqLatency = 0;
|
||||
mAvgSeqLatency = mAvgFtLatency = mAvgGenT2TLatency = 0;
|
||||
|
||||
calculateLatencies();
|
||||
|
||||
std::vector<float> reqLatencies;
|
||||
std::vector<float> ftLatencies;
|
||||
std::vector<float> genT2TLatencies;
|
||||
|
||||
int totalOutputTokens = 0;
|
||||
mNumErrorSamples = 0;
|
||||
mNumSamples = 0;
|
||||
for (auto reqInfo : mRequestBenchInfos)
|
||||
{
|
||||
mAvgSeqLatency += reqInfo.second.latency;
|
||||
reqLatencies.push_back(reqInfo.second.latency);
|
||||
totalOutputTokens += reqInfo.second.outputLength;
|
||||
if (!reqInfo.second.hasError)
|
||||
{
|
||||
mAvgSeqLatency += reqInfo.second.latency;
|
||||
reqLatencies.push_back(reqInfo.second.latency);
|
||||
totalOutputTokens += reqInfo.second.outputLength;
|
||||
|
||||
if (mStreaming)
|
||||
{
|
||||
mAvgFtLatency += reqInfo.second.firstTokenLatency;
|
||||
mAvgGenT2TLatency += reqInfo.second.avgGenT2TLatency;
|
||||
ftLatencies.push_back(reqInfo.second.firstTokenLatency);
|
||||
genT2TLatencies.push_back(reqInfo.second.avgGenT2TLatency);
|
||||
}
|
||||
++mNumSamples;
|
||||
}
|
||||
else
|
||||
{
|
||||
++mNumErrorSamples;
|
||||
}
|
||||
}
|
||||
|
||||
mSeqThroughput = mNumSamples / (mTotalLatency / 1000);
|
||||
mAvgSeqLatency /= mNumSamples;
|
||||
mTokenThroughput = totalOutputTokens / (mTotalLatency / 1000);
|
||||
|
||||
std::sort(reqLatencies.begin(), reqLatencies.end());
|
||||
|
||||
mP99SeqLatency = calcPercentile(reqLatencies, 99);
|
||||
mP90SeqLatency = calcPercentile(reqLatencies, 90);
|
||||
mP50SeqLatency = calcPercentile(reqLatencies, 50);
|
||||
mMaxSeqLatency = reqLatencies.back();
|
||||
mMinSeqLatency = reqLatencies.front();
|
||||
|
||||
if (mStreaming)
|
||||
{
|
||||
mAvgFtLatency /= mNumSamples;
|
||||
mAvgGenT2TLatency /= mNumSamples;
|
||||
|
||||
std::sort(ftLatencies.begin(), ftLatencies.end());
|
||||
std::sort(genT2TLatencies.begin(), genT2TLatencies.end());
|
||||
|
||||
mP99FtLatency = calcPercentile(ftLatencies, 99);
|
||||
mP90FtLatency = calcPercentile(ftLatencies, 90);
|
||||
mP50FtLatency = calcPercentile(ftLatencies, 50);
|
||||
mMaxFtLatency = ftLatencies.back();
|
||||
mMinFtLatency = ftLatencies.front();
|
||||
|
||||
mP99GenT2TLatency = calcPercentile(genT2TLatencies, 99);
|
||||
mP90GenT2TLatency = calcPercentile(genT2TLatencies, 90);
|
||||
mP50GenT2TLatency = calcPercentile(genT2TLatencies, 50);
|
||||
mMaxGenT2TLatency = genT2TLatencies.back();
|
||||
mMinGenT2TLatency = genT2TLatencies.front();
|
||||
}
|
||||
}
|
||||
|
||||
void report()
|
||||
{
|
||||
|
||||
printf("[BENCHMARK] num_samples %d\n", mNumSamples);
|
||||
printf("[BENCHMARK] num_error_samples %d\n", mNumErrorSamples);
|
||||
printf("\n[BENCHMARK] num_samples %d\n", mNumSamples);
|
||||
printf("[BENCHMARK] total_latency(ms) %.2f\n", mTotalLatency);
|
||||
printf("[BENCHMARK] seq_throughput(seq/sec) %.2f\n", mSeqThroughput);
|
||||
printf("[BENCHMARK] token_throughput(token/sec) %.2f\n", mTokenThroughput);
|
||||
printf("[BENCHMARK] token_throughput(token/sec) %.2f\n\n", mTokenThroughput);
|
||||
|
||||
printf("[BENCHMARK] avg_sequence_latency(ms) %.2f\n", mAvgSeqLatency);
|
||||
printf("[BENCHMARK] max_sequence_latency(ms) %.2f\n", mMaxSeqLatency);
|
||||
printf("[BENCHMARK] min_sequence_latency(ms) %.2f\n", mMinSeqLatency);
|
||||
printf("[BENCHMARK] p99_sequence_latency(ms) %.2f\n", mP99SeqLatency);
|
||||
printf("[BENCHMARK] p90_sequence_latency(ms) %.2f\n", mP90SeqLatency);
|
||||
printf("[BENCHMARK] p50_sequence_latency(ms) %.2f\n", mP50SeqLatency);
|
||||
printf("[BENCHMARK] p50_sequence_latency(ms) %.2f\n\n", mP50SeqLatency);
|
||||
|
||||
if (mStreaming)
|
||||
{
|
||||
printf("[BENCHMARK] avg_time_to_first_token(ms) %.2f\n", mAvgFtLatency);
|
||||
printf("[BENCHMARK] max_time_to_first_token(ms) %.2f\n", mMaxFtLatency);
|
||||
printf("[BENCHMARK] min_time_to_first_token(ms) %.2f\n", mMinFtLatency);
|
||||
printf("[BENCHMARK] p99_time_to_first_token(ms) %.2f\n", mP99FtLatency);
|
||||
printf("[BENCHMARK] p90_time_to_first_token(ms) %.2f\n", mP90FtLatency);
|
||||
printf("[BENCHMARK] p50_time_to_first_token(ms) %.2f\n\n", mP50FtLatency);
|
||||
|
||||
printf("[BENCHMARK] avg_inter_token_latency(ms) %.2f\n", mAvgGenT2TLatency);
|
||||
printf("[BENCHMARK] max_inter_token_latency(ms) %.2f\n", mMaxGenT2TLatency);
|
||||
printf("[BENCHMARK] min_inter_token_latency(ms) %.2f\n", mMinGenT2TLatency);
|
||||
printf("[BENCHMARK] p99_inter_token_latency(ms) %.2f\n", mP99GenT2TLatency);
|
||||
printf("[BENCHMARK] p90_inter_token_latency(ms) %.2f\n", mP90GenT2TLatency);
|
||||
printf("[BENCHMARK] p50_inter_token_latency(ms) %.2f\n\n", mP50GenT2TLatency);
|
||||
}
|
||||
}
|
||||
|
||||
void writeOpMetricsToCsv()
|
||||
{
|
||||
if (!mOpCsvFile.empty())
|
||||
{
|
||||
std::vector<std::string> headers = {"num_samples", "total_latency(ms)", "seq_throughput(seq/sec)",
|
||||
"token_throughput(token/sec)", "avg_sequence_latency(ms)", "p99_sequence_latency(ms)",
|
||||
std::vector<std::string> headers = {"num_samples", "num_error_samples", "total_latency(ms)",
|
||||
"seq_throughput(seq/sec)", "token_throughput(token/sec)", "avg_sequence_latency(ms)",
|
||||
"max_sequence_latency(ms)", "min_sequence_latency(ms)", "p99_sequence_latency(ms)",
|
||||
"p90_sequence_latency(ms)", "p50_sequence_latency(ms)"};
|
||||
|
||||
if (mStreaming)
|
||||
{
|
||||
std::vector<std::string> streamingHeaders
|
||||
= {"avg_time_to_first_token(ms)", "max_time_to_first_token(ms)", "min_time_to_first_token(ms)",
|
||||
"p99_time_to_first_token(ms)", "p90_time_to_first_token(ms)", "p50_time_to_first_token(ms)",
|
||||
"avg_inter_token_latency(ms)", "max_inter_token_latency(ms)", "min_inter_token_latency(ms)",
|
||||
"p99_inter_token_latency(ms)", "p90_inter_token_latency(ms)", "p50_inter_token_latency(ms)"};
|
||||
|
||||
headers.insert(headers.end(), streamingHeaders.begin(), streamingHeaders.end());
|
||||
}
|
||||
|
||||
std::ofstream outputFile(mOpCsvFile);
|
||||
|
||||
if (outputFile.is_open())
|
||||
@ -438,9 +566,17 @@ public:
|
||||
outputFile << header << ",";
|
||||
}
|
||||
outputFile << "\n";
|
||||
outputFile << mNumSamples << "," << mTotalLatency << "," << mSeqThroughput << "," << mTokenThroughput
|
||||
<< "," << mAvgSeqLatency << "," << mP99SeqLatency << "," << mP90SeqLatency << ","
|
||||
<< mP50SeqLatency;
|
||||
outputFile << mNumSamples << "," << mNumErrorSamples << "," << mTotalLatency << "," << mSeqThroughput
|
||||
<< "," << mTokenThroughput << "," << mAvgSeqLatency << "," << mMaxSeqLatency << ","
|
||||
<< mMinSeqLatency << "," << mP99SeqLatency << "," << mP90SeqLatency << "," << mP50SeqLatency;
|
||||
if (mStreaming)
|
||||
{
|
||||
outputFile << "," << mAvgFtLatency << "," << mMaxFtLatency << "," << mMinFtLatency << ","
|
||||
<< mP99FtLatency << "," << mP90FtLatency << "," << mP50FtLatency << ","
|
||||
<< mAvgGenT2TLatency << "," << mMaxGenT2TLatency << "," << mMinGenT2TLatency << ","
|
||||
<< mP99GenT2TLatency << "," << mP90GenT2TLatency << "," << mP50GenT2TLatency;
|
||||
}
|
||||
|
||||
outputFile << "\n";
|
||||
}
|
||||
else
|
||||
@ -482,14 +618,31 @@ private:
|
||||
std::chrono::time_point<std::chrono::steady_clock> mStart;
|
||||
std::chrono::time_point<std::chrono::steady_clock> mEnd;
|
||||
int mNumSamples{};
|
||||
int mNumErrorSamples{};
|
||||
float mTotalLatency{};
|
||||
float mSeqThroughput{};
|
||||
float mAvgSeqLatency{};
|
||||
float mAvgGenT2TLatency{};
|
||||
float mAvgFtLatency{};
|
||||
float mTokenThroughput{};
|
||||
float mP99SeqLatency{};
|
||||
float mP90SeqLatency{};
|
||||
float mP50SeqLatency{};
|
||||
float mMaxSeqLatency{};
|
||||
float mMinSeqLatency{};
|
||||
float mP99FtLatency{};
|
||||
float mP90FtLatency{};
|
||||
float mP50FtLatency{};
|
||||
float mMaxFtLatency{};
|
||||
float mMinFtLatency{};
|
||||
float mP99GenT2TLatency{};
|
||||
float mP90GenT2TLatency{};
|
||||
float mP50GenT2TLatency{};
|
||||
float mMaxGenT2TLatency{};
|
||||
float mMinGenT2TLatency{};
|
||||
|
||||
std::string mOpCsvFile;
|
||||
bool mStreaming;
|
||||
std::string mRespJsonFile;
|
||||
std::unordered_map<uint64_t, TensorPtr> mResponseTensors;
|
||||
bool mOutputHasInput;
|
||||
@ -512,12 +665,15 @@ public:
|
||||
|
||||
texec::SchedulerConfig schedulerConfig(batch_scheduler::batchManagerToExecSchedPolicy(schedulerPolicy));
|
||||
texec::KvCacheConfig kvCacheConfig(benchmarkParams.enableBlockReuse, benchmarkParams.maxTokensInPagedKvCache,
|
||||
std::nullopt, std::nullopt, benchmarkParams.freeGpuMemoryFraction);
|
||||
std::nullopt, std::nullopt, benchmarkParams.freeGpuMemoryFraction, benchmarkParams.kvHostCacheSize,
|
||||
benchmarkParams.kvOnboardBlocks);
|
||||
texec::PeftCacheConfig peftCacheConfig(0, benchmarkParams.loraDeviceNumModLayers, 8, 64, 4, 4, 4, 24, 8,
|
||||
std::nullopt, benchmarkParams.loraHostCacheSize);
|
||||
texec::ExecutorConfig executorConfig(
|
||||
maxBeamWidth, schedulerConfig, kvCacheConfig, benchmarkParams.enableChunkedContext, true);
|
||||
executorConfig.setPeftCacheConfig(peftCacheConfig);
|
||||
executorConfig.setBatchingType(
|
||||
modelType == TrtGptModelType::V1 ? texec::BatchingType::kSTATIC : texec::BatchingType::kINFLIGHT);
|
||||
|
||||
mExecutor = std::make_unique<texec::Executor>(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig);
|
||||
|
||||
@ -572,21 +728,24 @@ public:
|
||||
auto responses = mExecutor->awaitResponses(std::nullopt, mWaitSleep);
|
||||
for (auto const& response : responses)
|
||||
{
|
||||
if (response.hasError())
|
||||
{
|
||||
// This request failed for some reason, get error msg
|
||||
std::string errStr = "Request id " + std::to_string(response.getRequestId()) + " failed with err "
|
||||
+ response.getErrorMsg();
|
||||
TLLM_THROW(errStr);
|
||||
}
|
||||
else if (response.getResult().isFinal)
|
||||
if (response.hasError() || response.getResult().isFinal)
|
||||
{
|
||||
auto reqId = response.getRequestId();
|
||||
mActiveCount--;
|
||||
numFinished++;
|
||||
if (!warmup)
|
||||
if (response.getResult().isFinal)
|
||||
{
|
||||
mRecorder->recordEnd(reqId);
|
||||
mActiveCount--;
|
||||
numFinished++;
|
||||
if (!warmup)
|
||||
{
|
||||
mRecorder->recordEnd(reqId, response.hasError());
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!warmup)
|
||||
{
|
||||
mRecorder->recordToken(reqId);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -818,9 +977,13 @@ public:
|
||||
if (final_response)
|
||||
{
|
||||
mWorkItemsQueue.markFinished(requestId);
|
||||
mRecorder->recordEnd(requestId, response_tensors);
|
||||
mRecorder->recordEnd(requestId, response_tensors, !errMsg.empty());
|
||||
mActiveCount--;
|
||||
}
|
||||
else
|
||||
{
|
||||
mRecorder->recordToken(requestId);
|
||||
}
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
@ -854,7 +1017,8 @@ struct Sample
|
||||
|
||||
using Samples = std::vector<Sample>;
|
||||
|
||||
Samples parseWorkloadJson(std::filesystem::path const& datasetPath, int maxNumSamples)
|
||||
Samples parseWorkloadJson(
|
||||
std::filesystem::path const& datasetPath, int maxNumSamples, std::optional<SizeType> const maxPromptLen)
|
||||
{
|
||||
auto constexpr allowExceptions = true;
|
||||
auto constexpr ignoreComments = true;
|
||||
@ -869,12 +1033,17 @@ Samples parseWorkloadJson(std::filesystem::path const& datasetPath, int maxNumSa
|
||||
if (samples.size() >= maxNumSamples)
|
||||
break;
|
||||
int32_t taskId = sample.count("task_id") ? sample["task_id"].template get<int32_t>() : -1;
|
||||
samples.emplace_back(Sample{sample["input_ids"], sample["output_len"], sample["delay"], taskId});
|
||||
auto input_ids(sample["input_ids"].template get<std::vector<int32_t>>());
|
||||
if (maxPromptLen && (input_ids.size() > maxPromptLen.value()))
|
||||
{
|
||||
input_ids.resize(maxPromptLen.value());
|
||||
}
|
||||
samples.emplace_back(Sample{std::move(input_ids), sample["output_len"], sample["delay"], taskId});
|
||||
}
|
||||
return samples;
|
||||
}
|
||||
|
||||
std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const& sample,
|
||||
std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const& sample, bool streaming,
|
||||
ITensor::SharedPtr const& beamWidthTensor, ITensor::SharedPtr const& eosId, ITensor::SharedPtr const& padId,
|
||||
BufferManager const& bufferManager, ITensor::SharedPtr const& returnContextLogits = nullptr,
|
||||
ITensor::SharedPtr const& returnGenerationLogits = nullptr, ITensor::SharedPtr const& loraWeights = nullptr,
|
||||
@ -916,6 +1085,10 @@ std::shared_ptr<InferenceRequest> makeRequest(std::uint64_t reqId, Sample const&
|
||||
{
|
||||
request->setLoraConfig(loraConfig);
|
||||
}
|
||||
if (streaming)
|
||||
{
|
||||
request->setIsStreaming(true);
|
||||
}
|
||||
return request;
|
||||
}
|
||||
|
||||
@ -941,7 +1114,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
BenchmarkParams const& benchmarkParams, batch_scheduler::SchedulerPolicy schedulerPolicy,
|
||||
std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits,
|
||||
std::optional<SizeType> const staticEmulatedBatchSize, std::optional<std::chrono::milliseconds> const batchTimeout,
|
||||
bool logIterationData, bool excludeInputInOutput, std::string const& responsesJsonFile)
|
||||
bool logIterationData, bool excludeInputInOutput, std::string const& responsesJsonFile,
|
||||
std::optional<SizeType> const maxPromptLen)
|
||||
{
|
||||
auto const worldConfig = WorldConfig::mpi();
|
||||
|
||||
@ -963,6 +1137,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
optionalParams.peftCacheManagerConfig.numPutWorkers = 4;
|
||||
optionalParams.peftCacheManagerConfig.numEnsureWorkers = 4;
|
||||
optionalParams.peftCacheManagerConfig.numCopyStreams = 4;
|
||||
optionalParams.kvCacheConfig.hostCacheSize = benchmarkParams.kvHostCacheSize;
|
||||
optionalParams.kvCacheConfig.onboardBlocks = benchmarkParams.kvOnboardBlocks;
|
||||
|
||||
BufferManager bufferManager{std::make_shared<CudaStream>()}; // the stream is not used
|
||||
|
||||
@ -970,11 +1146,12 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
bufferManager.copyFrom(&beamWidth, ITensor::makeShape({1}), MemoryType::kPINNED)};
|
||||
|
||||
// Load dataset
|
||||
auto const samples = parseWorkloadJson(datasetPath, maxNumSamples);
|
||||
auto const samples = parseWorkloadJson(datasetPath, maxNumSamples, maxPromptLen);
|
||||
auto const numSamples = samples.size();
|
||||
|
||||
int const maxBeamWidth = beamWidth;
|
||||
auto recorder = std::make_shared<Recorder>(opCsvFile, responsesJsonFile, excludeInputInOutput);
|
||||
auto recorder
|
||||
= std::make_shared<Recorder>(opCsvFile, benchmarkParams.streaming, responsesJsonFile, excludeInputInOutput);
|
||||
uint64_t terminateReqId = numSamples + 1;
|
||||
auto gptServer
|
||||
= std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams, recorder,
|
||||
@ -1008,8 +1185,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
reqId++;
|
||||
}
|
||||
Sample s{std::vector<int32_t>{1, 2, 3, 4, 5}, 1, 0.f, static_cast<int32_t>(taskId)};
|
||||
auto r = makeRequest(reqId, s, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager, nullptr,
|
||||
nullptr, p.first, p.second);
|
||||
auto r = makeRequest(reqId, s, benchmarkParams.streaming, beamWidthTensor, eosIdTensor, padIdTensor,
|
||||
bufferManager, nullptr, nullptr, p.first, p.second);
|
||||
gptServer->enqueue(r);
|
||||
}
|
||||
gptServer->waitForEmpty();
|
||||
@ -1026,7 +1203,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
++reqId;
|
||||
if (i == terminateReqId)
|
||||
++reqId;
|
||||
auto request = makeRequest(reqId, samples[0], beamWidthTensor, eosIdTensor, padIdTensor, bufferManager);
|
||||
auto request = makeRequest(
|
||||
reqId, samples[0], benchmarkParams.streaming, beamWidthTensor, eosIdTensor, padIdTensor, bufferManager);
|
||||
gptServer->enqueue(request);
|
||||
}
|
||||
gptServer->waitForEmpty();
|
||||
@ -1036,8 +1214,8 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
gptServer->resetBatchDeadline();
|
||||
for (std::size_t i = 0; i < numSamples; ++i)
|
||||
{
|
||||
auto request = makeRequest(i + 1, samples[i], beamWidthTensor, eosIdTensor, padIdTensor, bufferManager,
|
||||
returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor);
|
||||
auto request = makeRequest(i + 1, samples[i], benchmarkParams.streaming, beamWidthTensor, eosIdTensor,
|
||||
padIdTensor, bufferManager, returnContextLogitsFlagTensor, returnGenerationLogitsFlagTensor);
|
||||
gptServer->enqueue(request);
|
||||
auto delayInMs = static_cast<int>(samples[i].delay * 1000);
|
||||
|
||||
@ -1065,16 +1243,17 @@ void benchmarkExecutor(std::filesystem::path const& engineDir, TrtGptModelType m
|
||||
std::string const& datasetPath, std::string const& opCsvFile, int maxNumSamples, int beamWidth, int warmUp,
|
||||
std::optional<int32_t> const& eosId, std::optional<int32_t> const& padId, BenchmarkParams const& benchmarkParams,
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy, std::chrono::milliseconds waitSleep, bool returnContextLogits,
|
||||
bool returnGenerationLogits, std::optional<int> const staticEmulatedBatchSize, bool logIterationData)
|
||||
bool returnGenerationLogits, std::optional<int> const staticEmulatedBatchSize, bool logIterationData,
|
||||
std::optional<SizeType> const maxPromptLen)
|
||||
{
|
||||
auto const& world = tensorrt_llm::mpi::MpiComm::world();
|
||||
auto worldRank = world.getRank();
|
||||
|
||||
// Load dataset
|
||||
auto const samples = parseWorkloadJson(datasetPath, maxNumSamples);
|
||||
auto const samples = parseWorkloadJson(datasetPath, maxNumSamples, maxPromptLen);
|
||||
auto const numSamples = samples.size();
|
||||
|
||||
auto recorder = std::make_shared<Recorder>(opCsvFile);
|
||||
auto recorder = std::make_shared<Recorder>(opCsvFile, benchmarkParams.streaming);
|
||||
|
||||
auto executorServer = std::make_shared<ExecutorServer>(engineDir, modelType, beamWidth, schedulerPolicy,
|
||||
benchmarkParams, recorder, waitSleep, staticEmulatedBatchSize, logIterationData);
|
||||
@ -1244,6 +1423,12 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("lora_dir", "Directory containing LoRAs", cxxopts::value<std::string>()->default_value(""));
|
||||
options.add_options()("lora_host_cache_bytes", "LoRA host cache memory in bytes", cxxopts::value<size_t>());
|
||||
options.add_options()("lora_num_device_mod_layers", "LoRA number 1d cache rows", cxxopts::value<int>());
|
||||
options.add_options()("kv_host_cache_bytes",
|
||||
"Size of secondary memory pool used for offloading kv cache blocks (in bytes).",
|
||||
cxxopts::value<size_t>()->default_value("0"));
|
||||
options.add_options()("kv_dont_onboard_blocks",
|
||||
"If offloaded blocks should be onboarded to primary memory before reuse",
|
||||
cxxopts::value<bool>()->default_value("false"));
|
||||
|
||||
options.add_options()("exclude_input_in_output_seq",
|
||||
"When enabled, GptManager will exclude the input sequence from output. (Only works if --api is gptManager)",
|
||||
@ -1253,6 +1438,9 @@ int main(int argc, char* argv[])
|
||||
"When specified, dumps the responses to JSON file. (only works if --api is gptManager)",
|
||||
cxxopts::value<std::string>()->default_value(""));
|
||||
|
||||
options.add_options()(
|
||||
"max_prompt_len", "Truncate all prompts from dataset to the length specified.", cxxopts::value<SizeType>());
|
||||
|
||||
auto result = options.parse(argc, argv);
|
||||
|
||||
if (result.count("help"))
|
||||
@ -1347,6 +1535,12 @@ int main(int argc, char* argv[])
|
||||
benchmarkParams.loraDeviceNumModLayers = result["lora_num_device_mod_layers"].as<SizeType>();
|
||||
}
|
||||
|
||||
// Argument: How many KV cache blocks (as fraction of number of GPU kv cache blocks).
|
||||
benchmarkParams.kvHostCacheSize = result["kv_host_cache_bytes"].as<size_t>();
|
||||
|
||||
// Argument: If offloaded blocks should be onboarded to primary memory before they are reused.
|
||||
benchmarkParams.kvOnboardBlocks = !result["kv_dont_onboard_blocks"].as<bool>();
|
||||
|
||||
std::optional<TokenIdType> padId;
|
||||
// Argument: Padding token id
|
||||
if (result.count("pad_id"))
|
||||
@ -1390,6 +1584,13 @@ int main(int argc, char* argv[])
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Argument: max_prompt_len
|
||||
std::optional<SizeType> maxPromptLen;
|
||||
if (result.count("max_prompt_len"))
|
||||
{
|
||||
maxPromptLen = result["max_prompt_len"].as<SizeType>();
|
||||
}
|
||||
|
||||
// Argument: Log level
|
||||
auto logger = std::make_shared<TllmLogger>();
|
||||
auto const logLevel = result["log_level"].as<std::string>();
|
||||
@ -1429,7 +1630,7 @@ int main(int argc, char* argv[])
|
||||
maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy,
|
||||
waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, batchTimeout,
|
||||
logIterationData, result["exclude_input_in_output_seq"].as<bool>(),
|
||||
result["responses_json_file"].as<std::string>());
|
||||
result["responses_json_file"].as<std::string>(), maxPromptLen);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
@ -1443,7 +1644,7 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
benchmarkExecutor(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile, maxNumSamples,
|
||||
beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy, waitSleep,
|
||||
returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, logIterationData);
|
||||
returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, logIterationData, maxPromptLen);
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
|
||||
@ -125,11 +125,10 @@ def load_dataset_from_hf(dataset_config: DatasetConfig):
|
||||
type=str,
|
||||
default=None,
|
||||
help=f"Dataset config name in HuggingFace (if exists).")
|
||||
@click.option(
|
||||
"--dataset-split",
|
||||
type=str,
|
||||
default=None,
|
||||
help=f"Split of the dataset to use. Default will include all splits.")
|
||||
@click.option("--dataset-split",
|
||||
type=str,
|
||||
required=True,
|
||||
help=f"Split of the dataset to use.")
|
||||
@click.option("--dataset-input-key",
|
||||
required=True,
|
||||
type=str,
|
||||
|
||||
@ -244,6 +244,11 @@ def parse_arguments():
|
||||
help=
|
||||
"Check the estimated memory usage against the total GPU memory. Raise error if the estimated memory requirement is bigger than the total GPU memory"
|
||||
"Warning: only GPT model family is supported for now")
|
||||
parser.add_argument(
|
||||
'--dump_profile',
|
||||
default=False,
|
||||
action='store_true',
|
||||
help="Print profile information per layer (default = disabled)")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@ -310,6 +315,9 @@ def main(args):
|
||||
if args.build_only:
|
||||
return
|
||||
|
||||
if args.dump_profile and benchmark_profiler is not None:
|
||||
benchmark_profiler.set_recording_perf_profile(True)
|
||||
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
benchmarker.print_report_header(args.csv,
|
||||
|
||||
@ -21,12 +21,14 @@ class BenchmarkProfiler(object):
|
||||
timer_dict: dict
|
||||
aux_info: dict
|
||||
started: bool
|
||||
is_recording_perf_profile: bool
|
||||
|
||||
def __init__(self):
|
||||
self.cuda_event_dict = {}
|
||||
self.timer_dict = {}
|
||||
self.aux_info = {}
|
||||
self.started = False
|
||||
self.is_recording_perf_profile = False
|
||||
|
||||
def clean(self):
|
||||
self.cuda_event_dict = {}
|
||||
@ -75,3 +77,6 @@ class BenchmarkProfiler(object):
|
||||
if not self.started:
|
||||
return
|
||||
self.aux_info[aux_name] += add_value
|
||||
|
||||
def set_recording_perf_profile(self, value: bool):
|
||||
self.is_recording_perf_profile = value
|
||||
|
||||
@ -64,6 +64,12 @@ class BERTBenchmark(BaseBenchmark):
|
||||
self.session = tensorrt_llm.runtime.Session.from_serialized_engine(
|
||||
engine_buffer)
|
||||
|
||||
# Print context memory size for CI/CD to track.
|
||||
context_mem_size = self.session.context_mem_size
|
||||
print(
|
||||
f"Allocated {context_mem_size / 1048576.0:.2f} MiB for execution context memory."
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
for inlen in self.in_lens:
|
||||
if inlen > self.max_input_len:
|
||||
|
||||
@ -223,6 +223,15 @@ def build_gpt(args):
|
||||
quant_mode = quant_config.quant_mode
|
||||
|
||||
builder = Builder()
|
||||
builder_config_extra_kwargs = {}
|
||||
if get_model_family(args.model) == 'mamba':
|
||||
builder_config_extra_kwargs['mamba_d_state'] = build_config[
|
||||
'mamba_d_state']
|
||||
builder_config_extra_kwargs['mamba_d_conv'] = build_config[
|
||||
'mamba_d_conv']
|
||||
builder_config_extra_kwargs['mamba_expand'] = build_config[
|
||||
'mamba_expand']
|
||||
builder_config_extra_kwargs['max_beam_width'] = max_beam_width
|
||||
builder_config = builder.create_builder_config(
|
||||
name=args.model,
|
||||
precision=args.dtype,
|
||||
@ -246,7 +255,8 @@ def build_gpt(args):
|
||||
quant_mode=quant_mode,
|
||||
use_refit=False,
|
||||
opt_level=build_config['builder_opt'],
|
||||
strongly_typed=strongly_typed)
|
||||
strongly_typed=strongly_typed,
|
||||
**builder_config_extra_kwargs)
|
||||
engine_name = get_engine_name(args.model, args.dtype, world_size,
|
||||
runtime_rank)
|
||||
|
||||
@ -360,7 +370,8 @@ def build_gpt(args):
|
||||
}
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.LLaMAForCausalLM(config)
|
||||
|
||||
tensorrt_llm_model = optimize_model(tensorrt_llm_model,
|
||||
use_fused_mlp=True)
|
||||
elif family == "gptj":
|
||||
config = {
|
||||
'architecture': 'GPTJForCausalLM',
|
||||
@ -524,7 +535,9 @@ def build_gpt(args):
|
||||
}
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.BloomForCausalLM(config)
|
||||
|
||||
tensorrt_llm_model = optimize_model(
|
||||
tensorrt_llm_model,
|
||||
use_parallel_embedding=config.use_parallel_embedding)
|
||||
elif family == "falcon":
|
||||
config = {
|
||||
'architecture':
|
||||
@ -666,32 +679,45 @@ def build_gpt(args):
|
||||
|
||||
elif family == "qwen":
|
||||
config = {
|
||||
'architecture': 'QWenForCausalLM',
|
||||
'dtype': args.dtype,
|
||||
'num_hidden_layers': build_config['num_layers'],
|
||||
'num_attention_heads': build_config['num_heads'],
|
||||
'hidden_size': build_config['hidden_size'],
|
||||
'intermediate_size': build_config['inter_size'],
|
||||
'num_key_value_heads': num_kv_heads,
|
||||
'vocab_size': build_config['vocab_size'],
|
||||
'position_embedding_type': 'rope_gpt_neox',
|
||||
'max_position_embeddings': build_config['n_positions'],
|
||||
'hidden_act': build_config['hidden_act'],
|
||||
'rotary_base': 10000.0,
|
||||
'norm_epsilon': 1e-06,
|
||||
'architecture':
|
||||
'QWenForCausalLM',
|
||||
'dtype':
|
||||
args.dtype,
|
||||
'num_hidden_layers':
|
||||
build_config['num_layers'],
|
||||
'num_attention_heads':
|
||||
build_config['num_heads'],
|
||||
'num_key_value_heads':
|
||||
build_config['num_heads'] if build_config['num_kv_heads'] is None
|
||||
else build_config['num_kv_heads'],
|
||||
'hidden_size':
|
||||
build_config['hidden_size'],
|
||||
'intermediate_size':
|
||||
build_config['inter_size'],
|
||||
'vocab_size':
|
||||
build_config['vocab_size'],
|
||||
'position_embedding_type':
|
||||
'rope_gpt_neox',
|
||||
'max_position_embeddings':
|
||||
build_config['n_positions'],
|
||||
'hidden_act':
|
||||
build_config['hidden_act'],
|
||||
'quantization': {
|
||||
'group_size': 128,
|
||||
'quant_algo': quant_algo,
|
||||
'kv_cache_quant_algo': kv_cache_quant_algo,
|
||||
'group_size': 128
|
||||
'kv_cache_quant_algo': kv_cache_quant_algo
|
||||
},
|
||||
'mapping': {
|
||||
'world_size': world_size,
|
||||
'tp_size': world_size,
|
||||
'tp_size': world_size
|
||||
},
|
||||
'moe_num_experts':
|
||||
build_config["moe_num_experts"],
|
||||
'moe_top_k':
|
||||
build_config["moe_top_k"],
|
||||
}
|
||||
config = PretrainedConfig.from_dict(config)
|
||||
tensorrt_llm_model = tensorrt_llm.models.QWenForCausalLM(config)
|
||||
|
||||
elif family == "mamba":
|
||||
config = {
|
||||
'architecture': 'MambaLMHeadModel',
|
||||
@ -716,10 +742,6 @@ def build_gpt(args):
|
||||
else:
|
||||
raise Exception(f'Unexpected model: {args.model}')
|
||||
|
||||
if family in ['llama']:
|
||||
tensorrt_llm_model = optimize_model(tensorrt_llm_model,
|
||||
use_fused_mlp=True)
|
||||
|
||||
# Module -> Network
|
||||
network = builder.create_network()
|
||||
network.trt_network.name = engine_name
|
||||
@ -1225,14 +1247,21 @@ def build_enc_dec(args):
|
||||
def main(args):
|
||||
logger.set_level(args.log_level)
|
||||
if args.model in get_allowed_models(benchmark_type="gpt"):
|
||||
build_gpt(args)
|
||||
engine = build_gpt(args)[0]
|
||||
engine_size = engine.nbytes
|
||||
elif args.model in get_allowed_models(benchmark_type="bert"):
|
||||
build_bert(args)
|
||||
engine = build_bert(args)[0]
|
||||
engine_size = engine.nbytes
|
||||
elif args.model in get_allowed_models(benchmark_type="enc_dec"):
|
||||
build_enc_dec(args)
|
||||
encoder_engine, decoder_engine = build_enc_dec(args)[:2]
|
||||
engine_size = encoder_engine.nbytes + decoder_engine.nbytes
|
||||
else:
|
||||
raise Exception(f'Unexpected model: {args.model}')
|
||||
|
||||
# Print engine size for CI/CD to track.
|
||||
logger.info(
|
||||
f"Total engine size per GPU is {engine_size / 1048576:.2f} MiB.")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
mp.set_start_method('spawn')
|
||||
|
||||
@ -204,6 +204,12 @@ class EncDecBenchmark(BaseBenchmark):
|
||||
self.decoder_runtime_mapping,
|
||||
)
|
||||
|
||||
# Print context memory size for CI/CD to track.
|
||||
context_mem_size = self.encoder_session.context_mem_size + self.decoder_session.context_mem_size
|
||||
print(
|
||||
f"Allocated {context_mem_size / 1048576.0:.2f} MiB for execution context memory."
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
if 'whisper' in self.model_name:
|
||||
print(
|
||||
|
||||
@ -16,6 +16,7 @@ import os
|
||||
from dataclasses import asdict
|
||||
from math import ceil
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
|
||||
import tensorrt_llm
|
||||
@ -93,6 +94,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
# Plugins
|
||||
self.use_gpt_attention_plugin = False
|
||||
self.remove_input_padding = False
|
||||
self.use_mamba_conv1d_plugin = False
|
||||
if args.mode == 'plugin':
|
||||
self.use_gpt_attention_plugin = True
|
||||
self.remove_input_padding = True
|
||||
@ -129,6 +131,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
remove_input_padding=self.remove_input_padding,
|
||||
quant_mode=self.quant_mode,
|
||||
use_custom_all_reduce=self.use_custom_all_reduce,
|
||||
mamba_conv1d_plugin=self.use_mamba_conv1d_plugin,
|
||||
)
|
||||
if args.model == 'chatglm_6b':
|
||||
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
@ -177,6 +180,12 @@ class GPTBenchmark(BaseBenchmark):
|
||||
self.runtime_mapping,
|
||||
cuda_graph_mode=self.cuda_graph_mode)
|
||||
|
||||
# Print context memory size for CI/CD to track.
|
||||
context_mem_size = self.decoder.context_mem_size
|
||||
print(
|
||||
f"Allocated {context_mem_size / 1048576.0:.2f} MiB for execution context memory."
|
||||
)
|
||||
|
||||
def get_config(self):
|
||||
for inlen, outlen in self.in_out_lens:
|
||||
if inlen > self.max_input_len or outlen > self.max_output_len:
|
||||
@ -338,3 +347,56 @@ class GPTBenchmark(BaseBenchmark):
|
||||
kv_pairs = [f"{k} {v}" for k, v in report_dict.items()]
|
||||
line = '[BENCHMARK] ' + " ".join(kv_pairs)
|
||||
print(line)
|
||||
|
||||
if benchmark_profiler is not None and benchmark_profiler.is_recording_perf_profile:
|
||||
perf_profile_data = self.decoder.profiler.results
|
||||
if not perf_profile_data:
|
||||
tensorrt_llm.logger.error("profiler data is empty")
|
||||
return
|
||||
|
||||
ctx_layers = list()
|
||||
generation_layers = list()
|
||||
start = 0
|
||||
ctx_iter_cnt = 0
|
||||
generation_iter_cnt = 0
|
||||
|
||||
# split context/generations layer information
|
||||
for idx, layer_info in enumerate(perf_profile_data):
|
||||
if layer_info[0] == "step":
|
||||
if layer_info[1] == 0:
|
||||
ctx_layers.extend(perf_profile_data[start:idx])
|
||||
ctx_iter_cnt += 1
|
||||
else:
|
||||
generation_layers.extend(perf_profile_data[start:idx])
|
||||
generation_iter_cnt += 1
|
||||
start = idx + 1
|
||||
|
||||
# Reduce all data
|
||||
def reduce_layer_data(layers):
|
||||
layer_infos = dict()
|
||||
for layer in layers:
|
||||
if layer[0] in layer_infos:
|
||||
layer_infos[layer[0]] += layer[1]
|
||||
else:
|
||||
layer_infos[layer[0]] = layer[1]
|
||||
return layer_infos
|
||||
|
||||
# Dump kernel data
|
||||
def dump_kernel_profile_table(name: str, profile_data: list,
|
||||
iter_cnt: int):
|
||||
table = pd.DataFrame(
|
||||
[[k, '{:0.3f}'.format(v)] for k, v in profile_data.items()],
|
||||
columns=['{} Phase LayerName'.format(name), 'times (ms)'])
|
||||
|
||||
def ljust(s):
|
||||
s = s.astype(str).str.strip()
|
||||
return s.str.ljust(s.str.len().max())
|
||||
|
||||
print(table.apply(ljust).to_string(index=False, justify='left'))
|
||||
print("{} phase step iter: {}".format(name, iter_cnt))
|
||||
|
||||
ctx_layer_infos = reduce_layer_data(ctx_layers)
|
||||
generation_layer_infos = reduce_layer_data(generation_layers)
|
||||
dump_kernel_profile_table("Context", ctx_layer_infos, ctx_iter_cnt)
|
||||
dump_kernel_profile_table("Generation", generation_layer_infos,
|
||||
generation_iter_cnt)
|
||||
|
||||
81
benchmarks/suite/README.md
Normal file
81
benchmarks/suite/README.md
Normal file
@ -0,0 +1,81 @@
|
||||
# TensorRT-LLM Benchmarking
|
||||
|
||||
**WORK IN PROGRESS**
|
||||
|
||||
This package is the official benchmarking suite for TensorRT-LLM. This benchmark will be updated
|
||||
as development of TensorRT-LLM continues.
|
||||
|
||||
## Installation
|
||||
|
||||
From this folder, run `pip install -r requirements.txt` to install the extra dependencies required for this tool.
|
||||
|
||||
### Available Model Options
|
||||
|
||||
The following model options are available for benchmarking models.
|
||||
|
||||
| Option | Required | Default | Description |
|
||||
| :- | :-: | :-: | :- |
|
||||
| `--model` | Y | - | The name of the model to benchmark. |
|
||||
| `--dtype` | N | `float16` | The datatype of the weights. |
|
||||
| `--kv-dtype` | N | `float16` | The datatype to store the KV Cache in. |
|
||||
| `--quantization` | N | `None` |The quantization algorithm to be used when benchmarking. See the [documentation](https://nvidia.github.io/TensorRT-LLM/precision.html) for more information|
|
||||
| `--workspace` | N | `/tmp` | The directory to store benchmarking intermediate files. |
|
||||
| `--tensor-parallel-size` | N | `1` | Number of tensor parallel shards to run the benchmark with. |
|
||||
| `--pipeline-parallel-size` | N | `1` | Number of pipeline parallel shards to run the benchmark with. |
|
||||
|
||||
#### Supported Networks for Benchmarking
|
||||
|
||||
- [`tiiuae/falcon-7b`](https://huggingface.co/tiiuae/falcon-7b)
|
||||
- [`tiiuae/falcon-40b`](https://huggingface.co/tiiuae/falcon-40b)
|
||||
- [`tiiuae/falcon-180B`](https://huggingface.co/tiiuae/falcon-180B)
|
||||
- [`meta-llama/Llama-2-7b-hf`](https://huggingface.co/meta-llama/Llama-2-7b-hf)
|
||||
- [`meta-llama/Llama-2-13b-hf`](https://huggingface.co/meta-llama/Llama-2-13b-hf)
|
||||
- [`meta-llama/Llama-2-70b-hf`](https://huggingface.co/meta-llama/Llama-2-70b-hf)
|
||||
- [`EleutherAI/gpt-j-6b`](https://huggingface.co/EleutherAI/gpt-j-6b)
|
||||
|
||||
#### Support Quantization Modes
|
||||
|
||||
TensorRT-LLM supports a number of quanization modes. For more information about quantization, see the [documentation](https://nvidia.github.io/TensorRT-LLM/precision.html).
|
||||
|
||||
- None (no quantization applied)
|
||||
- W8A16
|
||||
- W4A16
|
||||
- W4A16_AWQ
|
||||
- W4A8_AWQ
|
||||
- W4A16_GPTQ
|
||||
- FP8
|
||||
- INT8
|
||||
|
||||
> [!NOTE] Please see the supported quantization methods for each network [here](https://nvidia.github.io/TensorRT-LLM/precision.html#support-matrix)
|
||||
|
||||
## Static Benchmarking a Network
|
||||
|
||||
In order to benchmark a static batch for a network, run a command like the following:
|
||||
|
||||
```shell
|
||||
cd tensorrt_llm_bench/
|
||||
python benchmark.py --model tiiuae/falcon-7b static --isl 128 --osl 128 --batch 1
|
||||
```
|
||||
|
||||
This command line will build a unique engine for the configuration and run the benchmark using
|
||||
the `gptSessionBenchmark` binary. You need to build the TensorRT-LLM wheel with the `--benchmarks` flag for this binary to be compiled:
|
||||
|
||||
```shell
|
||||
python3 ./scripts/build_wheel.py --benchmarks <other options>
|
||||
```
|
||||
|
||||
The complete list of arguments are given here:
|
||||
| Option | Required | Default | Description |
|
||||
| :- | :-: | :-: | :- |
|
||||
| `--batch` | Y | - | The batch size to benchmark. |
|
||||
| `--isl` | Y | - | The input sequence length to pass in during benchmark. |
|
||||
| `--osl` | Y | - | The output sequence length to generate in the benchmark. |
|
||||
| `--gpt-session-path` | N | `../../cpp/build/benchmarks/gptSessionBenchmark` | The path to the built gptSessionBenchmark binary. |
|
||||
| `--max-tokens-in-kv-cache` | N | `None` | The maximum number of tokens to store in the KV Cache during benchmarking. |
|
||||
| `--kv-cache-mem-percent` | N | `0.9` | The percentage of free memory that the KV cache is allowed to occupy. |
|
||||
| `--warm-up-runs` | N | `2` | The number of warm up runs to run before benchmarking actual results. |
|
||||
| `--num-runs` | N | `10` | The number runs to generate benchmarking results from. |
|
||||
| `--duration` | N | `60` | The minimum iteration time, in seconds, to measure. |
|
||||
|
||||
> [!WARNING]
|
||||
> `gptSession` will be deprecated for the 1.0 release of TensorRT-LLM. This command line will change in order to match and update benchmarks accordingly.
|
||||
3
benchmarks/suite/requirements.txt
Normal file
3
benchmarks/suite/requirements.txt
Normal file
@ -0,0 +1,3 @@
|
||||
pydantic>=2.2.1
|
||||
click-option-group == 0.5.6
|
||||
aenum == 3.1.15
|
||||
1
benchmarks/suite/tensorrt_llm_bench/__init__.py
Normal file
1
benchmarks/suite/tensorrt_llm_bench/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Module for running TensorRT-LLM benchmarks."""
|
||||
95
benchmarks/suite/tensorrt_llm_bench/benchmark.py
Normal file
95
benchmarks/suite/tensorrt_llm_bench/benchmark.py
Normal file
@ -0,0 +1,95 @@
|
||||
from pathlib import Path
|
||||
from typing import get_args
|
||||
|
||||
import click
|
||||
from static import static_benchmark
|
||||
from utils import (VALID_CACHE_DTYPES, VALID_COMPUTE_DTYPES, VALID_MODELS,
|
||||
VALID_QUANT_ALGOS)
|
||||
from utils.dataclasses import BenchmarkConfig
|
||||
|
||||
|
||||
@click.group(context_settings={'show_default': True})
|
||||
@click.option(
|
||||
"--model",
|
||||
"-m",
|
||||
required=True,
|
||||
type=click.Choice(tuple(get_args(VALID_MODELS))),
|
||||
help="The Huggingface name of the model to benchmark.",
|
||||
)
|
||||
@click.option(
|
||||
"--kv-dtype",
|
||||
type=click.Choice(tuple(get_args(VALID_CACHE_DTYPES))),
|
||||
default="float16",
|
||||
help="The dtype to store the KV Cache in.",
|
||||
)
|
||||
@click.option(
|
||||
"--dtype",
|
||||
type=click.Choice(tuple(get_args(VALID_COMPUTE_DTYPES))),
|
||||
default="float16",
|
||||
help="Activation and plugin data type.",
|
||||
)
|
||||
@click.option(
|
||||
"--quantization",
|
||||
"-q",
|
||||
type=click.Choice(tuple(get_args(VALID_QUANT_ALGOS))),
|
||||
default="None",
|
||||
help=
|
||||
("The quantization algorithm to be used when benchmarking. See the "
|
||||
"documentations for more information.\n"
|
||||
" - https://nvidia.github.io/TensorRT-LLM/precision.html"
|
||||
" - https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/quantization-in-TRT-LLM.md"
|
||||
))
|
||||
@click.option(
|
||||
"--workspace",
|
||||
required=False,
|
||||
type=click.Path(writable=True, readable=True),
|
||||
default="/tmp",
|
||||
help="The directory to store benchmarking intermediate files.",
|
||||
)
|
||||
@click.option(
|
||||
"--tensor-parallel-size",
|
||||
"-tp",
|
||||
type=int,
|
||||
default=1,
|
||||
required=False,
|
||||
help="Number of tensor parallel shards to run the benchmark with.",
|
||||
)
|
||||
@click.option(
|
||||
"--pipeline-parallel-size",
|
||||
"-pp",
|
||||
type=int,
|
||||
default=1,
|
||||
required=False,
|
||||
help="Number of pipeline parallel shards to run the benchmark with.",
|
||||
)
|
||||
@click.pass_context
|
||||
def benchmark(
|
||||
ctx,
|
||||
model: str,
|
||||
workspace: Path,
|
||||
dtype: str,
|
||||
kv_dtype: str,
|
||||
quantization: str,
|
||||
tensor_parallel_size: int,
|
||||
pipeline_parallel_size: int,
|
||||
):
|
||||
"""Utility for using TRT-LLM for benchmarking networks from Huggingface."""
|
||||
ctx.obj = BenchmarkConfig(
|
||||
model=model,
|
||||
workspace=Path(workspace),
|
||||
dtype=dtype,
|
||||
cache_dtype=kv_dtype,
|
||||
quantization=quantization,
|
||||
tensor_parallel=tensor_parallel_size,
|
||||
pipeline_parallel=pipeline_parallel_size,
|
||||
)
|
||||
|
||||
# Create the workspace where we plan to store intermediate files.
|
||||
ctx.obj.workspace.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
# Add nested subcommands to main benchmark CLI.
|
||||
benchmark.add_command(static_benchmark)
|
||||
|
||||
if __name__ == "__main__":
|
||||
benchmark()
|
||||
83
benchmarks/suite/tensorrt_llm_bench/static.py
Normal file
83
benchmarks/suite/tensorrt_llm_bench/static.py
Normal file
@ -0,0 +1,83 @@
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from utils.benchmarkers import gptSessionBenchmarker
|
||||
from utils.dataclasses import BenchmarkConfig, BenchmarkResults
|
||||
|
||||
|
||||
@click.command("static")
|
||||
@click.option(
|
||||
"--batch",
|
||||
required=True,
|
||||
type=int,
|
||||
help="Batch size to build and run the static benchmark with.",
|
||||
)
|
||||
@click.option("--isl",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Input sequence length (in tokens).")
|
||||
@click.option("--osl",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Output sequence length (in tokens).")
|
||||
@click.option(
|
||||
"--gpt-session-path",
|
||||
"-b",
|
||||
type=click.Path(),
|
||||
default=Path(os.path.dirname(os.path.realpath(__file__)), "../../..",
|
||||
"cpp/build/benchmarks/gptSessionBenchmark").absolute(),
|
||||
help="Path to TRT-LLM gptSession benchmark binary.")
|
||||
@click.option("--max-tokens-in-kv-cache",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of tokens to store in KV cache")
|
||||
@click.option(
|
||||
"--kv-cache-mem-percent",
|
||||
type=float,
|
||||
default=0.9,
|
||||
help="The percentage of free memory that the KV Cache is allowed to occupy.",
|
||||
)
|
||||
@click.option("--warm-up-runs",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of warm up runs before benchmarking")
|
||||
@click.option("--num-runs",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of times to run benchmark")
|
||||
@click.option("--duration",
|
||||
type=int,
|
||||
default=60,
|
||||
help="Minimum duration of iteration to measure, in seconds")
|
||||
@click.pass_obj
|
||||
def static_benchmark(benchmark_cfg: BenchmarkConfig, batch: int, isl: int,
|
||||
osl: int, gpt_session_path: Path, warm_up_runs: int,
|
||||
num_runs: int, duration: int, max_tokens_in_kv_cache: int,
|
||||
kv_cache_mem_percent: float):
|
||||
"""Run a static benchmark with a fixed batch size, ISL, and OSL."""
|
||||
if max_tokens_in_kv_cache is None:
|
||||
max_tokens_in_kv_cache = batch * isl
|
||||
|
||||
benchmarker = gptSessionBenchmarker(
|
||||
benchmark_cfg,
|
||||
gpt_session_path,
|
||||
batch,
|
||||
isl,
|
||||
osl,
|
||||
warm_up_runs,
|
||||
num_runs,
|
||||
duration,
|
||||
max_tokens_in_kv_cache,
|
||||
kv_cache_mem_percent,
|
||||
)
|
||||
|
||||
print(f"Building TRT-LLM engine for '{benchmark_cfg.model}'...")
|
||||
benchmarker.build()
|
||||
|
||||
print("Build complete. Running benchmark...")
|
||||
result: BenchmarkResults = benchmarker.benchmark()
|
||||
|
||||
print(f"JSON: {json.dumps(result.model_dump())}")
|
||||
print(result.get_summary(benchmarker.config))
|
||||
141
benchmarks/suite/tensorrt_llm_bench/utils/__init__.py
Normal file
141
benchmarks/suite/tensorrt_llm_bench/utils/__init__.py
Normal file
@ -0,0 +1,141 @@
|
||||
import functools
|
||||
import os
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, List, Literal
|
||||
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
VALID_MODELS = Literal["tiiuae/falcon-7b", "tiiuae/falcon-40b",
|
||||
"tiiuae/falcon-180B", "meta-llama/Llama-2-7b-hf",
|
||||
"meta-llama/Llama-2-13b-hf", "meta-llama/Llama-2-70b-hf",
|
||||
"EleutherAI/gpt-j-6b", ]
|
||||
VALID_COMPUTE_DTYPES = Literal["float16", "bfloat16"]
|
||||
VALID_CACHE_DTYPES = Literal["float16", "float8", "int8"]
|
||||
VALID_QUANT_ALGOS = Literal["None", f"{QuantAlgo.W8A16}", f"{QuantAlgo.W4A16}",
|
||||
f"{QuantAlgo.W4A16_AWQ}", f"{QuantAlgo.W4A8_AWQ}",
|
||||
f"{QuantAlgo.W4A16_GPTQ}", f"{QuantAlgo.FP8}",
|
||||
f"{QuantAlgo.INT8}"]
|
||||
|
||||
|
||||
class _MethodFunctionAdapter:
|
||||
"""An adapter class for running decorators on both methods and functions.
|
||||
|
||||
Found here: https://stackoverflow.com/a/1288936 with help of ChatGPT. This
|
||||
works via the following logic.
|
||||
|
||||
1. During function declaration, store the decorator and function in an
|
||||
instance of this class using `detect_methods`. Works for both functions
|
||||
and methods because a method will be a reference to `Class.method` at
|
||||
declaration time.
|
||||
2. The __call__ method makes this class callable. In the case of functions,
|
||||
the wrapper will simply call the decorator as a wrapper function simply
|
||||
passing arguments without accessing a class descriptor.
|
||||
3. The __get__ method is part of the descriptor protocol for Python classes.
|
||||
In the case of running a method, the call will access the property of a
|
||||
class instance which has been wrapped by this class. __get__ overrides are
|
||||
used to control how a class property is returned. In this case, we would
|
||||
like to return the method reference wrapped in the decorator.
|
||||
"""
|
||||
|
||||
def __init__(self, decorator, func):
|
||||
self.decorator = decorator
|
||||
self.func = func
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self.decorator(self.func)(*args, **kwargs)
|
||||
|
||||
def __get__(self, instance, owner):
|
||||
return self.decorator(self.func.__get__(instance, owner))
|
||||
|
||||
|
||||
def detect_methods(decorator):
|
||||
"""Decorator for applying a wrapper to both methods and functions."""
|
||||
|
||||
def apply_wrapper(func):
|
||||
return _MethodFunctionAdapter(decorator, func)
|
||||
|
||||
return apply_wrapper
|
||||
|
||||
|
||||
def command_logger(prefix: str = "") -> Callable:
|
||||
"""Logs the command for functions that call subprocesses.
|
||||
|
||||
NOTE: This helper assumes the command is in the first argument.
|
||||
|
||||
Args:
|
||||
func (Callable): Function whose first argument is a list of arguments.
|
||||
prefix (str, optional): Prefix to prepend to command. Defaults to "".
|
||||
|
||||
Returns:
|
||||
Callable: Function that includes command logging.
|
||||
"""
|
||||
|
||||
@detect_methods
|
||||
def inner_wrapper(func):
|
||||
|
||||
@functools.wraps(func)
|
||||
def wrapped(*args, **kwargs):
|
||||
# Append the prefix and join the command.
|
||||
print(f"{prefix}{' '.join(args[0])}")
|
||||
return func(*args, **kwargs)
|
||||
|
||||
return wrapped
|
||||
|
||||
return inner_wrapper
|
||||
|
||||
|
||||
@detect_methods
|
||||
def process_error_check(func: Callable) -> subprocess.CompletedProcess:
|
||||
"""Logs standard error and raises an exception on failed processes.
|
||||
|
||||
Args:
|
||||
func (Callable): Callable that returns a CompletedProcess.
|
||||
|
||||
Returns:
|
||||
subprocess.CompletedProcess: Returns a completed process just as
|
||||
an unwrapped `subprocess.run` would.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the wrapped function returns a non-zero error code.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def runtime_check(*args, **kwargs):
|
||||
finished_process = func(*args, **kwargs)
|
||||
|
||||
if finished_process.returncode != 0:
|
||||
raise RuntimeError(
|
||||
"Process failed. Output below.\n"
|
||||
"================================================================\n"
|
||||
f"{finished_process.stderr}")
|
||||
|
||||
return finished_process
|
||||
|
||||
return runtime_check
|
||||
|
||||
|
||||
def run_process(cmd: List[Any],
|
||||
run_dir: Path = None,
|
||||
use_environ: bool = False,
|
||||
stderr_on_stdout: bool = False) -> subprocess.CompletedProcess:
|
||||
"""Utility function for launching processes.
|
||||
|
||||
Args:
|
||||
cmd (List[Any]): A list of arguments that must be able to be cast to a string.
|
||||
run_dir (Path, optional): The directory to run the process from. Defaults to None.
|
||||
use_environ (bool, optional): Use the environment of the container to run the process. Necessary for any commands that start with `mpirun`, as mpi4py initializes its own MPI environment
|
||||
stderr_on_stdout (bool, optional): Pipe STDERR to STDOUT. Defaults to False.
|
||||
|
||||
Returns:
|
||||
subprocess.CompletedProcess: _description_
|
||||
"""
|
||||
result = subprocess.run(
|
||||
[str(x) for x in cmd],
|
||||
cwd=run_dir,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT if stderr_on_stdout else subprocess.PIPE,
|
||||
env=os.environ if use_environ else None,
|
||||
text=True,
|
||||
)
|
||||
return result
|
||||
228
benchmarks/suite/tensorrt_llm_bench/utils/benchmarkers.py
Normal file
228
benchmarks/suite/tensorrt_llm_bench/utils/benchmarkers.py
Normal file
@ -0,0 +1,228 @@
|
||||
from pathlib import Path
|
||||
from subprocess import CompletedProcess
|
||||
from typing import Dict, List, Protocol
|
||||
|
||||
from utils import command_logger, process_error_check, run_process
|
||||
from utils.dataclasses import BenchmarkConfig, BenchmarkResults
|
||||
from utils.trtllm_config import TRTLLMConfig
|
||||
|
||||
|
||||
class Benchmarker(Protocol):
|
||||
"""Protocol for defining benchmarking classes for building/benchmarking."""
|
||||
|
||||
def build(self) -> None:
|
||||
"""Build a model to be benchmarked."""
|
||||
...
|
||||
|
||||
def benchmark(self) -> BenchmarkResults:
|
||||
"""Benchmark the constructed model container by a benchmarker."""
|
||||
...
|
||||
|
||||
|
||||
class gptSessionBenchmarker:
|
||||
"""Utility class for running static benchmarks with gptSessionBenchmark."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: BenchmarkConfig,
|
||||
benchmark_binary: Path,
|
||||
batch_size: int,
|
||||
isl: int,
|
||||
osl: int,
|
||||
warm_up_runs: int,
|
||||
num_runs: int,
|
||||
duration: int,
|
||||
max_tokens_in_kv_cache: int,
|
||||
kv_cache_free_fraction: float = .9,
|
||||
):
|
||||
"""Initialize a gptSessionBenchmark instance.
|
||||
|
||||
Args:
|
||||
config (BenchmarkConfig): Benchmark configuration for build/run.
|
||||
benchmark_binary (Path): Path to the benchmarking binary.
|
||||
batch_size (int): Batch size to configure the build with.
|
||||
isl (int): Input sequence length to configure the build with.
|
||||
osl (int): Output sequence length to configure the build with.
|
||||
max_tokens_in_kv_cache (int): The maximum number of tokens to store
|
||||
in the KV cache
|
||||
kv_cache_free_fraction (float, optional): The amount of remaining
|
||||
GPU memory after model loading to save for the KV Cache. Defaults
|
||||
to .9.
|
||||
"""
|
||||
self.config: BenchmarkConfig = config
|
||||
self.gpt_session_path = Path(benchmark_binary).absolute()
|
||||
self.batch_size = batch_size
|
||||
self.input_length = isl
|
||||
self.output_length = osl
|
||||
self.warm_up = warm_up_runs
|
||||
self.num_runs = num_runs
|
||||
self.duration = duration
|
||||
self.kv_cache_mem = kv_cache_free_fraction
|
||||
self.max_tokens_in_kv_cache = max_tokens_in_kv_cache
|
||||
self.result = None
|
||||
|
||||
def get_build_command(self) -> List[str]:
|
||||
"""Build the engine command for TRT-LLM.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of command line arguments to run a build command.
|
||||
"""
|
||||
model = self.config.model
|
||||
tp = self.config.tensor_parallel
|
||||
pp = self.config.pipeline_parallel
|
||||
dtype = self.config.dtype
|
||||
kv_dtype = self.config.cache_dtype
|
||||
quant_algo = self.config.quantization.value
|
||||
output_dir = self.config.engine_path
|
||||
max_batch_size = self.batch_size
|
||||
max_isl = self.input_length
|
||||
max_osl = self.output_length
|
||||
workspace = self.config.workspace
|
||||
|
||||
# Generate the TRT-LLM Configuration file using the dataclass
|
||||
# NOTE: This method does not use weights.
|
||||
trtllm_config = TRTLLMConfig.from_hf(model, tp, pp, dtype, quant_algo,
|
||||
kv_dtype)
|
||||
# Write the generated configuration file to the benchmark workspace.
|
||||
trtllm_config.to_json(workspace)
|
||||
|
||||
# Return the full command for building TRT-LLM via subprocess call.
|
||||
cmd = [
|
||||
"trtllm-build",
|
||||
"--output_dir",
|
||||
output_dir,
|
||||
"--model_config",
|
||||
Path(workspace, "generated_config.json"),
|
||||
"--workers",
|
||||
self.config.world_size,
|
||||
# Define the maximums the engine can accept.
|
||||
"--max_batch_size",
|
||||
max_batch_size,
|
||||
"--max_input_len",
|
||||
max_isl,
|
||||
"--max_output_len",
|
||||
max_osl,
|
||||
"--context_fmha",
|
||||
"enable",
|
||||
# Set the attention plugin data type.
|
||||
"--gpt_attention_plugin",
|
||||
dtype.value,
|
||||
# Disable paged cache since we aren't batching on the fly.
|
||||
"--paged_kv_cache",
|
||||
"disable",
|
||||
] + kv_dtype.get_build_options(dtype)
|
||||
|
||||
return [str(arg) for arg in cmd]
|
||||
|
||||
@command_logger(prefix="BUILD COMMAND: ")
|
||||
@process_error_check
|
||||
def _run_build(self, cmd: List[str]) -> CompletedProcess:
|
||||
"""Wrapper for calling the build for TRT-LLM.
|
||||
|
||||
Purpose of this wrapper is so that we can decorate it/log it.
|
||||
|
||||
Args:
|
||||
cmd (List[str]): List of command line arguments for running.
|
||||
|
||||
Returns:
|
||||
CompletedProcess: Completed process information for parsing and
|
||||
reporting.
|
||||
"""
|
||||
return run_process(
|
||||
cmd,
|
||||
self.config.workspace,
|
||||
)
|
||||
|
||||
def build(self) -> None:
|
||||
"""Build the engine for benchmarking."""
|
||||
self._run_build(self.get_build_command())
|
||||
|
||||
@command_logger(prefix="BENCHMARK COMMAND: ")
|
||||
@process_error_check
|
||||
def _run_benchmark(self, cmd: List[str]) -> CompletedProcess:
|
||||
"""Run the benchmark command in the configured workspace.
|
||||
|
||||
Args:
|
||||
cmd (List[str]): List of command line arguments to run via
|
||||
subprocess.
|
||||
|
||||
Returns:
|
||||
CompletedProcess: Completed process information for reporting.
|
||||
"""
|
||||
return run_process(cmd, run_dir=self.config.workspace, use_environ=True)
|
||||
|
||||
@staticmethod
|
||||
def parse_benchmark_result(benchmark_line: str) -> Dict[str, str]:
|
||||
pass
|
||||
|
||||
def benchmark(self):
|
||||
"""Benchmarks a TRT-LLM for a configured instance."""
|
||||
|
||||
# Compile the command for running
|
||||
cmd = [
|
||||
"mpirun",
|
||||
"-allow-run-as-root",
|
||||
"-n",
|
||||
self.config.world_size,
|
||||
self.gpt_session_path,
|
||||
"--engine_dir",
|
||||
self.config.engine_path,
|
||||
"--batch_size",
|
||||
self.batch_size,
|
||||
"--log_level",
|
||||
"info",
|
||||
"--max_tokens_in_paged_kvcache",
|
||||
self.max_tokens_in_kv_cache,
|
||||
"--kv_cache_free_gpu_mem_fraction",
|
||||
self.kv_cache_mem,
|
||||
"--beam_width",
|
||||
"1",
|
||||
"--warm_up",
|
||||
self.warm_up,
|
||||
"--num_runs",
|
||||
self.num_runs,
|
||||
"--duration",
|
||||
self.duration,
|
||||
"--input_output_len",
|
||||
f"{self.input_length},{self.output_length};{self.input_length},1",
|
||||
]
|
||||
cmd = [str(arg) for arg in cmd]
|
||||
# Run the benchmark using the provided gptSession benchmark binary.
|
||||
bench_return = self._run_benchmark(cmd)
|
||||
results = [
|
||||
x.split(" ") for x in bench_return.stdout.split("\n")
|
||||
if "[BENCHMARK]" in x
|
||||
]
|
||||
|
||||
ttft = float(results[1][8])
|
||||
gen_time = float(results[0][8]) - ttft
|
||||
total_out = int(results[0][2]) * int(results[0][6])
|
||||
total_in = int(results[0][2]) * int(results[0][4])
|
||||
batch_size = int(results[0][2])
|
||||
|
||||
bench_result = BenchmarkResults(
|
||||
model=self.config.model,
|
||||
dtype=self.config.dtype.value,
|
||||
quantization=str(self.config.quantization.value),
|
||||
max_batch_size=batch_size,
|
||||
total_input_tokens=total_in,
|
||||
total_output_tokens=total_out,
|
||||
tp_size=self.config.tensor_parallel,
|
||||
pp_size=self.config.pipeline_parallel,
|
||||
kv_mem_fraction=self.kv_cache_mem,
|
||||
scheduler="Static",
|
||||
max_tokens_in_cache=self.max_tokens_in_kv_cache,
|
||||
inflight_batching=False,
|
||||
total_latency=results[0][8],
|
||||
first_token_latency=ttft,
|
||||
time_per_output_token=gen_time / (total_out - batch_size),
|
||||
latency_units="ms",
|
||||
throughput=results[0][10],
|
||||
throughput_units="tokens/second",
|
||||
peak_gpu_mem=results[0][16],
|
||||
peak_gpu_mem_units="GB",
|
||||
binary=str(self.gpt_session_path),
|
||||
build_cmd=" ".join(self.get_build_command()),
|
||||
benchmark_cmd=" ".join(cmd))
|
||||
|
||||
return bench_result
|
||||
102
benchmarks/suite/tensorrt_llm_bench/utils/dataclasses.py
Normal file
102
benchmarks/suite/tensorrt_llm_bench/utils/dataclasses.py
Normal file
@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, computed_field
|
||||
from utils import VALID_MODELS
|
||||
from utils.enums import ComputeDtypeEnum, KVCacheDtypeEnum, QuantizationAlgo
|
||||
|
||||
|
||||
class BenchmarkResults(BaseModel):
|
||||
"""High level report out for a benchmark."""
|
||||
|
||||
benchmark_cmd: str = ""
|
||||
binary: str
|
||||
build_cmd: str = ""
|
||||
first_token_latency: float
|
||||
inflight_batching: bool
|
||||
kv_mem_fraction: float
|
||||
latency_units: str
|
||||
max_batch_size: int
|
||||
max_tokens_in_cache: int
|
||||
model: VALID_MODELS
|
||||
peak_gpu_mem_units: str
|
||||
peak_gpu_mem: float
|
||||
scheduler: Literal["Static", "No evict", "Max Utilization"]
|
||||
throughput_units: str
|
||||
throughput: float
|
||||
time_per_output_token: float
|
||||
total_input_tokens: int
|
||||
total_latency: float
|
||||
total_output_tokens: int
|
||||
|
||||
def get_summary(self, config: BenchmarkConfig) -> str:
|
||||
"""Generate the summary information.
|
||||
|
||||
Args:
|
||||
config (BenchmarkConfig): Configuration for the run that generated
|
||||
this result.
|
||||
|
||||
Returns:
|
||||
str: Summary output for printing.
|
||||
"""
|
||||
return (
|
||||
"===========================================================\n"
|
||||
"= METADATA\n"
|
||||
"===========================================================\n"
|
||||
f"Model:\t\t\t{config.model}\n"
|
||||
f"TP Size:\t\t{config.tensor_parallel}\n"
|
||||
f"PP Size:\t\t{config.pipeline_parallel}\n"
|
||||
f"Scheduling Policy:\t{self.scheduler}\n"
|
||||
f"In-flight Batcher?:\t{self.inflight_batching}\n"
|
||||
f"Dtype:\t\t\t{config.dtype.value}\n"
|
||||
f"KV Cache Dtype:\t\t{config.cache_dtype.value}\n"
|
||||
f"KV Cache Size (tokens):\t{self.max_tokens_in_cache}\n"
|
||||
f"Quantization:\t\t{config.quantization.value}\n"
|
||||
f"KV Memory Percentage:\t{self.kv_mem_fraction * 100}%\n"
|
||||
f"\n"
|
||||
"===========================================================\n"
|
||||
"= ENGINE DETAILS\n"
|
||||
"===========================================================\n"
|
||||
f"Engine Directory:\t{config.engine_path}\n"
|
||||
f"Max Batch Size:\t\t{self.max_batch_size}\n"
|
||||
f"Total Input Length:\t{self.total_input_tokens}\n"
|
||||
f"Total Output Length:\t{self.total_input_tokens}\n"
|
||||
f"\n"
|
||||
"===========================================================\n"
|
||||
"= STATISTICS\n"
|
||||
"===========================================================\n"
|
||||
f"Throughput ({self.throughput_units}):\t{self.throughput}\n"
|
||||
f"Total Latency ({self.latency_units}):\t\t{self.total_latency}\n"
|
||||
f"First Token Latency ({self.latency_units}):\t{self.first_token_latency}\n"
|
||||
f"Token-to-token Latency ({self.latency_units}):\t{self.time_per_output_token}\n"
|
||||
f"Peak GPU Memory Usage ({self.peak_gpu_mem_units}):\t{self.peak_gpu_mem}\n"
|
||||
f"\n"
|
||||
"===========================================================\n"
|
||||
"= COMMANDS\n"
|
||||
"===========================================================\n"
|
||||
f"Build: {self.build_cmd}\n"
|
||||
f"Benchmark: {self.benchmark_cmd}\n")
|
||||
|
||||
|
||||
class BenchmarkConfig(BaseModel):
|
||||
"""Basic configuration of a benchmark."""
|
||||
|
||||
model: VALID_MODELS
|
||||
workspace: Path
|
||||
dtype: ComputeDtypeEnum
|
||||
cache_dtype: KVCacheDtypeEnum
|
||||
quantization: QuantizationAlgo
|
||||
tensor_parallel: int
|
||||
pipeline_parallel: int
|
||||
|
||||
@computed_field
|
||||
def engine_path(self) -> Path:
|
||||
"""Path to the engine workspace."""
|
||||
return Path(self.workspace.absolute(), self.model.lower())
|
||||
|
||||
@computed_field
|
||||
def world_size(self) -> int:
|
||||
"""Total world size needed to run the model."""
|
||||
return self.tensor_parallel * self.pipeline_parallel
|
||||
58
benchmarks/suite/tensorrt_llm_bench/utils/enums.py
Normal file
58
benchmarks/suite/tensorrt_llm_bench/utils/enums.py
Normal file
@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
|
||||
from aenum import MultiValueEnum
|
||||
|
||||
from tensorrt_llm.quantization.mode import QuantAlgo
|
||||
|
||||
|
||||
class KVCacheDtypeEnum(MultiValueEnum):
|
||||
"""Enumeration of KV Cache precisions in TRT-LLM."""
|
||||
FP8 = "FP8", "fp8", "float8"
|
||||
FP16 = None, "FP16", "fp16", "float16"
|
||||
INT8 = "INT8", "int8"
|
||||
|
||||
def get_build_options(self, dtype: ComputeDtypeEnum) -> List[str]:
|
||||
"""Get the build options for TRT-LLM based on KV Cache precision.
|
||||
|
||||
Args:
|
||||
dtype (ComputeDtypeEnum): The activation dtype for the model. This
|
||||
parameter maps the activation dtype for GEMM plugins for certain
|
||||
KV cache precisions.
|
||||
|
||||
Returns:
|
||||
List[str]: A list of command line arguments to be added to build
|
||||
commands.
|
||||
"""
|
||||
if self.value == self.FP8:
|
||||
return ["--strongly_typed"]
|
||||
else:
|
||||
return ["--gemm_plugin", dtype.value]
|
||||
|
||||
|
||||
class ComputeDtypeEnum(MultiValueEnum):
|
||||
"""Enumeration for activation data type."""
|
||||
|
||||
# FLOAT32 = "float32", "fp32", "FP32"
|
||||
FLOAT16 = "float16", "FLOAT16", "fp16", "FP16"
|
||||
BFLOAT16 = "bfloat16", "BFLOAT16", "bf16", "bfp16", "BF16"
|
||||
|
||||
|
||||
# TODO: use quantization.mode.QuantAlgo eventually
|
||||
class QuantizationAlgo(MultiValueEnum):
|
||||
"""Enumerated type for quantization algorithms for string mapping."""
|
||||
|
||||
W8A16 = QuantAlgo.W8A16.value
|
||||
W4A16 = QuantAlgo.W4A16.value
|
||||
W4A16_AWQ = QuantAlgo.W4A16_AWQ.value
|
||||
W4A8_AWQ = QuantAlgo.W4A8_AWQ.value
|
||||
W4A16_GPTQ = QuantAlgo.W4A16_GPTQ.value
|
||||
FP8 = QuantAlgo.FP8.value
|
||||
INT8 = QuantAlgo.INT8.value
|
||||
W8A8_SQ_PER_CHANNEL = QuantAlgo.W8A8_SQ_PER_CHANNEL.value
|
||||
W8A8_SQ_PER_TENSOR_PLUGIN = QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN.value
|
||||
W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN.value
|
||||
W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN = QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TENSOR_PLUGIN.value
|
||||
W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN = QuantAlgo.W8A8_SQ_PER_TENSOR_PER_TOKEN_PLUGIN.value
|
||||
NONE = None, "None", "FP16", "BF16"
|
||||
218
benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py
Normal file
218
benchmarks/suite/tensorrt_llm_bench/utils/trtllm_config.py
Normal file
@ -0,0 +1,218 @@
|
||||
import json
|
||||
import os
|
||||
from argparse import ArgumentParser
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import AliasChoices, AliasPath, BaseModel, Field, model_validator
|
||||
from transformers import AutoConfig
|
||||
from utils import VALID_QUANT_ALGOS
|
||||
from utils.enums import ComputeDtypeEnum, KVCacheDtypeEnum
|
||||
|
||||
PET_dict = {
|
||||
"tiiuae/falcon-7b": "rope_gpt_neox",
|
||||
"tiiuae/falcon-40b": "rope_gpt_neox",
|
||||
"tiiuae/falcon-180B": "rope_gpt_neox",
|
||||
"meta-llama/Llama-2-7b-hf": "rope_gpt_neox",
|
||||
"meta-llama/Llama-2-13b-hf": "rope_gpt_neox",
|
||||
"meta-llama/Llama-2-70b-hf": "rope_gpt_neox",
|
||||
"EleutherAI/gpt-j-6b": "rope_gptj",
|
||||
"bigscience/bloom-560m": "alibi",
|
||||
"mistralai/Mistral-7B-v0.1": "rope_gpt_neox",
|
||||
"01-ai/Yi-6B": "rope_gpt_neox",
|
||||
"01-ai/Yi-34B": "rope_gpt_neox",
|
||||
}
|
||||
HA_dict = {
|
||||
"tiiuae/falcon-7b": "gelu",
|
||||
"tiiuae/falcon-40b": "gelu",
|
||||
"tiiuae/falcon-180B": "gelu",
|
||||
"bigscience/bloom-560m": "gelu",
|
||||
}
|
||||
|
||||
|
||||
class TRTLLM_Mapping(BaseModel):
|
||||
world_size: int = 1
|
||||
tp_size: int = 1
|
||||
pp_size: int = 1
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_world_size(self) -> "TRTLLM_Mapping":
|
||||
self.world_size = self.tp_size * self.pp_size
|
||||
return self
|
||||
|
||||
|
||||
class TRTLLM_Quantization(BaseModel):
|
||||
quant_algo: Optional[VALID_QUANT_ALGOS] = None
|
||||
|
||||
kv_cache_quant_algo: Optional[Literal[None, "FP8", "INT8"]] = None
|
||||
|
||||
group_size: int = 128
|
||||
has_zero_point: bool = False
|
||||
pre_quant_scale: bool = False
|
||||
exclude_modules: Optional[list] = None
|
||||
|
||||
|
||||
class TRTLLM_CheckpointConfig(BaseModel):
|
||||
"""Dataclass for building TRT-LLM model configurations."""
|
||||
|
||||
_VALID_EMBED_TYPE = Literal["learned_absolute", "rope_gptj",
|
||||
"rope_gpt_neox", "alibi", "alibi_with_scale",
|
||||
"relative", "chatglm", ]
|
||||
|
||||
architecture: str = Field(validation_alias=AliasPath("architectures", 0))
|
||||
num_hidden_layers: int = Field(validation_alias=AliasChoices(
|
||||
"num_hidden_layers", "n_layer", "n_layers"))
|
||||
num_attention_heads: int = Field(validation_alias=AliasChoices(
|
||||
"num_attention_heads", "n_head", "n_heads"))
|
||||
num_key_value_heads: int = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("num_key_value_heads", "num_kv_heads"),
|
||||
)
|
||||
|
||||
hidden_size: int = Field(
|
||||
validation_alias=AliasChoices("hidden_size", "n_embd", "d_model"))
|
||||
norm_epsilon: float = Field(
|
||||
default=1e-5,
|
||||
validation_alias=AliasChoices("norm_epsilon", "layer_norm_epsilon"),
|
||||
)
|
||||
vocab_size: int
|
||||
max_position_embeddings: Optional[int] = Field(
|
||||
default=None,
|
||||
validation_alias=AliasChoices("max_position_embeddings", "n_positions"),
|
||||
)
|
||||
hidden_act: str = Field(
|
||||
validation_alias=AliasChoices("hidden_act", "activation_function"))
|
||||
# falcon options
|
||||
bias: Optional[bool] = None
|
||||
parallel_attention: Optional[bool] = Field(
|
||||
default=None, validation_alias=AliasChoices("parallel_attn"))
|
||||
new_decoder_architecture: Optional[bool] = None
|
||||
# opt options
|
||||
do_layer_norm_before: Optional[bool] = None
|
||||
# gptj options
|
||||
rotary_dim: Optional[int] = None
|
||||
|
||||
# dtype has priority over torch_dtype, the latter of which is usually defined in the HF config
|
||||
dtype: Literal["float16", "bfloat16"] = Field(
|
||||
validation_alias=AliasChoices("dtype", "torch_dtype"))
|
||||
logits_dtype: str = "float32"
|
||||
position_embedding_type: _VALID_EMBED_TYPE = "learned_absolute"
|
||||
use_parallel_embedding: bool = False
|
||||
embedding_sharding_dim: int = 0
|
||||
share_embedding_table: bool = False
|
||||
intermediate_size: int = None
|
||||
use_prompt_tuning: bool = False
|
||||
|
||||
mapping: TRTLLM_Mapping
|
||||
quantization: TRTLLM_Quantization
|
||||
|
||||
@model_validator(mode="after")
|
||||
def set_kv_head_default_value(self) -> "TRTLLM_CheckpointConfig":
|
||||
if self.num_key_value_heads is None:
|
||||
self.num_key_value_heads = self.num_attention_heads
|
||||
return self
|
||||
|
||||
|
||||
class TRTLLMConfig:
|
||||
|
||||
def __init__(self, trtllm_config, hf_config=None) -> None:
|
||||
self.trtllm_config = trtllm_config
|
||||
self.hf_config = hf_config
|
||||
# self.nemo_config = nemo_config
|
||||
|
||||
@classmethod
|
||||
def from_hf(
|
||||
cls,
|
||||
hf_model_name,
|
||||
tp,
|
||||
pp,
|
||||
dtype=None,
|
||||
quant_dtype=None,
|
||||
kv_cache_quant_dtype=None,
|
||||
):
|
||||
build_config = {
|
||||
"mapping": {
|
||||
"tp_size": tp,
|
||||
"pp_size": pp,
|
||||
},
|
||||
"quantization": {},
|
||||
}
|
||||
if dtype:
|
||||
build_config["dtype"] = ComputeDtypeEnum(dtype).value
|
||||
if quant_dtype:
|
||||
if not kv_cache_quant_dtype:
|
||||
# will throw errors during validation if the type is invalid
|
||||
kv_cache_quant_dtype = KVCacheDtypeEnum(quant_dtype).value
|
||||
build_config["quantization"] = {
|
||||
"quant_algo": quant_dtype,
|
||||
"kv_cache_quant_algo":
|
||||
KVCacheDtypeEnum(kv_cache_quant_dtype).value,
|
||||
}
|
||||
build_config["position_embedding_type"] = PET_dict[hf_model_name]
|
||||
if hf_model_name in HA_dict:
|
||||
build_config["hidden_act"] = HA_dict[hf_model_name]
|
||||
hf_config = AutoConfig.from_pretrained(hf_model_name).to_dict()
|
||||
trtllm_config = TRTLLM_CheckpointConfig(**hf_config,
|
||||
**build_config).model_dump()
|
||||
return cls(trtllm_config, hf_config)
|
||||
|
||||
def to_json(self, output_dir):
|
||||
with open(os.path.join(output_dir, "generated_config.json"), "w") as f:
|
||||
json.dump(self.trtllm_config, f, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
required=True,
|
||||
type=str,
|
||||
help="HF model name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="TP degree",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pp_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help="PP degree",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
type=str,
|
||||
help="Datatype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--quant_dtype",
|
||||
type=str,
|
||||
help="Quantization datatype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--kv_cache_quant_dtype",
|
||||
type=str,
|
||||
help="KV cache datatype",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--position_embedding_type",
|
||||
type=str,
|
||||
help="TRT-LLM argument",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--hidden_act",
|
||||
type=str,
|
||||
help="TRT-LLM argument",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
trtllm_config = TRTLLMConfig.from_hf(
|
||||
args.model,
|
||||
args.tp_size,
|
||||
args.pp_size,
|
||||
args.dtype,
|
||||
args.quant_dtype,
|
||||
args.kv_cache_quant_dtype,
|
||||
)
|
||||
trtllm_config.to_json(os.getcwd())
|
||||
@ -34,20 +34,24 @@ public:
|
||||
explicit KvCacheConfig(std::optional<SizeType> maxTokens = std::nullopt,
|
||||
std::optional<SizeType> maxAttentionWindow = std::nullopt,
|
||||
std::optional<SizeType> sinkTokenLength = std::nullopt,
|
||||
std::optional<float> freeGpuMemoryFraction = std::nullopt, bool enableBlockReuse = false, bool useUvm = false)
|
||||
std::optional<float> freeGpuMemoryFraction = std::nullopt, bool enableBlockReuse = false, bool useUvm = false,
|
||||
std::optional<size_t> hostCacheSize = std::nullopt, bool onboardBlocks = true)
|
||||
: maxTokens{maxTokens}
|
||||
, maxAttentionWindow{maxAttentionWindow}
|
||||
, sinkTokenLength{sinkTokenLength}
|
||||
, freeGpuMemoryFraction{freeGpuMemoryFraction}
|
||||
, enableBlockReuse(enableBlockReuse)
|
||||
, useUvm(useUvm)
|
||||
, hostCacheSize(hostCacheSize)
|
||||
, onboardBlocks(onboardBlocks)
|
||||
{
|
||||
}
|
||||
|
||||
explicit KvCacheConfig(executor::KvCacheConfig const& kvCacheConfig)
|
||||
: KvCacheConfig(kvCacheConfig.getMaxTokens(), kvCacheConfig.getMaxAttentionWindow(),
|
||||
kvCacheConfig.getSinkTokenLength(), kvCacheConfig.getFreeGpuMemoryFraction(),
|
||||
kvCacheConfig.getEnableBlockReuse(), false)
|
||||
kvCacheConfig.getEnableBlockReuse(), false, kvCacheConfig.getHostCacheSize(),
|
||||
kvCacheConfig.getOnboardBlocks())
|
||||
{
|
||||
}
|
||||
|
||||
@ -55,7 +59,8 @@ public:
|
||||
{
|
||||
return maxTokens == other.maxTokens && maxAttentionWindow == other.maxAttentionWindow
|
||||
&& sinkTokenLength == other.sinkTokenLength && freeGpuMemoryFraction == other.freeGpuMemoryFraction
|
||||
&& enableBlockReuse == other.enableBlockReuse && useUvm == other.useUvm;
|
||||
&& enableBlockReuse == other.enableBlockReuse && useUvm == other.useUvm
|
||||
&& hostCacheSize == other.hostCacheSize && onboardBlocks == other.onboardBlocks;
|
||||
}
|
||||
|
||||
std::optional<SizeType> maxTokens;
|
||||
@ -65,5 +70,7 @@ public:
|
||||
bool enableBlockReuse;
|
||||
static constexpr auto kDefaultGpuMemFraction = 0.9f;
|
||||
bool useUvm;
|
||||
std::optional<size_t> hostCacheSize;
|
||||
bool onboardBlocks;
|
||||
};
|
||||
} // namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
|
||||
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
|
||||
#include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
|
||||
#include "tensorrt_llm/common/memoryUtils.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/cudaStream.h"
|
||||
@ -88,12 +89,18 @@ struct KvCacheStats
|
||||
class KVCacheBlock
|
||||
{
|
||||
public:
|
||||
explicit KVCacheBlock(SizeType blockIdx);
|
||||
explicit KVCacheBlock(SizeType blockIdx, SizeType blocksInPrimaryPool);
|
||||
|
||||
void startScheduling();
|
||||
|
||||
[[nodiscard]] SizeType getBlockIdx() const;
|
||||
|
||||
[[nodiscard]] SizeType getMemoryPoolBlockOffset() const;
|
||||
|
||||
[[nodiscard]] bool isPrimary() const;
|
||||
|
||||
void swapMemoryPoolBlockOffset(std::shared_ptr<KVCacheBlock> otherBlock);
|
||||
|
||||
void incRefCount();
|
||||
|
||||
void decRefCount();
|
||||
@ -120,6 +127,8 @@ public:
|
||||
|
||||
void removeNextBlock(VecTokens const& tokens);
|
||||
|
||||
static std::shared_ptr<KVCacheBlock> findBestGPUBlockToFree(std::shared_ptr<KVCacheBlock> searchStart);
|
||||
|
||||
static std::shared_ptr<KVCacheBlock> findLeafBlock(std::shared_ptr<KVCacheBlock> searchStart);
|
||||
|
||||
[[nodiscard]] BlockPtr findMatchingBlock(VecTokens const& tokens) const;
|
||||
@ -135,6 +144,9 @@ private:
|
||||
// Linear index of block in pool
|
||||
SizeType mBlockIdx;
|
||||
|
||||
// Block in memory pool backing this block
|
||||
SizeType mMemoryPoolBlockOffset;
|
||||
|
||||
// Number of references to the block
|
||||
SizeType mRefCount;
|
||||
|
||||
@ -155,6 +167,9 @@ private:
|
||||
|
||||
// Flag indicating if block is full
|
||||
bool mIsFull;
|
||||
|
||||
// Flag indicating mMemoryPoolBlockOffset refers to secondary pool
|
||||
static constexpr SizeType secondaryPoolFlag = static_cast<SizeType>(1) << (8 * sizeof(SizeType) - 1);
|
||||
};
|
||||
|
||||
class GenerationRequest
|
||||
@ -271,7 +286,9 @@ class BlockManager
|
||||
public:
|
||||
using SizeType = tensorrt_llm::runtime::SizeType;
|
||||
|
||||
explicit BlockManager(SizeType blocksInPool, SizeType tokensPerBlock);
|
||||
explicit BlockManager(SizeType numLayers, SizeType numKvHeads, SizeType sizePerHead, SizeType tokensPerBlock,
|
||||
SizeType blocksInPrimaryPool, SizeType blocksInSecondaryPool, nvinfer1::DataType dtype,
|
||||
std::shared_ptr<runtime::CudaStream> stream, bool useUvm, bool onboardBlocks);
|
||||
|
||||
~BlockManager();
|
||||
|
||||
@ -283,6 +300,10 @@ public:
|
||||
//! \brief Assign blocks for new sequence. Does not try to reuse blocks.
|
||||
void addSequence(GenerationRequest& sequence, SizeType numBlocks, SizeType unsharedBlockIdx);
|
||||
|
||||
//! \brief Release block, which puts it back onto free blocks queue.
|
||||
//! \details Block appended by default, will be put at front if toFront is true.
|
||||
void releaseBlock(std::shared_ptr<KVCacheBlock> block, bool toFront = false);
|
||||
|
||||
//! \brief Allocate new block for each beam of the sequence.
|
||||
//! \details Might free cached blocks if no free blocks are available.
|
||||
void allocateBlock(GenerationRequest& sequence, bool shareAmongBeams = false);
|
||||
@ -300,7 +321,7 @@ public:
|
||||
|
||||
[[nodiscard]] SizeType getNumFreeBlocks() const noexcept
|
||||
{
|
||||
return mFreeBlocks.size();
|
||||
return mFreePrimaryBlocks.size();
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getNumReusedBlocks() const noexcept
|
||||
@ -333,6 +354,27 @@ public:
|
||||
return mTokensPerBlock;
|
||||
}
|
||||
|
||||
//! \brief Get size of one field in one layer in one block.
|
||||
[[nodiscard]] SizeType getBlockSize() const
|
||||
{
|
||||
return mBlockSize;
|
||||
}
|
||||
|
||||
[[nodiscard]] runtime::ITensor::SharedPtr getPrimaryPool() const noexcept
|
||||
{
|
||||
return mPrimaryPool;
|
||||
}
|
||||
|
||||
//! \brief Get raw void* pointer to K block.
|
||||
//! \param blockIdx the blockIdx as returned by getBlockIdx()
|
||||
//! \param layerNum layer number.
|
||||
//! \param fieldIdx either 0 (K) or 1 (V),
|
||||
[[nodiscard]] void* getKOrVBlockPointer(SizeType blockIdx, SizeType layerNum, SizeType fieldIdx) const;
|
||||
|
||||
//! \brief Bring offloaded block from secondary to primary memory.
|
||||
//! \details Does nothing of block is already in primary memory.
|
||||
void onboardBlock(BlockPtr offloadBlock);
|
||||
|
||||
private:
|
||||
//! \brief Add single block to beam of sequence and mAllocatedBlocksPerSeq.
|
||||
void addBlockToBeam(BlockPtr& block, GenerationRequest& sequence, SizeType beamIdx, SizeType seqSlotIdx);
|
||||
@ -351,6 +393,11 @@ private:
|
||||
SizeType loadOrAllocateBlocks(
|
||||
std::list<VecTokens> const& blockedTokens, GenerationRequest& sequence, SizeType beamIdx, SizeType seqSlotIdx);
|
||||
|
||||
//! \brief Find best primary block to free.
|
||||
//! \details The best primary block to free is the primary block that appears first in the queue and have no primary
|
||||
//! block descendants
|
||||
[[nodiscard]] std::shared_ptr<KVCacheBlock> findBestGPUBlockToFree();
|
||||
|
||||
//! \brief Find block least likely to be reused, free it if necessary and return.
|
||||
[[nodiscard]] BlockPtr getFreeBlock();
|
||||
|
||||
@ -360,11 +407,30 @@ private:
|
||||
//! \brief Free block from previous block and claim it from free blocks list.
|
||||
void claimLeafBlock(KVCacheBlock& block);
|
||||
|
||||
//! \brief Compute pointer to raw KV block (K & V, all layers).
|
||||
[[nodiscard]] runtime::ITensor::SharedPtr computeBlockPointer(std::shared_ptr<KVCacheBlock> block) const;
|
||||
|
||||
//! \brief Copy content of src block to dst.
|
||||
void copyBlock(BlockPtr src, BlockPtr dst);
|
||||
|
||||
private:
|
||||
// List of free blocks
|
||||
FreeBlocksQueue mFreeBlocks;
|
||||
// List of free blocks. Blocks are either backed by fast primary memory or slow secondary memory,
|
||||
// we maintain separate queues for these.
|
||||
FreeBlocksQueue mFreePrimaryBlocks;
|
||||
FreeBlocksQueue mFreeSecondaryBlocks;
|
||||
// List of allocated blocks for each sequences
|
||||
std::vector<std::vector<BlockPtr>> mAllocatedBlocksPerSeq;
|
||||
// Memory pools. Primary is fast memory, secondary is slower memory used for offloading.
|
||||
runtime::ITensor::SharedPtr mPrimaryPool;
|
||||
runtime::ITensor::SharedPtr mSecondaryPool;
|
||||
// Whether offloaded blocks should be onboarded before reuse.
|
||||
bool mOnboardBlocks;
|
||||
// Buffer manager
|
||||
runtime::BufferManager mBufferManager;
|
||||
// Number of layers
|
||||
SizeType mNumLayers;
|
||||
// Volume of [numKvHeads, tokensPerBlock, sizePerHead]
|
||||
SizeType mBlockSize;
|
||||
// Used to keep track of number of free blocks during scheduling
|
||||
SizeType mSchedulingNumFreeBlocks;
|
||||
// Number of tokens per one block
|
||||
@ -385,9 +451,9 @@ public:
|
||||
using CudaStreamPtr = std::shared_ptr<runtime::CudaStream>;
|
||||
|
||||
KVCacheManager(SizeType numLayers, SizeType numKvHeads, SizeType sizePerHead, SizeType tokensPerBlock,
|
||||
SizeType maxNumBlocks, SizeType maxNumSequences, SizeType maxBeamWidth, SizeType maxAttentionWindow,
|
||||
SizeType sinkTokenLength, bool useOneMoreBlock, nvinfer1::DataType dtype, CudaStreamPtr stream,
|
||||
bool enableBlockReuse = false, bool useUvm = false);
|
||||
SizeType blocksInPrimaryPool, SizeType blocksInSecondaryPool, SizeType maxNumSequences, SizeType maxBeamWidth,
|
||||
SizeType maxAttentionWindow, SizeType sinkTokenLength, bool useOneMoreBlock, nvinfer1::DataType dtype,
|
||||
CudaStreamPtr stream, bool enableBlockReuse = false, bool useUvm = false, bool onboardBlocks = true);
|
||||
|
||||
void startScheduling();
|
||||
|
||||
@ -422,10 +488,10 @@ public:
|
||||
return kvCacheStats;
|
||||
}
|
||||
|
||||
// Volume of [2, numKvHeads, tokensPerBlock, sizePerHead]
|
||||
// Volume of [numKvHeads, tokensPerBlock, sizePerHead]
|
||||
[[nodiscard]] SizeType getBlockSize() const
|
||||
{
|
||||
return mBlockSize;
|
||||
return mBlockManager.getBlockSize();
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType getMaxBlocksPerSeq() const
|
||||
@ -450,11 +516,6 @@ public:
|
||||
/// @return The number of blocks
|
||||
[[nodiscard]] SizeType getNeededBlocksToCompletion(LlmRequest const& req) const;
|
||||
|
||||
[[nodiscard]] std::vector<runtime::ITensor::SharedPtr> const& getMemoryPools() const
|
||||
{
|
||||
return mPools;
|
||||
}
|
||||
|
||||
void addContextTokens(SizeType seqSlotIdx, SizeType numTokens);
|
||||
|
||||
void addToken(SizeType seqSlotIdx);
|
||||
@ -487,9 +548,9 @@ public:
|
||||
* modelConfig.getSizePerHead();
|
||||
}
|
||||
|
||||
[[nodiscard]] static SizeType calculateMaxNumBlocks(KvCacheConfig const& config, nvinfer1::DataType dtype,
|
||||
tensorrt_llm::runtime::GptModelConfig const& modelConfig, tensorrt_llm::runtime::WorldConfig const& worldConfig,
|
||||
runtime::BufferManager const& bufferManager);
|
||||
[[nodiscard]] static std::tuple<SizeType, SizeType> const calculateMaxNumBlocks(KvCacheConfig const& config,
|
||||
nvinfer1::DataType dtype, tensorrt_llm::runtime::GptModelConfig const& modelConfig,
|
||||
tensorrt_llm::runtime::WorldConfig const& worldConfig, runtime::BufferManager const& bufferManager);
|
||||
|
||||
[[nodiscard]] SizeType getNumPrepopulatedTokens(SizeType batchSlotIdx, SizeType beamIdx) const
|
||||
{
|
||||
@ -506,6 +567,9 @@ public:
|
||||
void rewindKVCache(SizeType seqSlotIdx, SizeType rewindLengths);
|
||||
|
||||
private:
|
||||
void setPointers(void** pointersPtr, nvinfer1::Dims const& pointersShape, SizeType layerNum, SizeType seqSlotIdx,
|
||||
SizeType beamIdx, SizeType blockIdx, SizeType blockId);
|
||||
|
||||
void resetBlockPointers(SizeType seqSlotIdx, SizeType beamWidth);
|
||||
void cacheBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
|
||||
void cacheNewBlockPointers(GenerationRequest const& seq, SizeType seqSlotIdx);
|
||||
@ -513,8 +577,8 @@ private:
|
||||
void updateToken(SizeType seqSlotIdx, bool addToken);
|
||||
|
||||
private:
|
||||
// Number of elements per one blocks
|
||||
SizeType mBlockSize;
|
||||
// Number of layers
|
||||
SizeType mNumLayers;
|
||||
// Maximum number of sequences
|
||||
SizeType mMaxNumSequences;
|
||||
// Maximum beam width
|
||||
@ -530,16 +594,12 @@ private:
|
||||
SizeType mMaxTokenNum;
|
||||
// Number of tokens in the sink blocks
|
||||
SizeType mSinkBlockTokenLength;
|
||||
// Pools
|
||||
std::vector<runtime::ITensor::SharedPtr> mPools;
|
||||
// Block manager
|
||||
BlockManager mBlockManager;
|
||||
// List of all sequences
|
||||
std::vector<SequencesPtr> mSequences;
|
||||
// buffer for block pointers for all managed sequences
|
||||
runtime::ITensor::SharedPtr mSequenceBlockPointers;
|
||||
// Buffer manager
|
||||
runtime::BufferManager mBufferManager;
|
||||
// Whether to cache KV pages for reuse
|
||||
bool mEnableBlockReuse;
|
||||
};
|
||||
|
||||
@ -334,13 +334,16 @@ public:
|
||||
explicit KvCacheConfig(bool enableBlockReuse = false, std::optional<SizeType> const& maxTokens = std::nullopt,
|
||||
std::optional<SizeType> const& maxAttentionWindow = std::nullopt,
|
||||
std::optional<SizeType> const& sinkTokenLength = std::nullopt,
|
||||
std::optional<FloatType> const& freeGpuMemoryFraction = std::nullopt);
|
||||
std::optional<FloatType> const& freeGpuMemoryFraction = std::nullopt,
|
||||
std::optional<size_t> const& hostCacheSize = std::nullopt, bool onboardBlocks = true);
|
||||
|
||||
[[nodiscard]] bool getEnableBlockReuse() const;
|
||||
[[nodiscard]] std::optional<SizeType> getMaxTokens() const;
|
||||
[[nodiscard]] std::optional<SizeType> getMaxAttentionWindow() const;
|
||||
[[nodiscard]] std::optional<SizeType> getSinkTokenLength() const;
|
||||
[[nodiscard]] std::optional<FloatType> getFreeGpuMemoryFraction() const;
|
||||
[[nodiscard]] std::optional<size_t> getHostCacheSize() const;
|
||||
[[nodiscard]] bool getOnboardBlocks() const;
|
||||
|
||||
private:
|
||||
/// @brief Controls if KV cache blocks can be reused for different requests
|
||||
@ -362,6 +365,13 @@ private:
|
||||
/// If both mMaxTokens and mFreeGpuMemoryFraction are specified, memory corresponding to the minimum will be
|
||||
/// allocated.
|
||||
std::optional<FloatType> mFreeGpuMemoryFraction;
|
||||
|
||||
/// @brief Size of secondary memory pool in bytes. Default is 0.
|
||||
/// Having a secondary memory pool increases KV cache block reuse potential.
|
||||
std::optional<size_t> mHostCacheSize;
|
||||
|
||||
/// @brief Controls whether offloaded blocks should be onboarded back into primary memory before being reused.
|
||||
bool mOnboardBlocks;
|
||||
};
|
||||
|
||||
SizeType const kDefaultIterStatsMaxIterations = 1000;
|
||||
|
||||
@ -58,12 +58,18 @@ public:
|
||||
|
||||
static auto constexpr kBYTE_TYPE = nvinfer1::DataType::kUINT8;
|
||||
|
||||
//! \brief Allocates an `IBuffer` of the given size on the GPU.
|
||||
//! \brief Allocates an `IBuffer` of the given size on the GPU, using cudaMallocAsync.
|
||||
[[nodiscard]] IBufferPtr gpu(std::size_t size, nvinfer1::DataType type = kBYTE_TYPE) const;
|
||||
|
||||
//! \brief Allocates an `ITensor` of the given dimensions on the GPU.
|
||||
//! \brief Allocates an `ITensor` of the given dimensions on the GPU, using cudaMallocAsync.
|
||||
[[nodiscard]] ITensorPtr gpu(nvinfer1::Dims dims, nvinfer1::DataType type = kBYTE_TYPE) const;
|
||||
|
||||
//! \brief Allocates an `IBuffer` of the given size on the GPU, using cudaMalloc.
|
||||
[[nodiscard]] static IBufferPtr gpuSync(std::size_t size, nvinfer1::DataType type = kBYTE_TYPE);
|
||||
|
||||
//! \brief Allocates an `ITensor` of the given dimensions on the GPU, using cudaMalloc.
|
||||
[[nodiscard]] static ITensorPtr gpuSync(nvinfer1::Dims dims, nvinfer1::DataType type = kBYTE_TYPE);
|
||||
|
||||
//! \brief Allocates an `IBuffer` of the given size on the CPU.
|
||||
[[nodiscard]] static IBufferPtr cpu(std::size_t size, nvinfer1::DataType type = kBYTE_TYPE);
|
||||
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:2309c96080b61795e03130338d9a023c67c4ecba7b7ba9d32797d2ce8fe170aa
|
||||
size 2869834
|
||||
oid sha256:50298b3d11057edfeb48d0b45098e7fcce4f1c3f7c28389f641e5d45853c9b12
|
||||
size 2895786
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:eb623f435dd920799783d353b5062fd4c3dcaa3529d2007d1550cc21ecae39ee
|
||||
size 2898404
|
||||
oid sha256:29a19dc0fb9826ed6aea1f06cf07f5e253b308ac78ddfa06ab98e43dbb032dc8
|
||||
size 2920988
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
577dc50763c3152388738bcb22ed9f73 libtensorrt_llm_batch_manager_static.a
|
||||
2ae3fab3836c6fd961d82458663888bb libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
9fd0ada1f commit
|
||||
0f15ea17f78179fbb5b3439456361e41 libtensorrt_llm_batch_manager_static.a
|
||||
7ce3a17e5d12c105c7e3a8d14141ff44 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
1b4b3c74a315271a2c3125f7bdae1e2442a5c85b commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:26da47fbe5e5a2246db58152011c78b87a56deafb6df38821638b1c27c00af22
|
||||
size 2796970
|
||||
oid sha256:d46279b3d3adbc1eb50f342a1002e3f10ff51ff30ee506fd1027bd70ad902384
|
||||
size 2820644
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:04243eaf43b74886b6b8ab8ec13482c2c1c80d3468e8f6f40a8e16a8d4cace6e
|
||||
size 2769552
|
||||
oid sha256:b777489f45fba1edc756d946ece9ab3edcf60be41039989f085394f0928927f9
|
||||
size 2793016
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c95bca35d86ddcfd7df9ac81dcd05381b1487c91027e7103ae836761f3c771a1
|
||||
size 19002500
|
||||
@ -0,0 +1,2 @@
|
||||
09960aa8bf288fc4896bebebc83b020e tensorrt_llm_batch_manager_static.lib
|
||||
1b4b3c74a315271a2c3125f7bdae1e2442a5c85b commit
|
||||
@ -77,4 +77,24 @@ int getEnvMmhaBlocksPerSequence()
|
||||
return mmhaBlocksPerSequence;
|
||||
}
|
||||
|
||||
int getEnvMmhaKernelBlockSize()
|
||||
{
|
||||
static bool init = false;
|
||||
static int mmhaKernelBlockSize = 0;
|
||||
if (!init)
|
||||
{
|
||||
init = true;
|
||||
char const* mmhaKernelBlockSizeEnv = std::getenv("TRTLLM_MMHA_KERNEL_BLOCK_SIZE");
|
||||
if (mmhaKernelBlockSizeEnv)
|
||||
{
|
||||
mmhaKernelBlockSize = std::atoi(mmhaKernelBlockSizeEnv);
|
||||
if (mmhaKernelBlockSize <= 0)
|
||||
{
|
||||
TLLM_LOG_WARNING("Invalid value for TRTLLM_MMHA_KERNEL_BLOCK_SIZE. Will use default values instead!");
|
||||
}
|
||||
}
|
||||
}
|
||||
return mmhaKernelBlockSize;
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
@ -28,4 +28,6 @@ bool getEnvMmhaMultiblockDebug();
|
||||
|
||||
int getEnvMmhaBlocksPerSequence();
|
||||
|
||||
int getEnvMmhaKernelBlockSize();
|
||||
|
||||
} // namespace tensorrt_llm::common
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:36f02388a9bd2ae3d45f0d6480bd95cd99f8ea30eebf1a315b8d54e742fed479
|
||||
size 846308
|
||||
oid sha256:183e7b5b112dc6eb5618fb9ac702f5b8fc9ad4a34067c633de3720f501f44a43
|
||||
size 847076
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:c707f67abccca217d81e8d85e361b6d214131045763df5f806cb789157ea4f80
|
||||
size 857730
|
||||
oid sha256:455a376305510cc22e1aeed6ac1c4dcccbfe0731d06de8d0fe631335d29ca3c3
|
||||
size 858658
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
a552c727c128f6ca402ddc119d295ab0 libtensorrt_llm_executor_static.a
|
||||
38ad482b0be0996970bc572622967acf libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
9fd0ada1f commit
|
||||
0b3c4c60fe043285139bad82e4f815d9 libtensorrt_llm_executor_static.a
|
||||
e71f7c14efd873fb7f12aab09973db09 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
1b4b3c74a315271a2c3125f7bdae1e2442a5c85b commit
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5aa174bdc52db4f8f193ccdbc4b764835564204a5576c91fffd36720e9d79fdd
|
||||
size 884870
|
||||
oid sha256:d5e036441b0f1545d591be8dcf743e68a56beaa77594c3784212bac95a858b9b
|
||||
size 885862
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e1117c2c02debd8f373d541c6d96cac5f66a7a9d374a10bbb7c89d6455b869ba
|
||||
size 837988
|
||||
oid sha256:00a15798b0f4d879915b885a5203a1c764ba0c3bbbe34c61730c7b944a2da064
|
||||
size 838884
|
||||
|
||||
@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5d9f137b9202460b8538d035067bb54f52c65737e798f283fa82d9e017fa8cc2
|
||||
size 9770148
|
||||
@ -0,0 +1,2 @@
|
||||
d29485c1ddb2d2d4cb1c1f4476642124 tensorrt_llm_executor_static.lib
|
||||
1b4b3c74a315271a2c3125f7bdae1e2442a5c85b commit
|
||||
69
cpp/tensorrt_llm/kernels/beamSearchKernels.cu
Normal file
69
cpp/tensorrt_llm/kernels/beamSearchKernels.cu
Normal file
@ -0,0 +1,69 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/kernels/beamSearchKernels.h"
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
template <typename T, int MAX_K>
|
||||
void topK_softMax_kernelLauncher(
|
||||
T const* logits, T const* bias, void* workspace, BeamHypotheses& bh, cudaStream_t stream);
|
||||
|
||||
#define CASE_K(MAX_K) \
|
||||
topK_softMax_kernelLauncher<T, MAX_K>(logits, bias, workspace, bh, stream); \
|
||||
break;
|
||||
|
||||
template <typename T>
|
||||
void invokeTopkSoftMax(T const* logits, T const* bias, void* workspace, BeamHypotheses& bh, cudaStream_t stream)
|
||||
{
|
||||
switch (padToNextPowerOfTwo(bh.beam_width))
|
||||
{
|
||||
case 1:
|
||||
case 2:
|
||||
case 4: // 0 < beam_width <= 4
|
||||
CASE_K(4)
|
||||
case 8: // 4 < beam_width <= 8
|
||||
CASE_K(8)
|
||||
#ifndef FAST_BUILD // For fast build, skip case 3, 4, 5
|
||||
case 16: // 9 < beam_width <= 16
|
||||
CASE_K(16)
|
||||
case 32: // 16 < beam_width <= 32
|
||||
CASE_K(32)
|
||||
case 64: // 32 < beam_width <= 64
|
||||
CASE_K(64)
|
||||
#endif // FAST_BUILD
|
||||
default:
|
||||
throw std::runtime_error(fmtstr(
|
||||
"%s:%d Topk kernel of beam search does not support beam_width=%d", __FILE__, __LINE__, bh.beam_width));
|
||||
}
|
||||
}
|
||||
|
||||
#undef CASE_K
|
||||
|
||||
template void invokeTopkSoftMax<float>(
|
||||
float const* logits, float const* bias, void* workspace, BeamHypotheses& bh, cudaStream_t stream);
|
||||
|
||||
template void invokeTopkSoftMax<half>(
|
||||
half const* logits, half const* bias, void* workspace, BeamHypotheses& bh, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
112
cpp/tensorrt_llm/kernels/beamSearchKernels.h
Normal file
112
cpp/tensorrt_llm/kernels/beamSearchKernels.h
Normal file
@ -0,0 +1,112 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
static constexpr int nMaxBeamWidth = 64; // max beam width supported now
|
||||
static constexpr int nSmallTopKBlockSize = 256;
|
||||
static constexpr int nSmallTopKMaxVocParts = 128;
|
||||
|
||||
struct BeamHypotheses
|
||||
{
|
||||
// clang-format off
|
||||
|
||||
// BS: batch_size, BM: beam_width, mSL: max_seq_length
|
||||
// %%: parameter name when dynamic_decoder.forward() / gather_tree() are called in [generation.py] (python workflow)
|
||||
|
||||
// Candidate beams: When a beam generates end_id or its sequence length reaches mSL, it becomes a candidate beam to be selected finally.
|
||||
// Candidate-Beam-Array (CBA): Arrays (size: BM*2) to place the candidate beams and related information
|
||||
|
||||
// Scalar values
|
||||
bool is_return_normed_score{true}; // return normed_score / cum_log_probs, useless yet
|
||||
int batch_size{0}; //
|
||||
int beam_width{0}; //
|
||||
int ite{0}; // index of local_batch, always be 0 when pp_size==1
|
||||
int local_batch_size{0}; //
|
||||
int max_seq_len{0}; //
|
||||
int vocab_size{0}; // vocab_size_padded
|
||||
|
||||
// Pointers from SamplingConfig
|
||||
float const* diversity_rates{nullptr}; // [BS]
|
||||
float const* length_penalties{nullptr}; // [BS]
|
||||
int const* early_stoppings{nullptr}; // [BS]
|
||||
|
||||
// Pointers from input
|
||||
int const* input_lengths{nullptr}; // [BS, BM] %% context_length
|
||||
int const* end_ids{nullptr}; // [BS, BM] %% self.end_ids
|
||||
|
||||
// Pointers for output
|
||||
int* final_output_ids{nullptr}; // [BS, BM, mSL] %% self.output_ids
|
||||
float* log_probs{nullptr}; // [mSL, BS, BM] %% self.log_probs_tiled
|
||||
int* seq_len{nullptr}; // [BS, BM] %% self.sequence_length_buffer
|
||||
float* cum_log_probs{nullptr}; // [BS, BM] %% self.cum_log_probs
|
||||
|
||||
// Pointers of CBA
|
||||
int* output_ids_cba{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_output_ids_tgt
|
||||
float* log_probs_cba{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_log_probs
|
||||
int* seq_len_cba{nullptr}; // [BS, BM*2] %% self.beam_hyps_sequence_lengths_tgt
|
||||
float* cum_log_probs_cba{nullptr}; // [BS, BM*2] %% self.beam_hyps_cum_log_probs
|
||||
float* normed_scores_cba{nullptr}; // [BS, BM*2] %% self.beam_hyps_normed_scores
|
||||
int* num_beams{nullptr}; // [BS] %% self.beam_hyps_num_beams number of beams in CBA
|
||||
float* min_normed_scores{nullptr}; // [BS] %% self.beam_hyps_min_normed_scores worst score in CBA
|
||||
|
||||
// Pointers related to beam search process, they are initialized in those two functions:
|
||||
// [gptDecoder.cpp] GptDecoder<T>::forward or [dynamicDecodeOp.cpp] FtDynamicDecode<T>::forward
|
||||
bool* is_done{nullptr}; // [BS] %% self.beam_hyps_is_done whether a whole batch is finished
|
||||
FinishedState* finished; // [BS*BM] %% self.finished whether and how a beam is finished
|
||||
|
||||
// Pointers for backtrack of the beams, they are relocated in [dynamicDecodeLayer.cpp] DynamicDecodeLayer<T>::prepareIdsPtrs
|
||||
int** output_ids_ptr{nullptr}; // [BS][BM, mSL] %% self.output_ids
|
||||
int** parent_ids_ptr{nullptr}; // [BS][BM, mSL] %% self.parent_ids
|
||||
|
||||
// Pointers for gather_tree(), read the unfinished beams from them and write to CBA for the final selection
|
||||
int const* output_ids_src{nullptr}; // [BS, BM, mSL] %% self.output_ids
|
||||
int const* parent_ids_src{nullptr}; // [BS, BM, mSL] %% self.parent_ids
|
||||
|
||||
// clang-format on
|
||||
};
|
||||
|
||||
__inline__ int padToNextPowerOfTwo(int const n)
|
||||
{
|
||||
// Pad n up to the nearest power of 2
|
||||
int recursor = n - 1;
|
||||
int res = 2;
|
||||
while (recursor >>= 1)
|
||||
res <<= 1;
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T applyLengthPenalty(T const log_prob, int const length, float const length_penalty)
|
||||
{
|
||||
// score = log(prob) / (length ^ length_penalty)
|
||||
if (length_penalty == 0.0f || length == 1)
|
||||
{
|
||||
return log_prob;
|
||||
}
|
||||
return log_prob / static_cast<T>(powf(static_cast<float>(length), length_penalty));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeTopkSoftMax(T const* logits, T const* bias, void* workspace, BeamHypotheses& bh, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "onlineSoftmaxBeamsearchKernelsTemplate.h"
|
||||
#include "beamSearchKernelsTemplate.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "onlineSoftmaxBeamsearchKernelsTemplate.h"
|
||||
#include "beamSearchKernelsTemplate.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "onlineSoftmaxBeamsearchKernelsTemplate.h"
|
||||
#include "beamSearchKernelsTemplate.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "onlineSoftmaxBeamsearchKernelsTemplate.h"
|
||||
#include "beamSearchKernelsTemplate.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
@ -14,7 +14,7 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "onlineSoftmaxBeamsearchKernelsTemplate.h"
|
||||
#include "beamSearchKernelsTemplate.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
@ -0,0 +1,909 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef CUDART_VERSION
|
||||
#error CUDART_VERSION Undefined!
|
||||
#elif (CUDART_VERSION >= 11050)
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include "3rdparty/cub/cub.cuh"
|
||||
#endif
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include "tensorrt_llm/kernels/beamSearchKernels.h"
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
#define DO_SPLIT_SMALL_TOP_K_SOFTMAX
|
||||
|
||||
#define TOPK_FP16_STORAGE 0
|
||||
|
||||
#pragma nv_diag_suppress static_var_with_dynamic_init
|
||||
|
||||
template <typename T, int MAX_K2, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
void batchBeamKernel(int const* __restrict topk_id_buffer, T const* __restrict topk_val_buffer, BeamHypotheses bh)
|
||||
{
|
||||
int const tid = threadIdx.x;
|
||||
int const bid = blockIdx.x;
|
||||
int const gbid{bh.ite * bh.local_batch_size + bid}; // global batch index
|
||||
int const K{bh.beam_width};
|
||||
int const V{bh.vocab_size};
|
||||
int const nCandidate{K * K * 2};
|
||||
T const MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
float const diversity_rate{bh.diversity_rates[gbid]};
|
||||
float const length_penalty{bh.length_penalties[gbid]};
|
||||
int const early_stopping{bh.early_stoppings[gbid]};
|
||||
int const* input_lengths{bh.input_lengths};
|
||||
|
||||
__shared__ int nBeamForNextStep;
|
||||
__shared__ float smem_cum_log_probs[MAX_K2 / 2];
|
||||
|
||||
if (tid == 0)
|
||||
{
|
||||
nBeamForNextStep = 0;
|
||||
}
|
||||
if (tid < K)
|
||||
{
|
||||
smem_cum_log_probs[tid] = bh.cum_log_probs[bid * K + tid];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (bh.num_beams != nullptr)
|
||||
{
|
||||
// Beam search is enabled
|
||||
if (bh.num_beams[gbid] == 0 && tid == 0)
|
||||
{
|
||||
// Initialize worst_score in the first time
|
||||
bh.min_normed_scores[gbid] = FLT_MAX;
|
||||
}
|
||||
else if (early_stopping == 1 && bh.num_beams[gbid] == K
|
||||
|| early_stopping != 1 && bh.finished[bid * K].isFinished())
|
||||
{
|
||||
// New but false condition:
|
||||
// else if (early_stopping == 1 && bh.num_beams[gbid] == K || early_stopping != 1 && bh.is_done[bid])
|
||||
// Condition of early return:
|
||||
// 1. In EarlyStopping mode, and we have got enough beams
|
||||
// 2. In NonEarlyStopping mode, and this batch has been marked as done
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// Get top 2K tokens from candidates
|
||||
topk_id_buffer += bid * nCandidate;
|
||||
topk_val_buffer += bid * nCandidate;
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
cub_kvp partial_topk{nCandidate - 1, -MAX_T_VAL};
|
||||
cub::ArgMax arg_max;
|
||||
extern __shared__ char smem[];
|
||||
T* smem_topk = reinterpret_cast<T*>(smem);
|
||||
|
||||
for (int id = tid; id < nCandidate; id += THREADBLOCK_SIZE)
|
||||
{
|
||||
int const index = bh.num_beams == nullptr ? id % K : id / 2 / K;
|
||||
T val = topk_val_buffer[id] + static_cast<T>(diversity_rate * index);
|
||||
cub_kvp new_elem{id, val};
|
||||
partial_topk = arg_max(partial_topk, new_elem);
|
||||
smem_topk[id] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
using BlockReduce = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>;
|
||||
__shared__ typename BlockReduce::TempStorage reduce_buffer;
|
||||
__shared__ cub_kvp cta_topk[MAX_K2];
|
||||
__shared__ int thread_requiring_update;
|
||||
|
||||
for (int i = 0; i < 2 * K; ++i)
|
||||
{
|
||||
cub_kvp total_topk = BlockReduce(reduce_buffer).Reduce(partial_topk, arg_max);
|
||||
if (tid == 0)
|
||||
{
|
||||
cta_topk[i] = total_topk;
|
||||
smem_topk[total_topk.key] = -MAX_T_VAL;
|
||||
thread_requiring_update = total_topk.key % THREADBLOCK_SIZE;
|
||||
}
|
||||
__syncthreads();
|
||||
// Only one thread needs to update the old partial before the next block reduce.
|
||||
// No need to do this in the last iteration.
|
||||
if (tid == thread_requiring_update && i < (2 * K - 1))
|
||||
{
|
||||
partial_topk.key = nCandidate - 1;
|
||||
partial_topk.value = -MAX_T_VAL;
|
||||
for (int index = tid; index < nCandidate; index += THREADBLOCK_SIZE)
|
||||
{
|
||||
cub_kvp new_elem{index, smem_topk[index]};
|
||||
partial_topk = arg_max(partial_topk, new_elem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (tid == 0)
|
||||
{
|
||||
// Adjust beams or select completed beams sequentially
|
||||
// Reference (might be changed along HF in the future):
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L272
|
||||
for (int i = 0; i < 2 * K; ++i)
|
||||
{
|
||||
int const current_key = cta_topk[i].key;
|
||||
T const current_value = cta_topk[i].value;
|
||||
bool const is_end_token = topk_id_buffer[current_key] % V == bh.end_ids[bid];
|
||||
if (i < K && bh.num_beams != nullptr && is_end_token)
|
||||
{
|
||||
// Condition of this branch
|
||||
// In Beam search mode, this token is end_token and belongs to top K range in Beam search mode
|
||||
int const seq_len = bh.seq_len[bid * K + i] + 1 - bh.input_lengths[gbid * K + i];
|
||||
float const normed_score = applyLengthPenalty(current_value, seq_len, length_penalty);
|
||||
int beam_idx = bh.num_beams[gbid];
|
||||
if (beam_idx == K)
|
||||
{
|
||||
// There are already K beams
|
||||
if (normed_score < bh.min_normed_scores[gbid])
|
||||
{
|
||||
// Current score is worse than the worst one in candidate beams
|
||||
if (early_stopping)
|
||||
{
|
||||
// Stop since we have got enough beams
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Continue since there might be longer but better beams
|
||||
continue;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// Current score is better than the worst one in candidate beams
|
||||
// Find the candidate beam index with the worst score and erase it
|
||||
for (int j = 0; j < K; j++)
|
||||
{
|
||||
if (bh.normed_scores_cba[gbid * (K * 2) + j] == bh.min_normed_scores[gbid])
|
||||
{
|
||||
beam_idx = j;
|
||||
bh.num_beams[gbid]--;
|
||||
bh.min_normed_scores[gbid] = FLT_MAX;
|
||||
bh.normed_scores_cba[gbid * (K * 2) + j] = normed_score;
|
||||
for (int l = 0; l < K; l++)
|
||||
{
|
||||
bh.min_normed_scores[gbid]
|
||||
= min(bh.min_normed_scores[gbid], bh.normed_scores_cba[gbid * (K * 2) + l]);
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
int prev_id = (topk_id_buffer[current_key] / V) % K;
|
||||
int const current_step = bh.seq_len[bid * K + prev_id];
|
||||
int const tgt_id_offset = ((bid + bh.ite * bh.local_batch_size) * (K * 2) + beam_idx) * bh.max_seq_len;
|
||||
bh.output_ids_cba[tgt_id_offset + current_step] = bh.end_ids[bid];
|
||||
if (bh.log_probs_cba != nullptr)
|
||||
{
|
||||
bh.log_probs_cba[tgt_id_offset + current_step] = (float) topk_val_buffer[current_key]
|
||||
- smem_cum_log_probs[(topk_id_buffer[current_key] / V) % K];
|
||||
}
|
||||
// Write finished beam from work tree to CBA
|
||||
for (int j = current_step - 1; j >= 0; j--)
|
||||
{
|
||||
bh.output_ids_cba[tgt_id_offset + j] = bh.output_ids_ptr[bid][prev_id * bh.max_seq_len + j];
|
||||
prev_id = bh.parent_ids_ptr[bid][prev_id * bh.max_seq_len + j];
|
||||
}
|
||||
if (bh.log_probs_cba != nullptr && bh.log_probs != nullptr)
|
||||
{
|
||||
prev_id = (topk_id_buffer[current_key] / V) % K;
|
||||
for (int j = current_step - 1; j >= 0; j--)
|
||||
{
|
||||
int const index = j * bh.batch_size * K + bh.ite * bh.local_batch_size * K + bid * K + prev_id;
|
||||
bh.log_probs_cba[tgt_id_offset + j] = bh.log_probs[index];
|
||||
prev_id = bh.parent_ids_ptr[bid][prev_id * bh.max_seq_len + j];
|
||||
}
|
||||
}
|
||||
int const tgt_beam_idx = gbid * (K * 2) + beam_idx;
|
||||
bh.seq_len_cba[tgt_beam_idx] = current_step;
|
||||
bh.normed_scores_cba[tgt_beam_idx] = normed_score;
|
||||
bh.min_normed_scores[gbid] = min(bh.min_normed_scores[gbid], bh.normed_scores_cba[tgt_beam_idx]);
|
||||
bh.num_beams[gbid]++;
|
||||
bh.cum_log_probs_cba[tgt_beam_idx] = (float) topk_val_buffer[current_key];
|
||||
}
|
||||
else if (i < K || bh.num_beams != nullptr && !is_end_token)
|
||||
{
|
||||
// Condition of this branch
|
||||
// 1. bh.num_beams == nullptr && i < K, i.e., beam search is disable
|
||||
// 2. bh.num_beams != nullptr && i < K && is_end_token == false, i.e., add token at the end
|
||||
// 3. bh.num_beams != nullptr && i >= K && is_end_token == false, i.e., add token at the end
|
||||
int const current_step = bh.seq_len[bid * K + nBeamForNextStep];
|
||||
// Write the selected token to work tree
|
||||
bh.output_ids_ptr[bid][nBeamForNextStep * bh.max_seq_len + current_step] = topk_id_buffer[current_key];
|
||||
if (bh.log_probs != nullptr)
|
||||
{
|
||||
bh.log_probs[current_step * bh.batch_size * K + bid * K + nBeamForNextStep]
|
||||
= (float) topk_val_buffer[current_key]
|
||||
- smem_cum_log_probs[(topk_id_buffer[current_key] / V) % K];
|
||||
}
|
||||
bh.cum_log_probs[bid * K + nBeamForNextStep] = (float) topk_val_buffer[current_key];
|
||||
nBeamForNextStep++;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Condition of this branch, which we do nothing for it
|
||||
// 1. bh.num_beams == nullptr && i >= K, i.e., beam search is disable
|
||||
// 2. bh.num_beams != nullptr && i >= K && is_end_token == true, i.e., ignore the worse beams
|
||||
}
|
||||
|
||||
// if (early_stopping == 1 && bh.num_beams[gbid] >= K || nBeamForNextStep >= K)
|
||||
if (nBeamForNextStep >= K)
|
||||
{
|
||||
// Condition of this branch:
|
||||
// 1. In EarlyStopping mode, and get enough candidate beams
|
||||
// 2. In EarlyStopping mode, and get enough tokens for the next generation step
|
||||
// 3. In NonEarlyStopping mode, and get enough tokens for the next generation step
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update bh.is_done
|
||||
if (tid == 0 && bh.num_beams != nullptr)
|
||||
{
|
||||
if (bh.num_beams[bid] < K)
|
||||
{
|
||||
// no enough beams
|
||||
bh.is_done[bid] = false;
|
||||
}
|
||||
else if (early_stopping == 1)
|
||||
{
|
||||
// enough candidate beams in EarlyStopping mode
|
||||
bh.is_done[bid] = true;
|
||||
}
|
||||
else
|
||||
{
|
||||
// enough beams in NonEarlyStopping mode
|
||||
int seq_len = bh.seq_len[bid * K] + 1 - input_lengths[gbid * K];
|
||||
float const best_sum_logprobs = cta_topk[0].value;
|
||||
// According to semantics of HF, cta_topk[0].value is used as best_sum_logprobs
|
||||
// But maybe bh.cum_log_probs[bid * K + i] is more suitable?
|
||||
// https://github.com/huggingface/transformers/blob/main/src/transformers/generation/beam_search.py#L307
|
||||
if (early_stopping != 0 && length_penalty > 0.0f)
|
||||
{
|
||||
// Specialization for early_stopping == "never" and length_penalty > 0 in HF
|
||||
seq_len = bh.max_seq_len - input_lengths[gbid * K];
|
||||
}
|
||||
float const highest_attainable_score = applyLengthPenalty(best_sum_logprobs, seq_len, length_penalty);
|
||||
bh.is_done[bid] = bh.min_normed_scores[gbid] >= highest_attainable_score;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
// Update sequence_lengths, parent_ids, output_ids and finished
|
||||
__shared__ int s_sequence_lengths[MAX_K2 / 2];
|
||||
if (tid < K)
|
||||
{
|
||||
s_sequence_lengths[tid] = bh.seq_len[bid * K + tid];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tid < K)
|
||||
{
|
||||
int const bb_index = bid * K + tid;
|
||||
int const current_step = s_sequence_lengths[tid];
|
||||
if (!bh.finished[bb_index].isFinished())
|
||||
{
|
||||
s_sequence_lengths[tid]++;
|
||||
}
|
||||
int const new_id = bh.output_ids_ptr[bid][tid * bh.max_seq_len + current_step];
|
||||
int const new_beam_id = (new_id / V) % K;
|
||||
int const new_word_id = new_id % V;
|
||||
bh.seq_len[bb_index] = s_sequence_lengths[new_beam_id];
|
||||
if (new_word_id == bh.end_ids[bid])
|
||||
{
|
||||
bh.finished[bb_index].setFinishedEOS();
|
||||
}
|
||||
bh.parent_ids_ptr[bid][tid * bh.max_seq_len + current_step] = new_beam_id;
|
||||
bh.output_ids_ptr[bid][tid * bh.max_seq_len + current_step] = new_word_id;
|
||||
if ((early_stopping == 1) && (bh.num_beams != nullptr && bh.num_beams[gbid] == K)
|
||||
|| (early_stopping != 1) && bh.is_done[bid])
|
||||
{
|
||||
bh.is_done[bid] = true;
|
||||
bh.finished[bb_index].setFinished();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct __align__(8) MD
|
||||
{
|
||||
float m;
|
||||
float d;
|
||||
};
|
||||
|
||||
__device__ __forceinline__ MD reduce_md_op(MD a, MD b)
|
||||
{
|
||||
bool const is_a_bigger = a.m > b.m;
|
||||
MD const bigger = is_a_bigger ? a : b;
|
||||
MD const smaller = is_a_bigger ? b : a;
|
||||
MD res{bigger.m, bigger.d + smaller.d * __expf(smaller.m - bigger.m)};
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T, int MAX_K>
|
||||
struct TopKMD
|
||||
{
|
||||
MD md;
|
||||
TopK<T, MAX_K> topk;
|
||||
};
|
||||
|
||||
template <typename T, int MAX_K>
|
||||
__device__ __forceinline__ TopKMD<T, MAX_K> reduce_topk_md_op(TopKMD<T, MAX_K> const& a, TopKMD<T, MAX_K> const& b)
|
||||
{
|
||||
TopKMD<T, MAX_K> res;
|
||||
res.md = reduce_md_op(a.md, b.md);
|
||||
res.topk = reduce_topk_op(a.topk, b.topk);
|
||||
return res;
|
||||
}
|
||||
|
||||
template <typename T, int ITEMS_PER_THREAD, int MAX_K, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__ void beamKernel(T const* __restrict logits, T const* __restrict bias,
|
||||
float const* __restrict cum_log_probs, FinishedState const* __restrict finished, int* __restrict topk_id_buffer,
|
||||
T* __restrict topk_val_buffer, int V, int K, int const* __restrict end_ids)
|
||||
{
|
||||
int const tid = threadIdx.x;
|
||||
int const bid = blockIdx.x;
|
||||
T const MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
TopKMD<float, MAX_K> partial;
|
||||
partial.md.m = -MAX_T_VAL;
|
||||
partial.md.d = 0.0F;
|
||||
partial.topk.init();
|
||||
|
||||
if (finished[bid].isFinished())
|
||||
{
|
||||
for (int id = tid; id < V; id += THREADBLOCK_SIZE)
|
||||
{
|
||||
float const val = id == end_ids[bid / K] ? MAX_T_VAL : -MAX_T_VAL;
|
||||
MD new_elem{val, 1.0F};
|
||||
partial.md = reduce_md_op(partial.md, new_elem);
|
||||
partial.topk.insert(val, id);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
T const* local_logits = logits + bid * V;
|
||||
for (int id = tid; id < V; id += THREADBLOCK_SIZE)
|
||||
{
|
||||
float const val = local_logits[id] + bias[id];
|
||||
MD new_elem{val, 1.0F};
|
||||
partial.md = reduce_md_op(partial.md, new_elem);
|
||||
partial.topk.insert(val, id);
|
||||
}
|
||||
}
|
||||
|
||||
typedef cub::BlockReduce<TopKMD<float, MAX_K>, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage reduce_buffer;
|
||||
|
||||
TopKMD<float, MAX_K> total = BlockReduce(reduce_buffer).Reduce(partial, reduce_topk_md_op<float, MAX_K>);
|
||||
|
||||
if (tid == 0)
|
||||
{
|
||||
int* local_topk_id = topk_id_buffer + bid * K;
|
||||
T const* local_topk_val = topk_val_buffer + bid * K;
|
||||
float const total_m = total.md.m;
|
||||
float const total_d = logf(total.md.d);
|
||||
float local_cum_log_probs = cum_log_probs[bid];
|
||||
for (int i = 0; i < K; ++i)
|
||||
{
|
||||
local_topk_id[i] = total.topk.p[i] + bid * V;
|
||||
local_topk_val[i] = total.topk.u[i] - total_m - total_d + local_cum_log_probs;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int ITEMS_PER_THREAD, int MAX_K2, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beamStage1BaseKernel(T const* __restrict logits,
|
||||
T const* __restrict bias, FinishedState const* __restrict finished, float* __restrict temp_buffer, int V, int K,
|
||||
int const* __restrict end_ids)
|
||||
{
|
||||
// Compare to beamStage1FastKernel, here is no share memory for storage of logits,
|
||||
// and each ThreadBlock is responsible for `V / voc_parts` elements
|
||||
constexpr int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2;
|
||||
int const tid = threadIdx.x;
|
||||
int const bid = blockIdx.x;
|
||||
int const V_local = (V + gridDim.y - 1) / gridDim.y;
|
||||
int const section_start = V_local * blockIdx.y;
|
||||
int const section_end = std::min(section_start + V_local, V);
|
||||
T const MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
// Load element from logits to do reduce_md and arg_max meanwhile
|
||||
#if TOPK_FP16_STORAGE == 1
|
||||
TopKMD<__half, MAX_K2> partial;
|
||||
#else
|
||||
TopKMD<T, MAX_K2> partial;
|
||||
#endif
|
||||
partial.md.m = -MAX_T_VAL;
|
||||
partial.md.d = 0.0F;
|
||||
partial.topk.init();
|
||||
|
||||
if (finished[bid].isFinished())
|
||||
{
|
||||
#pragma unroll 1
|
||||
for (int id = section_start + tid; id < section_end; id += THREADBLOCK_SIZE)
|
||||
{
|
||||
float const val = (id == end_ids[bid / K]) ? MAX_T_VAL : -MAX_T_VAL;
|
||||
MD const new_elem_md{val, 1.0F};
|
||||
partial.md = reduce_md_op(partial.md, new_elem_md);
|
||||
partial.topk.insert(val, id);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
T const* local_logits = logits + bid * V;
|
||||
#pragma unroll 1
|
||||
for (int id = section_start + tid; id < section_end; id += THREADBLOCK_SIZE)
|
||||
{
|
||||
T const b = bias == nullptr ? (T) 0.0f : bias[id];
|
||||
T const val = local_logits[id] + b;
|
||||
MD new_elem_md{val, 1.0F};
|
||||
partial.md = reduce_md_op(partial.md, new_elem_md);
|
||||
partial.topk.insert(val, id);
|
||||
}
|
||||
}
|
||||
|
||||
// Search the top 2K elements among `V` elements and write into smem_output
|
||||
#if TOPK_FP16_STORAGE == 1
|
||||
typedef cub::BlockReduce<TopKMD<__half, MAX_K2>, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage reduce_buffer;
|
||||
TopKMD<__half, MAX_K2> total = BlockReduce(reduce_buffer).Reduce(partial, reduce_topk_md_op<__half, MAX_K2>);
|
||||
#else
|
||||
typedef cub::BlockReduce<TopKMD<T, MAX_K2>, THREADBLOCK_SIZE> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage reduce_buffer;
|
||||
TopKMD<T, MAX_K2> total = BlockReduce(reduce_buffer).Reduce(partial, reduce_topk_md_op<T, MAX_K2>);
|
||||
#endif
|
||||
__shared__ float smem_output[PACKED_TOP_KMD_SIZE];
|
||||
|
||||
if (tid == 0)
|
||||
{
|
||||
for (int i = 0; i < 2 * K; i++)
|
||||
{
|
||||
int const index = bid * V + total.topk.p[i];
|
||||
reinterpret_cast<int*>(smem_output)[i] = index;
|
||||
smem_output[MAX_K2 + i] = total.topk.u[i];
|
||||
}
|
||||
smem_output[2 * MAX_K2] = total.md.d;
|
||||
smem_output[2 * MAX_K2 + 1] = total.md.m;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Write the smem_output into temp_buffer
|
||||
float* local_temp_buffer = temp_buffer + bid * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE;
|
||||
#pragma unroll
|
||||
for (int id = tid; id < PACKED_TOP_KMD_SIZE; id += THREADBLOCK_SIZE)
|
||||
{
|
||||
local_temp_buffer[id] = smem_output[id];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int ITEMS_PER_THREAD, int MAX_K2, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE, 1) __global__ void beamStage1FastKernel(T const* __restrict logits,
|
||||
T const* __restrict bias, FinishedState const* __restrict finished, float* __restrict temp_buffer, int V, int K,
|
||||
int const* __restrict end_ids, int const V_local)
|
||||
{
|
||||
constexpr int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2;
|
||||
int const tid = threadIdx.x;
|
||||
int const bid = blockIdx.x;
|
||||
int const section_start = V_local * blockIdx.y;
|
||||
int const section_end = std::min(section_start + V_local, V);
|
||||
int const valid_smem_length = section_end - section_start;
|
||||
T const MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
// Load element from logits to smem_logprobs, doing reduce_md and arg_max meanwhile
|
||||
// Each thread is responsible for `V_local / THREADBLOCK_SIZE` elements
|
||||
extern __shared__ char smem_[];
|
||||
T* smem_logprobs = reinterpret_cast<T*>(smem_);
|
||||
|
||||
MD partial_md{-MAX_T_VAL, 0.0f};
|
||||
|
||||
#if TOPK_FP16_STORAGE == 1
|
||||
using cub_kvp = cub::KeyValuePair<int, __half>;
|
||||
#else
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
#endif
|
||||
cub_kvp partial_topk{V - 1, -MAX_T_VAL};
|
||||
cub::ArgMax arg_max;
|
||||
|
||||
if (finished[bid].isFinished())
|
||||
{
|
||||
#pragma unroll 1
|
||||
for (int id = section_start + tid; id < section_end; id += THREADBLOCK_SIZE)
|
||||
{
|
||||
float const val = (id == end_ids[bid / K]) ? MAX_T_VAL : -MAX_T_VAL;
|
||||
int const smem_index = id - section_start;
|
||||
smem_logprobs[smem_index] = val;
|
||||
MD const new_elem_md{val, 1.0F};
|
||||
partial_md = reduce_md_op(partial_md, new_elem_md);
|
||||
cub_kvp const new_elem_topk{smem_index, val};
|
||||
partial_topk = arg_max(partial_topk, new_elem_topk);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
T const* local_logits = logits + bid * V;
|
||||
#pragma unroll 1
|
||||
for (int id = section_start + tid; id < section_end; id += THREADBLOCK_SIZE)
|
||||
{
|
||||
T const b = bias == nullptr ? (T) 0.0f : bias[id];
|
||||
T const val = local_logits[id] + b;
|
||||
int const smem_index = id - section_start;
|
||||
smem_logprobs[smem_index] = val;
|
||||
MD new_elem_md{val, 1.0F};
|
||||
partial_md = reduce_md_op(partial_md, new_elem_md);
|
||||
cub_kvp new_elem_topk{smem_index, val};
|
||||
partial_topk = arg_max(partial_topk, new_elem_topk);
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Search the top 2K elements among `V_local` elements of this ThreadBlock and write into smem_output
|
||||
__shared__ float smem_output[PACKED_TOP_KMD_SIZE];
|
||||
__shared__ int thread_requiring_update;
|
||||
|
||||
using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
|
||||
using BlockReduceTopK = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>;
|
||||
|
||||
__shared__ union
|
||||
{
|
||||
typename BlockReduceMD::TempStorage md_smem;
|
||||
typename BlockReduceTopK::TempStorage topk_smem;
|
||||
} reduce_buffer;
|
||||
|
||||
for (int i = 0; i < 2 * K; ++i)
|
||||
{
|
||||
// Pop the element with largest value to "smem_output" per iteration
|
||||
cub_kvp total_topk = BlockReduceTopK(reduce_buffer.topk_smem).Reduce(partial_topk, arg_max);
|
||||
if (tid == 0)
|
||||
{
|
||||
int const index = bid * V + section_start + total_topk.key;
|
||||
reinterpret_cast<int*>(smem_output)[i] = index;
|
||||
smem_output[MAX_K2 + i] = total_topk.value;
|
||||
smem_logprobs[total_topk.key] = -MAX_T_VAL; // pollute the value of the popped element
|
||||
thread_requiring_update = total_topk.key % THREADBLOCK_SIZE;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (tid == thread_requiring_update && i < 2 * K - 1)
|
||||
{
|
||||
// The thread popped the element need to update its partial_topk
|
||||
// No need to do this in the last iteration
|
||||
partial_topk.key = V - 1;
|
||||
partial_topk.value = -MAX_T_VAL;
|
||||
for (int index = tid; index < valid_smem_length; index += THREADBLOCK_SIZE)
|
||||
{
|
||||
cub_kvp new_elem{index, smem_logprobs[index]};
|
||||
partial_topk = arg_max(partial_topk, new_elem);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Do reduce_md among the top 2K elements in the smem_output and write into tail of smem_output
|
||||
auto reduce_md_func = [](const MD& a, const MD& b) { return reduce_md_op(a, b); };
|
||||
MD total_md = BlockReduceMD(reduce_buffer.md_smem).Reduce(partial_md, reduce_md_func);
|
||||
if (tid == 0)
|
||||
{
|
||||
smem_output[2 * MAX_K2] = total_md.d;
|
||||
smem_output[2 * MAX_K2 + 1] = total_md.m;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Write the smem_output into temp_buffer
|
||||
float* local_temp_buffer = temp_buffer + bid * PACKED_TOP_KMD_SIZE * gridDim.y + blockIdx.y * PACKED_TOP_KMD_SIZE;
|
||||
#pragma unroll
|
||||
for (int id = tid; id < PACKED_TOP_KMD_SIZE; id += THREADBLOCK_SIZE)
|
||||
{
|
||||
local_temp_buffer[id] = smem_output[id];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int MAX_K2, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__
|
||||
void beamStage2Kernel(float const* __restrict temp_buffer, float const* __restrict cum_log_probs,
|
||||
int* __restrict topk_id_buffer, T* __restrict topk_val_buffer, int const K, int const voc_parts, int const V)
|
||||
{
|
||||
constexpr int PACKED_TOP_KMD_SIZE = 2 * MAX_K2 + 2;
|
||||
int const bid = blockIdx.x;
|
||||
int const tid = threadIdx.x;
|
||||
T const MAX_T_VAL = (std::is_same<T, half>::value) ? HALF_FLT_MAX : FLT_MAX;
|
||||
|
||||
using cub_kvp = cub::KeyValuePair<int, T>;
|
||||
using BlockReduceTopK = cub::BlockReduce<cub_kvp, THREADBLOCK_SIZE>;
|
||||
using BlockReduceMD = cub::BlockReduce<MD, THREADBLOCK_SIZE>;
|
||||
|
||||
extern __shared__ char smem[];
|
||||
float* smem_topk = reinterpret_cast<float*>(smem);
|
||||
__shared__ cub_kvp buf_smem_kv[MAX_K2];
|
||||
|
||||
__shared__ union
|
||||
{
|
||||
typename BlockReduceTopK::TempStorage topk_smem;
|
||||
typename BlockReduceMD::TempStorage md_smem;
|
||||
|
||||
} shared_temp_storage;
|
||||
|
||||
cub::ArgMax arg_max;
|
||||
MD partial_md{-MAX_T_VAL, 0.0f};
|
||||
cub_kvp total_topk{V - 1, -MAX_T_VAL};
|
||||
|
||||
auto reduce_md_func = [](const MD& a, const MD& b) { return reduce_md_op(a, b); };
|
||||
|
||||
// Load and unpack into registers through smem
|
||||
float const* local_temp_storage = temp_buffer + PACKED_TOP_KMD_SIZE * bid * voc_parts;
|
||||
for (int idx = tid; idx < PACKED_TOP_KMD_SIZE * voc_parts; idx += THREADBLOCK_SIZE)
|
||||
{
|
||||
smem_topk[idx] = local_temp_storage[idx];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Find the argmax within each voc_parts
|
||||
// Find the topK across all voc_parts
|
||||
for (int k = 0; k < 2 * K; ++k)
|
||||
{
|
||||
cub_kvp partial_topk{V - 1, -MAX_T_VAL};
|
||||
// Only threads responsible for a chunk will do the computation
|
||||
if (tid < voc_parts)
|
||||
{
|
||||
for (int i = 0; i < 2 * K; ++i)
|
||||
{
|
||||
int const current_index = tid * PACKED_TOP_KMD_SIZE + i;
|
||||
T current_value = smem_topk[current_index + MAX_K2];
|
||||
cub_kvp new_elem = {current_index, current_value};
|
||||
partial_topk = arg_max(partial_topk, new_elem);
|
||||
}
|
||||
}
|
||||
|
||||
cub_kvp total_topk = BlockReduceTopK(shared_temp_storage.topk_smem).Reduce(partial_topk, arg_max);
|
||||
__syncthreads();
|
||||
|
||||
if (tid == 0)
|
||||
{
|
||||
// Store kv pairs in shared mem buffer
|
||||
int temp_offset = total_topk.key;
|
||||
int global_offset = reinterpret_cast<int*>(smem_topk)[temp_offset];
|
||||
total_topk.key = global_offset;
|
||||
buf_smem_kv[k] = total_topk;
|
||||
|
||||
// Invalidate the maximum value within the chunk
|
||||
reinterpret_cast<int*>(smem_topk)[temp_offset] = V - 1; // id in share memory
|
||||
smem_topk[temp_offset + MAX_K2] = -MAX_T_VAL; // value in share memory
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Extract and reduce MD values across the chunks
|
||||
if (tid < voc_parts)
|
||||
{
|
||||
partial_md.d = smem_topk[tid * PACKED_TOP_KMD_SIZE + 2 * MAX_K2];
|
||||
partial_md.m = smem_topk[tid * PACKED_TOP_KMD_SIZE + 2 * MAX_K2 + 1];
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
MD total_md = BlockReduceMD(shared_temp_storage.md_smem).Reduce(partial_md, reduce_md_func);
|
||||
|
||||
if (tid == 0)
|
||||
{
|
||||
float d_total_log = logf(total_md.d);
|
||||
|
||||
for (int i = 0; i < MAX_K2; ++i)
|
||||
{
|
||||
float val = (float) buf_smem_kv[i].value - total_md.m - d_total_log;
|
||||
if (i < 2 * K)
|
||||
{
|
||||
topk_id_buffer[bid * 2 * K + i] = buf_smem_kv[i].key;
|
||||
topk_val_buffer[bid * 2 * K + i] = val + cum_log_probs[bid];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int MAX_K2>
|
||||
void beamStage2KernelLauncher(float const* temp_buffer, float const* cum_log_probs, int* topk_id_buffer,
|
||||
T* topk_val_buffer, int const batch_size, int const beam_width, int const voc_parts, int const V,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
// TODO: rewrite kernel to remove dependence of constant block size to reduce compilation time
|
||||
size_t const smem_size = sizeof(float) * voc_parts * (2 * MAX_K2 + 2);
|
||||
|
||||
if (voc_parts <= 32)
|
||||
{
|
||||
beamStage2Kernel<T, MAX_K2, 32><<<batch_size * beam_width, 32, smem_size, stream>>>(
|
||||
temp_buffer, cum_log_probs, topk_id_buffer, topk_val_buffer, beam_width, voc_parts, V);
|
||||
return;
|
||||
}
|
||||
if (voc_parts <= 64)
|
||||
{
|
||||
beamStage2Kernel<T, MAX_K2, 64><<<batch_size * beam_width, 64, smem_size, stream>>>(
|
||||
temp_buffer, cum_log_probs, topk_id_buffer, topk_val_buffer, beam_width, voc_parts, V);
|
||||
return;
|
||||
}
|
||||
if (voc_parts <= 128)
|
||||
{
|
||||
beamStage2Kernel<T, MAX_K2, 128><<<batch_size * beam_width, 128, smem_size, stream>>>(
|
||||
temp_buffer, cum_log_probs, topk_id_buffer, topk_val_buffer, beam_width, voc_parts, V);
|
||||
return;
|
||||
}
|
||||
assert(0);
|
||||
}
|
||||
|
||||
template <typename T, int MAX_K>
|
||||
void topK_softMax_kernelLauncher(
|
||||
T const* logits, T const* bias, void* workspace, BeamHypotheses& bh, cudaStream_t stream)
|
||||
{
|
||||
// Workflow of this function (reference: https://github.com/NVIDIA/online-softmax)
|
||||
// Using batch_size (BS) = 2, beam_width (BM) = 5, vocab_size (V) = 32000 as an example:
|
||||
// nPaddedBeamWidth (pBM) = 8 = 2 ^ ceil(log(BM)), nSmallTopKMaxVocParts (nVP) = 128 (Constant)
|
||||
// MAX_K = 8 = pBM, MAX_K2 = 16 = 2 * pBM
|
||||
// logits.shape = [BS, BM, V]
|
||||
// blockSize = 128, voc_parts = 13, voc_size_chunk = 2462 = ceil(32000/13)
|
||||
|
||||
// The content of workspace (length aligned to 4):
|
||||
// | allocated size | used size | data type |
|
||||
// ┏━━━━━━━━━━━━━━━━━┓ ---------------------------------------------------------------------------
|
||||
// ┃ topk_id_buffer ┃ BS * pBM * pBM * 2 | | int |
|
||||
// ┣━━━━━━━━━━━━━━━━━┫ -------------------------------------- Change "pBM" into "BM" -------------
|
||||
// ┃ topk_val_buffer ┃ BS * pBM * pBM * 2 | | float |
|
||||
// ┣━━━━━━━━━━━━━━━━━┫ -------------------------------------- in the left formulas -------------
|
||||
// ┃ temp_buffer ┃ BS * pBM * nVP * (2 * (pBM * 2) + 2) | | float |
|
||||
// ┗━━━━━━━━━━━━━━━━━┛ ---------------------------------------------------------------------------
|
||||
|
||||
// Stage1: gridDim(BS*BM,voc_parts,1), blockDim(blockSize,1,1)
|
||||
// Each ThreadBlock takes `voc_size_chunk` contiguous elements in logits to do TopK and reduce_md,
|
||||
// then writes output into temp_buffer.
|
||||
// At end of this kernel, each ThreadBlock holds the indexes and values of the top 2*K elements,
|
||||
// as well as the m(x) and l(x) of those elements (see paper of Flash Attention, arXiv:2205.14135)
|
||||
// temp_buffer.shape = [BS*BM, voc_parts, 2*MAX_K2+2]
|
||||
// The content of the last dimension of temp_buffer (updated by each ThreadBlock, we call it "Tile"):
|
||||
// ┏━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━┓
|
||||
// ┃ topk_id ┃ topk_val ┃ md ┃
|
||||
// ┗━━━━━━━━━┻━━━━━━━━━━┻━━━━━━━┛
|
||||
// | allocated size | MAX_K2 | MAX_K2 | 2 |
|
||||
// | used size | 2*BM | 2*BM | 2 |
|
||||
// | data type | int | float | float |
|
||||
|
||||
// Stage2: gridDim(BS*BM,1,1), blockDim(32/64/128,1,1)
|
||||
// Each TheadBlock takes `voc_parts` contiguous Tiles in temp_buffer to do reduce_topk and reduce_md,
|
||||
// writes output topk_id into in topk_id_buffer, writes topk_value + cum_log_probs into topk_val_buffer.
|
||||
|
||||
// batchBeamKernel: gridDim(BS,1,1), blockDim(128,1,1)
|
||||
// Each TheadBlock is responsible for one batch, doing work below:
|
||||
// + moves one beam into candidate-beam-array if it is finished (gemerated end_id in this step).
|
||||
// + selects BM elements for the next generation step if not.
|
||||
// + maintains related score array, min_normed_score / is_done / finished, etc..
|
||||
|
||||
constexpr int items_per_thread = 1;
|
||||
constexpr int blockSize = (MAX_K < 16) ? ((MAX_K < 8) ? nSmallTopKBlockSize : 128) : 64;
|
||||
int const batch_size{bh.local_batch_size};
|
||||
int const beam_width{bh.beam_width};
|
||||
int const V{bh.vocab_size};
|
||||
int const* end_ids{bh.end_ids};
|
||||
float* cum_log_probs{bh.cum_log_probs};
|
||||
FinishedState const* finished{bh.finished};
|
||||
|
||||
int const offset = roundUp(batch_size * beam_width * beam_width * 2, 4);
|
||||
int* topk_id_buffer = reinterpret_cast<int*>(workspace);
|
||||
T* topk_val_buffer = reinterpret_cast<T*>(topk_id_buffer + offset);
|
||||
float* temp_buffer = reinterpret_cast<float*>(topk_val_buffer + offset);
|
||||
|
||||
#ifdef DO_SPLIT_SMALL_TOP_K_SOFTMAX
|
||||
|
||||
// Upper limit count of ThreadBlock, gotten by using no share memory
|
||||
int max_active_blocks = -1;
|
||||
TLLM_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks, beamStage1FastKernel<T, items_per_thread, 2 * MAX_K, blockSize>, blockSize, 0));
|
||||
|
||||
// Find the max smem on the device and use that to determine the vocab parts in the best case.
|
||||
int max_smem_per_sm = -1;
|
||||
int max_smem_per_block = -1;
|
||||
int const device = tensorrt_llm::common::getDevice();
|
||||
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, device));
|
||||
TLLM_CUDA_CHECK(cudaDeviceGetAttribute(&max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, device));
|
||||
cudaFuncAttributes attr;
|
||||
TLLM_CUDA_CHECK(cudaFuncGetAttributes(&attr, beamStage1FastKernel<T, items_per_thread, 2 * MAX_K, blockSize>));
|
||||
|
||||
// One ThreadBlock must at least have share memory of `sizeof(T) * V / nSmallTopKMaxVocParts` bytes
|
||||
int const static_smem = attr.sharedSizeBytes;
|
||||
int const max_dyn_smem_per_block = max_smem_per_block - static_smem;
|
||||
TLLM_CHECK_WITH_INFO(sizeof(T) * V <= max_dyn_smem_per_block * nSmallTopKMaxVocParts,
|
||||
"Vocab size is too large for split-k TopK beam search fast path.");
|
||||
|
||||
// Find the maximum of ThreadBlock (maximum of voc_parts, minimum of smem),
|
||||
// satisfying voc_parts <= nSmallTopKMaxVocParts && dyn_smem_size * voc_parts >= sizeof(T) * V
|
||||
int const driver_smem_per_block = max_smem_per_sm - max_smem_per_block;
|
||||
int const extra_smem = driver_smem_per_block + static_smem;
|
||||
int voc_parts = nSmallTopKMaxVocParts + 1;
|
||||
for (int n_block = max_active_blocks - 1; n_block > 0 && voc_parts > nSmallTopKMaxVocParts; --n_block)
|
||||
{
|
||||
int smem_per_block = max_smem_per_sm / n_block;
|
||||
int dyn_smem_size = smem_per_block - extra_smem;
|
||||
dyn_smem_size -= dyn_smem_size % sizeof(T);
|
||||
voc_parts = (sizeof(T) * V + dyn_smem_size - 1) / dyn_smem_size;
|
||||
}
|
||||
|
||||
if (voc_parts <= nSmallTopKMaxVocParts)
|
||||
{
|
||||
// Use stage 1 fast kernel
|
||||
int const voc_size_chunk = (V + voc_parts - 1) / voc_parts;
|
||||
int const dyn_smem_size = sizeof(T) * voc_size_chunk;
|
||||
if (dyn_smem_size >= (48 << 10))
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaFuncSetAttribute(beamStage1FastKernel<T, items_per_thread, 2 * MAX_K, blockSize>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, dyn_smem_size));
|
||||
}
|
||||
dim3 gridSize(batch_size * beam_width, voc_parts);
|
||||
beamStage1FastKernel<T, items_per_thread, 2 * MAX_K, blockSize><<<gridSize, blockSize, dyn_smem_size, stream>>>(
|
||||
logits, bias, finished, temp_buffer, V, beam_width, end_ids, voc_size_chunk);
|
||||
}
|
||||
else
|
||||
{
|
||||
// use stage 1 base kernel
|
||||
int voc_parts = 4;
|
||||
if (batch_size * beam_width < 256)
|
||||
{
|
||||
// TODO: add heuristics for base stage 1 kernel
|
||||
// Volta has 80 SMs, so we aim for three waves
|
||||
voc_parts = (240 + batch_size * beam_width - 1) / (batch_size * beam_width);
|
||||
voc_parts = std::min(128, voc_parts); // we implement up to 128
|
||||
}
|
||||
cudaFuncSetAttribute(beamStage1BaseKernel<T, items_per_thread, 2 * MAX_K, blockSize>,
|
||||
cudaFuncAttributePreferredSharedMemoryCarveout, cudaSharedmemCarveoutMaxL1);
|
||||
dim3 gridSize(batch_size * beam_width, voc_parts);
|
||||
beamStage1BaseKernel<T, items_per_thread, 2 * MAX_K, blockSize>
|
||||
<<<gridSize, blockSize, 0, stream>>>(logits, bias, finished, temp_buffer, V, beam_width, end_ids);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
|
||||
beamStage2KernelLauncher<T, 2 * MAX_K>(
|
||||
temp_buffer, cum_log_probs, topk_id_buffer, topk_val_buffer, batch_size, beam_width, voc_parts, V, stream);
|
||||
#else
|
||||
beamKernel<T, items_per_thread, MAX_K, blockSize><<<batch_size * beam_width, blockSize, 0, stream>>>(
|
||||
logits, bias, cum_log_probs, finished, topk_id_buffer, topk_val_buffer, V, beam_width, end_ids);
|
||||
#endif
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
// Keep 2 * beam_width candidates in case of k candidates finishes in one iteration
|
||||
size_t const smem_size = sizeof(T) * beam_width * beam_width * 2;
|
||||
|
||||
if (smem_size >= (48 << 10))
|
||||
{
|
||||
TLLM_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
batchBeamKernel<T, MAX_K * 2, 32>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
|
||||
}
|
||||
|
||||
batchBeamKernel<T, MAX_K * 2, 32><<<batch_size, 32, smem_size, stream>>>(topk_id_buffer, topk_val_buffer, bh);
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
#define INSTANTIATE_BEAMSEARCH_K(T, MAX_K) \
|
||||
template void topK_softMax_kernelLauncher<T, MAX_K>( \
|
||||
T const* logits, T const* bias, void* workspace, BeamHypotheses& bh, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -1,99 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef CUDART_VERSION
|
||||
#error CUDART_VERSION Undefined!
|
||||
#elif (CUDART_VERSION >= 11050)
|
||||
#include <cub/cub.cuh>
|
||||
#else
|
||||
#include "3rdparty/cub/cub.cuh"
|
||||
#endif
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/reduceKernelUtils.cuh"
|
||||
#include "tensorrt_llm/common/stringUtils.h"
|
||||
#include "tensorrt_llm/kernels/beamSearchTopkKernels.h"
|
||||
|
||||
using namespace tensorrt_llm::common;
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
__global__ void insertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished,
|
||||
float const* cum_log_probs, int const batch_size, int const beam_width)
|
||||
{
|
||||
int const bid = blockIdx.x;
|
||||
int const tgt_start_idx = beam_hyps.num_beams[bid];
|
||||
int const max_seq_len{beam_hyps.max_seq_len};
|
||||
float const length_penalty{beam_hyps.length_penalties == nullptr ? 1.0f : beam_hyps.length_penalties[bid]};
|
||||
if (beam_hyps.is_done[bid])
|
||||
{
|
||||
return;
|
||||
}
|
||||
for (int beam_idx = 0; beam_idx < beam_width; beam_idx++)
|
||||
{
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
int const src_beam_idx = bid * beam_width + beam_idx;
|
||||
int const tgt_beam_idx = bid * beam_width * 2 + beam_idx + tgt_start_idx;
|
||||
|
||||
int const last_token_idx = beam_hyps.sequence_lengths_src[src_beam_idx] - 1;
|
||||
|
||||
beam_hyps.output_ids_tgt[tgt_beam_idx * max_seq_len + last_token_idx]
|
||||
= beam_hyps.output_ids_src[src_beam_idx * max_seq_len + last_token_idx];
|
||||
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr)
|
||||
{
|
||||
beam_hyps.log_probs[tgt_beam_idx * max_seq_len + last_token_idx]
|
||||
= beam_hyps.log_probs_src[last_token_idx * batch_size * beam_width + src_beam_idx];
|
||||
}
|
||||
int prev_id = beam_hyps.parent_ids_src[src_beam_idx * max_seq_len + last_token_idx];
|
||||
for (int token_idx = last_token_idx - 1; token_idx >= 0; token_idx--)
|
||||
{
|
||||
// output_ids_tgt need to use max_seq_len + 1 because its shape is
|
||||
// [bs, beam_width, max_seq_len + 1]
|
||||
beam_hyps.output_ids_tgt[tgt_beam_idx * max_seq_len + token_idx]
|
||||
= beam_hyps.output_ids_src[bid * beam_width * max_seq_len + prev_id * max_seq_len + token_idx];
|
||||
if (beam_hyps.log_probs != nullptr && beam_hyps.log_probs_src != nullptr)
|
||||
{
|
||||
beam_hyps.log_probs[tgt_beam_idx * max_seq_len + token_idx]
|
||||
= beam_hyps.log_probs_src[token_idx * batch_size * beam_width + bid * beam_width + prev_id];
|
||||
}
|
||||
prev_id = beam_hyps.parent_ids_src[bid * beam_width * max_seq_len + prev_id * max_seq_len + token_idx];
|
||||
}
|
||||
beam_hyps.sequence_lengths_tgt[tgt_beam_idx] = last_token_idx + 1;
|
||||
|
||||
// TODO huggingface uses total length to normalize the scores, instead of number of generated tokens.
|
||||
// Check that is it reasonable or not.
|
||||
beam_hyps.normed_scores[tgt_beam_idx] = apply_length_penalty(cum_log_probs[src_beam_idx],
|
||||
finished[src_beam_idx].isFinished() ? last_token_idx + 1 : last_token_idx, length_penalty);
|
||||
beam_hyps.cum_log_probs[tgt_beam_idx] = cum_log_probs[src_beam_idx];
|
||||
|
||||
beam_hyps.num_beams[bid]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished, float const* cum_log_probs,
|
||||
int const batch_size, int const beam_width, cudaStream_t stream)
|
||||
{
|
||||
insertUnfinishedPath<<<batch_size, 256, 0, stream>>>(beam_hyps, finished, cum_log_probs, batch_size, beam_width);
|
||||
}
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -1,96 +0,0 @@
|
||||
/*
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/kernels/decodingCommon.h"
|
||||
#include <cuda_runtime.h>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
// We keep tracing `beam_width` beams during iterations, once a beam is finished,
|
||||
// we record the ids and its normed score in output_ids_tgt and normed_scores
|
||||
struct BeamHypotheses
|
||||
{
|
||||
// BS: batch_size
|
||||
// BM: beam_width
|
||||
// mSL: max_seq_length
|
||||
// %%: parameter name when we call [generation.py] dynamic_decoder.forward (python workflow)
|
||||
|
||||
// Pointers initialized in these two functions below:
|
||||
// [gptDecoder.cpp] GptDecoder<T>::forward or [dynamicDecodeOp.cpp] FtDynamicDecode<T>::forward
|
||||
bool* is_done{nullptr}; // [BS] %% self.beam_hyps_is_done
|
||||
float* cum_log_probs{nullptr}; // [BS, BM*2] %% self.beam_hyps_cum_log_probs
|
||||
float* log_probs{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_log_probs
|
||||
float* min_normed_scores{nullptr}; // [BS] %% self.beam_hyps_min_normed_scores
|
||||
float* normed_scores{nullptr}; // [BS, BM*2] %% self.beam_hyps_normed_scores
|
||||
int* num_beams{nullptr}; // [BS] %% self.beam_hyps_num_beams
|
||||
int* output_ids_tgt{nullptr}; // [BS, BM*2, mSL] %% self.beam_hyps_output_ids_tgt
|
||||
int* sequence_lengths_tgt{nullptr}; // [BS, BM*2] %% self.beam_hyps_sequence_lengths_tgt
|
||||
int const* input_lengths{nullptr}; // [BS*BM] %% context_length
|
||||
|
||||
// Pointers initialized in [onlineBeamSearchLayer.cu] invokeSoftMax:
|
||||
int const* end_ids{nullptr}; // [BS*BM] %% self.end_ids
|
||||
FinishedState* finished; // [BS*BM] %% self.finished
|
||||
float* cum_log_probs_src{nullptr}; // [BS, BM] %% self.cum_log_probs
|
||||
float* log_probs_src{nullptr}; // [mSL, BS, BM] %% self.log_probs_tiled
|
||||
int* sequence_lengths_src{nullptr}; // [BS*BM] %% self.sequence_length_buffer
|
||||
// These two pointers are relocated in [dynamicDecodeLayer.cpp] DynamicDecodeLayer<T>::prepareIdsPtrs
|
||||
int** output_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] %% self.output_ids
|
||||
int** parent_ids_tgt_ptr{nullptr}; // [BS][BM, mSL] %% self.parent_ids
|
||||
|
||||
float* diversity_rates{nullptr}; // [BS] from SamplingConfig
|
||||
float* length_penalties{nullptr}; // [BS] from SamplingConfig
|
||||
int* early_stoppings{nullptr}; // [BS] from SamplingConfig
|
||||
|
||||
// Pointers for function gatherTree
|
||||
int const* output_ids_src{nullptr}; //
|
||||
int const* parent_ids_src{nullptr}; //
|
||||
|
||||
// Scalar values
|
||||
bool is_return_normed_score{true}; // return normed_cum_log_probs or cum_log_probs, always be true now
|
||||
int batch_size{0}; //
|
||||
int beam_width{0}; //
|
||||
int ite{0}; // index of local_batch, always be 0 if pp_size==1
|
||||
int local_batch_size{0}; //
|
||||
int max_seq_len{0}; //
|
||||
int step{0}; // only used in [beamSearchTopkKernels.cu], always be 0 in [onlineSoftmaxBeamsearchKernels*.cu.h]
|
||||
int vocab_size{0}; // vocab_size_padded
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__device__ __forceinline__ T apply_length_penalty(T log_prob, int length, float length_penalty)
|
||||
{
|
||||
// score = log(prob) / (length ^ length_penalty)
|
||||
if (length_penalty == 0.0f || length == 1)
|
||||
{
|
||||
return log_prob;
|
||||
}
|
||||
return log_prob / static_cast<T>(powf(static_cast<float>(length), length_penalty));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeTopkBeamSearch(void* workspace, size_t& workspace_size, T* log_probs, int* ids, BeamHypotheses* beam_hyps,
|
||||
bool const* finished, int const* sequence_lengths, int const batch_size, int const beam_width,
|
||||
int const vocab_size_padded_, const T diversity_rate, float const length_penalty, int const* end_ids,
|
||||
cudaStream_t stream);
|
||||
|
||||
void invokeInsertUnfinishedPath(BeamHypotheses beam_hyps, FinishedState const* finished, float const* cum_log_probs,
|
||||
int const batch_size, int const beam_width, cudaStream_t stream);
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -63,6 +63,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin[];
|
||||
@ -72,6 +73,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin[];
|
||||
@ -81,6 +83,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin[];
|
||||
@ -90,6 +93,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin[];
|
||||
@ -99,6 +103,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin[];
|
||||
@ -108,6 +113,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_t
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin[];
|
||||
@ -118,6 +124,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubi
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin[];
|
||||
@ -128,6 +135,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin[];
|
||||
@ -138,6 +146,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubi
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin[];
|
||||
@ -148,6 +157,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin[];
|
||||
@ -158,6 +168,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubi
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin[];
|
||||
@ -168,6 +179,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin[];
|
||||
@ -178,6 +190,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubi
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin[];
|
||||
@ -188,6 +201,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin[];
|
||||
@ -198,6 +212,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubi
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin[];
|
||||
@ -208,6 +223,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin[];
|
||||
@ -218,6 +234,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubi
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin[];
|
||||
@ -228,6 +245,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm80_cu_cubin[];
|
||||
@ -238,6 +256,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm80_cu
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm80_cu_cubin[];
|
||||
@ -248,6 +267,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm80_cu_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm86_cu_cubin[];
|
||||
@ -258,6 +278,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm86_cu
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm86_cu_cubin[];
|
||||
@ -268,6 +289,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm86_cu_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin[];
|
||||
@ -278,6 +300,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin[];
|
||||
@ -288,6 +311,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm70_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm70_cu_cubin[];
|
||||
@ -361,6 +385,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_tma_ws_sm90_cu_cu
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
@ -370,6 +395,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_alibi_tma_ws_sm90
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_256_S_32_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_256_S_40_tma_ws_sm90_cu_cubin_len;
|
||||
@ -379,6 +405,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_tma_ws_sm90_cu_cu
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
@ -388,6 +415,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_alibi_tma_ws_sm90
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_tma_ws_sm90_cu_cubin_len;
|
||||
@ -397,6 +425,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_tma_ws_sm90_
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
@ -406,6 +435,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_alibi_tma_ws
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm89_cu_cubin_len;
|
||||
@ -416,6 +446,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm89_cu_cubin_len
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm89_cu_cubin_len;
|
||||
@ -426,6 +457,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm89_cu_cubin_len;
|
||||
@ -436,6 +468,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm89_cu_cubin_len
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm89_cu_cubin_len;
|
||||
@ -446,6 +479,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm80_cu_cubin_len;
|
||||
@ -456,6 +490,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm80_cu_cubin_len
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm80_cu_cubin_len;
|
||||
@ -466,6 +501,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm80_cu_cubin_len;
|
||||
@ -476,6 +512,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm80_cu_cubin_len
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm80_cu_cubin_len;
|
||||
@ -486,6 +523,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_16_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_32_sm86_cu_cubin_len;
|
||||
@ -496,6 +534,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_sm86_cu_cubin_len
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_16_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm86_cu_cubin_len;
|
||||
@ -506,6 +545,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_96_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_104_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_16_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_32_sm86_cu_cubin_len;
|
||||
@ -516,6 +556,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_sm86_cu_cubin_len
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_16_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_32_sm86_cu_cubin_len;
|
||||
@ -526,6 +567,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_96_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_104_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm80_cu_cubin_len;
|
||||
@ -536,6 +578,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm80_cu_cubi
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm80_cu_cubin_len;
|
||||
@ -546,6 +589,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm80_cu_cubin
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm86_cu_cubin_len;
|
||||
@ -556,6 +600,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm86_cu_cubi
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm86_cu_cubin_len;
|
||||
@ -566,6 +611,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm86_cu_cubin
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_sm89_cu_cubin_len;
|
||||
@ -576,6 +622,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_sm89_cu_cubi
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_sm89_cu_cubin_len;
|
||||
@ -586,6 +633,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_sm89_cu_cubin
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_sm70_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_40_sm70_cu_cubin_len;
|
||||
@ -844,6 +892,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 0, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 1, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 2, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 0, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 1, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 2, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 0, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 1, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 2, false, false},
|
||||
@ -871,6 +922,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, false, 2, true, false},
|
||||
@ -898,6 +952,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 2, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 2, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 2, false, false},
|
||||
@ -925,6 +982,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 2, true, false},
|
||||
@ -952,6 +1012,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 2, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 2, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 2, false, false},
|
||||
@ -979,6 +1042,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, 2, true, false},
|
||||
@ -1009,6 +1075,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, 2, true, true},
|
||||
@ -1039,6 +1108,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sm89_kernel_nl", 49152, 128, 64, false, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sm89_kernel_nl", 49152, 128, 64, false, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sm89_kernel_nl", 49152, 128, 64, false, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, 2, true, false},
|
||||
@ -1069,6 +1141,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
@ -1099,6 +1174,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sm89_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sm89_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sm89_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
@ -1129,6 +1207,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, 2, true, true},
|
||||
@ -1159,6 +1240,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sm80_kernel_nl", 49152, 128, 64, false, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sm80_kernel_nl", 49152, 128, 64, false, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sm80_kernel_nl", 49152, 128, 64, false, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, 2, true, false},
|
||||
@ -1189,6 +1273,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
@ -1219,6 +1306,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sm80_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_causal_sm80_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sm80_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_causal_sm80_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sm80_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_causal_sm80_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
@ -1249,6 +1339,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, 2, true, true},
|
||||
@ -1279,6 +1372,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sm86_kernel_nl", 49152, 128, 64, false, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sm86_kernel_nl", 49152, 128, 64, false, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sm86_kernel_nl", 49152, 128, 64, false, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, 2, true, false},
|
||||
@ -1309,6 +1405,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
@ -1339,6 +1438,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sm86_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_causal_sm86_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sm86_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_causal_sm86_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sm86_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_causal_sm86_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
@ -1369,6 +1471,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
@ -1399,6 +1504,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm80_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_causal_sm80_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm80_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_causal_sm80_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm80_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_causal_sm80_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
@ -1429,6 +1537,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
@ -1459,6 +1570,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm86_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_causal_sm86_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm86_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_causal_sm86_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm86_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_causal_sm86_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
@ -1489,6 +1603,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, true, 2, true, true},
|
||||
@ -1519,6 +1636,9 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_causal_sm89_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_causal_sm89_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_kernel_nl", 49152, 128, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_causal_sm89_kernel_nl", 49152, 128, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, true, 2, true, false},
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -714,12 +714,12 @@ bool MHARunner::fmha_supported(int const headSize, int const sm)
|
||||
else if (sm == kSM_80 || sm == kSM_86 || sm == kSM_89)
|
||||
{
|
||||
return (headSize == 16 || headSize == 32 || headSize == 40 || headSize == 64 || headSize == 80 || headSize == 96
|
||||
|| headSize == 104 || headSize == 128 || headSize == 160 || headSize == 256);
|
||||
|| headSize == 104 || headSize == 128 || headSize == 160 || headSize == 192 || headSize == 256);
|
||||
}
|
||||
else if (sm == kSM_90)
|
||||
{
|
||||
return (headSize == 32 || headSize == 40 || headSize == 64 || headSize == 80 || headSize == 96
|
||||
|| headSize == 104 || headSize == 128 || headSize == 160 || headSize == 256);
|
||||
|| headSize == 104 || headSize == 128 || headSize == 160 || headSize == 192 || headSize == 256);
|
||||
}
|
||||
|
||||
return false;
|
||||
|
||||
@ -33,6 +33,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_pagedKV_tma_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_256_S_32_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_256_S_40_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
@ -42,6 +43,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_pagedKV_alib
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_256_S_32_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_256_S_40_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
@ -51,6 +53,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_pagedKV_tma_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_256_S_32_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_256_S_40_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
@ -60,6 +63,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_pagedKV_alib
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
@ -69,6 +73,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_pagedKV
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
@ -78,6 +83,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_pagedKV
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_16_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_32_pagedKV_sm89_cu_cubin[];
|
||||
@ -88,6 +94,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_pagedKV_sm89
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_16_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_pagedKV_sm89_cu_cubin[];
|
||||
@ -98,6 +105,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_96_pagedKV_sm89_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_104_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_16_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_32_pagedKV_sm89_cu_cubin[];
|
||||
@ -108,6 +116,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_pagedKV_sm89
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_16_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_32_pagedKV_sm89_cu_cubin[];
|
||||
@ -118,6 +127,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_96_pagedKV_sm89_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_104_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_16_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_32_pagedKV_sm80_cu_cubin[];
|
||||
@ -128,6 +138,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_pagedKV_sm80
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_16_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_pagedKV_sm80_cu_cubin[];
|
||||
@ -138,6 +149,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_96_pagedKV_sm80_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_104_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_16_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_32_pagedKV_sm80_cu_cubin[];
|
||||
@ -148,6 +160,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_pagedKV_sm80
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_16_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_32_pagedKV_sm80_cu_cubin[];
|
||||
@ -158,6 +171,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_96_pagedKV_sm80_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_104_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_16_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_32_pagedKV_sm86_cu_cubin[];
|
||||
@ -168,6 +182,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_pagedKV_sm86
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_16_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_pagedKV_sm86_cu_cubin[];
|
||||
@ -178,6 +193,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_96_pagedKV_sm86_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_104_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_16_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_32_pagedKV_sm86_cu_cubin[];
|
||||
@ -188,6 +204,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_pagedKV_sm86
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_16_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_32_pagedKV_sm86_cu_cubin[];
|
||||
@ -198,6 +215,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_96_pagedKV_sm86_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_104_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_16_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_128_128_S_32_pagedKV_sm90_cu_cubin[];
|
||||
@ -208,6 +226,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_96_pagedKV_sm90
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_104_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_128_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_16_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_64_S_32_pagedKV_sm90_cu_cubin[];
|
||||
@ -218,6 +237,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_96_pagedKV_sm90_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_104_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_32_S_128_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_16_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_128_128_S_32_pagedKV_sm90_cu_cubin[];
|
||||
@ -228,6 +248,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_96_pagedKV_sm90
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_104_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_128_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_16_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_64_S_32_pagedKV_sm90_cu_cubin[];
|
||||
@ -238,6 +259,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_96_pagedKV_sm90_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_104_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_32_S_128_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_pagedKV_sm80_cu_cubin[];
|
||||
@ -248,6 +270,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_pagedKV
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_pagedKV_sm80_cu_cubin[];
|
||||
@ -258,6 +281,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_pagedKV_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm80_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_pagedKV_sm86_cu_cubin[];
|
||||
@ -268,6 +292,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_pagedKV
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_pagedKV_sm86_cu_cubin[];
|
||||
@ -278,6 +303,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_pagedKV_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm86_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_pagedKV_sm89_cu_cubin[];
|
||||
@ -288,6 +314,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_pagedKV
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_pagedKV_sm89_cu_cubin[];
|
||||
@ -298,6 +325,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_pagedKV_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm89_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_pagedKV_sm90_cu_cubin[];
|
||||
@ -308,6 +336,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_pagedKV
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_pagedKV_sm90_cu_cubin[];
|
||||
@ -318,6 +347,7 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_pagedKV_
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm90_cu_cubin[];
|
||||
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_256_S_32_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
@ -328,6 +358,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_pagedKV_tma_ws_sm
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_256_S_32_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_256_S_40_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
@ -337,6 +368,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_pagedKV_alibi_tma
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_256_S_32_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_256_S_40_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
@ -346,6 +378,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_pagedKV_tma_ws_sm
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_256_S_32_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_256_S_40_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
@ -355,6 +388,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_pagedKV_alibi_tma
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
@ -364,6 +398,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_pagedKV_tma_
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_32_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_256_S_40_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
@ -373,6 +408,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_pagedKV_alib
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_16_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_32_pagedKV_sm89_cu_cubin_len;
|
||||
@ -383,6 +419,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_pagedKV_sm89_cu_c
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_16_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_pagedKV_sm89_cu_cubin_len;
|
||||
@ -393,6 +430,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_96_pagedKV_sm89_cu_cu
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_104_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_16_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_32_pagedKV_sm89_cu_cubin_len;
|
||||
@ -403,6 +441,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_pagedKV_sm89_cu_c
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_16_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_32_pagedKV_sm89_cu_cubin_len;
|
||||
@ -413,6 +452,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_96_pagedKV_sm89_cu_cu
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_104_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_16_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_32_pagedKV_sm80_cu_cubin_len;
|
||||
@ -423,6 +463,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_pagedKV_sm80_cu_c
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_16_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_pagedKV_sm80_cu_cubin_len;
|
||||
@ -433,6 +474,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_96_pagedKV_sm80_cu_cu
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_104_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_16_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_32_pagedKV_sm80_cu_cubin_len;
|
||||
@ -443,6 +485,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_pagedKV_sm80_cu_c
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_16_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_32_pagedKV_sm80_cu_cubin_len;
|
||||
@ -453,6 +496,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_96_pagedKV_sm80_cu_cu
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_104_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_16_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_32_pagedKV_sm86_cu_cubin_len;
|
||||
@ -463,6 +507,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_pagedKV_sm86_cu_c
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_16_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_pagedKV_sm86_cu_cubin_len;
|
||||
@ -473,6 +518,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_96_pagedKV_sm86_cu_cu
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_104_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_16_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_32_pagedKV_sm86_cu_cubin_len;
|
||||
@ -483,6 +529,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_pagedKV_sm86_cu_c
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_16_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_32_pagedKV_sm86_cu_cubin_len;
|
||||
@ -493,6 +540,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_96_pagedKV_sm86_cu_cu
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_104_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_16_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_128_128_S_32_pagedKV_sm90_cu_cubin_len;
|
||||
@ -503,6 +551,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_96_pagedKV_sm90_cu_c
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_104_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_128_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_16_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_64_S_32_pagedKV_sm90_cu_cubin_len;
|
||||
@ -513,6 +562,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_96_pagedKV_sm90_cu_cu
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_104_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_32_S_128_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_16_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_128_128_S_32_pagedKV_sm90_cu_cubin_len;
|
||||
@ -523,6 +573,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_96_pagedKV_sm90_cu_c
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_104_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_128_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_16_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_64_S_32_pagedKV_sm90_cu_cubin_len;
|
||||
@ -533,6 +584,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_96_pagedKV_sm90_cu_cu
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_104_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_32_S_128_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_pagedKV_sm80_cu_cubin_len;
|
||||
@ -543,6 +595,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_pagedKV_sm80
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_pagedKV_sm80_cu_cubin_len;
|
||||
@ -553,6 +606,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_pagedKV_sm80_
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm80_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_pagedKV_sm86_cu_cubin_len;
|
||||
@ -563,6 +617,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_pagedKV_sm86
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_pagedKV_sm86_cu_cubin_len;
|
||||
@ -573,6 +628,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_pagedKV_sm86_
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm86_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_pagedKV_sm89_cu_cubin_len;
|
||||
@ -583,6 +639,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_pagedKV_sm89
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_pagedKV_sm89_cu_cubin_len;
|
||||
@ -593,6 +650,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_pagedKV_sm89_
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm89_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_16_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_128_128_S_32_pagedKV_sm90_cu_cubin_len;
|
||||
@ -603,6 +661,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_96_pagedKV_sm90
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_104_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_128_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_16_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_32_pagedKV_sm90_cu_cubin_len;
|
||||
@ -613,6 +672,7 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_96_pagedKV_sm90_
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_104_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_32_S_128_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm90_cu_cubin_len;
|
||||
|
||||
static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
@ -659,6 +719,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, false, false},
|
||||
@ -686,6 +749,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_160_pagedKV_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_192_pagedKV_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_64_S_256_pagedKV_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, false, 2, true, false},
|
||||
@ -713,6 +779,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false},
|
||||
@ -740,6 +809,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_160_pagedKV_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_192_pagedKV_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_64_S_256_pagedKV_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false},
|
||||
@ -767,6 +839,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, false, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_sliding_window_causal_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, false, false},
|
||||
@ -794,6 +869,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_160_pagedKV_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_192_pagedKV_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_64_S_256_pagedKV_sliding_window_causal_alibi_tma_ws_sm90_kernel", 196864, 384, 64, false, true, true, true, 2, true, false},
|
||||
@ -824,6 +902,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true},
|
||||
@ -854,6 +935,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false},
|
||||
@ -884,6 +968,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
@ -914,6 +1001,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
@ -944,6 +1034,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true},
|
||||
@ -974,6 +1067,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false},
|
||||
@ -1004,6 +1100,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
@ -1034,6 +1133,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
@ -1064,6 +1166,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true},
|
||||
@ -1094,6 +1199,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false},
|
||||
@ -1124,6 +1232,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
@ -1154,6 +1265,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
@ -1184,6 +1298,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_160_pagedKV_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_192_pagedKV_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_128_S_256_pagedKV_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, false, 2, true, true},
|
||||
@ -1214,6 +1331,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm90_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_160_pagedKV_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm90_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_192_pagedKV_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm90_kernel_nl", 49152, 128, 64, false, true, false, false, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, false, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_S_256_pagedKV_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, false, 2, true, false},
|
||||
@ -1244,6 +1364,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_160_pagedKV_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_192_pagedKV_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_128_S_256_pagedKV_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
@ -1274,6 +1397,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_160_pagedKV_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_192_pagedKV_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_BF16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_bf16_64_16_S_256_pagedKV_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
@ -1304,6 +1430,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sliding_window_causal_sm80_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
@ -1334,6 +1463,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_80, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm80_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm80_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sliding_window_causal_sm80_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
@ -1364,6 +1496,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sliding_window_causal_sm86_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
@ -1394,6 +1529,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_86, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm86_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm86_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sliding_window_causal_sm86_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
@ -1424,6 +1562,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sliding_window_causal_sm89_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
@ -1454,6 +1595,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_89, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm89_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm89_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sliding_window_causal_sm89_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
@ -1484,6 +1628,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_160_pagedKV_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_192_pagedKV_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 0, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 1, true, true},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_128_S_256_pagedKV_sliding_window_causal_sm90_kernel_nl_tiled", 81920, 128, 64, false, true, false, true, 2, true, true},
|
||||
@ -1514,6 +1661,9 @@ static const struct FusedMultiHeadAttentionPagedKVKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_160_pagedKV_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 192, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_192_pagedKV_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 0, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 1, true, false},
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm90_cu_cubin, cubin_fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sm90_cu_cubin_len, "fmha_v2_flash_attention_fp16_fp32_64_16_S_256_pagedKV_sliding_window_causal_sm90_kernel_nl", 49152, 128, 64, false, true, false, true, 2, true, false}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user