mirror of
https://github.com/NVIDIA/TensorRT-LLM.git
synced 2026-01-13 22:18:36 +08:00
Update TensorRT-LLM (#1358)
Co-authored-by: Kaiyu <26294424+kaiyux@users.noreply.github.com>
This commit is contained in:
parent
66ca3378c6
commit
850b6fa1e7
@ -34,7 +34,7 @@
|
||||
- Optimize AllReduce for parallel attention on Falcon and GPT-J
|
||||
- Enable split-k for weight-only cutlass kernel when SM>=75
|
||||
* Documentation
|
||||
- Add [documentation for new builder workflow](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/new_workflow.md)
|
||||
- Add [documentation for convert/build workflow](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/checkpoint.md)
|
||||
|
||||
## Versions 0.6.0 / 0.6.1
|
||||
|
||||
|
||||
@ -282,6 +282,7 @@ The list of supported models is:
|
||||
* [InternLM](examples/internlm)
|
||||
* [LLaMA](examples/llama)
|
||||
* [LLaMA-v2](examples/llama)
|
||||
* [Mamba](examples/mamba)
|
||||
* [mBART](examples/enc_dec)
|
||||
* [Mistral](examples/llama#mistral-v01)
|
||||
* [MPT](examples/mpt)
|
||||
@ -454,7 +455,7 @@ For example: `mpirun -n 1 python3 examples/run.py ...`
|
||||
- Support FP16 fMHA on NVIDIA V100 GPU
|
||||
* API
|
||||
- Add a set of High-level APIs for end-to-end generation tasks (see examples/high-level-api/README.md)
|
||||
- **[BREAKING CHANGES]** Migrate models to the new build workflow, including LLaMA, Mistral, Mixtral, InternLM, ChatGLM, Falcon, GPT-J, GPT-NeoX, Medusa, MPT, Baichuan and Phi (see docs/source/new_workflow.md)
|
||||
- **[BREAKING CHANGES]** Migrate models to the new build workflow, including LLaMA, Mistral, Mixtral, InternLM, ChatGLM, Falcon, GPT-J, GPT-NeoX, Medusa, MPT, Baichuan and Phi (see docs/source/checkpoint.md)
|
||||
- **[BREAKING CHANGES]** Deprecate `LayerNorm` and `RMSNorm` plugins and removed corresponding build parameters
|
||||
- **[BREAKING CHANGES]** Remove optional parameter `maxNumSequences` for GPT manager
|
||||
* Bug fixes
|
||||
@ -482,7 +483,7 @@ For example: `mpirun -n 1 python3 examples/run.py ...`
|
||||
- Batch manager arguments documentation updates
|
||||
- Add documentation for best practices for tuning the performance of TensorRT-LLM (See docs/source/perf_best_practices.md)
|
||||
- Add documentation for Falcon AWQ support (See examples/falcon/README.md)
|
||||
- Update to the `docs/source/new_workflow.md` documentation
|
||||
- Update to the `docs/source/checkpoint.md` documentation
|
||||
- Update AWQ INT4 weight only quantization documentation for GPT-J
|
||||
- Add blog: Speed up inference with SOTA quantization techniques in TRT-LLM
|
||||
- Refine TensorRT-LLM backend README structure #133
|
||||
|
||||
@ -19,8 +19,10 @@ set(TOP_LEVEL_DIR "${PROJECT_SOURCE_DIR}/..")
|
||||
|
||||
add_custom_target(benchmarks)
|
||||
|
||||
set(CXXOPTS_SRC_DIR ${PROJECT_SOURCE_DIR}/../3rdparty/cxxopts)
|
||||
add_subdirectory(${CXXOPTS_SRC_DIR} ${CMAKE_CURRENT_BINARY_DIR}/cxxopts)
|
||||
if(NOT TARGET cxxopts::cxxopts)
|
||||
set(CXXOPTS_SRC_DIR ${PROJECT_SOURCE_DIR}/../3rdparty/cxxopts)
|
||||
add_subdirectory(${CXXOPTS_SRC_DIR} ${CMAKE_CURRENT_BINARY_DIR}/cxxopts)
|
||||
endif()
|
||||
|
||||
function(add_benchmark test_name test_src)
|
||||
add_executable(${test_name} ${test_src})
|
||||
|
||||
@ -127,6 +127,7 @@ python prepare_dataset.py \
|
||||
|
||||
For `tokenizer`, specifying the path to the local tokenizer that have already been downloaded, or simply the name of the tokenizer from HuggingFace like `meta-llama/Llama-2-7b` will both work. The tokenizer will be downloaded automatically for the latter case.
|
||||
|
||||
|
||||
#### Prepare TensorRT-LLM engines
|
||||
Please make sure that the engines are built with argument `--use_inflight_batching` and `--remove_input_padding` if you'd like to benchmark inflight batching, for more details, please see the document in TensorRT-LLM examples.
|
||||
|
||||
@ -187,3 +188,130 @@ Take GPT-350M as an example for single GPU with static batching
|
||||
--static_emulated_timeout 100 \
|
||||
--dataset ../../benchmarks/cpp/tokens-fixed-lengths.json
|
||||
```
|
||||
|
||||
#### Benchmarking LoRA
|
||||
|
||||
Using either of the `prepare_dataset.py` methods above, add `--rand-task-id <start-id> <end-id>` to the command. This will add a random `task_id` from `<start-id>` to `<end-id>` inclusive.
|
||||
You can then use `utils/generate_rand_loras.py` to generate random LoRA weights for benchmarking purposes. `utils/generate_rand_loras.py` takes an example LoRA for the model you are benchmarking.
|
||||
Then you can run `gptManagerBenchmark` with `--type IFB` and `--lora_dir /path/to/utils/generate_rand_loras/output`
|
||||
|
||||
End-to-end LoRA benchmarking script
|
||||
|
||||
```
|
||||
git-lfs clone https://huggingface.co/meta-llama/Llama-2-13b-hf
|
||||
git-lfs clone https://huggingface.co/hfl/chinese-llama-2-lora-13b
|
||||
|
||||
MODEL_CHECKPOINT=Llama-2-13b-hf
|
||||
CONVERTED_CHECKPOINT=Llama-2-13b-hf-ckpt
|
||||
TOKENIZER=Llama-2-13b-hf
|
||||
LORA_ENGINE=Llama-2-13b-hf-engine
|
||||
|
||||
DTYPE=float16
|
||||
TP=2
|
||||
PP=1
|
||||
MAX_LEN=1024
|
||||
MAX_BATCH=32
|
||||
MAX_LORA_RANK=32
|
||||
|
||||
SOURCE_LORA=chinese-llama-2-lora-13b
|
||||
CPP_LORA=chinese-llama-2-lora-13b-cpp
|
||||
|
||||
EG_DIR=/tmp/lora-eg
|
||||
|
||||
# Build lora enabled engine
|
||||
python examples/llama/convert_checkpoint.py --model_dir ${MODEL_CHECKPOINT} \
|
||||
--output_dir ${CONVERTED_CHECKPOINT} \
|
||||
--dtype ${DTYPE} \
|
||||
--tp_size ${TP} \
|
||||
--pp_size 1 \
|
||||
--lora_target_modules attn_qkv \
|
||||
--max_lora_rank ${MAX_LORA_RANK}
|
||||
|
||||
${HOME}/.local/bin/trtllm-build \
|
||||
--checkpoint_dir ${CONVERTED_CHECKPOINT} \
|
||||
--output_dir ${LORA_ENGINE} \
|
||||
--max_batch_size ${MAX_BATCH} \
|
||||
--max_input_len $MAX_LEN \
|
||||
--max_output_len $MAX_LEN \
|
||||
--gpt_attention_plugin float16 \
|
||||
--paged_kv_cache enable \
|
||||
--remove_input_padding enable \
|
||||
--gemm_plugin float16 \
|
||||
--lora_plugin float16 \
|
||||
--use_paged_context_fmha enable \
|
||||
--use_custom_all_reduce disable
|
||||
|
||||
NUM_LORAS=(8 16 24 32 64 128 256)
|
||||
NUM_REQUESTS=1024
|
||||
|
||||
# Convert LoRA to cpp format
|
||||
python examples/gpt/nemo_lora_convert.py \
|
||||
-i $SOURCE_LORA \
|
||||
--storage-type $DTYPE \
|
||||
--write-cpp-runtime-tensors \
|
||||
-o $CPP_LORA
|
||||
|
||||
# Prepare datasets
|
||||
mkdir -p $EG_DIR/data
|
||||
|
||||
# Prepare dataset without lora_task_id
|
||||
python benchmarks/cpp/prepare_dataset.py \
|
||||
--output "${EG_DIR}/data/token-norm-dist.json" \
|
||||
--request-rate -1 \
|
||||
--time-delay-dist constant \
|
||||
--tokenizer $TOKENIZER \
|
||||
token-norm-dist \
|
||||
--num-requests $NUM_REQUESTS \
|
||||
--input-mean 256 --input-stdev 16 --output-mean 128 --output-stdev 24
|
||||
|
||||
# Prepare dataset with lora_task_ids from 0 - $nloras
|
||||
for nloras in ${NUM_LORAS[@]}; do
|
||||
python benchmarks/cpp/prepare_dataset.py \
|
||||
--output "${EG_DIR}/data/token-norm-dist-lora-${nloras}.json" \
|
||||
--request-rate -1 \
|
||||
--time-delay-dist constant \
|
||||
--rand-task-id 0 $(( $nloras - 1 )) \
|
||||
--tokenizer $TOKENIZER \
|
||||
token-norm-dist \
|
||||
--num-requests $NUM_REQUESTS \
|
||||
--input-mean 256 --input-stdev 16 --output-mean 128 --output-stdev 24
|
||||
done
|
||||
|
||||
# Generate random lora weights for 256 adapters
|
||||
python benchmarks/cpp/utils/generate_rand_loras.py ${CPP_LORA} ${EG_DIR}/loras 256
|
||||
|
||||
# perform benchmarking
|
||||
|
||||
# First run inference without LoRAs
|
||||
mkdir -p ${EG_DIR}/log-base-lora
|
||||
mpirun -n ${TP} --output-filename ${EG_DIR}/log-base-lora \
|
||||
cpp/build_Debug/benchmarks/gptManagerBenchmark \
|
||||
--engine_dir $LORA_ENGINE \
|
||||
--type IFB \
|
||||
--dataset "${EG_DIR}/data/token-norm-dist.json" \
|
||||
--lora_host_cache_bytes 8589934592 \
|
||||
--lora_num_device_mod_layers $(( 32 * $NUM_LAYERS * $NUM_LORA_MODS * $MAX_LORA_RANK )) \
|
||||
--kv_cache_free_gpu_mem_fraction 0.80 \
|
||||
--log_level info \
|
||||
--eos_id ${EOS_ID}
|
||||
|
||||
# Now run inference with various numbers or loras
|
||||
# The host cache is set large enough to hold all the LoRAs in lora_dir
|
||||
# GPU cache is set to hold 32 LoRAs
|
||||
# This benchmark will preload all the LoRAs into the host cache
|
||||
# We run inference on a range of active LoRAs exercising different cache miss rates.
|
||||
for nloras in ${NUM_LORAS[@]}; do
|
||||
mkdir -p ${EG_DIR}/log-lora-${nloras}
|
||||
mpirun -n ${TP} --output-filename "${EG_DIR}/log-lora-${nloras}" \
|
||||
cpp/build_Debug/benchmarks/gptManagerBenchmark \
|
||||
--engine_dir $LORA_ENGINE \
|
||||
--type IFB \
|
||||
--dataset "${EG_DIR}/data/token-norm-dist-lora-${nloras}.json" \
|
||||
--lora_host_cache_bytes 8589934592 \
|
||||
--lora_num_device_mod_layers $(( 32 * $NUM_LAYERS * $NUM_LORA_MODS * $MAX_LORA_RANK )) \
|
||||
--kv_cache_free_gpu_mem_fraction 0.80 \
|
||||
--log_level info \
|
||||
--eos_id ${EOS_ID} \
|
||||
--lora_dir ${EG_DIR}/loras
|
||||
done
|
||||
```
|
||||
|
||||
@ -34,6 +34,7 @@
|
||||
#include <cstdint>
|
||||
#include <cxxopts.hpp>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <nlohmann/json.hpp>
|
||||
#include <string>
|
||||
#include <thread>
|
||||
@ -300,9 +301,13 @@ struct BenchInfo
|
||||
|
||||
class Recorder
|
||||
{
|
||||
using TensorPtr = ITensor::SharedPtr;
|
||||
|
||||
public:
|
||||
explicit Recorder(std::string opCsvFile)
|
||||
explicit Recorder(std::string opCsvFile, std::string responsesJsonFile = "", bool excludeInputInOutput = false)
|
||||
: mOpCsvFile(std::move(opCsvFile))
|
||||
, mRespJsonFile(std::move(responsesJsonFile))
|
||||
, mOutputHasInput(!excludeInputInOutput)
|
||||
{
|
||||
}
|
||||
|
||||
@ -328,6 +333,10 @@ public:
|
||||
mRequestBenchInfos[requestId] = BenchInfo(inputLength, outputLength, start);
|
||||
}
|
||||
|
||||
// number of output tokens not calculated from output sequence here, instead set to max_output_len
|
||||
// - if eos_id == -1 (default behavior), this is correct since output seq will have max permissible length.
|
||||
// - However, if eos_id != -1, the token size of output sequence may be less than max_output_len, and token
|
||||
// throughput may be inaccurate
|
||||
void recordStart(SizeType inputLength, SizeType maxNewTokens, uint64_t requestId,
|
||||
std::chrono::time_point<std::chrono::steady_clock> const& start)
|
||||
{
|
||||
@ -342,20 +351,62 @@ public:
|
||||
.count();
|
||||
}
|
||||
|
||||
void recordEnd(uint64_t requestId, std::list<NamedTensor> const& responseTensors)
|
||||
{
|
||||
this->recordEnd(requestId);
|
||||
|
||||
if (mRespJsonFile.empty())
|
||||
return;
|
||||
int32_t outputSeqLen;
|
||||
|
||||
for (auto& tensor : responseTensors)
|
||||
{
|
||||
if (tensor.name == inference_request::kOutputIdsTensorName)
|
||||
{
|
||||
mResponseTensors[requestId] = tensor.tensor;
|
||||
}
|
||||
else if (tensor.name == inference_request::kSequenceLengthTensorName)
|
||||
{
|
||||
// Tensor of shape nBeams, and we only need the first one
|
||||
outputSeqLen = *(bufferCast<int32_t>(*(tensor.tensor)));
|
||||
if (mOutputHasInput)
|
||||
{
|
||||
int inputSeqLen = mRequestBenchInfos[requestId].inputLength;
|
||||
outputSeqLen -= inputSeqLen;
|
||||
}
|
||||
mRequestBenchInfos[requestId].outputLength = outputSeqLen;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float calcPercentile(std::vector<float> const& latencies, int percentile)
|
||||
{
|
||||
int const index = static_cast<int>(std::ceil((percentile / 100.0) * latencies.size())) - 1;
|
||||
return latencies[index];
|
||||
}
|
||||
|
||||
void calculateMetrics()
|
||||
{
|
||||
mNumSamples = mRequestBenchInfos.size();
|
||||
mTotalLatency = std::chrono::duration<float, std::milli>(mEnd - mStart).count();
|
||||
mSeqThroughput = mNumSamples / (mTotalLatency / 1000);
|
||||
mAvgSeqLatency = 0;
|
||||
|
||||
std::vector<float> reqLatencies;
|
||||
int totalOutputTokens = 0;
|
||||
for (auto reqInfo : mRequestBenchInfos)
|
||||
{
|
||||
mAvgSeqLatency += reqInfo.second.latency;
|
||||
reqLatencies.push_back(reqInfo.second.latency);
|
||||
totalOutputTokens += reqInfo.second.outputLength;
|
||||
}
|
||||
mAvgSeqLatency /= mNumSamples;
|
||||
mTokenThroughput = totalOutputTokens / (mTotalLatency / 1000);
|
||||
|
||||
std::sort(reqLatencies.begin(), reqLatencies.end());
|
||||
mP99SeqLatency = calcPercentile(reqLatencies, 99);
|
||||
mP90SeqLatency = calcPercentile(reqLatencies, 90);
|
||||
mP50SeqLatency = calcPercentile(reqLatencies, 50);
|
||||
}
|
||||
|
||||
void report()
|
||||
@ -363,8 +414,11 @@ public:
|
||||
printf("[BENCHMARK] num_samples %d\n", mNumSamples);
|
||||
printf("[BENCHMARK] total_latency(ms) %.2f\n", mTotalLatency);
|
||||
printf("[BENCHMARK] seq_throughput(seq/sec) %.2f\n", mSeqThroughput);
|
||||
printf("[BENCHMARK] avg_sequence_latency(ms) %.2f\n", mAvgSeqLatency);
|
||||
printf("[BENCHMARK] token_throughput(token/sec) %.2f\n", mTokenThroughput);
|
||||
printf("[BENCHMARK] avg_sequence_latency(ms) %.2f\n", mAvgSeqLatency);
|
||||
printf("[BENCHMARK] p99_sequence_latency(ms) %.2f\n", mP99SeqLatency);
|
||||
printf("[BENCHMARK] p90_sequence_latency(ms) %.2f\n", mP90SeqLatency);
|
||||
printf("[BENCHMARK] p50_sequence_latency(ms) %.2f\n", mP50SeqLatency);
|
||||
}
|
||||
|
||||
void writeOpMetricsToCsv()
|
||||
@ -372,7 +426,8 @@ public:
|
||||
if (!mOpCsvFile.empty())
|
||||
{
|
||||
std::vector<std::string> headers = {"num_samples", "total_latency(ms)", "seq_throughput(seq/sec)",
|
||||
"avg_sequence_latency(ms)", "token_throughput(token/sec)"};
|
||||
"token_throughput(token/sec)", "avg_sequence_latency(ms)", "p99_sequence_latency(ms)",
|
||||
"p90_sequence_latency(ms)", "p50_sequence_latency(ms)"};
|
||||
|
||||
std::ofstream outputFile(mOpCsvFile);
|
||||
|
||||
@ -383,8 +438,9 @@ public:
|
||||
outputFile << header << ",";
|
||||
}
|
||||
outputFile << "\n";
|
||||
outputFile << mNumSamples << "," << mTotalLatency << "," << mSeqThroughput << "," << mAvgSeqLatency
|
||||
<< "," << mTokenThroughput;
|
||||
outputFile << mNumSamples << "," << mTotalLatency << "," << mSeqThroughput << "," << mTokenThroughput
|
||||
<< "," << mAvgSeqLatency << "," << mP99SeqLatency << "," << mP90SeqLatency << ","
|
||||
<< mP50SeqLatency;
|
||||
outputFile << "\n";
|
||||
}
|
||||
else
|
||||
@ -394,6 +450,32 @@ public:
|
||||
}
|
||||
}
|
||||
|
||||
void dumpResponseSeqs()
|
||||
{
|
||||
if (mRespJsonFile.empty())
|
||||
return;
|
||||
nlohmann::json jsonResponses = nlohmann::json::array();
|
||||
for (auto const& [respId, respTokensTensor] : mResponseTensors)
|
||||
{
|
||||
int inputLength = mRequestBenchInfos[respId].inputLength;
|
||||
int outputLength = mRequestBenchInfos[respId].outputLength;
|
||||
std::vector<int32_t> outputTokens(outputLength);
|
||||
|
||||
int32_t* outputToksBufferPtr = bufferCast<int32_t>(*respTokensTensor);
|
||||
if (mOutputHasInput)
|
||||
outputToksBufferPtr += inputLength;
|
||||
std::copy(outputToksBufferPtr, outputToksBufferPtr + outputLength, outputTokens.begin());
|
||||
|
||||
nlohmann::json currResp;
|
||||
currResp["response_id"] = respId;
|
||||
currResp["response_tokens"] = outputTokens;
|
||||
jsonResponses.push_back(currResp);
|
||||
}
|
||||
std::ofstream outFile(mRespJsonFile);
|
||||
outFile << jsonResponses;
|
||||
outFile.close();
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<uint64_t, BenchInfo> mRequestBenchInfos;
|
||||
|
||||
@ -404,7 +486,14 @@ private:
|
||||
float mSeqThroughput{};
|
||||
float mAvgSeqLatency{};
|
||||
float mTokenThroughput{};
|
||||
float mP99SeqLatency{};
|
||||
float mP90SeqLatency{};
|
||||
float mP50SeqLatency{};
|
||||
std::string mOpCsvFile;
|
||||
std::string mRespJsonFile;
|
||||
std::unordered_map<uint64_t, TensorPtr> mResponseTensors;
|
||||
bool mOutputHasInput;
|
||||
|
||||
}; // class Recorder
|
||||
|
||||
class ExecutorServer
|
||||
@ -430,7 +519,7 @@ public:
|
||||
maxBeamWidth, schedulerConfig, kvCacheConfig, benchmarkParams.enableChunkedContext, true);
|
||||
executorConfig.setPeftCacheConfig(peftCacheConfig);
|
||||
|
||||
mExecutor = std::make_shared<texec::Executor>(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig);
|
||||
mExecutor = std::make_unique<texec::Executor>(trtEnginePath, texec::ModelType::kDECODER_ONLY, executorConfig);
|
||||
|
||||
if (logIterationData)
|
||||
{
|
||||
@ -519,7 +608,7 @@ public:
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<texec::Executor> mExecutor;
|
||||
std::unique_ptr<texec::Executor> mExecutor;
|
||||
std::thread mCollectStatsThread;
|
||||
std::shared_ptr<Recorder> mRecorder;
|
||||
std::chrono::milliseconds mWaitSleep;
|
||||
@ -535,7 +624,7 @@ public:
|
||||
batch_scheduler::SchedulerPolicy schedulerPolicy, TrtGptModelOptionalParams const& optionalParams,
|
||||
std::shared_ptr<Recorder> recorder, std::optional<uint64_t> terminateReqId, std::chrono::milliseconds waitSleep,
|
||||
std::optional<SizeType> const staticEmulatedBatchSize,
|
||||
std::optional<std::chrono::milliseconds> const batchTimeout, bool logIterationData)
|
||||
std::optional<std::chrono::milliseconds> const batchTimeout, bool logIterationData, bool excludeInputInOutput)
|
||||
: mRecorder(std::move(recorder))
|
||||
, mTerminateReqId(terminateReqId)
|
||||
, mWaitSleep(waitSleep)
|
||||
@ -564,7 +653,7 @@ public:
|
||||
[this](uint64_t requestId, std::list<NamedTensor> const& response_tensors, bool final_response,
|
||||
std::string const& errMsg)
|
||||
{ return sendResponse(requestId, response_tensors, final_response, errMsg); },
|
||||
nullptr, iterationDataCallback, optionalParams, terminateReqId);
|
||||
nullptr, iterationDataCallback, optionalParams, terminateReqId, std::nullopt, excludeInputInOutput);
|
||||
}
|
||||
|
||||
~GptServer()
|
||||
@ -729,7 +818,7 @@ public:
|
||||
if (final_response)
|
||||
{
|
||||
mWorkItemsQueue.markFinished(requestId);
|
||||
mRecorder->recordEnd(requestId);
|
||||
mRecorder->recordEnd(requestId, response_tensors);
|
||||
mActiveCount--;
|
||||
}
|
||||
}
|
||||
@ -852,7 +941,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
BenchmarkParams const& benchmarkParams, batch_scheduler::SchedulerPolicy schedulerPolicy,
|
||||
std::chrono::milliseconds waitSleep, bool returnContextLogits, bool returnGenerationLogits,
|
||||
std::optional<SizeType> const staticEmulatedBatchSize, std::optional<std::chrono::milliseconds> const batchTimeout,
|
||||
bool logIterationData)
|
||||
bool logIterationData, bool excludeInputInOutput, std::string const& responsesJsonFile)
|
||||
{
|
||||
auto const worldConfig = WorldConfig::mpi();
|
||||
|
||||
@ -885,10 +974,11 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
auto const numSamples = samples.size();
|
||||
|
||||
int const maxBeamWidth = beamWidth;
|
||||
auto recorder = std::make_shared<Recorder>(opCsvFile);
|
||||
auto recorder = std::make_shared<Recorder>(opCsvFile, responsesJsonFile, excludeInputInOutput);
|
||||
uint64_t terminateReqId = numSamples + 1;
|
||||
auto gptServer = std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams,
|
||||
recorder, terminateReqId, waitSleep, staticEmulatedBatchSize, batchTimeout, logIterationData);
|
||||
auto gptServer
|
||||
= std::make_shared<GptServer>(engineDir, modelType, maxBeamWidth, schedulerPolicy, optionalParams, recorder,
|
||||
terminateReqId, waitSleep, staticEmulatedBatchSize, batchTimeout, logIterationData, excludeInputInOutput);
|
||||
|
||||
ITensor::SharedPtr eosIdTensor{
|
||||
eosId ? bufferManager.copyFrom(&eosId.value(), ITensor::makeShape({1}), MemoryType::kPINNED) : nullptr};
|
||||
@ -962,6 +1052,7 @@ void benchmarkGptManager(std::filesystem::path const& engineDir, TrtGptModelType
|
||||
recorder->calculateMetrics();
|
||||
recorder->report();
|
||||
recorder->writeOpMetricsToCsv();
|
||||
recorder->dumpResponseSeqs();
|
||||
// Send terminateReqId to terminate servers on all ranks
|
||||
// Server on rank 0 will broadcast the terminate signal to other servers on multi-GPU cases
|
||||
gptServer->enqueue(std::make_shared<InferenceRequest>(terminateReqId));
|
||||
@ -1154,6 +1245,14 @@ int main(int argc, char* argv[])
|
||||
options.add_options()("lora_host_cache_bytes", "LoRA host cache memory in bytes", cxxopts::value<size_t>());
|
||||
options.add_options()("lora_num_device_mod_layers", "LoRA number 1d cache rows", cxxopts::value<int>());
|
||||
|
||||
options.add_options()("exclude_input_in_output_seq",
|
||||
"When enabled, GptManager will exclude the input sequence from output. (Only works if --api is gptManager)",
|
||||
cxxopts::value<bool>());
|
||||
|
||||
options.add_options()("responses_json_file",
|
||||
"When specified, dumps the responses to JSON file. (only works if --api is gptManager)",
|
||||
cxxopts::value<std::string>()->default_value(""));
|
||||
|
||||
auto result = options.parse(argc, argv);
|
||||
|
||||
if (result.count("help"))
|
||||
@ -1329,7 +1428,8 @@ int main(int argc, char* argv[])
|
||||
benchmarkGptManager(result["engine_dir"].as<std::string>(), modelType, datasetPath, opCsvFile,
|
||||
maxNumSamples, beamWidth, result["warm_up"].as<int>(), eosId, padId, benchmarkParams, schedulerPolicy,
|
||||
waitSleep, returnContextLogits, returnGenerationLogits, staticEmulatedBatchSize, batchTimeout,
|
||||
logIterationData);
|
||||
logIterationData, result["exclude_input_in_output_seq"].as<bool>(),
|
||||
result["responses_json_file"].as<std::string>());
|
||||
}
|
||||
catch (std::exception const& e)
|
||||
{
|
||||
|
||||
@ -29,7 +29,9 @@ from tensorrt_llm.functional import AllReduceStrategy, allreduce
|
||||
from tensorrt_llm.plugin.plugin import current_all_reduce_helper
|
||||
|
||||
|
||||
def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
|
||||
def allreduce_benchmark(dtype: str,
|
||||
test_range: str = "10,10000000,10",
|
||||
no_header: bool = False):
|
||||
tllm.logger.set_level('error')
|
||||
world_size = tllm.mpi_world_size()
|
||||
rank = tllm.mpi_rank()
|
||||
@ -48,11 +50,15 @@ def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
|
||||
|
||||
size = min_size
|
||||
dtype_size = torch.finfo(torch_dtype).bits // 8
|
||||
if mapping.rank == 0 and not no_header:
|
||||
print(
|
||||
f"{'world_size':<15}, {'dtype':<10}, {'message size':<15}, {'strategy':<15}, {'duration (ms)':<10}"
|
||||
)
|
||||
while size < max_size:
|
||||
input = torch.ones(size, dtype=torch_dtype, device="cuda")
|
||||
|
||||
for strategy in [
|
||||
AllReduceStrategy.RING, AllReduceStrategy.ONESHOT,
|
||||
AllReduceStrategy.NCCL, AllReduceStrategy.ONESHOT,
|
||||
AllReduceStrategy.TWOSHOT
|
||||
]:
|
||||
builder = tllm.Builder()
|
||||
@ -95,33 +101,40 @@ def allreduce_benchmark(dtype: str, test_range: str = "10,10000000,10"):
|
||||
session = tllm.runtime.Session.from_engine(build_engine())
|
||||
_, start = cuda.cuEventCreate(0)
|
||||
_, stop = cuda.cuEventCreate(0)
|
||||
runtimes = []
|
||||
with peer_access(mapping):
|
||||
MPI.COMM_WORLD.barrier()
|
||||
|
||||
cuda.cuEventRecord(start, stream.cuda_stream)
|
||||
session.run(inputs=feed_dict,
|
||||
outputs={"output": output},
|
||||
stream=stream.cuda_stream)
|
||||
cuda.cuEventRecord(stop, stream.cuda_stream)
|
||||
torch.cuda.synchronize()
|
||||
_, ms = cuda.cuEventElapsedTime(start, stop)
|
||||
for _ in range(10):
|
||||
cuda.cuEventRecord(start, stream.cuda_stream)
|
||||
session.run(inputs=feed_dict,
|
||||
outputs={"output": output},
|
||||
stream=stream.cuda_stream)
|
||||
cuda.cuEventRecord(stop, stream.cuda_stream)
|
||||
torch.cuda.synchronize()
|
||||
_, ms = cuda.cuEventElapsedTime(start, stop)
|
||||
runtimes.append(ms)
|
||||
|
||||
median_ms = sorted(runtimes)[len(runtimes) // 2]
|
||||
assert torch.allclose(output, (input * world_size)**inner_loop)
|
||||
|
||||
if mapping.rank == 0:
|
||||
print(f"{size=}, {strategy=}, {ms=}")
|
||||
print(
|
||||
f"{mapping.world_size:<15}, {dtype:<10}, {size:<15}, {strategy.name:<15}, {median_ms:<10.2f}"
|
||||
)
|
||||
|
||||
size *= ratio
|
||||
if mapping.rank == 0:
|
||||
print("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--dtype", "-t", default="float16")
|
||||
parser.add_argument("--range",
|
||||
"-r",
|
||||
default="256,25600000,10",
|
||||
help="min_size,max_size,multiplicative_ratio")
|
||||
parser.add_argument(
|
||||
"--range",
|
||||
"-r",
|
||||
default="256,256000000,10", # 256 to 256M
|
||||
help="min_size,max_size,multiplicative_ratio")
|
||||
parser.add_argument("--no-header", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
allreduce_benchmark(args.dtype, args.range)
|
||||
allreduce_benchmark(args.dtype, args.range, args.no_header)
|
||||
|
||||
@ -293,6 +293,42 @@ _allowed_configs = {
|
||||
pre_norm=True,
|
||||
do_layer_norm_before=True,
|
||||
)),
|
||||
"starcoder":
|
||||
ModelConfig(name="starcoder_15.5b",
|
||||
family="gpt",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=40,
|
||||
num_heads=48,
|
||||
num_kv_heads=1,
|
||||
hidden_size=6144,
|
||||
vocab_size=49152,
|
||||
hidden_act='gelu',
|
||||
n_positions=8192,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"starcoder2_3b":
|
||||
ModelConfig(name="starcoder2_3b",
|
||||
family="gpt",
|
||||
benchmark_type="gpt",
|
||||
build_config=BuildConfig(
|
||||
num_layers=30,
|
||||
num_heads=24,
|
||||
num_kv_heads=2,
|
||||
hidden_size=3072,
|
||||
vocab_size=49152,
|
||||
hidden_act='gelu',
|
||||
n_positions=16384,
|
||||
position_embedding_type='rope_gpt_neox',
|
||||
rotary_pct=1.0,
|
||||
max_batch_size=256,
|
||||
max_input_len=512,
|
||||
max_output_len=200,
|
||||
builder_opt=None,
|
||||
)),
|
||||
"llama_7b":
|
||||
ModelConfig(name="llama_7b",
|
||||
family="llama",
|
||||
|
||||
@ -37,7 +37,7 @@ from tensorrt_llm.models import PretrainedConfig, quantize_model
|
||||
from tensorrt_llm.models.modeling_utils import optimize_model
|
||||
from tensorrt_llm.network import net_guard
|
||||
from tensorrt_llm.plugin.plugin import ContextFMHAType
|
||||
from tensorrt_llm.quantization import QuantMode
|
||||
from tensorrt_llm.quantization import QuantAlgo, QuantMode
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
@ -224,23 +224,23 @@ def get_quant_mode(quantization):
|
||||
|
||||
def get_quant_algo(quantization):
|
||||
if quantization == "fp8":
|
||||
return "FP8", "FP8"
|
||||
return QuantAlgo.FP8, QuantAlgo.FP8
|
||||
elif quantization == "fp8_gemm":
|
||||
return "FP8", None
|
||||
return QuantAlgo.FP8, None
|
||||
elif quantization == "fp8_kv_cache":
|
||||
return None, "FP8"
|
||||
return None, QuantAlgo.FP8
|
||||
elif quantization == "int8_sq_per_tensor":
|
||||
return "W8A8_SQ_PER_TENSOR_PLUGIN", None
|
||||
return QuantAlgo.W8A8_SQ_PER_TENSOR_PLUGIN, None
|
||||
elif quantization == "int8_sq_per_token_channel":
|
||||
return "W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN", None
|
||||
return QuantAlgo.W8A8_SQ_PER_CHANNEL_PER_TOKEN_PLUGIN, None
|
||||
elif quantization == "int8_weight_only":
|
||||
return "W8A16", None
|
||||
return QuantAlgo.W8A16, None
|
||||
elif quantization == "int4_weight_only":
|
||||
return "W4A16", None
|
||||
return QuantAlgo.W4A16, None
|
||||
elif quantization == "int4_weight_only_awq":
|
||||
return "W4A16_AWQ", None
|
||||
return QuantAlgo.W4A16_AWQ, None
|
||||
elif quantization == "int4_weight_only_gptq":
|
||||
return "W4A16_GPTQ", None
|
||||
return QuantAlgo.W4A16_GPTQ, None
|
||||
elif quantization is None:
|
||||
return None, None
|
||||
|
||||
@ -764,13 +764,11 @@ def build_gpt(args):
|
||||
else:
|
||||
raise Exception(f'Unexpected model: {args.model}')
|
||||
|
||||
quant_kwargs = {}
|
||||
if family not in [
|
||||
'gpt', 'opt', 'bloom', 'falcon', 'llama', 'internlm', 'gptneox',
|
||||
'gptj', 'mamba', 'baichuan', 'chatglm', 'chatglm2', 'chatglm3'
|
||||
]:
|
||||
tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode,
|
||||
**quant_kwargs)
|
||||
tensorrt_llm_model = quantize_model(tensorrt_llm_model, quant_mode)
|
||||
|
||||
if family in ['llama']:
|
||||
tensorrt_llm_model = optimize_model(tensorrt_llm_model,
|
||||
@ -788,6 +786,7 @@ def build_gpt(args):
|
||||
network.plugin_config.enable_remove_input_padding()
|
||||
network.plugin_config.set_lookup_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_moe_plugin(dtype=args.dtype)
|
||||
network.plugin_config.set_mamba_conv1d_plugin(dtype=args.dtype)
|
||||
|
||||
if args.quantization is None or "fp8" not in args.quantization:
|
||||
network.plugin_config.set_gemm_plugin(dtype=args.dtype)
|
||||
|
||||
@ -96,6 +96,7 @@ class GPTBenchmark(BaseBenchmark):
|
||||
self.use_gpt_attention_plugin = True
|
||||
self.remove_input_padding = True
|
||||
self.use_moe_plugin = True
|
||||
self.use_mamba_conv1d_plugin = True
|
||||
elif args.mode == 'ootb-except-mha':
|
||||
self.use_gpt_attention_plugin = True
|
||||
|
||||
@ -121,6 +122,8 @@ class GPTBenchmark(BaseBenchmark):
|
||||
gpt_attention_plugin=self.use_gpt_attention_plugin,
|
||||
paged_kv_cache=self.paged_kv_cache if hasattr(
|
||||
self, 'paged_kv_cache') else False,
|
||||
paged_state=self.paged_state
|
||||
if hasattr(self, 'paged_state') else False,
|
||||
dtype=self.dtype,
|
||||
remove_input_padding=self.remove_input_padding,
|
||||
quant_mode=self.quant_mode,
|
||||
@ -148,8 +151,6 @@ class GPTBenchmark(BaseBenchmark):
|
||||
model_config.mamba_d_state = self.mamba_d_state
|
||||
model_config.mamba_d_conv = self.mamba_d_conv
|
||||
model_config.mamba_expand = self.mamba_expand
|
||||
self.remove_input_padding = False
|
||||
model_config.remove_input_padding = False
|
||||
self.sampling_config = tensorrt_llm.runtime.SamplingConfig(
|
||||
end_id=0, pad_id=0, top_k=args.top_k, top_p=args.top_p)
|
||||
self.decoder = tensorrt_llm.runtime.MambaLMHeadModelGenerationSession(
|
||||
|
||||
@ -25,6 +25,7 @@
|
||||
namespace tensorrt_llm::batch_manager::kv_cache_manager
|
||||
{
|
||||
|
||||
//! @brief Encapsulates parameters to configure paged KV cache.
|
||||
class KvCacheConfig
|
||||
{
|
||||
public:
|
||||
|
||||
@ -53,7 +53,7 @@ public:
|
||||
using VecLogProbs = std::vector<float>;
|
||||
using BeamTokens = std::vector<VecTokens>;
|
||||
using TensorPtr = TTensor;
|
||||
using LogitsPostProcessor = std::function<TensorPtr(RequestIdType, TensorPtr&, BeamTokens const&, TStream)>;
|
||||
using LogitsPostProcessor = std::function<void(RequestIdType, TensorPtr&, BeamTokens const&, TStream)>;
|
||||
|
||||
GenericLlmRequest(RequestIdType requestId, SizeType maxNewTokens, std::shared_ptr<VecTokens> inputTokens,
|
||||
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType> endId = std::nullopt,
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#include "tensorrt_llm/executor/executor.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/decodingMode.h"
|
||||
#include "tensorrt_llm/runtime/medusaModule.h"
|
||||
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
@ -40,7 +41,8 @@ public:
|
||||
bool enableTrtOverlap = false, std::optional<std::vector<SizeType>> const& deviceIds = std::nullopt,
|
||||
bool normalizeLogProbs = true, bool enableChunkedContext = false,
|
||||
std::optional<runtime::DecodingMode> const& decodingMode = std::nullopt,
|
||||
PeftCacheManagerConfig const& peftCacheManagerConfig = PeftCacheManagerConfig{})
|
||||
PeftCacheManagerConfig const& peftCacheManagerConfig = PeftCacheManagerConfig{},
|
||||
std::optional<runtime::MedusaModule::MedusaChoices> const& medusaChoices = std::nullopt)
|
||||
: kvCacheConfig{kvCacheConfig}
|
||||
, enableTrtOverlap{enableTrtOverlap}
|
||||
, deviceIds(deviceIds)
|
||||
@ -48,6 +50,7 @@ public:
|
||||
, enableChunkedContext{enableChunkedContext}
|
||||
, decodingMode{decodingMode}
|
||||
, peftCacheManagerConfig(peftCacheManagerConfig)
|
||||
, medusaChoices(medusaChoices)
|
||||
{
|
||||
}
|
||||
|
||||
@ -55,7 +58,8 @@ public:
|
||||
: TrtGptModelOptionalParams(KvCacheConfig(executorConfig.getKvCacheConfig()), false,
|
||||
executorConfig.getParallelConfig().value_or(executor::ParallelConfig()).getDeviceIds(),
|
||||
executorConfig.getNormalizeLogProbs(), executorConfig.getEnableChunkedContext(), std::nullopt,
|
||||
PeftCacheManagerConfig(executorConfig.getPeftCacheConfig()))
|
||||
PeftCacheManagerConfig(executorConfig.getPeftCacheConfig().value_or(executor::PeftCacheConfig())),
|
||||
executorConfig.getMedusaChoices())
|
||||
{
|
||||
}
|
||||
|
||||
@ -74,6 +78,7 @@ public:
|
||||
bool enableChunkedContext;
|
||||
std::optional<runtime::DecodingMode> decodingMode;
|
||||
PeftCacheManagerConfig peftCacheManagerConfig;
|
||||
std::optional<runtime::MedusaModule::MedusaChoices> medusaChoices;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::batch_manager
|
||||
|
||||
@ -64,15 +64,17 @@ std::string fmtstr(char const* format, ...) __attribute__((format(printf, 1, 2))
|
||||
#define __PRETTY_FUNCTION__ __FUNCSIG__
|
||||
#endif
|
||||
|
||||
auto constexpr kDefaultDelimiter = ", ";
|
||||
|
||||
template <typename U, typename TStream, typename T>
|
||||
inline TStream& arr2outCasted(TStream& out, T* arr, size_t size)
|
||||
inline TStream& arr2outCasted(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter)
|
||||
{
|
||||
out << "(";
|
||||
if (size > 0)
|
||||
{
|
||||
for (size_t i = 0; i < size - 1; ++i)
|
||||
{
|
||||
out << static_cast<U>(arr[i]) << ", ";
|
||||
out << static_cast<U>(arr[i]) << delim;
|
||||
}
|
||||
out << static_cast<U>(arr[size - 1]);
|
||||
}
|
||||
@ -81,22 +83,22 @@ inline TStream& arr2outCasted(TStream& out, T* arr, size_t size)
|
||||
}
|
||||
|
||||
template <typename TStream, typename T>
|
||||
inline TStream& arr2out(TStream& out, T* arr, size_t size)
|
||||
inline TStream& arr2out(TStream& out, T* arr, size_t size, char const* delim = kDefaultDelimiter)
|
||||
{
|
||||
return arr2outCasted<T>(out, arr, size);
|
||||
return arr2outCasted<T>(out, arr, size, delim);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::string arr2str(T* arr, size_t size)
|
||||
inline std::string arr2str(T* arr, size_t size, char const* delim = kDefaultDelimiter)
|
||||
{
|
||||
std::stringstream ss;
|
||||
return arr2out(ss, arr, size).str();
|
||||
return arr2out(ss, arr, size, delim).str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline std::string vec2str(std::vector<T> vec)
|
||||
inline std::string vec2str(std::vector<T> vec, char const* delim = kDefaultDelimiter)
|
||||
{
|
||||
return arr2str(vec.data(), vec.size());
|
||||
return arr2str(vec.data(), vec.size(), delim);
|
||||
}
|
||||
|
||||
inline bool strStartsWith(std::string const& str, std::string const& prefix)
|
||||
|
||||
@ -27,8 +27,6 @@
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <variant>
|
||||
#include <vector>
|
||||
|
||||
namespace tensorrt_llm::executor
|
||||
@ -43,18 +41,19 @@ class SamplingConfig
|
||||
public:
|
||||
/// @brief Constructor for SamplingConfig
|
||||
/// See description of parameters below
|
||||
SamplingConfig(SizeType beamWidth = 1, std::optional<SizeType> topK = std::nullopt,
|
||||
std::optional<FloatType> topP = std::nullopt, std::optional<FloatType> topPMin = std::nullopt,
|
||||
std::optional<SizeType> topPResetIds = std::nullopt, std::optional<FloatType> topPDecay = std::nullopt,
|
||||
std::optional<RandomSeedType> randomSeed = std::nullopt, std::optional<FloatType> temperature = std::nullopt,
|
||||
std::optional<SizeType> minLength = std::nullopt,
|
||||
std::optional<FloatType> beamSearchDiversityRate = std::nullopt,
|
||||
std::optional<FloatType> repetitionPenalty = std::nullopt,
|
||||
std::optional<FloatType> presencePenalty = std::nullopt,
|
||||
std::optional<FloatType> frequencyPenalty = std::nullopt, std::optional<FloatType> lengthPenalty = std::nullopt,
|
||||
std::optional<SizeType> earlyStopping = std::nullopt);
|
||||
|
||||
~SamplingConfig();
|
||||
explicit SamplingConfig(SizeType beamWidth = 1, std::optional<SizeType> const& topK = std::nullopt,
|
||||
std::optional<FloatType> const& topP = std::nullopt, std::optional<FloatType> const& topPMin = std::nullopt,
|
||||
std::optional<SizeType> const& topPResetIds = std::nullopt,
|
||||
std::optional<FloatType> const& topPDecay = std::nullopt,
|
||||
std::optional<RandomSeedType> const& randomSeed = std::nullopt,
|
||||
std::optional<FloatType> const& temperature = std::nullopt,
|
||||
std::optional<SizeType> const& minLength = std::nullopt,
|
||||
std::optional<FloatType> const& beamSearchDiversityRate = std::nullopt,
|
||||
std::optional<FloatType> const& repetitionPenalty = std::nullopt,
|
||||
std::optional<FloatType> const& presencePenalty = std::nullopt,
|
||||
std::optional<FloatType> const& frequencyPenalty = std::nullopt,
|
||||
std::optional<FloatType> const& lengthPenalty = std::nullopt,
|
||||
std::optional<SizeType> const& earlyStopping = std::nullopt);
|
||||
|
||||
bool operator==(SamplingConfig const& other) const;
|
||||
|
||||
@ -91,23 +90,24 @@ private:
|
||||
std::optional<FloatType> mTopPDecay;
|
||||
/// @brief Controls the random seed used by the random number generator in sampling
|
||||
std::optional<RandomSeedType> mRandomSeed;
|
||||
/// @brief Controls the modulation of logits when sampling new tokens. Default is 1.0f
|
||||
/// @brief Controls the modulation of logits when sampling new tokens. It can have values > 0.f. Default is 1.0f
|
||||
std::optional<FloatType> mTemperature;
|
||||
/// @brief Lower bound on the number of tokens to generate
|
||||
/// @brief Lower bound on the number of tokens to generate. Values < 1 have no effect. Default is 1.
|
||||
std::optional<SizeType> mMinLength;
|
||||
/// @brief Controls the diversity in beam search.
|
||||
std::optional<FloatType> mBeamSearchDiversityRate;
|
||||
/// @brief Used to penalize tokens based on how often they appear in the sequence. Default is 0.f
|
||||
/// @brief Used to penalize tokens based on how often they appear in the sequence. It can have any value > 0.f.
|
||||
/// Values < 1.f encourages repetition, values > 1.f discourages it. Default is 1.f
|
||||
std::optional<FloatType> mRepetitionPenalty;
|
||||
/// @brief Used to penalize tokens already present in the sequence (irrespective of the number of appearances).
|
||||
/// Default is 0.f
|
||||
/// @brief Used to penalize tokens already present in the sequence (irrespective of the number of appearances). It
|
||||
/// can have any values. Values < 0.f encourage repetition, values > 0.f discourage it. Default is 0.f
|
||||
std::optional<FloatType> mPresencePenalty;
|
||||
/// @brief Used to penalize tokens already present in the sequence (dependent on the number of appearances). Default
|
||||
/// is 0.f
|
||||
/// @brief Used to penalize tokens already present in the sequence (dependent on the number of appearances). It can
|
||||
/// have any values. Values < 0.f encourage repetition, values > 0.f discourage it. Default is 0.f
|
||||
std::optional<FloatType> mFrequencyPenalty;
|
||||
/// @brief Controls how to penalize longer sequences in beam search. Default is 0.f
|
||||
std::optional<FloatType> mLengthPenalty;
|
||||
/// @brief Controls whether the generation process finishes once beamWidth sentences are generated (end with
|
||||
/// @brief Controls whether the generation process finishes once beamWidth sentences are generated (ends with
|
||||
/// end_token)
|
||||
std::optional<SizeType> mEarlyStopping;
|
||||
};
|
||||
@ -116,12 +116,12 @@ private:
|
||||
class OutputConfig
|
||||
{
|
||||
public:
|
||||
OutputConfig(bool returnLogProbs = false, bool returnContextLogits = false, bool returnGenerationLogits = false,
|
||||
bool excludeInputFromOutput = false);
|
||||
explicit OutputConfig(bool returnLogProbs = false, bool returnContextLogits = false,
|
||||
bool returnGenerationLogits = false, bool excludeInputFromOutput = false);
|
||||
|
||||
/// @brief Controls if Result should contain log probabilities. Default is false
|
||||
/// @brief Controls if Result should contain log probabilities. Default is false.
|
||||
bool returnLogProbs;
|
||||
/// @brief Controls if Result should contain the context logits. Default is false
|
||||
/// @brief Controls if Result should contain the context logits. Default is false.
|
||||
bool returnContextLogits;
|
||||
/// @brief Controls if Result should contain the generation logits. Default is false.
|
||||
bool returnGenerationLogits;
|
||||
@ -135,9 +135,7 @@ class SpeculativeDecodingConfig
|
||||
{
|
||||
public:
|
||||
explicit SpeculativeDecodingConfig(VecTokens tokens, std::optional<Tensor> logits = std::nullopt,
|
||||
std::optional<FloatType> acceptanceThreshold = std::nullopt);
|
||||
|
||||
~SpeculativeDecodingConfig();
|
||||
std::optional<FloatType> const& acceptanceThreshold = std::nullopt);
|
||||
|
||||
[[nodiscard]] VecTokens getTokens() const;
|
||||
[[nodiscard]] std::optional<Tensor> getLogits() const;
|
||||
@ -147,9 +145,9 @@ private:
|
||||
friend class Serialization;
|
||||
/// @brief The draft tokens
|
||||
VecTokens mTokens;
|
||||
/// @brief The draft logits
|
||||
/// @brief The draft logits. Expected shape: [num_draft_tokens, vocab_size].
|
||||
std::optional<Tensor> mLogits;
|
||||
/// @brief The acceptance threshold
|
||||
/// @brief The acceptance threshold. Must be > 0.f and <= 1.f
|
||||
std::optional<FloatType> mAcceptanceThreshold;
|
||||
};
|
||||
|
||||
@ -157,14 +155,14 @@ private:
|
||||
class PromptTuningConfig
|
||||
{
|
||||
public:
|
||||
PromptTuningConfig(Tensor embeddingTable);
|
||||
~PromptTuningConfig();
|
||||
explicit PromptTuningConfig(Tensor embeddingTable);
|
||||
|
||||
[[nodiscard]] Tensor getEmbeddingTable() const;
|
||||
|
||||
private:
|
||||
friend class Serialization;
|
||||
/// @brief The prompt embedding table
|
||||
/// @brief The prompt embedding table. Expected shape: [task vocab_size, hidden_size]. Data type must match model
|
||||
/// weights.
|
||||
Tensor mEmbeddingTable;
|
||||
};
|
||||
|
||||
@ -172,9 +170,8 @@ private:
|
||||
class LoraConfig
|
||||
{
|
||||
public:
|
||||
LoraConfig(
|
||||
explicit LoraConfig(
|
||||
IdType taskId, std::optional<Tensor> weights = std::nullopt, std::optional<Tensor> config = std::nullopt);
|
||||
~LoraConfig();
|
||||
|
||||
[[nodiscard]] IdType getTaskId() const;
|
||||
[[nodiscard]] std::optional<Tensor> getWeights() const;
|
||||
@ -185,9 +182,9 @@ private:
|
||||
|
||||
/// @brief The Lora task id
|
||||
IdType mTaskId;
|
||||
/// @brief The Lora weights
|
||||
/// @brief The Lora weights. See TRT-LLM documentation for expected shapes and types
|
||||
std::optional<Tensor> mWeights;
|
||||
/// @brief The Lora configuration
|
||||
/// @brief The Lora configuration. See TRT-LLM documentation for detailed description of the config tensor
|
||||
std::optional<Tensor> mConfig;
|
||||
};
|
||||
|
||||
@ -199,7 +196,7 @@ public:
|
||||
|
||||
/// @param inputTokenIds The input token ids
|
||||
/// @param maxNewTokens The maximum number of tokens to generate
|
||||
/// @param streaming Indicates if the responses should be streamed or not
|
||||
/// @param streaming Indicates if the responses should be streamed or not. Default is false.
|
||||
/// @param samplingConfig The sampling configuration
|
||||
/// @param outputConfig The output configuration
|
||||
/// @param endId The end token id
|
||||
@ -213,8 +210,8 @@ public:
|
||||
/// @param logitsPostProcessorName The logits postprocessor name. Must correspond to one of the logits postprocessor
|
||||
/// name provided to the ExecutorConfig.
|
||||
Request(VecTokens inputTokenIds, SizeType maxNewTokens, bool streaming = false,
|
||||
SamplingConfig samplingConfig = SamplingConfig(), OutputConfig outputConfig = OutputConfig(),
|
||||
std::optional<SizeType> endId = std::nullopt, std::optional<SizeType> padId = std::nullopt,
|
||||
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
|
||||
std::optional<SizeType> const& endId = std::nullopt, std::optional<SizeType> const& padId = std::nullopt,
|
||||
std::optional<std::list<VecTokens>> badWords = std::nullopt,
|
||||
std::optional<std::list<VecTokens>> stopWords = std::nullopt,
|
||||
std::optional<Tensor> embeddingBias = std::nullopt,
|
||||
@ -245,16 +242,16 @@ public:
|
||||
[[nodiscard]] std::optional<std::string> getLogitsPostProcessorName() const;
|
||||
|
||||
void setStreaming(bool streaming);
|
||||
void setSamplingConfig(SamplingConfig config);
|
||||
void setOutputConfig(OutputConfig outputConfig);
|
||||
void setSamplingConfig(SamplingConfig const& config);
|
||||
void setOutputConfig(OutputConfig const& outputConfig);
|
||||
void setEndId(SizeType endId);
|
||||
void setPadId(SizeType padId);
|
||||
void setBadWords(std::list<VecTokens> badWords);
|
||||
void setStopWords(std::list<VecTokens> stopWords);
|
||||
void setEmbeddingBias(Tensor);
|
||||
void setSpeculativeDecodingConfig(SpeculativeDecodingConfig specDecodingConfig);
|
||||
void setPromptTuningConfig(PromptTuningConfig pTuningConfig);
|
||||
void setLoraConfig(LoraConfig loraConfig);
|
||||
void setBadWords(std::list<VecTokens> const& badWords);
|
||||
void setStopWords(std::list<VecTokens> const& stopWords);
|
||||
void setEmbeddingBias(Tensor const& embeddingBias);
|
||||
void setSpeculativeDecodingConfig(SpeculativeDecodingConfig const& specDecodingConfig);
|
||||
void setPromptTuningConfig(PromptTuningConfig const& pTuningConfig);
|
||||
void setLoraConfig(LoraConfig const& loraConfig);
|
||||
void setLogitsPostProcessorName(std::string const& logitsPostProcessorName);
|
||||
|
||||
private:
|
||||
@ -275,7 +272,7 @@ struct Result
|
||||
/// @brief The cumulative log probabilities. Size beamSize.
|
||||
std::optional<VecLogProbs> cumLogProbs;
|
||||
|
||||
/// @brief The log probabilities for each generated token. Size [beamSize, seqLen]
|
||||
/// @brief The log probabilities for each generated token. Size [beamSize, outputLen]
|
||||
std::optional<std::vector<VecLogProbs>> logProbs;
|
||||
|
||||
/// @brief The context logits. Size [promptLen, vocabSizePadded]
|
||||
@ -299,18 +296,18 @@ public:
|
||||
Response& operator=(Response&& other) noexcept;
|
||||
|
||||
/// @brief Get the id of the request for which this response was generated
|
||||
IdType getRequestId() const;
|
||||
[[nodiscard]] IdType getRequestId() const;
|
||||
|
||||
/// @brief Indicates if this response has an error or not
|
||||
bool hasError() const;
|
||||
[[nodiscard]] bool hasError() const;
|
||||
|
||||
/// @brief Get the error msg for this response
|
||||
/// Will throw an exception if hasError is false
|
||||
std::string getErrorMsg() const;
|
||||
[[nodiscard]] std::string getErrorMsg() const;
|
||||
|
||||
/// @brief Get the result for this response
|
||||
/// Will throw an exception if hasResult is true
|
||||
Result getResult() const;
|
||||
[[nodiscard]] Result getResult() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
@ -322,7 +319,6 @@ class SchedulerConfig
|
||||
{
|
||||
public:
|
||||
explicit SchedulerConfig(SchedulerPolicy policy = SchedulerPolicy::kGUARANTEED_NO_EVICT);
|
||||
~SchedulerConfig();
|
||||
|
||||
[[nodiscard]] SchedulerPolicy getPolicy() const;
|
||||
|
||||
@ -335,10 +331,10 @@ private:
|
||||
class KvCacheConfig
|
||||
{
|
||||
public:
|
||||
KvCacheConfig(bool enableBlockReuse = false, std::optional<SizeType> maxTokens = std::nullopt,
|
||||
std::optional<SizeType> maxAttentionWindow = std::nullopt,
|
||||
std::optional<SizeType> sinkTokenLength = std::nullopt,
|
||||
std::optional<FloatType> freeGpuMemoryFraction = std::nullopt);
|
||||
explicit KvCacheConfig(bool enableBlockReuse = false, std::optional<SizeType> const& maxTokens = std::nullopt,
|
||||
std::optional<SizeType> const& maxAttentionWindow = std::nullopt,
|
||||
std::optional<SizeType> const& sinkTokenLength = std::nullopt,
|
||||
std::optional<FloatType> const& freeGpuMemoryFraction = std::nullopt);
|
||||
|
||||
[[nodiscard]] bool getEnableBlockReuse() const;
|
||||
[[nodiscard]] std::optional<SizeType> getMaxTokens() const;
|
||||
@ -383,11 +379,10 @@ public:
|
||||
/// @param deviceIds The IDs of the GPUs involved in the execution of the model
|
||||
/// @param participantIds The participant IDs (MPI ranks if commType == kMPI) involved in the execution of the
|
||||
/// model. The first participant is considered to be the leader.
|
||||
ParallelConfig(CommunicationType commType = CommunicationType::kMPI,
|
||||
explicit ParallelConfig(CommunicationType commType = CommunicationType::kMPI,
|
||||
CommunicationMode commMode = CommunicationMode::kLEADER,
|
||||
std::optional<std::vector<SizeType>> deviceIds = std::nullopt,
|
||||
std::optional<std::vector<SizeType>> participantIds = std::nullopt);
|
||||
~ParallelConfig();
|
||||
|
||||
[[nodiscard]] CommunicationType getCommunicationType() const;
|
||||
[[nodiscard]] CommunicationMode getCommunicationMode() const;
|
||||
@ -396,8 +391,8 @@ public:
|
||||
|
||||
void setCommunicationType(CommunicationType type);
|
||||
void setCommunicationMode(CommunicationMode mode);
|
||||
void setDeviceIds(std::vector<SizeType> deviceIds);
|
||||
void setParticipantIds(std::vector<SizeType> participantIds);
|
||||
void setDeviceIds(std::vector<SizeType> const& deviceIds);
|
||||
void setParticipantIds(std::vector<SizeType> const& participantIds);
|
||||
|
||||
private:
|
||||
/// @brief The type of communication protocol used. Default is MPI.
|
||||
@ -417,10 +412,11 @@ private:
|
||||
class PeftCacheConfig
|
||||
{
|
||||
public:
|
||||
PeftCacheConfig(SizeType numHostModuleLayer = 0, SizeType numDeviceModuleLayer = 0, SizeType optimalAdapterSize = 8,
|
||||
SizeType maxAdapterSize = 64, SizeType numPutWorkers = 1, SizeType numEnsureWorkers = 1,
|
||||
SizeType numCopyStreams = 1, SizeType maxPagesPerBlockHost = 24, SizeType maxPagesPerBlockDevice = 8,
|
||||
std::optional<float> deviceCachePercent = std::nullopt, std::optional<size_t> hostCacheSize = std::nullopt);
|
||||
explicit PeftCacheConfig(SizeType numHostModuleLayer = 0, SizeType numDeviceModuleLayer = 0,
|
||||
SizeType optimalAdapterSize = 8, SizeType maxAdapterSize = 64, SizeType numPutWorkers = 1,
|
||||
SizeType numEnsureWorkers = 1, SizeType numCopyStreams = 1, SizeType maxPagesPerBlockHost = 24,
|
||||
SizeType maxPagesPerBlockDevice = 8, std::optional<float> const& deviceCachePercent = std::nullopt,
|
||||
std::optional<size_t> const& hostCacheSize = std::nullopt);
|
||||
|
||||
[[nodiscard]] SizeType getNumHostModuleLayer() const;
|
||||
[[nodiscard]] SizeType getNumDeviceModuleLayer() const;
|
||||
@ -462,16 +458,16 @@ private:
|
||||
/// @brief Configuration class for the model executor
|
||||
class ExecutorConfig
|
||||
{
|
||||
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;
|
||||
|
||||
public:
|
||||
ExecutorConfig(SizeType maxBeamWidth = 1, SchedulerConfig schedulerConfig = SchedulerConfig(),
|
||||
KvCacheConfig kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false, bool normalizeLogProbs = true,
|
||||
SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations,
|
||||
explicit ExecutorConfig(SizeType maxBeamWidth = 1, SchedulerConfig const& schedulerConfig = SchedulerConfig(),
|
||||
KvCacheConfig const& kvCacheConfig = KvCacheConfig(), bool enableChunkedContext = false,
|
||||
bool normalizeLogProbs = true, SizeType iterStatsMaxIterations = kDefaultIterStatsMaxIterations,
|
||||
SizeType requestStatsMaxIterations = kDefaultRequestStatsMaxIterations,
|
||||
BatchingType batchingType = BatchingType::kINFLIGHT,
|
||||
std::optional<ParallelConfig> parallelConfig = std::nullopt,
|
||||
PeftCacheConfig peftCacheConfig = PeftCacheConfig(), LogitsPostProcessorMap = {});
|
||||
std::optional<PeftCacheConfig> const& peftCacheConfig = std::nullopt,
|
||||
std::optional<LogitsPostProcessorMap> logitsPostProcessorMap = std::nullopt,
|
||||
std::optional<MedusaChoices> medusaChoices = std::nullopt);
|
||||
|
||||
[[nodiscard]] SizeType getMaxBeamWidth() const;
|
||||
[[nodiscard]] SchedulerConfig getSchedulerConfig() const;
|
||||
@ -482,20 +478,22 @@ public:
|
||||
[[nodiscard]] SizeType getRequestStatsMaxIterations() const;
|
||||
[[nodiscard]] BatchingType getBatchingType() const;
|
||||
[[nodiscard]] std::optional<ParallelConfig> getParallelConfig() const;
|
||||
[[nodiscard]] PeftCacheConfig getPeftCacheConfig() const;
|
||||
[[nodiscard]] LogitsPostProcessorMap getLogitsPostProcessorMap() const;
|
||||
[[nodiscard]] std::optional<PeftCacheConfig> getPeftCacheConfig() const;
|
||||
[[nodiscard]] std::optional<LogitsPostProcessorMap> getLogitsPostProcessorMap() const;
|
||||
[[nodiscard]] std::optional<MedusaChoices> getMedusaChoices() const;
|
||||
|
||||
void setMaxBeamWidth(SizeType maxBeamWidth);
|
||||
void setSchedulerConfig(SchedulerConfig schedulerConfig);
|
||||
void setKvCacheConfig(KvCacheConfig kvCacheConfig);
|
||||
void setSchedulerConfig(SchedulerConfig const& schedulerConfig);
|
||||
void setKvCacheConfig(KvCacheConfig const& kvCacheConfig);
|
||||
void setEnableChunkedContext(bool enableChunkedContext);
|
||||
void setNormalizeLogProbs(bool normalizeLogProbs);
|
||||
void setIterStatsMaxIterations(SizeType iterStatsMaxIterations);
|
||||
void setRequestStatsMaxIterations(SizeType requestStatsMaxIterations);
|
||||
void setBatchingType(BatchingType batchingType);
|
||||
void setParallelConfig(ParallelConfig parallelConfig);
|
||||
void setPeftCacheConfig(PeftCacheConfig peftCacheConfig);
|
||||
void setLogitsPostProcessorMap(LogitsPostProcessorMap logitsPostProcessorMap);
|
||||
void setParallelConfig(ParallelConfig const& parallelConfig);
|
||||
void setPeftCacheConfig(PeftCacheConfig const& peftCacheConfig);
|
||||
void setLogitsPostProcessorMap(LogitsPostProcessorMap const& logitsPostProcessorMap);
|
||||
void setMedusaChoices(MedusaChoices const& medusaChoices);
|
||||
|
||||
private:
|
||||
/// @brief The beam width value of requests that will be sent to the executor
|
||||
@ -524,14 +522,14 @@ private:
|
||||
|
||||
/// @brief The parallel execution configuration.
|
||||
std::optional<ParallelConfig> mParallelConfig;
|
||||
PeftCacheConfig mPeftCacheConfig;
|
||||
LogitsPostProcessorMap mLogitsPostProcessorMap;
|
||||
std::optional<PeftCacheConfig> mPeftCacheConfig;
|
||||
std::optional<LogitsPostProcessorMap> mLogitsPostProcessorMap;
|
||||
std::optional<MedusaChoices> mMedusaChoices;
|
||||
};
|
||||
|
||||
/// @brief The executor is responsible for receiving new requests and sending responses, and running the inference
|
||||
class Executor
|
||||
{
|
||||
using RequestPtr = std::shared_ptr<Request>;
|
||||
|
||||
public:
|
||||
/// @brief
|
||||
@ -539,38 +537,38 @@ public:
|
||||
/// @param modelType The type of model
|
||||
/// @param executorConfig The configuration for the executor
|
||||
/// @param comm An optional inter-process communicator configuration
|
||||
Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig executorConfig);
|
||||
Executor(std::filesystem::path const& modelPath, ModelType modelType, ExecutorConfig const& executorConfig);
|
||||
|
||||
Executor(std::vector<uint8_t> const& engineBuffer, std::string const& jsonConfigStr, ModelType modelType,
|
||||
ExecutorConfig executorConfig);
|
||||
ExecutorConfig const& executorConfig);
|
||||
|
||||
Executor(std::shared_ptr<Model> model, ExecutorConfig executorConfig);
|
||||
Executor(std::shared_ptr<Model> model, ExecutorConfig const& executorConfig);
|
||||
|
||||
~Executor();
|
||||
|
||||
/// @brief Enqueue a new request
|
||||
/// @param request The LLM request which contains input tokens and request parameters
|
||||
/// @return A unique id that identifies the request
|
||||
IdType enqueueRequest(Request request);
|
||||
[[nodiscard]] IdType enqueueRequest(Request const& request);
|
||||
|
||||
/// @brief Enqueue a batch of request
|
||||
std::vector<IdType> enqueueRequests(std::vector<Request> requests);
|
||||
[[nodiscard]] std::vector<IdType> enqueueRequests(std::vector<Request> const& requests);
|
||||
|
||||
/// @brief Await for ready responses
|
||||
/// @param id An optional request id. If not specified, responses for any request can be returned
|
||||
/// @param timeout The maximum time to wait for new responses
|
||||
/// @return A vector of responses
|
||||
std::vector<Response> awaitResponses(
|
||||
std::optional<IdType> id = std::nullopt, std::optional<std::chrono::milliseconds> timeout = std::nullopt);
|
||||
[[nodiscard]] std::vector<Response> awaitResponses(std::optional<IdType> const& requestId = std::nullopt,
|
||||
std::optional<std::chrono::milliseconds> const& timeout = std::nullopt);
|
||||
|
||||
/// @brief Get the number of ready responses
|
||||
/// @param id The request id
|
||||
/// @param requestId An optional request id
|
||||
/// @return The number of ready responses
|
||||
SizeType getNumResponsesReady(std::optional<IdType> id = std::nullopt);
|
||||
[[nodiscard]] SizeType getNumResponsesReady(std::optional<IdType> const& requestId = std::nullopt) const;
|
||||
|
||||
/// @brief Cancel the request with provided request id
|
||||
/// @param id The request id for which to cancel the response
|
||||
void cancelRequest(IdType id);
|
||||
void cancelRequest(IdType requestId);
|
||||
|
||||
/// @brief Signals the server to shutdown
|
||||
/// This call is blocking. Only returns when all requests have terminated or timeout has been reached
|
||||
@ -586,6 +584,9 @@ public:
|
||||
/// @return Request stats grouped by iterations
|
||||
std::deque<RequestStatsPerIteration> getLatestRequestStats();
|
||||
|
||||
/// @brief Indicates if the current process is allowed to enqueueRequests
|
||||
[[nodiscard]] bool canEnqueueRequests() const;
|
||||
|
||||
private:
|
||||
class Impl;
|
||||
std::unique_ptr<Impl> mImpl;
|
||||
|
||||
@ -52,7 +52,9 @@ using IterationType = std::uint64_t;
|
||||
using RandomSeedType = std::uint64_t;
|
||||
using VecLogProbs = std::vector<FloatType>;
|
||||
using StreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
|
||||
using LogitsPostProcessor = std::function<Tensor(IdType, Tensor&, BeamTokens const&, StreamPtr&)>;
|
||||
using LogitsPostProcessor = std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr&)>;
|
||||
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;
|
||||
using MedusaChoices = std::vector<std::vector<SizeType>>;
|
||||
|
||||
enum class DataType
|
||||
{
|
||||
@ -153,15 +155,29 @@ enum class ModelType
|
||||
kDECODER_ONLY = 0,
|
||||
};
|
||||
|
||||
/// @brief The batching type
|
||||
enum class BatchingType
|
||||
{
|
||||
/// @brief STATIC refers to the traditional batching scheme with a batch of requests running in lockstep until the
|
||||
/// full generation for all of them is complete. Requests in a batch are all padded up to the maximum input and
|
||||
/// output sequence length of any member of the batch.
|
||||
kSTATIC = 0,
|
||||
|
||||
/// @brief INFLIGHT refers to a scheme where newly arrived requests are dynamically incorporated into the batch
|
||||
/// under execution, and requests are returned as soon as the end condition is met without any padding.
|
||||
kINFLIGHT = 1,
|
||||
};
|
||||
|
||||
/// @brief The policy used to select the subset of available requests in each iteration of the executor generation loop
|
||||
enum class SchedulerPolicy
|
||||
{
|
||||
/// @brief MAX_UTILIZATION packs as many requests as the underlying TRT engine can support in any iteration of the
|
||||
/// InflightBatching generation loop. While this is expected to maximize GPU throughput, it might require that some
|
||||
/// requests be paused and restarted depending on peak KV cache memory availability.
|
||||
kMAX_UTILIZATION = 0,
|
||||
|
||||
/// @brief GUARANTEED_NO_EVICT uses KV cache more conservatively guaranteeing that a request, once started, will run
|
||||
/// to completion without eviction.
|
||||
kGUARANTEED_NO_EVICT = 1,
|
||||
};
|
||||
|
||||
@ -228,7 +244,7 @@ struct IterationStats
|
||||
/// @brief Ending time of this iteration
|
||||
std::string timestamp;
|
||||
/// @brief Iteration id
|
||||
SizeType iter;
|
||||
IterationType iter;
|
||||
/// @brief Number of active requests
|
||||
SizeType numActiveRequests;
|
||||
/// @brief Number of max active requests
|
||||
|
||||
@ -16,27 +16,37 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <NvInferRuntime.h>
|
||||
#include <cstdint>
|
||||
#include <mutex>
|
||||
|
||||
// Forward declarations
|
||||
namespace nvinfer1
|
||||
{
|
||||
class ILoggerFinder;
|
||||
class IPluginCreator;
|
||||
class ILogger;
|
||||
} // namespace nvinfer1
|
||||
|
||||
namespace tensorrt_llm::plugins::api
|
||||
{
|
||||
|
||||
auto constexpr kDefaultNamespace = "tensorrt_llm";
|
||||
|
||||
class LoggerFinder : public nvinfer1::ILoggerFinder
|
||||
class LoggerManager
|
||||
{
|
||||
public:
|
||||
//! Set the logger finder.
|
||||
void setLoggerFinder(nvinfer1::ILoggerFinder* finder);
|
||||
|
||||
//! Get the logger.
|
||||
nvinfer1::ILogger* findLogger() override;
|
||||
[[maybe_unused]] nvinfer1::ILogger* logger();
|
||||
|
||||
static LoggerFinder& getInstance() noexcept;
|
||||
static LoggerManager& getInstance() noexcept;
|
||||
|
||||
static nvinfer1::ILogger* defaultLogger() noexcept;
|
||||
|
||||
private:
|
||||
LoggerFinder() = default;
|
||||
LoggerManager() = default;
|
||||
|
||||
nvinfer1::ILoggerFinder* mLoggerFinder{nullptr};
|
||||
std::mutex mMutex;
|
||||
@ -47,10 +57,11 @@ private:
|
||||
extern "C"
|
||||
{
|
||||
// This function is used for explicitly registering the TRT-LLM plugins and the default logger.
|
||||
bool initTrtLlmPlugins(void* logger, char const* libNamespace = tensorrt_llm::plugins::api::kDefaultNamespace);
|
||||
bool initTrtLlmPlugins(void* logger = tensorrt_llm::plugins::api::LoggerManager::defaultLogger(),
|
||||
char const* libNamespace = tensorrt_llm::plugins::api::kDefaultNamespace);
|
||||
|
||||
// The functions below are used by TensorRT to when loading a shared plugin library with automatic registering.
|
||||
// see https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#generating-plugin-library
|
||||
TENSORRTAPI [[maybe_unused]] void setLoggerFinder([[maybe_unused]] nvinfer1::ILoggerFinder* finder);
|
||||
TENSORRTAPI [[maybe_unused]] nvinfer1::IPluginCreator* const* getPluginCreators(int32_t& nbCreators);
|
||||
[[maybe_unused]] void setLoggerFinder([[maybe_unused]] nvinfer1::ILoggerFinder* finder);
|
||||
[[maybe_unused]] nvinfer1::IPluginCreator* const* getPluginCreators(int32_t& nbCreators);
|
||||
}
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
#include <NvInferRuntime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
@ -107,6 +108,9 @@ public:
|
||||
return allocate(memoryType, ITensor::makeShape({}), type);
|
||||
}
|
||||
|
||||
//! \brief Set the contents of the given `buffer` to value.
|
||||
void setMem(IBuffer& buffer, int32_t value) const;
|
||||
|
||||
//! \brief Set the contents of the given `buffer` to zero.
|
||||
void setZero(IBuffer& buffer) const;
|
||||
|
||||
|
||||
@ -76,6 +76,20 @@ public:
|
||||
|
||||
// parameters for beam search
|
||||
TensorPtr cacheIndirection; // [maxBatchSize, beamWidth, maxSeqLen] - the k/v cache index for beam search, on gpu
|
||||
|
||||
// Medusa
|
||||
class MedusaInputs
|
||||
{
|
||||
public:
|
||||
TensorPtr medusaPaths; // [maxBatchSize, maxTokensPerStep, maxMedusaHeads + 1], on gpu
|
||||
TensorPtr medusaTreeIds; // [maxBatchSize, maxTokensPerStep], on gpu
|
||||
std::vector<std::vector<TensorPtr>>
|
||||
medusaLogits; // [maxBatchSize][maxMedusaHeads][tokensPerStep, vocabSizePadded], on gpu
|
||||
TensorPtr medusaCurTokensPerStep; // [maxBatchSize], on gpu
|
||||
TensorPtr medusaTargetTokensPerStep; // [maxBatchSize], on gpu
|
||||
};
|
||||
|
||||
std::optional<MedusaInputs> medusaInputs;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -19,7 +19,7 @@
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/common.h"
|
||||
#include "tensorrt_llm/runtime/iTensor.h"
|
||||
|
||||
#include <optional>
|
||||
#include <utility>
|
||||
|
||||
namespace tensorrt_llm::runtime
|
||||
@ -84,6 +84,18 @@ public:
|
||||
TensorPtr cacheIndirection; // [batchSize, beamWidth, maxSeqLen], k/v indirection for next generation step, on gpu
|
||||
|
||||
BeamHypotheses beamHypotheses;
|
||||
|
||||
// Medusa
|
||||
class MedusaOutputs
|
||||
{
|
||||
public:
|
||||
TensorPtr medusaNextDraftTokens; // [maxBatchSize, maxTokensPerStep], on gpu
|
||||
TensorPtr medusaAcceptedTokensLen; // [maxBatchSize], on gpu
|
||||
TensorPtr medusaAcceptedLengthsCumSum; // [maxBatchSize + 1], on gpu
|
||||
TensorPtr medusaPathsOffsets; // [maxBatchSize * maxNumHeads], on gpu
|
||||
};
|
||||
|
||||
std::optional<MedusaOutputs> medusaOutputs;
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -26,6 +26,67 @@
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
//! @details
|
||||
//! ***Mandatory inputs***
|
||||
//!
|
||||
//! * `endId`, is the token ID that marks the end of the input sequence (aka `EOS`
|
||||
//! or end-of-sequence). It's `50,256` for the GPT2 model which has a vocabulary
|
||||
//! of `50,257` tokens, for example,
|
||||
//! * `padId`, is the token ID that is used for padding (i.e. fills in the slots
|
||||
//! that are at an index greater-or-equal to the input length for padded
|
||||
//! sequences). It can be set to the same value as `endId`,
|
||||
//! * `ids`, is the tensor of input IDs. That tensor must be allocated on the GPU.
|
||||
//! When the input tensor is padded, the shape of `ids` is `[batchSize,
|
||||
//! maxInputLength]`, where `batchSize` and `maxInputLength` must respect the
|
||||
//! maximum sizes in `sessionConfig` passed to the `GptSession` constructor.
|
||||
//! When the input is packed, the shape of `ids` is `[numTokens]`, where
|
||||
//! `numTokens` is the sum of the lengths of the different sequences in the batch,
|
||||
//! * `lengths`, is the tensor of input sequence lengths. That tensor must be
|
||||
//! allocated on the GPU and contain `batchSize` values,
|
||||
//! * `packed`, indicates if the `ids` tensor is packed or padded. In this
|
||||
//! release, that flag must match the value passed to the constructor through
|
||||
//! the instance of the `ModelConfig` class. In a future release, the session
|
||||
//! may be made more flexible and automatically pad or pack the input,
|
||||
//!
|
||||
//! ***Optional inputs***
|
||||
//!
|
||||
//! * `embeddingBiasOpt`, is a tensor of floating-point values on the GPU that
|
||||
//! contains the bias to add to the logits during sampling (after the projection
|
||||
//! from hidden states to logits as the last step of the model). This tensor
|
||||
//! must have `vocabSize` elements (as defined in the `modelConfig` argument
|
||||
//! passed to the constructor),
|
||||
//! * `badWordsList`, is a tensor of integers on the GPU that encodes the list of
|
||||
//! words that have to be banned from generated sequences. Its shape is `[2,
|
||||
//! badWordsLength]`, as explained below, or `[batchSize, 2, badWordsLength]`
|
||||
//! when there is a different list for each sequence in the batch,
|
||||
//! * `stopWordsList`, is a tensor of integers on the GPU that encodes the list of
|
||||
//! words that trigger the end of the generation for a sequence. Its shape is
|
||||
//! `[2, stopWordsLength]`, as explained below, or `[batchSize, 2,
|
||||
//! stopWordsLength]` when there is a different list for each sequence in the
|
||||
//! batch,
|
||||
//! * `maxNewTokens`, is the maximum number of tokens to generate.
|
||||
//!
|
||||
//! The `badWordsList` and `stopWordsList` tensors have the same shape `[2,
|
||||
//! length]`. Let's consider an example with three words to describe the
|
||||
//! representation of those lists. The first word contains tokens `[5, 7, 3]`, the
|
||||
//! second one contains `[9, 2]` and the third one is composed of tokens `[6, 2, 4,
|
||||
//! 1]`. In total, there are 9 tokens. That's the length. The shape of the tensor
|
||||
//! is `[2, 9]`. The first row of the tensor must contain the 9 token IDs and the
|
||||
//! second row must store the
|
||||
//! [inclusive prefix-sum](https://en.wikipedia.org/wiki/Prefix_sum)
|
||||
//! of the word lengths as shown on the following diagram:
|
||||
//!
|
||||
//! ```
|
||||
//! 0 3 5 9
|
||||
//! | | | |
|
||||
//! V V V V
|
||||
//! [ 5, 7, 3, 9, 2, 6, 2, 4, 1]
|
||||
//! [ 3, 5, 9, -1, -1, -1, -1, -1, -1]
|
||||
//! ```
|
||||
//!
|
||||
//! In case all the words are made of a single token, the inner-most dimension of
|
||||
//! the tensor must be increased by 1 (i.e. the length for 4 words, each made of a
|
||||
//! single token, must be 5 instead of 4 -- the shape is `[2, 5]`).
|
||||
template <typename TTensor, typename PromptTuningParams>
|
||||
class GenericGenerationInput
|
||||
{
|
||||
|
||||
@ -25,6 +25,51 @@
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
//! @details
|
||||
//! ***Mandatory outputs***
|
||||
//!
|
||||
//! * `ids`, is a tensor that contains the output token IDs. Its shape is
|
||||
//! `[batchSize, beamWidth, maxSeqLength]` where `maxSeqLength` is the sum of
|
||||
//! `maxInputLength` and `maxNewTokens`. After generation, it contains, for each
|
||||
//! sequence, a copy of the input tokens followed by the output tokens. When a
|
||||
//! sequence is shorter than `maxSeqLength`, padding tokens are added at the end
|
||||
//! of the sequence.
|
||||
//!
|
||||
//! _Note that the shape of that tensor is different in this version of
|
||||
//! TensorRT-LLM from its shape in previous versions where it was `[maxSeqLength,
|
||||
//! batchSize, beamWidth]`_.
|
||||
//!
|
||||
//! ***Optional outputs***
|
||||
//!
|
||||
//! * `logProbs`, is a tensor of floating-point values on the GPU to store the
|
||||
//! log-prob of the generated tokens. Its shape is `[maxNewTokens, batchSize,
|
||||
//! beamWidth]`. Its shape will likely change in a future release to match the
|
||||
//! shape of the output `ids` tensor.
|
||||
//! * `contextLogits`, is a tensor of values on the GPU (same datatype as the
|
||||
//! computation type) to store the logits for the context. Its shape is
|
||||
//! `[batchSize, maxSequenceLength, vocabSizePadded]`. If use `remove_input_padding`, its shape is `[packedSize,
|
||||
//! vocabSizePadded]`. This buffer will only be filled in if the TensorRT engine was built with the
|
||||
//! `gather_context_logits` or `gather_all_token_logits` parameter enabled.
|
||||
//!
|
||||
//! After inference is complete, you can get the context logits in `GenerationOutput.contextLogits`, these are
|
||||
//! variables on the GPU. For specific acquisition methods, please refer to the example of
|
||||
//! [gptSessionBenchmark.cpp](https://github.com/NVIDIA/TensorRT-LLM/blob/main/benchmarks/cpp/gptSessionBenchmark.cpp).
|
||||
//!
|
||||
//! It is important to point out
|
||||
//! that enabling the computation may have an impact on performance (the language modeling head (LM head) has to
|
||||
//! perform a matrix multiplication on all the context tokens instead of a just the last one).
|
||||
//! * `generationLogits`, is a tensor of values on the GPU (same datatype as the
|
||||
//! computation type) to store the logits for the generation. Its shape is
|
||||
//! `[batchSize, beamWidth, maxOutputLen, vocabSizePadded]`. This buffer will only be
|
||||
//! filled in if the TensorRT engine was built with the `gather_generation_logits` or
|
||||
//! `gather_all_token_logits` parameter enabled.
|
||||
//!
|
||||
//! Generation logits can also be obtained through `GenerationOutput.generationLogits` after inference is completed.
|
||||
//! * `onTokenGenerated`, is a callback function invoked in the generation loop to
|
||||
//! pass newly generated tokens to the caller while the loop continues to
|
||||
//! execute. An implementation of that callback must accept the output `ids`
|
||||
//! tensor, the generation `step` and a boolean flag that indicates if the
|
||||
//! generation is complete.
|
||||
template <typename TTensor>
|
||||
class GenericGenerationOutput
|
||||
{
|
||||
|
||||
@ -81,7 +81,8 @@ public:
|
||||
|
||||
static std::unique_ptr<IGptDecoder> create(DecodingMode const& mode, nvinfer1::DataType dtype, size_t maxBatchSize,
|
||||
size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
|
||||
BufferManager::CudaStreamPtr const& stream);
|
||||
BufferManager::CudaStreamPtr const& stream, std::optional<runtime::SizeType> maxTokensPerStep = std::nullopt,
|
||||
std::optional<runtime::SizeType> maxNumMedusaHeads = std::nullopt);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -93,7 +94,9 @@ public:
|
||||
using TensorPtr = std::shared_ptr<ITensor>;
|
||||
|
||||
GptDecoder(DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize,
|
||||
size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream);
|
||||
size_t vocabSizePadded, size_t maxSequenceLength, CudaStreamPtr const& stream,
|
||||
std::optional<runtime::SizeType> maxTokensPerStep = std::nullopt,
|
||||
std::optional<runtime::SizeType> maxNumMedusaHeads = std::nullopt);
|
||||
|
||||
void setup(SamplingConfig const& samplingConfig, size_t batchSize, SizeType maxSequenceLength,
|
||||
std::optional<TensorPtr> const& batchSlots = std::nullopt) override;
|
||||
@ -119,20 +122,23 @@ private:
|
||||
SamplingConfig mSamplingConfig;
|
||||
|
||||
cudaDeviceProp mProp; // Avoid dangling pointers in mDynamicDecodeLayer
|
||||
|
||||
size_t mMaxBatchSize;
|
||||
};
|
||||
|
||||
inline std::unique_ptr<IGptDecoder> IGptDecoder::create(DecodingMode const& mode, nvinfer1::DataType dtype,
|
||||
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, size_t maxSequenceLength,
|
||||
BufferManager::CudaStreamPtr const& stream)
|
||||
BufferManager::CudaStreamPtr const& stream, std::optional<runtime::SizeType> maxTokensPerStep,
|
||||
std::optional<runtime::SizeType> maxNumMedusaHeads)
|
||||
{
|
||||
switch (dtype)
|
||||
{
|
||||
case nvinfer1::DataType::kFLOAT:
|
||||
return std::make_unique<GptDecoder<float>>(
|
||||
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, maxSequenceLength, stream);
|
||||
return std::make_unique<GptDecoder<float>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
|
||||
maxSequenceLength, stream, maxTokensPerStep, maxNumMedusaHeads);
|
||||
case nvinfer1::DataType::kHALF:
|
||||
return std::make_unique<GptDecoder<half>>(
|
||||
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, maxSequenceLength, stream);
|
||||
return std::make_unique<GptDecoder<half>>(mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded,
|
||||
maxSequenceLength, stream, maxTokensPerStep, maxNumMedusaHeads);
|
||||
default: return nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
@ -43,7 +43,7 @@ public:
|
||||
//! Setup the decoder before calling `forward()`
|
||||
void setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth, SizeType maxAttentionWindow,
|
||||
SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep, bool fusedDecoder,
|
||||
nvinfer1::DataType dtype) override;
|
||||
nvinfer1::DataType dtype, GptModelConfig const& modelConfig) override;
|
||||
|
||||
void newBatch(
|
||||
GenerationInput const& inputs, GenerationOutput const& outputs, SamplingConfig const& samplingConfig) override;
|
||||
@ -156,13 +156,19 @@ public:
|
||||
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
|
||||
[[nodiscard]] TensorPtr getNextDraftTokens() const override
|
||||
{
|
||||
return mNextDraftTokens;
|
||||
return mJointDecodingOutput->medusaOutputs->medusaNextDraftTokens;
|
||||
}
|
||||
|
||||
//! @returns [batchSize], lengths of the predicted draft tokens for next step, on gpu
|
||||
[[nodiscard]] TensorPtr getNextDraftTokenLengths() const override
|
||||
//! @returns [batchSize + 1], exclusive sum of accepted draft token lengths, on gpu
|
||||
[[nodiscard]] TensorPtr getMedusaAcceptedLengthsCumSum() const override
|
||||
{
|
||||
return mNextDraftTokenLengths;
|
||||
return mJointDecodingOutput->medusaOutputs->medusaAcceptedLengthsCumSum;
|
||||
}
|
||||
|
||||
//! @returns [batchSize * maxMedusaHeads], accepted paths packed into continuous tensor, on gpu
|
||||
[[nodiscard]] TensorPtr getMedusaAcceptedPackedPaths() const override
|
||||
{
|
||||
return mJointDecodingOutput->medusaOutputs->medusaPathsOffsets;
|
||||
}
|
||||
|
||||
private:
|
||||
@ -172,6 +178,27 @@ private:
|
||||
//! @brief Initialize the decoder at `batchIdx` with a new `request`.
|
||||
void newRequest(SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
|
||||
|
||||
//! @brief Allocate buffers for medusa decoding.
|
||||
void allocateMedusaBuffers();
|
||||
|
||||
//! @brief Setup buffers for medusa decoding.
|
||||
void setupMedusa(GptModelConfig const& modelConfig);
|
||||
|
||||
//! @brief Setups decoder internal tensors for new speculative decoding request
|
||||
void newRequestSpeculativeDecoding(
|
||||
SizeType batchIdx, decoder_batch::Request const& request, SamplingConfig const& samplingConfig);
|
||||
|
||||
//! @brief Setups decoder internal tensors for new Medusa request
|
||||
void newRequestMedusa(SizeType batchIdx, decoder_batch::Request const& request);
|
||||
|
||||
//! @brief Asynchronously calls unfused decoder for whole batch in loop
|
||||
void forwardAsyncUnfusedDecoder(
|
||||
SizeType step, decoder_batch::Output& output, decoder_batch::Input const& input, CudaEvent const& eventStart);
|
||||
|
||||
//! @brief Asynchronously calls fused decoder for whole batch
|
||||
void forwardAsyncFusedDecoder(
|
||||
SizeType step, decoder_batch::Output& output, decoder_batch::Input const& input, CudaEvent const& eventStart);
|
||||
|
||||
private:
|
||||
std::size_t const mVocabSize;
|
||||
std::size_t const mVocabSizePadded;
|
||||
@ -200,7 +227,7 @@ private:
|
||||
TensorPtr mFinishedSum;
|
||||
std::vector<SizeType> mMaxNewTokens;
|
||||
std::vector<SizeType> mBeamWidths;
|
||||
std::vector<SizeType> mGeneratedTokensPerStep;
|
||||
std::vector<SizeType> mGeneratedTokensPerEngineStep;
|
||||
|
||||
TensorPtr mFinishedSteps; // [maxTokensPerStep, batchSize, beamWidth] finished states of type FinishedState
|
||||
// for each generated token of maxTokensPerStep, on gpu
|
||||
@ -216,16 +243,17 @@ private:
|
||||
TensorPtr mBatchSlotsAcceptTokens; // [maxBatchSize], int32_t, address map, pinned
|
||||
TensorPtr mBatchSlotsAcceptLogits; // [maxBatchSize], int32_t, address map, pinned
|
||||
TensorPtr mTargetLogitsPtrs; // [maxBatchSize], float*, pointers to target logits, pinned
|
||||
TensorPtr mNextDraftTokens;
|
||||
TensorPtr mNextDraftTokenLengths;
|
||||
SizeType mMaxSequenceLength{};
|
||||
SizeType mMaxAttentionWindow{};
|
||||
SizeType mSinkTokenLength{};
|
||||
SizeType mActualBatchSize{};
|
||||
SizeType mMaxTokensPerStep{};
|
||||
SizeType mMaxTokensPerEngineStep{};
|
||||
SizeType mMaxStopWordsLen{};
|
||||
SizeType mMaxBadWordsLen{};
|
||||
// How many tokens for one request can be processed per mDecoders call
|
||||
SizeType mMaxTokensPerDecoderStep{};
|
||||
|
||||
bool mFusedDecoder{false};
|
||||
bool mUseMedusa{false};
|
||||
};
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -25,13 +25,21 @@
|
||||
namespace tensorrt_llm::runtime
|
||||
{
|
||||
|
||||
struct MambaConfig
|
||||
{
|
||||
SizeType dState = 0;
|
||||
SizeType dConv = 0;
|
||||
SizeType expand = 0;
|
||||
};
|
||||
|
||||
class GptModelConfig
|
||||
{
|
||||
public:
|
||||
enum class ModelVariant : std::int32_t
|
||||
{
|
||||
kGpt = 0,
|
||||
kGlm = 1, // https://github.com/THUDM/GLM and https://github.com/THUDM/ChatGLM-6B
|
||||
kGlm = 1, // https://github.com/THUDM/GLM and https://github.com/THUDM/ChatGLM-6B
|
||||
kMamba = 2, // https://github.com/state-spaces/mamba
|
||||
};
|
||||
|
||||
explicit GptModelConfig(
|
||||
@ -44,8 +52,10 @@ public:
|
||||
, mSizePerHead(mHiddenSize / mNbHeads)
|
||||
, mDataType(dtype)
|
||||
, mUseGptAttentionPlugin(false)
|
||||
, mUseMambaConv1dPlugin(false)
|
||||
, mInputPacked{false}
|
||||
, mPagedKvCache{false}
|
||||
, mPagedState{false}
|
||||
, mTokensPerBlock{64}
|
||||
, mQuantMode{common::QuantMode::none()}
|
||||
, mMaxBatchSize(0)
|
||||
@ -128,6 +138,16 @@ public:
|
||||
mUseGptAttentionPlugin = useGptAttentionPlugin;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr useMambaConv1dPlugin() const noexcept
|
||||
{
|
||||
return mUseMambaConv1dPlugin;
|
||||
}
|
||||
|
||||
void constexpr useMambaConv1dPlugin(bool useMambaConv1dPlugin) noexcept
|
||||
{
|
||||
mUseMambaConv1dPlugin = useMambaConv1dPlugin;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr usePackedInput() const noexcept
|
||||
{
|
||||
return mInputPacked;
|
||||
@ -148,6 +168,16 @@ public:
|
||||
mPagedKvCache = pagedKvCache;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr usePagedState() const noexcept
|
||||
{
|
||||
return mPagedState;
|
||||
}
|
||||
|
||||
void constexpr usePagedState(bool pagedState) noexcept
|
||||
{
|
||||
mPagedState = pagedState;
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType constexpr getTokensPerBlock() const noexcept
|
||||
{
|
||||
return mTokensPerBlock;
|
||||
@ -170,7 +200,8 @@ public:
|
||||
|
||||
[[nodiscard]] bool constexpr supportsInflightBatching() const noexcept
|
||||
{
|
||||
return mUseGptAttentionPlugin && mInputPacked && mPagedKvCache;
|
||||
return (isTransformerBased() && mUseGptAttentionPlugin && mInputPacked && mPagedKvCache)
|
||||
|| (isSsmBased() && mUseMambaConv1dPlugin && mInputPacked && mPagedState);
|
||||
}
|
||||
|
||||
[[nodiscard]] SizeType constexpr getMaxBatchSize() const noexcept
|
||||
@ -368,6 +399,47 @@ public:
|
||||
mMedusaModule = medusaModule;
|
||||
}
|
||||
|
||||
[[nodiscard]] nvinfer1::DataType getKvDataType() const noexcept
|
||||
{
|
||||
if (getQuantMode().hasFp8KvCache())
|
||||
{
|
||||
return nvinfer1::DataType::kFP8;
|
||||
}
|
||||
else if (getQuantMode().hasInt8KvCache())
|
||||
{
|
||||
return nvinfer1::DataType::kINT8;
|
||||
}
|
||||
else
|
||||
{
|
||||
return getDataType();
|
||||
}
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr isTransformerBased() const noexcept
|
||||
{
|
||||
return mModelVariant == ModelVariant::kGpt || mModelVariant == ModelVariant::kGlm;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool hasMambaConfig() const noexcept
|
||||
{
|
||||
return mMambaConfig.has_value();
|
||||
}
|
||||
|
||||
[[nodiscard]] std::optional<MambaConfig> getMambaConfig() const noexcept
|
||||
{
|
||||
return mMambaConfig;
|
||||
}
|
||||
|
||||
void setMambaConfig(MambaConfig const& mambaConfig) noexcept
|
||||
{
|
||||
mMambaConfig = mambaConfig;
|
||||
}
|
||||
|
||||
[[nodiscard]] bool constexpr isSsmBased() const noexcept
|
||||
{
|
||||
return mModelVariant == ModelVariant::kMamba;
|
||||
}
|
||||
|
||||
private:
|
||||
SizeType mVocabSize;
|
||||
SizeType mNbLayers;
|
||||
@ -377,8 +449,10 @@ private:
|
||||
SizeType mSizePerHead;
|
||||
nvinfer1::DataType mDataType;
|
||||
bool mUseGptAttentionPlugin;
|
||||
bool mUseMambaConv1dPlugin;
|
||||
bool mInputPacked;
|
||||
bool mPagedKvCache;
|
||||
bool mPagedState;
|
||||
SizeType mTokensPerBlock;
|
||||
common::QuantMode mQuantMode;
|
||||
SizeType mMaxBatchSize;
|
||||
@ -404,5 +478,7 @@ private:
|
||||
SizeType mMaxLoraRank;
|
||||
|
||||
std::optional<MedusaModule> mMedusaModule;
|
||||
|
||||
std::optional<MambaConfig> mMambaConfig;
|
||||
};
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -83,13 +83,23 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
// The maximum number of sequences in a batch
|
||||
SizeType maxBatchSize;
|
||||
// The maximum width of the beams in beam-search
|
||||
SizeType maxBeamWidth;
|
||||
// The length of the longest input sequence
|
||||
SizeType maxSequenceLength;
|
||||
// Whether the session will use a different decoder per request.
|
||||
// It must be set to `true` when running in-flight batching
|
||||
bool decoderPerRequest{false};
|
||||
// Whether the session will use CUDA graphs for the engine execution in generation phase
|
||||
bool cudaGraphMode{false};
|
||||
KvCacheConfig kvCacheConfig{};
|
||||
// The micro batch size to be used in context phase.
|
||||
// Batches entered in `GptSession::generation` will be split into smaller micro batches of this size
|
||||
std::optional<SizeType> ctxMicroBatchSize = std::nullopt;
|
||||
// The micro batch size to be used in generation phase.
|
||||
// Batches entered in `GptSession::generation` will be split into smaller micro batches of this size.
|
||||
std::optional<SizeType> genMicroBatchSize = std::nullopt;
|
||||
std::optional<DecodingMode> decodingMode = std::nullopt;
|
||||
bool normalizeLogProbs = true;
|
||||
@ -134,6 +144,12 @@ public:
|
||||
CudaEvent end;
|
||||
};
|
||||
|
||||
//! @param sessionConfig Configuration of the session,
|
||||
//! @param modelConfig Description of the model,
|
||||
//! @param worldConfig Description of the environment,
|
||||
//! @param engineBuffer The compiled TensorRT engine (const void*),
|
||||
//! @param engineSize The size in bytes of the TensorRT engine (size_t),
|
||||
//! @param logger The optional logger.
|
||||
GptSession(Config const& sessionConfig, GptModelConfig const& modelConfig, WorldConfig const& worldConfig,
|
||||
void const* engineBuffer, std::size_t engineSize, LoggerPtr logger = nullptr);
|
||||
|
||||
@ -176,6 +192,31 @@ public:
|
||||
|
||||
[[nodiscard]] nvinfer1::DataType getLogitDataType() const;
|
||||
|
||||
//! @brief This function performs the generation loop.
|
||||
//! @details Given input tensors to read from, output tensors to populate, that member function
|
||||
//! can be produced or each sequence has reached completion (due to the production
|
||||
//! will run the generation loop until it reaches the maximum number of tokens that
|
||||
//! of "end-of-sequence" or a word in the list of "stop words"). The pseudo-code of
|
||||
//! that function looks like (member function names were changed to keep the
|
||||
//! presentation simple):
|
||||
//!
|
||||
//! ```cpp
|
||||
//! // Have all the sequences in the batch reached completion?
|
||||
//! bool allFinished = false;
|
||||
//!
|
||||
//! // Until all sequences are finished or the number of steps reaches the limit...
|
||||
//! for (int step = 0; !allFinished && step < maxNewTokens; ++step) {
|
||||
//!
|
||||
//! // Trigger the computation of the logits...
|
||||
//! computeLogits(...);
|
||||
//!
|
||||
//! // Run the sampling to produce a token (for each active sequence) from the logits.
|
||||
//! allFinished = generateTokensFromLogits(...);
|
||||
//!
|
||||
//! // Callback to stream the output tokens while the generation loop continues.
|
||||
//! onTokenGenerated(...);
|
||||
//! }
|
||||
//! ```
|
||||
void generate(GenerationOutput& outputs, GenerationInput const& inputs, SamplingConfig const& samplingConfig,
|
||||
std::shared_ptr<GenerationProfiler> const generationProfiler = nullptr);
|
||||
|
||||
|
||||
@ -46,7 +46,7 @@ public:
|
||||
, endId{endId}
|
||||
, computeCumLogProbs(false)
|
||||
, computeLogProbs(false)
|
||||
, generatedTokensPerStep(1)
|
||||
, generatedTokensPerEngineStep(1)
|
||||
{
|
||||
}
|
||||
|
||||
@ -66,7 +66,9 @@ public:
|
||||
|
||||
bool computeCumLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||
bool computeLogProbs; // boolean that controls if cumLogProbs should be computed for that request
|
||||
SizeType generatedTokensPerStep;
|
||||
SizeType generatedTokensPerEngineStep;
|
||||
TensorPtr medusaPaths; // [tokensPerStep, medusaHeads + 1], on gpu
|
||||
TensorPtr medusaTreeIds; // [tokensPerStep], on gpu
|
||||
};
|
||||
|
||||
class Input
|
||||
@ -109,6 +111,8 @@ public:
|
||||
// parameters for beam search
|
||||
TensorConstPtr cacheIndirection; // [batchSize, maxBeamWidth, maxSeqLen] - indices into KV cache of different rays
|
||||
// within one beam for beam search, on gpu
|
||||
std::vector<std::vector<TensorConstPtr>>
|
||||
medusaLogits; // [maxBatchSize][maxNumHeads][tokensPerStep, vocabSizePadded]
|
||||
};
|
||||
|
||||
using Output = decoder::Output;
|
||||
@ -183,8 +187,11 @@ public:
|
||||
//! @returns [batchSize, maxTokensPerStep-1], predicted draft tokens for next step, on gpu
|
||||
virtual TensorPtr getNextDraftTokens() const = 0;
|
||||
|
||||
//! @returns [batchSize], lengths of the predicted draft tokens for next step, on gpu
|
||||
virtual TensorPtr getNextDraftTokenLengths() const = 0;
|
||||
//! @returns [batchSize + 1], exclusive sum of accepted draft token lengths, on gpu
|
||||
virtual TensorPtr getMedusaAcceptedLengthsCumSum() const = 0;
|
||||
|
||||
//! @returns [batchSize * maxMedusaHeads], accepted paths packed into continuous tensor, on gpu
|
||||
virtual TensorPtr getMedusaAcceptedPackedPaths() const = 0;
|
||||
|
||||
protected:
|
||||
IGptDecoderBatch() = default;
|
||||
|
||||
@ -75,7 +75,7 @@ public:
|
||||
//! Setup the decoder before calling `forward()`, also calls reshapeBuffers
|
||||
virtual void setup(DecodingMode const& mode, SizeType maxBatchSize, SizeType maxBeamWidth,
|
||||
SizeType maxAttentionWindow, SizeType sinkTokenLength, SizeType maxSequenceLength, SizeType maxTokensPerStep,
|
||||
bool fusedDecoder, nvinfer1::DataType dtype)
|
||||
bool fusedDecoder, nvinfer1::DataType dtype, GptModelConfig const& modelConfig)
|
||||
= 0;
|
||||
|
||||
//! @brief Initialize the decoder with new batch of inputs.
|
||||
|
||||
@ -34,7 +34,6 @@ private:
|
||||
template <typename T>
|
||||
using OptVec = std::optional<std::vector<T>>;
|
||||
|
||||
private:
|
||||
template <typename T>
|
||||
static OptVec<T> fuseValues(
|
||||
std::vector<SamplingConfig> const& configs, std::function<OptVec<T>(SizeType ci)> accessor)
|
||||
@ -91,6 +90,8 @@ public:
|
||||
earlyStopping = fuseValues<SizeType>(configs, [&configs](SizeType ci) { return configs[ci].earlyStopping; });
|
||||
draftAcceptanceThreshold
|
||||
= fuseValues<FloatType>(configs, [&configs](SizeType ci) { return configs[ci].draftAcceptanceThreshold; });
|
||||
topKMedusaHeads = fuseValues<std::vector<runtime::SizeType>>(
|
||||
configs, [&configs](SizeType ci) { return configs[ci].topKMedusaHeads; });
|
||||
}
|
||||
|
||||
explicit SamplingConfig(executor::SamplingConfig const& samplingConfig,
|
||||
@ -152,7 +153,22 @@ public:
|
||||
// speculative decoding, only the first value is used (in gptDecoderBatch.cpp)
|
||||
OptVec<FloatType> draftAcceptanceThreshold; // [1] or [batch_size]
|
||||
|
||||
// medusa params
|
||||
OptVec<std::vector<runtime::SizeType>> topKMedusaHeads; // [batchSize, maxMedusaHeads]
|
||||
|
||||
std::optional<bool> normalizeLogProbs;
|
||||
|
||||
bool operator==(SamplingConfig const& other) const
|
||||
{
|
||||
return beamWidth == other.beamWidth && temperature == other.temperature && minLength == other.minLength
|
||||
&& repetitionPenalty == other.repetitionPenalty && presencePenalty == other.presencePenalty
|
||||
&& frequencyPenalty == other.frequencyPenalty && topK == other.topK && topP == other.topP
|
||||
&& randomSeed == other.randomSeed && topPDecay == other.topPDecay && topPMin == other.topPMin
|
||||
&& topPResetIds == other.topPResetIds && beamSearchDiversityRate == other.beamSearchDiversityRate
|
||||
&& lengthPenalty == other.lengthPenalty && earlyStopping == other.earlyStopping
|
||||
&& draftAcceptanceThreshold == other.draftAcceptanceThreshold && topKMedusaHeads == other.topKMedusaHeads
|
||||
&& normalizeLogProbs == other.normalizeLogProbs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::runtime
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:fd8e608359009dffbcc5817cd96531254c3ad13df7030b3b7cdf2d609fea99e1
|
||||
size 2408892
|
||||
oid sha256:ba545e1931c9405b75028b019ac3949ec5cec57c304aaa10ea6c854f572225b1
|
||||
size 2856456
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e59449c78d8682be1f0671fa6d8073c71eb37ae452417b70f70bb7db4a68f48b
|
||||
size 2434826
|
||||
oid sha256:8ef69cd446d54a1c876237f812839e6ecd9174c327edc5ff4f6594bb2b203aae
|
||||
size 2885046
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
ae7c209c38b4c343b0fc49decff6fed5 libtensorrt_llm_batch_manager_static.a
|
||||
f2fdaabe328c0eb1e46e8ded7bec4d87 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
d2cce02a8 commit
|
||||
cd113ef1af7d78ac0791d4323a8ef370 libtensorrt_llm_batch_manager_static.a
|
||||
c3583c5524c71f2cd5ae7a3bba864377 libtensorrt_llm_batch_manager_static.pre_cxx11.a
|
||||
45a8cb4ea commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:88e519a38b4172b960083acf12db2ce17c880ce355cc1c9361f1ae85d839551d
|
||||
size 2377646
|
||||
oid sha256:32ca7c2a6701457ecb537a56d9558fb62d35ec5443905d63f1f1a288d8f48f87
|
||||
size 2780748
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:54199fac4bbe94dc314bed8c889753cbb00d2bad1e672384a350dc2b97e4a0b1
|
||||
size 2343620
|
||||
oid sha256:c8fdf3d223bb7e0a5eeffbea0a82e50a8e0ec3815b274cdc95d6fb1c36f2178d
|
||||
size 2755044
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:57a1c54097341e561ae44f5ae69fa6a7e33061e2d0451d2f42a37f22993a22bb
|
||||
size 818584
|
||||
oid sha256:36f02388a9bd2ae3d45f0d6480bd95cd99f8ea30eebf1a315b8d54e742fed479
|
||||
size 846308
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:3d443d55b92501991a6102c523d46ddfdf620fa5ab37abcee3e2d6ee4c4d9e90
|
||||
size 833262
|
||||
oid sha256:c707f67abccca217d81e8d85e361b6d214131045763df5f806cb789157ea4f80
|
||||
size 857730
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
b92b19f8d7eff851dadb8a8e3010a565 libtensorrt_llm_executor_static.a
|
||||
a546902e11b24c1b890fd913c3e844c5 libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
d2cce02a8 commit
|
||||
a552c727c128f6ca402ddc119d295ab0 libtensorrt_llm_executor_static.a
|
||||
38ad482b0be0996970bc572622967acf libtensorrt_llm_executor_static.pre_cxx11.a
|
||||
45a8cb4ea commit
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9233382570d3c9c5417ed1f279c234d323b4dd465bbdca86612e137fabfb9962
|
||||
size 866182
|
||||
oid sha256:cc2d59c4878e74f7e38a65187ed303a77b43a3b71753b3e4dcc99a937ccbcdf8
|
||||
size 884870
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:03ee314aa8ca65abf013c6e5106b701defb5c1435d5fe8879829952c1d2cab1f
|
||||
size 812078
|
||||
oid sha256:dc7e967c9aa7ef50227a791c670fe71a9bdef907ce45d3282955ebd5e2ead88f
|
||||
size 837988
|
||||
|
||||
@ -303,6 +303,25 @@ extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_80_sm70_c
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_128_sm70_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin[];
|
||||
// FP8
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin[];
|
||||
extern unsigned char cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin[];
|
||||
|
||||
extern uint32_t cubin_fmha_v2_fp16_64_32_ldgsts_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_fp16_128_32_ldgsts_sm90_cu_cubin_len;
|
||||
@ -582,6 +601,25 @@ extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_80_sm70_cu_cub
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_128_sm70_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin_len;
|
||||
// FP8
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
extern uint32_t cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len;
|
||||
|
||||
static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{
|
||||
@ -1511,7 +1549,62 @@ static const struct FusedMultiHeadAttentionKernelMetaInfoV2
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_kernel", 98304, 128, 0, false, true, false, 1, false, false },
|
||||
{ DATA_TYPE_FP16, 0, 160, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_Causal_S_160_sm70_kernel_nl", 98304, 128, 64, false, true, false, 1, false, false },
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_kernel", 98304, 128, 0, false, true, false, 1, false, false },
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_kernel_nl", 98304, 128, 64, false, true, false, 1, false, false }
|
||||
{ DATA_TYPE_FP16, 0, 256, kSM_70, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin, cubin_fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_cu_cubin_len, "fmha_v2_flash_attention_fp16_64_16_Causal_S_256_sm70_kernel_nl", 98304, 128, 64, false, true, false, 1, false, false },
|
||||
// FP8
|
||||
{ DATA_TYPE_E4M3, 0, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_kernel", 53504, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_causal_tma_ws_sm90_kernel", 53504, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_sliding_window_causal_tma_ws_sm90_kernel", 53504, 384, 64, false, true, true, 2, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_causal_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_sliding_window_causal_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 2, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_causal_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_sliding_window_causal_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 2, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_80_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_80_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_96_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_96_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_104_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_104_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_128_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_128_sliding_window_causal_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_160_causal_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_160_sliding_window_causal_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 2, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 0, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_256_causal_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 1, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_256_sliding_window_causal_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 2, false, false},
|
||||
{ DATA_TYPE_E4M3, 0, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_kernel", 53504, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_causal_alibi_tma_ws_sm90_kernel", 53504, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 32, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_32_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_32_sliding_window_causal_alibi_tma_ws_sm90_kernel", 53504, 384, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_causal_alibi_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 40, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_40_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_40_sliding_window_causal_alibi_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_causal_alibi_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 64, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_256_S_64_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_256_S_64_sliding_window_causal_alibi_tma_ws_sm90_kernel", 106752, 384, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_80_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 80, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_80_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_80_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_96_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 96, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_96_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_96_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_104_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 104, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_104_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_104_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_128_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 128, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_128_S_128_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_128_S_128_sliding_window_causal_alibi_tma_ws_sm90_kernel", 180480, 384, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_160_causal_alibi_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 160, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_160_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_160_sliding_window_causal_alibi_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 2, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 0, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_256_causal_alibi_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 1, true, false},
|
||||
{ DATA_TYPE_E4M3, 0, 256, kSM_90, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin, cubin_fmha_v2_flash_attention_e4m3_64_64_S_256_alibi_tma_ws_sm90_cu_cubin_len, "fmha_v2_flash_attention_e4m3_64_64_S_256_sliding_window_causal_alibi_tma_ws_sm90_kernel", 131328, 384, 64, false, true, true, 2, true, false}
|
||||
};
|
||||
|
||||
// clang-format on
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@ -86,7 +86,9 @@ public:
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
(sm == kSM_70 || sm == kSM_80 || sm == kSM_86 || sm == kSM_89 || sm == kSM_90), "Unsupported architecture");
|
||||
TLLM_CHECK_WITH_INFO((mDataType == DATA_TYPE_FP16 || mDataType == DATA_TYPE_BF16), "Unsupported data type");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
(mDataType == DATA_TYPE_FP16 || mDataType == DATA_TYPE_BF16 || mDataType == DATA_TYPE_E4M3),
|
||||
"Unsupported data type");
|
||||
|
||||
pagedKVXmmaKernel = getPagedKVXMMAKernelsV2(mDataType, sm);
|
||||
xmmaKernel = getXMMAKernelsV2(mDataType, sm);
|
||||
@ -117,7 +119,7 @@ public:
|
||||
float const scale_softmax = 1.f; // Seems to be only required for int8
|
||||
float const scale_bmm2 = 1.f;
|
||||
|
||||
Data_type scale_type = mLaunchParams.force_fp32_acc ? DATA_TYPE_FP32 : mDataType;
|
||||
Data_type scale_type = mLaunchParams.force_fp32_acc || mDataType == DATA_TYPE_E4M3 ? DATA_TYPE_FP32 : mDataType;
|
||||
// Use exp2f optimization for warp-specialized ws kernels on Hopper.
|
||||
if (mLaunchParams.useBase2ExpTrick)
|
||||
{
|
||||
@ -130,6 +132,7 @@ public:
|
||||
set_alpha(params.scale_bmm1, scale_bmm1, scale_type);
|
||||
}
|
||||
set_alpha(params.scale_softmax, scale_softmax, scale_type);
|
||||
// Host scale_bmm2 will not be used.
|
||||
set_alpha(params.scale_bmm2, scale_bmm2, scale_type);
|
||||
|
||||
params.b = b;
|
||||
@ -138,7 +141,7 @@ public:
|
||||
params.d = mHeadSize;
|
||||
params.sliding_window_size = sliding_window_size;
|
||||
|
||||
params.o_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
|
||||
params.o_stride_in_bytes = get_size_in_bytes(mNumHeads * mHeadSize, mDataType);
|
||||
|
||||
// Total sequence length needed by TMA descriptor
|
||||
// it should be actual total seq length if non-padded input is given.
|
||||
@ -153,8 +156,8 @@ public:
|
||||
}
|
||||
|
||||
// Support packed QKV.
|
||||
void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen, bool const has_alibi,
|
||||
bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||
void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
|
||||
float const* scale_bmm2_d, bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||
{
|
||||
|
||||
// Determine launch parameters.
|
||||
@ -169,7 +172,14 @@ public:
|
||||
bool const isSm90 = (sm == kSM_90);
|
||||
bool const isSm8x = (sm == kSM_86 || sm == kSM_89);
|
||||
bool const isSm80 = (sm == kSM_80);
|
||||
if (isSm70)
|
||||
|
||||
// Only warp-specialized FMHA kernels support FP8 on Hopper.
|
||||
if (isSm90 && mDataType == DATA_TYPE_E4M3)
|
||||
{
|
||||
mLaunchParams.flash_attention = true;
|
||||
mLaunchParams.force_unroll = true;
|
||||
}
|
||||
else if (isSm70)
|
||||
{
|
||||
mLaunchParams.flash_attention = true;
|
||||
mLaunchParams.force_unroll = true; // need more profile
|
||||
@ -234,7 +244,10 @@ public:
|
||||
|
||||
// Set kernel parameters.
|
||||
setup_params(mParams, b, s, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||
mParams.qkv_stride_in_bytes = (mNumHeads + 2 * mParams.h_kv) * mHeadSize * sizeof(half);
|
||||
// TODO: move this to setup_params when fp8 paged kv fmha is supported.
|
||||
// TRT doesn't support host scales. Use device scales instead.
|
||||
mParams.scale_bmm2_d = reinterpret_cast<uint32_t const*>(scale_bmm2_d);
|
||||
mParams.qkv_stride_in_bytes = get_size_in_bytes((mNumHeads + 2 * mParams.h_kv) * mHeadSize, mDataType);
|
||||
}
|
||||
|
||||
// Support paged_kv_cache and chunked_attention.
|
||||
@ -309,6 +322,7 @@ public:
|
||||
mLaunchParams.attention_mask_type = ContextAttentionMaskType::SLIDING_WINDOW_CAUSAL;
|
||||
}
|
||||
|
||||
// TODO: add paged kv FP8 FMHA.
|
||||
setup_params(
|
||||
mPagedKVParams, b, s_q, s_kv, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||
mPagedKVParams.q_stride_in_bytes = mNumHeads * mHeadSize * sizeof(half);
|
||||
@ -320,7 +334,7 @@ public:
|
||||
void set_tma_descriptors()
|
||||
{
|
||||
// split D into multiple groups in order to match the TMA swizzle mode (128B)
|
||||
const uint32_t d_in_bytes = mLaunchParams.padded_d * sizeof(uint16_t);
|
||||
const uint32_t d_in_bytes = get_size_in_bytes(mLaunchParams.padded_d, mDataType);
|
||||
const uint32_t d_groups = d_in_bytes > 128 ? d_in_bytes / 128 : 1;
|
||||
|
||||
// separate q, k, and v tma descriptors
|
||||
@ -351,9 +365,9 @@ public:
|
||||
|
||||
// stride size in bytes. Assumes least significant dim is 1 (?)
|
||||
uint64_t tensor_stride_qkv[3];
|
||||
tensor_stride_qkv[0] = tensor_size_qkv[0] * sizeof(uint16_t); // d
|
||||
tensor_stride_qkv[1] = tensor_size_qkv[1] * tensor_stride_qkv[0]; // d*h
|
||||
tensor_stride_qkv[2] = tensor_size_qkv[2] * tensor_stride_qkv[1]; // d*h*3
|
||||
tensor_stride_qkv[0] = get_size_in_bytes(tensor_size_qkv[0], mDataType); // d
|
||||
tensor_stride_qkv[1] = tensor_size_qkv[1] * tensor_stride_qkv[0]; // d*h
|
||||
tensor_stride_qkv[2] = tensor_size_qkv[2] * tensor_stride_qkv[1]; // d*h*3
|
||||
|
||||
// traversal stride
|
||||
uint32_t traversal_stride_qkv[4] = {1, 1, 1, 1};
|
||||
@ -365,7 +379,7 @@ public:
|
||||
uint32_t fp32_to_tf32 = 0;
|
||||
|
||||
// gmma descriptor mode
|
||||
const uint32_t d_bytes_per_group = (mLaunchParams.padded_d * sizeof(uint16_t)) / d_groups;
|
||||
const uint32_t d_bytes_per_group = d_in_bytes / d_groups;
|
||||
const cudaTmaDescSwizzle swizzle_mode = (d_bytes_per_group > 64
|
||||
? cudaTmaDescSwizzle::SWIZZLE_128B
|
||||
: (d_bytes_per_group > 32 ? cudaTmaDescSwizzle::SWIZZLE_64B : cudaTmaDescSwizzle::SWIZZLE_32B));
|
||||
@ -388,21 +402,21 @@ public:
|
||||
|
||||
// Q: STEP_Q
|
||||
box_size[3] = q_step;
|
||||
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, cudaTmaDescFormat::F16_RN,
|
||||
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
|
||||
tensor_size_qkv, tensor_stride_qkv, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32,
|
||||
&mParams.tma_desc_q);
|
||||
// Desc Format (data type).
|
||||
const cudaTmaDescFormat desc_format
|
||||
= (get_size_in_bytes(1, mDataType) == 1) ? cudaTmaDescFormat::U8 : cudaTmaDescFormat::F16_RN;
|
||||
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format, cudaTmaDescInterleave::INTERLEAVE_DISABLED,
|
||||
swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv,
|
||||
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, &mParams.tma_desc_q);
|
||||
|
||||
// K/V: STEP_KV
|
||||
box_size[3] = kv_step;
|
||||
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, cudaTmaDescFormat::F16_RN,
|
||||
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
|
||||
tensor_size_qkv, tensor_stride_qkv, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32,
|
||||
&mParams.tma_desc_k);
|
||||
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, cudaTmaDescFormat::F16_RN,
|
||||
cudaTmaDescInterleave::INTERLEAVE_DISABLED, swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED,
|
||||
tensor_size_qkv, tensor_stride_qkv, traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32,
|
||||
&mParams.tma_desc_v);
|
||||
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format, cudaTmaDescInterleave::INTERLEAVE_DISABLED,
|
||||
swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv,
|
||||
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, &mParams.tma_desc_k);
|
||||
qkv_tma_descriptor.set_tma_desctriptor(qkv_ptr, desc_format, cudaTmaDescInterleave::INTERLEAVE_DISABLED,
|
||||
swizzle_mode, cudaTmaDescPromotion::PROMOTION_DISABLED, tensor_size_qkv, tensor_stride_qkv,
|
||||
traversal_stride_qkv, box_size, oob_fill, fp32_to_tf32, &mParams.tma_desc_v);
|
||||
}
|
||||
|
||||
// Q are contiguous in the shape of [B, S, H, D]
|
||||
@ -520,8 +534,9 @@ public:
|
||||
|
||||
void setup_flags(bool const force_fp32_acc, bool const is_s_padded, bool const causal_mask, int const num_kv_heads)
|
||||
{
|
||||
// BF16 FMHA only accumulates on FP32
|
||||
mLaunchParams.force_fp32_acc = mDataType == DATA_TYPE_BF16 || force_fp32_acc;
|
||||
// BF16 FMHA only accumulates on FP32.
|
||||
// E4M3 FMHA only supports fp32 accumulation currently.
|
||||
mLaunchParams.force_fp32_acc = mDataType == DATA_TYPE_BF16 || mDataType == DATA_TYPE_E4M3 || force_fp32_acc;
|
||||
mLaunchParams.attention_mask_type
|
||||
= causal_mask ? ContextAttentionMaskType::CAUSAL : ContextAttentionMaskType::PADDING;
|
||||
|
||||
@ -646,9 +661,9 @@ FusedMHARunnerV2::FusedMHARunnerV2(
|
||||
FusedMHARunnerV2::~FusedMHARunnerV2() = default;
|
||||
|
||||
void FusedMHARunnerV2::setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
|
||||
bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||
float const* scale_bmm2_d, bool const has_alibi, bool const scale_alibi, int const tp_size, int const tp_rank)
|
||||
{
|
||||
pimpl->setup(b, s, sliding_window_size, total_seqlen, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||
pimpl->setup(b, s, sliding_window_size, total_seqlen, scale_bmm2_d, has_alibi, scale_alibi, tp_size, tp_rank);
|
||||
}
|
||||
|
||||
void FusedMHARunnerV2::setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
|
||||
|
||||
@ -48,7 +48,8 @@ public:
|
||||
virtual ~MHARunner() = default;
|
||||
|
||||
virtual void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
|
||||
bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1, int const tp_rank = 0)
|
||||
float const* scale_bmm2_d = nullptr, bool const has_alibi = false, bool const scale_alibi = false,
|
||||
int const tp_size = 1, int const tp_rank = 0)
|
||||
= 0;
|
||||
|
||||
virtual void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
|
||||
@ -91,8 +92,8 @@ public:
|
||||
~FusedMHARunnerV2(); // for pimpl
|
||||
|
||||
void setup(int const b, int const s, int const sliding_window_size, int const total_seqlen,
|
||||
bool const has_alibi = false, bool const scale_alibi = false, int const tp_size = 1,
|
||||
int const tp_rank = 0) override;
|
||||
float const* scale_bmm2_d = nullptr, bool const has_alibi = false, bool const scale_alibi = false,
|
||||
int const tp_size = 1, int const tp_rank = 0) override;
|
||||
|
||||
void setup_paged_kv(int const b, int const s_q, int const s_kv, int const blocks_per_context_sequence,
|
||||
int const tokens_per_kv_block, int const sliding_window_size, int const total_seqlen,
|
||||
|
||||
@ -122,8 +122,9 @@ struct Fused_multihead_attention_params_v2
|
||||
int *counters, *max_barriers, *sum_barriers, *locks;
|
||||
// Scratch buffers to finalize softmax.
|
||||
float *max_scratch_ptr, *sum_scratch_ptr;
|
||||
// Scratch buffer to finalize the output (not needed for FP16).
|
||||
int* o_scratch_ptr;
|
||||
|
||||
// Scale bmm2 in the device memory.
|
||||
uint32_t const* scale_bmm2_d;
|
||||
|
||||
// In multi-query or grouped-query attention (MQA/GQA), several Q heads are associated with one KV head
|
||||
int h_kv;
|
||||
@ -199,6 +200,10 @@ struct Fused_multihead_attention_paged_kv_params_v2
|
||||
// The scaling factors for the kernel.
|
||||
uint32_t scale_bmm1, scale_softmax, scale_bmm2;
|
||||
|
||||
// TODO: add fp8 paged kv fmha later.
|
||||
// Scale bmm2 in the device memory.
|
||||
// uint32_t const* scale_bmm2_d;
|
||||
|
||||
// Do we use Niall's trick to avoid I2F/F2I in the INT8 kernel.
|
||||
// See https://confluence.nvidia.com/pages/viewpage.action?pageId=302779721 for details.
|
||||
bool enable_i2f_trick;
|
||||
|
||||
@ -17,17 +17,19 @@
|
||||
#include "customAllReduceKernels.h"
|
||||
#include "tensorrt_llm/common/cudaBf16Fallbacks.cuh"
|
||||
#include "tensorrt_llm/common/cudaTypeUtils.cuh"
|
||||
#include "tensorrt_llm/common/dataType.h"
|
||||
#include <tuple>
|
||||
#include <type_traits>
|
||||
|
||||
namespace tensorrt_llm::kernels
|
||||
{
|
||||
|
||||
using tensorrt_llm::common::datatype_enum;
|
||||
using tensorrt_llm::common::divUp;
|
||||
using tensorrt_llm::common::roundUp;
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ void st_flag_release(uint32_t& flag, uint32_t* flag_addr)
|
||||
static inline __device__ void st_flag_release(uint32_t const& flag, uint32_t* flag_addr)
|
||||
{
|
||||
#if __CUDA_ARCH__ >= 700
|
||||
asm volatile("st.global.release.sys.b32 [%1], %0;" ::"r"(flag), "l"(flag_addr));
|
||||
@ -39,13 +41,15 @@ static inline __device__ void st_flag_release(uint32_t& flag, uint32_t* flag_add
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline __device__ void ld_flag_acquire(uint32_t& flag, uint32_t* flag_addr)
|
||||
static inline __device__ uint32_t ld_flag_acquire(uint32_t* flag_addr)
|
||||
{
|
||||
uint32_t flag;
|
||||
#if __CUDA_ARCH__ >= 700
|
||||
asm volatile("ld.global.acquire.sys.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
#else
|
||||
asm volatile("ld.global.volatile.b32 %0, [%1];" : "=r"(flag) : "l"(flag_addr));
|
||||
#endif
|
||||
return flag;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
@ -107,72 +111,155 @@ inline __device__ int4 add128b(T& a, T& b)
|
||||
return c.packed;
|
||||
}
|
||||
|
||||
__inline__ __device__ void multi_gpu_barrier(
|
||||
uint32_t** signals, const uint32_t flag, const size_t rank, const size_t world_size, int const tidx, int const bidx)
|
||||
__inline__ __device__ void multi_gpu_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
|
||||
size_t const world_size, int const tidx, int const bidx)
|
||||
{
|
||||
// At the end of the function, we now that has least block 0 from all others GPUs have reached that point.
|
||||
uint32_t volatile* my_signals = signals[rank];
|
||||
// After this function, at least one block in each GPU has reached the barrier
|
||||
if (tidx < world_size)
|
||||
{
|
||||
// The 1st block notifies the other ranks.
|
||||
// we can think of signals having the shape [world_size, world_size]
|
||||
// Dimension 0 is the "listening" dimension, dimension 2 is "emitting" dimension
|
||||
|
||||
// Block 0 broadcasts its flag (local_rank on emitting dimension) to all receivers
|
||||
if (bidx == 0)
|
||||
{
|
||||
signals[tidx][rank] = flag;
|
||||
signals[tidx][local_rank] = flag;
|
||||
}
|
||||
|
||||
// Busy-wait until all ranks are ready.
|
||||
// All blocks check that corresponding block 0 on other GPUs have set the flag
|
||||
// No deadlock because block #0 is always the first block started
|
||||
uint32_t volatile* my_signals = signals[local_rank];
|
||||
while (my_signals[tidx] != flag)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure we can move on...
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
__global__ void multiGpuBarrierKernel(AllReduceParams params)
|
||||
__inline__ __device__ void block_barrier(uint32_t** signals, uint32_t const flag, size_t const local_rank,
|
||||
size_t const world_size, int const tidx, int const bidx)
|
||||
{
|
||||
multi_gpu_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, params.ranks_per_node,
|
||||
threadIdx.x, blockIdx.x);
|
||||
// After this function, the block of id == bidx of each GPU has reached the barrier
|
||||
if (tidx < world_size)
|
||||
{
|
||||
// we can think of signals having the shape [world_size, num_blocks, world_size]
|
||||
// (+ an offset on dim 1 to account for flags used in multi_gpu_barrier)
|
||||
// Dimension 0 is the "listening" dimension, dimension 2 is "emitting" dimension
|
||||
|
||||
// Block broadcast its flag (local_rank on emitting dimension) to all receivers
|
||||
uint32_t flag_block_offset = world_size + bidx * world_size;
|
||||
st_flag_release(flag, signals[tidx] + flag_block_offset + local_rank);
|
||||
|
||||
// Blocks check that corresponding blocks on other GPUs have also set the flag
|
||||
uint32_t* peer_barrier_d = signals[local_rank] + flag_block_offset + tidx;
|
||||
while (ld_flag_acquire(peer_barrier_d) != flag)
|
||||
{
|
||||
}
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
template <typename T, int RANKS_PER_NODE>
|
||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true, bool PUSH_MODE = false>
|
||||
static __global__ void oneShotAllReduceKernel(AllReduceParams params)
|
||||
{
|
||||
// Suppose that two GPUs participate in the AR exchange, and we start four blocks.
|
||||
// The message is partitioned into chunks as detailed below:
|
||||
// message
|
||||
// |-------------------|
|
||||
// GPU 0 | B0 | B1 | B2 | B3 |
|
||||
// GPU 1 | B0 | B1 | B2 | B3 |
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 copies the chunk it is responsible for, from local_input to shareable buffer
|
||||
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier)
|
||||
// 3. B0 on GPU 0 pull and sum the chunk from GPU 1, writes the result to local_output
|
||||
//
|
||||
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
|
||||
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
|
||||
//
|
||||
// With PUSH_MODE, we consider that the shared buffer is of size:
|
||||
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size]
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 push the chunk is it responsible for into all other GPUs:
|
||||
// params.peer_comm_buffer_ptrs[:, local_gpu, B0 slice]
|
||||
// 2. block sync so the block is shared by other GPUs
|
||||
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
|
||||
|
||||
int const bidx = blockIdx.x;
|
||||
int const tidx = threadIdx.x;
|
||||
|
||||
// The number of elements packed into one for comms
|
||||
static constexpr int NUM_ELTS = 16 / sizeof(T);
|
||||
|
||||
// Packed data type for comms
|
||||
static constexpr int PACKED_ELTS = 16 / sizeof(T);
|
||||
using PackedStruct = typename PackedOn16Bytes<T>::Type;
|
||||
|
||||
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
||||
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
|
||||
T* local_shared_buffer = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[params.local_rank]);
|
||||
T* local_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);
|
||||
|
||||
// The source pointers. Distributed round-robin for the different warps.
|
||||
T const* src_d[RANKS_PER_NODE];
|
||||
// Start and end offsets of the thread
|
||||
size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
|
||||
size_t const chunk_end = std::min((bidx + 1) * params.elts_per_block, params.elts_total);
|
||||
|
||||
T* buffers[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
// buffers[0] is always the local buffers. Helps load balancing reads.
|
||||
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
|
||||
src_d[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
|
||||
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
|
||||
}
|
||||
|
||||
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
|
||||
size_t offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
|
||||
// The end of the segment computed by that block.
|
||||
size_t max_offset = std::min((bidx + 1) * params.elts_per_block, params.elts_per_rank);
|
||||
if constexpr (PUSH_MODE || COPY_INPUT)
|
||||
{
|
||||
// Copy from local buffer to shareable buffer
|
||||
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * PACKED_ELTS)
|
||||
{
|
||||
if constexpr (PUSH_MODE)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
*reinterpret_cast<int4*>(&buffers[ii][params.local_rank * params.elts_total + iter_offset])
|
||||
= *reinterpret_cast<int4 const*>(&local_input_buffer[iter_offset]);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<int4*>(&local_shared_buffer[iter_offset])
|
||||
= *reinterpret_cast<int4 const*>(&local_input_buffer[iter_offset]);
|
||||
}
|
||||
}
|
||||
|
||||
// wait for equivalent blocks of other GPUs to have copied data to their shareable buffer
|
||||
block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
||||
}
|
||||
else
|
||||
{
|
||||
// In the non-copy case, we assume that once the kernel has been started, data is ready to be consumed
|
||||
multi_gpu_barrier(
|
||||
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
||||
}
|
||||
|
||||
// Each block accumulates the values from the different GPUs on the same node.
|
||||
for (size_t iter_offset = offset; iter_offset < max_offset; iter_offset += blockDim.x * NUM_ELTS)
|
||||
for (size_t iter_offset = chunk_start; iter_offset < chunk_end; iter_offset += blockDim.x * PACKED_ELTS)
|
||||
{
|
||||
// Iterate over the different ranks/devices on the node to load the values.
|
||||
PackedStruct vals[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
vals[ii].packed = *reinterpret_cast<int4 const*>(&src_d[ii][iter_offset]);
|
||||
if constexpr (PUSH_MODE)
|
||||
{
|
||||
vals[ii].packed
|
||||
= *reinterpret_cast<int4 const*>(&buffers[params.local_rank][ii * params.elts_total + iter_offset]);
|
||||
}
|
||||
else
|
||||
{
|
||||
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][iter_offset]);
|
||||
}
|
||||
}
|
||||
|
||||
// Sum the values from the different ranks.
|
||||
@ -185,55 +272,130 @@ static __global__ void oneShotAllReduceKernel(AllReduceParams params)
|
||||
}
|
||||
|
||||
// Store to the destination buffer.
|
||||
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[iter_offset]) = sums.packed;
|
||||
*reinterpret_cast<int4*>(&local_output_buffer[iter_offset]) = sums.packed;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int RANKS_PER_NODE>
|
||||
template <typename T, int RANKS_PER_NODE, bool COPY_INPUT = true, bool PUSH_MODE = false>
|
||||
static __global__ void twoShotAllReduceKernel(AllReduceParams params)
|
||||
{
|
||||
// Suppose that two GPUs participate in the AR exchange, and we start two blocks.
|
||||
// The message is partitioned into chunks as detailed below:
|
||||
// message
|
||||
// |-------------------|
|
||||
// |--GPU 0--|--GPU 1--| (GPU responsibility parts)
|
||||
// GPU 0 | B0 | B1 | B0 | B1 |
|
||||
// GPU 1 | B0 | B1 | B0 | B1 |
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 copies all chunks is it responsible for, from local_input to shareable buffer
|
||||
// 2. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #0)
|
||||
// 3. B0 on GPU 0 gather and sum the B0 chunks from GPU 1, that are in the GPU 0 responsibility
|
||||
// part (the first half of the message, see GPU responsibility row above)
|
||||
// 3bis. Likewise, B0 on GPU 1 copies and sum the chunks for GPU 0,
|
||||
// where GPU 1 is responsible: the second half of the message.
|
||||
// 4. B0 on GPU 0 and B0 on GPU 1 wait for each other (block_barrier #1)
|
||||
// 5. B0 writes result to local_output. It gathers each chunk from its responsible GPU.
|
||||
// For example, here it reads the first chunk from GPU 0 and second chunk from GPU 1.
|
||||
//
|
||||
// With COPY_INPUT == false, skip step 1. and use gpu_barrier instead of block barrier during step 2.
|
||||
// We only to know if the other GPU as arrived at the AR kernel, that would mean that data is ready
|
||||
// to be read.
|
||||
//
|
||||
// Note that compared to one-shot, one block (CTA) writes multiple input chunks and write multiple output chunks.
|
||||
// However, it's only responsible for the summation of a single chunk.
|
||||
//
|
||||
// With PUSH_MODE, we consider that the shared buffer is of size:
|
||||
// params.peer_comm_buffer_ptrs: [world_size, world_size, message_size / world_size]
|
||||
//
|
||||
// Here the step-by-step behavior of one block:
|
||||
// 1. B0 push the chunks is it responsible for into the corresponding GPUs:
|
||||
// params.peer_comm_buffer_ptrs[target_gpu, local_gpu, current B0 slice]
|
||||
// 2. block sync so the blocks have been shared by other GPUs
|
||||
// 3. Reduce along second dimension params.peer_comm_buffer_ptrs[local_gpu, :, B0 slice]
|
||||
// 4. block barrier (corresponding blocks have finished reduction)
|
||||
// 5. pull and write on local buffer, by reading params.peer_comm_buffer_ptrs[:, 0, B0 slice] (reduction result is
|
||||
// written at index 0 of 2nd dim)
|
||||
|
||||
// The block index.
|
||||
int const bidx = blockIdx.x;
|
||||
// The thread index with the block.
|
||||
int const tidx = threadIdx.x;
|
||||
|
||||
// The number of elements packed into one for comms
|
||||
static constexpr int NUM_ELTS = 16 / sizeof(T);
|
||||
|
||||
// Packed data type for comms
|
||||
static constexpr int PACKED_ELTS = 16 / sizeof(T);
|
||||
using PackedType = typename PackedOn16Bytes<T>::Type;
|
||||
|
||||
// The location in the destination array (load 8 fp16 or load 4 fp32 using LDG.128).
|
||||
const size_t block_offset = bidx * params.elts_per_block + tidx * NUM_ELTS;
|
||||
const size_t block_start = params.rank_offset + block_offset;
|
||||
// The end of the segment computed by that block.
|
||||
size_t max_offset = min(block_start + params.elts_per_block, params.rank_offset + params.elts_per_rank);
|
||||
T const* local_input_buffer = reinterpret_cast<T const*>(params.local_input_buffer_ptr);
|
||||
T* local_shared_buffer = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[params.local_rank]);
|
||||
T* local_output_buffer = reinterpret_cast<T*>(params.local_output_buffer_ptr);
|
||||
|
||||
multi_gpu_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
||||
size_t const chunk_start = bidx * params.elts_per_block + tidx * PACKED_ELTS;
|
||||
size_t const chunk_end = min(chunk_start + params.elts_per_block, params.elts_per_rank);
|
||||
|
||||
// The source pointers. Distributed round-robin for the different warps.
|
||||
T* src_d[RANKS_PER_NODE];
|
||||
// The destination ranks for round-robin gathering
|
||||
size_t dst_rank[RANKS_PER_NODE];
|
||||
T* buffers[RANKS_PER_NODE];
|
||||
int ranks[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
// A mapping of the ranks to scatter reads as much as possible
|
||||
int rank = (params.local_rank + ii) % RANKS_PER_NODE;
|
||||
src_d[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
|
||||
dst_rank[ii] = rank;
|
||||
ranks[ii] = rank;
|
||||
buffers[ii] = reinterpret_cast<T*>(params.peer_comm_buffer_ptrs[rank]);
|
||||
}
|
||||
|
||||
if constexpr (PUSH_MODE || COPY_INPUT)
|
||||
{
|
||||
// Copy all blocks from local buffer to shareable buffer
|
||||
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
size_t offset_rank = ii * params.elts_per_rank + local_offset;
|
||||
if (offset_rank >= params.elts_total)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
if constexpr (PUSH_MODE)
|
||||
{
|
||||
*reinterpret_cast<int4*>(&buffers[ii][params.local_rank * params.elts_per_rank + local_offset])
|
||||
= *reinterpret_cast<int4 const*>(&local_input_buffer[offset_rank]);
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<int4*>(&local_shared_buffer[offset_rank])
|
||||
= *reinterpret_cast<int4 const*>(&local_input_buffer[offset_rank]);
|
||||
}
|
||||
}
|
||||
}
|
||||
block_barrier(params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
||||
}
|
||||
else
|
||||
{
|
||||
// In the non-copy case, we assume that once the kernel has been started, data is ready to be consumed
|
||||
multi_gpu_barrier(
|
||||
params.peer_barrier_ptrs_in, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
||||
}
|
||||
|
||||
// Each block accumulates the values from the different GPUs on the same node.
|
||||
for (size_t local_offset = block_start; local_offset < max_offset; local_offset += blockDim.x * NUM_ELTS)
|
||||
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS)
|
||||
{
|
||||
size_t const responsible_block_offset = local_offset + params.rank_offset;
|
||||
|
||||
// Iterate over the different ranks/devices on the node to load the values.
|
||||
PackedType vals[RANKS_PER_NODE];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
vals[ii].packed = *reinterpret_cast<int4 const*>(&src_d[ii][local_offset]);
|
||||
if constexpr (PUSH_MODE)
|
||||
{
|
||||
vals[ii].packed
|
||||
= *reinterpret_cast<int4 const*>(&local_shared_buffer[ii * params.elts_per_rank + local_offset]);
|
||||
}
|
||||
else
|
||||
{
|
||||
vals[ii].packed = *reinterpret_cast<int4 const*>(&buffers[ii][responsible_block_offset]);
|
||||
}
|
||||
}
|
||||
|
||||
// Sum the values from the different ranks.
|
||||
@ -246,86 +408,77 @@ static __global__ void twoShotAllReduceKernel(AllReduceParams params)
|
||||
}
|
||||
|
||||
// Store to the local buffer.
|
||||
*reinterpret_cast<int4*>(&src_d[0][local_offset]) = sums.packed;
|
||||
}
|
||||
|
||||
// sync threads to make sure all block threads have the sums
|
||||
__syncthreads();
|
||||
|
||||
// barriers among the blocks with the same idx (release-acquire semantics)
|
||||
if (tidx < RANKS_PER_NODE)
|
||||
{
|
||||
// The all blocks notifies the other ranks.
|
||||
uint32_t flag_block_offset = RANKS_PER_NODE + bidx * RANKS_PER_NODE;
|
||||
st_flag_release(params.barrier_flag, params.peer_barrier_ptrs_in[tidx] + flag_block_offset + params.local_rank);
|
||||
|
||||
// Busy-wait until all ranks are ready.
|
||||
uint32_t rank_barrier = 0;
|
||||
uint32_t* peer_barrier_d = params.peer_barrier_ptrs_in[params.local_rank] + flag_block_offset + tidx;
|
||||
do
|
||||
if constexpr (PUSH_MODE)
|
||||
{
|
||||
ld_flag_acquire(rank_barrier, peer_barrier_d);
|
||||
} while (rank_barrier != params.barrier_flag);
|
||||
*reinterpret_cast<int4*>(&local_shared_buffer[local_offset]) = sums.packed;
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<int4*>(&local_shared_buffer[responsible_block_offset]) = sums.packed;
|
||||
}
|
||||
}
|
||||
|
||||
// sync threads to make sure all other ranks has the final partial results
|
||||
__syncthreads();
|
||||
block_barrier(params.peer_barrier_ptrs_out, params.barrier_flag, params.local_rank, RANKS_PER_NODE, tidx, bidx);
|
||||
|
||||
size_t max_block_offset = min(block_offset + params.elts_per_block, params.elts_per_rank);
|
||||
// Gather all needed elts from other intra-node ranks
|
||||
for (size_t local_offset = block_offset; local_offset < max_block_offset; local_offset += blockDim.x * NUM_ELTS)
|
||||
for (size_t local_offset = chunk_start; local_offset < chunk_end; local_offset += blockDim.x * PACKED_ELTS)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < RANKS_PER_NODE; ++ii)
|
||||
{
|
||||
// use round-robin gathering from other ranks
|
||||
size_t offset_rank = dst_rank[ii] * params.elts_per_rank + local_offset;
|
||||
size_t offset_rank = ranks[ii] * params.elts_per_rank + local_offset;
|
||||
if (offset_rank >= params.elts_total)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
*reinterpret_cast<int4*>(&reinterpret_cast<T*>(params.local_output_buffer_ptr)[offset_rank])
|
||||
= *reinterpret_cast<int4*>(&src_d[ii][offset_rank]);
|
||||
|
||||
if constexpr (PUSH_MODE)
|
||||
{
|
||||
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank])
|
||||
= *reinterpret_cast<int4*>(&buffers[ii][local_offset]);
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<int4*>(&local_output_buffer[offset_rank])
|
||||
= *reinterpret_cast<int4*>(&buffers[ii][offset_rank]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
bool configurationSupported(AllReduceStrategyType algo, size_t msg_size, size_t n_ranks, nvinfer1::DataType type)
|
||||
{
|
||||
size_t elts_per_thread = 16 / common::getDTypeSize(type);
|
||||
int const msg_align = (algo == AllReduceStrategyType::TWOSHOT) ? n_ranks * elts_per_thread : elts_per_thread;
|
||||
bool supported_algo = (algo == AllReduceStrategyType::ONESHOT || algo == AllReduceStrategyType::TWOSHOT);
|
||||
return supported_algo && (msg_size % msg_align == 0);
|
||||
}
|
||||
|
||||
std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReduceParams& param, size_t elts_per_thread)
|
||||
{
|
||||
TLLM_CHECK(param.elts_total % elts_per_thread == 0);
|
||||
|
||||
int blocks_per_grid = 1, threads_per_block = DEFAULT_BLOCK_SIZE;
|
||||
|
||||
const size_t total_threads = param.elts_total / elts_per_thread;
|
||||
switch (algo)
|
||||
{
|
||||
case AllReduceStrategyType::ONESHOT:
|
||||
{ // one stage all reduce algo
|
||||
if (total_threads <= DEFAULT_BLOCK_SIZE)
|
||||
{ // local reduce
|
||||
threads_per_block = WARP_SIZE * divUp(total_threads, WARP_SIZE);
|
||||
blocks_per_grid = 1;
|
||||
}
|
||||
else
|
||||
{ // local reduce
|
||||
threads_per_block = DEFAULT_BLOCK_SIZE;
|
||||
blocks_per_grid = divUp(total_threads, DEFAULT_BLOCK_SIZE);
|
||||
blocks_per_grid = std::min(static_cast<int>(MAX_ALL_REDUCE_BLOCKS), blocks_per_grid);
|
||||
}
|
||||
param.elts_per_rank = param.elts_total;
|
||||
param.elts_per_block = elts_per_thread * divUp(param.elts_per_rank, elts_per_thread * blocks_per_grid);
|
||||
{
|
||||
TLLM_CHECK(param.elts_total % elts_per_thread == 0);
|
||||
size_t const total_threads = roundUp(param.elts_total / elts_per_thread, WARP_SIZE);
|
||||
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
|
||||
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
|
||||
param.elts_per_block = roundUp(divUp(param.elts_total, blocks_per_grid), elts_per_thread);
|
||||
break;
|
||||
}
|
||||
case AllReduceStrategyType::TWOSHOT:
|
||||
{ // two stage all reduce algo
|
||||
const size_t elts_per_rank = param.elts_total / param.ranks_per_node;
|
||||
TLLM_CHECK(elts_per_rank % elts_per_thread == 0);
|
||||
{
|
||||
TLLM_CHECK(param.elts_total % (elts_per_thread * param.ranks_per_node) == 0);
|
||||
size_t const total_threads = roundUp(param.elts_total / (elts_per_thread * param.ranks_per_node), WARP_SIZE);
|
||||
|
||||
size_t total_threads = elts_per_rank / elts_per_thread;
|
||||
total_threads = WARP_SIZE * ((total_threads + WARP_SIZE - 1) / WARP_SIZE);
|
||||
TLLM_CHECK(total_threads % WARP_SIZE == 0);
|
||||
/*
|
||||
threads_per_block = std::min(DEFAULT_BLOCK_SIZE, total_threads);
|
||||
blocks_per_grid = std::min(static_cast<size_t>(MAX_ALL_REDUCE_BLOCKS), divUp(total_threads, threads_per_block));
|
||||
*/
|
||||
|
||||
while (total_threads % blocks_per_grid != 0 || total_threads / blocks_per_grid > DEFAULT_BLOCK_SIZE)
|
||||
{
|
||||
@ -345,9 +498,8 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
|
||||
blocks_per_grid /= iter_factor;
|
||||
}
|
||||
param.elts_per_rank = param.elts_total / param.ranks_per_node;
|
||||
param.elts_per_block = param.elts_per_rank / blocks_per_grid;
|
||||
param.elts_per_block = elts_per_thread * divUp(param.elts_per_block, elts_per_thread);
|
||||
param.rank_offset = param.rank * param.elts_per_rank;
|
||||
param.rank_offset = param.local_rank * param.elts_per_rank;
|
||||
param.elts_per_block = roundUp(divUp(param.elts_per_rank, blocks_per_grid), elts_per_thread);
|
||||
break;
|
||||
}
|
||||
default: TLLM_THROW("Algorithm not supported here.");
|
||||
@ -356,44 +508,74 @@ std::tuple<int, int> kernelLaunchConfig(AllReduceStrategyType algo, AllReducePar
|
||||
return std::make_tuple(blocks_per_grid, threads_per_block);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename T, int RANKS_PER_NODE>
|
||||
void dispatchARKernels(
|
||||
AllReduceStrategyType algo, AllReduceParams& param, int blocks_per_grid, int threads_per_block, cudaStream_t stream)
|
||||
template <typename T, int RANKS_PER_NODE, bool PUSH_MODE = false, bool USE_MEMCPY = false>
|
||||
void AllReduceDispatchMemcpy(
|
||||
AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceParams& param, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(!(USE_MEMCPY && PUSH_MODE), "Memcpy cannot be used with PUSH_MODE.");
|
||||
size_t elts_per_thread = 16 / sizeof(T);
|
||||
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(algo, param, elts_per_thread);
|
||||
|
||||
if (USE_MEMCPY)
|
||||
{
|
||||
cudaMemcpyAsync(param.peer_comm_buffer_ptrs[param.local_rank], param.local_input_buffer_ptr,
|
||||
param.elts_total * sizeof(T), cudaMemcpyDeviceToDevice, stream);
|
||||
}
|
||||
|
||||
if (algo == AllReduceStrategyType::ONESHOT)
|
||||
{
|
||||
oneShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
|
||||
oneShotAllReduceKernel<T, RANKS_PER_NODE, !USE_MEMCPY, PUSH_MODE>
|
||||
<<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
|
||||
}
|
||||
else
|
||||
{
|
||||
twoShotAllReduceKernel<T, RANKS_PER_NODE><<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
|
||||
twoShotAllReduceKernel<T, RANKS_PER_NODE, !USE_MEMCPY, PUSH_MODE>
|
||||
<<<blocks_per_grid, threads_per_block, 0, stream>>>(param);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int RANKS_PER_NODE, bool PUSH_MODE = false>
|
||||
void AllReduceDispatchPushMode(
|
||||
AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceParams& param, cudaStream_t stream)
|
||||
{
|
||||
if (static_cast<std::underlying_type_t<AllReduceStrategyConfig>>(config)
|
||||
& static_cast<std::underlying_type_t<AllReduceStrategyConfig>>(AllReduceStrategyConfig::USE_MEMCPY))
|
||||
{
|
||||
AllReduceDispatchMemcpy<T, RANKS_PER_NODE, PUSH_MODE, true>(algo, config, param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
AllReduceDispatchMemcpy<T, RANKS_PER_NODE, PUSH_MODE, false>(algo, config, param, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, int RANKS_PER_NODE> //, bool USE_MEMCPY = false, bool PUSH_MODE = false>
|
||||
void AllReduceDispatchRanksPerNode(
|
||||
AllReduceStrategyType algo, AllReduceStrategyConfig config, AllReduceParams& param, cudaStream_t stream)
|
||||
{
|
||||
if (static_cast<std::underlying_type_t<AllReduceStrategyConfig>>(config)
|
||||
& static_cast<std::underlying_type_t<AllReduceStrategyConfig>>(AllReduceStrategyConfig::PULL_MODE))
|
||||
{
|
||||
AllReduceDispatchPushMode<T, RANKS_PER_NODE, true>(algo, config, param, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
AllReduceDispatchPushMode<T, RANKS_PER_NODE, false>(algo, config, param, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream)
|
||||
void AllReduceDispatchType(
|
||||
AllReduceParams& param, AllReduceStrategyType strat, AllReduceStrategyConfig config, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK(strat == AllReduceStrategyType::ONESHOT || strat == AllReduceStrategyType::TWOSHOT);
|
||||
sync_check_cuda_error();
|
||||
|
||||
size_t elts_per_thread = 16 / sizeof(T);
|
||||
auto [blocks_per_grid, threads_per_block] = kernelLaunchConfig(strat, param, elts_per_thread);
|
||||
switch (param.ranks_per_node)
|
||||
{
|
||||
case 2: dispatchARKernels<T, 2>(strat, param, blocks_per_grid, threads_per_block, stream); break;
|
||||
case 4: dispatchARKernels<T, 4>(strat, param, blocks_per_grid, threads_per_block, stream); break;
|
||||
case 6: dispatchARKernels<T, 6>(strat, param, blocks_per_grid, threads_per_block, stream); break;
|
||||
case 8: dispatchARKernels<T, 8>(strat, param, blocks_per_grid, threads_per_block, stream); break;
|
||||
default: break;
|
||||
case 2: AllReduceDispatchRanksPerNode<T, 2>(strat, config, param, stream); break;
|
||||
case 4: AllReduceDispatchRanksPerNode<T, 4>(strat, config, param, stream); break;
|
||||
case 6: AllReduceDispatchRanksPerNode<T, 6>(strat, config, param, stream); break;
|
||||
case 8: AllReduceDispatchRanksPerNode<T, 8>(strat, config, param, stream); break;
|
||||
default: TLLM_THROW("Custom all reduce only supported on {2, 4, 6, 8} GPUs per node.");
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
void invokeMultiGpuBarrier(AllReduceParams& param, cudaStream_t stream)
|
||||
{
|
||||
multiGpuBarrierKernel<<<1, param.ranks_per_node, 0, stream>>>(param);
|
||||
}
|
||||
|
||||
AllReduceParams AllReduceParams::deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value)
|
||||
@ -425,30 +607,25 @@ AllReduceParams AllReduceParams::deserialize(int32_t const* buffer, size_t tpSiz
|
||||
return params;
|
||||
}
|
||||
|
||||
void customAllReduce(kernels::AllReduceParams& params, void* data, size_t elts, size_t size_per_elem,
|
||||
datatype_enum dataType, AllReduceStrategyType strat, cudaStream_t stream)
|
||||
void customAllReduce(kernels::AllReduceParams& params, nvinfer1::DataType dataType, AllReduceStrategyType strat,
|
||||
AllReduceStrategyConfig config, cudaStream_t stream)
|
||||
{
|
||||
params.local_output_buffer_ptr = data;
|
||||
params.elts_total = elts;
|
||||
TLLM_CHECK_WITH_INFO(configurationSupported(strat, params.elts_total, params.ranks_per_node, dataType),
|
||||
"Custom all-reduce configuration unsupported");
|
||||
|
||||
if (dataType == datatype_enum::TYPE_FP32)
|
||||
sync_check_cuda_error();
|
||||
|
||||
switch (dataType)
|
||||
{
|
||||
kernels::invokeOneOrTwoShotAllReduceKernel<float>(params, strat, stream);
|
||||
}
|
||||
else if (dataType == datatype_enum::TYPE_FP16)
|
||||
{
|
||||
kernels::invokeOneOrTwoShotAllReduceKernel<half>(params, strat, stream);
|
||||
}
|
||||
case nvinfer1::DataType::kFLOAT: AllReduceDispatchType<float>(params, strat, config, stream); break;
|
||||
case nvinfer1::DataType::kHALF: AllReduceDispatchType<half>(params, strat, config, stream); break;
|
||||
#ifdef ENABLE_BF16
|
||||
else if (dataType == datatype_enum::TYPE_BF16)
|
||||
{
|
||||
kernels::invokeOneOrTwoShotAllReduceKernel<__nv_bfloat16>(params, strat, stream);
|
||||
}
|
||||
case nvinfer1::DataType::kBF16: AllReduceDispatchType<__nv_bfloat16>(params, strat, config, stream); break;
|
||||
#endif
|
||||
else
|
||||
{
|
||||
TLLM_THROW("Unsupported dataType for customAllReduce");
|
||||
default: TLLM_THROW("Unsupported dataType for customAllReduce");
|
||||
}
|
||||
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::kernels
|
||||
|
||||
@ -16,12 +16,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <assert.h>
|
||||
#include <NvInferRuntime.h>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
#include "tensorrt_llm/common/tensor.h"
|
||||
@ -38,12 +36,18 @@ constexpr size_t DEFAULT_BLOCK_SIZE = 1024;
|
||||
// they must be kept in sync
|
||||
enum class AllReduceStrategyType : int8_t
|
||||
{
|
||||
RING = 0,
|
||||
NCCL = 0,
|
||||
ONESHOT = 1,
|
||||
TWOSHOT = 2,
|
||||
AUTO = 3,
|
||||
};
|
||||
|
||||
enum class AllReduceStrategyConfig : int8_t
|
||||
{
|
||||
USE_MEMCPY = 1 << 0,
|
||||
PULL_MODE = 1 << 1,
|
||||
};
|
||||
|
||||
struct AllReduceParams
|
||||
{
|
||||
size_t elts_total;
|
||||
@ -56,16 +60,14 @@ struct AllReduceParams
|
||||
uint32_t* peer_barrier_ptrs_out[MAX_RANKS_PER_NODE];
|
||||
void* peer_comm_buffer_ptrs[MAX_RANKS_PER_NODE];
|
||||
void* local_output_buffer_ptr;
|
||||
void const* local_input_buffer_ptr;
|
||||
|
||||
static AllReduceParams deserialize(int32_t const* buffer, size_t tpSize, size_t tpRank, uint32_t flag_value);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void invokeOneOrTwoShotAllReduceKernel(AllReduceParams& param, AllReduceStrategyType strat, cudaStream_t stream);
|
||||
bool configurationSupported(AllReduceStrategyType algo, size_t msg_size, size_t n_ranks, nvinfer1::DataType type);
|
||||
|
||||
void invokeMultiGpuBarrier(AllReduceParams& param, cudaStream_t stream);
|
||||
|
||||
void customAllReduce(kernels::AllReduceParams& params, void* data, size_t elts, size_t size_per_elem,
|
||||
common::datatype_enum dataType, AllReduceStrategyType strat, cudaStream_t stream);
|
||||
void customAllReduce(kernels::AllReduceParams& params, nvinfer1::DataType dataType, AllReduceStrategyType strat,
|
||||
AllReduceStrategyConfig config, cudaStream_t stream);
|
||||
|
||||
} // namespace tensorrt_llm::kernels
|
||||
|
||||
@ -148,6 +148,7 @@ struct Multihead_attention_params_base
|
||||
float const* qkv_scale_quant_orig = nullptr;
|
||||
float const* attention_out_scale_orig_quant = nullptr;
|
||||
|
||||
// 8 bits kv cache scales.
|
||||
float const* kv_scale_orig_quant = nullptr;
|
||||
float const* kv_scale_quant_orig = nullptr;
|
||||
|
||||
|
||||
@ -1638,13 +1638,13 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
|
||||
{
|
||||
if (HANDLE_KV)
|
||||
{
|
||||
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, params.rotary_embedding_base,
|
||||
params.rotary_embedding_scale, current_pos_idx);
|
||||
apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, rotary_embedding_base,
|
||||
rotary_embedding_scale, current_pos_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, params.rotary_embedding_base,
|
||||
params.rotary_embedding_scale, current_pos_idx);
|
||||
apply_rotary_embedding(
|
||||
q, tidx, params.rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale, current_pos_idx);
|
||||
}
|
||||
break;
|
||||
}
|
||||
@ -2589,6 +2589,9 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Quantized output only supports fp8 currently, which should be used together with FP8 Context FMHA.
|
||||
using Quantized_t = __nv_fp8_e4m3;
|
||||
using Quantized_vec = typename packed_type<__nv_fp8_e4m3, num_elems<V_vec_accum>::value>::type;
|
||||
auto const bhi = tensorrt_llm::common::flat_index2(batch_beam_idx, hi, num_heads);
|
||||
auto const bhi_seq_len_tile = bhi * params.seq_len_tile;
|
||||
// Output the final values.
|
||||
@ -2596,35 +2599,36 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
|
||||
{
|
||||
auto const bhvi = tensorrt_llm::common::flat_index2(bhi, vi, Dh);
|
||||
#ifdef MMHA_USE_FP32_ACCUM_FOR_OUT
|
||||
if (write_attention_quant)
|
||||
if (!MULTI_BLOCK_FLAG)
|
||||
{
|
||||
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec_accum>::value>::type;
|
||||
out = mul<V_vec_accum, float>(*params.attention_out_scale_orig_quant, out);
|
||||
*reinterpret_cast<Packed_Int8_t*>(&(reinterpret_cast<int8_t*>(params.out)[bhvi])) = cast_to_int8(out);
|
||||
}
|
||||
else
|
||||
{
|
||||
if (!MULTI_BLOCK_FLAG)
|
||||
if (write_attention_quant)
|
||||
{
|
||||
out = mul<V_vec_accum, float>(*params.attention_out_scale_orig_quant, out);
|
||||
Quantized_vec final_out;
|
||||
convert_to_fp8(&final_out, out);
|
||||
*reinterpret_cast<Quantized_vec*>(reinterpret_cast<Quantized_t*>(params.out) + bhvi) = final_out;
|
||||
}
|
||||
else
|
||||
{
|
||||
// This makes sure we have coalesced memory access.
|
||||
V_vec_k final_out;
|
||||
convert_from_float(&final_out, out);
|
||||
*reinterpret_cast<V_vec_k*>(¶ms.out[bhvi]) = final_out;
|
||||
}
|
||||
else
|
||||
{
|
||||
// for write partial output to partial_out
|
||||
int partial_out_offset = c_tile * params.batch_size * num_heads * params.hidden_size_per_head;
|
||||
// for write partial statistics to partial_max and partial_sum
|
||||
int partial_stats_offset = bhi_seq_len_tile + c_tile;
|
||||
}
|
||||
else
|
||||
{
|
||||
// for write partial output to partial_out
|
||||
int partial_out_offset = c_tile * params.batch_size * num_heads * params.hidden_size_per_head;
|
||||
// for write partial statistics to partial_max and partial_sum
|
||||
int partial_stats_offset = bhi_seq_len_tile + c_tile;
|
||||
|
||||
// This makes sure we have coalesced memory access.
|
||||
V_vec_k partial_out;
|
||||
convert_from_float(&partial_out, out);
|
||||
*reinterpret_cast<V_vec_k*>(¶ms.partial_out[partial_out_offset + bhvi]) = partial_out;
|
||||
convert_from_float(reinterpret_cast<float*>(¶ms.partial_max[partial_stats_offset]), qk_max);
|
||||
convert_from_float(reinterpret_cast<float*>(¶ms.partial_sum[partial_stats_offset]), sum);
|
||||
}
|
||||
// This makes sure we have coalesced memory access.
|
||||
V_vec_k partial_out;
|
||||
convert_from_float(&partial_out, out);
|
||||
*reinterpret_cast<V_vec_k*>(¶ms.partial_out[partial_out_offset + bhvi]) = partial_out;
|
||||
convert_from_float(reinterpret_cast<float*>(¶ms.partial_max[partial_stats_offset]), qk_max);
|
||||
convert_from_float(reinterpret_cast<float*>(¶ms.partial_sum[partial_stats_offset]), sum);
|
||||
}
|
||||
#else // MMHA_USE_FP32_ACCUM_FOR_OUT
|
||||
*reinterpret_cast<V_vec_accum*>(¶ms.out[bhvi]) = out;
|
||||
@ -2768,13 +2772,25 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
|
||||
|
||||
if (oo == 0 && (Dh == Dh_MAX || oi < Dh))
|
||||
{
|
||||
auto const inv_sum = __fdividef(1.f, final_sum + 1.e-6f);
|
||||
auto const inv_sum = __fdividef(
|
||||
write_attention_quant ? *params.attention_out_scale_orig_quant : 1.f, final_sum + 1.e-6f);
|
||||
|
||||
Tk inv_sum_compute;
|
||||
convert_from_float(&inv_sum_compute, inv_sum);
|
||||
|
||||
thread_accumulated_out = mul<V_vec_k, Tk, V_vec_k>(inv_sum_compute, thread_accumulated_out);
|
||||
*reinterpret_cast<V_vec_k*>(¶ms.out[bhi * Dh + oi]) = thread_accumulated_out;
|
||||
|
||||
if (write_attention_quant)
|
||||
{
|
||||
Quantized_vec final_out;
|
||||
convert_to_fp8(&final_out, thread_accumulated_out);
|
||||
*reinterpret_cast<Quantized_vec*>(reinterpret_cast<Quantized_t*>(params.out) + bhi * Dh + oi)
|
||||
= final_out;
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<V_vec_k*>(¶ms.out[bhi * Dh + oi]) = thread_accumulated_out;
|
||||
}
|
||||
}
|
||||
|
||||
// Reset qk_current_smem and block_counter for the next timestep
|
||||
|
||||
@ -279,7 +279,7 @@ public:
|
||||
// NOTE: MHA kernels should read kv cache that has already been appended with new tokens' kv cache.
|
||||
void const* xqa_q_input_ptr = xqaParams.output;
|
||||
invokeApplyBiasRopeUpdateKVCache<T, KVCacheBuffer, true>(static_cast<T*>(const_cast<void*>(xqaParams.qkv)),
|
||||
static_cast<T*>(const_cast<void*>(xqaParams.output)), kv_cache_buffer,
|
||||
(__nv_fp8_e4m3*) nullptr, static_cast<T*>(const_cast<void*>(xqaParams.output)), kv_cache_buffer,
|
||||
static_cast<T const*>(xqaParams.qkv_bias), xqaParams.sequence_lengths, nullptr, nullptr,
|
||||
xqaParams.batch_size, xqaParams.generation_input_length, xqaParams.cyclic_attention_window_size,
|
||||
xqaParams.sink_token_length, xqaParams.batch_size * beam_width * xqaParams.generation_input_length,
|
||||
@ -287,7 +287,7 @@ public:
|
||||
xqaParams.rotary_embedding_base, xqaParams.rotary_embedding_scale_type, xqaParams.rotary_embedding_scale,
|
||||
xqaParams.rotary_embedding_max_positions, xqaParams.position_embedding_type,
|
||||
xqaParams.medusa_position_offsets, xqaParams.position_shift_enabled, (float*) nullptr, 0, cache_type,
|
||||
xqaParams.kv_scale_orig_quant, true, beam_width, rotary_kernel_launch_cache, stream);
|
||||
xqaParams.kv_scale_orig_quant, true, false, beam_width, rotary_kernel_launch_cache, stream);
|
||||
|
||||
sync_check_cuda_error();
|
||||
|
||||
|
||||
@ -1838,6 +1838,16 @@ inline __device__ Float8_ mul(float a, uint4 b)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
inline __device__ uint2 mul(float a, uint2 b)
|
||||
{
|
||||
uint16_t h = float_to_half(a);
|
||||
uint2 c = mul<uint2, uint16_t, uint2>(h, b);
|
||||
return c;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
inline __device__ uint4 mul(float a, uint4 b)
|
||||
{
|
||||
@ -1877,6 +1887,14 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 mul(float a, __nv_bfloat162 b)
|
||||
{
|
||||
return mul<__nv_bfloat162>(__float2bfloat162_rn(a), b);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b)
|
||||
{
|
||||
@ -1908,6 +1926,14 @@ inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b)
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
inline __device__ bf16_4_t mul(float a, bf16_4_t b)
|
||||
{
|
||||
return mul<bf16_4_t>(__float2bfloat16(a), b);
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <>
|
||||
inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b)
|
||||
{
|
||||
|
||||
@ -734,16 +734,18 @@ __device__ __forceinline__ int4 reduceMaxInt4(int4 const& a, int4 const& b)
|
||||
}
|
||||
|
||||
template <typename T, SizeType BLOCK_SIZE>
|
||||
__global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds,
|
||||
SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots,
|
||||
SizeType const* paths, TokenIdType const* endIds, T const* medusaLogits, T const** logitsPtrs, SizeType batchSize,
|
||||
SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads,
|
||||
SizeType maxTokensPerStep)
|
||||
__global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
|
||||
TokenIdType const* targetIds, SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal,
|
||||
SizeType const* batchSlots, SizeType const* paths, TokenIdType const* endIds, T const** medusaLogits,
|
||||
T const** logitsPtrs, SizeType* curTokensPerStep, SizeType const* targetTokensPerStep, SizeType* bestPathIds,
|
||||
SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftTokens, SizeType maxSeqLen,
|
||||
SizeType maxNumHeads, SizeType maxTokensPerStep)
|
||||
{
|
||||
auto const batchIdx = static_cast<SizeType>(blockIdx.x);
|
||||
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
|
||||
auto const inputLength = sequenceLengths[batchSlot];
|
||||
auto const endId = endIds[batchSlot];
|
||||
auto const numTokensPerStep = curTokensPerStep[batchSlot];
|
||||
auto const maxNumDraftTokens = maxNumHeads + 1;
|
||||
|
||||
int4 partialMax{-1, -1, 0, 0};
|
||||
@ -761,7 +763,7 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
|
||||
{
|
||||
continue;
|
||||
}
|
||||
auto const targetTokenIdx = batchSlot * maxTargetSeqLen + tokenId;
|
||||
auto const targetTokenIdx = batchSlot * maxDraftTokens + tokenId;
|
||||
auto targetToken = targetIds[targetTokenIdx];
|
||||
auto nextIdx = tokenId;
|
||||
|
||||
@ -775,17 +777,16 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
|
||||
acceptedLength = ti;
|
||||
break;
|
||||
}
|
||||
auto const targetTokenIdx = batchSlot * maxTargetSeqLen + tokenId;
|
||||
auto const draftTokenIdx = batchSlot * maxDraftSeqLen + inputLength + tokenId;
|
||||
auto const draftToken = outputIds[draftTokenIdx];
|
||||
|
||||
auto const targetTokenIdx = batchSlot * maxDraftTokens + tokenId;
|
||||
auto const draftTokenIdx = batchSlot * (maxDraftTokens - 1) + tokenId - 1;
|
||||
// In context phase, no draft tokens are given. Set draft token to -1 to get guaranteed rejection
|
||||
auto const draftToken = tokenId >= numTokensPerStep ? -1 : draftIds[draftTokenIdx];
|
||||
// Check if draft tokens are the same as target tokens
|
||||
bool const accepted = draftToken == targetToken;
|
||||
hasEnd = targetToken == endId;
|
||||
if (!accepted || hasEnd)
|
||||
{
|
||||
acceptedLength = hasEnd ? ti - 1 : ti;
|
||||
nextIdx = tokenId;
|
||||
break;
|
||||
}
|
||||
targetToken = targetIds[targetTokenIdx];
|
||||
@ -816,16 +817,16 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
|
||||
|
||||
auto const acceptedLength = totalShared.x;
|
||||
auto const bestPathIdx = totalShared.y;
|
||||
auto const bestNextIdx = totalShared.w;
|
||||
auto const bestNextIdx = numTokensPerStep == 1 ? 0 : totalShared.w;
|
||||
auto const pathOffset = flat_index3(batchSlot, bestPathIdx, 0, maxTokensPerStep, maxNumDraftTokens);
|
||||
for (auto ti = static_cast<SizeType>(threadIdx.x); ti < acceptedLength; ti += static_cast<SizeType>(blockDim.x))
|
||||
{
|
||||
auto const tokenId = paths[pathOffset + ti];
|
||||
auto const targetSrcTokenIdx = batchSlot * maxTargetSeqLen + tokenId;
|
||||
auto const draftDstTokenIdx = batchSlot * maxDraftSeqLen + inputLength + ti;
|
||||
auto const targetSrcTokenIdx = batchSlot * maxDraftTokens + tokenId;
|
||||
auto const outputTokenIdx = batchSlot * maxSeqLen + inputLength + ti;
|
||||
auto const targetToken = targetIds[targetSrcTokenIdx];
|
||||
// Copy accepted tokens to the sequence with draft tokens (outputIds === outputIds)
|
||||
outputIds[draftDstTokenIdx] = targetToken;
|
||||
outputIds[outputTokenIdx] = targetToken;
|
||||
}
|
||||
|
||||
// Leading thread reconstructs winning path and sets new data
|
||||
@ -840,41 +841,135 @@ __global__ void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdT
|
||||
// Make correction to the sequence length
|
||||
sequenceLengths[batchSlot] += acceptedLength;
|
||||
acceptedLengths[batchSlot] = acceptedLength;
|
||||
// In Medusa decoding step, number of draft tokens is 0 and must be updated for the next steps
|
||||
if (numTokensPerStep == 1)
|
||||
{
|
||||
curTokensPerStep[batchSlot] = targetTokensPerStep[batchSlot];
|
||||
}
|
||||
bestPathIds[batchSlot] = bestPathIdx;
|
||||
}
|
||||
|
||||
// Prepare logits pointers to respective logits from Medusa Heads for the all-top-K sampling kernel
|
||||
for (auto hi = static_cast<SizeType>(threadIdx.x); hi < maxNumHeads; hi += static_cast<SizeType>(blockDim.x))
|
||||
{
|
||||
logitsPtrs[batchIdx * maxNumHeads + hi]
|
||||
= medusaLogits + flat_index4(hi, batchIdx, bestNextIdx, 0, maxBatchSize, maxTokensPerStep, vocabSize);
|
||||
= medusaLogits[batchSlot * maxNumHeads + hi] + flat_index2(bestNextIdx, 0, vocabSize);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds, SizeType* sequenceLengths,
|
||||
SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots, SizeType const* paths,
|
||||
TokenIdType const* endIds, T const* medusaLogits, T const** logitsPtrs, SizeType batchSize, SizeType vocabSize,
|
||||
SizeType maxBatchSize, SizeType maxDraftSeqLen, SizeType maxTargetSeqLen, SizeType maxNumHeads,
|
||||
void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds, TokenIdType const* targetIds,
|
||||
SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots,
|
||||
SizeType const* paths, TokenIdType const* endIds, T const** medusaLogits, T const** logitsPtrs,
|
||||
SizeType* curTokensPerStep, SizeType const* targetTokensPerStep, SizeType* bestPathIds, SizeType batchSize,
|
||||
SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftTokens, SizeType maxSeqLen, SizeType maxNumHeads,
|
||||
SizeType maxTokensPerStep, cudaStream_t stream)
|
||||
{
|
||||
constexpr SizeType BLOCK_SIZE = 256;
|
||||
dim3 block(BLOCK_SIZE);
|
||||
dim3 grid(batchSize);
|
||||
acceptDraftTokensByIdsWithPaths<T, BLOCK_SIZE><<<grid, block, 0, stream>>>(outputIds, targetIds, sequenceLengths,
|
||||
acceptedLengths, finishedFinal, batchSlots, paths, endIds, medusaLogits, logitsPtrs, batchSize, vocabSize,
|
||||
maxBatchSize, maxDraftSeqLen, maxTargetSeqLen, maxNumHeads, maxTokensPerStep);
|
||||
acceptDraftTokensByIdsWithPaths<T, BLOCK_SIZE><<<grid, block, 0, stream>>>(outputIds, draftIds, targetIds,
|
||||
sequenceLengths, acceptedLengths, finishedFinal, batchSlots, paths, endIds, medusaLogits, logitsPtrs,
|
||||
curTokensPerStep, targetTokensPerStep, bestPathIds, batchSize, vocabSize, maxBatchSize, maxDraftTokens,
|
||||
maxSeqLen, maxNumHeads, maxTokensPerStep);
|
||||
}
|
||||
|
||||
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds,
|
||||
SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots,
|
||||
SizeType const* paths, TokenIdType const* endIds, float const* medusaLogits, float const** logitsPtrs,
|
||||
SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftSeqLen, SizeType maxTargetSeqLen,
|
||||
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
|
||||
TokenIdType const* targetIds, SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal,
|
||||
SizeType const* batchSlots, SizeType const* paths, TokenIdType const* endIds, float const** medusaLogits,
|
||||
float const** logitsPtrs, SizeType* curTokensPerStep, SizeType const* targetTokensPerStep, SizeType* bestPathIds,
|
||||
SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftTokens, SizeType maxSeqLen,
|
||||
SizeType maxNumHeads, SizeType maxTokensPerStep, cudaStream_t stream);
|
||||
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* targetIds,
|
||||
SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal, SizeType const* batchSlots,
|
||||
SizeType const* paths, TokenIdType const* endIds, half const* medusaLogits, half const** logitsPtrs,
|
||||
SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, int32_t maxDraftSeqLen, SizeType maxTargetSeqLen,
|
||||
template void acceptDraftTokensByIdsWithPaths(TokenIdType* outputIds, TokenIdType const* draftIds,
|
||||
TokenIdType const* targetIds, SizeType* sequenceLengths, SizeType* acceptedLengths, FinishedState* finishedFinal,
|
||||
SizeType const* batchSlots, SizeType const* paths, TokenIdType const* endIds, half const** medusaLogits,
|
||||
half const** logitsPtrs, SizeType* curTokensPerStep, SizeType const* targetTokensPerStep, SizeType* bestPathIds,
|
||||
SizeType batchSize, SizeType vocabSize, SizeType maxBatchSize, SizeType maxDraftTokens, SizeType maxSeqLen,
|
||||
SizeType maxNumHeads, SizeType maxTokensPerStep, cudaStream_t stream);
|
||||
|
||||
__global__ void scatterMedusaDraftTokens(TokenIdType* treeDraftIds, TokenIdType const* sourceDraftIds,
|
||||
SizeType const* treeIds, SizeType const* tokensPerStepData, SizeType const* batchSlots, SizeType maxTokensPerStep)
|
||||
{
|
||||
auto const batchIdx = static_cast<SizeType>(blockIdx.x);
|
||||
auto const batchSlot = batchSlots == nullptr ? batchIdx : batchSlots[batchIdx];
|
||||
auto const tokensPerStep = tokensPerStepData[batchSlot];
|
||||
auto const maxDraftTokens = maxTokensPerStep - 1;
|
||||
for (auto index = static_cast<SizeType>(threadIdx.x); index < tokensPerStep - 1;
|
||||
index += static_cast<SizeType>(blockDim.x))
|
||||
{
|
||||
auto const indexInTree = treeIds[batchSlot * maxDraftTokens + index];
|
||||
auto const treeDraftIdx = batchSlot * maxDraftTokens + index;
|
||||
auto const sourceDraftIdx = batchSlot * maxTokensPerStep + indexInTree;
|
||||
treeDraftIds[treeDraftIdx] = sourceDraftIds[sourceDraftIdx];
|
||||
}
|
||||
}
|
||||
|
||||
void scatterMedusaDraftTokens(TokenIdType* treeDraftIds, TokenIdType const* sourceDraftIds, SizeType const* treeIds,
|
||||
SizeType const* tokensPerStep, SizeType const* batchSlots, SizeType maxDraftTokens, SizeType batchSize,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
constexpr SizeType BLOCK_SIZE = 256;
|
||||
scatterMedusaDraftTokens<<<batchSize, BLOCK_SIZE, 0, stream>>>(
|
||||
treeDraftIds, sourceDraftIds, treeIds, tokensPerStep, batchSlots, maxDraftTokens);
|
||||
}
|
||||
|
||||
template <int32_t BLOCK_SIZE>
|
||||
__global__ void packAcceptedPaths(SizeType* acceptedLengthsCumSum, SizeType* pathsOffsets,
|
||||
SizeType const* acceptedLengths, SizeType const* bestPathIds, SizeType const* paths, SizeType const* batchSlots,
|
||||
SizeType batchSize, SizeType maxTokensPerStep, SizeType maxNumDraftTokens)
|
||||
{
|
||||
// Specialize BlockScan for a 1D block of 128 threads of type int
|
||||
typedef cub::BlockScan<SizeType, BLOCK_SIZE> BlockScan;
|
||||
|
||||
// Allocate shared memory for BlockScan
|
||||
__shared__ typename BlockScan::TempStorage tempStorage;
|
||||
auto const batchSizeRounded = ((batchSize + BLOCK_SIZE - 1) / BLOCK_SIZE) * BLOCK_SIZE;
|
||||
__shared__ SizeType currentCumSum;
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
currentCumSum = 0;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for (SizeType bi = static_cast<SizeType>(threadIdx.x); bi < batchSizeRounded;
|
||||
bi += static_cast<SizeType>(blockDim.x))
|
||||
{
|
||||
auto const valid = bi < batchSize;
|
||||
auto const batchSlot = valid ? batchSlots[bi] : 0;
|
||||
auto const acceptedLen = valid ? acceptedLengths[batchSlot] - 1 : 0;
|
||||
SizeType cumSum;
|
||||
BlockScan(tempStorage).ExclusiveSum(acceptedLen + currentCumSum, cumSum);
|
||||
if (threadIdx.x == blockDim.x - 1)
|
||||
{
|
||||
currentCumSum = cumSum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (valid)
|
||||
{
|
||||
acceptedLengthsCumSum[bi] = cumSum;
|
||||
auto const bestPathIdx = bestPathIds[batchSlot];
|
||||
auto const pathIdx = flat_index3(batchSlot, bestPathIdx, 0, maxTokensPerStep, maxNumDraftTokens);
|
||||
for (SizeType ti = 0; ti < acceptedLen; ++ti)
|
||||
{
|
||||
pathsOffsets[cumSum + ti] = paths[pathIdx + ti + 1] - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (threadIdx.x == 0)
|
||||
{
|
||||
acceptedLengthsCumSum[batchSize] = currentCumSum;
|
||||
}
|
||||
}
|
||||
|
||||
void invokePackAcceptedPaths(SizeType* acceptedLengthsCumSum, SizeType* pathsOffsets, SizeType const* acceptedLengths,
|
||||
SizeType const* bestPathIds, SizeType const* paths, SizeType const* batchSlots, SizeType batchSize,
|
||||
SizeType maxTokensPerStep, SizeType maxNumDraftTokens, cudaStream_t stream)
|
||||
{
|
||||
constexpr SizeType BLOCK_SIZE = 1024;
|
||||
packAcceptedPaths<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, stream>>>(acceptedLengthsCumSum, pathsOffsets, acceptedLengths,
|
||||
bestPathIds, paths, batchSlots, batchSize, maxTokensPerStep, maxNumDraftTokens);
|
||||
}
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -161,11 +161,9 @@ void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_ti
|
||||
//! accepted tokens. Fills logitsPtrs tensor with the pointers to the respective medusa logits tensor according
|
||||
//! to the next after the last accepted token.
|
||||
//!
|
||||
//! \param outputIds input/output buffer [maxBatchSize, maxDraftSeqLen],
|
||||
//! input tokens followed by draft tokens to be verified.
|
||||
//! After accepting tokens, gets overwritten such that input tokens are followed by the accepted tokens
|
||||
//! and one additional token -- next after the last accepted.
|
||||
//! \param targetIds input buffer [maxBatchSize, maxTargetSeqLen], tokens predicted from the target medusa head
|
||||
//! \param outputIds output buffer [maxBatchSize, maxSeqLen], input tokens.
|
||||
//! \param draftIds input buffer [maxBatchSize, maxDraftTokens], draft tokens
|
||||
//! \param targetIds input buffer [maxBatchSize, maxDraftTokens], tokens predicted from the target medusa head
|
||||
//! \param sequenceLengths input/output buffer [maxBatchSize], length of the data in outputIds without draft tokens
|
||||
//! Incrememnted according to the accepted length
|
||||
//! \param acceptedLengths output buffer [maxBatchSize], length of the data accepted tokens
|
||||
@ -179,20 +177,65 @@ void invokeTransposeLogProbs(float* output_log_probs, float* output_log_probs_ti
|
||||
//! to the logits from medusa heads
|
||||
//! \param logitsPtrs output buffer [batchSize, maxNumHeads], contains pointers to the
|
||||
//! respective rows of the medusaLogits for the next after the accepted token
|
||||
//! \param curTokensPerStep current tokens to compute per step will be updated to
|
||||
//! targetTokensPerStep if curTokensPerStep == 1
|
||||
//! \param targetTokensPerStep target values of tokens to compute per step
|
||||
//! \param bestPathIds output buffer [maxBatchSize], indices of the selected paths
|
||||
//! \param batchSize current batch size
|
||||
//! \param maxBatchSize maximum batch size
|
||||
//! \param vocabSize vocab size
|
||||
//! \param maxDraftSeqLen maximum sequence length of the sequence containing draft tokens
|
||||
//! \param maxTargetSeqLen maximum sequence length predicted from target head
|
||||
//! \param maxDraftTokens maximum sequence length of the sequence containing draft tokens
|
||||
//! \param maxSeqLen maximum sequence length of output ids
|
||||
//! \param maxNumHeads maximum number of medusa heads
|
||||
//! \param maxTokensPerStep maximum number of tokens per step configured in the system
|
||||
//! \param stream stream
|
||||
template <typename T>
|
||||
void acceptDraftTokensByIdsWithPaths(runtime::TokenIdType* outputIds, runtime::TokenIdType const* targetIds,
|
||||
runtime::SizeType* sequenceLengths, runtime::SizeType* acceptedLengths, FinishedState* finishedFinal,
|
||||
runtime::SizeType const* batchSlots, runtime::SizeType const* paths, runtime::TokenIdType const* endIds,
|
||||
T const* medusaLogits, T const** logitsPtrs, runtime::SizeType batchSize, runtime::SizeType maxBatchSize,
|
||||
runtime::SizeType vocabSize, runtime::SizeType maxDraftSeqLen, runtime::SizeType maxTargetSeqLen,
|
||||
runtime::SizeType maxNumHeads, runtime::SizeType maxTokensPerStep, cudaStream_t stream);
|
||||
void acceptDraftTokensByIdsWithPaths(runtime::TokenIdType* outputIds, runtime::TokenIdType const* draftIds,
|
||||
runtime::TokenIdType const* targetIds, runtime::SizeType* sequenceLengths, runtime::SizeType* acceptedLengths,
|
||||
FinishedState* finishedFinal, runtime::SizeType const* batchSlots, runtime::SizeType const* paths,
|
||||
runtime::TokenIdType const* endIds, T const** medusaLogits, T const** logitsPtrs,
|
||||
runtime::SizeType* curTokensPerStep, runtime::SizeType const* targetTokensPerStep, runtime::SizeType* bestPathIds,
|
||||
runtime::SizeType batchSize, runtime::SizeType maxBatchSize, runtime::SizeType vocabSize,
|
||||
runtime::SizeType maxDraftTokens, runtime::SizeType maxSeqLen, runtime::SizeType maxNumHeads,
|
||||
runtime::SizeType maxTokensPerStep, cudaStream_t stream);
|
||||
|
||||
//! \brief assembles draft tokens to treeDraftIds from sourceDraftIds using indices of treeIds
|
||||
//!
|
||||
//! \param treeDraftIds output buffer [maxBatchSize, maxDraftTokens], output draft tokens
|
||||
//! scattered from sourceDraftIds according to treeIds111
|
||||
//! \param sourceDraftIds input buffer [maxBatchSize, maxDraftTokens], draft tokens saved leanearly after
|
||||
//! sampling from Medusa heads with TopK.
|
||||
//! \param treeIds input buffer [maxBatchSize, maxDraftTokens], address map from sourceDraftIds to treeDraftIds
|
||||
//! [0, unqiueDraftTokens] -> [0, maxDraftTokens], where unqiueDraftTokens = sum(MedusaHeadsTopK)
|
||||
//! unqiueDraftTokens <= maxDraftTokens
|
||||
//! \param tokensPerStep input buffer [maxBatchSize], number of output draft tokens
|
||||
//! \param batchSlots input buffer [maxBatchSize], address map from local index
|
||||
//! to global index [0, batchSize] -> [0, maxBatchSize]
|
||||
//! \param maxDraftTokens maximum number of tokens per step configured in the system
|
||||
//! \param batchSize current batch size
|
||||
//! \param stream cuda stream
|
||||
void scatterMedusaDraftTokens(runtime::TokenIdType* treeDraftIds, runtime::TokenIdType const* sourceDraftIds,
|
||||
runtime::SizeType const* treeIds, runtime::SizeType const* tokensPerStep, runtime::SizeType const* batchSlots,
|
||||
runtime::SizeType maxDraftTokens, runtime::SizeType batchSize, cudaStream_t stream);
|
||||
|
||||
//! \brief Linearly packs accepted paths in memory according to the accceptedLengths and bestPathIds
|
||||
//!
|
||||
//! \param acceptedLengthsCumSum input buffer [maxBatchSize + 1], exclusive sum of accepted lengths
|
||||
//! (indexed linearly in memory).
|
||||
//! \param pathsOffsets input buffer [maxBatchSize * maxDraftLen], slices of accepted paths packed in memory
|
||||
//! \param acceptedLengths input buffer [maxBatchSize], length of the data accepted tokens
|
||||
//! \param bestPathIds input buffer [maxBatchSize], indices of the selected paths
|
||||
//! \param paths input buffer [maxBatchSize, maxTokensPerStep, maxNumHeads+1],
|
||||
//! paths to restore sequences from outputIds and targetIds. Should be filled with -1 for everything that is not path.
|
||||
//! \param batchSlots input buffer [batchSize], address map from local index
|
||||
//! to global index [0, batchSize] -> [0, maxBatchSize]
|
||||
//! \param batchSize current batch size
|
||||
//! \param maxTokensPerStep maximum number of tokens per step configured in the system
|
||||
//! \param maxDraftTokens maximum sequence length of the sequence containing draft tokens
|
||||
//! \param stream stream
|
||||
void invokePackAcceptedPaths(runtime::SizeType* acceptedLengthsCumSum, runtime::SizeType* pathsOffsets,
|
||||
runtime::SizeType const* acceptedLengths, runtime::SizeType const* bestPathIds, runtime::SizeType const* paths,
|
||||
runtime::SizeType const* batchSlots, runtime::SizeType batchSize, runtime::SizeType maxTokensPerStep,
|
||||
runtime::SizeType maxNumDraftTokens, cudaStream_t stream);
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
1183
cpp/tensorrt_llm/kernels/mambaConv1dKernels.cu
Normal file
1183
cpp/tensorrt_llm/kernels/mambaConv1dKernels.cu
Normal file
File diff suppressed because it is too large
Load Diff
64
cpp/tensorrt_llm/kernels/mambaConv1dKernels.h
Normal file
64
cpp/tensorrt_llm/kernels/mambaConv1dKernels.h
Normal file
@ -0,0 +1,64 @@
|
||||
/*
|
||||
* Adapted from https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.h
|
||||
* Copyright (c) 2023, Tri Dao.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*
|
||||
* Not a contribution
|
||||
* Changes made by NVIDIA CORPORATION & AFFILIATES or otherwise documented as
|
||||
* NVIDIA-proprietary are not a contribution and subject to the following terms and conditions:
|
||||
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: LicenseRef-NvidiaProprietary
|
||||
*
|
||||
* NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
|
||||
* property and proprietary rights in and to this material, related
|
||||
* documentation and any modifications thereto. Any use, reproduction,
|
||||
* disclosure or distribution of this material and related documentation
|
||||
* without an express license agreement from NVIDIA CORPORATION or
|
||||
* its affiliates is strictly prohibited.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include "tensorrt_llm/common/cudaUtils.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
|
||||
struct MambaConv1dParamsBase
|
||||
{
|
||||
int batch, dim, max_seqlen, dconv;
|
||||
bool remove_padding;
|
||||
void* __restrict__ in_ptr;
|
||||
void* state_in_ptr;
|
||||
void* state_out_ptr;
|
||||
void* __restrict__ weight_ptr;
|
||||
void* __restrict__ bias_ptr;
|
||||
void* __restrict__ out_ptr;
|
||||
int const* __restrict__ last_token_ids_ptr;
|
||||
int const* __restrict__ state_slot_mapping_ptr;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <typename input_t>
|
||||
void invokeMambaConv1dContext(MambaConv1dParamsBase& params, cudaStream_t stream);
|
||||
|
||||
template <typename input_t>
|
||||
void invokeMambaConv1dGeneration(MambaConv1dParamsBase& params, cudaStream_t stream);
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -16,6 +16,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "tensorrt_llm/common/assert.h"
|
||||
#include <limits.h>
|
||||
#include <stdint.h>
|
||||
|
||||
@ -36,6 +37,25 @@ enum Data_type
|
||||
DATA_TYPE_E5M2
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static inline size_t get_size_in_bytes(size_t n, Data_type dtype)
|
||||
{
|
||||
switch (dtype)
|
||||
{
|
||||
case DATA_TYPE_FP32: return n * 4;
|
||||
case DATA_TYPE_FP16: return n * 2;
|
||||
case DATA_TYPE_INT32: return n * 4;
|
||||
case DATA_TYPE_INT8: return n;
|
||||
case DATA_TYPE_BF16: return n * 2;
|
||||
case DATA_TYPE_E4M3: return n;
|
||||
case DATA_TYPE_E5M2: return n;
|
||||
default: TLLM_CHECK_WITH_INFO(false, "FMHA Data Type is not supported."); return 0;
|
||||
}
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
constexpr int32_t kSM_70 = 70;
|
||||
constexpr int32_t kSM_72 = 72;
|
||||
constexpr int32_t kSM_75 = 75;
|
||||
|
||||
@ -41,6 +41,8 @@ static int const SMALL_TOP_K_SOFTMAX_THREADBLOCK_SIZE = 256;
|
||||
|
||||
#define TOPK_FP16_STORAGE 0
|
||||
|
||||
#pragma nv_diag_suppress static_var_with_dynamic_init
|
||||
|
||||
template <typename T, int MAX_K2, int THREADBLOCK_SIZE>
|
||||
__launch_bounds__(THREADBLOCK_SIZE) __global__ void batch_topk_kernel(
|
||||
int const* __restrict topk_id, T const* __restrict topk_val, BeamHypotheses bh, int const candidate_size)
|
||||
|
||||
@ -31,7 +31,7 @@ template <typename KVCacheBuffer, int MaxLayerCount, typename MoveEltType>
|
||||
__global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheBuffer, MaxLayerCount> kvCacheBuffers,
|
||||
int const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
|
||||
int32_t const* pastKeyValueLengths, int rewindDraftTokenCommonCount, int const* rewindDraftTokenSeparateAdjustments,
|
||||
int eltCountPerHead)
|
||||
int const* seqSlotRemapping, int eltCountPerHead)
|
||||
{
|
||||
int seqIdx = blockIdx.x;
|
||||
int headIdx = blockIdx.y;
|
||||
@ -41,16 +41,17 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
|
||||
int laneIdx = threadIdx.x & 0x1f;
|
||||
int seqDraftTokenStart = seqAcceptedDraftTokenOffsets[seqIdx];
|
||||
int seqDraftTokenEnd = seqAcceptedDraftTokenOffsets[seqIdx + 1];
|
||||
auto const seqSlot = seqSlotRemapping == nullptr ? seqIdx : seqSlotRemapping[seqIdx];
|
||||
int seqDraftCount = seqDraftTokenEnd - seqDraftTokenStart;
|
||||
if (seqDraftCount == 0)
|
||||
{
|
||||
return;
|
||||
}
|
||||
KVCacheBuffer& kvCacheBuffer = kvCacheBuffers[layerIdx];
|
||||
int tokenStartIdx = pastKeyValueLengths[seqIdx] - rewindDraftTokenCommonCount;
|
||||
int tokenStartIdx = pastKeyValueLengths[seqSlot] - rewindDraftTokenCommonCount;
|
||||
if (rewindDraftTokenSeparateAdjustments != nullptr)
|
||||
{
|
||||
tokenStartIdx -= rewindDraftTokenSeparateAdjustments[seqIdx];
|
||||
tokenStartIdx -= rewindDraftTokenSeparateAdjustments[seqSlot];
|
||||
}
|
||||
int maxEltCountPerMove = kUpdateKVCacheKernelShmSize / sizeof(MoveEltType) / seqDraftCount;
|
||||
int eltCountPerMove = min(maxEltCountPerMove, eltCountPerHead);
|
||||
@ -65,7 +66,7 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
|
||||
int tokenPos = packedAcceptedDraftTokensIndices[seqDraftTokenStart + tokenIdx];
|
||||
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenIdx * eltCountCurrentMove;
|
||||
int tokenKVPosition = tokenStartIdx + tokenPos;
|
||||
auto* kPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getKBlockPtr(seqIdx, tokenKVPosition));
|
||||
auto* kPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getKBlockPtr(seqSlot, tokenKVPosition));
|
||||
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
|
||||
{
|
||||
int channelIdx = loadChannelIdx + startChannelOffset;
|
||||
@ -80,7 +81,7 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
|
||||
int tokenPos = tokenIdx;
|
||||
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenIdx * eltCountCurrentMove;
|
||||
int tokenKVPosition = tokenStartIdx + tokenPos;
|
||||
auto* kPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getKBlockPtr(seqIdx, tokenKVPosition));
|
||||
auto* kPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getKBlockPtr(seqSlot, tokenKVPosition));
|
||||
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
|
||||
{
|
||||
int channelIdx = loadChannelIdx + startChannelOffset;
|
||||
@ -95,7 +96,7 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
|
||||
int tokenPos = packedAcceptedDraftTokensIndices[seqDraftTokenStart + tokenIdx];
|
||||
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenIdx * eltCountCurrentMove;
|
||||
int tokenKVPosition = tokenStartIdx + tokenPos;
|
||||
auto* vPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getVBlockPtr(seqIdx, tokenKVPosition));
|
||||
auto* vPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getVBlockPtr(seqSlot, tokenKVPosition));
|
||||
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
|
||||
{
|
||||
int channelIdx = loadChannelIdx + startChannelOffset;
|
||||
@ -110,7 +111,7 @@ __global__ void updateKVCacheDraftTokenLocationBatchedKernel(std::array<KVCacheB
|
||||
int tokenPos = tokenIdx;
|
||||
auto* tokenSmemBuffer = eltLoadSmemBuffer + tokenPos * eltCountCurrentMove;
|
||||
int tokenKVPosition = tokenStartIdx + tokenPos;
|
||||
auto* vPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getVBlockPtr(seqIdx, tokenKVPosition));
|
||||
auto* vPtr = reinterpret_cast<MoveEltType*>(kvCacheBuffer.getVBlockPtr(seqSlot, tokenKVPosition));
|
||||
for (int loadChannelIdx = laneIdx; loadChannelIdx < eltCountCurrentMove; loadChannelIdx += 32)
|
||||
{
|
||||
int channelIdx = loadChannelIdx + startChannelOffset;
|
||||
@ -126,7 +127,8 @@ template <typename KVCacheBuffer, int MaxLayerCount>
|
||||
void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
|
||||
int const* seqAcceptedDraftTokenOffsets, IndexType const* packedAcceptedDraftTokensIndices,
|
||||
int32_t const* pastKeyValueLengths, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, cudaStream_t stream)
|
||||
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
// make sure launch buffer is enough
|
||||
static_assert(MaxLayerCount * sizeof(KVCacheBuffer) <= 3072);
|
||||
@ -148,8 +150,8 @@ void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
|
||||
{
|
||||
kvCacheBufferArray[i] = kvCacheBuffers[i];
|
||||
}
|
||||
void (*pKernelFunc)(
|
||||
std::array<KVCacheBuffer, MaxLayerCount>, int const*, IndexType const*, int32_t const*, int, int const*, int)
|
||||
void (*pKernelFunc)(std::array<KVCacheBuffer, MaxLayerCount>, int const*, IndexType const*, int32_t const*, int,
|
||||
int const*, int const*, int)
|
||||
= nullptr;
|
||||
switch (alignedBytes)
|
||||
{
|
||||
@ -182,7 +184,7 @@ void updateKVCacheDraftTokenLocationBatched(KVCacheBuffer const* kvCacheBuffers,
|
||||
}
|
||||
pKernelFunc<<<grid, block, 0, stream>>>(kvCacheBufferArray, seqAcceptedDraftTokenOffsets,
|
||||
packedAcceptedDraftTokensIndices, pastKeyValueLengths, rewindDraftTokenCommonCount,
|
||||
rewindDraftTokenSeparateAdjustments, eltCountPerHead);
|
||||
rewindDraftTokenSeparateAdjustments, seqSlotRemapping, eltCountPerHead);
|
||||
TLLM_CUDA_CHECK(cudaGetLastError());
|
||||
}
|
||||
|
||||
@ -207,7 +209,7 @@ template <typename KVCacheBuffer>
|
||||
void updateKVCacheDraftTokenLocation(KVCacheBuffer const* kvCacheBuffers, int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int layerCount, int seqCount,
|
||||
int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments,
|
||||
cudaStream_t stream)
|
||||
int const* seqSlotRemapping, cudaStream_t stream)
|
||||
{
|
||||
int startLayer = 0;
|
||||
static constexpr int kMaxLayersPerIter = 32;
|
||||
@ -217,7 +219,7 @@ void updateKVCacheDraftTokenLocation(KVCacheBuffer const* kvCacheBuffers, int co
|
||||
updateKVCacheDraftTokenLocationBatched<KVCacheBuffer, kMaxLayersPerIter>(kvCacheBuffers + startLayer,
|
||||
seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices, pastKeyValueLengths, microBatchLayerCount,
|
||||
seqCount, numKVHeads, sizeInBytesPerKVHead, rewindDraftTokenCommonCount,
|
||||
rewindDraftTokenSeparateAdjustments, stream);
|
||||
rewindDraftTokenSeparateAdjustments, seqSlotRemapping, stream);
|
||||
startLayer += microBatchLayerCount;
|
||||
}
|
||||
}
|
||||
@ -225,7 +227,8 @@ void updateKVCacheDraftTokenLocation(KVCacheBuffer const* kvCacheBuffers, int co
|
||||
void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, cudaStream_t stream)
|
||||
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
|
||||
int maxKVCacheLen, cudaStream_t stream)
|
||||
{
|
||||
std::vector<KVLinearBuffer> kvLinearBuffers;
|
||||
kvLinearBuffers.reserve(layerCount);
|
||||
@ -237,14 +240,14 @@ void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffse
|
||||
}
|
||||
updateKVCacheDraftTokenLocation(kvLinearBuffers.data(), seqAcceptedDraftTokenOffsets,
|
||||
packedAcceptedDraftTokensIndices, pastKeyValueLengths, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
|
||||
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, stream);
|
||||
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, seqSlotRemapping, stream);
|
||||
}
|
||||
|
||||
void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray,
|
||||
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount,
|
||||
int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock,
|
||||
cudaStream_t stream)
|
||||
int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq,
|
||||
int tokensPerBlock, cudaStream_t stream)
|
||||
{
|
||||
std::vector<KVBlockArray> kvBlockArrays;
|
||||
kvBlockArrays.reserve(layerCount);
|
||||
@ -256,47 +259,47 @@ void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffset
|
||||
}
|
||||
updateKVCacheDraftTokenLocation(kvBlockArrays.data(), seqAcceptedDraftTokenOffsets,
|
||||
packedAcceptedDraftTokensIndices, pastKeyValueLengths, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
|
||||
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, stream);
|
||||
rewindDraftTokenCommonCount, rewindDraftTokenSeparateAdjustments, seqSlotRemapping, stream);
|
||||
}
|
||||
|
||||
void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCount, int maxKVCacheLen, cudaStream_t stream)
|
||||
int rewindDraftTokenCount, int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream)
|
||||
{
|
||||
updateLinearKVCacheDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
|
||||
pastKeyValueLengths, pastKeyValueList, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
|
||||
rewindDraftTokenCount, nullptr, maxKVCacheLen, stream);
|
||||
rewindDraftTokenCount, nullptr, seqSlotRemapping, maxKVCacheLen, stream);
|
||||
}
|
||||
|
||||
void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray,
|
||||
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount,
|
||||
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
|
||||
int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
|
||||
{
|
||||
updateKVBlockArrayDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
|
||||
pastKeyValueLengths, pointerArray, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead,
|
||||
rewindDraftTokenCount, nullptr, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
|
||||
rewindDraftTokenCount, nullptr, seqSlotRemapping, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
|
||||
}
|
||||
|
||||
void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int* rewindDraftTokenCounts, int maxKVCacheLen, cudaStream_t stream)
|
||||
int* rewindDraftTokenCounts, int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream)
|
||||
{
|
||||
updateLinearKVCacheDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
|
||||
pastKeyValueLengths, pastKeyValueList, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead, 0,
|
||||
rewindDraftTokenCounts, maxKVCacheLen, stream);
|
||||
rewindDraftTokenCounts, seqSlotRemapping, maxKVCacheLen, stream);
|
||||
}
|
||||
|
||||
void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray,
|
||||
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int* rewindDraftTokenCounts,
|
||||
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
|
||||
int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream)
|
||||
{
|
||||
updateKVBlockArrayDraftTokenLocation(seqAcceptedDraftTokenOffsets, packedAcceptedDraftTokensIndices,
|
||||
pastKeyValueLengths, pointerArray, layerCount, seqCount, numKVHeads, sizeInBytesPerKVHead, 0,
|
||||
rewindDraftTokenCounts, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
|
||||
rewindDraftTokenCounts, seqSlotRemapping, maxKVCacheLen, maxBlocksPerSeq, tokensPerBlock, stream);
|
||||
}
|
||||
|
||||
} // namespace tensorrt_llm::kernels::parallel_decoding
|
||||
|
||||
@ -39,13 +39,17 @@ using IndexType = int;
|
||||
* @param numKVHeads : Number of KV heads
|
||||
* @param sizeInBytesPerKVHead : Size of each KV head
|
||||
* @param rewindDraftTokenCount : Count to rewind
|
||||
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
|
||||
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
|
||||
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
|
||||
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCount, int maxKVCacheLen, cudaStream_t stream);
|
||||
int rewindDraftTokenCount, int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
* Update Block KV cache using common rewind count.
|
||||
@ -59,6 +63,10 @@ void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDra
|
||||
* @param numKVHeads : Number of KV heads
|
||||
* @param sizeInBytesPerKVHead : Size of each KV head
|
||||
* @param rewindDraftTokenCount : Count to rewind
|
||||
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
|
||||
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
|
||||
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
|
||||
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
|
||||
* @param tokensPerBlock : Tokens per block of Block KV cache
|
||||
@ -67,7 +75,7 @@ void updateLinearKVCacheDraftTokenLocationCommonRewind(int const* seqAcceptedDra
|
||||
void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray,
|
||||
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCount,
|
||||
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
|
||||
int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
* Update Linear KV cache using separate rewind count for each sequence.
|
||||
@ -82,13 +90,17 @@ void updateKVBlockArrayDraftTokenLocationCommonRewind(int const* seqAcceptedDraf
|
||||
* @param sizeInBytesPerKVHead : Size of each KV head
|
||||
* @param rewindDraftTokenCounts : Pointer to an array of length seqCount, each element indicated the rewind count of
|
||||
* one sequence.
|
||||
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
|
||||
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
|
||||
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
|
||||
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int* rewindDraftTokenCounts, int maxKVCacheLen, cudaStream_t stream);
|
||||
int* rewindDraftTokenCounts, int const* seqSlotRemapping, int maxKVCacheLen, cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
* Update Block KV cache using separate rewind count for each sequence.
|
||||
@ -103,6 +115,10 @@ void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedD
|
||||
* @param sizeInBytesPerKVHead : Size of each KV head
|
||||
* @param rewindDraftTokenCounts : Pointer to an array of length seqCount, each element indicated the rewind count of
|
||||
* one sequence.
|
||||
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
|
||||
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
|
||||
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
|
||||
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
|
||||
* @param tokensPerBlock : Tokens per block of Block KV cache
|
||||
@ -111,7 +127,7 @@ void updateLinearKVCacheDraftTokenLocationSeparateRewind(int const* seqAcceptedD
|
||||
void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray,
|
||||
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int* rewindDraftTokenCounts,
|
||||
int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
|
||||
int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock, cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
* Update Linear KV cache using both common rewind and separate rewind count for each sequence. The common
|
||||
@ -129,13 +145,18 @@ void updateKVBlockArrayDraftTokenLocationSeparateRewind(int const* seqAcceptedDr
|
||||
* @param rewindDraftTokenCommonCount : Common token count to rewind
|
||||
* @param rewindDraftTokenSeparateAdjustments : Pointer to an array of length seqCount, each element indicated the
|
||||
* rewind adjustment for one sequence.
|
||||
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
|
||||
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
|
||||
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
|
||||
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param stream : CUDA stream to use.
|
||||
*/
|
||||
void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths,
|
||||
int8_t* const* pastKeyValueList, int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead,
|
||||
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, cudaStream_t stream);
|
||||
int rewindDraftTokenCommonCount, int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping,
|
||||
int maxKVCacheLen, cudaStream_t stream);
|
||||
|
||||
/*!
|
||||
* Update Block KV cache using both common rewind and separate rewind count for each sequence. The common
|
||||
@ -153,6 +174,10 @@ void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffse
|
||||
* @param rewindDraftTokenCommonCount : Common token count to rewind
|
||||
* @param rewindDraftTokenSeparateAdjustments : Pointer to an array of length seqCount, each element indicated the
|
||||
* rewind adjustment for one sequence.
|
||||
* @param seqSlotRemapping mapping from batch index to index of the seqSlot in the sorted seqSlot buffer
|
||||
* e.g. for requests [0, 1, 2] with seqSlots [5, 3, 4], seqSlotRemapping is [1, 2, 0]
|
||||
* Required to match seqAcceptedDraftTokenOffsets and packedAcceptedDraftTokensIndices from gptDecoderBatch
|
||||
* and pointerArray and pastKeyValueLengths from runtimeBuffers.
|
||||
* @param maxKVCacheLen : Maximum length of each KV cache
|
||||
* @param maxBlocksPerSeq : Maximum blocks per sequence of Block KV cache.
|
||||
* @param tokensPerBlock : Tokens per block of Block KV cache
|
||||
@ -161,7 +186,7 @@ void updateLinearKVCacheDraftTokenLocation(int const* seqAcceptedDraftTokenOffse
|
||||
void updateKVBlockArrayDraftTokenLocation(int const* seqAcceptedDraftTokenOffsets,
|
||||
IndexType const* packedAcceptedDraftTokensIndices, int32_t const* pastKeyValueLengths, int64_t* const* pointerArray,
|
||||
int layerCount, int seqCount, int numKVHeads, int sizeInBytesPerKVHead, int rewindDraftTokenCommonCount,
|
||||
int* rewindDraftTokenSeparateAdjustments, int maxKVCacheLen, int maxBlocksPerSeq, int tokensPerBlock,
|
||||
cudaStream_t stream);
|
||||
int* rewindDraftTokenSeparateAdjustments, int const* seqSlotRemapping, int maxKVCacheLen, int maxBlocksPerSeq,
|
||||
int tokensPerBlock, cudaStream_t stream);
|
||||
|
||||
} // namespace tensorrt_llm::kernels::parallel_decoding
|
||||
|
||||
@ -150,7 +150,6 @@ __global__ void topKStage2Sampling(int const* __restrict topKTmpIdBuf, T* topKTm
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
if (tokensPerStep != nullptr && tokenIdx >= tokensPerStep[batchSlot])
|
||||
{
|
||||
return;
|
||||
|
||||
@ -65,6 +65,8 @@ __device__ void convertAndStore(__nv_bfloat16* output, float input)
|
||||
}
|
||||
#endif
|
||||
|
||||
#pragma nv_diag_suppress static_var_with_dynamic_init
|
||||
|
||||
template <typename input_t, typename weight_t, int DSTATE = 16, int CHANNELS_PER_BLOCK = 128, int STAGES = 12,
|
||||
int SEQ_UNROLL = 6>
|
||||
__launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBase params)
|
||||
@ -74,8 +76,8 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
|
||||
input_t* x = reinterpret_cast<input_t*>(params.u_ptr);
|
||||
input_t* dt = reinterpret_cast<input_t*>(params.delta_ptr);
|
||||
weight_t* A = reinterpret_cast<weight_t*>(params.A_ptr);
|
||||
input_t* B = reinterpret_cast<input_t*>(params.B_ptr);
|
||||
input_t* C = reinterpret_cast<input_t*>(params.C_ptr);
|
||||
input_t* B = reinterpret_cast<input_t*>(params.BC_ptr);
|
||||
input_t* C = reinterpret_cast<input_t*>(params.BC_ptr);
|
||||
weight_t* D = reinterpret_cast<weight_t*>(params.D_ptr);
|
||||
input_t* z = reinterpret_cast<input_t*>(params.z_ptr);
|
||||
weight_t* dt_bias = reinterpret_cast<weight_t*>(params.delta_bias_ptr);
|
||||
@ -101,11 +103,24 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
|
||||
int const channel = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int const sample = blockIdx.y; // batch id
|
||||
|
||||
int const slot_idx = params.slot_mapping_ptr == nullptr ? sample : params.slot_mapping_ptr[sample];
|
||||
int const bc_cols = DSTATE * 2 + params.dt_rank;
|
||||
int const b_offset = params.dt_rank;
|
||||
int const c_offset = params.dt_rank + DSTATE;
|
||||
|
||||
int num_tokens;
|
||||
int start_token_idx;
|
||||
start_token_idx = sample * params.seqlen;
|
||||
num_tokens = params.last_token_ids_ptr[sample];
|
||||
|
||||
if (params.remove_padding)
|
||||
{
|
||||
start_token_idx = sample == 0 ? 0 : params.last_token_ids_ptr[sample - 1];
|
||||
int end_token_idx = params.last_token_ids_ptr[sample];
|
||||
num_tokens = end_token_idx - start_token_idx;
|
||||
}
|
||||
else
|
||||
{
|
||||
start_token_idx = sample * params.max_seqlen;
|
||||
num_tokens = params.last_token_ids_ptr[sample];
|
||||
}
|
||||
int const seq_loops = (num_tokens + SEQ_UNROLL - 1) / SEQ_UNROLL;
|
||||
|
||||
int const input_matrix_row_id = start_token_idx;
|
||||
@ -132,8 +147,8 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
|
||||
for (int token_id = si * SEQ_UNROLL; token_id < num_tokens && token_id < (si + 1) * SEQ_UNROLL; token_id++)
|
||||
{
|
||||
|
||||
input_t* my_B = &B[input_matrix_row_id * DSTATE + token_id * DSTATE];
|
||||
input_t* my_C = &C[input_matrix_row_id * DSTATE + token_id * DSTATE];
|
||||
input_t* my_B = &B[(input_matrix_row_id + token_id) * bc_cols + b_offset];
|
||||
input_t* my_C = &C[(input_matrix_row_id + token_id) * bc_cols + c_offset];
|
||||
|
||||
int block_channel_per_token = blockIdx.x * blockDim.x;
|
||||
int block_channel
|
||||
@ -304,7 +319,7 @@ __launch_bounds__(256, 1) __global__ void selective_scan_loop_kernel(SSMParamsBa
|
||||
// Write the new state back out to the cache
|
||||
for (int i = 0; i < DSTATE; i++)
|
||||
{
|
||||
input_t* my_state = &state[sample * num_channels * DSTATE];
|
||||
input_t* my_state = &state[slot_idx * num_channels * DSTATE];
|
||||
int offset = i * num_channels + channel;
|
||||
convertAndStore(&my_state[offset], state_reg[i]);
|
||||
}
|
||||
@ -351,8 +366,8 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
|
||||
input_t* x = reinterpret_cast<input_t*>(params.u_ptr);
|
||||
input_t* dt = reinterpret_cast<input_t*>(params.delta_ptr);
|
||||
weight_t* A = reinterpret_cast<weight_t*>(params.A_ptr);
|
||||
input_t* B = reinterpret_cast<input_t*>(params.B_ptr);
|
||||
input_t* C = reinterpret_cast<input_t*>(params.C_ptr);
|
||||
input_t* B = reinterpret_cast<input_t*>(params.BC_ptr);
|
||||
input_t* C = reinterpret_cast<input_t*>(params.BC_ptr);
|
||||
weight_t* D = reinterpret_cast<weight_t*>(params.D_ptr);
|
||||
input_t* z = reinterpret_cast<input_t*>(params.z_ptr);
|
||||
weight_t* dt_bias = reinterpret_cast<weight_t*>(params.delta_bias_ptr);
|
||||
@ -363,8 +378,12 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
|
||||
if (channel >= num_channels)
|
||||
return;
|
||||
int const sample = blockIdx.y;
|
||||
int const slot_idx = params.slot_mapping_ptr == nullptr ? sample : params.slot_mapping_ptr[sample];
|
||||
int const bc_cols = DSTATE * 2 + params.dt_rank;
|
||||
int const b_offset = params.dt_rank;
|
||||
int const c_offset = params.dt_rank + DSTATE;
|
||||
|
||||
input_t* my_state = &state[sample * num_channels * DSTATE];
|
||||
input_t* my_state = &state[slot_idx * num_channels * DSTATE];
|
||||
input_t* my_output = &output[sample * num_channels];
|
||||
|
||||
float rA[DSTATE];
|
||||
@ -377,8 +396,8 @@ __launch_bounds__(128, 2) __global__ void selective_scan_update_kernel(SSMParams
|
||||
for (int i = 0; i < DSTATE; i++)
|
||||
{
|
||||
rA[i] = toFloat(A[i * num_channels + channel]);
|
||||
rB[i] = toFloat(B[sample * DSTATE + i]);
|
||||
rC[i] = toFloat(C[sample * DSTATE + i]);
|
||||
rB[i] = toFloat(B[sample * bc_cols + b_offset + i]);
|
||||
rC[i] = toFloat(C[sample * bc_cols + c_offset + i]);
|
||||
rState[i] = toFloat(my_state[i * num_channels + channel]);
|
||||
}
|
||||
|
||||
|
||||
@ -40,9 +40,9 @@ namespace kernels
|
||||
|
||||
struct SSMParamsBase
|
||||
{
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, dstate;
|
||||
int batch, dim, dstate, dt_rank;
|
||||
int max_seqlen; // only valid for padded input.
|
||||
bool remove_padding;
|
||||
bool is_variable_B;
|
||||
bool is_variable_C;
|
||||
|
||||
@ -50,8 +50,7 @@ struct SSMParamsBase
|
||||
|
||||
// Common data pointers.
|
||||
void* __restrict__ A_ptr;
|
||||
void* __restrict__ B_ptr;
|
||||
void* __restrict__ C_ptr;
|
||||
void* __restrict__ BC_ptr;
|
||||
void* __restrict__ D_ptr;
|
||||
void* __restrict__ u_ptr;
|
||||
void* __restrict__ delta_ptr;
|
||||
@ -60,6 +59,7 @@ struct SSMParamsBase
|
||||
void* __restrict__ x_ptr;
|
||||
void* __restrict__ z_ptr;
|
||||
int const* __restrict__ last_token_ids_ptr;
|
||||
int const* __restrict__ slot_mapping_ptr;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
@ -107,7 +107,7 @@ void invokeStopWordsCriterion(int32_t const** outputIds, int32_t const** parentI
|
||||
}
|
||||
|
||||
__global__ void lengthCriterion(FinishedState* finished, int32_t* finishedSum, uint32_t const* sequenceLimitLength,
|
||||
int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth)
|
||||
int32_t* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth)
|
||||
{
|
||||
int32_t threadFinishedCount = 0;
|
||||
auto const batchIdx = blockIdx.x;
|
||||
@ -122,6 +122,7 @@ __global__ void lengthCriterion(FinishedState* finished, int32_t* finishedSum, u
|
||||
if (sequenceLengths[batchSlotBeamWidthIdx] >= sequenceLimitLength[batchSlot])
|
||||
{
|
||||
finishState.setFinishedMaxLength();
|
||||
sequenceLengths[batchSlotBeamWidthIdx] = sequenceLimitLength[batchSlot];
|
||||
}
|
||||
threadFinishedCount += finishState.isFinished() ? 1 : 0;
|
||||
finished[batchSlotBeamWidthIdx] = finishState;
|
||||
@ -148,8 +149,7 @@ __global__ void lengthCriterion(FinishedState* finished, int32_t* finishedSum, u
|
||||
}
|
||||
|
||||
void invokeLengthCriterion(FinishedState* finished, int32_t* finishedSum, uint32_t const* sequenceLimitLength,
|
||||
int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth,
|
||||
cudaStream_t stream)
|
||||
int32_t* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth, cudaStream_t stream)
|
||||
{
|
||||
// Check if we have attained the sequence length limit. If so, stop the
|
||||
// sequence. In addition, check if all sequences are stopped and return the
|
||||
|
||||
@ -54,14 +54,13 @@ void invokeStopWordsCriterion(int32_t const** outputIds, int32_t const** parentI
|
||||
//! \param finishedSum output buffer [1].
|
||||
//! Total sum of finished requests
|
||||
//! \param sequenceLimitLength input buffer [maxBatchSize]. Maximum sequence length.
|
||||
//! \param sequenceLengths input buffer [maxBatchSize, beamWidth].
|
||||
//! \param sequenceLengths input/output buffer [maxBatchSize, beamWidth].
|
||||
//! Current sequence lengths of the request tokens.
|
||||
//! \param batchSlots input buffer[batchSize], optional. Indices of rows of data in memory pool
|
||||
//! \param batchSize batch size
|
||||
//! \param beamWidth beam width
|
||||
//! \param stream stream
|
||||
void invokeLengthCriterion(FinishedState* finished, int32_t* finishedSum, uint32_t const* sequenceLimitLength,
|
||||
int32_t const* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth,
|
||||
cudaStream_t stream);
|
||||
int32_t* sequenceLengths, int32_t const* batchSlots, int32_t batchSize, int32_t beamWidth, cudaStream_t stream);
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
@ -108,15 +108,15 @@ void invokeTranspose4dBatchMajor(T const* k_src, T const* v_src, KVCacheBuffer&
|
||||
|
||||
// NOTE: this kernel is in-place, QKV will be modified, if other kernels need that, may need copy or use before it.
|
||||
template <typename T, typename KVCacheBuffer, bool IsGenerate = false>
|
||||
void invokeApplyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias, int const* seq_lens,
|
||||
int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len,
|
||||
void invokeApplyBiasRopeUpdateKVCache(T* QKV, void* O, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias,
|
||||
int const* seq_lens, int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len,
|
||||
int const cyclic_kv_cache_len, int const sink_token_len, int const token_num, int const head_num,
|
||||
int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float const rotary_embedding_base,
|
||||
const RotaryScalingType rotary_scale_type, float const rotary_embedding_scale,
|
||||
int const rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type,
|
||||
int const* medusa_position_offsets, bool const position_shift_enabled, float const* scale, int const int8_mode,
|
||||
const KvCacheDataType cache_type, float const* kvScaleOrigQuant, bool const enable_paged_kv_fmha,
|
||||
int const beam_width, int2& grid_block_cache, cudaStream_t stream);
|
||||
bool const quantized_fp8_output, int const beam_width, int2& grid_block_cache, cudaStream_t stream);
|
||||
|
||||
template <typename T, typename BT>
|
||||
void invokeAddRelativeAttentionBiasUnaligned(T* qk_buf, const BT* relative_attention_bias, int const batch_size,
|
||||
|
||||
@ -46,7 +46,9 @@ template <typename T, int Dh_MAX>
|
||||
struct Rotary_vec_t
|
||||
{
|
||||
using Type = T;
|
||||
using Packed_type = T;
|
||||
// Quantized output type only supports fp8 currently.
|
||||
using Packed_type = __nv_fp8_e4m3;
|
||||
using Quantized_type = void;
|
||||
static constexpr int size = 1;
|
||||
};
|
||||
|
||||
@ -56,6 +58,7 @@ template <>
|
||||
struct Rotary_vec_t<float, 32>
|
||||
{
|
||||
using Type = float;
|
||||
using Quantized_type = __nv_fp8_e4m3;
|
||||
using Packed_type = float;
|
||||
static constexpr int size = 1;
|
||||
};
|
||||
@ -64,6 +67,7 @@ template <>
|
||||
struct Rotary_vec_t<float, 64>
|
||||
{
|
||||
using Type = float2;
|
||||
using Quantized_type = mmha::fp8_2_t;
|
||||
using Packed_type = float2;
|
||||
static constexpr int size = 2;
|
||||
};
|
||||
@ -72,6 +76,7 @@ template <>
|
||||
struct Rotary_vec_t<float, 128>
|
||||
{
|
||||
using Type = float4;
|
||||
using Quantized_type = mmha::fp8_4_t;
|
||||
using Packed_type = float2;
|
||||
static constexpr int size = 4;
|
||||
};
|
||||
@ -80,6 +85,7 @@ template <>
|
||||
struct Rotary_vec_t<float, 256>
|
||||
{
|
||||
using Type = mmha::Float8_;
|
||||
using Quantized_type = mmha::fp8_8_t;
|
||||
using Packed_type = float2;
|
||||
static constexpr int size = 8;
|
||||
};
|
||||
@ -90,6 +96,7 @@ template <>
|
||||
struct Rotary_vec_t<half, 32>
|
||||
{
|
||||
using Type = uint16_t;
|
||||
using Quantized_type = __nv_fp8_e4m3;
|
||||
using Packed_type = uint16_t;
|
||||
static constexpr int size = 1;
|
||||
};
|
||||
@ -98,6 +105,7 @@ template <>
|
||||
struct Rotary_vec_t<half, 64>
|
||||
{
|
||||
using Type = uint32_t;
|
||||
using Quantized_type = mmha::fp8_2_t;
|
||||
using Packed_type = uint32_t;
|
||||
static constexpr int size = 2;
|
||||
};
|
||||
@ -106,6 +114,7 @@ template <>
|
||||
struct Rotary_vec_t<half, 128>
|
||||
{
|
||||
using Type = uint2;
|
||||
using Quantized_type = mmha::fp8_4_t;
|
||||
using Packed_type = uint32_t;
|
||||
static constexpr int size = 4;
|
||||
};
|
||||
@ -114,6 +123,7 @@ template <>
|
||||
struct Rotary_vec_t<half, 256>
|
||||
{
|
||||
using Type = uint4;
|
||||
using Quantized_type = mmha::fp8_8_t;
|
||||
using Packed_type = uint32_t;
|
||||
static constexpr int size = 8;
|
||||
};
|
||||
@ -126,6 +136,7 @@ template <>
|
||||
struct Rotary_vec_t<__nv_bfloat16, 32>
|
||||
{
|
||||
using Type = __nv_bfloat16;
|
||||
using Quantized_type = __nv_fp8_e4m3;
|
||||
using Packed_type = __nv_bfloat16;
|
||||
static constexpr int size = 1;
|
||||
};
|
||||
@ -134,6 +145,7 @@ template <>
|
||||
struct Rotary_vec_t<__nv_bfloat16, 64>
|
||||
{
|
||||
using Type = __nv_bfloat162;
|
||||
using Quantized_type = mmha::fp8_2_t;
|
||||
using Packed_type = __nv_bfloat162;
|
||||
static constexpr int size = 2;
|
||||
};
|
||||
@ -142,6 +154,7 @@ template <>
|
||||
struct Rotary_vec_t<__nv_bfloat16, 128>
|
||||
{
|
||||
using Type = mmha::bf16_4_t;
|
||||
using Quantized_type = mmha::fp8_4_t;
|
||||
using Packed_type = __nv_bfloat162;
|
||||
static constexpr int size = 4;
|
||||
};
|
||||
@ -150,6 +163,7 @@ template <>
|
||||
struct Rotary_vec_t<__nv_bfloat16, 256>
|
||||
{
|
||||
using Type = mmha::bf16_8_t;
|
||||
using Quantized_type = mmha::fp8_8_t;
|
||||
using Packed_type = __nv_bfloat162;
|
||||
static constexpr int size = 8;
|
||||
};
|
||||
@ -158,13 +172,14 @@ struct Rotary_vec_t<__nv_bfloat16, 256>
|
||||
|
||||
template <typename T, typename T_cache, int Dh_MAX, bool ADD_BIAS, bool STORE_QKV, bool POS_SHIFT,
|
||||
typename KVCacheBuffer, bool IS_GENERATE>
|
||||
__global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBuffer, T const* __restrict qkv_bias,
|
||||
int const* seq_lens, int const* kv_seq_lens, int const* padding_offset, float const* kvScaleOrigQuant,
|
||||
int const num_tokens, int const batch_size, int const seq_len, int const cyclic_kv_cache_len,
|
||||
int const sink_token_len, int const head_num, int const kv_head_num, int const qheads_per_kv_head,
|
||||
int const size_per_head, int const rotary_embedding_dim, float rotary_embedding_base,
|
||||
__global__ void applyBiasRopeUpdateKVCache(T* QKV, void* O, T* Q, KVCacheBuffer kvCacheBuffer,
|
||||
T const* __restrict qkv_bias, int const* seq_lens, int const* kv_seq_lens, int const* padding_offset,
|
||||
float const* kvScaleOrigQuant, int const num_tokens, int const batch_size, int const seq_len,
|
||||
int const cyclic_kv_cache_len, int const sink_token_len, int const head_num, int const kv_head_num,
|
||||
int const qheads_per_kv_head, int const size_per_head, int const rotary_embedding_dim, float rotary_embedding_base,
|
||||
RotaryScalingType const rotary_scale_type, float rotary_embedding_scale, int const rotary_embedding_max_positions,
|
||||
PositionEmbeddingType const position_embedding_type, int const* medusa_position_offsets, int const beam_width)
|
||||
PositionEmbeddingType const position_embedding_type, int const* medusa_position_offsets,
|
||||
bool const quantized_fp8_output, int const beam_width)
|
||||
{
|
||||
// This kernel add bias to QKV, which has shape [batch_size, seq_len, 3, head_num, size_per_head]
|
||||
// Extract the Q input when using paged KV FMHA.
|
||||
@ -190,6 +205,9 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu
|
||||
// VEC_SIZE is power of 2.
|
||||
constexpr int VEC_SIZE = Rotary_vec_t<T, Dh_MAX>::size;
|
||||
using Vec_type = typename Rotary_vec_t<T, Dh_MAX>::Type;
|
||||
// Quantized output only supports fp8 currently.
|
||||
using Quantized_elt_type = __nv_fp8_e4m3;
|
||||
using Quantized_type = typename Rotary_vec_t<T, Dh_MAX>::Quantized_type;
|
||||
using Packed_type = typename Rotary_vec_t<T, Dh_MAX>::Packed_type;
|
||||
bool const has_padding = padding_offset == nullptr;
|
||||
|
||||
@ -348,7 +366,16 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu
|
||||
|
||||
if constexpr (STORE_QKV)
|
||||
{
|
||||
*reinterpret_cast<Vec_type*>(&QKV[src_q_idx]) = q;
|
||||
if (quantized_fp8_output)
|
||||
{
|
||||
// use 1.0f scale currently for qkv input of FP8 FMHA.
|
||||
mmha::convert_to_fp8(
|
||||
reinterpret_cast<Quantized_type*>(reinterpret_cast<Quantized_elt_type*>(O) + src_q_idx), q);
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<Vec_type*>(&QKV[src_q_idx]) = q;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -358,10 +385,21 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu
|
||||
{
|
||||
if constexpr (STORE_QKV)
|
||||
{
|
||||
*reinterpret_cast<Vec_type*>(&QKV[src_k_idx]) = k;
|
||||
if constexpr (ADD_BIAS)
|
||||
if (quantized_fp8_output)
|
||||
{
|
||||
*reinterpret_cast<Vec_type*>(&QKV[src_v_idx]) = v;
|
||||
// use 1.0f scale currently for qkv input of FP8 FMHA.
|
||||
mmha::convert_to_fp8(
|
||||
reinterpret_cast<Quantized_type*>(reinterpret_cast<Quantized_elt_type*>(O) + src_k_idx), k);
|
||||
mmha::convert_to_fp8(
|
||||
reinterpret_cast<Quantized_type*>(reinterpret_cast<Quantized_elt_type*>(O) + src_v_idx), v);
|
||||
}
|
||||
else
|
||||
{
|
||||
*reinterpret_cast<Vec_type*>(&QKV[src_k_idx]) = k;
|
||||
if constexpr (ADD_BIAS)
|
||||
{
|
||||
*reinterpret_cast<Vec_type*>(&QKV[src_v_idx]) = v;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -405,22 +443,22 @@ __global__ void applyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer kvCacheBu
|
||||
= std::min((grid_size + head_num - 1) / head_num, (token_num + tokens_per_block - 1) / tokens_per_block); \
|
||||
dim3 grid(blocks_per_sequence, head_num); \
|
||||
applyBiasRopeUpdateKVCache<T, T_cache, Dh_MAX, ADD_BIAS, STORE_QKV, POS_SHIFT, KVCacheBuffer, IS_GENERATE> \
|
||||
<<<grid, block, 0, stream>>>(QKV, Q, kvTable, qkv_bias, seq_lens, kv_seq_lens, padding_offset, \
|
||||
<<<grid, block, 0, stream>>>(QKV, O, Q, kvTable, qkv_bias, seq_lens, kv_seq_lens, padding_offset, \
|
||||
kvScaleOrigQuant, token_num, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, head_num, \
|
||||
kv_head_num, head_num / kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, \
|
||||
rotary_scale_type, updated_rotary_embedding_scale, rotary_embedding_max_positions, \
|
||||
position_embedding_type, medusa_position_offsets, beam_width);
|
||||
position_embedding_type, medusa_position_offsets, quantized_fp8_output, beam_width);
|
||||
|
||||
template <int Dh_MAX, typename T, typename T_cache, typename KVCacheBuffer, bool IS_GENERATE>
|
||||
void kernelDispatchHeadSize(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias, int const* seq_lens,
|
||||
void kernelDispatchHeadSize(T* QKV, void* O, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias, int const* seq_lens,
|
||||
int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len,
|
||||
int const cyclic_kv_cache_len, int const sink_token_len, int const token_num, int const head_num,
|
||||
int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float const rotary_embedding_base,
|
||||
const RotaryScalingType rotary_scale_type, float const rotary_embedding_scale,
|
||||
int const rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type,
|
||||
int const* medusa_position_offsets, bool const position_shift_enabled, float const* scale,
|
||||
float const* kvScaleOrigQuant, int const int8_mode, bool const enable_paged_kv_fmha, int const beam_width,
|
||||
int2& grid_block_cache, cudaStream_t stream)
|
||||
float const* kvScaleOrigQuant, int const int8_mode, bool const enable_paged_kv_fmha,
|
||||
bool const quantized_fp8_output, int const beam_width, int2& grid_block_cache, cudaStream_t stream)
|
||||
{
|
||||
bool const add_bias = qkv_bias != nullptr;
|
||||
bool const store_contiguous_qkv = !enable_paged_kv_fmha;
|
||||
@ -482,15 +520,15 @@ void kernelDispatchHeadSize(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_b
|
||||
}
|
||||
|
||||
template <typename T, typename T_cache, typename KVCacheBuffer, bool IS_GENERATE>
|
||||
void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias,
|
||||
void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, void* O, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias,
|
||||
int const* seq_lens, int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len,
|
||||
int const cyclic_kv_cache_len, int const sink_token_len, int const token_num, int const head_num,
|
||||
int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float const rotary_embedding_base,
|
||||
const RotaryScalingType rotary_scale_type, float const rotary_embedding_scale,
|
||||
int const rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type,
|
||||
int const* medusa_position_offsets, bool const position_shift_enabled, float const* scale,
|
||||
float const* kvScaleOrigQuant, int const int8_mode, bool const enable_paged_kv_fmha, int const beam_width,
|
||||
int2& grid_block_cache, cudaStream_t stream)
|
||||
float const* kvScaleOrigQuant, int const int8_mode, bool const enable_paged_kv_fmha,
|
||||
bool const quantized_fp8_output, int const beam_width, int2& grid_block_cache, cudaStream_t stream)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(int8_mode != 2, "w8a8 not yet implemented with RoPE"); // TODO
|
||||
if constexpr (!IS_GENERATE)
|
||||
@ -509,30 +547,30 @@ void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, T* Q, KVCacheBuffer& kvTab
|
||||
// GPTJ Rotary embedding needs at least two elements per thread.
|
||||
if (size_per_head <= 64)
|
||||
{
|
||||
kernelDispatchHeadSize<64, T, T_cache, KVCacheBuffer, IS_GENERATE>(QKV, Q, kvTable, qkv_bias, seq_lens,
|
||||
kernelDispatchHeadSize<64, T, T_cache, KVCacheBuffer, IS_GENERATE>(QKV, O, Q, kvTable, qkv_bias, seq_lens,
|
||||
kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, token_num, head_num,
|
||||
kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
|
||||
rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, medusa_position_offsets,
|
||||
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, beam_width,
|
||||
grid_block_cache, stream);
|
||||
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, quantized_fp8_output,
|
||||
beam_width, grid_block_cache, stream);
|
||||
}
|
||||
else if (size_per_head <= 128)
|
||||
{
|
||||
kernelDispatchHeadSize<128, T, T_cache, KVCacheBuffer, IS_GENERATE>(QKV, Q, kvTable, qkv_bias, seq_lens,
|
||||
kernelDispatchHeadSize<128, T, T_cache, KVCacheBuffer, IS_GENERATE>(QKV, O, Q, kvTable, qkv_bias, seq_lens,
|
||||
kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, token_num, head_num,
|
||||
kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
|
||||
rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, medusa_position_offsets,
|
||||
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, beam_width,
|
||||
grid_block_cache, stream);
|
||||
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, quantized_fp8_output,
|
||||
beam_width, grid_block_cache, stream);
|
||||
}
|
||||
else if (size_per_head <= 256)
|
||||
{
|
||||
kernelDispatchHeadSize<256, T, T_cache, KVCacheBuffer, IS_GENERATE>(QKV, Q, kvTable, qkv_bias, seq_lens,
|
||||
kernelDispatchHeadSize<256, T, T_cache, KVCacheBuffer, IS_GENERATE>(QKV, O, Q, kvTable, qkv_bias, seq_lens,
|
||||
kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, token_num, head_num,
|
||||
kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
|
||||
rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, medusa_position_offsets,
|
||||
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, beam_width,
|
||||
grid_block_cache, stream);
|
||||
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, quantized_fp8_output,
|
||||
beam_width, grid_block_cache, stream);
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -541,15 +579,15 @@ void invokeApplyBiasRopeUpdateKVCacheDispatch(T* QKV, T* Q, KVCacheBuffer& kvTab
|
||||
}
|
||||
|
||||
template <typename T, typename KVCacheBuffer, bool IS_GENERATE>
|
||||
void invokeApplyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias, int const* seq_lens,
|
||||
int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len,
|
||||
void invokeApplyBiasRopeUpdateKVCache(T* QKV, void* O, T* Q, KVCacheBuffer& kvTable, T const* qkv_bias,
|
||||
int const* seq_lens, int const* kv_seq_lens, int const* padding_offset, int const batch_size, int const seq_len,
|
||||
int const cyclic_kv_cache_len, int const sink_token_len, int const token_num, int const head_num,
|
||||
int const kv_head_num, int const size_per_head, int const rotary_embedding_dim, float const rotary_embedding_base,
|
||||
const RotaryScalingType rotary_scale_type, float const rotary_embedding_scale,
|
||||
int const rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type,
|
||||
int const* medusa_position_offsets, bool const position_shift_enabled, float const* scale, int const int8_mode,
|
||||
const KvCacheDataType cache_type, float const* kvScaleOrigQuant, bool const enable_paged_kv_fmha,
|
||||
int const beam_width, int2& grid_block_cache, cudaStream_t stream)
|
||||
bool const quantized_fp8_output, int const beam_width, int2& grid_block_cache, cudaStream_t stream)
|
||||
{
|
||||
// Block handles both K and V tile.
|
||||
constexpr int x = (sizeof(T) == 4) ? 4 : 8;
|
||||
@ -557,38 +595,38 @@ void invokeApplyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer& kvTable, T co
|
||||
|
||||
if (cache_type == KvCacheDataType::INT8)
|
||||
{
|
||||
invokeApplyBiasRopeUpdateKVCacheDispatch<T, int8_t, KVCacheBuffer, IS_GENERATE>(QKV, Q, kvTable, qkv_bias,
|
||||
invokeApplyBiasRopeUpdateKVCacheDispatch<T, int8_t, KVCacheBuffer, IS_GENERATE>(QKV, O, Q, kvTable, qkv_bias,
|
||||
seq_lens, kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, token_num,
|
||||
head_num, kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
|
||||
rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, medusa_position_offsets,
|
||||
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, beam_width,
|
||||
grid_block_cache, stream);
|
||||
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, quantized_fp8_output,
|
||||
beam_width, grid_block_cache, stream);
|
||||
}
|
||||
#ifdef ENABLE_FP8
|
||||
else if (cache_type == KvCacheDataType::FP8)
|
||||
{
|
||||
invokeApplyBiasRopeUpdateKVCacheDispatch<T, __nv_fp8_e4m3, KVCacheBuffer, IS_GENERATE>(QKV, Q, kvTable,
|
||||
invokeApplyBiasRopeUpdateKVCacheDispatch<T, __nv_fp8_e4m3, KVCacheBuffer, IS_GENERATE>(QKV, O, Q, kvTable,
|
||||
qkv_bias, seq_lens, kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len,
|
||||
token_num, head_num, kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base,
|
||||
rotary_scale_type, rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type,
|
||||
medusa_position_offsets, position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha,
|
||||
beam_width, grid_block_cache, stream);
|
||||
quantized_fp8_output, beam_width, grid_block_cache, stream);
|
||||
}
|
||||
#endif // ENABLE_FP8
|
||||
else
|
||||
{
|
||||
invokeApplyBiasRopeUpdateKVCacheDispatch<T, T, KVCacheBuffer, IS_GENERATE>(QKV, Q, kvTable, qkv_bias, seq_lens,
|
||||
kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, token_num, head_num,
|
||||
kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
|
||||
invokeApplyBiasRopeUpdateKVCacheDispatch<T, T, KVCacheBuffer, IS_GENERATE>(QKV, O, Q, kvTable, qkv_bias,
|
||||
seq_lens, kv_seq_lens, padding_offset, batch_size, seq_len, cyclic_kv_cache_len, sink_token_len, token_num,
|
||||
head_num, kv_head_num, size_per_head, rotary_embedding_dim, rotary_embedding_base, rotary_scale_type,
|
||||
rotary_embedding_scale, rotary_embedding_max_positions, position_embedding_type, medusa_position_offsets,
|
||||
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, beam_width,
|
||||
grid_block_cache, stream);
|
||||
position_shift_enabled, scale, kvScaleOrigQuant, int8_mode, enable_paged_kv_fmha, quantized_fp8_output,
|
||||
beam_width, grid_block_cache, stream);
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_ADDFUSEDQKVBIAS_TRANSPOSE(T, KVCacheBuffer, IS_GENERATE) \
|
||||
template void invokeApplyBiasRopeUpdateKVCache<T, KVCacheBuffer, IS_GENERATE>(T * QKV, T * Q, \
|
||||
KVCacheBuffer & kvTable, const T* qkv_bias, const int* seq_lens, const int* kv_seq_lens, \
|
||||
template void invokeApplyBiasRopeUpdateKVCache<T, KVCacheBuffer, IS_GENERATE>(T * QKV, void* O, T* Q, \
|
||||
KVCacheBuffer& kvTable, const T* qkv_bias, const int* seq_lens, const int* kv_seq_lens, \
|
||||
const int* padding_offset, const int batch_size, const int seq_len, const int cyclic_kv_cache_len, \
|
||||
const int sink_token_len, const int token_num, const int head_num, const int kv_head_num, \
|
||||
const int size_per_head, const int rotary_embedding_dim, const float rotary_embedding_base, \
|
||||
@ -596,7 +634,8 @@ void invokeApplyBiasRopeUpdateKVCache(T* QKV, T* Q, KVCacheBuffer& kvTable, T co
|
||||
const int rotary_embedding_max_positions, const PositionEmbeddingType position_embedding_type, \
|
||||
const int* medusa_position_offsets, const bool position_shift_enabled, const float* scale, \
|
||||
const int int8_mode, const KvCacheDataType cache_type, const float* kvScaleOrigQuant, \
|
||||
const bool enable_paged_kv_fmha, const int beam_width, int2& grid_block_cache, cudaStream_t stream)
|
||||
const bool enable_paged_kv_fmha, bool const quantized_fp8_output, const int beam_width, \
|
||||
int2& grid_block_cache, cudaStream_t stream)
|
||||
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
|
||||
172
cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/int8SQ.cu
Normal file
172
cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/int8SQ.cu
Normal file
@ -0,0 +1,172 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "tensorrt_llm/kernels/weightOnlyBatchedGemv/int8SQ.h"
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace smooth_quant
|
||||
{
|
||||
template <typename Type, int CtaM, int CtaN, int Threads, bool PerChannel, bool PerToken>
|
||||
__global__ void int8_sq(int8_t const* act, int8_t const* weight, float const* scale_channels, float const* scale_tokens,
|
||||
Type* output, int m, int n, int k)
|
||||
{
|
||||
using VecType = int4;
|
||||
static constexpr int kStepK = 128 / (8 * sizeof(int8_t));
|
||||
static constexpr int CtaK = kStepK * Threads;
|
||||
int tile_id_m = blockIdx.x * CtaM;
|
||||
int tile_id_n = blockIdx.y * CtaN;
|
||||
int tid = threadIdx.x;
|
||||
int8_t tile_a[kStepK], tile_w[CtaN * kStepK];
|
||||
int acc[CtaM * CtaN];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CtaM * CtaN; ++i)
|
||||
{
|
||||
acc[i] = 0;
|
||||
}
|
||||
act += tile_id_m * k;
|
||||
weight += tile_id_n * k;
|
||||
output += tile_id_m * n + tile_id_n;
|
||||
for (int idx_k = tid * kStepK; idx_k < k; idx_k += CtaK)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CtaN; ++i)
|
||||
{
|
||||
reinterpret_cast<VecType*>(tile_w)[i] = reinterpret_cast<VecType const*>(weight + i * k + idx_k)[0];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CtaM; ++i)
|
||||
{
|
||||
reinterpret_cast<VecType*>(tile_a)[0] = reinterpret_cast<VecType const*>(act + i * k + idx_k)[0];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < CtaN; ++j)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int l = 0; l < kStepK; l += 4)
|
||||
{
|
||||
acc[i * CtaN + j] = __dp4a(reinterpret_cast<int*>(tile_a + l)[0],
|
||||
reinterpret_cast<int*>(tile_w + j * kStepK + l)[0], acc[i * CtaN + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr int kWarpSize = 32;
|
||||
static constexpr int kWarpNum = Threads / kWarpSize;
|
||||
__shared__ int shmem[CtaM * CtaN * kWarpNum];
|
||||
int warp_id = tid / kWarpSize, lane_id = tid % kWarpSize;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < CtaM; ++i)
|
||||
{
|
||||
#pragma unroll
|
||||
for (int j = 0; j < CtaN; ++j)
|
||||
{
|
||||
int val = acc[i * CtaN + j];
|
||||
val += __shfl_xor_sync(~0, val, 16);
|
||||
val += __shfl_xor_sync(~0, val, 8);
|
||||
val += __shfl_xor_sync(~0, val, 4);
|
||||
val += __shfl_xor_sync(~0, val, 2);
|
||||
val += __shfl_xor_sync(~0, val, 1);
|
||||
if (lane_id == 0)
|
||||
{
|
||||
shmem[i * CtaN + j + warp_id * CtaM * CtaN] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int ii = tid; ii < CtaM * CtaN; ii += Threads)
|
||||
{
|
||||
int mid = ii / CtaN, nid = ii % CtaN;
|
||||
float scale_channel, scale_token;
|
||||
if constexpr (PerChannel)
|
||||
{
|
||||
scale_channel = scale_channels[tile_id_n + nid];
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_channel = scale_channels[0];
|
||||
}
|
||||
if constexpr (PerToken)
|
||||
{
|
||||
scale_token = scale_tokens[tile_id_m + mid];
|
||||
}
|
||||
else
|
||||
{
|
||||
scale_token = scale_tokens[0];
|
||||
}
|
||||
int val = 0;
|
||||
#pragma unroll
|
||||
for (int jj = 0; jj < kWarpNum; ++jj)
|
||||
{
|
||||
val += shmem[jj * CtaM * CtaN + ii];
|
||||
}
|
||||
output[mid * n + nid] = static_cast<Type>(static_cast<float>(val) * scale_channel * scale_token);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Type, int CtaM, int CtaN, int Threads, bool PerChannel, bool PerToken>
|
||||
void int8_sq_kernel(Params& params, cudaStream_t s)
|
||||
{
|
||||
dim3 block(Threads);
|
||||
dim3 grid(params.m / CtaM, params.n / CtaN);
|
||||
int8_sq<Type, CtaM, CtaN, Threads, PerChannel, PerToken><<<grid, block, 0, s>>>(params.act, params.weight,
|
||||
params.scale_channels, params.scale_tokens, reinterpret_cast<Type*>(params.output), params.m, params.n,
|
||||
params.k);
|
||||
}
|
||||
|
||||
template <typename Type, bool PerChannel, bool PerToken>
|
||||
void algo_tactic_dispatcher(Params& params, cudaStream_t s)
|
||||
{
|
||||
#define DISPATCH(TargetM, CtaM, CtaN, Threads) \
|
||||
if (params.m == TargetM) \
|
||||
{ \
|
||||
int8_sq_kernel<Type, CtaM, CtaN, Threads, PerChannel, PerToken>(params, s); \
|
||||
return; \
|
||||
}
|
||||
DISPATCH(1, 1, 2, 128);
|
||||
DISPATCH(2, 2, 2, 128);
|
||||
DISPATCH(3, 3, 2, 128);
|
||||
DISPATCH(4, 4, 2, 128);
|
||||
#undef DISPATCH
|
||||
}
|
||||
|
||||
template <typename Type>
|
||||
void int8_sq_launcher(Params& params, cudaStream_t s)
|
||||
{
|
||||
#define DISPATCH(PerChannel, PerToken) \
|
||||
if (per_channel == PerChannel && per_token == PerToken) \
|
||||
{ \
|
||||
algo_tactic_dispatcher<Type, PerChannel, PerToken>(params, s); \
|
||||
return; \
|
||||
}
|
||||
bool per_channel = params.quant_mode.hasPerChannelScaling();
|
||||
bool per_token = params.quant_mode.hasPerTokenScaling();
|
||||
DISPATCH(false, false);
|
||||
DISPATCH(false, true);
|
||||
DISPATCH(true, false);
|
||||
DISPATCH(true, true);
|
||||
#undef DISPATCH
|
||||
}
|
||||
|
||||
template void int8_sq_launcher<float>(Params& params, cudaStream_t s);
|
||||
template void int8_sq_launcher<half>(Params& params, cudaStream_t s);
|
||||
template void int8_sq_launcher<int>(Params& params, cudaStream_t s);
|
||||
} // namespace smooth_quant
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
63
cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/int8SQ.h
Normal file
63
cpp/tensorrt_llm/kernels/weightOnlyBatchedGemv/int8SQ.h
Normal file
@ -0,0 +1,63 @@
|
||||
/*
|
||||
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
#include "tensorrt_llm/common/quantization.h"
|
||||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace tensorrt_llm
|
||||
{
|
||||
namespace kernels
|
||||
{
|
||||
namespace smooth_quant
|
||||
{
|
||||
struct Params
|
||||
{
|
||||
int8_t const* act;
|
||||
int8_t const* weight;
|
||||
float const* scale_tokens;
|
||||
float const* scale_channels;
|
||||
void* output;
|
||||
int m, n, k;
|
||||
tensorrt_llm::common::QuantMode quant_mode;
|
||||
|
||||
Params(int8_t const* _act, int8_t const* _weight, float const* _scale_tokens, float const* _scale_channels,
|
||||
void* _output, int _m, int _n, int _k, tensorrt_llm::common::QuantMode _quant_mode)
|
||||
: act(_act)
|
||||
, weight(_weight)
|
||||
, scale_tokens(_scale_tokens)
|
||||
, scale_channels(_scale_channels)
|
||||
, output(_output)
|
||||
, m(_m)
|
||||
, n(_n)
|
||||
, k(_k)
|
||||
, quant_mode(_quant_mode)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
template <typename>
|
||||
void int8_sq_launcher(Params& params, cudaStream_t s);
|
||||
} // namespace smooth_quant
|
||||
} // namespace kernels
|
||||
} // namespace tensorrt_llm
|
||||
@ -41,12 +41,12 @@ public:
|
||||
class SetupParams : public DecodingSetupParams
|
||||
{
|
||||
public:
|
||||
std::optional<std::vector<runtime::SizeType>> runtime_top_k; // [1] or [batchSize] on cpu
|
||||
std::optional<std::vector<float>> runtime_top_p; // [1] or [batchSize] on cpu
|
||||
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [batchSize] on cpu
|
||||
std::optional<std::vector<float>> top_p_decay; // [batchSize], must between [0, 1]
|
||||
std::optional<std::vector<float>> top_p_min; // [batchSize], must between [0, 1]
|
||||
std::optional<std::vector<std::int32_t>> top_p_reset_ids; // [batchSize]
|
||||
std::optional<std::vector<runtime::SizeType>> runtime_top_k; // [1] or [batchSize] on cpu
|
||||
std::optional<std::vector<float>> runtime_top_p; // [1] or [batchSize] on cpu
|
||||
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [batchSize] on cpu
|
||||
std::optional<std::vector<float>> top_p_decay; // [batchSize], must between [0, 1]
|
||||
std::optional<std::vector<float>> top_p_min; // [batchSize], must between [0, 1]
|
||||
std::optional<std::vector<runtime::TokenIdType>> top_p_reset_ids; // [batchSize]
|
||||
std::optional<bool> normalize_log_probs;
|
||||
};
|
||||
|
||||
|
||||
@ -78,8 +78,10 @@ public:
|
||||
tc::Tensor output_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len]
|
||||
|
||||
// Medusa params
|
||||
std::optional<tc::Tensor> nextDraftTokens; // [batch_size, max_draft_tokens_per_step]
|
||||
std::optional<tc::Tensor> acceptedLengths; // [batch_size]
|
||||
std::optional<tc::Tensor> nextDraftTokens; // [batch_size, max_draft_tokens_per_step]
|
||||
std::optional<tc::Tensor> acceptedLengths; // [batch_size]
|
||||
std::optional<tc::Tensor> acceptedLengthsCumSum; // [batch_size + 1]
|
||||
std::optional<tc::Tensor> medusaPathsOffsets; // [batch_size * max_medusa_heads]
|
||||
};
|
||||
|
||||
} // namespace tensorrt_llm::layers
|
||||
|
||||
@ -292,7 +292,6 @@ void DynamicDecodeLayer<T>::setupLayers(
|
||||
typename MedusaDecodingLayer<T>::MedusaSetupParams medusaSetupParams;
|
||||
medusaSetupParams.runtimeTopK = setupParams.runtime_top_k;
|
||||
medusaSetupParams.runtimeHeadsTopK = setupParams.topKMedusaHeads;
|
||||
medusaSetupParams.tokensPerStep = setupParams.tokensPerStep;
|
||||
medusaSetupParams.randomSeed = setupParams.randomSeed;
|
||||
mMedusaDecodingLayer->setup(batchSize, batchSlots, medusaSetupParams);
|
||||
}
|
||||
@ -555,14 +554,19 @@ void DynamicDecodeLayer<T>::layersForward(Tensor& logits, OutputParams& outputs,
|
||||
typename MedusaDecodingLayer<T>::MedusaForwardParams medusaInputParams(logits, endIds);
|
||||
medusaInputParams.finished = outputs.finished.value();
|
||||
medusaInputParams.batch_slots = params.batch_slots;
|
||||
medusaInputParams.paths = params.paths.value();
|
||||
medusaInputParams.medusaLogits = params.medusaLogits.value();
|
||||
medusaInputParams.paths = params.medusaInputs->medusaPaths;
|
||||
medusaInputParams.medusaLogits = params.medusaInputs->medusaLogits;
|
||||
medusaInputParams.medusaCurTokensPerStep = params.medusaInputs->medusaCurTokensPerStep;
|
||||
medusaInputParams.medusaTargetTokensPerStep = params.medusaInputs->medusaTargetTokensPerStep;
|
||||
medusaInputParams.treeIds = params.medusaInputs->medusaTreeIds;
|
||||
|
||||
DecodingOutputParams medusaOutputParams(outputs.output_ids);
|
||||
medusaOutputParams.sequence_length = outputs.sequence_length.value();
|
||||
medusaOutputParams.finished = outputs.finished.value();
|
||||
medusaOutputParams.nextDraftTokens = outputs.nextDraftTokens.value();
|
||||
medusaOutputParams.acceptedLengths = outputs.acceptedLengths.value();
|
||||
medusaOutputParams.nextDraftTokens = outputs.medusaOutputs->nextDraftTokens;
|
||||
medusaOutputParams.acceptedLengths = outputs.medusaOutputs->acceptedLengths;
|
||||
medusaOutputParams.acceptedLengthsCumSum = outputs.medusaOutputs->medusaAcceptedLengthsCumSum;
|
||||
medusaOutputParams.medusaPathsOffsets = outputs.medusaOutputs->medusaPathsOffsets;
|
||||
|
||||
mMedusaDecodingLayer->forward(medusaOutputParams, medusaInputParams);
|
||||
}
|
||||
@ -614,7 +618,8 @@ void DynamicDecodeLayer<T>::applyPenalties(OutputParams& outputs, ForwardParams
|
||||
|
||||
#undef GET_PENALTIES
|
||||
|
||||
auto const tokensPerStep = params.tokensPerStep ? params.tokensPerStep->template getPtr<SizeType const>() : nullptr;
|
||||
auto const tokensPerStep
|
||||
= params.medusaInputs ? params.medusaInputs->medusaCurTokensPerStep.template getPtr<SizeType const>() : nullptr;
|
||||
InvokeBatchApplyPenaltyParams<T> penaltyParams{reinterpret_cast<T const* const*>(logitsPtrsHostData),
|
||||
mRuntimeLogitsDevice, embeddingBias, mPenaltyWorkspaceDevice, mPenaltyWorkspacePrevDevice, temperatures,
|
||||
repetitionPenalties, presencePenalties, frequencyPenalties,
|
||||
@ -788,8 +793,9 @@ void DynamicDecodeLayer<T>::prepareOutputData(OutputParams& outputs, ForwardPara
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
auto idsPtrHostSlice = ITensor::slice(idsPtrsHost, cyclicStep, 1);
|
||||
auto idsPtrHost = reinterpret_cast<TokenIdType**>(runtime::bufferCast<int64_t>(*idsPtrHostSlice));
|
||||
auto const numNewTokens = outputs.acceptedLengths ? outputs.acceptedLengths->template getPtr<SizeType const>()
|
||||
: static_cast<SizeType*>(nullptr);
|
||||
auto const numNewTokens = outputs.medusaOutputs
|
||||
? outputs.medusaOutputs->acceptedLengths.template getPtr<SizeType const>()
|
||||
: static_cast<SizeType*>(nullptr);
|
||||
invokeCopyNextStepIds(outputs.newTokens.template getPtr<TokenIdType>(), idsPtrHost,
|
||||
outputs.sequence_length->template getPtr<SizeType>(), numNewTokens, batchSlots, batchSize, maxBatchSize,
|
||||
beamWidth, maxSeqLen, maxTokensPerStep, stream);
|
||||
|
||||
@ -84,7 +84,6 @@ public:
|
||||
|
||||
// Medusa params
|
||||
std::optional<std::vector<std::vector<runtime::SizeType>>> topKMedusaHeads; // [batchSize, maxMedusaHeads]
|
||||
std::optional<std::vector<runtime::SizeType>> tokensPerStep; // [batchSize]
|
||||
};
|
||||
|
||||
void setup(runtime::SizeType batch_size, runtime::SizeType beam_width, int const* batch_slots,
|
||||
@ -139,10 +138,18 @@ public:
|
||||
std::optional<tc::Tensor> batch_slots; // [batch_size], in pinned memory
|
||||
|
||||
// Medusa inputs
|
||||
std::optional<tc::Tensor> tokensPerStep; // [batch_size], optional, on gpu
|
||||
std::optional<tc::Tensor> paths; // [batch_size, max_tokens_per_step, max_num_heads + 1], optional, on gpu
|
||||
std::optional<tc::Tensor>
|
||||
medusaLogits; // [max_num_heads, batch_size, max_tokens_per_step, vocab_size], optional, on gpu
|
||||
class MedusaInputs
|
||||
{
|
||||
public:
|
||||
tc::Tensor medusaCurTokensPerStep; // [batch_size], optional, on gpu
|
||||
tc::Tensor medusaTargetTokensPerStep; // [batch_size], optional, on gpu
|
||||
tc::Tensor medusaPaths; // [batch_size, max_tokens_per_step, max_num_heads + 1], optional, on gpu
|
||||
tc::Tensor medusaTreeIds; // [batch_size, max_tokens_per_step], optional, on gpu
|
||||
std::vector<std::vector<tc::Tensor>>
|
||||
medusaLogits; // [max_batch_size][max_num_heads][tokens_per_step, vocab_size], optional, on gpu
|
||||
};
|
||||
|
||||
std::optional<MedusaInputs> medusaInputs;
|
||||
};
|
||||
|
||||
class OutputParams
|
||||
@ -173,9 +180,16 @@ public:
|
||||
tc::Tensor parent_ids_ptr; // [batch_size] int* (2-d array), each int* has [beam_width, max_seq_len]
|
||||
|
||||
// Medusa outputs
|
||||
std::optional<tc::Tensor>
|
||||
nextDraftTokens; // [batch_size, max_tokens_per_step], draft tokens predicted by Medusa heads
|
||||
std::optional<tc::Tensor> acceptedLengths; // [batch_size], lengths of the accepted draft tokens + 1
|
||||
class MedusaOutputs
|
||||
{
|
||||
public:
|
||||
tc::Tensor nextDraftTokens; // [batch_size, max_tokens_per_step], draft tokens predicted by Medusa heads
|
||||
tc::Tensor acceptedLengths; // [batch_size], lengths of the accepted draft tokens + 1
|
||||
tc::Tensor medusaAcceptedLengthsCumSum; // [batch_size + 1]
|
||||
tc::Tensor medusaPathsOffsets; // [batch_size * max_medusa_heads]
|
||||
};
|
||||
|
||||
std::optional<MedusaOutputs> medusaOutputs;
|
||||
};
|
||||
|
||||
void forward(OutputParams& outputs, ForwardParams const& params);
|
||||
|
||||
@ -22,6 +22,7 @@
|
||||
#include "tensorrt_llm/kernels/decodingKernels.h"
|
||||
#include "tensorrt_llm/kernels/samplingTopKKernels.h"
|
||||
#include "tensorrt_llm/runtime/bufferManager.h"
|
||||
#include "tensorrt_llm/runtime/iBuffer.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
@ -68,7 +69,8 @@ void MedusaDecodingLayer<T>::allocateBuffer()
|
||||
|
||||
// Get sampling workspace size
|
||||
{
|
||||
auto samplingSizePrimarySampling = getTopKWorkspaceSize<T>(mMaxBatchSize, 1, TOP_K_MAX, mVocabSizePadded);
|
||||
auto samplingSizePrimarySampling
|
||||
= getTopKWorkspaceSize<T>(mMaxBatchSize, mMaxTokensPerStep, TOP_K_MAX, mVocabSizePadded);
|
||||
|
||||
auto const maxBatchSizeHeadNums = mMaxBatchSize * mMaxNumHeads;
|
||||
auto samplingSizeMedusaHeadsSampling
|
||||
@ -82,35 +84,40 @@ void MedusaDecodingLayer<T>::allocateBuffer()
|
||||
runtime::TRTDataType<TokenIdType*>::value);
|
||||
mCummulativeTopK.resize(mMaxBatchSize * mMaxNumHeads);
|
||||
|
||||
std::array<size_t, 10> deviceBufferSizes;
|
||||
std::array<size_t, 11> deviceBufferSizes;
|
||||
deviceBufferSizes[0] = mMaxBatchSize * sizeof(curandState_t);
|
||||
deviceBufferSizes[1] = mMaxBatchSize * sizeof(SizeType);
|
||||
deviceBufferSizes[2] = mMaxBatchSize * mMaxNumHeads * sizeof(SizeType);
|
||||
deviceBufferSizes[3] = mSamplingWorkspaceSize;
|
||||
deviceBufferSizes[4] = mMaxBatchSize * sizeof(SizeType);
|
||||
deviceBufferSizes[5] = mMaxBatchSize * mMaxTokensPerStep * sizeof(TokenIdType);
|
||||
deviceBufferSizes[6] = mMaxBatchSize * mMaxNumHeads * sizeof(uint64_t);
|
||||
deviceBufferSizes[7] = mMaxBatchSize * mMaxNumHeads * sizeof(T*);
|
||||
deviceBufferSizes[8] = mMaxBatchSize * mMaxNumHeads * sizeof(curandState_t);
|
||||
deviceBufferSizes[9] = mMaxBatchSize * mMaxNumHeads * sizeof(SizeType);
|
||||
deviceBufferSizes[1] = mMaxBatchSize * mMaxNumHeads * sizeof(SizeType);
|
||||
deviceBufferSizes[2] = mSamplingWorkspaceSize;
|
||||
deviceBufferSizes[3] = mMaxBatchSize * sizeof(SizeType);
|
||||
deviceBufferSizes[4] = mMaxBatchSize * mMaxTokensPerStep * sizeof(TokenIdType);
|
||||
deviceBufferSizes[5] = mMaxBatchSize * mMaxNumHeads * sizeof(uint64_t);
|
||||
deviceBufferSizes[6] = mMaxBatchSize * mMaxNumHeads * sizeof(T*);
|
||||
deviceBufferSizes[7] = mMaxBatchSize * mMaxNumHeads * sizeof(curandState_t);
|
||||
deviceBufferSizes[8] = mMaxBatchSize * mMaxNumHeads * sizeof(SizeType);
|
||||
deviceBufferSizes[9] = mMaxBatchSize * mMaxTokensPerStep * sizeof(TokenIdType);
|
||||
deviceBufferSizes[10] = mMaxBatchSize * sizeof(SizeType);
|
||||
|
||||
mCurandStatesDevice = mAllocator->reMalloc(mCurandStatesDevice, deviceBufferSizes[0], false);
|
||||
mTokensPerStepDevice = mAllocator->reMalloc(mTokensPerStepDevice, deviceBufferSizes[1], false);
|
||||
mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[2], false);
|
||||
mSamplingWorkspaceDevice = mAllocator->reMalloc(mSamplingWorkspaceDevice, deviceBufferSizes[3], false);
|
||||
mRuntimeTopKDevice = mAllocator->reMalloc(mRuntimeTopKDevice, deviceBufferSizes[4], false);
|
||||
mTargetTokensDevice = mAllocator->reMalloc(mTargetTokensDevice, deviceBufferSizes[5], false);
|
||||
mRandomSeedsDevice = mAllocator->reMalloc(mRandomSeedsDevice, deviceBufferSizes[6], false);
|
||||
mMedusaLogitsPtrsDevice = mAllocator->reMalloc(mMedusaLogitsPtrsDevice, deviceBufferSizes[7], false);
|
||||
mSetupWorkspaceDevice = mAllocator->reMalloc(mSetupWorkspaceDevice, deviceBufferSizes[1], false);
|
||||
mSamplingWorkspaceDevice = mAllocator->reMalloc(mSamplingWorkspaceDevice, deviceBufferSizes[2], false);
|
||||
mRuntimeTopKDevice = mAllocator->reMalloc(mRuntimeTopKDevice, deviceBufferSizes[3], false);
|
||||
mTargetTokensDevice = mAllocator->reMalloc(mTargetTokensDevice, deviceBufferSizes[4], false);
|
||||
mRandomSeedsDevice = mAllocator->reMalloc(mRandomSeedsDevice, deviceBufferSizes[5], false);
|
||||
mMedusaSelectedLogitsPtrsDevice
|
||||
= mAllocator->reMalloc(mMedusaSelectedLogitsPtrsDevice, deviceBufferSizes[6], false);
|
||||
mCurandStatesMedusaLogitsDevice
|
||||
= mAllocator->reMalloc(mCurandStatesMedusaLogitsDevice, deviceBufferSizes[8], false);
|
||||
= mAllocator->reMalloc(mCurandStatesMedusaLogitsDevice, deviceBufferSizes[7], false);
|
||||
mRuntimeTopKPerRequestPerMedusaHeadDevice
|
||||
= mAllocator->reMalloc(mRuntimeTopKPerRequestPerMedusaHeadDevice, deviceBufferSizes[9], false);
|
||||
= mAllocator->reMalloc(mRuntimeTopKPerRequestPerMedusaHeadDevice, deviceBufferSizes[8], false);
|
||||
mNewDraftTokensDevice = mAllocator->reMalloc(mNewDraftTokensDevice, deviceBufferSizes[9], false);
|
||||
mBestPathIdsDevice = mAllocator->reMalloc(mBestPathIdsDevice, deviceBufferSizes[10], false);
|
||||
|
||||
mTiledBatchSlotsSetup = BufferManager::pinnedPool(
|
||||
ITensor::makeShape({static_cast<SizeType>(mMaxBatchSize * mMaxNumHeads)}), nvinfer1::DataType::kINT32);
|
||||
mTiledBatchSlotsForward = BufferManager::pinnedPool(
|
||||
ITensor::makeShape({static_cast<SizeType>(mMaxBatchSize * mMaxNumHeads)}), nvinfer1::DataType::kINT32);
|
||||
mMedusaInputLogitsPtrs = BufferManager::pinnedPool(
|
||||
ITensor::makeShape({static_cast<SizeType>(mMaxBatchSize * mMaxNumHeads)}), TRTDataType<T*>::value);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -121,15 +128,16 @@ void MedusaDecodingLayer<T>::freeBuffer()
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
mAllocator->free((void**) (&mCurandStatesDevice));
|
||||
mAllocator->free((void**) (&mTokensPerStepDevice));
|
||||
mAllocator->free((void**) (&mSetupWorkspaceDevice));
|
||||
mAllocator->free((void**) (&mSamplingWorkspaceDevice));
|
||||
mAllocator->free((void**) (&mRuntimeTopKDevice));
|
||||
mAllocator->free((void**) (&mTargetTokensDevice));
|
||||
mAllocator->free((void**) (&mRandomSeedsDevice));
|
||||
mAllocator->free((void**) (&mMedusaLogitsPtrsDevice));
|
||||
mAllocator->free((void**) (&mMedusaSelectedLogitsPtrsDevice));
|
||||
mAllocator->free((void**) (&mCurandStatesMedusaLogitsDevice));
|
||||
mAllocator->free((void**) (&mRuntimeTopKPerRequestPerMedusaHeadDevice));
|
||||
mAllocator->free((void**) (&mNewDraftTokensDevice));
|
||||
mAllocator->free((void**) (&mBestPathIdsDevice));
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -191,17 +199,6 @@ void MedusaDecodingLayer<T>::setup(SizeType batchSize, SizeType const* batchSlot
|
||||
}
|
||||
initCurandStates({tiledRandomSeed}, batchSizeMaxNumHeads, tiledBatchSlots, mCurandStatesMedusaLogitsDevice);
|
||||
|
||||
// Prepare tokens per step
|
||||
{
|
||||
auto tokensPerStep = setupParams.tokensPerStep.value_or(std::vector<SizeType>{batchSize, mMaxTokensPerStep});
|
||||
TLLM_CHECK_WITH_INFO(tokensPerStep.size() == batchSize,
|
||||
fmtstr("tokensPerStep.size() (%lu) == batchSize (%d) is not satisfied!", tokensPerStep.size(), batchSize));
|
||||
|
||||
cudaAutoCpy(reinterpret_cast<SizeType*>(mSetupWorkspaceDevice), tokensPerStep.data(), batchSize, mStream);
|
||||
invokeScatterDecodingParams(
|
||||
reinterpret_cast<SizeType*>(mSetupWorkspaceDevice), mTokensPerStepDevice, batchSlots, batchSize, mStream);
|
||||
}
|
||||
|
||||
// Prepare runtime top K
|
||||
auto prepareRuntimeTopK = [this](std::vector<SizeType> const& runtimeTopK, SizeType batchSize,
|
||||
SizeType const* batchSlots, SizeType* runtimeTopKDevice)
|
||||
@ -221,7 +218,7 @@ void MedusaDecodingLayer<T>::setup(SizeType batchSize, SizeType const* batchSlot
|
||||
|
||||
auto constexpr defaultTopK = 1u;
|
||||
{
|
||||
auto runtimeTopK = setupParams.runtimeTopK.value_or(std::vector<SizeType>{batchSize, defaultTopK});
|
||||
auto runtimeTopK = setupParams.runtimeTopK.value_or(std::vector<SizeType>(batchSize, defaultTopK));
|
||||
auto const curMaxTopK = prepareRuntimeTopK(runtimeTopK, batchSize, batchSlots, mRuntimeTopKDevice);
|
||||
mRuntimeMaxTopK = std::max(mRuntimeMaxTopK, curMaxTopK);
|
||||
}
|
||||
@ -279,6 +276,10 @@ void MedusaDecodingLayer<T>::forward(DecodingOutputParams& outputs, MedusaForwar
|
||||
|
||||
sampleNewDraftTokens(outputs, inputs);
|
||||
|
||||
scatterNewDraftTokens(outputs, inputs);
|
||||
|
||||
packAcceptedPaths(outputs, inputs);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
@ -290,10 +291,9 @@ void MedusaDecodingLayer<T>::samplePrimeHeadTokens(DecodingOutputParams& outputs
|
||||
auto const batchSize = inputs.logits.shape[0];
|
||||
|
||||
auto logits = inputs.logits.template getPtr<T>();
|
||||
auto batchSlots
|
||||
= inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType const>() : static_cast<SizeType*>(nullptr);
|
||||
auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<SizeType>()
|
||||
: static_cast<SizeType*>(nullptr);
|
||||
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType const>() : nullptr;
|
||||
auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<SizeType>() : nullptr;
|
||||
auto tokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr<SizeType>();
|
||||
|
||||
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding");
|
||||
@ -302,11 +302,11 @@ void MedusaDecodingLayer<T>::samplePrimeHeadTokens(DecodingOutputParams& outputs
|
||||
// Sequence length is not modified, endIds is not checked, outputLogProbs are not supported.
|
||||
// Finished state is not set.
|
||||
invokeBatchTopKSampling(mSamplingWorkspaceDevice, logits, /* logProbsPtrs */ static_cast<T const* const*>(nullptr),
|
||||
/* outputIdsPtrs */ nullptr, mTargetTokensDevice, sequenceLengths,
|
||||
/* outputIdsPtrs */ nullptr, mTargetTokensDevice, /* sequenceLengths */ nullptr,
|
||||
/* finishedInput */ nullptr, /* finishedOutput */ nullptr,
|
||||
/* cumLogProbs */ nullptr, /* outputLogProbs */ nullptr, mCurandStatesDevice, mRuntimeMaxTopK,
|
||||
mRuntimeTopKDevice, 1.0f, /* runtimeTopPDevice */ nullptr, mVocabSizePadded, /* endIds */ nullptr, batchSlots,
|
||||
mStream, batchSize, mMaxBatchSize, mTokensPerStepDevice, mMaxTokensPerStep, mMaxTokensPerStep,
|
||||
mStream, batchSize, mMaxBatchSize, tokensPerStepDevice, mMaxTokensPerStep, mMaxTokensPerStep,
|
||||
/* skipDecode */ nullptr, /* normalizeLogProbs */ false,
|
||||
/* probsComputed */ false, /* return all Top-K*/ false);
|
||||
|
||||
@ -324,7 +324,6 @@ void MedusaDecodingLayer<T>::acceptDraftTokens(DecodingOutputParams& outputs, Me
|
||||
auto outputIds = outputs.output_ids.template getPtr<TokenIdType>();
|
||||
auto endIds = inputs.end_ids.template getPtr<TokenIdType const>();
|
||||
auto paths = inputs.paths.template getPtr<SizeType const>();
|
||||
auto medusaLogits = inputs.medusaLogits.template getPtr<T const>();
|
||||
|
||||
auto batchSlots
|
||||
= inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType const>() : static_cast<SizeType*>(nullptr);
|
||||
@ -332,20 +331,42 @@ void MedusaDecodingLayer<T>::acceptDraftTokens(DecodingOutputParams& outputs, Me
|
||||
: static_cast<SizeType*>(nullptr);
|
||||
auto acceptedLengths = outputs.acceptedLengths ? outputs.acceptedLengths->template getPtr<SizeType>()
|
||||
: static_cast<SizeType*>(nullptr);
|
||||
auto curTokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr<SizeType>();
|
||||
auto targetTokensPerStepDevice = inputs.medusaTargetTokensPerStep.template getPtr<SizeType>();
|
||||
|
||||
auto medusaInputLogitsPtrs = BufferRange<T*>(*mMedusaInputLogitsPtrs);
|
||||
for (SizeType bi = 0; bi < batchSize; ++bi)
|
||||
{
|
||||
auto const slot = batchSlots[bi];
|
||||
for (SizeType hi = 0; hi < mMaxNumHeads; ++hi)
|
||||
{
|
||||
medusaInputLogitsPtrs[slot * mMaxNumHeads + hi] = inputs.medusaLogits[slot][hi].template getPtr<T>();
|
||||
}
|
||||
}
|
||||
|
||||
auto draftIds = outputs.nextDraftTokens ? outputs.nextDraftTokens->template getPtr<TokenIdType>() : nullptr;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
curTokensPerStepDevice != nullptr, "Current tokens per step must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(
|
||||
targetTokensPerStepDevice != nullptr, "Target tokens per step must be provided for MedusaDecoding");
|
||||
|
||||
auto finishedStates
|
||||
= reinterpret_cast<FinishedState*>(outputs.finished->template getPtr<FinishedState::UnderlyingType>());
|
||||
|
||||
// Compare draft tokens from outputIds with sampled target tokens at mTargetTokensDevice using paths.
|
||||
// Select the longest accepted path, modify outputIds in-place, increment sequenceLengths accordingly.
|
||||
// Fill mMedusaLogitsPtrsDevice with respective Medusa logits
|
||||
acceptDraftTokensByIdsWithPaths(outputIds, mTargetTokensDevice, sequenceLengths, acceptedLengths, finishedStates,
|
||||
batchSlots, paths, endIds, medusaLogits, const_cast<T const**>(mMedusaLogitsPtrsDevice), batchSize, mVocabSize,
|
||||
mMaxBatchSize, maxSeqLen, mMaxTokensPerStep, mMaxNumHeads, mMaxTokensPerStep, mStream);
|
||||
// Fill mMedusaSelectedLogitsPtrsDevice with respective Medusa logits
|
||||
acceptDraftTokensByIdsWithPaths(outputIds, draftIds, mTargetTokensDevice, sequenceLengths, acceptedLengths,
|
||||
finishedStates, batchSlots, paths, endIds,
|
||||
reinterpret_cast<T const**>(bufferCast<int64_t>(*mMedusaInputLogitsPtrs)),
|
||||
const_cast<T const**>(mMedusaSelectedLogitsPtrsDevice), curTokensPerStepDevice, targetTokensPerStepDevice,
|
||||
mBestPathIdsDevice, batchSize, mVocabSize, mMaxBatchSize, mMaxTokensPerStep, maxSeqLen, mMaxNumHeads,
|
||||
mMaxTokensPerStep, mStream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
@ -358,8 +379,7 @@ void MedusaDecodingLayer<T>::sampleNewDraftTokens(DecodingOutputParams& outputs,
|
||||
auto const batchSize = inputs.logits.shape[0];
|
||||
auto batchSlots
|
||||
= inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType const>() : static_cast<SizeType*>(nullptr);
|
||||
auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<SizeType>()
|
||||
: static_cast<SizeType*>(nullptr);
|
||||
auto sequenceLengths = (outputs.sequence_length) ? outputs.sequence_length->template getPtr<SizeType>() : nullptr;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(sequenceLengths != nullptr, "Sequence lengths must be provided for MedusaDecoding");
|
||||
@ -378,9 +398,6 @@ void MedusaDecodingLayer<T>::sampleNewDraftTokens(DecodingOutputParams& outputs,
|
||||
}
|
||||
|
||||
auto draftIdsPtrs = reinterpret_cast<TokenIdType**>(bufferCast<int64_t>(*mDraftIdsPtrHost));
|
||||
auto draftIds = (outputs.nextDraftTokens) ? outputs.nextDraftTokens->template getPtr<TokenIdType>()
|
||||
: static_cast<TokenIdType*>(nullptr);
|
||||
TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding");
|
||||
|
||||
for (SizeType bi = 0; bi < batchSize; ++bi)
|
||||
{
|
||||
@ -388,12 +405,13 @@ void MedusaDecodingLayer<T>::sampleNewDraftTokens(DecodingOutputParams& outputs,
|
||||
for (SizeType hi = 0; hi < mMaxNumHeads; ++hi)
|
||||
{
|
||||
draftIdsPtrs[slot * mMaxNumHeads + hi]
|
||||
= draftIds + slot * mMaxTokensPerStep + mCummulativeTopK[slot * mMaxNumHeads + hi];
|
||||
= mNewDraftTokensDevice + slot * mMaxTokensPerStep + mCummulativeTopK[slot * mMaxNumHeads + hi];
|
||||
}
|
||||
}
|
||||
|
||||
invokeBatchTopKSampling(mSamplingWorkspaceDevice,
|
||||
/* logits */ static_cast<T const*>(nullptr), const_cast<T const* const*>(mMedusaLogitsPtrsDevice), draftIdsPtrs,
|
||||
/* logits */ static_cast<T const*>(nullptr), const_cast<T const* const*>(mMedusaSelectedLogitsPtrsDevice),
|
||||
draftIdsPtrs,
|
||||
/* outputIds */ nullptr, /* sequenceLength */ nullptr,
|
||||
/* finishedInput */ nullptr, /* finishedOutput */ nullptr,
|
||||
/* cumLogProbs */ nullptr, /* outputLogProbs */ nullptr, mCurandStatesMedusaLogitsDevice,
|
||||
@ -408,6 +426,54 @@ void MedusaDecodingLayer<T>::sampleNewDraftTokens(DecodingOutputParams& outputs,
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MedusaDecodingLayer<T>::scatterNewDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.logits.shape[0];
|
||||
auto batchSlots
|
||||
= inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType const>() : static_cast<SizeType*>(nullptr);
|
||||
|
||||
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
|
||||
|
||||
auto draftIds = outputs.nextDraftTokens ? outputs.nextDraftTokens->template getPtr<TokenIdType>() : nullptr;
|
||||
auto tokensPerStepDevice = inputs.medusaCurTokensPerStep.template getPtr<SizeType>();
|
||||
auto treeIds = inputs.treeIds.template getPtr<SizeType>();
|
||||
TLLM_CHECK_WITH_INFO(draftIds != nullptr, "Draft ids must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(tokensPerStepDevice != nullptr, "Tokens per step must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(treeIds != nullptr, "Tree ids must be provided for MedusaDecoding");
|
||||
|
||||
scatterMedusaDraftTokens(draftIds, mNewDraftTokensDevice, treeIds, tokensPerStepDevice, batchSlots,
|
||||
mMaxTokensPerStep, batchSize, mStream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void MedusaDecodingLayer<T>::packAcceptedPaths(DecodingOutputParams& outputs, MedusaForwardParams& inputs)
|
||||
{
|
||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
|
||||
|
||||
auto const batchSize = inputs.logits.shape[0];
|
||||
auto paths = inputs.paths.template getPtr<SizeType const>();
|
||||
auto batchSlots = inputs.batch_slots ? inputs.batch_slots->template getPtr<SizeType const>() : nullptr;
|
||||
auto acceptedLengths = outputs.acceptedLengths ? outputs.acceptedLengths->template getPtr<SizeType>() : nullptr;
|
||||
auto acceptedLengthsCumSum
|
||||
= outputs.acceptedLengthsCumSum ? outputs.acceptedLengthsCumSum->template getPtr<SizeType>() : nullptr;
|
||||
auto medusaPathsOffsets
|
||||
= outputs.medusaPathsOffsets ? outputs.medusaPathsOffsets->template getPtr<SizeType>() : nullptr;
|
||||
|
||||
TLLM_CHECK_WITH_INFO(batchSlots != nullptr, "Batch slots must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(acceptedLengths != nullptr, "Accepted lengths must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(acceptedLengthsCumSum != nullptr, "acceptedLengthsCumSum must be provided for MedusaDecoding");
|
||||
TLLM_CHECK_WITH_INFO(medusaPathsOffsets != nullptr, "medusaPathsOffsets must be provided for MedusaDecoding");
|
||||
invokePackAcceptedPaths(acceptedLengthsCumSum, medusaPathsOffsets, acceptedLengths, mBestPathIdsDevice, paths,
|
||||
batchSlots, batchSize, mMaxTokensPerStep, mMaxNumHeads + 1, mStream);
|
||||
|
||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
|
||||
}
|
||||
|
||||
template class MedusaDecodingLayer<float>;
|
||||
template class MedusaDecodingLayer<half>;
|
||||
|
||||
|
||||
@ -44,11 +44,10 @@ public:
|
||||
class MedusaSetupParams : public DecodingSetupParams
|
||||
{
|
||||
public:
|
||||
std::optional<std::vector<runtime::SizeType>> runtimeTopK; // [1] or [batchSize] on cpu
|
||||
std::optional<std::vector<runtime::SizeType>> runtimeTopK; // [1] or [batchSize] on cpu
|
||||
std::optional<std::vector<std::vector<runtime::SizeType>>>
|
||||
runtimeHeadsTopK; // [batchSize, maxMedusaHeads] on cpu
|
||||
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [batchSize] on cpu
|
||||
std::optional<std::vector<runtime::SizeType>> tokensPerStep; // [1] or [batchSize] on cpu
|
||||
runtimeHeadsTopK; // [batchSize, maxMedusaHeads] on cpu
|
||||
std::optional<std::vector<uint64_t>> randomSeed; // [1] or [batchSize] on cpu
|
||||
};
|
||||
|
||||
class MedusaForwardParams : public DecodingParams
|
||||
@ -59,8 +58,12 @@ public:
|
||||
{
|
||||
}
|
||||
|
||||
tc::Tensor paths; // [maxBatchSize, maxTokensPerStep, maxNumHeads + 1] on gpu
|
||||
tc::Tensor medusaLogits; // [maxNumHeads, maxBatchSize, maxTokensPerStep, vocabSize] on gpu
|
||||
tc::Tensor paths; // [maxBatchSize, maxTokensPerStep, maxNumHeads + 1] on gpu
|
||||
std::vector<std::vector<tc::Tensor>>
|
||||
medusaLogits; // [maxBatchSize][maxNumHeads][tokensPerStep, vocabSize] on gpu
|
||||
tc::Tensor medusaCurTokensPerStep; // [maxBatchSize] on gpu
|
||||
tc::Tensor medusaTargetTokensPerStep; // [maxBatchSize] on gpu
|
||||
tc::Tensor treeIds; // [maxBatchSize, maxTokensPerStep] on gpu
|
||||
};
|
||||
|
||||
MedusaDecodingLayer(runtime::SizeType maxBatchSize, runtime::SizeType vocabSize, runtime::SizeType vocabSizePadded,
|
||||
@ -80,6 +83,8 @@ private:
|
||||
void samplePrimeHeadTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs);
|
||||
void acceptDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs);
|
||||
void sampleNewDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs);
|
||||
void scatterNewDraftTokens(DecodingOutputParams& outputs, MedusaForwardParams& inputs);
|
||||
void packAcceptedPaths(DecodingOutputParams& outputs, MedusaForwardParams& inputs);
|
||||
|
||||
private:
|
||||
using Base::mStream;
|
||||
@ -97,19 +102,21 @@ private:
|
||||
runtime::SizeType mRuntimeMaxTopKPerRequestPerMedusaHead{0};
|
||||
|
||||
curandState_t* mCurandStatesDevice{nullptr};
|
||||
runtime::SizeType* mTokensPerStepDevice{nullptr};
|
||||
void* mSetupWorkspaceDevice{nullptr};
|
||||
void* mSamplingWorkspaceDevice{nullptr};
|
||||
runtime::SizeType* mRuntimeTopKDevice{nullptr};
|
||||
runtime::TokenIdType* mTargetTokensDevice{nullptr};
|
||||
uint64_t* mRandomSeedsDevice{nullptr};
|
||||
T** mMedusaLogitsPtrsDevice{nullptr};
|
||||
T** mMedusaSelectedLogitsPtrsDevice{nullptr};
|
||||
curandState_t* mCurandStatesMedusaLogitsDevice{nullptr};
|
||||
runtime::SizeType* mRuntimeTopKPerRequestPerMedusaHeadDevice{nullptr};
|
||||
runtime::TokenIdType* mNewDraftTokensDevice{nullptr};
|
||||
runtime::SizeType* mBestPathIdsDevice{nullptr};
|
||||
|
||||
runtime::ITensor::UniquePtr mTiledBatchSlotsSetup;
|
||||
runtime::ITensor::UniquePtr mTiledBatchSlotsForward;
|
||||
runtime::ITensor::UniquePtr mDraftIdsPtrHost;
|
||||
runtime::ITensor::UniquePtr mMedusaInputLogitsPtrs;
|
||||
|
||||
std::vector<runtime::SizeType> mCummulativeTopK;
|
||||
};
|
||||
|
||||
@ -66,7 +66,7 @@ void OnlineBeamSearchLayer<T>::invokeSoftMax(BeamSearchOutputParams& outputs, So
|
||||
bh.length_penalties = length_penalties_buf_;
|
||||
bh.early_stoppings = early_stoppings_buf_;
|
||||
|
||||
bh.batch_size = static_cast<std::int32_t>(outputs.output_ids_ptr.shape[0]);
|
||||
bh.batch_size = static_cast<std::int32_t>(params.end_ids.shape[0]);
|
||||
bh.beam_width = static_cast<std::int32_t>(outputs.output_ids_ptr.shape[1]);
|
||||
bh.ite = params.ite;
|
||||
bh.local_batch_size = params.logits.shape[0];
|
||||
|
||||
@ -45,7 +45,8 @@ set(PLUGIN_LISTS
|
||||
lookupPlugin
|
||||
loraPlugin
|
||||
mixtureOfExperts
|
||||
selectiveScanPlugin)
|
||||
selectiveScanPlugin
|
||||
mambaConv1dPlugin)
|
||||
|
||||
foreach(PLUGIN_ITER ${PLUGIN_LISTS})
|
||||
include_directories(${PLUGIN_ITER})
|
||||
|
||||
@ -26,6 +26,7 @@
|
||||
#include "tensorrt_llm/plugins/layernormQuantizationPlugin/layernormQuantizationPlugin.h"
|
||||
#include "tensorrt_llm/plugins/lookupPlugin/lookupPlugin.h"
|
||||
#include "tensorrt_llm/plugins/loraPlugin/loraPlugin.h"
|
||||
#include "tensorrt_llm/plugins/mambaConv1dPlugin/mambaConv1dPlugin.h"
|
||||
#include "tensorrt_llm/plugins/mixtureOfExperts/mixtureOfExpertsPlugin.h"
|
||||
#if ENABLE_MULTI_DEVICE
|
||||
#include "tensorrt_llm/plugins/ncclPlugin/allgatherPlugin.h"
|
||||
@ -73,7 +74,7 @@ public:
|
||||
GlobalLoggerFinder gGlobalLoggerFinder{};
|
||||
|
||||
#if !defined(_MSC_VER)
|
||||
__attribute__((constructor))
|
||||
[[maybe_unused]] __attribute__((constructor))
|
||||
#endif
|
||||
void initOnLoad()
|
||||
{
|
||||
@ -89,6 +90,40 @@ bool pluginsInitialized = false;
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace tensorrt_llm::plugins::api
|
||||
{
|
||||
|
||||
LoggerManager& tensorrt_llm::plugins::api::LoggerManager::getInstance() noexcept
|
||||
{
|
||||
static LoggerManager instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void LoggerManager::setLoggerFinder(nvinfer1::ILoggerFinder* finder)
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
if (mLoggerFinder == nullptr && finder != nullptr)
|
||||
{
|
||||
mLoggerFinder = finder;
|
||||
}
|
||||
}
|
||||
|
||||
[[maybe_unused]] nvinfer1::ILogger* LoggerManager::logger()
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
if (mLoggerFinder != nullptr)
|
||||
{
|
||||
return mLoggerFinder->findLogger();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
nvinfer1::ILogger* LoggerManager::defaultLogger() noexcept
|
||||
{
|
||||
return gLogger;
|
||||
}
|
||||
} // namespace tensorrt_llm::plugins::api
|
||||
|
||||
// New Plugin APIs
|
||||
|
||||
extern "C"
|
||||
@ -127,7 +162,7 @@ extern "C"
|
||||
|
||||
[[maybe_unused]] void setLoggerFinder([[maybe_unused]] nvinfer1::ILoggerFinder* finder)
|
||||
{
|
||||
tensorrt_llm::plugins::api::LoggerFinder::getInstance().setLoggerFinder(finder);
|
||||
tensorrt_llm::plugins::api::LoggerManager::getInstance().setLoggerFinder(finder);
|
||||
}
|
||||
|
||||
[[maybe_unused]] nvinfer1::IPluginCreator* const* getPluginCreators(std::int32_t& nbCreators)
|
||||
@ -155,6 +190,7 @@ extern "C"
|
||||
static tensorrt_llm::plugins::LookupPluginCreator lookupPluginCreator;
|
||||
static tensorrt_llm::plugins::LoraPluginCreator loraPluginCreator;
|
||||
static tensorrt_llm::plugins::SelectiveScanPluginCreator selectiveScanPluginCreator;
|
||||
static tensorrt_llm::plugins::MambaConv1dPluginCreator mambaConv1DPluginCreator;
|
||||
|
||||
static std::array pluginCreators
|
||||
= { creatorPtr(identityPluginCreator),
|
||||
@ -179,37 +215,10 @@ extern "C"
|
||||
creatorPtr(lookupPluginCreator),
|
||||
creatorPtr(loraPluginCreator),
|
||||
creatorPtr(selectiveScanPluginCreator),
|
||||
creatorPtr(mambaConv1DPluginCreator),
|
||||
};
|
||||
nbCreators = pluginCreators.size();
|
||||
return pluginCreators.data();
|
||||
}
|
||||
|
||||
} // extern "C"
|
||||
|
||||
namespace tensorrt_llm::plugins::api
|
||||
{
|
||||
LoggerFinder& tensorrt_llm::plugins::api::LoggerFinder::getInstance() noexcept
|
||||
{
|
||||
static LoggerFinder instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void LoggerFinder::setLoggerFinder(nvinfer1::ILoggerFinder* finder)
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
if (mLoggerFinder == nullptr && finder != nullptr)
|
||||
{
|
||||
mLoggerFinder = finder;
|
||||
}
|
||||
}
|
||||
|
||||
nvinfer1::ILogger* LoggerFinder::findLogger()
|
||||
{
|
||||
std::lock_guard<std::mutex> lk(mMutex);
|
||||
if (mLoggerFinder != nullptr)
|
||||
{
|
||||
return mLoggerFinder->findLogger();
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
} // namespace tensorrt_llm::plugins::api
|
||||
|
||||
@ -33,21 +33,22 @@ PluginFieldCollection GemmPluginCreator::mFC{};
|
||||
std::vector<nvinfer1::PluginField> GemmPluginCreator::mPluginAttributes;
|
||||
|
||||
void getProblemParams(cublasOperation_t& transa, cublasOperation_t& transb, int& m, int& n, int& k, int& lda, int& ldb,
|
||||
int& ldc, bool transA, bool transB, int M, int N, int K)
|
||||
int& ldc, bool transA, bool transB, int M, int N, int K, int padLda, int padLdb)
|
||||
{
|
||||
transa = transB ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
transb = transA ? CUBLAS_OP_T : CUBLAS_OP_N;
|
||||
m = N;
|
||||
n = M;
|
||||
k = K;
|
||||
lda = transB ? K : N;
|
||||
ldb = transA ? M : K;
|
||||
lda = transB ? K + padLdb : N + padLdb;
|
||||
ldb = transA ? M + padLda : K + padLda;
|
||||
ldc = N;
|
||||
}
|
||||
|
||||
void runGemm(int const M, int const N, int const K, bool const transA, bool const transB, const nvinfer1::DataType type,
|
||||
CublasGemmWrapperPtr const& cublasWrapperPtr, void const* act, void const* weight, void* output,
|
||||
std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic, void* workspace, cudaStream_t stream)
|
||||
void runGemm(int const M, int const N, int const K, bool const transA, bool const transB, int const padLda,
|
||||
int const padLdb, const nvinfer1::DataType type, CublasGemmWrapperPtr const& cublasWrapperPtr, void const* act,
|
||||
void const* weight, void* output, std::optional<cublasLtMatmulHeuristicResult_t> const& heuristic, void* workspace,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
cublasWrapperPtr->setStream(stream);
|
||||
cublasWrapperPtr->setWorkspace(workspace);
|
||||
@ -55,7 +56,7 @@ void runGemm(int const M, int const N, int const K, bool const transA, bool cons
|
||||
cublasOperation_t transa, transb;
|
||||
int m, n, k;
|
||||
int lda, ldb, ldc;
|
||||
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, transA, transB, M, N, K);
|
||||
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, transA, transB, M, N, K, padLda, padLdb);
|
||||
|
||||
cublasWrapperPtr->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc);
|
||||
cublasWrapperPtr->Gemm(transa, transb, m, n, k, weight, lda, act, ldb, output, ldc, heuristic);
|
||||
@ -78,7 +79,8 @@ void CublasLtGemmPluginProfiler::runTactic(
|
||||
nextWorkspacePtrWithAlignment(reinterpret_cast<int8_t*>(weightPtr), n * k * dataSize, ALIGNMENT));
|
||||
char* workspacePtr = reinterpret_cast<char*>(
|
||||
nextWorkspacePtrWithAlignment(reinterpret_cast<int8_t*>(outputPtr), m * n * dataSize, ALIGNMENT));
|
||||
runGemm(m, n, k, mTransA, mTransB, mType, mRunner, actPtr, weightPtr, outputPtr, {tactic}, workspacePtr, stream);
|
||||
runGemm(m, n, k, mTransA, mTransB, mPadLda, mPadLdb, mType, mRunner, actPtr, weightPtr, outputPtr, {tactic},
|
||||
workspacePtr, stream);
|
||||
}
|
||||
|
||||
bool CublasLtGemmPluginProfiler::checkTactic(int m, int n, int k, Config const& tactic) const
|
||||
@ -86,7 +88,7 @@ bool CublasLtGemmPluginProfiler::checkTactic(int m, int n, int k, Config const&
|
||||
cublasOperation_t transa, transb;
|
||||
int M = m, N = n, K = k;
|
||||
int lda, ldb, ldc;
|
||||
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, M, N, K);
|
||||
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, M, N, K, mPadLda, mPadLdb);
|
||||
|
||||
mRunner->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc);
|
||||
|
||||
@ -117,7 +119,7 @@ std::vector<CublasLtGemmPluginProfiler::Config> CublasLtGemmPluginProfiler::getT
|
||||
cublasOperation_t transa, transb;
|
||||
int m, n, k;
|
||||
int lda, ldb, ldc;
|
||||
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, M, N, K);
|
||||
getProblemParams(transa, transb, m, n, k, lda, ldb, ldc, mTransA, mTransB, M, N, K, mPadLda, mPadLdb);
|
||||
|
||||
mRunner->createDescriptors(transa, transb, m, n, k, lda, ldb, ldc);
|
||||
auto const heruistics = mRunner->getTactics(transa, transb, m, n, k, lda, ldb, ldc);
|
||||
@ -126,10 +128,12 @@ std::vector<CublasLtGemmPluginProfiler::Config> CublasLtGemmPluginProfiler::getT
|
||||
return heruistics;
|
||||
}
|
||||
|
||||
GemmPlugin::GemmPlugin(
|
||||
int transA, int transB, nvinfer1::DataType type, bool useFp8, GemmPlugin::PluginProfilerPtr const& pluginProfiler)
|
||||
GemmPlugin::GemmPlugin(int transA, int transB, int padLda, int padLdb, nvinfer1::DataType type, bool useFp8,
|
||||
GemmPlugin::PluginProfilerPtr const& pluginProfiler)
|
||||
: mTransA(transA)
|
||||
, mTransB(transB)
|
||||
, mPadLda(padLda)
|
||||
, mPadLdb(padLdb)
|
||||
, mType(type)
|
||||
, mUseFp8(useFp8)
|
||||
, mPluginProfiler(pluginProfiler)
|
||||
@ -145,6 +149,8 @@ GemmPlugin::GemmPlugin(void const* data, size_t length, GemmPlugin::PluginProfil
|
||||
char const *d = reinterpret_cast<char const*>(data), *a = d;
|
||||
read(d, mTransA);
|
||||
read(d, mTransB);
|
||||
read(d, mPadLda);
|
||||
read(d, mPadLdb);
|
||||
read(d, mType);
|
||||
read(d, mUseFp8);
|
||||
read(d, mDims);
|
||||
@ -169,6 +175,7 @@ void GemmPlugin::init()
|
||||
|
||||
mPluginProfiler->setTranspose(mTransA, mTransB);
|
||||
mPluginProfiler->setOutputType(mOutputType);
|
||||
mPluginProfiler->setPadLd(mPadLda, mPadLdb);
|
||||
|
||||
mGemmId = GemmIdCublas(mDims.n, mDims.k, mType, mTransA, mTransB, mOutputType);
|
||||
}
|
||||
@ -363,13 +370,16 @@ int GemmPlugin::enqueue(nvinfer1::PluginTensorDesc const* inputDesc, nvinfer1::P
|
||||
|
||||
int const nbDimsA = inputDesc[0].dims.nbDims;
|
||||
int const nbDimsB = inputDesc[1].dims.nbDims;
|
||||
auto const M = computeMDimension(mTransA, nbDimsA, inputDesc[0].dims.d);
|
||||
auto const N = computeNDimension(mTransB, nbDimsB, inputDesc[1].dims.d);
|
||||
int const K = mTransA ? inputDesc[0].dims.d[0] : inputDesc[0].dims.d[nbDimsA - 1];
|
||||
int const padM = mTransA ? mPadLda : 0;
|
||||
int const padN = mTransB ? 0 : mPadLdb;
|
||||
int const padK = mTransA ? 0 : mPadLda;
|
||||
auto const M = computeMDimension(mTransA, nbDimsA, inputDesc[0].dims.d) - padM;
|
||||
auto const N = computeNDimension(mTransB, nbDimsB, inputDesc[1].dims.d) - padN;
|
||||
int const K = mTransA ? inputDesc[0].dims.d[0] - padK : inputDesc[0].dims.d[nbDimsA - 1] - padK;
|
||||
|
||||
auto bestTactic = mPluginProfiler->getBestConfig(M, mGemmId);
|
||||
runGemm(M, N, K, mTransA, mTransB, mType, mCublasWrapper, inputs[0], inputs[1], outputs[0], bestTactic, workspace,
|
||||
stream);
|
||||
runGemm(M, N, K, mTransA, mTransB, mPadLda, mPadLdb, mType, mCublasWrapper, inputs[0], inputs[1], outputs[0],
|
||||
bestTactic, workspace, stream);
|
||||
return 0;
|
||||
}
|
||||
|
||||
@ -411,8 +421,9 @@ void GemmPlugin::destroy() noexcept
|
||||
|
||||
size_t GemmPlugin::getSerializationSize() const noexcept
|
||||
{
|
||||
return sizeof(mTransA) + sizeof(mTransB) + sizeof(mType) + sizeof(mDims) + sizeof(mUseFp8)
|
||||
+ mPluginProfiler->getSerializationSize(mGemmId) + sizeof(mOutputType); // selected tactics container size
|
||||
return sizeof(mTransA) + sizeof(mTransB) + sizeof(mPadLda) + sizeof(mPadLdb) + sizeof(mType) + sizeof(mDims)
|
||||
+ sizeof(mUseFp8) + mPluginProfiler->getSerializationSize(mGemmId)
|
||||
+ sizeof(mOutputType); // selected tactics container size
|
||||
}
|
||||
|
||||
void GemmPlugin::serialize(void* buffer) const noexcept
|
||||
@ -420,6 +431,8 @@ void GemmPlugin::serialize(void* buffer) const noexcept
|
||||
char *d = static_cast<char*>(buffer), *a = d;
|
||||
write(d, mTransA);
|
||||
write(d, mTransB);
|
||||
write(d, mPadLda);
|
||||
write(d, mPadLdb);
|
||||
write(d, mType);
|
||||
write(d, mUseFp8);
|
||||
write(d, mDims);
|
||||
@ -439,6 +452,8 @@ GemmPluginCreator::GemmPluginCreator()
|
||||
mPluginAttributes.clear();
|
||||
mPluginAttributes.emplace_back(PluginField("transA", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("transB", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("padLda", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("padLdb", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("type_id", nullptr, PluginFieldType::kINT32, 1));
|
||||
mPluginAttributes.emplace_back(PluginField("use_fp8", nullptr, PluginFieldType::kINT32, 0));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
@ -463,7 +478,7 @@ PluginFieldCollection const* GemmPluginCreator::getFieldNames() noexcept
|
||||
IPluginV2* GemmPluginCreator::createPlugin(char const* name, PluginFieldCollection const* fc) noexcept
|
||||
{
|
||||
PluginField const* fields = fc->fields;
|
||||
int transA, transB;
|
||||
int transA, transB, padLda, padLdb;
|
||||
nvinfer1::DataType type;
|
||||
int useFp8;
|
||||
// Read configurations from each fields
|
||||
@ -480,6 +495,16 @@ IPluginV2* GemmPluginCreator::createPlugin(char const* name, PluginFieldCollecti
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
transB = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "pad_lda"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
padLda = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "pad_ldb"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
padLdb = static_cast<int>(*(static_cast<int const*>(fields[i].data)));
|
||||
}
|
||||
else if (!strcmp(attrName, "type_id"))
|
||||
{
|
||||
TLLM_CHECK(fields[i].type == PluginFieldType::kINT32);
|
||||
@ -497,7 +522,7 @@ IPluginV2* GemmPluginCreator::createPlugin(char const* name, PluginFieldCollecti
|
||||
// Create plugin profiler with shared tactics map
|
||||
// FIXME enable tactic profiler
|
||||
auto pluginProfiler = gemmPluginProfileManager.createGemmPluginProfiler(/* inference */ false, /* skip */ true);
|
||||
auto* obj = new GemmPlugin(transA, transB, type, useFp8, pluginProfiler);
|
||||
auto* obj = new GemmPlugin(transA, transB, padLda, padLdb, type, useFp8, pluginProfiler);
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
return obj;
|
||||
}
|
||||
|
||||
@ -42,6 +42,12 @@ public:
|
||||
mTransB = transposeB;
|
||||
}
|
||||
|
||||
void setPadLd(int padLda, int padLdb)
|
||||
{
|
||||
mPadLda = padLda;
|
||||
mPadLdb = padLdb;
|
||||
}
|
||||
|
||||
void setOutputType(nvinfer1::DataType type)
|
||||
{
|
||||
mOutputType = type;
|
||||
@ -59,6 +65,8 @@ protected:
|
||||
private:
|
||||
bool mTransA;
|
||||
bool mTransB;
|
||||
int mPadLda;
|
||||
int mPadLdb;
|
||||
nvinfer1::DataType mOutputType;
|
||||
|
||||
static constexpr size_t ALIGNMENT = 256;
|
||||
@ -71,7 +79,8 @@ public:
|
||||
|
||||
GemmPlugin() = delete;
|
||||
|
||||
GemmPlugin(int transA, int transB, nvinfer1::DataType type, bool useFp8, PluginProfilerPtr const& profiler);
|
||||
GemmPlugin(int transA, int transB, int padLda, int padLdb, nvinfer1::DataType type, bool useFp8,
|
||||
PluginProfilerPtr const& profiler);
|
||||
|
||||
GemmPlugin(void const* data, size_t length, PluginProfilerPtr const& profiler);
|
||||
|
||||
@ -114,6 +123,8 @@ private:
|
||||
|
||||
int mTransA;
|
||||
int mTransB;
|
||||
int mPadLda;
|
||||
int mPadLdb;
|
||||
nvinfer1::DataType mType;
|
||||
nvinfer1::DataType mOutputType;
|
||||
|
||||
|
||||
@ -98,6 +98,7 @@ struct FusedQKVMaskedAttentionDispatchParams
|
||||
T const* ia3_key_weights;
|
||||
T const* ia3_value_weights;
|
||||
float const* qkv_scale_out;
|
||||
bool fp8_context_fmha;
|
||||
float const* attention_out_scale;
|
||||
bool mUnfuseQkvGemm;
|
||||
tc::QuantMode quant_option;
|
||||
@ -320,9 +321,12 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_params<T_MMHA, CROSS
|
||||
params.ia3_key_weights = reinterpret_cast<DataType const*>(input_params.ia3_key_weights);
|
||||
params.ia3_value_weights = reinterpret_cast<DataType const*>(input_params.ia3_value_weights);
|
||||
|
||||
if (input_params.quant_option.hasStaticActivationScaling())
|
||||
if (input_params.quant_option.hasStaticActivationScaling() || input_params.fp8_context_fmha)
|
||||
{
|
||||
// qkv_scale_out is nullptr currently (no scale).
|
||||
params.qkv_scale_quant_orig = input_params.qkv_scale_out;
|
||||
TLLM_CHECK_WITH_INFO(!input_params.fp8_context_fmha || input_params.attention_out_scale != nullptr,
|
||||
"attention output scale should be provided.");
|
||||
params.attention_out_scale_orig_quant = input_params.attention_out_scale;
|
||||
}
|
||||
|
||||
@ -374,7 +378,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
|
||||
int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type,
|
||||
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
|
||||
bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha,
|
||||
bool use_paged_context_fmha, bool use_cache, bool is_medusa_enabled)
|
||||
bool use_paged_context_fmha, bool use_fp8_context_fmha, bool use_cache, bool is_medusa_enabled)
|
||||
: mLayerIdx(layer_idx)
|
||||
, mNumHeads(num_heads)
|
||||
, mNumKVHeads(num_kv_heads)
|
||||
@ -408,6 +412,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
|
||||
, mPosShiftEnabled(pos_shift_enabled)
|
||||
, mDenseContextFMHA(dense_context_fmha)
|
||||
, mPagedContextFMHA(use_paged_context_fmha)
|
||||
, mFP8ContextFMHA(use_fp8_context_fmha)
|
||||
, mUseKVCache(use_cache)
|
||||
, mIsMedusaEnabled(is_medusa_enabled)
|
||||
{
|
||||
@ -442,6 +447,14 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(int layer_idx, int num_heads,
|
||||
}
|
||||
}
|
||||
|
||||
// Pre-Check of FP8 Context FMHA.
|
||||
if (mFP8ContextFMHA)
|
||||
{
|
||||
TLLM_CHECK_WITH_INFO(mEnableContextFMHA, "FP8 FMHA cannot be enabled because Context FMHA is not supported.");
|
||||
TLLM_CHECK_WITH_INFO(mSM == 90, "FP8 FMHA cannot be enabled on Pre-Hopper Arch.");
|
||||
TLLM_CHECK_WITH_INFO(!mPagedContextFMHA, "FP8 Context Paged KV FMHA hasn't been implemented yet.");
|
||||
}
|
||||
|
||||
TLLM_CHECK(isRoPE() == (rotary_embedding_dim != 0));
|
||||
TLLM_CHECK_WITH_INFO((mSM >= 80) || (mType != nvinfer1::DataType::kBF16),
|
||||
"Unsupported data type, pre SM 80 GPUs do not support bfloat16");
|
||||
@ -511,6 +524,7 @@ GPTAttentionPluginCommon::GPTAttentionPluginCommon(void const* data, size_t leng
|
||||
read(d, mPosShiftEnabled);
|
||||
read(d, mDenseContextFMHA);
|
||||
read(d, mPagedContextFMHA);
|
||||
read(d, mFP8ContextFMHA);
|
||||
read(d, mUseKVCache);
|
||||
read(d, mIsMedusaEnabled);
|
||||
|
||||
@ -558,13 +572,16 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(nvinfer1::DataType t
|
||||
const size_t qk_buf_float_size = mEnableContextFMHA ? 0
|
||||
: sizeof(float) * batch_size * mNumHeads * input_seq_length
|
||||
* (isCrossAttention() ? cross_qkv_length : input_seq_length);
|
||||
const size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA
|
||||
? batch_size * input_seq_length * (local_hidden_units_qo + 2 * local_hidden_units_kv)
|
||||
: 0;
|
||||
const size_t padding_offset_size = sizeof(int) * batch_size * input_seq_length;
|
||||
// It is assumed that the number of tokens per paged kv block should be >= 128.
|
||||
const size_t paged_kv_tma_desc_size = mPagedKVCache && mPagedContextFMHA
|
||||
? batch_size * 2 * TMA_DESC_SIZE_IN_BYTE * tc::divUp(max_attention_window, mTokensPerBlock)
|
||||
: 0;
|
||||
|
||||
int const NUM_BUFFERS = 12;
|
||||
int const NUM_BUFFERS = 13;
|
||||
size_t workspaces[NUM_BUFFERS];
|
||||
workspaces[0] = CUBLAS_WORKSPACE_SIZE;
|
||||
workspaces[1] = attention_mask_size;
|
||||
@ -576,8 +593,9 @@ size_t GPTAttentionPluginCommon::getWorkspaceSizeForContext(nvinfer1::DataType t
|
||||
workspaces[7] = qk_buf_size;
|
||||
workspaces[8] = qkv_buf_2_size;
|
||||
workspaces[9] = qk_buf_float_size;
|
||||
workspaces[10] = padding_offset_size;
|
||||
workspaces[11] = paged_kv_tma_desc_size;
|
||||
workspaces[10] = fp8_qkv_buffer_size;
|
||||
workspaces[11] = padding_offset_size;
|
||||
workspaces[12] = paged_kv_tma_desc_size;
|
||||
context_workspace_size = tc::calculateTotalWorkspaceSize(workspaces, NUM_BUFFERS);
|
||||
return context_workspace_size;
|
||||
}
|
||||
@ -725,6 +743,9 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
|
||||
const size_t qk_buf_float_size = mEnableContextFMHA ? 0
|
||||
: sizeof(float) * params.batch_size * mNumHeads
|
||||
* params.input_seq_length * (isCrossAttention() ? params.cross_qkv_length : params.input_seq_length);
|
||||
const size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA
|
||||
? params.batch_size * params.input_seq_length * (local_hidden_units_qo + 2 * local_hidden_units_kv)
|
||||
: 0;
|
||||
const size_t padding_offset_size
|
||||
= sizeof(int) * params.batch_size * (isCrossAttention() ? params.cross_qkv_length : params.input_seq_length);
|
||||
const size_t paged_kv_tma_desc_size = mPagedKVCache && mPagedContextFMHA
|
||||
@ -746,6 +767,8 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
|
||||
T* qk_buf_ = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, qk_buf_size));
|
||||
T* qkv_buf_2_ = reinterpret_cast<T*>(nextWorkspacePtr(workspace_byte_ptr, offset, qkv_buf_2_size));
|
||||
float* qk_buf_float_ = reinterpret_cast<float*>(nextWorkspacePtr(workspace_byte_ptr, offset, qk_buf_float_size));
|
||||
__nv_fp8_e4m3* fp8_qkv_buffer
|
||||
= reinterpret_cast<__nv_fp8_e4m3*>(nextWorkspacePtr(workspace_byte_ptr, offset, fp8_qkv_buffer_size));
|
||||
int* padding_offset = reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, padding_offset_size));
|
||||
void* paged_kv_tma_desc
|
||||
= reinterpret_cast<int*>(nextWorkspacePtr(workspace_byte_ptr, offset, paged_kv_tma_desc_size));
|
||||
@ -826,13 +849,15 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
|
||||
// Paged Context FMHA doesn't work with fp8/int8 kv cache currently.
|
||||
TLLM_CHECK_WITH_INFO(cache_type == KvCacheDataType::BASE || !enablePagedKVContextFMHA,
|
||||
"Paged Context FMHA doesn't work with fp8/int8 kv cache currently.");
|
||||
invokeApplyBiasRopeUpdateKVCache(const_cast<T*>(params.attention_input), q_buf_2_, kv_cache_buffer,
|
||||
const_cast<T*>(params.qkv_bias), params.q_seq_lengths, params.kv_seq_lengths,
|
||||
|
||||
invokeApplyBiasRopeUpdateKVCache(const_cast<T*>(params.attention_input), fp8_qkv_buffer, q_buf_2_,
|
||||
kv_cache_buffer, const_cast<T*>(params.qkv_bias), params.q_seq_lengths, params.kv_seq_lengths,
|
||||
mRemovePadding ? padding_offset : nullptr, params.batch_size, params.input_seq_length,
|
||||
params.cyclic_attention_window_size, params.sink_token_length, params.num_tokens, mNumHeads, mNumKVHeads,
|
||||
getHeadSize(), mRotaryEmbeddingDim, mRotaryEmbeddingBase, mRotaryEmbeddingScaleType, mRotaryEmbeddingScale,
|
||||
mRotaryEmbeddingMaxPositions, position_embedding_type, (int*) nullptr, mPosShiftEnabled, (float*) nullptr,
|
||||
0, cache_type, params.kv_scale_orig_quant, enablePagedKVContextFMHA, 1, mLaunchGridBlockCache, stream);
|
||||
mRotaryEmbeddingMaxPositions, position_embedding_type, (int*) nullptr, mPosShiftEnabled, nullptr, 0,
|
||||
cache_type, params.kv_scale_orig_quant, enablePagedKVContextFMHA, mFP8ContextFMHA, 1, mLaunchGridBlockCache,
|
||||
stream);
|
||||
sync_check_cuda_error();
|
||||
|
||||
// It is not needed with packed QKV input.
|
||||
@ -855,6 +880,7 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
|
||||
{
|
||||
TLLM_LOG_ERROR("Cannot support StreamingLLM now when enabling paged KV context FMHA.");
|
||||
}
|
||||
// TODO: add support for fp8 paged kv fmha later.
|
||||
mFMHARunner->setup_paged_kv(params.batch_size, params.input_seq_length, params.max_past_kv_len,
|
||||
params.max_blocks_per_sequence, mTokensPerBlock, params.cyclic_attention_window_size, params.num_tokens,
|
||||
isALiBi(), isAliBiWithScale(), mTpSize, mTpRank);
|
||||
@ -869,8 +895,10 @@ int GPTAttentionPluginCommon::enqueueContext(EnqueueContextParams<T, KVCacheBuff
|
||||
int const attention_window_size
|
||||
= mDenseContextFMHA ? params.num_tokens : params.cyclic_attention_window_size;
|
||||
mFMHARunner->setup(params.batch_size, params.input_seq_length, attention_window_size, params.num_tokens,
|
||||
isALiBi(), isAliBiWithScale(), mTpSize, mTpRank);
|
||||
mFMHARunner->run(const_cast<T*>(params.attention_input), cu_q_seqlens, params.context_buf, stream);
|
||||
params.attention_output_orig_quant, isALiBi(), isAliBiWithScale(), mTpSize, mTpRank);
|
||||
void const* fmha_input_tensor = mFP8ContextFMHA ? reinterpret_cast<void const*>(fp8_qkv_buffer)
|
||||
: reinterpret_cast<void const*>(params.attention_input);
|
||||
mFMHARunner->run(fmha_input_tensor, cu_q_seqlens, params.context_buf, stream);
|
||||
}
|
||||
sync_check_cuda_error();
|
||||
}
|
||||
@ -1177,7 +1205,6 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
|
||||
auto const quant_option = tc::QuantMode::fromDescription();
|
||||
float const* qkv_scale_out = nullptr;
|
||||
float const* attention_out_scale = nullptr;
|
||||
|
||||
int const* ia3_tasks = nullptr;
|
||||
T const* ia3_key_weights = nullptr;
|
||||
@ -1323,7 +1350,8 @@ int GPTAttentionPluginCommon::enqueueGeneration(
|
||||
dispatch_params.ia3_key_weights = ia3_key_weights;
|
||||
dispatch_params.ia3_value_weights = ia3_value_weights;
|
||||
dispatch_params.qkv_scale_out = qkv_scale_out;
|
||||
dispatch_params.attention_out_scale = attention_out_scale;
|
||||
dispatch_params.fp8_context_fmha = mFP8ContextFMHA;
|
||||
dispatch_params.attention_out_scale = params.attention_output_orig_quant;
|
||||
dispatch_params.quant_option = quant_option;
|
||||
dispatch_params.multi_block_mode = enable_multi_block;
|
||||
dispatch_params.max_seq_len_tile = max_num_seq_len_tiles;
|
||||
@ -1447,6 +1475,12 @@ int GPTAttentionPluginCommon::initialize() noexcept
|
||||
TLLM_CHECK_WITH_INFO(false, "GPTAttentionPlugin received wrong data type.");
|
||||
}
|
||||
|
||||
// FP8 FMHA should be used with fp8 workflow together.
|
||||
if (mFP8ContextFMHA)
|
||||
{
|
||||
data_type = DATA_TYPE_E4M3;
|
||||
}
|
||||
|
||||
// Load kernels for contiguous cache and paged kv cache at the same time.
|
||||
mFMHARunner.reset(new FusedMHARunnerV2(data_type, mNumHeads, getHeadSize(false), mQScaling));
|
||||
// Set flags: force_fp32_acc, is_s_padded, causal_mask, num_kv_heads.
|
||||
@ -1504,8 +1538,8 @@ size_t GPTAttentionPluginCommon::getCommonSerializationSize() noexcept
|
||||
+ sizeof(unsigned int) // mKVCacheQuantMode
|
||||
+ sizeof(mRemovePadding) + sizeof(mMaskType) + sizeof(mPagedKVCache) + sizeof(mTokensPerBlock) + sizeof(mType)
|
||||
+ sizeof(mMaxContextLength) + sizeof(mQKVBiasEnabled) + sizeof(mCrossAttention) + sizeof(mMaxDistance)
|
||||
+ sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) + sizeof(mPagedContextFMHA) + sizeof(mUseKVCache)
|
||||
+ sizeof(mUnfuseQkvGemm) + sizeof(mIsMedusaEnabled);
|
||||
+ sizeof(mPosShiftEnabled) + sizeof(mDenseContextFMHA) + sizeof(mPagedContextFMHA) + sizeof(mFP8ContextFMHA)
|
||||
+ sizeof(mUseKVCache) + sizeof(mUnfuseQkvGemm) + sizeof(mIsMedusaEnabled);
|
||||
}
|
||||
|
||||
void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
||||
@ -1543,6 +1577,7 @@ void GPTAttentionPluginCommon::serializeCommon(void* buffer) const noexcept
|
||||
write(d, mPosShiftEnabled);
|
||||
write(d, mDenseContextFMHA);
|
||||
write(d, mPagedContextFMHA);
|
||||
write(d, mFP8ContextFMHA);
|
||||
write(d, mUseKVCache);
|
||||
write(d, mIsMedusaEnabled);
|
||||
assert(d == a + getCommonSerializationSize());
|
||||
@ -1589,6 +1624,7 @@ GPTAttentionPluginCreatorCommon::GPTAttentionPluginCreatorCommon()
|
||||
mPluginAttributes.emplace_back(PluginField("pos_shift_enabled", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("dense_context_fmha", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("use_paged_context_fmha", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("use_fp8_context_fmha", nullptr, PluginFieldType::kINT8, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("use_cache", nullptr, PluginFieldType::kINT32, 0));
|
||||
mPluginAttributes.emplace_back(PluginField("is_medusa_enabled", nullptr, PluginFieldType::kINT8, 0));
|
||||
mFC.nbFields = mPluginAttributes.size();
|
||||
|
||||
@ -46,8 +46,8 @@ public:
|
||||
int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type,
|
||||
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
|
||||
bool qkv_bias_enabled, bool cross_attention = false, int max_distance = 0, bool pos_shift_enabled = false,
|
||||
bool dense_context_fmha = false, bool use_paged_context_fmha = false, bool use_cache = true,
|
||||
bool is_medusa_enabled = false);
|
||||
bool dense_context_fmha = false, bool use_paged_context_fmha = false, bool use_fp8_context_fmha = false,
|
||||
bool use_cache = true, bool is_medusa_enabled = false);
|
||||
|
||||
GPTAttentionPluginCommon(void const* data, size_t length);
|
||||
|
||||
@ -102,6 +102,7 @@ protected:
|
||||
int32_t const* kv_seq_lengths;
|
||||
float const* kv_scale_orig_quant;
|
||||
float const* kv_scale_quant_orig;
|
||||
float const* attention_output_orig_quant;
|
||||
T const* alibi_slopes;
|
||||
T* context_buf;
|
||||
void* key_value_cache;
|
||||
@ -137,6 +138,7 @@ protected:
|
||||
int32_t const* context_lengths;
|
||||
float const* kv_scale_orig_quant;
|
||||
float const* kv_scale_quant_orig;
|
||||
float const* attention_output_orig_quant;
|
||||
T const* alibi_slopes;
|
||||
T* context_buf;
|
||||
void* key_value_cache;
|
||||
@ -241,6 +243,7 @@ protected:
|
||||
int mMaxDistance = 0;
|
||||
bool mPosShiftEnabled = false;
|
||||
bool mPagedContextFMHA = false;
|
||||
bool mFP8ContextFMHA = false;
|
||||
bool mDenseContextFMHA = false;
|
||||
bool mIsMedusaEnabled = false;
|
||||
|
||||
|
||||
@ -47,13 +47,13 @@ GPTAttentionPlugin::GPTAttentionPlugin(int layer_idx, int num_heads, int num_kv_
|
||||
int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type,
|
||||
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
|
||||
bool qkv_bias_enabled, bool cross_attention, int max_distance, bool pos_shift_enabled, bool dense_context_fmha,
|
||||
bool use_paged_context_fmha, bool use_cache, bool is_medusa_enabled)
|
||||
bool use_paged_context_fmha, bool use_fp8_context_fmha, bool use_cache, bool is_medusa_enabled)
|
||||
: GPTAttentionPluginCommon(layer_idx, num_heads, num_kv_heads, head_size, unidirectional, q_scaling,
|
||||
position_embedding_type, rotary_embedding_dim, rotary_embedding_base, rotary_embedding_scale_type,
|
||||
rotary_embedding_scale, rotary_embedding_max_positions, tp_size, tp_rank, unfuse_qkv_gemm, context_fmha_type,
|
||||
multi_block_mode, enable_xqa, kv_cache_quant_mode, remove_input_padding, mask_type, paged_kv_cache,
|
||||
tokens_per_block, type, max_context_length, qkv_bias_enabled, cross_attention, max_distance, pos_shift_enabled,
|
||||
dense_context_fmha, use_paged_context_fmha, use_cache, is_medusa_enabled)
|
||||
dense_context_fmha, use_paged_context_fmha, use_fp8_context_fmha, use_cache, is_medusa_enabled)
|
||||
{
|
||||
initEntryIdx();
|
||||
}
|
||||
@ -83,6 +83,7 @@ bool GPTAttentionPlugin::isEntryUsed(IdxEntry const& entry) const
|
||||
case IdxEntry::PAST_KEY_VALUE: return useKVCache() && !mPagedKVCache;
|
||||
case IdxEntry::KV_CACHE_QUANTIZATION_SCALE: return useKVCache() && mKVCacheQuantMode.hasKvCacheQuant();
|
||||
case IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE: return useKVCache() && mKVCacheQuantMode.hasKvCacheQuant();
|
||||
case IdxEntry::ATTENTION_OUTPUT_QUANTIZATION_SCALE: return mFP8ContextFMHA && mKVCacheQuantMode.hasFp8Qdq();
|
||||
case IdxEntry::ALIBI_SLOPES: return isALiBi();
|
||||
case IdxEntry::RELATIVE_ATTENTION_BIAS: return isRelativePosition();
|
||||
case IdxEntry::CROSS_QKV: return isCrossAttention();
|
||||
@ -186,6 +187,10 @@ bool GPTAttentionPlugin::supportsFormatCombination(
|
||||
// kv_scale for mType->int8/fp8 and int8/fp8->mType conversion
|
||||
return inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format == TensorFormat::kLINEAR;
|
||||
}
|
||||
else if (mFP8ContextFMHA && pos == getIdx(IdxEntry::ATTENTION_OUTPUT_QUANTIZATION_SCALE))
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kFLOAT && inOut[pos].format == TensorFormat::kLINEAR;
|
||||
}
|
||||
else if (mPagedKVCache
|
||||
&& (pos == getIdx(IdxEntry::KV_CACHE_BLOCK_POINTERS) || pos == getIdx(IdxEntry::HOST_KV_CACHE_BLOCK_POINTERS)))
|
||||
{
|
||||
@ -213,6 +218,11 @@ bool GPTAttentionPlugin::supportsFormatCombination(
|
||||
{
|
||||
return inOut[pos].type == nvinfer1::DataType::kINT32;
|
||||
}
|
||||
else if (pos == nbInputs && mFP8ContextFMHA)
|
||||
{
|
||||
// Output tensor now supports fp8 data type.
|
||||
return (inOut[pos].type == nvinfer1::DataType::kFP8) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
}
|
||||
else
|
||||
{
|
||||
return (inOut[pos].type == mType) && (inOut[pos].format == TensorFormat::kLINEAR);
|
||||
@ -247,6 +257,7 @@ void GPTAttentionPlugin::configurePluginImpl(nvinfer1::DynamicPluginTensorDesc c
|
||||
/*context_lengths=*/nullptr,
|
||||
/*kv_scale_orig_quant=*/nullptr,
|
||||
/*kv_scale_quant_orig=*/nullptr,
|
||||
/*attention_out_orig_quant=*/nullptr,
|
||||
/*alibi_slopes=*/nullptr,
|
||||
/*context_buf_=*/nullptr,
|
||||
/*key_value_cache=*/nullptr,
|
||||
@ -498,6 +509,14 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
kv_scale_quant_orig = reinterpret_cast<float const*>(inputs[getIdx(IdxEntry::KV_CACHE_DEQUANTIZATION_SCALE)]);
|
||||
}
|
||||
|
||||
float const* attention_output_orig_quant = nullptr;
|
||||
if (mFP8ContextFMHA)
|
||||
{
|
||||
assert(inputDesc[getIdx(IdxEntry::ATTENTION_OUTPUT_QUANTIZATION_SCALE)].type == nvinfer1::DataType::kFLOAT);
|
||||
attention_output_orig_quant
|
||||
= reinterpret_cast<float const*>(inputs[getIdx(IdxEntry::ATTENTION_OUTPUT_QUANTIZATION_SCALE)]);
|
||||
}
|
||||
|
||||
int max_blocks_per_sequence = 0;
|
||||
void* block_pointers = nullptr;
|
||||
void* host_block_pointers = nullptr;
|
||||
@ -584,9 +603,9 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
|
||||
EnqueueContextParams<T, KVCacheBuffer> enqueue_params{attention_input, qkv_bias, max_context_q_len,
|
||||
max_context_kv_len, max_attention_window_size, cyclic_attention_window_size, sink_token_length,
|
||||
context_q_lengths, sequence_kv_length, kv_scale_orig_quant, kv_scale_quant_orig, alibi_slopes, context_buf_,
|
||||
key_value_cache, block_pointers, host_block_pointers, batch_size, localNbTokens, max_blocks_per_sequence,
|
||||
workspace};
|
||||
context_q_lengths, sequence_kv_length, kv_scale_orig_quant, kv_scale_quant_orig,
|
||||
attention_output_orig_quant, alibi_slopes, context_buf_, key_value_cache, block_pointers,
|
||||
host_block_pointers, batch_size, localNbTokens, max_blocks_per_sequence, workspace};
|
||||
if (isRelativePosition())
|
||||
{
|
||||
enqueue_params.relative_attention_bias
|
||||
@ -627,9 +646,9 @@ int GPTAttentionPlugin::enqueueSome(int32_t seqIdxBeg, int32_t localNbSeq, int32
|
||||
TLLM_CHECK_WITH_INFO(input_seq_length == num_medusa_tokens + 1, "The generation input length is not expected.");
|
||||
EnqueueGenerationParams<T, KVCacheBuffer> enqueue_params{attention_input, qkv_bias, input_seq_length,
|
||||
sequence_kv_length, max_context_kv_len, beamWidth, context_q_lengths, kv_scale_orig_quant,
|
||||
kv_scale_quant_orig, alibi_slopes, context_buf_, key_value_cache, block_pointers, max_attention_window_size,
|
||||
cyclic_attention_window_size, sink_token_length, num_requests, max_blocks_per_sequence, cache_indir,
|
||||
workspace, max_context_kv_len_list};
|
||||
kv_scale_quant_orig, attention_output_orig_quant, alibi_slopes, context_buf_, key_value_cache,
|
||||
block_pointers, max_attention_window_size, cyclic_attention_window_size, sink_token_length, num_requests,
|
||||
max_blocks_per_sequence, cache_indir, workspace, max_context_kv_len_list};
|
||||
enqueue_params.host_context_lengths = host_context_lengths;
|
||||
if (isRelativePosition())
|
||||
{
|
||||
@ -699,7 +718,8 @@ nvinfer1::DataType GPTAttentionPlugin::getOutputDataType(
|
||||
TLLM_CHECK(index == 0 || (!mPagedKVCache && index == 1));
|
||||
if (index == 0)
|
||||
{
|
||||
return inputTypes[getIdx(IdxEntry::QKV_TENSOR)];
|
||||
return mFP8ContextFMHA && mEnableContextFMHA ? nvinfer1::DataType::kFP8
|
||||
: inputTypes[getIdx(IdxEntry::QKV_TENSOR)];
|
||||
}
|
||||
else
|
||||
{
|
||||
@ -794,6 +814,7 @@ IPluginV2* GPTAttentionPluginCreator::createPlugin(char const* name, PluginField
|
||||
static_cast<bool>(p.getScalar<int8_t>("pos_shift_enabled").value()),
|
||||
static_cast<bool>(p.getScalar<int8_t>("dense_context_fmha").value()),
|
||||
static_cast<bool>(p.getScalar<int8_t>("use_paged_context_fmha").value()),
|
||||
static_cast<bool>(p.getScalar<int8_t>("use_fp8_context_fmha").value()),
|
||||
static_cast<bool>(p.getScalar<int32_t>("use_cache").value()),
|
||||
static_cast<bool>(p.getScalar<int8_t>("is_medusa_enabled").value()));
|
||||
obj->setPluginNamespace(mNamespace.c_str());
|
||||
|
||||
@ -80,8 +80,8 @@ public:
|
||||
int kv_cache_quant_mode, bool remove_input_padding, tensorrt_llm::kernels::AttentionMaskType mask_type,
|
||||
bool paged_kv_cache, int tokens_per_block, nvinfer1::DataType type, int32_t max_context_length,
|
||||
bool qkv_bias_enabled, bool cross_attention = false, int max_distance = 0, bool pos_shift_enabled = false,
|
||||
bool dense_context_fmha = false, bool use_paged_context_fmha = false, bool use_cache = true,
|
||||
bool is_medusa_enabled = false);
|
||||
bool dense_context_fmha = false, bool use_paged_context_fmha = false, bool use_fp8_context_fmha = false,
|
||||
bool use_cache = true, bool is_medusa_enabled = false);
|
||||
|
||||
GPTAttentionPlugin(void const* data, size_t length);
|
||||
|
||||
@ -164,6 +164,7 @@ private:
|
||||
PAST_KEY_VALUE,
|
||||
KV_CACHE_QUANTIZATION_SCALE,
|
||||
KV_CACHE_DEQUANTIZATION_SCALE,
|
||||
ATTENTION_OUTPUT_QUANTIZATION_SCALE,
|
||||
ALIBI_SLOPES,
|
||||
RELATIVE_ATTENTION_BIAS,
|
||||
CROSS_QKV,
|
||||
|
||||
21
cpp/tensorrt_llm/plugins/mambaConv1dPlugin/CMakeLists.txt
Normal file
21
cpp/tensorrt_llm/plugins/mambaConv1dPlugin/CMakeLists.txt
Normal file
@ -0,0 +1,21 @@
|
||||
#
|
||||
# SPDX-FileCopyrightText: Copyright (c) 1993-2022 NVIDIA CORPORATION &
|
||||
# AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
|
||||
# use this file except in compliance with the License. You may obtain a copy of
|
||||
# the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations under
|
||||
# the License.
|
||||
#
|
||||
file(GLOB SRCS *.cpp)
|
||||
set(PLUGIN_SOURCES ${PLUGIN_SOURCES} ${SRCS})
|
||||
set(PLUGIN_SOURCES
|
||||
${PLUGIN_SOURCES}
|
||||
PARENT_SCOPE)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user